Skip to content

Commit

Permalink
Fix a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
beginlner authored Jun 2, 2023
1 parent 85b51d6 commit 8e44c0e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion flash_attn/modules/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
"""
fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.norm1, RMSNorm)
fused_add_norm_fn = (dropout_add_rms_norm if RMSNorm and isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm)
if self.prenorm:
if not self.fused_dropout_add_ln:
Expand Down

0 comments on commit 8e44c0e

Please sign in to comment.