Skip to content

Fix int32_t to auto for code around WeightRow #4045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
using namespace fbgemm_gpu;
using Tensor = at::Tensor;

[[maybe_unused]] static constexpr float kINT8QparamsBytes = 8;
[[maybe_unused]] static constexpr int32_t kINT8QparamsBytes = 8;

////////////////////////////////////////////////////////////////////////////////
// Kernel Definitions
Expand Down
9 changes: 4 additions & 5 deletions fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,10 @@ static constexpr uint32_t kFullWarpMask = 0xff'ff'ff'ff;

static constexpr float kQParamEps = 1e-8f;

/* For rowwise int8 quantization, two quantization parameters (qparams)
will be stored at the end of each row in FP32 formats, appending a total of
8 bytes to each row.
*/
static constexpr float kINT8QparamsBytes = 8;
// For rowwise int8 quantization, two quantization parameters (qparams) will be
// stored at the end of each row in FP32 formats, appending a total of 8 bytes
// to each row.
static constexpr int32_t kINT8QparamsBytes = 8;

template <typename T>
DEVICE_INLINE T shfl_xor(
Expand Down
20 changes: 10 additions & 10 deletions fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,12 @@ struct WeightRow {

DEVICE_INLINE void warp_copy_to_cache(
cache_t* dst_row,
const int32_t dim_length,
const int32_t num_lanes,
const int32_t lane_id) {
const uint32_t dim_length,
const uint32_t num_lanes,
const uint32_t lane_id) {
if constexpr (std::is_same_v<emb_t, cache_t>) {
// No conversion required when emb_t and cache_t are the same type
for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
for (auto d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
same_type_vector_copy(
dst_row + d, reinterpret_cast<const cache_t*>(row_ + d));
}
Expand All @@ -229,17 +229,17 @@ struct WeightRow {

// Copy over for each warp-sized slice of Vec4's
// Does 2-step conversion: weight_t -> FP32 -> cache_t
for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
for (auto d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
const auto slice = load(d, qparams);
quantize_store(dst_row + d, slice, stoc_rounding_state_ptr_, qparams);
}
}
}

DEVICE_INLINE void warp_evict_cache(
const int32_t dim_length,
const int32_t num_lanes,
const int32_t lane_id) {
const uint32_t dim_length,
const uint32_t num_lanes,
const uint32_t lane_id) {
float2 qparams;

if constexpr (std::is_same_v<emb_t, uint8_t>) {
Expand All @@ -248,7 +248,7 @@ struct WeightRow {
std::numeric_limits<at::acc_type<cache_t, true>>::lowest();

// Compute the qparams from the cache row (not embedding row) weights
for (int32_t d = lane_id; d * 4 < dim_length; d += num_lanes) {
for (auto d = lane_id; d * 4 < dim_length; d += num_lanes) {
const auto cache_slice = load(d * 4, qparams); // qparams not used
local_max = max(local_max, cache_slice.vmax());
local_min = min(local_min, cache_slice.vmin());
Expand All @@ -263,7 +263,7 @@ struct WeightRow {
}
}

for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
for (auto d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
// Evict the slice into the embedding row
evict_cache(d, qparams);
}
Expand Down
14 changes: 7 additions & 7 deletions fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel(
bool stochastic_rounding,
at::PhiloxCudaState stochastic_rounding_philox_args) {
const int32_t C = lxu_cache_state.size(0);
for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique;
for (auto n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique;
n += gridDim.x * blockDim.y) {
// check if this warp is responsible for this whole segment.
const bool segment_start =
Expand All @@ -64,21 +64,21 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel(

// now, we need to insert the (unique!) values in indices[n:n + SL] into
// our slots.
const int32_t slot = threadIdx.x;
const auto slot = threadIdx.x;
const int64_t current_idx = lxu_cache_state[cache_set][slot];
const int64_t current_lfu_cost =
(current_idx != static_cast<int64_t>(kCacheStateInvalid))
? lfu_state[current_idx]
: -1;
int64_t costs[1] = {current_lfu_cost};
int32_t slots[1] = {slot};
uint32_t slots[1] = {slot};

BitonicSort<int64_t, int32_t, 1, Comparator<int64_t>>::sort(costs, slots);
const int32_t sorted_slot = slots[0];
const int64_t sorted_lfu_cost = costs[0];
BitonicSort<int64_t, uint32_t, 1, Comparator<int64_t>>::sort(costs, slots);
const auto sorted_slot = slots[0];
const auto sorted_lfu_cost = costs[0];

for (int32_t l = 0; l < min(SL, kWarpSize); ++l) {
const int32_t insert_slot = shfl_sync(sorted_slot, l);
const auto insert_slot = shfl_sync(sorted_slot, l);
const int64_t insert_current_lfu_cost = shfl_sync(sorted_lfu_cost, l);
const int64_t insert_idx = cache_set_sorted_indices[n + l];
const int64_t insert_lfu_cost = lfu_state[insert_idx];
Expand Down
14 changes: 7 additions & 7 deletions fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel(
lxu_cache_locking_counter) {
const int32_t C = lxu_cache_state.size(0);
int32_t n_conflict_misses = 0;
for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique;
for (auto n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique;
n += gridDim.x * blockDim.y) {
// check if this warp is responsible for this whole segment.
const bool segment_start =
Expand All @@ -70,20 +70,20 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel(

// now, we need to insert the (unique!) values in indices[n:n + SL] into
// our slots.
const int32_t slot = threadIdx.x;
const auto slot = threadIdx.x;
const int64_t slot_time = lru_state[cache_set][slot];
int64_t costs[1] = {slot_time};
int32_t slots[1] = {slot};
uint32_t slots[1] = {slot};

BitonicSort<int64_t, int32_t, 1, Comparator<int64_t>>::sort(costs, slots);
const int32_t sorted_slot = slots[0];
const int64_t sorted_lru_cost = costs[0];
BitonicSort<int64_t, uint32_t, 1, Comparator<int64_t>>::sort(costs, slots);
const auto sorted_slot = slots[0];
const auto sorted_lru_cost = costs[0];
const auto stoc_rounding_salt = kWarpSize *
(blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x +
threadIdx.x);

for (int32_t l = 0; l < min(SL, kWarpSize); ++l) {
const int32_t insert_slot = shfl_sync(sorted_slot, l);
const auto insert_slot = shfl_sync(sorted_slot, l);
if (lock_cache_line) {
auto count = lxu_cache_locking_counter[cache_set][insert_slot];
if (count > 0) {
Expand Down
14 changes: 7 additions & 7 deletions fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel(
bool stochastic_rounding,
at::PhiloxCudaState stochastic_rounding_philox_args) {
const int32_t B = lxu_cache_weights.size(0);
const int32_t b = blockIdx.x * blockDim.y + threadIdx.y;
const auto b = blockIdx.x * blockDim.y + threadIdx.y;
if (b >= B) {
return;
}
Expand All @@ -55,7 +55,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel(
if constexpr (std::is_same_v<emb_t, uint8_t>) {
D_emb += kINT8QparamsBytes;
}
StochasticRoundingRNGState state;

auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(
&weights[weights_offset_current + idx_current * D_emb + 0],
&lxu_cache_weights[b][0],
Expand All @@ -73,7 +73,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel(
weight_row.store_qparams(qparams);
}
}
for (int32_t d = threadIdx.x * 4; d < D_current; d += blockDim.x * 4) {
for (auto d = threadIdx.x * 4; d < D_current; d += blockDim.x * 4) {
weight_row.evict_cache(d, qparams);
}
}
Expand Down Expand Up @@ -175,7 +175,7 @@ __launch_bounds__(kMaxThreads) void lxu_cache_locking_counter_decrement_kernel(
lxu_cache_locking_counter,
pta::PackedTensorAccessor32<int32_t, 2, at::RestrictPtrTraits> count) {
const int32_t C = lxu_cache_locking_counter.size(0);
for (int32_t i = blockIdx.x * blockDim.y + threadIdx.y; i < C;
for (auto i = blockIdx.x * blockDim.y + threadIdx.y; i < C;
i += gridDim.x * blockDim.y) {
const auto j = threadIdx.x;
if (count[i][j] > 0) {
Expand Down Expand Up @@ -259,7 +259,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel(
const int32_t C = lxu_cache_state.size(0);
const int32_t N =
N_unique == nullptr ? linear_cache_indices.size(0) : *N_unique;
const int32_t n0 =
const auto n0 =
blockIdx.x * blockDim.y * blockDim.x + threadIdx.y * blockDim.x;
if (n0 >= N) {
return;
Expand All @@ -270,7 +270,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel(
int32_t n_hits = 0;
const auto slot = threadIdx.x;
for (int i = 0; i < blockDim.x; ++i) {
int32_t n = n0 + i;
const auto n = n0 + i;
if (n >= N) {
continue;
}
Expand Down Expand Up @@ -303,7 +303,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel(
}
}

const int32_t n = n0 + threadIdx.x;
const auto n = n0 + threadIdx.x;
if (n < N) {
lxu_cache_locations[n] = cache_location;
}
Expand Down
Loading