This page explains how to estimate parameters in an ODE model using Python.

For this example, an SIR model will be fit to data from a flu outbreak in an English boarding school in 1978. A total of 763 boys were at risk and 512 became infected. An article on the outbreak can be found in the British Medical Journal, while the exact data used in this example is taken from Dr. Brian Reich's fitting tutorial.

Unlike the MATLAB version of this tutorial, functions should be defined before being called; in MATLAB, all functions must go at the end of the script.

Mild disclaimer: I am learning Python myself, so this tutorial may not include “best practices” but it’s a good starting point 😊.



The SIR Model

The Susceptible-Infected-Recovered (SIR) model is described by the following equations (where a dot denotes differentiation with respect to time t):

$$ \begin{align*}\dot{S} &= -\frac{\beta SI}{N},\\ \dot{I} &= \frac{\beta SI}{N} - \gamma I,\\ \dot{R} &= \gamma I.\end{align*} $$

The total population is denoted $N(t) = S(t) + I(t) + R(t)$. The contact rate per day is denoted by $\beta$, and the recovery rate per day is denoted by $\gamma$. Note that $1/\gamma$ gives the average time it takes (in days) to recover.

Housekeeping

Here, we load the relevant Python libraries and get our dataset ready.

import numpy as np
import matplotlib.pyplot as plt
from scipy import integrate, optimize

ydata = [3,8,26,76,225,298,258,233,189,128,68,29,14,4];
xdata = [1,2,3,4,5,6,7,8,9,10,11,12,13,14];

ydata = np.array(ydata, dtype=float)
xdata = np.array(xdata, dtype=float)

Note that we have to convert xdata and ydata from lists to Numpy arrays of floats.

Functions

We need to write out our ODE system and a numerical integrator for the ODE system.

The model

def SIR(y, t, beta, gamma):
    ####################################
    # ODE system for fitting.
    # y: vector of state variables
    # t: independent variable (time)
    # beta, gamma: parameters to be fit
    ####################################

    S = y[0]; I = y[1]; R = y[2];
    N = S + I + R

    Sdot = -beta*S*I/N
    Idot = beta*S*I/N - gamma*I
    Rdot = gamma*I
    
    return Sdot, Idot, Rdot