Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Flatten observable before getting targets #287

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
install_requires=[
"amazon-braket-sdk>=1.87.0",
"autoray>=0.6.11",
"pennylane>=0.34.0,<0.40",
"pennylane>=0.34.0",
],
entry_points={
"pennylane.plugins": [
Expand All @@ -53,7 +53,7 @@
},
extras_require={
"test": [
"autoray<0.7.0", # autoray.tensorflow_diag no longer works
"autoray<0.7.0", # autoray.tensorflow_diag no longer works
"docutils>=0.19",
"flaky",
"pre-commit",
Expand Down
3 changes: 2 additions & 1 deletion src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from braket.device_schema import DeviceActionType
from braket.devices import Device, LocalSimulator
from braket.pennylane_plugin.translation import (
flatten_observable,
get_adjoint_gradient_result_type,
supported_observables,
supported_operations,
Expand Down Expand Up @@ -281,7 +282,7 @@ def _apply_gradient_result_type(self, circuit, braket_circuit):
f" observable, not {len(circuit.observables)} observables."
)
pl_measurements = circuit.measurements[0]
pl_observable = pl_measurements.obs
pl_observable = flatten_observable(pl_measurements.obs)
if pl_measurements.return_type != Expectation:
raise ValueError(
f"Braket can only compute gradients for circuits with a single expectation"
Expand Down
8 changes: 4 additions & 4 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def get_adjoint_gradient_result_type(
if "AdjointGradient" not in supported_result_types:
raise NotImplementedError("Unsupported return type: AdjointGradient")

braket_observable = _translate_observable(_flatten_observable(observable))
braket_observable = _translate_observable(observable)
braket_observable = (
braket_observable.item() if hasattr(braket_observable, "item") else braket_observable
)
Expand Down Expand Up @@ -590,7 +590,7 @@ def translate_result_type( # noqa: C901
return tuple(Sample(observables.Z(target)) for target in targets or measurement.wires)
raise NotImplementedError(f"Unsupported return type: {return_type}")

observable = _flatten_observable(observable)
observable = flatten_observable(observable)

if isinstance(observable, qml.ops.LinearCombination):
if return_type is ObservableReturnTypes.Expectation:
Expand All @@ -608,7 +608,7 @@ def translate_result_type( # noqa: C901
raise NotImplementedError(f"Unsupported return type: {return_type}")


def _flatten_observable(observable):
def flatten_observable(observable):
if isinstance(observable, (qml.ops.CompositeOp, qml.ops.SProd)):
simplified = qml.ops.LinearCombination(*observable.terms()).simplify()
coeffs, _ = simplified.terms()
Expand Down Expand Up @@ -735,7 +735,7 @@ def translate_result(
return dict(braket_result.measurement_counts)

translated = translate_result_type(measurement, targets, supported_result_types)
observable = _flatten_observable(observable)
observable = flatten_observable(observable)
if isinstance(observable, qml.ops.LinearCombination):
coeffs, _ = observable.terms()
return sum(
Expand Down
8 changes: 4 additions & 4 deletions test/integ_tests/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_qubit_state_vector(self, init_state, device, tol):

@qml.qnode(dev)
def circuit():
qml.QubitStateVector.compute_decomposition(state, wires=[0])
qml.StatePrep.compute_decomposition(state, wires=[0])
return qml.probs(wires=range(1))

assert np.allclose(circuit(), np.abs(state) ** 2, **tol)
Expand Down Expand Up @@ -177,15 +177,15 @@ def test_qubit_channel(self, init_state, dm_device, kraus, tol):
def assert_op_and_inverse(op, dev, state, wires, tol, op_args):
@qml.qnode(dev)
def circuit():
qml.QubitStateVector.compute_decomposition(state, wires=wires)
qml.StatePrep.compute_decomposition(state, wires=wires)
op(*op_args, wires=wires)
return qml.probs(wires=wires)

assert np.allclose(circuit(), np.abs(op.compute_matrix(*op_args) @ state) ** 2, **tol)

@qml.qnode(dev)
def circuit_inv():
qml.QubitStateVector.compute_decomposition(state, wires=wires)
qml.StatePrep.compute_decomposition(state, wires=wires)
qml.adjoint(op(*op_args, wires=wires))
return qml.probs(wires=wires)

Expand All @@ -197,7 +197,7 @@ def circuit_inv():
def assert_noise_op(op, dev, state, wires, tol, op_args):
@qml.qnode(dev)
def circuit():
qml.QubitStateVector.compute_decomposition(state, wires=wires)
qml.StatePrep.compute_decomposition(state, wires=wires)
op(*op_args, wires=wires)
return qml.probs(wires=wires)

Expand Down
17 changes: 16 additions & 1 deletion test/unit_tests/test_braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,6 @@ def test_execute_with_gradient_no_op_math(
result_types,
expected_pl_result,
):

task = Mock()
type(task).id = PropertyMock(return_value="task_arn")
task.state.return_value = "COMPLETED"
Expand Down Expand Up @@ -1628,6 +1627,22 @@ def test_supported_ops_set(monkeypatch):
assert dev.operations == test_ops


def test_simplification():
"""Test that the Projector observable is correctly supported."""
wires = 5
dev = BraketLocalQubitDevice(wires=wires)

obs = qml.ops.LinearCombination([1.0, 2.0], [qml.X(0) @ qml.I(1), qml.Y(0) @ qml.X(1)])

@qml.qnode(dev)
def circuit(x):
qml.RX(x, 0)
return qml.expval(obs)

phi = np.array(1.5, requires_grad=True)
assert np.isclose(circuit(phi), 0)


def test_projection():
"""Test that the Projector observable is correctly supported."""
wires = 2
Expand Down
Loading