Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,22 @@ class BatchFeature(UserDict):
initialization.
skip_tensor_conversion (`list[str]` or `set[str]`, *optional*):
List or set of keys that should NOT be converted to tensors, even when `tensor_type` is specified.
device (`str` or `torch.device`, *optional*):
The device to place tensors on. When specified, all tensor values will be moved to this device.
Note: This is a simple tensor movement operation, not GPU-accelerated processing.
"""

def __init__(
self,
data: Optional[dict[str, Any]] = None,
tensor_type: Union[None, str, TensorType] = None,
skip_tensor_conversion: Optional[Union[list[str], set[str]]] = None,
device: Optional[Union[str, "torch.device"]] = None,
):
super().__init__(data)
self.convert_to_tensors(tensor_type=tensor_type, skip_tensor_conversion=skip_tensor_conversion)
if device is not None:
self.to(device)
Comment on lines +86 to +87
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo instead of doing this, it needs to be passed to convert_to_tensors. And then we can use if tensor_type is 'pt' as torch.tensor(value, device=device)


def __getitem__(self, item: str) -> Any:
"""
Expand Down Expand Up @@ -663,3 +669,4 @@ def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"):
FeatureExtractionMixin.push_to_hub.__doc__ = FeatureExtractionMixin.push_to_hub.__doc__.format(
object="feature extractor", object_class="AutoFeatureExtractor", object_files="feature extractor file"
)

12 changes: 10 additions & 2 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ class CustomProcessorKwargs(ProcessingKwargs, total=False):
audio_kwargs: AudioKwargs = {
**AudioKwargs.__annotations__,
}
device: Annotated[Optional[Union[str, "torch.device"]], device_validator()]


class TokenizerChatTemplateKwargs(TypedDict, total=False):
Expand Down Expand Up @@ -625,9 +626,13 @@ def __call__(

- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
device (`str` or `torch.device`, *optional*):
The device to place tensors on. When specified, all output tensors will be moved to this device.
Note: This is a simple tensor movement operation, not GPU-accelerated processing like fast image processors.

Returns:
[`BatchFeature`]: A [`BatchFeature`] object with processed inputs in a dict format.
[`BatchFeature`]: A [`BatchFeature`] object with processed inputs in a dict format. All tensor outputs
will be on the same device when `device` is specified.
"""
if "audios" in kwargs and audio is None:
raise ValueError("You passed keyword argument `audios` which is deprecated. Please use `audio` instead.")
Expand Down Expand Up @@ -655,7 +660,9 @@ def __call__(
attribute_output = attribute(input_data, **kwargs[input_kwargs])
outputs.update(attribute_output)

return BatchFeature(outputs)
# Extract device parameter if present
device = kwargs.get("device")
return BatchFeature(outputs, device=device)
Comment on lines +663 to +665
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will not really work with structured inputs. It needs to be in common kwargs which will structure it properly in self.merge_kwargs. Then we can re-use it from merged_inputs


def check_argument_for_proper_class(self, argument_name, argument):
"""
Expand Down Expand Up @@ -1857,3 +1864,4 @@ def _check_special_mm_tokens(self, text: list[str], text_inputs: "BatchFeature",
ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
object="processor", object_class="AutoProcessor", object_files="processor files"
)

195 changes: 195 additions & 0 deletions tests/test_batchfeature_device_consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
"""
Test script to verify the new BatchFeature device consistency implementation.
This test verifies that the device parameter works correctly in BatchFeature.__init__.
"""

import torch
import transformers
from PIL import Image
import requests
from torchvision import transforms

def test_batchfeature_device_parameter():
"""Test BatchFeature device parameter directly"""
print("Testing BatchFeature device parameter...")

# Skip test if CUDA is not available
if not torch.cuda.is_available():
print("CUDA not available, skipping device parameter test")
return True

try:
# Create some test data
data = {
"input_ids": torch.tensor([[1, 2, 3, 4]]),
"attention_mask": torch.tensor([[1, 1, 1, 1]]),
"pixel_values": torch.randn(1, 3, 224, 224)
}

print("Original tensors devices:")
for key, value in data.items():
print(f" {key}: {value.device}")

# Create BatchFeature with device parameter
batch_feature = transformers.feature_extraction_utils.BatchFeature(
data=data,
tensor_type="pt",
device="cuda"
)

# Check if all tensors are on CUDA
print("\nAfter BatchFeature with device='cuda':")
all_on_cuda = True
for key, value in batch_feature.items():
if isinstance(value, torch.Tensor):
print(f" {key}: {value.device}")
if value.device.type != "cuda":
all_on_cuda = False
print(f" ERROR: {key} is not on CUDA")

if all_on_cuda:
print("✅ SUCCESS: All tensors moved to CUDA!")
return True
else:
print("❌ FAILURE: Not all tensors are on CUDA!")
return False

except Exception as e:
print(f"Error during BatchFeature device test: {e}")
return False

def test_oneformer_with_new_implementation():
"""Test OneFormer processor with new BatchFeature device implementation"""
print("\nTesting OneFormer processor with new device implementation...")

# Skip test if CUDA is not available
if not torch.cuda.is_available():
print("CUDA not available, skipping OneFormer device test")
return True

try:
# Setup processor
processor = transformers.OneFormerImageProcessorFast()
processor = transformers.OneFormerProcessor(
image_processor=processor,
tokenizer=transformers.AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32"),
)

# Load test image
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
image = Image.open(requests.get(url, stream=True).raw)

# Convert image to tensor and move to CUDA
to_tensor_transform = transforms.ToTensor()
image = to_tensor_transform(image).to("cuda")

print(f"Input image device: {image.device}")

# Process with explicit device parameter
inputs = processor(image, ["semantic"], return_tensors="pt", device="cuda")

# Check device consistency
print("Checking output devices:")
all_on_same_device = True
reference_device = None

for key, value in inputs.items():
if isinstance(value, torch.Tensor):
print(f" {key}: {value.device}")
if reference_device is None:
reference_device = value.device
elif value.device != reference_device:
all_on_same_device = False
print(f" ERROR: {key} is on {value.device} but expected {reference_device}")
elif isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], torch.Tensor):
device = value[0].device
print(f" {key}[0]: {device}")
if reference_device is None:
reference_device = device
elif device != reference_device:
all_on_same_device = False
print(f" ERROR: {key}[0] is on {device} but expected {reference_device}")

