Skip to content

Commit

Permalink
Merge pull request #947 from AI-Hypercomputer:rdyro-dpo
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 706041818
  • Loading branch information
maxtext authors committed Dec 14, 2024
2 parents e4e0a4f + e8fe1ac commit 1e39608
Show file tree
Hide file tree
Showing 10 changed files with 419 additions and 91 deletions.
18 changes: 8 additions & 10 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def load_state_if_possible(
abstract_unboxed_pre_state: train_state.TrainState,
enable_single_replica_ckpt_restoring: Optional[bool] = False,
dataset_type: Optional[str] = "tfds",
step: int = -1, # -1 means latest
):
"""Loads TrainState as possible from the inputs.
Expand All @@ -171,12 +172,9 @@ def load_state_if_possible(
if checkpoint_manager is not None:
max_logging.log("checkpoint manager exists so trying to load this run's existing checkpoint")

latest_step = checkpoint_manager.latest_step()
if latest_step is not None:
max_logging.log(
f"restoring from this run's directory latest step \
{latest_step}"
)
step = checkpoint_manager.latest_step() if step < 0 else step
if step is not None:
max_logging.log(f"restoring from this run's directory step {step}")

def map_to_pspec(data):
pspec = data.sharding.spec
Expand Down Expand Up @@ -210,19 +208,19 @@ def map_to_pspec(data):
if isinstance(checkpoint_manager, emergency_checkpoint_manager.CheckpointManager):
return (
checkpoint_manager.restore(
latest_step,
step,
args=ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args),
),
None,
)
if (
dataset_type == "grain"
and data_iterator is not None
and (checkpoint_manager.directory / str(latest_step) / "iter").exists()
and (checkpoint_manager.directory / str(step) / "iter").exists()
):
return (
checkpoint_manager.restore(
latest_step,
step,
args=ocp.args.Composite(
items=ocp.args.PyTreeRestore(
item=abstract_unboxed_pre_state,
Expand All @@ -236,7 +234,7 @@ def map_to_pspec(data):
else:
return (
checkpoint_manager.restore(
latest_step,
step,
args=ocp.args.Composite(
items=ocp.args.PyTreeRestore(
item=abstract_unboxed_pre_state,
Expand Down
10 changes: 8 additions & 2 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,14 @@ per_device_batch_size: 12.0
expansion_factor_real_data: -1 # if -1 then all hosts will load real data, else total_hosts//expansion_factor_real_data will pull data from GCS.
eval_per_device_batch_size: 0.0
max_corpus_chars: 10_000_000
train_data_column: 'text'
eval_data_column: 'text'
train_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected"
eval_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected"

# direct preference optimization (DPO)
use_dpo: False
dpo_label_smoothing: 0.0
dpo_beta: 0.1

# dataset_type must be synthetic, hf, grain, tfds
# details in: https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md
dataset_type: tfds
Expand Down
31 changes: 31 additions & 0 deletions MaxText/configs/dpo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
base_config: "base.yml"

use_dpo: true
train_data_columns: ['chosen', 'rejected']
eval_data_columns: ['chosen', 'rejected']
base_output_directory: 'gs://maxtext-external/logs'

per_device_batch_size: 2.0
steps: 10
max_target_length: 512
eval_interval: 5 # test eval once, in the middle of 10 training steps
eval_steps: 2

# TFDS Pipeline ----------------------
dataset_type: tfds
dataset_path: 'gs://maxtext-dataset/dpo/anthropic_rlhf'
dataset_name: 'tfds:1.0.0'
eval_dataset_name: 'tfds:1.0.0'
eval_split: 'test'

# HF Pipeline -------------------------
hf_eval_split: 'test'

gradient_clipping_threshold: 10.0
learning_rate: 5.0e-7
dpo_label_smoothing: 0.0
dpo_beta: 0.1

enable_goodput_recording: false
monitor_goodput: false
enable_checkpointing: true
33 changes: 20 additions & 13 deletions MaxText/input_pipeline/_grain_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Input pipeline using Grain."""

