Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix union validator to correctly identify inherited classes in discriminated unions #1613

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions src/validators/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl UnionValidator {
let strict = state.strict_or(self.strict);
let mut errors = MaybeErrors::new(self.custom_error.as_ref());

let mut best_match: Option<(Py<PyAny>, Exactness, Option<usize>)> = None;
let mut best_match: Option<(Py<PyAny>, Exactness, Option<usize>, usize, f64)> = None;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think now that we have so many fields here, we should create a struct for these so we can give them names.


for (choice, label) in &self.choices {
let state = &mut state.rebind_extra(|extra| {
Expand Down Expand Up @@ -141,25 +141,24 @@ impl UnionValidator {

let new_exactness = state.exactness.unwrap_or(Exactness::Lax);
let new_fields_set_count = state.fields_set_count;
let new_fields_set_percentage = new_fields_set_count.map(|count| {
let total_fields = input.as_dict().map_or(0, |dict| dict.len());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't depend on the validator so can be done once at the start of the function (i.e. above the self.choices loop.)

count as f64 / total_fields as f64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid a divide by zero here I think that we should keep total_fields as Option<usize>. For example consider an int or str input, these have no fields.

});

// we use both the exactness and the fields_set_count to determine the best union member match
// if fields_set_count is available for the current best match and the new candidate, we use this
// as the primary metric. If the new fields_set_count is greater, the new candidate is better.
// if the fields_set_count is the same, we use the exactness as a tie breaker to determine the best match.
// if the fields_set_count is not available for either the current best match or the new candidate,
// we use the exactness to determine the best match.
let new_success_is_best_match: bool =
best_match
.as_ref()
.map_or(true, |(_, cur_exactness, cur_fields_set_count)| {
match (*cur_fields_set_count, new_fields_set_count) {
.map_or(true, |(_, cur_exactness, cur_fields_set_count, cur_fields_set_count_value, cur_fields_set_percentage)| {
match (cur_fields_set_count, new_fields_set_count) {
(Some(cur), Some(new)) if cur != new => cur < new,
(Some(_), Some(_)) => cur_fields_set_percentage < new_fields_set_percentage,
_ => *cur_exactness < new_exactness,
}
});

if new_success_is_best_match {
best_match = Some((new_success, new_exactness, new_fields_set_count));
best_match = Some((new_success, new_exactness, new_fields_set_count, new_fields_set_count.unwrap_or(0), new_fields_set_percentage.unwrap_or(0.0)));
}
}
},
Expand All @@ -177,7 +176,7 @@ impl UnionValidator {
state.exactness = old_exactness;
state.fields_set_count = old_fields_set_count;

if let Some((best_match, exactness, fields_set_count)) = best_match {
if let Some((best_match, exactness, fields_set_count, _, _)) = best_match {
state.floor_exactness(exactness);
if let Some(count) = fields_set_count {
state.add_fields_set(count);
Expand Down
42 changes: 42 additions & 0 deletions tests/validators/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,3 +1333,45 @@ class Model:

assert isinstance(validator.validate_python({'x': {'foo': 'foo'}}).x, Foo)
assert isinstance(validator.validate_python({'x': {'bar': 'bar'}}).x, Bar)


class A:
a: int


class B(A):
b: int


def test_discriminated_union():
v = SchemaValidator(
{
'type': 'union',
'choices': [
{
'type': 'model',
'cls': A,
'schema': {
'type': 'model-fields',
'fields': {
'a': {'type': 'model-field', 'schema': {'type': 'int'}},
},
},
},
{
'type': 'model',
'cls': B,
'schema': {
'type': 'model-fields',
'fields': {
'a': {'type': 'model-field', 'schema': {'type': 'int'}},
'b': {'type': 'model-field', 'schema': {'type': 'int'}},
},
},
},
],
}
)

assert isinstance(v.validate_python({'a': 1}), A)
assert isinstance(v.validate_python({'a': 1, 'b': 2}), B)
Loading