diff --git a/.gitignore b/.gitignore index e417b247..b067ceb0 100644 --- a/.gitignore +++ b/.gitignore @@ -129,6 +129,7 @@ ENV/ env.bak/ venv.bak/ .pixi/ +.vscode/ # Spyder project settings .spyderproject diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c50d47f5..5ce0de1b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,12 +11,17 @@ repos: hooks: # Run the linter. - id: ruff - args: [ --fix ] + args: [--fix] # Run the formatter. - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.9.0 + rev: v1.8.0 hooks: - id: mypy - args: [--strict, --install-types, --non-interactive] - additional_dependencies: [sympy, attrs, pytest, click, dask] + args: [ + --strict, + --install-types, + --allow-subclassing-any, + --non-interactive, + ] + additional_dependencies: [sympy, attrs, pytest, click, dask] \ No newline at end of file diff --git a/src/dewret/__main__.py b/src/dewret/__main__.py index 92d947d2..e93e4273 100644 --- a/src/dewret/__main__.py +++ b/src/dewret/__main__.py @@ -133,7 +133,7 @@ def render( if output == "-": @contextmanager - def _opener(key: str, _: str) -> Generator[IO[Any], None, None]: + def _opener(key: str, mode: str) -> Generator[IO[Any], None, None]: print(" ------ ", key, " ------ ") yield sys.stdout print() diff --git a/src/dewret/core.py b/src/dewret/core.py index eecf3033..d10fb1fd 100644 --- a/src/dewret/core.py +++ b/src/dewret/core.py @@ -48,7 +48,8 @@ BasicType = str | float | bool | bytes | int | None RawType = BasicType | list["RawType"] | dict[str, "RawType"] FirmType = RawType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...] -ExprType = (FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...]) # type: ignore +# Basic is from Sympy, which does not have type annotations, so ExprType cannot pass mypy +ExprType = (FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...]) # type: ignore # fmt: skip U = TypeVar("U") T = TypeVar("T") diff --git a/src/dewret/renderers/cwl.py b/src/dewret/renderers/cwl.py index 18c086f4..81706d21 100644 --- a/src/dewret/renderers/cwl.py +++ b/src/dewret/renderers/cwl.py @@ -284,12 +284,15 @@ def from_step(cls, step: BaseStep) -> "StepDefinition": Args: step: step to convert. """ + out: list[str] | dict[str, "CommandInputSchema"] + if attrs_has(step.return_type) or (is_dataclass(step.return_type) and isclass(step.return_type)): + out = to_output_schema("out", step.return_type)["fields"] + else: + out = ["out"] return cls( name=step.name, run=step.task.name, - out=(to_output_schema("out", step.return_type)["fields"]) - if attrs_has(step.return_type) or is_dataclass(step.return_type) - else ["out"], + out=out, in_={ key: ( ReferenceDefinition.from_reference(param) @@ -463,13 +466,14 @@ def to_output_schema( for field in attrs_fields(typ) } elif is_dataclass(typ): - fields = { - str(field.name): cast( - CommandInputSchema, to_output_schema(field.name, field.type) - ) - for field in dataclass_fields(typ) - } - + fields = {} + for field in dataclass_fields(typ): + if isinstance(field.type, type) and issubclass(field.type, RawType | AttrsInstance | DataclassProtocol): + fields[str(field.name)] = cast( + CommandInputSchema, to_output_schema(field.name, field.type) + ) + else: + raise TypeError("Types of fields in results must also be valid result-types themselves (string-defined types not currently allowed)") if fields: output = CommandOutputSchema( type="record", diff --git a/src/dewret/workflow.py b/src/dewret/workflow.py index a37e6d18..2e00f6af 100644 --- a/src/dewret/workflow.py +++ b/src/dewret/workflow.py @@ -928,7 +928,11 @@ def find_field(self: FieldableProtocol, field: str | int, fallback_type: type | type_hints = get_type_hints(parent_type, localns={parent_type.__name__: parent_type}, include_extras=True) field_type = type_hints.get(field) if field_type is None: - field_type = next(iter(filter(lambda fld: fld.name == field, dataclass_fields(parent_type)))).type + dataclass_field_type = next(iter(filter(lambda fld: fld.name == field, dataclass_fields(parent_type)))).type + if isinstance(dataclass_field_type, str): + # TODO: we could ask Python to resolve the str expression for us + raise TypeError("Dataclass fields must be provided as types directly, not str") + field_type = dataclass_field_type except StopIteration: raise AttributeError(f"Dataclass {parent_type} does not have field {field}") from None elif attr_has(parent_type):