diff --git a/horqrux/__init__.py b/horqrux/__init__.py index 1b46fcf..f23cf20 100644 --- a/horqrux/__init__.py +++ b/horqrux/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +from .api import expectation from .apply import apply_gate, apply_operator from .parametric import PHASE, RX, RY, RZ from .primitive import NOT, SWAP, H, I, S, T, X, Y, Z diff --git a/horqrux/api.py b/horqrux/api.py new file mode 100644 index 0000000..eb65751 --- /dev/null +++ b/horqrux/api.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from collections import Counter +from typing import Any, Optional + +import jax +import jax.numpy as jnp +from jax import Array +from jax.experimental import checkify + +from horqrux.adjoint import adjoint_expectation +from horqrux.apply import apply_gate +from horqrux.primitive import GateSequence, Primitive +from horqrux.shots import finite_shots_fwd +from horqrux.utils import DiffMode, ForwardMode, OperationType, inner + + +def run( + circuit: GateSequence, + state: Array, + values: dict[str, float] = dict(), +) -> Array: + return apply_gate(state, circuit, values) + + +def sample( + state: Array, + gates: GateSequence, + values: dict[str, float] = dict(), + n_shots: int = 1000, +) -> Counter: + if n_shots < 1: + raise ValueError("You can only call sample with n_shots>0.") + + wf = apply_gate(state, gates, values) + probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel() + key = jax.random.PRNGKey(0) + n_qubits = len(state.shape) + # JAX handles pseudo random number generation by tracking an explicit state via a random key + # For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html + samples = jax.vmap( + lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs) + )(jax.random.split(key, n_shots)) + + return Counter( + { + format(k, "0{}b".format(n_qubits)): count.item() + for k, count in enumerate(jnp.bincount(samples)) + if count > 0 + } + ) + + +def ad_expectation( + state: Array, gates: GateSequence, observable: GateSequence, values: dict[str, float] +) -> Array: + """ + Run 'state' through a sequence of 'gates' given parameters 'values' + and compute the expectation given an observable. + """ + out_state = apply_gate(state, gates, values, OperationType.UNITARY) + projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY) + return inner(out_state, projected_state).real + + +def expectation( + state: Array, + gates: GateSequence, + observable: GateSequence, + values: dict[str, float], + diff_mode: DiffMode = DiffMode.AD, + forward_mode: ForwardMode = ForwardMode.EXACT, + n_shots: Optional[int] = None, + key: Any = jax.random.PRNGKey(0), +) -> Array: + """ + Run 'state' through a sequence of 'gates' given parameters 'values' + and compute the expectation given an observable. + """ + if diff_mode == DiffMode.AD: + return ad_expectation(state, gates, observable, values) + elif diff_mode == DiffMode.ADJOINT: + return adjoint_expectation(state, gates, observable, values) + elif diff_mode == DiffMode.GPSR: + checkify.check( + forward_mode == ForwardMode.SHOTS, "Finite shots and GPSR must be used together" + ) + checkify.check( + isinstance(observable, Primitive), + "Finite Shots only supports a single Primitive as an observable", + ) + checkify.check( + type(n_shots) is int, + "Number of shots must be an integer for finite shots.", + ) + # Type checking is disabled because mypy doesn't parse checkify.check. + return finite_shots_fwd(state, gates, observable, values, n_shots=n_shots, key=key) # type: ignore diff --git a/horqrux/primitive.py b/horqrux/primitive.py index 9790603..2aac3a4 100644 --- a/horqrux/primitive.py +++ b/horqrux/primitive.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Iterable, Tuple +from typing import Any, Iterable, Tuple, Union import numpy as np from jax import Array @@ -72,6 +72,9 @@ def __repr__(self) -> str: return self.name + f"(target={self.target[0]}, control={self.control[0]})" +GateSequence = Union[Primitive, Iterable[Primitive]] + + def I(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: """Identity / I gate. This function returns an instance of 'Primitive' and does *not* apply the gate. By providing tuple of ints to 'control', it turns into a controlled gate. @@ -190,7 +193,7 @@ def T(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: return Primitive("T", target, control) -## Multi (target) qubit gates +# Multi (target) qubit gates def SWAP(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: diff --git a/horqrux/shots.py b/horqrux/shots.py new file mode 100644 index 0000000..b7068f9 --- /dev/null +++ b/horqrux/shots.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from functools import reduce +from typing import Any + +import jax +import jax.numpy as jnp +from jax import Array +from jax.experimental import checkify + +from horqrux.apply import apply_gate +from horqrux.primitive import GateSequence, Primitive +from horqrux.utils import none_like + + +def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: + """For finite shot sampling we need to calculate the eigenvalues/vectors of + an observable. This helper function takes an observable and system size + (n_qubits) and returns the overall action of the observable on the whole + system. + + LIMITATION: currently only works for observables which are not controlled. + """ + checkify.check( + observable.control == observable.parse_idx(none_like(observable.target)), + "Controlled gates cannot be promoted from observables to operations on the whole state vector", + ) + unitary = observable.unitary() + target = observable.target[0][0] + identity = jnp.eye(2, dtype=unitary.dtype) + ops = [identity for _ in range(n_qubits)] + ops[target] = unitary + return reduce(lambda x, y: jnp.kron(x, y), ops[1:], ops[0]) + + +def finite_shots_fwd( + state: Array, + gates: GateSequence, + observable: Primitive, + values: dict[str, float], + n_shots: int = 100, + key: Any = jax.random.PRNGKey(0), +) -> Array: + """ + Run 'state' through a sequence of 'gates' given parameters 'values' + and compute the expectation given an observable. + """ + state = apply_gate(state, gates, values) + n_qubits = len(state.shape) + mat_obs = observable_to_matrix(observable, n_qubits) + eigvals, eigvecs = jnp.linalg.eigh(mat_obs) + inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten()) + probs = jnp.abs(inner_prod) ** 2 + return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean() diff --git a/horqrux/utils.py b/horqrux/utils.py index 4b1074e..2641530 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -34,6 +34,17 @@ class OperationType(StrEnum): JACOBIAN = "jacobian" +class DiffMode(StrEnum): + AD = "ad" + ADJOINT = "adjoint" + GPSR = "gpsr" + + +class ForwardMode(StrEnum): + EXACT = "exact" + SHOTS = "shots" + + def _dagger(operator: Array) -> Array: return jnp.conjugate(operator.T) diff --git a/tests/test_shots.py b/tests/test_shots.py new file mode 100644 index 0000000..1410899 --- /dev/null +++ b/tests/test_shots.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import jax.numpy as jnp + +from horqrux import expectation, random_state +from horqrux.parametric import RX +from horqrux.primitive import Z + +N_QUBITS = 4 +SHOTS_ATOL = 0.01 +N_SHOTS = 10000 + + +def test_shots() -> None: + ops = [RX("theta", 0)] + observable = Z(2) + values = {p: jnp.ones(1).item() for p in ["theta"]} + state = random_state(N_QUBITS) + + exp_exact = expectation(state, ops, observable, values, "ad") + exp_shots = expectation(state, ops, observable, values, "gpsr", "shots", n_shots=N_SHOTS) + + assert jnp.isclose(exp_exact, exp_shots, atol=SHOTS_ATOL)