Skip to content

Commit

Permalink
added einsum multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
grantexley committed Oct 10, 2024
1 parent bbdcdc5 commit 580905e
Showing 1 changed file with 86 additions and 120 deletions.
206 changes: 86 additions & 120 deletions fggs/sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,25 @@ def print_duplicate(loc: str, xs: Sequence[Node], ys: Sequence[Node]) -> bool:
else:
return False

def get_weight(fgg: FGG, inputses: MultiTensor, edge: Edge) -> Optional[PatternedTensor]:
""" the weight of an edge is the sum-product of the edge's label if nonterminal,
or the weight of its factor if terminal
"""
if edge.label.is_terminal:
return fgg.factors[edge.label.name].weights
else:
for inputs in inputses:
if edge.label in inputs:
return inputs[edge.label]
return None

def _J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring,
J_inputs: Optional[MultiTensor] = None) -> MultiTensor:
"""The Jacobian of F."""
Jx = MultiTensor(x.shapes+x.shapes, semiring)
for n in x.shapes[0]:
for rule in fgg.rules(n):
for edge in rule.rhs.edges():
for i, edge in enumerate(rule.rhs.edges()):
if edge.label not in Jx.shapes[1] and J_inputs is None:
continue
duplicate = print_duplicate('J', rule.rhs.ext, edge.nodes)
Expand All @@ -147,7 +158,6 @@ def _J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring,
J_inputs.add_single((n, edge.label), tau_edge)
else:
assert False

return Jx

def J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring,
Expand All @@ -156,66 +166,94 @@ def J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring,
Jx = MultiTensor(x.shapes+x.shapes, semiring)
for n in x.shapes[0]:
for rule in fgg.rules(n):

edges = list(rule.rhs.edges())

prefix_products = []
for i, edge in enumerate(edges[:-1]):
#duplicate = ##print_duplicate('J', rule.rhs.ext, edge.nodes)
ext = rule.rhs.ext + edge.nodes
tau_edge = single_spe(fgg, rule.rhs.nodes(), edge, ext, x, inputs, semiring=semiring)

if tau_edge is None:
prefix_products.extend( [None] * (len(edges) - 1 - i))
break

if len(prefix_products) == 0:
prefix_products.append(tau_edge)
else:
prefix_products.append(semiring.mul(prefix_products[-1], tau_edge))
if len(edges) == 1:
suffix_products = [sum_product_edges(fgg, rule.rhs.nodes(), set(), rule.rhs.ext + edges[0].nodes, x, inputs, semiring=semiring)]
else:
#precompute future nodes
future_nodes_lookup = []
future_nodes = set()
for edge in reversed(edges[1:]):
future_nodes.update(edge.nodes)
future_nodes_lookup.insert(0, future_nodes)

prefix_products = []
prefix_output_nodes = []
for i, edge in enumerate(edges[:-1]):

weight = get_weight(fgg, (inputs, x), edge)

if weight is None:
prefix_products.extend( [None] * (len(edges) - 1 - i))
prefix_output_nodes.extend( [None] * (len(edges) - 1 - i))
break

if i == 0:
previous_weight = PatternedTensor.eye(1, semiring=semiring)
previous_output_nodes = []
else:
previous_weight = prefix_products[-1]
previous_output_nodes = prefix_output_nodes[-1]

#output_nodes = list(rule.rhs.ext) + list(future_nodes_lookup[i])
output_nodes = [node for node in (list(rule.rhs.ext) + list(future_nodes_lookup[i])) if node in edge.nodes or node in previous_output_nodes]
out = einsum([previous_weight, weight], [previous_output_nodes, edge.nodes], output_nodes, semiring)

prefix_products.append(out)
prefix_output_nodes.append(output_nodes)

#precompute past nodes
past_nodes_lookup = []
past_nodes = set()
for edge in edges[:-1]:
past_nodes.update(edge.nodes)
past_nodes_lookup.insert(0, past_nodes)

suffix_products = []
suffix_output_nodes = []
for i, edge in enumerate(reversed(edges[1:])):

weight = get_weight(fgg, (inputs, x), edge)

suffix_products = []
for i, edge in enumerate(reversed(edges[1:])):
#duplicate = #print_duplicate('J', rule.rhs.ext, edge.nodes)
ext = rule.rhs.ext + edge.nodes
tau_edge = single_spe(fgg, rule.rhs.nodes(), edge, ext, x, inputs, semiring=semiring)

if tau_edge is None:
suffix_products.extend( [None] * (len(edges) - 1 - i))
break
if weight is None:
suffix_products.extend( [None] * (len(edges) - 1 - i))
suffix_output_nodes.extend( [None] * (len(edges) - 1 - i))
break

if len(suffix_products) == 0:
suffix_products.append(tau_edge)
else:
suffix_products.append(semiring.mul(suffix_products[-1], tau_edge))
if i == 0:
previous_weight = PatternedTensor.eye(1, semiring=semiring)
previous_output_nodes = []
else:
previous_weight = suffix_products[-1]
previous_output_nodes = suffix_output_nodes[-1]

suffix_products.reverse()

if not suffix_products or not prefix_products:
continue

