From 736f8cf6fb9b8ea66a4b6b85423865d2a8e724e3 Mon Sep 17 00:00:00 2001 From: Michael Clifford Date: Mon, 23 Sep 2024 17:10:23 -0400 Subject: [PATCH] pre-commit formatting Signed-off-by: Michael Clifford --- README.md | 3 ++- src/instructlab/training/__init__.py | 8 ++++++-- src/instructlab/training/data_process.py | 2 -- src/instructlab/training/main_ds.py | 6 ++++-- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 1f11efd6..92dc4295 100644 --- a/README.md +++ b/README.md @@ -415,6 +415,7 @@ run_training( train_args=training_args ) ``` + If the machine's above have shared storage, users can preprocess the training dataset a single time so that it can then distributed to each machine with the following update: ```python @@ -426,7 +427,7 @@ from instructlab.training import ( DataProcessArgs, data_process as dp ) - + ... data_process_args = DataProcessArgs( diff --git a/src/instructlab/training/__init__.py b/src/instructlab/training/__init__.py index 499625a5..7b4bf5e3 100644 --- a/src/instructlab/training/__init__.py +++ b/src/instructlab/training/__init__.py @@ -28,9 +28,13 @@ # defer import of main_ds -def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True) -> None: +def run_training( + torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True +) -> None: """Wrapper around the main training job that calls torchrun.""" # Local from .main_ds import run_training - return run_training(torch_args=torch_args, train_args=train_args, process_data=process_data) + return run_training( + torch_args=torch_args, train_args=train_args, process_data=process_data + ) diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 4eb642d2..6c1a20dd 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -221,10 +221,8 @@ def get_masked_and_orig_text(sample): def main(args: DataProcessArgs): - if not os.path.exists(args.data_output_path): os.makedirs(args.data_output_path, exist_ok=True) - print("\033[92m data arguments are:\033[0m") print("\033[36m" + args.model_dump_json() + "\033[0m") NUM_PROC = args.num_cpu_procs diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 7d175b54..825d1920 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -618,7 +618,9 @@ def main(args): # public API -def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True) -> None: +def run_training( + torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True +) -> None: """ Wrapper around the main training job that calls torchrun. """ @@ -627,7 +629,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs, process_dat raise ValueError( f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}" ) - + if process_data: dp.main( DataProcessArgs(