Skip to content

Commit

Permalink
Move rescale dtype recasting to match torchvision ToTensor (huggingfa…
Browse files Browse the repository at this point in the history
…ce#25229)

Move dtype recasting to match torchvision ToTensor
  • Loading branch information
amyeroberts authored Aug 1, 2023
1 parent 3170af7 commit d27e4c1
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/transformers/image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,12 @@ def rescale(
if not isinstance(image, np.ndarray):
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")

image = image.astype(dtype)

rescaled_image = image * scale
if data_format is not None:
rescaled_image = to_channel_dimension_format(rescaled_image, data_format)

rescaled_image = rescaled_image.astype(dtype)

return rescaled_image


Expand Down

0 comments on commit d27e4c1

Please sign in to comment.