@@ -165,6 +165,8 @@ measured_fp32_conv_gflops(double time_ms, size_t n, size_t c, size_t hi,
165165#define IGEMM_CONFIG_FILE " igemm_gtc.config"
166166#endif
167167
168+ #define IGEMM_RUN_ONLY_KERNEL_DEFAULT " off"
169+
168170#define WARMUP 3
169171#define REPEAT 8
170172#define SCLK_MHZ 1283
@@ -214,14 +216,14 @@ struct distribution_t<float>{
214216};
215217
216218template <typename Dst_T, typename Src_T>
217- void block_wise_rand_generator (Dst_T *p, int tid, int block_size, int total_size, Src_T min, Src_T max, Src_T scale)
219+ void block_wise_rand_generator (Dst_T *p, int tid, int block_size, size_t total_size, Src_T min, Src_T max, Src_T scale)
218220{
219221 std::mt19937 rng (std::chrono::system_clock::now ()
220222 .time_since_epoch ()
221223 .count () +
222224 std::hash<std::thread::id>()(std::this_thread::get_id ()));
223225 distribution_t <Src_T> distribution (min,max);
224- for (int i = tid; i < total_size; i += block_size) {
226+ for (size_t i = tid; i < total_size; i += block_size) {
225227 p[i] = static_cast <Dst_T>(scale * distribution (rng));
226228 }
227229}
@@ -342,6 +344,7 @@ void dump_arg(const args_t *arg) {
342344int main (int argc, char **argv) {
343345 char *hsaco = env_get_str (" IGEMM_HSACO" , IGEMM_HSACO);
344346 char *config_file = env_get_str (" IGEMM_CONFIG_FILE" , IGEMM_CONFIG_FILE);
347+ std::string run_only_kernel = env_get_str (" IGEMM_RUN_ONLY_KERNEL" , IGEMM_RUN_ONLY_KERNEL_DEFAULT);
345348 int warmup = env_get_int (" IGEMM_WARMUP" , WARMUP);
346349 int repeat = env_get_int (" IGEMM_REPEAT" , REPEAT);
347350 int sclk_mhz = env_get_int (" IGEMM_SCLK_MHZ" , SCLK_MHZ);
@@ -457,8 +460,8 @@ int main(int argc, char **argv) {
457460 gen_rand_vector<float , float >(host_input, static_cast <size_t >(n) * c * hi * wi, 0.0 , 1.0 );
458461 gen_rand_vector<float , float >(host_weight, static_cast <size_t >(k) * c * y * x, -0.5 , 0.5 );
459462
460- // gen_rand_vector<float, int>(host_input, n * c * hi * wi, 1, 1);
461- // gen_rand_vector<float, int>(host_weight, k * c * y * x, 1, 1);
463+ // gen_rand_vector<float, int>(host_input, static_cast<size_t>(n) * c * hi * wi, 1, 1);
464+ // gen_rand_vector<float, int>(host_weight, static_cast<size_t>(k) * c * y * x, 1, 1);
462465
463466#ifdef USE_GPU_NAIVE_CONV
464467 HIP_CALL (hipMemcpy (device_input, host_input,
@@ -506,6 +509,9 @@ int main(int argc, char **argv) {
506509 double nrms = get_fwd_nrms ();
507510 for (int i = 0 ; i < tunables.size (); i++) {
508511 igemm_gtc_tunable_t *tunable = &tunables[i];
512+ if (run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT)
513+ if (run_only_kernel != conv_fwd_driver.get_kernel_name (tunable))
514+ continue ;
509515
510516 printf (" [fwd:%2d] %s, " , i, conv_fwd_driver.get_kernel_name (tunable).c_str ());
511517 fflush (stdout);
@@ -569,8 +575,8 @@ int main(int argc, char **argv) {
569575 gen_rand_vector<float , float >(host_output, static_cast <size_t >(n) * k * ho * wo, 0.0 , 1.0 );
570576 gen_rand_vector<float , float >(host_weight, static_cast <size_t >(k) * c * y * x, -0.5 , 0.5 );
571577 gen_rand_vector<float , float >(host_input, static_cast <size_t >(n) * c * hi * wi, 999999 ., 9999999 .); // manually input value to a very large number
572- // gen_rand_vector<float, int>(host_output, n * k * ho * wo,1, 1);
573- // gen_rand_vector<float, int>(host_weight, k * c * y * x, 1, 1);
578+ // gen_rand_vector<float, int>(host_output, static_cast<size_t>(n) * k * ho * wo,1, 1);
579+ // gen_rand_vector<float, int>(host_weight, static_cast<size_t>(k) * c * y * x, 1, 1);
574580#ifdef USE_GPU_NAIVE_CONV
575581 HIP_CALL (hipMemcpy (device_output, host_output,
576582 static_cast <size_t >(n) * k * ho * wo * sizeof (float ), hipMemcpyHostToDevice));
@@ -618,6 +624,9 @@ int main(int argc, char **argv) {
618624 double nrms = get_bwd_nrms ();
619625 for (int i = 0 ; i < tunables.size (); i++) {
620626 igemm_gtc_tunable_t *tunable = &tunables[i];
627+ if (run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT)
628+ if (run_only_kernel != conv_bwd_driver.get_kernel_name (tunable))
629+ continue ;
621630
622631 printf (" [bwd:%2d] %s, " , i, conv_bwd_driver.get_kernel_name (tunable).c_str ());
623632 fflush (stdout);
@@ -680,8 +689,8 @@ int main(int argc, char **argv) {
680689 // gen rand
681690 gen_rand_vector<float , float >(host_input, static_cast <size_t >(n) * c * hi * wi, 0.0 , 1.0 );
682691 gen_rand_vector<float , float >(host_output, static_cast <size_t >(n) * k * ho * wo, -0.5 , 0.5 );
683- // gen_rand_vector<float, int>(host_input, n * k * hi * wi, -5, 5);
684- // gen_rand_vector<float, int>(host_output, n * k * ho * wo, 1, 1);
692+ // gen_rand_vector<float, int>(host_input, static_cast<size_t>(n) * k * hi * wi, -5, 5);
693+ // gen_rand_vector<float, int>(host_output, static_cast<size_t>(n) * k * ho * wo, 1, 1);
685694#ifdef USE_GPU_NAIVE_CONV
686695 HIP_CALL (hipMemcpy (device_input, host_input,
687696 static_cast <size_t >(n) * c * hi * wi * sizeof (float ), hipMemcpyHostToDevice));
@@ -763,13 +772,16 @@ int main(int argc, char **argv) {
763772
764773 for (int i = 0 ; i < tunables.size (); i++) {
765774 igemm_gtc_tunable_t *tunable = &tunables[i];
775+ if (run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT)
776+ if (run_only_kernel != conv_wrw_driver.get_kernel_name (tunable))
777+ continue ;
766778
767779 printf (" [wrw:%2d] %s, " , i, conv_wrw_driver.get_kernel_name (tunable).c_str ());
768780 fflush (stdout);
769781
770782 if (need_verify)
771783 HIP_CALL (hipMemset (device_weight, 0 ,
772- k * c * y * x * sizeof (float )));
784+ static_cast < size_t >(k) * c * y * x * sizeof (float )));
773785 result_t result =
774786 conv_wrw_driver.run (&conv_args, tunable, module , device_input,
775787 device_weight, device_output, warmup, repeat);
0 commit comments