From 7e36f8361363adc1c96bced796e13613d3c78902 Mon Sep 17 00:00:00 2001 From: Anthony Date: Mon, 16 Dec 2024 15:29:39 -0500 Subject: [PATCH] most working --- mrmustard/lab_dev/states/number.py | 4 +- mrmustard/physics/ansatz/array_ansatz.py | 17 +++++++- tests/test_training/test_opt.py | 53 +++++++++++++----------- 3 files changed, 45 insertions(+), 29 deletions(-) diff --git a/mrmustard/lab_dev/states/number.py b/mrmustard/lab_dev/states/number.py index 9f55c72b5..032627486 100644 --- a/mrmustard/lab_dev/states/number.py +++ b/mrmustard/lab_dev/states/number.py @@ -74,9 +74,7 @@ def __init__( self._add_parameter(make_parameter(False, cs, "cutoffs", (None, None))) self._representation = self.from_ansatz( modes=modes, - ansatz=ArrayAnsatz.from_function( - fock_state, n=self.n.value, cutoffs=self.cutoffs.value - ), + ansatz=ArrayAnsatz.from_function(fock_state, n=self.n, cutoffs=self.cutoffs), ).representation self.short_name = [str(int(n)) for n in self.n.value] for i, cutoff in enumerate(self.cutoffs.value): diff --git a/mrmustard/physics/ansatz/array_ansatz.py b/mrmustard/physics/ansatz/array_ansatz.py index d792957ef..4f1308b80 100644 --- a/mrmustard/physics/ansatz/array_ansatz.py +++ b/mrmustard/physics/ansatz/array_ansatz.py @@ -27,6 +27,7 @@ from IPython.display import display from mrmustard import math, widgets +from mrmustard.math.parameters import Variable from mrmustard.utils.typing import Batch, Scalar, Tensor, Vector from .base import Ansatz @@ -203,8 +204,20 @@ def trace(self, idx_z: tuple[int, ...], idx_zconj: tuple[int, ...]) -> ArrayAnsa return ArrayAnsatz([trace] if trace.shape == () else trace, batched=True) def _generate_ansatz(self): - if self._array is None: - self.array = [self._fn(**self._kwargs)] + names = list(self._kwargs.keys()) + vars = list(self._kwargs.values()) + + params = {} + param_types = [] + for name, param in zip(names, vars): + try: + params[name] = param.value + param_types.append(type(param)) + except AttributeError: + params[name] = param + + if self._array is None or Variable in param_types: + self.array = [self._fn(**params)] def _ipython_display_(self): if widgets.IN_INTERACTIVE_SHELL or (w := widgets.fock(self)) is None: diff --git a/tests/test_training/test_opt.py b/tests/test_training/test_opt.py index 9883bff93..c14497de3 100644 --- a/tests/test_training/test_opt.py +++ b/tests/test_training/test_opt.py @@ -531,38 +531,43 @@ def cost_fn(): assert np.allclose(bsgate.theta.value, 0.1, atol=0.01) assert np.allclose(bsgate.phi.value, 0.2, atol=0.01) - # def test_squeezing_grad_from_fock(self): - # """Test that the gradient of a squeezing gate is computed from the fock representation.""" - # skip_np() + def test_squeezing_grad_from_fock(self): + """Test that the gradient of a squeezing gate is computed from the fock representation.""" + skip_np() - # squeezing = Sgate((0,), r=1, r_trainable=True) + squeezing = Sgate((0,), r=1.0, r_trainable=True) - # def cost_fn(): - # return -(Number((0,), 2) >> squeezing >> Vacuum((0,)).dual) + def cost_fn(): + return -(Number((0,), 2) >> squeezing >> Vacuum((0,)).dual) - # opt = Optimizer(euclidean_lr=0.05) - # opt.minimize(cost_fn, by_optimizing=[squeezing], max_steps=100) + opt = Optimizer(euclidean_lr=0.05) + opt.minimize(cost_fn, by_optimizing=[squeezing], max_steps=100) - # def test_displacement_grad_from_fock(self): - # """Test that the gradient of a displacement gate is computed from the fock representation.""" - # skip_np() + def test_displacement_grad_from_fock(self): + """Test that the gradient of a displacement gate is computed from the fock representation.""" + skip_np() - # disp = Dgate(x=1.0, y=1.0, x_trainable=True, y_trainable=True) + disp = Dgate((0,), x=1.0, y=1.0, x_trainable=True, y_trainable=True) - # def cost_fn(): - # return -(Fock(2) >> disp << Vacuum(1)) + def cost_fn(): + return -(Number((0,), 2) >> disp >> Vacuum((0,)).dual) - # opt = Optimizer(euclidean_lr=0.05) - # opt.minimize(cost_fn, by_optimizing=[disp], max_steps=100) + opt = Optimizer(euclidean_lr=0.05) + opt.minimize(cost_fn, by_optimizing=[disp], max_steps=100) - # def test_bsgate_grad_from_fock(self): - # """Test that the gradient of a beamsplitter gate is computed from the fock representation.""" - # skip_np() + def test_bsgate_grad_from_fock(self): + """Test that the gradient of a beamsplitter gate is computed from the fock representation.""" + skip_np() - # sq = SqueezedVacuum(r=1.0, r_trainable=True) + sq = SqueezedVacuum((0,), r=1.0, r_trainable=True) - # def cost_fn(): - # return -((sq & Fock(1)) >> BSgate(0.5) << (Vacuum(1) & Fock(1))) + def cost_fn(): + return -( + sq + >> Number((1,), 1) + >> BSgate((0, 1), 0.5) + >> (Vacuum((0,)) >> Number((1,), 1)).dual + ) - # opt = Optimizer(euclidean_lr=0.05) - # opt.minimize(cost_fn, by_optimizing=[sq], max_steps=100) + opt = Optimizer(euclidean_lr=0.05) + opt.minimize(cost_fn, by_optimizing=[sq], max_steps=100)