Skip to content

Commit bbdd76c

Browse files
q10facebook-github-bot
authored andcommitted
Break up fbgemm_cuda_utils.cuh, pt 8 (#2807)
Summary: X-link: facebookresearch/FBGEMM#14 Pull Request resolved: #2807 - Break up `fbgemm_cuda_utils.cuh`, pt 8 Reviewed By: spcyppt Differential Revision: D59412344 fbshipit-source-id: d9acf70a666316e8b0d28726bb147502769313b1
1 parent 24e6f96 commit bbdd76c

File tree

11 files changed

+363
-338
lines changed

11 files changed

+363
-338
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
#include <algorithm>
1818

1919
#include <fbgemm_gpu/sparse_ops_utils.h>
20-
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
20+
#include "fbgemm_gpu/utils/cuda_prelude.cuh"
21+
#include "fbgemm_gpu/utils/stochastic_rounding.cuh"
2122

2223
#if !( \
2324
defined(USE_ROCM) || \

fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
3636
#include "fbgemm_gpu/fbgemm_tensor_accessor.h"
3737
#include "fbgemm_gpu/sparse_ops_utils.h"
38+
#include "fbgemm_gpu/utils/find_qparams.cuh"
3839
#include "fbgemm_gpu/utils/fixed_divisor.cuh"
3940
#include "fbgemm_gpu/utils/vec4.cuh"
4041
#include "fbgemm_gpu/utils/vec4acc.cuh"

fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh

Lines changed: 2 additions & 329 deletions
Original file line numberDiff line numberDiff line change
@@ -9,336 +9,9 @@
99
#pragma once
1010

1111
#include <ATen/ATen.h>
12-
#include <ATen/AccumulateType.h>
13-
#include <cuda_runtime.h>
14-
#include <curand_kernel.h>
1512
#include <ATen/cuda/CUDAGraphsUtils.cuh>
1613

1714
#include "fbgemm_gpu/utils/cuda_prelude.cuh"
18-
#include "fbgemm_gpu/utils/float.cuh"
19-
#include "fbgemm_gpu/utils/types.h"
15+
#include "fbgemm_gpu/utils/stochastic_rounding.cuh"
2016
#include "fbgemm_gpu/utils/vec4.cuh"
21-
#include "fbgemm_gpu/utils/vec4_rounding.cuh"
22-
23-
namespace fbgemm_gpu {
24-
25-
////////////////////////////////////////////////////////////////////////////////
26-
// Qparams
27-
////////////////////////////////////////////////////////////////////////////////
28-
29-
template <typename dst_t, typename src_t>
30-
DEVICE_INLINE void quantize_store(
31-
dst_t* output,
32-
const Vec4T<src_t>& value,
33-
StochasticRoundingRNGState* state,
34-
const float2 qparams) {
35-
if (!state) {
36-
nearest_rounding_vector<dst_t, src_t>(output, value, qparams);
37-
} else {
38-
stochastic_rounding_vector<dst_t, src_t>(output, value, *state, qparams);
39-
}
40-
}
41-
42-
template <typename dst_t, typename src_t>
43-
DEVICE_INLINE Vec4T<dst_t> dequantize_load(
44-
const src_t* value,
45-
const float2 /* unused */) {
46-
return Vec4T<dst_t>(value);
47-
}
48-
49-
template <>
50-
DEVICE_INLINE Vec4T<float> dequantize_load(
51-
const uint8_t* value,
52-
const float2 qparams) {
53-
Vec4T<float> out;
54-
out.acc.x = value[0] * qparams.x + qparams.y;
55-
out.acc.y = value[1] * qparams.x + qparams.y;
56-
out.acc.z = value[2] * qparams.x + qparams.y;
57-
out.acc.w = value[3] * qparams.x + qparams.y;
58-
return out;
59-
}
60-
61-
template <>
62-
DEVICE_INLINE Vec4T<at::Half> dequantize_load(
63-
const uint8_t* value,
64-
const float2 qparams) {
65-
Vec4T<at::Half> out;
66-
out.acc.x = value[0] * qparams.x + qparams.y;
67-
out.acc.y = value[1] * qparams.x + qparams.y;
68-
out.acc.z = value[2] * qparams.x + qparams.y;
69-
out.acc.w = value[3] * qparams.x + qparams.y;
70-
return out;
71-
}
72-
73-
template <typename emb_t>
74-
DEVICE_INLINE float2 load_qparams_from_row(emb_t* qparam_ptr) {
75-
float2 qparams;
76-
float* qparams_fp_ptr = reinterpret_cast<float*>(qparam_ptr);
77-
qparams.x = qparams_fp_ptr[0];
78-
qparams.y = qparams_fp_ptr[1];
79-
return qparams;
80-
}
81-
82-
template <typename emb_t>
83-
DEVICE_INLINE void store_qparams_to_row(emb_t* ptr, float2 qparams) {
84-
CUDA_KERNEL_ASSERT(false); // Only int8 embeddding should call this
85-
}
86-
87-
template <>
88-
DEVICE_INLINE void store_qparams_to_row(uint8_t* ptr, float2 qparams) {
89-
auto ptr_as_uint = reinterpret_cast<uintptr_t>(ptr);
90-
if (ptr_as_uint % 8 == 0) {
91-
*reinterpret_cast<float2*>(ptr) = qparams;
92-
} else if (ptr_as_uint % 4 == 0) {
93-
auto* ptr_float = reinterpret_cast<float*>(ptr);
94-
auto* qparam_ptr = reinterpret_cast<const float*>(&qparams.x);
95-
#pragma unroll
96-
for (int i = 0; i < 2; ++i) {
97-
ptr_float[i] = qparam_ptr[i];
98-
}
99-
} else if (ptr_as_uint % 2 == 0) {
100-
auto* ptr_16bit = reinterpret_cast<uint16_t*>(ptr);
101-
auto* qparam_ptr = reinterpret_cast<const uint16_t*>(&qparams.x);
102-
#pragma unroll
103-
for (int i = 0; i < 4; ++i) {
104-
ptr_16bit[i] = qparam_ptr[i];
105-
}
106-
} else {
107-
auto* qparam_ptr = reinterpret_cast<const uint8_t*>(&qparams.x);
108-
#pragma unroll
109-
for (int i = 0; i < 8; ++i) {
110-
ptr[i] = qparam_ptr[i];
111-
}
112-
}
113-
}
114-
115-
// Min a register value across all warp threads
116-
template <typename T, int ReduceWidth = kWarpSize>
117-
DEVICE_INLINE T warp_reduce_min(T val) {
118-
#pragma unroll
119-
for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
120-
val = std::min(val, shfl_xor(val, mask));
121-
}
122-
return val;
123-
}
124-
125-
// Max a register value across all warp threads
126-
template <typename T, int ReduceWidth = kWarpSize>
127-
DEVICE_INLINE T warp_reduce_max(T val) {
128-
#pragma unroll
129-
for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
130-
val = std::max(val, shfl_xor(val, mask));
131-
}
132-
return val;
133-
}
134-
135-
template <typename scalar_t>
136-
DEVICE_INLINE float2 warp_find_qparams(scalar_t local_min, scalar_t local_max) {
137-
float2 qparams;
138-
local_min = warp_reduce_min<scalar_t>(local_min);
139-
local_max = warp_reduce_max<scalar_t>(local_max);
140-
if (threadIdx.x == 0) {
141-
qparams.x = (local_max - local_min) / 255.0f;
142-
qparams.y = local_min;
143-
}
144-
qparams.x = shfl_sync(qparams.x, 0);
145-
qparams.y = shfl_sync(qparams.y, 0);
146-
return qparams;
147-
}
148-
149-
////////////////////////////////////////////////////////////////////////////////
150-
// Weight Row
151-
////////////////////////////////////////////////////////////////////////////////
152-
153-
template <typename emb_t, typename cache_t, typename dst_t>
154-
// TODO: pass in dimension info and calculate qparams for rowwise integer
155-
// quantization
156-
struct WeightRow {
157-
// Constructor for no stochastic rounding
158-
DEVICE_INLINE WeightRow(emb_t* row, cache_t* cache_row, int dim)
159-
: row_(row),
160-
cache_row_(cache_row),
161-
dim_(dim),
162-
stoc_rounding_state_(nullptr) {}
163-
164-
// Constructor for stochastic rounding
165-
DEVICE_INLINE WeightRow(
166-
emb_t* row,
167-
cache_t* cache_row,
168-
int dim,
169-
StochasticRoundingRNGState* stoc_rounding_state,
170-
const at::PhiloxCudaState* stochastic_rounding_philox_args,
171-
const uint64_t salt_value)
172-
: row_(row), cache_row_(cache_row), dim_(dim) {
173-
// Set the internal stoc_rounding_state_
174-
stoc_rounding_state_ = stoc_rounding_state;
175-
176-
if constexpr (!std::is_same_v<emb_t, float>) {
177-
if (stoc_rounding_state != nullptr) {
178-
const auto stochastic_rounding_seeds =
179-
at::cuda::philox::unpack(*stochastic_rounding_philox_args);
180-
181-
stochastic_rounding_init(
182-
std::get<0>(stochastic_rounding_seeds) ^
183-
std::get<1>(stochastic_rounding_seeds),
184-
// The salt value should be different for every *run* and every
185-
// *thread*.
186-
salt_value,
187-
stoc_rounding_state);
188-
}
189-
}
190-
}
191-
192-
emb_t* row_;
193-
cache_t* cache_row_;
194-
int dim_;
195-
StochasticRoundingRNGState* stoc_rounding_state_;
196-
197-
// Load from cache if resident; else load from embedding
198-
DEVICE_INLINE Vec4T<dst_t> load(const int32_t d, const float2 qparams) const {
199-
if (cache_row_) {
200-
return dequantize_load<dst_t, cache_t>(cache_row_ + d, qparams);
201-
} else {
202-
return dequantize_load<dst_t, emb_t>(row_ + d, qparams);
203-
}
204-
}
205-
206-
// Write back weight (high precision) to cache if resident; else write to
207-
// embedding assume dst_t is higher precision than cache_t and emb_t
208-
DEVICE_INLINE void
209-
store(const Vec4T<dst_t>& v, const int32_t d, const float2 qparams) {
210-
if (cache_row_) {
211-
quantize_store(cache_row_ + d, v, stoc_rounding_state_, qparams);
212-
} else {
213-
quantize_store(row_ + d, v, stoc_rounding_state_, qparams);
214-
}
215-
}
216-
217-
// Copy vector from src_vec to dst_vec (both are float)
218-
DEVICE_INLINE void same_type_vector_copy(
219-
float* dst_vec,
220-
const float* src_vec) {
221-
*reinterpret_cast<float4*>(dst_vec) =
222-
*reinterpret_cast<const float4*>(src_vec);
223-
}
224-
225-
// Copy vector from src_vec to dst_vec (both are at::Half)
226-
DEVICE_INLINE void same_type_vector_copy(
227-
at::Half* dst_vec,
228-
const at::Half* src_vec) {
229-
*reinterpret_cast<float2*>(dst_vec) =
230-
*reinterpret_cast<const float2*>(src_vec);
231-
}
232-
233-
// Evict cached row into embedding row (high prec -> low prec)
234-
DEVICE_INLINE void evict_cache(const int32_t d, const float2 qparams) {
235-
if constexpr (std::is_same_v<emb_t, cache_t>) {
236-
// No conversion required when emb_t and cache_t are the same type
237-
same_type_vector_copy(
238-
reinterpret_cast<cache_t*>(row_ + d),
239-
reinterpret_cast<const cache_t*>(cache_row_ + d));
240-
} else {
241-
// Does 2-step conversion: cache_t -> FP32 -> weight_t
242-
const auto cache_slice = load(d, qparams);
243-
quantize_store(row_ + d, cache_slice, stoc_rounding_state_, qparams);
244-
}
245-
}
246-
247-
DEVICE_INLINE void store_qparams(const float2 qparams) {
248-
store_qparams_to_row(row_ + dim_, qparams);
249-
}
250-
251-
DEVICE_INLINE float2 load_qparams() const {
252-
if constexpr (std::is_same_v<emb_t, uint8_t>) {
253-
return load_qparams_from_row<emb_t>(row_ + dim_);
254-
} else {
255-
return make_float2(0.0f, 0.0f);
256-
}
257-
}
258-
259-
DEVICE_INLINE void warp_copy_to_cache(
260-
cache_t* dst_row,
261-
const int32_t dim_length,
262-
const int32_t num_lanes,
263-
const int32_t lane_id) {
264-
if constexpr (std::is_same_v<emb_t, cache_t>) {
265-
// No conversion required when emb_t and cache_t are the same type
266-
for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
267-
same_type_vector_copy(
268-
dst_row + d, reinterpret_cast<const cache_t*>(row_ + d));
269-
}
270-
} else {
271-
// Load quantization params from embedding row
272-
const auto qparams = load_qparams();
273-
274-
// Copy over for each warp-sized slice of Vec4's
275-
// Does 2-step conversion: weight_t -> FP32 -> cache_t
276-
for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
277-
const auto slice = load(d, qparams);
278-
quantize_store(dst_row + d, slice, stoc_rounding_state_, qparams);
279-
}
280-
}
281-
}
282-
283-
DEVICE_INLINE void warp_evict_cache(
284-
const int32_t dim_length,
285-
const int32_t num_lanes,
286-
const int32_t lane_id) {
287-
float2 qparams;
288-
289-
if constexpr (std::is_same_v<emb_t, uint8_t>) {
290-
auto local_min = std::numeric_limits<at::acc_type<cache_t, true>>::max();
291-
auto local_max =
292-
std::numeric_limits<at::acc_type<cache_t, true>>::lowest();
293-
294-
// Compute the qparams from the cache row (not embedding row) weights
295-
for (int32_t d = lane_id; d * 4 < dim_length; d += num_lanes) {
296-
const auto cache_slice = load(d * 4, qparams); // qparams not used
297-
local_max = max(local_max, cache_slice.vmax());
298-
local_min = min(local_min, cache_slice.vmin());
299-
}
300-
301-
// Compute the max and min across the warps
302-
qparams = warp_find_qparams(local_min, local_max);
303-
304-
if (lane_id == 0) {
305-
// Store the qparams into the embedding row
306-
store_qparams(qparams);
307-
}
308-
}
309-
310-
for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
311-
// Evict the slice into the embedding row
312-
evict_cache(d, qparams);
313-
}
314-
}
315-
};
316-
317-
template <typename emb_t, typename cache_t, typename dst_t, bool uses_cache>
318-
struct WeightRowAccessor {
319-
const emb_t* row_;
320-
const cache_t* cache_row_;
321-
const int dim_;
322-
323-
DEVICE_INLINE
324-
WeightRowAccessor(const emb_t* row, const cache_t* cache_row, const int dim)
325-
: row_(row), cache_row_(cache_row), dim_(dim) {}
326-
327-
DEVICE_INLINE Vec4T<dst_t> load(const int32_t d, const float2 qparams) const {
328-
if constexpr (uses_cache) {
329-
return dequantize_load<dst_t, cache_t>(cache_row_ + d, qparams);
330-
} else {
331-
return dequantize_load<dst_t, emb_t>(row_ + d, qparams);
332-
}
333-
}
334-
335-
DEVICE_INLINE float2 load_qparams() const {
336-
if constexpr (std::is_same_v<emb_t, uint8_t>) {
337-
return load_qparams_from_row<emb_t>(row_ + dim_);
338-
} else {
339-
return make_float2(0.0f, 0.0f);
340-
}
341-
}
342-
};
343-
344-
} // namespace fbgemm_gpu
17+
#include "fbgemm_gpu/utils/weight_row.cuh"

fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <ATen/ATen.h>
1212
#include <cuda.h>
13+
#include <ATen/cuda/CUDAGraphsUtils.cuh>
1314
#if !( \
1415
defined(USE_ROCM) || \
1516
((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
@@ -20,10 +21,6 @@
2021
#endif
2122
#include <cuda_fp16.h>
2223

23-
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 9000
24-
#define FBGEMM_USE_SUBWARP_SHUFFLE
25-
#endif
26-
2724
namespace {
2825

2926
int get_device_sm_cnt_() {
@@ -36,6 +33,10 @@ int get_device_sm_cnt_() {
3633

3734
namespace fbgemm_gpu {
3835

36+
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 9000
37+
#define FBGEMM_USE_SUBWARP_SHUFFLE
38+
#endif
39+
3940
#define DEVICE_INLINE __device__ inline __attribute__((always_inline))
4041

4142
#define CUDA_DEVICE_GUARD(TENSOR) \

0 commit comments

Comments
 (0)