Skip to content

Commit

Permalink
[Feature] Extend expectations to take multiple observables. (#28)
Browse files Browse the repository at this point in the history
* Extend expectations to take multiple observables.

* consistent naming of observable(s)

* consistent naming of observable(s) (actually, this time)
  • Loading branch information
atiyo authored Sep 17, 2024
1 parent 77084a5 commit e69435a
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 33 deletions.
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class DQC(Circuit):
@partial(vmap, in_axes=(None, None, 0))
def __call__(self, param_values: Array, x: Array) -> Array:
param_dict = {name: val for name, val in zip(self.param_names, param_values)}
return adjoint_expectation(self.state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}})
return adjoint_expectation(self.state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}})[0]


circ = DQC(n_qubits=n_qubits, feature_map=[RX('phi', i) for i in range(n_qubits)], ansatz=hea(n_qubits, n_layers))
Expand Down
29 changes: 23 additions & 6 deletions horqrux/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from typing import Tuple

from jax import Array, custom_vjp
from jax import numpy as jnp

from horqrux.apply import apply_gate
from horqrux.parametric import Parametric
from horqrux.primitive import Primitive
from horqrux.primitive import GateSequence, Primitive
from horqrux.utils import OperationType, inner


Expand All @@ -23,21 +24,35 @@ def expectation(


@custom_vjp
def __adjoint_expectation_single_observable(
state: Array, gates: list[Primitive], observable: Primitive, values: dict[str, float]
) -> Array:
return expectation(state, gates, [observable], values)


def adjoint_expectation(
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
state: Array, gates: GateSequence, observables: list[Primitive], values: dict[str, float]
) -> Array:
return expectation(state, gates, observable, values)
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.
"""
outputs = [
__adjoint_expectation_single_observable(state, gates, observable, values)
for observable in observables
]
return jnp.stack(outputs)


def adjoint_expectation_fwd(
def adjoint_expectation_single_observable_fwd(
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
) -> Tuple[Array, Tuple[Array, Array, list[Primitive], dict[str, float]]]:
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
return inner(out_state, projected_state).real, (out_state, projected_state, gates, values)


def adjoint_expectation_bwd(
def adjoint_expectation_single_observable_bwd(
res: Tuple[Array, Array, list[Primitive], dict[str, float]], tangent: Array
) -> tuple:
"""Implementation of Algorithm 1 of https://arxiv.org/abs/2009.02823
Expand All @@ -56,4 +71,6 @@ def adjoint_expectation_bwd(
return (None, None, None, grads)


adjoint_expectation.defvjp(adjoint_expectation_fwd, adjoint_expectation_bwd)
__adjoint_expectation_single_observable.defvjp(
adjoint_expectation_single_observable_fwd, adjoint_expectation_single_observable_bwd
)
31 changes: 21 additions & 10 deletions horqrux/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def sample(
)


def ad_expectation(
state: Array, gates: GateSequence, observable: GateSequence, values: dict[str, float]
def __ad_expectation_single_observable(
state: Array, gates: GateSequence, observable: Primitive, values: dict[str, float]
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
Expand All @@ -63,10 +63,24 @@ def ad_expectation(
return inner(out_state, projected_state).real


def ad_expectation(
state: Array, gates: GateSequence, observables: list[Primitive], values: dict[str, float]
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.
"""
outputs = [
__ad_expectation_single_observable(state, gates, observable, values)
for observable in observables
]
return jnp.stack(outputs)


def expectation(
state: Array,
gates: GateSequence,
observable: GateSequence,
observables: list[Primitive],
values: dict[str, float],
diff_mode: DiffMode = DiffMode.AD,
forward_mode: ForwardMode = ForwardMode.EXACT,
Expand All @@ -78,20 +92,17 @@ def expectation(
and compute the expectation given an observable.
"""
if diff_mode == DiffMode.AD:
return ad_expectation(state, gates, observable, values)
return ad_expectation(state, gates, observables, values)
elif diff_mode == DiffMode.ADJOINT:
return adjoint_expectation(state, gates, observable, values)
return adjoint_expectation(state, gates, observables, values)
elif diff_mode == DiffMode.GPSR:
checkify.check(
forward_mode == ForwardMode.SHOTS, "Finite shots and GPSR must be used together"
)
checkify.check(
isinstance(observable, Primitive),
"Finite Shots only supports a single Primitive as an observable",
)
checkify.check(
type(n_shots) is int,
"Number of shots must be an integer for finite shots.",
)
# Type checking is disabled because mypy doesn't parse checkify.check.
return finite_shots_fwd(state, gates, observable, values, n_shots=n_shots, key=key) # type: ignore
# type: ignore
return finite_shots_fwd(state, gates, observables, values, n_shots=n_shots, key=key)
47 changes: 43 additions & 4 deletions horqrux/shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array:
def finite_shots_fwd(
state: Array,
gates: GateSequence,
observable: Primitive,
observables: list[Primitive],
values: dict[str, float],
n_shots: int = 100,
key: Any = jax.random.PRNGKey(0),
Expand All @@ -48,11 +48,50 @@ def finite_shots_fwd(
"""
state = apply_gate(state, gates, values)
n_qubits = len(state.shape)
mat_obs = observable_to_matrix(observable, n_qubits)
eigvals, eigvecs = jnp.linalg.eigh(mat_obs)
mat_obs = [observable_to_matrix(observable, n_qubits) for observable in observables]
eigs = [jnp.linalg.eigh(mat) for mat in mat_obs]
eigvecs, eigvals = align_eigenvectors(eigs)
inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten())
probs = jnp.abs(inner_prod) ** 2
return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean()
return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean(axis=0)


def align_eigenvectors(eigs: list[tuple[Array, Array]]) -> tuple[Array, Array]:
"""
Given a list of eigenvalue eigenvector matrix tuples in the form of
[(eigenvalue, eigenvector)...], this function aligns all the eigenvector
matrices so that they are identical, and also rearranges the corresponding
eigenvalues.
This is primarily used as a utility function to help sample multiple
correlated observables when using finite shots.
Given two permuted eigenvector matrices, A and B, we wish to find a permutation
matrix P such that A P = B. This function calculates such a permutation
matrix and uses it to align each eigenvector matrix to the first eigenvector
matrix of eigs.
"""
eigenvalues = []
eigs_copy = eigs.copy()
eigenvalue, eigenvector_matrix = eigs_copy.pop(0)
eigenvalues.append(eigenvalue)
# TODO: laxify this loop
for mat in eigs_copy:
inv = jnp.linalg.inv(mat[1])
P = (inv @ eigenvector_matrix).real > 0.5
checkify.check(
validate_permutation_matrix(P),
"Did not calculate valid permutation matrix",
)
eigenvalues.append(mat[0] @ P)
return eigenvector_matrix, jnp.stack(eigenvalues, axis=1)


def validate_permutation_matrix(P: Array) -> Array:
rows = P.sum(axis=0)
columns = P.sum(axis=1)
ones = jnp.ones(P.shape[0], dtype=rows.dtype)
return ((ones == rows) & (ones == columns)).min()


@finite_shots_fwd.defjvp
Expand Down
7 changes: 4 additions & 3 deletions tests/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from jax import Array, grad

from horqrux import random_state
from horqrux.adjoint import adjoint_expectation, expectation
from horqrux.adjoint import adjoint_expectation
from horqrux.api import expectation
from horqrux.parametric import PHASE, RX, RY, RZ
from horqrux.primitive import NOT, H, I, S, T, X, Y, Z

Expand All @@ -26,10 +27,10 @@ def test_gradcheck() -> None:
state = random_state(MAX_QUBITS)

def adjoint_expfn(values) -> Array:
return adjoint_expectation(state, ops, observable, values)
return adjoint_expectation(state, ops, observable, values)[0]

def ad_expfn(values) -> Array:
return expectation(state, ops, observable, values)
return expectation(state, ops, observable, values)[0]

grads_adjoint = grad(adjoint_expfn)(values)
grad_ad = grad(ad_expfn)(values)
Expand Down
18 changes: 9 additions & 9 deletions tests/test_shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,32 @@
from horqrux.parametric import RX
from horqrux.primitive import Z

N_QUBITS = 4
N_QUBITS = 2
SHOTS_ATOL = 0.01
N_SHOTS = 10000
N_SHOTS = 100_000


def test_shots() -> None:
ops = [RX("theta", 0)]
observable = Z(0)
observables = [Z(0), Z(1)]
state = random_state(N_QUBITS)
x = jnp.pi * 0.5

def exact(x):
values = {"theta": x}
return expectation(state, ops, observable, values, "ad")
return expectation(state, ops, observables, values, "ad")

def shots(x):
values = {"theta": x}
return expectation(state, ops, observable, values, "gpsr", "shots", n_shots=N_SHOTS)
return expectation(state, ops, observables, values, "gpsr", "shots", n_shots=N_SHOTS)

exp_exact = exact(x)
exp_shots = exact(x)
exp_shots = shots(x)

assert jnp.isclose(exp_exact, exp_shots, atol=SHOTS_ATOL)
assert jnp.allclose(exp_exact, exp_shots, atol=SHOTS_ATOL)

d_exact = jax.grad(exact)
d_shots = jax.grad(shots)
d_exact = jax.grad(lambda x: exact(x).sum())
d_shots = jax.grad(lambda x: shots(x).sum())

grad_backprop = d_exact(x)
grad_shots = d_shots(x)
Expand Down

0 comments on commit e69435a

Please sign in to comment.