diff --git a/dissect/cstruct/tools/__init__.py b/dissect/cstruct/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dissect/cstruct/tools/stubify.py b/dissect/cstruct/tools/stubify.py new file mode 100644 index 0000000..8268da7 --- /dev/null +++ b/dissect/cstruct/tools/stubify.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import importlib +import importlib.util +import logging +from argparse import ArgumentParser +from pathlib import Path +from textwrap import indent +from types import FunctionType, ModuleType +from typing import TYPE_CHECKING, Any + +import dissect.cstruct.types as types +from dissect.cstruct import cstruct + +if TYPE_CHECKING: + from collections.abc import Iterable + + +log = logging.getLogger(__name__) + + +def load_module(path: Path, base_path: Path) -> ModuleType | None: + module = None + try: + relative_path = path.relative_to(base_path) + module_tuple = (*relative_path.parent.parts, relative_path.stem) + spec = importlib.util.spec_from_file_location(".".join(module_tuple), path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + except Exception as e: + log.exception("Unable to import %s", path) + log.debug("Error while trying to import module %s", path, exc_info=e) + + return module + + +def to_type(_type: type | Any) -> type: + if not isinstance(_type, type): + _type = type(_type) + return _type + + +def stubify_file(path: Path, base_path: Path) -> str: + tmp_module = load_module(path, base_path) + if tmp_module is None: + return "" + + if not hasattr(tmp_module, "cstruct"): + return "" + + all_types = types.__all__.copy() + all_types.sort() + all_types.append("") + + cstruct_types = indent(",\n".join(all_types), prefix=" " * 4) + result = [ + "from __future__ import annotations\n", + "from typing import overload, BinaryIO\n", + "from typing_extensions import TypeAlias\n", + "from dissect.cstruct import cstruct", + f"from dissect.cstruct.types import (\n{cstruct_types})\n", + ] + + prev_entries = len(result) + + for name, variable in tmp_module.__dict__.items(): + if name.startswith("__"): + continue + + if isinstance(variable, ModuleType): + result.append(f"import {name}") + elif isinstance(variable, cstruct): + result.append(stubify_cstruct(variable, name)) + elif isinstance(variable, (bytes, bytearray, str, int, float, dict, list, tuple)): + result.append(f"{name}: {type(variable).__name__}") + elif isinstance(variable, FunctionType): + anno = variable.__annotations__ + _items = list(anno.items())[:-1] + + args = ", ".join(f"{name}: {to_type(_type).__name__}" for (name, _type) in _items) + + return_value = repr(to_type(anno.get("return")).__name__) + signature = f"def {name}({''.join(args)}) -> {return_value}:" + if variable.__doc__: + result.append(signature) + result.append(f' """{variable.__doc__}"""') + result.append(" ...") + else: + result.append(f"{signature} ...") + elif "dissect.cstruct" in variable.__module__: + if hasattr(variable, "cs"): + result.append(f"{name}: {variable.cs.__type_def_name__}.{variable.__name__}") + elif isinstance(variable, type): + result.append(f"from {variable.__module__} import {name}") + else: + result.append(f"{name}: {variable.__class__.__name__}") + + if prev_entries == len(result): + return "" + + # Empty line at the end of the file + result.append("") + return "\n".join(result) + + +def stubify_cstruct(c_structure: cstruct, name: str = "", ignore_type_defs: Iterable[str] | None = None) -> str: + ignore_type_defs = ignore_type_defs or [] + + result = [] + indentation = "" + if name: + result.append(f"class {name}(cstruct):") + indentation = " " * 4 + c_structure.__type_def_name__ = name + + prev_length = len(result) + for const, value in c_structure.consts.items(): + result.append(indent(f"{const}: {type(value).__name__} = ...", prefix=indentation)) + + if type_defs := stubify_typedefs(c_structure, ignore_type_defs, indentation): + result.append(type_defs) + + if prev_length == len(result): + # an empty definition, add elipses + result.append(indent("...", prefix=indentation)) + + return "\n".join(result) + + +def stubify_typedefs(c_structure: cstruct, ignore_type_defs: Iterable[str] | None = None, indentation: str = "") -> str: + ignore_type_defs = ignore_type_defs or [] + + result = [] + for name, type_def in c_structure.typedefs.items(): + if name in ignore_type_defs: + continue + + if isinstance(type_def, types.MetaType) and (text := type_def.to_type_stub(name)): + result.append(indent(text, prefix=indentation)) + + return "\n".join(result) + + +def setup_logger(verbosity: int) -> None: + level = logging.INFO + if verbosity >= 1: + level = logging.DEBUG + + logging.basicConfig(level=level) + + +def main() -> None: + description = """ + Create stub files for cstruct definitions. + + These stub files are in a `.pyi` format and provides type information to cstruct definitions. + """ + + parser = ArgumentParser("stubify", description=description) + parser.add_argument("path", type=Path) + parser.add_argument("-v", "--verbose", action="count", default=0) + args = parser.parse_args() + + setup_logger(args.verbose) + + file_path: Path = args.path + + iterator = file_path.rglob("*.py") + if file_path.is_file(): + iterator = [file_path] + + for file in iterator: + if file.is_file() and file.suffix == ".py": + stub = stubify_file(file, file_path) + if not stub: + continue + + with file.with_suffix(".pyi").open("wt") as output_file: + log.info("Writing stub of file %s to %s", file, output_file.name) + output_file.write(stub) + + +if __name__ == "__main__": + main() diff --git a/dissect/cstruct/types/base.py b/dissect/cstruct/types/base.py index 1fff7d1..96f0e35 100644 --- a/dissect/cstruct/types/base.py +++ b/dissect/cstruct/types/base.py @@ -178,6 +178,22 @@ def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int: """ return cls._write_array(stream, [*array, cls.__default__()]) + def _class_stub(cls) -> str: + return f"class {cls.__name__}({cls.__base__.__name__}):" + + def _type_stub(cls, name: str = "", underscore: bool = False) -> str: + cls_name = cls.__name__ + if underscore: + cls_name = f"_{cls_name}" + + if cls.__name__ in cls.cs.typedefs and (cs_name := getattr(cls.cs, "__type_def_name__", "")): + cls_name = f"{cs_name}.{cls_name}" + + return f"{name}: {cls_name}" + + def to_type_stub(cls, name: str) -> str: + return "" + class _overload: """Descriptor to use on the ``write`` and ``dumps`` methods on cstruct types. @@ -244,6 +260,14 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Array return cls.type._read_array(stream, num, context) + def default(cls) -> BaseType: + return type.__call__( + cls, [cls.type.default() for _ in range(0 if cls.dynamic or cls.null_terminated else cls.num_entries)] + ) + + def _type_stub(cls, name: str = "", underscore: bool = False) -> str: + return f"{name}: {cls.__base__.__name__}" + class Array(list, BaseType, metaclass=ArrayMetaType): """Implements a fixed or dynamically sized array type. @@ -270,6 +294,15 @@ def _write(cls, stream: BinaryIO, data: list[Any]) -> int: return cls.type._write_array(stream, data) + @classmethod + def _type_stub(cls, name: str = "", underscore: bool = False) -> str: + cls_name = cls.type.__name__ + + if cls_name in cls.cs.typedefs and (cs_name := getattr(cls.cs, "__type_def_name__", "")): + cls_name = f"{cs_name}.{cls_name}" + + return f"{name}: {cls.__base__.__name__}[{cls_name}]" + def _is_readable_type(value: Any) -> bool: return hasattr(value, "read") diff --git a/dissect/cstruct/types/enum.py b/dissect/cstruct/types/enum.py index 869c9fa..5d231df 100644 --- a/dissect/cstruct/types/enum.py +++ b/dissect/cstruct/types/enum.py @@ -84,6 +84,15 @@ def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int: data = [entry.value if isinstance(entry, Enum) else entry for entry in array] return cls._write_array(stream, [*data, cls.type.__default__()]) + def _class_stub(cls) -> str: + return f"class {cls.__name__}({cls.__base__.__name__}, {cls.type.__name__}):" + + def to_type_stub(cls, name: str = "") -> str: + result = [cls._class_stub()] + result.extend(f" {key} = ..." for key in cls.__members__) + + return "\n".join(result) + def _fix_alias_members(cls: type[Enum]) -> None: # Emulate aenum NoAlias behaviour diff --git a/dissect/cstruct/types/packed.py b/dissect/cstruct/types/packed.py index 493e85c..a6b5759 100644 --- a/dissect/cstruct/types/packed.py +++ b/dissect/cstruct/types/packed.py @@ -2,7 +2,7 @@ from functools import lru_cache from struct import Struct -from typing import Any, BinaryIO +from typing import Any, BinaryIO, Generic, TypeVar from dissect.cstruct.types.base import EOF, BaseType @@ -12,17 +12,20 @@ def _struct(endian: str, packchar: str) -> Struct: return Struct(f"{endian}{packchar}") -class Packed(BaseType): +T = TypeVar("T", int, float) + + +class Packed(BaseType, Generic[T]): """Packed type for Python struct (un)packing.""" packchar: str @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Packed: + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Packed[T]: return cls._read_array(stream, 1, context)[0] @classmethod - def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[Packed]: + def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[Packed[T]]: if count == EOF: data = stream.read() length = len(data) @@ -39,7 +42,7 @@ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | Non return [cls.__new__(cls, value) for value in fmt.unpack(data)] @classmethod - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Packed: + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Packed[T]: result = [] fmt = _struct(cls.cs.endian, cls.packchar) @@ -57,9 +60,13 @@ def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Pac return result @classmethod - def _write(cls, stream: BinaryIO, data: Packed) -> int: + def _write(cls, stream: BinaryIO, data: Packed[T]) -> int: return stream.write(_struct(cls.cs.endian, cls.packchar).pack(data)) @classmethod - def _write_array(cls, stream: BinaryIO, data: list[Packed]) -> int: + def _write_array(cls, stream: BinaryIO, data: list[Packed[T]]) -> int: return stream.write(_struct(cls.cs.endian, f"{len(data)}{cls.packchar}").pack(*data)) + + @classmethod + def to_type_stub(cls, name: str) -> str: + return f"{name}: TypeAlias = Packed[{cls.__base__.__name__}]" diff --git a/dissect/cstruct/types/pointer.py b/dissect/cstruct/types/pointer.py index e5c05c3..c7245eb 100644 --- a/dissect/cstruct/types/pointer.py +++ b/dissect/cstruct/types/pointer.py @@ -1,22 +1,24 @@ from __future__ import annotations -from typing import Any, BinaryIO +from typing import Any, BinaryIO, Generic, TypeVar from dissect.cstruct.exceptions import NullPointerDereference from dissect.cstruct.types.base import BaseType, MetaType from dissect.cstruct.types.char import Char from dissect.cstruct.types.void import Void +T = TypeVar("T", bound=MetaType) -class Pointer(int, BaseType): + +class Pointer(int, BaseType, Generic[T]): """Pointer to some other type.""" - type: MetaType + type: T _stream: BinaryIO | None _context: dict[str, Any] | None _value: BaseType - def __new__(cls, value: int, stream: BinaryIO | None, context: dict[str, Any] | None = None) -> Pointer: # noqa: PYI034 + def __new__(cls, value: int, stream: BinaryIO | None, context: dict[str, Any] | None = None) -> Pointer[T]: obj = super().__new__(cls, value) obj._stream = stream obj._context = context @@ -70,15 +72,15 @@ def __default__(cls) -> Pointer: return cls.__new__(cls, cls.cs.pointer.__default__(), None, None) @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Pointer: + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Pointer[T]: return cls.__new__(cls, cls.cs.pointer._read(stream, context), stream, context) @classmethod def _write(cls, stream: BinaryIO, data: int) -> int: return cls.cs.pointer._write(stream, data) - def dereference(self) -> Any: - if self == 0 or self._stream is None: + def dereference(self) -> T: + if self == 0: raise NullPointerDereference if self._value is None and not issubclass(self.type, Void): @@ -97,3 +99,7 @@ def dereference(self) -> Any: self._value = value return self._value + + @classmethod + def _type_stub(cls, name: str = "", underscore: bool = False) -> str: + return f"{name}: {cls.__base__.__name__}[{cls.type.__name__}]" diff --git a/dissect/cstruct/types/structure.py b/dissect/cstruct/types/structure.py index 530dd41..bdb1b00 100644 --- a/dissect/cstruct/types/structure.py +++ b/dissect/cstruct/types/structure.py @@ -6,7 +6,8 @@ from functools import lru_cache from itertools import chain from operator import attrgetter -from textwrap import dedent +from textwrap import dedent, indent +from types import FunctionType from typing import TYPE_CHECKING, Any, BinaryIO, Callable from dissect.cstruct.bitbuffer import BitBuffer @@ -38,6 +39,9 @@ def __repr__(self) -> str: bits_str = f" : {self.bits}" if self.bits else "" return f"" + def type_stub(self, underscore: bool = False) -> str: + return self.type._type_stub(self.name, underscore) + class StructureMetaType(MetaType): """Base metaclass for cstruct structure type classes.""" @@ -369,6 +373,28 @@ def commit(cls) -> None: for key, value in classdict.items(): setattr(cls, key, value) + def to_type_stub(cls, name: str = "", underscore: bool = False) -> str: + result = [f"class {'_' if underscore else ''}{cls.__name__}({cls.__base__.__name__}):"] + args = ["self"] + for field_name, field in cls.fields.items(): + _underscore = field_name == field.type.__name__ + already_defined = field.type.__name__ in cls.cs.typedefs + + if isinstance(field.type, StructureMetaType) and not already_defined: + result.append(indent(field.type.to_type_stub(underscore=_underscore), prefix=" " * 4)) + + result.append(f" {field.type_stub(_underscore)}") + + # Ignore field names from anonymous structures/unions + if field_name in cls.lookup: + args.append(f"{field.type_stub(_underscore)} = ...") + + result.append(indent("@overload", prefix=" " * 4)) + result.append(indent(f"def __init__({', '.join(args)}): ...", prefix=" " * 4)) + result.append(indent("@overload", prefix=" " * 4)) + result.append(indent("def __init__(self, fh: bytes | bytearray | BinaryIO, /): ...", prefix=" " * 4)) + return "\n".join(result) + class Structure(BaseType, metaclass=StructureMetaType): """Base class for cstruct structure type classes.""" diff --git a/out/tests/data/stub_file.pyi b/out/tests/data/stub_file.pyi new file mode 100644 index 0000000..2d5c2df --- /dev/null +++ b/out/tests/data/stub_file.pyi @@ -0,0 +1,4 @@ +from _typeshed import Incomplete + +c_def: str +c_structure: Incomplete diff --git a/pyproject.toml b/pyproject.toml index 9d83009..d2898ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,12 +31,16 @@ homepage = "https://dissect.tools" documentation = "https://docs.dissect.tools/en/latest/projects/dissect.cstruct" repository = "https://github.com/fox-it/dissect.cstruct" +[project.scripts] +cstruct-stubify = "dissect.cstruct.tools.stubify:main" + [tool.ruff] line-length = 120 required-version = ">=0.9.0" [tool.ruff.format] docstring-code-format = true +exclude = ["tests/data/*"] [tool.ruff.lint] select = [ @@ -74,6 +78,7 @@ select = [ "RUF", ] ignore = ["E203", "B904", "UP024", "ANN002", "ANN003", "ANN204", "ANN401", "SIM105", "TRY003"] +exclude = ["tests/data/*"] [tool.ruff.lint.per-file-ignores] "tests/docs/**" = ["INP001"] diff --git a/tests/data/stub_file.py b/tests/data/stub_file.py new file mode 100644 index 0000000..3be9785 --- /dev/null +++ b/tests/data/stub_file.py @@ -0,0 +1,10 @@ +from dissect.cstruct import cstruct + +c_def = """ +struct Test { + uint32 a; + uint32 b; +} +""" + +c_structure = cstruct().load(c_def) diff --git a/tests/data/stub_file.pyi b/tests/data/stub_file.pyi new file mode 100644 index 0000000..a374221 --- /dev/null +++ b/tests/data/stub_file.pyi @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import overload, BinaryIO + +from typing_extensions import TypeAlias + +from dissect.cstruct import cstruct +from dissect.cstruct.types import ( + Array, + ArrayMetaType, + BaseType, + Char, + CharArray, + Enum, + Field, + Flag, + Int, + LEB128, + MetaType, + Packed, + Pointer, + Structure, + Union, + Void, + Wchar, + WcharArray, +) + +c_def: str +class c_structure(cstruct): + int8: TypeAlias = Packed[int] + uint8: TypeAlias = Packed[int] + int16: TypeAlias = Packed[int] + uint16: TypeAlias = Packed[int] + int32: TypeAlias = Packed[int] + uint32: TypeAlias = Packed[int] + int64: TypeAlias = Packed[int] + uint64: TypeAlias = Packed[int] + float16: TypeAlias = Packed[float] + float: TypeAlias = Packed[float] + double: TypeAlias = Packed[float] + class Test(Structure): + a: c_structure.uint32 + b: c_structure.uint32 + @overload + def __init__(self, a: c_structure.uint32 = ..., b: c_structure.uint32 = ...): ... + @overload + def __init__(self, fh: bytes | bytearray | BinaryIO, /): ... diff --git a/tests/test_stubify_functions.py b/tests/test_stubify_functions.py new file mode 100644 index 0000000..201e507 --- /dev/null +++ b/tests/test_stubify_functions.py @@ -0,0 +1,194 @@ +import textwrap +from pathlib import Path + +import pytest + +from dissect.cstruct import cstruct +from dissect.cstruct.tools.stubify import stubify_cstruct, stubify_file, stubify_typedefs +from tests.utils import absolute_path + + +@pytest.mark.parametrize( + ("definition", "name", "expected_stub"), + [ + pytest.param( + """ + struct Test { + int a; + int b; + } + """, + "Test", + """ + class Test(Structure): + a: int32 + b: int32 + @overload + def __init__(self, a: int32 = ..., b: int32 = ...): ... + @overload + def __init__(self, fh: bytes | bytearray | BinaryIO, /): ... + """, + id="standard structure", + ), + pytest.param( + """ + struct Test { + int a[]; + } + """, + "Test", + """ + class Test(Structure): + a: Array[int32] + @overload + def __init__(self, a: Array[int32] = ...): ... + @overload + def __init__(self, fh: bytes | bytearray | BinaryIO, /): ... + """, + id="array", + ), + pytest.param( + """ + #define a 1 + #define b b"data" + #define c "test" + """, + None, + """ + a: int = ... + b: bytes = ... + c: str = ... + """, + id="definitions", + ), + pytest.param( + """ + struct Test { + int *a; + } + """, + "Test", + """ + class Test(Structure): + a: Pointer[int32] + @overload + def __init__(self, a: Pointer[int32] = ...): ... + @overload + def __init__(self, fh: bytes | bytearray | BinaryIO, /): ... + """, + id="pointers", + ), + pytest.param( + """ + enum Test { + A = 1, + B = 2, + C = 2 + }; + """, + "Test", + """ + class Test(Enum, uint32): + A = ... + B = ... + C = ... + """, + id="enums", + ), + pytest.param( + """ + flag Test { + A = 0x00001, + B = 0x00002, + C = 0x00004 + }; + """, + "Test", + """ + class Test(Flag, uint32): + A = ... + B = ... + C = ... + """, + id="flags", + ), + pytest.param( + """ + struct Test{ + union { + wchar a[]; + char b[]; + } + } + """, + "Test", + """ + class Test(Structure): + a: WcharArray + b: CharArray + @overload + def __init__(self): ... + @overload + def __init__(self, fh: bytes | bytearray | BinaryIO, /): ... + """, + id="unions", + ), + pytest.param("""""", "", "...", id="empty"), + ], +) +def test_to_type_stub(definition: str, name: str, expected_stub: str) -> None: + structure = cstruct() + ignore_list = list(structure.typedefs.keys()) + structure.load(definition) + + generated_stub = getattr(structure, name).cs if name else structure + expected_stub = textwrap.dedent(expected_stub).strip() + + assert stubify_cstruct(generated_stub, ignore_type_defs=ignore_list).strip() == expected_stub + + +def test_to_type_stub_empty() -> None: + structure = cstruct() + ignore_list = list(structure.typedefs.keys()) + structure.load("") + + assert stubify_cstruct(structure, "test", ignore_type_defs=ignore_list) == "class test(cstruct):\n ..." + + +def test_stubify_file() -> None: + stub_file = absolute_path("data/stub_file.py") + + output = stubify_file(stub_file, stub_file.parent) + + assert output == absolute_path("data/stub_file.pyi").read_text() + + +def test_stubify_file_unknown_file(tmp_path: Path) -> None: + assert stubify_file(tmp_path.joinpath("unknown_file.py"), tmp_path) == "" + + new_file = tmp_path.joinpath("new_file.py") + new_file.touch() + assert stubify_file(new_file, tmp_path) == "" + + +def test_stubify_typedef() -> None: + structure = cstruct() + expected_output = [ + "int8: TypeAlias = Packed[int]", + "uint8: TypeAlias = Packed[int]", + "int16: TypeAlias = Packed[int]", + "uint16: TypeAlias = Packed[int]", + "int32: TypeAlias = Packed[int]", + "uint32: TypeAlias = Packed[int]", + "int64: TypeAlias = Packed[int]", + "uint64: TypeAlias = Packed[int]", + "float16: TypeAlias = Packed[float]", + "float: TypeAlias = Packed[float]", + "double: TypeAlias = Packed[float]", + ] + + assert stubify_typedefs(structure) == "\n".join(expected_output) + assert stubify_typedefs(structure, ["int8"]) == "\n".join(expected_output[1:]) + assert stubify_typedefs(structure, ["int8", "double"]) == "\n".join(expected_output[1:-1]) + assert "float16: TypeAlias = Packed[float]" not in stubify_typedefs(structure, ["float16"]) + assert stubify_typedefs(structure, structure.typedefs.keys()) == "" diff --git a/tests/utils.py b/tests/utils.py index f70fe20..61071b4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -8,3 +9,7 @@ def verify_compiled(struct: type[Structure], compiled: bool) -> bool: return struct.__compiled__ == compiled + + +def absolute_path(path: str | Path) -> Path: + return Path(__file__).parent.joinpath(path)