Skip to content

Commit

Permalink
SpirV: Add missing matrix decoration for array of matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
SirLynix committed Jul 26, 2024
1 parent 95b85a7 commit e028a84
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 41 deletions.
19 changes: 8 additions & 11 deletions src/NZSL/SpirV/SpirvConstantCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1368,20 +1368,17 @@ namespace nzsl
if (debugLevel >= DebugLevel::Minimal)
debugInfos.Append(SpirvOp::OpMemberName, resultId, memberIndex, member.name);

std::uint32_t offset = member.offset.value();
AnyType* type = &member.type->type;
while (std::holds_alternative<Array>(*type))
type = &std::get<Array>(*type).elementType->type;

std::visit([&](auto&& arg)
if (std::holds_alternative<Matrix>(*type))
{
using T = std::decay_t<decltype(arg)>;

if constexpr (std::is_same_v<T, Matrix>)
{
annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::ColMajor);
annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::MatrixStride, 16);
}
}, member.type->type);
annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::ColMajor);
annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::MatrixStride, 16);
}

annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::Offset, offset);
annotations.Append(SpirvOp::OpMemberDecorate, resultId, memberIndex, SpirvDecoration::Offset, member.offset.value());
}
}
}
78 changes: 48 additions & 30 deletions tests/src/Tests/ExternalTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ module;
struct Data
{
[tag("Values")]
values: array[f32, 47]
values: array[f32, 47],
matrices: array[mat4[f32], 3]
}
external
Expand All @@ -193,7 +194,7 @@ external
[entry(frag)]
fn main()
{
let value = data.values[42];
let value = data.values[42] * data.matrices[1];
}
)";

Expand All @@ -206,19 +207,21 @@ uniform _nzslBindingdata
{
// member tag: Values
float values[47];
mat4 matrices[3];
} data;
void main()
{
float value = data.values[42];
mat4 value = data.values[42] * data.matrices[1];
}
)");

ExpectNZSL(*shaderModule, R"(
[tag("DataStruct")]
struct Data
{
[tag("Values")] values: array[f32, 47]
[tag("Values")] values: array[f32, 47],
matrices: array[mat4[f32], 3]
}
external
Expand All @@ -229,41 +232,56 @@ external
[entry(frag)]
fn main()
{
let value: f32 = data.values[42];
let value: mat4[f32] = data.values[42] * data.matrices[1];
})");

ExpectSPIRV(*shaderModule, R"(
OpSource SourceLanguage(Unknown) 100
OpName %5 "Data"
OpMemberName %5 0 "values"
OpName %7 "data"
OpName %15 "main"
OpDecorate %7 Decoration(Binding) 0
OpDecorate %7 Decoration(DescriptorSet) 0
OpName %9 "Data"
OpMemberName %9 0 "values"
OpMemberName %9 1 "matrices"
OpName %11 "data"
OpName %21 "main"
OpDecorate %11 Decoration(Binding) 0
OpDecorate %11 Decoration(DescriptorSet) 0
OpDecorate %4 Decoration(ArrayStride) 16
OpDecorate %5 Decoration(Block)
OpMemberDecorate %5 0 Decoration(Offset) 0
OpDecorate %8 Decoration(ArrayStride) 64
OpDecorate %9 Decoration(Block)
OpMemberDecorate %9 0 Decoration(Offset) 0
OpMemberDecorate %9 1 Decoration(ColMajor)
OpMemberDecorate %9 1 Decoration(MatrixStride) 16
OpMemberDecorate %9 1 Decoration(Offset) 752
%1 = OpTypeFloat 32
%2 = OpTypeInt 32 0
%3 = OpConstant %2 u32(47)
%4 = OpTypeArray %1 %3
%5 = OpTypeStruct %4
%6 = OpTypePointer StorageClass(Uniform) %5
%8 = OpTypeVoid
%9 = OpTypeFunction %8
%10 = OpTypeInt 32 1
%11 = OpConstant %10 i32(0)
%12 = OpTypeArray %1 %3
%13 = OpConstant %10 i32(42)
%14 = OpTypePointer StorageClass(Function) %1
%18 = OpTypePointer StorageClass(Uniform) %1
%7 = OpVariable %6 StorageClass(Uniform)
%15 = OpFunction %8 FunctionControl(0) %9
%16 = OpLabel
%17 = OpVariable %14 StorageClass(Function)
%19 = OpAccessChain %18 %7 %11 %13
%20 = OpLoad %1 %19
OpStore %17 %20
%5 = OpTypeVector %1 4
%6 = OpTypeMatrix %5 4
%7 = OpConstant %2 u32(3)
%8 = OpTypeArray %6 %7
%9 = OpTypeStruct %4 %8
%10 = OpTypePointer StorageClass(Uniform) %9
%12 = OpTypeVoid
%13 = OpTypeFunction %12
%14 = OpTypeInt 32 1
%15 = OpConstant %14 i32(0)
%16 = OpTypeArray %1 %3
%17 = OpConstant %14 i32(42)
%18 = OpConstant %14 i32(1)
%19 = OpTypeArray %6 %7
%20 = OpTypePointer StorageClass(Function) %6
%24 = OpTypePointer StorageClass(Uniform) %1
%27 = OpTypePointer StorageClass(Uniform) %6
%11 = OpVariable %10 StorageClass(Uniform)
%21 = OpFunction %12 FunctionControl(0) %13
%22 = OpLabel
%23 = OpVariable %20 StorageClass(Function)
%25 = OpAccessChain %24 %11 %15 %17
%26 = OpLoad %1 %25
%28 = OpAccessChain %27 %11 %18 %18
%29 = OpLoad %6 %28
%30 = OpMatrixTimesScalar %6 %29 %26
OpStore %23 %30
OpReturn
OpFunctionEnd)", {}, {}, true);
}
Expand Down

0 comments on commit e028a84

Please sign in to comment.