Skip to content

Commit

Permalink
refac expectation like pyq
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz committed Apr 19, 2024
1 parent 87561f1 commit b743496
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 14 deletions.
6 changes: 3 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ from operator import add
from typing import Any, Callable
from uuid import uuid4

from horqrux.adjoint import adjoint_expectation
from horqrux.circuit import Circuit, hea
from horqrux.circuit import Circuit, hea, expectation
from horqrux.primitive import Primitive
from horqrux.parametric import Parametric
from horqrux import Z, RX, RY, NOT, zero_state, apply_gate
from horqrux.utils import DiffMode


n_qubits = 5
Expand All @@ -137,7 +137,7 @@ class DQC(Circuit):
@partial(vmap, in_axes=(None, None, 0))
def __call__(self, param_values: Array, x: Array) -> Array:
param_dict = {name: val for name, val in zip(self.param_names, param_values)}
return adjoint_expectation(self.state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}})
return expectation(self.state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}}, DiffMode.ADJOINT)


circ = DQC(n_qubits=n_qubits, feature_map=[RX('phi', i) for i in range(n_qubits)], ansatz=hea(n_qubits, n_layers))
Expand Down
2 changes: 2 additions & 0 deletions horqrux/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

from .apply import apply_gate, apply_operator
from .circuit import Circuit, expectation
from .parametric import PHASE, RX, RY, RZ
from .primitive import NOT, SWAP, H, I, S, T, X, Y, Z
from .utils import (
DiffMode,
equivalent_state,
hilbert_reshape,
overlap,
Expand Down
4 changes: 2 additions & 2 deletions horqrux/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from horqrux.utils import OperationType, inner


def expectation(
def ad_expectation(
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
) -> Array:
"""
Expand All @@ -26,7 +26,7 @@ def expectation(
def adjoint_expectation(
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
) -> Array:
return expectation(state, gates, observable, values)
return ad_expectation(state, gates, observable, values)


def adjoint_expectation_fwd(
Expand Down
20 changes: 19 additions & 1 deletion horqrux/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from jax import Array
from jax.tree_util import register_pytree_node_class

from horqrux.adjoint import ad_expectation, adjoint_expectation
from horqrux.apply import apply_gate
from horqrux.parametric import RX, RY, Parametric
from horqrux.primitive import NOT, Primitive
from horqrux.utils import zero_state
from horqrux.utils import DiffMode, zero_state


@register_pytree_node_class
Expand Down Expand Up @@ -66,3 +67,20 @@ def hea(n_qubits: int, n_layers: int, rot_fns: list[Callable] = [RX, RY, RX]) ->
gates += ops

return gates


def expectation(
state: Array,
gates: list[Primitive],
observable: list[Primitive],
values: dict[str, float],
diff_mode: DiffMode | str = DiffMode.AD,
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.
"""
if diff_mode == DiffMode.AD:
return ad_expectation(state, gates, observable, values)
else:
return adjoint_expectation(state, gates, observable, values)
14 changes: 14 additions & 0 deletions horqrux/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ class OperationType(StrEnum):
JACOBIAN = "jacobian"


class DiffMode(StrEnum):
"""
Which Differentiation method to use.
Options: Automatic Differentiation - Using the autograd engine of JAX.
Adjoint Differentiation - An implementation of "Efficient calculation of gradients
in classical simulations of variational quantum algorithms",
Jones & Gacon, 2020
"""

AD = "ad"
ADJOINT = "adjoint"


def _dagger(operator: Array) -> Array:
return jnp.conjugate(operator.T)

Expand Down
14 changes: 6 additions & 8 deletions tests/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from jax import Array, grad

from horqrux import random_state
from horqrux.adjoint import adjoint_expectation, expectation
from horqrux.circuit import expectation
from horqrux.parametric import PHASE, RX, RY, RZ
from horqrux.primitive import NOT, H, I, S, T, X, Y, Z
from horqrux.utils import DiffMode

MAX_QUBITS = 7
PARAMETRIC_GATES = (RX, RY, RZ, PHASE)
Expand All @@ -25,13 +26,10 @@ def test_gradcheck() -> None:
}
state = random_state(MAX_QUBITS)

def adjoint_expfn(values) -> Array:
return adjoint_expectation(state, ops, observable, values)
def exp_fn(values: dict, diff_mode: DiffMode) -> Array:
return expectation(state, ops, observable, values, diff_mode)

def ad_expfn(values) -> Array:
return expectation(state, ops, observable, values)

grads_adjoint = grad(adjoint_expfn)(values)
grad_ad = grad(ad_expfn)(values)
grads_adjoint = grad(exp_fn)(values, "adjoint")
grad_ad = grad(exp_fn)(values, "ad")
for param, ad_grad in grad_ad.items():
assert jnp.isclose(grads_adjoint[param], ad_grad, atol=0.09)

0 comments on commit b743496

Please sign in to comment.