forked from microsoft/rStar
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
98 lines (85 loc) · 3.34 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Adapted from https://github.com/MARIO-Math-Reasoning/Super_MARIO
from __future__ import annotations
import os
import json
import torch
import argparse
from tqdm import tqdm
from datetime import datetime
from omegaconf import OmegaConf
from rstar_deepthink.agents import BS, MCTS
from rstar_deepthink.solver import Solver
from rstar_deepthink.config import BaseConfig
torch.set_num_threads(12)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def load_qaf(filename: str):
if filename.endswith(".json"):
with open(filename, "r") as f:
data = json.load(f)
if "example" in data:
data = data["example"]
elif filename.endswith(".jsonl"):
data = []
with open(filename, "r") as f:
lines = f.readlines()
for line in lines:
data.append(json.loads(line))
else:
raise ValueError(f"Unrecognized file format: {filename}")
return data
def batch(iterable, n=-1):
l = len(iterable)
if n <= 0:
n = l
for ndx in range(0, l, n):
yield iterable[ndx: min(ndx + n, l)]
def parse_args():
args = argparse.ArgumentParser()
args.add_argument('--custom_cfg', type=str, default="config/sft_eval_mcts.yaml")
args.add_argument("--qaf", type=str, default="", help="quesuion and answer file")
args.add_argument('--model_dir', type=str, default="")
args.add_argument('--reward_model_dir', type=str, default="")
args.add_argument('--save_in_model', type=str, default="")
args = args.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
config = OmegaConf.structured(BaseConfig)
if args.custom_cfg:
custom_config = OmegaConf.load(args.custom_cfg)
config = OmegaConf.merge(config, custom_config)
config = OmegaConf.create(OmegaConf.to_yaml(config, resolve=True))
if args.model_dir:
config.model_dir = args.model_dir
if args.reward_model_dir:
config.reward_model_dir = args.reward_model_dir
print(config)
llm_version = os.path.basename(config.model_dir.rstrip("/"))
data = load_qaf(args.qaf)
solver = Solver(config=config)
# init agent
if config.mode == "mcts":
agent = MCTS
elif config.mode == "bs":
agent = BS
else:
raise NotImplementedError
if args.reward_model_dir:
llm_version += "." + args.reward_model_dir.split("/")[-1]
saved_jsonl_file = f"{args.qaf}.{config.mode}.{llm_version}.{datetime.now().strftime('%Y%m%d%H%M%S')}.jsonl"
if args.save_in_model:
saved_jsonl_file = args.save_in_model + '.jsonl'
saved_jsonl_file_dir = os.path.dirname(saved_jsonl_file)
os.makedirs(saved_jsonl_file_dir, exist_ok=True)
with open(saved_jsonl_file, "a+", encoding='utf-8') as writer:
for cur_data in tqdm(batch(data, config.batch_size), desc="Main Processing"):
agents = [agent(config=config, question=d["question"], ground_truth=str(d["answer"]))
for d in cur_data]
jsonlines = solver.solve(agents, saved_jsonl_file, cur_data)
for d in cur_data:
question = d["question"]
d["rstar"] = jsonlines[question]
writer.write(json.dumps(d, ensure_ascii=False) + '\n')
writer.flush()