diff --git a/docs/index.md b/docs/index.md index 99d842c..1081a62 100644 --- a/docs/index.md +++ b/docs/index.md @@ -110,11 +110,11 @@ from operator import add from typing import Any, Callable from uuid import uuid4 -from horqrux.adjoint import adjoint_expectation -from horqrux.circuit import Circuit, hea +from horqrux.circuit import Circuit, hea, expectation from horqrux.primitive import Primitive from horqrux.parametric import Parametric from horqrux import Z, RX, RY, NOT, zero_state, apply_gate +from horqrux.utils import DiffMode n_qubits = 5 @@ -137,7 +137,7 @@ class DQC(Circuit): @partial(vmap, in_axes=(None, None, 0)) def __call__(self, param_values: Array, x: Array) -> Array: param_dict = {name: val for name, val in zip(self.param_names, param_values)} - return adjoint_expectation(self.state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}}) + return expectation(self.state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}}, DiffMode.ADJOINT) circ = DQC(n_qubits=n_qubits, feature_map=[RX('phi', i) for i in range(n_qubits)], ansatz=hea(n_qubits, n_layers)) diff --git a/horqrux/__init__.py b/horqrux/__init__.py index 1b46fcf..be34ae8 100644 --- a/horqrux/__init__.py +++ b/horqrux/__init__.py @@ -1,9 +1,11 @@ from __future__ import annotations from .apply import apply_gate, apply_operator +from .circuit import Circuit, expectation from .parametric import PHASE, RX, RY, RZ from .primitive import NOT, SWAP, H, I, S, T, X, Y, Z from .utils import ( + DiffMode, equivalent_state, hilbert_reshape, overlap, diff --git a/horqrux/adjoint.py b/horqrux/adjoint.py index d7fa777..fa35b78 100644 --- a/horqrux/adjoint.py +++ b/horqrux/adjoint.py @@ -10,7 +10,7 @@ from horqrux.utils import OperationType, inner -def expectation( +def ad_expectation( state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float] ) -> Array: """ @@ -26,7 +26,7 @@ def expectation( def adjoint_expectation( state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float] ) -> Array: - return expectation(state, gates, observable, values) + return ad_expectation(state, gates, observable, values) def adjoint_expectation_fwd( diff --git a/horqrux/circuit.py b/horqrux/circuit.py index ad47d2a..388b8ba 100644 --- a/horqrux/circuit.py +++ b/horqrux/circuit.py @@ -7,10 +7,11 @@ from jax import Array from jax.tree_util import register_pytree_node_class +from horqrux.adjoint import ad_expectation, adjoint_expectation from horqrux.apply import apply_gate from horqrux.parametric import RX, RY, Parametric from horqrux.primitive import NOT, Primitive -from horqrux.utils import zero_state +from horqrux.utils import DiffMode, zero_state @register_pytree_node_class @@ -66,3 +67,20 @@ def hea(n_qubits: int, n_layers: int, rot_fns: list[Callable] = [RX, RY, RX]) -> gates += ops return gates + + +def expectation( + state: Array, + gates: list[Primitive], + observable: list[Primitive], + values: dict[str, float], + diff_mode: DiffMode | str = DiffMode.AD, +) -> Array: + """ + Run 'state' through a sequence of 'gates' given parameters 'values' + and compute the expectation given an observable. + """ + if diff_mode == DiffMode.AD: + return ad_expectation(state, gates, observable, values) + else: + return adjoint_expectation(state, gates, observable, values) diff --git a/horqrux/utils.py b/horqrux/utils.py index 6a1c03f..eda4246 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -34,6 +34,20 @@ class OperationType(StrEnum): JACOBIAN = "jacobian" +class DiffMode(StrEnum): + """ + Which Differentiation method to use. + + Options: Automatic Differentiation - Using the autograd engine of JAX. + Adjoint Differentiation - An implementation of "Efficient calculation of gradients + in classical simulations of variational quantum algorithms", + Jones & Gacon, 2020 + """ + + AD = "ad" + ADJOINT = "adjoint" + + def _dagger(operator: Array) -> Array: return jnp.conjugate(operator.T) diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py index 4647e86..ce5c704 100644 --- a/tests/test_adjoint.py +++ b/tests/test_adjoint.py @@ -5,9 +5,10 @@ from jax import Array, grad from horqrux import random_state -from horqrux.adjoint import adjoint_expectation, expectation +from horqrux.circuit import expectation from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, H, I, S, T, X, Y, Z +from horqrux.utils import DiffMode MAX_QUBITS = 7 PARAMETRIC_GATES = (RX, RY, RZ, PHASE) @@ -25,13 +26,10 @@ def test_gradcheck() -> None: } state = random_state(MAX_QUBITS) - def adjoint_expfn(values) -> Array: - return adjoint_expectation(state, ops, observable, values) + def exp_fn(values: dict, diff_mode: DiffMode) -> Array: + return expectation(state, ops, observable, values, diff_mode) - def ad_expfn(values) -> Array: - return expectation(state, ops, observable, values) - - grads_adjoint = grad(adjoint_expfn)(values) - grad_ad = grad(ad_expfn)(values) + grads_adjoint = grad(exp_fn)(values, "adjoint") + grad_ad = grad(exp_fn)(values, "ad") for param, ad_grad in grad_ad.items(): assert jnp.isclose(grads_adjoint[param], ad_grad, atol=0.09)