Skip to content

Commit

Permalink
added code review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh committed Aug 12, 2024
1 parent 379b1e6 commit 688265e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ commands =
pre-commit: pre-commit run --all-files {posargs}
type-checking: mypy src tests {posargs}
tests: pytest --cov=scanspec --cov-report term --cov-report xml:cov.xml {posargs}
docs: sphinx-{posargs:build -EW --keep-going} -T docs build/html
docs: sphinx-{posargs:build -E --keep-going} -T docs build/html
"""

[tool.ruff]
Expand Down
15 changes: 8 additions & 7 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler
from pydantic.dataclasses import is_pydantic_dataclass, rebuild_dataclass
from pydantic_core import CoreSchema
from pydantic_core.core_schema import tagged_union_schema

__all__ = [
Expand Down Expand Up @@ -102,7 +103,7 @@ def calculate(self) -> int:
)
Args:
cls: The superclass of the union, Expression in the above example
super_cls: The superclass of the union, Expression in the above example
discriminator: The discriminator that will be inserted into the
serialized documents for type determination. Defaults to "type".
Expand Down Expand Up @@ -166,14 +167,14 @@ def __init__(self, base_class: type, discriminator: str):
# Classes and their field names that refer to this tagged union
self._discriminator = discriminator
# The members of the tagged union, i.e. subclasses of the baseclass
self._members: list[type] = []
self._subclasses: list[type] = []
self._references: set[type | Callable] = set()

def add_member(self, cls: type):
if cls in self._members:
if cls in self._subclasses:
return
self._members.append(cls)
for member in self._members:
self._subclasses.append(cls)
for member in self._subclasses:
if member is not cls:
_TaggedUnion._rebuild(member)
for ref in self._references:
Expand All @@ -191,9 +192,9 @@ def _rebuild(cls_or_func: type | Callable):
if issubclass(cls_or_func, BaseModel):
cls_or_func.model_rebuild(force=True)

def schema(self, handler):
def schema(self, handler: GetCoreSchemaHandler) -> CoreSchema:
return tagged_union_schema(
make_schema(tuple(self._members), handler),
make_schema(tuple(self._subclasses), handler),
discriminator=self._discriminator,
ref=self._base_class.__name__,
)
Expand Down

0 comments on commit 688265e

Please sign in to comment.