Skip to content

Commit

Permalink
feat: add phase RX gate (#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
AbeCoull authored Apr 19, 2024
1 parent 87353b7 commit 0e5e472
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 3 deletions.
1 change: 1 addition & 0 deletions doc/devices/braket_local.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,5 @@ from :mod:`braket.pennylane_plugin.ops <.ops>`:
braket.pennylane_plugin.GPi
braket.pennylane_plugin.GPi2
braket.pennylane_plugin.MS
braket.pennylane_plugin.PRx

1 change: 1 addition & 0 deletions doc/devices/braket_remote.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ from :mod:`braket.pennylane_plugin.ops <.ops>`:
braket.pennylane_plugin.GPi
braket.pennylane_plugin.GPi2
braket.pennylane_plugin.MS
braket.pennylane_plugin.PRx

Pulse Programming
~~~~~~~~~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions src/braket/pennylane_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
CPhaseShift10,
GPi,
GPi2,
PRx,
)

from ._version import __version__ # noqa: F401
51 changes: 51 additions & 0 deletions src/braket/pennylane_plugin/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,57 @@ def adjoint(self):
return CPhaseShift10(-phi, wires=self.wires)


class PRx(Operation):
r"""Phase Rx gate.
Unitary matrix:
.. math:: \mathtt{PRx}(\theta,\phi) = \begin{bmatrix}
\cos{(\theta / 2)} & -i e^{-i \phi} \sin{(\theta / 2)} \\
-i e^{i \phi} \sin{(\theta / 2)} & \cos{(\theta / 2)}
\end{bmatrix}.
**Details**
* Number of wires: 1
* Number of parameters: 2
Args:
theta (Union[FreeParameterExpression, float]): The first angle of the gate in
radians or expression representation.
phi (Union[FreeParameterExpression, float]): The second angle of the gate in
radians or expression representation.
"""

num_params = 2
num_wires = 1
grad_method = "F"

def __init__(self, theta, phi, wires, id=None):
super().__init__(theta, phi, wires=wires, id=id)

@staticmethod
def compute_matrix(theta, phi):
theta = _cast_to_tf(theta)
phi = _cast_to_tf(phi)
return np.array(
[
[
np.cos(theta / 2),
-1j * np.exp(-1j * phi) * np.sin(theta / 2),
],
[
-1j * np.exp(1j * phi) * np.sin(theta / 2),
np.cos(theta / 2),
],
]
)

def adjoint(self):
(theta, phi) = self.parameters
return PRx(-theta, phi, wires=self.wires)


class PSWAP(Operation):
r""" PSWAP(phi, wires)
Expand Down
9 changes: 9 additions & 0 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
CPhaseShift10,
GPi,
GPi2,
PRx,
)

_BRAKET_TO_PENNYLANE_OPERATIONS = {
Expand Down Expand Up @@ -88,6 +89,7 @@
"yy": "IsingYY",
"zz": "IsingZZ",
"ecr": "ECR",
"prx": "PRx",
"gpi": "GPi",
"gpi2": "GPi2",
"ms": "AAMS",
Expand Down Expand Up @@ -399,6 +401,13 @@ def _(zz: qml.IsingZZ, parameters, device=None):
return gates.ZZ(phi)


@_translate_operation.register
def _(_prx: PRx, parameters, device=None):
theta = parameters[0]
phi = parameters[1]
return gates.PRx(theta, phi)


@_translate_operation.register
def _(_gpi: GPi, parameters, device=None):
phi = parameters[0]
Expand Down
3 changes: 2 additions & 1 deletion test/unit_tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from numpy import float64

from braket.pennylane_plugin import PSWAP, CPhaseShift00, CPhaseShift01, CPhaseShift10
from braket.pennylane_plugin.ops import AAMS, MS, GPi, GPi2
from braket.pennylane_plugin.ops import AAMS, MS, GPi, GPi2, PRx

gates_1q_parametrized = [
(GPi, gates.GPi),
Expand All @@ -42,6 +42,7 @@

gates_2q_2p_parametrized = [
(MS, gates.MS),
(PRx, gates.PRx),
]

gates_2q_3p_parametrized = [
Expand Down
5 changes: 3 additions & 2 deletions test/unit_tests/test_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
CPhaseShift01,
CPhaseShift10,
)
from braket.pennylane_plugin.ops import AAMS, MS, GPi, GPi2
from braket.pennylane_plugin.ops import AAMS, MS, GPi, GPi2, PRx
from braket.pennylane_plugin.translation import (
_BRAKET_TO_PENNYLANE_OPERATIONS,
_translate_observable,
Expand Down Expand Up @@ -140,6 +140,7 @@ def _aws_device(
(GPi, gates.GPi, [0], [2]),
(GPi2, gates.GPi2, [0], [2]),
(MS, gates.MS, [0, 1], [2, 3]),
(PRx, gates.PRx, [0], [2, 3]),
(AAMS, gates.MS, [0, 1], [2, 3, 0.5]),
(qml.ECR, gates.ECR, [0, 1], []),
(qml.ISWAP, gates.ISwap, [0, 1], []),
Expand Down Expand Up @@ -339,7 +340,7 @@ def test_translate_operation(pl_cls, braket_cls, qubits, params):
pl_op = pl_cls(*params, wires=qubits)
braket_gate = braket_cls(*params)
assert translate_operation(pl_op) == braket_gate
if isinstance(pl_op, (GPi, GPi2, MS, AAMS)):
if isinstance(pl_op, (GPi, GPi2, MS, AAMS, PRx)):
translated_back = _braket_to_pl[
re.match("^[a-z0-2]+", braket_gate.to_ir(qubits, ir_type=IRType.OPENQASM)).group(0)
]
Expand Down

0 comments on commit 0e5e472

Please sign in to comment.