Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed incorrect use of walrus operator #1478

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

<h3>Bug fixes 🐛</h3>

* Fixed `argnums` parameter of `grad` and `value_and_grad` being ignored.
[(#1478)](https://github.com/PennyLaneAI/catalyst/pull/1478)

<h3>Internal changes ⚙️</h3>

* `from_plxpr` now uses the `qml.capture.PlxprInterpreter` class for reduced code duplication.
Expand Down Expand Up @@ -48,4 +51,5 @@ This release contains contributions from (in alphabetical order):
Sengthai Heng,
Christina Lee,
Mehrdad Malekmohammadi,
Paul Haochen Wang.
Paul Haochen Wang,
Rohan Nolan Lasrado.
2 changes: 1 addition & 1 deletion frontend/catalyst/api_extensions/differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ def __call__(self, *args, **kwargs):
results, in_arg_tree, out_tree, grad_params, len(jaxpr.out_avals)
)
else:
if argnums := self.grad_params.argnums is None:
if (argnums := self.grad_params.argnums) is None:
argnums = 0
if self.grad_params.scalar_out:
if self.grad_params.with_value:
Expand Down
32 changes: 31 additions & 1 deletion frontend/test/pytest/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pennylane as qml
import pytest
from jax import numpy as jnp
from jax.tree_util import tree_flatten
from jax.tree_util import tree_flatten, tree_map, tree_all, tree_structure

import catalyst.utils.calculate_grad_shape as infer
from catalyst import (
Expand Down Expand Up @@ -1780,5 +1780,35 @@ def fn(x):
assert np.allclose(res_pattern_partial, expected)


@pytest.mark.parametrize("argnums", [0, 1, (0, 1)])
@pytest.mark.parametrize("transform_qjit", [False, True])
def test_grad_argnums(argnums, transform_qjit):
"""Tests https://github.com/PennyLaneAI/catalyst/issues/1477"""
@qml.qnode(device=qml.device("lightning.qubit", wires=4), interface="jax")
def circuit(inputs, weights):
qml.AngleEmbedding(features=inputs, wires=range(4), rotation="X")
for i in range(1, 4):
qml.CRX(weights[i - 1], wires=[i, 0])
return qml.expval(qml.PauliZ(wires=0))

if transform_qjit:
circuit = qjit(circuit)
Comment on lines +1784 to +1795
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to test without qjit, since we would be testing PennyLane in that case (which has its own test suite):

Suggested change
@pytest.mark.parametrize("transform_qjit", [False, True])
def test_grad_argnums(argnums, transform_qjit):
"""Tests https://github.com/PennyLaneAI/catalyst/issues/1477"""
@qml.qnode(device=qml.device("lightning.qubit", wires=4), interface="jax")
def circuit(inputs, weights):
qml.AngleEmbedding(features=inputs, wires=range(4), rotation="X")
for i in range(1, 4):
qml.CRX(weights[i - 1], wires=[i, 0])
return qml.expval(qml.PauliZ(wires=0))
if transform_qjit:
circuit = qjit(circuit)
def test_grad_argnums(argnums):
"""Tests https://github.com/PennyLaneAI/catalyst/issues/1477"""
@qjit
@qml.qnode(device=qml.device("lightning.qubit", wires=4), interface="jax")
def circuit(inputs, weights):
qml.AngleEmbedding(features=inputs, wires=range(4), rotation="X")
for i in range(1, 4):
qml.CRX(weights[i - 1], wires=[i, 0])
return qml.expval(qml.PauliZ(wires=0))


weights = jnp.array([3.0326467, 0.98860157, 1.9887222])
inputs = jnp.array([0.9653214, 0.31468165, 0.63302994])

def compare_structure_and_value(o1, o2):
return tree_structure(o1) == tree_structure(o2) and \
tree_all(tree_map(jnp.allclose, o1, o2))

result = grad(circuit, argnums=argnums)(weights, inputs)
expected = jax.grad(circuit, argnums=argnums)(weights, inputs)
assert compare_structure_and_value(result, expected)

_, result = value_and_grad(circuit, argnums=argnums)(weights, inputs)
_, expected = jax.value_and_grad(circuit, argnums=argnums)(weights, inputs)
Comment on lines +1805 to +1809
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we might want to compare against a computation not involving Catalyst at all, which can be achieved like this (applying jax.grad function is a sort of Catalyst/JAX hybrid):

Suggested change
expected = jax.grad(circuit, argnums=argnums)(weights, inputs)
assert compare_structure_and_value(result, expected)
_, result = value_and_grad(circuit, argnums=argnums)(weights, inputs)
_, expected = jax.value_and_grad(circuit, argnums=argnums)(weights, inputs)
expected = jax.grad(circuit.original_function, argnums=argnums)(weights, inputs)
assert compare_structure_and_value(result, expected)
_, result = value_and_grad(circuit, argnums=argnums)(weights, inputs)
_, expected = jax.value_and_grad(circuit.original_function, argnums=argnums)(weights, inputs)

assert compare_structure_and_value(result, expected)


if __name__ == "__main__":
pytest.main(["-x", __file__])
Loading