Skip to content

Commit

Permalink
some asserts
Browse files Browse the repository at this point in the history
  • Loading branch information
apchytr committed Dec 18, 2024
1 parent 73d7da8 commit d37c04a
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion tests/test_training/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,30 +534,38 @@ def test_squeezing_grad_from_fock(self):
skip_np()

squeezing = Sgate((0,), r=1.0, r_trainable=True)
og_r = math.asnumpy(squeezing.r.value)

def cost_fn():
return -(Number((0,), 2) >> squeezing >> Vacuum((0,)).dual)

opt = Optimizer(euclidean_lr=0.05)
opt.minimize(cost_fn, by_optimizing=[squeezing], max_steps=100)

assert squeezing.r.value != og_r

def test_displacement_grad_from_fock(self):
"""Test that the gradient of a displacement gate is computed from the fock representation."""
skip_np()

disp = Dgate((0,), x=1.0, y=1.0, x_trainable=True, y_trainable=True)
disp = Dgate((0,), x=1.0, y=0.5, x_trainable=True, y_trainable=True)
og_x = math.asnumpy(disp.x.value)
og_y = math.asnumpy(disp.y.value)

def cost_fn():
return -(Number((0,), 2) >> disp >> Vacuum((0,)).dual)

opt = Optimizer(euclidean_lr=0.05)
opt.minimize(cost_fn, by_optimizing=[disp], max_steps=100)
assert og_x != disp.x.value
assert og_y != disp.y.value

def test_bsgate_grad_from_fock(self):
"""Test that the gradient of a beamsplitter gate is computed from the fock representation."""
skip_np()

sq = SqueezedVacuum((0,), r=1.0, r_trainable=True)
og_r = math.asnumpy(sq.r.value)

def cost_fn():
return -(
Expand All @@ -569,3 +577,5 @@ def cost_fn():

opt = Optimizer(euclidean_lr=0.05)
opt.minimize(cost_fn, by_optimizing=[sq], max_steps=100)

assert og_r != sq.r.value

0 comments on commit d37c04a

Please sign in to comment.