import glob
from pathlib import Path

import ml_collections
import jax
Expand All @@ -30,7 +31,7 @@

def get_datasets(data_file_pattern):
"""Load dataset from array_record files for using with grain"""
data_files = glob.glob(data_file_pattern)
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
dataset = grain.ArrayRecordDataSource(data_files)
return dataset

Expand All @@ -44,7 +45,7 @@ def preprocessing_pipeline(
grain_worker_count: int,
dataloading_host_index,
dataloading_host_count,
data_column,
data_columns,
shuffle: bool = False,
data_shuffle_seed=0,
tokenize=True,
Expand All @@ -54,34 +55,38 @@ def preprocessing_pipeline(
packing=True,
shift=True,
drop_remainder=False,
use_dpo: bool = False,
):
"""Use grain to pre-process the dataset and return iterators"""
assert global_batch_size % global_mesh.size == 0, "Batch size should be divisible number of global devices."

operations = []
operations.append(_input_pipeline_utils.ParseFeatures(data_column, tokenize))
operations.append(_input_pipeline_utils.NormalizeFeatures(data_column, tokenize))
operations.append(_input_pipeline_utils.ParseFeatures(data_columns, tokenize))
if not use_dpo:
assert len(data_columns) == 1
operations.append(_input_pipeline_utils.InputsTargetsFeatures(data_columns[0]))
data_columns = ("inputs", "targets")
operations.append(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))

if tokenize:
operations.append(
_grain_tokenizer.TokenizeAndTrim(["inputs", "targets"], max_target_length, tokenizer_path, add_bos, add_eos)
)
operations.append(_grain_tokenizer.TokenizeAndTrim(data_columns, max_target_length, tokenizer_path, add_bos, add_eos))

# Pack and Batch examples.
if packing:
if packing and not use_dpo:
length_struct = {col: max_target_length for col in data_columns}
operations.append(
grain.experimental.PackAndBatchOperation(
batch_size=global_batch_size // jax.process_count(),
length_struct={"inputs": max_target_length, "targets": max_target_length},
length_struct=length_struct,
)
)
operations.append(_input_pipeline_utils.ReformatPacking())
operations.append(_input_pipeline_utils.ReformatPacking(data_columns))
else:
operations.append(_input_pipeline_utils.PadToMaxLength(max_target_length))
operations.append(grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder))

# Shift inputs for teacher-forced training
if shift:
if shift and not use_dpo:
operations.append(_input_pipeline_utils.ShiftData(axis=1))

