SRI Immunization Model

The Code

Here's the code for the model relying only on numpy, networkx, holoviews and matplotlib in the background.

In [1]:
import collections
import itertools
import math

import numpy as np
import numpy.random as rnd
import networkx as nx

import param
import holoviews as hv

class SRI_Model(param.Parameterized):
    Implementation of the SRI epidemiology model
    using NetworkX and HoloViews for visualization.
    This code has been adapted from Simon Dobson's
    code here:
    In addition to his basic parameters I've added
    additional states to the model, a node may be
    in one of the following states:
      * Susceptible: Can catch the disease from a connected node.
      * Vaccinated: Immune to infection.
      * Infected: Has the disease and may pass it on to any connected node.
      * Recovered: Immune to infection.
      * Dead: Edges are removed from graph.

    network = param.ClassSelector(class_=nx.Graph, default=None, doc="""
        A custom NetworkX graph, instead of the default Erdos-Renyi graph.""")
    visualize = param.Boolean(default=True, doc="""
        Whether to compute layout of network for visualization.""")
    N = param.Integer(default=1000, doc="""
        Number of nodes to simulate.""")
    mean_connections = param.Number(default=10, doc="""
        Mean number of connections to make to other nodes.""")
    pSick = param.Number(default=0.01, doc="""
        Probability of a node to be initialized in sick state.""", bounds=(0, 1))

    pVaccinated = param.Number(default=0.1, bounds=(0, 1), doc="""
        Probability of a node to be initialized in vaccinated state.""")
    pInfect = param.Number(default=0.3, doc="""
        Probability of infection on each time step.""", bounds=(0, 1))
    pRecover = param.Number(default=0.05, doc="""
        Probability of recovering if infected on each timestep.""", bounds=(0, 1))
    pDeath = param.Number(default=0.1, doc="""
        Probability of death if infected on each timestep.""", bounds=(0, 1))
    DEAD = 'D'

    def __init__(self, **params):
        super(SRI_Model, self).__init__(**params)
        if not
            self.g = nx.erdos_renyi_graph(self.N, float(self.mean_connections)/self.N)
            self.g =
        self.vaccinated, self.infected = self.spreading_init()
        self.model = self.spreading_make_sir_model()
        self.color_mapping = [self.SPREADING_SUSCEPTIBLE,
                              self.SPREADING_RECOVERED, self.DEAD]
        if self.visualize:
            k = 2/(math.sqrt(self.g.order()))
            self.pos = hv.Graph.from_networkx(self.g, nx.spring_layout, iterations=50, k=k)

    def spreading_init(self):
        """Initialise the network with vaccinated, susceptible and infected states."""
        vaccinated, infected = 0, []
        for i in self.g.node.keys():
            self.g.node[i]['transmissions'] = 0
            if(rnd.random() <= self.pVaccinated): 
                self.g.node[i]['state'] = self.SPREADING_VACCINATED
                vaccinated += 1
            elif(rnd.random() <= self.pSick):
                self.g.node[i]['state'] = self.SPREADING_INFECTED
                self.g.node[i]['state'] = self.SPREADING_SUSCEPTIBLE
        return vaccinated, infected

    def spreading_make_sir_model(self):
        """Return an SIR model function for given infection and recovery probabilities."""
        # model (local rule) function
        def model( g, i ):
            if g.node[i]['state'] == self.SPREADING_INFECTED:
                # infect susceptible neighbours with probability pInfect
                for m in g.neighbors(i):
                    if g.node[m]['state'] == self.SPREADING_SUSCEPTIBLE:
                        if rnd.random() <= self.pInfect:
                            g.node[m]['state'] = self.SPREADING_INFECTED
                            g.node[i]['transmissions'] += 1

                # recover with probability pRecover
                if rnd.random() <= self.pRecover:
                    g.node[i]['state'] = self.SPREADING_RECOVERED
                elif rnd.random() <= self.pDeath:
                    edges = [edge for edge in self.g.edges() if i in edge] 
                    g.node[i]['state'] = self.DEAD

        return model

    def step(self):
        """Run a single step of the model over the graph."""
        for i in self.g.node.keys():
            self.model(self.g, i)

    def run(self, steps):
        Run the network for the specified number of time steps
        for i in range(steps):

    def stats(self):
        Return an ItemTable with statistics on the network data.
        state_labels = hv.OrderedDict([('S', 'Susceptible'), ('V', 'Vaccinated'), ('I', 'Infected'),
                                    ('R', 'Recovered'), ('D', 'Dead')])
        counts = collections.Counter()
        transmissions = []
        for n in self.g.nodes():
            state = state_labels[self.g.node[n]['state']]
            counts[state] += 1
            if n in self.infected:
        data = hv.OrderedDict([(l, counts[l])
                               for l in state_labels.values()])
        infected = len(set(self.infected))
        unvaccinated = float(self.N-self.vaccinated)
        data['$R_0$'] = np.mean(transmissions) if transmissions else 0
        data['Death rate DR'] = np.divide(float(data['Dead']),self.N)
        data['Infection rate IR'] = np.divide(float(infected), self.N)
        if unvaccinated:
            unvaccinated_dr = data['Dead']/unvaccinated
            unvaccinated_ir = infected/unvaccinated
            unvaccinated_dr = 0
            unvaccinated_ir = 0
        data['Unvaccinated DR'] = unvaccinated_dr
        data['Unvaccinated IR'] = unvaccinated_ir
        return hv.ItemTable(data)

    def animate(self, steps):
        Run the network for the specified number of steps accumulating animations
        of the network nodes and edges changing states and curves tracking the
        spread of the disease.
        if not self.visualize:
            raise Exception("Enable visualize option to get compute network visulizations.")

        # Declare HoloMap for network animation and counts array
        network_hmap = hv.HoloMap(kdims='Time')
        sird = np.zeros((steps, 5))
        # Declare dimensions and labels
        spatial_dims = [hv.Dimension('x', range=(-.1, 1.1)),
                        hv.Dimension('y', range=(-.1, 1.1))]
        state_labels = ['Susceptible', 'Vaccinated', 'Infected', 'Recovered', 'Dead']

        # Text annotation
        nlabel = hv.Text(0.9, 0.05, 'N=%d' % self.N)

        for i in range(steps):
            # Get path, point, states and count data
            states = [self.g.node[n]['state'] for n in self.g.nodes()]
            state_ints = [self.color_mapping.index(v) for v in states]
            state_array = np.array(state_ints, ndmin=2).T
            (sird[i, :], _) = np.histogram(state_array, bins=list(range(6)))

            # Create network path and node Elements
            nodes = self.pos.nodes.clone(datatype=['dictionary'])
            nodes = nodes.add_dimension('State', 0, states, True)
            graph = self.pos.clone((, nodes))
            # Create overlay and accumulate in network HoloMap
            network_hmap[i] = (graph * nlabel).relabel(group='Network', label='SRI')

        # Create Overlay of Curves
        #extents = (-1, -1, steps, np.max(sird)+2)
        curves = hv.NdOverlay({label: hv.Curve(zip(range(steps), sird[:, i]),
                                               'Time', 'Count')
                              for i, label in enumerate(state_labels)},
                              kdims=[hv.Dimension('State', values=state_labels)])
        # Animate VLine on top of Curves
        distribution = hv.HoloMap({i: (curves * hv.VLine(i)).relabel(group='Counts', label='SRI')
                                   for i in range(steps)}, kdims='Time')
        return network_hmap + distribution

The style

HoloViews allows use to define various style options in advance on the Store.options object.

In [2]:
hv.extension('bokeh', 'matplotlib')

# Set colors and style options for the Element types
from holoviews import Store, Options
from holoviews.core.options import Palette
opts = Store.options()

opts.Graph     = Options('plot', color_index='State')
opts.Graph     = Options('style', node_size=6, edge_line_width=1)
opts.Histogram = Options('plot', show_grid=False)
opts.Overlay   = Options('plot', show_frame=False)
opts.HeatMap   = Options('plot', xrotation=90)
opts.ItemTable = Options('plot', width=900, height=50)

opts.Overlay.Network = Options('plot', xaxis=None, yaxis=None)
opts.Overlay.Counts  = Options('plot', show_grid=True)

opts.VLine     = {'style': Options(color='black', line_width=1),
                  'plot':  Options(show_grid=True)}