-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_txt_gen.py
executable file
·131 lines (120 loc) · 4.12 KB
/
run_txt_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python3
import argparse
import random
import re
from openai import OpenAI
from typing import List
from tqdm import tqdm
from text.gen import make_chatgpt_query
from text.utils import post_process_sentences
from utils.files import append_sentences_to_file, read_file
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Sentence/word generation using ChatGPT"
)
parser.add_argument(
"--input_file", type=str, default=None, help="Input file with words"
)
parser.add_argument(
"--num", type=int, default=None, help="Number of sentences or words to generate"
)
parser.add_argument(
"--context",
type=str,
default="radiologia médica",
help="Context of the generated sentences",
)
parser.add_argument(
"--query",
type=str,
default=None,
help="A query to OpenAI's ChatGPT; the first number detected in the query will be replaced by the number of sentences to generate",
)
parser.add_argument(
"--return_type",
type=str,
default="frases",
help="Type of data to generate (default: frases)",
)
parser.add_argument(
"--api_key",
type=str,
default=None,
help="OpenAI API key",
)
parser.add_argument(
"--model",
type=str,
default="gpt-4o-mini",
help="ChatGPT model to use",
)
parser.add_argument(
"--seed",
type=int,
default=451,
help="Random seed (default: 451)",
)
parser.add_argument(
"--output",
type=str,
default=None,
help="Output file to write generated sentences",
)
args = parser.parse_args()
random.seed(args.seed)
if args.query is None:
if args.return_type == "frases":
args.query = f"Você é um médico ditando o laudo de um paciente. No contexto de {args.context}, gere {args.num} {args.return_type} contendo o termo '[MASK]', separadas por nova linha."
else:
args.query = f"No contexto de {args.context}, gere {args.num} {args.return_type} separadas por nova linha."
else:
args.num = (
int(re.search(r"\d+", args.query).group())
if re.search(r"\d+", args.query)
else None
)
if args.input_file:
wordlist = read_file(args.input_file)
else:
if args.return_type == "frases" and "[MASK]" in args.query:
wordlist = []
while True:
word = input("Enter a word (or press Enter to finish): ")
if word == "":
break
wordlist.append(word)
else:
wordlist = [""]
response_sentences: List[str] = []
original_query = args.query
openai_client = OpenAI(api_key=args.api_key)
for word in tqdm(wordlist):
word = word.strip()
query = re.sub(r"\[MASK\]", word, original_query)
number_of_sentences_left = args.num
while number_of_sentences_left > 0:
print(f"\nNumber of sentences left: {number_of_sentences_left}")
print(f"Querying OpenAI's {args.model} with '{query}'...")
query_response = make_chatgpt_query(
openai_client,
query,
return_type=args.return_type,
model=args.model,
)
print(query_response)
response_sentences.extend(
[s.split(" ", 1)[1] if s[0].isdigit() else s for s in query_response]
)
number_of_sentences_left -= len(query_response)
query = re.sub(r"\d+", str(number_of_sentences_left), query)
print()
generated_sentences = post_process_sentences(response_sentences, modify=True)
print("\nFinal results:")
print("-------------------")
for sentence in generated_sentences:
print(sentence)
print(f"\nTotal: {len(generated_sentences)} sentences")
print("-------------------\n")
if args.output:
print(f"Appending generated sentences to {args.output}...")
append_sentences_to_file(args.output, generated_sentences)