Skip to content

Commit

Permalink
fix single dispatch methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles MOUSSA committed Dec 18, 2024
1 parent 32ee665 commit 8c6c881
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 117 deletions.
114 changes: 47 additions & 67 deletions horqrux/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@
from horqrux.apply import apply_gate
from horqrux.primitive import GateSequence, Primitive
from horqrux.shots import finite_shots_fwd, observable_to_matrix
from horqrux.utils import DensityMatrix, DiffMode, ForwardMode, OperationType, inner
from horqrux.utils import (
DensityMatrix,
DiffMode,
ForwardMode,
OperationType,
get_probas,
inner,
sample_from_probs,
)


def run(
Expand All @@ -24,91 +32,61 @@ def run(
return apply_gate(state, circuit, values)


def sample_from_probs(probs: Array, n_qubits: int, n_shots: int) -> Counter:
key = jax.random.PRNGKey(0)

# JAX handles pseudo random number generation by tracking an explicit state via a random key
# For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html
samples = jax.vmap(
lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs)
)(jax.random.split(key, n_shots))

return Counter(
{
format(k, "0{}b".format(n_qubits)): count.item()
for k, count in enumerate(jnp.bincount(samples))
if count > 0
}
)


@singledispatch
def sample(
state: Array,
gates: GateSequence,
values: dict[str, float] = dict(),
n_shots: int = 1000,
) -> Counter:
raise NotImplementedError("sample method is not implemented")


@sample.register
def _(
state: Array,
state: Array | DensityMatrix,
gates: GateSequence,
values: dict[str, float] = dict(),
n_shots: int = 1000,
) -> Counter:
if n_shots < 1:
raise ValueError("You can only call sample with n_shots>0.")

output_circuit = apply_gate(state, gates, values)
n_qubits = len(state.shape)
probs = jnp.abs(jnp.float_power(output_circuit, 2.0)).ravel()
return sample_from_probs(probs, n_qubits, n_shots)

if isinstance(output_circuit, DensityMatrix):
n_qubits = len(output_circuit.array.shape) // 2
d = 2**n_qubits
output_circuit.array = output_circuit.array.reshape((d, d))
else:
n_qubits = len(output_circuit.array.shape)

@sample.register
def _(
state: DensityMatrix,
gates: GateSequence,
values: dict[str, float] = dict(),
n_shots: int = 1000,
) -> Counter:
if n_shots < 1:
raise ValueError("You can only call sample with n_shots>0.")

output_circuit = apply_gate(state, gates, values)
n_qubits = len(state.array.shape) // 2
d = 2**n_qubits
probs = jnp.diagonal(output_circuit.array.reshape((d, d))).real
probs = get_probas(output_circuit)
return sample_from_probs(probs, n_qubits, n_shots)


@singledispatch
def __ad_expectation_single_observable(
state: Array | DensityMatrix,
gates: GateSequence,
output_state: Array,
observable: Primitive,
values: dict[str, float],
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.
"""
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
raise NotImplementedError("__ad_expectation_single_observable is not implemented")

if not isinstance(out_state, DensityMatrix):
projected_state = apply_gate(
out_state,
observable,
values,
OperationType.UNITARY,
)
return inner(out_state, projected_state).real
n_qubits = len(out_state.array.shape) // 2

@__ad_expectation_single_observable.register
def _(
state: Array,
observable: Primitive,
values: dict[str, float],
) -> Array:
projected_state = apply_gate(
state,
observable,
values,
OperationType.UNITARY,
)
return inner(state, projected_state).real


@__ad_expectation_single_observable.register
def _(
state: DensityMatrix,
observable: Primitive,
values: dict[str, float],
) -> Array:
n_qubits = len(state.array.shape) // 2
mat_obs = observable_to_matrix(observable, n_qubits)
d = 2**n_qubits
prod = jnp.matmul(mat_obs, out_state.array.reshape((d, d)))
prod = jnp.matmul(mat_obs, state.array.reshape((d, d)))
return jnp.trace(prod, axis1=-2, axis2=-1).real


Expand All @@ -123,7 +101,9 @@ def ad_expectation(
and compute the expectation given an observable.
"""
outputs = [
__ad_expectation_single_observable(state, gates, observable, values)
__ad_expectation_single_observable(
apply_gate(state, gates, values, OperationType.UNITARY), observable, values
)
for observable in observables
]
return jnp.stack(outputs)
Expand Down
114 changes: 72 additions & 42 deletions horqrux/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ def apply_operator(
target: tuple[int, ...],
control: tuple[int | None, ...],
) -> Array:
"""Apply an operator on a state or density matrix.
Args:
state (Array): Array to operate on.
operator (Array): Array to contract over 'state'.
target (tuple[int, ...]): tuple of target qubits on which to apply the 'operator' to.
control (tuple[int | None, ...]): tuple of control qubits.
Raises:
NotImplementedError: If not implemented for given types.
Returns:
Array: The output of the application of the operator.
"""
raise NotImplementedError("apply_operator is not implemented")


Expand All @@ -51,12 +65,11 @@ def _(
dimension 'i' of 'state'. To restore the former order of dimensions, the affected dimensions
are moved to their original positions and the state is returned.
Arguments:
state: Array to operate on.
operator: Array to contract over 'state'.
target: tuple of target qubits on which to apply the 'operator' to.
control: tuple of control qubits.
is_density: Whether the state is provided as a density matrix.
Args:
state (Array): Array to operate on.
operator (Array): Array to contract over 'state'.
target (tuple[int, ...]): tuple of target qubits on which to apply the 'operator' to.
control (tuple[int | None, ...]): tuple of control qubits.
Returns:
Array after applying 'operator'.
Expand Down Expand Up @@ -85,12 +98,11 @@ def _(
of shape [2 for _ in range(2 * n_qubits)] for a given set of target and control qubits.
In case of a controlled operation, the 'operator' array will be embedded into a controlled array.
Arguments:
state: Density matrix to operate on.
operator: Array to contract over 'state'.
target: tuple of target qubits on which to apply the 'operator' to.
control: tuple of control qubits.
is_density: Whether the state is provided as a density matrix.
Args:
state (Array): Array to operate on.
operator (Array): Array to contract over 'state'.
target (tuple[int, ...]): tuple of target qubits on which to apply the 'operator' to.
control (tuple[int | None, ...]): tuple of control qubits.
Returns:
Density matrix after applying 'operator'.
Expand All @@ -106,8 +118,10 @@ def _(
new_state_dims = tuple(range(len(state_dims)))

# Apply operator to density matrix: ρ' = O ρ O†
support_perm = state_dims + tuple(set(tuple(range(state.array.ndim // 2))) - set(state_dims))
out_state = permute_basis(state.array, support_perm, False)
out_state = state.array
support_perm = state_dims + tuple(set(tuple(range(out_state.ndim // 2))) - set(state_dims))

out_state = permute_basis(out_state, support_perm, False)
out_state = jnp.tensordot(a=operator, b=out_state, axes=(op_out_dims, new_state_dims))

out_state = _dagger(out_state)
Expand All @@ -120,7 +134,7 @@ def _(

def apply_kraus_operator(
kraus: Array,
state: Array,
array: Array,
target: tuple[int, ...],
) -> Array:
"""Apply K \\rho K^\\dagger.
Expand All @@ -138,23 +152,53 @@ def apply_kraus_operator(
kraus = kraus.reshape(tuple(2 for _ in np.arange(n_qubits)))
op_dims = tuple(np.arange(kraus.ndim // 2, kraus.ndim, dtype=int))

state = jnp.tensordot(a=kraus, b=state, axes=(op_dims, state_dims))
array = jnp.tensordot(a=kraus, b=array, axes=(op_dims, state_dims))
new_state_dims = tuple(i for i in range(len(state_dims)))
state = jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims)
array = jnp.moveaxis(a=array, source=new_state_dims, destination=state_dims)

array = jnp.tensordot(a=kraus, b=_dagger(array), axes=(op_dims, state_dims))
array = _dagger(array)

return array


def apply_kraus_sum(
kraus_ops: NoiseProtocol,
array: Array,
target: tuple[int, ...],
) -> DensityMatrix:
"""Apply the following evolution as a sum of Kraus operators:
.. math::
S(\\rho) = \\sum_i K_i \\rho K_i^\\dagger
state = jnp.tensordot(a=kraus, b=_dagger(state), axes=(op_dims, state_dims))
state = _dagger(state)
Args:
noise (NoiseProtocol): Noise containing the K_i
state (Array): Input array
target (tuple[int, ...]): Qubits the operator is defined on.
return state
Returns:
DensityMatrix: Output density matrix.
"""

apply_one_kraus = jax.vmap(
partial(
apply_kraus_operator,
array=array,
target=target,
)
)
kraus_evol = apply_one_kraus(kraus_ops)
output_dm = jnp.sum(kraus_evol, 0)
return DensityMatrix(output_dm)


def apply_operator_with_noise(
state: Array | DensityMatrix,
state: DensityMatrix,
operator: Array,
target: tuple[int, ...],
control: tuple[int | None, ...],
noise: NoiseProtocol,
) -> Array:
) -> Array | DensityMatrix:
"""Evolves the input state and applies a noisy quantum channel
on the evolved state :math:`\rho`.
Expand All @@ -177,15 +221,7 @@ def apply_operator_with_noise(
return state_gate
else:
kraus_ops = jnp.stack(tuple(reduce(add, tuple(n.kraus for n in noise))))
apply_one_kraus = jax.vmap(
partial(
apply_kraus_operator,
state=state_gate.array if isinstance(state_gate, DensityMatrix) else state_gate,
target=target,
)
)
kraus_evol = apply_one_kraus(kraus_ops)
output_dm = jnp.sum(kraus_evol, 0)
output_dm = apply_kraus_sum(kraus_ops, state_gate.array, target)
return output_dm


Expand Down Expand Up @@ -272,7 +308,6 @@ def _(
op_type: The type of operation to perform: Unitary, Dagger or Jacobian.
group_gates: Group gates together which are acting on the same qubit.
merge_ops: Attempt to merge operators acting on the same qubit.
is_density: If True, state is provided as a density matrix.
Returns:
Array or density matrix after applying 'gate'.
Expand Down Expand Up @@ -301,16 +336,14 @@ def _(
output_state = reduce(
lambda state, gate: apply_operator_with_noise(state, *gate),
zip(operator, target, control, noise),
state.array,
state,
)
output_state = DensityMatrix(output_state)
else:
output_state = reduce(
lambda state, gate: apply_operator_with_noise(state, *gate),
zip(operator, target, control, noise),
lambda state, gate: apply_operator(state, *gate),
zip(operator, target, control),
state,
)

return output_state


Expand All @@ -331,7 +364,6 @@ def _(
op_type: The type of operation to perform: Unitary, Dagger or Jacobian.
group_gates: Group gates together which are acting on the same qubit.
merge_ops: Attempt to merge operators acting on the same qubit.
is_density: If True, state is provided as a density matrix.
Returns:
Array or density matrix after applying 'gate'.
Expand All @@ -351,11 +383,9 @@ def _(
if merge_ops:
operator, target, control = merge_operators(operator, target, control)
noise = [g.noise for g in gate]

output_state = reduce(
lambda state, gate: apply_operator_with_noise(state, *gate),
zip(operator, target, control, noise),
state.array,
state,
)

return DensityMatrix(output_state)
return output_state
Loading

0 comments on commit 8c6c881

Please sign in to comment.