From 9c1e00ca6801b4fcd8f2d62d976709f67d8b6df3 Mon Sep 17 00:00:00 2001 From: Michael Clifford Date: Mon, 23 Sep 2024 17:10:23 -0400 Subject: [PATCH] pre-commit formatting Signed-off-by: Michael Clifford --- README.md | 2 +- src/instructlab/training/__init__.py | 8 ++++++-- src/instructlab/training/data_process.py | 2 -- src/instructlab/training/main_ds.py | 6 ++++-- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index b97cee33..5d9b8831 100644 --- a/README.md +++ b/README.md @@ -247,7 +247,7 @@ from instructlab.training import ( DataProcessArgs, data_process as dp ) - + ... data_process_args = DataProcessArgs( diff --git a/src/instructlab/training/__init__.py b/src/instructlab/training/__init__.py index b066366d..91a701ab 100644 --- a/src/instructlab/training/__init__.py +++ b/src/instructlab/training/__init__.py @@ -22,9 +22,13 @@ # defer import of main_ds -def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True) -> None: +def run_training( + torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True +) -> None: """Wrapper around the main training job that calls torchrun.""" # Local from .main_ds import run_training - return run_training(torch_args=torch_args, train_args=train_args, process_data=process_data) + return run_training( + torch_args=torch_args, train_args=train_args, process_data=process_data + ) diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 9a61a404..77d32909 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -175,10 +175,8 @@ def get_masked_and_orig_text(sample): def main(args: DataProcessArgs): - if not os.path.exists(args.data_output_path): os.makedirs(args.data_output_path, exist_ok=True) - print("\033[92m data arguments are:\033[0m") print("\033[36m" + args.model_dump_json() + "\033[0m") NUM_PROC = args.num_cpu_procs diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index d368e620..71824f74 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -602,7 +602,9 @@ def main(args): # public API -def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True) -> None: +def run_training( + torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True +) -> None: """ Wrapper around the main training job that calls torchrun. """ @@ -611,7 +613,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs, process_dat raise ValueError( f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}" ) - + if process_data: dp.main( DataProcessArgs(