From 015ea611bbd6f54457ba46d9e93bd9f158044578 Mon Sep 17 00:00:00 2001 From: Rohan Nolan Lasrado <54094843+lasradorohan@users.noreply.github.com> Date: Tue, 21 Jan 2025 15:05:13 +0530 Subject: [PATCH 1/5] Fixed incorrect use of walrus operator Incorrect use of the walrus operator caused `argnums` to be assigned a boolean instead of the value of rhs value. --- frontend/catalyst/api_extensions/differentiation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/api_extensions/differentiation.py b/frontend/catalyst/api_extensions/differentiation.py index 7f66b38348..66f701e017 100644 --- a/frontend/catalyst/api_extensions/differentiation.py +++ b/frontend/catalyst/api_extensions/differentiation.py @@ -696,7 +696,7 @@ def __call__(self, *args, **kwargs): results, in_arg_tree, out_tree, grad_params, len(jaxpr.out_avals) ) else: - if argnums := self.grad_params.argnums is None: + if (argnums := self.grad_params.argnums) is None: argnums = 0 if self.grad_params.scalar_out: if self.grad_params.with_value: From 17519f809d4a376032a5debc194e63c2e390ce63 Mon Sep 17 00:00:00 2001 From: Rohan Nolan Lasrado <54094843+lasradorohan@users.noreply.github.com> Date: Thu, 23 Jan 2025 01:41:40 +0530 Subject: [PATCH 2/5] Added test for argnums parameter This tests the issue https://github.com/PennyLaneAI/catalyst/issues/1477 --- frontend/test/pytest/test_gradient.py | 28 +++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/frontend/test/pytest/test_gradient.py b/frontend/test/pytest/test_gradient.py index d07237e510..7516ad771b 100644 --- a/frontend/test/pytest/test_gradient.py +++ b/frontend/test/pytest/test_gradient.py @@ -1780,5 +1780,33 @@ def fn(x): assert np.allclose(res_pattern_partial, expected) +@pytest.mark.parametrize("argnums", [0, 1, (0, 1)]) +@pytest.mark.parametrize("transform_qjit", [False, True]) +def test_grad_argnums(argnums, transform_qjit): + @qml.qnode(device=qml.device("lightning.qubit", wires=4), interface="jax") + def circuit(inputs, weights): + qml.AngleEmbedding(features=inputs, wires=range(4), rotation="X") + for i in range(1, 4): + qml.CRX(weights[i - 1], wires=[i, 0]) + return qml.expval(qml.PauliZ(wires=0)) + + if transform_qjit: + circuit = qjit(circuit) + + weights = jnp.array([3.0326467, 0.98860157, 1.9887222]) + inputs = jnp.array([0.9653214, 0.31468165, 0.63302994]) + + def compare_structure_and_value(o1, o2): + return tree_structure(o1) == tree_structure(o2) and tree_all(tree_map(lambda x, y: jnp.allclose(x, y), o1, o2)) + + result = grad(circuit, argnums=argnums)(weights, inputs) + expected = jax.grad(circuit, argnums=argnums)(weights, inputs) + assert compare_structure_and_value(result, expected) + + _, result = value_and_grad(circuit, argnums=argnums)(weights, inputs) + _, expected = jax.value_and_grad(circuit, argnums=argnums)(weights, inputs) + assert compare_structure_and_value(result, expected) + + if __name__ == "__main__": pytest.main(["-x", __file__]) From 7f4fd4149c3391cb73288bac998b8be761a93ec6 Mon Sep 17 00:00:00 2001 From: Rohan Nolan Lasrado <54094843+lasradorohan@users.noreply.github.com> Date: Thu, 23 Jan 2025 01:47:04 +0530 Subject: [PATCH 3/5] Added test for argnums parameter Tests https://github.com/PennyLaneAI/catalyst/issues/1477 --- frontend/test/pytest/test_gradient.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/test_gradient.py b/frontend/test/pytest/test_gradient.py index 7516ad771b..613fd5c8a5 100644 --- a/frontend/test/pytest/test_gradient.py +++ b/frontend/test/pytest/test_gradient.py @@ -19,7 +19,7 @@ import pennylane as qml import pytest from jax import numpy as jnp -from jax.tree_util import tree_flatten +from jax.tree_util import tree_flatten, tree_map, tree_all, tree_structure import catalyst.utils.calculate_grad_shape as infer from catalyst import ( @@ -1783,6 +1783,7 @@ def fn(x): @pytest.mark.parametrize("argnums", [0, 1, (0, 1)]) @pytest.mark.parametrize("transform_qjit", [False, True]) def test_grad_argnums(argnums, transform_qjit): + """Tests https://github.com/PennyLaneAI/catalyst/issues/1477""" @qml.qnode(device=qml.device("lightning.qubit", wires=4), interface="jax") def circuit(inputs, weights): qml.AngleEmbedding(features=inputs, wires=range(4), rotation="X") @@ -1797,7 +1798,8 @@ def circuit(inputs, weights): inputs = jnp.array([0.9653214, 0.31468165, 0.63302994]) def compare_structure_and_value(o1, o2): - return tree_structure(o1) == tree_structure(o2) and tree_all(tree_map(lambda x, y: jnp.allclose(x, y), o1, o2)) + return tree_structure(o1) == tree_structure(o2) and \ + tree_all(tree_map(lambda x, y: jnp.allclose(x, y), o1, o2)) result = grad(circuit, argnums=argnums)(weights, inputs) expected = jax.grad(circuit, argnums=argnums)(weights, inputs) From e1533e1e2b72cbcf5e7fc24c8544621562d9e6d9 Mon Sep 17 00:00:00 2001 From: Rohan Nolan Lasrado <54094843+lasradorohan@users.noreply.github.com> Date: Thu, 23 Jan 2025 01:48:50 +0530 Subject: [PATCH 4/5] Added test for argnums parameter Tests https://github.com/PennyLaneAI/catalyst/issues/1477 --- frontend/test/pytest/test_gradient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_gradient.py b/frontend/test/pytest/test_gradient.py index 613fd5c8a5..3500130f2a 100644 --- a/frontend/test/pytest/test_gradient.py +++ b/frontend/test/pytest/test_gradient.py @@ -1799,7 +1799,7 @@ def circuit(inputs, weights): def compare_structure_and_value(o1, o2): return tree_structure(o1) == tree_structure(o2) and \ - tree_all(tree_map(lambda x, y: jnp.allclose(x, y), o1, o2)) + tree_all(tree_map(jnp.allclose, o1, o2)) result = grad(circuit, argnums=argnums)(weights, inputs) expected = jax.grad(circuit, argnums=argnums)(weights, inputs) From 08a1595978fc33baa7c0ca08f049b3ee2452defa Mon Sep 17 00:00:00 2001 From: Rohan Nolan Lasrado <54094843+lasradorohan@users.noreply.github.com> Date: Thu, 23 Jan 2025 01:59:44 +0530 Subject: [PATCH 5/5] Added release note for bug fix #1478 --- doc/releases/changelog-dev.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index cf4d9016a1..235ac08466 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -14,6 +14,9 @@