diff --git a/settings/scott/model.yml b/settings/scott/model.yml index 2a578bf9..6243501f 100644 --- a/settings/scott/model.yml +++ b/settings/scott/model.yml @@ -27,12 +27,14 @@ syringe_services: num_slots_start: 0 num_slots_stop: 237 risk: 0.02 + dx_scalar: 1.195 ssp_on: start_time: 4 stop_time: 121 num_slots_start: 237 num_slots_stop: 237 risk: 0.02 + dx_scalar: 1.195 agent_zero: num_partners: 4 diff --git a/tests/params/basic.yml b/tests/params/basic.yml index b0957de7..a16d806b 100644 --- a/tests/params/basic.yml +++ b/tests/params/basic.yml @@ -862,6 +862,7 @@ syringe_services: num_slots_start: 100 num_slots_stop: 100 risk: 0.02 + dx_scalar: 1.0 agent_zero: bond_type: Inj diff --git a/tests/params/simple_integration.yml b/tests/params/simple_integration.yml index 865f237e..51041a65 100644 --- a/tests/params/simple_integration.yml +++ b/tests/params/simple_integration.yml @@ -191,6 +191,7 @@ syringe_services: num_slots_start: 500 num_slots_stop: 500 risk: 0.00 + dx_scalar: 1.0 partner_tracing: prob: 1 diff --git a/titan/model.py b/titan/model.py index e7fa3544..197d7cb1 100644 --- a/titan/model.py +++ b/titan/model.py @@ -70,6 +70,7 @@ def __init__( self.new_prep = AgentSet("new_prep") self.ssp_enrolled_risk = 0.0 + self.ssp_dx = 1.0 self.time = -1 * self.params.model.time.burn_steps # burn is negative time self.id = nanoid.generate(size=8) @@ -113,10 +114,6 @@ def print_stats(self, stat: Dict[str, Dict[str, int]], outdir: str): ), "Graph must be enabled to print network reports" network_outdir = os.path.join(outdir, "network") - if self.params.outputs.network.draw_figures: - self.network_utils.visualize_network( - network_outdir, curtime=self.time, label=f"{self.id}" - ) if self.params.outputs.network.calc_component_stats: ao.print_components( @@ -899,6 +896,7 @@ def update_syringe_services(self): for item in self.params.syringe_services.timeline.values(): if item.start_time <= self.time < item.stop_time: self.ssp_enrolled_risk = item.risk + self.ssp_dx = item.dx_scalar ssp_num_slots = (item.num_slots_stop - item.num_slots_start) / ( item.stop_time - item.start_time @@ -1102,6 +1100,9 @@ def diagnose( sex_type ].hiv.dx.prob + if agent.ssp: + test_prob *= self.ssp_dx + # Rescale based on calibration param test_prob *= self.calibration.test_frequency diff --git a/titan/network.py b/titan/network.py index 653fbc0f..2aed2f10 100644 --- a/titan/network.py +++ b/titan/network.py @@ -6,8 +6,6 @@ import networkx as nx # type: ignore from networkx.drawing.nx_agraph import graphviz_layout # type: ignore -import matplotlib.pyplot as plt # type: ignore -import matplotlib.patches as patches # type: ignore class NetworkGraphUtils: @@ -163,121 +161,3 @@ def get_network_color(self, coloring) -> List[str]: ) return node_color - - def visualize_network( - self, - outdir: str, - coloring: str = "sex_type", - pos=None, - return_layout: bool = False, - node_size: Optional[float] = None, - curtime: int = 0, - infection_label: int = 0, - label: str = "Network", - ): - """ - Visualize the network using the spring layout (default). - - args: - outdir: directory the figure should be saved to - coloring: what attribute to color the nodes by - pos: a graphviz_layout - return_layout: whether to return the layout (if `False`, nothing is returned) - node_size: size of the nodes in the graph - curtime: the current timestep of the model - infection_label: number of infections to list in figure's label - label: identifier for this network - """ - if node_size is None: - node_size = 5000.0 / self.G.number_of_nodes() - - print(("\tPlotting {} colored by {}...").format(label, coloring)) - fig = plt.figure() - ax = fig.add_axes([0, 0, 1, 1]) - fig.clf() - - # build a rectangle in axes coords - left, width = 0.0, 1.0 - bottom, height = 0.0, 1.0 - right = left + width - top = bottom + height - - fig = plt.figure() - ax = fig.add_axes([0, 0, 1, 1]) - - # axes coordinates are 0,0 is bottom left and 1,1 is upper right - p = patches.Rectangle( - (left, bottom), - width, - height, - fill=False, - transform=ax.transAxes, - clip_on=False, - ) - - ax.add_patch(p) - - if not pos: - pos = graphviz_layout(self.G, prog="neato", args="") - - edge_color = "k" - node_shape = "o" - - # node color to by type - node_color = self.get_network_color(coloring) - - # node size indicating node degree - NodeSize = [] - if node_size: - for v in self.G: - NodeSize.append(node_size) - else: - for v in self.G: - NodeSize.append((10 * self.G.degree(v)) ** (1.0)) - - # draw: - nx.draw( - self.G, - pos, - node_size=NodeSize, - node_color=node_color, - node_shape=node_shape, - edge_color=edge_color, - with_labels=False, - linewidths=0.5, - width=0.5, - ) - - textstr = "\n".join( - ( - r"N infection={:.2f}".format( - infection_label, - ), - r"Time={:.2f}".format( - curtime, - ), - ) - ) - - # these are matplotlib.patch.Patch properties - props = dict(boxstyle="round", facecolor="wheat", alpha=0.9) - - # place a text box in upper right in axes coords - ax.text( - right - 0.025, - top - 0.025, - textstr, - horizontalalignment="right", - verticalalignment="top", - transform=ax.transAxes, - bbox=props, - ) - - filename = os.path.join( - outdir, f"{label}_{self.G.number_of_nodes()}_{coloring}_{curtime}.png" - ) - - fig.savefig(filename) - - if return_layout: - return pos diff --git a/titan/params/syringe_services.yml b/titan/params/syringe_services.yml index 8f258ff5..96609398 100644 --- a/titan/params/syringe_services.yml +++ b/titan/params/syringe_services.yml @@ -23,6 +23,10 @@ syringe_services: description: "Risk of unsafe sharing for agents enrolled in the SSP" min: 0.0 max: 1.0 + dx_scalar: + type: float + description: "Diagnosis scalar for HIV+ agents enrolled in ssp" + min: 0.0 default: ssp_default: start_time: 1 @@ -30,3 +34,4 @@ syringe_services: num_slots_start: 0 num_slots_stop: 0 risk: 0.02 + dx_scalar: 1.0