Skip to content

Commit b483dda

Browse files
committed
Fix Metal accuracy problem caused by <dtype>3 vectors usage
On example of float3 datatype: Using of float3 data type for loading of data cuncurrently into dense array shared between all threads in Metal threading group can lead to data race between threads. float3 datatype has size and and alignment eq to 16 bytes while kernel assumes to copy 12 bytes in arbitrary not aligned places. Using of packed_float3 datatypes solves the issue
1 parent f38ae65 commit b483dda

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

src/target/source/codegen_metal.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,17 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
178178
}
179179
bool fail = false;
180180
if (t.is_float()) {
181+
// Need to care about sizes and alignment of half3/float3 because tir representation might not
182+
// be aware of Metal half3/float3 details and can treat them as just three elements,
183+
// while sizes and alignmnents of half3/float3 are one element more (half3-8 bytes/
184+
// float13 - 16bytes).
185+
// Example of problematic pattern: filling of threadgroup packed array using float3 elements
186+
// by threads concurrently can lead to datarace and wrong data in threadgroup shared array.
187+
// packed_(half3/float3) are exactly datatypes dealing with 3 elements and per-element
188+
// alignment
189+
if (lanes == 3) {
190+
os << "packed_";
191+
}
181192
switch (t.bits()) {
182193
case 16:
183194
os << "half";

0 commit comments

Comments
 (0)