From 94c136657ada13b307bf5dc8381bf256c2c6b422 Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Tue, 21 Jan 2025 22:18:07 -0800 Subject: [PATCH 1/7] added mm_einsum.py file --- mrmustard/physics/mm_einsum.py | 270 +++++++++++++++++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 mrmustard/physics/mm_einsum.py diff --git a/mrmustard/physics/mm_einsum.py b/mrmustard/physics/mm_einsum.py new file mode 100644 index 000000000..2c378224c --- /dev/null +++ b/mrmustard/physics/mm_einsum.py @@ -0,0 +1,270 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import itertools +from mrmustard.lab_dev import CircuitComponent +from mrmustard.physics.wires import ReprEnum + + +def mm_einsum(*args: list[CircuitComponent | list[int]]): + r"""Performs tensor contractions between multiple circuit components using their indices. + + This function is analogous to numpy's einsum but specialized for MrMustard's circuit components. + It automatically determines the optimal contraction order and handles both continuous-variable (CV) + and Fock-space representations. + + Args: + *args: Alternating sequence of CircuitComponent objects and their corresponding index lists, + followed by a final output index list. The format should be: + [component1, indices1, component2, indices2, ..., componentN, indicesN, output_indices] + + Returns: + CircuitComponent: The resulting circuit component after performing all contractions. + + Notes: + - The function automatically determines the optimal contraction order to minimize computational cost + - Handles mixed CV and Fock-space representations + - Index values are arbitrary integers, but must be consistent across the expression + - The contraction behavior is similar to np.einsum but without requiring the equation string + """ + indices = list(args[1::2]) + representations = args[:-1:2] + ansatze = [r.ansatz for r in representations] + + sizes = dict() + for rep, idx in zip(representations, indices): + for j, (i, wire) in enumerate(zip(idx, rep.wires)): + # i+1 because the first index is the batch dimension + sizes[i] = rep.ansatz.array.shape[j + 1] if wire.repr == ReprEnum.FOCK else 0 + + contraction_order = optimal(inputs=[frozenset(idx) for idx in indices], fock_size_dict=sizes) + + for a, b in contraction_order: + common = list(set(indices[a]) & set(indices[b])) + remaining = [i for i in indices[a] + indices[b] if i not in common] + idx_a = [indices[a].index(i) for i in common] + idx_b = [indices[b].index(i) for i in common] + ansatze.append(ansatze[a].contract(ansatze[b], idx_a, idx_b)) + indices.append(remaining) + + perm = [indices[-1].index(i) for i in args[-1]] + return ansatze[-1].reorder(perm) + + +def _CV_flops(nA: int, nB: int, m: int) -> int: + r"""Calculate the cost of contracting two tensors with CV indices. + Args: + nA: Number of CV indices in the first tensor + nB: Number of CV indices in the second tensor + m: Number of CV indices involved in the contraction + """ + cost = ( + m * m * m # M inverse + + (m + 1) * m * nA # left matmul + + (m + 1) * m * nB # right matmul + + (m + 1) * m # addition + + m * m * m + ) # determinant of M + return cost + + +def _fock_flops( + fock_contracted_shape: tuple[int, ...], fock_remaining_shape: tuple[int, ...] +) -> int: + r"""Calculate the cost of contracting two tensors with Fock indices. + Args: + fock_contracted_shape: shape of the indices that participate in the contraction + fock_remaining_shape: shape of the indices that do not + """ + if len(fock_contracted_shape) > 0: + return np.prod(fock_contracted_shape) * np.prod(fock_remaining_shape) + else: + return 0 + + +def new_indices_and_flops( + idx1: frozenset[int], idx2: frozenset[int], fock_size_dict: dict[int, int] +) -> tuple[frozenset[int], int]: + r"""Calculate the cost of contracting two tensors with mixed CV and Fock indices. + + This function computes both the surviving indices and the computational cost (in FLOPS) + of contracting two tensors that contain a mixture of continuous-variable (CV) and + Fock-space indices. + + Args: + idx1: Set of indices for the first tensor. CV indices are integers not present + in fock_size_dict. + idx2: Set of indices for the second tensor. CV indices are integers not present + in fock_size_dict. + fock_size_dict: Dict mapping Fock index labels to their dimensions. Any index + not in this dict is treated as a CV index. + + Returns: + tuple[frozenset[int], int]: A tuple containing: + - frozenset of indices that survive the contraction + - total computational cost in FLOPS (including CV operations, + Fock contractions, and potential decompositions) + + Example: + >>> idx1 = frozenset({0, 1}) # 0 is CV, 1 is Fock + >>> idx2 = frozenset({1, 2}) # 2 is Fock + >>> fock_size_dict = {1: 2, 2: 3} + >>> new_indices_and_flops(idx1, idx2, fock_size_dict) + (frozenset({0, 2}), 9) + """ + + # Calculate index sets for contraction + contracted_indices = idx1 & idx2 # Indices that get contracted away + remaining_indices = idx1 ^ idx2 # Indices that remain after contraction + all_fock_indices = set(fock_size_dict.keys()) + + # Count CV and get Fock shapes + num_cv_contracted = len(contracted_indices - all_fock_indices) + fock_contracted_shape = [fock_size_dict[idx] for idx in contracted_indices & all_fock_indices] + fock_remaining_shape = [fock_size_dict[idx] for idx in remaining_indices & all_fock_indices] + + # Calculate flops + cv_flops = _CV_flops( + nA=len(idx1) - num_cv_contracted, nB=len(idx2) - num_cv_contracted, m=num_cv_contracted + ) + + fock_flops = _fock_flops(fock_contracted_shape, fock_remaining_shape) + + # Try decomposing the remaining indices + new_indices, decomp_flops = attempt_decomposition(remaining_indices, fock_size_dict) + + # flops for evaluating the ansatz with the remaining indices (measures ansatz complexity) + eval_flops = np.prod([fock_size_dict[idx] for idx in new_indices if idx in fock_size_dict]) + + total_flops = int(cv_flops + fock_flops + decomp_flops + eval_flops) + return new_indices, total_flops + + +def attempt_decomposition( + indices: set[int], fock_size_dict: dict[int, int] +) -> tuple[set[int], int]: + r"""Attempt to reduce the number of indices by combining Fock indices, + which is possible if there is only one CV index and multiple Fock indices. + (This is Kasper's decompose method). + + Args: + indices: Set of indices to potentially decompose + fock_size_dict: Dictionary mapping indices to their dimensions + + Returns: + tuple[frozenset[int], int]: A tuple containing: + - frozenset of decomposed indices + - computational cost of decomposition in FLOPS + """ + fock_indices_shape = [fock_size_dict[idx] for idx in indices if idx in fock_size_dict] + cv_indices = [idx for idx in indices if idx not in fock_size_dict] + + if len(cv_indices) == 1 and len(fock_indices_shape) > 1: + new_index = max(fock_size_dict) + 1 # Create new index with size = sum of Fock index sizes + decomposed_indices = {cv_indices[0], new_index} + fock_size_dict[new_index] = sum(fock_indices_shape) + decomp_flops = np.prod(fock_indices_shape) + return frozenset(decomposed_indices), decomp_flops + return frozenset(indices), 0 + + +def optimal( + inputs: list[frozenset[int]], + fock_size_dict: dict[int, int], + info: bool = False, +) -> list[tuple[int, int]]: + r"""Find the optimal contraction path for a mixed CV-Fock tensor network. + + This function performs an exhaustive search over all possible contraction orders + for a tensor network containing both continuous-variable (CV) and Fock-space tensors. + It uses a depth-first recursive strategy to find the sequence of pairwise contractions + that minimizes the total computational cost (FLOPS). + + CV indices are represented by integers not present in fock_size_dict, while Fock + indices must be keys in fock_size_dict. The algorithm caches intermediate results, + skips outer products (contractions between tensors with no shared indices), and + prunes the search when partial paths exceed the current best cost. + + Args: + inputs: List of index sets representing tensor indices + fock_size_dict: Mapping from Fock index labels to dimensions + info: If True, prints cache size diagnostics + + Returns: + tuple[tuple[int, int], ...]: The optimal contraction path as a sequence of pairs. + Each pair (i, j) indicates that tensors at positions i and j should be + contracted together. The resulting tensor is placed at position len(inputs). + + Example: + >>> inputs = [frozenset({0, 1}), frozenset({1, 2}), frozenset({2, 3})] + >>> fock_size_dict = {1: 2, 2: 2} # indices 0 and 3 are CV indices + >>> optimal(inputs, fock_size_dict) + ((0, 1), (2, 3)) + + Reference: + Based on the optimal path finder in opt_einsum: + https://github.com/dgasmith/opt_einsum/blob/master/opt_einsum/paths.py + """ + best_flops: int = float("inf") + best_path: tuple[tuple[int, int], ...] = () + result_cache: dict[tuple[frozenset[int], frozenset[int]], tuple[frozenset[int], int]] = {} + + def _optimal_iterate(path, remaining, inputs, flops): + nonlocal best_flops + nonlocal best_path + + if len(remaining) == 1: + best_flops = flops + best_path = path + return + + # check all remaining paths + for i, j in itertools.combinations(remaining, 2): + if i > j: + i, j = j, i + + # skip outer products + if not inputs[i] & inputs[j]: + continue + + key = (inputs[i], inputs[j]) + try: + new_indices, flops_ij = result_cache[key] + except KeyError: + new_indices, flops_ij = result_cache[key] = new_indices_and_flops( + *key, fock_size_dict + ) + + # sieve based on current best flops + new_flops = flops + flops_ij + if new_flops >= best_flops: + continue + + # add contraction and recurse into all remaining + _optimal_iterate( + path=path + ((i, j),), + inputs=inputs + (new_indices,), + remaining=remaining - {i, j} | {len(inputs)}, + flops=new_flops, + ) + + _optimal_iterate( + path=(), inputs=tuple(map(frozenset, inputs)), remaining=set(range(len(inputs))), flops=0 + ) + + if info: + print("len(fock_size_dict)", len(fock_size_dict), "len(result_cache)", len(result_cache)) + return best_path From 3b20398e0cab8baf1d6d5f366783ad200142bf5e Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Tue, 21 Jan 2025 22:19:15 -0800 Subject: [PATCH 2/7] added tests --- tests/test_physics/test_mm_einsum.py | 42 ++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/test_physics/test_mm_einsum.py diff --git a/tests/test_physics/test_mm_einsum.py b/tests/test_physics/test_mm_einsum.py new file mode 100644 index 000000000..282bc0814 --- /dev/null +++ b/tests/test_physics/test_mm_einsum.py @@ -0,0 +1,42 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the mm_einsum function.""" + +from mrmustard.physics.mm_einsum import mm_einsum +from mrmustard.lab_dev import Number, BSgate, QuadratureEigenstate +from mrmustard.physics.wires import Wires +import numpy as np + + +def test_mm_einsum(): + n = 100 + input1 = Number([0], n).to_fock().dm() + input2 = Number([1], n).to_fock().dm() + + bs_p = (BSgate([0, 1], np.pi / 4) >> QuadratureEigenstate([1], 0, np.pi / 2).dual).to_fock( + (2 * n + 1, n + 1, n + 1) + ) + + part1 = input1 @ bs_p.adjoint + part1.representation._wires = Wires(modes_out_bra={0, 1}, modes_out_ket={0}) + + part2 = input2 @ bs_p + part2.representation._wires = Wires(modes_in_bra={1}, modes_out_ket={0}, modes_in_ket={0}) + + expected = (part1 >> part2).representation.ansatz.array + result = mm_einsum( + input1, [0, 1], input2, [2, 3], bs_p.adjoint, [4, 0, 2], bs_p, [5, 1, 3], [4, 5] + ) + assert np.allclose(expected, result.array) From 7925097cf0e174bbc0bcae84ca35036097edb1d6 Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Tue, 21 Jan 2025 22:19:19 -0800 Subject: [PATCH 3/7] year --- mrmustard/physics/mm_einsum.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mrmustard/physics/mm_einsum.py b/mrmustard/physics/mm_einsum.py index 2c378224c..1ab46f14f 100644 --- a/mrmustard/physics/mm_einsum.py +++ b/mrmustard/physics/mm_einsum.py @@ -1,4 +1,4 @@ -# Copyright 2024 Xanadu Quantum Technologies Inc. +# Copyright 2025 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From d514dc79d0371b0695d1d041bde6d052b138141c Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Tue, 21 Jan 2025 22:25:27 -0800 Subject: [PATCH 4/7] codefactor --- mrmustard/physics/mm_einsum.py | 5 +++-- tests/test_physics/test_mm_einsum.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/mrmustard/physics/mm_einsum.py b/mrmustard/physics/mm_einsum.py index 1ab46f14f..3828d8206 100644 --- a/mrmustard/physics/mm_einsum.py +++ b/mrmustard/physics/mm_einsum.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Implementation of the mm_einsum function.""" -import numpy as np import itertools +import numpy as np from mrmustard.lab_dev import CircuitComponent from mrmustard.physics.wires import ReprEnum @@ -44,7 +45,7 @@ def mm_einsum(*args: list[CircuitComponent | list[int]]): representations = args[:-1:2] ansatze = [r.ansatz for r in representations] - sizes = dict() + sizes = {} for rep, idx in zip(representations, indices): for j, (i, wire) in enumerate(zip(idx, rep.wires)): # i+1 because the first index is the batch dimension diff --git a/tests/test_physics/test_mm_einsum.py b/tests/test_physics/test_mm_einsum.py index 282bc0814..cef83194f 100644 --- a/tests/test_physics/test_mm_einsum.py +++ b/tests/test_physics/test_mm_einsum.py @@ -14,13 +14,15 @@ """Tests for the mm_einsum function.""" +import numpy as np from mrmustard.physics.mm_einsum import mm_einsum from mrmustard.lab_dev import Number, BSgate, QuadratureEigenstate from mrmustard.physics.wires import Wires -import numpy as np def test_mm_einsum(): + r"""tests that the mm_einsum function returns the correct result compared to a manual contraction. + This specific example cannot be done with the usual >> and @ operators.""" n = 100 input1 = Number([0], n).to_fock().dm() input2 = Number([1], n).to_fock().dm() From 26e3851e1412cf677ef7b8d0c4f2203593f1f17b Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Tue, 21 Jan 2025 22:41:11 -0800 Subject: [PATCH 5/7] fix typing --- mrmustard/physics/mm_einsum.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mrmustard/physics/mm_einsum.py b/mrmustard/physics/mm_einsum.py index 3828d8206..29e476aff 100644 --- a/mrmustard/physics/mm_einsum.py +++ b/mrmustard/physics/mm_einsum.py @@ -14,6 +14,7 @@ """Implementation of the mm_einsum function.""" +from __future__ import annotations import itertools import numpy as np from mrmustard.lab_dev import CircuitComponent From d73ea2f9f2c791d93fbe1927e40d65c15928c91e Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Fri, 24 Jan 2025 10:42:58 -0800 Subject: [PATCH 6/7] fixed wires repr and new test --- mrmustard/physics/wires.py | 95 ++++++++++++++++++++++---------- tests/test_physics/test_wires.py | 16 +++++- 2 files changed, 81 insertions(+), 30 deletions(-) diff --git a/mrmustard/physics/wires.py b/mrmustard/physics/wires.py index ddbb22063..224bca977 100644 --- a/mrmustard/physics/wires.py +++ b/mrmustard/physics/wires.py @@ -20,7 +20,7 @@ class LegibleEnum(Enum): - """Enum class that provides a more legible string representation.""" + r"""Enum class that provides a more legible string representation.""" def __str__(self) -> str: return self.name @@ -30,7 +30,7 @@ def __repr__(self) -> str: class ReprEnum(LegibleEnum): - """Enumeration of possible representations for quantum states and operations.""" + r"""Enumeration of possible representations for quantum states and operations.""" UNSPECIFIED = auto() BARGMANN = auto() @@ -41,7 +41,7 @@ class ReprEnum(LegibleEnum): class WiresType(LegibleEnum): - """Enumeration of possible wire types in quantum circuits.""" + r"""Enumeration of possible wire types in quantum circuits.""" DM_LIKE = auto() # only output ket and bra on same modes KET_LIKE = auto() # only output ket @@ -54,7 +54,7 @@ class WiresType(LegibleEnum): @dataclass class QuantumWire: - """ + r""" Represents a quantum wire in a circuit. Args: @@ -96,7 +96,7 @@ def __eq__(self, other: QuantumWire) -> bool: ) def copy(self, new_id: bool = False) -> QuantumWire: - """Create a copy of the quantum wire. + r"""Create a copy of the quantum wire. Args: new_id (bool): If True, generates a new ID for the copy. Defaults to False. @@ -115,7 +115,7 @@ def copy(self, new_id: bool = False) -> QuantumWire: ) def _order(self) -> int: - """ + r""" Artificial ordering for sorting quantum wires. Order achieved is by bra/ket, then out/in, then mode. """ @@ -124,7 +124,7 @@ def _order(self) -> int: @dataclass class ClassicalWire: - """ + r""" Represents a classical wire in a circuit. Args: @@ -159,7 +159,7 @@ def __eq__(self, other: ClassicalWire) -> bool: return self.mode == other.mode and self.is_out == other.is_out and self.repr == other.repr def copy(self, new_id: bool = False) -> ClassicalWire: - """Returns a copy of the classical wire.""" + r"""Returns a copy of the classical wire.""" return ClassicalWire( mode=self.mode, is_out=self.is_out, @@ -170,13 +170,16 @@ def copy(self, new_id: bool = False) -> ClassicalWire: ) def _order(self) -> int: - """ + r""" Artificial ordering for sorting classical wires. Order is by out/in, then mode. Classical wires always come after quantum wires. """ return 1000_000 + self.mode + 10_000 * (1 - 2 * self.is_out) +## MARK: WIRES + + class Wires: # pylint: disable=too-many-public-methods r""" A class with wire functionality for tensor network applications. @@ -371,43 +374,77 @@ def from_wires( return w def copy(self, new_ids: bool = False) -> Wires: - """Returns a deep copy of this Wires object.""" + r"""Returns a deep copy of this Wires object.""" return Wires.from_wires( quantum={q.copy(new_ids) for q in self.quantum_wires}, classical={c.copy(new_ids) for c in self.classical_wires}, ) - ###### NEW WIRES ###### + def _transform( + self, + quantum_transform: Callable[[QuantumWire], QuantumWire], + classical_transform: Callable[[ClassicalWire], ClassicalWire] = lambda x: x, + reindex: bool = False, + ) -> Wires: + r""" + Applies a transformation function to each wire, preserving their data + (mode, representation, etc.) unless deliberately changed. + + Args: + quantum_transform: A function that takes a QuantumWire and returns a new QuantumWire. + classical_transform: A function that takes a ClassicalWire and returns a new ClassicalWire. + reindex: Whether to reindex the wires. + + Returns: + A new Wires object with the transformed wires. + """ + new_quantum = {quantum_transform(q) for q in self.quantum_wires} + new_classical = {classical_transform(c) for c in self.classical_wires} + new_wires = Wires.from_wires(quantum=new_quantum, classical=new_classical, copy=False) + if reindex: + new_wires._reindex() + return new_wires + + ## MARK: NEW WIRES @cached_property def adjoint(self) -> Wires: r""" New ``Wires`` object with the adjoint quantum wires (ket becomes bra and vice versa). """ - return Wires( - modes_out_bra=self.output.ket.modes, - modes_in_bra=self.input.ket.modes, - modes_out_ket=self.output.bra.modes, - modes_in_ket=self.input.bra.modes, - classical_out=self.output.classical.modes, - classical_in=self.input.classical.modes, - ) + + def _adjoint_transform(wire: QuantumWire) -> QuantumWire: + new_wire = wire.copy() + new_wire.is_ket = not wire.is_ket + return new_wire + + return self._transform(quantum_transform=_adjoint_transform, reindex=True) @cached_property def dual(self) -> Wires: r""" New ``Wires`` object with dual quantum and classical wires (input becomes output and vice versa). """ - return Wires( - modes_out_bra=self.input.bra.modes, - modes_in_bra=self.output.bra.modes, - modes_out_ket=self.input.ket.modes, - modes_in_ket=self.output.ket.modes, - classical_out=self.input.classical.modes, - classical_in=self.output.classical.modes, + + def _dual_transform(wire: QuantumWire | ClassicalWire) -> QuantumWire | ClassicalWire: + new_wire = wire.copy() + new_wire.is_out = not wire.is_out + return new_wire + + return self._transform( + quantum_transform=_dual_transform, classical_transform=_dual_transform, reindex=True ) - ###### SUBSETS OF WIRES ###### + ## MARK: SUBSETS OF WIRES + + def wire(self, mode: int, is_ket: bool, is_out: bool) -> QuantumWire | ClassicalWire: + r""" + Returns the wire with the given mode, is_ket, and is_out. + """ + return next( + (w for w in self.wires if w.mode == mode and w.is_ket == is_ket and w.is_out == is_out), + None, + ) def __getitem__(self, modes: tuple[int, ...] | int) -> Wires: r""" @@ -473,7 +510,7 @@ def output(self) -> Wires: classical={c for c in self.classical_wires if c.is_out}, ) - ###### PROPERTIES ###### + ## MARK: PROPERTIES @property def modes(self) -> set[int]: @@ -517,7 +554,7 @@ def wires(self) -> list[QuantumWire | ClassicalWire]: """ return sorted({*self.quantum_wires, *self.classical_wires}, key=lambda s: s._order()) - ###### METHODS ###### + ## MARK: METHODS def overlap(self, other: Wires) -> tuple[set[int], set[int]]: r""" diff --git a/tests/test_physics/test_wires.py b/tests/test_physics/test_wires.py index 0e9ba7f78..1a635dbcd 100644 --- a/tests/test_physics/test_wires.py +++ b/tests/test_physics/test_wires.py @@ -22,7 +22,7 @@ from ipywidgets import HTML from mrmustard.lab_dev.states import QuadratureEigenstate -from mrmustard.physics.wires import Wires +from mrmustard.physics.wires import Wires, ReprEnum from ..conftest import skip_np @@ -151,6 +151,20 @@ def test_matmul_error(self): with pytest.raises(ValueError): u @ v # pylint: disable=pointless-statement + def test_wire_properties_adjoint(self): + w = Wires(modes_out_bra={0}, modes_out_ket={0}) + w.wire(0, is_ket=True, is_out=True).repr = ReprEnum.FOCK + + w_adj = w.adjoint + assert w_adj.wire(0, is_ket=False, is_out=True).repr == ReprEnum.FOCK + + def test_wire_properties_dual(self): + w = Wires(modes_out_bra={0}, modes_in_bra={0}) + w.wire(0, is_ket=False, is_out=True).repr = ReprEnum.FOCK + + w_d = w.dual + assert w_d.wire(0, is_ket=False, is_out=False).repr == ReprEnum.FOCK + class TestWiresDisplay: """Test the wires _ipython_display_ functionality.""" From f39c94801d2cb2dd82dd32132684075ee8cccf7d Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Fri, 24 Jan 2025 10:43:14 -0800 Subject: [PATCH 7/7] make format --- tests/test_training/test_callbacks.py | 2 +- ...po Miatto's conflicted copy 2025-01-24).py | 592 ++++++++++++++++++ ...po Miatto's conflicted copy 2025-01-24).py | 268 ++++++++ 3 files changed, 861 insertions(+), 1 deletion(-) create mode 100644 tests/test_training/test_opt_lab_dev (Filippo Miatto's conflicted copy 2025-01-24).py create mode 100644 tests/test_training/test_trainer (Filippo Miatto's conflicted copy 2025-01-24).py diff --git a/tests/test_training/test_callbacks.py b/tests/test_training/test_callbacks.py index 02dbf6fb7..c1bfd9b27 100644 --- a/tests/test_training/test_callbacks.py +++ b/tests/test_training/test_callbacks.py @@ -18,7 +18,7 @@ import tensorflow as tf from mrmustard import math, settings -from mrmustard.lab_dev import Circuit, BSgate, S2gate, Vacuum +from mrmustard.lab_dev import BSgate, Circuit, S2gate, Vacuum from mrmustard.training import Optimizer, TensorboardCallback from ..conftest import skip_np diff --git a/tests/test_training/test_opt_lab_dev (Filippo Miatto's conflicted copy 2025-01-24).py b/tests/test_training/test_opt_lab_dev (Filippo Miatto's conflicted copy 2025-01-24).py new file mode 100644 index 000000000..f83f18396 --- /dev/null +++ b/tests/test_training/test_opt_lab_dev (Filippo Miatto's conflicted copy 2025-01-24).py @@ -0,0 +1,592 @@ +# Copyright 2022 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the Optimizer class""" + +import numpy as np +import tensorflow as tf +from hypothesis import given +from hypothesis import strategies as st +from thewalrus.symplectic import two_mode_squeezing + +from mrmustard import math, settings +from mrmustard.lab_dev import ( + BSgate, + Circuit, + Dgate, + DisplacedSqueezed, + Ggate, + GKet, + Interferometer, + Number, + RealInterferometer, + Rgate, + S2gate, + Sgate, + SqueezedVacuum, + TwoModeSqueezedVacuum, + Vacuum, +) +from mrmustard.math.parameters import Variable, update_euclidean +from mrmustard.physics.gaussian import number_means, von_neumann_entropy +from mrmustard.training import Optimizer +from mrmustard.training.callbacks import Callback + +from ..conftest import skip_np + + +class TestOptimizer: + r""" + Tests for the ``Optimizer`` class. + """ + + @given(n=st.integers(0, 3)) + def test_S2gate_coincidence_prob(self, n): + """Testing the optimal probability of obtaining |n,n> from a two mode squeezed vacuum""" + skip_np() + + settings.SEED = 40 + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + S = TwoModeSqueezedVacuum( + (0, 1), r=abs(settings.rng.normal(loc=1.0, scale=0.1)), r_trainable=True + ) + + def cost_fn(): + return -math.abs(S.fock_array((n + 1, n + 1))[n, n]) ** 2 + + def cb(optimizer, cost, trainables, **kwargs): # pylint: disable=unused-argument + return { + "cost": cost, + "lr": optimizer.learning_rate[update_euclidean], + "num_trainables": len(trainables), + } + + opt = Optimizer(euclidean_lr=0.01) + opt.minimize(cost_fn, by_optimizing=[S], max_steps=300, callbacks=cb) + + expected = 1 / (n + 1) * (n / (n + 1)) ** n + assert np.allclose(-cost_fn(), expected, atol=1e-5) + + cb_result = opt.callback_history.get("cb") + assert {res["num_trainables"] for res in cb_result} == {1} + assert {res["lr"] for res in cb_result} == {0.01} + assert [res["cost"] for res in cb_result] == opt.opt_history[1:] + + @given(i=st.integers(1, 5), k=st.integers(1, 5)) + def test_hong_ou_mandel_optimizer(self, i, k): + """Finding the optimal beamsplitter transmission to get Hong-Ou-Mandel dip + This generalizes the single photon Hong-Ou-Mandel effect to the many photon setting + see Eq. 20 of https://journals.aps.org/prresearch/pdf/10.1103/PhysRevResearch.3.043065 + which lacks a square root in the right hand side. + """ + skip_np() + + settings.SEED = 42 + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + r = np.arcsinh(1.0) + cutoff = 1 + i + k + + state = TwoModeSqueezedVacuum((0, 1), r=r, phi_trainable=True) + state2 = TwoModeSqueezedVacuum((2, 3), r=r, phi_trainable=True) + bs = BSgate( + (1, 2), + theta=np.arccos(np.sqrt(k / (i + k))) + 0.1 * settings.rng.normal(), + phi=settings.rng.normal(), + theta_trainable=True, + phi_trainable=True, + ) + circ = Circuit([state, state2, bs]) + + def cost_fn(): + return math.abs(circ.contract().fock_array((cutoff,) * 4)[i, 1, i + k - 1, k]) ** 2 + + opt = Optimizer(euclidean_lr=0.01) + opt.minimize( + cost_fn, + by_optimizing=[circ], + max_steps=300, + callbacks=[Callback(tag="null_cb", steps_per_call=3)], + ) + assert np.allclose(np.cos(bs.parameters.theta.value) ** 2, k / (i + k), atol=1e-2) + assert "null_cb" in opt.callback_history + assert len(opt.callback_history["null_cb"]) == (len(opt.opt_history) - 1) // 3 + + def test_learning_two_mode_squeezing(self): + """Finding the optimal beamsplitter transmission to make a pair of single photons""" + skip_np() + + settings.SEED = 42 + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + state_in = Vacuum((0, 1)) + s_gate = Sgate( + (0, 1), + r=abs(settings.rng.normal(size=2)), + phi=settings.rng.normal(size=2), + r_trainable=True, + phi_trainable=True, + ) + bs_gate = BSgate( + (0, 1), + theta=settings.rng.normal(), + phi=settings.rng.normal(), + theta_trainable=True, + phi_trainable=True, + ) + circ = Circuit([state_in, s_gate, bs_gate]) + + def cost_fn(): + amps = circ.contract().fock_array((2, 2)) + return -math.abs(amps[1, 1]) ** 2 + math.abs(amps[0, 1]) ** 2 + + opt = Optimizer(euclidean_lr=0.05) + + opt.minimize(cost_fn, by_optimizing=[circ], max_steps=300) + assert np.allclose(-cost_fn(), 0.25, atol=1e-5) + + def test_learning_two_mode_Ggate(self): + """Finding the optimal Ggate to make a pair of single photons""" + skip_np() + + settings.SEED = 42 + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + G = GKet((0, 1), symplectic_trainable=True) + + def cost_fn(): + amps = G.fock_array((2, 2)) + return -math.abs(amps[1, 1]) ** 2 + math.abs(amps[0, 1]) ** 2 + + opt = Optimizer(symplectic_lr=0.5, euclidean_lr=0.01) + + opt.minimize(cost_fn, by_optimizing=[G], max_steps=500) + assert np.allclose(-cost_fn(), 0.25, atol=1e-4) + + def test_learning_two_mode_Interferometer(self): + """Finding the optimal Interferometer to make a pair of single photons""" + skip_np() + + settings.SEED = 4 + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + state_in = Vacuum((0, 1)) + s_gate = Sgate( + (0, 1), + r=settings.rng.normal(size=2) ** 2, + phi=settings.rng.normal(size=2), + r_trainable=True, + phi_trainable=True, + ) + interferometer = Interferometer((0, 1), unitary_trainable=True) + circ = Circuit([state_in, s_gate, interferometer]) + + def cost_fn(): + amps = circ.contract().fock_array((2, 2)) + return -math.abs(amps[1, 1]) ** 2 + math.abs(amps[0, 1]) ** 2 + + opt = Optimizer(unitary_lr=0.5, euclidean_lr=0.01) + + opt.minimize(cost_fn, by_optimizing=[circ], max_steps=1000) + assert np.allclose(-cost_fn(), 0.25, atol=1e-5) + + def test_learning_two_mode_RealInterferometer(self): + """Finding the optimal Interferometer to make a pair of single photons""" + skip_np() + + settings.SEED = 2 + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + state_in = Vacuum((0, 1)) + s_gate = Sgate( + (0, 1), + r=settings.rng.normal(size=2) ** 2, + phi=settings.rng.normal(size=2), + r_trainable=True, + phi_trainable=True, + ) + r_inter = RealInterferometer((0, 1), orthogonal_trainable=True) + + circ = Circuit([state_in, s_gate, r_inter]) + + def cost_fn(): + amps = circ.contract().fock_array((2, 2)) + return -math.abs(amps[1, 1]) ** 2 + math.abs(amps[0, 1]) ** 2 + + opt = Optimizer(orthogonal_lr=0.5, euclidean_lr=0.01) + + opt.minimize(cost_fn, by_optimizing=[circ], max_steps=1000) + assert np.allclose(-cost_fn(), 0.25, atol=1e-5) + + def test_learning_four_mode_Interferometer(self): + """Finding the optimal Interferometer to make a NOON state with N=2""" + skip_np() + + settings.SEED = 4 + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + solution_U = np.array( + [ + [ + -0.47541806 + 0.00045878j, + -0.41513474 - 0.27218387j, + -0.11065812 - 0.39556922j, + -0.29912017 + 0.51900235j, + ], + [ + -0.05246398 + 0.5209089j, + -0.29650069 - 0.40653082j, + 0.57434638 - 0.04417284j, + 0.28230532 - 0.24738672j, + ], + [ + 0.28437557 + 0.08773767j, + 0.18377764 - 0.66496587j, + -0.5874942 - 0.19866946j, + 0.2010813 - 0.10210844j, + ], + [ + -0.63173183 - 0.11057324j, + -0.03468292 + 0.15245454j, + -0.25390362 - 0.2244298j, + 0.18706333 - 0.64375049j, + ], + ] + ) + perturbed = ( + Interferometer((0, 1, 2, 3), unitary=solution_U) + >> BSgate((0, 1), settings.rng.normal(scale=0.01)) + >> BSgate((2, 3), settings.rng.normal(scale=0.01)) + >> BSgate((1, 2), settings.rng.normal(scale=0.01)) + >> BSgate((0, 3), settings.rng.normal(scale=0.01)) + ) + X = perturbed.symplectic[0] + perturbed_U = X[:4, :4] + 1j * X[4:, :4] + + state_in = Vacuum((0, 1, 2, 3)) + s_gate = Sgate( + (0, 1, 2, 3), + r=settings.rng.normal(loc=np.arcsinh(1.0), scale=0.01, size=4), + r_trainable=True, + ) + interferometer = Interferometer((0, 1, 2, 3), unitary=perturbed_U, unitary_trainable=True) + + circ = Circuit([state_in, s_gate, interferometer]) + + def cost_fn(): + amps = circ.contract().fock_array((3, 3, 3, 3)) + return -math.abs((amps[1, 1, 2, 0] + amps[1, 1, 0, 2]) / np.sqrt(2)) ** 2 + + opt = Optimizer(unitary_lr=0.05) + opt.minimize(cost_fn, by_optimizing=[circ], max_steps=200) + assert np.allclose(-cost_fn(), 0.0625, atol=1e-5) + + def test_learning_four_mode_RealInterferometer(self): + """Finding the optimal Interferometer to make a NOON state with N=2""" + skip_np() + + settings.SEED = 6 + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + solution_O = np.array( + [ + [0.5, -0.5, 0.5, 0.5], + [-0.5, -0.5, -0.5, 0.5], + [0.5, 0.5, -0.5, 0.5], + [0.5, -0.5, -0.5, -0.5], + ] + ) + solution_S = (np.arcsinh(1.0), np.array([0.0, np.pi / 2, -np.pi, -np.pi / 2])) + pertubed = ( + RealInterferometer((0, 1, 2, 3), orthogonal=solution_O) + >> BSgate((0, 1), settings.rng.normal(scale=0.01)) + >> BSgate((2, 3), settings.rng.normal(scale=0.01)) + >> BSgate((1, 2), settings.rng.normal(scale=0.01)) + >> BSgate((0, 3), settings.rng.normal(scale=0.01)) + ) + perturbed_O = pertubed.symplectic[0][:4, :4] + + state_in = Vacuum((0, 1, 2, 3)) + s_gate = Sgate( + (0, 1, 2, 3), + r=solution_S[0] + settings.rng.normal(scale=0.01, size=4), + phi=solution_S[1] + settings.rng.normal(scale=0.01, size=4), + r_trainable=True, + phi_trainable=True, + ) + r_inter = RealInterferometer( + (0, 1, 2, 3), orthogonal=perturbed_O, orthogonal_trainable=True + ) + + circ = Circuit([state_in, s_gate, r_inter]) + + def cost_fn(): + amps = circ.contract().fock_array((2, 2, 3, 3)) + return -math.abs((amps[1, 1, 0, 2] + amps[1, 1, 2, 0]) / np.sqrt(2)) ** 2 + + opt = Optimizer() + + opt.minimize(cost_fn, by_optimizing=[circ], max_steps=200) + assert np.allclose(-cost_fn(), 0.0625, atol=1e-5) + + def test_squeezing_hong_ou_mandel_optimizer(self): + """Finding the optimal squeezing parameter to get Hong-Ou-Mandel dip in time + see https://www.pnas.org/content/117/52/33107/tab-article-info + """ + skip_np() + + settings.SEED = 42 + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + r = np.arcsinh(1.0) + + state_in = Vacuum((0, 1, 2, 3)) + S_01 = S2gate((0, 1), r=r, phi=0.0, phi_trainable=True) + S_23 = S2gate((2, 3), r=r, phi=0.0, phi_trainable=True) + S_12 = S2gate( + (1, 2), r=1.0, phi=settings.rng.normal(), r_trainable=True, phi_trainable=True + ) + + circ = Circuit([state_in, S_01, S_23, S_12]) + + def cost_fn(): + return math.abs(circ.contract().fock_array((2, 2, 2, 2))[1, 1, 1, 1]) ** 2 + + opt = Optimizer(euclidean_lr=0.001) + opt.minimize(cost_fn, by_optimizing=[circ], max_steps=300) + assert np.allclose(np.sinh(S_12.parameters.r.value) ** 2, 1, atol=1e-2) + + def test_parameter_passthrough(self): + """Same as the test above, but with param passthrough""" + skip_np() + + settings.SEED = 42 + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + r = np.arcsinh(1.0) + r_var = Variable(r, "r", (0.0, None)) + phi_var = Variable(settings.rng.normal(), "phi", (None, None)) + + state_in = Vacuum((0, 1, 2, 3)) + s2_gate0 = S2gate((0, 1), r=r, phi=0.0, phi_trainable=True) + s2_gate1 = S2gate((2, 3), r=r, phi=0.0, phi_trainable=True) + s2_gate2 = S2gate((1, 2), r=r_var, phi=phi_var) + + circ = Circuit([state_in, s2_gate0, s2_gate1, s2_gate2]) + + def cost_fn(): + return math.abs(circ.contract().fock_array((2, 2, 2, 2))[1, 1, 1, 1]) ** 2 + + opt = Optimizer(euclidean_lr=0.001) + opt.minimize(cost_fn, by_optimizing=[r_var, phi_var], max_steps=300) + assert np.allclose(np.sinh(r_var.value) ** 2, 1, atol=1e-2) + + def test_making_thermal_state_as_one_half_two_mode_squeezed_vacuum(self): + """Optimizes a Ggate on two modes so as to prepare a state with the same entropy + and mean photon number as a thermal state""" + skip_np() + + settings.SEED = 42 + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + def thermal_entropy(nbar): + return -(nbar * np.log((nbar) / (1 + nbar)) - np.log(1 + nbar)) + + nbar = 1.4 + S_init = two_mode_squeezing(np.arcsinh(1.0), 0.0) + S = thermal_entropy(nbar) + + G = Ggate((0, 1), symplectic=S_init, symplectic_trainable=True) + + def cost_fn(): + state = Vacuum((0, 1)) >> G + + state0 = state[0] + state1 = state[1] + + cov0, mean0, _ = [x[0] for x in state0.phase_space(s=0)] + cov1, mean1, _ = [x[0] for x in state1.phase_space(s=0)] + + num_mean0 = number_means(cov0, mean0)[0] + num_mean1 = number_means(cov1, mean1)[0] + + entropy = von_neumann_entropy(cov0) + return (num_mean0 - nbar) ** 2 + (entropy - S) ** 2 + (num_mean1 - nbar) ** 2 + + opt = Optimizer(symplectic_lr=0.1) + opt.minimize(cost_fn, by_optimizing=[G], max_steps=50) + S = math.asnumpy(G.parameters.symplectic.value) + cov = S @ S.T + assert np.allclose(cov, two_mode_squeezing(2 * np.arcsinh(np.sqrt(nbar)), 0.0)) + + def test_opt_backend_param(self): + """Test the optimization of a backend parameter defined outside a gate.""" + skip_np() + + # rotated displaced squeezed state + settings.SEED = 42 + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + rotation_angle = np.pi / 2 + target_state = SqueezedVacuum((0,), r=1.0, phi=rotation_angle) + + # angle of rotation gate + r_angle = math.new_variable(0, bounds=(0, np.pi), name="r_angle") + # trainable squeezing + S = Sgate((0,), r=0.1, phi=0, r_trainable=True, phi_trainable=False) + + def cost_fn_sympl(): + state_out = Vacuum((0,)) >> S >> Rgate((0,), theta=r_angle) + return 1 - math.abs((state_out >> target_state.dual) ** 2) + + opt = Optimizer(symplectic_lr=0.1, euclidean_lr=0.05) + opt.minimize(cost_fn_sympl, by_optimizing=[S, r_angle]) + + assert np.allclose(math.asnumpy(r_angle), rotation_angle / 2, atol=1e-4) + + def test_dgate_optimization(self): + """Test that Dgate is optimized correctly.""" + skip_np() + + settings.SEED = 24 + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + dgate = Dgate((0,), x_trainable=True, y_trainable=True) + target_state = DisplacedSqueezed((0,), r=0.0, x=0.1, y=0.2).fock_array((40,)) + + def cost_fn(): + state_out = Vacuum((0,)) >> dgate + return -math.abs(math.sum(math.conj(state_out.fock_array((40,))) * target_state)) ** 2 + + opt = Optimizer() + opt.minimize(cost_fn, by_optimizing=[dgate]) + + assert np.allclose(dgate.parameters.x.value, 0.1, atol=0.01) + assert np.allclose(dgate.parameters.y.value, 0.2, atol=0.01) + + def test_sgate_optimization(self): + """Test that Sgate is optimized correctly.""" + skip_np() + + settings.SEED = 25 + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + sgate = Sgate((0,), r=0.2, phi=0.1, r_trainable=True, phi_trainable=True) + target_state = SqueezedVacuum((0,), r=0.1, phi=0.2).fock_array((40,)) + + def cost_fn(): + state_out = Vacuum((0,)) >> sgate + + return -math.abs(math.sum(math.conj(state_out.fock_array((40,))) * target_state)) ** 2 + + opt = Optimizer() + opt.minimize(cost_fn, by_optimizing=[sgate]) + + assert np.allclose(sgate.parameters.r.value, 0.1, atol=0.01) + assert np.allclose(sgate.parameters.phi.value, 0.2, atol=0.01) + + def test_bsgate_optimization(self): + """Test that Sgate is optimized correctly.""" + skip_np() + + settings.SEED = 25 + rng = tf.random.get_global_generator() + rng.reset_from_seed(settings.SEED) + + G = GKet((0, 1)) + + bsgate = BSgate((0, 1), 0.05, 0.1, theta_trainable=True, phi_trainable=True) + target_state = (G >> BSgate((0, 1), 0.1, 0.2)).fock_array((40, 40)) + + def cost_fn(): + state_out = G >> bsgate + + return ( + -math.abs(math.sum(math.conj(state_out.fock_array((40, 40))) * target_state)) ** 2 + ) + + opt = Optimizer() + opt.minimize(cost_fn, by_optimizing=[bsgate]) + + assert np.allclose(bsgate.parameters.theta.value, 0.1, atol=0.01) + assert np.allclose(bsgate.parameters.phi.value, 0.2, atol=0.01) + + def test_squeezing_grad_from_fock(self): + """Test that the gradient of a squeezing gate is computed from the fock representation.""" + skip_np() + + squeezing = Sgate((0,), r=1.0, r_trainable=True) + og_r = math.asnumpy(squeezing.parameters.r.value) + + def cost_fn(): + return -((Number((0,), 2) >> squeezing >> Vacuum((0,)).dual) ** 2) + + opt = Optimizer(euclidean_lr=0.05) + opt.minimize(cost_fn, by_optimizing=[squeezing], max_steps=100) + + assert squeezing.parameters.r.value != og_r + + def test_displacement_grad_from_fock(self): + """Test that the gradient of a displacement gate is computed from the fock representation.""" + skip_np() + + disp = Dgate((0,), x=1.0, y=0.5, x_trainable=True, y_trainable=True) + og_x = math.asnumpy(disp.parameters.x.value) + og_y = math.asnumpy(disp.parameters.y.value) + + def cost_fn(): + return -((Number((0,), 2) >> disp >> Vacuum((0,)).dual) ** 2) + + opt = Optimizer(euclidean_lr=0.05) + opt.minimize(cost_fn, by_optimizing=[disp], max_steps=100) + assert og_x != disp.parameters.x.value + assert og_y != disp.parameters.y.value + + def test_bsgate_grad_from_fock(self): + """Test that the gradient of a beamsplitter gate is computed from the fock representation.""" + skip_np() + + sq = SqueezedVacuum((0,), r=1.0, r_trainable=True) + og_r = math.asnumpy(sq.parameters.r.value) + + def cost_fn(): + return -( + ( + sq + >> Number((1,), 1) + >> BSgate((0, 1), 0.5) + >> (Vacuum((0,)) >> Number((1,), 1)).dual + ) + ** 2 + ) + + opt = Optimizer(euclidean_lr=0.05) + opt.minimize(cost_fn, by_optimizing=[sq], max_steps=100) + + assert og_r != sq.parameters.r.value diff --git a/tests/test_training/test_trainer (Filippo Miatto's conflicted copy 2025-01-24).py b/tests/test_training/test_trainer (Filippo Miatto's conflicted copy 2025-01-24).py new file mode 100644 index 000000000..76902327f --- /dev/null +++ b/tests/test_training/test_trainer (Filippo Miatto's conflicted copy 2025-01-24).py @@ -0,0 +1,268 @@ +# Copyright 2022 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=import-outside-toplevel + +"""Tests for the ray-based trainer.""" + +import sys +from time import sleep + +import numpy as np +import pytest + +try: + import ray + + ray_available = True + + NUM_CPUS = 1 + ray.init(num_cpus=NUM_CPUS) +except ImportError: + ray_available = False + +from mrmustard.lab_dev import Dgate, Ggate, GKet, Vacuum +from mrmustard.training import Optimizer +from mrmustard.training.trainer import map_trainer, train_device, update_pop + +from ..conftest import skip_np + + +def wrappers(): + """Dummy wrappers tested.""" + + def make_circ(x=0.0, return_type=None): + from mrmustard import math + + math.change_backend("tensorflow") + + circ = Ggate((0,), symplectic_trainable=True) >> Dgate( + (0,), x=x, x_trainable=True, y_trainable=True + ) + return ( + [circ] if return_type == "list" else {"circ": circ} if return_type == "dict" else circ + ) + + def cost_fn(circ=make_circ(0.1), y_targ=0.0): + from mrmustard import math + + math.change_backend("tensorflow") + + target = GKet((0,)) >> Dgate((0,), -0.1, y_targ) + s = Vacuum((0,)) >> circ + return -math.abs((s >> target.dual) ** 2) + + return make_circ, cost_fn + + +@pytest.mark.skipif(not ray_available, reason="ray is not available") +class TestTrainer: + """Class containinf ray-related tests.""" + + @pytest.mark.parametrize( + "tasks", + [5, [{"y_targ": 0.1}, {"y_targ": -0.2}], {"c0": {}, "c1": {"y_targ": 0.07}}], + ) + @pytest.mark.parametrize("seed", [None, 42]) + def test_circ_cost(self, tasks, seed): # pylint: disable=redefined-outer-name + """Test distributed cost calculations.""" + skip_np() + + has_seed = isinstance(seed, int) + _, cost_fn = wrappers() + results = map_trainer( + cost_fn=cost_fn, + tasks=tasks, + num_cpus=NUM_CPUS, + **({"SEED": seed} if has_seed else {}), + ) + + if isinstance(tasks, dict): + assert set(results.keys()) == set(tasks.keys()) + results = list(results.values()) + assert all(r["optimizer"] is None for r in results) + assert all(r["device"] == [] for r in results) + if has_seed and isinstance(tasks, int): + assert len(set(r["cost"] for r in results)) == 1 + else: + assert ( + len(set(r["cost"] for r in results)) + >= (tasks if isinstance(tasks, int) else len(tasks)) - 1 + ) + + @pytest.mark.parametrize( + "tasks", + [[{"x": 0.1}, {"y_targ": 0.2}], {"c0": {}, "c1": {"euclidean_lr": 0.02}}], + ) + @pytest.mark.parametrize( + "return_type", + [None, "dict"], + ) + def test_circ_optimize(self, tasks, return_type): # pylint: disable=redefined-outer-name + """Test distributed optimizations.""" + skip_np() + + max_steps = 15 + make_circ, cost_fn = wrappers() + results = map_trainer( + cost_fn=cost_fn, + device_factory=make_circ, + tasks=tasks, + max_steps=max_steps, + symplectic_lr=0.05, + return_type=return_type, + num_cpus=NUM_CPUS, + ) + + if isinstance(tasks, dict): + assert set(results.keys()) == set(tasks.keys()) + results = list(results.values()) + assert ( + len(set(r["cost"] for r in results)) + >= (tasks if isinstance(tasks, int) else len(tasks)) - 1 + ) + assert all(isinstance(r["optimizer"], Optimizer) for r in results) + assert all((r["optimizer"].opt_history) for r in results) + + # Check if optimization history is actually decreasing. + opt_history = np.array(results[0]["optimizer"].opt_history) + assert len(opt_history) == max_steps + 1 + assert opt_history[0] - opt_history[-1] > 1e-6 + assert (np.diff(opt_history) < 0).sum() >= 3 + + @pytest.mark.parametrize( + "metric_fns", + [ + lambda c: (Vacuum((0,)) >> c >> c >> c).fock_array((5,)), + ], + ) + def test_circ_optimize_metrics(self, metric_fns): # pylint: disable=redefined-outer-name + """Tests custom metric functions on final circuits.""" + skip_np() + + make_circ, cost_fn = wrappers() + + tasks = { + "my-job": {"x": 0.1, "euclidean_lr": 0.01, "max_steps": 100}, + "my-other-job": {"x": -0.7, "euclidean_lr": 0.1, "max_steps": 20}, + } + + results = map_trainer( + cost_fn=cost_fn, + device_factory=make_circ, + tasks=tasks, + y_targ=0.35, + symplectic_lr=0.05, + metric_fns=metric_fns, + return_list=True, + num_cpus=NUM_CPUS, + ) + + assert set(results.keys()) == set(tasks.keys()) + results = list(results.values()) + assert all( + ("metrics" in r or set(metric_fns.keys()).issubset(set(r.keys()))) for r in results + ) + assert ( + len(set(r["cost"] for r in results)) + >= (tasks if isinstance(tasks, int) else len(tasks)) - 1 + ) + assert all(isinstance(r["optimizer"], Optimizer) for r in results) + assert all((r["optimizer"].opt_history) for r in results) + + # Check if optimization history is actually decreasing. + opt_history = np.array(results[0]["optimizer"].opt_history) + assert opt_history[1] - opt_history[-1] > 1e-6 + + def test_update_pop(self): + """Test for coverage.""" + skip_np() + + d = {"a": 3, "b": "foo"} + kwargs = {"b": "bar", "c": 22} + d1, kwargs = update_pop(d, **kwargs) + assert d1["b"] == "bar" + assert len(kwargs) == 1 + + def test_no_ray(self, monkeypatch): + """Tests ray import error""" + skip_np() + + monkeypatch.setitem(sys.modules, "ray", None) + with pytest.raises(ImportError, match="Failed to import `ray`"): + _ = map_trainer( + tasks=2, + num_cpus=NUM_CPUS, + ) + + def test_invalid_tasks(self): + """Tests unexpected tasks arg""" + skip_np() + + with pytest.raises( + ValueError, match="`tasks` is expected to be of type int, list, or dict." + ): + _ = map_trainer( + tasks=2.3, + num_cpus=NUM_CPUS, + ) + + def test_warn_unused_kwargs(self): # pylint: disable=redefined-outer-name + """Test warning of unused kwargs""" + skip_np() + + _, cost_fn = wrappers() + with pytest.warns(UserWarning, match="Unused kwargs:"): + results = train_device( + cost_fn=cost_fn, + foo="bar", + ) + assert len(results) >= 4 + assert isinstance(results["cost"], float) + + def test_no_pbar(self): # pylint: disable=redefined-outer-name + """Test turning off pregress bar""" + skip_np() + + _, cost_fn = wrappers() + results = map_trainer( + cost_fn=cost_fn, + tasks=2, + pbar=False, + num_cpus=NUM_CPUS, + ) + assert len(results) == 2 + + @pytest.mark.parametrize("tasks", [2, {"c0": {}, "c1": {"y_targ": -0.7}}]) + def test_unblock(self, tasks): # pylint: disable=redefined-outer-name + """Test unblock async mode""" + skip_np() + + _, cost_fn = wrappers() + result_getter = map_trainer( + cost_fn=cost_fn, + tasks=tasks, + unblock=True, + num_cpus=NUM_CPUS, + ) + assert callable(result_getter) + + sleep(0.2) + results = result_getter() + if len(results) <= (tasks if isinstance(tasks, int) else len(tasks)): + # safer on slower machines + sleep(1) + results = result_getter() + + assert len(results) == (tasks if isinstance(tasks, int) else len(tasks))