Skip to content

Commit

Permalink
Adding stub creation tooling
Browse files Browse the repository at this point in the history
  • Loading branch information
Miauwkeru committed Feb 27, 2024
1 parent 0ef089a commit 56cb0d9
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 1 deletion.
16 changes: 16 additions & 0 deletions dissect/cstruct/cstruct.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import ctypes as _ctypes
import io
import struct
import sys
import types
Expand Down Expand Up @@ -407,6 +408,21 @@ def _make_union(
) -> type[Structure]:
return self._make_struct(name, fields, align=align, anonymous=anonymous, base=Union)

def to_stub(self, name: str = ""):
output_data = io.StringIO()

for const, value in self.consts.items():
output_data.write(f"{const}: {type(value).__name__}=...\n")

for name, type_def in self.typedefs.items():
if not isinstance(type_def, str):
output_data.write(type_def.to_stub(name))
output_data.write("\n")

output_value = output_data.getvalue()
output_data.close()
return output_value


def ctypes(structure: Structure) -> _ctypes.Structure:
"""Create ctypes structures from cstruct structures."""
Expand Down
24 changes: 24 additions & 0 deletions dissect/cstruct/tools/stubify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Searches and creates a stub of a cstruct definitions
from argparse import ArgumentParser
from importlib import import_module
from pathlib import Path


def stubify_file(path: Path):
...


def main():
parser = ArgumentParser("stubify")
parser.add_argument("path", type=Path, required=True)
args = parser.parse_args()

file_path: Path = args.path

for file in file_path.glob("*.py"):
if file.is_file() and ".py" in file.suffixes:
stubify_file(file)


if __name__ == "__main__":
main()
7 changes: 7 additions & 0 deletions dissect/cstruct/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,13 @@ def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int:
"""
return cls._write_array(stream, array + [cls()])

def to_stub(cls, name: str = "") -> str:
output_str = ""
if bases := getattr(cls, "__bases__", None):
output_str = bases[0].__name__

return f"{name}: {output_str}"


class _overload:
"""Descriptor to use on the ``write`` and ``dumps`` methods on cstruct types.
Expand Down
15 changes: 14 additions & 1 deletion dissect/cstruct/types/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from contextlib import contextmanager
from functools import lru_cache
from operator import attrgetter
from textwrap import dedent
from textwrap import dedent, indent
from types import FunctionType
from typing import Any, BinaryIO, Callable, ContextManager, Optional

Expand Down Expand Up @@ -363,6 +363,19 @@ def commit(cls) -> None:
for key, value in classdict.items():
setattr(cls, key, value)

def to_stub(cls, name: str = ""):
with io.StringIO() as data:
data.write(f"class {cls.__name__}:\n")
call_args = ["self"]
for field in cls.__fields__:
if not getattr(field.type, "__anonymous__", False):
type_info = f"{field.name}{field.type.to_stub()}"
call_args.append(f"{type_info}=...")
data.write(indent(f"{type_info}\n", prefix=" " * 4))
call = ", ".join(call_args)
data.write(indent(f"def __call__({call}): ...", prefix=" " * 4))
return data.getvalue()


class Structure(BaseType, metaclass=StructureMetaType):
"""Base class for cstruct structure type classes."""
Expand Down
78 changes: 78 additions & 0 deletions tests/test_stub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import textwrap
from pathlib import Path

import pytest

from dissect.cstruct import cstruct


@pytest.mark.parametrize(
"definition, name, expected_stub",
[
(
"""
struct Test {
int a;
int b;
}
""",
"Test",
"""
class Test:
a: int
b: int
def __call__(self, a: int=..., b: int=...): ...
""",
),
(
"""
struct Test {
int a[];
}
""",
"Test",
"""
class Test:
a: Array
def __call__(self, a: Array=...): ...
""",
),
(
"""
#define a 1
#define b b"data"
#define c "test"
""",
None,
"""
a: int=...
b: bytes=...
c: str=...
""",
),
(
"""
struct Test{
union {
int a;
int b;
}
}
""",
"Test",
"""""",
),
],
ids=["standard structure", "array", "definitions", "unions"],
)
def test_to_stub(definition: str, name: str, expected_stub: str):
structure = cstruct()
structure.load(definition)

if name:
generated_stub = getattr(structure, name).to_stub()
else:
generated_stub = structure.to_stub()
expected_stub = textwrap.dedent(expected_stub).strip()

assert expected_stub in generated_stub

0 comments on commit 56cb0d9

Please sign in to comment.