Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make data processing optional in run_training() #220

Merged
merged 3 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,55 @@ run_training(
torchrun_args=torchrun_args,
training_args=training_args,
)

```

## Example training with separate data pre-processing

MichaelClifford marked this conversation as resolved.
Show resolved Hide resolved
If the machines in the example above have shared storage, users can pre-process the training dataset a single time so that it can then be distributed to each machine by making the following updates.

```python
from instructlab.training import (
run_training,
TorchrunArgs,
TrainingArgs,
DeepSpeedOptions,
DataProcessArgs,
data_process as dp
)

training_args = TrainingArgs(
# define data-specific arguments
model_path = "ibm-granite/granite-7b-base",
data_path = "path/to/dataset.jsonl",
ckpt_output_dir = "data/saved_checkpoints",
data_output_dir = "data/outputs",

# define model-trianing parameters
max_seq_len = 4096,
max_batch_len = 60000,
num_epochs = 10,
effective_batch_size = 3840,
save_samples = 250000,
learning_rate = 2e-6,
warmup_steps = 800,
is_padding_free = True, # set this to true when using Granite-based models
random_seed = 42,
process_data = True,
)
...

data_process_args = DataProcessArgs(
data_output_path = training_args.data_output_dir,
model_path = training_args.model_path,
data_path = training_args.data_path,
max_seq_len = training_args.max_seq_len,
chat_tmpl_path = training_args.chat_tmpl_path
)

dp.main(data_process_args)
run_training(
torch_args=torchrun_args,
train_args=training_args,
)
```
3 changes: 3 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,6 @@ class TrainingArgs(BaseModel):
# https://github.com/instructlab/training/issues/28
# quantize_dtype: QuantizeDataType = QuantizeDataType.NONE
lora: LoraOptions | None = None

# This field defines whether or not data processing will occur inside of `run_training()`
process_data: Optional[bool] = True
2 changes: 2 additions & 0 deletions src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def get_masked_and_orig_text(sample):


def main(args: DataProcessArgs):
if not os.path.exists(args.data_output_path):
os.makedirs(args.data_output_path, exist_ok=True)
MichaelClifford marked this conversation as resolved.
Show resolved Hide resolved
print("\033[92m data arguments are:\033[0m")
print("\033[36m" + args.model_dump_json() + "\033[0m")
NUM_PROC = args.num_cpu_procs
Expand Down
32 changes: 15 additions & 17 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,24 +645,22 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}"
)

# process the training data
if not os.path.exists(train_args.data_output_dir):
RobotSail marked this conversation as resolved.
Show resolved Hide resolved
os.makedirs(train_args.data_output_dir, exist_ok=True)
dp.main(
DataProcessArgs(
# XXX(osilkin): make a decision here, either:
# 1. the CLI is fully responsible for managing where the data is written
# 2. we never cache it and simply write it to a tmp file every time.
#
# An important reason for why #1 would be preferable is in the case of OpenShift/SELinux
# where the user has a defined place for new temporary data to be written.
data_output_path=train_args.data_output_dir,
model_path=train_args.model_path,
data_path=train_args.data_path,
max_seq_len=train_args.max_seq_len,
chat_tmpl_path=train_args.chat_tmpl_path,
if train_args.process_data:
dp.main(
DataProcessArgs(
# XXX(osilkin): make a decision here, either:
# 1. the CLI is fully responsible for managing where the data is written
# 2. we never cache it and simply write it to a tmp file every time.
#
# An important reason for why #1 would be preferable is in the case of OpenShift/SELinux
# where the user has a defined place for new temporary data to be written.
data_output_path=train_args.data_output_dir,
model_path=train_args.model_path,
data_path=train_args.data_path,
max_seq_len=train_args.max_seq_len,
chat_tmpl_path=train_args.chat_tmpl_path,
)
)
)

if not os.path.exists(train_args.ckpt_output_dir):
os.makedirs(train_args.ckpt_output_dir, exist_ok=True)
Expand Down