Skip to content

Commit

Permalink
feature: support shot vector
Browse files Browse the repository at this point in the history
  • Loading branch information
king-p3nguin committed Jun 7, 2024
1 parent 1c64c85 commit 14e6f14
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pennylane as qml
from braket.aws import AwsDevice
from braket.circuits import FreeParameter, Gate, ResultType, gates, noises, observables
from braket.circuits.observables import Observable as BraketObservable
from braket.circuits.result_types import (
AdjointGradient,
DensityMatrix,
Expand Down Expand Up @@ -530,7 +531,7 @@ def get_adjoint_gradient_result_type(
return AdjointGradient(observable=braket_observable, target=targets, parameters=parameters)


def translate_result_type(
def translate_result_type( # noqa: C901
measurement: MeasurementProcess, targets: list[int], supported_result_types: frozenset[str]
) -> Union[ResultType, tuple[ResultType, ...]]:
"""Translates a PennyLane ``MeasurementProcess`` into the corresponding Braket ``ResultType``.
Expand All @@ -547,6 +548,7 @@ def translate_result_type(
then this will return a result type for each term.
"""
return_type = measurement.return_type
observable = measurement.obs

if return_type is ObservableReturnTypes.Probability:
return Probability(targets)
Expand All @@ -558,14 +560,21 @@ def translate_result_type(
return DensityMatrix(targets)
raise NotImplementedError(f"Unsupported return type: {return_type}")

if isinstance(measurement.obs, (Hamiltonian, qml.Hamiltonian)):
if isinstance(observable, (Hamiltonian, qml.Hamiltonian)):
if return_type is ObservableReturnTypes.Expectation:
return tuple(
Expectation(_translate_observable(term), term.wires) for term in measurement.obs.ops
Expectation(_translate_observable(term), term.wires) for term in observable.ops
)
raise NotImplementedError(f"Return type {return_type} unsupported for Hamiltonian")

braket_observable = _translate_observable(measurement.obs)
if return_type is ObservableReturnTypes.Sample and observable is None:
if isinstance(measurement, qml.measurements.SampleMeasurement):
return tuple(
Sample(BraketObservable.Z(), target) for target in targets or measurement.wires
)
raise NotImplementedError(f"Unsupported measurement type: {type(measurement)}")

braket_observable = _translate_observable(observable)
if return_type is ObservableReturnTypes.Expectation:
return Expectation(braket_observable, targets)
elif return_type is ObservableReturnTypes.Variance:
Expand Down Expand Up @@ -698,6 +707,14 @@ def translate_result(
ag_result.value["gradient"][f"p_{i}"]
for i in sorted(key_indices)
]

if measurement.return_type is ObservableReturnTypes.Sample and observable is None:
if isinstance(measurement, qml.measurements.SampleMeasurement):
if targets:
return [m[targets] for m in braket_result.measurements]
return braket_result.measurements
raise NotImplementedError(f"Unsupported measurement type: {type(measurement)}")

translated = translate_result_type(measurement, targets, supported_result_types)
if isinstance(observable, (Hamiltonian, qml.Hamiltonian)):
coeffs, _ = observable.terms()
Expand Down

0 comments on commit 14e6f14

Please sign in to comment.