diff --git a/horqrux/shots.py b/horqrux/shots.py index b7068f9..982ecf3 100644 --- a/horqrux/shots.py +++ b/horqrux/shots.py @@ -1,11 +1,11 @@ from __future__ import annotations -from functools import reduce +from functools import partial, reduce from typing import Any import jax import jax.numpy as jnp -from jax import Array +from jax import Array, random from jax.experimental import checkify from horqrux.apply import apply_gate @@ -33,6 +33,7 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: return reduce(lambda x, y: jnp.kron(x, y), ops[1:], ops[0]) +@partial(jax.custom_jvp, nondiff_argnums=(0, 1, 2, 4, 5)) def finite_shots_fwd( state: Array, gates: GateSequence, @@ -52,3 +53,38 @@ def finite_shots_fwd( inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten()) probs = jnp.abs(inner_prod) ** 2 return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean() + + +@finite_shots_fwd.defjvp +def finite_shots_jvp( + state: Array, + gates: GateSequence, + observable: Primitive, + n_shots: int, + key: Array, + primals: tuple[dict[str, float]], + tangents: tuple[dict[str, float]], +) -> Array: + values = primals[0] + tangent_dict = tangents[0] + + # TODO: compute spectral gap through the generator which is associated with + # a param name. + spectral_gap = 2.0 + shift = jnp.pi / 2 + + def jvp_component(param_name: str, key: Array) -> Array: + up_key, down_key = random.split(key) + up_val = values.copy() + up_val[param_name] = up_val[param_name] + shift + f_up = finite_shots_fwd(state, gates, observable, up_val, n_shots, up_key) + down_val = values.copy() + down_val[param_name] = down_val[param_name] - shift + f_down = finite_shots_fwd(state, gates, observable, down_val, n_shots, down_key) + grad = spectral_gap * (f_up - f_down) / (4.0 * jnp.sin(spectral_gap * shift / 2.0)) + return grad * tangent_dict[param_name] + + params_with_keys = zip(values.keys(), random.split(key, len(values))) + fwd = finite_shots_fwd(state, gates, observable, values, n_shots, key) + jvp = sum(jvp_component(param, key) for param, key in params_with_keys) + return fwd, jvp.reshape(fwd.shape) diff --git a/tests/test_shots.py b/tests/test_shots.py index 1410899..831b3d7 100644 --- a/tests/test_shots.py +++ b/tests/test_shots.py @@ -1,5 +1,6 @@ from __future__ import annotations +import jax import jax.numpy as jnp from horqrux import expectation, random_state @@ -13,11 +14,27 @@ def test_shots() -> None: ops = [RX("theta", 0)] - observable = Z(2) - values = {p: jnp.ones(1).item() for p in ["theta"]} + observable = Z(0) state = random_state(N_QUBITS) + x = jnp.pi * 0.5 - exp_exact = expectation(state, ops, observable, values, "ad") - exp_shots = expectation(state, ops, observable, values, "gpsr", "shots", n_shots=N_SHOTS) + def exact(x): + values = {"theta": x} + return expectation(state, ops, observable, values, "ad") + + def shots(x): + values = {"theta": x} + return expectation(state, ops, observable, values, "gpsr", "shots", n_shots=N_SHOTS) + + exp_exact = exact(x) + exp_shots = exact(x) assert jnp.isclose(exp_exact, exp_shots, atol=SHOTS_ATOL) + + d_exact = jax.grad(exact) + d_shots = jax.grad(shots) + + grad_backprop = d_exact(x) + grad_shots = d_shots(x) + + assert jnp.isclose(grad_backprop, grad_shots, atol=SHOTS_ATOL)