Skip to content

Commit

Permalink
Merge pull request #190 from diprism/issue-168
Browse files Browse the repository at this point in the history
In Jacobian, re-use partial einsums (closes #168)
  • Loading branch information
davidweichiang authored Dec 5, 2024
2 parents 112dad7 + 46f671e commit a7f6872
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 142 deletions.
3 changes: 2 additions & 1 deletion bin/sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def tensor_to_dict_string(is_pretty, fgg, node, t):
ap.add_argument('-t', dest='trace', action='store_true', help='print out all intermediate sum-products')
ap.add_argument('-d', dest='double', action='store_true', help='use double-precision floating-point')
ap.add_argument('-p', dest='pretty', action='store_true', help='pretty print weights together with the value names')
ap.add_argument('-j', dest='j_precompute', action='store_true', help='precompute products while solving the Jacobian (faster for rules with many edges)')
ap.add_argument('--tikz', metavar='<file>', dest='tikz', default=None, help='convert the input JSON to tikz and write the output to the given file')

args = ap.parse_args()
Expand Down Expand Up @@ -109,7 +110,7 @@ def tensor_to_dict_string(is_pretty, fgg, node, t):
for w in fgg.factors.values():
w.weights.requires_grad_()

zs = fggs.sum_products(fgg, method=args.method, tol=args.tol, kmax=args.kmax)
zs = fggs.sum_products(fgg, method=args.method, tol=args.tol, kmax=args.kmax, j_precompute=args.j_precompute)
z = zs[fgg.start]

if args.trace:
Expand Down
4 changes: 2 additions & 2 deletions fggs/indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,8 @@ def reshape_or_view(f: Callable[[Tensor, List[int]], Tensor],
assert(all(goal >= 0 for goal in s))
assert(numel == reduce(mul, s, 1))
subst : Subst = {}
vaxes = tuple(unitAxis if goal == 1 else PhysicalAxis(goal)
for goal in s)
vaxes: Tuple[Axis, ...] = tuple(unitAxis if goal == 1 else PhysicalAxis(goal)
for goal in s)
if not productAxis(vaxes).unify(productAxis(self.vaxes), subst):
raise RuntimeError(f'Cannot reshape_or_view {self} to {s}')
paxes = tuple(cast(PhysicalAxis, k) for e in self.paxes
Expand Down
186 changes: 149 additions & 37 deletions fggs/sum_product.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations
__all__ = ['sum_product', 'sum_products']

from fggs.fggs import FGG, HRG, HRGRule, EdgeLabel, Edge, Node
from fggs.factors import FiniteFactor
from fggs.semirings import *
Expand All @@ -18,7 +17,7 @@
from torch import Tensor
from fggs.typing import TensorLikeT
from fggs.indices import Nonphysical, PatternedTensor, einsum, stack

from typing import cast
Function = Callable[[MultiTensor], MultiTensor]

def _formatwarning(message, category, filename=None, lineno=None, file=None, line=None):
Expand Down Expand Up @@ -149,6 +148,102 @@ def J(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring,
assert False
return Jx

def J_precompute_products(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring,
J_inputs: Optional[MultiTensor] = None) -> MultiTensor:
"""The Jacobian of F. New version that takes out the double loop over edges"""
Jx = MultiTensor(x.shapes+x.shapes, semiring)
for n in x.shapes[0]:
for rule in fgg.rules(n):
edges = list(rule.rhs.edges())
if len(edges) == 1:
suffix_products = [sum_product_edges(fgg, rule.rhs.nodes(), set(), rule.rhs.ext + edges[0].nodes, x, inputs, semiring=semiring)]
else:
prefix_products, prefix_output_nodes = compute_products(fgg, (x, inputs), edges, rule, semiring)
suffix_products, suffix_output_nodes = compute_products(fgg, (x, inputs), list(reversed(edges)), rule, semiring)
suffix_products.reverse()
suffix_output_nodes.reverse()
for i, edge in enumerate(edges):
if i == 0:
product = suffix_products[0]
elif i == len(edges) - 1:
product = prefix_products[-1]
else:
if prefix_products[i - 1] is None or suffix_products[i] is None: continue
#casts from optional to satisfy mypy
prefix_product: PatternedTensor = cast(PatternedTensor, prefix_products[i - 1])
suffix_product: PatternedTensor = cast(PatternedTensor, suffix_products[i])
prefix_output_node: List[Node] = cast(List[Node], prefix_output_nodes[i - 1])
suffix_output_node: List[Node] = cast(List[Node], suffix_output_nodes[i])
product = einsum([prefix_product, suffix_product],
[prefix_output_node, suffix_output_node],
[n for n in rule.rhs.ext + edge.nodes if n in prefix_output_node or n in suffix_output_node],
semiring)
if product is not None:
# add back removed disconnected external nodes
ext = list(rule.rhs.ext) + list(edge.nodes)
if product.ndim < len(ext):
connected = set()
for e in edges:
if i != 0 and i != len(edges) -1 and e is edge: continue
connected.update(e.nodes)
eshape = fgg.shape(ext)
vshape = [s if n in connected else 1 for n, s in zip(ext, eshape)]
product = product.view(*vshape).expand(*eshape)
if edge.label in Jx.shapes[1]:
Jx.add_single((n, edge.label), product)
elif J_inputs is not None and edge.label in J_inputs.shapes[1]:
J_inputs.add_single((n, edge.label), product)
return Jx

def compute_products(fgg: FGG, inputses: Iterable[MultiTensor], edges: List[Edge], rule: HRGRule, semiring: Semiring) -> Tuple[List[Optional[PatternedTensor]], List[Optional[List[Node]]]]:
future_nodes_lookup = []
future_nodes = []
seen_nodes = set()
for edge in reversed(edges[1:]):
for node in edge.nodes:
if node not in seen_nodes:
future_nodes.append(node)
seen_nodes.add(node)

future_nodes_lookup.append(future_nodes.copy())
future_nodes_lookup.reverse()

products: List[Optional[PatternedTensor]] = []
output_nodes: List[Optional[List[Node]]] = []
previous_weight: PatternedTensor = PatternedTensor.eye(1, semiring=semiring)
previous_output_nodes: List[Node] = []
for i, edge in enumerate(edges[:-1]):
weight = get_weight(edge, inputses)
if weight is None:
products.extend( [None] * (len(edges) - 1 - i))
output_nodes.extend( [None] * (len(edges) - 1 - i))
break
out, out_nodes = multiply_next_edge(fgg, previous_weight, previous_output_nodes, weight, list(edge.nodes),
list(rule.rhs.ext) + future_nodes_lookup[i], list(rule.rhs.nodes()), semiring)
products.append(out)
output_nodes.append(out_nodes)
previous_weight, previous_output_nodes = out, out_nodes

return products, output_nodes

def multiply_next_edge(fgg: FGG, previous_weight: PatternedTensor, previous_nodes: List[Node],
current_weight: PatternedTensor, current_nodes: List[Node],
ext: Iterable[Node], rule_rhs_nodes: List[Node], semiring: Semiring) -> Tuple[PatternedTensor, List[Node]]:

indexing: List[Sequence[Node]] = [previous_nodes, current_nodes]
tensors: List[PatternedTensor] = [previous_weight, current_weight]
connected: Set[Node] = set(previous_nodes + current_nodes)

ext, duplicate = rename_duplicate_nodes(fgg, ext, tensors, indexing, connected, semiring)
output_nodes = [n for n in ext if n in connected]

assert(all(tensor.physical.dtype == semiring.dtype for tensor in tensors))
out = einsum(tensors, indexing, output_nodes, semiring)
if duplicate:
print('einsum produced', out.physical.size(), 'for', out.size(), file=stderr)

out = multiply_in_disconnected_internals(out, rule_rhs_nodes, connected, ext, semiring, fgg)
return out, output_nodes

def log_softmax(a: TensorLikeT, dim: int) -> TensorLikeT:
# If a has infinite elements, log_softmax would return all nans.
Expand Down Expand Up @@ -195,6 +290,42 @@ def J_log(fgg: FGG, x: MultiTensor, inputs: MultiTensor, semiring: Semiring,
return Jx


def multiply_in_disconnected_internals(out: PatternedTensor, nodes: Iterable[Node], connected: Set[Node], ext: Iterable[Node], semiring: Semiring, fgg: FGG) -> PatternedTensor:
"""Multiply in any disconnected internal nodes"""
multiplier = 1
for n in nodes:
if n not in connected and n not in ext:
multiplier *= fgg.domains[n.label.name].size()
if multiplier != 1:
return semiring.mul(out, PatternedTensor.from_int(multiplier, semiring))
return out

def rename_duplicate_nodes(fgg: FGG, ext: Iterable[Node], tensors: List[PatternedTensor], indexing: List[Sequence[Node]], connected: Set[Node], semiring: Semiring):
""" Derivatives can sometimes produce duplicate external nodes.
Rename them apart and add identity factors between them."""
ext_orig = ext
ext = []
debug_duplicate = False
for n in ext_orig:
if n in ext:
ncopy = Node(n.label)
ext.append(ncopy)
connected.update([n, ncopy])
indexing.append([n, ncopy])
nsize = fgg.domains[n.label.name].size()
tensors.append(PatternedTensor.eye(nsize,semiring))
#debug_duplicate = True # Uncomment this line for debugging messages
else:
ext.append(n)

return ext, debug_duplicate

def get_weight(edge, inputses):
for inputs in inputses:
if edge.label in inputs:
return inputs[edge.label]
return None

def sum_product_edges(fgg: FGG, nodes: Iterable[Node], edges: Iterable[Edge], ext: Sequence[Node], *inputses: MultiTensor, semiring: Semiring) -> Optional[PatternedTensor]:
"""Compute the sum-product of a set of edges.
Expand All @@ -209,38 +340,19 @@ def sum_product_edges(fgg: FGG, nodes: Iterable[Node], edges: Iterable[Edge], ex
Return: the tensor of sum-products, or None if zero
"""

connected: Set[Node] = set()
indexing: List[Sequence[Node]] = []
tensors: List[PatternedTensor] = []

# Derivatives can sometimes produce duplicate external nodes.
# Rename them apart and add identity factors between them.
ext_orig = ext
ext = []
duplicate = False
for n in ext_orig:
if n in ext:
ncopy = Node(n.label)
ext.append(ncopy)
connected.update([n, ncopy])
indexing.append([n, ncopy])
nsize = fgg.domains[n.label.name].size()
tensors.append(PatternedTensor.eye(nsize,semiring))
#duplicate = True # Uncomment this line for debugging messages
else:
ext.append(n)
ext, duplicate = rename_duplicate_nodes(fgg, ext, tensors, indexing, connected, semiring)

for edge in edges:
connected.update(edge.nodes)
indexing.append(edge.nodes)
for inputs in reversed(inputses):
if edge.label in inputs:
tensors.append(inputs[edge.label])
break
weight = get_weight(edge, inputses)
if weight is not None:
tensors.append(weight)
else:
# One argument to einsum will be the zero tensor, so just return zero
return None
return None #One argument to einsum will be the zero tensor, so just return zero

# If an external node has no edges, einsum will complain, so remove it.
outputs = [node for node in ext if node in connected]
Expand All @@ -256,13 +368,8 @@ def sum_product_edges(fgg: FGG, nodes: Iterable[Node], edges: Iterable[Edge], ex
vshape = [s if n in connected else 1 for n, s in zip(ext, eshape)]
out = out.view(*vshape).expand(*eshape)

# Multiply in any disconnected internal nodes.
multiplier = 1
for n in nodes:
if n not in connected and n not in ext:
multiplier *= fgg.domains[n.label.name].size()
if multiplier != 1:
out = semiring.mul(out, PatternedTensor.from_int(multiplier, semiring))
out = multiply_in_disconnected_internals(out, nodes, connected, ext, semiring, fgg)

assert(out.physical.dtype == semiring.dtype)
return out

Expand Down Expand Up @@ -343,14 +450,14 @@ def forward(ctx, # type: ignore
Optional[Tensor] # rest tuple components (same length as out_labels)
], ...]:
ctx.fgg = fgg
method, semiring = opts['method'], opts['semiring']
method, semiring, j_precompute = opts['method'], opts['semiring'], opts.get('j_precompute', False)
ctx.opts = opts
ctx.in_labels = in_labels
ctx.out_labels = out_labels
ctx.save_for_backward(*in_values)


inputs: MultiTensor = {label: nonphysical.reincarnate(physical) for (label, nonphysical), physical in zip(in_labels, in_values)} # type: ignore

if method == 'linear':
out = linear(fgg, inputs, out_labels, semiring)
else:
Expand All @@ -360,7 +467,8 @@ def forward(ctx, # type: ignore
x0, tol=opts['tol'], kmax=opts['kmax'])
elif method == 'newton':
newton(lambda x: F(fgg, x, inputs, semiring),
lambda x: J(fgg, x, inputs, semiring),
lambda x: J_precompute_products(fgg, x, inputs, semiring)
if j_precompute else J(fgg, x, inputs, semiring),
x0, tol=opts['tol'], kmax=opts['kmax'])
elif method == 'one-step':
x0.copy_(F(fgg, x0, inputs, semiring))
Expand Down Expand Up @@ -390,7 +498,10 @@ def backward(ctx, grad_nonphysicals, *grad_out):
FGGMultiShape(ctx.fgg, (el for el, _ in ctx.in_labels))),
real_semiring)
if isinstance(semiring, RealSemiring):
jf = J(ctx.fgg, ctx.out_values, inputs, semiring, jf_inputs)
if ctx.opts.get('j_precompute', False):
jf = J_precompute_products(ctx.fgg, ctx.out_values, inputs, semiring, jf_inputs)
else:
jf = J(ctx.fgg, ctx.out_values, inputs, semiring, jf_inputs)
elif isinstance(semiring, LogSemiring):
jf = J_log(ctx.fgg, ctx.out_values, inputs, semiring, jf_inputs)
else:
Expand Down Expand Up @@ -422,6 +533,7 @@ def sum_products(fgg: FGG, **opts) -> Dict[EdgeLabel, Tensor]:
opts.setdefault('semiring', RealSemiring())
opts.setdefault('tol', 1e-5) # with float32, 1e-6 can fail
opts.setdefault('kmax', 1000) # for fixed-point, 100 is too low

if isinstance(opts['semiring'], BoolSemiring):
opts['tol'] = 0

Expand Down
Loading

0 comments on commit a7f6872

Please sign in to comment.