Skip to content

Commit

Permalink
[Feature] Finite shots (#23)
Browse files Browse the repository at this point in the history
* [Feature] Finite shots

* improved version

* stopping now haha

* comment

* Finite shot implementation for multiple qubits

* union instead of | for gatesequence definition

---------

Co-authored-by: Atiyo Ghosh <[email protected]>
  • Loading branch information
dominikandreasseitz and atiyo authored Jul 29, 2024
1 parent 92490b8 commit 2824f4b
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 2 deletions.
1 change: 1 addition & 0 deletions horqrux/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from .api import expectation
from .apply import apply_gate, apply_operator
from .parametric import PHASE, RX, RY, RZ
from .primitive import NOT, SWAP, H, I, S, T, X, Y, Z
Expand Down
97 changes: 97 additions & 0 deletions horqrux/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from __future__ import annotations

from collections import Counter
from typing import Any, Optional

import jax
import jax.numpy as jnp
from jax import Array
from jax.experimental import checkify

from horqrux.adjoint import adjoint_expectation
from horqrux.apply import apply_gate
from horqrux.primitive import GateSequence, Primitive
from horqrux.shots import finite_shots_fwd
from horqrux.utils import DiffMode, ForwardMode, OperationType, inner


def run(
circuit: GateSequence,
state: Array,
values: dict[str, float] = dict(),
) -> Array:
return apply_gate(state, circuit, values)


def sample(
state: Array,
gates: GateSequence,
values: dict[str, float] = dict(),
n_shots: int = 1000,
) -> Counter:
if n_shots < 1:
raise ValueError("You can only call sample with n_shots>0.")

wf = apply_gate(state, gates, values)
probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel()
key = jax.random.PRNGKey(0)
n_qubits = len(state.shape)
# JAX handles pseudo random number generation by tracking an explicit state via a random key
# For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html
samples = jax.vmap(
lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs)
)(jax.random.split(key, n_shots))

return Counter(
{
format(k, "0{}b".format(n_qubits)): count.item()
for k, count in enumerate(jnp.bincount(samples))
if count > 0
}
)


def ad_expectation(
state: Array, gates: GateSequence, observable: GateSequence, values: dict[str, float]
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.
"""
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
return inner(out_state, projected_state).real


def expectation(
state: Array,
gates: GateSequence,
observable: GateSequence,
values: dict[str, float],
diff_mode: DiffMode = DiffMode.AD,
forward_mode: ForwardMode = ForwardMode.EXACT,
n_shots: Optional[int] = None,
key: Any = jax.random.PRNGKey(0),
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.
"""
if diff_mode == DiffMode.AD:
return ad_expectation(state, gates, observable, values)
elif diff_mode == DiffMode.ADJOINT:
return adjoint_expectation(state, gates, observable, values)
elif diff_mode == DiffMode.GPSR:
checkify.check(
forward_mode == ForwardMode.SHOTS, "Finite shots and GPSR must be used together"
)
checkify.check(
isinstance(observable, Primitive),
"Finite Shots only supports a single Primitive as an observable",
)
checkify.check(
type(n_shots) is int,
"Number of shots must be an integer for finite shots.",
)
# Type checking is disabled because mypy doesn't parse checkify.check.
return finite_shots_fwd(state, gates, observable, values, n_shots=n_shots, key=key) # type: ignore
7 changes: 5 additions & 2 deletions horqrux/primitive.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Iterable, Tuple
from typing import Any, Iterable, Tuple, Union

import numpy as np
from jax import Array
Expand Down Expand Up @@ -72,6 +72,9 @@ def __repr__(self) -> str:
return self.name + f"(target={self.target[0]}, control={self.control[0]})"


GateSequence = Union[Primitive, Iterable[Primitive]]


def I(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
"""Identity / I gate. This function returns an instance of 'Primitive' and does *not* apply the gate.
By providing tuple of ints to 'control', it turns into a controlled gate.
Expand Down Expand Up @@ -190,7 +193,7 @@ def T(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
return Primitive("T", target, control)


## Multi (target) qubit gates
# Multi (target) qubit gates


def SWAP(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
Expand Down
54 changes: 54 additions & 0 deletions horqrux/shots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

from functools import reduce
from typing import Any

import jax
import jax.numpy as jnp
from jax import Array
from jax.experimental import checkify

from horqrux.apply import apply_gate
from horqrux.primitive import GateSequence, Primitive
from horqrux.utils import none_like


def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array:
"""For finite shot sampling we need to calculate the eigenvalues/vectors of
an observable. This helper function takes an observable and system size
(n_qubits) and returns the overall action of the observable on the whole
system.
LIMITATION: currently only works for observables which are not controlled.
"""
checkify.check(
observable.control == observable.parse_idx(none_like(observable.target)),
"Controlled gates cannot be promoted from observables to operations on the whole state vector",
)
unitary = observable.unitary()
target = observable.target[0][0]
identity = jnp.eye(2, dtype=unitary.dtype)
ops = [identity for _ in range(n_qubits)]
ops[target] = unitary
return reduce(lambda x, y: jnp.kron(x, y), ops[1:], ops[0])


def finite_shots_fwd(
state: Array,
gates: GateSequence,
observable: Primitive,
values: dict[str, float],
n_shots: int = 100,
key: Any = jax.random.PRNGKey(0),
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.
"""
state = apply_gate(state, gates, values)
n_qubits = len(state.shape)
mat_obs = observable_to_matrix(observable, n_qubits)
eigvals, eigvecs = jnp.linalg.eigh(mat_obs)
inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten())
probs = jnp.abs(inner_prod) ** 2
return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean()
11 changes: 11 additions & 0 deletions horqrux/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ class OperationType(StrEnum):
JACOBIAN = "jacobian"


class DiffMode(StrEnum):
AD = "ad"
ADJOINT = "adjoint"
GPSR = "gpsr"


class ForwardMode(StrEnum):
EXACT = "exact"
SHOTS = "shots"


def _dagger(operator: Array) -> Array:
return jnp.conjugate(operator.T)

Expand Down
23 changes: 23 additions & 0 deletions tests/test_shots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

import jax.numpy as jnp

from horqrux import expectation, random_state
from horqrux.parametric import RX
from horqrux.primitive import Z

N_QUBITS = 4
SHOTS_ATOL = 0.01
N_SHOTS = 10000


def test_shots() -> None:
ops = [RX("theta", 0)]
observable = Z(2)
values = {p: jnp.ones(1).item() for p in ["theta"]}
state = random_state(N_QUBITS)

exp_exact = expectation(state, ops, observable, values, "ad")
exp_shots = expectation(state, ops, observable, values, "gpsr", "shots", n_shots=N_SHOTS)

assert jnp.isclose(exp_exact, exp_shots, atol=SHOTS_ATOL)

0 comments on commit 2824f4b

Please sign in to comment.