From da158fb5540803a4ff4701d62b9280232472398d Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 3 Jan 2025 13:02:21 +0900 Subject: [PATCH] Optimize the `Trompt` example to reduce training time by ~30% (#477) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Reduces training time by 29% by: * removing unnecessary stream and D2H synchronisations * replacing `einsum` with `matmul` * making data copied async wrt host * using TF32 for matmuls ## Benchmark results Benchmarked the change using `jannis` dataset on `g6.4xlarge` that has a L4 GPU. To reproduce, run: ```console $ python examples/trompt.py --dataset jannis $ python examples/trompt.py --dataset jannis --compile ``` Average time per training step: | | master (`A`) | **this PR** (`B`) | `B/A` | |:-:|:-:|:-:|:-:| |eager|742.79 ms|**741.22** ms|0.99x| |compile|262.24 ms|**186.80** ms|0.71x| Average time per evaluation step: | | master (`A`) | **this PR** (`B`) | `B/A` | |:-:|:-:|:-:|:-:| |eager|303.76 ms|**302.73** ms|1.00x| |compile|58.75 ms|**56.27** ms|0.96x| --- Happy holidays! ๐ŸŽ„๐Ÿงน --- examples/trompt.py | 86 +++++++++++++++++++----------- torch_frame/nn/conv/trompt_conv.py | 3 +- 2 files changed, 56 insertions(+), 33 deletions(-) diff --git a/examples/trompt.py b/examples/trompt.py index 342ab653e..41d53d379 100644 --- a/examples/trompt.py +++ b/examples/trompt.py @@ -14,7 +14,6 @@ helena : 37.90 jannis : 72.98 """ - import argparse import os.path as osp @@ -27,6 +26,10 @@ from torch_frame.datasets import TabularBenchmark from torch_frame.nn import Trompt +# Use TF32 for faster matrix multiplication on Ampere GPUs. +# https://dev-discuss.pytorch.org/t/pytorch-and-tensorfloat32/504 +torch.set_float32_matmul_precision('high') + parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, default="california") parser.add_argument("--channels", type=int, default=128) @@ -64,12 +67,23 @@ train_tensor_frame = train_dataset.tensor_frame val_tensor_frame = val_dataset.tensor_frame test_tensor_frame = test_dataset.tensor_frame -train_loader = DataLoader(train_tensor_frame, batch_size=args.batch_size, - shuffle=True) -val_loader = DataLoader(val_tensor_frame, batch_size=args.batch_size) -test_loader = DataLoader(test_tensor_frame, batch_size=args.batch_size) +train_loader = DataLoader( + train_tensor_frame, + batch_size=args.batch_size, + shuffle=True, + pin_memory=True, +) +val_loader = DataLoader( + val_tensor_frame, + batch_size=args.batch_size, + pin_memory=True, +) +test_loader = DataLoader( + test_tensor_frame, + batch_size=args.batch_size, + pin_memory=True, +) -# Set up model and optimizer model = Trompt( channels=args.channels, out_channels=dataset.num_classes, @@ -79,59 +93,69 @@ col_names_dict=train_tensor_frame.col_names_dict, ).to(device) model = torch.compile(model, dynamic=True) if args.compile else model -optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, fused=True) lr_scheduler = ExponentialLR(optimizer, gamma=0.95) -def train(epoch: int) -> float: +def train(epoch: int) -> torch.Tensor: model.train() - loss_accum = total_count = 0 + loss_accum = torch.zeros(1, device=device, dtype=torch.float32).squeeze_() + total_count = 0 - for tf in tqdm(train_loader, desc=f"Epoch: {epoch}"): - tf = tf.to(device) + for tf in tqdm(train_loader, desc=f"Epoch {epoch:3d}"): + tf = tf.to(device, non_blocking=True) # [batch_size, num_layers, num_classes] out = model(tf) - num_layers = out.size(1) + batch_size, num_layers, num_classes = out.size() # [batch_size * num_layers, num_classes] - pred = out.view(-1, dataset.num_classes) - y = tf.y.repeat_interleave(num_layers) + pred = out.view(-1, num_classes) + y = tf.y.repeat_interleave( + num_layers, + output_size=num_layers * batch_size, + ) # Layer-wise logit loss loss = F.cross_entropy(pred, y) - optimizer.zero_grad() loss.backward() - loss_accum += float(loss) * len(tf.y) - total_count += len(tf.y) optimizer.step() + optimizer.zero_grad() + + total_count += len(tf.y) + loss *= len(tf.y) + loss_accum += loss + + lr_scheduler.step() return loss_accum / total_count @torch.no_grad() -def test(loader: DataLoader) -> float: +def test(loader: DataLoader, desc: str) -> torch.Tensor: model.eval() - accum = total_count = 0 + accum = torch.zeros(1, device=device, dtype=torch.long).squeeze_() + total_count = 0 - for tf in loader: - tf = tf.to(device) + for tf in tqdm(loader, desc=desc): + tf = tf.to(device, non_blocking=True) pred = model(tf).mean(dim=1) pred_class = pred.argmax(dim=-1) - accum += float((tf.y == pred_class).sum()) + accum += (tf.y == pred_class).sum() total_count += len(tf.y) return accum / total_count -best_val_acc = 0 -best_test_acc = 0 +best_val_acc = 0.0 +best_test_acc = 0.0 for epoch in range(1, args.epochs + 1): train_loss = train(epoch) - train_acc = test(train_loader) - val_acc = test(val_loader) - test_acc = test(test_loader) + train_acc = test(train_loader, "Eval (train)") + val_acc = test(val_loader, "Eval (val)") if best_val_acc < val_acc: best_val_acc = val_acc - best_test_acc = test_acc - print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, " - f"Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}") - lr_scheduler.step() + best_test_acc = test(test_loader, "Eval (test)") + + print(f"Train Loss: {train_loss:.4f}, " + f"Train Acc: {train_acc:.4f}, " + f"Val Acc: {val_acc:.4f}, " + f"Test Acc: {best_test_acc:.4f}") print(f"Best Val Acc: {best_val_acc:.4f}, Best Test Acc: {best_test_acc:.4f}") diff --git a/torch_frame/nn/conv/trompt_conv.py b/torch_frame/nn/conv/trompt_conv.py index 1627ccc31..a41f1136a 100644 --- a/torch_frame/nn/conv/trompt_conv.py +++ b/torch_frame/nn/conv/trompt_conv.py @@ -92,8 +92,7 @@ def forward(self, x: Tensor, x_prompt: Tensor) -> Tensor: # M_importance # [batch_size, num_prompts, channels], [batch_size, num_cols, channels] # -> [batch_size, num_prompts, num_cols] - m_importance = torch.einsum('ijl,ikl->ijk', stacked_e_prompt, - stacked_e_column) + m_importance = stacked_e_prompt @ stacked_e_column.transpose(1, 2) m_importance = F.softmax(m_importance, dim=-1) # [batch_size, num_prompts, num_cols, 1] m_importance = m_importance.unsqueeze(dim=-1)