Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dequantization Utils Library #2521

Merged
merged 3 commits into from
Nov 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions csrc/includes/dequantization_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include "conversion_utils.h"
#include "ds_kernel_utils.h"
#include "quantization.h"
#include "quantization_utils.h"

namespace cg = cooperative_groups;

#pragma once

namespace dequantize {
using Type = quantize::Type;

template <Type qType, int numBits>
using Params = quantize::Params<qType, numBits>;

constexpr int granularity = quantize::granularity;
using PackedInt4 = quantize::PackedInt4;

constexpr int h_per_chunk = granularity / sizeof(__half);
constexpr int h2_per_chunk = granularity / sizeof(__half2);

/*
Device function that reads quantized data from global memory, dequantizes
it, and stores it to global memory.
Template Arguments :
numBits - Number of bits in quantized element. int: 4, 8
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
unroll - Number of load steps to internally unroll int
threads - Number of threads to perform dequant int
Function arguments:
global_output - __half pointer in global memory
data - Quantized data in global memory
global_params - Quantization parameters in global memory
elems_per_group - Number of elements in each quantization group
total_elems - Tensor size (note, does not need to be multiple of elems_per_group)
*/
template <int numBits, Type qType, int unroll, int threads>
DS_D_INLINE void to_global(__half* global_output,
const int8_t* data,
const float* global_params,
const int elems_per_group,
const int total_elems);

/*
Device function that quantizes 16 bytes of __half type input data.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
Function Arguments :
local_output - Local array to store dequantized data __half* or __half2*
data - Pointer to quantized input data. int8_t*
Params - Parameters for quantization. Params<qType, numBits>
*/
template <int numBits, Type qType>
DS_D_INLINE void chunk(__half2* local_output, const int8_t* data, Params<qType, numBits> q_params);

template <int numBits, Type qType>
DS_D_INLINE void chunk(__half* local_output, const int8_t* data, Params<qType, numBits> q_params);

/**************** Implementations ******************/

template <int numBits, Type qType>
DS_D_INLINE void chunk(__half* local_output, const int8_t* data, Params<qType, numBits> q_params)
{
constexpr int32_t num_elems_packed = 8 / numBits;
constexpr int32_t iters = h_per_chunk / num_elems_packed;

#pragma unroll
for (int i = 0; i < iters; i++) {
if constexpr (num_elems_packed == 1) {
local_output[i] = q_params.dequantize(data[i]);
} else {
auto accessible_data = *(PackedInt4*)(&data[i]);
local_output[2 * i] = q_params.dequantize(accessible_data.low);
local_output[2 * i + 1] = q_params.dequantize(accessible_data.high);
}
}
}

template <int numBits, Type qType>
DS_D_INLINE void chunk(__half2* local_output, const int8_t* data, Params<qType, numBits> q_params)
{
__half* local_output_cast = reinterpret_cast<__half*>(local_output);
chunk<numBits>(local_output_cast, data, q_params);
}

template <int numBits, Type qType, int unroll, int threads>
DS_D_INLINE void _to_global(__half* global_output,
const int8_t* data,
const float* global_params,
const int elems_per_group,
const int total_elems)
{
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);

// Load constants
// TODO(cmikeh2): Refactor into functions?
constexpr int load_granularity = granularity * numBits / 16;
constexpr int load_step_stride = load_granularity * threads;
constexpr int load_block_stride = load_step_stride * unroll;

// Store constants
constexpr int store_step_stride = h_per_chunk * threads;
constexpr int store_block_stride = store_step_stride * unroll;

// Load offsets
const int load_block_offset = tb.group_index().x * load_block_stride;
// Note: we can use `load_granularity` since the dtype is `int8_t`.
const int load_thread_offset = tb.thread_index().x * load_granularity;
const int8_t* load_base = data + load_block_offset + load_thread_offset;

// Store offsets
const int store_block_offset = tb.group_index().x * store_block_stride;
const int store_thread_offset = tb.thread_index().x * h_per_chunk;
const int elem_id_base = store_block_offset + store_thread_offset;

int8_t local_load_buffer[load_granularity * unroll];
__half local_dequant_buffer[h_per_chunk * unroll];

/*
Note: Splitting this loop in half gave about 3-5% performance increase for reasons that aren't
totally clear to me, so this is a deliberately weird code structure.
*/
#pragma unroll
for (int i = 0; i < unroll; i++) {
const int elem_id_iter = elem_id_base + i * store_step_stride;

if (elem_id_iter < total_elems) {
mem_access::load_global<load_granularity>(local_load_buffer + i * load_granularity,
load_base + i * load_step_stride);
}
}

#pragma unroll
for (int i = 0; i < unroll; i++) {
const int elem_id_iter = elem_id_base + i * store_step_stride;
if (elem_id_iter < total_elems) {
// TODO(cmikeh2): Can we amortize this division? Perform once on the first iteration and
// use indexing math to do division free interpolation of the successive groups?
const int group_index = elem_id_iter / elems_per_group;
Params<qType, numBits> q_params(global_params, group_index);

chunk<numBits, qType>(local_dequant_buffer + i * h_per_chunk,
local_load_buffer + i * load_granularity,
q_params);
mem_access::store_global<granularity>(global_output + elem_id_iter,
local_dequant_buffer + i * h_per_chunk);
}
}
}

