Skip to content

Commit

Permalink
separate apply_operator with a density matrix version
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles MOUSSA committed Dec 10, 2024
1 parent 819833f commit bc479ba
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 22 deletions.
53 changes: 42 additions & 11 deletions horqrux/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def apply_operator(
operator: Array,
target: Tuple[int, ...],
control: Tuple[int | None, ...],
is_state_densitymat: bool = False,
) -> State:
"""Applies an operator, i.e. a single array of shape [2, 2, ...], on a given state
of shape [2 for _ in range(n_qubits)] for a given set of target and control qubits.
Expand Down Expand Up @@ -56,23 +55,51 @@ def apply_operator(
operator = _controlled(operator, len(control))
state_dims = (*control, *target) # type: ignore[arg-type]
n_qubits_op = int(np.log2(operator.shape[1]))
operator_reshaped = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits_op)))
op_out_dims = tuple(np.arange(operator_reshaped.ndim // 2, operator_reshaped.ndim, dtype=int))
op_in_dims = tuple(np.arange(0, operator_reshaped.ndim // 2, dtype=int))
operator = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits_op)))
op_out_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int))
# Apply operator
new_state_dims = tuple(i for i in range(len(state_dims)))
if not is_state_densitymat:
# only return O ρ with correctly swaped axis for tensordot
state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, state_dims))
return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims)
state = jnp.tensordot(a=operator, b=state, axes=(op_out_dims, state_dims))
return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims)


def apply_operator_dm(
state: State,
operator: Array,
target: Tuple[int, ...],
control: Tuple[int | None, ...],
) -> State:
"""Applies an operator, i.e. a single array of shape [2, 2, ...], on a given density matrix
of shape [2 for _ in range(2 * n_qubits)] for a given set of target and control qubits.
In case of a controlled operation, the 'operator' array will be embedded into a controlled array.
Arguments:
state: Density matrix to operate on.
operator: Array to contract over 'state'.
target: Tuple of target qubits on which to apply the 'operator' to.
control: Tuple of control qubits.
is_state_densitymat: Whether the state is provided as a density matrix.
Returns:
Density matrix after applying 'operator'.
"""
state_dims: Tuple[int, ...] = target
if is_controlled(control):
operator = _controlled(operator, len(control))
state_dims = (*control, *target) # type: ignore[arg-type]
n_qubits_op = int(np.log2(operator.shape[1]))
operator = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits_op)))
op_out_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int))
op_in_dims = tuple(np.arange(0, operator.ndim // 2, dtype=int))
new_state_dims = tuple(i for i in range(len(state_dims)))

# Apply operator to density matrix: ρ' = O ρ O†
support_perm = state_dims + tuple(set(tuple(range(state.ndim // 2))) - set(state_dims))
state = permute_basis(state, support_perm, False)
state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, new_state_dims))
state = jnp.tensordot(a=operator, b=state, axes=(op_out_dims, new_state_dims))

state = _dagger(state)
state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_out_dims, op_in_dims))
state = jnp.tensordot(a=operator, b=state, axes=(op_out_dims, op_in_dims))
state = _dagger(state)

state = permute_basis(state, support_perm, True)
Expand Down Expand Up @@ -109,7 +136,11 @@ def apply_operator_with_noise(
noise: NoiseProtocol,
is_state_densitymat: bool = False,
) -> State:
state_gate = apply_operator(state, operator, target, control, is_state_densitymat)
state_gate = (
apply_operator(state, operator, target, control)
if not is_state_densitymat
else apply_operator_dm(state, operator, target, control)
)
if len(noise) == 0:
return state_gate
else:
Expand Down
14 changes: 5 additions & 9 deletions tests/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from jax import Array

from horqrux.apply import apply_gate, apply_operator
from horqrux.apply import apply_gate, apply_operator, apply_operator_dm
from horqrux.parametric import PHASE, RX, RY, RZ
from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z
from horqrux.utils import density_mat, equivalent_state, product_state, random_state
Expand All @@ -29,12 +29,11 @@ def test_primitive(gate_fn: Callable) -> None:
)

# test density matrix is similar to pure state
dm = apply_operator(
dm = apply_operator_dm(
density_mat(orig_state),
gate.unitary(),
gate.target[0],
gate.control[0],
is_state_densitymat=True,
)
assert jnp.allclose(dm, density_mat(state))

Expand All @@ -53,12 +52,11 @@ def test_controlled_primitive(gate_fn: Callable) -> None:
)

# test density matrix is similar to pure state
dm = apply_operator(
dm = apply_operator_dm(
density_mat(orig_state),
gate.unitary(),
gate.target[0],
gate.control[0],
is_state_densitymat=True,
)
assert jnp.allclose(dm, density_mat(state))

Expand All @@ -75,12 +73,11 @@ def test_parametric(gate_fn: Callable) -> None:
)

# test density matrix is similar to pure state
dm = apply_operator(
dm = apply_operator_dm(
density_mat(orig_state),
gate.unitary(values),
gate.target[0],
gate.control[0],
is_state_densitymat=True,
)
assert jnp.allclose(dm, density_mat(state))

Expand All @@ -100,12 +97,11 @@ def test_controlled_parametric(gate_fn: Callable) -> None:
)

# test density matrix is similar to pure state
dm = apply_operator(
dm = apply_operator_dm(
density_mat(orig_state),
gate.unitary(values),
gate.target[0],
gate.control[0],
is_state_densitymat=True,
)
assert jnp.allclose(dm, density_mat(state))

Expand Down
6 changes: 4 additions & 2 deletions tests/test_shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ def test_shots() -> None:
def exact(x):
values = {"theta": x}
return expectation(state, ops, observables, values, diff_mode="ad")

def exact_dm(x):
values = {"theta": x}
return expectation(density_mat(state), ops, observables, values, diff_mode="ad", is_state_densitymat=True)
return expectation(
density_mat(state), ops, observables, values, diff_mode="ad", is_state_densitymat=True
)

def shots(x):
values = {"theta": x}
Expand Down

0 comments on commit bc479ba

Please sign in to comment.