-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
project dataclass - warn on extra kwargs (#474)
* override domain project init * add domain extra kwarg test * parametrize test and test warning * clean up
- Loading branch information
Showing
2 changed files
with
95 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,80 @@ | ||
from __future__ import annotations | ||
import datetime | ||
import logging | ||
import uuid | ||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Optional | ||
|
||
from dataclasses import dataclass, field | ||
from datetime import datetime | ||
from typing import Optional | ||
if TYPE_CHECKING: | ||
from rubicon_ml.domain.utils import TrainingMetadata | ||
|
||
from rubicon_ml.domain.utils import TrainingMetadata, uuid | ||
LOGGER = logging.getLogger() | ||
|
||
|
||
@dataclass | ||
@dataclass(init=False) | ||
class Project: | ||
"""A domain-level project. | ||
Parameters | ||
---------- | ||
name : str | ||
The project's name. | ||
created_at : datetime, optional | ||
The date and time the project was created. Defaults to `None` and uses | ||
`datetime.datetime.now` to generate a UTC timestamp. `created_at` should be | ||
left as `None` to allow for automatic generation. | ||
description : str, optional | ||
A description of the project. Defaults to `None`. | ||
github_url : str, optional | ||
The URL of the GitHub repository associated with this project. Defaults to | ||
`None`. | ||
id : str, optional | ||
The project's unique identifier. Defaults to `None` and uses `uuid.uuid4` | ||
to generate a unique ID. `id` should be left as `None` to allow for automatic | ||
generation. | ||
training_metadata : rubicon_ml.domain.utils.TrainingMetadata, optional | ||
Additional metadata pertaining to any data this project was trained on. | ||
Defaults to `None`. | ||
""" | ||
|
||
name: str | ||
|
||
id: str = field(default_factory=uuid.uuid4) | ||
created_at: Optional[datetime.datetime] = None | ||
description: Optional[str] = None | ||
github_url: Optional[str] = None | ||
training_metadata: Optional[TrainingMetadata] = None | ||
created_at: datetime = field(default_factory=datetime.utcnow) | ||
id: Optional[str] = None | ||
training_metadata: Optional["TrainingMetadata"] = None | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
created_at: Optional[datetime.datetime] = None, | ||
description: Optional[str] = None, | ||
github_url: Optional[str] = None, | ||
id: Optional[str] = None, | ||
training_metadata: Optional["TrainingMetadata"] = None, | ||
**kwargs, | ||
): | ||
"""Initialize this domain project.""" | ||
|
||
self.name = name | ||
|
||
self.created_at = created_at | ||
self.description = description | ||
self.github_url = github_url | ||
self.id = id | ||
self.training_metadata = training_metadata | ||
|
||
if self.created_at is None: | ||
try: # `datetime.UTC` added & `datetime.utcnow` deprecated in Python 3.11 | ||
self.created_at = datetime.datetime.now(datetime.UTC) | ||
except AttributeError: | ||
self.created_at = datetime.datetime.utcnow() | ||
|
||
if self.id is None: | ||
self.id = str(uuid.uuid4()) | ||
|
||
if kwargs: # replaces `dataclass` behavior of erroring on unexpected kwargs | ||
LOGGER.warning( | ||
f"{self.__class__.__name__}.__init__() got an unexpected keyword " | ||
f"argument(s): `{'`, `'.join([key for key in kwargs])}`" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from unittest import mock | ||
|
||
import pytest | ||
|
||
from rubicon_ml.domain.project import Project | ||
|
||
|
||
@pytest.mark.parametrize( | ||
["domain_cls", "required_kwargs"], | ||
[(Project, {"name": "test_domain_extra_kwargs"})], | ||
) | ||
def test_domain_extra_kwargs(domain_cls, required_kwargs): | ||
with mock.patch( | ||
f"rubicon_ml.domain.{domain_cls.__name__.lower()}.LOGGER.warning" | ||
) as mock_logger_warning: | ||
domain = domain_cls(extra="extra", **required_kwargs) | ||
|
||
mock_logger_warning.assert_called_once_with( | ||
f"{domain_cls.__name__}.__init__() got an unexpected keyword argument(s): `extra`", | ||
) | ||
|
||
assert "extra" not in domain.__dict__ | ||
for key, value in required_kwargs.items(): | ||
assert getattr(domain, key) == value |