Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use regex for parsing step_dir #739

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions axlearn/common/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import enum
import json
import os.path
import re
import tempfile
import threading
import time
Expand Down Expand Up @@ -81,8 +82,10 @@ class CheckpointValidationType(str, enum.Enum):


def parse_step_from_dir(step_dir: str) -> int:
# TODO(markblee): use regex.
return int(step_dir[-STEP_NUM_DIGITS:])
step = re.findall(r'(\d{' + str(STEP_NUM_DIGITS) + r'})', step_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I do think we should be restrictive about the patterns that we accept, specifically that it ends with STEP_NUM_DIGITS digits. Should we update to a pattern that matches that and also add a test case?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We run into an issue quite often where we'll have some checkpoint suffixed (i.e. checkpoints/step_00001000-ext) where the only way to run this with axlearn is to copy this directory to step_00001000. How would you suggest we incorporate a pattern like this? Happy to add test cases to this PR as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll probably need to understand more about your use-case and why we need to suffix checkpoints. Can you open an internal PR for discussion?

if len(step) < 1:
raise ValueError(f"Could not find step in '{step_dir}'")
return int(step[-1])


def check_state_structure(
Expand Down