Skip to content

Commit

Permalink
add test and format
Browse files Browse the repository at this point in the history
  • Loading branch information
josephleekl committed Jan 22, 2025
1 parent ee62471 commit 4835b1f
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <algorithm> // fill
#include <complex>
#include <vector>
#include <iostream>

#include "BitUtil.hpp" // log2PerfectPower, isPerfectPowerOf2
#include "CPUMemoryModel.hpp" // bestCPUMemoryModel
Expand Down Expand Up @@ -189,9 +188,10 @@ class StateVectorLQubitManaged final
* @param num_qubits The number of qubits.
*/
void updateNumQubits(std::size_t num_qubits) {
BaseType::num_qubits_ = num_qubits;
BaseType::setKernels(num_qubits, BaseType::threading_, BaseType::memory_model_);
data_.resize(exp2(num_qubits));
BaseType::num_qubits_ = num_qubits;
BaseType::setKernels(num_qubits, BaseType::threading_,
BaseType::memory_model_);
data_.resize(exp2(num_qubits));
}

AlignedAllocator<ComplexT> allocator() const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <stdexcept>
#include <utility>
#include <vector>
#include <iostream>

#include "BitUtil.hpp" // log2PerfectPower, isPerfectPowerOf2
#include "CPUMemoryModel.hpp" // getMemoryModel
Expand Down Expand Up @@ -145,9 +144,10 @@ class StateVectorLQubitRaw final
* @param num_qubits The number of qubits.
*/
void updateNumQubits(std::size_t num_qubits) {
BaseType::num_qubits_ = num_qubits;
length_ = exp2(num_qubits);
BaseType::setKernels(num_qubits, BaseType::threading_, BaseType::memory_model_);
BaseType::num_qubits_ = num_qubits;
length_ = exp2(num_qubits);
BaseType::setKernels(num_qubits, BaseType::threading_,
BaseType::memory_model_);
}
};
} // namespace Pennylane::LightningQubit
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,10 @@ void registerBackendClassSpecificBindings(PyClass &pyclass) {
.def(py::init([](std::size_t num_qubits) {
return new StateVectorT(num_qubits);
}))
.def("updateNumQubits", [](StateVectorT &sv, std::size_t num_qubits) {
sv.updateNumQubits(num_qubits);
})
.def("updateNumQubits",
[](StateVectorT &sv, std::size_t num_qubits) {
sv.updateNumQubits(num_qubits);
})
.def("resetStateVector", &StateVectorT::resetStateVector)
.def(
"setBasisState",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,27 @@ TEMPLATE_TEST_CASE("StateVectorLQubitManaged::SetStateVector",
REQUIRE(sv.getDataVector() == approx(expected_state));
}
}

TEMPLATE_TEST_CASE("StateVectorLQubitManaged::updateNumQubits",
"[StateVectorLQubitManaged]", float, double) {
using PrecisionT = TestType;

SECTION("StateVectorLQubitManaged<TestType> {std::size_t}") {
const std::size_t original_num_qubits = 3;
const std::size_t updated_num_qubits = 5;
StateVectorLQubitManaged<PrecisionT> sv(original_num_qubits);
REQUIRE(sv.getNumQubits() == 3);
REQUIRE(sv.getLength() == 8);
REQUIRE(sv.getDataVector().size() == 8);

sv.updateNumQubits(updated_num_qubits);
REQUIRE(sv.getNumQubits() == 5);
REQUIRE(sv.getLength() == 32);
REQUIRE(sv.getDataVector().size() == 32);

sv.updateNumQubits(original_num_qubits);
REQUIRE(sv.getNumQubits() == 3);
REQUIRE(sv.getLength() == 8);
REQUIRE(sv.getDataVector().size() == 8);
}
}
1 change: 0 additions & 1 deletion pennylane_lightning/lightning_qubit/_state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def _update_num_qubits(self, new_num_wires: int):
self._num_wires = new_num_wires

Check notice on line 88 in pennylane_lightning/lightning_qubit/_state_vector.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane_lightning/lightning_qubit/_state_vector.py#L88

Attribute '_num_wires' defined outside __init__ (attribute-defined-outside-init)
self._qubit_state.updateNumQubits(new_num_wires)
self.reset_state()

Check warning on line 90 in pennylane_lightning/lightning_qubit/_state_vector.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/_state_vector.py#L88-L90

Added lines #L88 - L90 were not covered by tests


def _apply_state_vector(self, state, device_wires: Wires):
"""Initialize the internal state vector in a specified state.
Expand Down
20 changes: 13 additions & 7 deletions pennylane_lightning/lightning_qubit/lightning_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ class LightningQubit(LightningBase):

def __init__( # pylint: disable=too-many-arguments
self,
*,
wires: Union[int, List] = None,
*,
c_dtype: Union[np.complex128, np.complex64] = np.complex128,
shots: Union[int, List] = None,
batch_obs: bool = False,
Expand Down Expand Up @@ -288,10 +288,10 @@ def __init__( # pylint: disable=too-many-arguments
}

# Creating the state vector
#self._statevector = self.LightningStateVector(num_wires=len(self.wires), dtype=self._c_dtype) if self.wires else None
# self._statevector = self.LightningStateVector(num_wires=len(self.wires), dtype=self._c_dtype) if self.wires else None
self._statevector = None
self._c_dtype = c_dtype

Check warning on line 293 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L292-L293

Added lines #L292 - L293 were not covered by tests

@property
def name(self):
"""The name of the device."""
Expand Down Expand Up @@ -350,7 +350,9 @@ def preprocess(self, execution_config: ExecutionConfig = DefaultExecutionConfig)

program.add_transform(validate_measurements, name=self.name)
program.add_transform(validate_observables, accepted_observables, name=self.name)
program.add_transform(validate_device_wires, self.wires, name=self.name) #TODO: NEED to change?
program.add_transform(

Check warning on line 353 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L353

Added line #L353 was not covered by tests
validate_device_wires, self.wires, name=self.name
) # TODO: NEED to change?
program.add_transform(
mid_circuit_measurements, device=self, mcm_config=exec_config.mcm_config
)
Expand Down Expand Up @@ -392,15 +394,19 @@ def execute(
for circuit in circuits:
if self._statevector is None:
if self.wires:
self._statevector = self.LightningStateVector(num_wires = len(self.wires), dtype=self._c_dtype)
self._statevector = self.LightningStateVector(

Check warning on line 397 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L395-L397

Added lines #L395 - L397 were not covered by tests
num_wires=len(self.wires), dtype=self._c_dtype
)
else:

Check notice on line 400 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane_lightning/lightning_qubit/lightning_qubit.py#L400

Access to a protected member _update_num_qubits of a client class (protected-access)
self._statevector = self.LightningStateVector(num_wires = circuit.num_wires, dtype=self._c_dtype)
self._statevector = self.LightningStateVector(

Check warning on line 401 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L401

Added line #L401 was not covered by tests
num_wires=circuit.num_wires, dtype=self._c_dtype
)
circuit = circuit.map_to_standard_wires()

Check warning on line 404 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L404

Added line #L404 was not covered by tests
else:
if not self.wires:
self._statevector._update_num_qubits(circuit.num_wires)
circuit = circuit.map_to_standard_wires()

Check warning on line 408 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L406-L408

Added lines #L406 - L408 were not covered by tests

if self._wire_map is not None:
[circuit], _ = qml.map_wires(circuit, self._wire_map)
results.append(
Expand Down
38 changes: 38 additions & 0 deletions tests/lightning_qubit/test_state_vector_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,3 +483,41 @@ def test_get_final_state(tol, operation, input, expected_output, par):
assert np.allclose(final_state.state, np.array(expected_output), atol=tol, rtol=0)
assert final_state.state.dtype == final_state.dtype
assert final_state == state_vector


def test_dynamically_allocate_qubit(tol, operation, input, expected_output, par):
"""Tests that applying an operation yields the expected output state for two wire
operations that have parameters."""
if device_name != "lightning.qubit":
pytest.skip("Only Lightning Qubit allows dynamic qubit allocation", allow_module_level=True)

wires = 2
state_vector = LightningStateVector(wires)
tape = QuantumScript(
[qml.StatePrep(np.array(input), Wires(range(wires))), operation(*par, Wires(range(wires)))]
)
final_state = state_vector.get_final_state(tape)

assert np.allclose(final_state.state, np.array(expected_output), atol=tol, rtol=0)
assert final_state.state.dtype == final_state.dtype
assert final_state == state_vector


@pytest.mark.parametrize("num_wires", range(2, 5))
@pytest.mark.parametrize("dtype", [np.complex64, np.complex128])
def test_update_num_qubit(num_wires, dtype):
""" """
if device_name != "lightning.qubit":
pytest.skip("Only Lightning Qubit allows dynamic qubit allocation")

state_vector = LightningStateVector(num_wires, dtype=dtype)

state_vector._update_num_qubits(num_wires + 2)
expected_output = np.zeros(2 ** (num_wires + 2), dtype=dtype)
expected_output[0] = 1
assert np.allclose(state_vector.state, expected_output)

state_vector._update_num_qubits(num_wires - 1)
expected_output = np.zeros(2 ** (num_wires - 1), dtype=dtype)
expected_output[0] = 1
assert np.allclose(state_vector.state, expected_output)
32 changes: 32 additions & 0 deletions tests/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,38 @@ def test_args(self):
with pytest.raises(TypeError, match="missing 1 required positional argument: 'wires'"):
qml.device(device_name)

@pytest.mark.skipif(
device_name != "lightning.qubit",
reason="Only Lightning Qubit support dynamic qubit allocation",
)
def test_dynamic_allocate_qubit(self, qubit_device, tol):
"""Test the dynamic allocation of qubits in Lightning devices"""

dev = qubit_device(None)
dev_default = qml.device("default.qubit")

def circuit1():
qml.Identity(wires=2)
qml.Identity(wires=1)
qml.Hadamard(wires=0)
return qml.state()

def circuit2():
qml.Identity(wires=3)
qml.Identity(wires=1)
qml.Hadamard(wires=0)
return qml.state()

def circuit3():
qml.Identity(wires=3)
qml.Hadamard(wires=0)
return qml.state()

for circuit in [circuit1, circuit2, circuit3]:
results = qml.qnode(dev)(circuit)()
expected = qml.qnode(dev_default)(circuit)()
assert np.allclose(results, expected, atol=tol, rtol=0)

@pytest.mark.skipif(
device_name == "lightning.tensor", reason="lightning.tensor requires num_wires > 1"
)
Expand Down

0 comments on commit 4835b1f

Please sign in to comment.