diff --git a/horqrux/api.py b/horqrux/api.py index 2eeb464..636e7a7 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -6,7 +6,6 @@ import jax import jax.numpy as jnp from jax import Array -from jax.experimental import checkify from horqrux.adjoint import adjoint_expectation from horqrux.apply import apply_gate @@ -96,13 +95,4 @@ def expectation( elif diff_mode == DiffMode.ADJOINT: return adjoint_expectation(state, gates, observables, values) elif diff_mode == DiffMode.GPSR: - checkify.check( - forward_mode == ForwardMode.SHOTS, "Finite shots and GPSR must be used together" - ) - checkify.check( - type(n_shots) is int, - "Number of shots must be an integer for finite shots.", - ) - # Type checking is disabled because mypy doesn't parse checkify.check. - # type: ignore - return finite_shots_fwd(state, gates, observables, values, n_shots=n_shots, key=key) + return finite_shots_fwd(state=state, gates=gates, observables=observables, values=values, n_shots=n_shots, key=key) diff --git a/horqrux/apply.py b/horqrux/apply.py index 5bd054c..7c893ef 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -112,6 +112,8 @@ def apply_gate( op_type: OperationType = OperationType.UNITARY, group_gates: bool = False, # Defaulting to False since this can be performed once before circuit execution merge_ops: bool = True, + shift_up_gates=jnp.array([], dtype=int), + shift_down_gates=jnp.array([], dtype=int), ) -> State: """Wrapper function for 'apply_operator' which applies a gate or a series of gates to a given state. Arguments: @@ -126,13 +128,17 @@ def apply_gate( State after applying 'gate'. """ operator: Tuple[Array, ...] + + def gate_shift(index): + return jnp.pi * ((shift_up_gates == index).sum() - (shift_down_gates == index).sum()) / 2 + if isinstance(gate, Primitive): operator_fn = getattr(gate, op_type) - operator, target, control = (operator_fn(values),), gate.target, gate.control + operator, target, control = (operator_fn(values, gate_shift(0)),), gate.target, gate.control else: if group_gates: gate = group_by_index(gate) - operator = tuple(getattr(g, op_type)(values) for g in gate) + operator = tuple(getattr(g, op_type)(values, gate_shift(i)) for i, g in enumerate(gate)) target = reduce(add, [g.target for g in gate]) control = reduce(add, [g.control for g in gate]) if merge_ops: diff --git a/horqrux/parametric.py b/horqrux/parametric.py index bd5d488..613da6a 100644 --- a/horqrux/parametric.py +++ b/horqrux/parametric.py @@ -57,8 +57,8 @@ def __iter__(self) -> Iterable: def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: return cls(*children, *aux_data) - def unitary(self, values: dict[str, float] = dict()) -> Array: - return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(values)) + def unitary(self, values: dict[str, float] = dict(), shift: float = 0.0) -> Array: + return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(values) + shift) def jacobian(self, values: dict[str, float] = dict()) -> Array: return _jacobian(OPERATIONS_DICT[self.generator_name], self.parse_values(values)) diff --git a/horqrux/primitive.py b/horqrux/primitive.py index 2aac3a4..b5efc98 100644 --- a/horqrux/primitive.py +++ b/horqrux/primitive.py @@ -58,7 +58,7 @@ def tree_flatten(self) -> Tuple[Tuple, Tuple[str, TargetQubits, ControlQubits]]: def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: return cls(*children, *aux_data) - def unitary(self, values: dict[str, float] = dict()) -> Array: + def unitary(self, values: dict[str, float] = dict(), shift: float = 0.0) -> Array: return OPERATIONS_DICT[self.generator_name] def dagger(self, values: dict[str, float] = dict()) -> Array: diff --git a/horqrux/shots.py b/horqrux/shots.py index 4383100..0505885 100644 --- a/horqrux/shots.py +++ b/horqrux/shots.py @@ -2,15 +2,16 @@ from functools import partial, reduce from typing import Any +from jax.custom_derivatives import SymbolicZero import jax import jax.numpy as jnp from jax import Array, random -from jax.experimental import checkify +from jax import lax from horqrux.apply import apply_gate +from horqrux.parametric import Parametric from horqrux.primitive import GateSequence, Primitive -from horqrux.utils import none_like def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: @@ -21,10 +22,6 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: LIMITATION: currently only works for observables which are not controlled. """ - checkify.check( - observable.control == observable.parse_idx(none_like(observable.target)), - "Controlled gates cannot be promoted from observables to operations on the whole state vector", - ) unitary = observable.unitary() target = observable.target[0][0] identity = jnp.eye(2, dtype=unitary.dtype) @@ -33,7 +30,8 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: return reduce(lambda x, y: jnp.kron(x, y), ops[1:], ops[0]) -@partial(jax.custom_jvp, nondiff_argnums=(0, 1, 2, 4, 5)) +# @jax.custom_jvp +@partial(jax.custom_jvp, nondiff_argnums=(4,)) def finite_shots_fwd( state: Array, gates: GateSequence, @@ -41,19 +39,22 @@ def finite_shots_fwd( values: dict[str, float], n_shots: int = 100, key: Any = jax.random.PRNGKey(0), + shift_up_gates=jnp.array([], dtype=int), + shift_down_gates=jnp.array([], dtype=int), ) -> Array: """ Run 'state' through a sequence of 'gates' given parameters 'values' and compute the expectation given an observable. """ - state = apply_gate(state, gates, values) + state = apply_gate(state=state, gate=gates, values=values, shift_up_gates=shift_up_gates, + shift_down_gates=shift_down_gates) n_qubits = len(state.shape) mat_obs = [observable_to_matrix(observable, n_qubits) for observable in observables] eigs = [jnp.linalg.eigh(mat) for mat in mat_obs] eigvecs, eigvals = align_eigenvectors(eigs) inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten()) probs = jnp.abs(inner_prod) ** 2 - return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean(axis=0) + return jax.random.choice(key, eigvals, (n_shots,), True, probs).mean(axis=0) def align_eigenvectors(eigs: list[tuple[Array, Array]]) -> tuple[Array, Array]: @@ -79,10 +80,6 @@ def align_eigenvectors(eigs: list[tuple[Array, Array]]) -> tuple[Array, Array]: for mat in eigs_copy: inv = jnp.linalg.inv(mat[1]) P = (inv @ eigenvector_matrix).real > 0.5 - checkify.check( - validate_permutation_matrix(P), - "Did not calculate valid permutation matrix", - ) eigenvalues.append(mat[0] @ P) return eigenvector_matrix, jnp.stack(eigenvalues, axis=1) @@ -94,36 +91,75 @@ def validate_permutation_matrix(P: Array) -> Array: return ((ones == rows) & (ones == columns)).min() -@finite_shots_fwd.defjvp +def get_shifted_gate(gate, values, shift): + current_angle = gate.parse_values(values) + new_angle = current_angle + shift + return Parametric(gate.generator_name, gate.target[0], gate.control[0], new_angle) + + +@partial(finite_shots_fwd.defjvp, symbolic_zeros=True) def finite_shots_jvp( - state: Array, - gates: GateSequence, - observable: Primitive, - n_shots: int, - key: Array, - primals: tuple[dict[str, float]], - tangents: tuple[dict[str, float]], + n_shots, + primals: tuple[list[Primitive], dict[str, float]], + tangents: tuple[list[Primitive], dict[str, float]] ) -> Array: - values = primals[0] - tangent_dict = tangents[0] - - # TODO: compute spectral gap through the generator which is associated with - # a param name. - spectral_gap = 2.0 - shift = jnp.pi / 2 - - def jvp_component(param_name: str, key: Array) -> Array: - up_key, down_key = random.split(key) - up_val = values.copy() - up_val[param_name] = up_val[param_name] + shift - f_up = finite_shots_fwd(state, gates, observable, up_val, n_shots, up_key) - down_val = values.copy() - down_val[param_name] = down_val[param_name] - shift - f_down = finite_shots_fwd(state, gates, observable, down_val, n_shots, down_key) - grad = spectral_gap * (f_up - f_down) / (4.0 * jnp.sin(spectral_gap * shift / 2.0)) - return grad * tangent_dict[param_name] - - params_with_keys = zip(values.keys(), random.split(key, len(values))) - fwd = finite_shots_fwd(state, gates, observable, values, n_shots, key) - jvp = sum(jvp_component(param, key) for param, key in params_with_keys) - return fwd, jvp.reshape(fwd.shape) + + # state: Array, + # gates: GateSequence, + # observables: list[Primitive], + # values: dict[str, float], + # n_shots: int = 100, + # key: Any = jax.random.PRNGKey(0), + # shift_up_gates=jnp.array([], dtype=int), + # shift_down_gates=jnp.array([], dtype=int), + + state, gates, observables, values, key, shift_up_gates, shift_down_gates = primals + fwd = finite_shots_fwd(state, gates, observables, values, n_shots, + key, shift_up_gates, shift_down_gates) + zero = jnp.zeros_like(fwd) + jvp = jnp.zeros_like(fwd) + + gate_tangents = [gate.param if isinstance( + gate, Parametric) else None for gate in tangents[1]] + gate_tangents = [tangents[3][param] if isinstance( + param, str) else param for param in gate_tangents] + gate_tangents = [tangent if not isinstance( + tangent, SymbolicZero) else None for tangent in gate_tangents] + gate_tangents = [tangent if isinstance( + tangent, jax.Array) else None for tangent in gate_tangents] + + def parametric_gradient_at_i(i, primals, n_shots): + state, gates, observables, values, key, shift_up_gates, shift_down_gates = primals + base_key = random.split(key, len(gates))[i] + up_key, down_key = random.split(base_key, 2) + new_shift_up_gates = jnp.append(shift_up_gates, i) + new_shift_down_gates = jnp.append(shift_down_gates, i) + f_up = finite_shots_fwd(state, gates, observables, values, n_shots, + up_key, new_shift_up_gates, shift_down_gates) + f_down = finite_shots_fwd(state, gates, observables, values, n_shots, + down_key, shift_up_gates, new_shift_down_gates) + shift = jnp.pi/2 + spectral_gap = 2.0 + return spectral_gap * (f_up - f_down) / (4.0 * jnp.sin(spectral_gap * shift / 2.0)) + + # def loop_func(i, carry): + # primals, zero, gate_tangents, jvp = carry + # jvp_component = lax.cond(jnp.isnan(gate_tangents[i]), + # lambda *args: zero, + # lambda i, primals, tangent: tangent[i] * + # parametric_gradient_at_i(i, primals), + # i, + # primals, + # gate_tangents, + # ) + # return primals, zero, gate_tangents, jvp + jvp_component + + for i, _ in enumerate(gates): + if gate_tangents[i] is None: + continue + jvp = jvp + gate_tangents[i] * parametric_gradient_at_i(i, primals, n_shots) + + # init_carry = primals, zero, gate_tangents, jvp + # loop_out = lax.fori_loop(0, gate_tangents.shape[0], loop_func, init_carry) + + return fwd, jvp diff --git a/tests/test_shots.py b/tests/test_shots.py index c98062d..a26c9d9 100644 --- a/tests/test_shots.py +++ b/tests/test_shots.py @@ -1,5 +1,7 @@ from __future__ import annotations +import functools + import jax import jax.numpy as jnp @@ -13,28 +15,27 @@ def test_shots() -> None: - ops = [RX("theta", 0)] observables = [Z(0), Z(1)] state = random_state(N_QUBITS) - x = jnp.pi * 0.5 + x = jnp.pi * 0.123 + y = jnp.pi * 0.456 - def exact(x): + def expect(x, y, method): values = {"theta": x} + ops = [RX("theta", 0), RX(0.2, 0), RX(y, 1), RX("theta", 1)] + if method == "shots": + return expectation(state, ops, observables, values, "gpsr", "shots", n_shots=N_SHOTS) return expectation(state, ops, observables, values, "ad") - def shots(x): - values = {"theta": x} - return expectation(state, ops, observables, values, "gpsr", "shots", n_shots=N_SHOTS) - - exp_exact = exact(x) - exp_shots = shots(x) + exp_exact = expect(x, y, "exact") + exp_shots = expect(x, y, "shots") assert jnp.allclose(exp_exact, exp_shots, atol=SHOTS_ATOL) - d_exact = jax.grad(lambda x: exact(x).sum()) - d_shots = jax.grad(lambda x: shots(x).sum()) + d_expect = jax.grad(lambda x, y, z: expect(x, y, z).sum(), argnums=[0, 1]) - grad_backprop = d_exact(x) - grad_shots = d_shots(x) + grad_backprop = jnp.stack(d_expect(x, y, "exact")) + with jax.check_tracer_leaks(): + grad_shots = jnp.stack(d_expect(x, y, "shots")) - assert jnp.isclose(grad_backprop, grad_shots, atol=SHOTS_ATOL) + assert jnp.allclose(grad_backprop, grad_shots, atol=SHOTS_ATOL)