Skip to content

Commit 1b18e96

Browse files
committed
Fix cuda flag with clang-repl
1 parent ecbd2d5 commit 1b18e96

File tree

3 files changed

+37
-17
lines changed

3 files changed

+37
-17
lines changed

clang/include/clang/Interpreter/Interpreter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ class Interpreter {
129129
public:
130130
virtual ~Interpreter();
131131
static llvm::Expected<std::unique_ptr<Interpreter>>
132-
create(std::unique_ptr<CompilerInstance> CI);
132+
create(std::unique_ptr<CompilerInstance> CI,
133+
std::unique_ptr<CompilerInstance> DeviceCI = nullptr);
133134
static llvm::Expected<std::unique_ptr<Interpreter>>
134135
createWithCUDA(std::unique_ptr<CompilerInstance> CI,
135136
std::unique_ptr<CompilerInstance> DCI);

clang/lib/Interpreter/DeviceOffload.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ IncrementalCUDADeviceParser::IncrementalCUDADeviceParser(
3434
TargetOpts(HostInstance.getTargetOpts()) {
3535
if (Err)
3636
return;
37-
DeviceCI = std::move(DeviceInstance);
3837
StringRef Arch = TargetOpts.CPU;
3938
if (!Arch.starts_with("sm_") || Arch.substr(3).getAsInteger(10, SMVersion)) {
39+
DeviceInstance.release();
4040
Err = llvm::joinErrors(std::move(Err), llvm::make_error<llvm::StringError>(
4141
"Invalid CUDA architecture",
4242
llvm::inconvertibleErrorCode()));
4343
return;
4444
}
45+
DeviceCI = std::move(DeviceInstance);
4546
}
4647

4748
llvm::Expected<TranslationUnitDecl *>

clang/lib/Interpreter/Interpreter.cpp

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -451,13 +451,44 @@ const char *const Runtimes = R"(
451451
)";
452452

453453
llvm::Expected<std::unique_ptr<Interpreter>>
454-
Interpreter::create(std::unique_ptr<CompilerInstance> CI) {
454+
Interpreter::create(std::unique_ptr<CompilerInstance> CI,
455+
std::unique_ptr<CompilerInstance> DeviceCI) {
455456
llvm::Error Err = llvm::Error::success();
456457
auto Interp =
457458
std::unique_ptr<Interpreter>(new Interpreter(std::move(CI), Err));
458459
if (Err)
459460
return std::move(Err);
460461

462+
if (DeviceCI) {
463+
// auto DeviceLLVMCtx = std::make_unique<llvm::LLVMContext>();
464+
// auto DeviceTSCtx =
465+
// std::make_unique<llvm::orc::ThreadSafeContext>(std::move(DeviceLLVMCtx));
466+
467+
// llvm::Error DeviceErr = llvm::Error::success();
468+
// llvm::ErrorAsOutParameter EAO(&DeviceErr);
469+
470+
// auto DeviceAct = std::make_unique<IncrementalAction>(
471+
// *DeviceCI, *DeviceTSCtx->getContext(), DeviceErr, *Interp);
472+
473+
// if (DeviceErr)
474+
// return std::move(DeviceErr);
475+
476+
// DeviceCI->ExecuteAction(*DeviceAct);
477+
DeviceCI->ExecuteAction(*Interp->Act);
478+
479+
llvm::IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> IMVFS =
480+
std::make_unique<llvm::vfs::InMemoryFileSystem>();
481+
482+
auto DeviceParser = std::make_unique<IncrementalCUDADeviceParser>(
483+
std::move(DeviceCI), *Interp->getCompilerInstance(), IMVFS, Err,
484+
Interp->PTUs);
485+
486+
if (Err)
487+
return std::move(Err);
488+
489+
Interp->DeviceParser = std::move(DeviceParser);
490+
}
491+
461492
// Add runtime code and set a marker to hide it from user code. Undo will not
462493
// go through that.
463494
auto PTU = Interp->Parse(Runtimes);
@@ -481,20 +512,7 @@ Interpreter::createWithCUDA(std::unique_ptr<CompilerInstance> CI,
481512
OverlayVFS->pushOverlay(IMVFS);
482513
CI->createFileManager(OverlayVFS);
483514

484-
auto Interp = Interpreter::create(std::move(CI));
485-
if (auto E = Interp.takeError())
486-
return std::move(E);
487-
488-
llvm::Error Err = llvm::Error::success();
489-
auto DeviceParser = std::make_unique<IncrementalCUDADeviceParser>(
490-
std::move(DCI), *(*Interp)->getCompilerInstance(), IMVFS, Err,
491-
(*Interp)->PTUs);
492-
if (Err)
493-
return std::move(Err);
494-
495-
(*Interp)->DeviceParser = std::move(DeviceParser);
496-
497-
return Interp;
515+
return Interpreter::create(std::move(CI), std::move(DCI));
498516
}
499517

500518
const CompilerInstance *Interpreter::getCompilerInstance() const {

0 commit comments

Comments
 (0)