From b4711b4153652ce4e33d9b4b149fe3f5b5825a27 Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Tue, 21 Jan 2025 21:57:13 -0800 Subject: [PATCH] better docstring, fixed index --- mrmustard/physics/mm_einsum.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/mrmustard/physics/mm_einsum.py b/mrmustard/physics/mm_einsum.py index 8bf7ab7ba..2c378224c 100644 --- a/mrmustard/physics/mm_einsum.py +++ b/mrmustard/physics/mm_einsum.py @@ -20,9 +20,25 @@ def mm_einsum(*args: list[CircuitComponent | list[int]]): - """ - Assumes args = [cc1, lst1, cc2, lst2, ..., ccN, lstN, lstOut] - like np.einsum without the string. + r"""Performs tensor contractions between multiple circuit components using their indices. + + This function is analogous to numpy's einsum but specialized for MrMustard's circuit components. + It automatically determines the optimal contraction order and handles both continuous-variable (CV) + and Fock-space representations. + + Args: + *args: Alternating sequence of CircuitComponent objects and their corresponding index lists, + followed by a final output index list. The format should be: + [component1, indices1, component2, indices2, ..., componentN, indicesN, output_indices] + + Returns: + CircuitComponent: The resulting circuit component after performing all contractions. + + Notes: + - The function automatically determines the optimal contraction order to minimize computational cost + - Handles mixed CV and Fock-space representations + - Index values are arbitrary integers, but must be consistent across the expression + - The contraction behavior is similar to np.einsum but without requiring the equation string """ indices = list(args[1::2]) representations = args[:-1:2] @@ -30,12 +46,13 @@ def mm_einsum(*args: list[CircuitComponent | list[int]]): sizes = dict() for rep, idx in zip(representations, indices): - for i, wire in enumerate(rep.wires): - sizes[i] = rep.ansatz.array.shape[i + 1] if wire.repr == ReprEnum.FOCK else 0 + for j, (i, wire) in enumerate(zip(idx, rep.wires)): + # i+1 because the first index is the batch dimension + sizes[i] = rep.ansatz.array.shape[j + 1] if wire.repr == ReprEnum.FOCK else 0 - path = optimal(inputs=[frozenset(idx) for idx in indices], fock_size_dict=sizes) + contraction_order = optimal(inputs=[frozenset(idx) for idx in indices], fock_size_dict=sizes) - for a, b in path: + for a, b in contraction_order: common = list(set(indices[a]) & set(indices[b])) remaining = [i for i in indices[a] + indices[b] if i not in common] idx_a = [indices[a].index(i) for i in common] @@ -43,7 +60,8 @@ def mm_einsum(*args: list[CircuitComponent | list[int]]): ansatze.append(ansatze[a].contract(ansatze[b], idx_a, idx_b)) indices.append(remaining) - return ansatze[-1] + perm = [indices[-1].index(i) for i in args[-1]] + return ansatze[-1].reorder(perm) def _CV_flops(nA: int, nB: int, m: int) -> int: