Skip to content

Commit

Permalink
[Inference] support llama3 a8w8c8_fp8 inference and cutlass_fp8_gemm (#…
Browse files Browse the repository at this point in the history
…8953)

* fp8

* check

* check

* check

* check

* cutlass fp8

* fp8 chech

* check

* support int8 kv cache

* a8w8c8_fp8

* Copyright check

---------

Co-authored-by: ckl117 <ck117@163.com>
  • Loading branch information
ckl117 and ckl117 authored Sep 2, 2024
1 parent a12781f commit a275ab7
Show file tree
Hide file tree
Showing 36 changed files with 6,952 additions and 59 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fp8_common.h"

#include "fp8_gemm_fused/dual_gemm_scale_bias_swiglu_16_32_64_stages3.h"
#include "fp8_gemm_fused/dual_gemm_scale_bias_swiglu_16_64_64_stages4.h"
#include "fp8_gemm_fused/dual_gemm_scale_bias_swiglu_64_64_64_stages3.h"

template <typename InputType, typename BiasType, typename OutType>
bool dispatch_dual_gemm_scale_bias_swiglu(DualGemmEpilogueAllParams params) {
if(params.M<=32){
return dual_gemm_scale_bias_swiglu_16_32_64_stages3<InputType, BiasType, OutType>(params);
} else if(params.M>32 && params.M<=64) {
return dual_gemm_scale_bias_swiglu_16_64_64_stages4<InputType, BiasType, OutType>(params);
} else {
return dual_gemm_scale_bias_swiglu_64_64_64_stages4<InputType, BiasType, OutType>(params);
}
}

544 changes: 544 additions & 0 deletions csrc/gpu/cutlass_kernels/fp8_gemm_fused/dual_gemm/device/dual_gemm.h

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines common types used for all DualGemm operators.
*/
#pragma once

namespace cutlass {
namespace gemm {

enum class DualGemmMode { kGemm, kBatched, kInvalid };

} // namespace gemm
} // namespace cutlass
Loading

0 comments on commit a275ab7

Please sign in to comment.