Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape #279

Closed
@msaroufim

Description

@msaroufim

I wrote a toy training loop to get something going with fp8 and ran into this padding related issue. I managed to solve it by just replacing a single line in my code by texts = ["Example text input 1 bla bla bla bla bla bla bla bla bla.", "Example text input 2.", "Example text input 3."] but it took me about 10 min to hunt down. I figure this is some performance related assert for tensor cores in which case padding feels like it makes sense

After that I now have a functioning hello world example with the loss going down

Epoch 1, Step 1, Loss: 8.910361289978027
Epoch 1, Step 2, Loss: 4.616391658782959
Epoch 2, Step 1, Loss: 2.377967119216919
Epoch 2, Step 2, Loss: 1.4298633337020874
Epoch 3, Step 1, Loss: 1.5666098594665527
Epoch 3, Step 2, Loss: 0.8038766384124756

Error

ao) [marksaroufim@devgpu003.cco3 ~/float8_experimental/test (main)]$ HF_TOKEN="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" python fp8.py 
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.32it/s]
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/marksaroufim/float8_experimental/test/fp8.py", line 62, in <module>
[rank0]:     loss.backward()
[rank0]:   File "/home/marksaroufim/anaconda3/envs/ao/lib/python3.10/site-packages/torch/_tensor.py", line 521, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/home/marksaroufim/anaconda3/envs/ao/lib/python3.10/site-packages/torch/autograd/__init__.py", line 289, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/home/marksaroufim/anaconda3/envs/ao/lib/python3.10/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:   File "/home/marksaroufim/float8_experimental/float8_experimental/float8_tensor.py", line 297, in __torch_dispatch__
[rank0]:     return FLOAT8_OPS_TABLE[func](func, args, kwargs)
[rank0]:   File "/home/marksaroufim/float8_experimental/float8_experimental/float8_ops.py", line 151, in float8_mm
[rank0]:     tensor_out, amax = addmm_float8_unwrapped(
[rank0]:   File "/home/marksaroufim/float8_experimental/float8_experimental/float8_python_api.py", line 55, in addmm_float8_unwrapped
[rank0]:     output, output_amax = torch._scaled_mm(
[rank0]: RuntimeError: Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (32000x14.
[rank0]:[W612 11:03:16.562113367 ProcessGroupNCCL.cpp:1158] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())
(ao) [marksaroufim@devgpu003.cco3 ~/float8_experimental/test (main)]$ 

Code

import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf").to("cuda:7")

# Convert all torch.nn.Linear modules to Float8DynamicLinear
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
swap_linear_with_float8_linear(model, Float8DynamicLinear)

# Wrap model with Fully Sharded Data Parallel (FSDP)
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import os
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'

dist.init_process_group(backend='nccl', init_method='env://')

# model = FSDP(model, use_orig_params=True)

# optionally compile the model
# model = torch.compile(model)

# Prepare your dataset and dataloader (customize this part as needed)
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, texts, tokenizer):
        self.encodings = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=512)

    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

# Example text data
texts = ["Example text input 1.", "Example text input 2.", "Example text input 3."]
dataset = TextDataset(texts, tokenizer)
dataloader = DataLoader(dataset, batch_size=2)

# Set up the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

# Training loop
model.train()
for epoch in range(3):  # Loop over the dataset multiple times
    for i, batch in enumerate(dataloader):
        inputs = {k: v.to(model.device) for k, v in batch.items()}
        
        # Forward pass
        outputs = model(**inputs, labels=inputs['input_ids'])
        loss = outputs.loss
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f'Epoch {epoch + 1}, Step {i + 1}, Loss: {loss.item()}')

# Save the fine-tuned model
model.save_pretrained("./fine_tuned_model")

print("Training complete!")

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationImprovements or additions to documentation

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions