Skip to content

Commit

Permalink
better docstring, fixed index
Browse files Browse the repository at this point in the history
  • Loading branch information
ziofil committed Jan 22, 2025
1 parent 5bf10ce commit b4711b4
Showing 1 changed file with 26 additions and 8 deletions.
34 changes: 26 additions & 8 deletions mrmustard/physics/mm_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,48 @@


def mm_einsum(*args: list[CircuitComponent | list[int]]):
"""
Assumes args = [cc1, lst1, cc2, lst2, ..., ccN, lstN, lstOut]
like np.einsum without the string.
r"""Performs tensor contractions between multiple circuit components using their indices.
This function is analogous to numpy's einsum but specialized for MrMustard's circuit components.
It automatically determines the optimal contraction order and handles both continuous-variable (CV)
and Fock-space representations.
Args:
*args: Alternating sequence of CircuitComponent objects and their corresponding index lists,
followed by a final output index list. The format should be:
[component1, indices1, component2, indices2, ..., componentN, indicesN, output_indices]
Returns:
CircuitComponent: The resulting circuit component after performing all contractions.
Notes:
- The function automatically determines the optimal contraction order to minimize computational cost
- Handles mixed CV and Fock-space representations
- Index values are arbitrary integers, but must be consistent across the expression
- The contraction behavior is similar to np.einsum but without requiring the equation string
"""
indices = list(args[1::2])
representations = args[:-1:2]
ansatze = [r.ansatz for r in representations]

sizes = dict()
for rep, idx in zip(representations, indices):
for i, wire in enumerate(rep.wires):
sizes[i] = rep.ansatz.array.shape[i + 1] if wire.repr == ReprEnum.FOCK else 0
for j, (i, wire) in enumerate(zip(idx, rep.wires)):
# i+1 because the first index is the batch dimension
sizes[i] = rep.ansatz.array.shape[j + 1] if wire.repr == ReprEnum.FOCK else 0

path = optimal(inputs=[frozenset(idx) for idx in indices], fock_size_dict=sizes)
contraction_order = optimal(inputs=[frozenset(idx) for idx in indices], fock_size_dict=sizes)

for a, b in path:
for a, b in contraction_order:
common = list(set(indices[a]) & set(indices[b]))
remaining = [i for i in indices[a] + indices[b] if i not in common]
idx_a = [indices[a].index(i) for i in common]
idx_b = [indices[b].index(i) for i in common]
ansatze.append(ansatze[a].contract(ansatze[b], idx_a, idx_b))
indices.append(remaining)

return ansatze[-1]
perm = [indices[-1].index(i) for i in args[-1]]
return ansatze[-1].reorder(perm)


def _CV_flops(nA: int, nB: int, m: int) -> int:
Expand Down

0 comments on commit b4711b4

Please sign in to comment.