Skip to content

Commit

Permalink
list to List
Browse files Browse the repository at this point in the history
  • Loading branch information
grantexley committed Nov 6, 2024
1 parent 53bf08b commit ef8612e
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions fggs/sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring,
#casts to from optional to satisfy mypy
prefix_product: PatternedTensor = cast(PatternedTensor, prefix_products[i - 1])
suffix_product: PatternedTensor = cast(PatternedTensor, suffix_products[i])
prefix_output_node: list[Node] = cast(list[Node], prefix_output_nodes[i - 1])
suffix_output_node: list[Node] = cast(list[Node], suffix_output_nodes[i])
prefix_output_node: List[Node] = cast(List[Node], prefix_output_nodes[i - 1])
suffix_output_node: List[Node] = cast(List[Node], suffix_output_nodes[i])
product = einsum([prefix_product, suffix_product],
[prefix_output_node, suffix_output_node],
rule.rhs.ext + edge.nodes, semiring)
Expand All @@ -184,7 +184,7 @@ def J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring,
J_inputs.add_single((n, edge.label), product)
return Jx

def compute_products(fgg: FGG, inputses: Iterable[MultiTensor], edges: list[Edge], rule_ext: list[Node], semiring: Semiring, direction: str) -> Tuple[list[Optional[PatternedTensor]], list[Optional[list[Node]]]]:
def compute_products(fgg: FGG, inputses: Iterable[MultiTensor], edges: List[Edge], rule_ext: List[Node], semiring: Semiring, direction: str) -> Tuple[List[Optional[PatternedTensor]], List[Optional[List[Node]]]]:
if direction == 'forward':
future_nodes_slice: Iterable[Edge] = reversed(edges[1:])
edge_loop_slice: Iterable[Edge] = edges[:-1]
Expand All @@ -206,10 +206,10 @@ def compute_products(fgg: FGG, inputses: Iterable[MultiTensor], edges: list[Edge
future_nodes_lookup.append(future_nodes.copy())
future_nodes_lookup.reverse()

products: list[Optional[PatternedTensor]] = []
output_nodes: list[Optional[list[Node]]] = []
products: List[Optional[PatternedTensor]] = []
output_nodes: List[Optional[List[Node]]] = []
previous_weight: PatternedTensor = PatternedTensor.eye(1, semiring=semiring)
previous_output_nodes: list[Node] = []
previous_output_nodes: List[Node] = []
for i, edge in enumerate(edge_loop_slice):
for inputs in inputses:
if edge.label in inputs:
Expand All @@ -233,9 +233,9 @@ def compute_products(fgg: FGG, inputses: Iterable[MultiTensor], edges: list[Edge

return products, output_nodes

def multiply_next_edge(fgg: FGG, previous_weight: PatternedTensor, previous_nodes: list[Node],
current_weight: PatternedTensor, current_nodes: list[Node],
ext: Iterable[Node], semiring: Semiring) -> Tuple[PatternedTensor, list[Node]]:
def multiply_next_edge(fgg: FGG, previous_weight: PatternedTensor, previous_nodes: List[Node],
current_weight: PatternedTensor, current_nodes: List[Node],
ext: Iterable[Node], semiring: Semiring) -> Tuple[PatternedTensor, List[Node]]:

indexing: List[Sequence[Node]] = [previous_nodes, current_nodes]
tensors: List[PatternedTensor] = [previous_weight, current_weight]
Expand Down

0 comments on commit ef8612e

Please sign in to comment.