Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] parameterizing coefficients of qml.dot fails with qjit #1464

Open
bnishanth16 opened this issue Jan 15, 2025 · 2 comments
Open

[BUG] parameterizing coefficients of qml.dot fails with qjit #1464

bnishanth16 opened this issue Jan 15, 2025 · 2 comments

Comments

@bnishanth16
Copy link

Experienced TracerBoolConversionError when trying to parameterize coeffients of qml.dot with qjit. Minimal example code below.

Expected behavior: (without qjit)

import pennylane as qml
import jax.numpy as jnp
from catalyst import qjit

@qml.qnode(qml.device('lightning.qubit', wires=1))
def circuit(x):
    coeffs = [x]
    ops = [qml.PauliX(0)]
    H = qml.dot(coeffs, ops) # H = qml.Hamiltonian(coeffs, ops)
    qml.evolve(H)
    return qml.expval(qml.Z(0))

print(circuit(jnp.array([0.5])))

Output:

>>>  [0.54030231]

Actual Behaviour: (with qjit)

import pennylane as qml
import jax.numpy as jnp
from catalyst import qjit

@qjit(autograph=True)
@qml.qnode(qml.device('lightning.qubit', wires=1))
def circuit(x):
    coeffs = [x]
    ops = [qml.PauliX(0)]
    H = qml.dot(coeffs, ops) # H = qml.Hamiltonian(coeffs, ops)
    qml.evolve(H)
    return qml.expval(qml.Z(0))

print(circuit(jnp.array([0.5])))

Output:

---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[7], line 14
     11     qml.evolve(H)
     12     return qml.expval(qml.Z(0))
---> 14 print(circuit(jnp.array([0.5])))

File ~/anaconda3/envs/pennylane40/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/anaconda3/envs/pennylane40/lib/python3.12/site-packages/catalyst/jit.py:521, in QJIT.__call__(self, *args, **kwargs)
    517         kwargs = {"static_argnums": self.compile_options.static_argnums, **kwargs}
    519     return self.user_function(*args, **kwargs)
--> 521 requires_promotion = self.jit_compile(args, **kwargs)
    523 # If we receive tracers as input, dispatch to the JAX integration.
    524 if any(isinstance(arg, jax.core.Tracer) for arg in tree_flatten(args)[0]):
...
File ~/anaconda3/envs/pennylane40/lib/python3.12/site-packages/pennylane/ops/functions/dot.py:171, in dot(coeffs, ops, pauli, grouping_type, method)
    168 # Convert possible PauliWord and PauliSentence instances to operation
    169 ops = [op.operation() if isinstance(op, (PauliWord, PauliSentence)) else op for op in ops]
--> 171 operands = [op if coeff == 1 else qml.s_prod(coeff, op) for coeff, op in zip(coeffs, ops)]
    172 return (
    173     operands[0]
    174     if len(operands) == 1
    175     else qml.sum(*operands, grouping_type=grouping_type, method=method)
    176 )

File ~/anaconda3/envs/pennylane40/lib/python3.12/site-packages/jax/_src/core.py:712, in Tracer.__bool__(self)
    710 def __bool__(self):
    711   check_bool_conversion(self)
--> 712   return self.aval._bool(self)

File ~/anaconda3/envs/pennylane40/lib/python3.12/site-packages/jax/_src/core.py:1475, in concretization_function_error.<locals>.error(self, arg)
   1474 def error(self, arg):
-> 1475   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[1]..
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

System information:

Name: PennyLane
Version: 0.40.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /home/nishanth/anaconda3/envs/pennylane40/lib/python3.12/site-packages
Requires: appdirs, autograd, autoray, cachetools, diastatic-malt, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, tomlkit, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning

Platform info:           Linux-4.4.0-19041-Microsoft-x86_64-with-glibc2.35
Python version:          3.12.8
Numpy version:           2.0.2
Scipy version:           1.15.1
Installed devices:
- default.clifford (PennyLane-0.40.0)
- default.gaussian (PennyLane-0.40.0)
- default.mixed (PennyLane-0.40.0)
- default.qubit (PennyLane-0.40.0)
- default.qutrit (PennyLane-0.40.0)
- default.qutrit.mixed (PennyLane-0.40.0)
- default.tensor (PennyLane-0.40.0)
- null.qubit (PennyLane-0.40.0)
- reference.qubit (PennyLane-0.40.0)
- nvidia.custatevec (PennyLane-Catalyst-0.10.0)
- nvidia.cutensornet (PennyLane-Catalyst-0.10.0)
- oqc.cloud (PennyLane-Catalyst-0.10.0)
- softwareq.qpp (PennyLane-Catalyst-0.10.0)
- lightning.qubit (PennyLane_Lightning-0.40.0)

Additional Information:
Replacing qml.dot with qml.Hamiltonian gives the same error.

@josh146
Copy link
Member

josh146 commented Jan 15, 2025

Thanks for catching this @bnishanth16! It looks like the qml.dot() operation is currently not JIT compatible when the coefficients of the dot product are dynamic variables. I've shared this with the team internally, and we'll look into what is needed to fix this.

@dime10
Copy link
Contributor

dime10 commented Jan 16, 2025

Yes unfortunately the coeffs have to be constant parameters at the moment. You can work around the problem by telling Catalyst that the x parameter should be constant, but the downside is that if x changes then the program has to be recompiled:

@qjit(static_argnums=0)
@qml.qnode(qml.device('lightning.qubit', wires=1))
def circuit(x):
    coeffs = [x]
    ops = [qml.PauliX(0)]
    H = qml.dot(coeffs, ops)  # H = qml.Hamiltonian(coeffs, ops)
    qml.evolve(H)
    return qml.expval(qml.Z(0))

>>> print(circuit(0.5))
0.5403023058681398

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants