Skip to content

Commit b8f7642

Browse files
sergey-kozubGoogle-ML-Automation
authored andcommitted
PR #21822: [XLA:GPU] Add support for SM100a architecture (Blackwell)
Imported from GitHub PR #21822 Created `ShouldUsePtxExtension` helper for the extension suffix (this will also be used for sm120, etc). CUDA 12.8 was recently released, which supports PTX 8.7, but that is not supported by the integrated LLVM (support added in llvm/llvm-project#124155), so leaving the association with PTX 8.6 - this doesn't raise warnings during compilation. Copybara import of the project: -- 267cf74 by Sergey Kozub <skozub@nvidia.com>: Add support for SM100a architecture (Blackwell) Merging this change closes #21822 COPYBARA_INTEGRATE_REVIEW=#21822 from openxla:devel/sm100a 267cf74 PiperOrigin-RevId: 720806648
1 parent 2b6351b commit b8f7642

File tree

9 files changed

+36
-17
lines changed

9 files changed

+36
-17
lines changed

xla/service/gpu/llvm_gpu_backend/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ cc_library(
9999
"//xla/service/llvm_ir:llvm_command_line_options",
100100
"//xla/stream_executor:device_description",
101101
"//xla/stream_executor:semantic_version",
102+
"//xla/stream_executor/cuda:ptx_compiler_helpers",
102103
"//xla/stream_executor/cuda:subprocess_compilation",
103104
"@com_google_absl//absl/base",
104105
"@com_google_absl//absl/status",

xla/service/gpu/llvm_gpu_backend/nvptx_backend.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ limitations under the License.
6060
#include "xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h"
6161
#include "xla/service/gpu/metrics.h"
6262
#include "xla/service/llvm_ir/llvm_command_line_options.h"
63+
#include "xla/stream_executor/cuda/ptx_compiler_helpers.h"
6364
#include "xla/stream_executor/cuda/subprocess_compilation.h"
6465
#include "xla/stream_executor/device_description.h"
6566
#include "xla/stream_executor/semantic_version.h"
@@ -237,8 +238,8 @@ std::string GetSmName(se::CudaComputeCapability compute_capability) {
237238
int sm_version = 30;
238239
// If the current compute capability isn't known, fallback to the
239240
// most recent version before it.
240-
int supported_versions[] = {90, 89, 87, 86, 80, 75, 72, 70, 62,
241-
61, 60, 53, 52, 50, 37, 35, 32, 30};
241+
int supported_versions[] = {100, 90, 89, 87, 86, 80, 75, 72, 70, 62,
242+
61, 60, 53, 52, 50, 37, 35, 32, 30};
242243
for (int v : supported_versions) {
243244
if (v <= compute_capability_version) {
244245
sm_version = v;
@@ -260,8 +261,9 @@ std::string GetSmName(se::CudaComputeCapability compute_capability) {
260261
// On Hopper, default to sm_90a so that all instructions can be used. But
261262
// only sm_90 is forward compatible, so don't use sm_90a with newer hardware:
262263
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility
264+
// Similarly for sm_100a (Blackwell).
263265
absl::string_view extension =
264-
(compute_capability.major == 9 && sm_version == 90) ? "a" : "";
266+
stream_executor::ShouldUsePtxExtension(compute_capability) ? "a" : "";
265267
return absl::StrCat("sm_", sm_version, extension);
266268
}
267269

@@ -331,7 +333,7 @@ absl::StatusOr<std::string> CompileToPtx(
331333

332334
namespace {
333335
constexpr stream_executor::SemanticVersion kFallbackPtxVersion{6, 5, 0};
334-
constexpr stream_executor::SemanticVersion kMaxPtxVersion{8, 5, 0};
336+
constexpr stream_executor::SemanticVersion kMaxPtxVersion{8, 6, 0};
335337
} // namespace
336338

337339
stream_executor::SemanticVersion
@@ -354,6 +356,10 @@ DetermineHighestSupportedPtxVersionFromCudaVersion(
354356
if (cuda_version < stream_executor::SemanticVersion{12, 6, 0}) {
355357
return {cuda_version.major() - 4, cuda_version.minor(), 0};
356358
}
359+
// CUDA 12.6 -> PTX 8.5
360+
if (cuda_version < stream_executor::SemanticVersion{12, 7, 0}) {
361+
return {cuda_version.major() - 4, cuda_version.minor() - 1, 0};
362+
}
357363

358364
// Return maximum known PTX version.
359365
return kMaxPtxVersion;

xla/service/gpu/llvm_gpu_backend/nvptx_backend_test.cc

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,11 @@ namespace {
2727
namespace se = ::stream_executor;
2828

2929
TEST(UtilsTest, TestGetSmName) {
30-
se::CudaComputeCapability cc_hopper(9, 0);
31-
ASSERT_EQ(nvptx::GetSmName(cc_hopper), "sm_90a");
32-
// Do not default to sm90_a after Hopper, because it is not forward
33-
// compatible.
34-
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility
35-
se::CudaComputeCapability cc_next(10, 0);
36-
ASSERT_EQ(nvptx::GetSmName(cc_next), "sm_90");
30+
ASSERT_EQ(nvptx::GetSmName(se::CudaComputeCapability{9, 0}), "sm_90a");
31+
ASSERT_EQ(nvptx::GetSmName(se::CudaComputeCapability{10, 0}), "sm_100a");
32+
// Do not use the extension for a yet-unknown compute capability.
33+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#release-notes-ptx-release-history
34+
ASSERT_EQ(nvptx::GetSmName(se::CudaComputeCapability{10, 9}), "sm_100");
3735
}
3836

3937
using VersionPair = std::pair<se::SemanticVersion, se::SemanticVersion>;

xla/stream_executor/cuda/driver_compilation_provider.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ absl::StatusOr<Assembly> DriverCompilationProvider::CompileAndLink(
165165
CHECK(info_log_buffer_size() <= kInfoLogBufferSize);
166166
info_log_buffer.resize(info_log_buffer_size());
167167

168-
absl::string_view extension = (cc.major == 9 && cc.minor == 0) ? "a" : "";
168+
absl::string_view extension = ShouldUsePtxExtension(cc) ? "a" : "";
169169
std::string architecture = absl::StrCat("sm_", cc.major, cc.minor, extension);
170170

171171
if (result != CUDA_SUCCESS) {

xla/stream_executor/cuda/nvjitlink_impl.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ static absl::Status ToStatus(nvJitLinkResult status,
8080
} \
8181
} while (false)
8282

83-
8483
static absl::StatusOr<std::string> nvJitLinkGetErrorLog(
8584
nvJitLinkHandle link_handle) {
8685
size_t size{};
@@ -139,7 +138,7 @@ absl::StatusOr<std::vector<uint8_t>> CompileAndLinkUsingLibNvJitLink(
139138
// On Hopper, default to sm_90a so that all instructions can be used. But
140139
// only sm_90 is forward compatible, so don't use sm_90a with newer hardware:
141140
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility
142-
absl::string_view extension = (cc.major == 9 && cc.minor == 0) ? "a" : "";
141+
absl::string_view extension = ShouldUsePtxExtension(cc) ? "a" : "";
143142
std::string architecture = absl::StrCat("sm_", cc.major, cc.minor, extension);
144143
cli_args.emplace_back(absl::StrCat("-arch=", architecture));
145144

xla/stream_executor/cuda/ptx_compiler_helpers.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,10 @@ void WarnIfBadPtxasVersion(absl::string_view method,
101101
});
102102
}
103103

104+
// The extension is used for compute capabilities 9.0 and 10.0.
105+
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility
106+
bool ShouldUsePtxExtension(const CudaComputeCapability& cc) {
107+
return (cc.major == 9 && cc.minor == 0) || (cc.major == 10 && cc.minor == 0);
108+
}
109+
104110
} // namespace stream_executor

xla/stream_executor/cuda/ptx_compiler_helpers.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ absl::Status CreateErrorFromPTXASLog(absl::string_view log,
4343
void WarnIfBadPtxasVersion(absl::string_view method,
4444
const CudaComputeCapability& cc,
4545
SemanticVersion compiler_version);
46+
47+
// Determine whether the PTX extension for a compute capability should be used.
48+
//
49+
// Returns true if the argument compute capability has PTX extensions that are
50+
// only valid for that compute capability. For example, "sm_90" only includes
51+
// features that are forward compatible, whereas "sm_90a" (the extension) also
52+
// includes Hopper-specific features, such as WGMMA. We want to use the latter.
53+
bool ShouldUsePtxExtension(const CudaComputeCapability& cc);
54+
4655
} // namespace stream_executor
4756

4857
#endif // XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_HELPERS_H_

xla/stream_executor/cuda/ptx_compiler_impl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ absl::StatusOr<std::vector<uint8_t>> CompileGpuAsmUsingLibNvPtxCompiler(
9797
// On Hopper, default to sm_90a so that all instructions can be used. But
9898
// only sm_90 is forward compatible, so don't use sm_90a with newer hardware:
9999
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility
100-
absl::string_view extension = (cc.major == 9 && cc.minor == 0) ? "a" : "";
100+
absl::string_view extension = ShouldUsePtxExtension(cc) ? "a" : "";
101101
std::string architecture = absl::StrCat("sm_", cc.major, cc.minor, extension);
102102

103103
options.extra_flags.emplace_back(absl::StrCat("-arch=", architecture));

xla/stream_executor/cuda/subprocess_compilation.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ absl::StatusOr<std::vector<uint8_t>> CompileGpuAsmUsingPtxAs(
294294
// On Hopper, default to sm_90a so that all instructions can be used. But
295295
// only sm_90 is forward compatible, so don't use sm_90a with newer hardware:
296296
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility
297-
std::string extension = (cc.major == 9 && cc.minor == 0) ? "a" : "";
297+
std::string extension = ShouldUsePtxExtension(cc) ? "a" : "";
298298
std::vector<std::string> ptxas_args = {
299299
std::string{ptxas_path},
300300
ptx_path,
@@ -515,7 +515,7 @@ absl::StatusOr<std::vector<uint8_t>> LinkUsingNvlink(
515515
};
516516
std::vector<std::string> args;
517517
args.push_back(std::string{nvlink_path});
518-
absl::string_view extension = (cc.major == 9 && cc.minor == 0) ? "a" : "";
518+
absl::string_view extension = ShouldUsePtxExtension(cc) ? "a" : "";
519519
args.push_back(absl::StrCat("-arch=sm_", cc.major, cc.minor, extension));
520520
for (int i = 0; i < images.size(); i++) {
521521
args.push_back(temp_files[i]);

0 commit comments

Comments
 (0)