Skip to content

Commit

Permalink
Add extra logic for checking unions against other unions. Style/checks
Browse files Browse the repository at this point in the history
Union to union checks make a full iteration and take compatibility into account
Add more tests for fw / backward compat on unions
  • Loading branch information
Tincu Gabriel committed Jun 4, 2020
1 parent 4921f95 commit 6d20218
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 27 deletions.
73 changes: 46 additions & 27 deletions karapace/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,36 +216,55 @@ def check_simple_value(self, source, target):
def extract_schema_if_union(self, source, target):
source_union = isinstance(source, (avro.schema.UnionSchema, tuple))
target_union = isinstance(target, (avro.schema.UnionSchema, tuple))

found = False
# Nothing to do here as neither is an union value
if not source_union and not target_union:
return source, target
yield source, target

# Unions and union compatibility with non-union types requires special handling so go through them here.
if source_union and target_union:
# First schema in source that matches target will be used
for source_schema in self.get_schema_field(source):
for target_schema in self.get_schema_field(target):
if source_schema.type == target_schema.type:
return source_schema, target_schema
raise IncompatibleSchema("Matching schema in union not found")

if source_union and not target_union:
elif source_union and target_union:
target_idx_found = set()
source_idx_found = set()
source_schema_fields = self.get_schema_field(source)
target_schema_fields = self.get_schema_field(target)
for i, source_schema in enumerate(source_schema_fields):
for j, target_schema in enumerate(target_schema_fields):
# some types are unhashable
if source_schema.type == target_schema.type and j not in target_idx_found and i not in source_idx_found:
target_idx_found.add(j)
source_idx_found.add(i)
yield source_schema, target_schema
if len(target_idx_found) < len(target_schema_fields) and len(source_idx_found) < len(source_schema_fields):
# sets overlap only
raise IncompatibleSchema("Union types are incompatible")
if len(target_idx_found) < len(target_schema_fields) and self._checking_for in {"FORWARD", "FULL"}:
raise IncompatibleSchema("Previous union contains more types")
if len(source_idx_found) < len(source_schema_fields) and self._checking_for in {"BACKWARD", "FULL"}:
raise IncompatibleSchema("Previous union contains less types")

elif source_union and not target_union:
for schema in self.get_schema_field(source):
if schema.type == target.type:
if self._checking_for in {"BACKWARD", "FULL"}:
raise IncompatibleSchema("Incompatible union for source: {} and target: {}".format(source, target))
return schema, target
raise IncompatibleSchema("Matching schema in union not found")
yield schema, target
found = True
break
if not found:
raise IncompatibleSchema("Matching schema in union not found")

if not source_union and target_union:
elif not source_union and target_union:
for schema in self.get_schema_field(target):
if schema.type == source.type:
if self._checking_for in {"FORWARD", "FULL"}:
raise IncompatibleSchema("Incompatible union for source: {} and target: {}".format(source, target))
return source, schema
raise IncompatibleSchema("Matching schema in union not found")
return None, None
yield source, schema
found = True
break
if not found:
raise IncompatibleSchema("Matching schema in union not found")
else:
yield None, None

def iterate_over_record_source_fields(self, source, target):
for source_field in source.fields:
Expand Down Expand Up @@ -275,11 +294,11 @@ def iterate_over_record_source_fields(self, source, target):
break

# Simple presentation form for Union fields. Extract the correct schemas already here.
source_field_value, target_field_value = self.extract_schema_if_union(
for source_field_value, target_field_value in self.extract_schema_if_union(
source_field_value, target_field_value
)
self.log.info("Recursing source with: source: %s target: %s", source_field_value, target_field_value)
self.check_compatibility(source_field_value, target_field_value)
):
self.log.info("Recursing source with: source: %s target: %s", source_field_value, target_field_value)
self.check_compatibility(source_field_value, target_field_value)
else:
if not self.check_type_promotion(source_field.type, target_field.type):
raise IncompatibleSchema(
Expand Down Expand Up @@ -319,11 +338,11 @@ def iterate_over_record_target_fields(self, source, target):
self.check_simple_value(source_field_value, target_field_value)
break

source_field_value, target_field_value = self.extract_schema_if_union(
for source_field_value, target_field_value in self.extract_schema_if_union(
source_field_value, target_field_value
)
self.log.info("Recursing target with: source: %s target: %s", source_field_value, target_field_value)
self.check_compatibility(source_field_value, target_field_value)
):
self.log.info("Recursing target with: source: %s target: %s", source_field_value, target_field_value)
self.check_compatibility(source_field_value, target_field_value)
else:
found = True
self.log.info("source_field is: %s, target_field: %s added", source_field, target_field)
Expand Down Expand Up @@ -364,5 +383,5 @@ def check_compatibility(self, source, target):
if not same_type and not (source_union or target_union):
raise IncompatibleSchema("source {} and target {} different types".format(source, target))

source, target = self.extract_schema_if_union(source, target)
self.check_fields(source, target)
for source_f, target_f in self.extract_schema_if_union(source, target):
self.check_fields(source_f, target_f)
44 changes: 44 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,49 @@ async def enum_schema_compatibility_checks(c, compatibility):
assert res["fields"][0]["type"]["symbols"] == ["second"]


async def union_to_union_check(c):
subject = os.urandom(16).hex()
res = await c.put(f"config/{subject}", json={"compatibility": "BACKWARD"})
assert res.status == 200
init_schema = {"name": "init", "type": "record", "fields": [{"name": "inner", "type": ["string", "int"]}]}
evolved = {"name": "init", "type": "record", "fields": [{"name": "inner", "type": ["null", "string"]}]}
evolved_compatible = {
"name": "init",
"type": "record",
"fields": [{
"name": "inner",
"type": [
"int", "string", {
"type": "record",
"name": "foobar_fields",
"fields": [{
"name": "foo",
"type": "string"
}]
}
]
}]
}
res = await c.post(f"subjects/{subject}/versions", json={"schema": jsonlib.dumps(init_schema)})
assert res.status == 200
assert "id" in res.json()
res = await c.post(f"subjects/{subject}/versions", json={"schema": jsonlib.dumps(evolved)})
assert res.status == 409
res = await c.post(f"subjects/{subject}/versions", json={"schema": jsonlib.dumps(evolved_compatible)})
assert res.status == 200
# fw compat check
subject = os.urandom(16).hex()
res = await c.put(f"config/{subject}", json={"compatibility": "FORWARD"})
assert res.status == 200
res = await c.post(f"subjects/{subject}/versions", json={"schema": jsonlib.dumps(evolved_compatible)})
assert res.status == 200
assert "id" in res.json()
res = await c.post(f"subjects/{subject}/versions", json={"schema": jsonlib.dumps(evolved)})
assert res.status == 409
res = await c.post(f"subjects/{subject}/versions", json={"schema": jsonlib.dumps(init_schema)})
assert res.status == 200


async def record_union_schema_compatibility_checks(c):
subject = os.urandom(16).hex()
res = await c.put(f"config/{subject}", json={"compatibility": "BACKWARD"})
Expand Down Expand Up @@ -1472,6 +1515,7 @@ async def check_common_endpoints(c):

async def run_schema_tests(c):
await schema_checks(c)
await union_to_union_check(c)
await check_type_compatibility(c)
await compatibility_endpoint_checks(c)
await record_schema_compatibility_checks(c)
Expand Down

0 comments on commit 6d20218

Please sign in to comment.