Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: support LinearCombination as observable #246

Merged
merged 3 commits into from
May 3, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feature: support LinearCombination as observable
ashlhans committed Apr 30, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 05ac87cd4b39bd91ceb1509bb2bd13a7b7077074
11 changes: 10 additions & 1 deletion src/braket/pennylane_plugin/ahs_device.py
Original file line number Diff line number Diff line change
@@ -44,6 +44,15 @@
from pennylane._version import __version__
from pennylane.measurements import MeasurementProcess, SampleMeasurement
from pennylane.ops import CompositeOp, Hamiltonian

try:
from pennylane.ops import LinearCombination
except (AttributeError, ImportError):

class LinearCombination:
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than doing conditional imports, I'd say it's fine to wait until 0.36 is released and import normally



from pennylane.pulse import ParametrizedEvolution
from pennylane.pulse.hardware_hamiltonian import HardwareHamiltonian, HardwarePulse

@@ -308,7 +317,7 @@ def _validate_measurement_basis(self, observable):
if isinstance(observable, CompositeOp):
for op in observable.operands:
self._validate_measurement_basis(op)
elif isinstance(observable, Hamiltonian):
elif isinstance(observable, (Hamiltonian, LinearCombination)):
for op in observable.ops:
self._validate_measurement_basis(op)

13 changes: 10 additions & 3 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
@@ -65,9 +65,16 @@
Variance,
)
from pennylane.operation import Operation
from pennylane.ops.qubit.hamiltonian import Hamiltonian
from pennylane.tape import QuantumTape

try:
from pennylane.ops import LinearCombination
except (AttributeError, ImportError):

class LinearCombination:
pass


from braket.pennylane_plugin.translation import (
get_adjoint_gradient_result_type,
supported_operations,
@@ -166,7 +173,7 @@ def observables(self) -> frozenset[str]:
# This needs to be here bc expectation(ax+by)== a*expectation(x)+b*expectation(y)
# is only true when shots=0
if not self.shots:
return base_observables.union({"Hamiltonian"})
return base_observables.union({"Hamiltonian", "LinearCombination"})
return base_observables

@property
@@ -227,7 +234,7 @@ def _apply_gradient_result_type(self, circuit, braket_circuit):
f"Braket can only compute gradients for circuits with a single expectation"
f" observable, not a {pl_measurements.return_type} observable."
)
if isinstance(pl_observable, Hamiltonian):
if isinstance(pl_observable, (qml.ops.Hamiltonian, LinearCombination)):
targets = [self.map_wires(op.wires) for op in pl_observable.ops]
else:
targets = self.map_wires(pl_observable.wires).tolist()
28 changes: 24 additions & 4 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,15 @@
from pennylane.measurements import MeasurementProcess, ObservableReturnTypes
from pennylane.operation import Observable, Operation
from pennylane.ops import Adjoint

try:
from pennylane.ops import LinearCombination
except (AttributeError, ImportError):

class LinearCombination:
pass


from pennylane.pulse import ParametrizedEvolution

from braket.pennylane_plugin.ops import (
@@ -558,7 +567,7 @@ def translate_result_type(
return DensityMatrix(targets)
raise NotImplementedError(f"Unsupported return type: {return_type}")

if isinstance(measurement.obs, qml.Hamiltonian):
if isinstance(measurement.obs, (qml.ops.Hamiltonian, LinearCombination)):
if return_type is ObservableReturnTypes.Expectation:
return tuple(
Expectation(_translate_observable(term), term.wires) for term in measurement.obs.ops
@@ -581,8 +590,9 @@ def _translate_observable(observable):
raise qml.DeviceError(f"Unsupported observable: {type(observable)}")


@_translate_observable.register
def _(H: qml.Hamiltonian):
@_translate_observable.register(qml.ops.Hamiltonian)
@_translate_observable.register(LinearCombination)
def _(H: Union[qml.ops.Hamiltonian, LinearCombination]):
# terms is structured like [C, O] where C is a tuple of all the coefficients, and O is
# a tuple of all the corresponding observable terms (X, Y, Z, H, etc or a tensor product
# of them)
@@ -651,6 +661,16 @@ def _(t: qml.ops.Prod):
return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.operands])


@_translate_observable.register
def _(t: qml.ops.SProd):
return t.scalar * _translate_observable(t.base)


@_translate_observable.register
def _(t: qml.ops.Sum):
return reduce(lambda x, y: x + y, [_translate_observable(operator) for operator in t.operands])


def translate_result(
braket_result: GateModelQuantumTaskResult,
measurement: MeasurementProcess,
@@ -688,7 +708,7 @@ def translate_result(
for i in sorted(key_indices)
]
translated = translate_result_type(measurement, targets, supported_result_types)
if isinstance(observable, qml.Hamiltonian):
if isinstance(observable, (qml.ops.Hamiltonian, LinearCombination)):
coeffs, _ = observable.terms()
return sum(
coeff * braket_result.get_value_by_result_type(result_type)
2 changes: 2 additions & 0 deletions test/unit_tests/test_translation.py
Original file line number Diff line number Diff line change
@@ -846,6 +846,8 @@ def _result_meta() -> dict:
),
(1.25 * observables.H(), 1.25 * qml.Hadamard(wires=0)),
(observables.X() @ observables.Y(), qml.ops.Prod(qml.PauliX(0), qml.PauliY(1))),
(observables.X() + observables.Y(), qml.ops.Sum(qml.PauliX(0), qml.PauliY(1))),
(observables.X(), qml.ops.SProd(scalar=4, base=qml.PauliX(0))),
],
)
def test_translate_hamiltonian_observable(expected_braket_H, pl_H):