-
Notifications
You must be signed in to change notification settings - Fork 30k
Description
System Info
transformers
version: 4.35.0- Platform: Linux-5.14.0-284.25.1.el9_2.x86_64-x86_64-with-glibc2.34
- Python version: 3.11.5
- Huggingface_hub version: 0.17.3
- Safetensors version: 0.4.0
- Accelerate version: 0.23.0
- Accelerate config: not found
- PyTorch version (GPU?): 2.1.0+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: trainer's default
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
I was trying to fine-tune whisper small with flash attention 2 on a private data. Followed the post here for most of the code. Here are some changes I made:
model_card = "openai/whisper-small"
model_name = model_card.split("/")[-1]
config = configparser.ConfigParser()
config.read("config.ini")
tran_df = pd.read_csv("../total_df.csv")
processor = AutoProcessor.from_pretrained(
model_card)
tokenizer = WhisperTokenizer.from_pretrained(
model_card)
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_card)
temo_dt = load_dataset(
"audiofolder", data_dir=config['DATA']['dataset'],
split="train[:1%]")
temo_dt = temo_dt.train_test_split(test_size=0.3)
temo_dt = temo_dt.cast_column("audio", Audio(sampling_rate=16000))
model = WhisperForConditionalGeneration.from_pretrained(
model_card, use_flash_attention_2=True,
torch_dtype=torch.float16)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
language="english", task="transcribe")
model.config.suppress_tokens = []
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
# training process
training_args = Seq2SeqTrainingArguments(
output_dir=f"../{model_name}",
per_device_train_batch_size=4,
gradient_accumulation_steps=16,
learning_rate=1e-5,
warmup_steps=500,
max_steps=6000,
# speed up
gradient_checkpointing=True,
evaluation_strategy="steps",
per_device_eval_batch_size=16,
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to="none",
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
auto_find_batch_size=True,
torch_compile=True,
)
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=temo_dt["train"],
eval_dataset=temo_dt["test"],
data_collator=data_collator,
compute_metrics=compute_metrics_wer,
tokenizer=processor.feature_extractor,
)
trainer.train()
It gave me this error: RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same.
So I tried to convert the temo_dt
to half tensor using the following code:
format = {'type': 'torch', 'format_kwargs' :{'dtype': torch.float16}}
temo_dt.set_format(**format)
But it returned this error: RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding).
Very interestingly, I can fine-tune the whisper small model perfectly without flash attention 2 using the code above. Is there anything I missed?
Expected behavior
Fine-tuning whisper should go as expected with use_flash_attention_2=True
.