Skip to content

Commit

Permalink
Improved tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz committed Jan 23, 2024
1 parent e6c11b7 commit abd2f55
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 55 deletions.
6 changes: 4 additions & 2 deletions horqrux/abstract.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Iterable, Tuple
from typing import Any, Callable, Iterable, Tuple

import numpy as np
from jax import Array
Expand Down Expand Up @@ -79,7 +79,9 @@ def __post_init__(self) -> None:
def parse_dict(values: dict[str, float] = {}) -> float:
return values[self.param]

self.parse_values = parse_dict if isinstance(self.param, str) else lambda x: self.param
self.parse_values: Callable[[dict[str, float]], float] = (
parse_dict if isinstance(self.param, str) else lambda x: self.param
)

def tree_flatten(self) -> Tuple[Tuple, Tuple[str, QubitSupport, QubitSupport, str]]:
children = ()
Expand Down
39 changes: 18 additions & 21 deletions horqrux/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,52 +10,49 @@

from horqrux.abstract import Operator

from .utils import ControlQubits, State, TargetQubits, _controlled, is_controlled
from .utils import State, _controlled, is_controlled


def apply_operator(
state: State,
operator: Array,
target: TargetQubits,
control: ControlQubits,
target: Tuple[int, ...],
control: Tuple[int | None, ...],
) -> State:
"""Applies a single or series of operators to the given state.
The operators are expected to either be an array over whose first axis we can iterate (e.g. [N_gates, 2 x 2])
or if you have a mix of single and multi qubit gates a tuple or list like [O_1, O_2, ...].
This function then sequentially applies this gates, adding control bits
as necessary and returning the state after applying all the gates.
"""Applies a single array corresponding to an operator to a given state
for a given set of target and control qubits.
Args:
state: Input state to operate on.
operator: List of arrays or array of operators to contract over the state.
target: Target indices, Tuple of Tuple of ints.
control: Control indices, Tuple of length target_idex of None or Tuple.
state: State to operate on.
operator: Array to contract over 'state'.
target: Tuple of target qubits on which to apply the 'operator' to.
control: Tuple of control qubits.
Returns:
Array: Changed state.
State after applying 'operator'.
"""
qubits = target
assert isinstance(control, tuple)
qubits: Tuple[int, ...] = target
if is_controlled(control):
operator = _controlled(operator, len(control))
qubits = (*control, *target)
n_qubits = int(np.log2(operator.size))
operator = operator.reshape(tuple(2 for _ in np.arange(n_qubits)))
op_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int))
state = jnp.tensordot(a=operator, b=state, axes=(op_dims, qubits))
return jnp.moveaxis(a=state, source=np.arange(len(qubits)), destination=qubits)
new_dims = tuple(i for i in range(len(qubits)))
return jnp.moveaxis(a=state, source=new_dims, destination=qubits)


def apply_gate(
state: State, gate: Operator | Iterable[Operator], values: dict[str, float] = {}
) -> State:
"""Applies a gate to given state. Essentially a simple wrapper around
apply_operator, see that docstring for more info.
"""Applies a gate or a series of gates to a given state.
This function sequentially applies 'gate', adding control bits
as necessary and returning the state after applying all the gates.
Arguments:
state (Array): State to operate on.
gate (Gate): Gate(s) to apply.
state: State to operate on.
gate: Gate(s) to apply.
Returns:
Array: Changed state.
Expand Down
35 changes: 19 additions & 16 deletions horqrux/parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import Array

from .abstract import Parametric
from .utils import ControlQubits, TargetQubits
from .utils import ControlQubits, TargetQubits, is_controlled


def RX(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric:
Expand Down Expand Up @@ -49,6 +49,23 @@ def RZ(param: float | str, target: TargetQubits, control: ControlQubits = (None,
return Parametric("Z", target, control, param=param)


class _PHASE(Parametric):
def unitary(self, values: dict[str, float] = {}) -> Array:
u = jnp.eye(2, 2, dtype=jnp.complex128)
u = u.at[(1, 1)].set(jnp.exp(1.0j * self.parse_values(values)))
return u

def jacobian(self, values: dict[str, float] = {}) -> Array:
jac = jnp.zeros((2, 2), dtype=jnp.complex128)
jac = jac.at[(1, 1)].set(1j * jnp.exp(1.0j * self.parse_values(values)))
return jac

@property
def name(self) -> str:
base_name = "PHASE"
return "C" + base_name if is_controlled(self.control) else base_name


def PHASE(param: float, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric:
"""Phase gate.
Expand All @@ -61,18 +78,4 @@ def PHASE(param: float, target: TargetQubits, control: ControlQubits = (None,))
Parametric: A Parametric gate object.
"""

def unitary(values: dict[str, float] = {}) -> Array:
u = jnp.eye(2, 2, dtype=jnp.complex128)
u = u.at[(1, 1)].set(jnp.exp(1.0j * values[param]))
return u

def jacobian(values: dict[str, float] = {}) -> Array:
jac = jnp.zeros((2, 2), dtype=jnp.complex128)
jac = jac.at[(1, 1)].set(1j * jnp.exp(1.0j * values[param]))
return jac

phase = Parametric("I", target, control, param)
phase.name = "PHASE"
phase.unitary = unitary
phase.jacobian = jacobian
return phase
return _PHASE("I", target, control, param)
3 changes: 2 additions & 1 deletion horqrux/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def uniform_state(
return state.reshape([2] * n_qubits)


def is_controlled(qs: ControlQubits) -> bool:
def is_controlled(qs: Tuple[int | None | Tuple[int, ...], ...]) -> bool:
# FIXME despaghettify
if qs is None:
return False
if isinstance(qs, tuple):
Expand Down
77 changes: 62 additions & 15 deletions tests/test_gates.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,72 @@
from __future__ import annotations

from typing import Callable

import jax.numpy as jnp
import numpy as np
import pytest

from horqrux.apply import apply_gate
from horqrux.parametric import RX, RY, RZ
from horqrux.primitive import NOT, SWAP, H, X, Y, Z
from horqrux.apply import apply_gate, apply_operator
from horqrux.parametric import PHASE, RX, RY, RZ
from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z
from horqrux.utils import equivalent_state, prepare_state

MAX_QUBITS = 7
PARAMETRIC_GATES = (RX, RY, RZ, PHASE)
PRIMITIVE_GATES = (NOT, H, X, Y, Z, I, S, T)


@pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES)
def test_primitive(gate_fn: Callable) -> None:
target = np.random.randint(0, MAX_QUBITS)
gate = gate_fn(target)
orig_state = prepare_state(MAX_QUBITS)
state = apply_gate(orig_state, gate)
assert jnp.allclose(
apply_operator(state, gate.dagger(), gate.target[0], gate.control[0]), orig_state
)


@pytest.mark.parametrize("gate_fn", PRIMITIVE_GATES)
def test_controlled_primitive(gate_fn: Callable) -> None:
target = np.random.randint(0, MAX_QUBITS)
control = np.random.randint(0, MAX_QUBITS)
while control == target:
control = np.random.randint(1, MAX_QUBITS)
gate = gate_fn(target, control)
orig_state = prepare_state(MAX_QUBITS)
state = apply_gate(orig_state, gate)
assert jnp.allclose(
apply_operator(state, gate.dagger(), gate.target[0], gate.control[0]), orig_state
)


@pytest.mark.parametrize("gate_fn", PARAMETRIC_GATES)
def test_parametric(gate_fn: Callable) -> None:
target = np.random.randint(0, MAX_QUBITS)
gate = gate_fn("theta", target)
values = {"theta": np.random.uniform(0.1, 2 * np.pi)}
orig_state = prepare_state(MAX_QUBITS)
state = apply_gate(orig_state, gate, values)
assert jnp.allclose(
apply_operator(state, gate.dagger(values), gate.target[0], gate.control[0]), orig_state
)


@pytest.mark.parametrize("gate_fn", PARAMETRIC_GATES)
def test_controlled_parametric(gate_fn: Callable) -> None:
target = np.random.randint(0, MAX_QUBITS)
control = np.random.randint(0, MAX_QUBITS)
while control == target:
control = np.random.randint(1, MAX_QUBITS)
gate = gate_fn("theta", target, control)
values = {"theta": np.random.uniform(0.1, 2 * np.pi)}
orig_state = prepare_state(MAX_QUBITS)
state = apply_gate(orig_state, gate, values)
assert jnp.allclose(
apply_operator(state, gate.dagger(values), gate.target[0], gate.control[0]), orig_state
)


@pytest.mark.parametrize(
["init_state", "final_state"],
Expand Down Expand Up @@ -42,15 +101,3 @@ def test_swap_gate(x):
state = prepare_state(len(init_state), init_state)
out_state = apply_gate(state, op)
assert equivalent_state(out_state, expected_state), "Output states not similar."


def test_single_gates():
state = prepare_state(7)
state = apply_gate(state, X(0, 1))
state = apply_gate(state, Y(1, 2))
state = apply_gate(state, Z(2, 4))
state = apply_gate(state, H(3, 5))
state = apply_gate(state, RX(1 / 4 * jnp.pi, 4, 1))
state = apply_gate(state, RY(1 / 3 * jnp.pi, 5, 2))
state = apply_gate(state, RZ(1 / 2 * jnp.pi, 6, 0))
# FIXME

0 comments on commit abd2f55

Please sign in to comment.