|
13 | 13 | #include "clang/Driver/Job.h"
|
14 | 14 | #include "llvm/ADT/StringSwitch.h"
|
15 | 15 | #include "llvm/TargetParser/Triple.h"
|
| 16 | +#include <regex> |
16 | 17 |
|
17 | 18 | using namespace clang::driver;
|
18 | 19 | using namespace clang::driver::tools;
|
@@ -173,6 +174,39 @@ bool isLegalValidatorVersion(StringRef ValVersionStr, const Driver &D) {
|
173 | 174 | return true;
|
174 | 175 | }
|
175 | 176 |
|
| 177 | +std::string getSpirvExtArg(ArrayRef<std::string> SpvExtensionArgs) { |
| 178 | + if (SpvExtensionArgs.empty()) { |
| 179 | + return "-spirv-ext=all"; |
| 180 | + } |
| 181 | + |
| 182 | + std::string LlvmOption = |
| 183 | + (Twine("-spirv-ext=+") + SpvExtensionArgs.front()).str(); |
| 184 | + SpvExtensionArgs = SpvExtensionArgs.slice(1); |
| 185 | + for (auto Extension : SpvExtensionArgs) { |
| 186 | + LlvmOption = (Twine(LlvmOption) + ",+" + Extension).str(); |
| 187 | + } |
| 188 | + return LlvmOption; |
| 189 | +} |
| 190 | + |
| 191 | +bool isValidSPIRVExtensionName(const std::string &str) { |
| 192 | + std::regex pattern("SPV_[a-zA-Z0-9_]+"); |
| 193 | + return std::regex_match(str, pattern); |
| 194 | +} |
| 195 | + |
| 196 | +// SPIRV extension names are of the form `SPV_[a-zA-Z0-9_]+`. We want to |
| 197 | +// disallow obviously invalid names to avoid issues when parsing `spirv-ext`. |
| 198 | +bool checkExtensionArgsAreValid(ArrayRef<std::string> SpvExtensionArgs, |
| 199 | + const Driver &Driver) { |
| 200 | + bool AllValid = true; |
| 201 | + for (auto Extension : SpvExtensionArgs) { |
| 202 | + if (!isValidSPIRVExtensionName(Extension)) { |
| 203 | + Driver.Diag(diag::err_drv_invalid_value) |
| 204 | + << "-fspv_extension" << Extension; |
| 205 | + AllValid = false; |
| 206 | + } |
| 207 | + } |
| 208 | + return AllValid; |
| 209 | +} |
176 | 210 | } // namespace
|
177 | 211 |
|
178 | 212 | void tools::hlsl::Validator::ConstructJob(Compilation &C, const JobAction &JA,
|
@@ -301,6 +335,17 @@ HLSLToolChain::TranslateArgs(const DerivedArgList &Args, StringRef BoundArch,
|
301 | 335 | DAL->append(A);
|
302 | 336 | }
|
303 | 337 |
|
| 338 | + if (getArch() == llvm::Triple::spirv) { |
| 339 | + std::vector<std::string> SpvExtensionArgs = |
| 340 | + Args.getAllArgValues(options::OPT_fspv_extension_EQ); |
| 341 | + if (checkExtensionArgsAreValid(SpvExtensionArgs, getDriver())) { |
| 342 | + std::string LlvmOption = getSpirvExtArg(SpvExtensionArgs); |
| 343 | + DAL->AddSeparateArg(nullptr, Opts.getOption(options::OPT_mllvm), |
| 344 | + LlvmOption); |
| 345 | + } |
| 346 | + Args.claimAllArgs(options::OPT_fspv_extension_EQ); |
| 347 | + } |
| 348 | + |
304 | 349 | if (!DAL->hasArg(options::OPT_O_Group)) {
|
305 | 350 | DAL->AddJoinedArg(nullptr, Opts.getOption(options::OPT_O), "3");
|
306 | 351 | }
|
|
0 commit comments