Skip to content

Commit

Permalink
fixed mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
grantexley committed Nov 6, 2024
1 parent 3ff9936 commit 53bf08b
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions fggs/sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch import Tensor
from fggs.typing import TensorLikeT
from fggs.indices import Nonphysical, PatternedTensor, einsum, stack

from typing import cast
Function = Callable[[MultiTensor], MultiTensor]

def _formatwarning(message, category, filename=None, lineno=None, file=None, line=None):
Expand Down Expand Up @@ -168,21 +168,26 @@ def J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring,
elif i == len(edges) - 1:
product = prefix_products[-1]
else:
if prefix_products[i - 1] is None or suffix_products[i] is None: continue
product = einsum([prefix_products[i-1], suffix_products[i]],
[prefix_output_nodes[i-1], suffix_output_nodes[i]],
rule.rhs.ext + edge.nodes, semiring)
if prefix_products[i - 1] is None or suffix_products[i] is None: continue
#casts to from optional to satisfy mypy
prefix_product: PatternedTensor = cast(PatternedTensor, prefix_products[i - 1])
suffix_product: PatternedTensor = cast(PatternedTensor, suffix_products[i])
prefix_output_node: list[Node] = cast(list[Node], prefix_output_nodes[i - 1])
suffix_output_node: list[Node] = cast(list[Node], suffix_output_nodes[i])
product = einsum([prefix_product, suffix_product],
[prefix_output_node, suffix_output_node],
rule.rhs.ext + edge.nodes, semiring)
if product is not None:
if edge.label in Jx.shapes[1]:
Jx.add_single((n, edge.label), product)
elif J_inputs is not None and edge.label in J_inputs.shapes[1]:
J_inputs.add_single((n, edge.label), product)
return Jx

def compute_products(fgg: FGG, inputses: Iterable[MultiTensor], edges: list[Edge], rule_ext: list[Node], semiring: Semiring, direction: str) -> Tuple[list[PatternedTensor], list[list[Node]]]:
def compute_products(fgg: FGG, inputses: Iterable[MultiTensor], edges: list[Edge], rule_ext: list[Node], semiring: Semiring, direction: str) -> Tuple[list[Optional[PatternedTensor]], list[Optional[list[Node]]]]:
if direction == 'forward':
future_nodes_slice = reversed(edges[1:])
edge_loop_slice = edges[:-1]
future_nodes_slice: Iterable[Edge] = reversed(edges[1:])
edge_loop_slice: Iterable[Edge] = edges[:-1]
elif direction == 'backward':
future_nodes_slice = edges[:-1]
edge_loop_slice = reversed(edges[1:])
Expand All @@ -201,10 +206,10 @@ def compute_products(fgg: FGG, inputses: Iterable[MultiTensor], edges: list[Edge
future_nodes_lookup.append(future_nodes.copy())
future_nodes_lookup.reverse()

products = []
output_nodes = []
previous_weight = PatternedTensor.eye(1, semiring=semiring)
previous_output_nodes = []
products: list[Optional[PatternedTensor]] = []
output_nodes: list[Optional[list[Node]]] = []
previous_weight: PatternedTensor = PatternedTensor.eye(1, semiring=semiring)
previous_output_nodes: list[Node] = []
for i, edge in enumerate(edge_loop_slice):
for inputs in inputses:
if edge.label in inputs:
Expand All @@ -230,7 +235,7 @@ def compute_products(fgg: FGG, inputses: Iterable[MultiTensor], edges: list[Edge

def multiply_next_edge(fgg: FGG, previous_weight: PatternedTensor, previous_nodes: list[Node],
current_weight: PatternedTensor, current_nodes: list[Node],
ext: Iterable[Node], semiring: Semiring) -> Tuple[PatternedTensor,Iterable[Node]]:
ext: Iterable[Node], semiring: Semiring) -> Tuple[PatternedTensor, list[Node]]:

indexing: List[Sequence[Node]] = [previous_nodes, current_nodes]
tensors: List[PatternedTensor] = [previous_weight, current_weight]
Expand Down

0 comments on commit 53bf08b

Please sign in to comment.