From 688265e1a05612362797a3e6f2a6f9bfa276baca Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Mon, 12 Aug 2024 11:52:27 +0100 Subject: [PATCH] added code review changes --- pyproject.toml | 2 +- src/scanspec/core.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 38c13f8f..e8e6bf18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,7 +106,7 @@ commands = pre-commit: pre-commit run --all-files {posargs} type-checking: mypy src tests {posargs} tests: pytest --cov=scanspec --cov-report term --cov-report xml:cov.xml {posargs} - docs: sphinx-{posargs:build -EW --keep-going} -T docs build/html + docs: sphinx-{posargs:build -E --keep-going} -T docs build/html """ [tool.ruff] diff --git a/src/scanspec/core.py b/src/scanspec/core.py index aa01bffd..74e0dffd 100644 --- a/src/scanspec/core.py +++ b/src/scanspec/core.py @@ -15,6 +15,7 @@ import numpy as np from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler from pydantic.dataclasses import is_pydantic_dataclass, rebuild_dataclass +from pydantic_core import CoreSchema from pydantic_core.core_schema import tagged_union_schema __all__ = [ @@ -102,7 +103,7 @@ def calculate(self) -> int: ) Args: - cls: The superclass of the union, Expression in the above example + super_cls: The superclass of the union, Expression in the above example discriminator: The discriminator that will be inserted into the serialized documents for type determination. Defaults to "type". @@ -166,14 +167,14 @@ def __init__(self, base_class: type, discriminator: str): # Classes and their field names that refer to this tagged union self._discriminator = discriminator # The members of the tagged union, i.e. subclasses of the baseclass - self._members: list[type] = [] + self._subclasses: list[type] = [] self._references: set[type | Callable] = set() def add_member(self, cls: type): - if cls in self._members: + if cls in self._subclasses: return - self._members.append(cls) - for member in self._members: + self._subclasses.append(cls) + for member in self._subclasses: if member is not cls: _TaggedUnion._rebuild(member) for ref in self._references: @@ -191,9 +192,9 @@ def _rebuild(cls_or_func: type | Callable): if issubclass(cls_or_func, BaseModel): cls_or_func.model_rebuild(force=True) - def schema(self, handler): + def schema(self, handler: GetCoreSchemaHandler) -> CoreSchema: return tagged_union_schema( - make_schema(tuple(self._members), handler), + make_schema(tuple(self._subclasses), handler), discriminator=self._discriminator, ref=self._base_class.__name__, )