1212// ===----------------------------------------------------------------------===//
1313
1414#include " mlir/Dialect/GPU/Transforms/Passes.h"
15- #include " llvm/Support/Debug.h"
15+ #include " llvm/ADT/StringRef.h"
16+ #include " llvm/Support/FileSystem.h"
17+ #include " llvm/Support/FileUtilities.h"
18+ #include " llvm/Support/MemoryBuffer.h"
19+ #include " llvm/Support/Process.h"
20+ #include " llvm/Support/Program.h"
21+ #include " llvm/Support/WithColor.h"
22+ #include " llvm/Support/raw_ostream.h"
1623
1724#if MLIR_GPU_TO_CUBIN_PASS_ENABLE
1825#include " mlir/Pass/Pass.h"
@@ -36,6 +43,106 @@ static void emitCudaError(const llvm::Twine &expr, const char *buffer,
3643 .concat (" ]" ));
3744}
3845
46+ static constexpr char kPtxasCompilerName [] = " ptxas" ;
47+
48+ // / Compiles the given generated PTX code with the given ptxas compiler.
49+ static FailureOr<std::string>
50+ compileWithPtxas (StringRef smCapability, StringRef ptxasParams,
51+ StringRef ptxSource, bool dumpPtx, std::string *message) {
52+ // Step 0. Find ptxas compiler
53+ std::optional<std::string> ptxasCompiler =
54+ llvm::sys::Process::FindInEnvPath (" PATH" , kPtxasCompilerName );
55+ if (!ptxasCompiler.has_value ())
56+ return failure ();
57+
58+ // Step 1. Create temporary files: ptx source file, log file and cubin file
59+ llvm::SmallString<64 > ptxSourceFile, stdinFile, stdoutFile, stderrFile;
60+ llvm::sys::fs::createTemporaryFile (" mlir-ptx" , " " , ptxSourceFile);
61+ llvm::sys::fs::createTemporaryFile (" ptxas-stdin" , " " , stdinFile);
62+ llvm::sys::fs::createTemporaryFile (" ptxas-stdout" , " " , stdoutFile);
63+ llvm::sys::fs::createTemporaryFile (" ptxas-stderr" , " " , stderrFile);
64+ std::string cubinFile = std::string (ptxSourceFile) + " .cubin" ;
65+ llvm::FileRemover stdinRemover (stdinFile.c_str ());
66+ llvm::FileRemover stdoutRemover (stdoutFile.c_str ());
67+ llvm::FileRemover stderrRemover (stderrFile.c_str ());
68+ llvm::FileRemover binRemover (cubinFile.c_str ());
69+ llvm::FileRemover srcRemover (ptxSourceFile.c_str ());
70+
71+ // Step 2. Write the generated PTX into a file, so we can pass it to ptxas
72+ // compiler
73+ std::error_code ec;
74+ llvm::raw_fd_ostream fPtxSource (ptxSourceFile, ec);
75+ fPtxSource << ptxSource;
76+ fPtxSource .close ();
77+ if (fPtxSource .has_error ()) {
78+ *message = std::string (
79+ " Could not write the generated ptx into a temporary file\n " );
80+ return failure ();
81+ }
82+
83+ // Step 3. Build the ptxas command line
84+ std::vector<StringRef> argVector{StringRef (" ptxas" ), StringRef (" -arch" ),
85+ smCapability, StringRef (ptxSourceFile),
86+ StringRef (" -o" ), StringRef (cubinFile)};
87+ #ifdef _WIN32
88+ auto tokenize = llvm::cl::TokenizeWindowsCommandLine;
89+ #else
90+ auto tokenize = llvm::cl::TokenizeGNUCommandLine;
91+ #endif // _WIN32
92+ llvm::BumpPtrAllocator scratchAllocator;
93+ llvm::StringSaver stringSaver (scratchAllocator);
94+ SmallVector<const char *> rawArgs;
95+ tokenize (ptxasParams, stringSaver, rawArgs, /* MarkEOLs=*/ false );
96+ for (const auto *rawArg : rawArgs)
97+ argVector.emplace_back (rawArg);
98+
99+ std::optional<StringRef> redirects[] = {
100+ stdinFile.str (),
101+ stdoutFile.str (),
102+ stderrFile.str (),
103+ };
104+
105+ // Step 4. Invoke ptxas
106+ if (llvm::sys::ExecuteAndWait (ptxasCompiler.value (),
107+ llvm::ArrayRef<llvm::StringRef>(argVector),
108+ /* Env=*/ std::nullopt ,
109+ /* Redirects=*/ redirects,
110+ /* SecondsToWait=*/ 0 ,
111+ /* MemoryLimit=*/ 0 ,
112+ /* ErrMsg=*/ message)) {
113+ if (message->empty ()) {
114+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> maybeErrorlog =
115+ llvm::MemoryBuffer::getFile (stderrFile);
116+ *message = std::string (" Invoking ptxas is failed, see the file: " );
117+ if (maybeErrorlog)
118+ *message += maybeErrorlog->get ()->getBuffer ().str ();
119+ }
120+ stderrRemover.releaseFile ();
121+ return failure ();
122+ }
123+
124+ // Step 5. The output of ptxas if verbose flag is set. This is useful
125+ // because it shows local memory usage, register usage, and etc.
126+ if (dumpPtx) {
127+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> maybeFlog =
128+ llvm::MemoryBuffer::getFile (stderrFile);
129+ if (maybeFlog) {
130+ llvm::WithColor::note () << maybeFlog->get ()->getBuffer ().str ();
131+ }
132+ }
133+
134+ // Step 6. Read the cubin file, and return. It will eventually be written
135+ // into executable.
136+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> maybeFcubin =
137+ llvm::MemoryBuffer::getFile (cubinFile);
138+ if (!maybeFcubin) {
139+ *message = std::string (" Could not read cubin file \n " );
140+ return failure ();
141+ }
142+
143+ return std::string (maybeFcubin->get ()->getBuffer ());
144+ }
145+
39146#define RETURN_ON_CUDA_ERROR (expr ) \
40147 do { \
41148 if (auto status = (expr)) { \
@@ -54,11 +161,13 @@ class SerializeToCubinPass
54161
55162 SerializeToCubinPass (StringRef triple = " nvptx64-nvidia-cuda" ,
56163 StringRef chip = " sm_35" , StringRef features = " +ptx60" ,
57- int optLevel = 2 , bool dumpPtx = false );
164+ int optLevel = 2 , bool dumpPtx = false ,
165+ bool usePtxas = true , StringRef ptxasParams = {});
58166
59167 StringRef getArgument () const override { return " gpu-to-cubin" ; }
60168 StringRef getDescription () const override {
61- return " Lower GPU kernel function to CUBIN binary annotations" ;
169+ return " Lower GPU kernel function to CUBIN binary "
170+ " annotations" ;
62171 }
63172
64173private:
@@ -80,9 +189,10 @@ llvm::once_flag SerializeToCubinPass::initializeBackendOnce;
80189
81190SerializeToCubinPass::SerializeToCubinPass (StringRef triple, StringRef chip,
82191 StringRef features, int optLevel,
83- bool dumpPtx) {
84- // No matter how this pass is constructed, ensure that the NVPTX backend
85- // is initialized exactly once.
192+ bool dumpPtx, bool usePtxas,
193+ StringRef ptxasParams) {
194+ // No matter how this pass is constructed, ensure that
195+ // the NVPTX backend is initialized exactly once.
86196 llvm::call_once (initializeBackendOnce, []() {
87197 // Initialize LLVM NVPTX backend.
88198 LLVMInitializeNVPTXTarget ();
@@ -94,7 +204,9 @@ SerializeToCubinPass::SerializeToCubinPass(StringRef triple, StringRef chip,
94204 maybeSetOption (this ->triple , triple);
95205 maybeSetOption (this ->chip , chip);
96206 maybeSetOption (this ->features , features);
207+ maybeSetOption (this ->ptxasParams , ptxasParams);
97208 this ->dumpPtx = dumpPtx;
209+ this ->usePtxas = usePtxas;
98210 if (this ->optLevel .getNumOccurrences () == 0 )
99211 this ->optLevel .setValue (optLevel);
100212}
@@ -112,7 +224,8 @@ SerializeToCubinPass::serializeISA(const std::string &isa) {
112224
113225 RETURN_ON_CUDA_ERROR (cuInit (0 ));
114226
115- // Linking requires a device context.
227+ // Linking requires a device
228+ // context.
116229 CUdevice device;
117230 RETURN_ON_CUDA_ERROR (cuDeviceGet (&device, 0 ));
118231 CUcontext context;
@@ -131,9 +244,24 @@ SerializeToCubinPass::serializeISA(const std::string &isa) {
131244
132245 auto kernelName = getOperation ().getName ().str ();
133246 if (dumpPtx) {
134- llvm::dbgs () << " Kernel Name : [" << kernelName << " ]\n " ;
135- llvm::dbgs () << isa << " \n " ;
247+ llvm::errs () << " // Kernel Name : [" << kernelName << " ]\n " ;
248+ llvm::errs () << isa << " \n " ;
136249 }
250+
251+ if (usePtxas) {
252+ // Try to compile it with ptxas first.
253+ std::string message;
254+ FailureOr<std::string> maybeCubinImage =
255+ compileWithPtxas (this ->chip , ptxasParams, isa, dumpPtx, &message);
256+ if (succeeded (maybeCubinImage)) {
257+ return std::make_unique<std::vector<char >>(
258+ maybeCubinImage.value ().begin (), maybeCubinImage.value ().end ());
259+ }
260+ emitError (loc) << message;
261+ return {};
262+ }
263+
264+ // Fallback to JIT compilation if ptxas fails.
137265 RETURN_ON_CUDA_ERROR (cuLinkAddData (
138266 linkState, CUjitInputType::CU_JIT_INPUT_PTX,
139267 const_cast <void *>(static_cast <const void *>(isa.c_str ())), isa.length (),
@@ -150,7 +278,7 @@ SerializeToCubinPass::serializeISA(const std::string &isa) {
150278 auto result =
151279 std::make_unique<std::vector<char >>(cubinAsChar, cubinAsChar + cubinSize);
152280
153- // This will also destroy the cubin data.
281+ // This will also destroy the cubin data.
154282 RETURN_ON_CUDA_ERROR (cuLinkDestroy (linkState));
155283 RETURN_ON_CUDA_ERROR (cuCtxDestroy (context));
156284
@@ -159,17 +287,22 @@ SerializeToCubinPass::serializeISA(const std::string &isa) {
159287
160288// Register pass to serialize GPU kernel functions to a CUBIN binary annotation.
161289void mlir::registerGpuSerializeToCubinPass () {
162- PassRegistration<SerializeToCubinPass> registerSerializeToCubin (
163- [] { return std::make_unique<SerializeToCubinPass>(); });
290+ PassRegistration<SerializeToCubinPass> registerSerializeToCubin ([] {
291+ // Initialize LLVM NVPTX backend.
292+ LLVMInitializeNVPTXTarget ();
293+ LLVMInitializeNVPTXTargetInfo ();
294+ LLVMInitializeNVPTXTargetMC ();
295+ LLVMInitializeNVPTXAsmPrinter ();
296+
297+ return std::make_unique<SerializeToCubinPass>();
298+ });
164299}
165300
166- std::unique_ptr<Pass> mlir::createGpuSerializeToCubinPass (StringRef triple,
167- StringRef arch,
168- StringRef features,
169- int optLevel,
170- bool dumpPtx) {
171- return std::make_unique<SerializeToCubinPass>(triple, arch, features,
172- optLevel, dumpPtx);
301+ std::unique_ptr<Pass> mlir::createGpuSerializeToCubinPass (
302+ const gpu::SerializationToCubinOptions &options) {
303+ return std::make_unique<SerializeToCubinPass>(
304+ options.triple , options.chip , options.features , options.optLevel ,
305+ options.dumpPtx , options.usePtxas , options.ptxasParams );
173306}
174307
175308#else // MLIR_GPU_TO_CUBIN_PASS_ENABLE
0 commit comments