-
Notifications
You must be signed in to change notification settings - Fork 566
Description
Describe the solution you'd like
Hi,
I'm working on Cell ACDC.
My personal computer runs Windows and has an AMD GPU. In order to run Cellpose on this GPU, I looked into DirectML, which uses Direct X to run PyTorch. Since DirectML doesn't support sparse tensors, I had to do some slight operation of pytorch. I'll provide the code in additional context, where I move every operation and tensor which is not supported to the CPU.
Although I had my doubts about the performance, it still seems A LOT faster than just running it on the CPU. The only problem is that one needs to use python 3.11.
Describe alternatives you've considered
Waiting 5-10x longer while segmenting on CPU
Additional context
Code for making cellpose use the DirectML:
def setup_custom_device(model, device):
model.gpu = True
model.device = device
model.mkldnn = False
if hasattr(model, 'cp'):
model.cp.gpu = True
model.cp.device = device
model.cp.mkldnn = False
if hasattr(model.cp, 'net'):
model.cp.net.to(device)
model.cp.net.mkldnn = False
if hasattr(model, 'net'):
model.net.to(device)
model.net.mkldnn = False
if hasattr(model, 'sz'):
model.sz.device = device
def setup_directML(model):
print(
'Using DirectML GPU for Cellpose model inference'
)
import torch_directml
directml_device = torch_directml.device()
setup_custom_device(model, directml_device)
Code for fixing sparse tensor implementation:
def fix_sparse_directML(verbose=True):
"""DirectML does not support sparse tensors, so we need to fallback to CPU
"""
import torch
import functools
import warnings
def fallback_to_cpu_on_sparse_error(func, verbose=True):
@functools.wraps(func) # wrapper shinanigans (thanks chatgpt)
def wrapper(*args, **kwargs):
device_arg = kwargs.get('device', None)
# Ensure indices are int64 if args[0] looks like indices
if len(args) >= 1 and isinstance(args[0], torch.Tensor):
if args[0].dtype != torch.int64:
args = (args[0].to(dtype=torch.int64),) + args[1:]
try: # try to move result to dml
result = func(*args, **kwargs)
if device_arg is not None and str(device_arg).lower() == "dml":
try:
result.to("dml")
except RuntimeError as e:
if verbose:
warnings.warn(f"Sparse op failed on DirectML, falling back to CPU: {e}")
kwargs['device'] = torch.device("cpu")
return func(*args, **kwargs)
return result
except RuntimeError as e: # try and run on dlm, if it fails, fallback to cpu
if "sparse" in str(e).lower() or "not implemented" in str(e).lower():
if verbose:
warnings.warn(f"Sparse op failed on DirectML, falling back to CPU: {e}")
kwargs['device'] = torch.device("cpu")
# Re-apply indices dtype correction before retrying on CPU
if len(args) >= 1 and isinstance(args[0], torch.Tensor):
if args[0].dtype != torch.int64:
args = (args[0].to(dtype=torch.int64),) + args[1:]
return func(*args, **kwargs)
else:
raise e
return wrapper
# --- Patch Sparse Tensor Constructors ---
# High-level API
torch.sparse_coo_tensor = fallback_to_cpu_on_sparse_error(torch.sparse_coo_tensor, verbose=verbose)
# Low-level API
if hasattr(torch._C, "_sparse_coo_tensor_unsafe"):
torch._C._sparse_coo_tensor_unsafe = fallback_to_cpu_on_sparse_error(torch._C._sparse_coo_tensor_unsafe, verbose=verbose)
if hasattr(torch._C, "_sparse_coo_tensor_with_dims_and_tensors"):
torch._C._sparse_coo_tensor_with_dims_and_tensors = fallback_to_cpu_on_sparse_error(
torch._C._sparse_coo_tensor_with_dims_and_tensors, verbose=verbose
)
if hasattr(torch.sparse, 'SparseTensor'):
torch.sparse.SparseTensor = fallback_to_cpu_on_sparse_error(torch.sparse.SparseTensor, verbose=verbose)
# suppress warnings
import warnings
warnings.filterwarnings("once", message="Sparse op failed on DirectML*")