Skip to content

Commit

Permalink
Merge branch 'master' into direct_MPO
Browse files Browse the repository at this point in the history
  • Loading branch information
LuisAlfredoNu authored Jan 27, 2025
2 parents 22cea29 + e1572f5 commit 7218334
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 51 deletions.
11 changes: 10 additions & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,20 @@
* Update Github CI to use Ubuntu 24 and remove `libopenblas-base` package.
[(#1041)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1041)

* Updates the `eval_jaxpr` method to handle the new signatures for the `cond`, `while`, and
`for` primitives.
[(#1051)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1051)

### Contributors

This release contains contributions from (in alphabetical order):

Yushao Chen, Amintor Dusko, Joseph Lee, Andrija Paurevic, Shuli Shu
Yushao Chen,
Amintor Dusko,
Christina Lee,
Joseph Lee,
Andrija Paurevic,
Shuli Shu

---

Expand Down
56 changes: 6 additions & 50 deletions pennylane_lightning/core/lightning_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,11 @@
"""
This module contains a class for executing plxpr using default qubit tools.
"""
from copy import copy

import jax
from pennylane.capture import disable, enable
from pennylane.capture.base_interpreter import PlxprInterpreter
from pennylane.capture.primitives import (
adjoint_transform_prim,
cond_prim,
ctrl_transform_prim,
for_loop_prim,
measure_prim,
while_loop_prim,
)
from pennylane.capture.base_interpreter import FlattenedHigherOrderPrimitives, PlxprInterpreter
from pennylane.capture.primitives import adjoint_transform_prim, ctrl_transform_prim, measure_prim
from pennylane.measurements import MidMeasureMP, Shots

from ._measurements_base import LightningBaseMeasurements
Expand Down Expand Up @@ -120,6 +112,10 @@ def interpret_measurement(self, measurement):
enable()


# pylint: disable=protected-access
LightningInterpreter._primitive_registrations.update(FlattenedHigherOrderPrimitives)


@LightningInterpreter.register_primitive(measure_prim)
def _(self, *invals, reset, postselect):
mp = MidMeasureMP(invals, reset=reset, postselect=postselect)
Expand All @@ -140,43 +136,3 @@ def _(self, *invals, jaxpr, n_consts, lazy=True):
def _(self, *invals, n_control, jaxpr, control_values, work_wires, n_consts):
# TODO: requires jaxpr -> list of ops first
raise NotImplementedError


# pylint: disable=too-many-arguments
@LightningInterpreter.register_primitive(for_loop_prim)
def _(self, start, stop, step, *invals, jaxpr_body_fn, consts_slice, args_slice):
consts = invals[consts_slice]
init_state = invals[args_slice]

res = init_state
for i in range(start, stop, step):
res = copy(self).eval(jaxpr_body_fn, consts, i, *res)

return res


# pylint: disable=too-many-arguments
@LightningInterpreter.register_primitive(while_loop_prim)
def _(self, *invals, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice):
consts_body = invals[body_slice]
consts_cond = invals[cond_slice]
init_state = invals[args_slice]

fn_res = init_state
while copy(self).eval(jaxpr_cond_fn, consts_cond, *fn_res)[0]:
fn_res = copy(self).eval(jaxpr_body_fn, consts_body, *fn_res)

return fn_res


@LightningInterpreter.register_primitive(cond_prim)
def _(self, *invals, jaxpr_branches, consts_slices, args_slice):
n_branches = len(jaxpr_branches)
conditions = invals[:n_branches]
args = invals[args_slice]

for pred, jaxpr, const_slice in zip(conditions, jaxpr_branches, consts_slices):
consts = invals[const_slice]
if pred and jaxpr is not None:
return copy(self).eval(jaxpr, consts, *args)
return ()

0 comments on commit 7218334

Please sign in to comment.