diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b0d0ba904c32..923ed084ffd9e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -203,6 +203,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_MakeAvailable(cutlass) list(APPEND VLLM_EXT_SRC + "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" + "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" diff --git a/Dockerfile b/Dockerfile index 36fcc2f83e9fb..9bae9a12c0eb2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -42,9 +42,6 @@ COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-cuda.txt -COPY requirements-mamba.txt requirements-mamba.txt -RUN python3 -m pip install packaging -RUN python3 -m pip install -r requirements-mamba.txt # cuda arch list used by torch # can be useful for both `dev` and `test` @@ -127,22 +124,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-dev.txt #################### DEV IMAGE #################### -#################### MAMBA Build IMAGE #################### -FROM dev as mamba-builder -# max jobs used for build -ARG max_jobs=2 -ENV MAX_JOBS=${max_jobs} - -WORKDIR /usr/src/mamba - -COPY requirements-mamba.txt requirements-mamba.txt - -# Download the wheel or build it if a pre-compiled release doesn't exist -RUN pip --verbose wheel -r requirements-mamba.txt \ - --no-build-isolation --no-deps --no-cache-dir - -#################### MAMBA Build IMAGE #################### - #################### vLLM installation IMAGE #################### # image with vLLM installed FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base @@ -179,10 +160,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install dist/*.whl --verbose -RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \ - --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir - RUN --mount=type=cache,target=/root/.cache/pip \ . /etc/environment && \ python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu new file mode 100644 index 0000000000000..88a64a8ece585 --- /dev/null +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -0,0 +1,700 @@ +// clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu +// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu +#include +#include +#include + +#include "causal_conv1d.h" +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include + +#include "static_switch.h" + + + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + + +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template +void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +void set_conv_params_fwd(ConvParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t width, + // device pointers + const at::Tensor x, + const at::Tensor weight, + const at::Tensor out, + void* bias_ptr, + bool silu_activation) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.width = width; + + params.silu_activation = silu_activation; + + // Set the pointers and strides. + params.x_ptr = x.data_ptr(); + params.weight_ptr = weight.data_ptr(); + params.bias_ptr = bias_ptr; + params.out_ptr = out.data_ptr(); + // All stride are in elements, not bytes. + params.x_batch_stride = x.stride(0); + params.x_c_stride = x.stride(1); + params.x_l_stride = x.stride(-1); + params.weight_c_stride = weight.stride(0); + params.weight_width_stride = weight.stride(1); + params.out_batch_stride = out.stride(0); + params.out_c_stride = out.stride(1); + params.out_l_stride = out.stride(-1); +} + + +at::Tensor +causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, + const c10::optional &bias_, + const c10::optional &seq_idx_, + const c10::optional &initial_states_, + const c10::optional &final_states_out_, + bool silu_activation) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int width = weight.size(-1); + + CHECK_SHAPE(x, batch_size, dim, seqlen); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); + const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; + + if (is_channel_last) { + TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); + TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8"); + } + TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + if (seq_idx_.has_value()) { + TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout"); + auto seq_idx = seq_idx_.value(); + TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32); + TORCH_CHECK(seq_idx.is_cuda()); + TORCH_CHECK(seq_idx.is_contiguous()); + CHECK_SHAPE(seq_idx, batch_size, seqlen); + } + + at::Tensor out = torch::empty_like(x); + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, + bias_.has_value() ? bias_.value().data_ptr() : nullptr, + silu_activation); + + if (seq_idx_.has_value()) { + params.seq_idx_ptr = seq_idx_.value().data_ptr(); + } else { + params.seq_idx_ptr = nullptr; + } + + if (initial_states_.has_value()) { + TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout"); + auto initial_states = initial_states_.value(); + TORCH_CHECK(initial_states.scalar_type() == input_type); + TORCH_CHECK(initial_states.is_cuda()); + CHECK_SHAPE(initial_states, batch_size, dim, width - 1); + TORCH_CHECK(initial_states.stride(1) == 1); + params.initial_states_ptr = initial_states.data_ptr(); + params.initial_states_batch_stride = initial_states.stride(0); + params.initial_states_c_stride = initial_states.stride(1); + params.initial_states_l_stride = initial_states.stride(2); + } else { + params.initial_states_ptr = nullptr; + } + + if (final_states_out_.has_value()) { + TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout"); + auto final_states = final_states_out_.value(); + TORCH_CHECK(final_states.scalar_type() == input_type); + TORCH_CHECK(final_states.is_cuda()); + CHECK_SHAPE(final_states, batch_size, dim, width - 1); + TORCH_CHECK(final_states.stride(1) == 1); + params.final_states_ptr = final_states.data_ptr(); + params.final_states_batch_stride = final_states.stride(0); + params.final_states_c_stride = final_states.stride(1); + params.final_states_l_stride = final_states.stride(2); + } else { + params.final_states_ptr = nullptr; + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { + if (!is_channel_last) { + causal_conv1d_fwd_cuda(params, stream); + } else { + causal_conv1d_channellast_fwd_cuda(params, stream); + } + }); + return out; +} + + +at::Tensor +causal_conv1d_update(const at::Tensor &x, + const at::Tensor &conv_state, + const at::Tensor &weight, + const c10::optional &bias_, + bool silu_activation) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations"); + TORCH_CHECK(conv_state.scalar_type() == input_type); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(conv_state.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int width = weight.size(-1); + + CHECK_SHAPE(x, batch_size, dim); + CHECK_SHAPE(conv_state, batch_size, dim, width); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + at::Tensor out = torch::empty_like(x); + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out, + bias_.has_value() ? bias_.value().data_ptr() : nullptr, + silu_activation); + params.conv_state_ptr = conv_state.data_ptr(); + // All stride are in elements, not bytes. + params.conv_state_batch_stride = conv_state.stride(0); + params.conv_state_c_stride = conv_state.stride(1); + params.conv_state_l_stride = conv_state.stride(2); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { + causal_conv1d_update_cuda(params, stream); + }); + return out; +} + +template +struct Causal_conv1d_fwd_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static_assert(kWidth <= kNElts); + static constexpr bool kIsVecLoad = kIsVecLoad_; + using vec_t = typename BytesToType::Type; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + static constexpr int kSmemIOSize = kIsVecLoad + ? 0 + : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); + static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_fwd_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_vec = reinterpret_cast(smem_); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_store_vec = reinterpret_cast(smem_); + vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. + if (tidx == 0) { + input_t zeros[kNElts] = {0}; + smem_exchange[kNThreads - 1] = reinterpret_cast(zeros)[0]; + } + + float weight_vals[kWidth]; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + + constexpr int kChunkSize = kNThreads * kNElts; + const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; + for (int chunk = 0; chunk < n_chunks; ++chunk) { + input_t x_vals_load[2 * kNElts] = {0}; + if constexpr(kIsVecLoad) { + typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); + } else { + __syncthreads(); + typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); + } + x += kChunkSize; + __syncthreads(); + // Thread kNThreads - 1 don't write yet, so that thread 0 can read + // the last elements of the previous chunk. + if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + __syncthreads(); + reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; + __syncthreads(); + // Now thread kNThreads - 1 can write the last elements of the current chunk. + if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + + float x_vals[2 * kNElts]; + #pragma unroll + for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } + + float out_vals[kNElts]; + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = bias_val; + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; + } + } + + if (params.silu_activation) { + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); + } + } + + input_t out_vals_store[kNElts]; + #pragma unroll + for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } + if constexpr(kIsVecLoad) { + typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); + } else { + typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); + } + out += kChunkSize; + } +} + + +template +void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; + BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { + using Ktraits = Causal_conv1d_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize; + dim3 grid(params.batch, params.dim); + + auto kernel = &causal_conv1d_fwd_kernel; + + if (kSmemSize >= 48 * 1024) { + #ifndef USE_ROCM + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + #else + // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function. + C10_CUDA_CHECK(cudaFuncSetAttribute( + (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; + #endif + } + kernel<<>>(params); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); + } +} + +template +struct Causal_conv1d_channellast_fwd_kernel_traits { + // The cache line is 128 bytes, and we try to read 16 bytes per thread. + // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension. + // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128 + // threads). Each each load is 16 x 32|64 elements in the L x C dimensions. + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static_assert(kNThreads % 32 == 0); + static constexpr int kNWarps = kNThreads / 32; + static constexpr int kWidth = kWidth_; + static constexpr int kChunkSizeL = kChunkSizeL_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static constexpr int kNEltsPerRow = 128 / kNBytes; + static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now + static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); + static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now + static_assert(kNColsPerWarp * kNThreadsPerRow == 32); + static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; + static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; + static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); + static constexpr bool kIsVecLoad = kIsVecLoad_; + using vec_t = typename BytesToType::Type; + // using BlockLoadT = cub::BlockLoad; + // using BlockStoreT = cub::BlockStore; + // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage), + // sizeof(typename BlockStoreT::TempStorage)}); + // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; + constexpr int kLPerLoad = Ktraits::kNColsPerLoad; + constexpr int kChunkSizeL = Ktraits::kChunkSizeL; + constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts]; + + const int batch_id = blockIdx.x; + const int chunk_l_id = blockIdx.y; + const int chunk_c_id = blockIdx.z; + const int tid = threadIdx.x; + const int l_idx = tid / kNThreadsPerC; + const int c_idx = tid % kNThreadsPerC; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + weight_t *weight = reinterpret_cast(params.weight_ptr) + + chunk_c_id * kChunkSizeC * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast(params.seq_idx_ptr) + + batch_id * params.seqlen + chunk_l_id * kChunkSizeL; + input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr + : reinterpret_cast(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + // The last L-chunk will also have enough info to write to final states, since it also contain a few x values + // from the previous L-chunk. + input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr + : reinterpret_cast(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + + #pragma unroll + for (int l = 0; l < Ktraits::kNLoads; ++l) { + input_t x_vals_load[kNElts] = {0}; + if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride); + } + reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; + } + // Load the elements from the previous chunk that are needed for convolution. + if (l_idx < kWidth - 1) { + input_t x_vals_load[kNElts] = {0}; + if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 + && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride); + } else if (initial_states != nullptr + && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0 + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(initial_states); + } + reinterpret_cast(x_smem[l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; + } + + __syncthreads(); + + if (final_states != nullptr + && l_idx < kWidth - 1 + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1) + // So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx] + *reinterpret_cast(final_states) = reinterpret_cast(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx]; + } + + constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); + static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); + constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; + static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); + // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity + static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); + static_assert((kLPerThread & (kLPerThread - 1)) == 0); + static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); + static_assert(kNThreadsPerRow <= 32); + + const int row_idx = tid / kNThreadsPerRow; + const int col_idx = tid % kNThreadsPerRow; + + float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); + float weight_vals[kWidth] = {0}; + if (chunk_c_id * kChunkSizeC + row_idx < params.dim) { + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; + } + } + float x_vals[kWidth - 1 + kLPerThread]; + #pragma unroll + for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { + x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); + } + int seq_idx_thread[kWidth - 1 + kLPerThread]; + if constexpr (kHasSeqIdx) { + #pragma unroll + for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { + seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1; + } + } + + float out_vals[kLPerThread]; + #pragma unroll + for (int i = 0; i < kLPerThread; ++i) { + out_vals[i] = bias_val; + const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + if constexpr (!kHasSeqIdx) { + out_vals[i] += weight_vals[w] * x_vals[i + w]; + } else { + out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f; + } + } + if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); } + } + + __syncthreads(); + #pragma unroll + for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; } + __syncthreads(); + + #pragma unroll + for (int l = 0; l < Ktraits::kNLoads; ++l) { + input_t out_vals_store[kNElts]; + reinterpret_cast(out_vals_store)[0] = reinterpret_cast(x_smem[l * kLPerLoad + l_idx])[c_idx]; + if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + *reinterpret_cast(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast(out_vals_store)[0]; + } + } + +} + +template +void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] { + using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits; + // constexpr int kSmemSize = Ktraits::kSmemSize; + constexpr int kChunkSizeL = Ktraits::kChunkSizeL; + constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; + const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; + const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; + dim3 grid(params.batch, n_chunks_L, n_chunks_C); + dim3 block(Ktraits::kNThreads); + auto kernel = &causal_conv1d_channellast_fwd_kernel; + // if (kSmemSize >= 48 * 1024) { + // C10_CUDA_CHECK(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + // } + // kernel<<>>(params); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +template +void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream); + } +} + +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +/////// + + + + +template +struct Causal_conv1d_update_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_update_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y * kNThreads + tidx; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride + + channel_id * params.conv_state_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + float weight_vals[kWidth] = {0}; + if (channel_id < params.dim) { + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + } + + float x_vals[kWidth] = {0}; + if (channel_id < params.dim) { + #pragma unroll + for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } + x_vals[kWidth - 1] = float(x[0]); + #pragma unroll + for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } + } + + float out_val = bias_val; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; } + if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } + if (channel_id < params.dim) { out[0] = input_t(out_val); } +} + +template +void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + using Ktraits = Causal_conv1d_update_kernel_traits; + dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); + auto kernel = &causal_conv1d_update_kernel; + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); + } +} + +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h new file mode 100644 index 0000000000000..bb25314c8bbbd --- /dev/null +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -0,0 +1,144 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +// clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h +#pragma once + +#include +#include +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ConvParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, width; + bool silu_activation; + + index_t x_batch_stride; + index_t x_c_stride; + index_t x_l_stride; + index_t weight_c_stride; + index_t weight_width_stride; + index_t out_batch_stride; + index_t out_c_stride; + index_t out_l_stride; + + index_t conv_state_batch_stride; + index_t conv_state_c_stride; + index_t conv_state_l_stride; + + // Common data pointers. + void *__restrict__ x_ptr; + void *__restrict__ weight_ptr; + void *__restrict__ bias_ptr; + void *__restrict__ out_ptr; + + void *__restrict__ conv_state_ptr; + + void *__restrict__ seq_idx_ptr; + + // No __restrict__ since initial_states could be the same as final_states. + void * initial_states_ptr; + index_t initial_states_batch_stride; + index_t initial_states_l_stride; + index_t initial_states_c_stride; + + void * final_states_ptr; + index_t final_states_batch_stride; + index_t final_states_l_stride; + index_t final_states_c_stride; +}; + + +#ifndef USE_ROCM + #include + + template + __device__ inline T shuffle_xor(T val, int offset) { + return __shfl_xor_sync(uint32_t(-1), val, offset); + } + + constexpr size_t custom_max(std::initializer_list ilist) + { + return std::max(ilist); + } + + template + constexpr T constexpr_min(T a, T b) { + return std::min(a, b); + } + +#else + #include + + template + __device__ inline T shuffle_xor(T val, int offset) { + return __shfl_xor(val, offset); + } + constexpr size_t custom_max(std::initializer_list ilist) + { + return *std::max_element(ilist.begin(), ilist.end()); + } + + template + constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; + } +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ inline T operator()(T const & x, T const & y) { return x + y; } +}; + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +template<> +struct Allreduce<2> { +template +static __device__ inline T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; diff --git a/csrc/mamba/causal_conv1d/static_switch.h b/csrc/mamba/causal_conv1d/static_switch.h new file mode 100644 index 0000000000000..ef74bf447f840 --- /dev/null +++ b/csrc/mamba/causal_conv1d/static_switch.h @@ -0,0 +1,28 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h +// clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h new file mode 100644 index 0000000000000..0070c92f6cd0f --- /dev/null +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -0,0 +1,276 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +// clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h + +#pragma once + +#ifndef USE_ROCM + #include +#else + #include +#endif +#include +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, dstate, n_groups, n_chunks; + int dim_ngroups_ratio; + bool is_variable_B; + bool is_variable_C; + + bool delta_softplus; + + index_t A_d_stride; + index_t A_dstate_stride; + index_t B_batch_stride; + index_t B_d_stride; + index_t B_dstate_stride; + index_t B_group_stride; + index_t C_batch_stride; + index_t C_d_stride; + index_t C_dstate_stride; + index_t C_group_stride; + index_t u_batch_stride; + index_t u_d_stride; + index_t delta_batch_stride; + index_t delta_d_stride; + index_t z_batch_stride; + index_t z_d_stride; + index_t out_batch_stride; + index_t out_d_stride; + index_t out_z_batch_stride; + index_t out_z_d_stride; + + // Common data pointers. + void *__restrict__ A_ptr; + void *__restrict__ B_ptr; + void *__restrict__ C_ptr; + void *__restrict__ D_ptr; + void *__restrict__ u_ptr; + void *__restrict__ delta_ptr; + void *__restrict__ delta_bias_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; + void *__restrict__ z_ptr; + void *__restrict__ out_z_ptr; + void *__restrict__ index_ptr; +}; + + + + +#ifndef USE_ROCM + + constexpr size_t custom_max(std::initializer_list ilist) + { + return std::max(ilist); + } + + template + constexpr T constexpr_min(T a, T b) { + return std::min(a, b); + } + +#else + constexpr size_t custom_max(std::initializer_list ilist) + { + return *std::max_element(ilist.begin(), ilist.end()); + } + + template + constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; + } +#endif + + +#define MAX_DSTATE 256 + + +inline __device__ float2 operator+(const float2 & a, const float2 & b){ + return {a.x + b.x, a.y + b.y}; +} + +inline __device__ float3 operator+(const float3 &a, const float3 &b) { + return {a.x + b.x, a.y + b.y, a.z + b.z}; +} + +inline __device__ float4 operator+(const float4 & a, const float4 & b){ + return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter{ + static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { + #pragma unroll + for (int i = 0; i < N; ++i) { dst[i] = src[i]; } + } +}; + +template +struct Converter{ + static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } + } +}; + +#if __CUDA_ARCH__ >= 800 +template +struct Converter{ + static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template struct SSMScanOp; + +template<> +struct SSMScanOp { + __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { + return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); + } +}; + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +template struct SSMScanPrefixCallbackOp { + using scan_t = std::conditional_t, float2, float4>; + scan_t running_prefix; + // Constructor + __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ scan_t operator()(scan_t block_aggregate) { + scan_t old_prefix = running_prefix; + running_prefix = SSMScanOp()(running_prefix, block_aggregate); + return old_prefix; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void load_input(typename Ktraits::input_t *u, + typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadT::TempStorage &smem_load, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_vec = reinterpret_cast(smem_load); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockLoadVecT(smem_load_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + #ifdef USE_ROCM + , Ktraits::kNThreads * Ktraits::kNLoads + #endif + + ); + } else { + typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); + } +} + +template +inline __device__ void load_index(int *u, + int (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_index_vec = reinterpret_cast(smem_load_index); + Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + ); + } else { + Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0); + } +} + +template +inline __device__ void load_weight(typename Ktraits::input_t *Bvar, + typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, + int seqlen) { + constexpr int kNItems = Ktraits::kNItems; + typename Ktraits::input_t B_vals_load[kNItems]; + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( + reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load) + ); + } else { + typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); + } + // #pragma unroll + // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } + Converter::to_float(B_vals_load, B_vals); +} + +template +inline __device__ void store_output(typename Ktraits::input_t *out, + const float (&out_vals)[Ktraits::kNItems], + typename Ktraits::BlockStoreT::TempStorage &smem_store, + int seqlen) { + typename Ktraits::input_t write_vals[Ktraits::kNItems]; + #pragma unroll + for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_store_vec = reinterpret_cast(smem_store); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockStoreVecT(smem_store_vec).Store( + reinterpret_cast(out), + reinterpret_cast(write_vals) + ); + } else { + typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); + } +} diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu new file mode 100644 index 0000000000000..df968dda92adc --- /dev/null +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -0,0 +1,593 @@ +// clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh +#include +#include +#include +#include "selective_scan.h" + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#ifndef USE_ROCM + #include + #include + #include +#else + #include + namespace cub = hipcub; +#endif + +#include "selective_scan.h" +#include "static_switch.h" + +template +struct Selective_Scan_fwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. + static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kHasZ = kHasZ_; + static constexpr bool kUseIndex = kUseIndex_; + + static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; + static constexpr int kNLoadsIndex = kNItems / 4; + using vec_t = typename BytesToType::Type; + using scan_t = float2; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadIndexT = cub::BlockLoad; + using BlockLoadIndexVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + sizeof(typename BlockLoadIndexT::TempStorage), + sizeof(typename BlockLoadIndexVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_fwd_kernel(SSMParamsBase params) { + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr bool kUseIndex = Ktraits::kUseIndex; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + constexpr bool kDirectIO = Ktraits::kDirectIO; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + // cast to lvalue reference of expected type + // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); + // auto& smem_load = reinterpret_cast(smem_loadstorescan); + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_index = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); + // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); + scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; + int *index = !kUseIndex ? nullptr :reinterpret_cast(params.index_ptr) + batch_id * params.seqlen; + + float D_val[kNRows] = {0}; + if (params.D_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; + } + } + float delta_bias[kNRows] = {0}; + if (params.delta_bias_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; + } + } + + + // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; + // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; + // } + + constexpr int kChunkSize = kNThreads * kNItems; + for (int chunk = 0; chunk < params.n_chunks; ++chunk) { + input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; + int index_vals_load[kNRows][kNItems]; + + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (!kDirectIO) { __syncthreads(); } + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (kUseIndex) { + load_index(index + r * params.delta_d_stride, index_vals_load[r], smem_load_index, params.seqlen - chunk * kChunkSize); + } + } + if constexpr (kUseIndex) { + index += kChunkSize; + } + u += kChunkSize; + delta += kChunkSize; + + float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float u_val = float(u_vals[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; + if (params.delta_softplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + delta_u_vals[r][i] = delta_vals[r][i] * u_val; + out_vals[r][i] = D_val[r] * u_val; + } + } + + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + constexpr float kLog2e = M_LOG2E; + A_val[r] *= kLog2e; + } + // This variable holds B * C if both B and C are constant across seqlen. If only B varies + // across seqlen, this holds C. If only C varies across seqlen, this holds B. + // If both B and C vary, this is unused. + weight_t BC_val[kNRows]; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (kIsVariableB) { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (1)); + if constexpr (!kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + } + if constexpr (kIsVariableC) { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (1 )); + if constexpr (!kIsVariableB) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; + } + } + } + if constexpr (!kIsVariableB && !kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if (r > 0) { __syncthreads(); } // Scan could be using the same smem + scan_t thread_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), + !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); + + // Reset A bar for cumulative sequences (Real) + if constexpr (kUseIndex) { + if (index_vals_load[r][i] == 0) { + thread_data[i].x = 0.f; + } + } + + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float2(1.f, 0.f); + } + } + } + // Initialize running total + scan_t running_prefix; + // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read + running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f)); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + typename Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + // There's a syncthreads in the scan op, so we don't need to sync here. + // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. + if (threadIdx.x == 0) { + smem_running_prefix[state_idx] = prefix_op.running_prefix; + x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const weight_t C_val = !kIsVariableC + ? BC_val[r] + : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); + out_vals[r][i] += thread_data[i].y * C_val; + } + } + } + + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + input_t z_vals[kNItems]; + __syncthreads(); + load_input(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + out_vals[r][i] *= z_val / (1 + expf(-z_val)); + } + __syncthreads(); + store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + } + + Bvar += kChunkSize * 1; + Cvar += kChunkSize * 1; + } +} + +template +void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { + // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block + // processing 1 row. + constexpr int kNRows = 1; + // kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size + constexpr bool kIsVariableB = true; + constexpr bool kIsVariableC = true; + constexpr bool kHasZ = true; + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); +} + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { + + #ifndef USE_ROCM + if (params.seqlen <= 128) { + selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } + #else + if (params.seqlen <= 256) { + selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } + #endif +} + +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +void set_ssm_params_fwd(SSMParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + const bool is_variable_B, + const bool is_variable_C, + // device pointers + const torch::Tensor u, + const torch::Tensor delta, + const torch::Tensor A, + const torch::Tensor B, + const torch::Tensor C, + const torch::Tensor out, + const torch::Tensor z, + const torch::Tensor out_z, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + bool has_z, + bool delta_softplus, + void* index_ptr) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.dstate = dstate; + params.n_groups = n_groups; + params.n_chunks = n_chunks; + params.dim_ngroups_ratio = dim / n_groups; + + params.delta_softplus = delta_softplus; + + params.is_variable_B = is_variable_B; + params.is_variable_C = is_variable_C; + + // Set the pointers and strides. + params.u_ptr = u.data_ptr(); + params.delta_ptr = delta.data_ptr(); + params.A_ptr = A.data_ptr(); + params.B_ptr = B.data_ptr(); + params.C_ptr = C.data_ptr(); + params.D_ptr = D_ptr; + params.delta_bias_ptr = delta_bias_ptr; + params.out_ptr = out.data_ptr(); + params.x_ptr = x_ptr; + params.z_ptr = has_z ? z.data_ptr() : nullptr; + params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; + + params.index_ptr = index_ptr; + + // All stride are in elements, not bytes. + params.A_d_stride = A.stride(0); + params.A_dstate_stride = A.stride(1); + if (!is_variable_B) { + params.B_d_stride = B.stride(0); + } else { + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + } + params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); + if (!is_variable_C) { + params.C_d_stride = C.stride(0); + } else { + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + } + params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + if (has_z) { + params.z_batch_stride = z.stride(0); + params.z_d_stride = z.stride(1); + params.out_z_batch_stride = out_z.stride(0); + params.out_z_d_stride = out_z.stride(1); + } + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); +} + +std::vector +selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, + const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, + const c10::optional &D_, + const c10::optional &z_, + const c10::optional &delta_bias_, + bool delta_softplus, + const c10::optional &index_, + const c10::optional &x) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float); + + const bool is_variable_B = B.dim() >= 3; + const bool is_variable_C = C.dim() >= 3; + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); + TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = is_variable_B ? B.size(1) : 1; + + TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size") + CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen ); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + + TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size") + CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + if (index_.has_value()) { + auto index = index_.value(); + TORCH_CHECK(index.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(index.is_cuda()); + CHECK_SHAPE(index, batch_size, seqlen); + } + + at::Tensor z, out_z; + const bool has_z = z_.has_value(); + TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size") + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + CHECK_SHAPE(z, batch_size, dim, seqlen); + out_z = torch::empty_like(z); + + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + // at::Tensor out = torch::empty_like(u); + // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout + at::Tensor out = torch::empty_like(delta); + if (x.has_value()){ + auto _x = x.value(); + TORCH_CHECK(_x.scalar_type() == weight_type); + TORCH_CHECK(_x.is_cuda()); + TORCH_CHECK(_x.stride(-1) == 1); + CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2); + } + + SSMParamsBase params; + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, out, z, out_z, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x.value().data_ptr(), + has_z, + delta_softplus, + index_.has_value() ? index_.value().data_ptr() : nullptr); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { + selective_scan_fwd_cuda(params, stream); + }); + std::vector result = {out, x.value()}; + if (has_z) { result.push_back(out_z); } + return result; +} + diff --git a/csrc/mamba/mamba_ssm/static_switch.h b/csrc/mamba/mamba_ssm/static_switch.h new file mode 100644 index 0000000000000..840cb2374a2f0 --- /dev/null +++ b/csrc/mamba/mamba_ssm/static_switch.h @@ -0,0 +1,28 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +// clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/static_switch.h +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/ops.h b/csrc/ops.h index 6bf0cff232528..8d24545de898d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -195,6 +195,28 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); +std::vector selective_scan_fwd( + const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, + const torch::Tensor& B, const torch::Tensor& C, + const c10::optional& D_, + const c10::optional& z_, + const c10::optional& delta_bias_, bool delta_softplus, + const c10::optional& index_, + const c10::optional& x); + +at::Tensor causal_conv1d_update(const at::Tensor& x, + const at::Tensor& conv_state, + const at::Tensor& weight, + const c10::optional& bias_, + bool silu_activation); + +at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, + const c10::optional& bias_, + const c10::optional& seq_idx_, + const c10::optional& initial_states_, + const c10::optional& final_states_out_, + bool silu_activation); + #ifndef USE_ROCM using fptr_t = int64_t; fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 6d1f53b75f4e2..7783acd741f5f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -202,6 +202,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA, &cutlass_scaled_mm_supports_fp8); + // Mamba selective scan kernel + ops.def( + "selective_scan_fwd(Tensor! u, Tensor! delta," + "Tensor! A, Tensor! B, Tensor! C," + "Tensor? D_, Tensor? z_, Tensor? delta_bias_," + "bool delta_softplus," + "Tensor? index_, Tensor? x) -> Tensor[]"); + ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); + + ops.def( + "causal_conv1d_update(Tensor! x," + "Tensor! conv_state," + "Tensor! weight," + "Tensor? bias_," + "bool silu_activation) -> Tensor"); + ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); + + ops.def( + "causal_conv1d_fwd(Tensor! x, Tensor! weight," + "Tensor? bias_," + "Tensor? seq_idx_," + "Tensor? initial_states_," + "Tensor? final_states_out_," + "bool silu_activation) -> Tensor"); + ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); #endif // Quantized GEMM for GPTQ. diff --git a/requirements-mamba.txt b/requirements-mamba.txt deleted file mode 100644 index 1838e87d063da..0000000000000 --- a/requirements-mamba.txt +++ /dev/null @@ -1,3 +0,0 @@ -# Mamba dependencies -mamba-ssm>=1.2.2 -causal-conv1d>=1.2.0 diff --git a/requirements-test.txt b/requirements-test.txt index cdbc3e50cc9ec..46eb05fc31099 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -11,7 +11,7 @@ pytest-shard # testing utils awscli -einops # required for MPT and qwen-vl +einops # required for MPT, qwen-vl and Mamba httpx peft requests diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py new file mode 100644 index 0000000000000..7bf338b36953a --- /dev/null +++ b/tests/kernels/test_causal_conv1d.py @@ -0,0 +1,205 @@ +from typing import Optional + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange + +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) + + +def causal_conv1d_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, + weight.unsqueeze(1), + bias, + padding=width - 1, + groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update_ref(x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None): + """ + x: (batch, dim) + conv_state: (batch, dim, width) + weight: (dim, width) + bias: (dim,) + + out: (batch, dim) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + batch, dim = x.shape + width = weight.shape[1] + assert conv_state.shape == (batch, dim, width) + assert weight.shape == (dim, width) + conv_state.copy_(torch.roll(conv_state, shifts=-1, + dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = x + out = torch.sum(conv_state * weight, dim=-1) # (B D) + if bias is not None: + out += bias + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) + + +@pytest.mark.parametrize("return_final_states", [False, True]) +@pytest.mark.parametrize("has_initial_states", [False, True]) +@pytest.mark.parametrize("channel_last", [False, True]) +@pytest.mark.parametrize("itype", [torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [False, True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("width", [4]) +@pytest.mark.parametrize("seqlen", [128, 512, 4096]) +@pytest.mark.parametrize('dim', [64, 4096 + 32]) +@pytest.mark.parametrize('batch', [1, 2]) +def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, + itype, channel_last, has_initial_states, + return_final_states): + if not channel_last and (has_initial_states or return_final_states): + pytest.skip( + "Only channel_last support initial_states or return_final_states") + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + # set seed + torch.random.manual_seed(0) + if not channel_last: + x = torch.randn(batch, + 4096 + dim + 64, + seqlen, + device=device, + dtype=itype)[:, 4096:4096 + dim, :] + else: + x = rearrange( + torch.randn(batch, + seqlen, + 4096 + dim + 64, + device=device, + dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s") + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None + if has_initial_states: + initial_states = torch.randn(batch, + width - 1, + dim, + device=device, + dtype=itype).transpose(1, 2) + else: + initial_states = None + x_ref = x.detach().clone() + weight_ref = weight.detach().clone() + bias_ref = bias.detach().clone() if bias is not None else None + initial_states_ref = initial_states.detach().clone( + ) if initial_states is not None else None + activation = None if not silu_activation else "silu" + out, final_states = causal_conv1d_fn( + x, + weight, + bias, + initial_states=initial_states, + return_final_states=return_final_states, + activation=activation) + out_ref, final_states_ref = causal_conv1d_ref( + x_ref, + weight_ref, + bias_ref, + initial_states=initial_states_ref, + return_final_states=return_final_states, + activation=activation) + if return_final_states: + assert final_states is not None and final_states_ref is not None + assert torch.allclose(final_states, + final_states_ref, + rtol=rtol, + atol=atol) + + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + if return_final_states: + out += F.sigmoid(final_states).sum(dim=-1, keepdim=True) + out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True) + + +@pytest.mark.parametrize("itype", [torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [False, True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("width", [2, 3, 4]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +@pytest.mark.parametrize("batch", [1, 2]) +def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, + itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + # set seed + torch.random.manual_seed(0) + batch = 2 + x = torch.randn(batch, dim, device=device, dtype=itype) + conv_state = torch.randn(batch, dim, width, device=device, dtype=itype) + weight = torch.randn(dim, + width, + device=device, + dtype=itype, + requires_grad=True) + if has_bias: + bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) + else: + bias = None + conv_state_ref = conv_state.detach().clone() + activation = None if not silu_activation else "silu" + out = causal_conv1d_update(x, + conv_state, + weight, + bias, + activation=activation) + out_ref = causal_conv1d_update_ref(x, + conv_state_ref, + weight, + bias, + activation=activation) + + assert torch.equal(conv_state, conv_state_ref) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py new file mode 100644 index 0000000000000..d3cb0a8656a02 --- /dev/null +++ b/tests/kernels/test_mamba_ssm.py @@ -0,0 +1,324 @@ +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) + + +def selective_state_update_ref(state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + dt = dt + dt_bias + dt = F.softplus(dt) if dt_softplus else dt + dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * + A) # (batch, nheads, dim, dstate) + B = repeat(B, "b g n -> b (g h) n", + h=nheads // ngroups) # (batch, nheads, dstate) + C = repeat(C, "b g n -> b (g h) n", + h=nheads // ngroups) # (batch, nheads, dstate) + dB = rearrange(dt, "b h d -> b h d 1") * rearrange( + B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate) + state.copy_(state * dA + + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate + out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) + if D is not None: + out += (x * D).to(out.dtype) + out = (out if z is None else out * F.silu(z)).to(x.dtype) + if not has_heads: + out = out.squeeze(1) + return out + + +def selective_scan_ref(u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + position_indices=None, + prev_state=None): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + prev_state: r(B D N), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + if position_indices is not None and position_indices[0, i] == 0: + x = deltaB_u[:, :, i] + else: + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +@pytest.mark.parametrize('wtype', [torch.float32]) +@pytest.mark.parametrize('itype', [torch.float32]) +@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("return_last_state", [True]) +@pytest.mark.parametrize('has_delta_bias', [True]) +@pytest.mark.parametrize('delta_softplus', [True]) +@pytest.mark.parametrize('has_z', [True]) +@pytest.mark.parametrize('has_D', [True]) +@pytest.mark.parametrize("varBC_groups", [1, 2]) +@pytest.mark.parametrize("is_variable_C", [True]) +@pytest.mark.parametrize("is_variable_B", [True]) +@pytest.mark.parametrize("scan_chunks", [1, 2, 3]) +def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, + has_z, has_delta_bias, delta_softplus, + return_last_state, seqlen, itype, wtype, scan_chunks): + if varBC_groups > 1 and (not is_variable_B or not is_variable_C): + pytest.skip() # This config is not applicable + device = 'cuda' + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + if has_z: # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + # set seed + torch.random.manual_seed(0) + batch_size = 2 + dim = 4 + dstate = 8 + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)) + if not is_variable_B: + B_shape = [dim, dstate] + elif varBC_groups == 1: + B_shape = [batch_size, dstate, seqlen] + else: + B_shape = [batch_size, varBC_groups, dstate, seqlen] + B = torch.randn(B_shape, + device=device, + dtype=wtype if not is_variable_B else itype) + if not is_variable_C: + C_shape = [dim, dstate] + elif varBC_groups == 1: + C_shape = [batch_size, dstate, seqlen] + else: + C_shape = [batch_size, varBC_groups, dstate, seqlen] + C = torch.randn(C_shape, + device=device, + dtype=wtype if not is_variable_C else itype) + D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None + z = torch.randn(batch_size, dim, seqlen, device=device, + dtype=itype) if has_z else None + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32) + ) if has_delta_bias else None + u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) + delta = (0.5 * + torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)) + state = None + state_ref = None + out = None + out_ref = None + outs = [] + for c in range(scan_chunks): + chunked_prompt_len = seqlen // scan_chunks + chunk_start = chunked_prompt_len * c + chunk_end = chunked_prompt_len * (c + 1) + if c == scan_chunks - 1: + chunk_end = seqlen + _B = B + if is_variable_B: + _B = B[..., chunk_start:chunk_end] + _C = C + if is_variable_B: + _C = C[..., chunk_start:chunk_end] + _z = z + if has_z: + assert z is not None + _z = z[..., chunk_start:chunk_end] + out, *rest = selective_scan_fn(u[..., chunk_start:chunk_end], + delta[..., chunk_start:chunk_end], + A, + _B, + _C, + D, + z=_z, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + return_last_state=return_last_state, + prev_state=state if c > 0 else None) + outs.append(out) + if return_last_state: + state = rest[0] + if len(outs) > 1: + out = torch.cat(outs, dim=-1) + out_ref, *rest = selective_scan_ref(u, + delta, + A, + B, + C, + D, + z=z, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + return_last_state=return_last_state) + if return_last_state: + state_ref = rest[0] + + assert out is not None and out_ref is not None + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + if return_last_state: + assert state is not None and state_ref is not None + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +def test_selective_state_update(dim, dstate, has_z, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + if torch.version.hip: + atol *= 2 + # set seed + torch.random.manual_seed(0) + batch_size = 1 + state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) + x = torch.randn(batch_size, dim, device=device, dtype=itype) + dt = torch.randn(batch_size, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(batch_size, dstate, device=device) + C = torch.randn(batch_size, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + state_ref = state.detach().clone() + out = selective_state_update(state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True) + out_ref = selective_state_update_ref(state_ref, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True) + + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e5e7bb6963973..fe254732e7309 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -500,6 +500,36 @@ def ggml_mul_mat_a8( return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row) +# mamba +def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], + seq_idx_: Optional[torch.Tensor], + initial_states_: Optional[torch.Tensor], + final_states_out_: Optional[torch.Tensor], + silu_activation: bool) -> torch.Tensor: + return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, + initial_states_, final_states_out_, + silu_activation) + + +def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, + weight: torch.Tensor, bias_: Optional[torch.Tensor], + silu_activation: bool) -> torch.Tensor: + return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, + silu_activation) + + +def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, + B: torch.Tensor, C: torch.Tensor, + D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, index_: Optional[torch.Tensor], + x: Optional[torch.Tensor]) -> List[torch.Tensor]: + return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, + delta_bias_, delta_softplus, index_, + x) + + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/__init__.py b/vllm/model_executor/layers/mamba/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/mamba/ops/__init__.py b/vllm/model_executor/layers/mamba/ops/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py new file mode 100644 index 0000000000000..413c8bc227ae8 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024, Tri Dao. + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + seq_idx: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out=None, + activation: str = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert (initial_states is + None), "initial_states must be None if seq_idx is not None" + assert (not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and (initial_states.stride(2) != 1 + and initial_states.stride(1) != 1): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert (final_states_out.stride(2) == 1 + or final_states_out.stride(1) == 1) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty(batch, + width - 1, + dim, + device=x.device, + dtype=x.dtype).transpose(1, 2) + else: + final_states_out = None + + out = ops.causal_conv1d_fwd(x, weight, bias, seq_idx, initial_states, + final_states_out, activation + in ["silu", "swish"]) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None): + """ + x: (batch, dim) + conv_state: (batch, dim, width) + weight: (dim, width) + bias: (dim,) + + out: (batch, dim) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation_bool = activation in ["silu", "swish"] + return ops.causal_conv1d_update(x, conv_state, weight, bias, + activation_bool) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py new file mode 100644 index 0000000000000..869c69214caf2 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -0,0 +1,346 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import torch +import triton +import triton.language as tl +from packaging import version + +from vllm import _custom_ops as ops + +TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") + +if TRITON3: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt +else: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt + + +@triton.heuristics( + {"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics( + {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) +@triton.jit +def _selective_scan_update_kernel( + # Pointers to matrices + state_ptr, + x_ptr, + dt_ptr, + dt_bias_ptr, + A_ptr, + B_ptr, + C_ptr, + D_ptr, + z_ptr, + out_ptr, + # Matrix dimensions + batch, + nheads, + dim, + dstate, + nheads_ngroups_ratio, + # Strides + stride_state_batch, + stride_state_head, + stride_state_dim, + stride_state_dstate, + stride_x_batch, + stride_x_head, + stride_x_dim, + stride_dt_batch, + stride_dt_head, + stride_dt_dim, + stride_dt_bias_head, + stride_dt_bias_dim, + stride_A_head, + stride_A_dim, + stride_A_dstate, + stride_B_batch, + stride_B_group, + stride_B_dstate, + stride_C_batch, + stride_C_group, + stride_C_dstate, + stride_D_head, + stride_D_dim, + stride_z_batch, + stride_z_head, + stride_z_dim, + stride_out_batch, + stride_out_head, + stride_out_dim, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + TIE_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head + if HAS_DT_BIAS: + dt_bias_ptr += pid_h * stride_dt_bias_head + A_ptr += pid_h * stride_A_head + B_ptr += pid_b * stride_B_batch + (pid_h // + nheads_ngroups_ratio) * stride_B_group + C_ptr += pid_b * stride_C_batch + (pid_h // + nheads_ngroups_ratio) * stride_C_group + if HAS_Z: + z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + + offs_n[None, :] * stride_state_dstate) + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + if HAS_DT_BIAS: + dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim + if HAS_D: + D_ptr += pid_h * stride_D_head + A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + + offs_n[None, :] * stride_A_dstate) + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_D: + D_ptrs = D_ptr + offs_m * stride_D_dim + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + + state = tl.load(state_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0) + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if not TIE_HDIM: + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptr).to(tl.float32) + dA = tl.exp(A * dt) # scalar, not a matrix + + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt + state = state * dA + dB * x[:, None] + tl.store(state_ptrs, + state, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + +def selective_state_update(state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + out = torch.empty_like(x) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else + (0, 0, 0)) + # We don't want autotune since it will overwrite the state + # We instead tune by hand. + BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else + ((16, 4) if dstate <= 32 else + ((8, 4) if dstate <= 64 else + ((4, 4) if dstate <= 128 else ((4, 8)))))) + tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride( + -1) == 0 and dt_bias.stride(-1) == 0 + with torch.cuda.device(x.device.index): + _selective_scan_update_kernel[grid]( + state, + x, + dt, + dt_bias, + A, + B, + C, + D, + z, + out, + batch, + nheads, + dim, + dstate, + nheads // ngroups, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + x.stride(0), + x.stride(1), + x.stride(2), + dt.stride(0), + dt.stride(1), + dt.stride(2), + *(dt_bias.stride(0), + dt_bias.stride(1)) if dt_bias is not None else 0, + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + C.stride(0), + C.stride(1), + C.stride(2), + *(D.stride(0), D.stride(1)) if D is not None else 0, + z_strides[0], + z_strides[1], + z_strides[2], + out.stride(0), + out.stride(1), + out.stride(2), + dt_softplus, + tie_hdim, + BLOCK_SIZE_M, + num_warps=num_warps, + ) + if not has_heads: + out = out.squeeze(1) + return out + + +def selective_scan_fn(u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + position_indices=None, + prev_state=None): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). + """ + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = B.unsqueeze(1) + if C.dim() == 3: + C = C.unsqueeze(1) + n_chunks = int((u.shape[-1] + 2048 - 1) / 2048) + x = torch.zeros(( + u.shape[0], + u.shape[1], + n_chunks, + int(A.shape[1] * 2), + ), + device=u.device, + dtype=torch.float32, + requires_grad=False) + x[:, :, 0, 0::2] = 1 + if prev_state is not None: + x[:, :, 0, 1::2].copy_(prev_state) + out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, + delta_softplus, position_indices, x) + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if z is None: + return out if not return_last_state else (out, last_state) + else: + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index caeda4e42d8a0..ac3b59f95f7e0 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -4,9 +4,6 @@ from typing import Dict, Iterable, List, Optional, Tuple import torch -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn -from mamba_ssm.ops.triton.selective_state_update import selective_state_update from torch import nn from torch.nn.parameter import Parameter from transformers import JambaConfig @@ -24,6 +21,10 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler @@ -161,7 +162,7 @@ def mamba_forward(self, (self.conv_kernel_size - hidden_states.shape[-1], 0)) cache_params.conv_state.copy_(conv_states) - hidden_states = causal_conv1d_fn( + hidden_states, _ = causal_conv1d_fn( hidden_states, conv_weights, self.conv1d.bias,