-
Notifications
You must be signed in to change notification settings - Fork 262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix union validator to correctly identify inherited classes in discriminated unions #1613
base: main
Are you sure you want to change the base?
Conversation
…minated unions Update the `validate_smart` function to identify the correct object class in a discriminated union by using the class with the greatest number of matching fields and the greatest percentage of matching fields. * Modify `validate_smart` function in `src/validators/union.rs` to calculate the number and percentage of matching fields for each class and select the best match. * Add test cases in `tests/validators/test_union.py` to verify the updated `validate_smart` function for discriminated unions and ensure correct class selection based on the number and percentage of matching fields. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/pydantic/pydantic-core/tree/main?shareId=XXXX-XXXX-XXXX-XXXX).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a reasonable heuristic to me when picking what the best match should be, thanks.
cc @dmontagu we've talked about this kind of idea long ago, this seems like a reasonable small iteration towards a good result 👍
@@ -113,7 +113,7 @@ impl UnionValidator { | |||
let strict = state.strict_or(self.strict); | |||
let mut errors = MaybeErrors::new(self.custom_error.as_ref()); | |||
|
|||
let mut best_match: Option<(Py<PyAny>, Exactness, Option<usize>)> = None; | |||
let mut best_match: Option<(Py<PyAny>, Exactness, Option<usize>, usize, f64)> = None; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think now that we have so many fields here, we should create a struct
for these so we can give them names.
@@ -141,25 +141,24 @@ impl UnionValidator { | |||
|
|||
let new_exactness = state.exactness.unwrap_or(Exactness::Lax); | |||
let new_fields_set_count = state.fields_set_count; | |||
let new_fields_set_percentage = new_fields_set_count.map(|count| { | |||
let total_fields = input.as_dict().map_or(0, |dict| dict.len()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't depend on the validator so can be done once at the start of the function (i.e. above the self.choices
loop.)
@@ -141,25 +141,24 @@ impl UnionValidator { | |||
|
|||
let new_exactness = state.exactness.unwrap_or(Exactness::Lax); | |||
let new_fields_set_count = state.fields_set_count; | |||
let new_fields_set_percentage = new_fields_set_count.map(|count| { | |||
let total_fields = input.as_dict().map_or(0, |dict| dict.len()); | |||
count as f64 / total_fields as f64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid a divide by zero here I think that we should keep total_fields
as Option<usize>
. For example consider an int
or str
input, these have no fields.
Update the
validate_smart
function to identify the correct object class in a discriminated union by using the class with the greatest number of matching fields and the greatest percentage of matching fields.validate_smart
function insrc/validators/union.rs
to calculate the number and percentage of matching fields for each class and select the best match.tests/validators/test_union.py
to verify the updatedvalidate_smart
function for discriminated unions and ensure correct class selection based on the number and percentage of matching fields.For more details, open the Copilot Workspace session.