Skip to content

Commit

Permalink
Optimize the Trompt example to reduce training time by ~30% (#477)
Browse files Browse the repository at this point in the history
## Summary
Reduces training time by 29% by:
* removing unnecessary stream and D2H synchronisations
* replacing `einsum` with `matmul`
* making data copied async wrt host
* using TF32 for matmuls

## Benchmark results

Benchmarked the change using `jannis` dataset on `g6.4xlarge` that has a
L4 GPU. To reproduce, run:
```console
$ python examples/trompt.py --dataset jannis
$ python examples/trompt.py --dataset jannis --compile
```

Average time per training step:
| | master (`A`) | **this PR** (`B`) | `B/A` |
|:-:|:-:|:-:|:-:|
|eager|742.79 ms|**741.22** ms|0.99x|
|compile|262.24 ms|**186.80** ms|0.71x|

Average time per evaluation step:
| | master (`A`) | **this PR** (`B`) | `B/A` |
|:-:|:-:|:-:|:-:|
|eager|303.76 ms|**302.73** ms|1.00x|
|compile|58.75 ms|**56.27** ms|0.96x|



---

Happy holidays! 🎄🧹
  • Loading branch information
akihironitta authored Jan 3, 2025
1 parent 17b5507 commit da158fb
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 33 deletions.
86 changes: 55 additions & 31 deletions examples/trompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
helena : 37.90
jannis : 72.98
"""

import argparse
import os.path as osp

Expand All @@ -27,6 +26,10 @@
from torch_frame.datasets import TabularBenchmark
from torch_frame.nn import Trompt

# Use TF32 for faster matrix multiplication on Ampere GPUs.
# https://dev-discuss.pytorch.org/t/pytorch-and-tensorfloat32/504
torch.set_float32_matmul_precision('high')

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="california")
parser.add_argument("--channels", type=int, default=128)
Expand Down Expand Up @@ -64,12 +67,23 @@
train_tensor_frame = train_dataset.tensor_frame
val_tensor_frame = val_dataset.tensor_frame
test_tensor_frame = test_dataset.tensor_frame
train_loader = DataLoader(train_tensor_frame, batch_size=args.batch_size,
shuffle=True)
val_loader = DataLoader(val_tensor_frame, batch_size=args.batch_size)
test_loader = DataLoader(test_tensor_frame, batch_size=args.batch_size)
train_loader = DataLoader(
train_tensor_frame,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
)
val_loader = DataLoader(
val_tensor_frame,
batch_size=args.batch_size,
pin_memory=True,
)
test_loader = DataLoader(
test_tensor_frame,
batch_size=args.batch_size,
pin_memory=True,
)

# Set up model and optimizer
model = Trompt(
channels=args.channels,
out_channels=dataset.num_classes,
Expand All @@ -79,59 +93,69 @@
col_names_dict=train_tensor_frame.col_names_dict,
).to(device)
model = torch.compile(model, dynamic=True) if args.compile else model
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, fused=True)
lr_scheduler = ExponentialLR(optimizer, gamma=0.95)


def train(epoch: int) -> float:
def train(epoch: int) -> torch.Tensor:
model.train()
loss_accum = total_count = 0
loss_accum = torch.zeros(1, device=device, dtype=torch.float32).squeeze_()
total_count = 0

for tf in tqdm(train_loader, desc=f"Epoch: {epoch}"):
tf = tf.to(device)
for tf in tqdm(train_loader, desc=f"Epoch {epoch:3d}"):
tf = tf.to(device, non_blocking=True)
# [batch_size, num_layers, num_classes]
out = model(tf)
num_layers = out.size(1)
batch_size, num_layers, num_classes = out.size()
# [batch_size * num_layers, num_classes]
pred = out.view(-1, dataset.num_classes)
y = tf.y.repeat_interleave(num_layers)
pred = out.view(-1, num_classes)
y = tf.y.repeat_interleave(
num_layers,
output_size=num_layers * batch_size,
)
# Layer-wise logit loss
loss = F.cross_entropy(pred, y)
optimizer.zero_grad()
loss.backward()
loss_accum += float(loss) * len(tf.y)
total_count += len(tf.y)
optimizer.step()
optimizer.zero_grad()

total_count += len(tf.y)
loss *= len(tf.y)
loss_accum += loss

lr_scheduler.step()
return loss_accum / total_count


@torch.no_grad()
def test(loader: DataLoader) -> float:
def test(loader: DataLoader, desc: str) -> torch.Tensor:
model.eval()
accum = total_count = 0
accum = torch.zeros(1, device=device, dtype=torch.long).squeeze_()
total_count = 0

for tf in loader:
tf = tf.to(device)
for tf in tqdm(loader, desc=desc):
tf = tf.to(device, non_blocking=True)
pred = model(tf).mean(dim=1)
pred_class = pred.argmax(dim=-1)
accum += float((tf.y == pred_class).sum())
accum += (tf.y == pred_class).sum()
total_count += len(tf.y)

return accum / total_count


best_val_acc = 0
best_test_acc = 0
best_val_acc = 0.0
best_test_acc = 0.0
for epoch in range(1, args.epochs + 1):
train_loss = train(epoch)
train_acc = test(train_loader)
val_acc = test(val_loader)
test_acc = test(test_loader)
train_acc = test(train_loader, "Eval (train)")
val_acc = test(val_loader, "Eval (val)")
if best_val_acc < val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
f"Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}")
lr_scheduler.step()
best_test_acc = test(test_loader, "Eval (test)")

print(f"Train Loss: {train_loss:.4f}, "
f"Train Acc: {train_acc:.4f}, "
f"Val Acc: {val_acc:.4f}, "
f"Test Acc: {best_test_acc:.4f}")

print(f"Best Val Acc: {best_val_acc:.4f}, Best Test Acc: {best_test_acc:.4f}")
3 changes: 1 addition & 2 deletions torch_frame/nn/conv/trompt_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def forward(self, x: Tensor, x_prompt: Tensor) -> Tensor:
# M_importance
# [batch_size, num_prompts, channels], [batch_size, num_cols, channels]
# -> [batch_size, num_prompts, num_cols]
m_importance = torch.einsum('ijl,ikl->ijk', stacked_e_prompt,
stacked_e_column)
m_importance = stacked_e_prompt @ stacked_e_column.transpose(1, 2)
m_importance = F.softmax(m_importance, dim=-1)
# [batch_size, num_prompts, num_cols, 1]
m_importance = m_importance.unsqueeze(dim=-1)
Expand Down

0 comments on commit da158fb

Please sign in to comment.