-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [Feature] Finite shots * improved version * stopping now haha * comment * Finite shot implementation for multiple qubits * union instead of | for gatesequence definition --------- Co-authored-by: Atiyo Ghosh <[email protected]>
- Loading branch information
1 parent
92490b8
commit 2824f4b
Showing
6 changed files
with
191 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from __future__ import annotations | ||
|
||
from collections import Counter | ||
from typing import Any, Optional | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax import Array | ||
from jax.experimental import checkify | ||
|
||
from horqrux.adjoint import adjoint_expectation | ||
from horqrux.apply import apply_gate | ||
from horqrux.primitive import GateSequence, Primitive | ||
from horqrux.shots import finite_shots_fwd | ||
from horqrux.utils import DiffMode, ForwardMode, OperationType, inner | ||
|
||
|
||
def run( | ||
circuit: GateSequence, | ||
state: Array, | ||
values: dict[str, float] = dict(), | ||
) -> Array: | ||
return apply_gate(state, circuit, values) | ||
|
||
|
||
def sample( | ||
state: Array, | ||
gates: GateSequence, | ||
values: dict[str, float] = dict(), | ||
n_shots: int = 1000, | ||
) -> Counter: | ||
if n_shots < 1: | ||
raise ValueError("You can only call sample with n_shots>0.") | ||
|
||
wf = apply_gate(state, gates, values) | ||
probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel() | ||
key = jax.random.PRNGKey(0) | ||
n_qubits = len(state.shape) | ||
# JAX handles pseudo random number generation by tracking an explicit state via a random key | ||
# For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html | ||
samples = jax.vmap( | ||
lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs) | ||
)(jax.random.split(key, n_shots)) | ||
|
||
return Counter( | ||
{ | ||
format(k, "0{}b".format(n_qubits)): count.item() | ||
for k, count in enumerate(jnp.bincount(samples)) | ||
if count > 0 | ||
} | ||
) | ||
|
||
|
||
def ad_expectation( | ||
state: Array, gates: GateSequence, observable: GateSequence, values: dict[str, float] | ||
) -> Array: | ||
""" | ||
Run 'state' through a sequence of 'gates' given parameters 'values' | ||
and compute the expectation given an observable. | ||
""" | ||
out_state = apply_gate(state, gates, values, OperationType.UNITARY) | ||
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY) | ||
return inner(out_state, projected_state).real | ||
|
||
|
||
def expectation( | ||
state: Array, | ||
gates: GateSequence, | ||
observable: GateSequence, | ||
values: dict[str, float], | ||
diff_mode: DiffMode = DiffMode.AD, | ||
forward_mode: ForwardMode = ForwardMode.EXACT, | ||
n_shots: Optional[int] = None, | ||
key: Any = jax.random.PRNGKey(0), | ||
) -> Array: | ||
""" | ||
Run 'state' through a sequence of 'gates' given parameters 'values' | ||
and compute the expectation given an observable. | ||
""" | ||
if diff_mode == DiffMode.AD: | ||
return ad_expectation(state, gates, observable, values) | ||
elif diff_mode == DiffMode.ADJOINT: | ||
return adjoint_expectation(state, gates, observable, values) | ||
elif diff_mode == DiffMode.GPSR: | ||
checkify.check( | ||
forward_mode == ForwardMode.SHOTS, "Finite shots and GPSR must be used together" | ||
) | ||
checkify.check( | ||
isinstance(observable, Primitive), | ||
"Finite Shots only supports a single Primitive as an observable", | ||
) | ||
checkify.check( | ||
type(n_shots) is int, | ||
"Number of shots must be an integer for finite shots.", | ||
) | ||
# Type checking is disabled because mypy doesn't parse checkify.check. | ||
return finite_shots_fwd(state, gates, observable, values, n_shots=n_shots, key=key) # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from __future__ import annotations | ||
|
||
from functools import reduce | ||
from typing import Any | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax import Array | ||
from jax.experimental import checkify | ||
|
||
from horqrux.apply import apply_gate | ||
from horqrux.primitive import GateSequence, Primitive | ||
from horqrux.utils import none_like | ||
|
||
|
||
def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array: | ||
"""For finite shot sampling we need to calculate the eigenvalues/vectors of | ||
an observable. This helper function takes an observable and system size | ||
(n_qubits) and returns the overall action of the observable on the whole | ||
system. | ||
LIMITATION: currently only works for observables which are not controlled. | ||
""" | ||
checkify.check( | ||
observable.control == observable.parse_idx(none_like(observable.target)), | ||
"Controlled gates cannot be promoted from observables to operations on the whole state vector", | ||
) | ||
unitary = observable.unitary() | ||
target = observable.target[0][0] | ||
identity = jnp.eye(2, dtype=unitary.dtype) | ||
ops = [identity for _ in range(n_qubits)] | ||
ops[target] = unitary | ||
return reduce(lambda x, y: jnp.kron(x, y), ops[1:], ops[0]) | ||
|
||
|
||
def finite_shots_fwd( | ||
state: Array, | ||
gates: GateSequence, | ||
observable: Primitive, | ||
values: dict[str, float], | ||
n_shots: int = 100, | ||
key: Any = jax.random.PRNGKey(0), | ||
) -> Array: | ||
""" | ||
Run 'state' through a sequence of 'gates' given parameters 'values' | ||
and compute the expectation given an observable. | ||
""" | ||
state = apply_gate(state, gates, values) | ||
n_qubits = len(state.shape) | ||
mat_obs = observable_to_matrix(observable, n_qubits) | ||
eigvals, eigvecs = jnp.linalg.eigh(mat_obs) | ||
inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten()) | ||
probs = jnp.abs(inner_prod) ** 2 | ||
return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from __future__ import annotations | ||
|
||
import jax.numpy as jnp | ||
|
||
from horqrux import expectation, random_state | ||
from horqrux.parametric import RX | ||
from horqrux.primitive import Z | ||
|
||
N_QUBITS = 4 | ||
SHOTS_ATOL = 0.01 | ||
N_SHOTS = 10000 | ||
|
||
|
||
def test_shots() -> None: | ||
ops = [RX("theta", 0)] | ||
observable = Z(2) | ||
values = {p: jnp.ones(1).item() for p in ["theta"]} | ||
state = random_state(N_QUBITS) | ||
|
||
exp_exact = expectation(state, ops, observable, values, "ad") | ||
exp_shots = expectation(state, ops, observable, values, "gpsr", "shots", n_shots=N_SHOTS) | ||
|
||
assert jnp.isclose(exp_exact, exp_shots, atol=SHOTS_ATOL) |