- 
                Notifications
    
You must be signed in to change notification settings  - Fork 4.5k
 
[Inference/Feat] Add convert_fp8 op for fp8 test in the future #5706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
            Courtesy-Xs
  merged 2 commits into
  hpcaitech:feature/colossal-infer
from
Courtesy-Xs:feat_quant_kv_cache_step6
  
      
      
   
  May 10, 2024 
      
    
  
     Merged
                    Changes from all commits
      Commits
    
    
            Show all changes
          
          
            2 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      
    File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| #include <torch/extension.h> | ||
| #include <ATen/cuda/Exceptions.h> | ||
| #include <ATen/cuda/CUDAContext.h> | ||
| 
     | 
||
| #include <cmath> | ||
| 
     | 
||
| #include "common/micros.h" | ||
| #include "utils/vec_copy.h" | ||
| #include "funcs/cast_functor.h" | ||
| 
     | 
||
| 
     | 
||
| using colossalAI::cuda::utils::copy; | ||
| using colossalAI::cuda::utils::get_vec_size; | ||
| using colossalAI::funcs::CastFunctor; | ||
| 
     | 
||
| template <typename InT, typename OutT, int VecSize> | ||
| __global__ void convert_fp8_kernel(const InT* ins_data, OutT* outs_data, int numel, int tail) | ||
| { | ||
| int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x); | ||
| const int64_t grid_size = blockDim.x * gridDim.x; | ||
| if(idx > numel + tail) { | ||
| return; | ||
| } | ||
| 
     | 
||
| for(int64_t i = idx; i < numel; i += grid_size) { | ||
| copy<InT, OutT, VecSize>(ins_data + i * VecSize, outs_data + i * VecSize); | ||
| } | ||
| // Tail process | ||
| if(threadIdx.x == 0) | ||
| { | ||
| for(int i = 0; i < tail; ++i) | ||
| { | ||
| outs_data[i + numel * VecSize] = CastFunctor<InT, OutT>()(ins_data[i + numel * VecSize]); | ||
| } | ||
| } | ||
| } | ||
| 
     | 
||
| template <typename InT, typename OutT> | ||
| void apply_convert_fp8(torch::Tensor& input, torch::Tensor& output) | ||
| { | ||
| const int kVecSize = get_vec_size<InT>(input); | ||
| const int kNumel = torch::numel(input); | ||
| 
     | 
||
| const int kVecNumel = (kNumel >> static_cast<int>(std::log2(kVecSize))); | ||
| const int kTail = kNumel & (kVecSize - 1); | ||
| int grid_size = kVecNumel ? (kVecNumel + 255) / 256 : 1; | ||
| 
     | 
||
| const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
| 
     | 
||
| dim3 grid(grid_size); | ||
| dim3 block(256); | ||
| 
     | 
||
| #define _(VEC_SIZE) \ | ||
| convert_fp8_kernel<InT, OutT, VEC_SIZE> \ | ||
| <<<grid, block, 0, stream>>> \ | ||
| (reinterpret_cast<const InT*>(input.data_ptr()), \ | ||
| reinterpret_cast<OutT*>(output.data_ptr()), \ | ||
| kVecNumel, \ | ||
| kTail) | ||
| 
     | 
||
| switch (kVecSize) | ||
| { | ||
| case 1: | ||
| _(1); | ||
| break; | ||
| case 2: | ||
| _(2); | ||
| break; | ||
| case 4: | ||
| _(4); | ||
| break; | ||
| } | ||
| #undef _ | ||
| AT_CUDA_CHECK(cudaGetLastError()); | ||
| } | ||
| 
     | 
