-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
109 lines (94 loc) · 3.52 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import os
from datetime import datetime
import fire
import tiktoken
import torch
from torch.cuda.amp.autocast_mode import autocast
from torch.cuda.amp.grad_scaler import GradScaler
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from dataset import TransformerDataset
from decoder_transformer import DecoderTransformer
def train(
iters: int = 1000,
batch_size: int = 32,
lr: float = 3e-4,
device: str | torch.device = "cuda" if torch.cuda.is_available() else "mps",
checkpoint_dir: str = "checkpoints",
eval_every: int = 100,
):
device = torch.device(device)
batch_size = int(batch_size)
learning_rate = lr
iters = iters
eval_every = eval_every or iters // 10
eval_iters = iters // 10
encoding = tiktoken.get_encoding("gpt2")
with open("verne.txt", "r") as f:
text = f.read()
vocab_size = encoding.n_vocab
embed_size = 384
context_size = 256
num_heads = 6
num_blocks = 6
dataset = TransformerDataset(text, encoding, context_size)
train_data, val_data = random_split(
dataset, [int(len(dataset) * 0.9), len(dataset) - int(len(dataset) * 0.9)]
)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
start_time = datetime.now().strftime("%Y%m%d_%H%M")
model = DecoderTransformer(
num_blocks=num_blocks,
num_heads=num_heads,
embed_size=embed_size,
context_size=context_size,
vocab_size=vocab_size,
).to(device)
print(f"Loaded model with {sum(p.numel() for p in model.parameters())} parameters")
model = torch.compile(model) if torch.cuda.is_available() else model
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scaler = GradScaler(enabled=torch.cuda.is_available())
progress_bar = tqdm(range(1, iters + 1))
os.makedirs(checkpoint_dir, exist_ok=True)
# train the model
train_loader_iter = iter(train_loader)
for i in progress_bar:
x, y = next(train_loader_iter)
model.train()
optimizer.zero_grad()
with autocast(enabled=torch.cuda.is_available()):
logits, loss = model(x.to(device), y.to(device))
if torch.cuda.is_available():
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
if i % eval_every == 0:
# Save the model state
torch.save(
model.state_dict(),
os.path.join(checkpoint_dir, f"{start_time}_model_state.pt"),
)
# Evaluate the model
model.eval()
with torch.no_grad():
eval_losses = []
for loader in [iter(train_loader), iter(val_loader)]:
losses = torch.zeros(eval_iters)
for j in range(eval_iters):
with autocast(enabled=torch.cuda.is_available()):
x, y = next(loader)
model.eval()
_, loss = model(x.to(device), y.to(device))
losses[j] = loss.item()
eval_losses.append(losses.mean().item())
progress_bar.set_postfix_str(
f"Train Loss: {eval_losses[0]:.4f}, Val Loss: {eval_losses[1]:.4f}"
)
# run
if __name__ == "__main__":
fire.Fire(train)