diff --git a/skore/src/skore/project.py b/skore/src/skore/project.py index 4ce869111..6dc2dadd9 100644 --- a/skore/src/skore/project.py +++ b/skore/src/skore/project.py @@ -2,7 +2,7 @@ import logging from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, Optional, Union from skore.item import ( CrossValidationItem, @@ -38,12 +38,7 @@ def __init__( self.item_repository = item_repository self.view_repository = view_repository - def put( - self, - key: Union[str, dict[str, Any]], - value: Optional[Any] = None, - on_error: Literal["warn", "raise"] = "warn", - ): + def put(self, key: Union[str, dict[str, Any]], value: Optional[Any] = None): """Add one or more key-value pairs to the Project. If `key` is a string, then `put` adds the single `key`-`value` pair mapping to @@ -52,95 +47,64 @@ def put( the Project. If an item with the same key already exists, its value is replaced by the new one. - If `on_error` is "raise", any error stops the execution. If `on_error` - is "warn" (or anything other than "raise"), a warning is shown instead. + + The dict format is the same as equivalent to running `put` for each individual + key-value pair. In other words, + ```python + project.put({"hello": 1, "goodbye": 2}) + ``` + is equivalent to + ```python + project.put("hello", 1) + project.put("goodbye", 2) + ``` + In particular, this means that if some key-value pair is invalid + (e.g. if a key is not a string, or a value's type is not supported), + then all the key-value pairs up to the first offending key-value pair will + be successfully inserted, *and then* an error will be raised. Parameters ---------- key : str | dict[str, Any] The key to associate with `value` in the Project, or dict of key-value pairs to add to the Project. - value : Any | None + value : Any, optional The value to associate with `key` in the Project. If `key` is a dict, this argument is ignored. - on_error : "warn" or "raise", optional - Upon error (e.g. if the key is not a string), whether to raise an error or - to print a warning. Default is "warn". Raises ------ ProjectPutError - If the key-value pair(s) cannot be saved properly, - and `on_error` is "raise". + If the key-value pair(s) cannot be saved properly. """ if isinstance(key, dict): - self.put_several(key, on_error=on_error) + for key_, value in key.items(): + self.put_one(key_, value) else: - self.put_one(key, value, on_error=on_error) + self.put_one(key, value) - def put_one( - self, key: str, value: Any, on_error: Literal["warn", "raise"] = "warn" - ): + def put_one(self, key: str, value: Any): """Add a key-value pair to the Project. - If `on_error` is "raise", any error stops the execution. If `on_error` - is "warn" (or anything other than "raise"), a warning is shown instead. - Parameters ---------- key : str The key to associate with `value` in the Project. Must be a string. value : Any The value to associate with `key` in the Project. - on_error : {"warn", "raise"}, optional - Upon error (e.g. if the key is not a string), whether to raise an error or - to print a warning. Default is "warn". Raises ------ ProjectPutError - If the key-value pair cannot be saved properly, and `on_error` is "raise". + If the key-value pair cannot be saved properly. """ try: item = object_to_item(value) self.put_item(key, item) except (NotImplementedError, TypeError) as e: - if on_error == "raise": - raise ProjectPutError( - "Key-value pair could not be inserted in the Project" - ) from e - - logger.warning( - "Key-value pair could not be inserted in the Project " - f"due to the following error: {e}" - ) - - def put_several( - self, key_to_value: dict, on_error: Literal["warn", "raise"] = "warn" - ): - """Add several values to the Project. - - If `on_error` is "raise", the first error stops the execution (so the - later key-value pairs will not be inserted). If `on_error` is "warn" (or - anything other than "raise"), errors do not stop the execution, and are - shown as they come as warnings; all the valid key-value pairs are inserted. - - Parameters - ---------- - key_to_value : dict[str, Any] - The key-value pairs to put in the Project. Keys must be strings. - on_error : {"warn", "raise"}, optional - Upon error (e.g. if a key is not a string), whether to raise an error or - to print a warning. Default is "warn". - - Raises - ------ - ProjectPutError - If a key-value pair in `key_to_value` cannot be saved properly, - and `on_error` is "raise". - """ - for key, value in key_to_value.items(): - self.put_one(key, value, on_error=on_error) + raise ProjectPutError( + "Key-value pair could not be inserted in the Project" + ) from e def put_item(self, key: str, item: Item): """Add an Item to the Project.""" diff --git a/skore/tests/unit/test_project.py b/skore/tests/unit/test_project.py index dc5c795ff..9f240db11 100644 --- a/skore/tests/unit/test_project.py +++ b/skore/tests/unit/test_project.py @@ -140,7 +140,7 @@ def test_put_kwargs(in_memory_project): def test_put_wrong_key_type(in_memory_project): with pytest.raises(ProjectPutError): - in_memory_project.put(key=2, value=1, on_error="raise") + in_memory_project.put(key=2, value=1) assert in_memory_project.list_item_keys() == [] @@ -151,13 +151,6 @@ def test_put_twice(in_memory_project): assert in_memory_project.get("key2") == 5 -def test_put_int_key(in_memory_project, caplog): - # Warns that 0 is not a string, but doesn't raise - in_memory_project.put(0, "hello") - assert len(caplog.record_tuples) == 1 - assert in_memory_project.list_item_keys() == [] - - def test_get(in_memory_project): in_memory_project.put("key1", 1) assert in_memory_project.get("key1") == 1 @@ -203,21 +196,9 @@ def test_put_several_happy_path(in_memory_project): assert in_memory_project.list_item_keys() == ["a", "b"] -def test_put_several_canonical(in_memory_project): - """Use `put_several` instead of the `put` alias.""" - in_memory_project.put_several({"a": "foo", "b": "bar"}) - assert in_memory_project.list_item_keys() == ["a", "b"] - - -def test_put_several_some_errors(in_memory_project, caplog): - in_memory_project.put( - { - 0: "hello", - 1: "hello", - 2: "hello", - } - ) - assert len(caplog.record_tuples) == 3 +def test_put_several_some_errors(in_memory_project): + with pytest.raises(ProjectPutError): + in_memory_project.put({0: "hello", 1: "hello", 2: "hello"}) assert in_memory_project.list_item_keys() == [] @@ -229,23 +210,25 @@ def test_put_several_nested(in_memory_project): def test_put_several_error(in_memory_project): """If some key-value pairs are wrong, add all that are valid and print a warning.""" - in_memory_project.put({"a": "foo", "b": (lambda: "unsupported object")}) + with pytest.raises(ProjectPutError): + in_memory_project.put({"a": "foo", "b": (lambda: "unsupported object")}) assert in_memory_project.list_item_keys() == ["a"] def test_put_key_is_a_tuple(in_memory_project): """If key is not a string, warn.""" - in_memory_project.put(("a", "foo"), ("b", "bar")) + with pytest.raises(ProjectPutError): + in_memory_project.put(("a", "foo"), ("b", "bar")) assert in_memory_project.list_item_keys() == [] def test_put_key_is_a_set(in_memory_project): """Cannot use an unhashable type as a key.""" with pytest.raises(ProjectPutError): - in_memory_project.put(set(), "hello", on_error="raise") + in_memory_project.put(set(), "hello") def test_put_wrong_key_and_value_raise(in_memory_project): """When `on_error` is "raise", raise the first error that occurs.""" with pytest.raises(ProjectPutError): - in_memory_project.put(0, (lambda: "unsupported object"), on_error="raise") + in_memory_project.put(0, (lambda: "unsupported object"))