Skip to content

Commit

Permalink
Add unity normalization constant to unfixed leaf nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong committed Aug 22, 2022
1 parent de1c751 commit 974038d
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,9 +649,7 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
)
# It is possible that a simple node is non-fixed, in which case we want to
# provide an inside array that reflects the prior distribution
nonfixed_samples = np.intersect1d(
self.priors.nonfixed_node_ids(), self.ts.samples()
)
nonfixed_samples = np.intersect1d(inside.nonfixed_node_ids(), self.ts.samples())
for u in nonfixed_samples:
# this is in the same probability space as the prior, so we should be
# OK just to copy the prior values straight in (but we should check they
Expand All @@ -663,6 +661,8 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
(self.ts.num_edges, self.lik.grid_size), self.lik.identity_constant
)
norm = np.full(self.ts.num_nodes, np.nan)
to_visit = np.zeros(self.ts.num_nodes, dtype=bool)
to_visit[inside.nonfixed_node_ids()] = True
# Iterate through the nodes via groupby on parent node
for parent, edges in tqdm(
self.edges_by_parent_asc(),
Expand Down Expand Up @@ -707,6 +707,12 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
g_i[edge.id] = edge_lik
norm[parent] = np.max(val) if normalize else 1
inside[parent] = self.lik.reduce(val, norm[parent])
to_visit[parent] = False

# There may be nodes that are not parents but are also not fixed (e.g.
# undated sample nodes). These need an identity normalization constant
for unfixed_unvisited in np.where(to_visit)[0]:
norm[unfixed_unvisited] = 1

if cache_inside:
self.g_i = self.lik.reduce(g_i, norm[self.ts.tables.edges.child, None])
Expand Down

0 comments on commit 974038d

Please sign in to comment.