Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added how to use commit0 for sampling during STAR training #105

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,5 @@ config.yml
hydra_outputs/
.commit0*
.agent*
docs/analysis*.md
docs/analysis*.md
wandb/
44 changes: 44 additions & 0 deletions commit0/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import Union, List
from typing_extensions import Annotated
import commit0.harness.batch_run_pytest_ids
import commit0.harness.run_pytest_ids
import commit0.harness.get_pytest_ids
import commit0.harness.build
Expand Down Expand Up @@ -300,6 +301,49 @@ def test(
)


@commit0_app.command()
def batch_test(
test_ids: str = typer.Argument(
None,
help='All ways pytest supports to run and select tests. Please provide a single string. Example: "test_mod.py", "testing/", "test_mod.py::test_func", "-k \'MyClass and not method\'"',
),
backend: str = typer.Option("modal", help="Backend to use for testing"),
timeout: int = typer.Option(1800, help="Timeout for tests in seconds"),
num_cpus: int = typer.Option(1, help="Number of CPUs to use"),
reference: Annotated[
bool, typer.Option("--reference", help="Test the reference commit")
] = False,
coverage: Annotated[
bool, typer.Option("--coverage", help="Whether to get coverage information")
] = False,
rebuild: bool = typer.Option(
False, "--rebuild", help="Whether to rebuild an image"
),
commit0_config_file: str = typer.Option(
".commit0.yaml",
help="Path to the commit0 dot file, where the setup config is stored",
),
verbose: int = typer.Option(
1,
"--verbose",
"-v",
help="Set this to 2 for more logging information",
count=True,
),
) -> None:
"""Run tests on a Commit0 repository."""
commit0.harness.batch_run_pytest_ids.main(
test_ids,
reference,
coverage,
backend,
timeout,
num_cpus,
rebuild,
verbose,
)


