Skip to content

Commit

Permalink
function for TM
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz committed Apr 15, 2024
1 parent 1552f37 commit be6e3dd
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 14 deletions.
5 changes: 2 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ class TotalMagnetization:
def __call__(self, state: Array, values: dict) -> Array:
return reduce(add, [apply_gate(state, pauli, values) for pauli in self.paulis])

total_magnetization = lambda n_qubits, state: reduce(add, [apply_gate(state, [Z(i) for i in range(n_qubits)], {})])

@dataclass
class Circuit:
Expand All @@ -283,16 +284,14 @@ class Circuit:
RX("y", i) for i in range(self.n_qubits // 2, self.n_qubits)
]
self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers)
self.observable = TotalMagnetization(self.n_qubits)

def __call__(self, param_vals: Array, x: Array, y: Array) -> Array:
state = zero_state(self.n_qubits)
param_dict = {name: val for name, val in zip(self.param_names, param_vals)}
out_state = apply_gate(
state, self.feature_map + self.ansatz, {**param_dict, **{"x": x, "y": y}}
)
projected_state = self.observable(state, param_dict)
return jnp.real(inner(out_state, projected_state))
return inner(out_state, total_magnetization(self.n_qubits, out_state)).real

@property
def n_vparams(self) -> int:
Expand Down
19 changes: 19 additions & 0 deletions horqrux/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ def apply_operator(
return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims)


def group_by_index(gates: Iterable[Primitive]) -> Iterable[Primitive]:
sorted_gates = []
gate_batch = []
for gate in gates:
if not is_controlled(gate.control):
gate_batch.append(gate)
else:
if len(gate_batch) > 0:
gate_batch.sort(key=lambda g: g.target)
sorted_gates += gate_batch
gate_batch = []
sorted_gates.append(gate)
if len(gate_batch) > 0:
gate_batch.sort(key=lambda g: g.target)
sorted_gates += gate_batch
return sorted_gates


def merge_operators(
operators: tuple[Array, ...], targets: tuple[int, ...], controls: tuple[int, ...]
) -> tuple[tuple[Array, ...], tuple[int, ...], tuple[int, ...]]:
Expand Down Expand Up @@ -106,6 +124,7 @@ def apply_gate(
operator_fn = getattr(gate, op_type)
operator, target, control = (operator_fn(values),), gate.target, gate.control
else:
gate = group_by_index(gate)
operator = tuple(getattr(g, op_type)(values) for g in gate)
target = reduce(add, [g.target for g in gate])
control = reduce(add, [g.control for g in gate])
Expand Down
8 changes: 4 additions & 4 deletions horqrux/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ def uniform_state(
return state.reshape([2] * n_qubits)


def is_controlled(qs: Tuple[int | None, ...] | int | None) -> bool:
if isinstance(qs, int):
def is_controlled(qubit_support: Tuple[int | None, ...] | int | None) -> bool:
if isinstance(qubit_support, int):
return True
elif isinstance(qs, tuple):
return any(is_controlled(q) for q in qs)
elif isinstance(qubit_support, tuple):
return any(is_controlled(q) for q in qubit_support)
return False


Expand Down
20 changes: 13 additions & 7 deletions tests/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from jax import Array

from horqrux.apply import apply_gate, apply_operator
from horqrux.apply import apply_gate, apply_operator, group_by_index, merge_operators
from horqrux.parametric import PHASE, RX, RY, RZ
from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z
from horqrux.utils import equivalent_state, product_state, random_state
Expand Down Expand Up @@ -105,10 +105,16 @@ def test_swap_gate(inputs: tuple[str, str, Array]) -> None:


def test_merge_gates() -> None:
gates = [RX("theta", 0), RY("lambda", 0)]
state = product_state("00")
out_state = apply_gate(
state,
gates,
{"theta": np.random.uniform(0.1, 2 * np.pi), "lambda": np.random.uniform(0.1, 2 * np.pi)},
gates = [RX("a", 0), RZ("b", 1), RY("c", 0)]
gates = group_by_index(gates)
values = {
"a": np.random.uniform(0.1, 2 * np.pi),
"b": np.random.uniform(0.1, 2 * np.pi),
"c": np.random.uniform(0.1, 2 * np.pi),
}
op, trgt, ctrl = merge_operators(
tuple(g.unitary(values) for g in gates),
tuple(g.target for g in gates),
tuple(g.control for g in gates),
)
assert len(op) == 2

0 comments on commit be6e3dd

Please sign in to comment.