Skip to content

Commit

Permalink
Minimal working eval
Browse files Browse the repository at this point in the history
Summary:

Test Plan:
  • Loading branch information
EntilZha committed Feb 14, 2025
1 parent a3e0647 commit 655eca6
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 27 deletions.
4 changes: 2 additions & 2 deletions bytelatent/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ class ValidationArgs(BaseModel):

class EvalArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
dump_dir: str
ckpt_dir: str
dump_dir: str | None = None
ckpt_dir: str | None = None
metric_log_dir: str | None = None
generator: PackedCausalTransformerGeneratorArgs = (
PackedCausalTransformerGeneratorArgs()
Expand Down
68 changes: 47 additions & 21 deletions bytelatent/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM

from bytelatent.args import EvalArgs, ValidationArgs
from bytelatent.args import (
EvalArgs,
TrainArgs,
ValidationArgs,
find_and_sanitize_chunks,
)
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
from bytelatent.config_parser import parse_args_to_pydantic_model
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.distributed import (
DistributedArgs,
dist_mean_dict,
Expand Down Expand Up @@ -113,36 +119,52 @@ def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
return results


def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
srcs = {}
def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
srcs = []
for src in val_args.sources:
path = os.path.join(val_args.root_dir, src)
srcs[path] = 1.0
srcs.append(path)

for src in train_cfg.data.sources:
path = os.path.join(train_cfg.data.root_dir, src)
srcs[path] = 1.0

multi_state = init_choice_state(
"", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
)
path_to_iter = setup_sources(multi_state)
srcs.append(path)

path_to_iter = {}
for path in srcs:
chunks = find_and_sanitize_chunks(
path,
world_size=1,
file_pattern="*.val.jsonl",
s3_profile=train_cfg.data.s3_profile,
)
assert (
len(chunks) == 1
), f"There should be only 1 chunk per validation file, but found: {chunks}"
chunk = chunks[0]
iterator = ArrowFileIterator(
dataset_files=[chunk],
file_path=None,
preprocess_dir=None,
entropy_model_name=None,
worker_id=0,
num_workers=1,
arrow_batch_size=train_cfg.data.arrow_batch_size,
s3_profile=train_cfg.data.s3_profile,
file_format="json",
)
path_to_iter[path] = iterator

max_gen_len = generator.max_gen_len
# We temporarily lower max gen len
generator.max_gen_len = 1

all_val_metrics = {}
for src in path_to_iter:
jsonl_iterator = path_to_iter[src]
example_iterator = path_to_iter[src].create_iter()
texts = []
logger.info(f"Running validation on {src}...")
for step, (content, state) in enumerate(jsonl_iterator):
if state["current_iter"] > 0 or (
val_args.max_steps is not None and step >= val_args.max_steps
):
break
content_key = "text" if ("text" in content) else "content"
texts.append(content[content_key])
for step, example in enumerate(example_iterator):
texts.append(example.text)

_, loglikelihood, _ = generator.generate(texts)

Expand Down Expand Up @@ -187,7 +209,7 @@ def launch_eval(eval_args: EvalArgs):
else:
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
if not fs.exists(consolidate_path) and get_global_rank() == 0:
consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)

fs.mkdirs(eval_args.dump_dir, exist_ok=True)
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
Expand All @@ -206,10 +228,13 @@ def launch_eval(eval_args: EvalArgs):

wrap = EvalHarnessLM(generator)
# Redo
results = simple_evaluate(wrap, eval_args.harness.model_dump())
# results = simple_evaluate(wrap, **eval_args.harness.model_dump())
results = {"results": []}

val_results = None
if eval_args.validation:
val_results = eval_on_val(generator, eval_args.validation, train_cfg)

if get_global_rank() == 0:
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
f.write(json.dumps(results))
Expand All @@ -218,6 +243,7 @@ def launch_eval(eval_args: EvalArgs):
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
f.write(json.dumps(val_results))
logger.info(f"All validation results: {val_results}")

if eval_args.metric_log_dir and get_global_rank() == 0:
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")

Expand Down Expand Up @@ -247,7 +273,7 @@ def launch_eval(eval_args: EvalArgs):


def main():
eval_args = parse_args(EvalArgs)
eval_args = parse_args_to_pydantic_model(EvalArgs)
launch_eval(eval_args)


Expand Down
6 changes: 3 additions & 3 deletions bytelatent/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer(
):
train_args_path = os.path.join(consolidated_path, "params.json")
fs = get_fs(train_args_path)
with fs.open(train_args_path) as f:
train_args = TrainArgs.model_validate_json(f.read())
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))

if train_args.train_entropy_model:
model_args = train_args.entropy_model
Expand All @@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer(
train_args.distributed.model_dtype
]
tokenizer = train_args.data.tokenizer_args.build()
st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True)
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
st_dict = torch.load(f, weights_only=True)
model.load_state_dict(st_dict["model"])
model = model.cuda().eval()
for param in model.parameters():
Expand Down
4 changes: 3 additions & 1 deletion bytelatent/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ def set_preemption_flag(signum, frame):
preemption_flag["flag"] = True


def every_n_steps(train_state, freq, acc_step=None, acc_freq=None):
def every_n_steps(train_state, freq: int, acc_step=None, acc_freq=None):
if freq < 0:
return False
test = train_state.step % freq == 0
if acc_step is not None:
test = test and (train_state.acc_step == acc_step)
Expand Down

0 comments on commit 655eca6

Please sign in to comment.