forked from xxyQwQ/ComfyBench
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
134 lines (124 loc) · 4.71 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import sys
import time
import json
import argparse
from utils import console
from utils.comfy import execute_prompt
from agent.cot import CoTPipeline
from agent.rag import RAGPipeline
from agent.cotsc import CoTSCPipeline
from agent.comfy import ComfyPipeline
from agent.fewshot import FewShotPipeline
from agent.zeroshot import ZeroShotPipeline
from agent.variant.comfy_no_adapt import ComfyNoAdaptPipeline
from agent.variant.comfy_no_refine import ComfyNoRefinePipeline
from agent.variant.comfy_no_combine import ComfyNoCombinePipeline
from agent.variant.comfy_no_retrieve import ComfyNoRetrievePipeline
from agent.variant.rag_json_representation import RAGJsonRepresentationPipeline
from agent.variant.rag_list_representation import RAGListRepresentationPipeline
def main():
# parse argument
parser = argparse.ArgumentParser()
parser.add_argument('--instruction', type=str, required=True)
parser.add_argument('--agent_name', type=str, required=True)
parser.add_argument('--save_path', type=str, default=None)
parser.add_argument('--num_examples', type=int, default=3)
parser.add_argument('--num_trajectories', type=int, default=3)
parser.add_argument('--num_references', type=int, default=5)
parser.add_argument('--step_limitation', type=int, default=5)
parser.add_argument('--debug_limitation', type=int, default=1)
args = parser.parse_args()
if args.save_path is None:
timestamp = time.strftime('%Y-%m-%d-%H-%M-%S')
args.save_path = f'./cache/pipeline_record/{timestamp}'
os.makedirs(args.save_path, exist_ok=True)
with open(f'{args.save_path}/instruction.txt', 'w') as file:
file.write(args.instruction)
sys.stdout = console.Logger(f'{args.save_path}/logging.txt')
# create pipeline
print(' Program Status '.center(80, '-'))
print('creating pipeline...')
print()
if args.agent_name == 'zeroshot':
pipeline = ZeroShotPipeline()
elif args.agent_name == 'fewshot':
pipeline = FewShotPipeline(
num_examples=args.num_examples
)
elif args.agent_name == 'cot':
pipeline = CoTPipeline(
num_examples=args.num_examples
)
elif args.agent_name == 'cotsc':
pipeline = CoTSCPipeline(
num_examples=args.num_examples,
num_trajectories=args.num_trajectories
)
elif args.agent_name == 'rag':
pipeline = RAGPipeline(
num_references=args.num_references
)
elif args.agent_name == 'comfy':
pipeline = ComfyPipeline(
num_references=args.num_references,
step_limitation=args.step_limitation,
debug_limitation=args.debug_limitation
)
elif args.agent_name == 'rag_json_representation':
pipeline = RAGJsonRepresentationPipeline(
num_references=args.num_references
)
elif args.agent_name == 'rag_list_representation':
pipeline = RAGListRepresentationPipeline(
num_references=args.num_references
)
elif args.agent_name == 'comfy_no_combine':
pipeline = ComfyNoCombinePipeline(
num_references=args.num_references,
step_limitation=args.step_limitation
)
elif args.agent_name == 'comfy_no_adapt':
pipeline = ComfyNoAdaptPipeline(
num_references=args.num_references,
step_limitation=args.step_limitation
)
elif args.agent_name == 'comfy_no_retrieve':
pipeline = ComfyNoRetrievePipeline(
num_references=args.num_references,
step_limitation=args.step_limitation
)
elif args.agent_name == 'comfy_no_refine':
pipeline = ComfyNoRefinePipeline(
num_references=args.num_references,
step_limitation=args.step_limitation
)
# generate workflow
print(' Program Status '.center(80, '-'))
print('generating workflow...')
print()
try:
prompt = pipeline(args.instruction)
with open(f'{args.save_path}/workflow.json', 'w') as file:
json.dump(prompt, file, indent=4)
except Exception as error:
print(' Program Status '.center(80, '-'))
print(f'failed to generate workflow: {error}')
print()
return
# execute workflow
print(' Program Status '.center(80, '-'))
print('executing workflow...')
print()
try:
_, outputs = execute_prompt(prompt)
for name, data in outputs.items():
with open(f'{args.save_path}/{name}', 'wb') as file:
file.write(data)
except Exception as error:
print(' Program Status '.center(80, '-'))
print(f'failed to execute workflow: {error}')
print()
return
if __name__ == '__main__':
main()