Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a tp -i for interactive workflows #48

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/cpu_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ jobs:
# TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Remove and burn the token.
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
pytest
export PJRT_DEVICE=CPU
export JAX_PLATFORMS=cpu
pytest -v
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ How to run tests:
pytest
```

How to run some of the tests, and re-run them whenever you change a file:

```sh
tp -i test ... # replace with path to tests/directories
```

How to format:

```sh
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ dev = [
"click~=8.1.8",
"toml~=0.10.2",
"dataclasses-json~=0.6.7",
"watchdog~=6.0.0",
"pathspec~=0.12.1",
"xpk@git+https://github.com/AI-Hypercomputer/xpk"
]

Expand Down
148 changes: 145 additions & 3 deletions torchprime/launcher/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@
tp is a CLI for common torchprime workflows.
"""

import json
import json # noqa: I001
import os
import subprocess
import sys
import time
import threading
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path

import click
from pathspec import PathSpec
from pathspec.patterns import GitWildMatchPattern # type: ignore
import toml
from dataclasses_json import dataclass_json
from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer


@dataclass_json
Expand All @@ -25,12 +32,26 @@ class Config:
artifact_dir: str


def interactive(f):
@click.pass_context
def wrapper(ctx, *args, **kwargs):
return run_with_watcher(ctx)(f)(*args, **kwargs)

wrapper.__name__ = f.__name__
return wrapper


@click.group()
def cli():
@click.option(
"-i", "--interactive", is_flag=True, default=False, help="Enable shouting mode."
)
@click.pass_context
def cli(ctx, interactive):
"""
tp is a CLI for common torchprime workflows.
"""
pass
ctx.ensure_object(dict)
ctx.obj["interactive"] = interactive


@cli.command()
Expand Down Expand Up @@ -142,6 +163,7 @@ def create_and_activate_gcloud(gcloud_config_name, config: Config):
)
)
@click.argument("args", nargs=-1, type=click.UNPROCESSED)
@interactive
def run(args):
"""
Runs the provided SPMD training command as an xpk job on a GKE cluster.
Expand Down Expand Up @@ -196,6 +218,24 @@ def run(args):
subprocess.run(xpk_command, check=True)


@cli.command(
context_settings=dict(
ignore_unknown_options=True,
)
)
@click.argument("args", nargs=-1, type=click.UNPROCESSED)
@interactive
def test(args):
"""
Runs unit tests in torchprime by forwarding arguments to pytest.
"""
ensure_command("pytest")
try:
subprocess.run(["pytest"] + list(args), check=True)
except subprocess.CalledProcessError as e:
sys.exit(e.returncode)


def forward_env(name: str) -> list[str]:
if name in os.environ:
return ["--env", f"{name}={os.environ[name]}"]
Expand Down Expand Up @@ -242,5 +282,107 @@ def ensure_command(name: str):
) from err


class FileChangeHandler(FileSystemEventHandler):
def __init__(self, command_context, gitignore_spec):
self.command_context = command_context
self.gitignore_spec = gitignore_spec
self.last_trigger_time = time.time()
self.last_modified_file = ""
self.file_modified = threading.Condition()
self.run_command_thread = threading.Thread(target=self.run_command_thread_fn)
self.run_command_thread.daemon = True
self.run_command_thread.start()

# Trigger initial run
with self.file_modified:
self.file_modified.notify()

def on_modified(self, event):
if event.is_directory:
return

# Check if file matches gitignore patterns
relative_path = os.path.relpath(str(event.src_path), str(get_project_dir()))
if self.gitignore_spec.match_file(relative_path):
return

# Exclude `.git` directory
if ".git" in relative_path.split(os.sep):
return

# Debounce frequent modifications.
current_time = time.time()
if current_time - self.last_trigger_time > 1:
self.last_trigger_time = current_time
else:
return

# Raise a condition variable to signal that the file has been modified.
with self.file_modified:
self.last_modified_file = str(event.src_path)
self.file_modified.notify()

def run_command_thread_fn(self):
while True:
with self.file_modified:
self.file_modified.wait()
last_modified_file = self.last_modified_file
if last_modified_file:
click.echo(f"""
File {last_modified_file} modified, rerunning command...
""")
sys.argv[1] = sys.argv[1].replace("-i", "").replace("--interactive", "").strip()
main_command = " ".join(s for s in sys.argv[1:] if s != "")
subprocess.run(f"tp {main_command}", shell=True, check=False)
click.echo(f"""
Done running `tp {main_command}`.
""")


def watch_directory(project_dir, command_context):
# Load gitignore patterns
gitignore_patterns = []
gitignore_path = os.path.join(project_dir, ".gitignore")
if os.path.exists(gitignore_path):
with open(gitignore_path) as f:
gitignore_patterns = f.readlines()

# Create PathSpec object from gitignore
gitignore_spec = PathSpec.from_lines(GitWildMatchPattern, gitignore_patterns)

event_handler = FileChangeHandler(command_context, gitignore_spec)
observer = Observer()
observer.schedule(event_handler, project_dir, recursive=True)
observer.start()

try:
while True:
time.sleep(1)
except KeyboardInterrupt:
observer.stop()
observer.join()


def run_with_watcher(ctx):
"""Wrapper to run commands with file watching if interactive mode is enabled"""

def decorator(f):
def wrapper(*args, **kwargs):
# If interactive mode is enabled, start watching for changes
if ctx.obj.get("interactive"):
project_dir = get_project_dir()
click.echo(
f"Watching directory {project_dir} for changes. Press Ctrl+C to stop.\n"
)
watch_directory(project_dir, ctx)
else:
# Just run the command
return f(*args, **kwargs)

return wrapper

return decorator


if __name__ == "__main__":
cli()
Loading