Skip to content

Commit

Permalink
training works
Browse files Browse the repository at this point in the history
  • Loading branch information
wenting-zhao committed Dec 8, 2024
1 parent f6b2a71 commit 68b724b
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 24 deletions.
31 changes: 23 additions & 8 deletions examples/star/star.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,34 @@
"""Main STaR Loop"""

from datasets import Dataset, load_dataset
from datasets import Dataset, DatasetDict, load_dataset
from inference import generate_predictions
from train import train
from utils import execute_tests, parse_args
from utils import execute_tests, format_solution, generate_prompt, parse_args


def main():
args = parse_args()
ds = load_dataset(args.dataset_name)
ds = load_dataset(args.dataset_name, args.dataset_config_name)
assert "train" in ds
# format the dataset for training and evaluation
for split in ds:
texts = []
if split == "train": continue
for example in ds[split]:
canonical_solution = f"```python\n{example['canonical_solution']}\n```"
text = [{"role": "user", "message": generate_prompt(example["prompt"], example["test"])}, {"role": "assistant", "message": format_solution(canonical_solution, example["prompt"])}]
texts.append(text)
print(text)
ds[split] = ds[split].add_column(name="text", column=texts)
ds["train"] = ds["train"].select(range(10))

# sample
all_samples = generate_predictions(
args.model_name_or_path, ds["train"], args.temperature, args.n
)
assert len(ds["train"]) == len(all_samples)

# verify and construct the training set
all_traces, all_execution_results = execute_tests(ds["train"], all_samples)
passed_examples = []
for example, execution_results, samples in zip(
Expand All @@ -22,13 +37,13 @@ def main():
for execution_result, sample in zip(execution_results, samples):
# pytest exit code: https://docs.pytest.org/en/stable/reference/exit-codes.html
if execution_result == 0:
example["prediction"] = sample
example["text"] = [{"role": "user", "message": generate_prompt(example["prompt"], example["test"])}, {"role": "assistant", "message": format_solution(sample, example["prompt"])}]
passed_examples.append(example)
break
new_ds = Dataset.from_list(passed_examples)
new_ds.to_json("star_training.json")
print(len(passed_examples) / len(ds["train"]))
train(args)
raw_datasets = DatasetDict({"train": Dataset.from_list(passed_examples), "validation": ds["validation"]})

# train
train(raw_datasets, args.model_name_or_path, args)


if __name__ == "__main__":
Expand Down
18 changes: 2 additions & 16 deletions examples/star/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

import datasets
import torch
from accelerate import Accelerator, DistributedType
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
Expand Down Expand Up @@ -234,10 +234,6 @@ def tokenize_function(examples):
)
)

# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
Expand Down Expand Up @@ -291,17 +287,7 @@ def tokenize_function(examples):
model.train()
if args.with_tracking:
total_loss = 0
if (
args.resume_from_checkpoint
and epoch == starting_epoch
and resume_step is not None
):
# We skip the first `n` batches in the dataloader when resuming from a checkpoint
active_dataloader = accelerator.skip_first_batches(
train_dataloader, resume_step
)
else:
active_dataloader = train_dataloader
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
with accelerator.accumulate(model):
batch["labels"] = batch["input_ids"].clone().detach()
Expand Down
11 changes: 11 additions & 0 deletions examples/star/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tqdm import tqdm
from typing import List, Tuple
from transformers import MODEL_MAPPING, SchedulerType
from commit0.harness.utils import extract_code_blocks


def execute_tests(
Expand Down Expand Up @@ -98,6 +99,16 @@ def generate_prompt(prompt: str, test: str) -> str:
"""


def format_solution(text, prompt):
matches = extract_code_blocks(text)
if len(matches) > 0:
solution = matches[0]
solution = f"```python\n{solution}\n```"
else:
solution = prompt + "\n\n" + text
return solution


def parse_args():
parser = argparse.ArgumentParser(
description="Finetune a transformers model on a causal language modeling task"
Expand Down

0 comments on commit 68b724b

Please sign in to comment.