Skip to content

Conversation

@GZGavinZhao
Copy link
Contributor

@GZGavinZhao GZGavinZhao commented Nov 17, 2025

Proposed changes

Support compiling and running CK on RDNA1 GPUs (including the new gfx10-1-generic target). This just involves

  1. Relaxing some #if defined(__gfx10x__) guards to include defined(__gfx101__)
  2. Some small CMake changes to fix a linking error where ckProfiler "unable to find -ldevice_quantization", since device quantization instances seem to be incompatible/unavailable for RDNA1 GPUs.

Fixes #2411.

Fixes #3185.

Test plan

Built CK on an AWS g4ad.xlarge instance (gfx1011) and ran tests as follows:

cmake -B build -S . -GNinja \
        -DCMAKE_PREFIX_PATH=/opt/rocm \
        -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
        -DCMAKE_BUILD_TYPE=RelWithDebInfo \
        -DGPU_TARGETS=gfx1011 \
        -DBUILD_TESTING=ON \
ctest --test-dir build -j1 --output-on-failure

109/110 tests reliably pass. test_batchnorm_infer_rank_4 fails (<1% has wrong values) due to a compiler bug documented at the end..

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

I don't understand the hardware capabilities of RDNA1 and what capabilities are required by CK, so I'd hope if some maintainers of CK can check if RDNA1 GPUs are indeed incompatible with device quantization and if they have a more robust solution to the linker error.

Running unmodified, test test_batchnorm_infer_rank_4 fails with the below output:

Test output
Start testing: Nov 17 04:41 UTC
----------------------------------------------------------
93/110 Testing: test_batchnorm_infer_rank_4
93/110 Test: test_batchnorm_infer_rank_4
Command: "/mnt/work/composable_kernel/build/bin/test_batchnorm_infer_rank_4"
Directory: /mnt/work/composable_kernel/build/test/batchnorm
"test_batchnorm_infer_rank_4" start time: Nov 17 04:41 UTC
Output:
----------------------------------------------------------
[==========] Running 8 tests from 4 test suites.
[----------] Global test environment set-up.
[----------] 2 tests from TestBatchNormInferRank4/0, where TypeParam = std::tuple<_Float16,_Float16,float,_Float16,_Float16,float>
[ RUN      ] TestBatchNormInferRank4/0.nhwc
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
[       OK ] TestBatchNormInferRank4/0.nhwc (7597 ms)
[ RUN      ] TestBatchNormInferRank4/0.nchw
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
[       OK ] TestBatchNormInferRank4/0.nchw (4023 ms)
[----------] 2 tests from TestBatchNormInferRank4/0 (11620 ms total)

[----------] 2 tests from TestBatchNormInferRank4/1, where TypeParam = std::tuple<float,float,float,float,float,float>
[ RUN      ] TestBatchNormInferRank4/1.nhwc
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
[       OK ] TestBatchNormInferRank4/1.nhwc (2442 ms)
[ RUN      ] TestBatchNormInferRank4/1.nchw
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
[       OK ] TestBatchNormInferRank4/1.nchw (1606 ms)
[----------] 2 tests from TestBatchNormInferRank4/1 (4049 ms total)

[----------] 2 tests from TestBatchNormInferRank4/2, where TypeParam = std::tuple<unsigned short,unsigned short,float,unsigned short,unsigned short,float>
[ RUN      ] TestBatchNormInferRank4/2.nhwc
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
[       OK ] TestBatchNormInferRank4/2.nhwc (2150 ms)
[ RUN      ] TestBatchNormInferRank4/2.nchw
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
found 14 instances
[       OK ] TestBatchNormInferRank4/2.nchw (1415 ms)
[----------] 2 tests from TestBatchNormInferRank4/2 (3566 ms total)

[----------] 2 tests from TestBatchNormInferRank4/3, where TypeParam = std::tuple<double,double,double,double,double,double>
[ RUN      ] TestBatchNormInferRank4/3.nhwc
found 9 instances
y results        out[3833888] != ref[3833888]: 0.2905493 != -0.001812637
y results        out[3833892] != ref[3833892]: 0.4788376 != 0.08403782
y results        out[3833896] != ref[3833896]: 0.8666859 != 1.079449
y results        out[3833900] != ref[3833900]: -0.9379467 != 0.3999299
y results        out[3837984] != ref[3837984]: 0.2905493 != -0.2596534
max err: 7.314777e+38, number of errors: 94, 0.00149409% wrong values
/mnt/work/composable_kernel/test/batchnorm/batchnorm_infer_rank_4.cpp:75: Failure
Value of: pass
  Actual: false
Expected: true

found 9 instances
y results        out[262272] != ref[262272]: -0.9631336 != -0.6550316
y results        out[262276] != ref[262276]: 0.1491528 != 0.4125393
y results        out[262280] != ref[262280]: 0.5467949 != 0.7510509
y results        out[262284] != ref[262284]: 0.9366151 != 0.7876599
y results        out[264320] != ref[264320]: -0.9631336 != -1.087958
max err: 7.909358e+38, number of errors: 188, 0.00298818% wrong values
/mnt/work/composable_kernel/test/batchnorm/batchnorm_infer_rank_4.cpp:75: Failure
Value of: pass
  Actual: false
