Skip to content

Using Trainer + a pretrained tokenizer + 4D attention mask is extremely slow #32101

Open
@JackCai1206

Description

@JackCai1206

System Info

transformers 4.41.0

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import LlamaForCausalLM, LlamaConfig, TrainingArguments, Trainer, AutoTokenizer
from datasets import IterableDataset
import numpy as np

model_config = LlamaConfig(
    vocab_size=10,
    hidden_size=384,
    num_hidden_layers=6,
    num_attention_heads=6,
    intermediate_size=1024,
    max_position_embeddings=1024,
)
model = LlamaForCausalLM(model_config)
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')

def get_data1():
    for i in range(10000):
        yield {'input_ids': np.zeros(1024, dtype=int), 'labels': np.zeros(1024, dtype=int), 'attention_mask': np.zeros((1, 1024, 1024), dtype=float)}

def get_data2():
    for i in range(10000):
        yield {'input_ids': np.zeros(1024, dtype=int), 'labels': np.zeros(1024, dtype=int), 'attention_mask': np.zeros((1024), dtype=int)}
    
ds_slow = IterableDataset.from_generator(get_data1).with_format('torch')
ds_fast = IterableDataset.from_generator(get_data2).with_format('torch')

training_args = TrainingArguments(max_steps=1, output_dir='./out', report_to=None, per_device_train_batch_size=32, gradient_accumulation_steps=32)
trainer1 = Trainer(model, training_args, train_dataset=ds_slow, tokenizer=tokenizer)
trainer2 = Trainer(model, training_args, train_dataset=ds_fast, tokenizer=tokenizer)

import cProfile
cProfile.run('trainer1.train()', './test_slow.profile')
cProfile.run('trainer2.train()', './test_fast.profile')
import pstats

# compare the two profiles
p1 = pstats.Stats('./test_slow.profile')
p2 = pstats.Stats('./test_fast.profile')
p1.sort_stats('cumtime').print_stats()
         1582200 function calls (1401111 primitive calls) in 340.112 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000  340.112  340.112 {built-in method builtins.exec}
        1    0.000    0.000  340.112  340.112 <string>:1(<module>)
        1    0.000    0.000  340.112  340.112 trainer.py:1784(train)
        1    0.017    0.017  340.112  340.112 trainer.py:1892(_inner_training_loop)
       33    0.001    0.000  326.171    9.884 data_loader.py:663(__iter__)
       33    0.001    0.000  325.473    9.863 data_loader.py:618(_fetch_batches)
 2486/265    0.001    0.000  325.428    1.228 {built-in method builtins.next}
       33    0.001    0.000  325.088    9.851 dataloader.py:625(__next__)
       33    0.725    0.022  325.083    9.851 dataloader.py:672(_next_data)
       33    0.002    0.000  323.988    9.818 fetch.py:24(fetch)
       33    0.000    0.000  320.979    9.727 trainer_utils.py:807(__call__)
       33    0.000    0.000  320.971    9.726 data_collator.py:270(__call__)
       33   16.982    0.515  320.971    9.726 data_collator.py:52(pad_without_fast_tokenizer_warning)
       33    0.005    0.000  303.989    9.212 tokenization_utils_base.py:3209(pad)
     6493  235.747    0.036  235.747    0.036 {built-in method torch.tensor}
      197    0.001    0.000  234.735    1.192 tokenization_utils_base.py:204(__init__)
      197    0.001    0.000  234.732    1.192 tokenization_utils_base.py:681(convert_to_tensors)
       99    0.000    0.000  234.730    2.371 tokenization_utils_base.py:718(as_tensor)
p2.sort_stats('cumtime').print_stats()
        1567440 function calls (1386340 primitive calls) in 16.431 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   16.431   16.431 {built-in method builtins.exec}
        1    0.000    0.000   16.431   16.431 <string>:1(<module>)
        1    0.000    0.000   16.431   16.431 trainer.py:1784(train)
        1    0.018    0.018   16.431   16.431 trainer.py:1892(_inner_training_loop)
       32    0.003    0.000   14.327    0.448 trainer.py:3212(training_step)
       32    0.001    0.000    8.830    0.276 accelerator.py:2093(backward)
       32    0.000    0.000    8.829    0.276 _tensor.py:433(backward)
       32    0.000    0.000    8.829    0.276 __init__.py:149(backward)
       32    8.827    0.276    8.827    0.276 {method 'run_backward' of 'torch._C._EngineBase' objects}
       33    0.000    0.000    4.546    0.138 memory.py:147(empty_cache)
       33    4.546    0.138    4.546    0.138 {built-in method torch._C._cuda_emptyCache}
 2486/265    0.001    0.000    1.469    0.006 {built-in method builtins.next}
       33    0.001    0.000    1.160    0.035 data_loader.py:663(__iter__)
       33    0.000    0.000    1.145    0.035 data_loader.py:618(_fetch_batches)
       33    0.000    0.000    1.136    0.034 dataloader.py:625(__next__)
       33    0.003    0.000    1.134    0.034 dataloader.py:672(_next_data)
       33    0.002    0.000    1.124    0.034 fetch.py:24(fetch)
       32    0.000    0.000    0.955    0.030 trainer.py:3254(compute_loss)
...
        1    0.000    0.000    0.000    0.000 modeling_utils.py:903(_
...

Expected behavior

Since the trace of the profiler is really long I only included the first few lines.
I am running a small llama model on some dummy data, the only difference between the two datasets is that the slow version outputs 4D attention masks, which is a feature recently added in #27539. I am running both trainers for 1 iteration.

As you can see the slow run is 340s while the fast one runs in 16s.

The slow version of the trainer is many times slower than the fast version. The problem probably lies in the default collator DataCollatorWithPadding (when there is a pretrained tokenizer), which calls tokenizer.pad on the 4D attention masks. When you takeaway either 1) the pretrained tokenizer or 2) the 4D attention mask, trainer runs much faster.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions