Skip to content

Commit

Permalink
Add JVP implementation of parameter shift rule (#27)
Browse files Browse the repository at this point in the history
* add jvp implementation of parameter shift rule

* add type annotations to jvp functions. apply black.
  • Loading branch information
atiyo authored Sep 9, 2024
1 parent 2824f4b commit 77084a5
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 6 deletions.
40 changes: 38 additions & 2 deletions horqrux/shots.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

from functools import reduce
from functools import partial, reduce
from typing import Any

import jax
import jax.numpy as jnp
from jax import Array
from jax import Array, random
from jax.experimental import checkify

from horqrux.apply import apply_gate
Expand Down Expand Up @@ -33,6 +33,7 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array:
return reduce(lambda x, y: jnp.kron(x, y), ops[1:], ops[0])


@partial(jax.custom_jvp, nondiff_argnums=(0, 1, 2, 4, 5))
def finite_shots_fwd(
state: Array,
gates: GateSequence,
Expand All @@ -52,3 +53,38 @@ def finite_shots_fwd(
inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten())
probs = jnp.abs(inner_prod) ** 2
return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean()


@finite_shots_fwd.defjvp
def finite_shots_jvp(
state: Array,
gates: GateSequence,
observable: Primitive,
n_shots: int,
key: Array,
primals: tuple[dict[str, float]],
tangents: tuple[dict[str, float]],
) -> Array:
values = primals[0]
tangent_dict = tangents[0]

# TODO: compute spectral gap through the generator which is associated with
# a param name.
spectral_gap = 2.0
shift = jnp.pi / 2

def jvp_component(param_name: str, key: Array) -> Array:
up_key, down_key = random.split(key)
up_val = values.copy()
up_val[param_name] = up_val[param_name] + shift
f_up = finite_shots_fwd(state, gates, observable, up_val, n_shots, up_key)
down_val = values.copy()
down_val[param_name] = down_val[param_name] - shift
f_down = finite_shots_fwd(state, gates, observable, down_val, n_shots, down_key)
grad = spectral_gap * (f_up - f_down) / (4.0 * jnp.sin(spectral_gap * shift / 2.0))
return grad * tangent_dict[param_name]

params_with_keys = zip(values.keys(), random.split(key, len(values)))
fwd = finite_shots_fwd(state, gates, observable, values, n_shots, key)
jvp = sum(jvp_component(param, key) for param, key in params_with_keys)
return fwd, jvp.reshape(fwd.shape)
25 changes: 21 additions & 4 deletions tests/test_shots.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import jax
import jax.numpy as jnp

from horqrux import expectation, random_state
Expand All @@ -13,11 +14,27 @@

def test_shots() -> None:
ops = [RX("theta", 0)]
observable = Z(2)
values = {p: jnp.ones(1).item() for p in ["theta"]}
observable = Z(0)
state = random_state(N_QUBITS)
x = jnp.pi * 0.5

exp_exact = expectation(state, ops, observable, values, "ad")
exp_shots = expectation(state, ops, observable, values, "gpsr", "shots", n_shots=N_SHOTS)
def exact(x):
values = {"theta": x}
return expectation(state, ops, observable, values, "ad")

def shots(x):
values = {"theta": x}
return expectation(state, ops, observable, values, "gpsr", "shots", n_shots=N_SHOTS)

exp_exact = exact(x)
exp_shots = exact(x)

assert jnp.isclose(exp_exact, exp_shots, atol=SHOTS_ATOL)

d_exact = jax.grad(exact)
d_shots = jax.grad(shots)

grad_backprop = d_exact(x)
grad_shots = d_shots(x)

assert jnp.isclose(grad_backprop, grad_shots, atol=SHOTS_ATOL)

0 comments on commit 77084a5

Please sign in to comment.