Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions driver/conv_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ measured_fp32_conv_gflops(double time_ms, size_t n, size_t c, size_t hi,
#define IGEMM_CONFIG_FILE "igemm_gtc.config"
#endif

#define IGEMM_RUN_ONLY_KERNEL_DEFAULT "off"

#define WARMUP 3
#define REPEAT 8
#define SCLK_MHZ 1283
Expand Down Expand Up @@ -224,14 +226,14 @@ struct distribution_t<float>{
};

template <typename Dst_T, typename Src_T>
void block_wise_rand_generator(Dst_T *p, int tid, int block_size, int total_size, Src_T min, Src_T max, Src_T scale)
void block_wise_rand_generator(Dst_T *p, int tid, int block_size, size_t total_size, Src_T min, Src_T max, Src_T scale)
{
std::mt19937 rng(std::chrono::system_clock::now()
.time_since_epoch()
.count() +
std::hash<std::thread::id>()(std::this_thread::get_id()));
distribution_t<Src_T> distribution(min,max);
for (int i = tid; i < total_size; i += block_size) {
for (size_t i = tid; i < total_size; i += block_size) {
p[i] = static_cast<Dst_T>(scale * distribution(rng));
}
}
Expand Down Expand Up @@ -350,6 +352,7 @@ void dump_arg(const args_t *arg) {
int main(int argc, char **argv) {
char *hsaco = env_get_str("IGEMM_HSACO", IGEMM_HSACO);
char *config_file = env_get_str("IGEMM_CONFIG_FILE", IGEMM_CONFIG_FILE);
std::string run_only_kernel = env_get_str("IGEMM_RUN_ONLY_KERNEL", IGEMM_RUN_ONLY_KERNEL_DEFAULT);
int warmup = env_get_int("IGEMM_WARMUP", WARMUP);
int repeat = env_get_int("IGEMM_REPEAT", REPEAT);
int sclk_mhz = env_get_int("IGEMM_SCLK_MHZ", SCLK_MHZ);
Expand Down Expand Up @@ -457,8 +460,8 @@ int main(int argc, char **argv) {
gen_rand_vector<float, float>(host_input, static_cast<size_t>(n) * c * hi * wi, 0.0, 1.0);
gen_rand_vector<float, float>(host_weight, static_cast<size_t>(k) * c * y * x, -0.5, 0.5);

//gen_rand_vector<float, int>(host_input, n * c * hi * wi, 1, 1);
//gen_rand_vector<float, int>(host_weight, k * c * y * x, 1, 1);
//gen_rand_vector<float, int>(host_input, static_cast<size_t>(n) * c * hi * wi, 1, 1);
//gen_rand_vector<float, int>(host_weight, static_cast<size_t>(k) * c * y * x, 1, 1);

#ifdef USE_GPU_NAIVE_CONV
HIP_CALL(hipMemcpy(device_input, host_input,
Expand Down Expand Up @@ -491,6 +494,9 @@ int main(int argc, char **argv) {
double nrms = get_fwd_nrms();
for (int i = 0; i < tunables.size(); i++) {
igemm_gtc_tunable_t *tunable = &tunables[i];
if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT)
if(run_only_kernel != conv_fwd_driver.get_kernel_name(tunable))
continue;

printf("[fwd:%2d] %s, ", i, conv_fwd_driver.get_kernel_name(tunable).c_str());
fflush(stdout);
Expand Down Expand Up @@ -554,8 +560,8 @@ int main(int argc, char **argv) {
gen_rand_vector<float, float>(host_output, static_cast<size_t>(n) * k * ho * wo, 0.0, 1.0);
gen_rand_vector<float, float>(host_weight, static_cast<size_t>(k) * c * y * x, -0.5, 0.5);
gen_rand_vector<float, float>(host_input, static_cast<size_t>(n) * c * hi * wi, 999999., 9999999.); // manually input value to a very large number
// gen_rand_vector<float, int>(host_output, n * k * ho * wo,1, 1);
// gen_rand_vector<float, int>(host_weight, k * c * y * x, 1, 1);
// gen_rand_vector<float, int>(host_output, static_cast<size_t>(n) * k * ho * wo,1, 1);
// gen_rand_vector<float, int>(host_weight, static_cast<size_t>(k) * c * y * x, 1, 1);
#ifdef USE_GPU_NAIVE_CONV
HIP_CALL(hipMemcpy(device_output, host_output,
static_cast<size_t>(n) * k * ho * wo * sizeof(float), hipMemcpyHostToDevice));
Expand Down Expand Up @@ -588,6 +594,9 @@ int main(int argc, char **argv) {
double nrms = get_bwd_nrms();
for (int i = 0; i < tunables.size(); i++) {
igemm_gtc_tunable_t *tunable = &tunables[i];
if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT)
if(run_only_kernel != conv_bwd_driver.get_kernel_name(tunable))
continue;

printf("[bwd:%2d] %s, ", i, conv_bwd_driver.get_kernel_name(tunable).c_str());
fflush(stdout);
Expand Down Expand Up @@ -650,8 +659,8 @@ int main(int argc, char **argv) {
// gen rand
gen_rand_vector<float, float>(host_input, static_cast<size_t>(n) * c * hi * wi, 0.0, 1.0);
gen_rand_vector<float, float>(host_output, static_cast<size_t>(n) * k * ho * wo, -0.5, 0.5);
//gen_rand_vector<float, int>(host_input, n * k * hi * wi, -5, 5);
//gen_rand_vector<float, int>(host_output, n * k * ho * wo, 1, 1);
//gen_rand_vector<float, int>(host_input, static_cast<size_t>(n) * k * hi * wi, -5, 5);
//gen_rand_vector<float, int>(host_output, static_cast<size_t>(n) * k * ho * wo, 1, 1);
#ifdef USE_GPU_NAIVE_CONV
HIP_CALL(hipMemcpy(device_input, host_input,
static_cast<size_t>(n) * c * hi * wi * sizeof(float), hipMemcpyHostToDevice));
Expand Down Expand Up @@ -718,13 +727,16 @@ int main(int argc, char **argv) {

for (int i = 0; i < tunables.size(); i++) {
igemm_gtc_tunable_t *tunable = &tunables[i];
if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT)
if(run_only_kernel != conv_wrw_driver.get_kernel_name(tunable))
continue;

printf("[wrw:%2d] %s, ", i, conv_wrw_driver.get_kernel_name(tunable).c_str());
fflush(stdout);

if (need_verify)
HIP_CALL(hipMemset(device_weight, 0,
k * c * y * x * sizeof(float)));
static_cast<size_t>(k) * c * y * x * sizeof(float)));
result_t result =
conv_wrw_driver.run(&conv_args, tunable, module, device_input,
device_weight, device_output, warmup, repeat);
Expand Down
69 changes: 65 additions & 4 deletions driver/igemm_fwd_gtc_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ class igemm_fwd_gtc_t {
int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w);
int group = arg->get_int("group_count");

int splits = split_batch_size(arg, tunable);
n = n/splits; // split batch size here

int gemm_m_per_block = tunable->gemm_m_per_block;
int gemm_n_per_block = tunable->gemm_n_per_block;
int nxe = tunable->nxe;
Expand All @@ -157,6 +160,54 @@ class igemm_fwd_gtc_t {
return grid_size;
}

// this is to support big tensor > 4G. need to decide how many splits needed
// return the number of splits
int split_batch_size(const args_t *arg, const igemm_gtc_tunable_t *tunable)
{
int hi = arg->get_int("in_h");
int wi = arg->get_int("in_w");
int n = arg->get_int("batchsize");
int k = arg->get_int("out_channels");
int c = arg->get_int("in_channels");

int stride_h = arg->get_int("conv_stride_h");
int stride_w = arg->get_int("conv_stride_w");
int dilation_h = arg->get_int("dilation_h");
int dilation_w = arg->get_int("dilation_w");
int pad_h = arg->get_int("pad_h");
int pad_w = arg->get_int("pad_w");
int y = arg->get_int("fil_h");
int x = arg->get_int("fil_w");
int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h);
int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w);

int data_byte = utility_string_to_data_byte(tunable->precision);
size_t image_size_input = static_cast<size_t>(c) * hi * wi * data_byte;
size_t image_size_output = static_cast<size_t>(k) * ho * wo * data_byte;
size_t size_4g = 0xffffffffUL;
if(image_size_input >= size_4g || image_size_output >= size_4g)
return 0;

size_t image_size = image_size_input >= image_size_output ? image_size_input : image_size_output;
size_t splited_n = size_4g / image_size;

// round up splits, we must match
// 1. splited_n * image_size < size_4g
// 2. n % splited_n == 0
// if(splited_n >= n)
// return 1;
assert(splited_n != 0);
while(splited_n >= 1){
// printf("n:%d, splited_n:%d\n", n, splited_n);
if(n % splited_n == 0)
break;
splited_n--;
}

assert(splited_n * image_size < size_4g && n % splited_n == 0);
return n / splited_n;
}

bool tunable_is_valid(const args_t *arg,
const igemm_gtc_tunable_t *tunable)
{
Expand All @@ -180,6 +231,13 @@ class igemm_fwd_gtc_t {

assert(c % group == 0 && k % group == 0);

int splits = split_batch_size(arg, tunable);
if(splits == 0){
printf("image size (c*h*w) is bigger than 4g, which is not supported now\n");
return false;
}
n = n/splits; // split batch size here

int gemm_m_per_block = tunable->gemm_m_per_block;
int gemm_n_per_block = tunable->gemm_n_per_block;
int gemm_k_per_block = tunable->gemm_k_per_block;
Expand Down Expand Up @@ -273,13 +331,16 @@ class igemm_fwd_gtc_t {

assert(c % group == 0 && k % group == 0);

int splits = split_batch_size(arg, tunable);
n = n/splits; // split batch size here

int gemm_m_per_block = tunable->gemm_m_per_block;
int gemm_n_per_block = tunable->gemm_n_per_block;
int gemm_k_per_block = tunable->gemm_k_per_block;
int nxe = tunable->nxe;
int nxb = tunable->nxb;
int b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0

igemm_fwd_gtc_karg_t karg;
size_t karg_size = sizeof(karg);
karg.p_in = p_in;
Expand Down Expand Up @@ -347,7 +408,7 @@ class igemm_fwd_gtc_t {
hipModuleGetFunction(&kernel_func, module, kernel_name.c_str()));

auto launch_fwd = [&]() -> float {
// printf("launch fwd block:%d, grid:%d\n", block_size, grid_size);
// printf("launch fwd block:%d, grid:%dx%d\n", block_size, grid_size, splits);
// dump_fwd_karg(&karg);
void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &karg,
HIP_LAUNCH_PARAM_BUFFER_SIZE, &karg_size,
Expand All @@ -361,7 +422,7 @@ class igemm_fwd_gtc_t {
hipEventCreate(&stop);

// for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem
HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, 1, 1,
HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, splits, 1,
block_size, 1, 1, 0, 0, NULL,
(void **)&config, start, stop));

Expand All @@ -373,7 +434,7 @@ class igemm_fwd_gtc_t {
gpu_timer_t timer(NULL);
timer.start();

HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, 1, 1,
HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, splits, 1,
block_size, 1, 1, 0, 0, NULL,
(void **)&config));

Expand Down
20 changes: 19 additions & 1 deletion igemm/algo/igemm_fwd_gtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,8 @@ def __init__(self, mc, outer):
sseq = gpr_sequencer_t()
self.outer = outer
self.s_ka = sym_t('s_ka' , sseq(2))
self.s_bx = sym_t('s_bx' , sseq(2))
self.s_bx = sym_t('s_bx' , sseq(1))
self.s_by = sym_t('s_by' , sseq(1))
self.s_p_in = sym_t('s_p_in' , sseq(4))
self.s_p_wei = sym_t('s_p_wei' , sseq(4))
self.s_p_out = sym_t('s_p_out' , sseq(4))
Expand Down Expand Up @@ -1230,6 +1231,7 @@ def get_kernel_code(self):
kernel_code = amdgpu_kernel_code_t({
'enable_sgpr_kernarg_segment_ptr' : 1,
'enable_sgpr_workgroup_id_x' : 1,
'enable_sgpr_workgroup_id_y' : 1,
'enable_vgpr_workitem_id' : 0,
'workgroup_group_segment_byte_size' : self.tunable.lds_total,
'kernarg_segment_byte_size' : self.karg.get_count(),
Expand Down Expand Up @@ -1521,6 +1523,22 @@ def emit_kernel_prologue(self):
self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_n()}], {igemm_log2(nb_n0)}")
self._emit(f"s_mul_i32 s[{s.s_out_stride_n0()}], s[{s.s_out_stride_n()}], s[{s.s_tmp()}]")

# calculate batch split and accumulate the base pointer for input/output
self._emit(f"s_mul_i32 s[{s.s_tmp(0)}], s[{s.s_n()}], s[{s.s_in_stride_n()}]")
self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_n()}], s[{s.s_out_stride_n()}]")
self._emit(f"s_lshl_b32 s[{s.s_tmp(4)}], s[{s.s_tmp(0)}], {igemm_log2(data_byte)}")
self._emit(f"s_lshl_b32 s[{s.s_tmp(5)}], s[{s.s_tmp(1)}], {igemm_log2(data_byte)}")

self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_by()}], s[{s.s_tmp(4)}]")
self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_by()}], s[{s.s_tmp(4)}]")
self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]")
self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]")

self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_by()}], s[{s.s_tmp(5)}]")
self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_by()}], s[{s.s_tmp(5)}]")
self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp()}]")
self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]")

# early init s_knum in case shifted
if self.tunable.nxe != 0:
self._emit(f"s_mul_i32 s[{s.s_knum()}], s[{s.s_wei_stride_c()}], s[{s.s_c()}]")
Expand Down