Skip to content

Commit

Permalink
Fix union validator to correctly identify inherited classes in discri…
Browse files Browse the repository at this point in the history
…minated unions

Update the `validate_smart` function to identify the correct object class in a discriminated union by using the class with the greatest number of matching fields and the greatest percentage of matching fields.

* Modify `validate_smart` function in `src/validators/union.rs` to calculate the number and percentage of matching fields for each class and select the best match.
* Add test cases in `tests/validators/test_union.py` to verify the updated `validate_smart` function for discriminated unions and ensure correct class selection based on the number and percentage of matching fields.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/pydantic/pydantic-core/tree/main?shareId=XXXX-XXXX-XXXX-XXXX).
  • Loading branch information
benglewis committed Jan 30, 2025
1 parent 0ede4d1 commit 05e7f1b
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 11 deletions.
21 changes: 10 additions & 11 deletions src/validators/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl UnionValidator {
let strict = state.strict_or(self.strict);
let mut errors = MaybeErrors::new(self.custom_error.as_ref());

let mut best_match: Option<(Py<PyAny>, Exactness, Option<usize>)> = None;
let mut best_match: Option<(Py<PyAny>, Exactness, Option<usize>, usize, f64)> = None;

for (choice, label) in &self.choices {
let state = &mut state.rebind_extra(|extra| {
Expand Down Expand Up @@ -141,25 +141,24 @@ impl UnionValidator {

let new_exactness = state.exactness.unwrap_or(Exactness::Lax);
let new_fields_set_count = state.fields_set_count;
let new_fields_set_percentage = new_fields_set_count.map(|count| {
let total_fields = input.as_dict().map_or(0, |dict| dict.len());
count as f64 / total_fields as f64
});

// we use both the exactness and the fields_set_count to determine the best union member match
// if fields_set_count is available for the current best match and the new candidate, we use this
// as the primary metric. If the new fields_set_count is greater, the new candidate is better.
// if the fields_set_count is the same, we use the exactness as a tie breaker to determine the best match.
// if the fields_set_count is not available for either the current best match or the new candidate,
// we use the exactness to determine the best match.
let new_success_is_best_match: bool =
best_match
.as_ref()
.map_or(true, |(_, cur_exactness, cur_fields_set_count)| {
match (*cur_fields_set_count, new_fields_set_count) {
.map_or(true, |(_, cur_exactness, cur_fields_set_count, cur_fields_set_count_value, cur_fields_set_percentage)| {
match (cur_fields_set_count, new_fields_set_count) {
(Some(cur), Some(new)) if cur != new => cur < new,
(Some(_), Some(_)) => cur_fields_set_percentage < new_fields_set_percentage,
_ => *cur_exactness < new_exactness,
}
});

if new_success_is_best_match {
best_match = Some((new_success, new_exactness, new_fields_set_count));
best_match = Some((new_success, new_exactness, new_fields_set_count, new_fields_set_count.unwrap_or(0), new_fields_set_percentage.unwrap_or(0.0)));
}
}
},
Expand All @@ -177,7 +176,7 @@ impl UnionValidator {
state.exactness = old_exactness;
state.fields_set_count = old_fields_set_count;

if let Some((best_match, exactness, fields_set_count)) = best_match {
if let Some((best_match, exactness, fields_set_count, _, _)) = best_match {
state.floor_exactness(exactness);
if let Some(count) = fields_set_count {
state.add_fields_set(count);
Expand Down
42 changes: 42 additions & 0 deletions tests/validators/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,3 +1333,45 @@ class Model:

assert isinstance(validator.validate_python({'x': {'foo': 'foo'}}).x, Foo)
assert isinstance(validator.validate_python({'x': {'bar': 'bar'}}).x, Bar)


class A:
a: int


class B(A):
b: int


def test_discriminated_union():
v = SchemaValidator(
{
'type': 'union',
'choices': [
{
'type': 'model',
'cls': A,
'schema': {
'type': 'model-fields',
'fields': {
'a': {'type': 'model-field', 'schema': {'type': 'int'}},
},
},
},
{
'type': 'model',
'cls': B,
'schema': {
'type': 'model-fields',
'fields': {
'a': {'type': 'model-field', 'schema': {'type': 'int'}},
'b': {'type': 'model-field', 'schema': {'type': 'int'}},
},
},
},
],
}
)

assert isinstance(v.validate_python({'a': 1}), A)
assert isinstance(v.validate_python({'a': 1, 'b': 2}), B)

0 comments on commit 05e7f1b

Please sign in to comment.