index_sampler = grain.IndexSampler(
Expand Down Expand Up @@ -123,12 +128,13 @@ def make_grain_train_iterator(
grain_worker_count=config.grain_worker_count,
dataloading_host_index=process_indices.index(jax.process_index()),
dataloading_host_count=len(process_indices),
data_column=config.train_data_column,
data_columns=config.train_data_columns,
shuffle=config.enable_data_shuffling,
data_shuffle_seed=config.data_shuffle_seed,
tokenize=config.tokenize_train_data,
add_bos=config.add_bos,
add_eos=config.add_eos,
use_dpo=config.use_dpo,
)
return train_iter

Expand All @@ -149,11 +155,12 @@ def make_grain_eval_iterator(
grain_worker_count=config.grain_worker_count,
dataloading_host_index=process_indices.index(jax.process_index()),
dataloading_host_count=len(process_indices),
data_column=config.eval_data_column,
data_columns=config.eval_data_columns,
shuffle=False,
data_shuffle_seed=config.data_shuffle_seed,
tokenize=config.tokenize_eval_data,
add_bos=config.add_bos,
add_eos=config.add_eos,
use_dpo=config.use_dpo,
)
return eval_iter
36 changes: 23 additions & 13 deletions MaxText/input_pipeline/_hf_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
"""

"""Input pipeline using Huggingface datasets."""
import functools

import ml_collections
import jax
import datasets
import transformers
import grain.python as grain
import numpy as np

from input_pipeline import _input_pipeline_utils
import multihost_dataloading
Expand All @@ -31,7 +33,7 @@ def preprocessing_pipeline(
dataloading_host_count,
global_mesh,
dataset,
data_column_name,
data_column_names,
tokenize,
tokenizer_path,
hf_access_token,
Expand All @@ -46,6 +48,7 @@ def preprocessing_pipeline(
num_threads=1,
drop_remainder=False,
generate_padding_example=False,
use_dpo=None,
):
"""pipeline for preprocessing HF dataset"""

Expand All @@ -67,11 +70,9 @@ def preprocessing_pipeline(
dataset = dataset.map(
_input_pipeline_utils.tokenization,
batched=True,
fn_kwargs={"hf_tokenizer": tokenizer, "max_length": max_target_length - 1, "column_name": data_column_name},
fn_kwargs={"hf_tokenizer": tokenizer, "max_length": max_target_length - 1, "column_names": data_column_names},
)
dataset = dataset.select_columns(["input_ids"]).rename_column("input_ids", data_column_name)
else:
dataset = dataset.select_columns([data_column_name])
dataset = dataset.select_columns(data_column_names)

dataset = _input_pipeline_utils.HFDataSource(
dataset,
Expand All @@ -80,24 +81,31 @@ def preprocessing_pipeline(
num_threads,
generate_padding_example,
max_target_length,
data_column_name,
data_column_names,
)
operations = []
operations.append(_input_pipeline_utils.HFNormalizeFeatures(data_column_name))
if not use_dpo:
assert len(data_column_names) == 1
operations.append(_input_pipeline_utils.HFNormalizeFeatures(data_column_names[0]))
data_column_names = ("inputs", "targets")
else:
lists2array = lambda x: jax.tree.map(np.asarray, x, is_leaf=lambda x: isinstance(x, (list, tuple)))
operations.append(grain.MapOperation(lists2array))

if packing:
if packing and not use_dpo:
length_struct = {col: max_target_length for col in data_column_names}
operations.append(
grain.experimental.PackAndBatchOperation(
batch_size=global_batch_size // jax.process_count(),
length_struct={"inputs": max_target_length, "targets": max_target_length},
length_struct=length_struct,
)
)
operations.append(_input_pipeline_utils.ReformatPacking())
operations.append(_input_pipeline_utils.ReformatPacking(data_column_names))
else:
operations.append(_input_pipeline_utils.PadToMaxLength(max_target_length))
operations.append(grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder))

if shift:
if shift and not use_dpo:
operations.append(_input_pipeline_utils.ShiftData(axis=1))

# Since HuggingFace IterableDataset does not support access through index
Expand Down Expand Up @@ -147,7 +155,7 @@ def make_hf_train_iterator(
dataloading_host_count=len(process_indices_train),
global_mesh=global_mesh,
dataset=train_ds,
data_column_name=config.train_data_column,
data_column_names=config.train_data_columns,
tokenize=config.tokenize_train_data,
tokenizer_path=config.tokenizer_path,
hf_access_token=config.hf_access_token,
Expand All @@ -158,6 +166,7 @@ def make_hf_train_iterator(
add_bos=config.add_bos,
add_eos=config.add_eos,
generate_padding_example=True,
use_dpo=config.use_dpo,
)
return train_iter

Expand Down Expand Up @@ -185,7 +194,7 @@ def make_hf_eval_iterator(
dataloading_host_count=len(process_indices_eval),
global_mesh=global_mesh,
dataset=eval_ds,
data_column_name=config.eval_data_column,
data_column_names=config.eval_data_columns,
tokenize=config.tokenize_eval_data,
tokenizer_path=config.tokenizer_path,
hf_access_token=config.hf_access_token,
Expand All @@ -196,5 +205,6 @@ def make_hf_eval_iterator(
add_bos=config.add_bos,
add_eos=config.add_eos,
generate_padding_example=eval_generate_padding_example,
use_dpo=config.use_dpo,
)
return eval_iter
Loading

0 comments on commit 1e39608

Please sign in to comment.