forked from loubnabnl/santacoder-finetuning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfim.py
259 lines (213 loc) · 8.17 KB
/
fim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
from pathlib import Path
import signal
from tree_sitter import Language, Parser, Node
import functools
import random
import hashlib
import numpy as np
from numpy.random import RandomState
from typing import List, Tuple, Any, Optional
Language.build_library(
f"{Path(__file__).parent}/build/languages.so",
[f"{Path(__file__).parent}/tree-sitter-typescript/typescript"]
)
TS_LANGUAGE = Language(
f"{Path(__file__).parent}/build/languages.so", 'typescript')
PARSER = Parser()
PARSER.set_language(TS_LANGUAGE)
# this is expensive so we cache it
@functools.lru_cache(maxsize=None)
def get_fim_token_ids(tokenizer):
try:
_, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD = tokenizer.special_tokens_map[
"additional_special_tokens"
]
suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = (
tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD]
)
except KeyError:
suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = None, None, None, None
return suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id
def get_prefix_middle_suffix(np_rng: RandomState, sample: bytes, strip_suffix_rate: float) -> Optional[Tuple[Tuple[str, str, str], RandomState]]:
def is_child_type_annotation(node):
"""Checks if any of the parent nodes is an annotation node."""
node = node.parent
while node is not None:
if node.type == "type_annotation" or node.type == "opting_type_annotation" or node.type == "omitting_type_annotation":
return True
node = node.parent
return False
def contains_url(node):
# check if it contains a url, should not contain a //
string = sample[node.start_byte:node.end_byte].decode("utf-8")
return "//" in string
QUERY = TS_LANGUAGE.query("""
[
(type_annotation) @annotation
(opting_type_annotation) @annotation
(omitting_type_annotation) @annotation
]
""")
tree = PARSER.parse(sample)
# Each capture has a start_byte and end_byte; these are the indices of the
# type annotation. We want to invert these indices, i.e. get the substrings
# between the captures (and also the substring before the first capture and
# the substring after the last capture).
captures: List[Tuple[Node, str]] = QUERY.captures(tree.root_node)
def is_splitable(node):
return not is_child_type_annotation(node) and not contains_url(node)
def is_capturable(node):
return not is_child_type_annotation(node) and not contains_url(node)
captures_no_child: List[Node] = []
for i, (node, _) in enumerate(captures):
if is_capturable(node):
captures_no_child += [node]
splittable_indices: List[int] = []
for i, node in enumerate(captures_no_child):
if is_splitable(node):
splittable_indices += [i]
if len(splittable_indices) == 0:
return None
random_pick_i = np_rng.choice(splittable_indices)
prefix_b: bytes = sample[:captures_no_child[random_pick_i].start_byte]
middle_b: bytes = sample[captures_no_child[random_pick_i]
.start_byte:captures_no_child[random_pick_i].end_byte]
if middle_b.startswith(b":"):
prefix_b += b": "
middle_b = middle_b[1:].lstrip()
suffix_b: bytes = b""
# if we strip the types to the suffix:
if np_rng.binomial(1, strip_suffix_rate):
l = len(captures_no_child)
for i in range(random_pick_i, l - 1):
suffix_b += sample[captures_no_child[i]
.end_byte:captures_no_child[i + 1].start_byte]
suffix_b += sample[captures_no_child[l - 1].end_byte:]
else: # keep the types in the suffix
suffix_b = sample[captures_no_child[random_pick_i].end_byte:]
prefix_str = prefix_b.decode("utf-8")
middle_str = middle_b.decode("utf-8")
suffix_str = suffix_b.decode("utf-8")
return (prefix_str, middle_str, suffix_str), np_rng
# Adapted from https://github.com/bigcode-project/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py
def permute(
tokenizer,
sample,
np_rng,
suffix_tok_id,
prefix_tok_id,
middle_tok_id,
fim_rate=0.5,
fim_spm_rate=0.5,
strip_suffix_rate=0.9,
):
"""
Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes:
PSM and SPM (with a probability of fim_spm_rate).
"""
if np_rng.binomial(1, fim_rate):
decoded_bytes: str | bytes = tokenizer.decode(sample)
if not isinstance(decoded_bytes, bytes):
decoded_bytes = decoded_bytes.encode("utf-8")
try:
def timeout_handler(_, __):
decoded = decoded_bytes.decode("utf-8")
h = hashlib.sha1(
f"blob {len(decoded.encode())}\0{decoded}".encode()).hexdigest()
raise Exception(f"Timeout after 10 seconds: {h}")
def crash_in_cpp_handler(_, __):
decoded = decoded_bytes.decode("utf-8")
h = hashlib.sha1(
f"blob {len(decoded.encode())}\0{decoded}".encode()).hexdigest()
raise Exception(f"Crash in C++: {h}")
# set a timeout of 10 seconds, using signal.alarm
signal.signal(signal.SIGALRM, timeout_handler)
# oh boy oh boy
signal.signal(signal.SIGSEGV, crash_in_cpp_handler)
signal.signal(signal.SIGABRT, crash_in_cpp_handler)
signal.alarm(10)
res = get_prefix_middle_suffix(
np_rng, decoded_bytes, strip_suffix_rate)
signal.alarm(0)
except Exception as e:
print(e)
print("GOT FAILED SAMPLE:\n", decoded_bytes)
return None, np_rng
if res is None:
print("No type annotations found in sample")
return None, np_rng
(prefix_str, middle_str, suffix_str), np_rng = res
print(f"Was able to split on type: \"{middle_str}\"")
prefix = np.array(tokenizer.encode(prefix_str))
middle = np.array(tokenizer.encode(middle_str))
suffix = np.array(tokenizer.encode(suffix_str))
if np_rng.binomial(1, fim_spm_rate):
# SPM (variant 2 from FIM paper)
new_sample = np.concatenate(
[
[prefix_tok_id, suffix_tok_id],
suffix,
[middle_tok_id],
prefix,
middle,
]
)
else:
# PSM
new_sample = np.concatenate(
[
[prefix_tok_id],
prefix,
[suffix_tok_id],
suffix,
[middle_tok_id],
middle,
]
)
else:
# don't do FIM preproc
new_sample = sample
return list(new_sample), np_rng
if __name__ == "__main__": # some unit tests
import os
rng = np.random.RandomState(seed=int(os.urandom(4).hex(), 16))
sample = """
interface Foo {
foo(x: number, y: number): number;
name: {
first: string;
last: {
name: string;
age: number;
};
};
}
function foo(x: number, y:number):number {
return x + y;
}
// some unicode to mess things up
// 😀 😃 😄 😁 😆 😅
function foo2(x:number, y: number): number {
return x + y;
}
interface Bar {
bar(x: number, y: number): number;
name: {
first: string;
last: string;
};
}
function url() {
let url = `https://127.0.0.1:${SUBSTRATE_PORT}`
}
"""
bytes_sample = bytes(sample, "utf-8")
print("sample:", sample)
print("bytes_sample:", bytes_sample)
print("get_prefix_middle_suffix:")
res = get_prefix_middle_suffix(rng, bytes_sample, 0.5)
if res is not None:
(prefix_str, middle_str, suffix_str), rng = res
print("prefix_str:", prefix_str)
print("middle_str:", middle_str)
print("suffix_str:", suffix_str)