From 90d3b0fb1803fed7edab91e63c013d2fb83a525a Mon Sep 17 00:00:00 2001 From: ANIKET SHIVAM <3268307+ANIKET-SHIVAM@users.noreply.github.com> Date: Tue, 26 Sep 2023 14:24:26 -0700 Subject: [PATCH] CUTLASS 3.2.1 (#1113) * Updates for 3.2.1 release. * Minor fix in gemm op profiler for raster order. * Add scheduler mapping for raster order in the kernels. --- CHANGELOG.md | 12 +- CMakeLists.txt | 59 +- CUDA.cmake | 4 +- README.md | 15 +- cmake/NvidiaCutlassConfig.cmake | 9 +- .../08_turing_tensorop_gemm/CMakeLists.txt | 1 - .../turing_tensorop_gemm.cu | 5 +- .../turing_tensorop_conv2dfprop.cu | 6 - examples/12_gemm_bias_relu/CMakeLists.txt | 1 - .../fused_two_convs_s8_sm75_rf.cu | 5 - .../fused_two_convs_s8_sm75_shmem.cu | 8 - .../fused_two_gemms_s8_sm75_rf.cu | 4 - .../fused_two_gemms_s8_sm75_shmem.cu | 6 - .../threadblock/grouped_threadblock_swizzle.h | 28 - examples/24_gemm_grouped/CMakeLists.txt | 3 +- examples/40_cutlass_py/README.md | 27 +- examples/40_cutlass_py/conv2d.py | 8 +- examples/40_cutlass_py/customizable/README.md | 25 - examples/40_cutlass_py/customizable/conv2d.py | 7 +- examples/40_cutlass_py/customizable/gemm.py | 135 +-- .../customizable/gemm_grouped.py | 7 +- examples/40_cutlass_py/gemm.py | 6 +- examples/40_cutlass_py/gemm_grouped.py | 6 +- .../ir_gen/gen_device.py | 8 - ...ampere_gemm_universal_streamk_broadcast.cu | 140 ++- .../52_hopper_gather_scatter_fusion.cu | 3 +- .../gather_gemm.hpp | 2 +- .../54_hopper_fp8_warp_specialized_gemm.cu | 2 +- examples/python/00_basic_gemm.ipynb | 4 +- examples/python/01_epilogue.ipynb | 4 +- .../02_pytorch_extension_grouped_gemm.ipynb | 2 - examples/python/04_epilogue_visitor.ipynb | 221 ++++ examples/python/README.md | 4 + include/cute/algorithm/axpby.hpp | 9 +- include/cute/algorithm/gemm.hpp | 6 +- include/cute/algorithm/tuple_algorithms.hpp | 75 +- include/cute/arch/copy_sm90_desc.hpp | 2 +- include/cute/arch/mma_sm80.hpp | 17 +- include/cute/arch/mma_sm90.hpp | 7 +- include/cute/arch/mma_sm90_desc.hpp | 14 +- include/cute/atom/copy_traits_sm90_tma.hpp | 474 +++++--- include/cute/atom/mma_atom.hpp | 2 +- include/cute/atom/mma_traits_sm75.hpp | 4 +- include/cute/config.hpp | 2 +- include/cute/container/bit_field.hpp | 10 +- include/cute/container/tuple.hpp | 38 +- include/cute/int_tuple.hpp | 113 +- include/cute/layout.hpp | 65 +- include/cute/numeric/arithmetic_tuple.hpp | 6 +- include/cute/numeric/complex.hpp | 144 +-- include/cute/numeric/integral_constant.hpp | 59 +- include/cute/numeric/integral_ratio.hpp | 175 +++ include/cute/numeric/math.hpp | 36 +- include/cute/pointer.hpp | 36 +- include/cute/stride.hpp | 3 + include/cute/swizzle.hpp | 171 ++- include/cute/swizzle_layout.hpp | 2 + include/cute/swizzle_ptr.hpp | 9 +- include/cute/util/print.hpp | 66 +- include/cutlass/arch/mma_sm75.h | 104 +- include/cutlass/arch/mma_sm80.h | 17 +- include/cutlass/array.h | 16 +- include/cutlass/array_subbyte.h | 9 + include/cutlass/barrier.h | 2 +- include/cutlass/bfloat16.h | 11 + include/cutlass/complex.h | 56 +- include/cutlass/conv/conv2d_problem_size.h | 20 +- include/cutlass/conv/conv3d_problem_size.h | 52 +- include/cutlass/conv/convolution.h | 12 +- .../cutlass/conv/kernel/direct_convolution.h | 2 +- .../conv/kernel/implicit_gemm_convolution.h | 2 +- .../conv/threadblock/threadblock_swizzle.h | 12 +- include/cutlass/coord.h | 10 + include/cutlass/core_io.h | 11 +- include/cutlass/cutlass.h | 10 + include/cutlass/detail/helper_macros.hpp | 14 + include/cutlass/detail/layout.hpp | 17 + .../collective/builders/sm90_builder.inl | 92 +- .../collective/collective_builder.hpp | 1 + .../cutlass/epilogue/collective/detail.hpp | 4 +- .../sm90_epilogue_tma_warpspecialized.hpp | 129 ++- include/cutlass/epilogue/fusion/callbacks.hpp | 2 + .../cutlass/epilogue/fusion/operations.hpp | 8 +- .../sm90_callbacks_tma_warpspecialized.hpp | 88 +- ...90_visitor_compute_tma_warpspecialized.hpp | 121 ++ .../sm90_visitor_load_tma_warpspecialized.hpp | 60 +- ...sm90_visitor_store_tma_warpspecialized.hpp | 44 +- .../sm90_visitor_tma_warpspecialized.hpp | 104 +- include/cutlass/epilogue/thread/activation.h | 367 ++---- .../thread/linear_combination_generic.h | 63 +- .../epilogue_with_visitor_callbacks.h | 495 ++++++++ .../threadblock/fusion/visitor_2x.hpp | 433 +++++++ .../threadblock/fusion/visitor_compute.hpp | 109 ++ .../threadblock/fusion/visitor_load.hpp | 559 +++++++++ .../threadblock/fusion/visitor_store.hpp | 781 +++++++++++++ .../epilogue/threadblock/fusion/visitors.hpp | 25 +- .../predicated_tile_iterator_params.h | 10 + include/cutlass/fast_math.h | 60 +- include/cutlass/float8.h | 16 + include/cutlass/functional.h | 11 +- .../collective/builders/sm90_gmma_builder.inl | 2 +- .../sm90_mma_tma_gmma_rs_warpspecialized.hpp | 96 +- .../cutlass/gemm/device/gemm_universal_base.h | 14 - include/cutlass/gemm/gemm.h | 31 +- .../cutlass/gemm/gemm_enumerated_types.h | 76 +- .../default_gemm_universal_with_visitor.h | 157 +++ .../kernel/gemm_grouped_problem_visitor.h | 1 - include/cutlass/gemm/kernel/gemm_universal.h | 16 +- .../gemm/kernel/gemm_universal_with_visitor.h | 321 ++++++ .../gemm_universal_with_visitor_streamk.h | 892 +++++++++++++++ .../gemm/kernel/params_universal_base.h | 73 +- include/cutlass/gemm/kernel/sm70_gemm.hpp | 2 +- include/cutlass/gemm/kernel/sm90_gemm_tma.hpp | 4 +- .../kernel/sm90_gemm_tma_warpspecialized.hpp | 4 +- ...0_gemm_tma_warpspecialized_cooperative.hpp | 19 +- ...sm90_gemm_tma_warpspecialized_pingpong.hpp | 5 +- .../gemm/kernel/sm90_tile_scheduler.hpp | 227 +--- .../kernel/sm90_tile_scheduler_stream_k.hpp | 525 ++------- .../gemm/kernel/tile_scheduler_params.h | 1005 +++++++++++++++++ .../gemm/threadblock/threadblock_swizzle.h | 72 +- .../threadblock/threadblock_swizzle_streamk.h | 12 +- .../cutlass/gemm_coord.hpp | 59 +- include/cutlass/half.h | 12 +- include/cutlass/integer_subbyte.h | 14 + .../cutlass/kernel_hardware_info.h | 77 +- include/cutlass/kernel_hardware_info.hpp | 47 +- include/cutlass/layout/matrix.h | 14 + include/cutlass/layout/pitch_linear.h | 11 + include/cutlass/layout/vector.h | 1 + include/cutlass/numeric_size.h | 93 ++ include/cutlass/numeric_types.h | 45 +- include/cutlass/pipeline/sm90_pipeline.hpp | 55 +- include/cutlass/platform/platform.h | 25 +- include/cutlass/subbyte_reference.h | 7 +- .../collective/sm90_wgmma_transpose.hpp | 2 +- .../predicated_tile_access_iterator_params.h | 12 + include/cutlass/uint128.h | 9 +- .../cutlass/{workspace.hpp => workspace.h} | 11 + .../building_in_windows_with_visual_studio.md | 93 ++ .../building_with_clang_as_host_compiler.md | 53 + media/docs/code_organization.md | 9 +- media/docs/profiler.md | 2 + python/README.md | 71 +- python/cutlass/__init__.py | 37 +- python/cutlass/backend/__init__.py | 5 - python/cutlass/backend/arguments.py | 74 +- python/cutlass/backend/c_types.py | 252 ++++- python/cutlass/backend/compiler.py | 30 +- python/cutlass/backend/conv2d_operation.py | 355 +++--- python/cutlass/backend/epilogue.py | 890 +++------------ .../cutlass/backend/{test => evt}/__init__.py | 8 +- .../cutlass/backend/evt/backend/__init__.py | 36 + .../backend/evt/backend/emitter_base.py | 158 +++ .../backend/evt/backend/sm80_emitter.py | 47 + .../cutlass/backend/evt/backend/sm80_nodes.py | 258 +++++ .../backend/evt/backend/sm90_emitter.py | 98 ++ .../cutlass/backend/evt/backend/sm90_nodes.py | 351 ++++++ python/cutlass/backend/evt/epilogue.py | 165 +++ .../cutlass/backend/evt/frontend/__init__.py | 33 + .../backend/evt/frontend/frontend_base.py | 262 +++++ .../backend/evt/frontend/python_ast.py | 184 +++ python/cutlass/backend/evt/ir/__init__.py | 53 + .../cutlass/backend/evt/ir/compute_nodes.py | 91 ++ python/cutlass/backend/evt/ir/dag_ir.py | 235 ++++ .../backend/evt/ir/layout_algorithm.py | 324 ++++++ python/cutlass/backend/evt/ir/layout_nodes.py | 336 ++++++ python/cutlass/backend/evt/ir/load_nodes.py | 294 +++++ python/cutlass/backend/evt/ir/node.py | 292 +++++ python/cutlass/backend/evt/ir/store_nodes.py | 276 +++++ python/cutlass/backend/evt/ir/tensor.py | 130 +++ python/cutlass/backend/evt/passes/__init__.py | 42 + .../backend/evt/passes/graph_drawer.py | 158 +++ .../backend/evt/passes/pass_argument_type.py | 116 ++ .../backend/evt/passes/pass_dag_2_tree.py | 147 +++ .../backend/evt/passes/pass_fix_element_d.py | 64 ++ .../backend/evt/passes/pass_get_impl.py | 89 ++ .../evt/passes/pass_layout_elimination.py | 217 ++++ .../backend/evt/passes/pass_manager.py | 163 +++ .../evt/passes/pass_no_op_elimination.py | 53 + .../backend/evt/passes/pass_preprocess_red.py | 98 ++ .../evt/passes/pass_shape_type_propagation.py | 59 + .../evt/passes/smem_size_calculator.py | 200 ++++ python/cutlass/backend/frontend.py | 22 +- python/cutlass/backend/gemm_operation.py | 597 +++++----- python/cutlass/backend/library.py | 529 ++------- python/cutlass/backend/parser.py | 877 -------------- python/cutlass/backend/reduction_operation.py | 97 +- python/cutlass/backend/test/conv2d_testbed.py | 807 ------------- .../backend/test/gemm_grouped_testbed.py | 276 ----- python/cutlass/backend/test/gemm_testbed.py | 765 ------------- python/cutlass/backend/test/profiler.py | 69 -- python/cutlass/backend/utils/__init__.py | 1 - python/cutlass/backend/utils/datatypes.py | 71 +- .../cutlass/backend/utils/reference_model.py | 317 ------ python/cutlass/backend/utils/software.py | 6 +- python/cutlass/cpp/cutlass_bindings.cpp | 182 --- .../cpp/include/conv/conv_problem_size.h | 102 -- python/cutlass/cpp/include/conv/convolution.h | 91 -- python/cutlass/cpp/include/conv/host.h | 54 - .../epilogue/epilogue_visitor_generic.h | 222 ---- .../epilogue/epilogue_visitor_op/unary_ops.h | 233 ---- .../visitor_op_accumulator.h | 148 --- .../epilogue_visitor_op/visitor_op_binary.h | 245 ---- .../visitor_op_column_broadcast.h | 250 ---- .../visitor_op_column_reduction.h | 341 ------ .../visitor_op_linear_combination.h | 266 ----- .../visitor_op_row_broadcast.h | 258 ----- .../visitor_op_row_reduction.h | 319 ------ .../visitor_op_tensor_input.h | 188 --- .../visitor_op_tensor_output.h | 240 ---- .../epilogue_visitor_op/visitor_op_unary.h | 226 ---- .../epilogue_visitor_with_layernorm.h | 480 -------- python/cutlass/cpp/include/gemm/gemm.h | 77 -- .../gemm/gemm_universal_with_visitor.h | 638 ----------- python/cutlass/cpp/include/layout/matrix.h | 87 -- python/cutlass/cpp/include/layout/tensor.h | 74 -- python/cutlass/cpp/include/swizzling.h | 169 --- python/cutlass/cpp/include/tensor_coord.h | 78 -- python/cutlass/cpp/include/tensor_ref_view.h | 102 -- python/cutlass/cpp/include/types.h | 146 --- python/cutlass/cpp/library.h | 32 - python/cutlass/cpp/test/conv/conv_problems.h | 54 - python/cutlass/cpp/test/conv/convolution.h | 49 - python/cutlass/cpp/test/conv/host.h | 181 --- python/cutlass/cpp/test/gemm/gemm.h | 45 - python/cutlass/cpp/test/gemm/host.h | 431 ------- python/cutlass/emit/common.py | 18 +- python/cutlass/emit/pytorch.py | 96 +- python/cutlass/epilogue/__init__.py | 53 + python/cutlass/{ => epilogue}/epilogue.py | 53 +- .../tensor_ref.py => epilogue/evt_ops.py} | 68 +- python/cutlass/library_defaults.py | 50 +- python/cutlass/op/__init__.py | 2 +- python/cutlass/op/conv.py | 421 ++++--- python/cutlass/op/gemm.py | 137 ++- python/cutlass/op/gemm_grouped.py | 25 +- python/cutlass/op/op.py | 96 +- python/cutlass/profiler/__init__.py | 37 + python/cutlass/profiler/event_profiler.py | 185 +++ python/cutlass/shape.py | 184 +++ python/cutlass/swizzle.py | 23 +- python/cutlass/utils/__init__.py | 2 +- python/cutlass/utils/check.py | 10 +- python/cutlass/utils/datatypes.py | 122 +- python/cutlass_library/__init__.py | 49 + .../cutlass_library}/conv2d_operation.py | 59 +- .../cutlass_library}/conv3d_operation.py | 49 +- .../cutlass_library}/gemm_operation.py | 36 +- .../cutlass_library}/generator.py | 132 ++- .../cutlass_library}/library.py | 87 +- python/cutlass_library/manifest.py | 683 +++++++++++ .../cutlass_library}/rank_2k_operation.py | 65 +- .../cutlass_library}/rank_k_operation.py | 61 +- .../cutlass_library}/symm_operation.py | 65 +- .../cutlass_library}/trmm_operation.py | 71 +- python/pycute/__init__.py | 36 + python/pycute/int_tuple.py | 230 ++++ python/pycute/layout.py | 358 ++++++ python/pycute/swizzle.py | 129 +++ python/pycute/typing.py | 42 + python/setup.py | 110 +- .../setup_library.py | 23 +- python/setup_pycute.py | 46 + test/CMakeLists.txt | 5 +- test/python/backend/conv/__init__.py | 0 ...nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py | 233 ---- ...nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py | 209 ---- ...m_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py | 130 --- ...hwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py | 127 --- ...nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py | 196 ---- ...nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py | 220 ---- ...nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py | 341 ------ ...nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py | 86 -- ...m_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py | 128 --- ...hwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py | 139 --- ...nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py | 285 ----- ...nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py | 129 --- ...nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py | 274 ----- ...m_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py | 128 --- ...hwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py | 139 --- test/python/backend/gemm/__init__.py | 0 test/python/backend/gemm/gemm_bf16_sm80.py | 128 --- test/python/backend/gemm/gemm_bf16_sm90.py | 138 --- test/python/backend/gemm/gemm_f16_sm80.py | 479 -------- test/python/backend/gemm/gemm_f16_sm90.py | 182 --- test/python/backend/gemm/gemm_f32_sm80.py | 178 --- test/python/backend/gemm/gemm_f64_sm80.py | 134 --- test/python/backend/gemm/gemm_f64_sm90.py | 124 -- test/python/backend/gemm/gemm_grouped_sm80.py | 235 ---- test/python/backend/gemm/gemm_s8_sm80.py | 261 ----- test/python/backend/gemm/gemm_s8_sm90.py | 154 --- test/python/conv2d/conv2d_test_utils.py | 508 --------- .../cutlass/conv2d/conv2d_problem_sizes.py | 660 +++++++++++ .../{ => cutlass}/conv2d/conv2d_sm80.py | 62 +- .../cutlass/conv2d/conv2d_test_utils.py | 425 +++++++ .../conv => cutlass/conv2d}/run_all_tests.py | 14 +- test/python/{ => cutlass}/emit/pytorch.py | 43 +- .../python/cutlass/evt/evt_compute_sm80_90.py | 100 ++ test/python/cutlass/evt/evt_layout_sm80_90.py | 173 +++ test/python/cutlass/evt/evt_load_sm80_90.py | 142 +++ test/python/cutlass/evt/evt_mixed_sm80_90.py | 274 +++++ test/python/cutlass/evt/evt_store_sm80_90.py | 155 +++ .../{conv2d => cutlass/evt}/run_all_tests.py | 4 +- test/python/cutlass/evt/utils/evt_testbed.py | 230 ++++ .../python/{ => cutlass}/gemm/gemm_batched.py | 13 +- .../{ => cutlass}/gemm/gemm_f16_sm80.py | 6 +- .../{ => cutlass}/gemm/gemm_f16_sm90.py | 6 +- .../{ => cutlass}/gemm/gemm_f32_sm80.py | 6 +- .../{ => cutlass}/gemm/gemm_f64_sm80.py | 6 +- .../{ => cutlass}/gemm/gemm_f64_sm90.py | 6 +- .../python/{ => cutlass}/gemm/gemm_s8_sm80.py | 6 +- .../python/{ => cutlass}/gemm/gemm_s8_sm90.py | 6 +- test/python/cutlass/gemm/gemm_testbed.py | 387 +++++++ .../{ => cutlass}/gemm/run_all_tests.py | 4 +- .../python/cutlass/gemm}/utils.py | 120 +- .../interface/conv2d_interface.py | 1 - .../python/cutlass/interface/evt_interface.py | 245 ++++ .../{ => cutlass}/interface/gemm_interface.py | 5 +- test/python/{ => cutlass}/interface/utils.py | 6 +- test/python/pycute/run_all_tests.py | 75 ++ test/python/pycute/test_coalesce.py | 95 ++ test/python/pycute/test_complement.py | 92 ++ test/python/pycute/test_composition.py | 204 ++++ test/python/pycute/test_int_tuple.py | 80 ++ test/python/pycute/test_left_inverse.py | 87 ++ test/python/pycute/test_right_inverse.py | 96 ++ test/python/pycute/test_typing.py | 59 + ...wx_s4cxrskx_s4ncxhwx_tensor_op_s32_sm75.cu | 2 - ...4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu | 1 - ...wx_s8cxrskx_s8ncxhwx_tensor_op_s32_sm75.cu | 1 - ...8nhwc_s8nhwc_s32nhwc_tensor_op_s32_sm75.cu | 1 - test/unit/core/CMakeLists.txt | 20 + test/unit/core/cpp11.cu | 86 ++ test/unit/cute/core/CMakeLists.txt | 2 +- test/unit/cute/core/constant_arithmetic.cpp | 106 -- .../unit/cute/core/constants.cpp | 37 +- test/unit/cute/core/mixedbits.cpp | 90 +- test/unit/cute/hopper/tma_load.cu | 323 +++--- test/unit/cute/hopper/tma_load_testbed.hpp | 199 ++++ test/unit/cute/hopper/tma_store.cu | 153 +-- test/unit/cute/hopper/tma_store_testbed.hpp | 184 +++ .../gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu | 11 +- .../gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu | 5 +- .../gemm_s4n_s4t_s4n_tensor_op_s32_sm75.cu | 2 - .../gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu | 2 - ...mm_s4t_s4n_s32n_wmma_tensor_op_s32_sm75.cu | 2 - .../gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu | 2 - ...mm_s4t_s4n_s32t_wmma_tensor_op_s32_sm75.cu | 1 - .../gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu | 1 - .../gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu | 3 +- .../gemm_s8n_s8t_s8n_tensor_op_s32_sm75.cu | 2 - .../gemm_s8t_s8n_s32n_tensor_op_s32_sm75.cu | 2 - .../gemm_s8t_s8n_s32t_tensor_op_s32_sm75.cu | 2 - .../gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu | 2 - .../gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu | 2 - test/unit/gemm/device/gemm_testbed_3x.hpp | 18 +- test/unit/gemm/device/gemm_testbed_3x_evt.hpp | 2 +- test/unit/gemm/device/sm90_evt_operations.hpp | 53 - ...er_warpspecialized_cooperative_aux_load.cu | 12 +- ...cluster_warpspecialized_cooperative_dag.cu | 8 +- ...rpspecialized_cooperative_row_broadcast.cu | 4 +- ...uster_warpspecialized_pingpong_aux_load.cu | 12 +- ...32_cluster_warpspecialized_pingpong_dag.cu | 8 +- ..._warpspecialized_pingpong_row_broadcast.cu | 4 +- ...sm90_gemm_f8_f8_bf16_tensor_op_fp32_evt.cu | 4 +- ...cluster_warpspecialized_cooperative_evt.cu | 4 +- .../sm90_gemm_f8_f8_f8_tensor_op_fp32_evt.cu | 4 +- .../gemm/threadblock/mma_pipelined_sm75.cu | 2 +- .../threadblock/mma_pipelined_wmma_sm75.cu | 2 +- .../threadblock/mma_singlestage_wmma_sm75.cu | 2 +- test/unit/gemm/warp/gemm_sm75.cu | 2 +- test/unit/pipeline/pipeline_async.cu | 6 - ...e_tma_async_warp_specialized_persistent.cu | 1 - tools/library/CMakeLists.txt | 247 ++-- .../library/include/cutlass/library/library.h | 1 + tools/library/include/cutlass/library/types.h | 7 + tools/library/include/cutlass/library/util.h | 7 + tools/library/scripts/__init__.py | 0 tools/library/scripts/manifest.py | 476 -------- tools/library/scripts/rt.py | 796 ------------- tools/library/src/gemm_operation_3x.hpp | 14 + tools/library/src/util.cu | 44 + tools/profiler/CMakeLists.txt | 18 +- .../profiler}/conv2d_operation_profiler.h | 2 + .../profiler}/conv3d_operation_profiler.h | 2 + .../cutlass/profiler}/cublas_helpers.h | 0 .../cutlass/profiler}/cudnn_helpers.h | 0 .../cutlass/profiler}/cutlass_profiler.h | 0 .../{src => include/cutlass/profiler}/debug.h | 0 .../cutlass/profiler}/device_allocation.h | 0 .../cutlass/profiler}/device_context.h | 0 .../cutlass/profiler}/enumerated_types.h | 0 .../profiler}/gemm_operation_profiler.h | 6 +- .../cutlass/profiler}/gpu_timer.h | 0 .../cutlass/profiler}/operation_profiler.h | 0 .../cutlass/profiler}/options.h | 4 + .../cutlass/profiler}/performance_report.h | 0 .../cutlass/profiler}/performance_result.h | 0 .../cutlass/profiler}/problem_space.h | 9 + .../profiler}/rank_2k_operation_profiler.h | 0 .../profiler}/rank_k_operation_profiler.h | 0 .../profiler}/reduction_operation_profiler.h | 0 .../sparse_gemm_operation_profiler.h | 0 .../profiler}/symm_operation_profiler.h | 0 .../profiler}/trmm_operation_profiler.h | 0 .../profiler/src/conv2d_operation_profiler.cu | 5 +- .../profiler/src/conv3d_operation_profiler.cu | 5 +- tools/profiler/src/cublas_helpers.cu | 2 +- tools/profiler/src/cudnn_helpers.cpp | 2 +- tools/profiler/src/cutlass_profiler.cu | 18 +- tools/profiler/src/device_allocation.cu | 2 +- tools/profiler/src/device_context.cu | 2 +- tools/profiler/src/enumerated_types.cpp | 2 +- tools/profiler/src/gemm_operation_profiler.cu | 22 +- tools/profiler/src/gpu_timer.cpp | 2 +- tools/profiler/src/main.cpp | 4 +- tools/profiler/src/operation_profiler.cu | 18 +- tools/profiler/src/options.cu | 15 +- tools/profiler/src/performance_report.cpp | 5 +- tools/profiler/src/performance_result.cu | 4 +- tools/profiler/src/problem_space.cpp | 42 +- .../src/rank_2k_operation_profiler.cu | 6 +- .../profiler/src/rank_k_operation_profiler.cu | 6 +- .../src/sparse_gemm_operation_profiler.cu | 6 +- tools/profiler/src/symm_operation_profiler.cu | 6 +- tools/profiler/src/trmm_operation_profiler.cu | 6 +- .../util/include/cutlass/util/print_error.hpp | 7 +- .../cutlass/util/reference/host/gett.hpp | 4 +- 428 files changed, 22241 insertions(+), 21750 deletions(-) create mode 100644 examples/python/04_epilogue_visitor.ipynb create mode 100644 include/cute/numeric/integral_ratio.hpp create mode 100644 include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h create mode 100644 include/cutlass/epilogue/threadblock/fusion/visitor_2x.hpp create mode 100644 include/cutlass/epilogue/threadblock/fusion/visitor_compute.hpp create mode 100644 include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp create mode 100644 include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp rename python/cutlass/cpp/include/layout/layout.h => include/cutlass/epilogue/threadblock/fusion/visitors.hpp (81%) rename python/cutlass/cpp/include/arch.h => include/cutlass/gemm/gemm_enumerated_types.h (52%) create mode 100644 include/cutlass/gemm/kernel/default_gemm_universal_with_visitor.h create mode 100644 include/cutlass/gemm/kernel/gemm_universal_with_visitor.h create mode 100644 include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h create mode 100644 include/cutlass/gemm/kernel/tile_scheduler_params.h rename python/cutlass/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h => include/cutlass/gemm_coord.hpp (69%) rename python/cutlass/cpp/compiler.h => include/cutlass/kernel_hardware_info.h (57%) create mode 100644 include/cutlass/numeric_size.h rename include/cutlass/{workspace.hpp => workspace.h} (87%) create mode 100644 media/docs/build/building_in_windows_with_visual_studio.md create mode 100644 media/docs/build/building_with_clang_as_host_compiler.md rename python/cutlass/backend/{test => evt}/__init__.py (85%) create mode 100644 python/cutlass/backend/evt/backend/__init__.py create mode 100644 python/cutlass/backend/evt/backend/emitter_base.py create mode 100644 python/cutlass/backend/evt/backend/sm80_emitter.py create mode 100644 python/cutlass/backend/evt/backend/sm80_nodes.py create mode 100644 python/cutlass/backend/evt/backend/sm90_emitter.py create mode 100644 python/cutlass/backend/evt/backend/sm90_nodes.py create mode 100644 python/cutlass/backend/evt/epilogue.py create mode 100644 python/cutlass/backend/evt/frontend/__init__.py create mode 100644 python/cutlass/backend/evt/frontend/frontend_base.py create mode 100644 python/cutlass/backend/evt/frontend/python_ast.py create mode 100644 python/cutlass/backend/evt/ir/__init__.py create mode 100644 python/cutlass/backend/evt/ir/compute_nodes.py create mode 100644 python/cutlass/backend/evt/ir/dag_ir.py create mode 100644 python/cutlass/backend/evt/ir/layout_algorithm.py create mode 100644 python/cutlass/backend/evt/ir/layout_nodes.py create mode 100644 python/cutlass/backend/evt/ir/load_nodes.py create mode 100644 python/cutlass/backend/evt/ir/node.py create mode 100644 python/cutlass/backend/evt/ir/store_nodes.py create mode 100644 python/cutlass/backend/evt/ir/tensor.py create mode 100644 python/cutlass/backend/evt/passes/__init__.py create mode 100644 python/cutlass/backend/evt/passes/graph_drawer.py create mode 100644 python/cutlass/backend/evt/passes/pass_argument_type.py create mode 100644 python/cutlass/backend/evt/passes/pass_dag_2_tree.py create mode 100644 python/cutlass/backend/evt/passes/pass_fix_element_d.py create mode 100644 python/cutlass/backend/evt/passes/pass_get_impl.py create mode 100644 python/cutlass/backend/evt/passes/pass_layout_elimination.py create mode 100644 python/cutlass/backend/evt/passes/pass_manager.py create mode 100644 python/cutlass/backend/evt/passes/pass_no_op_elimination.py create mode 100644 python/cutlass/backend/evt/passes/pass_preprocess_red.py create mode 100644 python/cutlass/backend/evt/passes/pass_shape_type_propagation.py create mode 100644 python/cutlass/backend/evt/passes/smem_size_calculator.py delete mode 100644 python/cutlass/backend/parser.py delete mode 100644 python/cutlass/backend/test/conv2d_testbed.py delete mode 100644 python/cutlass/backend/test/gemm_grouped_testbed.py delete mode 100644 python/cutlass/backend/test/gemm_testbed.py delete mode 100644 python/cutlass/backend/test/profiler.py delete mode 100644 python/cutlass/backend/utils/reference_model.py delete mode 100644 python/cutlass/cpp/cutlass_bindings.cpp delete mode 100644 python/cutlass/cpp/include/conv/conv_problem_size.h delete mode 100644 python/cutlass/cpp/include/conv/convolution.h delete mode 100644 python/cutlass/cpp/include/conv/host.h delete mode 100644 python/cutlass/cpp/include/epilogue/epilogue_visitor_generic.h delete mode 100644 python/cutlass/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h delete mode 100644 python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h delete mode 100644 python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h delete mode 100644 python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h delete mode 100644 python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h delete mode 100644 python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h delete mode 100644 python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h delete mode 100644 python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h delete mode 100644 python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_input.h delete mode 100644 python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_output.h delete mode 100644 python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h delete mode 100644 python/cutlass/cpp/include/epilogue/epilogue_visitor_with_layernorm.h delete mode 100644 python/cutlass/cpp/include/gemm/gemm.h delete mode 100644 python/cutlass/cpp/include/gemm/gemm_universal_with_visitor.h delete mode 100644 python/cutlass/cpp/include/layout/matrix.h delete mode 100644 python/cutlass/cpp/include/layout/tensor.h delete mode 100644 python/cutlass/cpp/include/swizzling.h delete mode 100644 python/cutlass/cpp/include/tensor_coord.h delete mode 100644 python/cutlass/cpp/include/tensor_ref_view.h delete mode 100644 python/cutlass/cpp/include/types.h delete mode 100644 python/cutlass/cpp/library.h delete mode 100644 python/cutlass/cpp/test/conv/conv_problems.h delete mode 100644 python/cutlass/cpp/test/conv/convolution.h delete mode 100644 python/cutlass/cpp/test/conv/host.h delete mode 100644 python/cutlass/cpp/test/gemm/gemm.h delete mode 100644 python/cutlass/cpp/test/gemm/host.h create mode 100644 python/cutlass/epilogue/__init__.py rename python/cutlass/{ => epilogue}/epilogue.py (69%) rename python/cutlass/{backend/tensor_ref.py => epilogue/evt_ops.py} (60%) create mode 100644 python/cutlass/profiler/__init__.py create mode 100644 python/cutlass/profiler/event_profiler.py create mode 100644 python/cutlass/shape.py create mode 100644 python/cutlass_library/__init__.py rename {tools/library/scripts => python/cutlass_library}/conv2d_operation.py (89%) rename {tools/library/scripts => python/cutlass_library}/conv3d_operation.py (85%) rename {tools/library/scripts => python/cutlass_library}/gemm_operation.py (96%) rename {tools/library/scripts => python/cutlass_library}/generator.py (98%) rename {tools/library/scripts => python/cutlass_library}/library.py (92%) create mode 100644 python/cutlass_library/manifest.py rename {tools/library/scripts => python/cutlass_library}/rank_2k_operation.py (85%) rename {tools/library/scripts => python/cutlass_library}/rank_k_operation.py (85%) rename {tools/library/scripts => python/cutlass_library}/symm_operation.py (85%) rename {tools/library/scripts => python/cutlass_library}/trmm_operation.py (85%) create mode 100644 python/pycute/__init__.py create mode 100644 python/pycute/int_tuple.py create mode 100644 python/pycute/layout.py create mode 100644 python/pycute/swizzle.py create mode 100644 python/pycute/typing.py rename test/python/backend/gemm/run_all_tests.py => python/setup_library.py (84%) create mode 100644 python/setup_pycute.py delete mode 100644 test/python/backend/conv/__init__.py delete mode 100644 test/python/backend/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py delete mode 100644 test/python/backend/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py delete mode 100644 test/python/backend/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py delete mode 100644 test/python/backend/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py delete mode 100644 test/python/backend/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py delete mode 100644 test/python/backend/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py delete mode 100644 test/python/backend/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py delete mode 100644 test/python/backend/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py delete mode 100644 test/python/backend/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py delete mode 100644 test/python/backend/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py delete mode 100644 test/python/backend/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py delete mode 100644 test/python/backend/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py delete mode 100644 test/python/backend/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py delete mode 100644 test/python/backend/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py delete mode 100644 test/python/backend/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py delete mode 100644 test/python/backend/gemm/__init__.py delete mode 100644 test/python/backend/gemm/gemm_bf16_sm80.py delete mode 100644 test/python/backend/gemm/gemm_bf16_sm90.py delete mode 100644 test/python/backend/gemm/gemm_f16_sm80.py delete mode 100644 test/python/backend/gemm/gemm_f16_sm90.py delete mode 100644 test/python/backend/gemm/gemm_f32_sm80.py delete mode 100644 test/python/backend/gemm/gemm_f64_sm80.py delete mode 100644 test/python/backend/gemm/gemm_f64_sm90.py delete mode 100644 test/python/backend/gemm/gemm_grouped_sm80.py delete mode 100644 test/python/backend/gemm/gemm_s8_sm80.py delete mode 100644 test/python/backend/gemm/gemm_s8_sm90.py delete mode 100644 test/python/conv2d/conv2d_test_utils.py create mode 100644 test/python/cutlass/conv2d/conv2d_problem_sizes.py rename test/python/{ => cutlass}/conv2d/conv2d_sm80.py (79%) create mode 100644 test/python/cutlass/conv2d/conv2d_test_utils.py rename test/python/{backend/conv => cutlass/conv2d}/run_all_tests.py (85%) rename test/python/{ => cutlass}/emit/pytorch.py (89%) create mode 100644 test/python/cutlass/evt/evt_compute_sm80_90.py create mode 100644 test/python/cutlass/evt/evt_layout_sm80_90.py create mode 100644 test/python/cutlass/evt/evt_load_sm80_90.py create mode 100644 test/python/cutlass/evt/evt_mixed_sm80_90.py create mode 100644 test/python/cutlass/evt/evt_store_sm80_90.py rename test/python/{conv2d => cutlass/evt}/run_all_tests.py (93%) create mode 100644 test/python/cutlass/evt/utils/evt_testbed.py rename test/python/{ => cutlass}/gemm/gemm_batched.py (95%) rename test/python/{ => cutlass}/gemm/gemm_f16_sm80.py (99%) rename test/python/{ => cutlass}/gemm/gemm_f16_sm90.py (99%) rename test/python/{ => cutlass}/gemm/gemm_f32_sm80.py (98%) rename test/python/{ => cutlass}/gemm/gemm_f64_sm80.py (98%) rename test/python/{ => cutlass}/gemm/gemm_f64_sm90.py (97%) rename test/python/{ => cutlass}/gemm/gemm_s8_sm80.py (98%) rename test/python/{ => cutlass}/gemm/gemm_s8_sm90.py (98%) create mode 100644 test/python/cutlass/gemm/gemm_testbed.py rename test/python/{ => cutlass}/gemm/run_all_tests.py (93%) rename {python/cutlass/backend/test => test/python/cutlass/gemm}/utils.py (70%) rename test/python/{ => cutlass}/interface/conv2d_interface.py (99%) create mode 100644 test/python/cutlass/interface/evt_interface.py rename test/python/{ => cutlass}/interface/gemm_interface.py (98%) rename test/python/{ => cutlass}/interface/utils.py (91%) create mode 100644 test/python/pycute/run_all_tests.py create mode 100644 test/python/pycute/test_coalesce.py create mode 100644 test/python/pycute/test_complement.py create mode 100644 test/python/pycute/test_composition.py create mode 100644 test/python/pycute/test_int_tuple.py create mode 100644 test/python/pycute/test_left_inverse.py create mode 100644 test/python/pycute/test_right_inverse.py create mode 100644 test/python/pycute/test_typing.py create mode 100644 test/unit/core/cpp11.cu delete mode 100644 test/unit/cute/core/constant_arithmetic.cpp rename python/cutlass/cpp/include/gemm/host.h => test/unit/cute/core/constants.cpp (57%) create mode 100644 test/unit/cute/hopper/tma_load_testbed.hpp create mode 100644 test/unit/cute/hopper/tma_store_testbed.hpp delete mode 100644 tools/library/scripts/__init__.py delete mode 100644 tools/library/scripts/manifest.py delete mode 100644 tools/library/scripts/rt.py rename tools/profiler/{src => include/cutlass/profiler}/conv2d_operation_profiler.h (99%) rename tools/profiler/{src => include/cutlass/profiler}/conv3d_operation_profiler.h (99%) rename tools/profiler/{src => include/cutlass/profiler}/cublas_helpers.h (100%) rename tools/profiler/{src => include/cutlass/profiler}/cudnn_helpers.h (100%) rename tools/profiler/{src => include/cutlass/profiler}/cutlass_profiler.h (100%) rename tools/profiler/{src => include/cutlass/profiler}/debug.h (100%) rename tools/profiler/{src => include/cutlass/profiler}/device_allocation.h (100%) rename tools/profiler/{src => include/cutlass/profiler}/device_context.h (100%) rename tools/profiler/{src => include/cutlass/profiler}/enumerated_types.h (100%) rename tools/profiler/{src => include/cutlass/profiler}/gemm_operation_profiler.h (97%) rename tools/profiler/{src => include/cutlass/profiler}/gpu_timer.h (100%) rename tools/profiler/{src => include/cutlass/profiler}/operation_profiler.h (100%) rename tools/profiler/{src => include/cutlass/profiler}/options.h (97%) rename tools/profiler/{src => include/cutlass/profiler}/performance_report.h (100%) rename tools/profiler/{src => include/cutlass/profiler}/performance_result.h (100%) rename tools/profiler/{src => include/cutlass/profiler}/problem_space.h (98%) rename tools/profiler/{src => include/cutlass/profiler}/rank_2k_operation_profiler.h (100%) rename tools/profiler/{src => include/cutlass/profiler}/rank_k_operation_profiler.h (100%) rename tools/profiler/{src => include/cutlass/profiler}/reduction_operation_profiler.h (100%) rename tools/profiler/{src => include/cutlass/profiler}/sparse_gemm_operation_profiler.h (100%) rename tools/profiler/{src => include/cutlass/profiler}/symm_operation_profiler.h (100%) rename tools/profiler/{src => include/cutlass/profiler}/trmm_operation_profiler.h (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 039fc805da..7bb701ed3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,15 @@ # NVIDIA CUTLASS Changelog +## [3.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1) (2023-09-22) +* Python support SM90 Epilogue Visitor Tree (EVT) on top of the C++ support released in 3.2.0. +* SM80 EVT support in C++ and Python. +* Other SM90 epilogue improvements. +* Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details. +* Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](/python/README.md) for details. +* SM90 TF32 kernel improvements for all layouts. +* SM90 rasterization direction support in the CUTLASS profiler. +* Improvement for CUTLASS profiler build times. +* Remove Python-C++ bindings. ## [3.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.0) (2023-08-03) @@ -91,7 +101,7 @@ * [Few channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities * [Fixed channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size * [Unit tests](/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) - * [Python-based instance emitter](/tools/library/scripts/generator.py) in the CUTLASS Library and support in the Profiler + * [Python-based instance emitter](/python/cutlass_library/generator.py) in the CUTLASS Library and support in the Profiler * [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores * Supported types: f32, cf32, f64, cf64, tf32x3, complex tf32x3 * [HERK](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](/tools/library/scripts/rank_k_operation.py) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2d4f9cc3a3..b880de0a52 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,7 +40,7 @@ endif() message(STATUS "CMake Version: ${CMAKE_VERSION}") set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set") -project(CUTLASS VERSION 3.2.0 LANGUAGES CXX) +project(CUTLASS VERSION 3.2.1 LANGUAGES CXX) include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) if (CUDA_VERSION VERSION_LESS 11.3) @@ -85,6 +85,21 @@ message(STATUS "Default Install Location: ${CMAKE_INSTALL_PREFIX}") set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.") # 0 - Sanity, 1 - Release-Quality, 2 - Exhaustive +find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED) + +# Install cutlass_library Python package +execute_process( + WORKING_DIRECTORY ${CUTLASS_DIR}/python + COMMAND ${Python3_EXECUTABLE} ${CUTLASS_DIR}/python/setup_library.py develop --user + RESULT_VARIABLE cutlass_lib_GENERATOR_INSTALL_RESULT + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log +) + +if(NOT cutlass_lib_GENERATOR_INSTALL_RESULT EQUAL 0) + message(FATAL_ERROR "Error installing cutlass_library package. See ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log") +endif() + ################################################################################ set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library") @@ -92,10 +107,16 @@ if(CUTLASS_ENABLE_HEADERS_ONLY) set(CUTLASS_ENABLE_EXAMPLES_INIT OFF) set(CUTLASS_ENABLE_TOOLS_INIT ON) set(CUTLASS_ENABLE_LIBRARY_INIT OFF) + set(CUTLASS_ENABLE_TESTS_INIT OFF) else() set(CUTLASS_ENABLE_EXAMPLES_INIT ON) set(CUTLASS_ENABLE_TOOLS_INIT ON) set(CUTLASS_ENABLE_LIBRARY_INIT ON) + if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME}) + set(CUTLASS_ENABLE_TESTS_INIT ON) + else() + set(CUTLASS_ENABLE_TESTS_INIT OFF) + endif() endif() set(CUTLASS_TEST_UNIT_ENABLE_WARNINGS OFF CACHE BOOL "Enable warnings on waived unit tests.") @@ -104,20 +125,10 @@ set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable C set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools") set(CUTLASS_ENABLE_LIBRARY ${CUTLASS_ENABLE_LIBRARY_INIT} CACHE BOOL "Enable CUTLASS Library") set(CUTLASS_ENABLE_PROFILER ${CUTLASS_ENABLE_LIBRARY} CACHE BOOL "Enable CUTLASS Profiler") -set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUTLASS Proformance") - -if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME}) - set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_LIBRARY}) -else() - set(CUTLASS_ENABLE_TESTS_INIT OFF) -endif() +set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUTLASS Performance") set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests") - -if (CUTLASS_ENABLE_TESTS) - include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake) -endif() - +set(CUTLASS_ENABLE_GTEST_UNIT_TESTS ${CUTLASS_ENABLE_TESTS} CACHE BOOL "Enable CUTLASS GTest-based Unit Tests") ################################################################################ set(CUTLASS_NVCC_ARCHS_SUPPORTED "") @@ -285,6 +296,8 @@ if (CUTLASS_ENABLE_TENSOR_CORE_MMA) endif() + + if (NOT MSVC AND CUTLASS_NVCC_KEEP) # MSVC flow handles caching already, but for other generators we handle it here. set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files") @@ -395,6 +408,7 @@ endif() # Some tests require this build option in order to link. if (MSVC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /bigobj") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /bigobj") endif() function(cutlass_apply_cuda_gencode_flags TARGET) @@ -572,11 +586,17 @@ target_include_directories( $ $ $ - $ $ $ ) +# Mark CTK headers as system to supress warnings from them +target_include_directories( + CUTLASS + SYSTEM INTERFACE + $ + ) + install( DIRECTORY ${CUTLASS_INCLUDE_DIR}/ @@ -633,6 +653,11 @@ endif() include(CTest) enable_testing() + +if (CUTLASS_ENABLE_GTEST_UNIT_TESTS) + include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake) +endif() + if (NOT TARGET test_all) add_custom_target(test_all) endif() @@ -818,7 +843,7 @@ function(cutlass_add_executable_tests NAME TARGET) set(CUTLASS_CTEST_GENERATED_FILES ${CUTLASS_CTEST_GENERATED_FILES};ctest/${TEST_NAME}/CTestTestfile.${TEST_NAME}.cmake CACHE INTERNAL "") - if (CUTLASS_INSTALL_TESTS) + if (CUTLASS_INSTALL_TESTS) file(GENERATE OUTPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake" @@ -831,7 +856,7 @@ function(cutlass_add_executable_tests NAME TARGET) RENAME CTestTestfile.${TEST_NAME}.cmake ) - endif() + endif() endfunction() @@ -849,7 +874,9 @@ endif() if (CUTLASS_ENABLE_TESTS) add_subdirectory(test) + if (CUTLASS_ENABLE_GTEST_UNIT_TESTS) add_dependencies(test_all test_unit) + endif() endif() if (CUTLASS_INSTALL_TESTS) diff --git a/CUDA.cmake b/CUDA.cmake index 32bd8a58b4..b9c60bcd0b 100644 --- a/CUDA.cmake +++ b/CUDA.cmake @@ -305,10 +305,10 @@ function(cutlass_add_library NAME) if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang") cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS}) - add_library(${NAME} ${TARGET_SOURCE_ARGS}) + add_library(${NAME} ${TARGET_SOURCE_ARGS} "") else() set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) - cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS}) + cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS} "") endif() cutlass_apply_standard_compile_options(${NAME}) diff --git a/README.md b/README.md index 7ed86c117f..2d09925798 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ In addition to GEMMs, CUTLASS implements high-performance convolution via the im # What's New in CUTLASS 3.2 -CUTLASS 3.2 is an update to CUTLASS adding: +CUTLASS 3.2.0 is an update to CUTLASS adding: - New warp-specialized persistent FP8 GEMM kernel [kernel schedules](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](/examples/54_hopper_fp8_warp_specialized_gemm). - New [Epilogue Visitor Tree (EVT)](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue. - [Stream-K](/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release. @@ -53,6 +53,14 @@ CUTLASS 3.2 is an update to CUTLASS adding: - New CUTLASS 2D Convolution Python interface. New [example](/examples/python/03_basic_conv2d.ipynb) here. - Support for Windows (MSVC) builds. +CUTLASS 3.2.1 is an update to CUTLASS adding: +- Python support SM90 Epilogue Visitor Tree (EVT) on top of the C++ support released in 3.2.0. +- SM80 EVT support in C++ and Python. +- Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details. +- Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](/python/README.md) for details. +- SM90 TF32 kernel improvements for all layouts. +- SM90 rasterization direction support in the CUTLASS profiler. +- Improvement for CUTLASS profiler build times. Minimum requirements: @@ -176,7 +184,8 @@ CUTLASS is a header-only template library and does not need to be built to be us projects. Client applications should target CUTLASS's `include/` directory in their include paths. -CUTLASS unit tests, examples, and utilities can be build with CMake starting version 3.12. +CUTLASS unit tests, examples, and utilities can be build with CMake. +The minimum version of CMake is given in the [Quickstart guide](media/docs/quickstart.md). Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed on your system. @@ -512,7 +521,7 @@ reference_device: Passed ## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler - Please follow the links for more CMake examples on selectively compiling CUTLASS kernels: - [GEMM CMake Examples](media/docs/quickstart.md#gemm-cmake-examples) - - [Implicit GEMM conovlution CMake Examples](media/docs/quickstart.md#convolution-cmake-examples) + - [Implicit GEMM convolution CMake Examples](media/docs/quickstart.md#convolution-cmake-examples) - [Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md) diff --git a/cmake/NvidiaCutlassConfig.cmake b/cmake/NvidiaCutlassConfig.cmake index 701ecb4af4..56d1c45076 100644 --- a/cmake/NvidiaCutlassConfig.cmake +++ b/cmake/NvidiaCutlassConfig.cmake @@ -2,6 +2,11 @@ get_filename_component(NvidiaCutlass_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH include(CMakeFindDependencyMacro) -if(NOT TARGET nvidia::cutlass::CUTLASS) - include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake") +if(TARGET nvidia::cutlass::CUTLASS) + return() endif() + +include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake") + +# For backward compatibility with the old name +add_library(cutlass_lib ALIAS cutlass_library) diff --git a/examples/08_turing_tensorop_gemm/CMakeLists.txt b/examples/08_turing_tensorop_gemm/CMakeLists.txt index e9d659e192..a240bcc97f 100644 --- a/examples/08_turing_tensorop_gemm/CMakeLists.txt +++ b/examples/08_turing_tensorop_gemm/CMakeLists.txt @@ -31,6 +31,5 @@ cutlass_example_add_executable( 08_turing_tensorop_gemm turing_tensorop_gemm.cu - DISABLE_TESTS ON ) diff --git a/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu b/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu index f627b842a5..c5498adf33 100644 --- a/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu +++ b/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu @@ -291,8 +291,8 @@ int run() { LayoutInputB, ElementOutput, LayoutOutput, - ElementComputeEpilogue, - ElementComputeEpilogue> + int32_t, + int32_t> gemm_device; // Launch device reference gemm kernel @@ -355,4 +355,3 @@ int main() { return run(); } - diff --git a/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu b/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu index ade0b97947..6f234410c2 100644 --- a/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu +++ b/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu @@ -143,7 +143,6 @@ compare if the output from CUTLASS kernel is same as the reference implicit GEMM #include "cutlass/util/tensor_view_io.h" #include "helper.h" - // The code section below describes datatype for input, output tensors and computation between // elements using ElementAccumulator = int32_t; // Data type of accumulator @@ -675,7 +674,6 @@ Result profile_convolution(Options const &options) { return result; } - ///////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char const **args) { @@ -762,11 +760,7 @@ int main(int argc, char const **args) { Result::print_header(std::cout, options) << std::endl; result.print(std::cout, 1, options) << std::endl; } - return 0; } ///////////////////////////////////////////////////////////////////////////////////////////////// - - - diff --git a/examples/12_gemm_bias_relu/CMakeLists.txt b/examples/12_gemm_bias_relu/CMakeLists.txt index abe61be1ce..5d4dac6cf0 100644 --- a/examples/12_gemm_bias_relu/CMakeLists.txt +++ b/examples/12_gemm_bias_relu/CMakeLists.txt @@ -31,6 +31,5 @@ cutlass_example_add_executable( 12_gemm_bias_relu gemm_bias_relu.cu - DISABLE_TESTS ON ) diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu index 64955f8f83..07b583469e 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu @@ -220,7 +220,6 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() { return pass; } - int main() { std::vectorfuncs = { @@ -229,10 +228,6 @@ int main() { }; return testRun(75, funcs, "conv int8 RF residency"); - } - - //////////////////////////////////////////////////////////////////////////////// - diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu index 7f82518123..9886be0d0f 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu @@ -39,7 +39,6 @@ #include "device/b2b_implicit_gemm_convolution.h" #include "b2b_interleaved_conv2d_run.h" #include "test_run.h" - //////////////////////////////////////////////////////////////////////////////// cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 ( @@ -219,20 +218,13 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() { return pass; } - - int main() { - std::vectorfuncs = { &run_nonfused_conv2d_fprop_optimized_s8_sm75, &run_fused_conv2d_fprop_optimized_s8_sm75_shmem }; return testRun(75, funcs, "conv int8 shmem staging"); - } - - //////////////////////////////////////////////////////////////////////////////// - diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu index 565cca7e5c..3872caa22f 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu @@ -195,7 +195,6 @@ bool run_fused_gemm_s8_rf_res() { return passed; } - int main() { std::vectorfuncs = { @@ -204,9 +203,6 @@ int main() { }; return testRun(75, funcs, "gemm int8 RF residency"); - - } - //////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu index 8719d74839..d1ab01945d 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu @@ -43,7 +43,6 @@ #include "device/b2b_gemm.h" #include "b2b_interleaved_gemm_run.h" #include "test_run.h" - //////////////////////////////////////////////////////////////////////////////// cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*640, 64, 576); @@ -197,18 +196,13 @@ bool run_fused_gemm_s8_shmem() { return passed; } - int main() { std::vectorfuncs = { &run_nonfused_gemm_s8, &run_fused_gemm_s8_shmem }; - return testRun(75, funcs, "gemm int8 shmem staing"); - - } - //////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h b/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h index 42ef4110a8..62efba26a5 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h +++ b/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h @@ -90,34 +90,6 @@ struct GroupedThreadblockSwizzle : detail::GroupedThreadblockSwizzleBase { } }; -template < - typename ThreadblockShape, - typename LayoutC, - cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, - int PrefetchTileCount = 128, - int ThreadCount = PrefetchTileCount> -struct GemmGroupedThreadblockSwizzle : GroupedThreadblockSwizzle< - cutlass::gemm::kernel::GemmGroupedProblemVisitor< - ThreadblockShape, - GroupScheduleMode_, - PrefetchTileCount, - ThreadCount, - platform::is_same::value - > - > { - using Base = GroupedThreadblockSwizzle::value>>; - - CUTLASS_HOST_DEVICE - GemmGroupedThreadblockSwizzle(typename Base::ProblemVisitor::Params& params, - typename Base::ProblemVisitor::SharedStorage& shared_storage, - int block_idx) : Base(params, shared_storage, block_idx) {} -}; - template < typename ThreadblockShape, typename LayoutC, diff --git a/examples/24_gemm_grouped/CMakeLists.txt b/examples/24_gemm_grouped/CMakeLists.txt index db3479f4d2..054b96d1ed 100644 --- a/examples/24_gemm_grouped/CMakeLists.txt +++ b/examples/24_gemm_grouped/CMakeLists.txt @@ -31,6 +31,7 @@ cutlass_example_add_executable( 24_gemm_grouped - gemm_grouped.cu + gemm_grouped.cu ) + diff --git a/examples/40_cutlass_py/README.md b/examples/40_cutlass_py/README.md index d33a6d5371..c670e34072 100644 --- a/examples/40_cutlass_py/README.md +++ b/examples/40_cutlass_py/README.md @@ -1,27 +1,4 @@ # PyCUTLASS Examples -**NOTE:** This directory contains examples for PyCUTLASS, a Python library providing low-level -building blocks for emitting CUTLASS C++ kernels. For examples using CUTLASS's Pythonic interface, -see the [examples/python](/examples/python) directory. - -Two types of examples are provided: -* _Basic examples_: minimal examples that illustrate how to set up GEMMs, convolutions, and grouped GEMM operations -* [_Customizable examples_](customizable): examples that allow one to specify a variety of template parameters for the given kernel - -## Setting up the Python interface -Please follow the instructions [here](/python/README.md#installation) to set up the PyCUTLASS. - -## Running examples -Each of the basic examples can be run as follows: -```shell -# Run the GEMM example -python gemm.py - -# Run the Conv2d example -python conv2d.py - -# Run the grouped GEMM example -python gemm_grouped.py -``` - -To run the customizable examples, refer to the README in the [customizable](customizable) directory. +This directory contains deprecated examples for PyCUTLASS, a precursor to the CUTLASS Python interface. +For examples of using CUTLASS's actively-maintained Pythonic interface, see the [examples/python](/examples/python) directory. diff --git a/examples/40_cutlass_py/conv2d.py b/examples/40_cutlass_py/conv2d.py index a21f97690c..5e7b8e24e5 100644 --- a/examples/40_cutlass_py/conv2d.py +++ b/examples/40_cutlass_py/conv2d.py @@ -33,10 +33,14 @@ Basic example of using the CUTLASS Python interface to run a 2d convolution """ +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + import argparse -import torch import numpy as np -import sys +import torch import cutlass_bindings import cutlass.backend as pycutlass diff --git a/examples/40_cutlass_py/customizable/README.md b/examples/40_cutlass_py/customizable/README.md index cd25c69f3f..e8aeee9e71 100644 --- a/examples/40_cutlass_py/customizable/README.md +++ b/examples/40_cutlass_py/customizable/README.md @@ -165,28 +165,3 @@ Example 7: GELU ```python python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu ``` -### Epilogue Visitor Tree -Example 1: -```python -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -``` -Example 2: -```python -python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -``` -Example 3: -```python -python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -``` -Example 4: -```python -python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -``` -Example 5: -```python -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3 -``` -Example 6: -```python -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3 -``` diff --git a/examples/40_cutlass_py/customizable/conv2d.py b/examples/40_cutlass_py/customizable/conv2d.py index 6fb2494473..01e4133e7c 100644 --- a/examples/40_cutlass_py/customizable/conv2d.py +++ b/examples/40_cutlass_py/customizable/conv2d.py @@ -29,13 +29,18 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################ + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + import numpy as np import cutlass.backend as pycutlass from cutlass.backend import * from cutlass.backend.utils.device import device_cc from cutlass.backend.conv2d_operation import * from cutlass.backend.utils.reference_model import Conv2dReferenceModule -import sys import torch.nn.functional as F import argparse diff --git a/examples/40_cutlass_py/customizable/gemm.py b/examples/40_cutlass_py/customizable/gemm.py index 745f6aac2b..d98ffe884e 100644 --- a/examples/40_cutlass_py/customizable/gemm.py +++ b/examples/40_cutlass_py/customizable/gemm.py @@ -29,13 +29,18 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################ + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + import numpy as np import cutlass.backend as pycutlass from cutlass.backend import * from cutlass.backend.utils.device import device_cc import cutlass_bindings from bfloat16 import bfloat16 -import sys import argparse @@ -100,8 +105,6 @@ parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], help="This option describes the epilogue part of the kernel") -parser.add_argument("-epv", "--epilogue_visitor", default=None, - type=str, choices=['RowReduction', 'ColumnReduction', 'RowBroadcast', 'ColumnBroadcast'], help="epilogue visitor for more complex epilogues") # swizzling parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle", "BatchedIdentitySwizzle"], @@ -193,71 +196,10 @@ swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor) -visitor = args.epilogue_visitor is not None - -if args.epilogue_visitor == "ColumnReduction": - class ColumnReduction_(EpilogueVisitTree): - def __call__( - self, accum: 'tensor', c: 'tensor', - alpha: 'scalar', beta: 'scalar'): - # - D = alpha * accum + beta * c - reduction = reduction_op(D, "column", "Add", args.threadblock_shape[0]) - return D, reduction - epilogue_functor = ColumnReduction_( - epilogue_functor, tile_description, math_inst.element_accumulator, - C.alignment, element_epilogue, C.element) - epilogue_functor.initialize() -elif args.epilogue_visitor == "RowReduction": - class RowReduction_(EpilogueVisitTree): - def __call__( - self, accum: 'tensor', c: 'tensor', - alpha: 'scalar', beta: 'scalar'): - # - D = alpha * accum + tanh.numpy(beta * c) - reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1]) - return D, reduction - epilogue_functor = RowReduction_( - epilogue_functor, tile_description, math_inst.element_accumulator, - C.alignment, element_epilogue, C.element) - epilogue_functor.initialize() - -elif args.epilogue_visitor == "RowBroadcast": - class RowBroadcast_(EpilogueVisitTree): - def __call__( - self, accum: 'tensor', c: 'tensor', - vector: 'row', alpha: 'scalar', beta: 'scalar'): - # - T = accum + vector - scale_T = alpha * T - Z = relu.numpy(scale_T + beta * c) - return Z, T - epilogue_functor = RowBroadcast_( - epilogue_functor, tile_description, math_inst.element_accumulator, - C.alignment, element_epilogue, C.element) - epilogue_functor.initialize() -elif args.epilogue_visitor == "ColumnBroadcast": - class ColumnBroadcast_(EpilogueVisitTree): - def __call__( - self, accum: 'tensor', c: 'tensor', - vector: 'column', alpha: 'scalar', beta: 'scalar'): - # - T = accum + vector - scale_T = leaky_relu.numpy(alpha * T, 0.2) - Z = scale_T + beta * c - return Z, T - epilogue_functor = ColumnBroadcast_( - epilogue_functor, tile_description, math_inst.element_accumulator, - C.alignment, element_epilogue, C.element) - epilogue_functor.initialize() -else: - epilogue_functor = epilogue_functor - operation = GemmOperationUniversal( arch=args.compute_capability, tile_description=tile_description, A=A, B=B, C=C, - epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor, - visitor=visitor + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor ) if args.print_cuda: @@ -347,38 +289,7 @@ def __call__( shape=(args.batch * problem_size.m() * problem_size.n(),) ).astype(getattr(np, args.element_c)) -if args.epilogue_visitor == "RowReduction": - cta_n = args.threadblock_shape[1] - num_cta_n = (problem_size.n() + cta_n - 1) // cta_n - reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, args.element_c)) - output_op = operation.epilogue_type( - D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()] - ) -elif args.epilogue_visitor == "ColumnReduction": - cta_m = args.threadblock_shape[0] - num_cta_m = (problem_size.m() + cta_m - 1) // cta_m - reduction = np.zeros(shape=(args.batch * problem_size.n() * num_cta_m,), dtype=getattr(np, args.element_c)) - output_op = operation.epilogue_type( - D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()] - ) -elif args.epilogue_visitor == "RowBroadcast": - vector = np.ceil( - np.random.uniform(low=-8.5, high=7.5, size=(args.batch, 1, problem_size.n())) - ).astype(getattr(np, args.element_c)) - tensor_t = np.empty_like(tensor_D) - output_op = operation.epilogue_type( - c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()] - ) -elif args.epilogue_visitor == "ColumnBroadcast": - vector = np.ceil( - np.random.uniform(low=-8.5, high=7.5, size=(args.batch, problem_size.m(), 1)) - ).astype(getattr(np, args.element_c)) - tensor_t = np.empty_like(tensor_D) - output_op = operation.epilogue_type( - c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()] - ) -else: - output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)) +output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)) arguments = GemmArguments( operation=operation, problem_size=problem_size, @@ -411,38 +322,8 @@ def __call__( tensor_D_ref = reference.run( tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta, args.bias, args.batch) -if args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]: - tensor_D_ref = (tensor_D_ref.reshape((args.batch, problem_size.m(), problem_size.n())) + vector).flatten() tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args)) -if args.epilogue_visitor in ["RowReduction", "ColumnReduction"]: - output_op.sync() - accum_ref = reference.run( - tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch) - tensor_D_ref, reduction_ref = epilogue_functor( - accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())), - tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())), - args.alpha, args.beta - ) - tensor_D_ref = tensor_D_ref.flatten() - reduction_ref = reduction_ref.flatten() - assert np.allclose(reduction_ref, reduction, atol=1e-2) - -elif args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]: - output_op.sync() - accum_ref = reference.run( - tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch) - - tensor_D_ref, tensor_T_ref = epilogue_functor( - accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())), - tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())), - vector, args.alpha, args.beta) - - tensor_D_ref = tensor_D_ref.flatten() - tensor_T_ref = tensor_T_ref.flatten() - - assert np.array_equal(tensor_t, tensor_T_ref) - try: assert np.array_equal(tensor_D, tensor_D_ref) except: diff --git a/examples/40_cutlass_py/customizable/gemm_grouped.py b/examples/40_cutlass_py/customizable/gemm_grouped.py index 0cecb328d0..06638b5fed 100644 --- a/examples/40_cutlass_py/customizable/gemm_grouped.py +++ b/examples/40_cutlass_py/customizable/gemm_grouped.py @@ -29,12 +29,17 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################ + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + import numpy as np import cutlass.backend as pycutlass from cutlass.backend import * from cutlass.backend.utils.device import device_cc import csv -import sys import argparse diff --git a/examples/40_cutlass_py/gemm.py b/examples/40_cutlass_py/gemm.py index 17b5d389bc..88fbd79b22 100644 --- a/examples/40_cutlass_py/gemm.py +++ b/examples/40_cutlass_py/gemm.py @@ -33,9 +33,13 @@ Basic example of using the CUTLASS Python interface to run a GEMM """ +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + import argparse import numpy as np -import sys import cutlass_bindings import cutlass.backend as pycutlass diff --git a/examples/40_cutlass_py/gemm_grouped.py b/examples/40_cutlass_py/gemm_grouped.py index 16e25d0c9c..e461ba9db5 100644 --- a/examples/40_cutlass_py/gemm_grouped.py +++ b/examples/40_cutlass_py/gemm_grouped.py @@ -33,9 +33,13 @@ Basic example of using the CUTLASS Python interface to run a grouped GEMM """ +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + import argparse import numpy as np -import sys import cutlass_bindings import cutlass.backend as pycutlass diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py index d47b886363..c6df88cce4 100644 --- a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py @@ -434,14 +434,6 @@ def gen_func_run(self): " if (result != cudaSuccess) {\n" + \ " return Status::kErrorInternal;\n" + \ " }\n" + \ - "\n" + \ - " result = cudaFuncSetAttribute(\n" + \ - " Kernel,\n" + \ - " cudaFuncAttributePreferredSharedMemoryCarveout, 100);\n" + \ - "\n" + \ - " if (result != cudaSuccess) {\n" + \ - " return Status::kErrorInternal;\n" + \ - " }\n" + \ " }\n" + \ " cutlass::Kernel<<>>(params_);\n" + \ " result = cudaGetLastError();\n" + \ diff --git a/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu index db2eff51f3..4e7751f740 100644 --- a/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu +++ b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu @@ -83,6 +83,10 @@ #include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/tensor_view_io.h" +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + #include "helper.h" @@ -120,6 +124,7 @@ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadb using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape) using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape) constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop +constexpr int EVTEpilogueStages = 1; // Number of epilogue stages in EVT // Residual block configuration @@ -166,23 +171,93 @@ using DeviceGemmBasic = cutlass::gemm::device::GemmUniversalWithBroadcast< AlignmentA, AlignmentB>; -// StreamK device GEMM implementation type -using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalStreamkWithBroadcast< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, +// StreamK device GEMM implementation type with EVT +using namespace cute; + +using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementC, + AlignmentC, + EVTEpilogueStages +>; + +using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + +using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, ElementC, + cute::Stride<_0, _1, int32_t> // StrideMNL +>; + +using C1 = cutlass::epilogue::threadblock::VisitorAuxLoad< + OutputTileThreadMap, ElementC, + cute::Stride // StrideMNL +>; + +using C2 = cutlass::epilogue::threadblock::VisitorAuxLoad< + OutputTileThreadMap, ElementC, + cute::Stride // StrideMNL +>; + +using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::plus, ElementCompute, ElementCompute, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT< + Compute0, + Accum, + Bias>; + +using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::plus, ElementCompute, ElementCompute, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT< + Compute1, + EVTCompute0, + C1>; + +using Compute2 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::plus, ElementOutput, ElementCompute, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using EVTCompute2 = cutlass::epilogue::threadblock::Sm80EVT< + Compute2, + EVTCompute1, + C2>; + +using D = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest, + cute::Stride // StrideMNL +>; + +using EVTD = cutlass::epilogue::threadblock::Sm80EVT< + D, + EVTCompute2>; + +using EVTKernelStreamK = + typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, + ElementC, LayoutC, AlignmentC, ElementAccumulator, - OperatorClass, - ArchTag, + ElementCompute, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, - EpilogueOp, + EVTD, cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, NumStages, - AlignmentA, - AlignmentB>; + cutlass::arch::OpMultiplyAdd, + EVTEpilogueStages +>::GemmKernel; +using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalAdapter; ///////////////////////////////////////////////////////////////////////////////////////////////// /// Testbed utility types @@ -360,36 +435,41 @@ typename DeviceGemmStreamK::Arguments args_from_options( cutlass::HostTensor &tensor_Vector/*, cutlass::HostTensor &tensor_Tensor*/ ) -{ +{ + typename EVTD::Arguments callback_args{ + { + { + { + {}, // Accum + {tensor_Vector.device_data(), ElementC(0), {_0{}, _1{}, int32_t(options.problem_size.n())}}, // Bias + {} // Compute0 + }, // EVTCompute0 + {tensor_c1.device_data(), ElementC(0), {options.problem_size.n(), _1{}, options.problem_size.mn().product()}}, // C1 + {} // Compute1 + }, // EVTCompute1 + {tensor_c2.device_data(), ElementC(0), {options.problem_size.n(), _1{}, options.problem_size.mn().product()}}, // C2 + {} // Compute2 + }, // EVTCompute2 + {tensor_d.device_data(), {options.problem_size.n(), _1{}, options.problem_size.mn().product()}}, // D + }; // EVTD + return typename DeviceGemmStreamK::Arguments( cutlass::gemm::GemmUniversalMode::kGemm, // universal mode options.problem_size, // problem_size options.split_k_factor, // batch count / splitk slices - { // epilogue parameters - ElementAccumulator(options.alpha), - ElementAccumulator(options.beta) - }, + callback_args, // argument of EVT callbacks tensor_a.device_data(), // ptr_A tensor_b.device_data(), // ptr_B - tensor_c1.device_data(), // ptr_C1 - tensor_c2.device_data(), // ptr_C2 - tensor_d.device_data(), // ptr_D - tensor_Vector.device_data(), // ptr_Vector - /* tensor_Tensor.device_data(), */nullptr,// ptr_Tensor // We're not storing Tensor + nullptr, // ptr_C (unused) + nullptr, // ptr_D (unused) options.problem_size.mk().product(), // batch_stride_A options.problem_size.nk().product(), // batch_stride_B - options.problem_size.mn().product(), // batch_stride_C1 - options.problem_size.mn().product(), // batch_stride_C2 - options.problem_size.mn().product(), // batch_stride_D - options.problem_size.mn().product(), // batch_stride_Vector - options.problem_size.mn().product(), // batch_stride_Tensor + 0, // batch_stride_C (unused) + 0, // batch_stride_D (unused) tensor_a.layout().stride(0), // stride_a tensor_b.layout().stride(0), // stride_b - tensor_c1.layout().stride(0), // stride_c1 - tensor_c2.layout().stride(0), // stride_c2 - tensor_d.layout().stride(0), // stride_d - /*tensor_Vector.layout().stride(0)*/0, // stride_Vector // Vector stride is always 0 - /*tensor_Tensor.layout().stride(0)*/0, // stride_Tensor // We're not storing Tensor + 0, // stride_c (unused) + 0, // stride_d (unused) options.avail_sms); // avail_sms } diff --git a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu index 110c6e44b1..c99afc05e6 100644 --- a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu +++ b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu @@ -526,7 +526,8 @@ struct ExampleRunner // Forward calls via lambda to avoid specifying template arguments auto gather_call = [](auto&&... args){ gather(static_cast(args)...); }; - auto scatter_call = [](auto&&... args){ scatter(static_cast(args)...); }; + // MSVC doesn't count use inside a false "if constexpr" branch. + [[maybe_unused]] auto scatter_call = [](auto&&... args){ scatter(static_cast(args)...); }; if constexpr (DoGatherA) { run_gather(gather_call, tensor_a, tensor_a_gathered, arguments.gather_A, problem_size.batch(), stride_A); diff --git a/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp index 458cb19554..579122210a 100644 --- a/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp +++ b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp @@ -58,7 +58,7 @@ class GemmGather // Type Aliases // using ProblemShape = ProblemShape_; - using TileScheduleTag = TileScheduler_; + using TileSchedulerTag = TileScheduler_; using TileScheduler = TileScheduler_; static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); diff --git a/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu b/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu index f6291c6e7f..080d703454 100644 --- a/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu +++ b/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu @@ -161,7 +161,7 @@ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; using ElementScalar = typename EpilogueOutputOp::ElementScalar; using ElementAmax = typename EpilogueOutputOp::ElementAmax; -using ActivationFunctor = typename EpilogueOutputOp::ActivationFn; +using ActivationFunctor = typename EpilogueOutputOp::ActivationFn; using StrideA = typename Gemm::GemmKernel::StrideA; using StrideB = typename Gemm::GemmKernel::StrideB; diff --git a/examples/python/00_basic_gemm.ipynb b/examples/python/00_basic_gemm.ipynb index 65c1107fe6..6c8222e0de 100644 --- a/examples/python/00_basic_gemm.ipynb +++ b/examples/python/00_basic_gemm.ipynb @@ -7,9 +7,7 @@ "metadata": {}, "source": [ "# Basic example of using the CUTLASS Python interface\n", - "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs.\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)\n" + "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs.\n" ] }, { diff --git a/examples/python/01_epilogue.ipynb b/examples/python/01_epilogue.ipynb index f7abddd886..13acbffdac 100644 --- a/examples/python/01_epilogue.ipynb +++ b/examples/python/01_epilogue.ipynb @@ -7,9 +7,7 @@ "metadata": {}, "source": [ "# Example of using elementwise activation functions in the CUTLASS Python interface\n", - "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues.\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)" + "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues.\n" ] }, { diff --git a/examples/python/02_pytorch_extension_grouped_gemm.ipynb b/examples/python/02_pytorch_extension_grouped_gemm.ipynb index b0cdb0edfd..ecd7828044 100644 --- a/examples/python/02_pytorch_extension_grouped_gemm.ipynb +++ b/examples/python/02_pytorch_extension_grouped_gemm.ipynb @@ -10,8 +10,6 @@ "This notebook walks through a basic example of using the CUTLASS Python interface to declare\n", "a grouped GEMM kernel and export it as a PyTorch CUDA extension. Note that GEMM and Conv2d can also be exported as PyTorch CUDA extensions. \n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)\n", - "\n", "## Background on grouped GEMM\n", "Grouped GEMM enables one to execute a set of GEMMs (each with potentially different sizes and strides)\n", "in a single CUDA kernel. It can be thought of as a generalized version of a pointer-array GEMM,\n", diff --git a/examples/python/04_epilogue_visitor.ipynb b/examples/python/04_epilogue_visitor.ipynb new file mode 100644 index 0000000000..72547d1999 --- /dev/null +++ b/examples/python/04_epilogue_visitor.ipynb @@ -0,0 +1,221 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "5d24a692", + "metadata": {}, + "source": [ + "# Example of using epilogue visitor in the CUTLASS Python interface\n", + "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues through CUTLASS Epilogue Visitor." + ] + }, + { + "cell_type": "markdown", + "id": "3ca993fe", + "metadata": {}, + "source": [ + "We first import various packages needed for the example, construct the input and output tensors that will be used in our example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63a70a3c", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import cutlass\n", + "from cutlass.epilogue import relu\n", + "from cutlass import Tensor as FakeTensor\n", + "from cutlass.profiler import CUDAEventProfiler\n", + "\n", + "# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n", + "# omit this information.\n", + "print_module = True\n", + "\n", + "# The Epilogue Visitor feature currently only works for SM80 and 90\n", + "from cutlass.backend.utils.device import device_cc\n", + "if device_cc() not in [80, 90]:\n", + " import sys\n", + " sys.exit()\n", + "\n", + "m = 16384\n", + "n = m\n", + "k = 512\n", + "\n", + "type_A = torch.float16\n", + "type_B = torch.float16\n", + "type_C = torch.float16\n", + "type_D = torch.float16\n", + "\n", + "torch.manual_seed(2023)\n", + "scope_min = -4\n", + "scope_max = 4\n", + "tensor_A = torch.ceil(torch.empty(size=(m, k), dtype=type_A, device=\"cuda\").uniform_(scope_min, scope_max))\n", + "tensor_B = torch.ceil(torch.empty(size=(k, n), dtype=type_B, device=\"cuda\").uniform_(scope_min, scope_max))\n", + "tensor_C = torch.ceil(torch.empty(size=(m, n), dtype=type_C, device=\"cuda\").uniform_(scope_min, scope_max))\n", + "tensor_D = torch.zeros_like(tensor_C)\n", + "\n", + "plan = cutlass.op.Gemm(element=torch.float16, layout=cutlass.LayoutType.RowMajor, element_accumulator=torch.float32)" + ] + }, + { + "cell_type": "markdown", + "id": "1eb0d95b", + "metadata": {}, + "source": [ + "## Define the epilogue visitor functor\n", + "The epilogue functor can be defined as a simple Python function and a set of example tensors for inputs and outputs. The example below illustrates a complex epilogue under the directed acyclic graph structure (`F` is used twice). The epilogue takes source tensors in different ranks: `alpha`, `beta` are scalars, `bias` is a column vector to broadcast, and `C`, `aux` are matrices. It contains various math operations from basic arithmatic operations and built-in callable functions like `relu`. It also accomodates multiple outputs `D` and `F`. Note that there are some restrictions on syntax.\n", + "* Each named variable must be assigned exactly once and defined before it it used.\n", + "* Reserved names: `accum`, `C`, and `D` are reserved for accumulator, tensor_C, and tensor_D.\n", + "* Return values must be a named variable.\n", + "\n", + "The example tensors is a dictionary with tensor names as keys and reference tensors as values. The reference tensors can be `float`, `torch.Tensor`, `numpy.ndarray`, or our `FakeTensor`. They provides the shape and data type information of the inputs and outputs of the epilogue.\n", + "\n", + "The epilogue can be generated simply through `cutlass.evt.trace(, )`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d257833", + "metadata": {}, + "outputs": [], + "source": [ + "# Define epilogue visitor\n", + "def example_epilogue(accum, alpha, C, beta, aux, bias):\n", + " F = alpha * accum + (beta * C + aux)\n", + " E = relu(F + 1) + bias\n", + " D = E + F\n", + " return D, F\n", + "\n", + "# Construct inputs and outputs\n", + "alpha = 0.5\n", + "beta = 0.5\n", + "aux = torch.ceil(torch.empty(size=(m, n), dtype=type_C, device=\"cuda\").uniform_(scope_min, scope_max))\n", + "bias = torch.ceil(torch.empty(size=(m, 1), dtype=type_C, device=\"cuda\").uniform_(scope_min, scope_max))\n", + "tensor_F = torch.zeros_like(tensor_D)\n", + "examples_tensors = {\n", + " \"accum\": FakeTensor(element=torch.float32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor),\n", + " \"alpha\": alpha,\n", + " \"C\": tensor_C,\n", + " \"beta\": beta,\n", + " \"aux\": aux,\n", + " \"bias\": bias,\n", + " \"D\": tensor_D,\n", + " \"F\": tensor_F\n", + "}\n", + "\n", + "# Trace the epilogue visitor\n", + "epilogue_visitor = cutlass.epilogue.trace(example_epilogue, examples_tensors)" + ] + }, + { + "cell_type": "markdown", + "id": "54961694", + "metadata": {}, + "source": [ + "## Run a GEMM with the epilogue visitor functor\n", + "The `epilogue_visitor` can be used by setting the plan's `epilogue_visitor` field. The arguments for the epilogue visitor are provided as a `dict` through the `visitor_args` keyword argument." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fe49443", + "metadata": {}, + "outputs": [], + "source": [ + "visitor_args = {\n", + " \"alpha\": alpha, \"C\": tensor_C, \"beta\": beta, \n", + " \"aux\": aux, \"bias\": bias, \"D\": tensor_D, \"F\": tensor_F\n", + "}\n", + "\n", + "plan.epilogue_visitor = epilogue_visitor\n", + "plan.run(\n", + " tensor_A, tensor_B, tensor_C, tensor_D, \n", + " visitor_args=visitor_args, print_module=print_module)" + ] + }, + { + "cell_type": "markdown", + "id": "455d0a37", + "metadata": {}, + "source": [ + "The epilogue function `example_epilogue` can be used as a reference function. We can now verify the results simply with" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e32e7798", + "metadata": {}, + "outputs": [], + "source": [ + "class TorchReference(torch.nn.Module):\n", + " def forward(self, A, B, alpha, C, beta, aux, bias):\n", + " accum = torch.matmul(A, B)\n", + " return example_epilogue(accum, alpha, C, beta, aux, bias)\n", + "\n", + "torch_reference = TorchReference()\n", + "if hasattr(torch, \"compile\"):\n", + " # If the torch.compile feature is available\n", + " torch_reference = torch.compile(torch_reference)\n", + "\n", + "tensor_D_ref, tensor_F_ref = torch_reference(tensor_A, tensor_B, alpha, tensor_C, beta, aux, bias)\n", + "\n", + "assert torch.equal(tensor_D, tensor_D_ref)\n", + "assert torch.equal(tensor_F, tensor_F_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "b69e441f", + "metadata": {}, + "source": [ + "The performance of CUTLASS fused kernel can be profiled with" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8db92150", + "metadata": {}, + "outputs": [], + "source": [ + "warmup_iterations = 10\n", + "profile_iterations = 50\n", + "# Profile CUTLASS fused kernel\n", + "duration = CUDAEventProfiler(\n", + " plan, warmup_iterations, profile_iterations,\n", + " tensor_A, tensor_B, tensor_C, tensor_D, \n", + " visitor_args=visitor_args)()\n", + "\n", + "print(f\"CUTLASS duration: {duration:.2f} ms\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/python/README.md b/examples/python/README.md index 2ed80e1939..590f2e24e4 100644 --- a/examples/python/README.md +++ b/examples/python/README.md @@ -16,3 +16,7 @@ * [03_basic_conv2d](/examples/python/03_basic_conv2d.ipynb) Shows how to declare, configure, compile, and run a CUTLASS Conv2d using the Python interface + +* [04_epilogue_visitor](/examples/python/04_epilogue_visitor.ipynb) + + Shows how to fuse elementwise activation functions to GEMMs via the Python Epilogue Visitor interface diff --git a/include/cute/algorithm/axpby.hpp b/include/cute/algorithm/axpby.hpp index a613417d39..a01fb1df14 100644 --- a/include/cute/algorithm/axpby.hpp +++ b/include/cute/algorithm/axpby.hpp @@ -68,7 +68,14 @@ axpby(Alpha const& alpha, Beta const& beta, Tensor & y) { - auto isBetaZero = (beta == Int<0>{}); + auto isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; + } + else { + return beta == Int<0>{}; + } + } (); CUTE_UNROLL for (int i = 0; i < size(x); ++i) { diff --git a/include/cute/algorithm/gemm.hpp b/include/cute/algorithm/gemm.hpp index 44a0f7d487..4a8e6fdd17 100644 --- a/include/cute/algorithm/gemm.hpp +++ b/include/cute/algorithm/gemm.hpp @@ -218,7 +218,6 @@ gemm(MMA_Atom const& mma, CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); - gemm(mma, D, // (M,N) make_tensor(A.data(), append<2>(A.layout())), // (M,1) @@ -253,7 +252,7 @@ gemm(MMA_Atom const& mma, CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutC_TV{}) == Int<1>{}); CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutA_TV{}) == Int<1>{}); CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutB_TV{}) == Int<1>{}); - + gemm(mma, make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N) make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K) @@ -282,7 +281,6 @@ gemm(MMA_Atom const& mma, CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); - auto M = size<1>(A); auto N = size<1>(B); // REGISTER .reuse OPTIMIZATIONS @@ -409,7 +407,6 @@ gemm(MMA_Atom const& mma, CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); - auto K = size<2>(A); CUTE_UNROLL @@ -454,7 +451,6 @@ gemm(MMA_Atom const& mma, CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutC_TV{}) == Int<1>{}); CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutA_TV{}) == Int<1>{}); CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutB_TV{}) == Int<1>{}); - gemm(mma, make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N) make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K) diff --git a/include/cute/algorithm/tuple_algorithms.hpp b/include/cute/algorithm/tuple_algorithms.hpp index d9ae200338..4eeebf8b7a 100644 --- a/include/cute/algorithm/tuple_algorithms.hpp +++ b/include/cute/algorithm/tuple_algorithms.hpp @@ -140,7 +140,11 @@ CUTE_HOST_DEVICE constexpr auto transform_apply(T&& t, F&& f, G&& g) { - return detail::tapply(static_cast(t), f, g, tuple_seq{}); + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t))); + } } template @@ -148,7 +152,11 @@ CUTE_HOST_DEVICE constexpr auto transform_apply(T0&& t0, T1&& t1, F&& f, G&& g) { - return detail::tapply(static_cast(t0), static_cast(t1), f, g, tuple_seq{}); + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t0), static_cast(t1), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t0), static_cast(t1))); + } } template @@ -156,7 +164,11 @@ CUTE_HOST_DEVICE constexpr auto transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g) { - return detail::tapply(static_cast(t0), static_cast(t1), static_cast(t2), f, g, tuple_seq{}); + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t0), static_cast(t1), static_cast(t2), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t0), static_cast(t1), static_cast(t2))); + } } // @@ -306,21 +318,16 @@ transform_leaf(T0 const& t0, T1 const& t1, F&& f) namespace detail { -template -CUTE_HOST_DEVICE constexpr -auto -find_if(T const& t, F&& f, seq<>) -{ - return cute::integral_constant::value>{}; -} - template CUTE_HOST_DEVICE constexpr auto find_if(T const& t, F&& f, seq) { if constexpr (decltype(f(get(t)))::value) { - return cute::integral_constant{}; + return cute::C{}; + } else + if constexpr (sizeof...(Is) == 0) { + return cute::C{}; } else { return find_if(t, f, seq{}); } @@ -338,7 +345,7 @@ find_if(T const& t, F&& f) if constexpr (is_tuple::value) { return detail::find_if(t, f, tuple_seq{}); } else { - return cute::integral_constant{}; + return cute::C{}; } CUTE_GCC_UNREACHABLE; @@ -355,12 +362,12 @@ find(T const& t, X const& x) template CUTE_HOST_DEVICE constexpr auto -none_of(T const& t, F&& f) +any_of(T const& t, F&& f) { if constexpr (is_tuple::value) { - return cute::integral_constant::value>{}; + return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (false_type{} || ... || a); }, tuple_seq{}); } else { - return not f(t); + return f(t); } CUTE_GCC_UNREACHABLE; @@ -372,8 +379,7 @@ auto all_of(T const& t, F&& f) { if constexpr (is_tuple::value) { - auto not_f = [&](auto const& a) { return not f(a); }; - return cute::integral_constant::value>{}; + return detail::apply(t, [&] (auto const&... a) { return (true_type{} && ... && f(a)); }, tuple_seq{}); } else { return f(t); } @@ -384,9 +390,9 @@ all_of(T const& t, F&& f) template CUTE_HOST_DEVICE constexpr auto -any_of(T const& t, F&& f) +none_of(T const& t, F&& f) { - return not none_of(t, f); + return not any_of(t, f); } // @@ -410,6 +416,14 @@ filter_tuple(T0 const& t0, T1 const& t1, F&& f) return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); }); } +template +CUTE_HOST_DEVICE constexpr +auto +filter_tuple(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) +{ + return transform_apply(t0, t1, t2, f, [](auto const&... a) { return cute::tuple_cat(a...); }); +} + // // Fold (Reduce, Accumulate) // (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n) @@ -595,6 +609,13 @@ unwrap(T const& t) // // Flatten a hierarchical tuple to a tuple of depth one. // +// + +template +struct is_flat : true_type {}; + +template +struct is_flat> : bool_constant<(true && ... && (not is_tuple::value))> {}; template CUTE_HOST_DEVICE constexpr @@ -602,7 +623,12 @@ auto flatten_to_tuple(T const& t) { if constexpr (is_tuple::value) { - return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); + if constexpr (is_flat::value) { + return t; + } else + { + return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); + } } else { return cute::make_tuple(t); } @@ -616,7 +642,12 @@ auto flatten(T const& t) { if constexpr (is_tuple::value) { - return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); + if constexpr (is_flat::value) { + return t; + } else + { + return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); + } } else { return t; } diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp index aaef8b4161..d33ed305be 100644 --- a/include/cute/arch/copy_sm90_desc.hpp +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -177,7 +177,7 @@ to_CUtensorMapSwizzle(SmemSwizzleBits const& t) { #if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) using TmaDescriptor = CUtensorMap; #else - using TmaDescriptor = struct { char bytes[128]; }; + using TmaDescriptor = struct alignas(64) { char bytes[128]; }; #endif //////////////////////////////////////////////////////////////////////////////////////////////////// /// Initiates a TensorMap Prefetch diff --git a/include/cute/arch/mma_sm80.hpp b/include/cute/arch/mma_sm80.hpp index 6050500a47..8dc5fdcb2c 100644 --- a/include/cute/arch/mma_sm80.hpp +++ b/include/cute/arch/mma_sm80.hpp @@ -37,8 +37,19 @@ // Config #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) # define CUTE_ARCH_MMA_SM80_ENABLED + +#if (__CUDA_ARCH__ <= 900) +#define CUTE_ARCH_MMA_B1_AND_SM80_ENABLED +#endif + +#if (__CUDA_ARCH__ <= 890) +#define CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED +#endif + #endif + + namespace cute { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -2044,7 +2055,7 @@ struct SM80_8x8x128_S32U1U1S32_TN_XORPOPC uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc " "{%0, %1}," @@ -2077,7 +2088,7 @@ struct SM80_16x8x128_S32U1U1S32_TN_XORPOPC uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc " "{%0, %1, %2, %3}," @@ -2110,7 +2121,7 @@ struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc " "{%0, %1, %2, %3}," diff --git a/include/cute/arch/mma_sm90.hpp b/include/cute/arch/mma_sm90.hpp index 25a98e6cb0..64561fa1f6 100644 --- a/include/cute/arch/mma_sm90.hpp +++ b/include/cute/arch/mma_sm90.hpp @@ -38,6 +38,7 @@ // Config #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) # define CUTE_ARCH_MMA_SM90_ENABLED +# define CUTE_ARCH_MMA_F64_SM90_ENABLED #endif //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -60,7 +61,7 @@ struct SM90_16x8x4_F64F64F64F64_TN double const& b0, double const& c0, double const& c1, double const& c2, double const& c3) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) asm volatile( "mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64" "{%0, %1, %2, %3}," @@ -93,7 +94,7 @@ struct SM90_16x8x8_F64F64F64F64_TN double const& b0, double const& b1, double const& c0, double const& c1, double const& c2, double const& c3) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) asm volatile( "mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64" "{%0, %1, %2, %3}," @@ -127,7 +128,7 @@ struct SM90_16x8x16_F64F64F64F64_TN double const& b0, double const& b1, double const& b2, double const& b3, double const& c0, double const& c1, double const& c2, double const& c3) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) asm volatile( "mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64" "{%0, %1, %2, %3}," diff --git a/include/cute/arch/mma_sm90_desc.hpp b/include/cute/arch/mma_sm90_desc.hpp index ae647eb9ed..4c99b9ef7c 100644 --- a/include/cute/arch/mma_sm90_desc.hpp +++ b/include/cute/arch/mma_sm90_desc.hpp @@ -86,22 +86,22 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) { union GmmaDescriptor { - CUTE_HOST_DEVICE constexpr + CUTE_HOST_DEVICE constexpr GmmaDescriptor() noexcept : desc_(0) {} - CUTE_HOST_DEVICE constexpr + CUTE_HOST_DEVICE constexpr GmmaDescriptor(uint64_t desc) noexcept : desc_(desc) {} - CUTE_HOST_DEVICE constexpr + CUTE_HOST_DEVICE constexpr GmmaDescriptor(GmmaDescriptor const& t) noexcept : desc_(t.desc_) {} - CUTE_HOST_DEVICE constexpr + CUTE_HOST_DEVICE constexpr GmmaDescriptor(GmmaDescriptor && t) noexcept : desc_(t.desc_) {} - - CUTE_HOST_DEVICE constexpr + + CUTE_HOST_DEVICE constexpr GmmaDescriptor& operator=(GmmaDescriptor const& t) noexcept { desc_ = t.desc_; return *this; } - CUTE_HOST_DEVICE constexpr + CUTE_HOST_DEVICE constexpr GmmaDescriptor& operator=(GmmaDescriptor && t) noexcept { desc_ = t.desc_; return *this; diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index 9b91f87ef4..d2617abdd8 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -38,9 +38,17 @@ #include #include +#include + namespace cute { +template +struct AuxTmaParams { + using GmemStrides = GmemStrides_; + GmemStrides g_stride_; +}; + ////////////////////////////////////////////////////////////////////////////// ///////////////////////////// TMA_LOAD /////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////// @@ -88,14 +96,14 @@ struct Copy_Traits { static_assert(is_smem::value, "Expected smem dst for SM90_TMA_LOAD"); - traits.copy_unpack_(raw_pointer_cast(dst.data()), src.data().coord_, tuple_seq{}); + traits.copy_unpack_(cute::raw_pointer_cast(dst.data()), src.data().coord_, tuple_seq{}); } }; // The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar // Use .with(tma_mbar) to construct an executable version -template -struct Copy_Traits +template +struct Copy_Traits { using ThrID = Layout<_1>; @@ -109,7 +117,8 @@ struct Copy_Traits // SM90_TMA_LOAD arguments TmaDescriptor tma_desc_; - GmemStrides g_stride_; + using AuxParams = AuxParams_; + AuxParams aux_params_; // Return TmaDescriptor/TensorMap CUTE_HOST_DEVICE constexpr @@ -133,8 +142,8 @@ struct Copy_Traits CUTE_HOST_DEVICE constexpr auto get_tma_tensor(GShape const& g_shape) const { - static_assert(is_congruent::value); - return make_counting_tensor(make_layout(g_shape, g_stride_)); + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); } // Don't try to execute a copy with SM90_TMA_LOAD before calling .with() @@ -190,12 +199,12 @@ struct Copy_Traits { static_assert(is_smem::value, "Expected smem dst for SM90_TMA_LOAD_MULTICAST"); - traits.copy_unpack_(raw_pointer_cast(dst.data()), src.data().coord_, tuple_seq{}); + traits.copy_unpack_(cute::raw_pointer_cast(dst.data()), src.data().coord_, tuple_seq{}); } }; -template -struct Copy_Traits +template +struct Copy_Traits { using ThrID = Layout<_1>; @@ -209,7 +218,8 @@ struct Copy_Traits // SM90_TMA_LOAD_MULTICAST arguments TmaDescriptor tma_desc_; - GmemStrides g_stride_; + using AuxParams = AuxParams_; + AuxParams aux_params_; // Return TmaDescriptor/TensorMap CUTE_HOST_DEVICE constexpr @@ -230,8 +240,8 @@ struct Copy_Traits CUTE_HOST_DEVICE constexpr auto get_tma_tensor(GShape const& g_shape) const { - static_assert(is_congruent::value); - return make_counting_tensor(make_layout(g_shape, g_stride_)); + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); } // Don't try to execute a copy with SM90_TMA_LOAD_MULTICAST before calling .with() @@ -248,8 +258,8 @@ struct Copy_Traits ////////////////////////////////////////////////////////////////////////////// // The executable SM90_TMA_STORE with tma_desc -template -struct Copy_Traits +template +struct Copy_Traits { using ThrID = Layout<_1>; @@ -263,7 +273,8 @@ struct Copy_Traits // SM90_TMA_STORE arguments TmaDescriptor tma_desc_; - GmemStrides g_stride_; + using AuxParams = AuxParams_; + AuxParams aux_params_; // Return TmaDescriptor/TensorMap CUTE_HOST_DEVICE constexpr @@ -277,8 +288,8 @@ struct Copy_Traits CUTE_HOST_DEVICE constexpr auto get_tma_tensor(GShape const& g_shape) const { - static_assert(is_congruent::value); - return make_counting_tensor(make_layout(g_shape, g_stride_)); + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); } template @@ -305,7 +316,7 @@ struct Copy_Traits static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor - traits.copy_unpack_(raw_pointer_cast(src.data()), dst.data().coord_, tuple_seq{}); + traits.copy_unpack_(cute::raw_pointer_cast(src.data()), dst.data().coord_, tuple_seq{}); } }; @@ -417,9 +428,78 @@ struct Copy_Traits namespace detail { -// Use a smem2gmode map to read through the GMEM tensor -// and construct a TMA Descriptor for the resulting instruction -template (OldLayout) +// s0:d0 _1:d1 => continue +// _1:d0 s1:d1 => replace_back s1:d1 +// s0:d0 s1:s0*d0 => replace_back s0*s1:d0 if s0*s1 <= 256 +// s0:d0 s1:d1 => append s1:d1 +// +// @pre OldShape and OldStride are flat +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_256_impl(OldShape const& old_shape, OldStride const& old_stride, + NewShape const& new_shape, NewStride const& new_stride) +{ + if constexpr (I == rank_v) { + // Base case, we're done + if constexpr (is_constant<1, NewShape>::value) { + return Layout<_1,_0>{}; + } else { + return Layout{new_shape,new_stride}; + } + } else if constexpr (is_constant<1, decltype(get(old_shape))>::value) { + // shape(layout) == _1, skip it and continue + return coalesce_256_impl(old_shape, old_stride, new_shape, new_stride); + } else if constexpr (is_constant<1, NewShape>::value) { + // Replace our shape-1 with anything (Can only happen on input new_shape/new_stride) + return coalesce_256_impl(old_shape, old_stride, get(old_shape), get(old_stride)); + } else if constexpr (is_constant(old_stride) && + get(old_shape) * back(new_shape) <= Int<256>{})>::value) { + // Merge modes because the shapes and strides match and the merge is 256 or less + return coalesce_256_impl(old_shape, old_stride, + replace_back(new_shape, get(old_shape) * back(new_shape)), + new_stride); + } else { + // Can't replace or merge, so append a new mode + return coalesce_256_impl(old_shape, old_stride, + append(new_shape, get(old_shape)), + append(new_stride, get(old_stride))); + } + + CUTE_GCC_UNREACHABLE; +} + +// Combine all the modes that are possible to combine +// Does not respect the profile of the layout, but does preserve total size +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_256(Layout const& layout) +{ + auto flat_shape = flatten(layout.shape()); + auto flat_stride = flatten(layout.stride()); + return coalesce_256_impl<1>(flat_shape, flat_stride, get<0>(flat_shape), get<0>(flat_stride)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_256(Tensor const& tensor) +{ + return make_tensor(tensor.data(), coalesce_256(tensor.layout())); +} + + +// Use a smem_inv_h to read through the GMEM tensor +// and construct a TMA Descriptor for the resulting instruction +// At the same time, construct the Tma Tensor's Stride to generate +// the TMA coordinates that the instruction consumes. +// +template CUTE_HOST_RTC @@ -428,63 +508,78 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM Layout const& smem_inv_h, // smem_idx to hier gmode Swizzle const& swizzle) // Swizzle fn on smem_idx { - using T = typename GEngine::value_type; - - // This is the gmem "vector" that corresponds to the smem vector in memory (smem_box_shape):(gmem_prob_stride) - Tensor tma_gstride = recast(gtensor.compose(smem_inv_h)); - - // If the sizes of smem_inv_h and tma_gstride don't match, then a non-trivial recast was performed. - // In that case, require that the recasted modes all have size-1 so TMA can identity them and skip them. - for_each(zip(flatten(shape(smem_inv_h)), flatten(shape(tma_gstride))), [] (auto s_and_g) { - auto [s,g] = s_and_g; - CUTE_STATIC_ASSERT_V(s == g or g == Int<1>{}, - "A non-trivial recast was performed, but TMA cannot identify which modes to leave out."); - }); + // The smem vector is the same units as gtensor, so compose first and then recast + // tma_val_idx:gmem_strides + Tensor tile_gstride = recast(gtensor.compose(smem_inv_h)); + // Coalesce modes up to size-256 (the maximum TMA box extent in units of TmaInternalType) + // tma_box_shape:gmem_strides + Tensor tma_gstride = coalesce_256(tile_gstride); // Perform the tiling to the gmem vector again, but with indirections to the gtensor modes auto gbasis = make_identity_layout(shape(gtensor)); - auto tma_gbasis_tile_tmp = gbasis.compose(smem_inv_h); - // Instead of the recast (gbasis doesn't have type info), replace the shape with the already-recasted shape and coalesce out any size-1 modes - auto tma_gbasis_tile = coalesce(make_layout(shape(tma_gstride), stride(tma_gbasis_tile_tmp))); + auto tile_gbasis_tmp = gbasis.compose(smem_inv_h); + + // Instead of the recast (gbasis doesn't have type info), replace the shape with the already-recasted shape + // tma_box_shape:gmem_mode + auto tile_gbasis = make_layout(shape(tile_gstride), stride(tile_gbasis_tmp)); // Recast the original tensor for shape inspections - auto glayout_T = recast(gtensor).layout(); + auto gtensor_T = recast(gtensor); - // Find missing bases that don't belong to a size-1 mode of the recast input + // Find missing bases that don't appear in tile_gbasis // NOTE This is essentially ArithmeticTuple complement... - // NOTE in persuit of implementing an ArithmeticTuple logical_divide for smem_inv_h - auto tma_gbasis_full = fold(zip(flatten(shape(glayout_T)), flatten(stride(gbasis))), tma_gbasis_tile, - [](auto tma_g, auto s_and_d) { - auto [s,d] = s_and_d; - auto k = find(stride(tma_g), d); // Find the basis in tma_gstride - if constexpr (decltype(k != rank(tma_g) || is_constant<1, decltype(s)>{})::value) { - // If d was found or s is static-1, then don't append - return tma_g; + // NOTE in pursuit of implementing an ArithmeticTuple logical_divide for smem_inv_h + auto tile_gbasis_remaining_stride = filter_tuple(flatten(shape (gtensor_T)), flatten(stride(gtensor_T)), + flatten(stride(gbasis)), + [&](auto s, auto d, auto e) + { + if constexpr (is_constant<1, decltype(s)>::value || is_constant<0, decltype(d)>::value) { + return cute::tuple<>{}; // If size-1 or stride-0, then don't append } else { - // Else, append the missing basis - return append(tma_g, make_layout(Int<1>{}, d)); + using E = decltype(e); + auto has_e = any_of(stride(tile_gbasis), [] (auto tb) { return tb == E{}; }); + if constexpr (decltype(has_e)::value) { + return cute::tuple<>{}; // If d was found, then don't append + } else { + return cute::tuple(e); // Else, this is missing so append + } } }); + auto tile_gbasis_remaining_rank = rank(tile_gbasis_remaining_stride); + + // "Coalesce" the tile basis into a compatible shape with the tma + auto tma_gbasis_tile = tile_gbasis.compose(make_layout(wrap(shape(tma_gstride)))); - // Group the trailing modes to make this max rank-5 + // Append the remaining basis modes that contribute to the TMA with size-1 + auto tma_gbasis_full = make_layout(tuple_cat(wrap( shape(tma_gbasis_tile)), wrap(repeat(Int<1>{}))), + tuple_cat(wrap(stride(tma_gbasis_tile)), wrap(tile_gbasis_remaining_stride))); + + // Group the trailing modes to make this max rank-5 -- TMA rank limitation + // tma_box_shape:gmem_mode auto tma_gbasis = group(tma_gbasis_full); #if 0 - print("gtensor : "); print(gtensor); print("\n"); print("smem_inv_h : "); print(smem_inv_h); print("\n"); + print("gtensor : "); print(gtensor); print("\n"); + print("tile_gstride : "); print(tile_gstride); print("\n"); print("tma_gstride : "); print(tma_gstride); print("\n"); print("gbasis : "); print(gbasis); print("\n"); - print("tma_gb_tile : "); print(tma_gbasis_tile ); print("\n"); + print("tile_gbasis : "); print(tile_gbasis); print("\n"); print("tma_gbasis : "); print(tma_gbasis); print("\n"); #endif + // + // TMA desc creation + // + constexpr int tma_dim = decltype(rank(tma_gbasis))::value; // // TMA gmem desc info // - void* gmem_address = (void*) raw_pointer_cast(gtensor.data()); + void* gmem_address = (void*) raw_pointer_cast(gtensor_T.data()); + auto gmem_layout = gtensor_T.layout(); cute::array gmem_prob_shape = {1,1,1,1,1}; cute::array gmem_prob_stride = {0,0,0,0,0}; @@ -492,12 +587,12 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM for_each(make_seq{}, [&](auto i) { for_each(stride(tma_gbasis), [&](auto ej) { // Problem stride - uint64_t stride_j = basis_get(ej, stride(glayout_T)) * sizeof(T); + uint64_t stride_j = ceil_div(basis_get(ej, stride(gmem_layout)) * sizeof_bits_v, 8); uint64_t old_stride = gmem_prob_stride[i]; gmem_prob_stride[i] = gcd(gmem_prob_stride[i], stride_j); // Problem shape - uint64_t shape_j = basis_get(ej, shape(glayout_T)); + uint64_t shape_j = basis_get(ej, shape(gmem_layout)); if (gmem_prob_stride[i] != 0) { // Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1 gmem_prob_shape[i] = (gmem_prob_shape[i]-1) * (old_stride / gmem_prob_stride[i]) @@ -522,8 +617,8 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM assert(gmem_prob_shape[4] >= (uint64_t(1))); // Size must be min 1 assert(gmem_prob_shape[4] <= (uint64_t(1) << 32)); // Size must be max 2^32 - // TMA descriptor does not store the zeroth stride and assumes it is sizeof(T) == one element. - assert(gmem_prob_stride[0] == sizeof(T) && "Majorness of smem doesn't match majorness of gmem"); + // TMA descriptor does not store the zeroth stride and assumes it is 1 (TmaInternalType element). + assert(gmem_prob_stride[0] == sizeof(TmaInternalType) && "Majorness of smem doesn't match majorness of gmem"); assert((gmem_prob_stride[1]) < (uint64_t(1) << 40)); // Stride must be max 2^40 assert((gmem_prob_stride[1] & 0b1111) == 0); // Stride must be multiple of 16B (128b) @@ -545,14 +640,16 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM smem_box_shape[i] *= size(tma_gbasis); }); - assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 - assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 = 256 - assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 - assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 = 256 - assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 - assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 = 256 - assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 - assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[0] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[0] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[1] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[1] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[2] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[2] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[3] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[3] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[4] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[4] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 assert(smem_box_stride[0] >= (uint32_t(1))); // Stride must be min 1 assert(smem_box_stride[0] <= (uint32_t(8))); // Stride must be max 2^3 = 8 @@ -565,88 +662,101 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM assert(smem_box_stride[4] >= (uint32_t(1))); // Stride must be min 1 assert(smem_box_stride[4] <= (uint32_t(8))); // Stride must be max 2^3 = 8 - // - // Construct the descriptor - // - - TmaDescriptor tma_desc = {0}; - - // - // TMA general info - // - -#if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) - - CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); - CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; - CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; - CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; - - // TMA smem swizzle type - CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(swizzle)); - CUresult result = cuTensorMapEncodeTiled( - &tma_desc, - tma_format, - tma_dim, - gmem_address, - gmem_prob_shape.data(), - gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly 1 - smem_box_shape.data(), - smem_box_stride.data(), - tma_interleave, - smem_swizzle, - tma_l2Promotion, - tma_oobFill); - - if (result != CUDA_SUCCESS) { - std::cerr << "TMA Desc Addr: " << &tma_desc - << "\nformat " << tma_format - << "\ndim " << tma_dim - << "\ngmem_address " << gmem_address - << "\nglobalDim " << gmem_prob_shape - << "\nglobalStrides " << gmem_prob_stride - << "\nboxDim " << smem_box_shape - << "\nelementStrides " << smem_box_stride - << "\ninterleave " << tma_interleave - << "\nswizzle " << smem_swizzle - << "\nl2Promotion " << tma_l2Promotion - << "\noobFill " << tma_oobFill << std::endl; - std::cerr << "Error: Failed to initialize the TMA descriptor " << result << std::endl; - assert(false); - } - -#endif // (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) + // + // Construct the descriptor + // + + TmaDescriptor tma_desc = {0}; + + // + // TMA general info + // + + #if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) + + CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); + CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + + // TMA smem swizzle type + CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(swizzle)); + CUresult result = cuTensorMapEncodeTiled( + &tma_desc, + tma_format, + tma_dim, + gmem_address, + gmem_prob_shape.data(), + gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly 1 + smem_box_shape.data(), + smem_box_stride.data(), + tma_interleave, + smem_swizzle, + tma_l2Promotion, + tma_oobFill); + + if (result != CUDA_SUCCESS) { + std::cerr << "TMA Desc Addr: " << &tma_desc + << "\nformat " << tma_format + << "\ndim " << tma_dim + << "\ngmem_address " << gmem_address + << "\nglobalDim " << gmem_prob_shape + << "\nglobalStrides " << gmem_prob_stride + << "\nboxDim " << smem_box_shape + << "\nelementStrides " << smem_box_stride + << "\ninterleave " << tma_interleave + << "\nswizzle " << smem_swizzle + << "\nl2Promotion " << tma_l2Promotion + << "\noobFill " << tma_oobFill << std::endl; + std::cerr << "Error: Failed to initialize the TMA descriptor " << result << std::endl; + assert(false); + } + + #endif // (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) // Finally, get the inverse permutation of the E bases for the mocked gmem stride // NOTE This is essentially ArithmeticTuple inverse... auto gmem_stride_bases = transform_leaf(stride(gbasis), [&](auto ei) { - auto si = basis_get(ei, shape(glayout_T)); - auto di = basis_get(ei, stride(glayout_T)); - auto tma_gbasis_stride = stride(tma_gbasis); - // Find j such that E is in stride(tma_gbasis) - [[maybe_unused]] auto j = find_if(tma_gbasis_stride, [&](auto tma_stride_j) { return any_of(tma_stride_j, [&](auto dj) { return dj == ei; }); }); - // Return the TMA basis this gmode contributes to - if constexpr (is_constant<1, decltype(si)>::value || decltype(j == rank(tma_gbasis_stride))::value) { - return Int<0>{}; // Return arithmetic identity -- no contribution to the TMA - } else - if constexpr (decltype(rank(tma_gbasis_stride) == Int<1>{})::value) { - return E{}; // We know that the scale factor is Int<1>{} + auto si = basis_get(ei, shape(gmem_layout)); + auto di = basis_get(ei, stride(gmem_layout)); + if constexpr (is_constant<1, decltype(si)>::value || is_constant<0, decltype(di)>::value) { + return Int<0>{}; // If size-1 or stride-0, return arithmetic identity -- no contribution to the TMA } else { - return E{} * int32_t(di * sizeof(T) / cute::max(gmem_prob_stride[j], 16)); + auto tma_gbasis_stride = stride(tma_gbasis); + // Find j such that E is in stride(tma_gbasis) + using EI = decltype(ei); + [[maybe_unused]] auto j = find_if(tma_gbasis_stride, [&](auto tma_stride_j) { return any_of(tma_stride_j, [&](auto dj) { return dj == EI{}; }); }); + if constexpr (decltype(j == rank(tma_gbasis_stride))::value) { + return Int<0>{}; // If not-found, return arithmetic identity -- no contribution to the TMA + } else + if constexpr (decltype(j == Int<0>{})::value) { + auto scale = ratio(size(tma_gstride), size(smem_inv_h)) * basis_get(ei, stride(gtensor)); + return E{} * scale; // Return TMA Coord basis -- with a recast scale factor + } else + if constexpr (decltype(rank(tma_gbasis_stride) == Int<1>{})::value) { + return E{}; // Return TMA Coord basis -- known scale of Int<1>{} + } else { + int32_t scale = ceil_div(int32_t(di * sizeof_bits_v / cute::max(gmem_prob_stride[j], 16)), 8); + return E{} * scale; // Return TMA Coord basis -- with a dynamic scale factor + } } }); -#if 0 - print("gmem_stride_bases : "); print(gmem_stride_bases); print("\n"); -#endif + #if 0 + print("tma_gbasis : "); print(gmem_stride_bases); print("\n"); + #endif - return cute::make_tuple(tma_desc, gmem_stride_bases); + using AuxParams = AuxTmaParams; + return cute::make_tuple(tma_desc, AuxParams{gmem_stride_bases}); } // The "logical TMA tid" is a map from the CTA rank to its logical id // within the instruction. It works like a mask or ordering on the // CTAs. For non-multicast TMA, all CTAs should map to 0. For // multicast TMA of size 4, CTAs will be mapped to {0,1,2,3}. -template const& gtensor, // Full GMEM Tensor SLayout const& slayout, // CTA Tile of SMEM Layout const& cta_t_map, // T: CTA thr idx -> logical TMA tid - Layout const& cta_v_map) // V: CTA val idx -> gmem coord + Layout const& cta_v_map) // V: CTA val idx -> gmem mode { // // TMA parameter checking @@ -673,18 +783,19 @@ make_tma_copy_tiled(CopyOp, // // Invert the smem to get the largest contiguous vector in the smem layout + // smem idx -> smem coord auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); - // trunc_smem_idx -> trunc_smem_coord - // Map from smem idx to a gmem mode + // Compose with the V-Map to convert smem coord (CTA val idx) to gmem mode + // smem idx -> gmem mode auto sidx_to_gmode = coalesce(composition(cta_v_map, inv_smem_layout)); #if 0 - print("g_layout : "); print(gtensor.layout()); print("\n"); + print("g_tensor : "); print(gtensor); print("\n"); print("s_layout : "); print(slayout); print("\n"); print("cta_t_map : "); print(cta_t_map); print("\n"); print("cta_v_map : "); print(cta_v_map); print("\n"); - print("inv_smem : "); print(inv_smem_layout); print("\n"); + print("inv_s_layout : "); print(inv_smem_layout); print("\n"); print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n"); #endif @@ -693,9 +804,11 @@ make_tma_copy_tiled(CopyOp, // // Generate a TupleBasis for the gtensor + // gmem coord -> gmem coord auto glayout_basis = make_identity_layout(shape(gtensor)); // Tile the modes of gtensor with the truncated cta_v_map o inv_smem_layout_trunc + // smem idx -> gmem coord auto tma_layout_full = flatten(composition(glayout_basis, sidx_to_gmode)); // Truncate any incompatibilities -- no starting in the middle of gmodes @@ -704,61 +817,60 @@ make_tma_copy_tiled(CopyOp, return not is_constant<1,decltype(v)>{}; }); static_assert(smem_rank > 0, "Could not find a common tile-gmem vectorization. Does the Tile select out major GMEM modes?"); - // TMA uses a maximum of 5 modes - // If the gtensor has more than 5 modes, we need to reserve the last TMA-mode as a "multimode" - constexpr int smem_tma_rank = cute::min(int(smem_rank), (rank(tma_layout_full) > Int<5>{} ? 4 : 5)); // Keep only the static-1 basis modes into gmem - auto tma_layout_trunc = take<0,smem_tma_rank>(tma_layout_full); + auto tma_layout_trunc = take<0,smem_rank>(tma_layout_full); - // Split according to the portion each multicast CTA will be responsible for - auto tma_layout_vt = logical_divide(tma_layout_trunc, shape_div(size(tma_layout_trunc), cosize(cta_t_map))); + // Keep only the portion each multicast CTA will be responsible for + auto tma_layout_v = composition(tma_layout_trunc, shape_div(size(tma_layout_trunc), cosize(cta_t_map))); #if 0 print("glayout_basis : "); print(glayout_basis); print("\n"); print("tma_layout_full : "); print(tma_layout_full); print("\n"); print("tma_layout_trunc: "); print(tma_layout_trunc); print("\n"); - print("tma_layout_vt : "); print(tma_layout_vt); print("\n"); + print("tma_layout_v : "); print(tma_layout_v); print("\n"); #endif // - // Construct the TMA Desc and GMEM mode ordering + // Construct the TMA Desc and the strides of the TMA Tensor // - auto [tma_desc, gmem_stride_bases] = detail::make_tma_copy_desc(gtensor, layout<0>(tma_layout_vt), get_swizzle_portion(slayout)); + auto [tma_desc, aux_params] = detail::make_tma_copy_desc(gtensor, + tma_layout_v, + get_swizzle_portion(slayout)); // // Construct the Copy_Traits // using T = typename GEngine::value_type; - constexpr int num_bits_per_tma = decltype(size<0>(tma_layout_vt))::value * sizeof(T) * 8; - using Traits = Copy_Traits, decltype(gmem_stride_bases)>; + constexpr int num_bits_per_tma = decltype(size(tma_layout_trunc))::value * sizeof_bits_v; + using Traits = Copy_Traits, decltype(aux_params)>; + using Atom = Copy_Atom; + + Traits tma_traits{tma_desc, aux_params}; #if 0 - print("num_bits : "); print(NumBitsPerTMA{}); print("\n"); - print("g_stride_bases: "); print(gmem_stride_bases); print("\n"); + print("num_bits_per_tma : "); print(num_bits_per_tma); print("\n"); + print("g_stride_bases : "); print(tma_traits.aux_params_.g_stride_); print("\n"); #endif - Traits tma_traits{tma_desc, gmem_stride_bases}; - // // Construct the TiledCopy // auto cta_tiler = product_each(shape(cta_v_map)); - // (CTA V, CTA T) -> smem_coord - auto layout_vt = composition(inv_smem_layout, make_layout(shape(tma_layout_vt))); - // Scale that up to cover all of the smem_coords - // - // The smem vector might not cover all of the tile, - // so multiply it up to cover the entire tile. - // "T" here (the parallel index) is a CTA index. - auto layout_VT = tile_to_shape(layout_vt, make_shape(size(cta_v_map)/size<1>(layout_vt), size<1>(layout_vt))); - // Flip it and change the domain of the T from logical thr to thr_idx - auto layout_TV = make_layout(composition(layout<1>(layout_VT), cta_t_map), layout<0>(layout_VT)); + // CTA V -> smem_coord + auto layout_v = composition(inv_smem_layout, size(tma_layout_trunc)); + auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map)); + // CTA T -> smem idx + auto layout_t = make_layout(cosize(cta_t_map), shape_div(size(tma_layout_trunc), cosize(cta_t_map))); + // CTA TID -> smem coord + auto layout_T = composition(inv_smem_layout, composition(layout_t, cta_t_map)); + // Combine with the T mapping + auto layout_TV = make_layout(layout_T, layout_V); #if 0 print("cta_tiler : "); print(cta_tiler); print("\n"); @@ -766,8 +878,7 @@ make_tma_copy_tiled(CopyOp, print("layout_TV : "); print(layout_TV); print("\n"); #endif - using T = typename GEngine::value_type; - return TiledCopy, decltype(layout_TV), decltype(cta_tiler)>{tma_traits}; + return TiledCopy{tma_traits}; } } // end namespace detail @@ -844,7 +955,8 @@ make_tma_copy_tiled(CopyOp, copy(tma.with(barrier, mcast_mask), tAgA, tAsA); // copy with supporting TMA params */ -template (copy_op, + gtensor, + slayout, + make_layout(cluster_size), + make_identity_layout(cta_tile)); } // Explicit defaulting +template +CUTE_HOST_RTC +auto +make_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tile const& cta_tile, + Cluster_Size const& cluster_size) +{ + using TmaInternalType = typename GEngine::value_type; + return make_tma_copy(copy_op, + gtensor, + slayout, + cta_tile, + cluster_size); +} + template diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 844d653eeb..68bd290e6d 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -155,7 +155,7 @@ struct MMA_Atom> if constexpr (has_dereference::value) { // If the intended FrgTypeA is a view (of the current tensor), forward the whole - static_assert(is_same::value_type>::value, "Expecting ValTypeA type"); + static_assert(is_same, typename remove_cvref_t::value_type>::value, "Expecting ValTypeA type"); return make_tensor(std::forward(atensor)); } else { // Else, the intended FrgTypeA is a value type, construct a new tensor with a fragment layout diff --git a/include/cute/atom/mma_traits_sm75.hpp b/include/cute/atom/mma_traits_sm75.hpp index 405e871fd2..63f834664b 100644 --- a/include/cute/atom/mma_traits_sm75.hpp +++ b/include/cute/atom/mma_traits_sm75.hpp @@ -49,11 +49,11 @@ struct MMA_Traits using Shape_MNK = Shape<_16,_8,_8>; using ThrID = Layout<_32>; using ALayout = Layout,Shape < _2,_2>>, - Stride,Stride<_16,_1>>>; + Stride,Stride<_16,_8>>>; using BLayout = Layout,_2>, Stride,_8>>; using CLayout = Layout,Shape < _2,_2>>, - Stride,Stride<_16,_1>>>; + Stride,Stride<_16,_8>>>; }; /////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/config.hpp b/include/cute/config.hpp index 4a12f1c584..ba2504cd22 100644 --- a/include/cute/config.hpp +++ b/include/cute/config.hpp @@ -30,7 +30,7 @@ **************************************************************************************************/ #pragma once -#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) || defined(__clang__) +#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) # define CUTE_HOST_DEVICE __forceinline__ __host__ __device__ # define CUTE_DEVICE __forceinline__ __device__ # define CUTE_HOST __forceinline__ __host__ diff --git a/include/cute/container/bit_field.hpp b/include/cute/container/bit_field.hpp index 5398e3271e..0cd3e4fe5d 100644 --- a/include/cute/container/bit_field.hpp +++ b/include/cute/container/bit_field.hpp @@ -72,8 +72,16 @@ struct bit_field // Number of bits in data_[idx] used for NumBits if straddling, else 0 static constexpr uint32_t bit_hi = (idx + 1 < N) ? (storage_type_bits - bit_lo) : 0; +private: + // MSVC issues warning C4293 ("shift count negative or too big, undefined behavior") + // if we use NumBits directly in the shift expression, even if the shift occurs + // in the branch of a ternary expression where NumBits is known to be less than + // the number of bits of the value being shifted. + static constexpr uint32_t MollifiedNumBits = NumBits > 63u ? 63u : NumBits; +public: + // NumBits mask - static constexpr value_type mask = (NumBits < 64) ? ((uint64_t(1) << NumBits) - 1) : uint64_t(-1); + static constexpr value_type mask = (NumBits < 64u) ? ((uint64_t(1) << MollifiedNumBits) - 1) : uint64_t(-1); // NumBits mask for BitStart static constexpr storage_type mask_lo = storage_type(mask) << bit_lo; // NumBits mask for leftover bits in data_[idx+1] if straddling, else 0 diff --git a/include/cute/container/tuple.hpp b/include/cute/container/tuple.hpp index 3455a41620..75829f4520 100644 --- a/include/cute/container/tuple.hpp +++ b/include/cute/container/tuple.hpp @@ -76,6 +76,10 @@ namespace detail template ::value> struct EBO; +template +CUTE_HOST_DEVICE constexpr C findt(EBO const&) +{ return {}; } + // Specialization for types T that have no data; // the "static tuple leaf." Valid T here include // integral_constant, Int, @@ -218,6 +222,20 @@ get(tuple&& t) noexcept return detail::getv(static_cast&&>(t)); } +// +// find a type X within a cute::tuple +// Requires X to be unique in tuple +// Returns a static integer +// + +template +CUTE_HOST_DEVICE constexpr +auto +find(tuple const& t) noexcept +{ + return detail::findt(t); +} + // // Custom is_tuple trait simply checks the existence of tuple_size // and assumes std::get(.), std::tuple_element @@ -225,7 +243,7 @@ get(tuple&& t) noexcept namespace detail { template -auto has_tuple_size( T*) -> integral_constant::value>; +auto has_tuple_size( T*) -> bool_constant<(0 <= tuple_size::value)>; auto has_tuple_size(...) -> false_type; } // end namespace detail @@ -347,6 +365,14 @@ tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)..., get(t4)...); } +template +struct tuple_cat_static; + +template +struct tuple_cat_static, tuple> { + using type = tuple; +}; + } // end namespace detail CUTE_HOST_DEVICE constexpr @@ -370,9 +396,15 @@ CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1) { - return detail::tuple_cat(t0, t1, + if constexpr (is_static::value && is_static::value && + is_tuple::value && is_tuple::value) { + return typename detail::tuple_cat_static::type{}; + } else + { + return detail::tuple_cat(t0, t1, make_index_sequence::value>{}, make_index_sequence::value>{}); + } } template @@ -416,7 +448,7 @@ CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, T5 const& t5, Ts const&... ts) { - return cute::tuple_cat(cute::tuple_cat(t0,t1,t2,t3,t4), t5, ts...); + return cute::tuple_cat(cute::tuple_cat(t0,t1,t2,t3,t4), cute::tuple_cat(t5, ts...)); } #endif diff --git a/include/cute/int_tuple.hpp b/include/cute/int_tuple.hpp index 7875ac1581..4497034fcf 100644 --- a/include/cute/int_tuple.hpp +++ b/include/cute/int_tuple.hpp @@ -37,29 +37,20 @@ #include #include -namespace cute -{ - -template -using IntTuple = cute::tuple; +/** IntTuple is an integer or a tuple of IntTuples. + * This file holds utilities for working with IntTuples, + * but does not hold a concrete concept or class of IntTuple. + */ -// Construct an IntTuple with all value-elements -template -CUTE_HOST_DEVICE constexpr -IntTuple -make_int_tuple(Ts const&... t) +namespace cute { - return {t...}; -} -// CuTe does not treat integers as tuples. -// For example, is_tuple is false, and tuple_size doesn't compile. -// Nevertheless, CuTe defines rank(Integral) as 1 -// (where "Integral" is a shorthand for either run-time integers -// or CuTe's compile-time integer constants), -// so therefore get<0>(Integral) just returns its input. +// Implementation of get<0>(Integral). +// Even though is_tuple is false and tuple_size doesn't compile, +// CuTe defines rank(Integral) as 1, so it's useful for get<0>(Integral) to return its input template >::value)> -CUTE_HOST_DEVICE constexpr decltype(auto) +CUTE_HOST_DEVICE constexpr +decltype(auto) get(T&& t) noexcept { static_assert(I == 0, "Index out of range"); @@ -67,23 +58,12 @@ get(T&& t) noexcept } // Custom recursive get for anything that implements get(.) (for a single integer I). -template -CUTE_HOST_DEVICE constexpr decltype(auto) -get(Tuple&& t) noexcept -{ - using get_I0_result_t = cute::remove_cvref_t(static_cast(t)))>; - if constexpr (cute::is_integral::value) { - // Help MSVC deduce that the inner get(...) call is not a "local variable or temporary." - // The above if constexpr test repeats the constraint on the above get(T&&) overload. - // get<0, 0, ..., 0>(t) for cute::integral (either one of the built-in integer types like int, - // or one of CuTe's compile-time constant types) t, and for one or more zeros, just returns t. - static_assert(I1 == 0, "Index I1 is out of range"); - static_assert(((Is == 0) && ...), "At least one index in Is is out of range"); - return get(static_cast(t)); - } - else { - return get(get(static_cast(t))); - } +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(T&& t) noexcept +{ + return get(get(static_cast(t))); } // @@ -347,6 +327,16 @@ ceil_div(IntTupleA const& a, IntTupleB const& b) } /** Division for Shapes + * Case Tuple Tuple: + * Perform shape_div element-wise + * Case Tuple Int: + * Fold the division of b across each element of a + * Example: shape_div((4,5,6),40) -> shape_div((1,5,6),10) -> shape_div((1,1,6),2) -> (1,1,3) + * Case Int Tuple: + * Return shape_div(a, product(b)) + * Case Int Int: + * Enforce the divisibility condition a % b == 0 || b % a == 0 when possible + * Return a / b with rounding away from 0 (that is, 1 or -1 when a < b) */ template CUTE_HOST_DEVICE constexpr @@ -357,39 +347,28 @@ shape_div(IntTupleA const& a, IntTupleB const& b) if constexpr (is_tuple::value) { // tuple tuple static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); return transform(a, b, [](auto const& x, auto const& y) { return shape_div(x,y); }); - } else { // tuple int + } else { // tuple int auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), [] (auto const& init, auto const& ai) { return cute::make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai)); }); return result; } - } else { - if constexpr (is_tuple::value) { // int tuple - return shape_div(a, product(b)); - } else { // int int - //assert(a % b == 0 || b % a == 0); - return a / b != 0 ? a / b : signum(a) * signum(b); // divide with rounding away from zero - } + } else + if constexpr (is_tuple::value) { // int tuple + return shape_div(a, product(b)); + } else + if constexpr (is_static::value && is_static::value) { + static_assert(IntTupleA::value % IntTupleB::value == 0 || IntTupleB::value % IntTupleA::value == 0, "Static shape_div failure"); + return C{}; + } else { // int int + //assert(a % b == 0 || b % a == 0); // Wave dynamic assertion + return a / b != 0 ? a / b : signum(a) * signum(b); // Division with rounding away from zero } CUTE_GCC_UNREACHABLE; } -/** Division for Shapes that are static constants - * @pre t % u == 0 || u % t == 0 - * @result if t % u == 0, then t / u - * if u % t == 0, then signum(t) * signum(u) - */ -template -CUTE_HOST_DEVICE constexpr -constant -shape_div(constant const&, constant const&) -{ - static_assert(t % u == 0 || u % t == 0, "Static shape_div failure"); - return {}; -} - /** Minimum for Shapes */ template @@ -581,7 +560,7 @@ make_int_tuple(Indexable const& t, int n, T const& init) /** Fill the dynamic values of a Tuple with values from another Tuple * \code - * auto params = make_int_tuple(6,3,4); + * auto params = make_tuple(6,3,4); * cute::tuple, cute::tuple>, int, Int<2>> result; * fill_int_tuple_from(result, params); // (_1,(6,3,_3),4,_2) * \endcode @@ -893,7 +872,8 @@ increment(Coord& coord, Shape const& shape) struct ForwardCoordIteratorSentinal {}; -// A forward iterator for a coordinate that starts from zero and goes to shape +// A forward iterator for a starting coordinate in a shape's domain, and a shape. +// The starting coordinate may be zero but need not necessarily be. template struct ForwardCoordIterator { @@ -905,7 +885,7 @@ struct ForwardCoordIterator CUTE_HOST_DEVICE constexpr ForwardCoordIterator& operator++() { increment(coord, shape); return *this; } - // Sentinal for the end of the implied range + // Sentinel for the end of the implied range CUTE_HOST_DEVICE constexpr bool operator< (ForwardCoordIteratorSentinal const&) const { return back(coord) < back(shape); } CUTE_HOST_DEVICE constexpr @@ -924,6 +904,15 @@ struct ForwardCoordIterator Shape const& shape; }; +// A forward iterator for a coordinate that starts from a provided coordinate +template +CUTE_HOST_DEVICE constexpr +auto +make_coord_iterator(Coord const& coord, Shape const& shape) +{ + return ForwardCoordIterator{coord,shape}; +} + // A forward iterator for a coordinate that starts from zero template CUTE_HOST_DEVICE constexpr @@ -931,7 +920,7 @@ auto make_coord_iterator(Shape const& shape) { auto coord = repeat_like(shape, int(0)); - return ForwardCoordIterator{coord,shape}; + return make_coord_iterator(coord, shape); } } // end namespace cute diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index 5b81cfd833..5072f0121e 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -43,16 +43,16 @@ namespace cute // Aliases template -using Shape = IntTuple; +using Shape = cute::tuple; template -using Stride = IntTuple; +using Stride = cute::tuple; template -using Step = IntTuple; +using Step = cute::tuple; template -using Coord = IntTuple; +using Coord = cute::tuple; template CUTE_HOST_DEVICE constexpr @@ -1034,29 +1034,29 @@ complement(Shape const& shape, Stride const& stride, CoSizeHi const& cosize_hi) // Should just be a sort and a fold... // Then we could even handle dynamic strides (but they would destroy all static strides) - auto result = fold(make_seq{}, - cute::make_tuple(shape, stride, cute::make_tuple(), cute::make_tuple(Int<1>{})), - [](auto const& init, auto i) - { - auto curr_stride = cute::min(get<1>(init)); - auto curr_idx = find(get<1>(init), curr_stride); - auto curr_shape = get(get<0>(init)); - - return cute::make_tuple(remove(get<0>(init)), // Remove the curr shape - remove(get<1>(init)), // Remove the curr stride - append(get<2>(init), curr_stride / get<3,i>(init)), // new shape = curr_stride / last_stride - append(get<3>(init), curr_shape * curr_stride)); // new stride = curr_shape * curr_stride - }); + auto [shape_, stride_, result_shape_, result_stride] = + fold(make_seq{}, + cute::make_tuple(shape, stride, cute::make_tuple(), cute::make_tuple(Int<1>{})), + [](auto const& init, auto i) + { + auto [shape, stride, result_shape, result_stride] = init; + auto min_stride = cute::min(stride); + auto min_idx = find(stride, min_stride); + + return cute::make_tuple(remove(shape), // Remove the min_idx from shape + remove(stride), // Remove the min_idx from stride + append(result_shape , min_stride / get(result_stride)), // new shape = min_stride / last_stride + append(result_stride, get(shape) * min_stride)); // new stride = curr_shape * min_stride + }); // Append the last shape mode - auto result_stride = get<3>(result); - auto result_shape = append(get<2>(result), get<1,0>(result) / back(result_stride)); // new shape = curr_stride / last_stride - - // Compute the rest_stride - auto rest_stride = get<0,0>(result) * get<1,0>(result); - //return make_layout(append(result_shape, ceil_div(cosize_hi, rest_stride)), append(result_stride, rest_stride)); - // Jump into coalesce and append (ceil_div(cosize_hi, rest_stride), rest_stride) - return detail::bw_coalesce(result_shape, result_stride, ceil_div(cosize_hi, rest_stride), rest_stride); + auto result_shape = append(result_shape_, get<0>(stride_) / get(result_stride)); // new shape = min_stride / last_stride + + // Compute the rest_shape and rest_stride + auto rest_stride = get<0>(shape_) * get<0>(stride_); + auto rest_shape = ceil_div(cosize_hi, rest_stride); + // Jump into coalesce and append (rest_shape, rest_stride) + return detail::bw_coalesce(result_shape, result_stride, rest_shape, rest_stride); } CUTE_GCC_UNREACHABLE; @@ -1608,16 +1608,15 @@ CUTE_HOST_DEVICE constexpr auto recast(Layout const& layout) { - if constexpr (sizeof(NewType) == sizeof(OldType)) { + if constexpr (sizeof_bits::value == sizeof_bits::value) { return layout; - } else if constexpr (sizeof(NewType) > sizeof(OldType)) { - static_assert(sizeof(NewType) % sizeof(OldType) == 0, "NewType must be a multiple of OldType"); - return upcast(layout); - } else if constexpr (sizeof(NewType) < sizeof(OldType)) { - static_assert(sizeof(OldType) % sizeof(NewType) == 0, "NewType must be a divisor of OldType"); - return downcast(layout); + } else if constexpr (sizeof_bits::value > sizeof_bits::value) { + static_assert(sizeof_bits::value % sizeof_bits::value == 0, "NewType must be a multiple of OldType"); + return upcast::value/sizeof_bits::value>(layout); + } else if constexpr (sizeof_bits::value < sizeof_bits::value) { + static_assert(sizeof_bits::value % sizeof_bits::value == 0, "NewType must be a divisor of OldType"); + return downcast::value/sizeof_bits::value>(layout); } - CUTE_GCC_UNREACHABLE; } diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index c2c73be7d8..ead3005cc8 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -387,8 +387,7 @@ abs(ScaledBasis const& e) { } // Multiplication -template ::value)> +template CUTE_HOST_DEVICE constexpr auto operator*(A const& a, ScaledBasis const& e) { @@ -396,8 +395,7 @@ operator*(A const& a, ScaledBasis const& e) { return ScaledBasis{r}; } -template ::value)> +template CUTE_HOST_DEVICE constexpr auto operator*(ScaledBasis const& e, B const& b) { diff --git a/include/cute/numeric/complex.hpp b/include/cute/numeric/complex.hpp index 43e4dd6356..b0d60a05fe 100644 --- a/include/cute/numeric/complex.hpp +++ b/include/cute/numeric/complex.hpp @@ -30,97 +30,19 @@ **************************************************************************************************/ #pragma once -#include - -//#if defined(__CUDA_ARCH__) -//# include -//#else -//# include -//#endif - -// Suppress warnings for code in Thrust headers. - -#if defined(_MSC_VER) - // We check for MSVC first, because MSVC also defines __GNUC__. - // It's common for non-GCC compilers that emulate GCC's behavior - // to define __GNUC__. - // - // thrust/complex.h triggers MSVC's warning on conversion - // from double to float (or const float) ("possible loss of data"). - // MSVC treats this as an error by default (at least with - // CUTLASS's default CMake configuration). -#pragma warning( push ) -#pragma warning( disable : 4244 ) -#elif defined(__GNUC__) - // With GCC + CUDA 11.4, builds show spurious "-Wconversion" - // warnings on line 656 of thrust/detail/type_traits.h. -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" -#endif - -#if defined(__CUDACC_RTC__) -#include -#else -#include -#endif - -#if defined(_MSC_VER) -#pragma warning( pop ) -#elif defined(__GNUC__) -#pragma GCC diagnostic pop -#endif - #include +#include +#include namespace cute { -//#if defined(__CUDA_ARCH__) -//template -//using complex = cuda::std::complex; -//#else -//template -//using complex = std::complex; -//#endif - -//template -//using complex = thrust::complex; - -#if defined(__CUDACC_RTC__) -using cuda::std::complex; -#else -using thrust::complex; -#endif - -template -CUTE_HOST_DEVICE -T real(complex const& z) { - return z.real(); -} - -template -CUTE_HOST_DEVICE -T imag(complex const& z) { - return z.imag(); -} - -template -CUTE_HOST_DEVICE -complex conj(complex const& z) { - return complex(real(z), -imag(z)); -} - -// cute::conj forwards scalars -template -CUTE_HOST_DEVICE -T conj(T z) { - return z; -} - -//CUTE_HOST_DEVICE constexpr -//float conj(float z) { return z; } -//CUTE_HOST_DEVICE constexpr -//double conj(double z) { return z; } +using cutlass::complex; +using cutlass::is_complex; +using cutlass::RealType; +using cutlass::real; +using cutlass::imag; +using cutlass::conj; /// Fused multiply-add for complex numbers template @@ -131,10 +53,10 @@ fma(complex & d, complex const& b, complex const& c) { - d.real(c.real() + a.real() * b.real()); - d.imag(c.imag() + a.real() * b.imag()); - d.real(d.real() - a.imag() * b.imag()); - d.imag(d.imag() + a.imag() * b.real()); + d.real(fma( a.real(), b.real(), c.real())); + d.imag(fma( a.real(), b.imag(), c.imag())); + d.real(fma(-a.imag(), b.imag(), d.real())); + d.imag(fma( a.imag(), b.real(), d.imag())); } /// Fused multiply-add for triplets @@ -148,46 +70,4 @@ fma(complex const& a, return fma(c, a, b, c); } -/// Used to determine the real-valued underlying type of a numeric type T -template -struct RealType { - using Type = T; -}; - -/// Partial specialization for complex-valued type -template -struct RealType> { - using Type = T; -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct is_complex { - static bool const value = false; -}; - -template -struct is_complex> { - static bool const value = true; -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// -// Display utilities - -#if !defined(__CUDACC_RTC__) -template -CUTE_HOST std::ostream& operator<<(std::ostream& os, complex const& z) -{ - T _r = z.real(); - T _i = z.imag(); - - if (bool(_i)) { - return os << _r << "+i" << _i; - } else { - return os << _r; - } -} -#endif - } // end namespace cute diff --git a/include/cute/numeric/integral_constant.hpp b/include/cute/numeric/integral_constant.hpp index bb165111f0..a88892251b 100644 --- a/include/cute/numeric/integral_constant.hpp +++ b/include/cute/numeric/integral_constant.hpp @@ -30,15 +30,14 @@ **************************************************************************************************/ #pragma once -#include - -#include -#include +#include "cute/util/print.hpp" +#include "cute/util/type_traits.hpp" +#include "cute/numeric/math.hpp" namespace cute { -// Short name for fast compilation +// A constant value: short name and type-deduction for fast compilation template struct C { using type = C; @@ -48,29 +47,40 @@ struct C { CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } }; +// Deprecate template using constant = C; -template -using integral_constant = C; - template using bool_constant = C; using true_type = bool_constant; using false_type = bool_constant; +// A more std:: conforming integral_constant that enforces type but interops with C +template +struct integral_constant : C { + using type = integral_constant; + static constexpr T value = v; + using value_type = T; + // Disambiguate C::operator value_type() + //CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } + CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } +}; + // // Traits // // Use cute::is_std_integral to match built-in integral types (int, int64_t, unsigned, etc) -// Use cute::is_integral to match both built-in integral types AND constant +// Use cute::is_integral to match both built-in integral types AND static integral types. template struct is_integral : bool_constant::value> {}; template -struct is_integral> : true_type {}; +struct is_integral > : true_type {}; +template +struct is_integral> : true_type {}; // is_static detects if an (abstract) value is defined completely by it's type (no members) @@ -80,20 +90,22 @@ struct is_static : bool_constant::value> {}; template constexpr bool is_static_v = is_static::value; -// is_constant detects if a type is a constant and if v is equal to a value +// is_constant detects if a type is a static integral type and if v is equal to a value template struct is_constant : false_type {}; +template +struct is_constant : is_constant {}; +template +struct is_constant : is_constant {}; +template +struct is_constant : is_constant {}; +template +struct is_constant : is_constant {}; template -struct is_constant > : bool_constant {}; -template -struct is_constant const > : bool_constant {}; -template -struct is_constant const&> : bool_constant {}; -template -struct is_constant &> : bool_constant {}; -template -struct is_constant &&> : bool_constant {}; +struct is_constant > : bool_constant {}; +template +struct is_constant> : bool_constant {}; // // Specializations @@ -403,9 +415,10 @@ conditional_return(TrueType const& t, FalseType const& f) { // Display utilities // -template -CUTE_HOST_DEVICE void print(C const&) { - printf("_%d", int(t)); +template +CUTE_HOST_DEVICE void print(C) { + printf("_"); + ::cute::print(Value); } #if !defined(__CUDACC_RTC__) diff --git a/include/cute/numeric/integral_ratio.hpp b/include/cute/numeric/integral_ratio.hpp new file mode 100644 index 0000000000..028ffffd66 --- /dev/null +++ b/include/cute/numeric/integral_ratio.hpp @@ -0,0 +1,175 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include + +namespace cute +{ + +/** Compile-time rational arithmetic type. + * Like cute::C for std::integral_constant, cute::R for std::ratio has a short name + * for error messages and compile times. + * The static data members @a num and @a den represent the reduced numerator and denominator + * of the rational value. Thus, two cute::R types with different @a n or @a d are distinct types + * even if they represent the same rational value. A cute::R exposes the reduced canonical type + * via its type member. That is, cute::R<3,6>::type is cute::R<1,2> and cute::R<6,3>::type is cute::C<2> + */ +template +class R { + static_assert(d != 0); + static constexpr auto an = abs(n); + static constexpr auto ad = abs(d); + static constexpr auto g = gcd(an, ad); + + public: + static constexpr auto num = signum(n) * signum(d) * an / g; + static constexpr auto den = ad / g; + // RI: den >= 1 && gcd(abs(num),den) == 1 + using type = typename conditional, R>::type; +}; + +template +CUTE_HOST_DEVICE constexpr +typename R::type +ratio(C, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator*(R, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator*(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator*(C, R) { + return {}; +} + +// Product with dynamic type needs to produce an integer... +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +operator*(C const& c, R) { + return c * R::num / R::den; +} + +// Product with dynamic type needs to produce an integer... +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +operator*(R, C const& c) { + return c * R::num / R::den; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator+(R, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator+(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator+(C, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +bool_constant::num == R::num && R::den == R::den> +operator==(R, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +bool_constant::num == c && R::den == 1> +operator==(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +bool_constant::num == c && R::den == 1> +operator==(C, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +abs(R) { + return {}; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(R) { + print(C{}); print("/"); print(C{}); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, R) { + return os << "_" << C{} << "/" << C{}; +} +#endif + +} // end namespace cute diff --git a/include/cute/numeric/math.hpp b/include/cute/numeric/math.hpp index ec46fd79a4..fc717c9310 100644 --- a/include/cute/numeric/math.hpp +++ b/include/cute/numeric/math.hpp @@ -73,11 +73,26 @@ abs(T const& t) { CUTE_GCC_UNREACHABLE; } +// Returns 1 if x > 0, -1 if x < 0, and 0 if x is zero. +template ::value)> +CUTE_HOST_DEVICE constexpr +int +signum(T const& x) { + if constexpr (is_signed::value) { + return (T(0) < x) - (x < T(0)); + } else { + return T(0) < x; + } + + CUTE_GCC_UNREACHABLE; +} + // // C++17 operations // -// Greatest common divisor of two integers +// Greatest common divisor of two positive integers template ::value && is_std_integral::value)> @@ -92,7 +107,7 @@ gcd(T t, U u) { } } -// Least common multiple of two integers +// Least common multiple of two positive integers template ::value && is_std_integral::value)> @@ -280,23 +295,6 @@ shiftr(T x, int s) { return s >= 0 ? (x >> s) : (x << -s); } -// Returns 1 if x > 0, -1 if x < 0, and 0 if x is zero. -template ::value)> -CUTE_HOST_DEVICE constexpr -int -signum(T const& x) { - return T(0) < x; -} - -template ::value)> -CUTE_HOST_DEVICE constexpr -int -signum(T const& x) { - return (T(0) < x) - (x < T(0)); -} - // Safe divide // @pre t % u == 0 // @result t / u diff --git a/include/cute/pointer.hpp b/include/cute/pointer.hpp index 479ad699b5..6c6a738f10 100644 --- a/include/cute/pointer.hpp +++ b/include/cute/pointer.hpp @@ -58,6 +58,19 @@ raw_pointer_cast(T* ptr) { return ptr; } +// +// Extract the physical type from a logical elem type. +// +template +struct get_raw_type +{ + using type = T; +}; + +template +using get_raw_type_t = typename get_raw_type::type; + + // // Pointer categories // @@ -79,6 +92,8 @@ template struct device_ptr { using value_type = T; + + static const uint32_t ElementsPerStoredItem = sizeof(T) * 8 / sizeof_bits_v; CUTE_HOST_DEVICE constexpr device_ptr(T* ptr) : ptr_(ptr) {} @@ -91,11 +106,14 @@ struct device_ptr template CUTE_HOST_DEVICE constexpr - T& operator[](Index const& i) const { return ptr_[i]; } + T& operator[](Index const& i) const { + static_assert(sizeof_bits_v >= 8, "Use subbyte_iterator to access the element"); + return ptr_[i]; + } template CUTE_HOST_DEVICE constexpr - DerivedType operator+(Index const& i) const { return {ptr_ + i}; } + DerivedType operator+(Index const& i) const { return {ptr_ + i / ElementsPerStoredItem}; } CUTE_HOST_DEVICE constexpr friend ptrdiff_t operator-(device_ptr const& a, @@ -326,44 +344,44 @@ recast(rmem_ptr const& ptr) { template CUTE_HOST_DEVICE void print(T const* const ptr) { - printf("raw_ptr_%db(%p)", int(8*sizeof(T)), ptr); + printf("raw_ptr_%db(%p)", int(sizeof_bits::value), ptr); } template CUTE_HOST_DEVICE void print(gmem_ptr const& ptr) { - printf("gmem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get()); + printf("gmem_ptr_%db(%p)", int(sizeof_bits::value), ptr.get()); } template CUTE_HOST_DEVICE void print(smem_ptr const& ptr) { - printf("smem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get()); + printf("smem_ptr_%db(%p)", int(sizeof_bits::value), ptr.get()); } template CUTE_HOST_DEVICE void print(rmem_ptr const& ptr) { - printf("rmem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get()); + printf("rmem_ptr_%db(%p)", int(sizeof_bits::value), ptr.get()); } #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr const& ptr) { - return os << "gmem_ptr_" << int(8*sizeof(T)) << "b"; + return os << "gmem_ptr_" << int(sizeof_bits::value) << "b"; } template CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr const& ptr) { - return os << "smem_ptr_" << int(8*sizeof(T)) << "b"; + return os << "smem_ptr_" << int(sizeof_bits::value) << "b"; } template CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr const& ptr) { - return os << "rmem_ptr_" << int(8*sizeof(T)) << "b"; + return os << "rmem_ptr_" << int(sizeof_bits::value) << "b"; } #endif // !defined(__CUDACC_RTC__) diff --git a/include/cute/stride.hpp b/include/cute/stride.hpp index 06d4b97755..d5221339eb 100644 --- a/include/cute/stride.hpp +++ b/include/cute/stride.hpp @@ -75,6 +75,9 @@ crd2idx_itt(CInt const& coord, { if constexpr (sizeof...(Is) == 0) { // Avoid recursion and mod on single/last iter return crd2idx(coord, get(shape), get(stride)); + } else if constexpr (is_constant<0, CInt>::value) { + return crd2idx(_0{}, get(shape), get(stride)) + + (_0{} + ... + crd2idx(_0{}, get(shape), get(stride))); } else { // General case return crd2idx(coord % product(get(shape)), get(shape), get(stride)) + crd2idx_itt(coord / product(get(shape)), shape, stride, seq{}); diff --git a/include/cute/swizzle.hpp b/include/cute/swizzle.hpp index c8d910a03b..39ac311de2 100644 --- a/include/cute/swizzle.hpp +++ b/include/cute/swizzle.hpp @@ -218,41 +218,40 @@ recast(Swizzle const& swizzle) // consumed and which bits are free. Furthermore, it is useful to know whether // each of these bits is known statically or dynamically. -// MixedBits is an integer class where some bits are known statically and some -// bits are known dynamically. These sets of bits are disjoint and it is known -// statically which bits are known dynamically. +// MixedBits is an 32-bit unsigned integer class where some bits are known statically +// and some bits are known dynamically. These sets of bits are disjoint and it is +// known statically which bits are known dynamically. // MixedBits can only be manipulated through bitwise operations // Abstract value: StaticInt | (dynamic_int_ & StaticFlags) -template // 0: static, 1: dynamic +template // 0: static, 1: dynamic struct MixedBits { // Representation invariants static_assert(StaticFlags != 0, "Should be at least one dynamic bit in MixedBits."); static_assert((StaticInt & StaticFlags) == 0, "No static/dynamic overlap allowed in MixedBits."); - // assert((dynamic_int_ & ~F) == 0); - DynamicType dynamic_int_; + uint32_t dynamic_int_; + // assert((dynamic_int_ & ~StaticFlags) == 0); CUTE_HOST_DEVICE constexpr operator uint32_t() const noexcept { return StaticInt | dynamic_int_; } }; -template +// Return a value representing (C{} | (d & C)) potentially using MixedBits to track s and f. +// This maker does allow ((s & f) != 0) and enforces the MixedBits invariant before creation. +template CUTE_HOST_DEVICE constexpr auto -make_mixed_bits(constant const&, DynamicType const& d, constant const&) +make_mixed_bits(C, DynamicType const& d, C) { static_assert(is_integral::value); - if constexpr (is_static::value) { - static_assert((s & DynamicType::value & f) == 0, "No static/dynamic overlap allowed."); - return constant{} | (d & constant{}); // Just return a static int - } else if constexpr (f == 0) { - return constant{}; // Just return a static int + constexpr uint32_t new_f = uint32_t(f) & ~uint32_t(s); // StaticBits take precedence, M<0,f>{d} | C{} + if constexpr (new_f == 0 || is_static::value) { + return C{} | (d & C{}); // Just return a static int } else { - return MixedBits{d & f}; // MixedBits + return MixedBits{uint32_t(d) & new_f}; // MixedBits } CUTE_GCC_UNREACHABLE; @@ -263,28 +262,28 @@ make_mixed_bits(constant const&, DynamicType const& d, constant const& // // Equality -template +template CUTE_HOST_DEVICE constexpr auto -operator==(MixedBits const& m, constant const&) +operator==(MixedBits const& m, C) { - return (S0 == (S1 & ~F0)) && (m.dynamic_int_ == (S1 & F0)); + return (S0 == (uint32_t(S1) & ~F0)) && (m.dynamic_int_ == (uint32_t(S1) & F0)); } -template +template CUTE_HOST_DEVICE constexpr auto -operator==(constant const& s, MixedBits const& m) +operator==(C s, MixedBits const& m) { return m == s; } // Bitwise AND -template +template CUTE_HOST_DEVICE constexpr auto -operator&(MixedBits const& m0, MixedBits const& m1) +operator&(MixedBits const& m0, MixedBits const& m1) { // Truth table for (S0,D0,F0) & (S1,D1,F1) -> (S,D,F) // S0D0F0 | 0X0 | 001 | 011 | 1X0 | @@ -294,36 +293,36 @@ operator&(MixedBits const& m0, MixedBits const& m1) // 011 | 0X0 | 001 | 011 | 011 | // 1X0 | 0X0 | 001 | 011 | 1X0 | - return make_mixed_bits(constant{}, + return make_mixed_bits(C{}, //(S0 | m0.dynamic_int_) & (S1 | m1.dynamic_int_), ((S1 & F0) & m0.dynamic_int_) | ((S0 & F1) & m1.dynamic_int_) | (m0.dynamic_int_ & m1.dynamic_int_), - constant{}); + C<(S1 & F0) | (S0 & F1) | (F0 & F1)>{}); } -template +template CUTE_HOST_DEVICE constexpr auto -operator&(MixedBits const& m, constant const&) +operator&(MixedBits const& m, C) { - return make_mixed_bits(constant{}, + return make_mixed_bits(C{}, m.dynamic_int_, - constant{}); + C{}); } -template +template CUTE_HOST_DEVICE constexpr auto -operator&(constant const& s, MixedBits const& m) +operator&(C s, MixedBits const& m) { return m & s; } // Bitwise OR -template +template CUTE_HOST_DEVICE constexpr auto -operator|(MixedBits const& m0, MixedBits const& m1) +operator|(MixedBits const& m0, MixedBits const& m1) { // Truth table for (S0,D0,F0) | (S1,D1,F1) -> (S,D,F) // S0D0F0 | 0X0 | 001 | 011 | 1X0 | @@ -333,35 +332,35 @@ operator|(MixedBits const& m0, MixedBits const& m1) // 011 | 011 | 011 | 011 | 1X0 | // 1X0 | 1X0 | 1X0 | 1X0 | 1X0 | - return make_mixed_bits(constant{}, + return make_mixed_bits(C{}, ((~S1 & F0) & m0.dynamic_int_) | ((~S0 & F1) & m1.dynamic_int_), - constant{}); + C<(~S0 & F1) | (~S1 & F0)>{}); } -template +template CUTE_HOST_DEVICE constexpr auto -operator|(MixedBits const& m, constant const&) +operator|(MixedBits const& m, C) { - return make_mixed_bits(constant{}, + return make_mixed_bits(C{}, m.dynamic_int_, - constant{}); + C{}); } -template +template CUTE_HOST_DEVICE constexpr auto -operator|(constant const& s, MixedBits const& m) +operator|(C s, MixedBits const& m) { return m | s; } // Bitwise XOR -template +template CUTE_HOST_DEVICE constexpr auto -operator^(MixedBits const& m0, MixedBits const& m1) +operator^(MixedBits const& m0, MixedBits const& m1) { // Truth table for (S0,D0,F0) ^ (S1,D1,F1) -> (S,D,F) // S0D0F0 | 0X0 | 001 | 011 | 1X0 | @@ -371,53 +370,53 @@ operator^(MixedBits const& m0, MixedBits const& m1) // 011 | 011 | 011 | 001 | 001 | // 1X0 | 1X0 | 011 | 001 | 0X0 | - return make_mixed_bits(constant{}, + return make_mixed_bits(C<(~S0 & S1 & ~F0) | (S0 & ~S1 & ~F1)>{}, (S0 | m0.dynamic_int_) ^ (S1 | m1.dynamic_int_), - constant{}); + C{}); } -template +template CUTE_HOST_DEVICE constexpr auto -operator^(MixedBits const& m, constant const&) +operator^(MixedBits const& m, C) { - return make_mixed_bits(constant{}, - (S0 | m.dynamic_int_) ^ S1, - constant{}); + return make_mixed_bits(C<(~S0 & uint32_t(S1) & ~F0) | (S0 & ~uint32_t(S1))>{}, + (S0 | m.dynamic_int_) ^ uint32_t(S1), + C{}); } -template +template CUTE_HOST_DEVICE constexpr auto -operator^(constant const& s, MixedBits const& m) +operator^(C s, MixedBits const& m) { return m ^ s; } -template +template CUTE_HOST_DEVICE constexpr auto -operator<<(MixedBits const& m, constant const&) +operator<<(MixedBits const& m, C) { - return make_mixed_bits(constant{}, + return make_mixed_bits(C<(S0 << S1)>{}, m.dynamic_int_ << S1, - constant{}); + C<(F0 << S1)>{}); } -template +template CUTE_HOST_DEVICE constexpr auto -operator>>(MixedBits const& m, constant const&) +operator>>(MixedBits const& m, C) { - return make_mixed_bits(constant> S1)>{}, + return make_mixed_bits(C<(S0 >> S1)>{}, m.dynamic_int_ >> S1, - constant> S1)>{}); + C<(F0 >> S1)>{}); } -template +template CUTE_HOST_DEVICE constexpr auto -shiftl(MixedBits const& m, constant const& s) +shiftl(MixedBits const& m, C s) { if constexpr (S1 >= 0) { return m << s; @@ -426,10 +425,10 @@ shiftl(MixedBits const& m, constant const& s) } } -template +template CUTE_HOST_DEVICE constexpr auto -shiftr(MixedBits const& m, constant const& s) +shiftr(MixedBits const& m, C s) { if constexpr (S1 >= 0) { return m >> s; @@ -442,24 +441,24 @@ shiftr(MixedBits const& m, constant const& s) // upcast and downcast // -template +template CUTE_HOST_DEVICE constexpr auto -safe_div(MixedBits const& m, constant const& s) +safe_div(MixedBits const& m, C s) { - static_assert(has_single_bit(S1), "Only divide MixedBits by powers of two."); - return make_mixed_bits(safe_div(constant{}, s), + static_assert(has_single_bit(uint32_t(S1)), "Only divide MixedBits by powers of two."); + return make_mixed_bits(safe_div(C{}, s), safe_div(m.dynamic_int_, s), - safe_div(constant{}, s)); + safe_div(C{}, s)); } -template +template CUTE_HOST_DEVICE constexpr auto -upcast(MixedBits const& m) +upcast(MixedBits const& m) { static_assert(has_single_bit(N), "Only divide MixedBits by powers of two."); - return safe_div(m, constant{}); + return safe_div(m, C{}); } template ::value)> @@ -467,18 +466,18 @@ CUTE_HOST_DEVICE constexpr auto upcast(T const& m) { - return safe_div(m, constant{}); + return safe_div(m, C{}); } -template +template CUTE_HOST_DEVICE constexpr auto -downcast(MixedBits const& m) +downcast(MixedBits const& m) { static_assert(has_single_bit(N), "Only scale MixedBits by powers of two."); - return make_mixed_bits(constant{}, + return make_mixed_bits(C{}, m.dynamic_int_ * N, - constant{}); + C{}); } template ::value)> @@ -486,7 +485,7 @@ CUTE_HOST_DEVICE constexpr auto downcast(T const& m) { - return m * constant{}; + return m * C{}; } // @@ -525,17 +524,17 @@ to_mixed_bits(Layout const& layout, Coord const& coord) // Display utilities // -template -CUTE_HOST_DEVICE void print(MixedBits const& m) +template +CUTE_HOST_DEVICE void print(MixedBits const& m) { - printf("M_%u|(%u&%u)=%u", S, uint32_t(m.dynamic_int_), F, uint32_t(m)); + printf("M_%u|(%u&%u)=%u", S, m.dynamic_int_, F, uint32_t(m)); } #if !defined(__CUDACC_RTC__) template -CUTE_HOST std::ostream& operator<<(std::ostream& os, MixedBits const& m) +CUTE_HOST std::ostream& operator<<(std::ostream& os, MixedBits const& m) { - return os << "M_" << S << "|(" << uint32_t(m.dynamic_int_) << "&" << F << ")=" << uint32_t(m); + return os << "M_" << S << "|(" << m.dynamic_int_ << "&" << F << ")=" << uint32_t(m); } template diff --git a/include/cute/swizzle_layout.hpp b/include/cute/swizzle_layout.hpp index a5919716e2..be966d97e7 100644 --- a/include/cute/swizzle_layout.hpp +++ b/include/cute/swizzle_layout.hpp @@ -128,6 +128,7 @@ namespace detail { // Get just the Swizzle part of a composed layout. template +CUTE_HOST_DEVICE constexpr auto get_swizzle_portion(ComposedLayout,Offset,LayoutB>) { @@ -136,6 +137,7 @@ get_swizzle_portion(ComposedLayout,Offset,LayoutB>) // A non-swizzled layout's "Swizzle part" is the identity swizzle. template +CUTE_HOST_DEVICE constexpr auto get_swizzle_portion(Layout) { diff --git a/include/cute/swizzle_ptr.hpp b/include/cute/swizzle_ptr.hpp index 50bfbfa2dd..fde7454f14 100644 --- a/include/cute/swizzle_ptr.hpp +++ b/include/cute/swizzle_ptr.hpp @@ -70,6 +70,8 @@ struct smem_ptr_swizzle { static_assert(is_empty::value, "Swizzle can't have state."); + static const uint32_t ElementsPerStoredItem = sizeof(T) * 8 / sizeof_bits_v; + CUTE_HOST_DEVICE constexpr T* get() const { @@ -98,6 +100,7 @@ struct smem_ptr_swizzle CUTE_HOST_DEVICE constexpr T& operator[](Int const& i) const { + static_assert(sizeof_bits_v >= 8, "Use subbyte_iterator to access the element"); return *apply_swizzle(get() + i); } @@ -105,7 +108,7 @@ struct smem_ptr_swizzle CUTE_HOST_DEVICE constexpr smem_ptr_swizzle operator+(Int const& i) const { - return {ptr_ + i}; + return {ptr_ + i / ElementsPerStoredItem}; } T* ptr_; @@ -286,14 +289,14 @@ CUTE_HOST_DEVICE void print(smem_ptr_flag_bits const& ptr) template CUTE_HOST_DEVICE void print(smem_ptr_swizzle> const& ptr) { - printf("smem_ptr_S<%d,%d,%d>_%db(%p)", B, M, S, int(8*sizeof(T)), ptr.get()); + printf("smem_ptr_S<%d,%d,%d>_%db(%p)", B, M, S, int(sizeof_bits::value), ptr.get()); } #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr_swizzle> const&) { - return os << "smem_ptr_S<" << B << "," << M << "," << S << ">_" << int(8*sizeof(T)) << "b"; + return os << "smem_ptr_S<" << B << "," << M << "," << S << ">_" << int(sizeof_bits::value) << "b"; } #endif diff --git a/include/cute/util/print.hpp b/include/cute/util/print.hpp index f585bf3172..31dad07fd0 100644 --- a/include/cute/util/print.hpp +++ b/include/cute/util/print.hpp @@ -33,7 +33,6 @@ #include #include -#include // // CUDA compatible print and printf @@ -119,30 +118,75 @@ get_format(double) { CUTE_HOST_DEVICE void -print(char const& c) { +print(char c) { printf("%c", c); } -template ::value)> CUTE_HOST_DEVICE void -print(T const& a) { - printf("%d", int(a)); +print(signed char a) { + printf("%hhd", a); } -template CUTE_HOST_DEVICE void -print(char const* format, T const&... t) { - printf(format, t...); +print(unsigned char a) { + printf("%hhu", a); +} + +CUTE_HOST_DEVICE +void +print(short a) { + printf("%hd", a); +} + +CUTE_HOST_DEVICE +void +print(unsigned short a) { + printf("%hu", a); +} + +CUTE_HOST_DEVICE +void +print(int a) { + printf("%d", a); +} + +CUTE_HOST_DEVICE +void +print(unsigned int a) { + printf("%u", a); +} + +CUTE_HOST_DEVICE +void +print(long a) { + printf("%ld", a); +} + +CUTE_HOST_DEVICE +void +print(unsigned long a) { + printf("%lu", a); +} + +CUTE_HOST_DEVICE +void +print(long long a) { + printf("%lld", a); +} + +CUTE_HOST_DEVICE +void +print(unsigned long long a) { + printf("%llu", a); } template CUTE_HOST_DEVICE void -print(T const&... t) { - (print(t), ...); +print(char const* format, T const&... t) { + printf(format, t...); } CUTE_HOST_DEVICE diff --git a/include/cutlass/arch/mma_sm75.h b/include/cutlass/arch/mma_sm75.h index 4d6c63102c..a08ba333c9 100644 --- a/include/cutlass/arch/mma_sm75.h +++ b/include/cutlass/arch/mma_sm75.h @@ -130,7 +130,7 @@ struct Mma< CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -196,7 +196,7 @@ struct Mma< CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -257,13 +257,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -318,13 +317,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -379,14 +377,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k16.row.col.s8.u8 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - - #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -441,13 +437,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -461,7 +456,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = S8 * S8 + S32 template <> struct Mma< - gemm::GemmShape<8,8,16>, + gemm::GemmShape<8, 8, 16>, 32, int8_t, layout::RowMajor, @@ -471,7 +466,7 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<8,8,16>; + using Shape = gemm::GemmShape<8, 8, 16>; using ElementA = int8_t; using LayoutA = layout::RowMajor; @@ -508,13 +503,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -522,7 +516,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = U8 * S8 + S32 template <> struct Mma< - gemm::GemmShape<8,8,16>, + gemm::GemmShape<8, 8, 16>, 32, uint8_t, layout::RowMajor, @@ -532,7 +526,7 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<8,8,16>; + using Shape = gemm::GemmShape<8, 8, 16>; using ElementA = uint8_t; using LayoutA = layout::RowMajor; @@ -569,13 +563,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -583,7 +576,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = S8 * U8 + S32 template <> struct Mma< - gemm::GemmShape<8,8,16>, + gemm::GemmShape<8, 8, 16>, 32, int8_t, layout::RowMajor, @@ -593,7 +586,7 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<8,8,16>; + using Shape = gemm::GemmShape<8, 8, 16>; using ElementA = int8_t; using LayoutA = layout::RowMajor; @@ -630,13 +623,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -644,7 +636,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = U8 * U8 + S32 template <> struct Mma< - gemm::GemmShape<8,8,16>, + gemm::GemmShape<8, 8, 16>, 32, uint8_t, layout::RowMajor, @@ -654,7 +646,7 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<8,8,16>; + using Shape = gemm::GemmShape<8, 8, 16>; using ElementA = uint8_t; using LayoutA = layout::RowMajor; @@ -691,13 +683,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -711,7 +702,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = S4 * S4 + S32 template <> struct Mma< - gemm::GemmShape<8,8,32>, + gemm::GemmShape<8, 8, 32>, 32, int4b_t, layout::RowMajor, @@ -721,7 +712,7 @@ struct Mma< layout::RowMajor, OpMultiplyAdd> { - using Shape = gemm::GemmShape<8,8,32>; + using Shape = gemm::GemmShape<8, 8, 32>; using ElementA = int4b_t; using LayoutA = layout::RowMajor; @@ -751,19 +742,19 @@ struct Mma< unsigned const & A = reinterpret_cast(a); unsigned const & B = reinterpret_cast(b); + int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -771,7 +762,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = U4 * S4 + S32 template <> struct Mma< - gemm::GemmShape<8,8,32>, + gemm::GemmShape<8, 8, 32>, 32, uint4b_t, layout::RowMajor, @@ -781,7 +772,7 @@ struct Mma< layout::RowMajor, OpMultiplyAdd> { - using Shape = gemm::GemmShape<8,8,32>; + using Shape = gemm::GemmShape<8, 8, 32>; using ElementA = uint4b_t; using LayoutA = layout::RowMajor; @@ -818,13 +809,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -832,7 +822,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = S4 * U4 + S32 template <> struct Mma< - gemm::GemmShape<8,8,32>, + gemm::GemmShape<8, 8, 32>, 32, int4b_t, layout::RowMajor, @@ -842,7 +832,7 @@ struct Mma< layout::RowMajor, OpMultiplyAdd> { - using Shape = gemm::GemmShape<8,8,32>; + using Shape = gemm::GemmShape<8, 8, 32>; using ElementA = int4b_t; using LayoutA = layout::RowMajor; @@ -879,13 +869,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -893,7 +882,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = U4 * U4 + S32 template <> struct Mma< - gemm::GemmShape<8,8,32>, + gemm::GemmShape<8, 8, 32>, 32, uint4b_t, layout::RowMajor, @@ -903,7 +892,7 @@ struct Mma< layout::RowMajor, OpMultiplyAdd> { - using Shape = gemm::GemmShape<8,8,32>; + using Shape = gemm::GemmShape<8, 8, 32>; using ElementA = uint4b_t; using LayoutA = layout::RowMajor; @@ -940,13 +929,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -960,7 +948,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = S4 * S4 + S32 template <> struct Mma< - gemm::GemmShape<8,8,32>, + gemm::GemmShape<8, 8, 32>, 32, int4b_t, layout::RowMajor, @@ -970,7 +958,7 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<8,8,32>; + using Shape = gemm::GemmShape<8, 8, 32>; using ElementA = int4b_t; using LayoutA = layout::RowMajor; @@ -1007,13 +995,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -1021,7 +1008,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = U4 * S4 + S32 template <> struct Mma< - gemm::GemmShape<8,8,32>, + gemm::GemmShape<8, 8, 32>, 32, uint4b_t, layout::RowMajor, @@ -1031,7 +1018,7 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<8,8,32>; + using Shape = gemm::GemmShape<8, 8, 32>; using ElementA = uint4b_t; using LayoutA = layout::RowMajor; @@ -1068,13 +1055,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -1082,7 +1068,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = S4 * U4 + S32 template <> struct Mma< - gemm::GemmShape<8,8,32>, + gemm::GemmShape<8, 8, 32>, 32, int4b_t, layout::RowMajor, @@ -1092,7 +1078,7 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<8,8,32>; + using Shape = gemm::GemmShape<8, 8, 32>; using ElementA = int4b_t; using LayoutA = layout::RowMajor; @@ -1129,13 +1115,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -1143,7 +1128,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = U4 * U4 + S32 template <> struct Mma< - gemm::GemmShape<8,8,32>, + gemm::GemmShape<8, 8, 32>, 32, uint4b_t, layout::RowMajor, @@ -1153,7 +1138,7 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<8,8,32>; + using Shape = gemm::GemmShape<8, 8, 32>; using ElementA = uint4b_t; using LayoutA = layout::RowMajor; @@ -1190,13 +1175,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -1287,7 +1271,7 @@ struct Mma< CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); CUTLASS_UNUSED(d); - assert(0); // WMMA must be supported to issue binary matrix multiply-accumulate instructions. + CUTLASS_NOT_IMPLEMENTED(); // WMMA must be supported to issue binary matrix multiply-accumulate instructions. #endif // defined(CUTLASS_ARCH_WMMA_ENABLED) diff --git a/include/cutlass/arch/mma_sm80.h b/include/cutlass/arch/mma_sm80.h index c01a7b07c4..18543b71a8 100644 --- a/include/cutlass/arch/mma_sm80.h +++ b/include/cutlass/arch/mma_sm80.h @@ -53,7 +53,16 @@ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) #define CUTLASS_ARCH_MMA_SM80_ENABLED + +#if (__CUDA_ARCH__ <= 900) +#define CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED +#endif +#if (__CUDA_ARCH__ <= 890) +#define CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED +#endif + #endif + #endif //////////////////////////////////////////////////////////////////////////////// @@ -2084,7 +2093,7 @@ struct Mma< FragmentC const &c ) const { -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) +#if defined(CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); @@ -2149,7 +2158,7 @@ struct Mma< FragmentC const &c ) const { -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) +#if defined(CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); @@ -2220,7 +2229,7 @@ struct Mma< FragmentC const &c ) const { -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) +#if defined(CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); @@ -2244,7 +2253,7 @@ struct Mma< CUTLASS_UNUSED(d); assert(0); -#endif // defined(CUTLASS_ARCH_MMA_SM80_ENABLED) +#endif // defined(CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED) } }; diff --git a/include/cutlass/array.h b/include/cutlass/array.h index 19d16cc251..e5132d9d83 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -33,10 +33,24 @@ and is safe to use in a union. */ +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ + #pragma once #include "cutlass/cutlass.h" #include "cutlass/functional.h" -#include "cutlass/numeric_types.h" +#include "cutlass/numeric_size.h" +#include "cutlass/half.h" +#include "cutlass/integer_subbyte.h" +#include "cutlass/tfloat32.h" +#include "cutlass/bfloat16.h" #include "cutlass/half.h" namespace cutlass { diff --git a/include/cutlass/array_subbyte.h b/include/cutlass/array_subbyte.h index ac30422408..7ec158b16d 100644 --- a/include/cutlass/array_subbyte.h +++ b/include/cutlass/array_subbyte.h @@ -32,6 +32,15 @@ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe to use in a union. */ +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ #pragma once diff --git a/include/cutlass/barrier.h b/include/cutlass/barrier.h index 5250048222..63502571fc 100644 --- a/include/cutlass/barrier.h +++ b/include/cutlass/barrier.h @@ -342,7 +342,7 @@ struct SyncManager { CUTLASS_DEVICE static void wait_lt(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int count) { - BarrierSync::wait_lt_helper(lock_ptr, thread_idx, flag_idx, count); + BarrierSync::wait_lt(lock_ptr, thread_idx, flag_idx, count); } CUTLASS_DEVICE diff --git a/include/cutlass/bfloat16.h b/include/cutlass/bfloat16.h index b660cd44c6..0c3397d288 100644 --- a/include/cutlass/bfloat16.h +++ b/include/cutlass/bfloat16.h @@ -33,6 +33,17 @@ \brief Defines a proxy class for storing non-standard 16-bit floating point values with 8 bits of exponent and 7 bit of mantissa. */ + +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ + #pragma once #if defined(__CUDACC_RTC__) diff --git a/include/cutlass/complex.h b/include/cutlass/complex.h index ffce5d09b7..729d242ceb 100644 --- a/include/cutlass/complex.h +++ b/include/cutlass/complex.h @@ -300,6 +300,13 @@ class complex CUTLASS_HOST_DEVICE T &imag() { return _imag; } + /// Set the real part of the complex number + CUTLASS_HOST_DEVICE + void real(T real) { _real = real; } + + /// Set the imaginary part of the complex number + CUTLASS_HOST_DEVICE + void imag(T imag) { _imag = imag; } #if !defined(__CUDACC_RTC__) /// Converts to cuFloatComplex @@ -431,34 +438,55 @@ CUTLASS_HOST_DEVICE R norm_accumulate(T const &x, R const & accumulator) { /// Norm accumulate specialized for complex types template CUTLASS_HOST_DEVICE R norm_accumulate(complex const &z, R const &accumulator) { - return accumulator + static_cast(real(z)) * static_cast(real(z)) + + return accumulator + static_cast(real(z)) * static_cast(real(z)) + static_cast(imag(z)) * static_cast(imag(z)); } -/// Returns the complex conjugate CUTLASS_HOST_DEVICE float conj(float const &z) { return z; } -/// Returns the complex conjugate CUTLASS_HOST_DEVICE double conj(double const &z) { return z; } +CUTLASS_HOST_DEVICE half_t conj(half_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE int32_t conj(int32_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE uint32_t conj(uint32_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE int4b_t conj(int4b_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE uint4b_t conj(uint4b_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE bfloat16_t conj(bfloat16_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE uint1b_t conj(uint1b_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE tfloat32_t conj(tfloat32_t const& z) { + return z; +} + /// Returns the complex conjugate template CUTLASS_HOST_DEVICE complex conj(complex const &z) { return complex(real(z), -imag(z)); } -/// Indentity transform for non-complex types -template -CUTLASS_HOST_DEVICE T conj(T const &z) { - static_assert( !platform::is_same::value && - !platform::is_same::value && - !platform::is_same>::value && - !platform::is_same>::value, "May not be a complex data type"); - return z; -} /// Projects the complex number z onto the Riemann sphere template @@ -511,10 +539,10 @@ CUTLASS_HOST_DEVICE complex sin(complex const &z) { return (exp(-z) - exp(z)) * complex(T(0), T(1) / T(2)); } -/// Comparison +/// Comparison template CUTLASS_HOST_DEVICE bool operator<(complex const &lhs, complex const &rhs) { - return true; + return true; } ////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/conv2d_problem_size.h b/include/cutlass/conv/conv2d_problem_size.h index e7d8360fa9..8c29767fb6 100644 --- a/include/cutlass/conv/conv2d_problem_size.h +++ b/include/cutlass/conv/conv2d_problem_size.h @@ -44,13 +44,22 @@ Map tensor sizes (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator) Map tensor problem sizes (Conv2d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator) */ +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ #pragma once #include "cutlass/cutlass.h" #include "cutlass/tensor_coord.h" #include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/gemm_enumerated_types.h" #include "cutlass/matrix_coord.h" #include "cutlass/conv/convolution.h" #include "cutlass/functional.h" @@ -80,7 +89,7 @@ struct Conv2dProblemSize { public: CUTLASS_HOST_DEVICE - Conv2dProblemSize(): + Conv2dProblemSize(): N(0), H(0), W(0), C(0), P(0), Q(0), K(0), R(0), S(0), pad_h(0), pad_w(0), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1), mode(Mode::kConvolution), split_k_slices(1), groups(1) { } @@ -125,7 +134,7 @@ struct Conv2dProblemSize { int split_k_slices = 1, int groups = 1 ): - N(N), H(H), W(W), C(C), K(K), R(R), S(S), P(P), Q(Q), + N(N), H(H), W(W), C(C), P(P), Q(Q), K(K), R(R), S(S), pad_h(pad_h), pad_w(pad_w), stride_h(stride_h), stride_w(stride_w), dilation_h(dilation_h), dilation_w(dilation_w), mode(mode), split_k_slices(split_k_slices), groups (groups) { } @@ -145,11 +154,11 @@ struct Conv2dProblemSize { int groups = 1 ): N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), + P(output_size.h()), Q(output_size.w()), K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), pad_h(padding[0]), pad_w(padding[2]), stride_h(stride.row()), stride_w(stride.column()), dilation_h(dilation.row()), dilation_w(dilation.column()), - P(output_size.h()), Q(output_size.w()), mode(mode), split_k_slices(split_k_slices), groups(groups) {} /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord @@ -188,8 +197,8 @@ struct Conv2dProblemSize { int groups = 1 ): N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), + P(output_size.h()), Q(output_size.w()), K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), - P(output_size.h()), Q(output_size.w()), pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1), mode(mode), split_k_slices(split_k_slices), groups(groups) {} @@ -486,7 +495,6 @@ int depthwise_gemm_k_iterations( CUTLASS_HOST_DEVICE int implicit_gemm_k_iterations_per_channel( Operator conv_operator, - int threadblock_K, Conv2dProblemSize const &problem_size, IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) { diff --git a/include/cutlass/conv/conv3d_problem_size.h b/include/cutlass/conv/conv3d_problem_size.h index 5bef4ffb71..4a2b20704c 100644 --- a/include/cutlass/conv/conv3d_problem_size.h +++ b/include/cutlass/conv/conv3d_problem_size.h @@ -44,6 +44,15 @@ Map tensor sizes (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator) Map tensor problem sizes (Conv3d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator) */ +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ #pragma once @@ -80,11 +89,11 @@ struct Conv3dProblemSize : public Conv2dProblemSize { public: CUTLASS_HOST_DEVICE Conv3dProblemSize(): + Conv2dProblemSize(), D(0), T(0), Z(0), pad_d(0), stride_d(1), - dilation_d(1), - Conv2dProblemSize() { } + dilation_d(1) { } /// Constructor for default padding, stride, dilation, and split-K CUTLASS_HOST_DEVICE @@ -102,10 +111,10 @@ struct Conv3dProblemSize : public Conv2dProblemSize { int R, int S, Mode mode - ): + ): + Conv2dProblemSize(N, H, W, C, P, Q, K, R, S, mode), D(D), T(T), Z(Z), - pad_d(T / 2), stride_d(1), dilation_d(1), - Conv2dProblemSize(N, H, W, C, P, Q, K, R, S, mode) { } + pad_d(T / 2), stride_d(1), dilation_d(1) { } /// Constructor CUTLASS_HOST_DEVICE @@ -134,15 +143,15 @@ struct Conv3dProblemSize : public Conv2dProblemSize { Mode mode, int split_k_slices = 1, int groups = 1 - ): - D(D), T(T), Z(Z), - pad_d(pad_d), stride_d(stride_d), dilation_d(dilation_d), + ): Conv2dProblemSize( - N, H, W, C, K, R, S, P, Q, - pad_h, pad_w, - stride_h, stride_w, - dilation_h, dilation_w, - mode, split_k_slices, groups) { } + N, H, W, C, K, R, S, P, Q, + pad_h, pad_w, + stride_h, stride_w, + dilation_h, dilation_w, + mode, split_k_slices, groups), + D(D), T(T), Z(Z), + pad_d(pad_d), stride_d(stride_d), dilation_d(dilation_d) { } /// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D // set *user-defined* output size and sets Z, P, and Q (include all data members in ctor) @@ -158,8 +167,6 @@ struct Conv3dProblemSize : public Conv2dProblemSize { int split_k_slices = 1, int groups = 1 ): - D(input_size.d()), T(filter_size.d()), Z(output_size.d()), - pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]), Conv2dProblemSize( {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, @@ -167,8 +174,9 @@ struct Conv3dProblemSize : public Conv2dProblemSize { {stride[1], stride[2]}, {dilation[1], dilation[2]}, {output_size.n(), output_size.h(), output_size.w(), output_size.c()}, - mode, split_k_slices, groups - ) { } + mode, split_k_slices, groups), + D(input_size.d()), T(filter_size.d()), Z(output_size.d()), + pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]) { } /// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D // *computes* output size and sets Z, P and Q (include all data members in ctor) @@ -183,18 +191,18 @@ struct Conv3dProblemSize : public Conv2dProblemSize { int split_k_slices = 1, int groups = 1 ): - D(input_size.d()), T(filter_size.d()), - pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]), Conv2dProblemSize( {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, {padding[1], padding[1], padding[2], padding[2]}, {stride[1], stride[2]}, {dilation[1], dilation[2]}, - mode, split_k_slices, groups - ) { + mode, split_k_slices, groups), + D(input_size.d()), T(filter_size.d()), + pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]) + { // set output Z - Z = ((D + pad_d * 2 - T * dilation_d) / stride_d) + 1; + Z = ((D + pad_d * 2 - T * dilation_d) / stride_d) + 1; } /// Equality operator (ignores mode and split_k_slice) diff --git a/include/cutlass/conv/convolution.h b/include/cutlass/conv/convolution.h index 2984901b9d..5b1e4d34c0 100644 --- a/include/cutlass/conv/convolution.h +++ b/include/cutlass/conv/convolution.h @@ -70,13 +70,23 @@ Map elements' data types (ImplicitGemm -> Conv): GemmToConvElementMap Map elements' data types (Conv -> ImplicitGemm): ConvToGemmElementMap */ +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ + #pragma once #include "cutlass/cutlass.h" #include "cutlass/layout/tensor.h" #include "cutlass/tensor_coord.h" #include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/gemm_enumerated_types.h" #include "cutlass/matrix_coord.h" namespace cutlass { diff --git a/include/cutlass/conv/kernel/direct_convolution.h b/include/cutlass/conv/kernel/direct_convolution.h index ef7a920e64..f5cce5939e 100644 --- a/include/cutlass/conv/kernel/direct_convolution.h +++ b/include/cutlass/conv/kernel/direct_convolution.h @@ -142,7 +142,7 @@ struct DirectConvolutionParams { ThreadblockShape::kN); gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel( - kConvolutionalOperator, ThreadblockShape::kK, args.problem_size, kIteratorAlgorithm); + kConvolutionalOperator, args.problem_size, kIteratorAlgorithm); ThreadblockSwizzle threadblock_swizzle; diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution.h b/include/cutlass/conv/kernel/implicit_gemm_convolution.h index 2669ff7758..79dac6dd83 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution.h @@ -250,7 +250,7 @@ struct ImplicitGemmConvolution { ThreadblockShape::kN); gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel( - kConvolutionalOperator, ThreadblockShape::kK, args.problem_size, kIteratorAlgorithm); + kConvolutionalOperator, args.problem_size, kIteratorAlgorithm); ThreadblockSwizzle threadblock_swizzle; diff --git a/include/cutlass/conv/threadblock/threadblock_swizzle.h b/include/cutlass/conv/threadblock/threadblock_swizzle.h index 4b886049d3..726a77c8b7 100644 --- a/include/cutlass/conv/threadblock/threadblock_swizzle.h +++ b/include/cutlass/conv/threadblock/threadblock_swizzle.h @@ -95,11 +95,11 @@ struct StridedDgradHorizontalThreadblockSwizzle : /// Returns the shape of the problem in units of logical tiles /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) CUTLASS_HOST_DEVICE - gemm::GemmCoord get_tiled_shape( + static gemm::GemmCoord get_tiled_shape( cutlass::conv::Operator conv_operator, cutlass::conv::Conv2dProblemSize const &problem_size, gemm::GemmCoord tile_size, - int split_k_slices) const { + int split_k_slices) { gemm::GemmCoord implicit_gemm_problem_size = cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); @@ -136,11 +136,11 @@ struct StridedDgradIdentityThreadblockSwizzle : /// Returns the shape of the problem in units of logical tiles /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) CUTLASS_HOST_DEVICE - gemm::GemmCoord get_tiled_shape( + static gemm::GemmCoord get_tiled_shape( cutlass::conv::Operator conv_operator, cutlass::conv::Conv2dProblemSize const &problem_size, gemm::GemmCoord tile_size, - int split_k_slices) const { + int split_k_slices) { gemm::GemmCoord implicit_gemm_problem_size = cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); @@ -174,10 +174,10 @@ struct DepthwiseDirect2dConvIdentityThreadblockSwizzle /// Returns the shape of the problem in units of logical tiles CUTLASS_HOST_DEVICE - gemm::GemmCoord get_tiled_shape(cutlass::conv::Operator conv_operator, + static gemm::GemmCoord get_tiled_shape(cutlass::conv::Operator conv_operator, cutlass::conv::Conv2dProblemSize const &problem_size, gemm::GemmCoord tile_size, - int split_k_slices) const { + int split_k_slices) { gemm::GemmCoord implicit_gemm_problem_size = cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); diff --git a/include/cutlass/coord.h b/include/cutlass/coord.h index 455838533e..50fd51930b 100644 --- a/include/cutlass/coord.h +++ b/include/cutlass/coord.h @@ -32,6 +32,16 @@ \brief A Coord is a coordinate of arbitrary rank into a tensor or matrix */ +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ + #pragma once #if defined(__CUDACC_RTC__) diff --git a/include/cutlass/core_io.h b/include/cutlass/core_io.h index c0a9685076..63617afa25 100644 --- a/include/cutlass/core_io.h +++ b/include/cutlass/core_io.h @@ -31,7 +31,15 @@ /*! \file \brief Helpers for printing cutlass/core objects */ +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ #pragma once #include @@ -45,11 +53,10 @@ #include "cutlass/matrix_shape.h" #include "cutlass/layout/pitch_linear.h" #include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/gemm_enumerated_types.h" #include "cutlass/conv/convolution.h" #include "cutlass/conv/conv2d_problem_size.h" #include "cutlass/conv/conv3d_problem_size.h" - /////////////////////////////////////////////////////////////////////////////////////////////////// /// Output operator for CUDA built-in dim3 type diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index bbef6fc2c6..75a46d56cf 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -33,6 +33,16 @@ \brief Basic include for CUTLASS. */ +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. +*/ + #pragma once #include "cutlass/detail/helper_macros.hpp" diff --git a/include/cutlass/detail/helper_macros.hpp b/include/cutlass/detail/helper_macros.hpp index 0c3a9cd2f4..5e0ea623fa 100644 --- a/include/cutlass/detail/helper_macros.hpp +++ b/include/cutlass/detail/helper_macros.hpp @@ -141,4 +141,18 @@ namespace cutlass { #define CUTLASS_THREAD_LOCAL #endif +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if (201700L <= __cplusplus) +#define CUTLASS_CONSTEXPR_IF_CXX17 constexpr +#define CUTLASS_CXX17_OR_LATER 1 +#else +#define CUTLASS_CONSTEXPR_IF_CXX17 +#define CUTLASS_CXX17_OR_LATER 0 +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + }; // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/detail/layout.hpp b/include/cutlass/detail/layout.hpp index da76f0d655..2defe558c3 100644 --- a/include/cutlass/detail/layout.hpp +++ b/include/cutlass/detail/layout.hpp @@ -239,6 +239,23 @@ check_alignment(Shape const & shape, Stride const & stride) { : get_contiguous_shape(cute::get<1>(shape), cute::get<1>(stride)) % Alignment == 0; } +// Check if tensor shape satisfies a given major alignment + +template +CUTLASS_HOST_DEVICE constexpr +size_t +alignment_for_swizzle(cute::Swizzle) { + static_assert(B >= 0 and M >= 0); + return size_t(1) << size_t(B + M + cute::abs(S)); +} + +template +CUTLASS_HOST_DEVICE constexpr +size_t +alignment_for_swizzle(Layout layout) { + return alignment_for_swizzle(cute::detail::get_swizzle_portion(layout)); +} + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::detail diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index 072da89969..dec4b9ff6e 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -259,6 +259,95 @@ struct Sm90TmaBuilderImpl { >; }; +/////////////////////////////////////////////////////////////////////////////// +// Descriptor classes for defining EVT nodes +// Some of the epilogue visitor nodes require non-intuitive template arguments +// such as CopyOpS2R for AuxLoad node. Traditionaly, these are resolved by the +// builder classes. Here we provide a set of descriptor classes that resolve +// these template arguments from more intuitive types such as Stride, Layout + +// Get TileShape, EpilogueTile, Dispatch Policy, StagesC, and STagesD +template< + typename TileShape_MNK, + typename EpilogueTileType, + typename ElementC, + typename ElementD, + typename Schedule +> +struct EpilogueDescriptor { + using TileShape = TileShape_MNK; + using EpilogueTile = + decltype( + detail::sm90_compute_tile_shape_or_override< + ElementD, EpilogueTileType, Schedule + >() + ); + using DispatchPolicy = + decltype( + detail::sm90_get_tma_dispatch_policy< + TileShape_MNK, EpilogueTile, + ElementC, ElementD, Schedule + >() + ); + constexpr static int StagesC = DispatchPolicy::StagesC; + constexpr static int StagesD = DispatchPolicy::StagesD; +}; + +// Get Stride, SmemLayout, and CopyOpS2R for AuxLoad node +template< + typename EpilogueDescriptor, + typename StrideOrLayoutTag, + typename ElementAux +> +struct AuxLoadDescriptor { + constexpr static int Stages = EpilogueDescriptor::StagesC; + using EpilogueTile = typename EpilogueDescriptor::EpilogueTile; + using Element = ElementAux; + using Stride = cutlass::detail::TagToStrideC_t; + using SmemLayoutAtom = + decltype( + detail::sm90_get_epilogue_smem_swizzle_layout_atom< + Stride, ElementAux, typename EpilogueDescriptor::EpilogueTile + >() + ); + using CopyOpS2R = + decltype(detail::sm90_get_smem_load_op_for_source()); +}; + +// Get Stride, SmemLayout, and CopyOpS2R for AuxStore node +template< + typename EpilogueDescriptor, + typename StrideOrLayoutTag, + typename ElementAux +> +struct AuxStoreDescriptor { + constexpr static int Stages = EpilogueDescriptor::StagesD; + using EpilogueTile = typename EpilogueDescriptor::EpilogueTile; + using Element = ElementAux; + using Stride = cutlass::detail::TagToStrideC_t; + using SmemLayoutAtom = + decltype( + detail::sm90_get_epilogue_smem_swizzle_layout_atom< + Stride, ElementAux, typename EpilogueDescriptor::EpilogueTile + >() + ); + using CopyOpR2S = + decltype(detail::sm90_get_smem_store_op_for_accumulator()); +}; + +template< + typename EpilogueDescriptor, + typename ElementVector +> +struct RowBroadcastDescriptor { + constexpr static int Stages = ceil_div( + EpilogueDescriptor::StagesC, + size(shape_div(take<0, 2>(typename EpilogueDescriptor::TileShape{}), typename EpilogueDescriptor::EpilogueTile{})) + ) + 1; + + using Element = ElementVector; +}; + } // namespace detail /////////////////////////////////////////////////////////////////////////////// @@ -426,7 +515,8 @@ private: ElementD, GmemLayoutTagD, AlignmentD, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >; public: diff --git a/include/cutlass/epilogue/collective/collective_builder.hpp b/include/cutlass/epilogue/collective/collective_builder.hpp index 46ad166b2e..02cb795b79 100644 --- a/include/cutlass/epilogue/collective/collective_builder.hpp +++ b/include/cutlass/epilogue/collective/collective_builder.hpp @@ -45,6 +45,7 @@ struct EpilogueTileAuto {}; // Used to let the builder pick the epilogue schedule automatically. // Can be overridden with kernel schedule tags in cutlass/gemm/dispatch_policy.hpp struct EpilogueScheduleAuto {}; +struct EpilogueIm2ColScheduleAuto {}; template < class ArchTag, diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index af77479c77..62d2ef755b 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -126,14 +126,14 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { CUTLASS_HOST_DEVICE static constexpr int get_load_pipe_increment([[maybe_unused]] TileShapeMNK) { - return 0; + return 1; } template CUTLASS_HOST_DEVICE static constexpr int get_store_pipe_increment([[maybe_unused]] TileShapeMNK) { - return 0; + return 1; } CUTLASS_DEVICE diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index 5bdfab882f..fe146a8546 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -40,6 +40,7 @@ #include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/epilogue/thread/scale_type.h" #include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/detail/layout.hpp" #include "cutlass/trace.h" #include "cute/tensor.hpp" @@ -119,40 +120,52 @@ class CollectiveEpilogue< static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); private: - using InternalElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages + using SmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages constexpr static int StagesC = StagesC_; constexpr static int StagesD = StagesD_; + constexpr static bool ReuseSmemC = ReuseSmemC_; constexpr static bool is_source_supported = not cute::is_void_v; - // internal optimization to reuse C shared memory for storing D - using SmemLayoutAtomBitsC = decltype(downcast::value>(SmemLayoutAtomC{})); - using SmemLayoutAtomBitsD = decltype(downcast::value>(SmemLayoutAtomD{})); - constexpr static bool support_smem_reuse = is_source_supported && - sizeof(InternalElementC) == sizeof(ElementD) && - StrideC{} == StrideD{} && - StagesD <= StagesC && - cute::is_same_v; - constexpr static bool ReuseSmemC = DispatchPolicy::ReuseSmemC; - static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); - constexpr static bool is_m_major_C = detail::is_m_major(); constexpr static bool is_m_major_D = detail::is_m_major(); -public: using SmemLayoutC = decltype(tile_to_shape( SmemLayoutAtomC{}, make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), cute::conditional_t, Step<_1,_2,_3>>{} )); using SmemLayoutD = decltype(tile_to_shape( SmemLayoutAtomD{}, - make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), cute::conditional_t, Step<_1,_2,_3>>{} )); + constexpr static bool support_smem_reuse = is_source_supported && StagesD <= StagesC + && cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{})); + static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); + + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + + struct TensorStorageWithC { + alignas(SmemAlignmentC) array_aligned smem_C; + alignas(SmemAlignmentD) array_aligned smem_D; + + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + }; + + struct TensorStorageWithoutC { + alignas(SmemAlignmentD) array_aligned smem_D; + + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + }; + +public: // TMA pipeline for loading C using LoadPipeline = cutlass::PipelineTransactionAsync; using LoadPipelineState = cutlass::PipelineState; constexpr static uint32_t TmaTransactionBytes = - size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof(InternalElementC)); + size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof(SmemElementC)); // TMA pipeline for storing D using StorePipeline = cute::conditional_t; struct SharedStorage { - struct TensorStorage : aligned_struct<128> { - cute::conditional_t, - array_aligned> smem_C; - alignas(128) cute::conditional_t, - array_aligned> smem_D; - - using FusionStorage = typename FusionCallbacks::SharedStorage; - alignas(128) FusionStorage thread; - } tensors; + using TensorStorage = + cute::conditional_t; + TensorStorage tensors; using PipelineStorage = typename LoadPipeline::SharedStorage; PipelineStorage pipeline; @@ -192,7 +197,7 @@ class CollectiveEpilogue< struct Params { using TMA_C = decltype(make_tma_copy( CopyOpG2S{}, - make_tensor(static_cast(nullptr), + make_tensor(static_cast(nullptr), repeat_like(StrideC{}, int32_t(0)), StrideC{}), SmemLayoutC{}(_,_,0))); using TMA_D = decltype(make_tma_copy( @@ -316,21 +321,22 @@ class CollectiveEpilogue< int thread_idx, TensorStorage& shared_tensors) { using namespace cute; - using _X = Underscore; // Indexing variables auto [M, N, K, L] = problem_shape_mnkl; auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; // Represent the full source tensor, slice to get the tile this CTA is currently responsible for - Tensor mC_mnl = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor gC_mnl = local_tile(mC_mnl, tile_shape_MNK, make_coord(_,_,_), Step<_1,_1,_X>{}); // (CTA_M,CTA_N,m,n,l) - Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (CTA_M,CTA_N) + Tensor mC = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (CTA_M,CTA_N) // Apply epilogue subtile, get matching smem tensor - auto ptr_sC = make_smem_ptr(shared_tensors.smem_C.data()); + SmemElementC* ptr_sC = reinterpret_cast(shared_tensors.smem_D.data()); + if constexpr (not ReuseSmemC and is_source_supported) { + ptr_sC = shared_tensors.smem_C.data(); + } Tensor gC_epi = local_tile(gC, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor sC_epi = make_tensor(ptr_sC, SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); @@ -420,8 +426,9 @@ class CollectiveEpilogue< int thread_idx, TensorStorage& shared_tensors) { using namespace cute; - using _X = Underscore; using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; static_assert(is_rmem::value, "Accumulator must be RF resident."); static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); @@ -439,16 +446,22 @@ class CollectiveEpilogue< auto epi_tile_n = size<1>(EpilogueTile{}); // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mD_mnl = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor gD_mnl = local_tile(mD_mnl, tile_shape_MNK, make_coord(_,_,_), Step<_1,_1,_X>{}); // (CTA_M,CTA_N,m,n,l) - Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (CTA_M,CTA_N) - - // Apply epilogue subtiling, construct corresponding pipelined smem tensors - auto ptr_sC = make_smem_ptr(shared_tensors.smem_C.data()); - auto ptr_sD = make_smem_ptr(shared_tensors.smem_D.data()); + Tensor mD = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (CTA_M,CTA_N) + + // Apply epilogue subtiling Tensor gD_epi = local_tile(gD, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor sC_epi = make_tensor(ptr_sC, SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) - Tensor sD_epi = make_tensor(ptr_sD, SmemLayoutD{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + // Construct the corresponding pipelined smem tensors + SmemElementC* ptr_sC = reinterpret_cast(shared_tensors.smem_D.data()); + if constexpr (not ReuseSmemC and is_source_supported) { + ptr_sC = shared_tensors.smem_C.data(); + } + ElementD* ptr_sD = shared_tensors.smem_D.data(); + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) // Get the smallest tiled copy we can use to retile the accumulators using CopyAtomC = Copy_Atom; @@ -458,14 +471,11 @@ class CollectiveEpilogue< TiledCopy tiled_r2s = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) - Tensor tRS_sD = conditional_return( - thread_r2s.partition_D(recast(sC_epi)), // (R2S,R2S_M,R2S_N,PIPE_C) - thread_r2s.partition_D(sD_epi) ); // (R2S,R2S_M,R2S_N,PIPE_D) + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) - // Allocate register tensors - auto tRS_rD_shape = take<0,3>(shape(thread_r2s.partition_S(sD_epi))); - Tensor tRS_rC = make_tensor(tRS_rD_shape); // (R2S,R2S_M,R2S_N) - Tensor tRS_rD = make_tensor(tRS_rD_shape); // (R2S,R2S_M,R2S_N) + // Allocate D registers + Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi)))); + Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) // Vectorized fragment view constexpr int FragmentSize = DispatchPolicy::FragmentSize; @@ -474,16 +484,23 @@ class CollectiveEpilogue< CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly"); // (t)hread-partition for (s)mem to (r)egister copy (tSR_) - TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); - Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) - Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tRS_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tRS_rC = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); - Tensor bSG_sD = conditional_return( - thrblk_s2g.partition_S(recast(sC_epi)), // (S2G,S2G_M,S2G_N,PIPE_C) - thrblk_s2g.partition_S(sD_epi) ); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) CUTE_STATIC_ASSERT(mma_tile_m == epi_tile_m, "EPI_TILE_M must equal MMA_TILE_M"); diff --git a/include/cutlass/epilogue/fusion/callbacks.hpp b/include/cutlass/epilogue/fusion/callbacks.hpp index e9b8f65194..979a5257cf 100644 --- a/include/cutlass/epilogue/fusion/callbacks.hpp +++ b/include/cutlass/epilogue/fusion/callbacks.hpp @@ -62,6 +62,7 @@ struct FusionCallbacksTraits { using Operation = T; using CtaTile_MNK = void; using EpilogueTile_MN = void; + using ElementCompute = void; }; template < @@ -78,6 +79,7 @@ struct FusionCallbacksTraits< using Operation = Operation_; using CtaTile_MNK = CtaTile_MNK_; using EpilogueTile_MN = EpilogueTile_MN_; + using ElementCompute = typename Operation::ElementCompute; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index 14db464397..848d9a1146 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -60,7 +60,7 @@ struct FusionOperation { using ElementBias = void; static constexpr int AlignmentBias = 0; static constexpr bool IsPerRowBiasSupported = false; - template using ActivationFn = void; + using ActivationFn = void; static constexpr bool IsEltActSupported = false; using ElementAux = void; @@ -108,8 +108,7 @@ template< > struct LinCombEltAct : LinearCombination { - template - using ActivationFn = ActivationFn_; + using ActivationFn = ActivationFn_; static constexpr bool IsEltActSupported = true; }; @@ -142,8 +141,7 @@ template< struct LinCombPerRowBiasEltAct : LinCombPerRowBias { - template - using ActivationFn = ActivationFn_; + using ActivationFn = ActivationFn_; static constexpr bool IsEltActSupported = true; }; diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index b2290a40fb..84f75f92ac 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -217,6 +217,9 @@ struct FusionCallbacks< ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + operator typename Impl::Arguments() const { return { // unary op: activation(beta * C + (alpha * acc)) @@ -230,7 +233,7 @@ struct FusionCallbacks< }, // end binary op {} // ternary args : multiply_add }, // end ternary op - {} // unary args: activation + activation // unary args: activation }; // end unary op } }; @@ -258,7 +261,7 @@ using Sm90LinCombPerRowBias = Sm90EVT, // alpha * acc + bias Sm90ScalarBroadcast, // alpha Sm90AccFetch, // acc - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,_0>, AlignmentBias> // bias + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,int>, AlignmentBias> // bias > >; @@ -293,7 +296,10 @@ struct FusionCallbacks< ElementScalar beta = ElementScalar(0); ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; + + using StrideBias = Stride<_1,_0,int>; ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; operator typename Impl::Arguments() const { return @@ -303,7 +309,7 @@ struct FusionCallbacks< { // ternary op : alpha * acc + bias {{alpha}, {alpha_ptr}}, // leaf args : alpha {}, // leaf args : acc - {bias_ptr}, // leaf args : bias + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias {} // ternary args : multiply_add }, // end ternary op {} // ternary args : multiply_add @@ -373,7 +379,13 @@ struct FusionCallbacks< ElementScalar beta = ElementScalar(0); ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; + + using StrideBias = Stride<_1,_0,int>; ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); operator typename Impl::Arguments() const { return @@ -384,12 +396,12 @@ struct FusionCallbacks< { // ternary op : alpha * acc + bias {{alpha}, {alpha_ptr}}, // leaf args : alpha {}, // leaf args : acc - {bias_ptr}, // leaf args : bias + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias {} // ternary args : multiply_add }, // end ternary op {} // ternary args : multiply_add }, // end ternary op - {} // unary args : activation + activation // unary args : activation }; // end unary op } }; @@ -461,10 +473,9 @@ struct FusionCallbacks< ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle > { - using StrideAux = cutlass::gemm::TagToStrideC_t; using Impl = Sm90LinCombPerRowBiasEltActAux< - CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >; using Operation = @@ -478,7 +489,15 @@ struct FusionCallbacks< ElementScalar beta = ElementScalar(0); ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; + + using StrideBias = Stride<_1,_0,int>; ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; ElementAux* aux_ptr = nullptr; StrideAux dAux = {}; @@ -492,14 +511,14 @@ struct FusionCallbacks< { // ternary op : alpha * acc + bias {{alpha}, {alpha_ptr}}, // leaf args : alpha {}, // leaf args : acc - {bias_ptr}, // leaf args : bias + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias {} // ternary args : multiply_add }, // end ternary op {} // ternary args : multiply_add }, // end ternary op {aux_ptr, dAux} // unary args : store }, // end unary op - {} // unary args : activation + activation // unary args : activation }; // end unary op } }; @@ -528,7 +547,7 @@ using Sm90PerRowLinCombPerRowBias = Sm90EVT, // alpha * acc + bias Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,_0>, AlignmentScalar>, // alpha Sm90AccFetch, // acc - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,_0>, AlignmentBias> // bias + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,int>, AlignmentBias> // bias > >; @@ -591,7 +610,13 @@ struct FusionCallbacks< ElementScalar beta = ElementScalar(0); ElementScalar const* alpha_ptr = nullptr; ElementScalar const* beta_ptr = nullptr; + + using StrideBias = Stride<_1,_0,int>; ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); operator typename Impl::Arguments() const { return @@ -600,14 +625,14 @@ struct FusionCallbacks< {beta_ptr, beta}, // leaf args : beta {}, // leaf args : C { // ternary op : alpha * acc + bias - {alpha_ptr, alpha}, // leaf args : alpha - {}, // leaf args : acc - {bias_ptr}, // leaf args : bias + {alpha_ptr, alpha}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias {} // ternary args : multiply_add }, // end ternary op {} // ternary args : multiply_add }, // end ternary op - {} // unary args : activation + activation // unary args : activation }; // end unary op } }; @@ -650,7 +675,7 @@ using Sm90ScaledLinCombPerRowBias = Sm90EVT, // alpha * acc + bias Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha Sm90AccFetch, // acc - Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,_0>, AlignmentBias> // bias + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,int>, AlignmentBias> // bias > >; @@ -728,7 +753,12 @@ struct FusionCallbacks< ElementScalar const* scale_c_ptr = nullptr; ElementScalar const* scale_d_ptr = nullptr; + using StrideBias = Stride<_1,_0,int>; ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); operator typename Impl::Arguments() const { return @@ -742,14 +772,14 @@ struct FusionCallbacks< { // ternary op : (scale_a * scale_b * alpha) * acc + bias {{scale_a, scale_b, alpha}, {scale_a_ptr, scale_b_ptr, alpha_ptr} - }, // leaf args : (scale_a * scale_b * alpha) - {}, // leaf args : acc - {bias_ptr}, // leaf args : bias + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias {} // ternary args : multiply_add }, // end ternary op {} // ternary args : multiply_add }, // end ternary op - {} // unary args : activation + activation // unary args : activation }, // end unary op {{scale_d}, {scale_d_ptr} @@ -855,10 +885,10 @@ struct FusionCallbacks< ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle > { - using StrideAux = cutlass::gemm::TagToStrideC_t; using Impl = Sm90ScaledLinCombPerRowBiasEltActAmaxAux< - CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, ActivationFn, ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >; using Operation = @@ -885,9 +915,17 @@ struct FusionCallbacks< ElementScalar scale_aux = ElementScalar(1); ElementScalar const* scale_aux_ptr = nullptr; + using StrideBias = Stride<_1,_0,int>; ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + ElementAmax* amax_D_ptr = nullptr; ElementAmax* amax_aux_ptr = nullptr; + + using StrideAux = cutlass::gemm::TagToStrideC_t; ElementAux* aux_ptr = nullptr; StrideAux dAux = {}; @@ -905,9 +943,9 @@ struct FusionCallbacks< { // ternary op : (scale_a * scale_b * alpha) * acc + bias {{scale_a, scale_b, alpha}, {scale_a_ptr, scale_b_ptr, alpha_ptr} - }, // leaf args : (scale_a * scale_b * alpha) - {}, // leaf args : acc - {bias_ptr}, // leaf args : bias + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias {} // ternary args : multiply_add }, // end ternary op {} // ternary args : multiply_add @@ -924,7 +962,7 @@ struct FusionCallbacks< { // unary op : reduce(activation(Z)) { // unary op : activation(Z) {}, // leaf args : Z - {} // unary args : activation + activation // unary args : activation }, // end unary op {amax_D_ptr_} // unary args : reduce }, // end unary op diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index 9d3dabd799..0d62a4bdcb 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -99,7 +99,9 @@ struct Sm90Compute : Sm90VisitorImpl<> { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile, class TiledCopy, class SrcTensor @@ -118,6 +120,123 @@ struct Sm90Compute : Sm90VisitorImpl<> { }; +// partial specialization for compute fns that define an Arguments member, e.g. activation hyperparameters +template< + template class ComputeFn, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle +> +struct Sm90Compute< + ComputeFn, + ElementOutput, + ElementCompute, + RoundStyle, + cute::void_t::Arguments> +> { + + struct SharedStorage { }; + + using Arguments = typename ComputeFn::Arguments; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90Compute() { } + + CUTLASS_HOST_DEVICE + Sm90Compute(Params const& params, SharedStorage const& shared_storage) + : params(params) {} + + Params const params; + + template < + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class EpilogueTile + > + CUTLASS_DEVICE auto + get_producer_load_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + int thread_idx) { + return EmptyProducerLoadCallbacks{}; + } + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(Params const& params) + : params(params) {} + + Params const& params; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const&... frg_inputs) { + return transform_apply(cute::make_tuple(frg_inputs...), + [&] (auto&& frg_input) { + using ElementInput = typename cute::remove_cvref_t::Element; + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + return convert_input(frg_input); + }, + [&] (auto&&... cvt_frg_inputs) { + using ComputeOutput = ComputeFn>; + using ConvertOutput = NumericArrayConverter; + ComputeOutput compute_output{}; + ConvertOutput convert_output{}; + + return convert_output(compute_output(cvt_frg_inputs..., params)); + } + ); + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class EpilogueTile, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + return ConsumerStoreCallbacks(params); + } + +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // // Performance Optimized Specializations @@ -215,7 +334,9 @@ struct Sm90TreeVisitor< template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile, class TiledCopy, class SrcTensor diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index 28559027a7..348a62befe 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -71,7 +71,9 @@ struct Sm90AccFetch : Sm90VisitorImpl<> { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile, class TiledCopy, class SrcTensor @@ -129,7 +131,9 @@ struct Sm90SrcFetch : Sm90VisitorImpl<> { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile, class TiledCopy, class SrcTensor @@ -181,7 +185,8 @@ struct Sm90AuxLoad { cute::conditional_t, Step<_1,_2,_3>>{} )); struct SharedStorage { - alignas(128) array_aligned smem_aux; + alignas(cutlass::detail::alignment_for_swizzle(SmemLayout{})) + array_aligned smem_aux; }; struct Arguments { @@ -222,9 +227,9 @@ struct Sm90AuxLoad { Sm90AuxLoad() { } CUTLASS_HOST_DEVICE - Sm90AuxLoad(Params const& params, SharedStorage& shared_storage) + Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage) : params_ptr(¶ms), - smem_aux(shared_storage.smem_aux.data()) { } + smem_aux(const_cast(shared_storage.smem_aux.data())) { } Params const* params_ptr; Element* smem_aux; @@ -273,7 +278,9 @@ struct Sm90AuxLoad { }; template < - class TileShapeMNK + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL > CUTLASS_DEVICE auto get_producer_load_callbacks( @@ -284,8 +291,9 @@ struct Sm90AuxLoad { int thread_idx) { auto [M, N, K, L] = problem_shape_mnkl; + auto [m, n, k, l] = tile_coord_mnkl; Tensor mAux = params_ptr->tma_load_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor gAux = sm90_tensor_to_cta_tile(mAux, tile_shape_mnk, tile_coord_mnkl); // (CTA_M,CTA_N) + Tensor gAux = local_tile(mAux, take<0,2>(tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) Tensor gAux_epi = local_tile(gAux, epi_tile, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) Tensor sAux_epi = make_tensor(make_smem_ptr(smem_aux), SmemLayout{}); // (EPI_TILE_M,EPI_TILE_N,PIPE) @@ -339,7 +347,9 @@ struct Sm90AuxLoad { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class TiledCopy, class SrcTensor > @@ -363,7 +373,8 @@ struct Sm90AuxLoad { make_tiled_copy_S(Copy_Atom{}, tiled_copy), make_tiled_copy_D(Copy_Atom{}, tiled_copy) ); - Tensor sAux_epi = make_tensor(make_smem_ptr(smem_aux), SmemLayout{}); // (EPI_TILE_M,EPI_TILE_N,PIPE) + Tensor sAux_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE) auto tSR_sAux = tiled_s2r.get_slice(thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE) @@ -378,6 +389,7 @@ struct Sm90AuxLoad { ///////////////////////////////////////////////////////////////////////////////////////////////// // Scalar broadcast +// Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors template< class Element, class StrideMNL = Stride<_0,_0,_0>, @@ -387,7 +399,8 @@ template< struct Sm90ScalarBroadcast { static_assert( (cute::is_same_v>) || // scalar broadcast, e.g. alpha - (cute::is_same_v>)); // batched scalar broadcast, e.g. per-batch alpha + (cute::is_same_v>) || // batched scalar broadcast, e.g. per-batch alpha + (cute::is_same_v>)); struct SharedStorage { }; @@ -419,7 +432,7 @@ struct Sm90ScalarBroadcast { Sm90ScalarBroadcast() { } CUTLASS_HOST_DEVICE - Sm90ScalarBroadcast(Params const& params, SharedStorage& shared_storage) + Sm90ScalarBroadcast(Params const& params, SharedStorage const& shared_storage) : params_ptr(¶ms) { // Get the scalar for non-batched broadcast if constexpr (cute::is_same_v>) { @@ -431,7 +444,9 @@ struct Sm90ScalarBroadcast { Params const* params_ptr; template < + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile > CUTLASS_DEVICE auto @@ -442,7 +457,9 @@ struct Sm90ScalarBroadcast { EpilogueTile epi_tile, int thread_idx) { // Get the scalar for batched broadcast - if constexpr (cute::is_same_v>) { + if constexpr ( + cute::is_same_v> || + cute::is_same_v>) { auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; update_scalar(l_coord); } @@ -470,7 +487,9 @@ struct Sm90ScalarBroadcast { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile, class TiledCopy, class SrcTensor @@ -486,7 +505,9 @@ struct Sm90ScalarBroadcast { SrcTensor const& tCrC) { // Get the scalar for batched broadcast - if constexpr (cute::is_same_v>) { + if constexpr ( + cute::is_same_v> || + cute::is_same_v>) { auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; update_scalar(l_coord); } @@ -541,7 +562,7 @@ struct Sm90RowBroadcast { // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem struct SharedStorage { - array_aligned(CtaTileShapeMNK{}) * Stages> smem_row; + alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_row; }; struct Arguments { @@ -562,9 +583,9 @@ struct Sm90RowBroadcast { Sm90RowBroadcast() { } CUTLASS_HOST_DEVICE - Sm90RowBroadcast(Params const& params, SharedStorage& shared_storage) + Sm90RowBroadcast(Params const& params, SharedStorage const& shared_storage) : params(params), - smem_row(shared_storage.smem_row.data()) { } + smem_row(const_cast(shared_storage.smem_row.data())) { } Params params; Element* smem_row; @@ -613,7 +634,9 @@ struct Sm90RowBroadcast { }; template < + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile > CUTLASS_DEVICE auto @@ -625,8 +648,9 @@ struct Sm90RowBroadcast { int thread_idx) { auto [M, N, K, L] = problem_shape_mnkl; + auto [m, n, k, l] = tile_coord_mnkl; Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); - Tensor gRow = sm90_tensor_to_cta_tile(mRow, tile_shape_mnk, tile_coord_mnkl); // (CTA_M,CTA_N) + Tensor gRow = local_tile(mRow, take<0,2>(tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); @@ -680,7 +704,9 @@ struct Sm90RowBroadcast { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile, class TiledCopy, class SrcTensor @@ -757,13 +783,15 @@ struct Sm90ColBroadcast { Sm90ColBroadcast() { } CUTLASS_HOST_DEVICE - Sm90ColBroadcast(Params const& params, SharedStorage& shared_storage) + Sm90ColBroadcast(Params const& params, SharedStorage const& shared_storage) : params(params) { } Params params; template < + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile > CUTLASS_DEVICE auto @@ -819,7 +847,9 @@ struct Sm90ColBroadcast { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile, class TiledCopy, class SrcTensor diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index 7da6d09c49..8e1ffb08a2 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -83,7 +83,8 @@ struct Sm90AuxStore { cute::conditional_t, Step<_1,_2,_3>>{} )); struct SharedStorage { - alignas(128) array_aligned smem_aux; + alignas(cutlass::detail::alignment_for_swizzle(SmemLayout{})) + array_aligned smem_aux; }; struct Arguments { @@ -125,9 +126,9 @@ struct Sm90AuxStore { Sm90AuxStore() { } CUTLASS_HOST_DEVICE - Sm90AuxStore(Params const& params, SharedStorage& shared_storage) + Sm90AuxStore(Params const& params, SharedStorage const& shared_storage) : params_ptr(¶ms), - smem_aux(shared_storage.smem_aux.data()) { } + smem_aux(const_cast(shared_storage.smem_aux.data())) { } Params const* params_ptr; Element* smem_aux; @@ -143,7 +144,9 @@ struct Sm90AuxStore { } template < - class TileShapeMNK + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL > CUTLASS_DEVICE auto get_producer_load_callbacks( @@ -233,7 +236,9 @@ struct Sm90AuxStore { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class TiledCopy, class SrcTensor > @@ -248,14 +253,16 @@ struct Sm90AuxStore { SrcTensor const& tCrC) { auto [M, N, K, L] = problem_shape_mnkl; + auto [m, n, k, l] = tile_coord_mnkl; Tensor mAux = params_ptr->tma_store_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor gAux = sm90_tensor_to_cta_tile(mAux, tile_shape_mnk, tile_coord_mnkl); // (CTA_M,CTA_N) + Tensor gAux = local_tile(mAux, take<0,2>(tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) gAux, epi_tile, tiled_copy, thread_idx); Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) - Tensor sAux_epi = make_tensor(make_smem_ptr(smem_aux), SmemLayout{}); // (EPI_TILE_M,EPI_TILE_N,PIPE) + Tensor sAux_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE) Tensor gAux_epi = local_tile(gAux, epi_tile, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) auto tiled_r2s = conditional_return( @@ -297,8 +304,8 @@ template < struct Sm90ScalarReduction { static_assert( (cute::is_same_v>) || // scalar reduction, e.g. tensor max element - (cute::is_same_v>)); // batched scalar reduction, e.g. per-batch max element - + (cute::is_same_v>) || // batched scalar reduction, e.g. per-batch max element + (cute::is_same_v>)); struct SharedStorage { }; struct Arguments { @@ -329,13 +336,15 @@ struct Sm90ScalarReduction { Sm90ScalarReduction() { } CUTLASS_HOST_DEVICE - Sm90ScalarReduction(Params const& params, SharedStorage& shared_storage) + Sm90ScalarReduction(Params const& params, SharedStorage const& shared_storage) : params(params) { } Params const params; template < + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile > CUTLASS_DEVICE auto @@ -417,7 +426,9 @@ struct Sm90ScalarReduction { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile, class TiledCopy, class SrcTensor @@ -502,13 +513,15 @@ struct Sm90RowReduction { Sm90RowReduction() { } CUTLASS_HOST_DEVICE - Sm90RowReduction(Params const& params, SharedStorage& shared_storage) + Sm90RowReduction(Params const& params, SharedStorage const& shared_storage) : params(params) { } Params params; template < + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile > CUTLASS_DEVICE auto @@ -619,7 +632,9 @@ struct Sm90RowReduction { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile, class TiledCopy, class SrcTensor @@ -707,13 +722,15 @@ struct Sm90ColReduction { Sm90ColReduction() { } CUTLASS_HOST_DEVICE - Sm90ColReduction(Params const& params, SharedStorage& shared_storage) + Sm90ColReduction(Params const& params, SharedStorage const& shared_storage) : params(params) { } Params params; template < + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile > CUTLASS_DEVICE auto @@ -765,10 +782,11 @@ struct Sm90ColReduction { Array frg_I = convert_input(frg_input); Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { - if (elem_less(tCcCol(i), residue_mn)) { + if (elem_less(tCcCol_mn(i), residue_mn)) { ElementCompute& tCrCol_vmn = tCrCol_mn(epi_v * FragmentSize + i); tCrCol_vmn = reduce_input(tCrCol_vmn, frg_I[i]); } @@ -808,7 +826,9 @@ struct Sm90ColReduction { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile, class TiledCopy, class SrcTensor diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp index 7750701e18..85b69333d6 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -51,34 +51,12 @@ using cute::tuple; namespace detail { -// Convenience aliases -using ProblemShapeMNKL = tuple; -using TileCoordMNKL = tuple; - ///////////////////////////////////////////////////////////////////////////////////////////////// // // Partitioning Helpers // ///////////////////////////////////////////////////////////////////////////////////////////////// -template < - class Engine, class LayoutMNL, - class TileShapeMNK -> -CUTLASS_HOST_DEVICE -constexpr auto -sm90_tensor_to_cta_tile( - Tensor mT, // (M,N,L) - TileShapeMNK tile_shape_mnk, // (CTA_M,CTA_N,CTA_K) - TileCoordMNKL tile_coord_mnkl) { - using _X = Underscore; - - auto [m, n, k, l] = tile_coord_mnkl; - Tensor mT_mnl = local_tile(mT, tile_shape_mnk, make_coord(_,_,_), Step<_1,_1,_X>{}); // (CTA_M,CTA_N) - - return mT_mnl(_,_,m,n,l); -} - template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy class CtaTileMN, @@ -106,6 +84,7 @@ template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy class Engine, class LayoutMNL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile, class TiledCopy > @@ -118,7 +97,8 @@ sm90_partition_for_epilogue( EpilogueTile epi_tile, // (EPI_TILE_M,EPI_TILE_N) TiledCopy tiled_copy, int thread_idx) { - Tensor cT = sm90_tensor_to_cta_tile(mT, tile_shape_mnk, tile_coord_mnkl); // (CTA_M,CTA_N) + auto [m, n, k, l] = tile_coord_mnkl; + Tensor cT = local_tile(mT, take<0,2>(tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) Tensor tCcT = sm90_partition_for_epilogue(cT, epi_tile, tiled_copy, thread_idx); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) @@ -156,7 +136,7 @@ struct Sm90VisitorImplBase { Sm90VisitorImplBase() {} CUTLASS_HOST_DEVICE - Sm90VisitorImplBase(Params const& params, SharedStorage& shared_storage) + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) : ops(transform_apply(tuple{}, params, shared_storage, [] (auto&& op, auto const& op_params, auto&& op_storage) { using Op = cute::remove_cvref_t; @@ -262,7 +242,9 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { // Producer load callbacks factory // All operations must redefine this, but most can just dispatch to the base impl template < + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile > CUTLASS_DEVICE auto @@ -363,7 +345,9 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { // All operations must redefine this template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile, class TiledCopy, class SrcTensor @@ -446,7 +430,9 @@ struct Sm90TreeVisitor : Sm90VisitorImpl { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile, class TiledCopy, class SrcTensor @@ -516,7 +502,9 @@ struct Sm90SplitTreeVisitor : Sm90VisitorImpl { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class ProblemShapeMNKL, class TileShapeMNK, + class TileCoordMNKL, class EpilogueTile, class TiledCopy, class SrcTensor @@ -651,9 +641,11 @@ namespace detail { template struct Sm90VisitorImplBase { - struct SharedStorage { - typename Op0::SharedStorage op_0; - }; + // Retain tuple for SharedStorage because empty structs have 1B alignment + // tuples use multiple inheritance, avoids this problem + using SharedStorage = tuple< + typename Op0::SharedStorage + >; struct Arguments { typename Op0::Arguments op_0; @@ -675,9 +667,9 @@ struct Sm90VisitorImplBase { Sm90VisitorImplBase() {} CUTLASS_HOST_DEVICE - Sm90VisitorImplBase(Params const& params, SharedStorage& shared_storage) + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) : ops({ - Op0(params.op_0, shared_storage.op_0) + Op0(params.op_0, get<0>(shared_storage)) }) {} tuple ops; @@ -686,10 +678,10 @@ struct Sm90VisitorImplBase { template struct Sm90VisitorImplBase { - struct SharedStorage { - typename Op0::SharedStorage op_0; - typename Op1::SharedStorage op_1; - }; + using SharedStorage = tuple< + typename Op0::SharedStorage, + typename Op1::SharedStorage + >; struct Arguments { typename Op0::Arguments op_0; @@ -714,10 +706,10 @@ struct Sm90VisitorImplBase { Sm90VisitorImplBase() {} CUTLASS_HOST_DEVICE - Sm90VisitorImplBase(Params const& params, SharedStorage& shared_storage) + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) : ops({ - Op0(params.op_0, shared_storage.op_0), - Op1(params.op_1, shared_storage.op_1) + Op0(params.op_0, get<0>(shared_storage)), + Op1(params.op_1, get<1>(shared_storage)) }) {} tuple ops; @@ -726,11 +718,11 @@ struct Sm90VisitorImplBase { template struct Sm90VisitorImplBase { - struct SharedStorage { - typename Op0::SharedStorage op_0; - typename Op1::SharedStorage op_1; - typename Op2::SharedStorage op_2; - }; + using SharedStorage = tuple< + typename Op0::SharedStorage, + typename Op1::SharedStorage, + typename Op2::SharedStorage + >; struct Arguments { typename Op0::Arguments op_0; @@ -758,11 +750,11 @@ struct Sm90VisitorImplBase { Sm90VisitorImplBase() {} CUTLASS_HOST_DEVICE - Sm90VisitorImplBase(Params const& params, SharedStorage& shared_storage) + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) : ops({ - Op0(params.op_0, shared_storage.op_0), - Op1(params.op_1, shared_storage.op_1), - Op2(params.op_2, shared_storage.op_2) + Op0(params.op_0, get<0>(shared_storage)), + Op1(params.op_1, get<1>(shared_storage)), + Op2(params.op_2, get<2>(shared_storage)) }) {} tuple ops; @@ -771,12 +763,12 @@ struct Sm90VisitorImplBase { template struct Sm90VisitorImplBase { - struct SharedStorage { - typename Op0::SharedStorage op_0; - typename Op1::SharedStorage op_1; - typename Op2::SharedStorage op_2; - typename Op3::SharedStorage op_3; - }; + using SharedStorage = tuple< + typename Op0::SharedStorage, + typename Op1::SharedStorage, + typename Op2::SharedStorage, + typename Op3::SharedStorage + >; struct Arguments { typename Op0::Arguments op_0; @@ -807,12 +799,12 @@ struct Sm90VisitorImplBase { Sm90VisitorImplBase() {} CUTLASS_HOST_DEVICE - Sm90VisitorImplBase(Params const& params, SharedStorage& shared_storage) + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) : ops({ - Op0(params.op_0, shared_storage.op_0), - Op1(params.op_1, shared_storage.op_1), - Op2(params.op_2, shared_storage.op_2), - Op3(params.op_3, shared_storage.op_3) + Op0(params.op_0, get<0>(shared_storage)), + Op1(params.op_1, get<1>(shared_storage)), + Op2(params.op_2, get<2>(shared_storage)), + Op3(params.op_3, get<3>(shared_storage)) }) {} tuple ops; diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 526d46b569..221aa0f3cd 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -49,38 +49,6 @@ namespace cutlass { namespace epilogue { namespace thread { -///////////////////////////////////////////////////////////////////////////////////////////////// -template -struct LinearCombinationGenericParams { - T alpha; ///< scales accumulators - T beta; ///< scales source tensor - T const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory - T const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory - - // - // Methods - // - - CUTLASS_HOST_DEVICE - LinearCombinationGenericParams(): - alpha(T(1)), - beta(T(0)), - alpha_ptr(nullptr), - beta_ptr(nullptr) { } - - CUTLASS_HOST_DEVICE - LinearCombinationGenericParams( - T alpha, - T beta = T(0) - ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { } - - CUTLASS_HOST_DEVICE - LinearCombinationGenericParams( - T const *alpha_ptr, - T const *beta_ptr = nullptr - ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { } -}; - ///////////////////////////////////////////////////////////////////////////////////////////////// // Identity operator @@ -92,27 +60,67 @@ struct Identity { T operator()(T value) const { return value; } +}; - using Params = LinearCombinationGenericParams; +template +struct Identity > { + CUTLASS_HOST_DEVICE + Array operator()(Array const &value) const { + return value; + } +}; + +/// Scale operator +template +struct Scale { + struct Arguments { + T scale = T(1); + }; CUTLASS_HOST_DEVICE - T operator()(T const &value, Params const ¶ms_) const { - return this->operator()(value); + T operator()(T const& value, T const& scale) const { + multiplies mul; + return mul(scale, value); + } + + CUTLASS_HOST_DEVICE + T operator()(T const& value, Arguments const& args = Arguments()) const { + return this->operator()(value, args.scale); } }; template -struct Identity > { +struct Scale> { + using Arguments = typename Scale::Arguments; + CUTLASS_HOST_DEVICE - Array operator()(Array const &value) const { - return value; + Array operator()(Array const& values, T const& scale) const { + multiplies> mul; + return mul(scale, values); } - using Params = LinearCombinationGenericParams; + CUTLASS_HOST_DEVICE + Array operator()(Array const& values, Arguments const& args = Arguments()) const { + return this->operator()(values, args.scale); + } +}; + +/// Specialization to compose other activations with a defined unary operator +/// e.g. Scale> +template