Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

depr: remove legacy opmath #283

Merged
merged 4 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/braket/pennylane_plugin/ahs_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@
from typing import Optional, Union

import numpy as np
import pennylane as qml
from pennylane._version import __version__
from pennylane.devices import QubitDevice
from pennylane.measurements import MeasurementProcess, SampleMeasurement
from pennylane.ops import CompositeOp, Hamiltonian
from pennylane.ops import CompositeOp
from pennylane.pulse import ParametrizedEvolution
from pennylane.pulse.hardware_hamiltonian import HardwareHamiltonian, HardwarePulse

Expand Down Expand Up @@ -312,9 +311,6 @@ def _validate_measurement_basis(self, observable):
if isinstance(observable, CompositeOp):
for op in observable.operands:
self._validate_measurement_basis(op)
elif isinstance(observable, (Hamiltonian, qml.Hamiltonian)):
for op in observable.ops:
self._validate_measurement_basis(op)

elif not observable.has_diagonalizing_gates:
raise RuntimeError(
Expand Down
4 changes: 2 additions & 2 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
Variance,
)
from pennylane.operation import Operation
from pennylane.ops import Hamiltonian, Sum
from pennylane.ops import Sum
from pennylane.tape import QuantumTape

from braket.aws import (
Expand Down Expand Up @@ -287,7 +287,7 @@ def _apply_gradient_result_type(self, circuit, braket_circuit):
f"Braket can only compute gradients for circuits with a single expectation"
f" observable, not a {pl_measurements.return_type} observable."
)
if isinstance(pl_observable, (Hamiltonian, Sum)):
if isinstance(pl_observable, Sum):
targets = [self.map_wires(op.wires) for op in pl_observable.terms()[1]]
else:
targets = self.map_wires(pl_observable.wires).tolist()
Expand Down
13 changes: 4 additions & 9 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def supported_observables(device: Device, shots: int) -> frozenset[str]:
*[_BRAKET_TO_PENNYLANE_OBSERVABLES[braket_obs] for braket_obs in braket_observables],
)
supported |= {"Prod", "SProd"}
return supported if shots else supported | {"Sum", "Hamiltonian", "LinearCombination"}
return supported if shots else supported | {"Sum", "LinearCombination"}


def get_adjoint_gradient_result_type(
Expand Down Expand Up @@ -568,7 +568,7 @@ def translate_result_type( # noqa: C901

Returns:
Union[ResultType, tuple[ResultType]]: The Braket result type corresponding to
the given observable; if the observable type has multiple terms, for example a Hamiltonian,
the given observable; if the observable type has multiple terms, for example a Sum,
then this will return a result type for each term.
"""
return_type = measurement.return_type
Expand All @@ -595,7 +595,7 @@ def translate_result_type( # noqa: C901
if isinstance(observable, qml.ops.LinearCombination):
if return_type is ObservableReturnTypes.Expectation:
return tuple(Expectation(_translate_observable(op)) for op in observable.terms()[1])
raise NotImplementedError(f"Return type {return_type} unsupported for Hamiltonian")
raise NotImplementedError(f"Return type {return_type} unsupported for LinearCombination")

braket_observable = _translate_observable(observable)
if return_type is ObservableReturnTypes.Expectation:
Expand All @@ -609,7 +609,7 @@ def translate_result_type( # noqa: C901


def _flatten_observable(observable):
if isinstance(observable, (qml.ops.Hamiltonian, qml.ops.CompositeOp, qml.ops.SProd)):
if isinstance(observable, (qml.ops.CompositeOp, qml.ops.SProd)):
simplified = qml.ops.LinearCombination(*observable.terms()).simplify()
coeffs, _ = simplified.terms()
if len(coeffs) > 1 or coeffs[0] != 1:
Expand Down Expand Up @@ -669,11 +669,6 @@ def _(obs: qml.Projector):
return observables.Hermitian(obs.matrix(), targets=wires)


@_translate_observable.register
def _(t: qml.operation.Tensor):
return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.obs])


Comment on lines -672 to -676
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, why do we remove this, is this now internally handled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor is an operator class that has been marked as legacy for a long time, and has already been deprecated. Now we are completing its removal. So qml.operation.Tensor just won't exist anymore, and doesn't need any handling.

@_translate_observable.register
def _(t: qml.ops.Prod):
return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.operands])
Expand Down
19 changes: 0 additions & 19 deletions test/unit_tests/test_ahs_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,25 +467,6 @@ def test_check_validity_valid_circuit(self, H, params):

dev.check_validity(ops, obs)

@pytest.mark.parametrize("H, params", HAMILTONIANS_AND_PARAMS)
def test_check_validity_valid_circuit_no_op_math(self, H, params):
"""Tests that check_validity() doesn't raise any errors when the operations and
observables are valid."""
qml.operation.disable_new_opmath()
ops = [ParametrizedEvolution(H, params, [0, 1.5])]
obs = [
qml.PauliZ(0),
qml.expval(qml.PauliZ(0)),
qml.var(qml.Identity(0)),
qml.sample(qml.PauliZ(0)),
qml.prod(qml.PauliZ(0), qml.Identity(1)),
qml.Hamiltonian([2, 3], [qml.PauliZ(0), qml.PauliZ(1)]),
qml.counts(),
]
dev = qml.device("braket.local.ahs", wires=3)

dev.check_validity(ops, obs)

@pytest.mark.parametrize("H, params", HAMILTONIANS_AND_PARAMS)
def test_check_validity_raises_error_for_state_based_measurement(self, H, params):
"""Tests that requesting a measurement other than a sample-based
Expand Down
5 changes: 2 additions & 3 deletions test/unit_tests/test_braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,6 @@ def test_execute_with_gradient_no_op_math(
result_types,
expected_pl_result,
):
qml.operation.disable_new_opmath()

task = Mock()
type(task).id = PropertyMock(return_value="task_arn")
Expand Down Expand Up @@ -830,8 +829,8 @@ def test_pl_to_braket_circuit_hamiltonian():


def test_pl_to_braket_circuit_hamiltonian_tensor_product_terms():
"""Tests that a PennyLane circuit is correctly converted into a Braket circuit"""
"""when the Hamiltonian has multiple tensor product ops"""
"""Tests that a PennyLane circuit is correctly converted into a Braket circuit
when the Hamiltonian has multiple tensor product ops"""
dev = _aws_device(wires=2, foo="bar")

with QuantumTape() as tape:
Expand Down
2 changes: 1 addition & 1 deletion test/unit_tests/test_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def test_translate_result_type_hamiltonian_unsupported_return(return_type):
with Hamiltonian observable and non-Expectation return type"""
obs = qml.Hamiltonian((2, 3), (qml.PauliX(wires=0), qml.PauliY(wires=1)))
tape = qml.tape.QuantumTape(measurements=[_braket_to_pl_result_types[return_type](obs)])
with pytest.raises(NotImplementedError, match="unsupported for Hamiltonian"):
with pytest.raises(NotImplementedError, match="unsupported for LinearCombination"):
translate_result_type(tape.measurements[0], [0], frozenset())


Expand Down