diff --git a/src/validators/union.rs b/src/validators/union.rs index bfe744212..531e2ca9d 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -113,7 +113,7 @@ impl UnionValidator { let strict = state.strict_or(self.strict); let mut errors = MaybeErrors::new(self.custom_error.as_ref()); - let mut best_match: Option<(Py, Exactness, Option)> = None; + let mut best_match: Option<(Py, Exactness, Option, usize, f64)> = None; for (choice, label) in &self.choices { let state = &mut state.rebind_extra(|extra| { @@ -141,25 +141,24 @@ impl UnionValidator { let new_exactness = state.exactness.unwrap_or(Exactness::Lax); let new_fields_set_count = state.fields_set_count; + let new_fields_set_percentage = new_fields_set_count.map(|count| { + let total_fields = input.as_dict().map_or(0, |dict| dict.len()); + count as f64 / total_fields as f64 + }); - // we use both the exactness and the fields_set_count to determine the best union member match - // if fields_set_count is available for the current best match and the new candidate, we use this - // as the primary metric. If the new fields_set_count is greater, the new candidate is better. - // if the fields_set_count is the same, we use the exactness as a tie breaker to determine the best match. - // if the fields_set_count is not available for either the current best match or the new candidate, - // we use the exactness to determine the best match. let new_success_is_best_match: bool = best_match .as_ref() - .map_or(true, |(_, cur_exactness, cur_fields_set_count)| { - match (*cur_fields_set_count, new_fields_set_count) { + .map_or(true, |(_, cur_exactness, cur_fields_set_count, cur_fields_set_count_value, cur_fields_set_percentage)| { + match (cur_fields_set_count, new_fields_set_count) { (Some(cur), Some(new)) if cur != new => cur < new, + (Some(_), Some(_)) => cur_fields_set_percentage < new_fields_set_percentage, _ => *cur_exactness < new_exactness, } }); if new_success_is_best_match { - best_match = Some((new_success, new_exactness, new_fields_set_count)); + best_match = Some((new_success, new_exactness, new_fields_set_count, new_fields_set_count.unwrap_or(0), new_fields_set_percentage.unwrap_or(0.0))); } } }, @@ -177,7 +176,7 @@ impl UnionValidator { state.exactness = old_exactness; state.fields_set_count = old_fields_set_count; - if let Some((best_match, exactness, fields_set_count)) = best_match { + if let Some((best_match, exactness, fields_set_count, _, _)) = best_match { state.floor_exactness(exactness); if let Some(count) = fields_set_count { state.add_fields_set(count); diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index 1951fb3de..009317760 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -1333,3 +1333,45 @@ class Model: assert isinstance(validator.validate_python({'x': {'foo': 'foo'}}).x, Foo) assert isinstance(validator.validate_python({'x': {'bar': 'bar'}}).x, Bar) + + +class A: + a: int + + +class B(A): + b: int + + +def test_discriminated_union(): + v = SchemaValidator( + { + 'type': 'union', + 'choices': [ + { + 'type': 'model', + 'cls': A, + 'schema': { + 'type': 'model-fields', + 'fields': { + 'a': {'type': 'model-field', 'schema': {'type': 'int'}}, + }, + }, + }, + { + 'type': 'model', + 'cls': B, + 'schema': { + 'type': 'model-fields', + 'fields': { + 'a': {'type': 'model-field', 'schema': {'type': 'int'}}, + 'b': {'type': 'model-field', 'schema': {'type': 'int'}}, + }, + }, + }, + ], + } + ) + + assert isinstance(v.validate_python({'a': 1}), A) + assert isinstance(v.validate_python({'a': 1, 'b': 2}), B)