-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
107 lines (91 loc) · 3.72 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import pickle
import torch
from torch.utils.data import Dataset
from collections import defaultdict
def split_str(line):
line = line.rstrip('\n')
digits = sorted(line[1: 11].split(', '))
digits_str = ' '.join(str(d) for d in digits)
solution_str = line[14:]
return digits_str, solution_str
class DatasetOf24Game(Dataset):
# initialize class vars
one_data_sample = "[4, 8, 9, 3]: 8 + 4 = 12, 9 + 3 = 12, 12 + 12 = 24 "
one_problem_sample = "[4, 8, 9, 3]: "
one_result_sample = "8 + 4 = 12, 9 + 3 = 12, 12 + 12 = 24 "
all_chars = "0123456789[,]:+-*/= "
itoc = {i:c for i, c in enumerate(all_chars)}
ctoi = {c:i for i, c in enumerate(all_chars)}
all_data_set = set()
all_test_data_set = set()
all_train_data = []
all_test_data = []
all_data = []
# train test split
with open('data/dataset1_9.pkl', 'rb') as f:
res = pickle.load(f)
my_train_xys, my_test_xys = res['train'], res['test']
my_test_digits = set(x for x, _ in my_test_xys)
# x --> solution set y
all_train_mapping = defaultdict(set)
all_test_mapping = defaultdict(set)
with open(os.path.dirname(os.path.realpath(__file__)) + "/data/24_game_all_data.txt", "r") as f:
for line in f:
line = line.rstrip('\n')
all_data_set.add(line)
all_data.append(line)
digits_str, solution_str = split_str(line)
if digits_str in my_test_digits:
all_test_data.append(line)
all_test_data_set.add(line)
all_test_mapping[digits_str].add(solution_str)
else:
all_train_data.append(line)
all_train_mapping[digits_str].add(solution_str)
def __init__(self, split, return_tokenized=True):
self.return_tokenized = return_tokenized
self.split = split # train/test
self.ixes = []
if split == 'train':
self.ixes = DatasetOf24Game.all_train_data
elif split == 'test':
self.ixes = DatasetOf24Game.all_test_data
elif split == 'all':
self.ixes = DatasetOf24Game.all_data
else:
raise Exception("'split' must be 'all', 'train' or 'test'!")
@staticmethod
def get_vocab_size():
return len(DatasetOf24Game.all_chars)
@staticmethod
def get_block_size():
# return len of an example
return len(DatasetOf24Game.one_data_sample)
def __len__(self):
return len(self.ixes)
def __getitem__(self, idx):
# a data sample: [4, 8, 9, 3]: 8 + 4 = 12, 9 + 3 = 12, 12 + 12 = 24
s = self.ixes[idx]
if self.return_tokenized:
dix = [DatasetOf24Game.ctoi[c] for c in s] # convert each character to its token index
# x will be input to GPT and y will be the associated expected outputs
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence
y[:len(DatasetOf24Game.one_problem_sample)] = -1 # we will only train in the output locations. -1 will mask loss to zero
return x, y
else:
return s
if __name__ == '__main__':
dataset = DatasetOf24Game('all')
exit()
# print([dataset.ctoi[c] for c in '[4, 8, 9, 3]: 8 + 4 = 12, 9 + 3 = 12, 12 + 12 = 24 '])
# print('[5, 5, 5, 5]: 5 + 5 = 10, 5 + 10 = 15, 15 + 9 = 24 ' in dataset.all_data_set)
dataset = DatasetOf24Game('train', return_tokenized=False)
print(type(dataset[515]))
from torch.utils.data.dataloader import DataLoader
loader = DataLoader(dataset, batch_size=10)
for mini_batch in loader:
print(mini_batch)
print(type(mini_batch))
break