Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise SFT File #793

Merged
merged 12 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions config/llama_sft_hf_ckpt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Model configuration
model:
type: llama
seq_len: 2048
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: true
flash_attention_block_size: 512
use_bias: false
use_layer_norm_weight: false
72 changes: 54 additions & 18 deletions examples/sft/sft.py → src/levanter/main/sft.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional
from typing import List, Optional, Union

import jax.random as jrandom
import transformers
Expand All @@ -15,10 +15,17 @@
from levanter import callbacks
from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback
from levanter.data import PermutationDataset
from levanter.data.text import ChatUrlDataSourceConfig, EpochDataset, mk_chat_sft_dataset, mk_supervised_dataset
from levanter.main.train_lm import TrainLmConfig
from levanter.models.lm_model import LmHeadModel, compute_next_token_loss
from levanter.trainer import Trainer
from levanter.data.text import (
ChatUrlDataSourceConfig,
EpochDataset,
SupervisedSourceConfig,
mk_chat_sft_dataset,
mk_supervised_dataset,
)
from levanter.models.llama import LlamaConfig
from levanter.models.lm_model import LmConfig, LmHeadModel, compute_next_token_loss
from levanter.optim import AdamConfig, OptimizerConfig
from levanter.trainer import Trainer, TrainerConfig


logger = logging.getLogger(__name__)
Expand All @@ -38,24 +45,40 @@ class DatasetType(str, Enum):


@dataclass
class SFTConfig(TrainLmConfig):
class SFTConfig:
# inherit most of the config from TrainLmConfig
max_tune_length: int = 2048
trainer: TrainerConfig = field(default_factory=TrainerConfig)
model: LmConfig = field(default_factory=LlamaConfig)
optimizer: OptimizerConfig = field(default_factory=AdamConfig)
supervised_data: Optional[SupervisedSourceConfig | dict[str, SupervisedSourceConfig]] = None

# config related to continued pretraining
initialize_from_hf: Union[bool, str] = False
hf_save_path: Optional[str] = None
hf_upload: Optional[str] = None
hf_save_steps: int = 0

max_seq_len: int = 2048
model_name_or_path: str = "meta-llama/Llama-2-7b-hf"
tokenizer: str = "meta-llama/Llama-2-7b-hf"

# Add dataset type and chat-specific fields
dataset_type: DatasetType = DatasetType.HUGGINGFACE
dataset_type: DatasetType = DatasetType.CHAT_JSONL
chat_train_urls: Optional[List[str]] = None
messages_field: str = "messages"
input_role: str = "user"
output_role: str = "assistant"

data_seed: Optional[int] = None # if provided, will override the data seed from the trainer

# if provided, will initialize from this checkpoint, used for llama style data mixture
epoch: int = 0


def train(config: SFTConfig):
tokenizer = transformers.AutoTokenizer.from_pretrained(
config.tokenizer,
model_max_length=config.max_tune_length,
model_max_length=config.max_seq_len,
padding_side="right",
trust_remote_code=True,
)
Expand All @@ -79,7 +102,11 @@ def train(config: SFTConfig):
elif config.trainer.initialize_from is None:
raise ValueError("Must specify either --initialize_from_hf or --initialize_from")
else:
converter = None
if config.hf_save_steps:
converter = HFCheckpointConverter.from_hf(config.model_name_or_path, trust_remote_code=True)
converter = converter.replaced(tokenizer=tokenizer)
else:
converter = None
model_config = config.model

levanter.initialize(config)
Expand All @@ -100,8 +127,16 @@ def train(config: SFTConfig):
if config.dataset_type == DatasetType.CHAT_JSONL:
assert config.chat_train_urls is not None
assert config.supervised_data is not None

# Get the cache_dir safely
cache_dir = (
config.supervised_data.cache_dir
if not isinstance(config.supervised_data, dict)
else next(iter(config.supervised_data.values())).cache_dir
)

chat_config = ChatUrlDataSourceConfig(
cache_dir=config.supervised_data.cache_dir,
cache_dir=cache_dir,
train_urls=config.chat_train_urls, # No validation in this config
messages_field=config.messages_field,
input_role=config.input_role,
Expand All @@ -110,7 +145,13 @@ def train(config: SFTConfig):
train_dataset = mk_chat_sft_dataset(chat_config, tokenizer, model_config.Pos)
else:
assert config.supervised_data is not None
train_dataset = mk_supervised_dataset(config.supervised_data, "train", tokenizer, model_config.Pos)
if isinstance(config.supervised_data, dict):
# TODO: figure out what actually makes sense here
# for marin we will just use the url code path
config_to_use = next(iter(config.supervised_data.values()))
else:
config_to_use = config.supervised_data
train_dataset = mk_supervised_dataset(config_to_use, "train", tokenizer, model_config.Pos)
logger.info("Supervised dataset created")
train_dataset = PermutationDataset(train_dataset, data_key)

Expand Down Expand Up @@ -161,11 +202,6 @@ def train(config: SFTConfig):

loader = trainer.data_loader(train_dataset, trainer.TrainBatch)

if int(state.step) != 0:
logger.info(f"Resuming training from step {state.step}")
for i in range(state.step):
next(loader)

if config.hf_save_path is not None:
# bit gross to reach this far into the config, but it's fine
if config.trainer.checkpointer.append_run_id_to_base_path:
Expand Down
Loading