Skip to content

Commit

Permalink
multiply disconnected internal nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
grantexley committed Nov 7, 2024
1 parent c8d1d27 commit dcacdb3
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions fggs/sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,10 @@ def multiply_next_edge(fgg: FGG, previous_weight: PatternedTensor, previous_node

assert(all(tensor.physical.dtype == semiring.dtype for tensor in tensors))
out = einsum(tensors, indexing, output_nodes, semiring)
# if out.ndim < len(ext):
# eshape = fgg.shape(ext)
# vshape = [s if n in connected else 1 for n, s in zip(ext, eshape)]
# out = out.view(*vshape).expand(*eshape)
if out.ndim < len(output_nodes):
eshape = fgg.shape(output_nodes)
vshape = [s if n in connected else 1 for n, s in zip(output_nodes, eshape)]
out = out.view(*vshape).expand(*eshape)

# Multiply in any disconnected internal nodes.
multiplier = 1
Expand Down

0 comments on commit dcacdb3

Please sign in to comment.