Skip to content

Commit

Permalink
removed foward/backward options
Browse files Browse the repository at this point in the history
  • Loading branch information
grantexley committed Dec 5, 2024
1 parent 305a72b commit 46f671e
Showing 1 changed file with 10 additions and 20 deletions.
30 changes: 10 additions & 20 deletions fggs/sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,10 @@ def J_precompute_products(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semirin
if len(edges) == 1:
suffix_products = [sum_product_edges(fgg, rule.rhs.nodes(), set(), rule.rhs.ext + edges[0].nodes, x, inputs, semiring=semiring)]
else:
prefix_products, prefix_output_nodes = compute_products(fgg, (x, inputs), edges, rule, semiring, "forward")
suffix_products, suffix_output_nodes = compute_products(fgg, (x, inputs), edges, rule, semiring, "backward")
prefix_products, prefix_output_nodes = compute_products(fgg, (x, inputs), edges, rule, semiring)
suffix_products, suffix_output_nodes = compute_products(fgg, (x, inputs), list(reversed(edges)), rule, semiring)
suffix_products.reverse()
suffix_output_nodes.reverse()
for i, edge in enumerate(edges):
if i == 0:
product = suffix_products[0]
Expand Down Expand Up @@ -193,33 +195,24 @@ def J_precompute_products(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semirin
J_inputs.add_single((n, edge.label), product)
return Jx

def compute_products(fgg: FGG, inputses: Iterable[MultiTensor], edges: List[Edge], rule: HRGRule, semiring: Semiring, direction: str) -> Tuple[List[Optional[PatternedTensor]], List[Optional[List[Node]]]]:
if direction == 'forward':
future_nodes_slice: Iterable[Edge] = reversed(edges[1:])
edge_loop_slice: Iterable[Edge] = edges[:-1]
elif direction == 'backward':
future_nodes_slice = edges[:-1]
edge_loop_slice = reversed(edges[1:])
else:
raise ValueError("Direction must be 'forward' or 'backward'.")

def compute_products(fgg: FGG, inputses: Iterable[MultiTensor], edges: List[Edge], rule: HRGRule, semiring: Semiring) -> Tuple[List[Optional[PatternedTensor]], List[Optional[List[Node]]]]:
future_nodes_lookup = []
future_nodes = []
seen_nodes = set()
for edge in future_nodes_slice:
for edge in reversed(edges[1:]):
for node in edge.nodes:
if node not in seen_nodes:
future_nodes.append(node)
seen_nodes.add(node)
future_nodes_lookup.append(future_nodes.copy())

future_nodes_lookup.append(future_nodes.copy())
future_nodes_lookup.reverse()

products: List[Optional[PatternedTensor]] = []
output_nodes: List[Optional[List[Node]]] = []
previous_weight: PatternedTensor = PatternedTensor.eye(1, semiring=semiring)
previous_output_nodes: List[Node] = []
for i, edge in enumerate(edge_loop_slice):
for i, edge in enumerate(edges[:-1]):
weight = get_weight(edge, inputses)
if weight is None:
products.extend( [None] * (len(edges) - 1 - i))
Expand All @@ -230,10 +223,7 @@ def compute_products(fgg: FGG, inputses: Iterable[MultiTensor], edges: List[Edge
products.append(out)
output_nodes.append(out_nodes)
previous_weight, previous_output_nodes = out, out_nodes

if direction == 'backward':
output_nodes.reverse()
products.reverse()

return products, output_nodes

def multiply_next_edge(fgg: FGG, previous_weight: PatternedTensor, previous_nodes: List[Node],
Expand Down

0 comments on commit 46f671e

Please sign in to comment.