Skip to content

Commit

Permalink
Add InvocationContext.functional(). (apple#618)
Browse files Browse the repository at this point in the history
  • Loading branch information
apghml authored Aug 2, 2024
1 parent 339591e commit 20db3b0
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 26 deletions.
121 changes: 99 additions & 22 deletions axlearn/common/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def do_foo(self, ...):
import contextlib
import copy
import dataclasses
import functools
import hashlib
import inspect
import os.path
Expand All @@ -50,11 +51,18 @@ def do_foo(self, ...):
from absl import logging
from typing_extensions import Protocol

from axlearn.common import traceback_util
from axlearn.common import struct, traceback_util
from axlearn.common.config import REQUIRED, Configurable, Required, RequiredFieldValue, config_class
from axlearn.common.summary import Summary
from axlearn.common.traceback_util import annotate_stack, no_stack_summary
from axlearn.common.utils import Nested, NestedTensor, Tensor, partial_with_fn_metadata, prune_tree
from axlearn.common.utils import (
Nested,
NestedTensor,
Tensor,
partial_with_fn_metadata,
prune_tree,
raise_for_cycles,
)


def _generate_seed_from_name(name: str) -> np.int64:
Expand Down Expand Up @@ -178,7 +186,7 @@ def __add__(self, other: T) -> T:


# TODO(markblee): Link to docs on invocation contexts.
@dataclass
@functools.partial(struct.dataclass, frozen=False)
class InvocationContext: # pylint: disable=too-many-instance-attributes
"""The invocation context for `Module.__call__()`.
Expand All @@ -193,13 +201,13 @@ class InvocationContext: # pylint: disable=too-many-instance-attributes
output_collection: See `OutputCollection`.
"""

name: str
parent: Optional["InvocationContext"]
module: Optional["Module"]
state: NestedTensor
is_training: bool
prng_key: Optional[Tensor]
output_collection: OutputCollection
name: str = struct.field(pytree_node=False)
parent: Optional["InvocationContext"] = struct.field(pytree_node=True)
module: Optional["Module"] = struct.field(pytree_node=False)
state: NestedTensor = struct.field(pytree_node=True)
is_training: bool = struct.field(pytree_node=False)
prng_key: Optional[Tensor] = struct.field(pytree_node=True)
output_collection: OutputCollection = struct.field(pytree_node=True)

def path(self):
if self.parent is None:
Expand Down Expand Up @@ -334,6 +342,24 @@ def get_state_updates(self):
def get_module_outputs(self):
return self.output_collection.module_outputs

def functional(self, method_fn: Callable) -> "_Functional":
"""Transforms `method_fn` (with this context) into a pure functional Callable.
The returned Callable will have the same behavior as `method_fn`, except that it runs
inside this context instead of the current context and returns
an OutputCollection in addition to the method output instead of mutating the context it
runs it.
This context and the arguments to `method_fn` are not modified by the call.
Args:
method_fn: The function to call.
Returns:
The callable described above.
"""
return _Functional(method_fn=method_fn, context=self, require_parent=False)


@dataclass
class ContextStack(threading.local):
Expand Down Expand Up @@ -389,10 +415,10 @@ def current_context() -> Optional[InvocationContext]:


@contextlib.contextmanager
def set_current_context(context: InvocationContext):
def set_current_context(context: InvocationContext, *, require_parent: bool = True):
if _global_context_stack.stack:
cur_context = _global_context_stack.stack[-1]
if context.parent is not cur_context:
if context.parent is not cur_context and require_parent:
raise ValueError(
f"context ({context.path()})'s parent "
f"must match the current context ({cur_context.path()}). "
Expand Down Expand Up @@ -800,6 +826,53 @@ def nullary():
return nullary


@functools.partial(struct.dataclass, frozen=False)
class _Functional:
"""A pure functional call to `method_fn`."""

# The function to call.
method_fn: Callable = struct.field(pytree_node=False)
# The context to call method_fn in.
# This will be copied to prevent method_fn from mutating the original.
context: InvocationContext = struct.field(pytree_node=True)
# Whether to require that context.parent is current_context().
require_parent: bool = struct.field(pytree_node=False)

def __call__(self, *args, **kwargs) -> Tuple[Any, OutputCollection]:
"""Invokes method_fn in a pure functional fashion.
The invocation will not depend on external inputs or have any side effects. The results only
depend on the given inputs. All outputs are reflected in the return value.
Args:
*args: Positional arguments to method_fn.
**kwargs: Keyword arguments to method_fn.
Returns:
(method_outputs, output_collection), where
- method_outputs are the return value of the method.
- output_collection is an OutputCollection containing summaries and state updates.
Raises:
ValueError: If there are circular references in args, kwargs, or context.
"""
call = getattr(self.method_fn, "__qualname__", None) or getattr(self.method_fn, "__name__")
logging.vlog(1, "functional: %s.%s %s(*%s, **%s)", call, self.method_fn, args, kwargs)

# Copy to prevent method_fn from mutating the original.
# Some badly behaved tests call F() with an InvocationContext.state that contains
# circular references.
# This results in a cryptic error that doesn't make the root cause obvious.
# So we raise a clearer error explicitly.
raise_for_cycles(dict(context=self.context, args=args, kwargs=kwargs))
context, args, kwargs = jax.tree_util.tree_map(lambda x: x, (self.context, args, kwargs))

with set_current_context(context, require_parent=self.require_parent):
# pylint: disable-next=not-an-iterable,not-a-mapping,not-callable
method_outputs = self.method_fn(*args, **kwargs)
return method_outputs, context.output_collection


def functional(
module: Module,
prng_key: Optional[Tensor],
Expand Down Expand Up @@ -832,6 +905,9 @@ def functional(
(method_outputs, output_collection), where
- method_outputs are the return value of the method.
- output_collection is an OutputCollection containing summaries and state updates.
Raises:
ValueError: If there are circular references in args, kwargs, or context.
"""
context = InvocationContext(
name="root",
Expand All @@ -843,19 +919,20 @@ def functional(
prng_key=prng_key,
)

args = []
kwargs = {}
if isinstance(inputs, dict):
kwargs = inputs
else:
args = inputs
method_fn = getattr(module, method)
logging.vlog(1, "functional: %s.%s %s(%s)", module, method, method_fn, inputs)
with set_current_context(context):
if isinstance(inputs, dict):
input_args, input_kwargs = [], inputs
else:
input_args, input_kwargs = inputs, {}
method_outputs = method_fn(*input_args, **input_kwargs)

for output_collection_type in drop_output_collections:
getattr(context.output_collection, output_collection_type).clear()
fn = _Functional(context=context, method_fn=method_fn, require_parent=True)
method_outputs, output_collection = fn(*args, **kwargs)

return method_outputs, context.output_collection
for output_collection_type in drop_output_collections:
getattr(output_collection, output_collection_type).clear()
return method_outputs, output_collection


def scan_in_context(
Expand Down
55 changes: 55 additions & 0 deletions axlearn/common/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

"""Tests for module.py."""
# pylint: disable=protected-access
# type: ignore[attribute-error]
import contextlib
import threading
from typing import List, Optional, Union
Expand All @@ -14,6 +15,7 @@

from axlearn.common import summary, test_utils
from axlearn.common.config import REQUIRED, Required, config_class
from axlearn.common.layers import Linear
from axlearn.common.metrics import WeightedScalar
from axlearn.common.module import (
InvocationContext,
Expand Down Expand Up @@ -165,6 +167,24 @@ def test_context_stack(self):
with set_current_context(context2):
pass

def test_nested_context(self):
"""Test calling `set_current_context(..., require_parent=False)."""
module1 = new_test_module("test1")
module2 = new_test_module("test2")
context1 = InvocationContext(
name="context1",
parent=None, # root context
module=module1,
is_training=True,
prng_key=jax.random.PRNGKey(123),
state={"x": 1},
output_collection=new_output_collection(),
)
context2 = context1.add_child("context2", module=module2, state={"x": 2})
context2.parent = None
with set_current_context(context2, require_parent=False) as ctx:
self.assertEqual(ctx.parent, None)

def test_context_stack_mutlithread(self):
module1 = new_test_module("root")
module1._add_child("child1", TestModule.default_config())
Expand Down Expand Up @@ -263,6 +283,41 @@ def test_add_summary_validation(
stack.enter_context(self.assertRaises(ValueError))
ctx.add_summary("summary", value)

def test_functional(self):
"""Tests the `Functional` class and `InvocationContext.functional()`."""
with test_utils.bind_layer(Linear.default_config().set(input_dim=5, output_dim=5)) as layer:

def fn(x: Tensor, y: Tensor) -> Tensor:
current_context().add_state_update("my_state", y)
return layer(x)

args = [jnp.ones(5)]
kwargs = dict(y=jnp.zeros(3))
new_fn = current_context().functional(fn)
old_ctx = current_context()
result, output_collection = new_fn(*args, **kwargs)
self.assertIs(current_context(), old_ctx)
self.assertNestedEqual(current_context().output_collection, new_output_collection())
# The below line would cause an output conflict error if we had called fn() instead of
# new_fn() on the line above. But it doesn't since new_fn() restores the context to its
# original state after the call.
result2 = fn(*args, **kwargs)
self.assertNestedEqual(result, result2)
self.assertNestedEqual(output_collection, current_context().output_collection)

def test_functional_with_method_call(self):
"""Demonstrates usage of `InvocationContext.functional()` with a module method instead of an
anonymous function.
"""
with test_utils.bind_layer(Linear.default_config().set(input_dim=5, output_dim=5)) as layer:
args = [jnp.ones(5)]
new_fn = current_context().functional(layer.forward)
old_ctx = current_context()
result, _ = new_fn(*args)
self.assertIs(current_context(), old_ctx)
result2 = layer.forward(*args)
self.assertNestedEqual(result, result2)


class NestedModule(Module):
"""A nested module."""
Expand Down
83 changes: 83 additions & 0 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import numpy as np
from absl import logging
from jax import numpy as jnp
from jax._src.tree_util import KeyEntry, KeyPath
from jax.experimental import maps, mesh_utils, multihost_utils
from jax.sharding import PartitionSpec
from jax.tree_util import register_pytree_node_class
Expand Down Expand Up @@ -1357,3 +1358,85 @@ def thread_stack_traces() -> Sequence[Sequence[str]]:
lines.append(f">>> {line.rstrip()}")
grouped_lines.append(lines)
return grouped_lines


def pytree_children(node: Any) -> Sequence[Tuple[KeyEntry, Any]]:
"""Generate the (key, value) pairs for the immediate children of a pytree `node`.
The returned children match those returned by
`jax.tree_util.default_registry.flatten_one_level()`.
Reference: jax._src.tree_util.generate_key_paths()
Example:
```
assert pytree_children(dict(a=[1,2])) == [(DictKey('a'), [1,2])]
```
"""
# pylint: disable-next=protected-access
registry_with_keypaths = jax._src.tree_util._registry_with_keypaths

key_handler = registry_with_keypaths.get(type(node))
if key_handler:
key_children, _ = key_handler.flatten_with_keys(node)
return key_children

flat = jax.tree_util.default_registry.flatten_one_level(node)
if flat is None:
return []

if isinstance(node, tuple) and hasattr(node, "_fields") and flat[1] == type(node):
# Handle namedtuple as a special case, based on heuristic.
return [(jax.tree_util.GetAttrKey(s), getattr(node, s)) for s in node._fields]
return [(jax.tree_util.FlattenedIndexKey(i), c) for i, c in enumerate(flat[0])]


def find_cycles(tree: Nested) -> dict[str, KeyPath]:
"""Find a cycle in pytree `tree` if one exists.
This function finds a descendant which has reference equality with one of its own
ancestors, if one exists.
Args:
tree: The tree to find cycles in.
Returns:
If no cycle is found, an empty dict.
If a cycle is found a dict with keys:
* descendant: The KeyPath to the descendant.
* ancestor: The KeyPath to the ancestor.
"""

def _find_cycles(tree: Nested, *, key_path: KeyPath, seen: list[int]) -> dict[str, KeyPath]:
# DFS and check if path to root contains repeats.
# This is quadratic time in the depth of the tree but could be made linear
# time with a small amount of additional implementation complexity.
uid = id(tree)
if uid in seen:
result = dict(descendant=key_path[:], ancestor=key_path[: seen.index(uid)])
return result
seen.append(uid)
items = pytree_children(tree)
for key, child in items:
key_path.append(key)
result = _find_cycles(child, key_path=key_path, seen=seen)
key_path.pop()
if result:
return result
seen.pop()
return {}

return _find_cycles(tree, key_path=[], seen=[])


def raise_for_cycles(tree: Any):
"""Raise an informative error message if `tree` contains cycles."""

cycles = find_cycles(tree)
if cycles:
raise ValueError(
"Circular reference in args, kwargs, or context.\n"
"Descendant refers to ancestor.\n"
f"Descendant KeyPath: {cycles['descendant']}.\n"
f"Ancestor KeyPath: {cycles['ancestor']}."
)
Loading

0 comments on commit 20db3b0

Please sign in to comment.