diff --git a/axlearn/common/module.py b/axlearn/common/module.py index 7d216d52..fb1ac4f0 100644 --- a/axlearn/common/module.py +++ b/axlearn/common/module.py @@ -37,6 +37,7 @@ def do_foo(self, ...): import contextlib import copy import dataclasses +import functools import hashlib import inspect import os.path @@ -50,11 +51,18 @@ def do_foo(self, ...): from absl import logging from typing_extensions import Protocol -from axlearn.common import traceback_util +from axlearn.common import struct, traceback_util from axlearn.common.config import REQUIRED, Configurable, Required, RequiredFieldValue, config_class from axlearn.common.summary import Summary from axlearn.common.traceback_util import annotate_stack, no_stack_summary -from axlearn.common.utils import Nested, NestedTensor, Tensor, partial_with_fn_metadata, prune_tree +from axlearn.common.utils import ( + Nested, + NestedTensor, + Tensor, + partial_with_fn_metadata, + prune_tree, + raise_for_cycles, +) def _generate_seed_from_name(name: str) -> np.int64: @@ -178,7 +186,7 @@ def __add__(self, other: T) -> T: # TODO(markblee): Link to docs on invocation contexts. -@dataclass +@functools.partial(struct.dataclass, frozen=False) class InvocationContext: # pylint: disable=too-many-instance-attributes """The invocation context for `Module.__call__()`. @@ -193,13 +201,13 @@ class InvocationContext: # pylint: disable=too-many-instance-attributes output_collection: See `OutputCollection`. """ - name: str - parent: Optional["InvocationContext"] - module: Optional["Module"] - state: NestedTensor - is_training: bool - prng_key: Optional[Tensor] - output_collection: OutputCollection + name: str = struct.field(pytree_node=False) + parent: Optional["InvocationContext"] = struct.field(pytree_node=True) + module: Optional["Module"] = struct.field(pytree_node=False) + state: NestedTensor = struct.field(pytree_node=True) + is_training: bool = struct.field(pytree_node=False) + prng_key: Optional[Tensor] = struct.field(pytree_node=True) + output_collection: OutputCollection = struct.field(pytree_node=True) def path(self): if self.parent is None: @@ -334,6 +342,24 @@ def get_state_updates(self): def get_module_outputs(self): return self.output_collection.module_outputs + def functional(self, method_fn: Callable) -> "_Functional": + """Transforms `method_fn` (with this context) into a pure functional Callable. + + The returned Callable will have the same behavior as `method_fn`, except that it runs + inside this context instead of the current context and returns + an OutputCollection in addition to the method output instead of mutating the context it + runs it. + + This context and the arguments to `method_fn` are not modified by the call. + + Args: + method_fn: The function to call. + + Returns: + The callable described above. + """ + return _Functional(method_fn=method_fn, context=self, require_parent=False) + @dataclass class ContextStack(threading.local): @@ -389,10 +415,10 @@ def current_context() -> Optional[InvocationContext]: @contextlib.contextmanager -def set_current_context(context: InvocationContext): +def set_current_context(context: InvocationContext, *, require_parent: bool = True): if _global_context_stack.stack: cur_context = _global_context_stack.stack[-1] - if context.parent is not cur_context: + if context.parent is not cur_context and require_parent: raise ValueError( f"context ({context.path()})'s parent " f"must match the current context ({cur_context.path()}). " @@ -800,6 +826,53 @@ def nullary(): return nullary +@functools.partial(struct.dataclass, frozen=False) +class _Functional: + """A pure functional call to `method_fn`.""" + + # The function to call. + method_fn: Callable = struct.field(pytree_node=False) + # The context to call method_fn in. + # This will be copied to prevent method_fn from mutating the original. + context: InvocationContext = struct.field(pytree_node=True) + # Whether to require that context.parent is current_context(). + require_parent: bool = struct.field(pytree_node=False) + + def __call__(self, *args, **kwargs) -> Tuple[Any, OutputCollection]: + """Invokes method_fn in a pure functional fashion. + + The invocation will not depend on external inputs or have any side effects. The results only + depend on the given inputs. All outputs are reflected in the return value. + + Args: + *args: Positional arguments to method_fn. + **kwargs: Keyword arguments to method_fn. + + Returns: + (method_outputs, output_collection), where + - method_outputs are the return value of the method. + - output_collection is an OutputCollection containing summaries and state updates. + + Raises: + ValueError: If there are circular references in args, kwargs, or context. + """ + call = getattr(self.method_fn, "__qualname__", None) or getattr(self.method_fn, "__name__") + logging.vlog(1, "functional: %s.%s %s(*%s, **%s)", call, self.method_fn, args, kwargs) + + # Copy to prevent method_fn from mutating the original. + # Some badly behaved tests call F() with an InvocationContext.state that contains + # circular references. + # This results in a cryptic error that doesn't make the root cause obvious. + # So we raise a clearer error explicitly. + raise_for_cycles(dict(context=self.context, args=args, kwargs=kwargs)) + context, args, kwargs = jax.tree_util.tree_map(lambda x: x, (self.context, args, kwargs)) + + with set_current_context(context, require_parent=self.require_parent): + # pylint: disable-next=not-an-iterable,not-a-mapping,not-callable + method_outputs = self.method_fn(*args, **kwargs) + return method_outputs, context.output_collection + + def functional( module: Module, prng_key: Optional[Tensor], @@ -832,6 +905,9 @@ def functional( (method_outputs, output_collection), where - method_outputs are the return value of the method. - output_collection is an OutputCollection containing summaries and state updates. + + Raises: + ValueError: If there are circular references in args, kwargs, or context. """ context = InvocationContext( name="root", @@ -843,19 +919,20 @@ def functional( prng_key=prng_key, ) + args = [] + kwargs = {} + if isinstance(inputs, dict): + kwargs = inputs + else: + args = inputs method_fn = getattr(module, method) - logging.vlog(1, "functional: %s.%s %s(%s)", module, method, method_fn, inputs) - with set_current_context(context): - if isinstance(inputs, dict): - input_args, input_kwargs = [], inputs - else: - input_args, input_kwargs = inputs, {} - method_outputs = method_fn(*input_args, **input_kwargs) - for output_collection_type in drop_output_collections: - getattr(context.output_collection, output_collection_type).clear() + fn = _Functional(context=context, method_fn=method_fn, require_parent=True) + method_outputs, output_collection = fn(*args, **kwargs) - return method_outputs, context.output_collection + for output_collection_type in drop_output_collections: + getattr(output_collection, output_collection_type).clear() + return method_outputs, output_collection def scan_in_context( diff --git a/axlearn/common/module_test.py b/axlearn/common/module_test.py index a05383c0..ffde4ab3 100644 --- a/axlearn/common/module_test.py +++ b/axlearn/common/module_test.py @@ -2,6 +2,7 @@ """Tests for module.py.""" # pylint: disable=protected-access +# type: ignore[attribute-error] import contextlib import threading from typing import List, Optional, Union @@ -14,6 +15,7 @@ from axlearn.common import summary, test_utils from axlearn.common.config import REQUIRED, Required, config_class +from axlearn.common.layers import Linear from axlearn.common.metrics import WeightedScalar from axlearn.common.module import ( InvocationContext, @@ -165,6 +167,24 @@ def test_context_stack(self): with set_current_context(context2): pass + def test_nested_context(self): + """Test calling `set_current_context(..., require_parent=False).""" + module1 = new_test_module("test1") + module2 = new_test_module("test2") + context1 = InvocationContext( + name="context1", + parent=None, # root context + module=module1, + is_training=True, + prng_key=jax.random.PRNGKey(123), + state={"x": 1}, + output_collection=new_output_collection(), + ) + context2 = context1.add_child("context2", module=module2, state={"x": 2}) + context2.parent = None + with set_current_context(context2, require_parent=False) as ctx: + self.assertEqual(ctx.parent, None) + def test_context_stack_mutlithread(self): module1 = new_test_module("root") module1._add_child("child1", TestModule.default_config()) @@ -263,6 +283,41 @@ def test_add_summary_validation( stack.enter_context(self.assertRaises(ValueError)) ctx.add_summary("summary", value) + def test_functional(self): + """Tests the `Functional` class and `InvocationContext.functional()`.""" + with test_utils.bind_layer(Linear.default_config().set(input_dim=5, output_dim=5)) as layer: + + def fn(x: Tensor, y: Tensor) -> Tensor: + current_context().add_state_update("my_state", y) + return layer(x) + + args = [jnp.ones(5)] + kwargs = dict(y=jnp.zeros(3)) + new_fn = current_context().functional(fn) + old_ctx = current_context() + result, output_collection = new_fn(*args, **kwargs) + self.assertIs(current_context(), old_ctx) + self.assertNestedEqual(current_context().output_collection, new_output_collection()) + # The below line would cause an output conflict error if we had called fn() instead of + # new_fn() on the line above. But it doesn't since new_fn() restores the context to its + # original state after the call. + result2 = fn(*args, **kwargs) + self.assertNestedEqual(result, result2) + self.assertNestedEqual(output_collection, current_context().output_collection) + + def test_functional_with_method_call(self): + """Demonstrates usage of `InvocationContext.functional()` with a module method instead of an + anonymous function. + """ + with test_utils.bind_layer(Linear.default_config().set(input_dim=5, output_dim=5)) as layer: + args = [jnp.ones(5)] + new_fn = current_context().functional(layer.forward) + old_ctx = current_context() + result, _ = new_fn(*args) + self.assertIs(current_context(), old_ctx) + result2 = layer.forward(*args) + self.assertNestedEqual(result, result2) + class NestedModule(Module): """A nested module.""" diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 6160b8fe..93cd7f7c 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -39,6 +39,7 @@ import numpy as np from absl import logging from jax import numpy as jnp +from jax._src.tree_util import KeyEntry, KeyPath from jax.experimental import maps, mesh_utils, multihost_utils from jax.sharding import PartitionSpec from jax.tree_util import register_pytree_node_class @@ -1357,3 +1358,85 @@ def thread_stack_traces() -> Sequence[Sequence[str]]: lines.append(f">>> {line.rstrip()}") grouped_lines.append(lines) return grouped_lines + + +def pytree_children(node: Any) -> Sequence[Tuple[KeyEntry, Any]]: + """Generate the (key, value) pairs for the immediate children of a pytree `node`. + + The returned children match those returned by + `jax.tree_util.default_registry.flatten_one_level()`. + + Reference: jax._src.tree_util.generate_key_paths() + + Example: + ``` + assert pytree_children(dict(a=[1,2])) == [(DictKey('a'), [1,2])] + ``` + """ + # pylint: disable-next=protected-access + registry_with_keypaths = jax._src.tree_util._registry_with_keypaths + + key_handler = registry_with_keypaths.get(type(node)) + if key_handler: + key_children, _ = key_handler.flatten_with_keys(node) + return key_children + + flat = jax.tree_util.default_registry.flatten_one_level(node) + if flat is None: + return [] + + if isinstance(node, tuple) and hasattr(node, "_fields") and flat[1] == type(node): + # Handle namedtuple as a special case, based on heuristic. + return [(jax.tree_util.GetAttrKey(s), getattr(node, s)) for s in node._fields] + return [(jax.tree_util.FlattenedIndexKey(i), c) for i, c in enumerate(flat[0])] + + +def find_cycles(tree: Nested) -> dict[str, KeyPath]: + """Find a cycle in pytree `tree` if one exists. + + This function finds a descendant which has reference equality with one of its own + ancestors, if one exists. + + Args: + tree: The tree to find cycles in. + + Returns: + If no cycle is found, an empty dict. + If a cycle is found a dict with keys: + * descendant: The KeyPath to the descendant. + * ancestor: The KeyPath to the ancestor. + """ + + def _find_cycles(tree: Nested, *, key_path: KeyPath, seen: list[int]) -> dict[str, KeyPath]: + # DFS and check if path to root contains repeats. + # This is quadratic time in the depth of the tree but could be made linear + # time with a small amount of additional implementation complexity. + uid = id(tree) + if uid in seen: + result = dict(descendant=key_path[:], ancestor=key_path[: seen.index(uid)]) + return result + seen.append(uid) + items = pytree_children(tree) + for key, child in items: + key_path.append(key) + result = _find_cycles(child, key_path=key_path, seen=seen) + key_path.pop() + if result: + return result + seen.pop() + return {} + + return _find_cycles(tree, key_path=[], seen=[]) + + +def raise_for_cycles(tree: Any): + """Raise an informative error message if `tree` contains cycles.""" + + cycles = find_cycles(tree) + if cycles: + raise ValueError( + "Circular reference in args, kwargs, or context.\n" + "Descendant refers to ancestor.\n" + f"Descendant KeyPath: {cycles['descendant']}.\n" + f"Ancestor KeyPath: {cycles['ancestor']}." + ) diff --git a/axlearn/common/utils_test.py b/axlearn/common/utils_test.py index 3424c1e5..99f30daa 100644 --- a/axlearn/common/utils_test.py +++ b/axlearn/common/utils_test.py @@ -57,6 +57,7 @@ create_device_mesh, dispatch_input_batch, expand_vdicts, + find_cycles, flatten_items, get_data_dir, get_recursively, @@ -64,6 +65,7 @@ input_partition_spec, match_regex_rules, prune_tree, + pytree_children, runtime_checks, set_data_dir, set_recursively, @@ -251,6 +253,68 @@ def test_as_tensor(self): ), ) + def test_pytree_children(self): + # DictKey + original_tree = dict(a=3, b=2, c=dict(d=1)) + tree = original_tree + self.assertSequenceEqual( + pytree_children(tree), [(jax.tree_util.DictKey(k), v) for k, v in original_tree.items()] + ) + + # SequenceKey + tree = tuple(tree.values()) + self.assertSequenceEqual( + pytree_children(tree), + [(jax.tree_util.SequenceKey(k), v) for k, v in enumerate(original_tree.values())], + ) + + # GetAttrKey with NamedTuple + class TestNamedTuple(NamedTuple): + a: int + b: int + c: dict + + tree = TestNamedTuple(**original_tree) + self.assertSequenceEqual( + pytree_children(tree), + [(jax.tree_util.GetAttrKey(k), v) for k, v in original_tree.items()], + ) + + # FlattenedIndexKey + @dataclasses.dataclass + class TestUnstructured: + a: int + b: int + c: dict + + jax.tree_util.register_pytree_node( + TestUnstructured, + flatten_func=lambda x: ((x.a, x.b, x.c), None), + unflatten_func=lambda x, _: TestUnstructured(*x), + ) + tree = TestUnstructured(**original_tree) + self.assertSequenceEqual( + pytree_children(tree), + [(jax.tree_util.FlattenedIndexKey(k), v) for k, v in enumerate(original_tree.values())], + ) + + # No children + self.assertSequenceEqual(pytree_children([]), []) + + # No children + self.assertSequenceEqual(pytree_children(3), []) + + def test_find_cycles(self): + x = {} + y = dict(a=x, b=x, c=x) + self.assertFalse(find_cycles(y)) + + y = dict(a=dict(b={}), c=dict(d={})) + y["c"]["d"]["e"] = y["c"] + ancestor = [jax.tree_util.DictKey("c")] + descendant = ancestor + [jax.tree_util.DictKey("d"), jax.tree_util.DictKey("e")] + self.assertEqual(find_cycles(y), dict(ancestor=ancestor, descendant=descendant)) + def assertNumpyArrayEqual(self, a, b): self.assertIsInstance(a, np.ndarray) self.assertIsInstance(b, np.ndarray) diff --git a/axlearn/vision/coca_test.py b/axlearn/vision/coca_test.py index 55e0b871..02d13cbe 100644 --- a/axlearn/vision/coca_test.py +++ b/axlearn/vision/coca_test.py @@ -166,7 +166,7 @@ def test_coca_textual_encoder(self, act_fn): # Parameters required by CLIP textual encoder. params["output_proj"] = params["contrastive_output_proj"] params["output_norm"] = params["contrastive_output_norm"] - params["text_encoder"]["encoder"] = params["text_encoder"] + params["text_encoder"]["encoder"] = dict(params["text_encoder"]) coca_outputs, _ = F( coca_textual_encoder, @@ -274,9 +274,9 @@ def _compare_against_clip_model( clip_params["textual_encoder"]["output_norm"] = coca_params["textual_encoder"][ "contrastive_output_norm" ] - clip_params["textual_encoder"]["text_encoder"]["encoder"] = coca_params["textual_encoder"][ - "text_encoder" - ] + clip_params["textual_encoder"]["text_encoder"]["encoder"] = dict( + coca_params["textual_encoder"]["text_encoder"] + ) clip_params["fusion_network"] = coca_params["contrastive_fusion_network"] coca_outputs, _ = F(