Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: support LinearCombination as observable #246

Merged
merged 3 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/braket/pennylane_plugin/ahs_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@
from typing import Optional, Union

import numpy as np
import pennylane as qml
from braket.ahs.analog_hamiltonian_simulation import AnalogHamiltonianSimulation
from braket.aws import AwsDevice, AwsQuantumTask, AwsSession
from braket.devices import Device, LocalSimulator
from pennylane import QubitDevice
from pennylane._version import __version__
from pennylane.measurements import MeasurementProcess, SampleMeasurement
from pennylane.ops import CompositeOp, Hamiltonian
from pennylane.ops import CompositeOp
ashlhans marked this conversation as resolved.
Show resolved Hide resolved
from pennylane.pulse import ParametrizedEvolution
from pennylane.pulse.hardware_hamiltonian import HardwareHamiltonian, HardwarePulse

Expand Down Expand Up @@ -308,7 +309,7 @@ def _validate_measurement_basis(self, observable):
if isinstance(observable, CompositeOp):
for op in observable.operands:
self._validate_measurement_basis(op)
elif isinstance(observable, Hamiltonian):
elif isinstance(observable, (qml.ops.Hamiltonian, qml.Hamiltonian)):
ashlhans marked this conversation as resolved.
Show resolved Hide resolved
for op in observable.ops:
self._validate_measurement_basis(op)

Expand Down
5 changes: 2 additions & 3 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
Variance,
)
from pennylane.operation import Operation
from pennylane.ops.qubit.hamiltonian import Hamiltonian
from pennylane.tape import QuantumTape

from braket.pennylane_plugin.translation import (
Expand Down Expand Up @@ -166,7 +165,7 @@ def observables(self) -> frozenset[str]:
# This needs to be here bc expectation(ax+by)== a*expectation(x)+b*expectation(y)
# is only true when shots=0
if not self.shots:
return base_observables.union({"Hamiltonian"})
return base_observables.union({"Hamiltonian", "LinearCombination"})
return base_observables

@property
Expand Down Expand Up @@ -227,7 +226,7 @@ def _apply_gradient_result_type(self, circuit, braket_circuit):
f"Braket can only compute gradients for circuits with a single expectation"
f" observable, not a {pl_measurements.return_type} observable."
)
if isinstance(pl_observable, Hamiltonian):
if isinstance(pl_observable, (qml.ops.Hamiltonian, qml.Hamiltonian)):
targets = [self.map_wires(op.wires) for op in pl_observable.ops]
else:
targets = self.map_wires(pl_observable.wires).tolist()
Expand Down
19 changes: 15 additions & 4 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def translate_result_type(
return DensityMatrix(targets)
raise NotImplementedError(f"Unsupported return type: {return_type}")

if isinstance(measurement.obs, qml.Hamiltonian):
if isinstance(measurement.obs, (qml.ops.Hamiltonian, qml.Hamiltonian)):
if return_type is ObservableReturnTypes.Expectation:
return tuple(
Expectation(_translate_observable(term), term.wires) for term in measurement.obs.ops
Expand All @@ -581,8 +581,9 @@ def _translate_observable(observable):
raise qml.DeviceError(f"Unsupported observable: {type(observable)}")


@_translate_observable.register
def _(H: qml.Hamiltonian):
@_translate_observable.register(qml.ops.Hamiltonian)
@_translate_observable.register(qml.Hamiltonian)
def _(H: Union[qml.ops.Hamiltonian, qml.Hamiltonian]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work?

Suggested change
@_translate_observable.register(qml.ops.Hamiltonian)
@_translate_observable.register(qml.Hamiltonian)
def _(H: Union[qml.ops.Hamiltonian, qml.Hamiltonian]):
@_translate_observable.register
def _(H: Union[qml.ops.Hamiltonian, qml.Hamiltonian]):

Otherwise,

@_translate_observable.register(qml.ops.Hamiltonian)
@_translate_observable.register(qml.Hamiltonian)
def _(H):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only since Python 3.11

# terms is structured like [C, O] where C is a tuple of all the coefficients, and O is
# a tuple of all the corresponding observable terms (X, Y, Z, H, etc or a tensor product
# of them)
Expand Down Expand Up @@ -651,6 +652,16 @@ def _(t: qml.ops.Prod):
return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.operands])


@_translate_observable.register
def _(t: qml.ops.SProd):
return t.scalar * _translate_observable(t.base)


@_translate_observable.register
def _(t: qml.ops.Sum):
return reduce(lambda x, y: x + y, [_translate_observable(operator) for operator in t.operands])


def translate_result(
braket_result: GateModelQuantumTaskResult,
measurement: MeasurementProcess,
Expand Down Expand Up @@ -688,7 +699,7 @@ def translate_result(
for i in sorted(key_indices)
]
translated = translate_result_type(measurement, targets, supported_result_types)
if isinstance(observable, qml.Hamiltonian):
if isinstance(observable, (qml.ops.Hamiltonian, qml.Hamiltonian)):
coeffs, _ = observable.terms()
return sum(
coeff * braket_result.get_value_by_result_type(result_type)
Expand Down
2 changes: 2 additions & 0 deletions test/unit_tests/test_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,8 @@ def _result_meta() -> dict:
),
(1.25 * observables.H(), 1.25 * qml.Hadamard(wires=0)),
(observables.X() @ observables.Y(), qml.ops.Prod(qml.PauliX(0), qml.PauliY(1))),
(observables.X() + observables.Y(), qml.ops.Sum(qml.PauliX(0), qml.PauliY(1))),
(observables.X(), qml.ops.SProd(scalar=4, base=qml.PauliX(0))),
],
)
def test_translate_hamiltonian_observable(expected_braket_H, pl_H):
Expand Down
Loading