Skip to content

Commit

Permalink
Enable attention selection for wav2vec2 (#1757)
Browse files Browse the repository at this point in the history
Signed-off-by: Urszula Golowicz <urszula.golowicz@intel.com>
  • Loading branch information
ugolowic authored Feb 21, 2025
1 parent c5a715c commit c899e42
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 9 deletions.
6 changes: 4 additions & 2 deletions examples/audio-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ python run_audio_classification.py \
--throughput_warmup_steps 3 \
--sdp_on_bf16 \
--bf16 \
--trust_remote_code True
--trust_remote_code True \
--attn_implementation sdpa
```

On a single HPU, this script should run in ~13 minutes and yield an accuracy of **97.96%**.
Expand Down Expand Up @@ -98,7 +99,8 @@ PT_HPU_LAZY_MODE=0 python ../gaudi_spawn.py \
--bf16 \
--trust_remote_code True \
--torch_compile \
--torch_compile_backend hpu_backend
--torch_compile_backend hpu_backend \
--attn_implementation sdpa
```

On 8 HPUs, this script should run in ~12 minutes and yield an accuracy of **80.49%**.
Expand Down
4 changes: 3 additions & 1 deletion examples/audio-classification/run_audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ class ModelArguments:
)

def __post_init__(self):
if self.use_flash_attention:
os.environ["USE_FLASH_ATTENTION"] = "1"
if self.flash_attention_recompute:
assert self.use_flash_attention, "flash_attention_recompute is set, but use_flash_attention is not"
os.environ["FLASH_ATTENTION_RECOMPUTE"] = "1"
Expand Down Expand Up @@ -389,7 +391,7 @@ def compute_metrics(eval_pred):
revision=model_args.model_revision,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
attn_implementation="sdpa" if model_args.use_flash_attention else "eager",
attn_implementation=training_args.attn_implementation,
)
model = AutoModelForAudioClassification.from_pretrained(
model_args.model_name_or_path,
Expand Down
10 changes: 5 additions & 5 deletions examples/speech-recognition/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ python run_speech_recognition_ctc.py \
--bf16 \
--use_hpu_graphs_for_training \
--use_hpu_graphs_for_inference \
--sdp_on_bf16
--attn_implementation sdpa
```

On a single HPU, this script should run in *ca.* 6 hours and yield a CTC loss of **0.059** and a word error rate of **0.0423**.
Expand Down Expand Up @@ -132,7 +132,7 @@ python ../gaudi_spawn.py \
--sdp_on_bf16 \
--use_hpu_graphs_for_training \
--use_hpu_graphs_for_inference \
--sdp_on_bf16
--attn_implementation sdpa
```

On 8 HPUs, this script should run in *ca.* 49 minutes and yield a CTC loss of **0.0613** and a word error rate of **0.0458**.
Expand Down Expand Up @@ -181,7 +181,8 @@ python ../gaudi_spawn.py \
--gaudi_config_name Habana/wav2vec2 \
--throughput_warmup_steps 3 \
--deepspeed ../../tests/configs/deepspeed_zero_2.json \
--sdp_on_bf16
--sdp_on_bf16 \
--attn_implementation sdpa
```
[The documentation](https://huggingface.co/docs/optimum/habana/usage_guides/deepspeed) provides more information about how to use DeepSpeed within Optimum Habana.
Expand Down Expand Up @@ -214,8 +215,7 @@ python run_speech_recognition_ctc.py \
--gaudi_config_name="Habana/wav2vec2" \
--sdp_on_bf16 \
--bf16 \
--use_hpu_graphs_for_inference \
--sdp_on_bf16
--use_hpu_graphs_for_inference
```
## Sequence to Sequence

Expand Down
28 changes: 28 additions & 0 deletions examples/speech-recognition/run_speech_recognition_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,33 @@ class ModelArguments:
"useful to downsample the output length."
},
)
use_flash_attention: bool = field(
default=False, metadata={"help": "Whether to use Habana flash attention for fine-tuning"}
)
flash_attention_recompute: bool = field(
default=False,
metadata={
"help": "Whether to enable recompute in Habana flash attention for fine-tuning."
" It is applicable only when use_flash_attention is True."
},
)
flash_attention_fast_softmax: bool = field(
default=False,
metadata={
"help": "Whether to use fast softmax for Habana flash attention."
" It is applicable only when use_flash_attention is True."
},
)

def __post_init__(self):
if self.use_flash_attention:
os.environ["USE_FLASH_ATTENTION"] = "1"
if self.flash_attention_recompute:
assert self.use_flash_attention, "flash_attention_recompute is set, but use_flash_attention is not"
os.environ["FLASH_ATTENTION_RECOMPUTE"] = "1"
if self.flash_attention_fast_softmax:
assert self.use_flash_attention, "flash_attention_fast_softmax is set, but use_flash_attention is not"
os.environ["FLASH_ATTENTION_FAST_SOFTMAX"] = "1"


@dataclass
Expand Down Expand Up @@ -535,6 +562,7 @@ def remove_special_characters(batch):
cache_dir=model_args.cache_dir,
token=data_args.token,
trust_remote_code=data_args.trust_remote_code,
attn_implementation=training_args.attn_implementation,
)

# 4. Next, if no tokenizer file is defined,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def __init__(
is_causal,
config,
)
self.use_flash_attention = True if os.getenv("USE_FLASH_ATTENTION") == "1" else False
self.flash_attention_fast_softmax = True if os.getenv("FLASH_ATTENTION_FAST_SOFTMAX") == "1" else False
self.flash_attention_recompute = True if os.getenv("FLASH_ATTENTION_RECOMPUTE") == "1" else False

Expand Down Expand Up @@ -581,7 +582,7 @@ def forward(
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False

if FusedSDPA:
if self.use_flash_attention and FusedSDPA:
if tgt_len == 1:
# next token
softmax_mode = True if os.getenv("QUANT_CONFIG", "") else False
Expand Down
9 changes: 9 additions & 0 deletions optimum/habana/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,15 @@ class GaudiTrainingArguments(TrainingArguments):
},
)

# Use this to override default attn_implementation in transformers
attn_implementation: Optional[str] = field(
default="eager",
metadata={
"help": "choose whether to use scale dot product attention (SDPA) or not.",
"choices": ["eager", "sdpa"],
},
)

sdp_on_bf16: bool = field(
default=False,
metadata={"help": "Allow pyTorch to use reduced precision in the SDPA math backend"},
Expand Down

0 comments on commit c899e42

Please sign in to comment.