diff --git a/README.md b/README.md index 3590687..2d3f427 100644 --- a/README.md +++ b/README.md @@ -1,46 +1,43 @@ -# horqrux +[![Linting / Tests/ Documentation](https://github.com/pasqal-io/horqrux/actions/workflows/run-tests-and-mypy.yml/badge.svg)](https://github.com/pasqal-io/horqrux/actions/workflows/run-tests-and-mypy.yml) +[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) +[![Pypi](https://badge.fury.io/py/horqrux.svg)](https://pypi.org/project/horqrux/) -**horqrux** is a [JAX](https://jax.readthedocs.io/en/latest/)-based state vector simulator designed for quantum machine learning. -It acts as a backend for [`Qadence`](https://github.com/pasqal-io/qadence), a digital-analog quantum programming interface. +`horqrux` is a [JAX](https://jax.readthedocs.io/en/latest/)-based state vector simulator designed for quantum machine learning and acts as a backend for [`Qadence`](https://github.com/pasqal-io/qadence), a digital-analog quantum programming interface. ## Installation -`horqrux` (CPU-only) can be installed from PyPI with `pip` as follows: +To install the CPU-only version, simply use `pip`: ```bash pip install horqrux ``` -If you want to install the GPU version, simply do: +If you intend to use GPU: ```bash pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html ``` -[![Linting / Tests/ Documentation](https://github.com/pasqal-io/horqrux/actions/workflows/run-tests-and-mypy.yml/badge.svg)](https://github.com/pasqal-io/horqrux/actions/workflows/run-tests-and-mypy.yml) -[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) -[![Pypi](https://badge.fury.io/py/horqrux.svg)](https://pypi.org/project/horqrux/) +## Getting started +`horqrux` adopts a minimalistic and functional interface however the [docs](https://pasqal-io.github.io/horqrux/latest/) provide a comprehensive A-Z guide ranging from how to apply simple primitive and parametric gates, to using [adjoint differentiation](https://arxiv.org/abs/2009.02823) to fit a nonlinear function and implementing [DQC](https://arxiv.org/abs/2011.10395) to solve a partial differential equation. +## Contributing -## Install from source +To learn how to contribute, please visit the [CONTRIBUTING](docs/CONTRIBUTING.md) page. -We recommend to use the [`hatch`](https://hatch.pypa.io/latest/) environment manager to install `horqrux` from source: +When developing within `horqrux`, you can either use the python environment manager [`hatch`](https://hatch.pypa.io/latest/): ```bash -python -m pip install hatch +pip install hatch -# get into a shell with all the dependencies -python -m hatch shell +# enter a shell with containing all the dependencies +hatch shell # run a command within the virtual environment with all the dependencies -python -m hatch run python my_script.py +hatch run python my_script.py ``` -Please note that `hatch` will not combine nicely with other environment managers such Conda. If you want to use Conda, install `horqrux` from source using `pip`: +When using any other environment manager like `venv` or `conda`, simply do: ```bash -# within the Conda environment -python -m pip install -e . +# within the virtual environment +pip install -e . ``` - -## Contributing - -Please refer to [CONTRIBUTING](docs/CONTRIBUTING.md) to learn how to contribute to `horqrux`. diff --git a/docs/index.md b/docs/index.md index 900e029..cb5fc71 100644 --- a/docs/index.md +++ b/docs/index.md @@ -11,7 +11,7 @@ choice and install it normally with `pip`: pip install horqrux ``` -## Gates +## Digital operations `horqrux` implements a large selection of both primitive and parametric single to n-qubit, digital quantum gates. @@ -68,10 +68,34 @@ param_value = 1 / 4 * jnp.pi new_state = apply_gate(state, RX(param_value, target_qubit, control_qubit)) ``` +## Analog Operations + +`horqrux` also allows for global state evolution via the `HamiltonianEvolution` operation. +Note that it expects a hamiltonian and a time evolution parameter passed as `numpy` or `jax.numpy` arrays. To build arbitrary Pauli hamiltonians, we recommend using [Qadence](https://github.com/pasqal-io/qadence/blob/main/examples/backends/low_level/horqrux_analog.py). + +```python exec="on" source="material-block" +from jax.numpy import pi, array, diag, kron, cdouble +from horqrux.analog import HamiltonianEvolution +from horqrux.apply import apply_gate +from horqrux.utils import uniform_state + +sigmaz = diag(array([1.0, -1.0], dtype=cdouble)) +Hbase = kron(sigmaz, sigmaz) + +Hamiltonian = kron(Hbase, Hbase) +n_qubits = 4 +t_evo = pi / 4 +hamevo = HamiltonianEvolution(tuple([i for i in range(n_qubits)])) +psi = uniform_state(n_qubits) +psi_star = apply_gate(psi, hamevo, {"hamiltonian": Hamiltonian, "time_evolution": t_evo}) +``` + +## Fitting a nonlinear function using adjoint differentiation + We can now build a fully differentiable variational circuit by simply defining a sequence of gates and a set of initial parameter values we want to optimize. -Horqrux provides an implementation of [adjoint differentiation](https://arxiv.org/abs/2009.02823), -which we can use to fit a function using a simple circuit class wrapper. +`horqrux` provides an implementation of [adjoint differentiation](https://arxiv.org/abs/2009.02823), +which we can use to fit a function using a simple `Circuit` class. ```python exec="on" source="material-block" html="1" from __future__ import annotations @@ -87,7 +111,7 @@ from typing import Any, Callable from uuid import uuid4 from horqrux.adjoint import adjoint_expectation -from horqrux.abstract import Primitive +from horqrux.primitive import Primitive from horqrux import Z, RX, RY, NOT, zero_state, apply_gate @@ -121,18 +145,16 @@ class Circuit: def __post_init__(self) -> None: # We will use a featuremap of RX rotations to encode some classical data - self.feature_map: list[Primitive] = [RX('phi', i) for i in range(n_qubits)] + self.feature_map: list[Primitive] = [RX('phi', i) for i in range(self.n_qubits)] self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers) self.observable: list[Primitive] = [Z(0)] @partial(vmap, in_axes=(None, None, 0)) - def forward(self, param_values: Array, x: Array) -> Array: + def __call__(self, param_values: Array, x: Array) -> Array: state = zero_state(self.n_qubits) param_dict = {name: val for name, val in zip(self.param_names, param_values)} return adjoint_expectation(state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}}) - def __call__(self, param_values: Array, x: Array) -> Array: - return self.forward(param_values, x) @property def n_vparams(self) -> int: @@ -154,15 +176,15 @@ def loss_fn(param_vals: Array, x: Array, y: Array) -> Array: return jnp.mean(optax.l2_loss(y_pred, y)) -def optimize_step(params: dict[str, Array], opt_state: Array, grads: dict[str, Array]) -> tuple: +def optimize_step(param_vals: Array, opt_state: Array, grads: Array) -> tuple: updates, opt_state = optimizer.update(grads, opt_state) - params = optax.apply_updates(params, updates) - return params, opt_state + param_vals = optax.apply_updates(param_vals, updates) + return param_vals, opt_state @jit -def train_step(i: int, inputs: tuple +def train_step(i: int, paramvals_w_optstate: tuple ) -> tuple: - param_vals, opt_state = inputs + param_vals, opt_state = paramvals_w_optstate loss, grads = value_and_grad(loss_fn)(param_vals, x, y) param_vals, opt_state = optimize_step(param_vals, opt_state, grads) return param_vals, opt_state @@ -188,3 +210,188 @@ def fig_to_html(fig: Figure) -> str: # markdown-exec: hide # from docs import docutils # markdown-exec: hide print(fig_to_html(plt.gcf())) # markdown-exec: hide ``` +## Fitting a partial differential equation using DQC + +Finally, we show how [DQC](https://arxiv.org/abs/2011.10395) can be implemented in `horqrux` and solve a partial differential equation. + +```python exec="on" source="material-block" html="1" +from __future__ import annotations + +from dataclasses import dataclass +from functools import reduce +from itertools import product +from operator import add +from uuid import uuid4 + +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import optax +from jax import Array, jit, value_and_grad, vmap +from numpy.random import uniform + +from horqrux import NOT, RX, RY, Z, apply_gate, zero_state +from horqrux.primitive import Primitive +from horqrux.utils import inner + +LEARNING_RATE = 0.01 +N_QUBITS = 4 +DEPTH = 3 +VARIABLES = ("x", "y") +X_POS = 0 +Y_POS = 1 +N_POINTS = 150 +N_EPOCHS = 1000 + + +def ansatz_w_params(n_qubits: int, n_layers: int) -> tuple[list, list]: + all_ops = [] + param_names = [] + rots_fns = [RX, RY, RX] + for _ in range(n_layers): + for i in range(n_qubits): + ops = [ + fn(str(uuid4()), qubit) + for fn, qubit in zip(rots_fns, [i for _ in range(len(rots_fns))]) + ] + param_names += [op.param for op in ops] + ops += [NOT((i + 1) % n_qubits, i % n_qubits) for i in range(n_qubits)] + all_ops += ops + + return all_ops, param_names + + +@dataclass +class TotalMagnetization: + n_qubits: int + + def __post_init__(self) -> None: + self.paulis = [Z(i) for i in range(self.n_qubits)] + + def __call__(self, state: Array, values: dict) -> Array: + return reduce(add, [apply_gate(state, pauli, values) for pauli in self.paulis]) + + +@dataclass +class Circuit: + n_qubits: int + n_layers: int + + def __post_init__(self) -> None: + self.feature_map: list[Primitive] = [RX("x", i) for i in range(self.n_qubits // 2)] + [ + RX("y", i) for i in range(self.n_qubits // 2, self.n_qubits) + ] + self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers) + self.observable = TotalMagnetization(self.n_qubits) + + def __call__(self, param_vals: Array, x: Array, y: Array) -> Array: + state = zero_state(self.n_qubits) + param_dict = {name: val for name, val in zip(self.param_names, param_vals)} + out_state = apply_gate( + state, self.feature_map + self.ansatz, {**param_dict, **{"x": x, "y": y}} + ) + projected_state = self.observable(state, param_dict) + return jnp.real(inner(out_state, projected_state)) + + @property + def n_vparams(self) -> int: + return len(self.param_names) + + +circ = Circuit(N_QUBITS, DEPTH) +# Create random initial values for the parameters +key = jax.random.PRNGKey(42) +param_vals = jax.random.uniform(key, shape=(circ.n_vparams,)) + +optimizer = optax.adam(learning_rate=0.01) +opt_state = optimizer.init(param_vals) + + +def exp_fn(param_vals: Array, x: Array, y: Array) -> Array: + return circ(param_vals, x, y) + + +def loss_fn(param_vals: Array, x: Array, y: Array) -> Array: + def pde_loss(x: float, y: float) -> Array: + l_b, r_b, t_b, b_b = list( + map( + lambda xy: exp_fn(param_vals, *xy), + [ + [jnp.zeros((1, 1)), y], # u(0,y)=0 + [jnp.ones((1, 1)), y], # u(L,y)=0 + [x, jnp.ones((1, 1))], # u(x,H)=0 + [x, jnp.zeros((1, 1))], # u(x,0)=f(x) + ], + ) + ) + b_b -= jnp.sin(jnp.pi * x) + hessian = jax.hessian(lambda xy: exp_fn(param_vals, xy[0], xy[1]))( + jnp.concatenate( + [ + x.reshape( + 1, + ), + y.reshape( + 1, + ), + ] + ) + ) + interior = hessian[X_POS][X_POS] + hessian[Y_POS][Y_POS] # uxx+uyy=0 + return reduce(add, list(map(lambda term: jnp.power(term, 2), [l_b, r_b, t_b, b_b, interior]))) + + return jnp.mean(vmap(pde_loss, in_axes=(0, 0))(x, y)) + + +def optimize_step(param_vals: Array, opt_state: Array, grads: dict[str, Array]) -> tuple: + updates, opt_state = optimizer.update(grads, opt_state, param_vals) + param_vals = optax.apply_updates(param_vals, updates) + return param_vals, opt_state + + +# collocation points sampling and training +def sample_points(n_in: int, n_p: int) -> Array: + return uniform(0, 1.0, (n_in, n_p)) + + +@jit +def train_step(i: int, paramvals_w_optstate: tuple) -> tuple: + param_vals, opt_state = paramvals_w_optstate + x, y = sample_points(2, N_POINTS) + loss, grads = value_and_grad(loss_fn)(param_vals, x, y) + return optimize_step(param_vals, opt_state, grads) + + +param_vals, opt_state = jax.lax.fori_loop(0, N_EPOCHS, train_step, (param_vals, opt_state)) +# compare the solution to known ground truth +single_domain = jnp.linspace(0, 1, num=N_POINTS) +domain = jnp.array(list(product(single_domain, single_domain))) +# analytical solution +analytic_sol = ( + (np.exp(-np.pi * domain[:, 0]) * np.sin(np.pi * domain[:, 1])).reshape(N_POINTS, N_POINTS).T +) +# DQC solution + +dqc_sol = vmap(lambda domain: exp_fn(param_vals, domain[0], domain[1]), in_axes=(0,))(domain).reshape( + N_POINTS, N_POINTS +) +# # plot results +fig, ax = plt.subplots(1, 2, figsize=(7, 7)) +ax[0].imshow(analytic_sol, cmap="turbo") +ax[0].set_xlabel("x") +ax[0].set_ylabel("y") +ax[0].set_title("Analytical solution u(x,y)") +ax[1].imshow(dqc_sol, cmap="turbo") +ax[1].set_xlabel("x") +ax[1].set_ylabel("y") +ax[1].set_title("DQC solution u(x,y)") +from io import StringIO # markdown-exec: hide +from matplotlib.figure import Figure # markdown-exec: hide +def fig_to_html(fig: Figure) -> str: # markdown-exec: hide + buffer = StringIO() # markdown-exec: hide + fig.savefig(buffer, format="svg") # markdown-exec: hide + return buffer.getvalue() # markdown-exec: hide +# from docs import docutils # markdown-exec: hide +print(fig_to_html(plt.gcf())) # markdown-exec: hide +``` diff --git a/horqrux/abstract.py b/horqrux/abstract.py deleted file mode 100644 index a53486a..0000000 --- a/horqrux/abstract.py +++ /dev/null @@ -1,129 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any, Iterable, Tuple - -import numpy as np -from jax import Array -from jax.tree_util import register_pytree_node_class - -from .matrices import OPERATIONS_DICT -from .utils import ( - ControlQubits, - QubitSupport, - TargetQubits, - _dagger, - _jacobian, - _unitary, - is_controlled, - none_like, -) - - -@register_pytree_node_class -@dataclass -class Primitive: - """Primitive gate class which stores information about generators target and control qubits - of a particular quantum operator.""" - - generator_name: str - target: QubitSupport - control: QubitSupport - - @staticmethod - def parse_idx( - idx: Tuple, - ) -> Tuple: - if isinstance(idx, (int, np.int64)): - return ((idx,),) - elif isinstance(idx, tuple): - return (idx,) - else: - return (idx.astype(int),) - - def __post_init__(self) -> None: - self.target = Primitive.parse_idx(self.target) - if self.control is None: - self.control = none_like(self.target) - else: - self.control = Primitive.parse_idx(self.control) - - def __iter__(self) -> Iterable: - return iter((self.generator_name, self.target, self.control)) - - def tree_flatten(self) -> Tuple[Tuple, Tuple[str, TargetQubits, ControlQubits]]: - children = () - aux_data = (self.generator_name, self.target[0], self.control[0]) - return (children, aux_data) - - @classmethod - def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: - return cls(*children, *aux_data) - - def unitary(self, values: dict[str, float] = dict()) -> Array: - return OPERATIONS_DICT[self.generator_name] - - def dagger(self, values: dict[str, float] = dict()) -> Array: - return _dagger(self.unitary(values)) - - @property - def name(self) -> str: - return "C" + self.generator_name if is_controlled(self.control) else self.generator_name - - def __repr__(self) -> str: - return self.name + f"(target={self.target[0]}, control={self.control[0]})" - - -@register_pytree_node_class -@dataclass -class Parametric(Primitive): - """Extension of the Primitive class adding the option to pass a parameter.""" - - generator_name: str - target: QubitSupport - control: QubitSupport - param: str | float = "" - - def __post_init__(self) -> None: - super().__post_init__() - - def parse_dict(values: dict[str, float] = dict()) -> float: - return values[self.param] # type: ignore[index] - - def parse_val(values: dict[str, float] = dict()) -> float: - return self.param # type: ignore[return-value] - - self.parse_values = parse_dict if isinstance(self.param, str) else parse_val - - def tree_flatten(self) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, str | float]]: # type: ignore[override] - children = () - aux_data = ( - self.generator_name, - self.target[0], - self.control[0], - self.param, - ) - return (children, aux_data) - - def __iter__(self) -> Iterable: - return iter((self.generator_name, self.target, self.control, self.param)) - - @classmethod - def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: - return cls(*children, *aux_data) - - def unitary(self, values: dict[str, float] = dict()) -> Array: - return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(values)) - - def jacobian(self, values: dict[str, float] = dict()) -> Array: - return _jacobian(OPERATIONS_DICT[self.generator_name], self.parse_values(values)) - - @property - def name(self) -> str: - base_name = "R" + self.generator_name - return "C" + base_name if is_controlled(self.control) else base_name - - def __repr__(self) -> str: - return ( - self.name + f"(target={self.target[0]}, control={self.control[0]}, param={self.param})" - ) diff --git a/horqrux/adjoint.py b/horqrux/adjoint.py index e7c785c..1ca6913 100644 --- a/horqrux/adjoint.py +++ b/horqrux/adjoint.py @@ -5,8 +5,9 @@ from jax import Array, custom_vjp from jax.numpy import real as jnpreal -from horqrux.abstract import Parametric, Primitive from horqrux.apply import apply_gate +from horqrux.parametric import Parametric +from horqrux.primitive import Primitive from horqrux.utils import OperationType, inner diff --git a/horqrux/analog.py b/horqrux/analog.py index 28c4f3c..c8acb4e 100644 --- a/horqrux/analog.py +++ b/horqrux/analog.py @@ -6,7 +6,7 @@ from jax.scipy.linalg import expm from jax.tree_util import register_pytree_node_class -from .abstract import Primitive, QubitSupport +from .primitive import Primitive, QubitSupport @register_pytree_node_class diff --git a/horqrux/apply.py b/horqrux/apply.py index 6395922..77d2d10 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -8,7 +8,7 @@ import numpy as np from jax import Array -from horqrux.abstract import Primitive +from horqrux.primitive import Primitive from .utils import OperationType, State, _controlled, is_controlled diff --git a/horqrux/parametric.py b/horqrux/parametric.py index 406cf12..bd5d488 100644 --- a/horqrux/parametric.py +++ b/horqrux/parametric.py @@ -1,10 +1,77 @@ from __future__ import annotations +from dataclasses import dataclass +from typing import Any, Iterable, Tuple + import jax.numpy as jnp from jax import Array +from jax.tree_util import register_pytree_node_class + +from .matrices import OPERATIONS_DICT +from .primitive import Primitive +from .utils import ( + ControlQubits, + QubitSupport, + TargetQubits, + _jacobian, + _unitary, + is_controlled, +) + + +@register_pytree_node_class +@dataclass +class Parametric(Primitive): + """Extension of the Primitive class adding the option to pass a parameter.""" + + generator_name: str + target: QubitSupport + control: QubitSupport + param: str | float = "" + + def __post_init__(self) -> None: + super().__post_init__() + + def parse_dict(values: dict[str, float] = dict()) -> float: + return values[self.param] # type: ignore[index] + + def parse_val(values: dict[str, float] = dict()) -> float: + return self.param # type: ignore[return-value] + + self.parse_values = parse_dict if isinstance(self.param, str) else parse_val + + def tree_flatten(self) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, str | float]]: # type: ignore[override] + children = () + aux_data = ( + self.generator_name, + self.target[0], + self.control[0], + self.param, + ) + return (children, aux_data) + + def __iter__(self) -> Iterable: + return iter((self.generator_name, self.target, self.control, self.param)) + + @classmethod + def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: + return cls(*children, *aux_data) + + def unitary(self, values: dict[str, float] = dict()) -> Array: + return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(values)) + + def jacobian(self, values: dict[str, float] = dict()) -> Array: + return _jacobian(OPERATIONS_DICT[self.generator_name], self.parse_values(values)) + + @property + def name(self) -> str: + base_name = "R" + self.generator_name + return "C" + base_name if is_controlled(self.control) else base_name -from .abstract import Parametric -from .utils import ControlQubits, TargetQubits, is_controlled + def __repr__(self) -> str: + return ( + self.name + f"(target={self.target[0]}, control={self.control[0]}, param={self.param})" + ) def RX(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric: diff --git a/horqrux/primitive.py b/horqrux/primitive.py index 59c0ee0..9790603 100644 --- a/horqrux/primitive.py +++ b/horqrux/primitive.py @@ -1,9 +1,75 @@ from __future__ import annotations -from .abstract import Primitive -from .utils import ControlQubits, TargetQubits - -# Single qubit gates +from dataclasses import dataclass +from typing import Any, Iterable, Tuple + +import numpy as np +from jax import Array +from jax.tree_util import register_pytree_node_class + +from .matrices import OPERATIONS_DICT +from .utils import ( + ControlQubits, + QubitSupport, + TargetQubits, + _dagger, + is_controlled, + none_like, +) + + +@register_pytree_node_class +@dataclass +class Primitive: + """Primitive gate class which stores information about generators target and control qubits + of a particular quantum operator.""" + + generator_name: str + target: QubitSupport + control: QubitSupport + + @staticmethod + def parse_idx( + idx: Tuple, + ) -> Tuple: + if isinstance(idx, (int, np.int64)): + return ((idx,),) + elif isinstance(idx, tuple): + return (idx,) + else: + return (idx.astype(int),) + + def __post_init__(self) -> None: + self.target = Primitive.parse_idx(self.target) + if self.control is None: + self.control = none_like(self.target) + else: + self.control = Primitive.parse_idx(self.control) + + def __iter__(self) -> Iterable: + return iter((self.generator_name, self.target, self.control)) + + def tree_flatten(self) -> Tuple[Tuple, Tuple[str, TargetQubits, ControlQubits]]: + children = () + aux_data = (self.generator_name, self.target[0], self.control[0]) + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: + return cls(*children, *aux_data) + + def unitary(self, values: dict[str, float] = dict()) -> Array: + return OPERATIONS_DICT[self.generator_name] + + def dagger(self, values: dict[str, float] = dict()) -> Array: + return _dagger(self.unitary(values)) + + @property + def name(self) -> str: + return "C" + self.generator_name if is_controlled(self.control) else self.generator_name + + def __repr__(self) -> str: + return self.name + f"(target={self.target[0]}, control={self.control[0]})" def I(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive: diff --git a/pyproject.toml b/pyproject.toml index dd5e050..a326cd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ authors = [ requires-python = ">=3.8,<3.13" license = {text = "Apache 2.0"} -version = "0.6.0" +version = "0.6.1" classifiers=[ "License :: Other/Proprietary License",