-
Notifications
You must be signed in to change notification settings - Fork 264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a WalkCoreSchema #1099
Add a WalkCoreSchema #1099
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1099 +/- ##
==========================================
+ Coverage 89.70% 89.83% +0.12%
==========================================
Files 106 107 +1
Lines 16364 16982 +618
Branches 35 35
==========================================
+ Hits 14680 15255 +575
- Misses 1677 1720 +43
Partials 7 7
Continue to review full report in Codecov by Sentry.
|
CodSpeed Performance ReportMerging #1099 will degrade performances by 18.81%Comparing Summary
Benchmarks breakdown
|
I like the predicate filter idea, I think hopefully it'll lead to a smaller implementation that'll also be easier to maintain if we introduce future schema types. You could even make the predicates user-suppliable in Python which may help adopting the implementation (move more predicates to Rust as needed). |
That's an interesting idea. Like make something like: VisitPredicate = Callable[[CoreSchema], bool] That gets run from Rust. Then write any predicates we need in Python. For example: @dataclass
class CombinedPredicate:
call: Callable[[CoreSchema], bool]
def __call__(self, schema):
return self.call(schema)
class CombinablePredicate:
def __or__(self, other):
return CombinedPredicate(lambda s: self(s) or other(s))
class HasRef(CombinablePredicate):
def __call__(self, schema: CoreSchema) -> bool:
return bool(schema.get('ref', False)) And once those are stabilized we can move them to Rust. Is that what you had in mind? |
@davidhewitt I implemented the filter API as discussed above |
6159bc6
to
6f3abc4
Compare
@davidhewitt I benchmarked this and it's coming out no faster than our existing Python version (which calls a Python function at every level in addition to doing the traversal in Python) even when there is no filter (so it never calls into Python). import timeit
from typing import Any, Callable
from pydantic._internal._core_utils import walk_core_schema
from pydantic_core import CoreSchema, WalkCoreSchema
from pydantic_core import core_schema as cs
def plain_ser_func(x: Any) -> str:
return 'abc'
def wrap_ser_func(x: Any, handler: cs.SerializerFunctionWrapHandler) -> Any:
return handler(x)
def no_info_val_func(x: Any) -> Any:
return x
def no_info_wrap_val_func(x: Any, handler: cs.ValidatorFunctionWrapHandler) -> Any:
return handler(x)
class NamedClass:
pass
schema = cs.union_schema(
[
cs.any_schema(serialization=cs.plain_serializer_function_ser_schema(plain_ser_func)),
cs.none_schema(serialization=cs.plain_serializer_function_ser_schema(plain_ser_func)),
cs.bool_schema(serialization=cs.simple_ser_schema('bool')),
cs.int_schema(serialization=cs.simple_ser_schema('int')),
cs.float_schema(serialization=cs.simple_ser_schema('float')),
cs.decimal_schema(serialization=cs.plain_serializer_function_ser_schema(plain_ser_func)),
cs.str_schema(serialization=cs.simple_ser_schema('str')),
cs.bytes_schema(serialization=cs.simple_ser_schema('bytes')),
cs.date_schema(serialization=cs.simple_ser_schema('date')),
cs.time_schema(serialization=cs.simple_ser_schema('time')),
cs.datetime_schema(serialization=cs.simple_ser_schema('datetime')),
cs.timedelta_schema(serialization=cs.simple_ser_schema('timedelta')),
cs.literal_schema(
expected=[1, 2, 3],
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.is_instance_schema(
cls=NamedClass,
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.is_subclass_schema(
cls=NamedClass,
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.callable_schema(
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.list_schema(
cs.int_schema(serialization=cs.simple_ser_schema('int')),
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.tuple_positional_schema(
[cs.int_schema(serialization=cs.simple_ser_schema('int'))],
extras_schema=cs.int_schema(serialization=cs.simple_ser_schema('int')),
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.tuple_variable_schema(
cs.int_schema(serialization=cs.simple_ser_schema('int')),
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.set_schema(
cs.int_schema(serialization=cs.simple_ser_schema('int')),
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.frozenset_schema(
cs.int_schema(serialization=cs.simple_ser_schema('int')),
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.generator_schema(
cs.int_schema(serialization=cs.simple_ser_schema('int')),
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.dict_schema(
cs.int_schema(serialization=cs.simple_ser_schema('int')),
cs.int_schema(serialization=cs.simple_ser_schema('int')),
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.no_info_after_validator_function(
no_info_val_func,
cs.int_schema(),
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.no_info_before_validator_function(
no_info_val_func,
cs.int_schema(),
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.no_info_wrap_validator_function(
no_info_wrap_val_func,
cs.int_schema(),
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.no_info_plain_validator_function(
no_info_val_func,
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.with_default_schema(
cs.int_schema(),
default=1,
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.nullable_schema(
cs.int_schema(),
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.union_schema(
[
cs.int_schema(),
cs.str_schema(),
],
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.tagged_union_schema(
{
'a': cs.int_schema(),
'b': cs.str_schema(),
},
'type',
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.chain_schema(
[
cs.int_schema(),
cs.str_schema(),
],
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.lax_or_strict_schema(
cs.int_schema(),
cs.str_schema(),
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.json_or_python_schema(
cs.int_schema(),
cs.str_schema(),
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.typed_dict_schema(
{'a': cs.typed_dict_field(cs.int_schema())},
computed_fields=[
cs.computed_field(
'b',
cs.int_schema(),
)
],
extras_schema=cs.int_schema(serialization=cs.simple_ser_schema('int')),
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.model_schema(
NamedClass,
cs.model_fields_schema(
{'a': cs.model_field(cs.int_schema())},
extras_schema=cs.int_schema(serialization=cs.simple_ser_schema('int')),
computed_fields=[
cs.computed_field(
'b',
cs.int_schema(),
)
],
),
),
cs.dataclass_schema(
NamedClass,
cs.dataclass_args_schema(
'Model',
[cs.dataclass_field('a', cs.int_schema())],
computed_fields=[
cs.computed_field(
'b',
cs.int_schema(),
)
],
),
['a'],
),
cs.call_schema(
cs.arguments_schema(
[cs.arguments_parameter('x', cs.int_schema())],
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
no_info_val_func,
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.custom_error_schema(
cs.int_schema(),
custom_error_type='CustomError',
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.json_schema(
cs.int_schema(),
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.url_schema(
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.multi_host_url_schema(
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.definitions_schema(
cs.int_schema(),
[
cs.int_schema(ref='#/definitions/int'),
],
),
cs.definition_reference_schema(
'#/definitions/int',
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
cs.uuid_schema(
serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
),
]
)
def walk_core() -> None:
WalkCoreSchema().walk(schema)
Recurse = Callable[[cs.CoreSchema, 'Walk'], cs.CoreSchema]
Walk = Callable[[cs.CoreSchema, Recurse], cs.CoreSchema]
def visit_pydantic(schema: cs.CoreSchema, recurse: Recurse) -> CoreSchema:
return recurse(schema, visit_pydantic)
def walk_pydantic() -> None:
walk_core_schema(schema, visit_pydantic)
print(timeit.timeit(walk_core, number=1000))
print(timeit.timeit(walk_pydantic, number=1000)) |
Let's revisit this once #1085 gets merged which might improve performance significantly. |
Closing as stale, given that we closed #615 above |
@samuelcolvin @dmontagu we discussed this previously and I shot it down because just implementing what we have in pydantic in Rust would not be much faster (aside from the speedup of calling CPython APIs from Rust) because for every 2-3 key accesses we do in Rust (which would be faster) we'd be calling into Python and back (so the absolute change may not be very large and there's the FFI slowdown to contend with).
I was thinking about it some more and I think if we change the API we have in pydantic to this we can get a much larger speedup. Essentially, instead of having a single callback for all schemas I'm doing a different callback for each schema. This serves as a sort of "filter" to minimize calls into Python. Out of the ~3 "walks" we do in pydantic this covers two:
(there are some others for discriminated unions and such, I haven't looked into those)
However, this does not cover the case where we need to visit every schema:
https://github.com/pydantic/pydantic/blob/667cd3776ee40e06018d0b7ff477c6cd0199b098/pydantic/_internal/_core_utils.py#L449-L450
For that last case I see a couple of options:
visit_all_schemas
callback that slows things down significantly but allows visiting all schemas (and hence collecting all refs).visit_schema_with_ref
that gets called for any schema with aref
. This seems somewhat reasonable but it may be a bit too "specialized" of an implementation for our current use case. That is, it's a bandaid solution to a poor API.Walk(visit=[if_schema_has_key("ref")(callback), if_schema_has_type("int")(callback), (if_schema_has_type("int") & if_schema_has_key("ref"))(callback)])
. This maybe also works to get rid of the dozens of arguments to the constructor this implementation currently has.