Skip to content

Commit

Permalink
refactor J
Browse files Browse the repository at this point in the history
  • Loading branch information
grantexley committed Oct 3, 2024
1 parent af0fab5 commit bbdcdc5
Showing 1 changed file with 75 additions and 8 deletions.
83 changes: 75 additions & 8 deletions fggs/sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring,
for i, edge in enumerate(edges[:-1]):
#duplicate = ##print_duplicate('J', rule.rhs.ext, edge.nodes)
ext = rule.rhs.ext + edge.nodes
tau_edge = sum_product_edges(fgg, rule.rhs.nodes(), {edge}, ext, x, inputs, semiring=semiring)
tau_edge = single_spe(fgg, rule.rhs.nodes(), edge, ext, x, inputs, semiring=semiring)

if tau_edge is None:
prefix_products.extend( [None] * (len(edges) - 1 - i))
Expand All @@ -178,7 +178,7 @@ def J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring,
for i, edge in enumerate(reversed(edges[1:])):
#duplicate = #print_duplicate('J', rule.rhs.ext, edge.nodes)
ext = rule.rhs.ext + edge.nodes
tau_edge = sum_product_edges(fgg, rule.rhs.nodes(), {edge}, ext, x, inputs, semiring=semiring)
tau_edge = single_spe(fgg, rule.rhs.nodes(), edge, ext, x, inputs, semiring=semiring)

if tau_edge is None:
suffix_products.extend( [None] * (len(edges) - 1 - i))
Expand All @@ -194,12 +194,6 @@ def J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring,
if not suffix_products or not prefix_products:
continue

if all([x is None for x in suffix_products]):
continue

if all([x is None for x in prefix_products]):
continue

for i, edge in enumerate(edges):
if i == 0:
product = suffix_products[0]
Expand All @@ -208,6 +202,7 @@ def J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring,
else:
if prefix_products[i - 1] is None or suffix_products[i] is None:
continue

product = semiring.mul(prefix_products[i - 1], suffix_products[i])
if product is not None:
# if duplicate:
Expand Down Expand Up @@ -268,6 +263,78 @@ def J_log(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring,
assert False
return Jx

def single_spe(fgg: FGG, nodes: Iterable[Node], edge: Edge, ext: Sequence[Node], *inputses: MultiTensor, semiring: Semiring) -> Optional[PatternedTensor]:
"""Compute the sum-product of a single edge.
Parameters:
- fgg
- nodes: all the nodes, even disconnected ones
- edge: the edges whose factors are multiplied together
- ext: the nodes whose values are not summed over
- inputses: dicts of sum-products of nonterminals that have
already been computed. Later elements of inputses override
earlier elements.
Return: the tensor of sum-products, or None if zero
"""

connected: Set[Node] = set()
indexing: List[Sequence[Node]] = []
tensors: List[PatternedTensor] = []

# Derivatives can sometimes produce duplicate external nodes.
# Rename them apart and add identity factors between them.
ext_orig = ext
ext = []
duplicate = False
for n in ext_orig:
if n in ext:
ncopy = Node(n.label)
ext.append(ncopy)
connected.update([n, ncopy])
indexing.append([n, ncopy])
nsize = fgg.domains[n.label.name].size()
tensors.append(PatternedTensor.eye(nsize,semiring))
#duplicate = True # Uncomment this line for debugging messages
else:
ext.append(n)

connected.update(edge.nodes)
indexing.append(edge.nodes)
for inputs in reversed(inputses):
if edge.label in inputs:
tensors.append(inputs[edge.label])
break
else:
# One argument to einsum will be the zero tensor, so just return zero
return None

# If an external node has no edges, einsum will complain, so remove it.
outputs = [node for node in ext if node in connected]


assert(all(tensor.physical.dtype == semiring.dtype for tensor in tensors))
out = einsum(tensors, indexing, outputs, semiring)
if duplicate:
print('einsum produced', out.physical.size(), 'for', out.size(), file=stderr)

# Restore any external nodes that were removed.
if out.ndim < len(ext):
eshape = fgg.shape(ext)
vshape = [s if n in connected else 1 for n, s in zip(ext, eshape)]
out = out.view(*vshape).expand(*eshape)


# Multiply in any disconnected internal nodes.
multiplier = 1
for n in nodes:
if n not in connected and n not in ext:
multiplier *= fgg.domains[n.label.name].size()
if multiplier != 1:
out = semiring.mul(out, PatternedTensor.from_int(multiplier, semiring))
assert(out.physical.dtype == semiring.dtype)
return out


def sum_product_edges(fgg: FGG, nodes: Iterable[Node], edges: Iterable[Edge], ext: Sequence[Node], *inputses: MultiTensor, semiring: Semiring) -> Optional[PatternedTensor]:
"""Compute the sum-product of a set of edges.
Expand Down

0 comments on commit bbdcdc5

Please sign in to comment.