#output_nodes = list(rule.rhs.ext) + list(past_nodes_lookup[i])
output_nodes = [node for node in (list(rule.rhs.ext) + list(future_nodes_lookup[i])) if node in edge.nodes or node in previous_output_nodes]
out = einsum([previous_weight, weight], [previous_output_nodes, edge.nodes], output_nodes, semiring)

suffix_products.append(out)
suffix_output_nodes.append(output_nodes)

suffix_output_nodes.reverse()
suffix_products.reverse()

for i, edge in enumerate(edges):
if i == 0:
product = suffix_products[0]
elif i == len(edges) - 1:
product = prefix_products[-1]
else:
if prefix_products[i - 1] is None or suffix_products[i] is None:
continue

product = semiring.mul(prefix_products[i - 1], suffix_products[i])
continue
product = einsum([prefix_products[i-1], suffix_products[i]],
[prefix_output_nodes[i-1], suffix_output_nodes[i]],
rule.rhs.ext + edge.nodes, semiring)
if product is not None:
# if duplicate:
# #print('sum_product_edges produced', product.physical.size(), 'for', product.size(), file=stderr)
if edge.label in Jx.shapes[1]:
Jx.add_single((n, edge.label), product)
elif J_inputs is not None and edge.label in J_inputs.shapes[1]:
J_inputs.add_single((n, edge.label), product)
else:
continue
assert False


return Jx


Expand Down Expand Up @@ -263,79 +301,6 @@ def J_log(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring,
assert False
return Jx

def single_spe(fgg: FGG, nodes: Iterable[Node], edge: Edge, ext: Sequence[Node], *inputses: MultiTensor, semiring: Semiring) -> Optional[PatternedTensor]:
"""Compute the sum-product of a single edge.
Parameters:
- fgg
- nodes: all the nodes, even disconnected ones
- edge: the edges whose factors are multiplied together
- ext: the nodes whose values are not summed over
- inputses: dicts of sum-products of nonterminals that have
already been computed. Later elements of inputses override
earlier elements.
Return: the tensor of sum-products, or None if zero
"""

connected: Set[Node] = set()
indexing: List[Sequence[Node]] = []
tensors: List[PatternedTensor] = []

# Derivatives can sometimes produce duplicate external nodes.
# Rename them apart and add identity factors between them.
ext_orig = ext
ext = []
duplicate = False
for n in ext_orig:
if n in ext:
ncopy = Node(n.label)
ext.append(ncopy)
connected.update([n, ncopy])
indexing.append([n, ncopy])
nsize = fgg.domains[n.label.name].size()
tensors.append(PatternedTensor.eye(nsize,semiring))
#duplicate = True # Uncomment this line for debugging messages
else:
ext.append(n)

connected.update(edge.nodes)
indexing.append(edge.nodes)
for inputs in reversed(inputses):
if edge.label in inputs:
tensors.append(inputs[edge.label])
break
else:
# One argument to einsum will be the zero tensor, so just return zero
return None

# If an external node has no edges, einsum will complain, so remove it.
outputs = [node for node in ext if node in connected]


assert(all(tensor.physical.dtype == semiring.dtype for tensor in tensors))
out = einsum(tensors, indexing, outputs, semiring)
if duplicate:
print('einsum produced', out.physical.size(), 'for', out.size(), file=stderr)

# Restore any external nodes that were removed.
if out.ndim < len(ext):
eshape = fgg.shape(ext)
vshape = [s if n in connected else 1 for n, s in zip(ext, eshape)]
out = out.view(*vshape).expand(*eshape)


# Multiply in any disconnected internal nodes.
multiplier = 1
for n in nodes:
if n not in connected and n not in ext:
multiplier *= fgg.domains[n.label.name].size()
if multiplier != 1:
out = semiring.mul(out, PatternedTensor.from_int(multiplier, semiring))
assert(out.physical.dtype == semiring.dtype)
return out


def sum_product_edges(fgg: FGG, nodes: Iterable[Node], edges: Iterable[Edge], ext: Sequence[Node], *inputses: MultiTensor, semiring: Semiring) -> Optional[PatternedTensor]:
"""Compute the sum-product of a set of edges.
Expand Down Expand Up @@ -371,23 +336,24 @@ def sum_product_edges(fgg: FGG, nodes: Iterable[Node], edges: Iterable[Edge], ex
#duplicate = True # Uncomment this line for debugging messages
else:
ext.append(n)

for edge in edges:
connected.update(edge.nodes)
indexing.append(edge.nodes)
for inputs in reversed(inputses):
if edge.label in inputs:
tensors.append(inputs[edge.label])
#print(f'EDGE: {edge.label.name}, TENSOR: {tensors[-1]}')
break
else:
# One argument to einsum will be the zero tensor, so just return zero
return None

# If an external node has no edges, einsum will complain, so remove it.
outputs = [node for node in ext if node in connected]


assert(all(tensor.physical.dtype == semiring.dtype for tensor in tensors))
#print(f'EINSUM CALL: \ntensors: {tensors}\n\nindexing: {indexing}\n\noutputs: {outputs}\n\n')
out = einsum(tensors, indexing, outputs, semiring)
if duplicate:
print('einsum produced', out.physical.size(), 'for', out.size(), file=stderr)
Expand Down

0 comments on commit 580905e

Please sign in to comment.