diff --git a/examples/network-hhmodel/README.md b/examples/network-hhmodel/README.md index d338d09e..d09af73a 100644 --- a/examples/network-hhmodel/README.md +++ b/examples/network-hhmodel/README.md @@ -1,29 +1,31 @@ # Example: Agent-based SEIR model with contact networks -This example demonstrates the use of the `network` module of `ixa`. +This example demonstrates a network model in `ixa`. -There are three CSV files: +There are three data files: - `Households.csv` represents 500 households of size 1-12. Individuals have age category and sex properties. Within the model, these individuals are placed in a densely connected network. -- `AgeUnder5Edges.csv` contains the edges connecting those aged under 5 +- `AgeUnder5Edges.csv` contains the edges connecting those aged under 5. - `Age5to17Edges.csv` contains the edges connecting those aged 5-17. -In `network.rs`, three corresponding edge types are created using the -`define_edge_type!` macro and the networks are formed by adding edges to the -context using `add_edge_bidi`. +The parameter `sar` is the secondary attack rate within each household, used to compute the within-household transmission rate. The parameter `relative_rate` is the ratio of the transmission rates between versus within households; it should be less than one. -In `seir.rs`, a SEIR model is implemented with different betas by network edge -type. Edge queries (`get_matching_edges`) allow us to identify the neighbors of -the infected individuals and consider whether they become exposed. +The simulation runs via: -`loader.rs` reads in the `Household.csv` file and `parameters.rs` sets up global -properties for the SEIR model. +- `parameters.rs` sets up global properties for the model. +- `loader.rs` reads in the `Household.csv` file and instantiates the people in it. +- `network.rs` forms a dense network of household contacts, then reads in the other + contact files and instantiates those network edges. The edges are tracked as + model entities. This module also selects which individuals have effective contact + during each time period. +- `seir.rs` manages transmission, infections, and disease trajectories. +- `incidence_report.rs` sets up a report with information on who became infected + by whom during the simulation and saves the information to a csv in an `\output` + folder. -`incidence_report.rs` sets up a report with information on who became infected -by whom during the simulation and saves the information to a csv in an `\output` -folder. +Note that the relative rate of transmission between households (relative to within households) is a property of the network edges. For technical reasons, ixa properties must implement `Eq`, which Rust floats do not. This example manually implements equality logic; future ixa versions may have other solutions. ## How to run the model diff --git a/examples/network-hhmodel/config.json b/examples/network-hhmodel/config.json index b3b0deee..ea756525 100644 --- a/examples/network-hhmodel/config.json +++ b/examples/network-hhmodel/config.json @@ -5,7 +5,7 @@ "sar": 0.15, "shape": 15.0, "infection_duration": 5.0, - "between_hh_transmission_reduction": 3.0, + "relative_rate": 0.333, "output_dir": "examples/network-hhmodel/output", "data_dir": "examples/network-hhmodel/data" } diff --git a/examples/network-hhmodel/incidence_report.rs b/examples/network-hhmodel/incidence_report.rs index 874d9232..b7d80472 100644 --- a/examples/network-hhmodel/incidence_report.rs +++ b/examples/network-hhmodel/incidence_report.rs @@ -124,7 +124,7 @@ mod test { sar: 1.0, shape: 15.0, infection_duration: 5.0, - between_hh_transmission_reduction: 1.0, + relative_rate: 1.0, data_dir: output_dir.to_str().unwrap().to_string(), output_dir: output_dir.to_str().unwrap().to_string(), }; @@ -145,8 +145,8 @@ mod test { .set_global_property_value(Parameters, parameters.clone()) .unwrap(); - let people = loader::init(&mut context); - network::init(&mut context, &people); + loader::init(&mut context); + network::init(&mut context, 1.0); incidence_report::init(&mut context).unwrap(); context.subscribe_to_event( @@ -158,7 +158,7 @@ mod test { let to_infect: Vec = vec![context.sample_entity(MainRng, Person).unwrap()]; #[allow(clippy::vec_init_then_push)] - seir::init(&mut context, &to_infect); + seir::init(&mut context, &to_infect, 1.0); context.execute(); } diff --git a/examples/network-hhmodel/loader.rs b/examples/network-hhmodel/loader.rs index 378d5fc3..f9fd31d1 100644 --- a/examples/network-hhmodel/loader.rs +++ b/examples/network-hhmodel/loader.rs @@ -5,7 +5,7 @@ use ixa::impl_property; use ixa::prelude::*; use serde::{Deserialize, Serialize}; -use crate::{example_dir, Person, PersonId}; +use crate::{example_dir, Person}; #[derive(Serialize, Deserialize, Copy, Clone, PartialEq, Eq, Debug, Hash)] pub struct Id(pub u16); @@ -39,45 +39,32 @@ struct PeopleRecord { household_id: HouseholdId, } -fn create_person_from_record(context: &mut Context, record: &PeopleRecord) -> PersonId { - context - .add_entity(with!( - Person, - record.id, - record.age_group, - record.sex, - record.household_id - )) - .unwrap() -} - pub fn open_csv(file_name: &str) -> Reader { let current_dir = example_dir(); let file_path = current_dir.join(file_name); csv::Reader::from_path(file_path).unwrap() } -pub fn init(context: &mut Context) -> Vec { +pub fn init(context: &mut Context) { // Load csv and deserialize records let mut reader = open_csv("Households.csv"); - let mut people = Vec::new(); for result in reader.deserialize() { let record: PeopleRecord = result.expect("Failed to parse record"); - people.push(create_person_from_record(context, &record)); + context + .add_entity(with!( + Person, + record.id, + record.age_group, + record.sex, + record.household_id + )) + .unwrap(); } - - context.index_property::(); - context.index_property::(); - - people } #[cfg(test)] mod tests { - use ixa::context::Context; - use ixa::random::ContextRandomExt; - use super::*; const EXPECTED_ROWS: usize = 1606; @@ -85,52 +72,52 @@ mod tests { #[test] fn test_init_expected_rows() { let mut context = Context::new(); - context.init_random(42); init(&mut context); assert_eq!(context.get_entity_count::(), EXPECTED_ROWS); } + // Check there is exactly one matching entity + fn assert_exists1( + context: &Context, + id: Id, + age_group: AgeGroup, + sex: Sex, + hh_id: HouseholdId, + ) { + assert_eq!( + context.query_entity_count(with!(Person, id, age_group, sex, hh_id)), + 1 + ); + } + #[test] fn test_some_people_load_correctly() { let mut context = Context::new(); - context.init_random(42); - - let people = init(&mut context); + init(&mut context); - let person = people[0]; - assert!(context.match_entity( - person, - with!( - Person, - Id(676), - AgeGroup::Age18to64, - Sex::Female, - HouseholdId(1) - ) - )); - - let person = people[246]; - assert!(context.match_entity( - person, - with!( - Person, - Id(213), - AgeGroup::AgeUnder5, - Sex::Female, - HouseholdId(162) - ) - )); - - let person = people[1591]; - assert!(context.match_entity( - person, - with!( - Person, - Id(1591), - AgeGroup::Age65Plus, - Sex::Male, - HouseholdId(496) - ) - )); + // e.g., the person with data id 676 should be 18-64, female, in household 1 + assert_exists1( + &context, + Id(676), + AgeGroup::Age18to64, + Sex::Female, + HouseholdId(1), + ); + + assert_exists1( + &context, + Id(213), + AgeGroup::AgeUnder5, + Sex::Female, + HouseholdId(162), + ); + + assert_exists1( + &context, + Id(1591), + AgeGroup::Age65Plus, + Sex::Male, + HouseholdId(496), + ); } } diff --git a/examples/network-hhmodel/main.rs b/examples/network-hhmodel/main.rs index fee887dc..98556154 100644 --- a/examples/network-hhmodel/main.rs +++ b/examples/network-hhmodel/main.rs @@ -7,6 +7,8 @@ mod parameters; mod seir; use std::path::PathBuf; +use parameters::Parameters; + define_entity!(Person); define_rng!(MainRng); @@ -27,14 +29,19 @@ fn initialize(context: &mut Context) { context.init_random(1); // Load people from csv and set up some base properties - let people = loader::init(context); + loader::init(context); // Load parameters from json let file_path = example_dir().join("config.json"); context.load_global_properties(&file_path).unwrap(); + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); + // Load network - network::init(context, &people); + network::init(context, parameters.relative_rate); // Initialize incidence report incidence_report::init(context).unwrap(); @@ -43,6 +50,5 @@ fn initialize(context: &mut Context) { let to_infect: Vec = vec![context.sample_entity(MainRng, Person).unwrap()]; #[allow(clippy::vec_init_then_push)] - seir::init(context, &to_infect); - context.execute(); + seir::init(context, &to_infect, 1.0); } diff --git a/examples/network-hhmodel/network.rs b/examples/network-hhmodel/network.rs index e9bd7061..3be2456d 100644 --- a/examples/network-hhmodel/network.rs +++ b/examples/network-hhmodel/network.rs @@ -1,111 +1,147 @@ -use ixa::network::edge::EdgeType; +use std::hash::Hash; + use ixa::prelude::*; use ixa::{HashSet, HashSetExt}; -use serde::Deserialize; +use rand_distr::Bernoulli; use crate::loader::{open_csv, HouseholdId, Id}; +use crate::parameters::Parameters; use crate::{Person, PersonId}; -define_edge_type!(struct Household, Person); -define_edge_type!(struct AgeUnder5, Person); -define_edge_type!(struct Age5to17, Person); +define_entity!(Edge); +define_property!(struct RelativeRate(f64), Edge, impl_eq_hash = both); +define_property!(struct Node1(PersonId), Edge); +define_property!(struct Node2(PersonId), Edge); + +define_rng!(NetworkRng); + +fn add_bidi_edge(context: &mut Context, p1: PersonId, p2: PersonId, relative_rate: f64) { + let rr = RelativeRate(relative_rate); -#[derive(Deserialize, Debug)] -struct EdgeRecord { - v1: u16, - v2: u16, + context + .add_entity(with!(Edge, Node1(p1), Node2(p2), rr)) + .unwrap(); + context + .add_entity(with!(Edge, Node2(p1), Node1(p2), rr)) + .unwrap(); } -fn create_household_networks(context: &mut Context, people: &[PersonId]) { +fn create_household_networks(context: &mut Context) { let mut households = HashSet::new(); + let people: Vec = context.query(with!(Person)).into_iter().collect(); + + // for every person, check what household they are in for person_id in people { - let household_id: HouseholdId = context.get_property(*person_id); + let household_id: HouseholdId = context.get_property(person_id); + // if we haven't seen this household before, find all its member, + // and connect them in a dense network if households.insert(household_id) { - let mut members: Vec = Vec::new(); - context.with_query_results(with!(Person, household_id), &mut |results| { - members = results.to_owned_vec() - }); - // create a dense network - while let Some(person) = members.pop() { - for other_person in &members { - context - .add_edge_bidi(person, *other_person, 1.0, Household) - .unwrap(); + let members: Vec = context + .query(with!(Person, household_id)) + .into_iter() + .collect(); + + for i in 0..(members.len() - 1) { + for j in (i + 1)..(members.len()) { + // by definition, edge rates are measured relative to within-household rates + add_bidi_edge(context, members[i], members[j], 1.0); } } } } } -fn load_edge_list>(context: &mut Context, file_name: &str, inner: ET) { +// Assert there is only one person with data ID `id`, then get their entity ID +fn get_entity_id_by_data_id(context: &mut Context, id: u16) -> PersonId { + let v: Vec = context.query(with!(Person, Id(id))).into_iter().collect(); + assert_eq!(v.len(), 1); + v[0] +} + +fn load_edge_list(context: &mut Context, file_name: &str, relative_rate: f64) { let mut reader = open_csv(file_name); for result in reader.deserialize() { - let record: EdgeRecord = result.expect("Failed to parse edge"); - let mut p1_vec = Vec::new(); - context.with_query_results(with!(Person, Id(record.v1)), &mut |people| { - p1_vec = people.to_owned_vec() - }); - assert_eq!(p1_vec.len(), 1); - let p1 = p1_vec[0]; - let mut p2_vec = Vec::new(); - context.with_query_results(with!(Person, Id(record.v2)), &mut |people| { - p2_vec = people.to_owned_vec() - }); - assert_eq!(p2_vec.len(), 1); - let p2 = p2_vec[0]; - context.add_edge_bidi(p1, p2, 1.0, inner.clone()).unwrap(); + let record: (u16, u16) = result.expect("Failed to parse edge"); + let p1 = get_entity_id_by_data_id(context, record.0); + let p2 = get_entity_id_by_data_id(context, record.1); + add_bidi_edge(context, p1, p2, relative_rate); } } -pub fn init(context: &mut Context, people: &[PersonId]) { - // Create dense household networks - create_household_networks(context, people); +// Assuming that time moves in steps of duration, what is the per-step probability of transmission? +fn sar_to_prob(sar: f64, infectious_period: f64, duration: f64) -> f64 { + 1.0 - (1.0 - sar).powf(duration / infectious_period) +} - // Add U5 edges from csv - load_edge_list(context, "AgeUnder5Edges.csv", AgeUnder5); +/// Get all the effective contacts a person will have over a certain duration +pub fn get_contacts(context: &Context, person_id: PersonId, duration: f64) -> HashSet { + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); - // Add U18 edges from csv - load_edge_list(context, "Age5to17Edges.csv", Age5to17); + // Base probability of transmission during the duration. + let base_prob = duration * sar_to_prob(parameters.sar, parameters.incubation_period, duration); + assert!(base_prob <= 1.0); + + // Find all the people this person has edges to. Those people are contacts in this + // duration, with a certain probability + let mut contacts = HashSet::new(); + for edge in context.query(with!(Edge, Node1(person_id))) { + let RelativeRate(relative_rate) = context.get_property(edge); + let Node2(person2) = context.get_property(edge); + if context.sample_distr( + NetworkRng, + Bernoulli::new(base_prob * relative_rate).unwrap(), + ) { + contacts.insert(person2); + } + } + + contacts +} + +// `rr`: relative rate of transmission between (vs. within) households +pub fn init(context: &mut Context, relative_rate: f64) { + // Create dense household networks + create_household_networks(context); + // Add other edges from csv's with lower transmission rate + load_edge_list(context, "AgeUnder5Edges.csv", relative_rate); + load_edge_list(context, "Age5to17Edges.csv", relative_rate); } #[cfg(test)] mod tests { - use super::*; use crate::{loader, network}; - const N_SIZE_12: usize = 1; - const N_SIZE_11: usize = 1; - const N_SIZE_3: usize = 122; - - #[test] - fn test_expected_12_member_household() { + // Assert that person with `id` has `n` contacts (i.e., edges going from + // them, and also edges going to them) + fn assert_has_n_contacts(id: u16, n: usize) { let mut context = Context::new(); - context.init_random(42); - let people = loader::init(&mut context); - network::init(&mut context, &people); - let deg11 = context.find_entities_by_degree::(11); - assert_eq!(deg11.len(), 12 * N_SIZE_12); + loader::init(&mut context); + network::init(&mut context, 1.0); + + let eid = get_entity_id_by_data_id(&mut context, id); + + let n_to = context.query_entity_count(with!(Edge, Node1(eid))); + let n_from = context.query_entity_count(with!(Edge, Node2(eid))); + assert_eq!(n_to, n); + assert_eq!(n_from, n); } #[test] - fn test_expected_11_member_household() { - let mut context = Context::new(); - context.init_random(42); - let people = loader::init(&mut context); - network::init(&mut context, &people); - let deg10 = context.find_entities_by_degree::(10); - assert_eq!(deg10.len(), 11 * N_SIZE_11); + fn test_person_826() { + // Person 826 is in a household of 5 with no other contacts. + // There should be 4 edges going from them, and 4 going to them. + assert_has_n_contacts(826, 4); } #[test] - fn test_expected_3_member_household() { - let mut context = Context::new(); - context.init_random(42); - let people = loader::init(&mut context); - network::init(&mut context, &people); - let deg10 = context.find_entities_by_degree::(2); - assert_eq!(deg10.len(), 3 * N_SIZE_3); + fn test_person_243() { + // Person 243 is in a household of size 6 (i.e., 5 hh contacts) + // and has 4 other contacts + assert_has_n_contacts(243, 5 + 4); } } diff --git a/examples/network-hhmodel/parameters.rs b/examples/network-hhmodel/parameters.rs index a32c4bce..8e0b43ef 100644 --- a/examples/network-hhmodel/parameters.rs +++ b/examples/network-hhmodel/parameters.rs @@ -11,7 +11,7 @@ pub struct ParametersValues { pub sar: f64, pub shape: f64, pub infection_duration: f64, - pub between_hh_transmission_reduction: f64, + pub relative_rate: f64, pub output_dir: String, pub data_dir: String, } diff --git a/examples/network-hhmodel/seir.rs b/examples/network-hhmodel/seir.rs index 80aa7b80..3609187b 100644 --- a/examples/network-hhmodel/seir.rs +++ b/examples/network-hhmodel/seir.rs @@ -1,11 +1,10 @@ use ixa::log::info; -use ixa::network::edge::EdgeType; use ixa::prelude::*; -use ixa::{impl_property, ExecutionPhase}; -use rand_distr::{Bernoulli, Gamma}; +use ixa::{impl_property, ExecutionPhase, HashSet, HashSetExt}; +use rand_distr::Gamma; use serde::{Deserialize, Serialize}; -use crate::network::{Age5to17, AgeUnder5, Household}; +use crate::network::get_contacts; use crate::parameters::Parameters; use crate::{Person, PersonId}; @@ -27,37 +26,20 @@ define_property!( default_const = InfectedBy(None) ); -fn sar_to_beta(sar: f64, infectious_period: f64) -> f64 { - 1.0 - (1.0 - sar).powf(1.0 / infectious_period) -} - fn calculate_waiting_time(context: &Context, shape: f64, mean_period: f64) -> f64 { let d = Gamma::new(shape, mean_period / shape).unwrap(); context.sample_distr(SeirRng, d) } -fn expose_network>(context: &mut Context, beta: f64) { - let infectious_people = context - .query(with!(Person, DiseaseStatus::I)) - .to_owned_vec(); - - for infectious in infectious_people { - let edges = context.get_matching_edges::(infectious, |context, edge| { - context.match_entity(edge.neighbor, with!(Person, DiseaseStatus::S)) - }); - - for e in edges { - if context.sample_distr(SeirRng, Bernoulli::new(beta).unwrap()) { - context.set_property(e.neighbor, DiseaseStatus::E); - info!( - "Person {} exposed person {} at time {}.", - infectious, - e.neighbor, - context.get_current_time() - ); - context.set_property(e.neighbor, InfectedBy(Some(infectious))); - } - } +fn expose(context: &mut Context, infector: PersonId, infectee: PersonId) { + let infectee_status: DiseaseStatus = context.get_property(infectee); + if infectee_status == DiseaseStatus::S { + info!( + "{infector:?} exposed {infectee:?} at time {}.", + context.get_current_time() + ); + context.set_property(infectee, DiseaseStatus::E); + context.set_property(infectee, InfectedBy(Some(infector))); } } @@ -68,15 +50,15 @@ fn schedule_waiting_event( mean_period: f64, new_status: DiseaseStatus, ) { - let ct = context.get_current_time(); - let waiting_time = calculate_waiting_time(context, shape, mean_period); + let t = context.get_current_time() + calculate_waiting_time(context, shape, mean_period); - context.add_plan(ct + waiting_time, move |context| { + context.add_plan(t, move |context| { + trace!("{person_id:?} changed to disease state {new_status:?} at t={t:?}"); context.set_property(person_id, new_status); }); } -fn schedule_infection(context: &mut Context, person_id: PersonId) { +fn schedule_infectiousness(context: &mut Context, person_id: PersonId) { let parameters = context .get_global_property_value(Parameters) .unwrap() @@ -106,41 +88,29 @@ fn schedule_recovery(context: &mut Context, person_id: PersonId) { ); } -pub fn init(context: &mut Context, initial_infections: &Vec) { +pub fn init(context: &mut Context, initial_infections: &Vec, period: f64) { context.add_periodic_plan_with_phase( - 1.0, - |context| { - let parameters = context - .get_global_property_value(Parameters) - .unwrap() - .clone(); - - // infect the networks - expose_network::( - context, - sar_to_beta(parameters.sar, parameters.incubation_period), - ); - expose_network::( - context, - sar_to_beta( - parameters.sar / parameters.between_hh_transmission_reduction, - parameters.incubation_period, - ), - ); - expose_network::( - context, - sar_to_beta( - parameters.sar / parameters.between_hh_transmission_reduction, - parameters.incubation_period, - ), - ); + period, + move |context| { + // get all infector-infectee pairs + let mut pairs = HashSet::new(); + for infector in context.query(with!(Person, DiseaseStatus::I)) { + for infectee in get_contacts(context, infector, period) { + pairs.insert((infector, infectee)); + } + } + + // do the exposures + for (infector, infectee) in pairs { + expose(context, infector, infectee) + } }, ExecutionPhase::Normal, ); context.subscribe_to_event( move |context, event: PropertyChangeEvent| match event.current { - DiseaseStatus::E => schedule_infection(context, event.entity_id), + DiseaseStatus::E => schedule_infectiousness(context, event.entity_id), DiseaseStatus::I => schedule_recovery(context, event.entity_id), _ => (), }, @@ -165,20 +135,16 @@ mod tests { #[test] fn test_disease_status() { let mut context = Context::new(); - context.init_random(42); + loader::init(&mut context); - let people = loader::init(&mut context); - - // set sar and between_hh_transmission_reduction to 1.0 so that - // beta is 1.0 let parameters = ParametersValues { incubation_period: 8.0, infectious_period: 27.0, sar: 1.0, shape: 15.0, infection_duration: 5.0, - between_hh_transmission_reduction: 1.0, + relative_rate: 0.5, data_dir: "examples/network-hhmodel/tests".to_owned(), output_dir: "examples/network-hhmodel/tests".to_owned(), }; @@ -186,31 +152,27 @@ mod tests { .set_global_property_value(Parameters, parameters) .unwrap(); - network::init(&mut context, &people); - - let mut to_infect = Vec::::new(); - context.with_query_results(with!(Person, Id(71)), &mut |people| { - to_infect.extend(people); - }); + network::init(&mut context, 1.0); - init(&mut context, &to_infect); + let to_infect = context.query(with!(Person, Id(71))).into_iter().collect(); + init(&mut context, &to_infect, 1.0); context.execute(); assert_eq!( - context.query_entity_count::(with!(Person, DiseaseStatus::S)), + context.query_entity_count(with!(Person, DiseaseStatus::S)), 399 ); assert_eq!( - context.query_entity_count::(with!(Person, DiseaseStatus::E)), + context.query_entity_count(with!(Person, DiseaseStatus::E)), 0 ); assert_eq!( - context.query_entity_count::(with!(Person, DiseaseStatus::I)), + context.query_entity_count(with!(Person, DiseaseStatus::I)), 0 ); assert_eq!( - context.query_entity_count::(with!(Person, DiseaseStatus::R)), + context.query_entity_count(with!(Person, DiseaseStatus::R)), 1207 ); }