diff --git a/docs/index.md b/docs/index.md index cb5fc71..9a24fdf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -272,6 +272,7 @@ class TotalMagnetization: def __call__(self, state: Array, values: dict) -> Array: return reduce(add, [apply_gate(state, pauli, values) for pauli in self.paulis]) +total_magnetization = lambda n_qubits, state: reduce(add, [apply_gate(state, [Z(i) for i in range(n_qubits)], {})]) @dataclass class Circuit: @@ -283,7 +284,6 @@ class Circuit: RX("y", i) for i in range(self.n_qubits // 2, self.n_qubits) ] self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers) - self.observable = TotalMagnetization(self.n_qubits) def __call__(self, param_vals: Array, x: Array, y: Array) -> Array: state = zero_state(self.n_qubits) @@ -291,8 +291,7 @@ class Circuit: out_state = apply_gate( state, self.feature_map + self.ansatz, {**param_dict, **{"x": x, "y": y}} ) - projected_state = self.observable(state, param_dict) - return jnp.real(inner(out_state, projected_state)) + return inner(out_state, total_magnetization(self.n_qubits, out_state)).real @property def n_vparams(self) -> int: diff --git a/horqrux/apply.py b/horqrux/apply.py index 796bd59..43f8494 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -51,6 +51,24 @@ def apply_operator( return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) +def group_by_index(gates: Iterable[Primitive]) -> Iterable[Primitive]: + sorted_gates = [] + gate_batch = [] + for gate in gates: + if not is_controlled(gate.control): + gate_batch.append(gate) + else: + if len(gate_batch) > 0: + gate_batch.sort(key=lambda g: g.target) + sorted_gates += gate_batch + gate_batch = [] + sorted_gates.append(gate) + if len(gate_batch) > 0: + gate_batch.sort(key=lambda g: g.target) + sorted_gates += gate_batch + return sorted_gates + + def merge_operators( operators: tuple[Array, ...], targets: tuple[int, ...], controls: tuple[int, ...] ) -> tuple[tuple[Array, ...], tuple[int, ...], tuple[int, ...]]: @@ -106,6 +124,7 @@ def apply_gate( operator_fn = getattr(gate, op_type) operator, target, control = (operator_fn(values),), gate.target, gate.control else: + gate = group_by_index(gate) operator = tuple(getattr(g, op_type)(values) for g in gate) target = reduce(add, [g.target for g in gate]) control = reduce(add, [g.control for g in gate]) diff --git a/horqrux/utils.py b/horqrux/utils.py index cf8ce3e..4b1074e 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -134,11 +134,11 @@ def uniform_state( return state.reshape([2] * n_qubits) -def is_controlled(qs: Tuple[int | None, ...] | int | None) -> bool: - if isinstance(qs, int): +def is_controlled(qubit_support: Tuple[int | None, ...] | int | None) -> bool: + if isinstance(qubit_support, int): return True - elif isinstance(qs, tuple): - return any(is_controlled(q) for q in qs) + elif isinstance(qubit_support, tuple): + return any(is_controlled(q) for q in qubit_support) return False diff --git a/tests/test_gates.py b/tests/test_gates.py index 63acb97..2802759 100644 --- a/tests/test_gates.py +++ b/tests/test_gates.py @@ -7,7 +7,7 @@ import pytest from jax import Array -from horqrux.apply import apply_gate, apply_operator +from horqrux.apply import apply_gate, apply_operator, group_by_index, merge_operators from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z from horqrux.utils import equivalent_state, product_state, random_state @@ -105,10 +105,16 @@ def test_swap_gate(inputs: tuple[str, str, Array]) -> None: def test_merge_gates() -> None: - gates = [RX("theta", 0), RY("lambda", 0)] - state = product_state("00") - out_state = apply_gate( - state, - gates, - {"theta": np.random.uniform(0.1, 2 * np.pi), "lambda": np.random.uniform(0.1, 2 * np.pi)}, + gates = [RX("a", 0), RZ("b", 1), RY("c", 0)] + gates = group_by_index(gates) + values = { + "a": np.random.uniform(0.1, 2 * np.pi), + "b": np.random.uniform(0.1, 2 * np.pi), + "c": np.random.uniform(0.1, 2 * np.pi), + } + op, trgt, ctrl = merge_operators( + tuple(g.unitary(values) for g in gates), + tuple(g.target for g in gates), + tuple(g.control for g in gates), ) + assert len(op) == 2