Expected: true

found 9 instances
found 9 instances
found 9 instances
found 9 instances
y results        out[858112] != ref[858112]: -0.0125578 != -0.6763946
y results        out[858116] != ref[858116]: 0.8256116 != 0.2654074
y results        out[858120] != ref[858120]: -0.6301264 != -0.715658
y results        out[858124] != ref[858124]: -0.9638686 != -0.7612199
y results        out[858240] != ref[858240]: -0.0125578 != -0.04071561
max err: 9.775917e+38, number of errors: 31, 0.00295639% wrong values
/mnt/work/composable_kernel/test/batchnorm/batchnorm_infer_rank_4.cpp:75: Failure
Value of: pass
  Actual: false
Expected: true

[  FAILED  ] TestBatchNormInferRank4/3.nhwc, where TypeParam = std::tuple<double,double,double,double,double,double> (1305 ms)
[ RUN      ] TestBatchNormInferRank4/3.nchw
found 9 instances
found 9 instances
y results        out[2392288] != ref[2392288]: 0.3779397 != 0.4924658
y results        out[2392292] != ref[2392292]: 0.3779397 != 1.439335
y results        out[2392296] != ref[2392296]: 0.3779397 != 0.5320557
y results        out[2392300] != ref[2392300]: 7.661843e+38 != 2.629553
y results        out[2396384] != ref[2396384]: 0.5149692 != 0.3845829
max err: 7.661843e+38, number of errors: 117, 0.001859665% wrong values
/mnt/work/composable_kernel/test/batchnorm/batchnorm_infer_rank_4.cpp:75: Failure
Value of: pass
  Actual: false
Expected: true

found 9 instances
y results        out[3408192] != ref[3408192]: -0.71014 != -0.6969489
y results        out[3408196] != ref[3408196]: -0.71014 != -0.8006292
y results        out[3408200] != ref[3408200]: -0.71014 != -0.6480811
y results        out[3408204] != ref[3408204]: -0.71014 != -0.6471095
y results        out[3410240] != ref[3410240]: 0.6120615 != 0.458294
max err: 6.493006e+38, number of errors: 155, 0.002463659% wrong values
/mnt/work/composable_kernel/test/batchnorm/batchnorm_infer_rank_4.cpp:75: Failure
Value of: pass
  Actual: false
Expected: true

found 9 instances
found 9 instances
found 9 instances
found 9 instances
[  FAILED  ] TestBatchNormInferRank4/3.nchw, where TypeParam = std::tuple<double,double,double,double,double,double> (1426 ms)
[----------] 2 tests from TestBatchNormInferRank4/3 (2732 ms total)

[----------] Global test environment tear-down
[==========] 8 tests from 4 test suites ran. (21968 ms total)
[  PASSED  ] 6 tests.
[  FAILED  ] 2 tests, listed below:
[  FAILED  ] TestBatchNormInferRank4/3.nhwc, where TypeParam = std::tuple<double,double,double,double,double,double>
[  FAILED  ] TestBatchNormInferRank4/3.nchw, where TypeParam = std::tuple<double,double,double,double,double,double>

 2 FAILED TESTS
<end of output>
Test time =  22.03 sec
----------------------------------------------------------
Test Failed.
"test_batchnorm_infer_rank_4" end time: Nov 17 04:42 UTC
"test_batchnorm_infer_rank_4" time elapsed: 00:00:22
----------------------------------------------------------

End testing: Nov 17 04:42 UTC

SMOKE_TEST =  22.03 sec*proc

HOWEVER upon applying this patch that adds only debug output, the test succeeds no matter how many times I run it:

Logging patch
diff --git i/include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp w/include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp
index bdc6dc998..df7c451bb 100644
--- i/include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp
+++ w/include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp
@@ -282,6 +282,9 @@ struct DeviceElementwiseImpl
             const bool in_out_same_vector_dim = GetLowestStrideDim(arg.inStridesArray_[I0]) ==
                                                 GetLowestStrideDim(arg.outStridesArray_[I0]);
 
+            printf("[CK_DEBUG] DeviceElementwiseImpl::Run: M0=%ld, M1=%ld, grid_size=%ld, BlockSize=%d, same_vector_dim=%d\n",
+                   static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(grid_size), BlockSize, static_cast<int>(in_out_same_vector_dim));
+
             const auto kernel = in_out_same_vector_dim
                                     ? kernel_elementwise<GridwiseElementwiseOpSameInOutVectorDim,
                                                          InGridDescTuple,
