Skip to content

Commit

Permalink
adding values to observable to matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles MOUSSA committed Dec 18, 2024
1 parent 8c6c881 commit 3a403a9
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
2 changes: 1 addition & 1 deletion horqrux/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _(
values: dict[str, float],
) -> Array:
n_qubits = len(state.array.shape) // 2
mat_obs = observable_to_matrix(observable, n_qubits)
mat_obs = observable_to_matrix(observable, n_qubits, values)
d = 2**n_qubits
prod = jnp.matmul(mat_obs, state.array.reshape((d, d)))
return jnp.trace(prod, axis1=-2, axis2=-1).real
Expand Down
15 changes: 11 additions & 4 deletions horqrux/shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from horqrux.utils import DensityMatrix, none_like


def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array:
def observable_to_matrix(
observable: Primitive,
n_qubits: int,
values: dict[str, float],
) -> Array:
"""For finite shot sampling we need to calculate the eigenvalues/vectors of
an observable. This helper function takes an observable and system size
(n_qubits) and returns the overall action of the observable on the whole
Expand All @@ -25,7 +29,7 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array:
observable.control == observable.parse_idx(none_like(observable.target)),
"Controlled gates cannot be promoted from observables to operations on the whole state vector",
)
unitary = observable.unitary()
unitary = observable.unitary(values=values)
target = observable.target[0][0]
identity = jnp.eye(2, dtype=unitary.dtype)
ops = [identity for _ in range(n_qubits)]
Expand All @@ -36,11 +40,12 @@ def observable_to_matrix(observable: Primitive, n_qubits: int) -> Array:
def eigenval_decomposition_sampling(
state: Array,
observables: list[Primitive],
values: dict[str, float],
n_qubits: int,
n_shots: int,
key: Any = jax.random.PRNGKey(0),
) -> Array:
mat_obs = [observable_to_matrix(observable, n_qubits) for observable in observables]
mat_obs = [observable_to_matrix(observable, n_qubits, values) for observable in observables]
eigs = [jnp.linalg.eigh(mat) for mat in mat_obs]
eigvecs, eigvals = align_eigenvectors(eigs)
inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten())
Expand Down Expand Up @@ -69,7 +74,9 @@ def finite_shots_fwd(
else:
output_gates = apply_gate(state, gates, values)
n_qubits = len(state.shape)
return eigenval_decomposition_sampling(output_gates, observables, n_qubits, n_shots, key)
return eigenval_decomposition_sampling(
output_gates, observables, values, n_qubits, n_shots, key
)


def align_eigenvectors(eigs: list[tuple[Array, Array]]) -> tuple[Array, Array]:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ def shots_dm(x):
assert jnp.allclose(exp_exact, exp_exact_dm)

exp_shots = shots(x)
# # FIXME: DM expectation not working
# # exp_shots_dm = shots_dm(x)
# FIXME: DM expectation not working
# exp_shots_dm = shots_dm(x)

# assert jnp.allclose(exp_exact, exp_shots, atol=SHOTS_ATOL)
# # assert jnp.allclose(exp_exact, exp_shots_dm, atol=SHOTS_ATOL)
assert jnp.allclose(exp_exact, exp_shots, atol=SHOTS_ATOL)
# assert jnp.allclose(exp_exact, exp_shots_dm, atol=SHOTS_ATOL)

d_exact = jax.grad(lambda x: exact(x).sum())
d_shots = jax.grad(lambda x: shots(x).sum())
Expand Down

0 comments on commit 3a403a9

Please sign in to comment.