template <int numBits, Type qType, int unroll, int threads>
DS_D_INLINE void to_global(__half* global_output,
const int8_t* data,
const float* global_params,
const int elems_per_group,
const int total_elems)
{
if constexpr (numBits == 4 || numBits == 8) {
_to_global<numBits, qType, unroll, threads>(
global_output, data, global_params, elems_per_group, total_elems);
} else if constexpr (numBits == 3) {
// TODO(cmikeh2): Need this implementation
assert(false);
} else {
assert(false);
}
}

} // namespace dequantize
9 changes: 9 additions & 0 deletions csrc/includes/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ void launch_quant(int8_t* output_data,
int elems_per_group,
cudaStream_t stream);

void launch_dequantize_kernel(__half* dequant_data,
const int8_t* q_data,
const float* q_params,
quantize::Type q_type,
int num_bits,
int elems_per_group,
int total_elems,
cudaStream_t stream);

template <typename T>
void launch_fake_quantize_kernel(T* vals,
int total_count,
Expand Down
44 changes: 41 additions & 3 deletions csrc/includes/quantization_utils.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#include <cstdio>
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include <cassert>
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
#include "memory_access_utils.h"
Expand Down Expand Up @@ -33,7 +37,12 @@ class Params {
*/
DS_D_INLINE int8_t quantize(__half val);

DS_D_INLINE __half dequantize(int8_t val);

DS_D_INLINE void store(float* params, int group_index);

// Initialize from memory
DS_D_INLINE Params(const float* params, int group_index);
};

template <int numBits>
Expand Down Expand Up @@ -61,11 +70,22 @@ class Params<Type::Symmetric, numBits> {
return (int8_t)data_i32;
}

DS_D_INLINE __half dequantize(int8_t val)
{
const float val_deq_f = conversion::to<float>(val) * scale;
return conversion::to<__half>(val_deq_f);
}

DS_D_INLINE void store(float* params, int group_index)
{
const float store_scale = 1 / scale;
mem_access::store_global<sizeof(float)>(params + group_index, &store_scale);
}

DS_D_INLINE Params(const float* params, int group_index)
{
mem_access::load_global<sizeof(float)>(&scale, params + group_index);
}
};

template <int numBits>
Expand All @@ -84,10 +104,14 @@ class Params<Type::IntegerSymmetric, numBits> {
return (int8_t)data_i32;
}

DS_D_INLINE __half dequantize(int8_t val) { assert(false); }

DS_D_INLINE void store(float* params, int group_index)
{
mem_access::store_global<sizeof(float)>(params + group_index, &scale);
}

DS_D_INLINE Params(const float* params, int group_index) { assert(false); }
};