if all_on_same_device and reference_device.type == "cuda":
print("✅ SUCCESS: All tensors are on CUDA and consistent!")
return True
else:
print("❌ FAILURE: Device consistency issue!")
return False

except Exception as e:
print(f"Error during OneFormer test: {e}")
return False

def test_cpu_to_cuda_movement():
"""Test moving CPU tensors to CUDA using device parameter"""
print("\nTesting CPU to CUDA movement...")

if not torch.cuda.is_available():
print("CUDA not available, skipping CPU to CUDA test")
return True

try:
# Setup processor
processor = transformers.OneFormerImageProcessorFast()
processor = transformers.OneFormerProcessor(
image_processor=processor,
tokenizer=transformers.AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32"),
)

# Load test image (keep on CPU)
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
image = Image.open(requests.get(url, stream=True).raw)

print("Input image device: CPU (PIL Image)")

# Process with device parameter to move to CUDA
inputs = processor(image, ["semantic"], return_tensors="pt", device="cuda")

# Check if all outputs are on CUDA
print("Checking if all outputs moved to CUDA:")
all_on_cuda = True

for key, value in inputs.items():
if isinstance(value, torch.Tensor):
print(f" {key}: {value.device}")
if value.device.type != "cuda":
all_on_cuda = False
print(f" ERROR: {key} is not on CUDA")
elif isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], torch.Tensor):
device = value[0].device
print(f" {key}[0]: {device}")
if device.type != "cuda":
all_on_cuda = False
print(f" ERROR: {key}[0] is not on CUDA")

if all_on_cuda:
print("✅ SUCCESS: All tensors moved to CUDA as requested!")
return True
else:
print("❌ FAILURE: Not all tensors are on CUDA!")
return False

except Exception as e:
print(f"Error during CPU to CUDA test: {e}")
return False

if __name__ == "__main__":
print("Testing new BatchFeature device consistency implementation...")
print("=" * 70)

# Run tests
test1_passed = test_batchfeature_device_parameter()
test2_passed = test_oneformer_with_new_implementation()
test3_passed = test_cpu_to_cuda_movement()

print("\n" + "=" * 70)
print("Test Summary:")
print(f"BatchFeature device parameter: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
print(f"OneFormer device consistency: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
print(f"CPU to CUDA movement: {'✅ PASSED' if test3_passed else '❌ FAILED'}")

if test1_passed and test2_passed and test3_passed:
print("\n🎉 All tests passed! New BatchFeature device implementation is working correctly.")
else:
print("\n⚠️ Some tests failed. Please check the implementation.")