forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Kernel/Model] Migrate mamba_ssm and causal_conv1d kernels to vLLM (v…
- Loading branch information
Showing
20 changed files
with
2,815 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
/****************************************************************************** | ||
* Copyright (c) 2024, Tri Dao. | ||
******************************************************************************/ | ||
// clang-format off | ||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h | ||
#pragma once | ||
|
||
#include <cuda_bf16.h> | ||
#include <cuda_fp16.h> | ||
//////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
struct ConvParamsBase { | ||
using index_t = uint32_t; | ||
|
||
int batch, dim, seqlen, width; | ||
bool silu_activation; | ||
|
||
index_t x_batch_stride; | ||
index_t x_c_stride; | ||
index_t x_l_stride; | ||
index_t weight_c_stride; | ||
index_t weight_width_stride; | ||
index_t out_batch_stride; | ||
index_t out_c_stride; | ||
index_t out_l_stride; | ||
|
||
index_t conv_state_batch_stride; | ||
index_t conv_state_c_stride; | ||
index_t conv_state_l_stride; | ||
|
||
// Common data pointers. | ||
void *__restrict__ x_ptr; | ||
void *__restrict__ weight_ptr; | ||
void *__restrict__ bias_ptr; | ||
void *__restrict__ out_ptr; | ||
|
||
void *__restrict__ conv_state_ptr; | ||
|
||
void *__restrict__ seq_idx_ptr; | ||
|
||
// No __restrict__ since initial_states could be the same as final_states. | ||
void * initial_states_ptr; | ||
index_t initial_states_batch_stride; | ||
index_t initial_states_l_stride; | ||
index_t initial_states_c_stride; | ||
|
||
void * final_states_ptr; | ||
index_t final_states_batch_stride; | ||
index_t final_states_l_stride; | ||
index_t final_states_c_stride; | ||
}; | ||
|
||
|
||
#ifndef USE_ROCM | ||
#include <cuda_bf16.h> | ||
|
||
template<typename T> | ||
__device__ inline T shuffle_xor(T val, int offset) { | ||
return __shfl_xor_sync(uint32_t(-1), val, offset); | ||
} | ||
|
||
constexpr size_t custom_max(std::initializer_list<size_t> ilist) | ||
{ | ||
return std::max(ilist); | ||
} | ||
|
||
template<typename T> | ||
constexpr T constexpr_min(T a, T b) { | ||
return std::min(a, b); | ||
} | ||
|
||
#else | ||
#include <hip/hip_bf16.h> | ||
|
||
template<typename T> | ||
__device__ inline T shuffle_xor(T val, int offset) { | ||
return __shfl_xor(val, offset); | ||
} | ||
constexpr size_t custom_max(std::initializer_list<size_t> ilist) | ||
{ | ||
return *std::max_element(ilist.begin(), ilist.end()); | ||
} | ||
|
||
template<typename T> | ||
constexpr T constexpr_min(T a, T b) { | ||
return a < b ? a : b; | ||
} | ||
#endif | ||
|
||
//////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
template<int BYTES> struct BytesToType {}; | ||
|
||
template<> struct BytesToType<16> { | ||
using Type = uint4; | ||
static_assert(sizeof(Type) == 16); | ||
}; | ||
|
||
template<> struct BytesToType<8> { | ||
using Type = uint64_t; | ||
static_assert(sizeof(Type) == 8); | ||
}; | ||
|
||
template<> struct BytesToType<4> { | ||
using Type = uint32_t; | ||
static_assert(sizeof(Type) == 4); | ||
}; | ||
|
||
template<> struct BytesToType<2> { | ||
using Type = uint16_t; | ||
static_assert(sizeof(Type) == 2); | ||
}; | ||
|
||
template<> struct BytesToType<1> { | ||
using Type = uint8_t; | ||
static_assert(sizeof(Type) == 1); | ||
}; | ||
|
||
//////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
template<typename T> | ||
struct SumOp { | ||
__device__ inline T operator()(T const & x, T const & y) { return x + y; } | ||
}; | ||
|
||
template<int THREADS> | ||
struct Allreduce { | ||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); | ||
template<typename T, typename Operator> | ||
static __device__ inline T run(T x, Operator &op) { | ||
constexpr int OFFSET = THREADS / 2; | ||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); | ||
return Allreduce<OFFSET>::run(x, op); | ||
} | ||
}; | ||
|
||
template<> | ||
struct Allreduce<2> { | ||
template<typename T, typename Operator> | ||
static __device__ inline T run(T x, Operator &op) { | ||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); | ||
return x; | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// Inspired by | ||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h | ||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h | ||
// clang-format off | ||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h | ||
|
||
#pragma once | ||
|
||
/// @param COND - a boolean expression to switch by | ||
/// @param CONST_NAME - a name given for the constexpr bool variable. | ||
/// @param ... - code to execute for true and false | ||
/// | ||
/// Usage: | ||
/// ``` | ||
/// BOOL_SWITCH(flag, BoolConst, [&] { | ||
/// some_function<BoolConst>(...); | ||
/// }); | ||
/// ``` | ||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \ | ||
[&] { \ | ||
if (COND) { \ | ||
static constexpr bool CONST_NAME = true; \ | ||
return __VA_ARGS__(); \ | ||
} else { \ | ||
static constexpr bool CONST_NAME = false; \ | ||
return __VA_ARGS__(); \ | ||
} \ | ||
}() |
Oops, something went wrong.