template <int numBits>
Expand Down Expand Up @@ -117,12 +141,26 @@ class Params<Type::Asymmetric, numBits> {
return (int8_t)data_i32;
}

DS_D_INLINE __half dequantize(int8_t val)
{
const float val_deq_f = conversion::to<float>(val) * scale + offset;
return conversion::to<__half>(val_deq_f);
}

DS_D_INLINE void store(float* params, int group_index)
{
// Codegen should turn this into stg.64
const float store_scale = 1 / scale;
mem_access::store_global<sizeof(float)>(params + 2 * group_index, &store_scale);
mem_access::store_global<sizeof(float)>(params + 2 * group_index + 1, &offset);
}

DS_D_INLINE Params(const float* params, int group_index)
{
// Codegen should turn this into ldg.64
mem_access::load_global<sizeof(float)>(&scale, params + 2 * group_index);
mem_access::load_global<sizeof(float)>(&offset, params + 2 * group_index + 1);
}
};

/*
Expand Down Expand Up @@ -293,7 +331,7 @@ Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
Function Arguments :
local_output - Pointer to shared memory to store quantized data. int8_t*
local_output - Pointer to local memory to store quantized data. int8_t*
data - Pointer to input data. __half*
Params - Parameters for quantization. Params<qType, numBits>
*/
Expand All @@ -306,7 +344,7 @@ Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
Function Arguments :
local_output - Pointer to shared memory to store quantized data. int8_t*
local_output - Pointer to local memory to store quantized data. int8_t*
data - Pointer to input data. __half2*
Params - Parameters for quantization. Params<qType, numBits>
*/
Expand Down
2 changes: 1 addition & 1 deletion csrc/includes/reduction_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ DS_D_INLINE __half init<ROpType::Min>()
}

template <>
__half init<ROpType::Max>()
DS_D_INLINE __half init<ROpType::Max>()
{
constexpr __half_raw neg_inf = {0xFC00};
return __half(neg_inf);
Expand Down
52 changes: 52 additions & 0 deletions csrc/quantization/dequantize.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

#include "dequantization_utils.h"
#include "memory_access_utils.h"

namespace cg = cooperative_groups;

template <int numBits, dequantize::Type qType, int unroll, int threads>
__global__ void dequantize_kernel(__half* __restrict__ dequant_data,
const int8_t* __restrict__ q_data,
const float* __restrict__ q_params,
int elems_per_group,
int total_elems)
{
dequantize::to_global<numBits, qType, unroll, threads>(
dequant_data, q_data, q_params, elems_per_group, total_elems);
}

#define LAUNCH_DEQUANT_KERNEL(num_bits, q_type) \
dequantize_kernel<num_bits, q_type, unroll, threads><<<grid, block, 0, stream>>>( \
dequant_data, q_data, q_params, elems_per_group, total_elems);

void launch_dequantize_kernel(__half* dequant_data,
const int8_t* q_data,
const float* q_params,
quantize::Type q_type,
int num_bits,
int elems_per_group,
int total_elems,
cudaStream_t stream)
{
constexpr int unroll = 8;
constexpr int threads = 512;
constexpr int elems_per_block = unroll * threads * dequantize::h_per_chunk;

const dim3 block(threads);
const dim3 grid((total_elems + elems_per_block - 1) / elems_per_block);

// TODO(cmikeh2): It may make sense to tune unroll, there is perf benefit for large
// problem sizes with this large unroll value.
if (num_bits == 8 && q_type == quantize::Type::Symmetric) {
LAUNCH_DEQUANT_KERNEL(8, quantize::Type::Symmetric);
} else if (num_bits == 8 && q_type == quantize::Type::Asymmetric) {
LAUNCH_DEQUANT_KERNEL(8, quantize::Type::Asymmetric);
} else if (num_bits == 4 && q_type == quantize::Type::Symmetric) {
LAUNCH_DEQUANT_KERNEL(4, quantize::Type::Symmetric);
} else if (num_bits == 4 && q_type == quantize::Type::Asymmetric) {
LAUNCH_DEQUANT_KERNEL(4, quantize::Type::Asymmetric);
}
}
Loading