diff --git a/config/gpt2_small_fast_supervised.yaml b/config/gpt2_small_fast_eval.yaml similarity index 65% rename from config/gpt2_small_fast_supervised.yaml rename to config/gpt2_small_fast_eval.yaml index 93675366d..14638db1b 100644 --- a/config/gpt2_small_fast_supervised.yaml +++ b/config/gpt2_small_fast_eval.yaml @@ -13,12 +13,17 @@ data: tokenizer: gpt2 cache_dir: "gs://levanter-data/tokenized/data_mix" supervised_data: - validation_urls: - - "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-dev-evaluation.jsonl.gz" - - "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-validation-evaluation.jsonl.gz" - cache_dir: "gs://marin-us-central2/benchmarks/tokenized-gpt2/mmlu/" - input_field: "input" - output_field: "output" + mmlu: + validation_urls: + - "gs://marin-us-central2/evaluation/mmlu-eval-subject-2eb39e/cais/*-validation-evaluation.jsonl.gz" + cache_dir: "gs://levanter-data/tokenized-gpt2/mmlu/" + tags: [ "e"] + arc_easy: + validation_urls: + - "gs://marin-us-central2/evaluation/arc-easy-b39e70/allenai/ai2_arc-ARC-Easy-validation-evaluation.jsonl.gz" + cache_dir: "gs://levanter-data/tokenized-gpt2/arc_easy/" + tags: [ "arc", "e"] + model: type: gpt2 hidden_dim: 768 diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 152781b0b..173c79212 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -15,7 +15,7 @@ from levanter import callbacks from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback from levanter.data import PermutationDataset -from levanter.data.text import ChatSFTDatasetConfig, EpochDataset, mk_chat_sft_dataset, mk_supervised_dataset +from levanter.data.text import ChatUrlDataSourceConfig, EpochDataset, mk_chat_sft_dataset, mk_supervised_dataset from levanter.main.train_lm import TrainLmConfig from levanter.models.lm_model import LmHeadModel, compute_next_token_loss from levanter.trainer import Trainer @@ -100,7 +100,7 @@ def train(config: SFTConfig): if config.dataset_type == DatasetType.CHAT_JSONL: assert config.chat_train_urls is not None assert config.supervised_data is not None - chat_config = ChatSFTDatasetConfig( + chat_config = ChatUrlDataSourceConfig( cache_dir=config.supervised_data.cache_dir, train_urls=config.chat_train_urls, # No validation in this config messages_field=config.messages_field, @@ -110,7 +110,7 @@ def train(config: SFTConfig): train_dataset = mk_chat_sft_dataset(chat_config, tokenizer, model_config.Pos) else: assert config.supervised_data is not None - train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer, model_config.Pos) + train_dataset = mk_supervised_dataset(config.supervised_data, "train", tokenizer, model_config.Pos) logger.info("Supervised dataset created") train_dataset = PermutationDataset(train_dataset, data_key) diff --git a/pyproject.toml b/pyproject.toml index b10358d07..abca1405d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "transformers>=4.41.2", "optax>=0.1.9", "wandb>=0.17.8", - "draccus>=0.8.0", + "draccus>=0.9.3", "pyarrow>=11.0.0", "zstandard>=0.20.0", "datasets>=3.1.0,<4.0", diff --git a/src/levanter/compat/hf_checkpoints.py b/src/levanter/compat/hf_checkpoints.py index ce267041c..7a116acae 100644 --- a/src/levanter/compat/hf_checkpoints.py +++ b/src/levanter/compat/hf_checkpoints.py @@ -39,6 +39,7 @@ from levanter.trainer import StepInfo from levanter.utils import jax_utils from levanter.utils.cloud_utils import temp_dir_before_upload +from levanter.utils.hf_utils import HfTokenizer from levanter.utils.jax_utils import best_effort_sharding, local_cpu_mesh, use_cpu_device from levanter.utils.py_utils import dataclass_with_default_init, logical_cpu_memory_size @@ -872,7 +873,7 @@ def cb(step: StepInfo): def arbitrary_load_from_hf( model_name_or_path, from_pretrained_lambda, revision=None, local_cache_dir=None, trust_remote_code=True -) -> Union[PreTrainedTokenizerBase | ProcessorMixin]: +) -> Union[HfTokenizer | ProcessorMixin]: is_url_like = urlparse(model_name_or_path).scheme != "" if is_url_like: if revision is not None: @@ -889,9 +890,7 @@ def arbitrary_load_from_hf( return from_pretrained_lambda(model_name_or_path, revision=revision, trust_remote_code=trust_remote_code) -def load_tokenizer( - model_name_or_path, revision=None, local_cache_dir=None, trust_remote_code=True -) -> PreTrainedTokenizerBase: +def load_tokenizer(model_name_or_path, revision=None, local_cache_dir=None, trust_remote_code=True) -> HfTokenizer: """Like AutoTokenizer.from_pretrained, but works with gs:// paths or anything on fsspec""" return arbitrary_load_from_hf( model_name_or_path, diff --git a/src/levanter/data/_preprocessor.py b/src/levanter/data/_preprocessor.py index 77b91617f..dd6578667 100644 --- a/src/levanter/data/_preprocessor.py +++ b/src/levanter/data/_preprocessor.py @@ -197,7 +197,7 @@ def __call__(self, batch): match transform: case _MapTransform(fn=fn): - batch = map(fn, batch) + batch = [fn(x) for x in batch] case _BatchMapTransform(fn=fn): batch = fn(batch) is_soa_form = isinstance(batch, dict) or isinstance(batch, pa.RecordBatch) diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index cca3156b8..9dca9b618 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -184,31 +184,6 @@ def gcs_glob(pattern: str) -> list[str]: return matching_urls -def datasource_from_chat_jsonl( - urls: Sequence[str], messages_field: str = "messages", input_role: str = "user", output_role: str = "assistant" -) -> "ShardedDataSource[dict]": - """Creates a ShardedDataSource from JSONL files containing chat messages. - - Args: - urls: Sequence of URLs or glob patterns pointing to JSONL files - messages_field: Field name containing the messages in each JSON object - input_role: Role identifier for input messages - output_role: Role identifier for output messages - - Returns: - ShardedDataSource configured for chat data - """ - # Expand any glob patterns in the URLs - expanded_urls = [] - for url in urls: - if any(c in url for c in "*?[]"): - expanded_urls.extend(gcs_glob(url)) - else: - expanded_urls.append(url) - - return ChatJsonlDataSource(expanded_urls, messages_field, input_role, output_role) - - def datasource_from_hf(id: str, *, split, **kwargs) -> ShardedDataSource[dict]: """ Create a ShardedDataset from a HuggingFace dataset. Arguments are passed to load_dataset. @@ -288,14 +263,49 @@ class TextUrlDataSource(ShardedDataSource[str]): def __init__(self, urls, text_key="text"): self.urls = urls - self._shard_name_to_url_mapping = _mk_shard_name_mapping(urls) self.text_key = text_key + self.base_ds = UrlDataSource(urls, columns=[text_key]) @property def shard_names(self) -> Sequence[str]: - return list(self._shard_name_to_url_mapping.keys()) + return self.base_ds.shard_names def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: + url = self.base_ds._shard_name_to_url_mapping[shard_name] + i = 0 + compression = "infer" + if url.endswith(".zstd"): # hacky way to detect zstd + compression = "zstd" + + format = _sniff_format_for_dataset(url) + + # special case for txt files + if format == ".txt": + with fsspec.open(url, "r", compression=compression) as f: + for line in f: + if i >= row: + yield line + i += 1 + else: + for doc in self.base_ds.open_shard_at_row(shard_name, row): + yield doc[self.text_key] + + +class UrlDataSource(ShardedDataSource[dict]): + """ + Dataset for various dict-like formats. + """ + + def __init__(self, urls, columns=None): + self.urls = urls + self._shard_name_to_url_mapping = _mk_shard_name_mapping(urls) + self.columns = columns + + @property + def shard_names(self) -> Sequence[str]: + return list(self._shard_name_to_url_mapping.keys()) + + def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: url = self._shard_name_to_url_mapping[shard_name] i = 0 compression = "infer" @@ -310,19 +320,18 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: # which is not nothing, but not ideal. for line in f: if i >= row: - yield json.loads(line)[self.text_key] - i += 1 - case ".txt": - with fsspec.open(url, "r", compression=compression) as f: - for line in f: - if i >= row: - yield line + obj = json.loads(line) + if self.columns: + yield {col: obj[col] for col in self.columns} i += 1 case ".json": with fsspec.open(url, "r", compression=compression) as f: data = json.load(f) for doc in data[row:]: - yield doc[self.text_key] + if self.columns: + yield {col: doc[col] for col in self.columns} + else: + yield doc case ".parquet": with fsspec.open(url, "rb", compression=compression) as f: parquet_file = pq.ParquetFile(f) @@ -347,11 +356,11 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: # Read from the starting row group onwards for rg_idx in range(row_group_index, parquet_file.num_row_groups): - table = parquet_file.read_row_group(rg_idx, columns=[self.text_key]) + table = parquet_file.read_row_group(rg_idx, columns=self.columns) if rg_idx == row_group_index: table = table.slice(start_row_in_group) for record in table.to_pylist(): - yield record[self.text_key] + yield record case _: raise ValueError(f"Unknown format {format}") @@ -531,32 +540,6 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: return iter(data[row:]) -class ChatJsonlDataSource(JsonlDataSource): - """DataSource that reads JSONL files containing OpenAI chat format messages.""" - - def __init__(self, urls: Sequence[str], messages_field: str, input_role: str, output_role: str): - super().__init__(urls) - self.messages_field = messages_field - self.input_role = input_role - self.output_role = output_role - - def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: - url = self._shard_name_to_url_mapping[shard_name] - i = 0 - with fsspec.open(url, "r", compression="infer") as f: - for line in f: - if i >= row: - data = json.loads(line) - messages = data[self.messages_field] - - # Extract input/output from messages - input_msg = next(m["content"] for m in messages if m["role"] == self.input_role) - output_msg = next(m["content"] for m in messages if m["role"] == self.output_role) - - yield {"input": input_msg, "output": output_msg} - i += 1 - - class ParquetDataSource(ShardedDataSource[dict]): def __init__(self, urls): self.urls = urls @@ -650,7 +633,8 @@ def shard_names(self) -> Sequence[str]: return self.source.shard_names def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[T]: - return map(self.fn, self.source.open_shard_at_row(shard_name, row)) + for doc in self.source.open_shard_at_row(shard_name, row): + yield self.fn(doc) class _BatchMappedShardedDataSource(ShardedDataSource[T], _TransformedDataset): diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 4cc000e59..f7764a8b2 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -3,21 +3,23 @@ import copy import dataclasses import functools +import json import logging import os from dataclasses import dataclass from functools import cached_property from itertools import chain -from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Protocol, Sequence, Tuple, TypeAlias, TypeVar, Union import datasets import equinox as eqx +import fsspec import jax import numpy as np import regex import tensorstore as ts from draccus import field -from jax._src.random import PRNGKey +from jax.random import PRNGKey from jaxtyping import PRNGKeyArray from tokenizers import normalizers @@ -35,9 +37,8 @@ from levanter.store.cache import CacheOptions, TreeCache from levanter.store.jagged_array import JaggedArrayStore from levanter.store.tree_store import TreeStore -from levanter.utils import fsspec_utils from levanter.utils.fsspec_utils import expand_glob -from levanter.utils.hf_utils import num_cpus_used_by_tokenizer +from levanter.utils.hf_utils import HfTokenizer, num_cpus_used_by_tokenizer silence_transformer_nag() # noqa @@ -46,7 +47,14 @@ from levanter.compat.hf_checkpoints import load_tokenizer # noqa from levanter.data._preprocessor import BatchProcessor, U, dict_from_record_batch # noqa from levanter.data.metrics_monitor import LoggerMetricsMonitor, LoggingMetricsMonitor, MetricsMonitor # noqa -from levanter.data.sharded_datasource import ShardedDataSource, TextUrlDataSource, WrappedHFDataSource # noqa +from levanter.data.sharded_datasource import ( # noqa + JsonlDataSource, + ShardedDataSource, + TextUrlDataSource, + UrlDataSource, + WrappedHFDataSource, + gcs_glob, +) from levanter.shapes import NamedShapeSpec, ShapeSpec # noqa from levanter.store.cache import build_or_load_cache # noqa from levanter.utils.jax_utils import key_iterator, local_cpu_mesh, use_cpu_device # noqa @@ -328,9 +336,18 @@ def __call__(self, batch: Sequence[str]) -> list[dict]: needs_merge = [] if self.padding is not False: - encoding = self.tokenizer(batch, return_attention_mask=self.return_attention_mask, verbose=False, padding=self.padding, max_length=self.max_length, truncation=True) # type: ignore + encoding = self.tokenizer( + batch, + return_attention_mask=self.return_attention_mask, + verbose=False, + padding=self.padding, + max_length=self.max_length, + truncation=True, + ) # type: ignore else: - encoding = self.tokenizer(batch, return_attention_mask=self.return_attention_mask, verbose=False) # type: ignore + encoding = self.tokenizer( + batch, return_attention_mask=self.return_attention_mask, verbose=False + ) # type: ignore if needs_merge: new_encoding = self._merge_split_encodings(batch, encoding, needs_merge) @@ -611,7 +628,7 @@ class LMTaskConfig(abc.ABC): If you want to shuffle in eras, set this to the era length""" @cached_property - def the_tokenizer(self) -> PreTrainedTokenizerBase: + def the_tokenizer(self) -> HfTokenizer: if self.tokenizer == "passthrough": return PassthroughTokenizer(self.vocab_size) else: @@ -648,6 +665,10 @@ def tagged_eval_sets( return [(eval_sets[name], tags[name]) for name in eval_sets] +CANONICAL_INPUT_FIELD = "prompt" +CANONICAL_OUTPUT_FIELD = "response" + + @dataclass class LMSupervisedDatasetConfig: """Config for supervised fine-tuning datasets""" @@ -662,15 +683,68 @@ class LMSupervisedDatasetConfig: validation_urls: List[str] = field(default_factory=list) # paths to jsonl/json files # Field names in the data - input_field: str = "prompt" # name of the input field - output_field: str = "response" # name of output field + input_field: str = CANONICAL_INPUT_FIELD # name of the input field + output_field: str = CANONICAL_OUTPUT_FIELD # name of output field # Optional metadata tags: Optional[List[str]] = None - name: Optional[str] = None -def preprocess_supervised_example( +class SupervisedSourceConfigBase(Protocol): + def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: + raise NotImplementedError + + input_field: str + output_field: str + tags: Optional[List[str]] + cache_dir: str + + +@dataclass(frozen=True) +class SupervisedHfSourceConfig(SupervisedSourceConfigBase): + cache_dir: str + id: str + name: str | None = None + + streaming: bool = True + + input_field: str = CANONICAL_INPUT_FIELD + output_field: str = CANONICAL_OUTPUT_FIELD + tags: Optional[List[str]] = None + + def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: + return WrappedHFDataSource(self.id, split=split, name=self.name, streaming=self.streaming).map( + lambda x: {CANONICAL_INPUT_FIELD: x[self.input_field], CANONICAL_OUTPUT_FIELD: x[self.output_field]} + ) + + +@dataclass(frozen=True) +class SupervisedUrlSourceConfig(SupervisedSourceConfigBase): + cache_dir: str + train_urls: list[str] = dataclasses.field(default_factory=list) + validation_urls: list[str] = dataclasses.field(default_factory=list) + + input_field: str = CANONICAL_INPUT_FIELD + output_field: str = CANONICAL_OUTPUT_FIELD + tags: Optional[List[str]] = None + + def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: + urls = self.train_urls if split == "train" else self.validation_urls + if not urls: + return None + + urls = [globbed for url in urls for globbed in expand_glob(url)] + + source = UrlDataSource(urls, columns=[self.input_field, self.output_field]) + return source.map( + lambda x: {CANONICAL_INPUT_FIELD: x[self.input_field], CANONICAL_OUTPUT_FIELD: x[self.output_field]} + ) + + +SupervisedSourceConfig: TypeAlias = Union[SupervisedHfSourceConfig, SupervisedUrlSourceConfig] + + +def _preprocess_supervised_example( batch, tokenizer: PreTrainedTokenizerBase, input_field: str, output_field: str ) -> dict: sources = [example[input_field] for example in batch] @@ -722,28 +796,69 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Po return lm_ex -def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis): - import levanter.data +def mk_supervised_datasets( + sources: Mapping[str, SupervisedSourceConfigBase] | SupervisedSourceConfigBase, + split: str, + tokenizer: PreTrainedTokenizerBase, + Pos: hax.Axis, +) -> dict[str, tuple[AsyncDataset[LmExample], Sequence[str]]]: + """ + Create supervised datasets from a set of sources. + + Returns: + A dictionary of dataset names to tuples of the dataset and the tags associated with the dataset. + """ + out: dict[str, tuple[AsyncDataset[LmExample], Sequence[str]]] = {} + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + if isinstance(sources, Mapping): + for name, config in sources.items(): + source = config.get_shard_source(split) + if source is None: + continue + + ds = _cache_supervised_set( + source, config.cache_dir, tokenizer, Pos, config.input_field, config.output_field + ) + + if config.tags is None: + tags = [name] + else: + tags = config.tags + [name] - # Choose data source based on config - if config.hf_dataset_name is not None: - # Using HF dataset - dataset = levanter.data.datasource_from_hf(config.hf_dataset_name, split=config.hf_dataset_split) + out[name] = (ds, tags) else: - # Using local files - validation_urls = [url for url_pat in config.validation_urls for url in fsspec_utils.expand_glob(url_pat)] - if not validation_urls: - raise ValueError("Must specify either hf_dataset_name or validation_urls") - dataset = levanter.data.datasource_from_jsonl(validation_urls) + source = sources.get_shard_source(split) # type: ignore + if source is not None: + ds = _cache_supervised_set( + source, sources.cache_dir, tokenizer, Pos, sources.input_field, sources.output_field + ) + tags = sources.tags or [] + if isinstance(sources, SupervisedHfSourceConfig): + name = sources.id + if sources.name is not None: + name = f"{name}/{sources.name}" + + tags = [name] + tags + else: + name = "default" + out[name] = (ds, tags) + + return out + + +def mk_supervised_dataset( + config: SupervisedSourceConfigBase, split: str, tokenizer: HfTokenizer, Pos: hax.Axis +) -> AsyncDataset[LmExample]: - input_field = config.input_field - output_field = config.output_field + source = config.get_shard_source(split) output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} - # Use the same preprocessing as before - dataset = dataset.map_batches( - lambda ex: preprocess_supervised_example(ex, tokenizer, input_field, output_field), + dataset = source.map_batches( # type: ignore + lambda ex: _preprocess_supervised_example(ex, tokenizer, config.input_field, config.output_field), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar, @@ -757,19 +872,36 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos)) -@dataclass -class ChatSFTDatasetConfig(LMSupervisedDatasetConfig): +def _cache_supervised_set(source, cache_dir, tokenizer, Pos, input_field, output_field): + """ + Cache a supervised dataset into input_ids and sources_len. + """ + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} + dataset = source.map_batches( + lambda ex: _preprocess_supervised_example(ex, tokenizer, input_field, output_field), + batch_size=128, + num_cpus=num_cpus_used_by_tokenizer(tokenizer), + output_exemplar=output_exemplar, + ) + cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(cache_dir, await_finished=True) + ds = cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos)) + return ds + + +@dataclass(frozen=True) +class ChatUrlDataSourceConfig: """Config for loading JSONL files in OpenAI chat format for supervised fine-tuning.""" + cache_dir: str + train_urls: List[str] = field(default_factory=list) + validation_urls: List[str] = field(default_factory=list) + # Chat format specific fields messages_field: str = "messages" input_role: str = "user" output_role: str = "assistant" - train_urls: List[str] = field(default_factory=list) # Add this line def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: - import levanter.data - """Gets ShardedDataSource for either training or validation data.""" urls = self.validation_urls if split == "validation" else self.train_urls @@ -777,7 +909,7 @@ def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: return None # Use the datasource_from_chat_jsonl function from sharded_datasource - return levanter.data.sharded_datasource.datasource_from_chat_jsonl( + return datasource_from_chat_jsonl( urls, messages_field=self.messages_field, input_role=self.input_role, output_role=self.output_role ) @@ -808,7 +940,7 @@ def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase) -> dict: def mk_chat_sft_dataset( - config: ChatSFTDatasetConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis + config: ChatUrlDataSourceConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis ) -> AsyncDataset[LmExample]: """Creates a dataset from JSONL files containing chat format data for SFT.""" source = config.get_shard_source("train") @@ -1117,3 +1249,55 @@ def build_caches( @property def sources(self) -> Mapping[str, LMDatasetSourceConfig]: return self.configs + + +def datasource_from_chat_jsonl( + urls: Sequence[str], messages_field: str = "messages", input_role: str = "user", output_role: str = "assistant" +) -> "ShardedDataSource[dict]": + """Creates a ShardedDataSource from JSONL files containing chat messages. + + Args: + urls: Sequence of URLs or glob patterns pointing to JSONL files + messages_field: Field name containing the messages in each JSON object + input_role: Role identifier for input messages + output_role: Role identifier for output messages + + Returns: + ShardedDataSource configured for chat data + """ + # Expand any glob patterns in the URLs + expanded_urls = [] + for url in urls: + if any(c in url for c in "*?[]"): + expanded_urls.extend(gcs_glob(url)) + else: + expanded_urls.append(url) + + return ChatJsonlDataSource(expanded_urls, messages_field, input_role, output_role) + + +# TODO: switch to actual multi-turn +class ChatJsonlDataSource(JsonlDataSource): + """DataSource that reads JSONL files containing OpenAI chat format messages.""" + + def __init__(self, urls: Sequence[str], messages_field: str, input_role: str, output_role: str): + super().__init__(urls) + self.messages_field = messages_field + self.input_role = input_role + self.output_role = output_role + + def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: + url = self._shard_name_to_url_mapping[shard_name] + i = 0 + with fsspec.open(url, "r", compression="infer") as f: + for line in f: + if i >= row: + data = json.loads(line) + messages = data[self.messages_field] + + # Extract input/output from messages + input_msg = next(m["content"] for m in messages if m["role"] == self.input_role) + output_msg = next(m["content"] for m in messages if m["role"] == self.output_role) + + yield {"input": input_msg, "output": output_msg} + i += 1 diff --git a/src/levanter/eval.py b/src/levanter/eval.py index 16342be4d..9fe9ab0d7 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -199,7 +199,6 @@ def eval_callback(step: StepInfo): log_dict = { # log micro average as just "loss" _join_prefix(prefix, "loss"): result.micro_avg_loss, - _join_prefix(prefix, "macro_loss"): result.macro_avg_loss, _join_prefix(prefix, "loading_time"): result.total_eval_loading_time, _join_prefix(prefix, "total_time"): time_fn(), } @@ -207,6 +206,8 @@ def eval_callback(step: StepInfo): logger.info(f"{prefix} loss: {result.micro_avg_loss:.3f}") has_tags = len(evaluator.dataset.tag_to_index) > 1 # 1 tag means there's no difference between micro and macro if has_tags: + log_dict[_join_prefix(prefix, "macro_loss")] = result.macro_avg_loss + for tag, loss in result.tag_macro_losses.items(): # don't log leaf tag macro losses because it doesn't mean anything different than micro loss if tag in evaluator.dataset.tag_to_index: diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index f2ad3e7ce..99165c017 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -16,7 +16,13 @@ from levanter import callbacks from levanter.checkpoint import EpochCheckpointer, load_checkpoint from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback -from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig, LMSupervisedDatasetConfig +from levanter.data.text import ( + CausalLmDataset, + LMDatasetConfig, + LMMixtureDatasetConfig, + SupervisedSourceConfig, + mk_supervised_datasets, +) from levanter.models.gpt2 import Gpt2Config from levanter.models.lm_model import LmConfig, compute_next_token_loss from levanter.optim import AdamConfig, OptimizerConfig @@ -30,7 +36,7 @@ @dataclass class TrainLmConfig: data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig) - supervised_data: Optional[LMSupervisedDatasetConfig] = None + supervised_data: Optional[SupervisedSourceConfig | dict[str, SupervisedSourceConfig]] = None trainer: TrainerConfig = field(default_factory=TrainerConfig) model: LmConfig = field(default_factory=Gpt2Config) optimizer: OptimizerConfig = field(default_factory=AdamConfig) @@ -208,12 +214,14 @@ def main(config: TrainLmConfig): trainer.add_hook(cb, every=config.trainer.steps_per_eval) if config.supervised_data is not None: - logger.info("Using supervised data") - supervised_eval = [(levanter.data.text.mk_supervised_dataset(config.supervised_data, tokenizer, Pos), "")] - # TODO Add tags + logger.info("Using supervised data for evals") + supervised_eval = mk_supervised_datasets(config.supervised_data, "validation", tokenizer, Pos) + + evals = list(supervised_eval.values()) + cb = levanter.eval.cb_tagged_lm_evaluate( EvalBatch, - supervised_eval, + evals, tokenizer, trainer.device_mesh, compute_axis_mapping, diff --git a/src/levanter/utils/hf_utils.py b/src/levanter/utils/hf_utils.py index 922de4830..41e4488d4 100644 --- a/src/levanter/utils/hf_utils.py +++ b/src/levanter/utils/hf_utils.py @@ -13,6 +13,11 @@ _HF_TOKENIZER_OFF_VALUES = {"off", "false", "f", "no", "n", "0"} HfTokenizer: TypeAlias = PreTrainedTokenizerFast | PreTrainedTokenizer +""" +Type alias for a Hugging Face tokenizer. This is a union of the two tokenizer types. +While there is PreTrainedTokenizerBase, it doesn't have all methods that are implemented in both +PreTrainedTokenizer and PreTrainedTokenizerFast. grumble grumble. +""" def num_cpus_used_by_tokenizer(tokenizer) -> int: diff --git a/tests/test_supervised.py b/tests/test_supervised.py index 54c99a102..23f9e240c 100644 --- a/tests/test_supervised.py +++ b/tests/test_supervised.py @@ -4,7 +4,7 @@ import haliax from haliax import Axis -from levanter.data.text import _prepare_supervised_example, preprocess_supervised_example +from levanter.data.text import _prepare_supervised_example, _preprocess_supervised_example def test_supervised_eval(): @@ -19,7 +19,7 @@ def test_supervised_eval(): if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - output = preprocess_supervised_example(examples, tokenizer, "input", "output") + output = _preprocess_supervised_example(examples, tokenizer, "input", "output") assert len(output["input_ids"][0]) == output["sources_len"][0] + 1 ex = {