-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed incorrect use of walrus operator #1478
base: main
Are you sure you want to change the base?
Conversation
Incorrect use of the walrus operator caused `argnums` to be assigned a boolean instead of the value of rhs value.
nevermind this is a different case here than I thought |
Hi @lasradorohan, thank you so much for uncovering this issue and even providing a fix for it 🚀 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've approved the CI to run on your PR to see if anything breaks (it will also check formatting and linting). In order to merge this PR, we'll require the following:
- Normally all bug fixes require one (or more) new tests that would have failed before and now pass. This is to prevent us from regressing into the bug again in the future. If you would like to add a test case (maybe a slimmed down version of how you originally discovered it) you can add a test in this file. Otherwise, let me know and we can add a test to your PR.
- Update the changelog saying what was fixed with this PR in this file, and most importantly, don't forget to add your name to the contributor list at the bottom of the file :)
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1478 +/- ##
=======================================
Coverage 96.70% 96.70%
=======================================
Files 76 76
Lines 8173 8173
Branches 846 846
=======================================
Hits 7904 7904
Misses 215 215
Partials 54 54 ☔ View full report in Codecov by Sentry. |
This tests the issue PennyLaneAI#1477
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work, happy to approve 💯 (looks like you just need to run black frontend
to satisfy the CI)
@pytest.mark.parametrize("transform_qjit", [False, True]) | ||
def test_grad_argnums(argnums, transform_qjit): | ||
"""Tests https://github.com/PennyLaneAI/catalyst/issues/1477""" | ||
@qml.qnode(device=qml.device("lightning.qubit", wires=4), interface="jax") | ||
def circuit(inputs, weights): | ||
qml.AngleEmbedding(features=inputs, wires=range(4), rotation="X") | ||
for i in range(1, 4): | ||
qml.CRX(weights[i - 1], wires=[i, 0]) | ||
return qml.expval(qml.PauliZ(wires=0)) | ||
|
||
if transform_qjit: | ||
circuit = qjit(circuit) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to test without qjit, since we would be testing PennyLane in that case (which has its own test suite):
@pytest.mark.parametrize("transform_qjit", [False, True]) | |
def test_grad_argnums(argnums, transform_qjit): | |
"""Tests https://github.com/PennyLaneAI/catalyst/issues/1477""" | |
@qml.qnode(device=qml.device("lightning.qubit", wires=4), interface="jax") | |
def circuit(inputs, weights): | |
qml.AngleEmbedding(features=inputs, wires=range(4), rotation="X") | |
for i in range(1, 4): | |
qml.CRX(weights[i - 1], wires=[i, 0]) | |
return qml.expval(qml.PauliZ(wires=0)) | |
if transform_qjit: | |
circuit = qjit(circuit) | |
def test_grad_argnums(argnums): | |
"""Tests https://github.com/PennyLaneAI/catalyst/issues/1477""" | |
@qjit | |
@qml.qnode(device=qml.device("lightning.qubit", wires=4), interface="jax") | |
def circuit(inputs, weights): | |
qml.AngleEmbedding(features=inputs, wires=range(4), rotation="X") | |
for i in range(1, 4): | |
qml.CRX(weights[i - 1], wires=[i, 0]) | |
return qml.expval(qml.PauliZ(wires=0)) |
expected = jax.grad(circuit, argnums=argnums)(weights, inputs) | ||
assert compare_structure_and_value(result, expected) | ||
|
||
_, result = value_and_grad(circuit, argnums=argnums)(weights, inputs) | ||
_, expected = jax.value_and_grad(circuit, argnums=argnums)(weights, inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we might want to compare against a computation not involving Catalyst at all, which can be achieved like this (applying jax.grad
function is a sort of Catalyst/JAX hybrid):
expected = jax.grad(circuit, argnums=argnums)(weights, inputs) | |
assert compare_structure_and_value(result, expected) | |
_, result = value_and_grad(circuit, argnums=argnums)(weights, inputs) | |
_, expected = jax.value_and_grad(circuit, argnums=argnums)(weights, inputs) | |
expected = jax.grad(circuit.original_function, argnums=argnums)(weights, inputs) | |
assert compare_structure_and_value(result, expected) | |
_, result = value_and_grad(circuit, argnums=argnums)(weights, inputs) | |
_, expected = jax.value_and_grad(circuit.original_function, argnums=argnums)(weights, inputs) |
Incorrect use of the walrus operator caused
argnums
to be assigned a boolean instead of the value of rhs value.Context:
Parameter
argnums
ofcatalyst.grad()
andcatalyst.value_and_grad()
was getting ignored. The gradients returned were of the first parameter regardless ofargnums
.In incorrect use of the walrus operator
:=
was found in the code.Description of the Change:
Brackets were added to ensure correct assignment of value.
Benefits:
argnums
now works as expected.Possible Drawbacks:
Related GitHub Issues:
#1477