Skip to content

Commit

Permalink
Merge pull request #11 from yasuhito/feature/add-program-and-device-t…
Browse files Browse the repository at this point in the history
…ype-registry

feat: add program and device type registry
  • Loading branch information
snuffkin authored Dec 16, 2024
2 parents 540a970 + 732172f commit d4ef9c9
Show file tree
Hide file tree
Showing 8 changed files with 380 additions and 24 deletions.
40 changes: 40 additions & 0 deletions src/tranqu/device_type_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Any

from .tranqu_error import TranquError


class DeviceLibNotFoundError(TranquError):
"""Error raised when device library cannot be detected."""


class DeviceTypeManager:
"""Class that manages mapping between device types and library identifiers."""

def __init__(self) -> None:
self._type_registry: dict[type, str] = {}

def register_type(self, device_lib: str, device_type: type) -> None:
"""Register a device type and its library identifier.
Args:
device_lib (str): Library identifier (e.g., "qiskit", "oqtopus")
device_type (Type): Device type class to register
"""
self._type_registry[device_type] = device_lib

def detect_lib(self, device: Any) -> str | None: # noqa: ANN401
"""Detect library based on device type.
Args:
device (Any): Device to inspect
Returns:
str | None: Library identifier for registered device type, None otherwise
"""
for device_type, lib in self._type_registry.items():
if isinstance(device, device_type):
return lib

return None
34 changes: 34 additions & 0 deletions src/tranqu/program_type_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any


class ProgramTypeManager:
"""Class that manages the mapping between program types and library identifiers."""

def __init__(self) -> None:
self._type_registry: dict[type, str] = {}

def register_type(self, program_lib: str, program_type: type) -> None:
"""Register a program type and its library identifier.
Args:
program_lib (str): Library identifier (e.g., "qiskit", "tket")
program_type (Type): Program type class to register
"""
self._type_registry[program_type] = program_lib

def detect_lib(self, program: Any) -> str | None: # noqa: ANN401
"""Detect the library based on the program type.
Args:
program (Any): Program to inspect
Returns:
str | None: Library identifier for registered program type, None otherwise
"""
for program_type, lib in self._type_registry.items():
if isinstance(program, program_type):
return lib

return None
64 changes: 59 additions & 5 deletions src/tranqu/tranqu.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,16 @@

from typing import TYPE_CHECKING, Any

from pytket import Circuit as TketCircuit
from qiskit import QuantumCircuit as QiskitCircuit # type: ignore[import-untyped]
from qiskit.providers import BackendV2 # type: ignore[import-untyped]

