Skip to content

Commit 83f9314

Browse files
authored
fix: cast input pixels to appropriate dtype for image_to_text pipelines (#24947)
* fix: cast input pixels to appropriate dtype for image_to_text tasks * fix: add casting to pixel inputs of additional models after running copy checks
1 parent 1c7e5e2 commit 83f9314

File tree

9 files changed

+15
-9
lines changed

9 files changed

+15
-9
lines changed

src/transformers/models/altclip/modeling_altclip.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1022,7 +1022,8 @@ def __init__(self, config: AltCLIPVisionConfig):
10221022

10231023
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
10241024
batch_size = pixel_values.shape[0]
1025-
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
1025+
target_dtype = self.patch_embedding.weight.dtype
1026+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
10261027
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
10271028

10281029
class_embeds = self.class_embedding.expand(batch_size, 1, -1)

src/transformers/models/blip/modeling_blip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def __init__(self, config: BlipVisionConfig):
246246
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
247247
batch_size = pixel_values.shape[0]
248248
target_dtype = self.patch_embedding.weight.dtype
249-
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
249+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
250250
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
251251

252252
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)

src/transformers/models/blip_2/modeling_blip_2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(self, config: Blip2VisionConfig):
109109
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
110110
batch_size = pixel_values.shape[0]
111111
target_dtype = self.patch_embedding.weight.dtype
112-
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
112+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
113113
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
114114

115115
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)

src/transformers/models/bridgetower/modeling_bridgetower.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,8 @@ def __init__(self, config: BridgeTowerVisionConfig):
284284

285285
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
286286
batch_size = pixel_values.shape[0]
287-
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
287+
target_dtype = self.patch_embedding.weight.dtype
288+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
288289
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
289290

290291
class_embeds = self.class_embedding.expand(batch_size, 1, -1)

src/transformers/models/chinese_clip/modeling_chinese_clip.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def __init__(self, config: ChineseCLIPVisionConfig):
196196

197197
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
198198
batch_size = pixel_values.shape[0]
199-
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
199+
target_dtype = self.patch_embedding.weight.dtype
200+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
200201
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
201202

202203
class_embeds = self.class_embedding.expand(batch_size, 1, -1)

src/transformers/models/clip/modeling_clip.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ def __init__(self, config: CLIPVisionConfig):
192192

193193
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
194194
batch_size = pixel_values.shape[0]
195-
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
195+
target_dtype = self.patch_embedding.weight.dtype
196+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
196197
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
197198

198199
class_embeds = self.class_embedding.expand(batch_size, 1, -1)

src/transformers/models/git/modeling_git.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,8 @@ def __init__(self, config: GitVisionConfig):
628628

629629
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
630630
batch_size = pixel_values.shape[0]
631-
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
631+
target_dtype = self.patch_embedding.weight.dtype
632+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
632633
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
633634

634635
class_embeds = self.class_embedding.expand(batch_size, 1, -1)

src/transformers/models/instructblip/modeling_instructblip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __init__(self, config: InstructBlipVisionConfig):
110110
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
111111
batch_size = pixel_values.shape[0]
112112
target_dtype = self.patch_embedding.weight.dtype
113-
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
113+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
114114
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
115115

116116
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)

src/transformers/models/x_clip/modeling_x_clip.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def __init__(self, config: XCLIPVisionConfig):
143143

144144
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
145145
batch_size = pixel_values.shape[0]
146-
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
146+
target_dtype = self.patch_embedding.weight.dtype
147+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
147148
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
148149

149150
class_embeds = self.class_embedding.expand(batch_size, 1, -1)

0 commit comments

Comments
 (0)