Skip to content

Commit

Permalink
Move to pyright and fix type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
callumforrester committed Aug 14, 2024
1 parent d5cce50 commit efb22fc
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 31 deletions.
38 changes: 21 additions & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
]
description = "Specify step and flyscan paths in a serializable, efficient and Pythonic way"
dependencies = [
"numpy>=2",
"click>=8.1",
"pydantic>=2.0",
]
dependencies = ["numpy>=2", "click>=8.1", "pydantic>=2.0"]
dynamic = ["version"]
license.file = "LICENSE"
readme = "README.md"
Expand All @@ -33,11 +29,11 @@ dev = [
"scanspec[plotting]",
"scanspec[service]",
"copier",
"mypy",
"myst-parser",
"pipdeptree",
"pre-commit",
"pydata-sphinx-theme>=0.12",
"pyright",
"pytest",
"pytest-cov",
"ruff",
Expand Down Expand Up @@ -65,8 +61,9 @@ name = "Tom Cobb"
[tool.setuptools_scm]
write_to = "src/scanspec/_version.py"

[tool.mypy]
ignore_missing_imports = true # Ignore missing stubs in imported modules
[tool.pyright]
# strict = ["src", "tests"]
reportMissingImports = false # Ignore missing stubs in imported modules

[tool.pytest.ini_options]
# Run pytest with all our checkers, and don't spam us with massive tracebacks on error
Expand Down Expand Up @@ -99,12 +96,12 @@ passenv = *
allowlist_externals =
pytest
pre-commit
mypy
pyright
sphinx-build
sphinx-autobuild
commands =
pre-commit: pre-commit run --all-files {posargs}
type-checking: mypy src tests {posargs}
type-checking: pyright src tests {posargs}
tests: pytest --cov=scanspec --cov-report term --cov-report xml:cov.xml {posargs}
docs: sphinx-{posargs:build -E --keep-going} -T docs build/html
"""
Expand All @@ -115,14 +112,21 @@ line-length = 88

[tool.ruff.lint]
extend-select = [
"B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b
"C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4
"E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e
"F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f
"W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w
"I", # isort - https://docs.astral.sh/ruff/rules/#isort-i
"UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up
"B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b
"C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4
"E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e
"F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f
"W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w
"I", # isort - https://docs.astral.sh/ruff/rules/#isort-i
"UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up
"SLF", # self - https://docs.astral.sh/ruff/settings/#lintflake8-self
]
ignore = [
"B008", # We use function calls in service arguments
]

[tool.ruff.lint.per-file-ignores]
# By default, private member access is allowed in tests
# See https://github.com/DiamondLightSource/python-copier-template/issues/154
# Remove this line to forbid private member access in tests
"tests/**/*" = ["SLF001"]
4 changes: 4 additions & 0 deletions src/scanspec/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import string
from typing import Callable

import click

Expand All @@ -25,6 +26,9 @@ def cli(ctx, log_level: str):

# if no command is supplied, print the help message
if ctx.invoked_subcommand is None:
# We need to prove that cli has been converted to a command
# by the click decorator to keep pyright happy.
assert isinstance(cli, click.Command)
click.echo(cli.get_help(ctx))


Expand Down
12 changes: 6 additions & 6 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@

StrictConfig: ConfigDict = {"extra": "forbid"}

C = TypeVar("C")
T = TypeVar("T", type, Callable)


def discriminated_union_of_subclasses(
super_cls: type,
super_cls: type[C],
discriminator: str = "type",
) -> type:
) -> type[C]:
"""Add all subclasses of super_cls to a discriminated union.
For all subclasses of super_cls, add a discriminator field to identify
Expand Down Expand Up @@ -137,9 +140,6 @@ def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler):
return super_cls


T = TypeVar("T", type, Callable)


def uses_tagged_union(cls_or_func: T) -> T:
"""
T = TypeVar("T", type, Callable)
Expand Down Expand Up @@ -616,7 +616,7 @@ def consume(self, num: int | None = None) -> Frames[Axis]:

def __len__(self) -> int:
"""Number of frames left in a scan, reduces when `consume` is called."""
return self.end_index - self.index
return int(self.end_index - self.index)


class Midpoints(Generic[Axis]):
Expand Down
18 changes: 12 additions & 6 deletions src/scanspec/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, xs, ys, zs, *args, **kwargs):
# Added here because of https://github.com/matplotlib/matplotlib/issues/21688
def do_3d_projection(self, renderer=None):
xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) # type: ignore
self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))

return np.min(zs)
Expand Down Expand Up @@ -109,11 +109,17 @@ def plot_spec(spec: Spec[Any], title: str | None = None):
# Setup axes
if ndims > 2:
plt.figure(figsize=(6, 6))
plt_axes: Axes3D = plt.axes(projection="3d")
plt_axes = plt.axes(projection="3d")
plt_axes.grid(False)
plt_axes.set_zlabel(axes[-3])
plt_axes.set_ylabel(axes[-2])
plt_axes.view_init(elev=15)
if isinstance(plt_axes, Axes3D):
plt_axes.set_zlabel(axes[-3])
plt_axes.set_ylabel(axes[-2])
plt_axes.view_init(elev=15)
else:
raise TypeError(
"Expected matplotlib to create an Axes3D object, "
f"instead got: {plt_axes}"
)
elif ndims == 2:
plt.figure(figsize=(6, 6))
plt_axes = plt.axes()
Expand Down Expand Up @@ -208,7 +214,7 @@ def plot_spec(spec: Spec[Any], title: str | None = None):
_plot_arrow(plt_axes, arrow_arr)
elif splines:
# Plot the starting arrow in the direction of the first point
arrow_arr = [(2 * a[0] - a[1], a[0]) for a in splines[0]]
arrow_arr = [np.array([2 * a[0] - a[1], a[0]]) for a in splines[0]]
_plot_arrow(plt_axes, arrow_arr)
else:
# First point isn't moving, put a right caret marker
Expand Down
3 changes: 2 additions & 1 deletion src/scanspec/sphinxext.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager

from docutils.statemachine import StringList
from matplotlib.sphinxext import plot_directive

from . import __version__
Expand All @@ -25,7 +26,7 @@ class ExampleSpecDirective(plot_directive.PlotDirective):
"""Runs `plot_spec` on the ``spec`` definied in the content."""

def run(self):
self.content = (
self.content = StringList(
["# Example Spec", "", "from scanspec.plot import plot_spec"]
+ [str(x) for x in self.content]
+ ["plot_spec(spec)"]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def test_gap_repeat() -> None:

def test_gap_repeat_non_snake() -> None:
# Check that no gap doesn't propogate to dim.gap for non-snaked axis
spec: Spec[Any] = Repeat(3, gap=False) * Line.bounded(x, 11, 19, 1)
spec: Spec[str] = Repeat(3, gap=False) * Line.bounded(x, 11, 19, 1)
dim = spec.frames()
assert len(dim) == 3
assert dim.lower == {x: pytest.approx([11, 11, 11])}
Expand Down

0 comments on commit efb22fc

Please sign in to comment.