Skip to content

Commit

Permalink
pre-commit formatting
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Clifford <[email protected]>
  • Loading branch information
MichaelClifford committed Sep 23, 2024
1 parent 631d48c commit 9c1e00c
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ from instructlab.training import (
DataProcessArgs,
data_process as dp
)

...

data_process_args = DataProcessArgs(
Expand Down
8 changes: 6 additions & 2 deletions src/instructlab/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@


# defer import of main_ds
def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True) -> None:
def run_training(
torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True
) -> None:
"""Wrapper around the main training job that calls torchrun."""
# Local
from .main_ds import run_training

return run_training(torch_args=torch_args, train_args=train_args, process_data=process_data)
return run_training(
torch_args=torch_args, train_args=train_args, process_data=process_data
)
2 changes: 0 additions & 2 deletions src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,8 @@ def get_masked_and_orig_text(sample):


def main(args: DataProcessArgs):

if not os.path.exists(args.data_output_path):
os.makedirs(args.data_output_path, exist_ok=True)

print("\033[92m data arguments are:\033[0m")
print("\033[36m" + args.model_dump_json() + "\033[0m")
NUM_PROC = args.num_cpu_procs
Expand Down
6 changes: 4 additions & 2 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,9 @@ def main(args):


# public API
def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True) -> None:
def run_training(
torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True
) -> None:
"""
Wrapper around the main training job that calls torchrun.
"""
Expand All @@ -611,7 +613,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs, process_dat
raise ValueError(
f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}"
)

if process_data:
dp.main(
DataProcessArgs(
Expand Down

0 comments on commit 9c1e00c

Please sign in to comment.