Skip to content

[FEATURE] Support for AMD gpu's (and others) on windows #1169

@Teranis

Description

@Teranis

Describe the solution you'd like
Hi,
I'm working on Cell ACDC.
My personal computer runs Windows and has an AMD GPU. In order to run Cellpose on this GPU, I looked into DirectML, which uses Direct X to run PyTorch. Since DirectML doesn't support sparse tensors, I had to do some slight operation of pytorch. I'll provide the code in additional context, where I move every operation and tensor which is not supported to the CPU.
Although I had my doubts about the performance, it still seems A LOT faster than just running it on the CPU. The only problem is that one needs to use python 3.11.
Describe alternatives you've considered
Waiting 5-10x longer while segmenting on CPU

Additional context
Code for making cellpose use the DirectML:

def setup_custom_device(model, device):
    model.gpu = True
    model.device = device
    model.mkldnn = False
    if hasattr(model, 'cp'):
        model.cp.gpu = True
        model.cp.device = device
        model.cp.mkldnn = False
        if hasattr(model.cp, 'net'):
            model.cp.net.to(device)
            model.cp.net.mkldnn = False
    if hasattr(model, 'net'):
        model.net.to(device)
        model.net.mkldnn = False
    if hasattr(model, 'sz'):
        model.sz.device = device

def setup_directML(model):
    print(
        'Using DirectML GPU for Cellpose model inference'
    )
    import torch_directml
    directml_device = torch_directml.device()
    setup_custom_device(model, directml_device)

Code for fixing sparse tensor implementation:

def fix_sparse_directML(verbose=True):
    """DirectML does not support sparse tensors, so we need to fallback to CPU
    """
    import torch
    import functools
    import warnings

    def fallback_to_cpu_on_sparse_error(func, verbose=True):
        @functools.wraps(func) # wrapper shinanigans (thanks chatgpt)
        def wrapper(*args, **kwargs):
            device_arg = kwargs.get('device', None)

            # Ensure indices are int64 if args[0] looks like indices
            if len(args) >= 1 and isinstance(args[0], torch.Tensor):
                if args[0].dtype != torch.int64:
                    args = (args[0].to(dtype=torch.int64),) + args[1:]

            try: # try to move result to dml
                result = func(*args, **kwargs)
                if device_arg is not None and str(device_arg).lower() == "dml":
                    try:
                        result.to("dml")
                    except RuntimeError as e:
                        if verbose:
                            warnings.warn(f"Sparse op failed on DirectML, falling back to CPU: {e}")
                        kwargs['device'] = torch.device("cpu")
                        return func(*args, **kwargs)
                return result

            except RuntimeError as e: # try and run on dlm, if it fails, fallback to cpu
                if "sparse" in str(e).lower() or "not implemented" in str(e).lower():
                    if verbose:
                        warnings.warn(f"Sparse op failed on DirectML, falling back to CPU: {e}")
                    kwargs['device'] = torch.device("cpu")

                    # Re-apply indices dtype correction before retrying on CPU
                    if len(args) >= 1 and isinstance(args[0], torch.Tensor):
                        if args[0].dtype != torch.int64:
                            args = (args[0].to(dtype=torch.int64),) + args[1:]

                    return func(*args, **kwargs)
                else:
                    raise e

        return wrapper

    # --- Patch Sparse Tensor Constructors ---

    # High-level API
    torch.sparse_coo_tensor = fallback_to_cpu_on_sparse_error(torch.sparse_coo_tensor, verbose=verbose)

    # Low-level API
    if hasattr(torch._C, "_sparse_coo_tensor_unsafe"):
        torch._C._sparse_coo_tensor_unsafe = fallback_to_cpu_on_sparse_error(torch._C._sparse_coo_tensor_unsafe, verbose=verbose)

    if hasattr(torch._C, "_sparse_coo_tensor_with_dims_and_tensors"):
        torch._C._sparse_coo_tensor_with_dims_and_tensors = fallback_to_cpu_on_sparse_error(
            torch._C._sparse_coo_tensor_with_dims_and_tensors, verbose=verbose
        )

    if hasattr(torch.sparse, 'SparseTensor'):
        torch.sparse.SparseTensor = fallback_to_cpu_on_sparse_error(torch.sparse.SparseTensor, verbose=verbose)
    
    # suppress warnings 
    import warnings
    warnings.filterwarnings("once", message="Sparse op failed on DirectML*")

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions