diff --git a/include/NZSL/Ast/ExpressionType.hpp b/include/NZSL/Ast/ExpressionType.hpp index 3812e33..de9a968 100644 --- a/include/NZSL/Ast/ExpressionType.hpp +++ b/include/NZSL/Ast/ExpressionType.hpp @@ -241,7 +241,7 @@ namespace nzsl::Ast std::string name; std::string tag; std::vector members; - bool isConditional = false; + unsigned int conditionIndex = 0; }; inline bool IsAliasType(const ExpressionType& type); diff --git a/include/NZSL/Ast/SanitizeVisitor.hpp b/include/NZSL/Ast/SanitizeVisitor.hpp index 2f84381..1205c9e 100644 --- a/include/NZSL/Ast/SanitizeVisitor.hpp +++ b/include/NZSL/Ast/SanitizeVisitor.hpp @@ -227,7 +227,7 @@ namespace nzsl::Ast { std::size_t index; IdentifierCategory category; - bool isConditional = false; + unsigned int conditionalIndex = 0; }; struct Identifier diff --git a/src/NZSL/Ast/SanitizeVisitor.cpp b/src/NZSL/Ast/SanitizeVisitor.cpp index 7de3f64..d389c38 100644 --- a/src/NZSL/Ast/SanitizeVisitor.cpp +++ b/src/NZSL/Ast/SanitizeVisitor.cpp @@ -185,7 +185,7 @@ namespace nzsl::Ast struct UsedExternalData { - bool isConditional; + unsigned int conditionalStatementIndex; }; static constexpr std::size_t ModuleIdSentinel = std::numeric_limits::max(); @@ -212,8 +212,9 @@ namespace nzsl::Ast Options options; FunctionData* currentFunction = nullptr; bool allowUnknownIdentifiers = false; - bool inConditionalStatement = false; bool inLoop = false; + unsigned int currentConditionalIndex = 0; + unsigned int nextConditionalIndex = 1; }; ModulePtr SanitizeVisitor::Sanitize(const Module& module, const Options& options, std::string* error) @@ -323,7 +324,12 @@ namespace nzsl::Ast const auto& env = *m_context->modules[moduleIndex].environment; identifierData = FindIdentifier(env, node.identifiers.front().identifier); if (identifierData) + { + if (m_context->options.partialSanitization && identifierData->conditionalIndex != m_context->currentConditionalIndex) + return Cloner::Clone(node); + return HandleIdentifier(identifierData, node.identifiers.front().sourceLocation); + } } } @@ -447,7 +453,7 @@ namespace nzsl::Ast if (!fieldPtr) { - if (s->isConditional) + if (s->conditionIndex != m_context->currentConditionalIndex) return Cloner::Clone(node); //< unresolved throw CompilerUnknownFieldError{ indexedExpr->sourceLocation, identifierEntry.identifier }; @@ -1150,6 +1156,9 @@ namespace nzsl::Ast if (identifierData->category == IdentifierCategory::Unresolved) return Cloner::Clone(node); + if (m_context->options.partialSanitization && identifierData->conditionalIndex != m_context->currentConditionalIndex) + return Cloner::Clone(node); + return HandleIdentifier(identifierData, node.sourceLocation); } @@ -1354,9 +1363,9 @@ namespace nzsl::Ast if (!conditionValue.has_value()) { - bool wasInConditionalStatement = m_context->inConditionalStatement; - m_context->inConditionalStatement = true; - Nz::CallOnExit restoreCond([=] { m_context->inConditionalStatement = wasInConditionalStatement; }); + unsigned int prevCondStatementIndex = m_context->currentConditionalIndex; + m_context->currentConditionalIndex = m_context->nextConditionalIndex++; + Nz::CallOnExit restoreCond([=] { m_context->currentConditionalIndex = prevCondStatementIndex; }); // Unresolvable condition auto condStatement = ShaderBuilder::ConditionalStatement(std::move(cloneCondition), Cloner::Clone(*node.statement)); @@ -1447,7 +1456,7 @@ namespace nzsl::Ast std::uint64_t bindingKey = BuildBindingKey(bindingSet, bindingIndex + i); if (auto it = m_context->usedBindingIndexes.find(bindingKey); it != m_context->usedBindingIndexes.end()) { - if (!it->second.isConditional || !usedBindingData.isConditional) + if (it->second.conditionalStatementIndex == m_context->currentConditionalIndex || usedBindingData.conditionalStatementIndex == m_context->currentConditionalIndex) throw CompilerExtBindingAlreadyUsedError{ sourceLoc, bindingSet, bindingIndex }; } @@ -1462,11 +1471,11 @@ namespace nzsl::Ast auto& extVar = clone->externalVars[i]; Context::UsedExternalData usedBindingData; - usedBindingData.isConditional = m_context->inConditionalStatement; + usedBindingData.conditionalStatementIndex = m_context->currentConditionalIndex; if (auto it = m_context->declaredExternalVar.find(extVar.name); it != m_context->declaredExternalVar.end()) { - if (!it->second.isConditional || !usedBindingData.isConditional) + if (it->second.conditionalStatementIndex == m_context->currentConditionalIndex || usedBindingData.conditionalStatementIndex == m_context->currentConditionalIndex) throw CompilerExtAlreadyDeclaredError{ extVar.sourceLocation, extVar.name }; } @@ -1586,7 +1595,7 @@ namespace nzsl::Ast bindingIndex++; Context::UsedExternalData usedBindingData; - usedBindingData.isConditional = m_context->inConditionalStatement; + usedBindingData.conditionalStatementIndex = m_context->currentConditionalIndex; extVar.bindingIndex = bindingIndex; RegisterBinding(arraySize, bindingSet, bindingIndex, usedBindingData, extVar.sourceLocation); @@ -1912,7 +1921,7 @@ namespace nzsl::Ast } } - clone->description.isConditional = m_context->inConditionalStatement; + clone->description.conditionIndex = m_context->currentConditionalIndex; clone->structIndex = RegisterStruct(clone->description.name, &clone->description, clone->structIndex, clone->sourceLocation); SanitizeIdentifier(clone->description.name, IdentifierScope::Struct); @@ -3547,7 +3556,7 @@ namespace nzsl::Ast bool unresolved = false; if (const IdentifierData* identifierData = FindIdentifier(name)) { - if (!m_context->inConditionalStatement || !identifierData->isConditional) + if (identifierData->conditionalIndex == m_context->currentConditionalIndex) throw CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; else unresolved = true; @@ -3571,7 +3580,7 @@ namespace nzsl::Ast { aliasIndex, IdentifierCategory::Alias, - m_context->inConditionalStatement + m_context->currentConditionalIndex } }); } @@ -3602,7 +3611,7 @@ namespace nzsl::Ast { constantIndex, IdentifierCategory::Constant, - m_context->inConditionalStatement + m_context->currentConditionalIndex } }); @@ -3654,7 +3663,7 @@ namespace nzsl::Ast { functionIndex, IdentifierCategory::Function, - m_context->inConditionalStatement + m_context->currentConditionalIndex } }); @@ -3673,7 +3682,7 @@ namespace nzsl::Ast { intrinsicIndex, IdentifierCategory::Intrinsic, - m_context->inConditionalStatement + m_context->currentConditionalIndex } }); @@ -3692,7 +3701,7 @@ namespace nzsl::Ast { moduleIndex, IdentifierCategory::Module, - m_context->inConditionalStatement + m_context->currentConditionalIndex } }); @@ -3706,7 +3715,7 @@ namespace nzsl::Ast { std::numeric_limits::max(), IdentifierCategory::ReservedName, - m_context->inConditionalStatement + m_context->currentConditionalIndex } }); } @@ -3716,7 +3725,7 @@ namespace nzsl::Ast bool unresolved = false; if (const IdentifierData* identifierData = FindIdentifier(name)) { - if (!m_context->inConditionalStatement || !identifierData->isConditional) + if (identifierData->conditionalIndex == m_context->currentConditionalIndex) throw CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; else unresolved = true; @@ -3740,7 +3749,7 @@ namespace nzsl::Ast { structIndex, IdentifierCategory::Struct, - m_context->inConditionalStatement + m_context->currentConditionalIndex } }); } @@ -3771,7 +3780,7 @@ namespace nzsl::Ast { typeIndex, IdentifierCategory::Type, - m_context->inConditionalStatement + m_context->currentConditionalIndex } }); @@ -3805,7 +3814,7 @@ namespace nzsl::Ast { typeIndex, IdentifierCategory::Type, - m_context->inConditionalStatement + m_context->currentConditionalIndex } }); @@ -3819,7 +3828,7 @@ namespace nzsl::Ast { std::numeric_limits::max(), IdentifierCategory::Unresolved, - m_context->inConditionalStatement + m_context->currentConditionalIndex } }); } @@ -3832,7 +3841,8 @@ namespace nzsl::Ast // Allow variable shadowing if (identifier->category != IdentifierCategory::Variable) throw CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; - else if (identifier->isConditional && m_context->inConditionalStatement) + + if (identifier->conditionalIndex != m_context->currentConditionalIndex) unresolved = true; //< right variable isn't know from this point } @@ -3854,7 +3864,7 @@ namespace nzsl::Ast { varIndex, IdentifierCategory::Variable, - m_context->inConditionalStatement + m_context->currentConditionalIndex } }); } @@ -4102,7 +4112,10 @@ namespace nzsl::Ast const ExpressionType* exprType = GetExpressionType(*node.expression); if (!exprType) + { + RegisterUnresolved(node.name); return ValidationResult::Unresolved; + } const ExpressionType& resolvedType = ResolveAlias(*exprType); diff --git a/tests/src/Tests/ModuleTests.cpp b/tests/src/Tests/ModuleTests.cpp index 8818ac9..8c10a37 100644 --- a/tests/src/Tests/ModuleTests.cpp +++ b/tests/src/Tests/ModuleTests.cpp @@ -825,4 +825,141 @@ OpStore OpReturn OpFunctionEnd)"); } + + WHEN("Testing forward vs deferred based on option") + { + // Test a bugfix where an unresolved identifier (identifier imported from an unknown module when precompiling) was being resolved in + + std::string_view gbufferOutput = R"( +[nzsl_version("1.0")] +module DeferredShading.GBuffer; + +[export] +struct GBufferOutput +{ + [location(0)] albedo: vec4[f32], + [location(1)] normal: vec4[f32], +} +)"; + + std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + +import GBufferOutput from DeferredShading.GBuffer; + +option ForwardPass: bool = true; + +[cond(ForwardPass)] +struct FragOut +{ + [location(0)] color: vec4[f32] +} + +[cond(!ForwardPass)] +alias FragOut = GBufferOutput; + +[entry(frag)] +fn FragMain() -> FragOut +{ + let color = vec4[f32](1.0, 0.0, 0.0, 1.0); + + const if (ForwardPass) + { + let output: FragOut; + output.color = color; + + return output; + } + else + { + let normal = vec3[f32](0.0, 1.0, 0.0); + + let output: FragOut; + output.albedo = color; + output.normal = vec4[f32](normal, 1.0); + + return output; + } +} +)"; + + nzsl::Ast::ModulePtr shaderModule = nzsl::Parse(nzslSource); + + auto directoryModuleResolver = std::make_shared(); + RegisterModule(directoryModuleResolver, gbufferOutput); + + nzsl::Ast::SanitizeVisitor::Options options; + options.partialSanitization = true; + + REQUIRE_NOTHROW(shaderModule = nzsl::Ast::Sanitize(*shaderModule, options)); + + options.moduleResolver = directoryModuleResolver; + options.partialSanitization = false; + options.removeOptionDeclaration = true; + + WHEN("Trying ForwardPass=true") + { + options.optionValues[nzsl::Ast::HashOption("ForwardPass")] = true; + + REQUIRE_NOTHROW(shaderModule = nzsl::Ast::Sanitize(*shaderModule, options)); + + ExpectNZSL(*shaderModule, R"( +struct FragOut +{ + [location(0)] color: vec4[f32] +} + +[entry(frag)] +fn FragMain() -> FragOut +{ + let color: vec4[f32] = vec4[f32](1.0, 0.0, 0.0, 1.0); + { + let output: FragOut; + output.color = color; + return output; + } + +} +)"); + } + + + WHEN("Trying ForwardPass=false") + { + options.optionValues[nzsl::Ast::HashOption("ForwardPass")] = false; + + REQUIRE_NOTHROW(shaderModule = nzsl::Ast::Sanitize(*shaderModule, options)); + + ExpectNZSL(*shaderModule, R"( +[nzsl_version("1.0")] +module _DeferredShading_GBuffer +{ + struct GBufferOutput + { + [location(0)] albedo: vec4[f32], + [location(1)] normal: vec4[f32] + } + +} +alias GBufferOutput = _DeferredShading_GBuffer.GBufferOutput; + +alias FragOut = GBufferOutput; + +[entry(frag)] +fn FragMain() -> FragOut +{ + let color: vec4[f32] = vec4[f32](1.0, 0.0, 0.0, 1.0); + { + let normal: vec3[f32] = vec3[f32](0.0, 1.0, 0.0); + let output: FragOut; + output.albedo = color; + output.normal = vec4[f32](normal, 1.0); + return output; + } + +} +)"); + } + } } diff --git a/tests/src/Tests/ShaderUtils.cpp b/tests/src/Tests/ShaderUtils.cpp index 207b6a7..42d1964 100644 --- a/tests/src/Tests/ShaderUtils.cpp +++ b/tests/src/Tests/ShaderUtils.cpp @@ -294,7 +294,6 @@ void ExpectGLSL(nzsl::ShaderStageType stageType, const nzsl::Ast::Module& shader void ExpectGLSL(const nzsl::Ast::Module& shaderModule, std::string_view expectedOutput, const nzsl::ShaderWriter::States& options, const nzsl::GlslWriter::Environment& env, bool testShaderCompilation) { - // Retrieve entry-point to get shader type std::optional entryShaderStage;