Skip to content

Commit

Permalink
Fix array size calculation (#5463)
Browse files Browse the repository at this point in the history
The function that get the number of elements in a composite variable
returns an incorrect values for the arrays. This is fixed, so that it
returns the correct number of elements for arrays where the number of
elements is represented as a 32-bit integer and is known at compile
time.

Fixes #4953
  • Loading branch information
s-perron authored Nov 2, 2023
1 parent eacc969 commit 9e7a1f2
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 6 deletions.
15 changes: 11 additions & 4 deletions source/opt/folding_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2067,7 +2067,8 @@ FoldingRule FMixFeedingExtract() {
}

// Returns the number of elements in the composite type |type|. Returns 0 if
// |type| is a scalar value.
// |type| is a scalar value. Return UINT32_MAX when the size is unknown at
// compile time.
uint32_t GetNumberOfElements(const analysis::Type* type) {
if (auto* vector_type = type->AsVector()) {
return vector_type->element_count();
Expand All @@ -2079,21 +2080,27 @@ uint32_t GetNumberOfElements(const analysis::Type* type) {
return static_cast<uint32_t>(struct_type->element_types().size());
}
if (auto* array_type = type->AsArray()) {
return array_type->length_info().words[0];
if (array_type->length_info().words[0] ==
analysis::Array::LengthInfo::kConstant &&
array_type->length_info().words.size() == 2) {
return array_type->length_info().words[1];
}
return UINT32_MAX;
}
return 0;
}

// Returns a map with the set of values that were inserted into an object by
// the chain of OpCompositeInsertInstruction starting with |inst|.
// The map will map the index to the value inserted at that index.
// The map will map the index to the value inserted at that index. An empty map
// will be returned if the map could not be properly generated.
std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) {
analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
std::map<uint32_t, uint32_t> values_inserted;
Instruction* current_inst = inst;
while (current_inst->opcode() == spv::Op::OpCompositeInsert) {
if (current_inst->NumInOperands() > inst->NumInOperands()) {
// This is the catch the case
// This is to catch the case
// %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0
// %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0
// %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1
Expand Down
23 changes: 21 additions & 2 deletions test/opt/fold_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ OpName %main "main"
%_ptr_struct_v2int_int_int = OpTypePointer Function %struct_v2int_int_int
%_ptr_v2float = OpTypePointer Function %v2float
%_ptr_v2double = OpTypePointer Function %v2double
%int_2 = OpConstant %int 2
%int_arr_2 = OpTypeArray %int %int_2
%short_0 = OpConstant %short 0
%short_2 = OpConstant %short 2
%short_3 = OpConstant %short 3
Expand All @@ -185,7 +187,6 @@ OpName %main "main"
%103 = OpConstant %int 7 ; Need a def with an numerical id to define id maps.
%int_0 = OpConstant %int 0
%int_1 = OpConstant %int 1
%int_2 = OpConstant %int 2
%int_3 = OpConstant %int 3
%int_4 = OpConstant %int 4
%int_10 = OpConstant %int 10
Expand Down Expand Up @@ -323,6 +324,7 @@ OpName %main "main"
%short_0x4400 = OpConstant %short 0x4400
%ushort_0xBC00 = OpConstant %ushort 0xBC00
%short_0xBC00 = OpConstant %short 0xBC00
%int_arr_2_undef = OpUndef %int_arr_2
)";

return header;
Expand Down Expand Up @@ -7648,7 +7650,24 @@ ::testing::Values(
"%4 = OpCompositeExtract %int %struct_v2int_int_int 3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, false)
4, false),
// Test case 18: Fold when every element of an array is inserted.
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
"; CHECK-DAG: [[arr_type:%\\w+]] = OpTypeArray [[int]] [[int2]]\n" +
"; CHECK-DAG: [[int10:%\\w+]] = OpConstant [[int]] 10\n" +
"; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
"; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[arr_type]] [[int10]] [[int1]]\n" +
"; CHECK: %5 = OpCopyObject [[arr_type]] [[construct]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%4 = OpCompositeInsert %int_arr_2 %int_10 %int_arr_2_undef 0\n" +
"%5 = OpCompositeInsert %int_arr_2 %int_1 %4 1\n" +
"OpReturn\n" +
"OpFunctionEnd",
5, true)
));

INSTANTIATE_TEST_SUITE_P(DotProductMatchingTest, MatchingInstructionFoldingTest,
Expand Down

0 comments on commit 9e7a1f2

Please sign in to comment.