Skip to content

Commit

Permalink
Revert "[dynamo][pytree][1/N] make CXX pytree traceable: tree_iter
Browse files Browse the repository at this point in the history
…/ `tree_leaves` (pytorch#137397)"

This reverts commit 07850bb.

Reverted pytorch#137397 on behalf of https://github.com/atalman due to Failing internal test ([comment](pytorch#137397 (comment)))
  • Loading branch information
pytorchmergebot committed Dec 2, 2024
1 parent eb7deb2 commit 9012e7a
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 139 deletions.
57 changes: 32 additions & 25 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import torch._dynamo.testing
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils._pytree as python_pytree
import torch.utils._pytree as pytree
import torch.utils.cpp_extension
from torch import Tensor
from torch._C import FileCheck
Expand Down Expand Up @@ -89,11 +89,9 @@
from torch.testing._internal.logging_utils import logs_to_string


HAS_OPTREE = python_pytree._cxx_pytree_exists
HAS_OPTREE = importlib.util.find_spec("optree")
if HAS_OPTREE:
import torch.utils._cxx_pytree as cxx_pytree
else:
cxx_pytree = None
import optree

MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"])
T = typing.TypeVar("T")
Expand Down Expand Up @@ -295,9 +293,9 @@ def fn(x):

@unittest.skipIf(not HAS_OPTREE, "missing optree package")
def test_optree_graph_break_message(self):
import optree

@torch.compile(backend="eager")
@torch.compile(
backend="eager",
)
def fn(x):
d = {"a": 1}
optree.tree_flatten(d)
Expand Down Expand Up @@ -8678,9 +8676,9 @@ def fn():

def test_tracing_py_tree(self):
def fn(xs):
flat_xs, spec = python_pytree.tree_flatten(xs)
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
return python_pytree.tree_unflatten(res, spec)
return pytree.tree_unflatten(res, spec)

xs = [torch.tensor(i) for i in range(3)]

Expand All @@ -8690,10 +8688,12 @@ def fn(xs):
self.assertEqual(counter.op_count, 3)

def test_tracing_nested_py_tree(self):
import torch.utils._pytree as pytree

def fn(xs):
flat_xs, spec = python_pytree.tree_flatten(xs)
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
return python_pytree.tree_unflatten(res, spec)
return pytree.tree_unflatten(res, spec)

xs = [torch.tensor(i) for i in range(3)]
xsl = [xs, xs, xs, xs]
Expand All @@ -8706,10 +8706,12 @@ def fn(xs):
self.assertEqual(counter.op_count, 12)

def test_tracing_nested_py_tree_tuples(self):
import torch.utils._pytree as pytree

def fn(xs):
flat_xs, spec = python_pytree.tree_flatten(xs)
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
return python_pytree.tree_unflatten(res, spec)
return pytree.tree_unflatten(res, spec)

xs = [torch.tensor(i) for i in range(3)]
xsl = (xs, xs, xs, xs)
Expand All @@ -8722,10 +8724,12 @@ def fn(xs):
self.assertEqual(counter.op_count, 12)

def test_tracing_nested_py_tree_dicts(self):
import torch.utils._pytree as pytree

def fn(xs):
flat_xs, spec = python_pytree.tree_flatten(xs)
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
return python_pytree.tree_unflatten(res, spec)
return pytree.tree_unflatten(res, spec)

xs = [torch.tensor(i) for i in range(3)]
xsl = {
Expand Down Expand Up @@ -8758,10 +8762,12 @@ def fn(x):
self.assertEqual(counter.op_count, 2)

def test_tracing_nested_py_tree_mixed_all(self):
import torch.utils._pytree as pytree

def fn(xs):
flat_xs, spec = python_pytree.tree_flatten(xs)
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
return python_pytree.tree_unflatten(res, spec)
return pytree.tree_unflatten(res, spec)

xs = [torch.tensor(i) for i in range(3)]
xsa = (xs, xs)
Expand Down Expand Up @@ -8806,12 +8812,13 @@ def fn(x):
self.assertEqual(cnt.frame_count, 2)

def test_tracing_py_tree_tensor_subclass(self):
import torch.utils._pytree as pytree
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils.checkpoint import checkpoint

def fn(xs):
nested_xs = [[xs]]
flat_xs, spec = python_pytree.tree_flatten(xs)
flat_xs, spec = pytree.tree_flatten(xs)
return flat_xs[0].clone()

# use checkpoint to trigger a "sourceless" tensor subclass
Expand All @@ -8826,11 +8833,13 @@ def checkpoint_fn(xs):
self.assertEqual(counter.op_count, 2)

def test_tracing_tree_map_only(self):
import torch.utils._pytree as pytree

def fn(xs):
def mapper(x):
return x.clone()

y = python_pytree.tree_map_only(torch.Tensor, mapper, xs)
y = pytree.tree_map_only(torch.Tensor, mapper, xs)
return y

xs = [torch.tensor(i) for i in range(3)] + ["hi"]
Expand Down Expand Up @@ -10184,9 +10193,7 @@ def fn(x, y):
self.assertEqual(actual, expected)

def test_pytree_tree_leaves(self):
implemtations = [("python", python_pytree)]
if cxx_pytree is not None:
implemtations.append(("cxx", cxx_pytree))
implemtations = [("python", pytree)]

for name, module in implemtations:
with self.subTest(f"pytree implement: {name}"):
Expand Down Expand Up @@ -10218,7 +10225,7 @@ def fn(x):
self.assertEqual(actual, expected)

def test_pytree_tree_flatten_unflatten(self):
implemtations = [("python", python_pytree)]
implemtations = [("python", pytree)]

for name, module in implemtations:
with self.subTest(f"pytree implement: {name}"):
Expand Down Expand Up @@ -10267,7 +10274,7 @@ def fn(x, y):
self.assertEqual(actual, expected)

def test_pytree_tree_map(self):
implemtations = [("python", python_pytree)]
implemtations = [("python", pytree)]

for name, module in implemtations:
with self.subTest(f"pytree implement: {name}"):
Expand Down
7 changes: 3 additions & 4 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -2080,11 +2080,10 @@ def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None)
obj_ref = None
# Not necessary to have weakref for Enum type, but there is a bug that
# makes hasattr(guarded_object.__class__, "__weakref__") return True.
supports_weakref = (
getattr(guarded_object.__class__, "__weakrefoffset__", 0) != 0
)
# See D64140537 for why we are checking for tuple.
if supports_weakref and not isinstance(guarded_object, (enum.Enum, tuple)):
if hasattr(guarded_object.__class__, "__weakref__") and not isinstance(
guarded_object, (enum.Enum, tuple)
):
obj_ref = weakref.ref(guarded_object)

guard.set_export_info(
Expand Down
1 change: 0 additions & 1 deletion torch/_dynamo/polyfills/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
itertools as itertools,
operator as operator,
os as os,
pytree as pytree,
sys as sys,
)

Expand Down
1 change: 0 additions & 1 deletion torch/_dynamo/polyfills/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"itertools",
"operator",
"os",
"pytree",
"sys",
)
POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple(
Expand Down
89 changes: 0 additions & 89 deletions torch/_dynamo/polyfills/pytree.py

This file was deleted.

1 change: 0 additions & 1 deletion torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3310,7 +3310,6 @@ def _module_dir(m: types.ModuleType):
"torch.testing",
"torch.utils._content_store",
"torch.utils._contextlib",
"torch.utils._cxx_pytree",
"torch.utils._device",
"torch.utils._foreach_utils",
"torch.utils._python_dispatch",
Expand Down
40 changes: 22 additions & 18 deletions torch/utils/_cxx_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
from typing_extensions import deprecated

import optree
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
from optree import PyTreeSpec # direct import for type annotations

import torch.utils._pytree as python_pytree
from torch.utils._pytree import KeyEntry as KeyEntry
import torch.utils._pytree as _pytree
from torch.utils._pytree import KeyEntry


__all__ = [
Expand Down Expand Up @@ -79,6 +79,7 @@

Context = Any
PyTree = Any
TreeSpec = PyTreeSpec
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
Expand Down Expand Up @@ -150,7 +151,9 @@ def register_pytree_node(
from_dumpable_context=from_dumpable_context,
)

python_pytree._private_register_pytree_node(
from . import _pytree as python

python._private_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
Expand Down Expand Up @@ -868,19 +871,24 @@ def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
f"treespec_dumps(spec): Expected `spec` to be instance of "
f"TreeSpec but got item of type {type(treespec)}."
)
from ._pytree import (
tree_structure as _tree_structure,
treespec_dumps as _treespec_dumps,
)

dummy_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
orig_treespec = python_pytree.tree_structure(dummy_tree)
return python_pytree.treespec_dumps(orig_treespec, protocol=protocol)
orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec))
return _treespec_dumps(orig_treespec, protocol=protocol)