diff --git i/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp w/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
index 839a68a97..588c39f3f 100644
--- i/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
+++ w/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
@@ -33,6 +33,22 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
                        const Block2TileMap block_2_tile_map,
                        const ElementwiseOperation elementwise_op)
 {
+    if(threadIdx.x == 0 && blockIdx.x == 0)
+    {
+        printf("[CK_DEBUG] kernel_elementwise executing on block 0, thread 0, arch=%d\n",
+#if defined(__gfx1010__)
+               1010
+#elif defined(__gfx1011__)
+               1011
+#elif defined(__gfx1012__)
+               1012
+#elif defined(__gfx1030__)
+               1030
+#else
+               0
+#endif
+        );
+    }
     GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
                                     out_grid_desc_tuple,
                                     p_in_global_tuple,
diff --git i/profiler/include/profiler/profile_batchnorm_infer_impl.hpp w/profiler/include/profiler/profile_batchnorm_infer_impl.hpp
index 5ae150f26..7a544971f 100644
--- i/profiler/include/profiler/profile_batchnorm_infer_impl.hpp
+++ w/profiler/include/profiler/profile_batchnorm_infer_impl.hpp
@@ -179,6 +179,8 @@ bool profile_batchnorm_infer_impl(int do_verification,
         ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
             DeviceOp>::GetInstances();
 
+    std::cout << "[CK_DEBUG] XDataType size=" << sizeof(XDataType)
+              << ", YDataType size=" << sizeof(YDataType) << std::endl;
     std::cout << "found " << instance_ptrs.size() << " instances" << std::endl;
 
     std::string best_instance_name;
@@ -254,14 +256,19 @@ bool profile_batchnorm_infer_impl(int do_verification,
         if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
         {
             num_kernel++;
+            std::cout << "[CK_DEBUG] Instance " << num_kernel << " SUPPORTED: "
+                      << inst_ptr->GetTypeString() << std::endl;
             if((instance_index != -1) && (instance_index + 1 != num_kernel))
             {
                 // skip test if instance_index is specified
+                std::cout << "[CK_DEBUG] Skipping instance " << num_kernel
+                          << " (index filter: " << instance_index << ")" << std::endl;
                 continue;
             }
         }
         else
         {
+            std::cout << "[CK_DEBUG] Instance UNSUPPORTED: " << inst_ptr->GetTypeString() << std::endl;
             if(time_kernel)
             {
                 std::cout << inst_ptr->GetTypeString()

UPDATE: this is a compiler bug, see comment

@GZGavinZhao
Copy link
Contributor Author

GZGavinZhao commented Nov 17, 2025

Oh this issue is a compiler optimization bug that seems to only happen with FP64:

diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
index 839a68a978..142f084a67 100644
--- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
+++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
@@ -33,6 +33,16 @@
                        const Block2TileMap block_2_tile_map,
                        const ElementwiseOperation elementwise_op)
 {
+#if defined(__gfx101__)
+    // Workaround for gfx101x FP64 compiler bug: prevent incorrect optimization
+    // The compiler appears to misoptimize FP64 elementwise kernels without this barrier
+    if(threadIdx.x == 0 && blockIdx.x == 0)
+    {
+        // Use volatile pointer dereference to force memory dependency
+        auto ptr = p_in_global_tuple[Number<0>{}];
+        asm volatile("" : "+v"(ptr) :: "memory");
+    }
+#endif
     GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
                                     out_grid_desc_tuple,
                                     p_in_global_tuple,

This fixes the test. Unfortunately I'm unable to create a minimum working example.

@illsilin
Copy link
Collaborator

Thanks, I appreciate the effort to add support on gfx101x. Our main issue with gfx101x is that we don't have the hardware to keep running and testing on a regular basis. Hence no official support. But I would be open to merging this.

@GZGavinZhao
Copy link
Contributor Author

Thank you! I completely understand that gfx101x has no official support, so this patch was written in a way such that it shouldn't affect other architectures.

@illsilin
Copy link
Collaborator

OK, I've made sure this passes through the CI and made inquiries about the hardware. Sounds like we may have something we could use. So I think we can merge this.

@illsilin illsilin self-assigned this Nov 19, 2025
@illsilin illsilin merged commit 07314ac into ROCm:develop Nov 20, 2025
23 of 25 checks passed
@GZGavinZhao
Copy link
Contributor Author

Thank you!

AviralGoelAMD pushed a commit that referenced this pull request Nov 28, 2025
* Allow compilation for RDNA1 (__gfx101__)

Signed-off-by: Gavin Zhao <git@gzgz.dev>

* More RDNA1 changes

Signed-off-by: Gavin Zhao <git@gzgz.dev>

* Even more RDNA1 changes

Signed-off-by: Gavin Zhao <git@gzgz.dev>

* cmake: skip build quantization for unsupported arches

* add gfx10-1-generic support as well

* add gfx1013 and complete gfx10-1-generic

* fix clang format

* enable DL kernels on gfx101x

---------

Signed-off-by: Gavin Zhao <git@gzgz.dev>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Issue]: Cannot compile version 7.1.0 for gfx1012, multiple errors in build log

2 participants