diff --git a/fggs/sum_product.py b/fggs/sum_product.py index 6a1ebbf..f643c6f 100644 --- a/fggs/sum_product.py +++ b/fggs/sum_product.py @@ -124,6 +124,17 @@ def print_duplicate(loc: str, xs: Sequence[Node], ys: Sequence[Node]) -> bool: else: return False +def get_weight(fgg: FGG, inputses: MultiTensor, edge: Edge) -> Optional[PatternedTensor]: + """ the weight of an edge is the sum-product of the edge's label if nonterminal, + or the weight of its factor if terminal + """ + if edge.label.is_terminal: + return fgg.factors[edge.label.name].weights + else: + for inputs in inputses: + if edge.label in inputs: + return inputs[edge.label] + return None def _J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring, J_inputs: Optional[MultiTensor] = None) -> MultiTensor: @@ -131,7 +142,7 @@ def _J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring, Jx = MultiTensor(x.shapes+x.shapes, semiring) for n in x.shapes[0]: for rule in fgg.rules(n): - for edge in rule.rhs.edges(): + for i, edge in enumerate(rule.rhs.edges()): if edge.label not in Jx.shapes[1] and J_inputs is None: continue duplicate = print_duplicate('J', rule.rhs.ext, edge.nodes) @@ -147,7 +158,6 @@ def _J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring, J_inputs.add_single((n, edge.label), tau_edge) else: assert False - return Jx def J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring, @@ -156,44 +166,77 @@ def J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring, Jx = MultiTensor(x.shapes+x.shapes, semiring) for n in x.shapes[0]: for rule in fgg.rules(n): - edges = list(rule.rhs.edges()) - - prefix_products = [] - for i, edge in enumerate(edges[:-1]): - #duplicate = ##print_duplicate('J', rule.rhs.ext, edge.nodes) - ext = rule.rhs.ext + edge.nodes - tau_edge = single_spe(fgg, rule.rhs.nodes(), edge, ext, x, inputs, semiring=semiring) - - if tau_edge is None: - prefix_products.extend( [None] * (len(edges) - 1 - i)) - break - - if len(prefix_products) == 0: - prefix_products.append(tau_edge) - else: - prefix_products.append(semiring.mul(prefix_products[-1], tau_edge)) + if len(edges) == 1: + suffix_products = [sum_product_edges(fgg, rule.rhs.nodes(), set(), rule.rhs.ext + edges[0].nodes, x, inputs, semiring=semiring)] + else: + #precompute future nodes + future_nodes_lookup = [] + future_nodes = set() + for edge in reversed(edges[1:]): + future_nodes.update(edge.nodes) + future_nodes_lookup.insert(0, future_nodes) + + prefix_products = [] + prefix_output_nodes = [] + for i, edge in enumerate(edges[:-1]): + + weight = get_weight(fgg, (inputs, x), edge) + + if weight is None: + prefix_products.extend( [None] * (len(edges) - 1 - i)) + prefix_output_nodes.extend( [None] * (len(edges) - 1 - i)) + break + + if i == 0: + previous_weight = PatternedTensor.eye(1, semiring=semiring) + previous_output_nodes = [] + else: + previous_weight = prefix_products[-1] + previous_output_nodes = prefix_output_nodes[-1] + + #output_nodes = list(rule.rhs.ext) + list(future_nodes_lookup[i]) + output_nodes = [node for node in (list(rule.rhs.ext) + list(future_nodes_lookup[i])) if node in edge.nodes or node in previous_output_nodes] + out = einsum([previous_weight, weight], [previous_output_nodes, edge.nodes], output_nodes, semiring) + + prefix_products.append(out) + prefix_output_nodes.append(output_nodes) + + #precompute past nodes + past_nodes_lookup = [] + past_nodes = set() + for edge in edges[:-1]: + past_nodes.update(edge.nodes) + past_nodes_lookup.insert(0, past_nodes) + + suffix_products = [] + suffix_output_nodes = [] + for i, edge in enumerate(reversed(edges[1:])): + + weight = get_weight(fgg, (inputs, x), edge) - suffix_products = [] - for i, edge in enumerate(reversed(edges[1:])): - #duplicate = #print_duplicate('J', rule.rhs.ext, edge.nodes) - ext = rule.rhs.ext + edge.nodes - tau_edge = single_spe(fgg, rule.rhs.nodes(), edge, ext, x, inputs, semiring=semiring) - - if tau_edge is None: - suffix_products.extend( [None] * (len(edges) - 1 - i)) - break + if weight is None: + suffix_products.extend( [None] * (len(edges) - 1 - i)) + suffix_output_nodes.extend( [None] * (len(edges) - 1 - i)) + break - if len(suffix_products) == 0: - suffix_products.append(tau_edge) - else: - suffix_products.append(semiring.mul(suffix_products[-1], tau_edge)) + if i == 0: + previous_weight = PatternedTensor.eye(1, semiring=semiring) + previous_output_nodes = [] + else: + previous_weight = suffix_products[-1] + previous_output_nodes = suffix_output_nodes[-1] - suffix_products.reverse() - - if not suffix_products or not prefix_products: - continue - + #output_nodes = list(rule.rhs.ext) + list(past_nodes_lookup[i]) + output_nodes = [node for node in (list(rule.rhs.ext) + list(future_nodes_lookup[i])) if node in edge.nodes or node in previous_output_nodes] + out = einsum([previous_weight, weight], [previous_output_nodes, edge.nodes], output_nodes, semiring) + + suffix_products.append(out) + suffix_output_nodes.append(output_nodes) + + suffix_output_nodes.reverse() + suffix_products.reverse() + for i, edge in enumerate(edges): if i == 0: product = suffix_products[0] @@ -201,21 +244,16 @@ def J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring, product = prefix_products[-1] else: if prefix_products[i - 1] is None or suffix_products[i] is None: - continue - - product = semiring.mul(prefix_products[i - 1], suffix_products[i]) + continue + product = einsum([prefix_products[i-1], suffix_products[i]], + [prefix_output_nodes[i-1], suffix_output_nodes[i]], + rule.rhs.ext + edge.nodes, semiring) if product is not None: - # if duplicate: - # #print('sum_product_edges produced', product.physical.size(), 'for', product.size(), file=stderr) if edge.label in Jx.shapes[1]: Jx.add_single((n, edge.label), product) elif J_inputs is not None and edge.label in J_inputs.shapes[1]: J_inputs.add_single((n, edge.label), product) - else: - continue - assert False - return Jx @@ -263,79 +301,6 @@ def J_log(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring, assert False return Jx -def single_spe(fgg: FGG, nodes: Iterable[Node], edge: Edge, ext: Sequence[Node], *inputses: MultiTensor, semiring: Semiring) -> Optional[PatternedTensor]: - """Compute the sum-product of a single edge. - - Parameters: - - fgg - - nodes: all the nodes, even disconnected ones - - edge: the edges whose factors are multiplied together - - ext: the nodes whose values are not summed over - - inputses: dicts of sum-products of nonterminals that have - already been computed. Later elements of inputses override - earlier elements. - - Return: the tensor of sum-products, or None if zero - """ - - connected: Set[Node] = set() - indexing: List[Sequence[Node]] = [] - tensors: List[PatternedTensor] = [] - - # Derivatives can sometimes produce duplicate external nodes. - # Rename them apart and add identity factors between them. - ext_orig = ext - ext = [] - duplicate = False - for n in ext_orig: - if n in ext: - ncopy = Node(n.label) - ext.append(ncopy) - connected.update([n, ncopy]) - indexing.append([n, ncopy]) - nsize = fgg.domains[n.label.name].size() - tensors.append(PatternedTensor.eye(nsize,semiring)) - #duplicate = True # Uncomment this line for debugging messages - else: - ext.append(n) - - connected.update(edge.nodes) - indexing.append(edge.nodes) - for inputs in reversed(inputses): - if edge.label in inputs: - tensors.append(inputs[edge.label]) - break - else: - # One argument to einsum will be the zero tensor, so just return zero - return None - - # If an external node has no edges, einsum will complain, so remove it. - outputs = [node for node in ext if node in connected] - - - assert(all(tensor.physical.dtype == semiring.dtype for tensor in tensors)) - out = einsum(tensors, indexing, outputs, semiring) - if duplicate: - print('einsum produced', out.physical.size(), 'for', out.size(), file=stderr) - - # Restore any external nodes that were removed. - if out.ndim < len(ext): - eshape = fgg.shape(ext) - vshape = [s if n in connected else 1 for n, s in zip(ext, eshape)] - out = out.view(*vshape).expand(*eshape) - - - # Multiply in any disconnected internal nodes. - multiplier = 1 - for n in nodes: - if n not in connected and n not in ext: - multiplier *= fgg.domains[n.label.name].size() - if multiplier != 1: - out = semiring.mul(out, PatternedTensor.from_int(multiplier, semiring)) - assert(out.physical.dtype == semiring.dtype) - return out - - def sum_product_edges(fgg: FGG, nodes: Iterable[Node], edges: Iterable[Edge], ext: Sequence[Node], *inputses: MultiTensor, semiring: Semiring) -> Optional[PatternedTensor]: """Compute the sum-product of a set of edges. @@ -371,13 +336,14 @@ def sum_product_edges(fgg: FGG, nodes: Iterable[Node], edges: Iterable[Edge], ex #duplicate = True # Uncomment this line for debugging messages else: ext.append(n) - + for edge in edges: connected.update(edge.nodes) indexing.append(edge.nodes) for inputs in reversed(inputses): if edge.label in inputs: tensors.append(inputs[edge.label]) + #print(f'EDGE: {edge.label.name}, TENSOR: {tensors[-1]}') break else: # One argument to einsum will be the zero tensor, so just return zero @@ -385,9 +351,9 @@ def sum_product_edges(fgg: FGG, nodes: Iterable[Node], edges: Iterable[Edge], ex # If an external node has no edges, einsum will complain, so remove it. outputs = [node for node in ext if node in connected] - assert(all(tensor.physical.dtype == semiring.dtype for tensor in tensors)) + #print(f'EINSUM CALL: \ntensors: {tensors}\n\nindexing: {indexing}\n\noutputs: {outputs}\n\n') out = einsum(tensors, indexing, outputs, semiring) if duplicate: print('einsum produced', out.physical.size(), 'for', out.size(), file=stderr)