from .device_converter import (
DeviceConverter,
DeviceConverterManager,
OqtoqusToQiskitDeviceConverter,
)
from .device_type_manager import DeviceTypeManager
from .program_converter import (
Openqasm3ToQiskitProgramConverter,
Openqasm3ToTketProgramConverter,
Expand All @@ -92,6 +97,7 @@
TketToOpenqasm3ProgramConverter,
TketToQiskitProgramConverter,
)
from .program_type_manager import ProgramTypeManager
from .transpiler import (
QiskitTranspiler,
TranspilerManager,
Expand All @@ -113,16 +119,20 @@ def __init__(self) -> None:
self._program_converter_manager = ProgramConverterManager()
self._device_converter_manager = DeviceConverterManager()
self._transpiler_manager = TranspilerManager()
self._program_type_manager = ProgramTypeManager()
self._device_type_manager = DeviceTypeManager()

self._register_builtin_program_converters()
self._register_builtin_device_converters()
self._register_builtin_transpilers()
self._register_builtin_program_types()
self._register_builtin_device_types()

def transpile( # noqa: PLR0913
self,
program: Any, # noqa: ANN401
program_lib: str,
transpiler_lib: str,
program_lib: str | None = None,
*,
transpiler_options: dict[str, Any] | None = None,
device: Any | None = None, # noqa: ANN401
Expand All @@ -132,23 +142,24 @@ def transpile( # noqa: PLR0913
Args:
program (Any): The program to be transformed.
program_lib (str): The library or format of the program.
transpiler_lib (str): The name of the transpiler to be used.
program_lib (str | None): The library or format of the program. If None,
will attempt to detect based on program type.
transpiler_options (dict[str, Any]): Options passed to the transpiler.
device (Any | None): Information about the device on which
the program will be executed.
device_lib (str | None): Specifies the type of the device.
Returns:
TranspileResult: The result of the transpilation, including
the transpiled program, various statistical information,
and mapping between virtual and physical quantum bits.
TranspileResult: The result of the transpilation.
"""
dispatcher = TranspilerDispatcher(
self._transpiler_manager,
self._program_converter_manager,
self._device_converter_manager,
self._program_type_manager,
self._device_type_manager,
)

return dispatcher.dispatch(
Expand Down Expand Up @@ -240,6 +251,42 @@ def register_device_converter(
converter,
)

def register_program_type(self, program_lib: str, program_type: type) -> None:
"""Register a mapping between a program type and its library identifier.
This method allows automatic detection of the program library based on the
program's type when calling transpile().
Args:
program_lib (str): The identifier for the program library
(e.g., "qiskit", "tket")
program_type (type): The type class to be associated with the library
Examples:
To register Qiskit's QuantumCircuit type:
tranqu.register_program_type("qiskit", QuantumCircuit)
"""
self._program_type_manager.register_type(program_lib, program_type)

def register_device_type(self, device_lib: str, device_type: type) -> None:
"""Register a mapping between a device type and its library identifier.
This method enables automatic detection of the device library based on
the device type when calling transpile().
Args:
device_lib (str): The identifier for the device library
(e.g., "qiskit", "oqtopus")
device_type (type): The type class to be associated with the library
Examples:
To register Qiskit's backend type:
tranqu.register_device_type("qiskit", BackendV2)
"""
self._device_type_manager.register_type(device_lib, device_type)

def _register_builtin_program_converters(self) -> None:
self.register_program_converter(
"openqasm3",
Expand Down Expand Up @@ -291,3 +338,10 @@ def _register_builtin_device_converters(self) -> None:

def _register_builtin_transpilers(self) -> None:
self.register_transpiler("qiskit", QiskitTranspiler())

def _register_builtin_program_types(self) -> None:
self.register_program_type("qiskit", QiskitCircuit)
self.register_program_type("tket", TketCircuit)

def _register_builtin_device_types(self) -> None:
self.register_device_type("qiskit", BackendV2)
64 changes: 48 additions & 16 deletions src/tranqu/transpiler_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Any

from .device_converter import DeviceConverterManager
from .device_type_manager import DeviceTypeManager
from .program_converter import ProgramConverterManager
from .program_type_manager import ProgramTypeManager
from .tranqu_error import TranquError
from .transpile_result import TranspileResult
from .transpiler import TranspilerManager
Expand All @@ -11,6 +13,10 @@ class TranspilerDispatcherError(TranquError):
"""Base class for errors related to the transpiler dispatcher."""


class ProgramLibNotFoundError(TranquError):
"""Error when program library cannot be detected."""


class ProgramNotSpecifiedError(TranspilerDispatcherError):
"""Error raised when no program is specified."""

Expand Down Expand Up @@ -51,6 +57,10 @@ class TranspilerDispatcher:
quantum programs between different libraries.
device_converter_manager (DeviceConverterManager): Handles conversion of
device specifications between different libraries.
program_type_manager (ProgramTypeManager): Manages detection of program types
and their corresponding libraries.
device_type_manager (DeviceTypeManager): Manages detection of device types
and their corresponding libraries.
"""

Expand All @@ -59,17 +69,21 @@ def __init__(
transpiler_manager: TranspilerManager,
program_converter_manager: ProgramConverterManager,
device_converter_manager: DeviceConverterManager,
program_type_manager: ProgramTypeManager,
device_type_manager: DeviceTypeManager,
) -> None:
self._transpiler_manager = transpiler_manager
self._program_converter_manager = program_converter_manager
self._device_converter_manager = device_converter_manager
self._program_type_manager = program_type_manager
self._device_type_manager = device_type_manager

def dispatch( # noqa: PLR0913 PLR0917
self,
program: Any, # noqa: ANN401
program_lib: str,
program_lib: str | None,
transpiler_lib: str,
transpiler_options: dict | None,
transpiler_options: dict[str, Any] | None,
device: Any | None, # noqa: ANN401
device_lib: str | None,
) -> TranspileResult:
Expand All @@ -89,7 +103,7 @@ def dispatch( # noqa: PLR0913 PLR0917
Raises:
ProgramNotSpecifiedError: Raised when no program is specified.
ProgramLibNotSpecifiedError: Raised when no program library is specified.
ProgramLibNotFoundError: Raised when program library cannot be detected.
TranspilerLibNotSpecifiedError: Raised when no transpiler library
is specified.
DeviceNotSpecifiedError: Raised when a device library is specified
Expand All @@ -99,23 +113,36 @@ def dispatch( # noqa: PLR0913 PLR0917
if program is None:
msg = "No program specified. Please specify a valid quantum circuit."
raise ProgramNotSpecifiedError(msg)
if program_lib is None:
msg = "No program library specified. Please specify a program format "
"('qiskit', 'openqasm3', 'tket', etc.)."
raise ProgramLibNotSpecifiedError(msg)
if transpiler_lib is None:
msg = "No transpiler library specified. Please specify a transpiler to use "
"('qiskit', 'tket', etc.)."
msg = "No transpiler library specified. Please specify a transpiler to use."
raise TranspilerLibNotSpecifiedError(msg)
if device_lib is not None and device is None:
msg = "Device library is specified but no device is specified. "
"Please specify a device."

detected_program_lib = (
self._detect_program_lib(program) if program_lib is None else program_lib
)
if detected_program_lib is None:
msg = (
"Could not detect program library. Please either "
"specify program_lib or register the program type "
"using register_program_type()."
)
raise ProgramLibNotFoundError(msg)

detected_device_lib = (
self._detect_device_lib(device) if device_lib is None else device_lib
)
if detected_device_lib is not None and device is None:
msg = "Device library is specified but no device is specified."
raise DeviceNotSpecifiedError(msg)

transpiler = self._transpiler_manager.fetch_transpiler(transpiler_lib)

converted_program = self._convert_program(program, program_lib, transpiler_lib)
converted_device = self._convert_device(device, device_lib, transpiler_lib)
converted_program = self._convert_program(
program, detected_program_lib, transpiler_lib
)
converted_device = self._convert_device(
device, detected_device_lib, transpiler_lib
)

transpile_result = transpiler.transpile(
converted_program,
Expand All @@ -126,11 +153,16 @@ def dispatch( # noqa: PLR0913 PLR0917
transpile_result.transpiled_program = self._convert_program(
transpile_result.transpiled_program,
transpiler_lib,
program_lib,
detected_program_lib,
)

return transpile_result

def _detect_program_lib(self, program: Any) -> str | None: # noqa: ANN401
return self._program_type_manager.detect_lib(program)

def _detect_device_lib(self, device: Any) -> str | None: # noqa: ANN401
return self._device_type_manager.detect_lib(device)

def _convert_program(self, program: Any, from_lib: str, to_lib: Any) -> Any: # noqa: ANN401
if self._program_converter_manager.has_converter(from_lib, to_lib):
return self._program_converter_manager.fetch_converter(
Expand Down
47 changes: 47 additions & 0 deletions tests/tranqu/test_device_type_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from tranqu.device_type_manager import DeviceTypeManager


class DummyDevice:
"""Dummy device class for testing"""


class TestDeviceTypeManager:
def setup_method(self):
self.manager = DeviceTypeManager()

def test_register_type(self):
self.manager.register_type("dummy", DummyDevice)
device = DummyDevice()

result = self.manager.detect_lib(device)

assert result == "dummy"

def test_detect_lib_returns_none_for_unregistered_type(self):
device = DummyDevice()

result = self.manager.detect_lib(device)

assert result is None

def test_detect_lib_with_multiple_registrations(self):
class AnotherDummyDevice:
pass

self.manager.register_type("dummy1", DummyDevice)
self.manager.register_type("dummy2", AnotherDummyDevice)

device1 = DummyDevice()
device2 = AnotherDummyDevice()

assert self.manager.detect_lib(device1) == "dummy1"
assert self.manager.detect_lib(device2) == "dummy2"

def test_register_type_multiple_times(self):
self.manager.register_type("dummy", DummyDevice)
self.manager.register_type("another_dummy", DummyDevice)

device = DummyDevice()

# The last registered library identifier is returned
assert self.manager.detect_lib(device) == "another_dummy"
Loading

0 comments on commit d4ef9c9

Please sign in to comment.