Skip to content

Commit

Permalink
pre-commit formatting
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Clifford <[email protected]>
  • Loading branch information
MichaelClifford committed Oct 1, 2024
1 parent cb4450d commit 736f8cf
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 7 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ run_training(
train_args=training_args
)
```

If the machine's above have shared storage, users can preprocess the training dataset a single time so that it can then distributed to each machine with the following update:

```python
Expand All @@ -426,7 +427,7 @@ from instructlab.training import (
DataProcessArgs,
data_process as dp
)

...

data_process_args = DataProcessArgs(
Expand Down
8 changes: 6 additions & 2 deletions src/instructlab/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@


# defer import of main_ds
def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True) -> None:
def run_training(
torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True
) -> None:
"""Wrapper around the main training job that calls torchrun."""
# Local
from .main_ds import run_training

return run_training(torch_args=torch_args, train_args=train_args, process_data=process_data)
return run_training(
torch_args=torch_args, train_args=train_args, process_data=process_data
)
2 changes: 0 additions & 2 deletions src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,8 @@ def get_masked_and_orig_text(sample):


def main(args: DataProcessArgs):

if not os.path.exists(args.data_output_path):
os.makedirs(args.data_output_path, exist_ok=True)

print("\033[92m data arguments are:\033[0m")
print("\033[36m" + args.model_dump_json() + "\033[0m")
NUM_PROC = args.num_cpu_procs
Expand Down
6 changes: 4 additions & 2 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,9 @@ def main(args):


# public API
def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True) -> None:
def run_training(
torch_args: TorchrunArgs, train_args: TrainingArgs, process_data: bool = True
) -> None:
"""
Wrapper around the main training job that calls torchrun.
"""
Expand All @@ -627,7 +629,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs, process_dat
raise ValueError(
f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}"
)

if process_data:
dp.main(
DataProcessArgs(
Expand Down

0 comments on commit 736f8cf

Please sign in to comment.