From 5fd2dad5202d4ef5fefd4528f6e34013049319cf Mon Sep 17 00:00:00 2001 From: seitzdom Date: Thu, 5 Dec 2024 19:33:11 +0400 Subject: [PATCH] [Feature] Add sample, Use single precision by default (#16) * [Feature] Use single precision by default * refac expectation like pyq * Add sample, increase atol * rework circ * Remove spurious import. * Lint. * change definition of a quantum circuit * more docstr on circuit * lint * rename to fparams and vparams --------- Co-authored-by: Roland Guichard --- docs/index.md | 46 ++++++++------- horqrux/__init__.py | 2 + horqrux/_misc.py | 5 +- horqrux/adjoint.py | 19 ++---- horqrux/apply.py | 4 +- horqrux/circuit.py | 134 +++++++++++++++++++++++++++++++++++------- horqrux/parametric.py | 4 +- horqrux/primitive.py | 2 +- horqrux/utils.py | 7 +++ tests/test_adjoint.py | 17 +++--- tests/test_analog.py | 7 ++- 11 files changed, 173 insertions(+), 74 deletions(-) diff --git a/docs/index.md b/docs/index.md index f9ace3f..87c805e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -110,11 +110,11 @@ from operator import add from typing import Any, Callable from uuid import uuid4 -from horqrux.adjoint import adjoint_expectation -from horqrux.circuit import Circuit, hea +from horqrux.circuit import QuantumCircuit, hea, expectation from horqrux.primitive import Primitive from horqrux.parametric import Parametric from horqrux import Z, RX, RY, NOT, zero_state, apply_gate +from horqrux.utils import DiffMode n_qubits = 5 @@ -128,19 +128,21 @@ fn = lambda x, degree: .05 * reduce(add, (jnp.cos(i*x) + jnp.sin(i*x) for i in r x = jnp.linspace(0, 10, 100) y = fn(x, 5) - -class DQC(Circuit): +@dataclass +class DQC(QuantumCircuit): def __post_init__(self) -> None: self.observable: list[Primitive] = [Z(0)] self.state = zero_state(self.n_qubits) @partial(vmap, in_axes=(None, None, 0)) def __call__(self, param_values: Array, x: Array) -> Array: - param_dict = {name: val for name, val in zip(self.param_names, param_values)} - return adjoint_expectation(self.state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}})[0] - + param_dict = {name: val for name, val in zip(self.vparams, param_values)} + return expectation(self.state, self.operations, self.observable, {**param_dict, **{'phi': x}}, DiffMode.ADJOINT) -circ = DQC(n_qubits=n_qubits, feature_map=[RX('phi', i) for i in range(n_qubits)], ansatz=hea(n_qubits, n_layers)) +feature_map = [RX('phi', i) for i in range(n_qubits)] +fm_names = [f.param for f in feature_map] +ansatz = hea(n_qubits, n_layers) +circ = DQC(n_qubits=n_qubits, operations=feature_map + ansatz, fparams=fm_names) # Create random initial values for the parameters key = jax.random.PRNGKey(42) param_vals = jax.random.uniform(key, shape=(circ.n_vparams,)) @@ -212,20 +214,20 @@ from jax import Array, jit, value_and_grad, vmap from numpy.random import uniform from horqrux.apply import group_by_index -from horqrux.circuit import Circuit, hea +from horqrux.circuit import QuantumCircuit, hea from horqrux import NOT, RX, RY, Z, apply_gate, zero_state from horqrux.primitive import Primitive from horqrux.parametric import Parametric from horqrux.utils import inner -LEARNING_RATE = 0.01 +LEARNING_RATE = 0.15 N_QUBITS = 4 DEPTH = 3 VARIABLES = ("x", "y") NUM_VARIABLES = len(VARIABLES) X_POS, Y_POS = [i for i in range(NUM_VARIABLES)] -BATCH_SIZE = 150 -N_EPOCHS = 1000 +BATCH_SIZE = 500 +N_EPOCHS = 500 def total_magnetization(n_qubits:int) -> Callable: paulis = [Z(i) for i in range(n_qubits)] @@ -237,26 +239,28 @@ def total_magnetization(n_qubits:int) -> Callable: return inner(out_state, projected_state).real return _total_magnetization - -class DQC(Circuit): +@dataclass +class DQC(QuantumCircuit): def __post_init__(self) -> None: - self.ansatz = group_by_index(self.ansatz) + self.operations = group_by_index(self.operations) self.observable = total_magnetization(self.n_qubits) self.state = zero_state(self.n_qubits) - def __call__(self, param_vals: Array, x: Array, y: Array) -> Array: - param_dict = {name: val for name, val in zip(self.param_names, param_vals)} + + def __call__(self, values: dict[str, Array], x: Array, y: Array) -> Array: + param_dict = {name: val for name, val in zip(self.vparams, values)} out_state = apply_gate( - self.state, self.feature_map + self.ansatz, {**param_dict, **{"x": x, "y": y}} + self.state, self.operations, {**param_dict, **{"f_x": x, "f_y": y}} ) return self.observable(out_state, {}) -fm = [RX("x", i) for i in range(N_QUBITS // 2)] + [ - RX("y", i) for i in range(N_QUBITS // 2, N_QUBITS) +fm = [RX("f_x", i) for i in range(N_QUBITS // 2)] + [ + RX("f_y", i) for i in range(N_QUBITS // 2, N_QUBITS) ] +fm_circuit_parameters = [f.param for f in fm] ansatz = hea(N_QUBITS, DEPTH) -circ = DQC(N_QUBITS, fm, ansatz) +circ = DQC(N_QUBITS, fm + ansatz, fm_circuit_parameters) # Create random initial values for the parameters key = jax.random.PRNGKey(42) param_vals = jax.random.uniform(key, shape=(circ.n_vparams,)) diff --git a/horqrux/__init__.py b/horqrux/__init__.py index f23cf20..513c031 100644 --- a/horqrux/__init__.py +++ b/horqrux/__init__.py @@ -2,9 +2,11 @@ from .api import expectation from .apply import apply_gate, apply_operator +from .circuit import QuantumCircuit, sample from .parametric import PHASE, RX, RY, RZ from .primitive import NOT, SWAP, H, I, S, T, X, Y, Z from .utils import ( + DiffMode, equivalent_state, hilbert_reshape, overlap, diff --git a/horqrux/_misc.py b/horqrux/_misc.py index 894c3c3..e7e58ff 100644 --- a/horqrux/_misc.py +++ b/horqrux/_misc.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import jax import jax.numpy as jnp from jax._src.typing import DType + def default_complex_dtype() -> DType: if jax.config.jax_enable_x64: return jnp.complex128 else: - return jnp.complex64 \ No newline at end of file + return jnp.complex64 diff --git a/horqrux/adjoint.py b/horqrux/adjoint.py index f8f076f..de7f925 100644 --- a/horqrux/adjoint.py +++ b/horqrux/adjoint.py @@ -3,7 +3,6 @@ from typing import Tuple from jax import Array, custom_vjp -from jax import numpy as jnp from horqrux.apply import apply_gate from horqrux.parametric import Parametric @@ -11,15 +10,15 @@ from horqrux.utils import OperationType, inner -def expectation( - state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float] +def ad_expectation( + state: Array, gates: list[Primitive], observables: list[Primitive], values: dict[str, float] ) -> Array: """ Run 'state' through a sequence of 'gates' given parameters 'values' and compute the expectation given an observable. """ out_state = apply_gate(state, gates, values, OperationType.UNITARY) - projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY) + projected_state = apply_gate(out_state, observables, values, OperationType.UNITARY) return inner(out_state, projected_state).real @@ -27,21 +26,13 @@ def expectation( def __adjoint_expectation_single_observable( state: Array, gates: list[Primitive], observable: Primitive, values: dict[str, float] ) -> Array: - return expectation(state, gates, [observable], values) + return ad_expectation(state, gates, [observable], values) def adjoint_expectation( state: Array, gates: GateSequence, observables: list[Primitive], values: dict[str, float] ) -> Array: - """ - Run 'state' through a sequence of 'gates' given parameters 'values' - and compute the expectation given an observable. - """ - outputs = [ - __adjoint_expectation_single_observable(state, gates, observable, values) - for observable in observables - ] - return jnp.stack(outputs) + return ad_expectation(state, gates, observables, values) # type: ignore[arg-type] def adjoint_expectation_single_observable_fwd( diff --git a/horqrux/apply.py b/horqrux/apply.py index 5bd054c..f82cf9f 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -43,8 +43,8 @@ def apply_operator( if is_controlled(control): operator = _controlled(operator, len(control)) state_dims = (*control, *target) # type: ignore[arg-type] - n_qubits = int(np.log2(operator.size)) - operator = operator.reshape(tuple(2 for _ in np.arange(n_qubits))) + n_qubits = int(np.log2(operator.shape[1])) + operator = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits))) op_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int)) state = jnp.tensordot(a=operator, b=state, axes=(op_dims, state_dims)) new_state_dims = tuple(i for i in range(len(state_dims))) diff --git a/horqrux/circuit.py b/horqrux/circuit.py index ad47d2a..ae4bb03 100644 --- a/horqrux/circuit.py +++ b/horqrux/circuit.py @@ -1,47 +1,77 @@ from __future__ import annotations -from dataclasses import dataclass +from collections import Counter +from dataclasses import dataclass, field from typing import Any, Callable from uuid import uuid4 +import jax +import jax.numpy as jnp from jax import Array from jax.tree_util import register_pytree_node_class +from horqrux.adjoint import ad_expectation, adjoint_expectation from horqrux.apply import apply_gate from horqrux.parametric import RX, RY, Parametric from horqrux.primitive import NOT, Primitive -from horqrux.utils import zero_state +from horqrux.utils import DiffMode, zero_state @register_pytree_node_class @dataclass -class Circuit: - """A minimalistic circuit class to store a sequence of gates.""" +class QuantumCircuit: + """A minimalistic circuit class to store a sequence of gates. - n_qubits: int - feature_map: list[Primitive] - ansatz: list[Primitive] + Attributes: + n_qubits (int): Number of qubits. + operations (list[Primitive]): Operations defining the circuit. + fparams (list[str]): List of parameters that are considered + non trainable, used for passing fixed input data to a quantum circuit. + The corresponding operations compose the `feature map`. + """ - def __post_init__(self) -> None: - self.state = zero_state(self.n_qubits) + n_qubits: int + operations: list[Primitive] + fparams: list[str] = field(default_factory=list) - def __call__(self, param_values: Array) -> Array: - return apply_gate( - self.state, - self.feature_map + self.ansatz, - {name: val for name, val in zip(self.param_names, param_values)}, - ) + def __call__(self, state: Array, values: dict[str, Array]) -> Array: + if state is None: + state = zero_state(self.n_qubits) + return apply_gate(state, self.operations, values) @property def param_names(self) -> list[str]: - return [str(op.param) for op in self.ansatz if isinstance(op, Parametric)] + """List of parameters of the circuit. + Composed of variational and feature map parameters. + + Returns: + list[str]: Names of parameters. + """ + return [str(op.param) for op in self.operations if isinstance(op, Parametric)] + + @property + def vparams(self) -> list[str]: + """List of variational parameters of the circuit. + + Returns: + list[str]: Names of variational parameters. + """ + return [name for name in self.param_names if name not in self.fparams] @property def n_vparams(self) -> int: - return len(self.param_names) + """Number of variational parameters. + + Returns: + int: Number of variational parameters. + """ + return len(self.param_names) - len(self.fparams) def tree_flatten(self) -> tuple: - children = (self.feature_map, self.ansatz) + children = ( + self.operations, + self.fparams, + ) aux_data = (self.n_qubits,) return (aux_data, children) @@ -50,15 +80,32 @@ def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: return cls(*aux_data, *children) -def hea(n_qubits: int, n_layers: int, rot_fns: list[Callable] = [RX, RY, RX]) -> list[Primitive]: +def hea( + n_qubits: int, + n_layers: int, + rot_fns: list[Callable] = [RX, RY, RX], + variational_param_prefix: str = "v_", +) -> list[Primitive]: """Hardware-efficient ansatz; A helper function to generate a sequence of rotations followed - by a global entangling operation.""" + by a global entangling operation. + + Args: + n_qubits (int): Number of qubits. + n_layers (int): Number of layers + rot_fns (list[Callable], optional): A list of rotations applied on one qubit. + Defaults to [RX, RY, RX]. + variational_param_prefix (str, optional): Prefix for the name of variational parameters. + Defaults to "v_". Names suffix are randomly generated strings with uuid4. + + Returns: + list[Primitive]: List of gates composing the ansatz. + """ gates = [] param_names = [] for _ in range(n_layers): for i in range(n_qubits): ops = [ - fn(str(uuid4()), qubit) + fn(variational_param_prefix + str(uuid4()), qubit) for fn, qubit in zip(rot_fns, [i for _ in range(len(rot_fns))]) ] param_names += [op.param for op in ops] @@ -66,3 +113,48 @@ def hea(n_qubits: int, n_layers: int, rot_fns: list[Callable] = [RX, RY, RX]) -> gates += ops return gates + + +def expectation( + state: Array, + gates: list[Primitive], + observable: list[Primitive], + values: dict[str, float], + diff_mode: DiffMode | str = DiffMode.AD, +) -> Array: + """ + Run 'state' through a sequence of 'gates' given parameters 'values' + and compute the expectation given an observable. + """ + if diff_mode == DiffMode.AD: + return ad_expectation(state, gates, observable, values) + else: + return adjoint_expectation(state, gates, observable, values) + + +def sample( + state: Array, + gates: list[Primitive], + values: dict[str, float] = dict(), + n_shots: int = 1000, +) -> Counter: + if n_shots < 1: + raise ValueError("You can only call sample with n_shots>0.") + + wf = apply_gate(state, gates, values) + probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel() + key = jax.random.PRNGKey(0) + n_qubits = len(state.shape) + # JAX handles pseudo random number generation by tracking an explicit state via a random key + # For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html + samples = jax.vmap( + lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs) + )(jax.random.split(key, n_shots)) + + return Counter( + { + format(k, "0{}b".format(n_qubits)): count.item() + for k, count in enumerate(jnp.bincount(samples)) + if count > 0 + } + ) diff --git a/horqrux/parametric.py b/horqrux/parametric.py index 49320d8..80b925b 100644 --- a/horqrux/parametric.py +++ b/horqrux/parametric.py @@ -72,9 +72,7 @@ def name(self) -> str: return "C" + base_name if is_controlled(self.control) else base_name def __repr__(self) -> str: - return ( - self.name + f"(target={self.target[0]}, control={self.control[0]}, param={self.param})" - ) + return self.name + f"(target={self.target}, control={self.control}, param={self.param})" def RX(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric: diff --git a/horqrux/primitive.py b/horqrux/primitive.py index 2aac3a4..7156117 100644 --- a/horqrux/primitive.py +++ b/horqrux/primitive.py @@ -69,7 +69,7 @@ def name(self) -> str: return "C" + self.generator_name if is_controlled(self.control) else self.generator_name def __repr__(self) -> str: - return self.name + f"(target={self.target[0]}, control={self.control[0]})" + return self.name + f"(target={self.target}, control={self.control})" GateSequence = Union[Primitive, Iterable[Primitive]] diff --git a/horqrux/utils.py b/horqrux/utils.py index 6587cdd..2cc82f4 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -39,9 +39,16 @@ class OperationType(StrEnum): class DiffMode(StrEnum): + """Differentiation mode.""" + AD = "ad" + """Automatic Differentiation - Using the autograd engine of JAX.""" ADJOINT = "adjoint" + """Adjoint Differentiation - An implementation of "Efficient calculation of gradients + in classical simulations of variational quantum algorithms", + Jones & Gacon, 2020.""" GPSR = "gpsr" + """Generalized parameter shift rule.""" class ForwardMode(StrEnum): diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py index 80fa144..27f0164 100644 --- a/tests/test_adjoint.py +++ b/tests/test_adjoint.py @@ -5,10 +5,10 @@ from jax import Array, grad from horqrux import random_state -from horqrux.adjoint import adjoint_expectation -from horqrux.api import expectation +from horqrux.circuit import expectation from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, H, I, S, T, X, Y, Z +from horqrux.utils import DiffMode MAX_QUBITS = 7 PARAMETRIC_GATES = (RX, RY, RZ, PHASE) @@ -26,13 +26,10 @@ def test_gradcheck() -> None: } state = random_state(MAX_QUBITS) - def adjoint_expfn(values) -> Array: - return adjoint_expectation(state, ops, observable, values)[0] + def exp_fn(values: dict, diff_mode: DiffMode = "ad") -> Array: + return expectation(state, ops, observable, values, diff_mode) - def ad_expfn(values) -> Array: - return expectation(state, ops, observable, values)[0] - - grads_adjoint = grad(adjoint_expfn)(values) - grad_ad = grad(ad_expfn)(values) + grads_adjoint = grad(exp_fn)(values, "adjoint") + grad_ad = grad(exp_fn)(values) for param, ad_grad in grad_ad.items(): - assert jnp.isclose(grads_adjoint[param], ad_grad, atol=0.09) + assert jnp.isclose(grads_adjoint[param], ad_grad, atol=1.0e-3) diff --git a/tests/test_analog.py b/tests/test_analog.py index c9b7d95..bf24dc0 100644 --- a/tests/test_analog.py +++ b/tests/test_analog.py @@ -34,7 +34,12 @@ def Hamiltonian_general(n_qubits: int = 2, batch_size: int = 1) -> jnp.array: return H_batch -@pytest.mark.parametrize("n_qubits, batch_size", [(2, 1), (4, 2)]) +@pytest.mark.parametrize( + "n_qubits, batch_size", + [ + (2, 1), + ], +) def test_hamevo_general(n_qubits: int, batch_size: int) -> None: H = Hamiltonian_general(n_qubits, batch_size) t_evo = np.random.uniform(0, 1, (batch_size, 1))