Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mypy Errors Triggered in Local Development #58

Open
wants to merge 4 commits into
base: release/0.10.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ ENV/
env.bak/
venv.bak/
.pixi/
.vscode/

# Spyder project settings
.spyderproject
Expand Down
13 changes: 9 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@ repos:
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
args: [--fix]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
rev: v1.8.0
hooks:
- id: mypy
args: [--strict, --install-types, --non-interactive]
additional_dependencies: [sympy, attrs, pytest, click, dask]
args: [
--strict,
--install-types,
--allow-subclassing-any,
--non-interactive,
]
additional_dependencies: [sympy, attrs, pytest, click, dask]
2 changes: 1 addition & 1 deletion src/dewret/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def render(
if output == "-":

@contextmanager
def _opener(key: str, _: str) -> Generator[IO[Any], None, None]:
def _opener(key: str, mode: str) -> Generator[IO[Any], None, None]:
print(" ------ ", key, " ------ ")
yield sys.stdout
print()
Expand Down
3 changes: 2 additions & 1 deletion src/dewret/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
BasicType = str | float | bool | bytes | int | None
RawType = BasicType | list["RawType"] | dict[str, "RawType"]
FirmType = RawType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...]
ExprType = (FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...]) # type: ignore
elleryames marked this conversation as resolved.
Show resolved Hide resolved
# Basic is from Sympy, which does not have type annotations, so ExprType cannot pass mypy
ExprType = (FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...]) # type: ignore # fmt: skip

U = TypeVar("U")
T = TypeVar("T")
Expand Down
24 changes: 14 additions & 10 deletions src/dewret/renderers/cwl.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,15 @@ def from_step(cls, step: BaseStep) -> "StepDefinition":
Args:
step: step to convert.
"""
out: list[str] | dict[str, "CommandInputSchema"]
if attrs_has(step.return_type) or (is_dataclass(step.return_type) and isclass(step.return_type)):
out = to_output_schema("out", step.return_type)["fields"]
else:
out = ["out"]
return cls(
name=step.name,
run=step.task.name,
out=(to_output_schema("out", step.return_type)["fields"])
if attrs_has(step.return_type) or is_dataclass(step.return_type)
else ["out"],
out=out,
in_={
key: (
ReferenceDefinition.from_reference(param)
Expand Down Expand Up @@ -463,13 +466,14 @@ def to_output_schema(
for field in attrs_fields(typ)
}
elif is_dataclass(typ):
fields = {
str(field.name): cast(
CommandInputSchema, to_output_schema(field.name, field.type)
)
for field in dataclass_fields(typ)
}

fields = {}
for field in dataclass_fields(typ):
if isinstance(field.type, type) and issubclass(field.type, RawType | AttrsInstance | DataclassProtocol):
fields[str(field.name)] = cast(
CommandInputSchema, to_output_schema(field.name, field.type)
)
else:
raise TypeError("Types of fields in results must also be valid result-types themselves (string-defined types not currently allowed)")
if fields:
output = CommandOutputSchema(
type="record",
Expand Down
6 changes: 5 additions & 1 deletion src/dewret/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,11 @@ def find_field(self: FieldableProtocol, field: str | int, fallback_type: type |
type_hints = get_type_hints(parent_type, localns={parent_type.__name__: parent_type}, include_extras=True)
field_type = type_hints.get(field)
if field_type is None:
field_type = next(iter(filter(lambda fld: fld.name == field, dataclass_fields(parent_type)))).type
dataclass_field_type = next(iter(filter(lambda fld: fld.name == field, dataclass_fields(parent_type)))).type
if isinstance(dataclass_field_type, str):
# TODO: we could ask Python to resolve the str expression for us
raise TypeError("Dataclass fields must be provided as types directly, not str")
field_type = dataclass_field_type
except StopIteration:
raise AttributeError(f"Dataclass {parent_type} does not have field {field}") from None
elif attr_has(parent_type):
Expand Down
Loading