Skip to content

Commit

Permalink
add debug option to show interpreter progress (#1680)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Jan 27, 2025
1 parent 6862a8e commit 98f286c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
28 changes: 27 additions & 1 deletion thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import contextvars
from contextlib import contextmanager
import dis
import sys
import time
import warnings
from types import (
BuiltinMethodType,
Expand Down Expand Up @@ -160,6 +162,7 @@ def __init__(
sharp_edges: SHARP_EDGES_OPTIONS,
executor_lookasides,
ad_hoc_executor,
show_interpreter_progress: bool = False,
):
self._sharp_edges: SHARP_EDGES_OPTIONS = sharp_edges
self._prologue_trace = prologue_trace
Expand All @@ -169,6 +172,26 @@ def __init__(
self._proxy_swapmap: dict[Variable, Proxy] = {}
self._executor_lookasides: dict[Callable, Callable] = executor_lookasides
self._ad_hoc_executor = ad_hoc_executor
self._show_interpreter_progress = show_interpreter_progress
self._last_printed_progress = -1
self._progress_interval = 10

def show_progress_if_verbose(self):
if not self._show_interpreter_progress:
return
t = time.perf_counter()
if t > self._last_printed_progress + self._progress_interval:
if self._last_printed_progress == -1:
print("Begin interpretation", file=sys.stderr)
num_bsyms = len(self._computation_trace.bound_symbols)
print(f"\rcaptured {num_bsyms} operations", end="", file=sys.stderr)
self._last_printed_progress = t

def end_progress(self):
if not self._show_interpreter_progress or self._last_printed_progress == -1:
return
self.show_progress_if_verbose()
print("\nFinished interpretation", file=sys.stderr)

@property
def ad_hoc_executor(self):
Expand Down Expand Up @@ -1296,6 +1319,7 @@ def should_register_for_prologue(pr, _toplevel=True):
def _general_jit_wrap_callback(value):
ctx: JitCtx = get_jit_ctx()

ctx.show_progress_if_verbose()
uvalue = value.value
# for modules, rewrite m.__dict__["key"] to m.key
if (
Expand Down Expand Up @@ -1864,6 +1888,7 @@ def update_tags(proxy_swapmap: dict[Variable, Proxy]) -> None:
DebugOptions.register_option(
"record_interpreter_history", bool, False, "record interpreter history (use thunder.last_interpreter_log to access)"
)
DebugOptions.register_option("show_interpreter_progress", bool, False, "show progress while running the interpreter")


def build_value_from_wrapped(wrapped_v):
Expand Down Expand Up @@ -1900,7 +1925,6 @@ def build_value_from_wrapped(wrapped_v):
#
# NOTE: The `class` packagename1_MyContainer will present in `import_ctx` and passed to the compiled function.
# This is taken care of by function `to_printable`.

kwargs = {}
for k, uattr in v.__dict__.items():
if k in wrapped_v.attribute_wrappers:
Expand Down Expand Up @@ -1973,6 +1997,7 @@ def thunder_general_jit(
sharp_edges=sharp_edges,
executor_lookasides=executor_lookasides,
ad_hoc_executor=ad_hoc_executor,
show_interpreter_progress=compile_data.debug_options.show_interpreter_progress,
)
jfn = interpret(
fn,
Expand All @@ -1997,6 +2022,7 @@ def thunder_general_jit(
res = build_value_from_wrapped(result)
prims.python_return(res)

ctx.end_progress()
pro_to_comp, pro_to_comp_set, computation_intermediates = get_computation_inputs_and_intermediates(
computation_trace
)
Expand Down
5 changes: 2 additions & 3 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3097,11 +3097,12 @@ def test_debug_options():
import dill

initial_state = dill.dumps(dict(DebugOptions.__dict__))
print(DebugOptions.__dict__)
DebugOptions.register_option("test_option", bool, False, "Test Option")

assert "Test Option" in DebugOptions.__doc__

do = DebugOptions()
assert do.test_option is False
do = DebugOptions(test_option=True)
assert do.test_option is True

Expand All @@ -3114,8 +3115,6 @@ def test_debug_options():
del DebugOptions.test_option

DebugOptions._set_docstring()

print(DebugOptions.__dict__)
assert dill.dumps(dict(DebugOptions.__dict__)) == initial_state


Expand Down

0 comments on commit 98f286c

Please sign in to comment.