|
9 | 9 | #pragma once |
10 | 10 |
|
11 | 11 | #include <ATen/ATen.h> |
12 | | -#include <ATen/AccumulateType.h> |
13 | | -#include <cuda_runtime.h> |
14 | | -#include <curand_kernel.h> |
15 | 12 | #include <ATen/cuda/CUDAGraphsUtils.cuh> |
16 | 13 |
|
17 | 14 | #include "fbgemm_gpu/utils/cuda_prelude.cuh" |
18 | | -#include "fbgemm_gpu/utils/float.cuh" |
19 | | -#include "fbgemm_gpu/utils/types.h" |
| 15 | +#include "fbgemm_gpu/utils/stochastic_rounding.cuh" |
20 | 16 | #include "fbgemm_gpu/utils/vec4.cuh" |
21 | | -#include "fbgemm_gpu/utils/vec4_rounding.cuh" |
22 | | - |
23 | | -namespace fbgemm_gpu { |
24 | | - |
25 | | -//////////////////////////////////////////////////////////////////////////////// |
26 | | -// Qparams |
27 | | -//////////////////////////////////////////////////////////////////////////////// |
28 | | - |
29 | | -template <typename dst_t, typename src_t> |
30 | | -DEVICE_INLINE void quantize_store( |
31 | | - dst_t* output, |
32 | | - const Vec4T<src_t>& value, |
33 | | - StochasticRoundingRNGState* state, |
34 | | - const float2 qparams) { |
35 | | - if (!state) { |
36 | | - nearest_rounding_vector<dst_t, src_t>(output, value, qparams); |
37 | | - } else { |
38 | | - stochastic_rounding_vector<dst_t, src_t>(output, value, *state, qparams); |
39 | | - } |
40 | | -} |
41 | | - |
42 | | -template <typename dst_t, typename src_t> |
43 | | -DEVICE_INLINE Vec4T<dst_t> dequantize_load( |
44 | | - const src_t* value, |
45 | | - const float2 /* unused */) { |
46 | | - return Vec4T<dst_t>(value); |
47 | | -} |
48 | | - |
49 | | -template <> |
50 | | -DEVICE_INLINE Vec4T<float> dequantize_load( |
51 | | - const uint8_t* value, |
52 | | - const float2 qparams) { |
53 | | - Vec4T<float> out; |
54 | | - out.acc.x = value[0] * qparams.x + qparams.y; |
55 | | - out.acc.y = value[1] * qparams.x + qparams.y; |
56 | | - out.acc.z = value[2] * qparams.x + qparams.y; |
57 | | - out.acc.w = value[3] * qparams.x + qparams.y; |
58 | | - return out; |
59 | | -} |
60 | | - |
61 | | -template <> |
62 | | -DEVICE_INLINE Vec4T<at::Half> dequantize_load( |
63 | | - const uint8_t* value, |
64 | | - const float2 qparams) { |
65 | | - Vec4T<at::Half> out; |
66 | | - out.acc.x = value[0] * qparams.x + qparams.y; |
67 | | - out.acc.y = value[1] * qparams.x + qparams.y; |
68 | | - out.acc.z = value[2] * qparams.x + qparams.y; |
69 | | - out.acc.w = value[3] * qparams.x + qparams.y; |
70 | | - return out; |
71 | | -} |
72 | | - |
73 | | -template <typename emb_t> |
74 | | -DEVICE_INLINE float2 load_qparams_from_row(emb_t* qparam_ptr) { |
75 | | - float2 qparams; |
76 | | - float* qparams_fp_ptr = reinterpret_cast<float*>(qparam_ptr); |
77 | | - qparams.x = qparams_fp_ptr[0]; |
78 | | - qparams.y = qparams_fp_ptr[1]; |
79 | | - return qparams; |
80 | | -} |
81 | | - |
82 | | -template <typename emb_t> |
83 | | -DEVICE_INLINE void store_qparams_to_row(emb_t* ptr, float2 qparams) { |
84 | | - CUDA_KERNEL_ASSERT(false); // Only int8 embeddding should call this |
85 | | -} |
86 | | - |
87 | | -template <> |
88 | | -DEVICE_INLINE void store_qparams_to_row(uint8_t* ptr, float2 qparams) { |
89 | | - auto ptr_as_uint = reinterpret_cast<uintptr_t>(ptr); |
90 | | - if (ptr_as_uint % 8 == 0) { |
91 | | - *reinterpret_cast<float2*>(ptr) = qparams; |
92 | | - } else if (ptr_as_uint % 4 == 0) { |
93 | | - auto* ptr_float = reinterpret_cast<float*>(ptr); |
94 | | - auto* qparam_ptr = reinterpret_cast<const float*>(&qparams.x); |
95 | | -#pragma unroll |
96 | | - for (int i = 0; i < 2; ++i) { |
97 | | - ptr_float[i] = qparam_ptr[i]; |
98 | | - } |
99 | | - } else if (ptr_as_uint % 2 == 0) { |
100 | | - auto* ptr_16bit = reinterpret_cast<uint16_t*>(ptr); |
101 | | - auto* qparam_ptr = reinterpret_cast<const uint16_t*>(&qparams.x); |
102 | | -#pragma unroll |
103 | | - for (int i = 0; i < 4; ++i) { |
104 | | - ptr_16bit[i] = qparam_ptr[i]; |
105 | | - } |
106 | | - } else { |
107 | | - auto* qparam_ptr = reinterpret_cast<const uint8_t*>(&qparams.x); |
108 | | -#pragma unroll |
109 | | - for (int i = 0; i < 8; ++i) { |
110 | | - ptr[i] = qparam_ptr[i]; |
111 | | - } |
112 | | - } |
113 | | -} |
114 | | - |
115 | | -// Min a register value across all warp threads |
116 | | -template <typename T, int ReduceWidth = kWarpSize> |
117 | | -DEVICE_INLINE T warp_reduce_min(T val) { |
118 | | -#pragma unroll |
119 | | - for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { |
120 | | - val = std::min(val, shfl_xor(val, mask)); |
121 | | - } |
122 | | - return val; |
123 | | -} |
124 | | - |
125 | | -// Max a register value across all warp threads |
126 | | -template <typename T, int ReduceWidth = kWarpSize> |
127 | | -DEVICE_INLINE T warp_reduce_max(T val) { |
128 | | -#pragma unroll |
129 | | - for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { |
130 | | - val = std::max(val, shfl_xor(val, mask)); |
131 | | - } |
132 | | - return val; |
133 | | -} |
134 | | - |
135 | | -template <typename scalar_t> |
136 | | -DEVICE_INLINE float2 warp_find_qparams(scalar_t local_min, scalar_t local_max) { |
137 | | - float2 qparams; |
138 | | - local_min = warp_reduce_min<scalar_t>(local_min); |
139 | | - local_max = warp_reduce_max<scalar_t>(local_max); |
140 | | - if (threadIdx.x == 0) { |
141 | | - qparams.x = (local_max - local_min) / 255.0f; |
142 | | - qparams.y = local_min; |
143 | | - } |
144 | | - qparams.x = shfl_sync(qparams.x, 0); |
145 | | - qparams.y = shfl_sync(qparams.y, 0); |
146 | | - return qparams; |
147 | | -} |
148 | | - |
149 | | -//////////////////////////////////////////////////////////////////////////////// |
150 | | -// Weight Row |
151 | | -//////////////////////////////////////////////////////////////////////////////// |
152 | | - |
153 | | -template <typename emb_t, typename cache_t, typename dst_t> |
154 | | -// TODO: pass in dimension info and calculate qparams for rowwise integer |
155 | | -// quantization |
156 | | -struct WeightRow { |
157 | | - // Constructor for no stochastic rounding |
158 | | - DEVICE_INLINE WeightRow(emb_t* row, cache_t* cache_row, int dim) |
159 | | - : row_(row), |
160 | | - cache_row_(cache_row), |
161 | | - dim_(dim), |
162 | | - stoc_rounding_state_(nullptr) {} |
163 | | - |
164 | | - // Constructor for stochastic rounding |
165 | | - DEVICE_INLINE WeightRow( |
166 | | - emb_t* row, |
167 | | - cache_t* cache_row, |
168 | | - int dim, |
169 | | - StochasticRoundingRNGState* stoc_rounding_state, |
170 | | - const at::PhiloxCudaState* stochastic_rounding_philox_args, |
171 | | - const uint64_t salt_value) |
172 | | - : row_(row), cache_row_(cache_row), dim_(dim) { |
173 | | - // Set the internal stoc_rounding_state_ |
174 | | - stoc_rounding_state_ = stoc_rounding_state; |
175 | | - |
176 | | - if constexpr (!std::is_same_v<emb_t, float>) { |
177 | | - if (stoc_rounding_state != nullptr) { |
178 | | - const auto stochastic_rounding_seeds = |
179 | | - at::cuda::philox::unpack(*stochastic_rounding_philox_args); |
180 | | - |
181 | | - stochastic_rounding_init( |
182 | | - std::get<0>(stochastic_rounding_seeds) ^ |
183 | | - std::get<1>(stochastic_rounding_seeds), |
184 | | - // The salt value should be different for every *run* and every |
185 | | - // *thread*. |
186 | | - salt_value, |
187 | | - stoc_rounding_state); |
188 | | - } |
189 | | - } |
190 | | - } |
191 | | - |
192 | | - emb_t* row_; |
193 | | - cache_t* cache_row_; |
194 | | - int dim_; |
195 | | - StochasticRoundingRNGState* stoc_rounding_state_; |
196 | | - |
197 | | - // Load from cache if resident; else load from embedding |
198 | | - DEVICE_INLINE Vec4T<dst_t> load(const int32_t d, const float2 qparams) const { |
199 | | - if (cache_row_) { |
200 | | - return dequantize_load<dst_t, cache_t>(cache_row_ + d, qparams); |
201 | | - } else { |
202 | | - return dequantize_load<dst_t, emb_t>(row_ + d, qparams); |
203 | | - } |
204 | | - } |
205 | | - |
206 | | - // Write back weight (high precision) to cache if resident; else write to |
207 | | - // embedding assume dst_t is higher precision than cache_t and emb_t |
208 | | - DEVICE_INLINE void |
209 | | - store(const Vec4T<dst_t>& v, const int32_t d, const float2 qparams) { |
210 | | - if (cache_row_) { |
211 | | - quantize_store(cache_row_ + d, v, stoc_rounding_state_, qparams); |
212 | | - } else { |
213 | | - quantize_store(row_ + d, v, stoc_rounding_state_, qparams); |
214 | | - } |
215 | | - } |
216 | | - |
217 | | - // Copy vector from src_vec to dst_vec (both are float) |
218 | | - DEVICE_INLINE void same_type_vector_copy( |
219 | | - float* dst_vec, |
220 | | - const float* src_vec) { |
221 | | - *reinterpret_cast<float4*>(dst_vec) = |
222 | | - *reinterpret_cast<const float4*>(src_vec); |
223 | | - } |
224 | | - |
225 | | - // Copy vector from src_vec to dst_vec (both are at::Half) |
226 | | - DEVICE_INLINE void same_type_vector_copy( |
227 | | - at::Half* dst_vec, |
228 | | - const at::Half* src_vec) { |
229 | | - *reinterpret_cast<float2*>(dst_vec) = |
230 | | - *reinterpret_cast<const float2*>(src_vec); |
231 | | - } |
232 | | - |
233 | | - // Evict cached row into embedding row (high prec -> low prec) |
234 | | - DEVICE_INLINE void evict_cache(const int32_t d, const float2 qparams) { |
235 | | - if constexpr (std::is_same_v<emb_t, cache_t>) { |
236 | | - // No conversion required when emb_t and cache_t are the same type |
237 | | - same_type_vector_copy( |
238 | | - reinterpret_cast<cache_t*>(row_ + d), |
239 | | - reinterpret_cast<const cache_t*>(cache_row_ + d)); |
240 | | - } else { |
241 | | - // Does 2-step conversion: cache_t -> FP32 -> weight_t |
242 | | - const auto cache_slice = load(d, qparams); |
243 | | - quantize_store(row_ + d, cache_slice, stoc_rounding_state_, qparams); |
244 | | - } |
245 | | - } |
246 | | - |
247 | | - DEVICE_INLINE void store_qparams(const float2 qparams) { |
248 | | - store_qparams_to_row(row_ + dim_, qparams); |
249 | | - } |
250 | | - |
251 | | - DEVICE_INLINE float2 load_qparams() const { |
252 | | - if constexpr (std::is_same_v<emb_t, uint8_t>) { |
253 | | - return load_qparams_from_row<emb_t>(row_ + dim_); |
254 | | - } else { |
255 | | - return make_float2(0.0f, 0.0f); |
256 | | - } |
257 | | - } |
258 | | - |
259 | | - DEVICE_INLINE void warp_copy_to_cache( |
260 | | - cache_t* dst_row, |
261 | | - const int32_t dim_length, |
262 | | - const int32_t num_lanes, |
263 | | - const int32_t lane_id) { |
264 | | - if constexpr (std::is_same_v<emb_t, cache_t>) { |
265 | | - // No conversion required when emb_t and cache_t are the same type |
266 | | - for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) { |
267 | | - same_type_vector_copy( |
268 | | - dst_row + d, reinterpret_cast<const cache_t*>(row_ + d)); |
269 | | - } |
270 | | - } else { |
271 | | - // Load quantization params from embedding row |
272 | | - const auto qparams = load_qparams(); |
273 | | - |
274 | | - // Copy over for each warp-sized slice of Vec4's |
275 | | - // Does 2-step conversion: weight_t -> FP32 -> cache_t |
276 | | - for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) { |
277 | | - const auto slice = load(d, qparams); |
278 | | - quantize_store(dst_row + d, slice, stoc_rounding_state_, qparams); |
279 | | - } |
280 | | - } |
281 | | - } |
282 | | - |
283 | | - DEVICE_INLINE void warp_evict_cache( |
284 | | - const int32_t dim_length, |
285 | | - const int32_t num_lanes, |
286 | | - const int32_t lane_id) { |
287 | | - float2 qparams; |
288 | | - |
289 | | - if constexpr (std::is_same_v<emb_t, uint8_t>) { |
290 | | - auto local_min = std::numeric_limits<at::acc_type<cache_t, true>>::max(); |
291 | | - auto local_max = |
292 | | - std::numeric_limits<at::acc_type<cache_t, true>>::lowest(); |
293 | | - |
294 | | - // Compute the qparams from the cache row (not embedding row) weights |
295 | | - for (int32_t d = lane_id; d * 4 < dim_length; d += num_lanes) { |
296 | | - const auto cache_slice = load(d * 4, qparams); // qparams not used |
297 | | - local_max = max(local_max, cache_slice.vmax()); |
298 | | - local_min = min(local_min, cache_slice.vmin()); |
299 | | - } |
300 | | - |
301 | | - // Compute the max and min across the warps |
302 | | - qparams = warp_find_qparams(local_min, local_max); |
303 | | - |
304 | | - if (lane_id == 0) { |
305 | | - // Store the qparams into the embedding row |
306 | | - store_qparams(qparams); |
307 | | - } |
308 | | - } |
309 | | - |
310 | | - for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) { |
311 | | - // Evict the slice into the embedding row |
312 | | - evict_cache(d, qparams); |
313 | | - } |
314 | | - } |
315 | | -}; |
316 | | - |
317 | | -template <typename emb_t, typename cache_t, typename dst_t, bool uses_cache> |
318 | | -struct WeightRowAccessor { |
319 | | - const emb_t* row_; |
320 | | - const cache_t* cache_row_; |
321 | | - const int dim_; |
322 | | - |
323 | | - DEVICE_INLINE |
324 | | - WeightRowAccessor(const emb_t* row, const cache_t* cache_row, const int dim) |
325 | | - : row_(row), cache_row_(cache_row), dim_(dim) {} |
326 | | - |
327 | | - DEVICE_INLINE Vec4T<dst_t> load(const int32_t d, const float2 qparams) const { |
328 | | - if constexpr (uses_cache) { |
329 | | - return dequantize_load<dst_t, cache_t>(cache_row_ + d, qparams); |
330 | | - } else { |
331 | | - return dequantize_load<dst_t, emb_t>(row_ + d, qparams); |
332 | | - } |
333 | | - } |
334 | | - |
335 | | - DEVICE_INLINE float2 load_qparams() const { |
336 | | - if constexpr (std::is_same_v<emb_t, uint8_t>) { |
337 | | - return load_qparams_from_row<emb_t>(row_ + dim_); |
338 | | - } else { |
339 | | - return make_float2(0.0f, 0.0f); |
340 | | - } |
341 | | - } |
342 | | -}; |
343 | | - |
344 | | -} // namespace fbgemm_gpu |
| 17 | +#include "fbgemm_gpu/utils/weight_row.cuh" |
0 commit comments