diff --git a/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubitManaged.hpp b/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubitManaged.hpp index 7107a0425a..544641b972 100644 --- a/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubitManaged.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubitManaged.hpp @@ -22,7 +22,6 @@ #include // fill #include #include -#include #include "BitUtil.hpp" // log2PerfectPower, isPerfectPowerOf2 #include "CPUMemoryModel.hpp" // bestCPUMemoryModel @@ -189,9 +188,10 @@ class StateVectorLQubitManaged final * @param num_qubits The number of qubits. */ void updateNumQubits(std::size_t num_qubits) { - BaseType::num_qubits_ = num_qubits; - BaseType::setKernels(num_qubits, BaseType::threading_, BaseType::memory_model_); - data_.resize(exp2(num_qubits)); + BaseType::num_qubits_ = num_qubits; + BaseType::setKernels(num_qubits, BaseType::threading_, + BaseType::memory_model_); + data_.resize(exp2(num_qubits)); } AlignedAllocator allocator() const { diff --git a/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubitRaw.hpp b/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubitRaw.hpp index 3d94126f5c..c420b8abb9 100644 --- a/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubitRaw.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubitRaw.hpp @@ -22,7 +22,6 @@ #include #include #include -#include #include "BitUtil.hpp" // log2PerfectPower, isPerfectPowerOf2 #include "CPUMemoryModel.hpp" // getMemoryModel @@ -145,9 +144,10 @@ class StateVectorLQubitRaw final * @param num_qubits The number of qubits. */ void updateNumQubits(std::size_t num_qubits) { - BaseType::num_qubits_ = num_qubits; - length_ = exp2(num_qubits); - BaseType::setKernels(num_qubits, BaseType::threading_, BaseType::memory_model_); + BaseType::num_qubits_ = num_qubits; + length_ = exp2(num_qubits); + BaseType::setKernels(num_qubits, BaseType::threading_, + BaseType::memory_model_); } }; } // namespace Pennylane::LightningQubit diff --git a/pennylane_lightning/core/src/simulators/lightning_qubit/bindings/LQubitBindings.hpp b/pennylane_lightning/core/src/simulators/lightning_qubit/bindings/LQubitBindings.hpp index a744de3ce1..c1fb36c88c 100644 --- a/pennylane_lightning/core/src/simulators/lightning_qubit/bindings/LQubitBindings.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_qubit/bindings/LQubitBindings.hpp @@ -183,9 +183,10 @@ void registerBackendClassSpecificBindings(PyClass &pyclass) { .def(py::init([](std::size_t num_qubits) { return new StateVectorT(num_qubits); })) - .def("updateNumQubits", [](StateVectorT &sv, std::size_t num_qubits) { - sv.updateNumQubits(num_qubits); - }) + .def("updateNumQubits", + [](StateVectorT &sv, std::size_t num_qubits) { + sv.updateNumQubits(num_qubits); + }) .def("resetStateVector", &StateVectorT::resetStateVector) .def( "setBasisState", diff --git a/pennylane_lightning/core/src/simulators/lightning_qubit/tests/Test_StateVectorLQubitManaged.cpp b/pennylane_lightning/core/src/simulators/lightning_qubit/tests/Test_StateVectorLQubitManaged.cpp index b424e7e3cc..14e3e5c5a5 100644 --- a/pennylane_lightning/core/src/simulators/lightning_qubit/tests/Test_StateVectorLQubitManaged.cpp +++ b/pennylane_lightning/core/src/simulators/lightning_qubit/tests/Test_StateVectorLQubitManaged.cpp @@ -165,3 +165,27 @@ TEMPLATE_TEST_CASE("StateVectorLQubitManaged::SetStateVector", REQUIRE(sv.getDataVector() == approx(expected_state)); } } + +TEMPLATE_TEST_CASE("StateVectorLQubitManaged::updateNumQubits", + "[StateVectorLQubitManaged]", float, double) { + using PrecisionT = TestType; + + SECTION("StateVectorLQubitManaged {std::size_t}") { + const std::size_t original_num_qubits = 3; + const std::size_t updated_num_qubits = 5; + StateVectorLQubitManaged sv(original_num_qubits); + REQUIRE(sv.getNumQubits() == 3); + REQUIRE(sv.getLength() == 8); + REQUIRE(sv.getDataVector().size() == 8); + + sv.updateNumQubits(updated_num_qubits); + REQUIRE(sv.getNumQubits() == 5); + REQUIRE(sv.getLength() == 32); + REQUIRE(sv.getDataVector().size() == 32); + + sv.updateNumQubits(original_num_qubits); + REQUIRE(sv.getNumQubits() == 3); + REQUIRE(sv.getLength() == 8); + REQUIRE(sv.getDataVector().size() == 8); + } +} diff --git a/pennylane_lightning/lightning_qubit/_state_vector.py b/pennylane_lightning/lightning_qubit/_state_vector.py index fefcb8e0f0..7c1f525e22 100644 --- a/pennylane_lightning/lightning_qubit/_state_vector.py +++ b/pennylane_lightning/lightning_qubit/_state_vector.py @@ -88,7 +88,6 @@ def _update_num_qubits(self, new_num_wires: int): self._num_wires = new_num_wires self._qubit_state.updateNumQubits(new_num_wires) self.reset_state() - def _apply_state_vector(self, state, device_wires: Wires): """Initialize the internal state vector in a specified state. diff --git a/pennylane_lightning/lightning_qubit/lightning_qubit.py b/pennylane_lightning/lightning_qubit/lightning_qubit.py index cf50138380..25f1fc8cca 100644 --- a/pennylane_lightning/lightning_qubit/lightning_qubit.py +++ b/pennylane_lightning/lightning_qubit/lightning_qubit.py @@ -229,8 +229,8 @@ class LightningQubit(LightningBase): def __init__( # pylint: disable=too-many-arguments self, - *, wires: Union[int, List] = None, + *, c_dtype: Union[np.complex128, np.complex64] = np.complex128, shots: Union[int, List] = None, batch_obs: bool = False, @@ -288,10 +288,10 @@ def __init__( # pylint: disable=too-many-arguments } # Creating the state vector - #self._statevector = self.LightningStateVector(num_wires=len(self.wires), dtype=self._c_dtype) if self.wires else None + # self._statevector = self.LightningStateVector(num_wires=len(self.wires), dtype=self._c_dtype) if self.wires else None self._statevector = None self._c_dtype = c_dtype - + @property def name(self): """The name of the device.""" @@ -350,7 +350,9 @@ def preprocess(self, execution_config: ExecutionConfig = DefaultExecutionConfig) program.add_transform(validate_measurements, name=self.name) program.add_transform(validate_observables, accepted_observables, name=self.name) - program.add_transform(validate_device_wires, self.wires, name=self.name) #TODO: NEED to change? + program.add_transform( + validate_device_wires, self.wires, name=self.name + ) # TODO: NEED to change? program.add_transform( mid_circuit_measurements, device=self, mcm_config=exec_config.mcm_config ) @@ -392,15 +394,19 @@ def execute( for circuit in circuits: if self._statevector is None: if self.wires: - self._statevector = self.LightningStateVector(num_wires = len(self.wires), dtype=self._c_dtype) + self._statevector = self.LightningStateVector( + num_wires=len(self.wires), dtype=self._c_dtype + ) else: - self._statevector = self.LightningStateVector(num_wires = circuit.num_wires, dtype=self._c_dtype) + self._statevector = self.LightningStateVector( + num_wires=circuit.num_wires, dtype=self._c_dtype + ) circuit = circuit.map_to_standard_wires() else: if not self.wires: self._statevector._update_num_qubits(circuit.num_wires) circuit = circuit.map_to_standard_wires() - + if self._wire_map is not None: [circuit], _ = qml.map_wires(circuit, self._wire_map) results.append( diff --git a/tests/lightning_qubit/test_state_vector_class.py b/tests/lightning_qubit/test_state_vector_class.py index 28927a0b55..a7aa8b6023 100644 --- a/tests/lightning_qubit/test_state_vector_class.py +++ b/tests/lightning_qubit/test_state_vector_class.py @@ -483,3 +483,41 @@ def test_get_final_state(tol, operation, input, expected_output, par): assert np.allclose(final_state.state, np.array(expected_output), atol=tol, rtol=0) assert final_state.state.dtype == final_state.dtype assert final_state == state_vector + + +def test_dynamically_allocate_qubit(tol, operation, input, expected_output, par): + """Tests that applying an operation yields the expected output state for two wire + operations that have parameters.""" + if device_name != "lightning.qubit": + pytest.skip("Only Lightning Qubit allows dynamic qubit allocation", allow_module_level=True) + + wires = 2 + state_vector = LightningStateVector(wires) + tape = QuantumScript( + [qml.StatePrep(np.array(input), Wires(range(wires))), operation(*par, Wires(range(wires)))] + ) + final_state = state_vector.get_final_state(tape) + + assert np.allclose(final_state.state, np.array(expected_output), atol=tol, rtol=0) + assert final_state.state.dtype == final_state.dtype + assert final_state == state_vector + + +@pytest.mark.parametrize("num_wires", range(2, 5)) +@pytest.mark.parametrize("dtype", [np.complex64, np.complex128]) +def test_update_num_qubit(num_wires, dtype): + """ """ + if device_name != "lightning.qubit": + pytest.skip("Only Lightning Qubit allows dynamic qubit allocation") + + state_vector = LightningStateVector(num_wires, dtype=dtype) + + state_vector._update_num_qubits(num_wires + 2) + expected_output = np.zeros(2 ** (num_wires + 2), dtype=dtype) + expected_output[0] = 1 + assert np.allclose(state_vector.state, expected_output) + + state_vector._update_num_qubits(num_wires - 1) + expected_output = np.zeros(2 ** (num_wires - 1), dtype=dtype) + expected_output[0] = 1 + assert np.allclose(state_vector.state, expected_output) diff --git a/tests/test_apply.py b/tests/test_apply.py index 927058d015..9639f54d4b 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -739,6 +739,38 @@ def test_args(self): with pytest.raises(TypeError, match="missing 1 required positional argument: 'wires'"): qml.device(device_name) + @pytest.mark.skipif( + device_name != "lightning.qubit", + reason="Only Lightning Qubit support dynamic qubit allocation", + ) + def test_dynamic_allocate_qubit(self, qubit_device, tol): + """Test the dynamic allocation of qubits in Lightning devices""" + + dev = qubit_device(None) + dev_default = qml.device("default.qubit") + + def circuit1(): + qml.Identity(wires=2) + qml.Identity(wires=1) + qml.Hadamard(wires=0) + return qml.state() + + def circuit2(): + qml.Identity(wires=3) + qml.Identity(wires=1) + qml.Hadamard(wires=0) + return qml.state() + + def circuit3(): + qml.Identity(wires=3) + qml.Hadamard(wires=0) + return qml.state() + + for circuit in [circuit1, circuit2, circuit3]: + results = qml.qnode(dev)(circuit)() + expected = qml.qnode(dev_default)(circuit)() + assert np.allclose(results, expected, atol=tol, rtol=0) + @pytest.mark.skipif( device_name == "lightning.tensor", reason="lightning.tensor requires num_wires > 1" )