Skip to content

Commit

Permalink
don't inhert trainlm
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Nov 12, 2024
1 parent a3a6db3 commit 93250b4
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions src/levanter/main/sft.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional
from typing import List, Optional, Union

import jax.random as jrandom
import transformers
Expand All @@ -15,10 +15,18 @@
from levanter import callbacks
from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback
from levanter.data import PermutationDataset
from levanter.data.text import ChatSFTDatasetConfig, EpochDataset, mk_chat_sft_dataset, mk_supervised_dataset
from levanter.data.text import (
ChatSFTDatasetConfig,
EpochDataset,
LMSupervisedDatasetConfig,
mk_chat_sft_dataset,
mk_supervised_dataset,
)
from levanter.main.train_lm import TrainLmConfig
from levanter.models.lm_model import LmHeadModel, compute_next_token_loss
from levanter.trainer import Trainer
from levanter.models.llama import LlamaConfig
from levanter.models.lm_model import LmConfig, LmHeadModel, compute_next_token_loss
from levanter.optim import AdamConfig, OptimizerConfig
from levanter.trainer import Trainer, TrainerConfig


logger = logging.getLogger(__name__)
Expand All @@ -40,22 +48,38 @@ class DatasetType(str, Enum):
@dataclass
class SFTConfig(TrainLmConfig):
# inherit most of the config from TrainLmConfig
max_tune_length: int = 2048
trainer: TrainerConfig = field(default_factory=TrainerConfig)
model: LmConfig = field(default_factory=LlamaConfig)
optimizer: OptimizerConfig = field(default_factory=AdamConfig)
supervised_data: Optional[LMSupervisedDatasetConfig] = None

# config related to continued pretraining
initialize_from_hf: Union[bool, str] = False
hf_save_path: Optional[str] = None
hf_upload: Optional[str] = None
hf_save_steps: int = 10000

max_seq_len: int = 2048
model_name_or_path: str = "meta-llama/Llama-2-7b-hf"
tokenizer: str = "meta-llama/Llama-2-7b-hf"

# Add dataset type and chat-specific fields
dataset_type: DatasetType = DatasetType.HUGGINGFACE
dataset_type: DatasetType = DatasetType.CHAT_JSONL
chat_train_urls: Optional[List[str]] = None
messages_field: str = "messages"
input_role: str = "user"
output_role: str = "assistant"

data_seed: Optional[int] = None # if provided, will override the data seed from the trainer

# if provided, will initialize from this checkpoint, used for llama style data mixture
epoch: int = 0


def train(config: SFTConfig):
tokenizer = transformers.AutoTokenizer.from_pretrained(
config.tokenizer,
model_max_length=config.max_tune_length,
model_max_length=config.max_seq_len,
padding_side="right",
trust_remote_code=True,
)
Expand Down

0 comments on commit 93250b4

Please sign in to comment.