Skip to content

Commit

Permalink
Merge pull request #23 from yasuhito/refactor/improve-dispatch-readab…
Browse files Browse the repository at this point in the history
…ility

refactor: improve TranspilerDispatcher flow by extracting helper methods
  • Loading branch information
snuffkin authored Dec 26, 2024
2 parents da6d883 + 0e81156 commit f18069c
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 136 deletions.
8 changes: 2 additions & 6 deletions src/tranqu/device_type_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
from .tranqu_error import TranquError


class DeviceLibNotFoundError(TranquError):
"""Error raised when device library cannot be detected."""


class DeviceLibraryAlreadyRegisteredError(TranquError):
"""Raised when attempting to register a device library that already exists."""

Expand Down Expand Up @@ -42,8 +38,8 @@ def register_type(

self._type_registry[device_type] = device_lib

def detect_lib(self, device: Any) -> str | None: # noqa: ANN401
"""Detect library based on device type.
def resolve_lib(self, device: Any) -> str | None: # noqa: ANN401
"""Resolve library based on device type.
Args:
device (Any): Device to inspect
Expand Down
4 changes: 2 additions & 2 deletions src/tranqu/program_type_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def register_type(

self._type_registry[program_type] = program_lib

def detect_lib(self, program: Any) -> str | None: # noqa: ANN401
"""Detect the library identifier for a given program instance.
def resolve_lib(self, program: Any) -> str | None: # noqa: ANN401
"""Resolve the library identifier for a given program instance.
Args:
program (Any): Program instance to inspect
Expand Down
194 changes: 115 additions & 79 deletions src/tranqu/transpiler_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class TranspilerDispatcherError(TranquError):
"""Base class for errors related to the transpiler dispatcher."""


class ProgramLibNotFoundError(TranquError):
"""Error when program library cannot be detected."""
class ProgramLibResolutionError(TranspilerDispatcherError):
"""Error raised when program library cannot be resolved."""


class ProgramNotSpecifiedError(TranspilerDispatcherError):
Expand Down Expand Up @@ -103,50 +103,22 @@ def dispatch( # noqa: PLR0913 PLR0917
Raises:
ProgramNotSpecifiedError: Raised when no program is specified.
ProgramLibNotFoundError: Raised when program library cannot be detected.
TranspilerLibNotSpecifiedError: Raised when no transpiler library
is specified.
DeviceNotSpecifiedError: Raised when a device library is specified
but no device is specified.
"""
if program is None:
msg = "No program specified. Please specify a valid quantum circuit."
raise ProgramNotSpecifiedError(msg)
if transpiler_lib is None:
transpiler_lib = self._transpiler_manager.get_default_transpiler_lib()
if transpiler_lib is None:
msg = (
"No transpiler library specified."
" Please specify a transpiler to use."
)
raise TranspilerLibNotSpecifiedError(msg)

detected_program_lib = (
self._detect_program_lib(program) if program_lib is None else program_lib
)
if detected_program_lib is None:
msg = (
"Could not detect program library. Please either "
"specify program_lib or register the program type "
"using register_program_type()."
)
raise ProgramLibNotFoundError(msg)

detected_device_lib = (
self._detect_device_lib(device) if device_lib is None else device_lib
)
if detected_device_lib is not None and device is None:
msg = "Device library is specified but no device is specified."
raise DeviceNotSpecifiedError(msg)

transpiler = self._transpiler_manager.fetch_transpiler(transpiler_lib)
selected_transpiler_lib = self._select_transpiler_lib(transpiler_lib)
resolved_program_lib = self._resolve_program_lib(program, program_lib)
resolved_device_lib = self._resolve_device_lib(device, device_lib)
transpiler = self._transpiler_manager.fetch_transpiler(selected_transpiler_lib)

converted_program = self._convert_program(
program, detected_program_lib, transpiler_lib
program, from_lib=resolved_program_lib, to_lib=selected_transpiler_lib
)
converted_device = self._convert_device(
device, detected_device_lib, transpiler_lib
device, from_lib=resolved_device_lib, to_lib=selected_transpiler_lib
)

transpile_result = transpiler.transpile(
Expand All @@ -157,62 +129,135 @@ def dispatch( # noqa: PLR0913 PLR0917

transpile_result.transpiled_program = self._convert_program(
transpile_result.transpiled_program,
transpiler_lib,
detected_program_lib,
from_lib=selected_transpiler_lib,
to_lib=resolved_program_lib,
)

return transpile_result

def _detect_program_lib(self, program: Any) -> str | None: # noqa: ANN401
return self._program_type_manager.detect_lib(program)
def _select_transpiler_lib(self, transpiler_lib: str | None) -> str:
selected_lib = transpiler_lib

def _detect_device_lib(self, device: Any) -> str | None: # noqa: ANN401
return self._device_type_manager.detect_lib(device)
if selected_lib is None:
selected_lib = self._transpiler_manager.get_default_transpiler_lib()

def _convert_program(self, program: Any, from_lib: str, to_lib: Any) -> Any: # noqa: ANN401
if self._program_converter_manager.has_converter(from_lib, to_lib):
return self._program_converter_manager.fetch_converter(
from_lib,
to_lib,
).convert(program)
if selected_lib is None:
msg = "No transpiler library specified. Please specify a transpiler to use."
raise TranspilerLibNotSpecifiedError(msg)

can_convert_to_qiskit = self._program_converter_manager.has_converter(
from_lib,
"qiskit",
)
can_convert_to_target = self._program_converter_manager.has_converter(
"qiskit",
to_lib,
)
if not (can_convert_to_qiskit and can_convert_to_target):
return selected_lib

def _resolve_program_lib(self, program: Any, program_lib: str | None) -> str: # noqa: ANN401
if program_lib is None:
resolved_lib = self._program_type_manager.resolve_lib(program)
else:
resolved_lib = program_lib

if resolved_lib is None:
msg = (
"Could not resolve program library. Please either "
"specify program_lib or register the program type "
"using register_program_type()."
)
raise ProgramLibResolutionError(msg)

return resolved_lib

def _resolve_device_lib(
self,
device: Any | None, # noqa: ANN401
device_lib: str | None,
) -> str | None:
if device is None and device_lib is not None:
msg = "Device library is specified but no device is specified."
raise DeviceNotSpecifiedError(msg)

if device_lib is None:
resolved_lib = self._device_type_manager.resolve_lib(device)
else:
resolved_lib = device_lib

return resolved_lib

def _convert_program(self, program: Any, *, from_lib: str, to_lib: str) -> Any: # noqa: ANN401
if self._can_convert_program_directly(from_lib=from_lib, to_lib=to_lib):
direct_converter = self._program_converter_manager.fetch_converter(
from_lib=from_lib,
to_lib=to_lib,
)
return direct_converter.convert(program)

if not self._can_convert_program_via_qiskit(from_lib=from_lib, to_lib=to_lib):
msg = (
f"No ProgramConverter path found to convert from {from_lib} to {to_lib}"
)
raise ProgramConversionPathNotFoundError(msg)

return self._program_converter_manager.fetch_converter(
"qiskit",
to_lib,
).convert(
self._program_converter_manager.fetch_converter(from_lib, "qiskit").convert(
program,
),
to_qiskit_converter = self._program_converter_manager.fetch_converter(
from_lib, "qiskit"
)
from_qiskit_converter = self._program_converter_manager.fetch_converter(
"qiskit", to_lib
)

qiskit_program = to_qiskit_converter.convert(program)
return from_qiskit_converter.convert(qiskit_program)

def _can_convert_program_directly(self, *, from_lib: str, to_lib: str) -> bool:
return self._program_converter_manager.has_converter(
from_lib=from_lib, to_lib=to_lib
)

def _can_convert_program_via_qiskit(self, *, from_lib: str, to_lib: str) -> bool:
can_convert_to_qiskit = self._program_converter_manager.has_converter(
from_lib=from_lib,
to_lib="qiskit",
)
can_convert_to_target = self._program_converter_manager.has_converter(
from_lib="qiskit",
to_lib=to_lib,
)
return can_convert_to_qiskit and can_convert_to_target

def _convert_device(
self,
device: Any | None, # noqa: ANN401
*,
from_lib: str | None,
to_lib: Any, # noqa: ANN401
to_lib: str,
) -> Any | None: # noqa: ANN401
if device is None or from_lib is None:
if device is None:
return None

if from_lib is None:
return device

if self._device_converter_manager.has_converter(from_lib, to_lib):
return self._device_converter_manager.fetch_converter(
if self._can_convert_device_directly(from_lib=from_lib, to_lib=to_lib):
direct_converter = self._device_converter_manager.fetch_converter(
from_lib,
to_lib,
).convert(device)
)
return direct_converter.convert(device)

if not self._can_convert_device_via_qiskit(from_lib=from_lib, to_lib=to_lib):
msg = (
f"No DeviceConverter path found to convert from {from_lib} to {to_lib}"
)
raise DeviceConversionPathNotFoundError(msg)

to_qiskit_converter = self._device_converter_manager.fetch_converter(
from_lib, "qiskit"
)
from_qiskit_converter = self._device_converter_manager.fetch_converter(
"qiskit", to_lib
)
qiskit_device = to_qiskit_converter.convert(device)
return from_qiskit_converter.convert(qiskit_device)

def _can_convert_device_directly(self, *, from_lib: str, to_lib: str) -> bool:
return self._device_converter_manager.has_converter(from_lib, to_lib)

def _can_convert_device_via_qiskit(self, *, from_lib: str, to_lib: str) -> bool:
can_convert_to_qiskit = self._device_converter_manager.has_converter(
from_lib,
"qiskit",
Expand All @@ -221,14 +266,5 @@ def _convert_device(
"qiskit",
to_lib,
)
if not (can_convert_to_qiskit and can_convert_to_target):
msg = (
f"No DeviceConverter path found to convert from {from_lib} to {to_lib}"
)
raise DeviceConversionPathNotFoundError(msg)

return self._device_converter_manager.fetch_converter("qiskit", to_lib).convert(
self._device_converter_manager.fetch_converter(from_lib, "qiskit").convert(
device,
),
)
return can_convert_to_qiskit and can_convert_to_target
61 changes: 36 additions & 25 deletions tests/tranqu/test_device_type_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,49 +10,58 @@ class DummyDevice:
"""Dummy device class for testing"""


class TestDeviceTypeManager:
def setup_method(self):
self.manager = DeviceTypeManager()
@pytest.fixture
def device_manager() -> DeviceTypeManager:
return DeviceTypeManager()


def test_register_type(self):
self.manager.register_type("dummy", DummyDevice)
class TestDeviceTypeManager:
def test_register_type(self, device_manager: DeviceTypeManager) -> None:
device_manager.register_type("dummy", DummyDevice)
device = DummyDevice()

result = self.manager.detect_lib(device)
result = device_manager.resolve_lib(device)

assert result == "dummy"

def test_detect_lib_returns_none_for_unregistered_type(self):
def test_resolve_lib_returns_none_for_unregistered_type(
self, device_manager: DeviceTypeManager
) -> None:
device = DummyDevice()

result = self.manager.detect_lib(device)
result = device_manager.resolve_lib(device)

assert result is None

def test_detect_lib_with_multiple_registrations(self):
def test_resolve_lib_with_multiple_registrations(
self, device_manager: DeviceTypeManager
) -> None:
class AnotherDummyDevice:
pass

self.manager.register_type("dummy1", DummyDevice)
self.manager.register_type("dummy2", AnotherDummyDevice)
device_manager.register_type("dummy1", DummyDevice)
device_manager.register_type("dummy2", AnotherDummyDevice)

device1 = DummyDevice()
device2 = AnotherDummyDevice()

assert self.manager.detect_lib(device1) == "dummy1"
assert self.manager.detect_lib(device2) == "dummy2"
assert device_manager.resolve_lib(device1) == "dummy1"
assert device_manager.resolve_lib(device2) == "dummy2"

def test_register_type_multiple_times(self):
self.manager.register_type("dummy", DummyDevice)
self.manager.register_type("another_dummy", DummyDevice)
def test_register_type_multiple_times(
self, device_manager: DeviceTypeManager
) -> None:
device_manager.register_type("dummy", DummyDevice)
device_manager.register_type("another_dummy", DummyDevice)

device = DummyDevice()

# The last registered library identifier is returned
assert self.manager.detect_lib(device) == "another_dummy"
assert device_manager.resolve_lib(device) == "another_dummy"

def test_register_type_raises_error_when_library_already_registered(self):
self.manager.register_type("dummy", DummyDevice)
def test_register_type_raises_error_when_library_already_registered(
self, device_manager: DeviceTypeManager
) -> None:
device_manager.register_type("dummy", DummyDevice)

with pytest.raises(
DeviceLibraryAlreadyRegisteredError,
Expand All @@ -61,12 +70,14 @@ def test_register_type_raises_error_when_library_already_registered(self):
"Use allow_override=True to force registration."
),
):
self.manager.register_type("dummy", DummyDevice)
device_manager.register_type("dummy", DummyDevice)

def test_register_type_with_allow_override(self):
self.manager.register_type("dummy", DummyDevice)
def test_register_type_with_allow_override(
self, device_manager: DeviceTypeManager
) -> None:
device_manager.register_type("dummy", DummyDevice)

self.manager.register_type("dummy", DummyDevice, allow_override=True)
device_manager.register_type("dummy", DummyDevice, allow_override=True)

device = DummyDevice()
assert self.manager.detect_lib(device) == "dummy"
assert device_manager.resolve_lib(device) == "dummy"
Loading

0 comments on commit f18069c

Please sign in to comment.