Skip to content

Commit 8afe37a

Browse files
committed
fwd support split batch (#78)
* fwd support split batch * remove confusing assert * fix several >4G address type
1 parent 801261f commit 8afe37a

File tree

3 files changed

+104
-13
lines changed

3 files changed

+104
-13
lines changed

driver/conv_driver.cpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ measured_fp32_conv_gflops(double time_ms, size_t n, size_t c, size_t hi,
165165
#define IGEMM_CONFIG_FILE "igemm_gtc.config"
166166
#endif
167167

168+
#define IGEMM_RUN_ONLY_KERNEL_DEFAULT "off"
169+
168170
#define WARMUP 3
169171
#define REPEAT 8
170172
#define SCLK_MHZ 1283
@@ -214,14 +216,14 @@ struct distribution_t<float>{
214216
};
215217

216218
template <typename Dst_T, typename Src_T>
217-
void block_wise_rand_generator(Dst_T *p, int tid, int block_size, int total_size, Src_T min, Src_T max, Src_T scale)
219+
void block_wise_rand_generator(Dst_T *p, int tid, int block_size, size_t total_size, Src_T min, Src_T max, Src_T scale)
218220
{
219221
std::mt19937 rng(std::chrono::system_clock::now()
220222
.time_since_epoch()
221223
.count() +
222224
std::hash<std::thread::id>()(std::this_thread::get_id()));
223225
distribution_t<Src_T> distribution(min,max);
224-
for (int i = tid; i < total_size; i += block_size) {
226+
for (size_t i = tid; i < total_size; i += block_size) {
225227
p[i] = static_cast<Dst_T>(scale * distribution(rng));
226228
}
227229
}
@@ -342,6 +344,7 @@ void dump_arg(const args_t *arg) {
342344
int main(int argc, char **argv) {
343345
char *hsaco = env_get_str("IGEMM_HSACO", IGEMM_HSACO);
344346
char *config_file = env_get_str("IGEMM_CONFIG_FILE", IGEMM_CONFIG_FILE);
347+
std::string run_only_kernel = env_get_str("IGEMM_RUN_ONLY_KERNEL", IGEMM_RUN_ONLY_KERNEL_DEFAULT);
345348
int warmup = env_get_int("IGEMM_WARMUP", WARMUP);
346349
int repeat = env_get_int("IGEMM_REPEAT", REPEAT);
347350
int sclk_mhz = env_get_int("IGEMM_SCLK_MHZ", SCLK_MHZ);
@@ -457,8 +460,8 @@ int main(int argc, char **argv) {
457460
gen_rand_vector<float, float>(host_input, static_cast<size_t>(n) * c * hi * wi, 0.0, 1.0);
458461
gen_rand_vector<float, float>(host_weight, static_cast<size_t>(k) * c * y * x, -0.5, 0.5);
459462

460-
//gen_rand_vector<float, int>(host_input, n * c * hi * wi, 1, 1);
461-
//gen_rand_vector<float, int>(host_weight, k * c * y * x, 1, 1);
463+
//gen_rand_vector<float, int>(host_input, static_cast<size_t>(n) * c * hi * wi, 1, 1);
464+
//gen_rand_vector<float, int>(host_weight, static_cast<size_t>(k) * c * y * x, 1, 1);
462465

463466
#ifdef USE_GPU_NAIVE_CONV
464467
HIP_CALL(hipMemcpy(device_input, host_input,
@@ -506,6 +509,9 @@ int main(int argc, char **argv) {
506509
double nrms = get_fwd_nrms();
507510
for (int i = 0; i < tunables.size(); i++) {
508511
igemm_gtc_tunable_t *tunable = &tunables[i];
512+
if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT)
513+
if(run_only_kernel != conv_fwd_driver.get_kernel_name(tunable))
514+
continue;
509515

510516
printf("[fwd:%2d] %s, ", i, conv_fwd_driver.get_kernel_name(tunable).c_str());
511517
fflush(stdout);
@@ -569,8 +575,8 @@ int main(int argc, char **argv) {
569575
gen_rand_vector<float, float>(host_output, static_cast<size_t>(n) * k * ho * wo, 0.0, 1.0);
570576
gen_rand_vector<float, float>(host_weight, static_cast<size_t>(k) * c * y * x, -0.5, 0.5);
571577
gen_rand_vector<float, float>(host_input, static_cast<size_t>(n) * c * hi * wi, 999999., 9999999.); // manually input value to a very large number
572-
// gen_rand_vector<float, int>(host_output, n * k * ho * wo,1, 1);
573-
// gen_rand_vector<float, int>(host_weight, k * c * y * x, 1, 1);
578+
// gen_rand_vector<float, int>(host_output, static_cast<size_t>(n) * k * ho * wo,1, 1);
579+
// gen_rand_vector<float, int>(host_weight, static_cast<size_t>(k) * c * y * x, 1, 1);
574580
#ifdef USE_GPU_NAIVE_CONV
575581
HIP_CALL(hipMemcpy(device_output, host_output,
576582
static_cast<size_t>(n) * k * ho * wo * sizeof(float), hipMemcpyHostToDevice));
@@ -618,6 +624,9 @@ int main(int argc, char **argv) {
618624
double nrms = get_bwd_nrms();
619625
for (int i = 0; i < tunables.size(); i++) {
620626
igemm_gtc_tunable_t *tunable = &tunables[i];
627+
if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT)
628+
if(run_only_kernel != conv_bwd_driver.get_kernel_name(tunable))
629+
continue;
621630

622631
printf("[bwd:%2d] %s, ", i, conv_bwd_driver.get_kernel_name(tunable).c_str());
623632
fflush(stdout);
@@ -680,8 +689,8 @@ int main(int argc, char **argv) {
680689
// gen rand
681690
gen_rand_vector<float, float>(host_input, static_cast<size_t>(n) * c * hi * wi, 0.0, 1.0);
682691
gen_rand_vector<float, float>(host_output, static_cast<size_t>(n) * k * ho * wo, -0.5, 0.5);
683-
//gen_rand_vector<float, int>(host_input, n * k * hi * wi, -5, 5);
684-
//gen_rand_vector<float, int>(host_output, n * k * ho * wo, 1, 1);
692+
//gen_rand_vector<float, int>(host_input, static_cast<size_t>(n) * k * hi * wi, -5, 5);
693+
//gen_rand_vector<float, int>(host_output, static_cast<size_t>(n) * k * ho * wo, 1, 1);
685694
#ifdef USE_GPU_NAIVE_CONV
686695
HIP_CALL(hipMemcpy(device_input, host_input,
687696
static_cast<size_t>(n) * c * hi * wi * sizeof(float), hipMemcpyHostToDevice));
@@ -763,13 +772,16 @@ int main(int argc, char **argv) {
763772

764773
for (int i = 0; i < tunables.size(); i++) {
765774
igemm_gtc_tunable_t *tunable = &tunables[i];
775+
if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT)
776+
if(run_only_kernel != conv_wrw_driver.get_kernel_name(tunable))
777+
continue;
766778

767779
printf("[wrw:%2d] %s, ", i, conv_wrw_driver.get_kernel_name(tunable).c_str());
768780
fflush(stdout);
769781

770782
if (need_verify)
771783
HIP_CALL(hipMemset(device_weight, 0,
772-
k * c * y * x * sizeof(float)));
784+
static_cast<size_t>(k) * c * y * x * sizeof(float)));
773785
result_t result =
774786
conv_wrw_driver.run(&conv_args, tunable, module, device_input,
775787
device_weight, device_output, warmup, repeat);

driver/igemm_fwd_gtc_driver.h

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ class igemm_fwd_gtc_t {
174174
int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w);
175175
int group = arg->get_int("group_count");
176176

177+
int splits = split_batch_size(arg, tunable);
178+
n = n/splits; // split batch size here
179+
177180
int gemm_m_per_block = tunable->gemm_m_per_block;
178181
int gemm_n_per_block = tunable->gemm_n_per_block;
179182
int nxe = tunable->nxe;
@@ -201,6 +204,54 @@ class igemm_fwd_gtc_t {
201204
return grid_size;
202205
}
203206

207+
// this is to support big tensor > 4G. need to decide how many splits needed
208+
// return the number of splits
209+
int split_batch_size(const args_t *arg, const igemm_gtc_tunable_t *tunable)
210+
{
211+
int hi = arg->get_int("in_h");
212+
int wi = arg->get_int("in_w");
213+
int n = arg->get_int("batchsize");
214+
int k = arg->get_int("out_channels");
215+
int c = arg->get_int("in_channels");
216+
217+
int stride_h = arg->get_int("conv_stride_h");
218+
int stride_w = arg->get_int("conv_stride_w");
219+
int dilation_h = arg->get_int("dilation_h");
220+
int dilation_w = arg->get_int("dilation_w");
221+
int pad_h = arg->get_int("pad_h");
222+
int pad_w = arg->get_int("pad_w");
223+
int y = arg->get_int("fil_h");
224+
int x = arg->get_int("fil_w");
225+
int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h);
226+
int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w);
227+
228+
int data_byte = utility_string_to_data_byte(tunable->precision);
229+
size_t image_size_input = static_cast<size_t>(c) * hi * wi * data_byte;
230+
size_t image_size_output = static_cast<size_t>(k) * ho * wo * data_byte;
231+
size_t size_4g = 0xffffffffUL;
232+
if(image_size_input >= size_4g || image_size_output >= size_4g)
233+
return 0;
234+
235+
size_t image_size = image_size_input >= image_size_output ? image_size_input : image_size_output;
236+
size_t splited_n = size_4g / image_size;
237+
238+
// round up splits, we must match
239+
// 1. splited_n * image_size < size_4g
240+
// 2. n % splited_n == 0
241+
// if(splited_n >= n)
242+
// return 1;
243+
assert(splited_n != 0);
244+
while(splited_n >= 1){
245+
// printf("n:%d, splited_n:%d\n", n, splited_n);
246+
if(n % splited_n == 0)
247+
break;
248+
splited_n--;
249+
}
250+
251+
assert(splited_n * image_size < size_4g && n % splited_n == 0);
252+
return n / splited_n;
253+
}
254+
204255
bool tunable_is_valid(const args_t *arg,
205256
const igemm_gtc_tunable_t *tunable)
206257
{
@@ -224,6 +275,13 @@ class igemm_fwd_gtc_t {
224275

225276
assert(c % group == 0 && k % group == 0);
226277

278+
int splits = split_batch_size(arg, tunable);
279+
if(splits == 0){
280+
printf("image size (c*h*w) is bigger than 4g, which is not supported now\n");
281+
return false;
282+
}
283+
n = n/splits; // split batch size here
284+
227285
int gemm_m_per_block = tunable->gemm_m_per_block;
228286
int gemm_n_per_block = tunable->gemm_n_per_block;
229287
int gemm_k_per_block = tunable->gemm_k_per_block;
@@ -375,6 +433,9 @@ class igemm_fwd_gtc_t {
375433

376434
assert(c % group == 0 && k % group == 0);
377435

436+
int splits = split_batch_size(arg, tunable);
437+
n = n/splits; // split batch size here
438+
378439
int gemm_m_per_block = tunable->gemm_m_per_block;
379440
int gemm_n_per_block = tunable->gemm_n_per_block;
380441
int gemm_k_per_block = tunable->gemm_k_per_block;
@@ -494,7 +555,7 @@ class igemm_fwd_gtc_t {
494555
hipModuleGetFunction(&kernel_func, module, kernel_name.c_str()));
495556

496557
auto launch_fwd = [&]() -> float {
497-
// printf("launch fwd block:%d, grid:%d\n", block_size, grid_size);
558+
// printf("launch fwd block:%d, grid:%dx%d\n", block_size, grid_size, splits);
498559
// dump_fwd_karg(&karg);
499560
void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, static_cast<void*>(&karg_buffer[0]),
500561
HIP_LAUNCH_PARAM_BUFFER_SIZE, &karg_size,
@@ -508,7 +569,7 @@ class igemm_fwd_gtc_t {
508569
hipEventCreate(&stop);
509570

510571
// for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem
511-
HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, 1, 1,
572+
HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, splits, 1,
512573
block_size, 1, 1, 0, 0, NULL,
513574
(void **)&config, start, stop));
514575

@@ -520,7 +581,7 @@ class igemm_fwd_gtc_t {
520581
gpu_timer_t timer(NULL);
521582
timer.start();
522583

523-
HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, 1, 1,
584+
HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, splits, 1,
524585
block_size, 1, 1, 0, 0, NULL,
525586
(void **)&config));
526587

igemm/algo/igemm_fwd_gtc.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,8 @@ def __init__(self, mc, outer):
674674
sseq = gpr_sequencer_t()
675675
self.outer = outer
676676
self.s_ka = sym_t('s_ka' , sseq(2))
677-
self.s_bx = sym_t('s_bx' , sseq(2))
677+
self.s_bx = sym_t('s_bx' , sseq(1))
678+
self.s_by = sym_t('s_by' , sseq(1))
678679
self.s_p_in = sym_t('s_p_in' , sseq(4))
679680
self.s_p_wei = sym_t('s_p_wei' , sseq(4))
680681
self.s_p_out = sym_t('s_p_out' , sseq(4))
@@ -1230,6 +1231,7 @@ def get_kernel_code(self):
12301231
kernel_code = amdgpu_kernel_code_t({
12311232
'enable_sgpr_kernarg_segment_ptr' : 1,
12321233
'enable_sgpr_workgroup_id_x' : 1,
1234+
'enable_sgpr_workgroup_id_y' : 1,
12331235
'enable_vgpr_workitem_id' : 0,
12341236
'workgroup_group_segment_byte_size' : self.tunable.lds_total,
12351237
'kernarg_segment_byte_size' : self.karg.get_count(),
@@ -1521,6 +1523,22 @@ def emit_kernel_prologue(self):
15211523
self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_n()}], {igemm_log2(nb_n0)}")
15221524
self._emit(f"s_mul_i32 s[{s.s_out_stride_n0()}], s[{s.s_out_stride_n()}], s[{s.s_tmp()}]")
15231525

1526+
# calculate batch split and accumulate the base pointer for input/output
1527+
self._emit(f"s_mul_i32 s[{s.s_tmp(0)}], s[{s.s_n()}], s[{s.s_in_stride_n()}]")
1528+
self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_n()}], s[{s.s_out_stride_n()}]")
1529+
self._emit(f"s_lshl_b32 s[{s.s_tmp(4)}], s[{s.s_tmp(0)}], {igemm_log2(data_byte)}")
1530+
self._emit(f"s_lshl_b32 s[{s.s_tmp(5)}], s[{s.s_tmp(1)}], {igemm_log2(data_byte)}")
1531+
1532+
self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_by()}], s[{s.s_tmp(4)}]")
1533+
self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_by()}], s[{s.s_tmp(4)}]")
1534+
self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]")
1535+
self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]")
1536+
1537+
self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_by()}], s[{s.s_tmp(5)}]")
1538+
self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_by()}], s[{s.s_tmp(5)}]")
1539+
self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp()}]")
1540+
self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]")
1541+
15241542
# early init s_knum in case shifted
15251543
if self.tunable.nxe != 0:
15261544
self._emit(f"s_mul_i32 s[{s.s_knum()}], s[{s.s_wei_stride_c()}], s[{s.s_c()}]")

0 commit comments

Comments
 (0)