From 5e95dcabe1d3d522a8bc5a45990c53d9d4e9f2eb Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 20 Feb 2024 12:38:59 +0100 Subject: [PATCH] [`cuda kernels`] only compile them when initializing (#29133) * only compile when needed * fix mra as well * fix yoso as well * update * rempve comment * Update src/transformers/models/deformable_detr/modeling_deformable_detr.py * Update src/transformers/models/deformable_detr/modeling_deformable_detr.py * opps * Update src/transformers/models/deta/modeling_deta.py * nit --- .../modeling_deformable_detr.py | 53 +++++++++++++++---- src/transformers/models/deta/modeling_deta.py | 29 +++++----- src/transformers/models/mra/modeling_mra.py | 40 ++++++-------- src/transformers/models/yoso/modeling_yoso.py | 39 ++++++++------ 4 files changed, 93 insertions(+), 68 deletions(-) diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 89682729c651bd..640c05257cc967 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -17,8 +17,10 @@ import copy import math +import os import warnings from dataclasses import dataclass +from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import torch @@ -46,21 +48,42 @@ from ...utils import is_accelerate_available, is_ninja_available, logging from ...utils.backbone_utils import load_backbone from .configuration_deformable_detr import DeformableDetrConfig -from .load_custom import load_cuda_kernels logger = logging.get_logger(__name__) -# Move this to not compile only when importing, this needs to happen later, like in __init__. -if is_torch_cuda_available() and is_ninja_available(): - logger.info("Loading custom CUDA kernels...") - try: - MultiScaleDeformableAttention = load_cuda_kernels() - except Exception as e: - logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}") - MultiScaleDeformableAttention = None -else: - MultiScaleDeformableAttention = None +MultiScaleDeformableAttention = None + + +def load_cuda_kernels(): + from torch.utils.cpp_extension import load + + global MultiScaleDeformableAttention + + root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deta" + src_files = [ + root / filename + for filename in [ + "vision.cpp", + os.path.join("cpu", "ms_deform_attn_cpu.cpp"), + os.path.join("cuda", "ms_deform_attn_cuda.cu"), + ] + ] + + MultiScaleDeformableAttention = load( + "MultiScaleDeformableAttention", + src_files, + with_cuda=True, + extra_include_paths=[str(root)], + extra_cflags=["-DWITH_CUDA=1"], + extra_cuda_cflags=[ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ], + ) + if is_vision_available(): from transformers.image_transforms import center_to_corners_format @@ -590,6 +613,14 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module): def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int): super().__init__() + + kernel_loaded = MultiScaleDeformableAttention is not None + if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded: + try: + load_cuda_kernels() + except Exception as e: + logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}") + if config.d_model % num_heads != 0: raise ValueError( f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 7e1b014c834eff..5d0b48b45d13ac 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -50,10 +50,15 @@ logger = logging.get_logger(__name__) +MultiScaleDeformableAttention = None + +# Copied from models.deformable_detr.load_cuda_kernels def load_cuda_kernels(): from torch.utils.cpp_extension import load + global MultiScaleDeformableAttention + root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deta" src_files = [ root / filename @@ -78,22 +83,6 @@ def load_cuda_kernels(): ], ) - import MultiScaleDeformableAttention as MSDA - - return MSDA - - -# Move this to not compile only when importing, this needs to happen later, like in __init__. -if is_torch_cuda_available() and is_ninja_available(): - logger.info("Loading custom CUDA kernels...") - try: - MultiScaleDeformableAttention = load_cuda_kernels() - except Exception as e: - logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}") - MultiScaleDeformableAttention = None -else: - MultiScaleDeformableAttention = None - # Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction class MultiScaleDeformableAttentionFunction(Function): @@ -596,6 +585,14 @@ class DetaMultiscaleDeformableAttention(nn.Module): def __init__(self, config: DetaConfig, num_heads: int, n_points: int): super().__init__() + + kernel_loaded = MultiScaleDeformableAttention is not None + if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded: + try: + load_cuda_kernels() + except Exception as e: + logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}") + if config.d_model % num_heads != 0: raise ValueError( f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index d11c2557710846..9915db471ef308 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -58,9 +58,11 @@ # See all Mra models at https://huggingface.co/models?filter=mra ] +mra_cuda_kernel = None + def load_cuda_kernels(): - global cuda_kernel + global mra_cuda_kernel src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "mra" def append_root(files): @@ -68,26 +70,7 @@ def append_root(files): src_files = append_root(["cuda_kernel.cu", "cuda_launch.cu", "torch_extension.cpp"]) - cuda_kernel = load("cuda_kernel", src_files, verbose=True) - - import cuda_kernel - - -cuda_kernel = None - - -if is_torch_cuda_available() and is_ninja_available(): - logger.info("Loading custom CUDA kernels...") - - try: - load_cuda_kernels() - except Exception as e: - logger.warning( - "Failed to load CUDA kernels. Mra requires custom CUDA kernels. Please verify that compatible versions of" - f" PyTorch and CUDA Toolkit are installed: {e}" - ) -else: - pass + mra_cuda_kernel = load("cuda_kernel", src_files, verbose=True) def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block): @@ -112,7 +95,7 @@ def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block): indices = indices.int() indices = indices.contiguous() - max_vals, max_vals_scatter = cuda_kernel.index_max(index_vals, indices, query_num_block, key_num_block) + max_vals, max_vals_scatter = mra_cuda_kernel.index_max(index_vals, indices, query_num_block, key_num_block) max_vals_scatter = max_vals_scatter.transpose(-1, -2)[:, :, None, :] return max_vals, max_vals_scatter @@ -178,7 +161,7 @@ def mm_to_sparse(dense_query, dense_key, indices, block_size=32): indices = indices.int() indices = indices.contiguous() - return cuda_kernel.mm_to_sparse(dense_query, dense_key, indices.int()) + return mra_cuda_kernel.mm_to_sparse(dense_query, dense_key, indices.int()) def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_size=32): @@ -216,7 +199,7 @@ def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_siz indices = indices.contiguous() dense_key = dense_key.contiguous() - dense_qk_prod = cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key, query_num_block) + dense_qk_prod = mra_cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key, query_num_block) dense_qk_prod = dense_qk_prod.transpose(-1, -2).reshape(batch_size, query_num_block * block_size, dim) return dense_qk_prod @@ -393,7 +376,7 @@ def mra2_attention( """ Use Mra to approximate self-attention. """ - if cuda_kernel is None: + if mra_cuda_kernel is None: return torch.zeros_like(query).requires_grad_() batch_size, num_head, seq_len, head_dim = query.size() @@ -561,6 +544,13 @@ def __init__(self, config, position_embedding_type=None): f"heads ({config.num_attention_heads})" ) + kernel_loaded = mra_cuda_kernel is not None + if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded: + try: + load_cuda_kernels() + except Exception as e: + logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}") + self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 9c0636340d1e7c..ab6fb1c151c0db 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -35,7 +35,14 @@ ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_ninja_available, + is_torch_cuda_available, + logging, +) from .configuration_yoso import YosoConfig @@ -49,28 +56,22 @@ # See all YOSO models at https://huggingface.co/models?filter=yoso ] +lsh_cumulation = None + def load_cuda_kernels(): global lsh_cumulation - try: - from torch.utils.cpp_extension import load + from torch.utils.cpp_extension import load - def append_root(files): - src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "yoso" - return [src_folder / file for file in files] - - src_files = append_root( - ["fast_lsh_cumulation_torch.cpp", "fast_lsh_cumulation.cu", "fast_lsh_cumulation_cuda.cu"] - ) + def append_root(files): + src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "yoso" + return [src_folder / file for file in files] - load("fast_lsh_cumulation", src_files, verbose=True) + src_files = append_root(["fast_lsh_cumulation_torch.cpp", "fast_lsh_cumulation.cu", "fast_lsh_cumulation_cuda.cu"]) - import fast_lsh_cumulation as lsh_cumulation + load("fast_lsh_cumulation", src_files, verbose=True) - return True - except Exception: - lsh_cumulation = None - return False + import fast_lsh_cumulation as lsh_cumulation def to_contiguous(input_tensors): @@ -305,6 +306,12 @@ def __init__(self, config, position_embedding_type=None): f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})" ) + kernel_loaded = lsh_cumulation is not None + if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded: + try: + load_cuda_kernels() + except Exception as e: + logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}") self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads)