Skip to content

[blip-2] Fix dtype mismatch when keep in fp32 #37068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 28, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,9 @@ def forward(
embeddings += position_embeddings

if query_embeds is not None:
# `query_embeds` are kept in fp32 when we use it with Qformer
if query_embeds.dtype != embeddings.dtype:
query_embeds = query_embeds.to(embeddings.dtype)
embeddings = torch.cat((query_embeds, embeddings), dim=1)
else:
embeddings = query_embeds
Expand Down Expand Up @@ -1385,6 +1388,10 @@ def forward(
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
# Qformer and latent query tokens are kept in fp32. We cast `encoder_hidden_states` if not fp32 already
if encoder_hidden_states.dtype != query_embeds.dtype:
encoder_hidden_states = encoder_hidden_states.to(query_embeds.dtype)

if isinstance(encoder_hidden_states, list):
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
else:
Expand Down Expand Up @@ -1447,7 +1454,7 @@ def forward(
class Blip2Model(Blip2PreTrainedModel):
config_class = Blip2Config
main_input_name = "pixel_values"
_keep_in_fp32_modules = ["query_tokens"]
_keep_in_fp32_modules = ["query_tokens", "qformer"]

def __init__(self, config: Blip2Config):
super().__init__(config)
Expand Down Expand Up @@ -1728,6 +1735,10 @@ def forward(
)
query_output = query_outputs[0]

# Qformer is kept in fp32, we downcast the output back if needed
if query_output.dtype != image_embeds.dtype:
query_output = query_output.to(image_embeds.dtype)

# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
language_model_attention_mask = torch.ones(
Expand Down Expand Up @@ -1799,7 +1810,7 @@ def forward(
)
class Blip2TextModelWithProjection(Blip2PreTrainedModel):
supports_gradient_checkpointing = False
_keep_in_fp32_modules = ["query_tokens"]
_keep_in_fp32_modules = ["query_tokens", "qformer"]

def __init__(self, config: Blip2Config):
super().__init__(config)
Expand Down Expand Up @@ -1898,7 +1909,7 @@ def forward(
)
class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
main_input_name = "pixel_values"
_keep_in_fp32_modules = ["query_tokens"]
_keep_in_fp32_modules = ["query_tokens", "qformer"]

def __init__(self, config: Blip2Config):
super().__init__(config)
Expand Down Expand Up @@ -2019,7 +2030,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
_keep_in_fp32_modules = ["query_tokens"]
_keep_in_fp32_modules = ["query_tokens", "qformer"]

def __init__(self, config: Blip2Config):
super().__init__(config)
Expand Down Expand Up @@ -2192,6 +2203,10 @@ def forward(
)
query_output = query_outputs[0]

# Qformer is kept in fp32, we downcast the output back if needed
if query_output.dtype != image_embeds.dtype:
query_output = query_output.to(image_embeds.dtype)

# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
language_model_attention_mask = torch.ones(
Expand Down Expand Up @@ -2313,6 +2328,10 @@ def generate(
)
query_output = query_outputs.last_hidden_state

# Qformer is kept in fp32, we downcast the output back if needed
if query_output.dtype != image_embeds.dtype:
query_output = query_output.to(image_embeds.dtype)

language_model_inputs = self.language_projection(query_output)
language_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
Expand Down Expand Up @@ -2372,7 +2391,7 @@ def generate(
)
class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
main_input_name = "pixel_values"
_keep_in_fp32_modules = ["query_tokens"]
_keep_in_fp32_modules = ["query_tokens", "qformer"]

def __init__(self, config: Blip2Config):
super().__init__(config)
Expand Down
Loading