Skip to content

Commit

Permalink
[Frontend] Update SProd observable tracing (#935)
Browse files Browse the repository at this point in the history
**Context:** After change 296316654bd6aabb3ff67eb0bac440fa8d706ee8 in
PennyLane, the SProd operation will return the flattened terms.

**Description of the Change:** This change fixes the tracing of SProd to
trace over all terms returned by `obs.terms()` as opposed to only the
first term.

**Benefits:** Fixes latest/latest/latest

**Possible Drawbacks:** None

**Related GitHub Issues:** 

[sc-69025]
  • Loading branch information
erick-xanadu authored Jul 16, 2024
1 parent bf69c2b commit 1e5c6d6
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
1 change: 1 addition & 0 deletions .dep-versions
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ llvm=cd9a641613eddf25d4b25eaa96b2c393d401d42c
enzyme=v0.0.130

# Always remove custom PL/LQ versions before release.
pennylane=296316654bd6aabb3ff67eb0bac440fa8d706ee8
9 changes: 8 additions & 1 deletion doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,14 @@

* Using float32 in callback functions would not crash in compilation phase anymore,
but rather raise the appropriate type exception to the user.
[(#916)]https://github.com/PennyLaneAI/catalyst/pull/916
[(#916)](https://github.com/PennyLaneAI/catalyst/pull/916)

* Fix tracing of `SProd` operations
[(#935)](https://github.com/PennyLaneAI/catalyst/pull/935)

After some changes in PennyLane, `Sprod.terms()` returns the terms as leaves
instead of a tree. This means that we need to manually trace each term and
finally multiply it with the coefficients to create a Hamiltonian.

<h3>Internal changes</h3>

Expand Down
11 changes: 7 additions & 4 deletions frontend/catalyst/jax_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,10 +705,13 @@ def trace_observables(
nested_obs = [trace_observables(o, qrp, m_wires)[0] for o in obs]
obs_tracers = hamiltonian_p.bind(jax.numpy.asarray(jnp.ones(len(obs))), *nested_obs)
elif isinstance(obs, qml.ops.op_math.SProd):
terms = obs.terms()
coeffs = jax.numpy.array(terms[0])
nested_obs = trace_observables(terms[1][0], qrp, m_wires)[0]
obs_tracers = hamiltonian_p.bind(coeffs, nested_obs)
coeffs, terms = obs.terms()
coeffs = jax.numpy.array(coeffs)
nested_obs = []
for term in terms:
obs = trace_observables(term, qrp, m_wires)[0]
nested_obs.append(obs)
obs_tracers = hamiltonian_p.bind(coeffs, *nested_obs)
else:
raise NotImplementedError(
f"Observable {obs} (of type {type(obs)}) is not impemented"
Expand Down

0 comments on commit 1e5c6d6

Please sign in to comment.