Skip to content

Commit

Permalink
Tune the AG performance for the llama-8b (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-ningxin authored Jul 16, 2024
1 parent 322710d commit 85af92c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/all_gather/gemm_v2_ag_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ struct GemmV2AGKernel_Space : OpSpaceBase<GemmV2AGKernel_Space> {
make_gemm_v2_hparams(Shape<_64, _64, _32>{}, Shape<_16, _8, _16>{}, _StreamkDP{})),
cute::make_tuple(Auto{}),
cute::make_tuple(
Shape<_128, _128, _64>{},
Shape<_128, _128, _32>{},
Shape<_128, _128, _64>{},
Shape<_64, _128, _32>{},
Shape<_64, _128, _64>{},
Shape<_64, _256, _32>{},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ static int config_ag_gemm_kernel_sm80_tp4_nnodes1 = []() {
inst.add(make_gemm_meta(make_gemm_dtype_config(_FP16{}(),_FP16{}(),_FP16{}(),_FP16{}()),_Sm80{}(),_AGKernel{}(),_RCR{}(),_GemmV2{}()),make_runtime_config(8192,12288,12288,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(256l,128l,32l),_GemmStreamK{}(),3,_RasterAlongM{}()));
/// NVLink
inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_Void{}(),_BF16{}()),_Sm80{}(),_AGKernel{}(),_RRR{}(),_GemmV2{}()),make_runtime_config(8192,12288,12288,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(128l,256l,32l),_GemmStreamK{}(),3,_RasterAlongM{}()));
inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_Void{}(),_BF16{}(),_FP32{}()),_Sm80{}(),_AGKernel{}(),_RRR{}(),_GemmV2{}(),make_gemm_v2_meta(false),None{}),make_runtime_config(2048,1792,4096,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(64l,128l,64l),_GemmStreamK{}(),3,_RasterAlongM{}()));
inst.add(make_gemm_meta(make_gemm_dtype_config(_FP16{}(),_FP16{}(),_Void{}(),_FP16{}(),_FP32{}()),_Sm80{}(),_AGKernel{}(),_RRR{}(),_GemmV2{}(),make_gemm_v2_meta(false),None{}),make_runtime_config(2048,1792,4096,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(64l,128l,64l),_GemmStreamK{}(),3,_RasterAlongM{}()));
inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_Void{}(),_BF16{}(),_FP32{}()),_Sm80{}(),_AGKernel{}(),_RCR{}(),_GemmV2{}(),make_gemm_v2_meta(false),None{}),make_runtime_config(2048,1792,4096,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(64l,128l,64l),_GemmStreamK{}(),3,_RasterAlongM{}()));
inst.add(make_gemm_meta(make_gemm_dtype_config(_FP16{}(),_FP16{}(),_Void{}(),_FP16{}(),_FP32{}()),_Sm80{}(),_AGKernel{}(),_RCR{}(),_GemmV2{}(),make_gemm_v2_meta(false),None{}),make_runtime_config(2048,1792,4096,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(64l,128l,64l),_GemmStreamK{}(),4,_RasterAlongM{}()));
inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_Void{}(),_BF16{}(),_FP32{}()),_Sm80{}(),_AGKernel{}(),_RRR{}(),_GemmV2{}(),make_gemm_v2_meta(false),None{}),make_runtime_config(2048,7168,4096,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(128l,128l,32l),_GemmStreamK{}(),3,_RasterAlongM{}()));
inst.add(make_gemm_meta(make_gemm_dtype_config(_FP16{}(),_FP16{}(),_Void{}(),_FP16{}(),_FP32{}()),_Sm80{}(),_AGKernel{}(),_RRR{}(),_GemmV2{}(),make_gemm_v2_meta(false),None{}),make_runtime_config(2048,7168,4096,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(128l,128l,32l),_GemmStreamK{}(),3,_RasterAlongM{}()));
inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_Void{}(),_BF16{}(),_FP32{}()),_Sm80{}(),_AGKernel{}(),_RCR{}(),_GemmV2{}(),make_gemm_v2_meta(false),None{}),make_runtime_config(2048,7168,4096,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(128l,128l,32l),_GemmStreamK{}(),3,_RasterAlongM{}()));
inst.add(make_gemm_meta(make_gemm_dtype_config(_FP16{}(),_FP16{}(),_Void{}(),_FP16{}(),_FP32{}()),_Sm80{}(),_AGKernel{}(),_RCR{}(),_GemmV2{}(),make_gemm_v2_meta(false),None{}),make_runtime_config(2048,7168,4096,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkSK{}()),None{},cute::make_tuple(256l,128l,32l),_GemmStreamK{}(),3,_RasterAlongN{}()));

return 0;
}();
}
Expand Down

0 comments on commit 85af92c

Please sign in to comment.