@commit0_app.command()
def evaluate(
branch: Union[str, None] = typer.Option(
Expand Down
8 changes: 3 additions & 5 deletions commit0/harness/run_pytest_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,12 @@ def main(
or "bigcodebench" in dataset_name
or "codecontests" in dataset_name
):
repo_name = example["instance_id"]
repo_name = str(example["instance_id"])
dataset_type = "simple"
else:
repo_name = example["repo"].split("/")[-1]
dataset_type = "commit0"
if repo_name in os.path.basename(repo_or_repo_dir) or repo_or_repo_dir.endswith(
repo_name
):
if repo_name == os.path.basename(repo_or_repo_dir):
spec = make_spec(example, dataset_type)
break
assert spec is not None, "No spec available"
Expand Down Expand Up @@ -174,7 +172,7 @@ def main(
prompt = example["prompt"] if "prompt" in example.keys() else ""
matches = extract_code_blocks(solution)
if len(matches) > 0:
solution = "\n\n".join(matches)
solution = matches[0]
else:
solution = prompt + "\n\n" + solution
patch = solution + "\n\n" + example["test"]
Expand Down
2 changes: 1 addition & 1 deletion commit0/harness/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def extract_code_blocks(text: str) -> List[str]:
from the text.

"""
pattern = r"```python\n(.*?)```"
pattern = r'(?s)```(?:python|py)?(.*?)```'
matches = re.finditer(pattern, text, re.DOTALL)
return [match.group(1).strip() for match in matches]

Expand Down
45 changes: 45 additions & 0 deletions examples/star/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Dict, List
from datasets import Dataset
from vllm import LLM, SamplingParams
from examples.star.utils import generate_prompt, cleanup


def generate_predictions(
model_name: str, dataset: Dataset, temperature: float = 1.0, n: int = 1
) -> List[List[str]]:
"""Generate predictions for a given dataset using a specified language model and
sampling parameters. The function loads the dataset, constructs prompts from
each example, and obtains generated predictions. The resulting predictions are
then added as a new column to the dataset.

Args:
----
model_name (str): Name of the model to use for generation.
dataset (Dataset): The Dataset object.
temperature (float, optional): Temperature setting for the model's
sampling strategy. Default is 1.0.
n (int, optional): Number of sampling runs per prompt. Default is 1.

Returns:
-------
predictions (List[List[str]]): Predictions on the dataset.

"""
sampling_params = SamplingParams(n=n, temperature=temperature, max_tokens=512)
llm = LLM(model=model_name)

prompts: List[List[Dict]] = []
for example in dataset:
prompt = example["prompt"]
test = example["test"]
prompt = generate_prompt(prompt, test)
prompts.append([{"role": "user", "content": prompt}])

outputs = llm.chat(prompts, sampling_params)

results: List[List[str]] = []
for output in outputs:
generated_texts = [one.text for one in output.outputs]
results.append(generated_texts)
cleanup(llm, vllm=True)
return results
15 changes: 15 additions & 0 deletions examples/star/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
python examples/star/star.py \
--model_name_or_path meta-llama/Llama-3.1-8B-Instruct \
--dataset_name commit0/mbpp \
-n 100 \
--output_dir outputs \
--low_cpu_mem_usage \
--with_tracking \
--report_to wandb \
--iteration 5 \
--learning_rate 1e-6 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--max_workers 64 \
--temperature 1

96 changes: 96 additions & 0 deletions examples/star/star.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Main STaR Loop"""

from copy import deepcopy
from datasets import Dataset, DatasetDict, load_dataset
from examples.star.inference import generate_predictions
from examples.star.train import train
from examples.star.utils import (
execute_tests,
format_solution,
generate_prompt,
parse_args,
)


def main() -> None:
args = parse_args()
ds = load_dataset(args.dataset_name, args.dataset_config_name)
assert "train" in ds
# format the dataset for training and evaluation
for split in ds:
texts = []
if split == "train":
continue
for example in ds[split]:
canonical_solution = f"```python\n{example['canonical_solution']}\n```"
text = [
{
"role": "user",
"message": generate_prompt(example["prompt"], example["test"]),
},
{
"role": "assistant",
"message": format_solution(canonical_solution, example["prompt"]),
},
]
texts.append(text)
ds[split] = ds[split].add_column(name="text", column=texts)

model_name = args.model_name_or_path
output_dir = deepcopy(args.output_dir)
for i in range(args.iteration):
# sample
all_samples = generate_predictions(
model_name, ds["train"], args.temperature, args.n
)
ds["train"].add_column(name="sample", column=all_samples).to_json(
f"{output_dir}/data/samples-iter{i}.json"
)
assert len(ds["train"]) == len(all_samples)

# verify and construct the training set
all_traces, all_execution_results = execute_tests(
ds["train"], all_samples, max_workers=args.max_workers
)
passed_examples = []
for example, execution_results, samples in zip(
ds["train"], all_execution_results, all_samples
):
for execution_result, sample in zip(execution_results, samples):
# pytest exit code: https://docs.pytest.org/en/stable/reference/exit-codes.html
if execution_result == 0:
example["text"] = [
{
"role": "user",
"message": generate_prompt(
example["prompt"], example["test"]
),
},
{
"role": "assistant",
"message": format_solution(sample, example["prompt"]),
},
]
passed_examples.append(example)
break
raw_datasets = DatasetDict(
{
"train": Dataset.from_list(passed_examples),
"validation": ds["validation"],
}
)
raw_datasets["train"].to_json(
f"{output_dir}/data/verified-samples-iter{i}.json"
)

# train
args.output_dir = f"{output_dir}/models-iter{i}"
train(raw_datasets, model_name, args)
model_name = args.output_dir


if __name__ == "__main__":
main()


__all__ = []
44 changes: 44 additions & 0 deletions examples/star/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Get test accuracy"""

from datasets import load_dataset
from examples.star.inference import generate_predictions
from examples.star.utils import (
execute_tests,
generate_prompt,
parse_args,
)


def main() -> None:
args = parse_args()
ds = load_dataset(args.dataset_name, args.dataset_config_name)['test']
model_name = args.model_name_or_path

# sample
all_samples = generate_predictions(
model_name, ds, args.temperature, args.n
)
ds.add_column(name="sample", column=all_samples).to_json(
f"{args.output_dir}/data/{model_name.split('/')[-1]}-test-samples.json"
)
assert len(ds) == len(all_samples)

# verify and construct the training set
all_traces, all_execution_results = execute_tests(
ds, all_samples, max_workers=args.max_workers
)
passed = 0
for example, execution_results, samples in zip(
ds, all_execution_results, all_samples
):
for execution_result, sample in zip(execution_results, samples):
# pytest exit code: https://docs.pytest.org/en/stable/reference/exit-codes.html
if execution_result == 0:
passed += 1
print(f"passed: {passed/len(ds)}")

if __name__ == "__main__":
main()


__all__ = []
8 changes: 8 additions & 0 deletions examples/star/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
python examples/star/test.py \
--model_name_or_path $1 \
--dataset_name commit0/mbpp \
-n $2 \
--output_dir outputs \
--max_workers 64 \
--temperature 0

Loading
Loading