Skip to content

Commit

Permalink
[Feature] Add sample, Use single precision by default (#16)
Browse files Browse the repository at this point in the history
* [Feature] Use single precision by default

* refac expectation like pyq

* Add sample, increase atol

* rework circ

* Remove spurious import.

* Lint.

* change definition of a quantum circuit

* more docstr on circuit

* lint

* rename to fparams and vparams

---------

Co-authored-by: Roland Guichard <[email protected]>
  • Loading branch information
dominikandreasseitz and RolandMacDoland authored Dec 5, 2024
1 parent 7548885 commit 5fd2dad
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 74 deletions.
46 changes: 25 additions & 21 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ from operator import add
from typing import Any, Callable
from uuid import uuid4

from horqrux.adjoint import adjoint_expectation
from horqrux.circuit import Circuit, hea
from horqrux.circuit import QuantumCircuit, hea, expectation
from horqrux.primitive import Primitive
from horqrux.parametric import Parametric
from horqrux import Z, RX, RY, NOT, zero_state, apply_gate
from horqrux.utils import DiffMode


n_qubits = 5
Expand All @@ -128,19 +128,21 @@ fn = lambda x, degree: .05 * reduce(add, (jnp.cos(i*x) + jnp.sin(i*x) for i in r
x = jnp.linspace(0, 10, 100)
y = fn(x, 5)


class DQC(Circuit):
@dataclass
class DQC(QuantumCircuit):
def __post_init__(self) -> None:
self.observable: list[Primitive] = [Z(0)]
self.state = zero_state(self.n_qubits)

@partial(vmap, in_axes=(None, None, 0))
def __call__(self, param_values: Array, x: Array) -> Array:
param_dict = {name: val for name, val in zip(self.param_names, param_values)}
return adjoint_expectation(self.state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}})[0]

param_dict = {name: val for name, val in zip(self.vparams, param_values)}
return expectation(self.state, self.operations, self.observable, {**param_dict, **{'phi': x}}, DiffMode.ADJOINT)

circ = DQC(n_qubits=n_qubits, feature_map=[RX('phi', i) for i in range(n_qubits)], ansatz=hea(n_qubits, n_layers))
feature_map = [RX('phi', i) for i in range(n_qubits)]
fm_names = [f.param for f in feature_map]
ansatz = hea(n_qubits, n_layers)
circ = DQC(n_qubits=n_qubits, operations=feature_map + ansatz, fparams=fm_names)
# Create random initial values for the parameters
key = jax.random.PRNGKey(42)
param_vals = jax.random.uniform(key, shape=(circ.n_vparams,))
Expand Down Expand Up @@ -212,20 +214,20 @@ from jax import Array, jit, value_and_grad, vmap
from numpy.random import uniform

from horqrux.apply import group_by_index
from horqrux.circuit import Circuit, hea
from horqrux.circuit import QuantumCircuit, hea
from horqrux import NOT, RX, RY, Z, apply_gate, zero_state
from horqrux.primitive import Primitive
from horqrux.parametric import Parametric
from horqrux.utils import inner

LEARNING_RATE = 0.01
LEARNING_RATE = 0.15
N_QUBITS = 4
DEPTH = 3
VARIABLES = ("x", "y")
NUM_VARIABLES = len(VARIABLES)
X_POS, Y_POS = [i for i in range(NUM_VARIABLES)]
BATCH_SIZE = 150
N_EPOCHS = 1000
BATCH_SIZE = 500
N_EPOCHS = 500

def total_magnetization(n_qubits:int) -> Callable:
paulis = [Z(i) for i in range(n_qubits)]
Expand All @@ -237,26 +239,28 @@ def total_magnetization(n_qubits:int) -> Callable:
return inner(out_state, projected_state).real
return _total_magnetization


class DQC(Circuit):
@dataclass
class DQC(QuantumCircuit):
def __post_init__(self) -> None:
self.ansatz = group_by_index(self.ansatz)
self.operations = group_by_index(self.operations)
self.observable = total_magnetization(self.n_qubits)
self.state = zero_state(self.n_qubits)

def __call__(self, param_vals: Array, x: Array, y: Array) -> Array:
param_dict = {name: val for name, val in zip(self.param_names, param_vals)}

def __call__(self, values: dict[str, Array], x: Array, y: Array) -> Array:
param_dict = {name: val for name, val in zip(self.vparams, values)}
out_state = apply_gate(
self.state, self.feature_map + self.ansatz, {**param_dict, **{"x": x, "y": y}}
self.state, self.operations, {**param_dict, **{"f_x": x, "f_y": y}}
)
return self.observable(out_state, {})


fm = [RX("x", i) for i in range(N_QUBITS // 2)] + [
RX("y", i) for i in range(N_QUBITS // 2, N_QUBITS)
fm = [RX("f_x", i) for i in range(N_QUBITS // 2)] + [
RX("f_y", i) for i in range(N_QUBITS // 2, N_QUBITS)
]
fm_circuit_parameters = [f.param for f in fm]
ansatz = hea(N_QUBITS, DEPTH)
circ = DQC(N_QUBITS, fm, ansatz)
circ = DQC(N_QUBITS, fm + ansatz, fm_circuit_parameters)
# Create random initial values for the parameters
key = jax.random.PRNGKey(42)
param_vals = jax.random.uniform(key, shape=(circ.n_vparams,))
Expand Down
2 changes: 2 additions & 0 deletions horqrux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

from .api import expectation
from .apply import apply_gate, apply_operator
from .circuit import QuantumCircuit, sample
from .parametric import PHASE, RX, RY, RZ
from .primitive import NOT, SWAP, H, I, S, T, X, Y, Z
from .utils import (
DiffMode,
equivalent_state,
hilbert_reshape,
overlap,
Expand Down
5 changes: 4 additions & 1 deletion horqrux/_misc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import jax
import jax.numpy as jnp
from jax._src.typing import DType


def default_complex_dtype() -> DType:
if jax.config.jax_enable_x64:
return jnp.complex128
else:
return jnp.complex64
return jnp.complex64
19 changes: 5 additions & 14 deletions horqrux/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,36 @@
from typing import Tuple

from jax import Array, custom_vjp
from jax import numpy as jnp

from horqrux.apply import apply_gate
from horqrux.parametric import Parametric
from horqrux.primitive import GateSequence, Primitive
from horqrux.utils import OperationType, inner


def expectation(
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
def ad_expectation(
state: Array, gates: list[Primitive], observables: list[Primitive], values: dict[str, float]
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.
"""
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observables, values, OperationType.UNITARY)
return inner(out_state, projected_state).real


@custom_vjp
def __adjoint_expectation_single_observable(
state: Array, gates: list[Primitive], observable: Primitive, values: dict[str, float]
) -> Array:
return expectation(state, gates, [observable], values)
return ad_expectation(state, gates, [observable], values)


def adjoint_expectation(
state: Array, gates: GateSequence, observables: list[Primitive], values: dict[str, float]
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.
"""
outputs = [
__adjoint_expectation_single_observable(state, gates, observable, values)
for observable in observables
]
return jnp.stack(outputs)
return ad_expectation(state, gates, observables, values) # type: ignore[arg-type]


def adjoint_expectation_single_observable_fwd(
Expand Down
4 changes: 2 additions & 2 deletions horqrux/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def apply_operator(
if is_controlled(control):
operator = _controlled(operator, len(control))
state_dims = (*control, *target) # type: ignore[arg-type]
n_qubits = int(np.log2(operator.size))
operator = operator.reshape(tuple(2 for _ in np.arange(n_qubits)))
n_qubits = int(np.log2(operator.shape[1]))
operator = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits)))
op_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int))
state = jnp.tensordot(a=operator, b=state, axes=(op_dims, state_dims))
new_state_dims = tuple(i for i in range(len(state_dims)))
Expand Down
134 changes: 113 additions & 21 deletions horqrux/circuit.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,77 @@
from __future__ import annotations

from dataclasses import dataclass
from collections import Counter
from dataclasses import dataclass, field
from typing import Any, Callable
from uuid import uuid4

import jax
import jax.numpy as jnp
from jax import Array
from jax.tree_util import register_pytree_node_class

from horqrux.adjoint import ad_expectation, adjoint_expectation
from horqrux.apply import apply_gate
from horqrux.parametric import RX, RY, Parametric
from horqrux.primitive import NOT, Primitive
from horqrux.utils import zero_state
from horqrux.utils import DiffMode, zero_state


@register_pytree_node_class
@dataclass
class Circuit:
"""A minimalistic circuit class to store a sequence of gates."""
class QuantumCircuit:
"""A minimalistic circuit class to store a sequence of gates.
n_qubits: int
feature_map: list[Primitive]
ansatz: list[Primitive]
Attributes:
n_qubits (int): Number of qubits.
operations (list[Primitive]): Operations defining the circuit.
fparams (list[str]): List of parameters that are considered
non trainable, used for passing fixed input data to a quantum circuit.
The corresponding operations compose the `feature map`.
"""

def __post_init__(self) -> None:
self.state = zero_state(self.n_qubits)
n_qubits: int
operations: list[Primitive]
fparams: list[str] = field(default_factory=list)

def __call__(self, param_values: Array) -> Array:
return apply_gate(
self.state,
self.feature_map + self.ansatz,
{name: val for name, val in zip(self.param_names, param_values)},
)
def __call__(self, state: Array, values: dict[str, Array]) -> Array:
if state is None:
state = zero_state(self.n_qubits)
return apply_gate(state, self.operations, values)

@property
def param_names(self) -> list[str]:
return [str(op.param) for op in self.ansatz if isinstance(op, Parametric)]
"""List of parameters of the circuit.
Composed of variational and feature map parameters.
Returns:
list[str]: Names of parameters.
"""
return [str(op.param) for op in self.operations if isinstance(op, Parametric)]

@property
def vparams(self) -> list[str]:
"""List of variational parameters of the circuit.
Returns:
list[str]: Names of variational parameters.
"""
return [name for name in self.param_names if name not in self.fparams]

@property
def n_vparams(self) -> int:
return len(self.param_names)
"""Number of variational parameters.
Returns:
int: Number of variational parameters.
"""
return len(self.param_names) - len(self.fparams)

def tree_flatten(self) -> tuple:
children = (self.feature_map, self.ansatz)
children = (
self.operations,
self.fparams,
)
aux_data = (self.n_qubits,)
return (aux_data, children)

Expand All @@ -50,19 +80,81 @@ def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
return cls(*aux_data, *children)


def hea(n_qubits: int, n_layers: int, rot_fns: list[Callable] = [RX, RY, RX]) -> list[Primitive]:
def hea(
n_qubits: int,
n_layers: int,
rot_fns: list[Callable] = [RX, RY, RX],
variational_param_prefix: str = "v_",
) -> list[Primitive]:
"""Hardware-efficient ansatz; A helper function to generate a sequence of rotations followed
by a global entangling operation."""
by a global entangling operation.
Args:
n_qubits (int): Number of qubits.
n_layers (int): Number of layers
rot_fns (list[Callable], optional): A list of rotations applied on one qubit.
Defaults to [RX, RY, RX].
variational_param_prefix (str, optional): Prefix for the name of variational parameters.
Defaults to "v_". Names suffix are randomly generated strings with uuid4.
Returns:
list[Primitive]: List of gates composing the ansatz.
"""
gates = []
param_names = []
for _ in range(n_layers):
for i in range(n_qubits):
ops = [
fn(str(uuid4()), qubit)
fn(variational_param_prefix + str(uuid4()), qubit)
for fn, qubit in zip(rot_fns, [i for _ in range(len(rot_fns))])
]
param_names += [op.param for op in ops]
ops += [NOT((i + 1) % n_qubits, i % n_qubits) for i in range(n_qubits)] # type: ignore[arg-type]
gates += ops

return gates


def expectation(
state: Array,
gates: list[Primitive],
observable: list[Primitive],
values: dict[str, float],
diff_mode: DiffMode | str = DiffMode.AD,
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.
"""
if diff_mode == DiffMode.AD:
return ad_expectation(state, gates, observable, values)
else:
return adjoint_expectation(state, gates, observable, values)


def sample(
state: Array,
gates: list[Primitive],
values: dict[str, float] = dict(),
n_shots: int = 1000,
) -> Counter:
if n_shots < 1:
raise ValueError("You can only call sample with n_shots>0.")

wf = apply_gate(state, gates, values)
probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel()
key = jax.random.PRNGKey(0)
n_qubits = len(state.shape)
# JAX handles pseudo random number generation by tracking an explicit state via a random key
# For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html
samples = jax.vmap(
lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs)
)(jax.random.split(key, n_shots))

return Counter(
{
format(k, "0{}b".format(n_qubits)): count.item()
for k, count in enumerate(jnp.bincount(samples))
if count > 0
}
)
4 changes: 1 addition & 3 deletions horqrux/parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def name(self) -> str:
return "C" + base_name if is_controlled(self.control) else base_name

def __repr__(self) -> str:
return (
self.name + f"(target={self.target[0]}, control={self.control[0]}, param={self.param})"
)
return self.name + f"(target={self.target}, control={self.control}, param={self.param})"


def RX(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric:
Expand Down
2 changes: 1 addition & 1 deletion horqrux/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def name(self) -> str:
return "C" + self.generator_name if is_controlled(self.control) else self.generator_name

def __repr__(self) -> str:
return self.name + f"(target={self.target[0]}, control={self.control[0]})"
return self.name + f"(target={self.target}, control={self.control})"


GateSequence = Union[Primitive, Iterable[Primitive]]
Expand Down
Loading

0 comments on commit 5fd2dad

Please sign in to comment.