Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed incorrect use of walrus operator #1478

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

lasradorohan
Copy link

Incorrect use of the walrus operator caused argnums to be assigned a boolean instead of the value of rhs value.

Context:
Parameter argnums of catalyst.grad() and catalyst.value_and_grad() was getting ignored. The gradients returned were of the first parameter regardless of argnums.
In incorrect use of the walrus operator := was found in the code.

Description of the Change:
Brackets were added to ensure correct assignment of value.

Benefits:
argnums now works as expected.

Possible Drawbacks:

Related GitHub Issues:
#1477

Incorrect use of the walrus operator caused `argnums` to be assigned a boolean instead of the value of rhs value.
@Qottmann
Copy link

Qottmann commented Jan 21, 2025

I dont think this is actually a bug. Also note that pylint complains when you do these brackets around the walrus operator.

I personally also prefer it with the brackets for better readability though, so wouldnt mind the change. But note that there are several occurences of this pattern in the codebase, so it wouldnt be just changing this line here.

nevermind this is a different case here than I thought

@dime10
Copy link
Contributor

dime10 commented Jan 22, 2025

Hi @lasradorohan, thank you so much for uncovering this issue and even providing a fix for it 🚀 🚀

Copy link
Contributor

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've approved the CI to run on your PR to see if anything breaks (it will also check formatting and linting). In order to merge this PR, we'll require the following:

  • Normally all bug fixes require one (or more) new tests that would have failed before and now pass. This is to prevent us from regressing into the bug again in the future. If you would like to add a test case (maybe a slimmed down version of how you originally discovered it) you can add a test in this file. Otherwise, let me know and we can add a test to your PR.
  • Update the changelog saying what was fixed with this PR in this file, and most importantly, don't forget to add your name to the contributor list at the bottom of the file :)

Copy link

codecov bot commented Jan 22, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 96.70%. Comparing base (78c7251) to head (015ea61).
Report is 4 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1478   +/-   ##
=======================================
  Coverage   96.70%   96.70%           
=======================================
  Files          76       76           
  Lines        8173     8173           
  Branches      846      846           
=======================================
  Hits         7904     7904           
  Misses        215      215           
  Partials       54       54           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, happy to approve 💯 (looks like you just need to run black frontend to satisfy the CI)

Comment on lines +1784 to +1795
@pytest.mark.parametrize("transform_qjit", [False, True])
def test_grad_argnums(argnums, transform_qjit):
"""Tests https://github.com/PennyLaneAI/catalyst/issues/1477"""
@qml.qnode(device=qml.device("lightning.qubit", wires=4), interface="jax")
def circuit(inputs, weights):
qml.AngleEmbedding(features=inputs, wires=range(4), rotation="X")
for i in range(1, 4):
qml.CRX(weights[i - 1], wires=[i, 0])
return qml.expval(qml.PauliZ(wires=0))

if transform_qjit:
circuit = qjit(circuit)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to test without qjit, since we would be testing PennyLane in that case (which has its own test suite):

Suggested change
@pytest.mark.parametrize("transform_qjit", [False, True])
def test_grad_argnums(argnums, transform_qjit):
"""Tests https://github.com/PennyLaneAI/catalyst/issues/1477"""
@qml.qnode(device=qml.device("lightning.qubit", wires=4), interface="jax")
def circuit(inputs, weights):
qml.AngleEmbedding(features=inputs, wires=range(4), rotation="X")
for i in range(1, 4):
qml.CRX(weights[i - 1], wires=[i, 0])
return qml.expval(qml.PauliZ(wires=0))
if transform_qjit:
circuit = qjit(circuit)
def test_grad_argnums(argnums):
"""Tests https://github.com/PennyLaneAI/catalyst/issues/1477"""
@qjit
@qml.qnode(device=qml.device("lightning.qubit", wires=4), interface="jax")
def circuit(inputs, weights):
qml.AngleEmbedding(features=inputs, wires=range(4), rotation="X")
for i in range(1, 4):
qml.CRX(weights[i - 1], wires=[i, 0])
return qml.expval(qml.PauliZ(wires=0))

Comment on lines +1805 to +1809
expected = jax.grad(circuit, argnums=argnums)(weights, inputs)
assert compare_structure_and_value(result, expected)

_, result = value_and_grad(circuit, argnums=argnums)(weights, inputs)
_, expected = jax.value_and_grad(circuit, argnums=argnums)(weights, inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we might want to compare against a computation not involving Catalyst at all, which can be achieved like this (applying jax.grad function is a sort of Catalyst/JAX hybrid):

Suggested change
expected = jax.grad(circuit, argnums=argnums)(weights, inputs)
assert compare_structure_and_value(result, expected)
_, result = value_and_grad(circuit, argnums=argnums)(weights, inputs)
_, expected = jax.value_and_grad(circuit, argnums=argnums)(weights, inputs)
expected = jax.grad(circuit.original_function, argnums=argnums)(weights, inputs)
assert compare_structure_and_value(result, expected)
_, result = value_and_grad(circuit, argnums=argnums)(weights, inputs)
_, expected = jax.value_and_grad(circuit.original_function, argnums=argnums)(weights, inputs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants