From f62f67ef7c6710aac613665c02f848ac66d1adb7 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Fri, 21 Feb 2025 17:11:52 +0100 Subject: [PATCH] Raise explicit error when join columns cannot be found (#1698) --- pyiceberg/table/__init__.py | 3 +++ tests/table/test_upsert.py | 50 +++++++++++++++++++++++++++++-------- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 8ff299ce6a..d7d29a55ca 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1159,6 +1159,9 @@ def upsert( else: raise ValueError(f"Field-ID could not be found: {join_cols}") + if len(join_cols) == 0: + raise ValueError("Join columns could not be found, please set identifier-field-ids or pass in explicitly.") + if not when_matched_update_all and not when_not_matched_insert_all: raise ValueError("no upsert options selected...exiting") diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index c97015e650..617b7fb501 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -25,6 +25,7 @@ from pyiceberg.exceptions import NoSuchTableError from pyiceberg.expressions import And, EqualTo, Reference from pyiceberg.expressions.literals import LongLiteral +from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.schema import Schema from pyiceberg.table import UpsertResult from pyiceberg.table.upsert_util import create_match_filter @@ -328,7 +329,7 @@ def test_upsert_with_identifier_fields(catalog: Catalog) -> None: schema = Schema( NestedField(1, "city", StringType(), required=True), - NestedField(2, "inhabitants", IntegerType(), required=True), + NestedField(2, "population", IntegerType(), required=True), # Mark City as the identifier field, also known as the primary-key identifier_field_ids=[1], ) @@ -338,17 +339,17 @@ def test_upsert_with_identifier_fields(catalog: Catalog) -> None: arrow_schema = pa.schema( [ pa.field("city", pa.string(), nullable=False), - pa.field("inhabitants", pa.int32(), nullable=False), + pa.field("population", pa.int32(), nullable=False), ] ) # Write some data df = pa.Table.from_pylist( [ - {"city": "Amsterdam", "inhabitants": 921402}, - {"city": "San Francisco", "inhabitants": 808988}, - {"city": "Drachten", "inhabitants": 45019}, - {"city": "Paris", "inhabitants": 2103000}, + {"city": "Amsterdam", "population": 921402}, + {"city": "San Francisco", "population": 808988}, + {"city": "Drachten", "population": 45019}, + {"city": "Paris", "population": 2103000}, ], schema=arrow_schema, ) @@ -356,12 +357,12 @@ def test_upsert_with_identifier_fields(catalog: Catalog) -> None: df = pa.Table.from_pylist( [ - # Will be updated, the inhabitants has been updated - {"city": "Drachten", "inhabitants": 45505}, + # Will be updated, the population has been updated + {"city": "Drachten", "population": 45505}, # New row, will be inserted - {"city": "Berlin", "inhabitants": 3432000}, + {"city": "Berlin", "population": 3432000}, # Ignored, already exists in the table - {"city": "Paris", "inhabitants": 2103000}, + {"city": "Paris", "population": 2103000}, ], schema=arrow_schema, ) @@ -388,3 +389,32 @@ def test_create_match_filter_single_condition() -> None: EqualTo(term=Reference(name="order_id"), literal=LongLiteral(101)), EqualTo(term=Reference(name="order_line_id"), literal=LongLiteral(1)), ) + + +def test_upsert_without_identifier_fields(catalog: Catalog) -> None: + identifier = "default.test_upsert_without_identifier_fields" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "city", StringType(), required=True), + NestedField(2, "population", IntegerType(), required=True), + # No identifier field :o + identifier_field_ids=[], + ) + + tbl = catalog.create_table(identifier, schema=schema) + # Write some data + df = pa.Table.from_pylist( + [ + {"city": "Amsterdam", "population": 921402}, + {"city": "San Francisco", "population": 808988}, + {"city": "Drachten", "population": 45019}, + {"city": "Paris", "population": 2103000}, + ], + schema=schema_to_pyarrow(schema), + ) + + with pytest.raises( + ValueError, match="Join columns could not be found, please set identifier-field-ids or pass in explicitly." + ): + tbl.upsert(df)