Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: support LinearCombination as observable #246

Merged
merged 3 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/braket/pennylane_plugin/ahs_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@
from pennylane._version import __version__
from pennylane.measurements import MeasurementProcess, SampleMeasurement
from pennylane.ops import CompositeOp, Hamiltonian

try:
from pennylane.ops import LinearCombination
except (AttributeError, ImportError):

class LinearCombination:
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than doing conditional imports, I'd say it's fine to wait until 0.36 is released and import normally



from pennylane.pulse import ParametrizedEvolution
from pennylane.pulse.hardware_hamiltonian import HardwareHamiltonian, HardwarePulse

Expand Down Expand Up @@ -308,7 +317,7 @@ def _validate_measurement_basis(self, observable):
if isinstance(observable, CompositeOp):
for op in observable.operands:
self._validate_measurement_basis(op)
elif isinstance(observable, Hamiltonian):
elif isinstance(observable, (Hamiltonian, LinearCombination)):
for op in observable.ops:
self._validate_measurement_basis(op)

Expand Down
13 changes: 10 additions & 3 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,16 @@
Variance,
)
from pennylane.operation import Operation
from pennylane.ops.qubit.hamiltonian import Hamiltonian
from pennylane.tape import QuantumTape

try:
from pennylane.ops import LinearCombination
except (AttributeError, ImportError):

class LinearCombination:
pass


from braket.pennylane_plugin.translation import (
get_adjoint_gradient_result_type,
supported_operations,
Expand Down Expand Up @@ -166,7 +173,7 @@ def observables(self) -> frozenset[str]:
# This needs to be here bc expectation(ax+by)== a*expectation(x)+b*expectation(y)
# is only true when shots=0
if not self.shots:
return base_observables.union({"Hamiltonian"})
return base_observables.union({"Hamiltonian", "LinearCombination"})
return base_observables

@property
Expand Down Expand Up @@ -227,7 +234,7 @@ def _apply_gradient_result_type(self, circuit, braket_circuit):
f"Braket can only compute gradients for circuits with a single expectation"
f" observable, not a {pl_measurements.return_type} observable."
)
if isinstance(pl_observable, Hamiltonian):
if isinstance(pl_observable, (qml.ops.Hamiltonian, LinearCombination)):
targets = [self.map_wires(op.wires) for op in pl_observable.ops]
else:
targets = self.map_wires(pl_observable.wires).tolist()
Expand Down
28 changes: 24 additions & 4 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@
from pennylane.measurements import MeasurementProcess, ObservableReturnTypes
from pennylane.operation import Observable, Operation
from pennylane.ops import Adjoint

try:
from pennylane.ops import LinearCombination
except (AttributeError, ImportError):

class LinearCombination:
pass


from pennylane.pulse import ParametrizedEvolution

from braket.pennylane_plugin.ops import (
Expand Down Expand Up @@ -558,7 +567,7 @@ def translate_result_type(
return DensityMatrix(targets)
raise NotImplementedError(f"Unsupported return type: {return_type}")

if isinstance(measurement.obs, qml.Hamiltonian):
if isinstance(measurement.obs, (qml.ops.Hamiltonian, LinearCombination)):
if return_type is ObservableReturnTypes.Expectation:
return tuple(
Expectation(_translate_observable(term), term.wires) for term in measurement.obs.ops
Expand All @@ -581,8 +590,9 @@ def _translate_observable(observable):
raise qml.DeviceError(f"Unsupported observable: {type(observable)}")


@_translate_observable.register
def _(H: qml.Hamiltonian):
@_translate_observable.register(qml.ops.Hamiltonian)
@_translate_observable.register(LinearCombination)
def _(H: Union[qml.ops.Hamiltonian, LinearCombination]):
# terms is structured like [C, O] where C is a tuple of all the coefficients, and O is
# a tuple of all the corresponding observable terms (X, Y, Z, H, etc or a tensor product
# of them)
Expand Down Expand Up @@ -651,6 +661,16 @@ def _(t: qml.ops.Prod):
return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.operands])


@_translate_observable.register
def _(t: qml.ops.SProd):
return t.scalar * _translate_observable(t.base)


@_translate_observable.register
def _(t: qml.ops.Sum):
return reduce(lambda x, y: x + y, [_translate_observable(operator) for operator in t.operands])


def translate_result(
braket_result: GateModelQuantumTaskResult,
measurement: MeasurementProcess,
Expand Down Expand Up @@ -688,7 +708,7 @@ def translate_result(
for i in sorted(key_indices)
]
translated = translate_result_type(measurement, targets, supported_result_types)
if isinstance(observable, qml.Hamiltonian):
if isinstance(observable, (qml.ops.Hamiltonian, LinearCombination)):
coeffs, _ = observable.terms()
return sum(
coeff * braket_result.get_value_by_result_type(result_type)
Expand Down
2 changes: 2 additions & 0 deletions test/unit_tests/test_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,8 @@ def _result_meta() -> dict:
),
(1.25 * observables.H(), 1.25 * qml.Hadamard(wires=0)),
(observables.X() @ observables.Y(), qml.ops.Prod(qml.PauliX(0), qml.PauliY(1))),
(observables.X() + observables.Y(), qml.ops.Sum(qml.PauliX(0), qml.PauliY(1))),
(observables.X(), qml.ops.SProd(scalar=4, base=qml.PauliX(0))),
],
)
def test_translate_hamiltonian_observable(expected_braket_H, pl_H):
Expand Down
Loading