Skip to content

Commit 197df27

Browse files
FindHaometa-codesync[bot]
authored andcommitted
Normalize CUDA Device to cuda:0 (#177)
Summary: fix #176 ## Changes This PR adds device normalization logic to ensure all tensors are created on `cuda:0`, regardless of the device specified in the JSON configuration. ## Modifications ### 1. `load_tensor()` function - Added device normalization before loading tensor from file - Any CUDA device string (e.g., `cuda`, `cuda:1`, `cuda:2`) is now mapped to `cuda:0` ### 2. `_create_base_tensor()` function - Added device normalization before creating new tensors - Applies to all tensor types: floating-point, integer, complex, and unsigned integer ## Impact - **Prevents multi-GPU issues**: All tensors are now consistently created on the same device - **Backward compatible**: Non-CUDA devices (e.g., `cpu`) remain unchanged - **Minimal code change**: Only two strategic locations modified for maximum coverage ## Testing This change affects the template file used for generating reproducers. All generated test cases will now place tensors on `cuda:0` by default. Pull Request resolved: #177 Reviewed By: wychi Differential Revision: D85061992 Pulled By: FindHao fbshipit-source-id: 594018cc2a8917a762328257950fdf7227ad7503
1 parent 4fa7d4e commit 197df27

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

tritonparse/reproducer/templates/example.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def load_tensor(tensor_file_path: Union[str, Path], device: str = None) -> torch
6767
RuntimeError: If the tensor cannot be loaded
6868
ValueError: If the computed hash doesn't match the filename hash
6969
"""
70+
# Normalize cuda device to cuda:0
71+
if device is not None and isinstance(device, str) and device.startswith("cuda"):
72+
device = "cuda:0"
73+
7074
blob_path = Path(tensor_file_path)
7175

7276
if not blob_path.exists():
@@ -210,6 +214,9 @@ def _create_base_tensor(arg_info) -> torch.Tensor:
210214

211215
shape = arg_info.get("shape", [])
212216
device = arg_info.get("device", "cpu")
217+
# Normalize cuda device to cuda:0
218+
if isinstance(device, str) and device.startswith("cuda"):
219+
device = "cuda:0"
213220

214221
# Extract statistical information if available
215222
mean = arg_info.get("mean")

0 commit comments

Comments
 (0)