diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index a9844b22f..e507ebcdf 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -110,6 +110,7 @@ def __init__(self, addable: AddableTypes = None, *args, **kwargs): self._moments: Moments = Moments() self._result_types: List[ResultType] = [] self._qubit_observable_mapping: Dict[Union[int, Circuit._ALL_QUBITS], Observable] = {} + self._qubit_target_mapping: Dict[int, List[int]] = {} if addable is not None: self.add(addable, *args, **kwargs) @@ -274,13 +275,23 @@ def _add_to_qubit_observable_mapping(self, result_type: ResultType) -> None: for target in targets: current_observable = all_qubits_observable or self._qubit_observable_mapping.get(target) + current_target = self._qubit_target_mapping.get(target) if current_observable and current_observable != observable: raise ValueError( f"Existing result type for observable {current_observable} for target {target}" f" conflicts with observable {observable} for new result type" ) + if result_type.target: + # The only way this can happen is if the observables (acting on multiple target + # qubits) and target qubits are the same, but the new target is the wrong order; + if current_target and current_target != targets: + raise ValueError( + f"Target order {current_target} of existing result type with observable" + f" {current_observable} conflicts with order {targets} of new result type" + ) self._qubit_observable_mapping[target] = observable + self._qubit_target_mapping[target] = targets if not result_type.target: self._qubit_observable_mapping[Circuit._ALL_QUBITS] = observable diff --git a/test/unit_tests/braket/circuits/test_circuit.py b/test/unit_tests/braket/circuits/test_circuit.py index 23974e542..8952bfae6 100644 --- a/test/unit_tests/braket/circuits/test_circuit.py +++ b/test/unit_tests/braket/circuits/test_circuit.py @@ -193,6 +193,15 @@ def test_add_result_type_observable_no_conflict_state_vector_obs_return_value(): assert circ.result_types == expected +@pytest.mark.xfail(raises=ValueError) +def test_add_result_type_same_observable_wrong_target_order(): + Circuit().add_result_type( + ResultType.Expectation(observable=Observable.Y() @ Observable.X(), target=[0, 1]) + ).add_result_type( + ResultType.Variance(observable=Observable.Y() @ Observable.X(), target=[1, 0]) + ) + + @pytest.mark.xfail(raises=TypeError) def test_add_result_type_with_target_and_mapping(prob): Circuit().add_result_type(prob, target=[10], target_mapping={0: 10})