Skip to content

Commit

Permalink
docs and refac
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz committed Apr 15, 2024
1 parent 816d007 commit 1552f37
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
12 changes: 8 additions & 4 deletions horqrux/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Tuple

from jax import Array, custom_vjp
from jax.numpy import real as jnpreal

from horqrux.apply import apply_gate
from horqrux.parametric import Parametric
Expand All @@ -14,15 +13,20 @@
def expectation(
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values' and compute the expectation given an observable.
"""
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
return jnpreal(inner(out_state, projected_state))
return inner(out_state, projected_state).real


@custom_vjp
def adjoint_expectation(
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
) -> Array:
"""Custom vector-jacobian product to compute gradients
in O(P) time using O(1) state vectors via Algorithm 1 in https://arxiv.org/abs/2009.02823."""
return expectation(state, gates, observable, values)


Expand All @@ -31,7 +35,7 @@ def adjoint_expectation_fwd(
) -> Tuple[Array, Tuple[Array, Array, list[Primitive], dict[str, float]]]:
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
return jnpreal(inner(out_state, projected_state)), (out_state, projected_state, gates, values)
return inner(out_state, projected_state).real, (out_state, projected_state, gates, values)


def adjoint_expectation_bwd(
Expand All @@ -43,7 +47,7 @@ def adjoint_expectation_bwd(
out_state = apply_gate(out_state, gate, values, OperationType.DAGGER)
if isinstance(gate, Parametric):
mu = apply_gate(out_state, gate, values, OperationType.JACOBIAN)
grads[gate.param] = tangent * 2 * jnpreal(inner(mu, projected_state))
grads[gate.param] = tangent * 2 * inner(mu, projected_state).real
projected_state = apply_gate(projected_state, gate, values, OperationType.DAGGER)
return (None, None, None, grads)

Expand Down
11 changes: 11 additions & 0 deletions horqrux/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ def apply_operator(
def merge_operators(
operators: tuple[Array, ...], targets: tuple[int, ...], controls: tuple[int, ...]
) -> tuple[tuple[Array, ...], tuple[int, ...], tuple[int, ...]]:
"""
If possible, merge several gates acting on the same qubits into a single tensordot operation.
Arguments:
operators: The arrays representing the unitaries to be merged.
targets: The corresponding target qubits.
controls: The corresponding control qubits.
Returns:
A tuple of merged operators, targets and controls.
"""
if len(operators) < 2:
return operators, targets, controls
operators, targets, controls = operators[::-1], targets[::-1], controls[::-1]
Expand Down
2 changes: 1 addition & 1 deletion horqrux/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def inner(state: Array, projection: Array) -> Array:


def overlap(state: Array, projection: Array) -> Array:
return jnp.real(jnp.power(inner(state, projection), 2))
return jnp.power(inner(state, projection), 2).real


def uniform_state(
Expand Down

0 comments on commit 1552f37

Please sign in to comment.