This repository has been archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 20
Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape #279
Labels
documentation
Improvements or additions to documentation
Comments
agreed, we should at the very least document this upfront, and look into if a good padding implementation is viable. |
I made this small util a very long time ago, let me retry with your example |
With the changes in that PR and using SGD import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf").to("cuda")
# Convert all torch.nn.Linear modules to Float8DynamicLinear
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
import float8_experimental
float8_experimental.config.pad_inner_dim = True
swap_linear_with_float8_linear(model, Float8DynamicLinear)
# Wrap model with Fully Sharded Data Parallel (FSDP)
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import os
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'
dist.init_process_group(backend='nccl', init_method='env://')
# model = FSDP(model, use_orig_params=True)
# optionally compile the model
# model = torch.compile(model)
# Prepare your dataset and dataloader (customize this part as needed)
class TextDataset(torch.utils.data.Dataset):
def __init__(self, texts, tokenizer):
self.encodings = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=512)
def __getitem__(self, idx):
return {key: val[idx] for key, val in self.encodings.items()}
def __len__(self):
return len(self.encodings.input_ids)
# Example text data
texts = ["Example text input 1.", "Example text input 2.", "Example text input 3."]
dataset = TextDataset(texts, tokenizer)
dataloader = DataLoader(dataset, batch_size=2)
# Set up the optimizer
# optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
optimizer = torch.optim.SGD(model.parameters(), lr=5e-4)
# Training loop
model.train()
for epoch in range(3): # Loop over the dataset multiple times
for i, batch in enumerate(dataloader):
inputs = {k: v.to(model.device) for k, v in batch.items()}
# Forward pass
outputs = model(**inputs, labels=inputs['input_ids'])
loss = outputs.loss
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch {epoch + 1}, Step {i + 1}, Loss: {loss.item()}')
# Save the fine-tuned model
model.save_pretrained("./fine_tuned_model")
print("Training complete!") All works, that being said we have a fuse "scaled_mm" things in flight and need to add a few more tests |
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
I wrote a toy training loop to get something going with fp8 and ran into this padding related issue. I managed to solve it by just replacing a single line in my code by
texts = ["Example text input 1 bla bla bla bla bla bla bla bla bla.", "Example text input 2.", "Example text input 3."]
but it took me about 10 min to hunt down. I figure this is some performance related assert for tensor cores in which case padding feels like it makes senseAfter that I now have a functioning hello world example with the loss going down
Error
Code
The text was updated successfully, but these errors were encountered: