From d158717e9de825a738612e8edc63c9ae87a233ea Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 24 Dec 2024 23:54:54 -0800 Subject: [PATCH] bugfix: casting int array to int32 for rope input arguments (#697) To avoid the potential bugs when user pass LongTensor to rope APIs. Also remove some files that are not used in current codebase. Fixed a bug in AOT mode that `apply_rope_pos_ids_cos_sin_cache` was not registered in pybind. --- csrc/dispatch_type_code.h | 192 -------------------------------------- csrc/dispatch_utils.h | 71 -------------- csrc/flashinfer_ops.cu | 2 + flashinfer/rope.py | 7 ++ 4 files changed, 9 insertions(+), 263 deletions(-) delete mode 100644 csrc/dispatch_type_code.h delete mode 100644 csrc/dispatch_utils.h diff --git a/csrc/dispatch_type_code.h b/csrc/dispatch_type_code.h deleted file mode 100644 index 4f717b95..00000000 --- a/csrc/dispatch_type_code.h +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#include -#include - -#include - -using namespace flashinfer; - -enum class TypeCode { - kFloat64 = 0, - kFloat32 = 1, - kFloat16 = 2, - kBFloat16 = 3, - kFloat8_e4m3fn = 4, - kFloat8_e5m2 = 5, - kInt64 = 100, - kUInt64 = 101, - kInt32 = 102, - kUInt32 = 103, - kInt16 = 104, - kUInt16 = 105, - kInt8 = 106, - kUInt8 = 107, -}; - -#ifdef FLASHINFER_ENABLE_BF16 -#define DISPATCH_TYPE_CODE_TO_CTYPE_FP16(type_code, c_type, ...) \ - [&]() -> bool { \ - switch (TypeCode(type_code)) { \ - case TypeCode::kFloat16: { \ - using c_type = nv_half; \ - return __VA_ARGS__(); \ - } \ - case TypeCode::kBFloat16: { \ - using c_type = nv_bfloat16; \ - return __VA_ARGS__(); \ - } \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch type code " << type_code; \ - FLASHINFER_ERROR(oss.str()); \ - return false; \ - } \ - }() -#else -#define DISPATCH_TYPE_CODE_TO_CTYPE_FP16(type_code, c_type, ...) \ - [&]() -> bool { \ - switch (TypeCode(type_code)) { \ - case TypeCode::kFloat16: { \ - using c_type = nv_half; \ - return __VA_ARGS__(); \ - } \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch type code " << type_code; \ - FLASHINFER_ERROR(oss.str()); \ - return false; \ - } \ - }() -#endif - -#ifdef FLASHINFER_ENABLE_FP8 -#define DISPATCH_TYPE_CODE_TO_CTYPE_FP8(type_code, c_type, ...) \ - [&]() -> bool { \ - switch (TypeCode(type_code)) { \ - case TypeCode::kFloat8_e4m3fn: { \ - using c_type = __nv_fp8_e4m3; \ - return __VA_ARGS__(); \ - } \ - case TypeCode::kFloat8_e5m2: { \ - using c_type = __nv_fp8_e5m2; \ - return __VA_ARGS__(); \ - } \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch type code " << type_code; \ - FLASHINFER_ERROR(oss.str()); \ - return false; \ - } \ - }() -#else -#define DISPATCH_TYPE_CODE_TO_CTYPE_FP8(type_code, c_type, ...) \ - [&]() -> bool { \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch type code " << type_code; \ - FLASHINFER_ERROR(oss.str()); \ - return false; \ - }() -#endif - -#if defined(FLASHINFER_ENABLE_BF16) && defined(FLASHINFER_ENABLE_FP8) -#define DISPATCH_TYPE_CODE_TO_CTYPE(type_code, c_type, ...) \ - [&]() -> bool { \ - switch (TypeCode(type_code)) { \ - case TypeCode::kFloat16: { \ - using c_type = nv_half; \ - return __VA_ARGS__(); \ - } \ - case TypeCode::kBFloat16: { \ - using c_type = nv_bfloat16; \ - return __VA_ARGS__(); \ - } \ - case TypeCode::kFloat8_e4m3fn: { \ - using c_type = __nv_fp8_e4m3; \ - return __VA_ARGS__(); \ - } \ - case TypeCode::kFloat8_e5m2: { \ - using c_type = __nv_fp8_e5m2; \ - return __VA_ARGS__(); \ - } \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch type code " << type_code; \ - FLASHINFER_ERROR(oss.str()); \ - return false; \ - } \ - }() -#elif defined(FLASHINFER_ENABLE_BF16) -#define DISPATCH_TYPE_CODE_TO_CTYPE(type_code, c_type, ...) \ - [&]() -> bool { \ - switch (TypeCode(type_code)) { \ - case TypeCode::kFloat16: { \ - using c_type = nv_half; \ - return __VA_ARGS__(); \ - } \ - case TypeCode::kBFloat16: { \ - using c_type = nv_bfloat16; \ - return __VA_ARGS__(); \ - } \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch type code " << type_code; \ - FLASHINFER_ERROR(oss.str()); \ - return false; \ - } \ - }() -#elif defined(FLASHINFER_ENABLE_FP8) -#define DISPATCH_TYPE_CODE_TO_CTYPE(type_code, c_type, ...) \ - [&]() -> bool { \ - switch (TypeCode(type_code)) { \ - case TypeCode::kFloat16: { \ - using c_type = nv_half; \ - return __VA_ARGS__(); \ - } \ - case TypeCode::kFloat8_e4m3fn: { \ - using c_type = __nv_fp8_e4m3; \ - return __VA_ARGS__(); \ - } \ - case TypeCode::kFloat8_e5m2: { \ - using c_type = __nv_fp8_e5m2; \ - return __VA_ARGS__(); \ - } \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch type code " << type_code; \ - FLASHINFER_ERROR(oss.str()); \ - return false; \ - } \ - }() -#else -#define DISPATCH_TYPE_CODE_TO_CTYPE(type_code, c_type, ...) \ - [&]() -> bool { \ - switch (TypeCode(type_code)) { \ - case TypeCode::kFloat16: { \ - using c_type = nv_half; \ - return __VA_ARGS__(); \ - } \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch type code " << type_code; \ - FLASHINFER_ERROR(oss.str()); \ - return false; \ - } \ - }() -#endif diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h deleted file mode 100644 index 7885cd72..00000000 --- a/csrc/dispatch_utils.h +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#include -#include - -#include -#include - -#include "dispatch_type_code.h" -#include "generated/dispatch.inc" - -using namespace flashinfer; - -#define _DISPATCH_SWITCH(var_name, cond, ...) \ - [&]() -> bool { \ - switch (cond) { \ - __VA_ARGS__ \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \ - FLASHINFER_ERROR(oss.str()); \ - return false; \ - } \ - }() - -#define _DISPATCH_CASE(case_expr, case_var, ...) \ - case case_expr: { \ - constexpr auto case_var = case_expr; \ - return __VA_ARGS__(); \ - } - -#define DISPATCH_head_dim(expr, const_expr, ...) \ - _DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__)) - -#define DISPATCH_pos_encoding_mode(expr, const_expr, ...) \ - _DISPATCH_SWITCH("positional encoding mode", expr, \ - _DISPATCH_CASES_pos_encoding_mode(const_expr, __VA_ARGS__)) - -#define DISPATCH_allow_fp16_qk_reduction(expr, const_expr, ...) \ - _DISPATCH_SWITCH("allow_fp16_qk_reduction", expr, \ - _DISPATCH_CASES_allow_fp16_qk_reduction(const_expr, __VA_ARGS__)) - -#define DISPATCH_mask_mode(expr, const_expr, ...) \ - _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) - -#define DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, ...) \ - [&]() -> bool { \ - if (use_logits_soft_cap) { \ - constexpr bool USE_LOGITS_SOFT_CAP = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr bool USE_LOGITS_SOFT_CAP = false; \ - return __VA_ARGS__(); \ - } \ - }() diff --git a/csrc/flashinfer_ops.cu b/csrc/flashinfer_ops.cu index 3d84bfc4..75c78e25 100644 --- a/csrc/flashinfer_ops.cu +++ b/csrc/flashinfer_ops.cu @@ -254,6 +254,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids"); m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids, "Apply Llama 3.1 style RoPE with positional ids"); + m.def("apply_rope_pos_ids_cos_sin_cache", &apply_rope_pos_ids_cos_sin_cache, + "Apply RoPE with positional ids and cosine/sine cache"); // sampling m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 17266b4e..8c781b29 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -56,6 +56,8 @@ def _apply_rope( rope_theta: float, ) -> None: with q.device as device: + indptr = indptr.int() + offsets = offsets.int() get_rope_module().apply_rope( q, k, @@ -104,6 +106,8 @@ def _apply_llama31_rope( old_context_len: float, ) -> None: with q.device as device: + indptr = indptr.int() + offsets = offsets.int() get_rope_module().apply_llama31_rope( q, k, @@ -154,6 +158,7 @@ def _apply_rope_pos_ids( rope_theta: float, ) -> None: with q.device as device: + pos_ids = pos_ids.int() get_rope_module().apply_rope_pos_ids( q, k, @@ -197,6 +202,7 @@ def _apply_rope_pos_ids_cos_sin_cache( interleave: bool, ) -> None: with q.device as device: + pos_ids = pos_ids.int() get_rope_module().apply_rope_pos_ids_cos_sin_cache( q, k, @@ -242,6 +248,7 @@ def _apply_llama31_rope_pos_ids( old_context_len: float, ) -> None: with q.device as device: + pos_ids = pos_ids.int() get_rope_module().apply_llama31_rope_pos_ids( q, k,