||
| void convert_fp8(torch::Tensor& input, torch::Tensor& output) | ||
| { | ||
| TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte || output.scalar_type() == at::ScalarType::Byte, "Data type of Input or Output should be torch.uint8 for convert_fp8!"); | ||
| TORCH_CHECK(input.scalar_type() != output.scalar_type(), "Data type of input and output are the same!"); | ||
| TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte || | ||
| input.scalar_type() == at::ScalarType::Float || | ||
| input.scalar_type() == at::ScalarType::Half || | ||
| input.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of input!"); | ||
| TORCH_CHECK(output.scalar_type() == at::ScalarType::Byte || | ||
| output.scalar_type() == at::ScalarType::Float || | ||
| output.scalar_type() == at::ScalarType::Half || | ||
| output.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of output!"); | ||
| TORCH_CHECK(input.sizes() == output.sizes(), "Shape of input and output should be the same!"); | ||
| 
     | 
||
| #define _(InT, OutT) \ | ||
| apply_convert_fp8<InT, OutT>(input, output) | ||
| 
     | 
||
| 
     | 
||
| if(input.scalar_type() == at::ScalarType::Byte) | ||
| { | ||
| if(output.scalar_type() == at::ScalarType::Float) | ||
| { | ||
| _(uint8_t, float); | ||
| } | ||
| else if(output.scalar_type() == at::ScalarType::Half) | ||
| { | ||
| _(uint8_t, half); | ||
| } | ||
| else if(output.scalar_type() == at::ScalarType::BFloat16) | ||
| { | ||
| _(uint8_t, __nv_bfloat16); | ||
| } | ||
| } | ||
| else | ||
| { | ||
| if(input.scalar_type() == at::ScalarType::Float) | ||
| { | ||
| _(float, uint8_t); | ||
| } | ||
| else if(input.scalar_type() == at::ScalarType::Half) | ||
| { | ||
| _(half, uint8_t); | ||
| } | ||
| else if(input.scalar_type() == at::ScalarType::BFloat16) | ||
| { | ||
| _(__nv_bfloat16, uint8_t); | ||
| } | ||
| } | ||
| 
     | 
||
| #undef _ | ||
| } | ||
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| import random | ||
| 
     | 
||
| import pytest | ||
| import torch | ||
| 
     | 
||
| from colossalai.kernel.kernel_loader import InferenceOpsLoader | ||
| from colossalai.utils import get_current_device | ||
| 
     | 
||
| inference_ops = InferenceOpsLoader().load() | ||
| 
     | 
||
| DTYPES = [torch.half, torch.bfloat16, torch.float] | ||
| NUM_TOKENS = [42] # Arbitrary values for testing | ||
| NUM_LAYERS = [1] # Arbitrary values for testing | ||
| NUM_HEADS = [8] # Arbitrary values for testing | ||
| HEAD_SIZES = [64, 80, 96, 112, 128, 256] | ||
| BLOCK_SIZES = [8, 16, 32] | ||
| 
     | 
||
| 
     | 
||
| @pytest.mark.skipif(True, reason="FP8 conversion still needs improvement, now we skip it's relative test!") | ||
| @pytest.mark.parametrize("num_heads", [8]) | ||
| @pytest.mark.parametrize("head_size", [64, 80, 96, 112, 128, 256]) | ||
| @pytest.mark.parametrize("block_size", [8, 16, 32]) | ||
| @pytest.mark.parametrize("num_blocks", [1024, 10000]) | ||
| @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float]) | ||
| @pytest.mark.parametrize("seed", [0]) | ||
| @torch.inference_mode() | ||
| def test_fp8_conversion( | ||
| num_heads: int, | ||
| head_size: int, | ||
| block_size: int, | ||
| num_blocks: int, | ||
| dtype: torch.dtype, | ||
| seed: int, | ||
| ) -> None: | ||
| random.seed(seed) | ||
| torch.random.manual_seed(seed) | ||
| torch.cuda.manual_seed(seed) | ||
| 
     | 
||
| device = get_current_device() | ||
| 
     | 
||
| low = -224.0 | ||
| high = 224.0 | ||
| shape = (num_blocks, num_heads, head_size, block_size) | ||
| cache = torch.empty(shape, dtype=dtype, device=device) | ||
| cache.uniform_(low, high) | ||
| 
     | 
||
| cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) | ||
| inference_ops.convert_fp8(cache, cache_fp8) | ||
| 
     | 
||
| converted_cache = torch.empty_like(cache) | ||
| inference_ops.convert_fp8(cache_fp8, converted_cache) | ||
| 
     | 
||
| assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) | ||
| 
     | 
||
| 
     | 
||
| if __name__ == "__main__": | ||
| test_fp8_conversion(8, 64, 8, 1024, torch.half, 0) | 
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.