diff --git a/test/export/test_export.py b/test/export/test_export.py index 78c56878f76fe..fb5f55d0389a5 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -8157,6 +8157,7 @@ def forward(self, x): export(f, (inputs,), dynamic_shapes=dynamic_shapes) @testing.expectedFailureRetraceabilityNonStrict + @testing.expectedFailureCppSerDes # dynamic shape serialization def test_disable_forced_specializations_ok(self): # check that we don't force specialization, and defer to runtime asserts # with allow_complex_guards_as_runtime_asserts=True to successfully export @@ -9210,7 +9211,9 @@ def forward(self, input1: torch.Tensor): inps = (torch.randn(1, 224, 768, device="cpu"),) export(Foo(), inps) - @testing.expectedFailureCppSerDes + @testing.expectedFailureSerDer # TODO(pianpwk): PowByNatural valuerange deserialization + @testing.expectedFailureCppSerDes # TODO(pianpwk): PowByNatural valuerange deserialization + @testing.expectedFailureSerDerNonStrict @testing.expectedFailureRetraceabilityNonStrict def test_dim_dynamic(self): dynamic = Dim.DYNAMIC diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index dbd5b401d7401..025cd6b0f3010 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -8,7 +8,7 @@ from torch._export.serde.union import _Union # NOTE: Please update this value if any modifications are made to the schema -SCHEMA_VERSION = (8, 2) +SCHEMA_VERSION = (8, 1) TREESPEC_VERSION = 1 @@ -62,42 +62,13 @@ class SymExprHint(_Union): as_bool: bool -# A leaf node in a SymExprNode, containing a bool/float/int/sympy.Symbol. -@dataclass(repr=False) -class SymBase(_Union): - as_bool: bool - as_float: float - as_int: int - as_symbol: str - - -# Represents an AST node in a sympy.Expr. -# If not a leaf node, "target" is a string representing the operator, -# and "args" is a list of child SymExprNodes. -# If a leaf node, "target" is None, "args" is empty, and "base" is a SymBase -# representing the leaf value. -@dataclass(repr=False) -class SymExprNode: - args: List["SymExprNode"] = field(default_factory=list) - target: Optional[str] = None - base: Optional[SymBase] = None - - # This is for storing the symbolic expressions behind symints/symfloats/symbools -# The deprecated "expr_str" field is easier to explain; we could store +# For example, we can get something like # SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4) -# for an expression where s0 == s1 == 2. -# We're moving away from expr_str for roundtrippability, and now deserialize into -# the "expr_ast" field, which is a tree representation of the expression, -# containing a root SymExprNode. -# While we're deprecating this, we'll store an empty string in "expr_str" for now, -# and support deserialization for both "expr_str" and "expr_ast" fields. -# We'll only serialize to "expr_ast". -# TODO(pianpwk): implement upgrader & delete. +# if we also have the hint that s0 and s1 are both 2. @dataclass class SymExpr: expr_str: str - expr_ast: Optional[SymExprNode] = None hint: Optional[SymExprHint] = None diff --git a/torch/_export/serde/schema.yaml b/torch/_export/serde/schema.yaml index ba0324a8c92a2..57de5d6fb689c 100644 --- a/torch/_export/serde/schema.yaml +++ b/torch/_export/serde/schema.yaml @@ -1,5 +1,5 @@ # @generated by update_schema.py -# checksum<<74ae9f550efb42873fc58f07990cefe4da35d38c673beeacbe1e20808d9a6962>> +# checksum<<8e27d48014d4ec1c773aef056c0c20b61bead54be8338e95d3347d3422472b9a>> Argument: kind: union fields: @@ -346,17 +346,6 @@ SchemaVersion: type: int minor: type: int -SymBase: - kind: union - fields: - as_bool: - type: bool - as_float: - type: float - as_int: - type: int - as_symbol: - type: str SymBool: kind: union fields: @@ -376,9 +365,6 @@ SymExpr: fields: expr_str: type: str - expr_ast: - type: Optional[SymExprNode] - default: None hint: type: Optional[SymExprHint] default: None @@ -391,18 +377,6 @@ SymExprHint: type: float as_bool: type: bool -SymExprNode: - kind: struct - fields: - args: - type: List[SymExprNode] - default: '[]' - target: - type: Optional[str] - default: None - base: - type: Optional[SymBase] - default: None SymInt: kind: union fields: @@ -463,5 +437,5 @@ UserOutputSpec: type: Argument SCHEMA_VERSION: - 8 -- 2 +- 1 TREESPEC_VERSION: 1 diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index 2705c5ec0c042..b71c02829ceba 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -5,7 +5,7 @@ import re import typing from enum import IntEnum -from typing import Any, Dict, ForwardRef, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from torch._export.serde import schema from torch._export.serde.union import _Union @@ -71,8 +71,6 @@ def dump_type(t) -> Tuple[str, str]: ) elif t == (): return "()", "" - elif isinstance(t, ForwardRef): - return t.__forward_arg__, f"ForwardRef<{t.__forward_arg__}>" else: raise AssertionError(f"Type {t} is not supported in export schema.") diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 68408ea19a5a0..6e25f8f89d1ef 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -12,7 +12,6 @@ import math import operator import re -import sys import typing import traceback @@ -82,12 +81,10 @@ RangeConstraint, ScalarType, SCHEMA_VERSION, - SymBase, SymBool, SymBoolArgument, SymExpr, SymExprHint, - SymExprNode, SymInt, SymIntArgument, TensorArgument, @@ -228,50 +225,17 @@ def deserialize_device(d: Device) -> torch.device: return torch.device(type=d.type, index=d.index) -def serialize_sym_expr(s: Union[torch.SymInt, sympy.Symbol]) -> SymExprNode: - base_types = ( - sympy.Symbol, - sympy.logic.boolalg.BooleanAtom, - sympy.Integer, - sympy.Float, - ) - expr = ( - s.node.expr - if isinstance(s, (torch.SymBool, torch.SymFloat, torch.SymInt)) - else s - ) - if isinstance(expr, base_types): - if isinstance(expr, sympy.Symbol): - base = SymBase.create(as_symbol=str(expr)) - elif isinstance(expr, sympy.logic.boolalg.Boolean): - base = SymBase.create(as_bool=bool(expr)) - elif isinstance(expr, sympy.Integer): - base = SymBase.create(as_int=int(expr)) - elif isinstance(expr, sympy.Float): - base = SymBase.create(as_float=float(expr)) - else: - raise SerializeError("how did we get here?") - return SymExprNode(target=None, base=base, args=[]) - else: - return SymExprNode( - target=type(expr).__module__ + "." + type(expr).__qualname__, - base=None, - args=[serialize_sym_expr(_s) for _s in expr.args], - ) - - def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt: if isinstance(s, (torch.SymInt, sympy.Symbol, int)): if symbolic_shapes.is_concrete_int(s): return SymInt.create(as_int=int(s)) else: assert isinstance(s, (torch.SymInt, sympy.Symbol)) - sym_expr = serialize_sym_expr(s) if s.node.hint is None: - return SymInt.create(as_expr=SymExpr(expr_str="", expr_ast=sym_expr)) # type: ignore[arg-type] + return SymInt.create(as_expr=SymExpr(str(s))) else: return SymInt.create( - as_expr=SymExpr(expr_str="", expr_ast=sym_expr, hint=SymExprHint.create(as_int=s.node.hint)) # type: ignore[arg-type] + as_expr=SymExpr(str(s), hint=SymExprHint.create(as_int=s.node.hint)) ) else: raise SerializeError( @@ -284,7 +248,7 @@ def serialize_sym_bool(s: Union[bool, torch.SymBool]) -> SymBool: if symbolic_shapes.is_concrete_bool(s): return SymBool.create(as_bool=bool(s)) else: - return SymBool.create(as_expr=SymExpr(expr_str="", expr_ast=serialize_sym_expr(s))) # type: ignore[arg-type] + return SymBool.create(as_expr=SymExpr(expr_str=str(s))) else: raise SerializeError( f"SymBool should be either symbol or bool, got `{s}` of type `{type(s)}`" @@ -1523,120 +1487,6 @@ def deserialize_operator(self, serialized_target: str): target = getattr(target, name) return target - def deserialize_sym_expr( - self, s: SymExprNode, hint: Optional[int] = None, force_int: Optional[bool] = True - ) -> Union[sympy.Expr]: - if (base := s.base) is not None: - if base.type == "as_bool": - assert isinstance(base.value, bool) - return base.value - elif base.type == "as_float": - assert isinstance(base.value, float) - return base.value - elif base.type == "as_int": - assert isinstance(base.value, int) - return base.value - else: - assert base.type == "as_symbol" - assert isinstance(base.value, str) - if base.value in self.symbol_name_to_symbol: - s = self.symbol_name_to_symbol[base.value] - else: - s = sympy.Symbol(base.value, integer=force_int) - self.symbol_name_to_symbol[base.value] = s - assert isinstance(s, sympy.Symbol) - if ( - hint is not None - and s not in self.shape_env.var_to_val - ): - self.shape_env.add_var_to_val(s, hint) - if vr := self.symbol_name_to_range.get(base.value): - self.shape_env.constrain_symbol_range( - s, - compiler_min=vr.lower, # type: ignore[arg-type] - compiler_max=vr.upper, # type: ignore[arg-type] - ) - return s - else: - if s.target is None: - raise SerializeError( - f"Expected SymExprNode.target to not be None, but found None for {s}" - ) - try: - *_module, _qualname = s.target.split(".") - node_cls = getattr(sys.modules[".".join(_module)], _qualname) - except Exception as exc: - raise SerializeError( - f"Error importing SymExprNode.target with type: {s.target}" - ) from exc - expr = node_cls(*[self.deserialize_sym_expr(arg, force_int=force_int) for arg in s.args]) - expr_str = str(expr) - if expr_str not in self.symbol_name_to_symbol: - self.symbol_name_to_symbol[expr_str] = expr - if vr := self.symbol_name_to_range.get(expr_str): - self.shape_env.constrain_symbol_range( - expr, - compiler_min=vr.lower, # type: ignore[arg-type] - compiler_max=vr.upper, # type: ignore[arg-type] - ) - return expr - - def __deprecated_do_not_use_deserialize_symint_expr_str(self, s: str, hint: Optional[int] = None) -> sympy.Expr: - """ - Old logic for deserializing sympy.Exprs stored as strings, where we would call sympy.sympify(), which does not - provide roundtrippability guarantees. We've switched to storing ASTs instead, but are keeping this alive for BC. - TODO(pianpwk): implement upgrader & delete this, along with "expr_str" field. - """ - if s in self.symbol_name_to_symbol: - sym = self.symbol_name_to_symbol[s] - else: - sym = sympy.sympify( - s, - locals={**self.sympy_functions, **self.symbol_name_to_symbol}, - ) - # NOTE(avik): Assumptions on symbols are not explicitly serialized. - # This seems dangerous: it might cause unknown differences in shape env behavior - # on deserialization? Probably deserves a follow-up. - - # Here we force symbols corresponding to SymInts to be at least integers. - # Otherwise some expressions that the shape env would otherwise evaluate to False, - # e.g., 2*s = 9, can have rational solutions, e.g., 9/2. - # TODO: This is HIGHLY SUSPICIOUS ezyang(May 2024) - sym = sym.subs( - {s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols} - ) - # We need to check if the symbol has already been allocated, - # self.symbol_name_to_symbol is not enough because the - # integer-ification of symbols can induce simplification; - # e.g., (2**s0 + 1) // 2 --> s0 when we know s0 is integral - if isinstance(sym, sympy.Symbol) and sym not in self.shape_env.var_to_val: - self.symbol_name_to_symbol[s] = sym - if hint is not None: - self.shape_env.add_var_to_val(sym, hint) - - if vr := self.symbol_name_to_range.get(s): - self.shape_env.constrain_symbol_range( - sym, - compiler_min=vr.lower, # type: ignore[arg-type] - compiler_max=vr.upper, # type: ignore[arg-type] - ) - else: - # Placeholders, in particular, can have shapes as symbolic expressions. - # We need to populate the shape env with the range constraints of their - # free symbols, otherwise evaluating such expressions will error. - self.symbol_name_to_symbol[s] = sym - free_symbols = sym.free_symbols - for s in free_symbols: - if s.name not in self.symbol_name_to_symbol: # type: ignore[attr-defined] - self.symbol_name_to_symbol[s.name] = s # type: ignore[attr-defined, assignment] - if vr := self.symbol_name_to_range.get(s.name): # type: ignore[attr-defined] - self.shape_env.constrain_symbol_range( - s, - compiler_min=vr.lower, # type: ignore[arg-type] - compiler_max=vr.upper, # type: ignore[arg-type] - ) - return sym - def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: val = s.value if s.type == "as_expr": @@ -1645,10 +1495,56 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: else: assert val.hint.type == "as_int" hint = val.hint.value - if val.expr_str != "": - sym = self.__deprecated_do_not_use_deserialize_symint_expr_str(val.expr_str, hint=hint) + + if val.expr_str in self.symbol_name_to_symbol: + sym = self.symbol_name_to_symbol[val.expr_str] else: - sym = self.deserialize_sym_expr(val.expr_ast, hint=hint, force_int=True) + sym = sympy.sympify( + val.expr_str, + locals={**self.sympy_functions, **self.symbol_name_to_symbol}, + ) + # NOTE(avik): Assumptions on symbols are not explicitly serialized. + # This seems dangerous: it might cause unknown differences in shape env behavior + # on deserialization? Probably deserves a follow-up. + + # Here we force symbols corresponding to SymInts to be at least integers. + # Otherwise some expressions that the shape env would otherwise evaluate to False, + # e.g., 2*s = 9, can have rational solutions, e.g., 9/2. + # TODO: This is HIGHLY SUSPICIOUS ezyang(May 2024) + sym = sym.subs( + {s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols} + ) + # We need to check if the symbol has already been allocated, + # self.symbol_name_to_symbol is not enough because the + # integer-ification of symbols can induce simplification; + # e.g., (2**s0 + 1) // 2 --> s0 when we know s0 is integral + if isinstance(sym, sympy.Symbol) and sym not in self.shape_env.var_to_val: + self.symbol_name_to_symbol[val.expr_str] = sym + if hint is not None: + self.shape_env.add_var_to_val(sym, hint) + + if vr := self.symbol_name_to_range.get(val.expr_str): + self.shape_env.constrain_symbol_range( + sym, + compiler_min=vr.lower, # type: ignore[arg-type] + compiler_max=vr.upper, # type: ignore[arg-type] + ) + else: + # Placeholders, in particular, can have shapes as symbolic expressions. + # We need to populate the shape env with the range constraints of their + # free symbols, otherwise evaluating such expressions will error. + self.symbol_name_to_symbol[val.expr_str] = sym + free_symbols = sym.free_symbols + for s in free_symbols: + if s.name not in self.symbol_name_to_symbol: + self.symbol_name_to_symbol[s.name] = s # type: ignore[assignment] + if vr := self.symbol_name_to_range.get(s.name): + self.shape_env.constrain_symbol_range( + s, + compiler_min=vr.lower, # type: ignore[arg-type] + compiler_max=vr.upper, # type: ignore[arg-type] + ) + return self.shape_env.create_symintnode(sym, hint=hint) elif s.type == "as_int": assert type(val) is int @@ -1658,29 +1554,19 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: f"SymInt has invalid field type {s.type} with value {s.value}" ) - def __deprecated_do_not_use_deserialize_symbool_expr_str(self, s: str) -> sympy.Expr: - """ - See docstring in __deprecated_do_not_use_deserialize_symint_expr_str. - TODO(pianpwk): implement upgrader & delete this, along with "expr_str" field. - """ - # first we sympify this just to access any untracked symbols - expr = sympy.sympify(s) - for sym in expr.free_symbols: - if ( - not isinstance(sym, sympy.Number) - and str(sym) not in self.symbol_name_to_symbol - ): - self.deserialize_sym_int(SymInt.create(as_expr=SymExpr(expr_str=str(sym)))) - # then we sympify again using locals to correctly reify with the constructed symbols - return sympy.sympify(s, locals=self.symbol_name_to_symbol) - def deserialize_sym_bool(self, s: SymBool) -> Union[bool, torch.SymBool]: val = s.value if s.type == "as_expr": - if val.expr_str != "": - expr = self.__deprecated_do_not_use_deserialize_symbool_expr_str(val.expr_str) - else: - expr = self.deserialize_sym_expr(val.expr_ast, force_int=True) + # first we sympify this just to access any untracked symbols + expr = sympy.sympify(val.expr_str) + for sym in expr.free_symbols: + if ( + not isinstance(sym, sympy.Number) + and str(sym) not in self.symbol_name_to_symbol + ): + self.deserialize_sym_int(SymInt.create(as_expr=SymExpr(str(sym)))) + # then we sympify again using locals to correctly reify with the constructed symbols + expr = sympy.sympify(val.expr_str, locals=self.symbol_name_to_symbol) return self.shape_env.create_symboolnode(expr) elif s.type == "as_bool": assert isinstance(val, bool) diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h index 21cab4cc92b42..934a73022ba1f 100644 --- a/torch/csrc/utils/generated_serialization_types.h +++ b/torch/csrc/utils/generated_serialization_types.h @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<<74ae9f550efb42873fc58f07990cefe4da35d38c673beeacbe1e20808d9a6962>> +// checksum<<8e27d48014d4ec1c773aef056c0c20b61bead54be8338e95d3347d3422472b9a>> // clang-format off #pragma once @@ -114,12 +114,10 @@ class OutputSpec; class OutputTokenSpec; class RangeConstraint; class SchemaVersion; -class SymBase; class SymBool; class SymBoolArgument; class SymExpr; class SymExprHint; -class SymExprNode; class SymInt; class SymIntArgument; class TensorArgument; @@ -251,112 +249,9 @@ class SymExprHint { } }; -class SymBase { - struct Void {}; - - public: - enum class Tag { - AS_BOOL, AS_FLOAT, AS_INT, AS_SYMBOL - }; - - private: - std::variant variant_; - Tag tag_; - - public: - Tag tag() const { - return tag_; - } - - const bool& get_as_bool() const { - return std::get<1>(variant_); - } - - const double& get_as_float() const { - return std::get<2>(variant_); - } - - const int64_t& get_as_int() const { - return std::get<3>(variant_); - } - - const std::string& get_as_symbol() const { - return std::get<4>(variant_); - } - - friend void to_json(nlohmann::json& nlohmann_json_j, const SymBase& nlohmann_json_t) { - - if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { - nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); - return; - } - if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) { - nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float(); - return; - } - if (nlohmann_json_t.tag_ == Tag::AS_INT) { - nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); - return; - } - if (nlohmann_json_t.tag_ == Tag::AS_SYMBOL) { - nlohmann_json_j["as_symbol"] = nlohmann_json_t.get_as_symbol(); - return; - } - } - - friend void from_json(const nlohmann::json& nlohmann_json_j, SymBase& nlohmann_json_t) { - - if (nlohmann_json_j.contains("as_bool")) { - nlohmann_json_t.variant_.emplace<1>(nlohmann_json_j.at("as_bool").template get()); - nlohmann_json_t.tag_ = Tag::AS_BOOL; - return; - } - if (nlohmann_json_j.contains("as_float")) { - nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_float").template get()); - nlohmann_json_t.tag_ = Tag::AS_FLOAT; - return; - } - if (nlohmann_json_j.contains("as_int")) { - nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_int").template get()); - nlohmann_json_t.tag_ = Tag::AS_INT; - return; - } - if (nlohmann_json_j.contains("as_symbol")) { - nlohmann_json_t.variant_.emplace<4>(nlohmann_json_j.at("as_symbol").template get()); - nlohmann_json_t.tag_ = Tag::AS_SYMBOL; - return; - } - } -}; - -class SymExprNode { - private: - std::vector> args = {}; - std::optional target = std::nullopt; - std::optional base = std::nullopt; - - public: - - const std::vector>& get_args() const { - return args; - } - - const std::optional& get_target() const { - return target; - } - - const std::optional& get_base() const { - return base; - } - - friend void to_json(nlohmann::json& nlohmann_json_j, const SymExprNode& nlohmann_json_t); - friend void from_json(const nlohmann::json& nlohmann_json_j, SymExprNode& nlohmann_json_t); -}; - class SymExpr { private: std::string expr_str; - std::optional expr_ast = std::nullopt; std::optional hint = std::nullopt; public: @@ -365,10 +260,6 @@ class SymExpr { return expr_str; } - const std::optional& get_expr_ast() const { - return expr_ast; - } - const std::optional& get_hint() const { return hint; } @@ -2214,30 +2105,15 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, SchemaVersion& nloh inline void to_json(nlohmann::json& nlohmann_json_j, const SymExpr& nlohmann_json_t) { nlohmann_json_j["expr_str"] = nlohmann_json_t.expr_str; - nlohmann_json_j["expr_ast"] = nlohmann_json_t.expr_ast; nlohmann_json_j["hint"] = nlohmann_json_t.hint; } inline void from_json(const nlohmann::json& nlohmann_json_j, SymExpr& nlohmann_json_t) { SymExpr nlohmann_json_default_obj; nlohmann_json_t.expr_str = nlohmann_json_j.value("expr_str", nlohmann_json_default_obj.expr_str); - nlohmann_json_t.expr_ast = nlohmann_json_j.value("expr_ast", nlohmann_json_default_obj.expr_ast); nlohmann_json_t.hint = nlohmann_json_j.value("hint", nlohmann_json_default_obj.hint); } -inline void to_json(nlohmann::json& nlohmann_json_j, const SymExprNode& nlohmann_json_t) { - nlohmann_json_j["args"] = nlohmann_json_t.args; - nlohmann_json_j["target"] = nlohmann_json_t.target; - nlohmann_json_j["base"] = nlohmann_json_t.base; -} - -inline void from_json(const nlohmann::json& nlohmann_json_j, SymExprNode& nlohmann_json_t) { - SymExprNode nlohmann_json_default_obj; - nlohmann_json_t.args = nlohmann_json_j.value("args", nlohmann_json_default_obj.args); - nlohmann_json_t.target = nlohmann_json_j.value("target", nlohmann_json_default_obj.target); - nlohmann_json_t.base = nlohmann_json_j.value("base", nlohmann_json_default_obj.base); -} - inline void to_json(nlohmann::json& nlohmann_json_j, const TensorArgument& nlohmann_json_t) { nlohmann_json_j["name"] = nlohmann_json_t.name; }