Skip to content

Commit d89c352

Browse files
Quetzonarchdbudanov-cmplr
authored andcommitted
Dot product bugfix to include more floating point types (#1578)
Switched the visitCallDot check to use isFloatingPointTy for scalar floating point operands. Bugfix for previous change regarding integer dot product. Original commit: KhronosGroup/SPIRV-LLVM-Translator@71e01b5
1 parent 35ff0e4 commit d89c352

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

llvm-spirv/lib/SPIRV/OCLToSPIRV.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,7 @@ void OCLToSPIRVBase::visitCallInst(CallInst &CI) {
323323
return;
324324
}
325325
if (DemangledName == kOCLBuiltinName::Dot &&
326-
(CI.getOperand(0)->getType()->isFloatTy() ||
327-
CI.getOperand(1)->getType()->isDoubleTy())) {
326+
CI.getOperand(0)->getType()->isFloatingPointTy()) {
328327
visitCallDot(&CI);
329328
return;
330329
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv -s %t.bc -o %t.regularized.bc
3+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_integer_dot_product -o %t-spirv.spv
4+
; RUN: spirv-val %t-spirv.spv
5+
; RUN: llvm-dis %t.regularized.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM
6+
; RUN: llvm-spirv %t.bc -spirv-text --spirv-ext=+SPV_KHR_integer_dot_product -o - | FileCheck %s --check-prefix=CHECK-SPIRV
7+
8+
;CHECK-LLVM: fmul
9+
10+
;CHECK-SPIRV: FMul
11+
12+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
13+
target triple = "spir"
14+
15+
; Function Attrs: convergent norecurse nounwind
16+
define spir_kernel void @test1(half %ha, half %hb) local_unnamed_addr #0 !kernel_arg_addr_space !3 !kernel_arg_access_qual !4 !kernel_arg_type !5 !kernel_arg_base_type !6 !kernel_arg_type_qual !7 {
17+
entry:
18+
%call = tail call spir_func half @_Z3dotDhDh(half %ha, half %hb) #2
19+
ret void
20+
}
21+
22+
; Function Attrs: convergent
23+
declare spir_func half @_Z3dotDhDh(half, half) local_unnamed_addr #1
24+
25+
attributes #0 = { convergent norecurse nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pocharer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="128" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "uniform-work-group-size"="false" "unsafe-fp-math"="false" "use-soft-float"="false" }
26+
attributes #1 = { convergent "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pocharer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
27+
attributes #2 = { convergent nounwind }
28+
29+
!llvm.module.flags = !{!0}
30+
!opencl.ocl.version = !{!1}
31+
!opencl.spir.version = !{!1}
32+
!llvm.ident = !{!2}
33+
34+
!0 = !{i32 1, !"wchar_size", i32 4}
35+
!1 = !{i32 2, i32 0}
36+
!2 = !{!"clang version 11.0.0 (https://github.com/c199914007/llvm.git 8b94769313ca84cb9370b60ed008501ff692cb71)"}
37+
!3 = !{i32 0, i32 0}
38+
!4 = !{!"none", !"none"}
39+
!5 = !{!"half", !"half"}
40+
!6 = !{!"half", !"half"}
41+
!7 = !{!"", !""}

0 commit comments

Comments
 (0)