diff --git a/horqrux/adjoint.py b/horqrux/adjoint.py index 1ca6913..221362f 100644 --- a/horqrux/adjoint.py +++ b/horqrux/adjoint.py @@ -3,7 +3,6 @@ from typing import Tuple from jax import Array, custom_vjp -from jax.numpy import real as jnpreal from horqrux.apply import apply_gate from horqrux.parametric import Parametric @@ -14,15 +13,20 @@ def expectation( state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float] ) -> Array: + """ + Run 'state' through a sequence of 'gates' given parameters 'values' and compute the expectation given an observable. + """ out_state = apply_gate(state, gates, values, OperationType.UNITARY) projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY) - return jnpreal(inner(out_state, projected_state)) + return inner(out_state, projected_state).real @custom_vjp def adjoint_expectation( state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float] ) -> Array: + """Custom vector-jacobian product to compute gradients + in O(P) time using O(1) state vectors via Algorithm 1 in https://arxiv.org/abs/2009.02823.""" return expectation(state, gates, observable, values) @@ -31,7 +35,7 @@ def adjoint_expectation_fwd( ) -> Tuple[Array, Tuple[Array, Array, list[Primitive], dict[str, float]]]: out_state = apply_gate(state, gates, values, OperationType.UNITARY) projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY) - return jnpreal(inner(out_state, projected_state)), (out_state, projected_state, gates, values) + return inner(out_state, projected_state).real, (out_state, projected_state, gates, values) def adjoint_expectation_bwd( @@ -43,7 +47,7 @@ def adjoint_expectation_bwd( out_state = apply_gate(out_state, gate, values, OperationType.DAGGER) if isinstance(gate, Parametric): mu = apply_gate(out_state, gate, values, OperationType.JACOBIAN) - grads[gate.param] = tangent * 2 * jnpreal(inner(mu, projected_state)) + grads[gate.param] = tangent * 2 * inner(mu, projected_state).real projected_state = apply_gate(projected_state, gate, values, OperationType.DAGGER) return (None, None, None, grads) diff --git a/horqrux/apply.py b/horqrux/apply.py index 866d3b4..796bd59 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -54,6 +54,17 @@ def apply_operator( def merge_operators( operators: tuple[Array, ...], targets: tuple[int, ...], controls: tuple[int, ...] ) -> tuple[tuple[Array, ...], tuple[int, ...], tuple[int, ...]]: + """ + If possible, merge several gates acting on the same qubits into a single tensordot operation. + + Arguments: + operators: The arrays representing the unitaries to be merged. + targets: The corresponding target qubits. + controls: The corresponding control qubits. + Returns: + A tuple of merged operators, targets and controls. + + """ if len(operators) < 2: return operators, targets, controls operators, targets, controls = operators[::-1], targets[::-1], controls[::-1] diff --git a/horqrux/utils.py b/horqrux/utils.py index 659d32c..cf8ce3e 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -123,7 +123,7 @@ def inner(state: Array, projection: Array) -> Array: def overlap(state: Array, projection: Array) -> Array: - return jnp.real(jnp.power(inner(state, projection), 2)) + return jnp.power(inner(state, projection), 2).real def uniform_state(