def treespec_loads(serialized: str) -> TreeSpec:
"""Deserialize a treespec from a JSON string."""
orig_treespec = python_pytree.treespec_loads(serialized)
dummy_tree = python_pytree.tree_unflatten(
[0] * orig_treespec.num_leaves,
orig_treespec,
from ._pytree import (
tree_unflatten as _tree_unflatten,
treespec_loads as _treespec_loads,
)

orig_treespec = _treespec_loads(serialized)
dummy_tree = _tree_unflatten([0] * orig_treespec.num_leaves, orig_treespec)
treespec = tree_structure(dummy_tree)
return treespec

Expand Down Expand Up @@ -994,10 +1002,6 @@ def key_get(obj: Any, kp: KeyPath) -> Any:
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")


with python_pytree._NODE_REGISTRY_LOCK:
python_pytree._cxx_pytree_imported = True
args, kwargs = (), {} # type: ignore[var-annotated]
for args, kwargs in python_pytree._cxx_pytree_pending_imports:
_private_register_pytree_node(*args, **kwargs)
python_pytree._cxx_pytree_pending_imports.clear()
del args, kwargs
_pytree._cxx_pytree_imported = True
for args, kwargs in _pytree._cxx_pytree_pending_imports:
_private_register_pytree_node(*args, **kwargs)

0 comments on commit 9012e7a

Please sign in to comment.