Skip to content

Commit

Permalink
Revert "[export] serialize sympy.Exprs as ASTs instead of strings (py…
Browse files Browse the repository at this point in the history
…torch#140084)"

This reverts commit d869344.

Reverted pytorch#140084 on behalf of https://github.com/izaitsevfb due to reverted internally in D66253238 ([comment](pytorch#140084 (comment)))
  • Loading branch information
pytorchmergebot committed Nov 21, 2024
1 parent da94ab0 commit d3c8f1a
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 365 deletions.
5 changes: 4 additions & 1 deletion test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -8157,6 +8157,7 @@ def forward(self, x):
export(f, (inputs,), dynamic_shapes=dynamic_shapes)

@testing.expectedFailureRetraceabilityNonStrict
@testing.expectedFailureCppSerDes # dynamic shape serialization
def test_disable_forced_specializations_ok(self):
# check that we don't force specialization, and defer to runtime asserts
# with allow_complex_guards_as_runtime_asserts=True to successfully export
Expand Down Expand Up @@ -9210,7 +9211,9 @@ def forward(self, input1: torch.Tensor):
inps = (torch.randn(1, 224, 768, device="cpu"),)
export(Foo(), inps)

@testing.expectedFailureCppSerDes
@testing.expectedFailureSerDer # TODO(pianpwk): PowByNatural valuerange deserialization
@testing.expectedFailureCppSerDes # TODO(pianpwk): PowByNatural valuerange deserialization
@testing.expectedFailureSerDerNonStrict
@testing.expectedFailureRetraceabilityNonStrict
def test_dim_dynamic(self):
dynamic = Dim.DYNAMIC
Expand Down
35 changes: 3 additions & 32 deletions torch/_export/serde/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch._export.serde.union import _Union

# NOTE: Please update this value if any modifications are made to the schema
SCHEMA_VERSION = (8, 2)
SCHEMA_VERSION = (8, 1)
TREESPEC_VERSION = 1


Expand Down Expand Up @@ -62,42 +62,13 @@ class SymExprHint(_Union):
as_bool: bool


# A leaf node in a SymExprNode, containing a bool/float/int/sympy.Symbol.
@dataclass(repr=False)
class SymBase(_Union):
as_bool: bool
as_float: float
as_int: int
as_symbol: str


# Represents an AST node in a sympy.Expr.
# If not a leaf node, "target" is a string representing the operator,
# and "args" is a list of child SymExprNodes.
# If a leaf node, "target" is None, "args" is empty, and "base" is a SymBase
# representing the leaf value.
@dataclass(repr=False)
class SymExprNode:
args: List["SymExprNode"] = field(default_factory=list)
target: Optional[str] = None
base: Optional[SymBase] = None


# This is for storing the symbolic expressions behind symints/symfloats/symbools
# The deprecated "expr_str" field is easier to explain; we could store
# For example, we can get something like
# SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4)
# for an expression where s0 == s1 == 2.
# We're moving away from expr_str for roundtrippability, and now deserialize into
# the "expr_ast" field, which is a tree representation of the expression,
# containing a root SymExprNode.
# While we're deprecating this, we'll store an empty string in "expr_str" for now,
# and support deserialization for both "expr_str" and "expr_ast" fields.
# We'll only serialize to "expr_ast".
# TODO(pianpwk): implement upgrader & delete.
# if we also have the hint that s0 and s1 are both 2.
@dataclass
class SymExpr:
expr_str: str
expr_ast: Optional[SymExprNode] = None
hint: Optional[SymExprHint] = None


Expand Down
30 changes: 2 additions & 28 deletions torch/_export/serde/schema.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# @generated by update_schema.py
# checksum<<74ae9f550efb42873fc58f07990cefe4da35d38c673beeacbe1e20808d9a6962>>
# checksum<<8e27d48014d4ec1c773aef056c0c20b61bead54be8338e95d3347d3422472b9a>>
Argument:
kind: union
fields:
Expand Down Expand Up @@ -346,17 +346,6 @@ SchemaVersion:
type: int
minor:
type: int
SymBase:
kind: union
fields:
as_bool:
type: bool
as_float:
type: float
as_int:
type: int
as_symbol:
type: str
SymBool:
kind: union
fields:
Expand All @@ -376,9 +365,6 @@ SymExpr:
fields:
expr_str:
type: str
expr_ast:
type: Optional[SymExprNode]
default: None
hint:
type: Optional[SymExprHint]
default: None
Expand All @@ -391,18 +377,6 @@ SymExprHint:
type: float
as_bool:
type: bool
SymExprNode:
kind: struct
fields:
args:
type: List[SymExprNode]
default: '[]'
target:
type: Optional[str]
default: None
base:
type: Optional[SymBase]
default: None
SymInt:
kind: union
fields:
Expand Down Expand Up @@ -463,5 +437,5 @@ UserOutputSpec:
type: Argument
SCHEMA_VERSION:
- 8
- 2
- 1
TREESPEC_VERSION: 1
4 changes: 1 addition & 3 deletions torch/_export/serde/schema_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
import typing
from enum import IntEnum
from typing import Any, Dict, ForwardRef, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from torch._export.serde import schema
from torch._export.serde.union import _Union
Expand Down Expand Up @@ -71,8 +71,6 @@ def dump_type(t) -> Tuple[str, str]:
)
elif t == ():
return "()", ""
elif isinstance(t, ForwardRef):
return t.__forward_arg__, f"ForwardRef<{t.__forward_arg__}>"
else:
raise AssertionError(f"Type {t} is not supported in export schema.")

Expand Down
Loading

0 comments on commit d3c8f1a

Please sign in to comment.