Skip to content

[HLSL][DXIL] Implement refract intrinsic #136026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

raoanag
Copy link

@raoanag raoanag commented Apr 16, 2025

  • Implement refract using HLSL source in hlsl_intrinsics.h
  • Implement the refract SPIR-V target built-in in clang/include/clang/Basic/BuiltinsSPIRV.td
  • Add sema checks for refract to CheckSPIRVBuiltinFunctionCall in clang/lib/Sema/SemaSPIRV.cpp
  • Add codegen for spv refract to EmitSPIRVBuiltinExpr in CGBuiltin.cpp
  • Add codegen tests to clang/test/CodeGenHLSL/builtins/refract.hlsl
  • Add spv codegen test to clang/test/CodeGenSPIRV/Builtins/refract.c
  • Add sema tests to clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
  • Add spv sema tests to clang/test/SemaSPIRV/BuiltIns/refract-errors.c
  • Create the int_spv_refract intrinsic in IntrinsicsSPIRV.td
  • In SPIRVInstructionSelector.cpp create the refract lowering and map it to int_spv_refract in SPIRVInstructionSelector::selectIntrinsic.
  • Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
  • Check for what OpenCL support is needed.

Resolves #99153

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@raoanag raoanag force-pushed the user/raoanag/refract branch from a1ccf10 to dff5181 Compare April 29, 2025 00:16
@raoanag raoanag marked this pull request as ready for review April 29, 2025 00:21
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang:codegen IR generation bugs: mangling, exceptions, etc. HLSL HLSL Language Support backend:SPIR-V llvm:ir labels Apr 29, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 29, 2025

@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-hlsl
@llvm/pr-subscribers-backend-spir-v

@llvm/pr-subscribers-clang

Author: None (raoanag)

Changes
  • Implement refract using HLSL source in hlsl_intrinsics.h

  • Implement the refract SPIR-V target built-in in clang/include/clang/Basic/BuiltinsSPIRV.td

  • Add sema checks for refract to CheckSPIRVBuiltinFunctionCall in clang/lib/Sema/SemaSPIRV.cpp

  • Add codegen for spv refract to EmitSPIRVBuiltinExpr in CGBuiltin.cpp

  • Add codegen tests to clang/test/CodeGenHLSL/builtins/refract.hlsl

  • Add spv codegen test to clang/test/CodeGenSPIRV/Builtins/refract.c

  • Add sema tests to clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl

  • Add spv sema tests to clang/test/SemaSPIRV/BuiltIns/refract-errors.c

  • Create the int_spv_refract intrinsic in IntrinsicsSPIRV.td

  • In SPIRVInstructionSelector.cpp create the refract lowering and map it to int_spv_refract in SPIRVInstructionSelector::selectIntrinsic.

  • Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll

  • Check for what OpenCL support is needed.

Resolves #99153


Patch is 53.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136026.diff

12 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsSPIRV.td (+6)
  • (modified) clang/lib/CodeGen/TargetBuiltins/SPIR.cpp (+15)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h (+21)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+59)
  • (modified) clang/lib/Sema/SemaSPIRV.cpp (+44-1)
  • (added) clang/test/CodeGenHLSL/builtins/refract.hlsl (+356)
  • (added) clang/test/CodeGenSPIRV/Builtins/refract.c (+34)
  • (added) clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl (+74)
  • (added) clang/test/SemaSPIRV/BuiltIns/refract-errors.c (+28)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+2)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll (+37)
diff --git a/clang/include/clang/Basic/BuiltinsSPIRV.td b/clang/include/clang/Basic/BuiltinsSPIRV.td
index 9f76d672cc7ce..c0f652b4f24e4 100644
--- a/clang/include/clang/Basic/BuiltinsSPIRV.td
+++ b/clang/include/clang/Basic/BuiltinsSPIRV.td
@@ -26,6 +26,12 @@ def SPIRVReflect : Builtin {
   let Prototype = "void(...)";
 }
 
+def SPIRVRefract : Builtin {
+  let Spellings = ["__builtin_spirv_refract"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def SPIRVSmoothStep : Builtin {
   let Spellings = ["__builtin_spirv_smoothstep"];
   let Attributes = [NoThrow, Const, CustomTypeChecking];
diff --git a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
index 92e2c1c6da68f..5fedb9553699f 100644
--- a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
@@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
         /*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
         ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
   }
+  case SPIRV::BI__builtin_spirv_refract: {
+    Value *I = EmitScalarExpr(E->getArg(0));
+    Value *N = EmitScalarExpr(E->getArg(1));
+    Value *eta = EmitScalarExpr(E->getArg(2));
+    assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+           E->getArg(1)->getType()->hasFloatingRepresentation() &&
+           E->getArg(2)->getType()->hasFloatingRepresentation() &&
+           "refract operands must have a float representation");
+    assert(E->getArg(0)->getType()->isVectorType() &&
+           E->getArg(1)->getType()->isVectorType() &&
+           "refract I and N operands must be a vector");
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/I->getType(), Intrinsic::spv_refract,
+        ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
+  }
   case SPIRV::BI__builtin_spirv_smoothstep: {
     Value *Min = EmitScalarExpr(E->getArg(0));
     Value *Max = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 3a8a9b6fa2a45..9a320a78453ac 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -71,6 +71,27 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 #endif
 }
 
+template <typename T> constexpr T refract_impl(T I, T N, T eta) {
+  T k = 1 - eta * eta * (1 - (N * I * N *I));
+  if(k < 0)
+    return 0;
+  else
+    return (eta * I - (eta * N * I + sqrt(k)) * N);
+}
+
+template <typename T, int L>
+constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T eta) {
+#if (__has_builtin(__builtin_spirv_refract))
+  return __builtin_spirv_refract(I, N, eta);
+#else
+  vector<T, L> k = 1 - eta * eta * (1 - dot(N, I) * dot(N, I));
+  if(k < 0)
+    return 0;
+  else
+    return (eta * I - (eta * dot(N, I) + sqrt(k)) * N);
+#endif
+}
+
 template <typename T> constexpr T fmod_impl(T X, T Y) {
 #if !defined(__DIRECTX__)
   return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 35ff80052cf43..bb5b770b4141a 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -386,6 +386,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
   return __detail::reflect_vec_impl(I, N);
 }
 
+//===----------------------------------------------------------------------===//
+// refract builtin
+//===----------------------------------------------------------------------===//
+
+/// \fn T refract(T I, T N, T eta)
+/// \brief Returns a refraction using an entering ray, \a I, a surface
+/// normal, \a N and refraction index \a eta
+/// \param I The entering ray.
+/// \param N The surface normal.
+/// \param eta The refraction index.
+///
+/// The return value is a floating-point vector that represents the refraction
+/// using the refraction index, \a eta, for the direction of the entering ray, \a I,
+/// off a surface with the normal \a N.
+///
+/// This function calculates the refraction vector using the following formulas:
+/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
+/// if k < 0.0 the result is 0.0
+/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
+///
+/// I and N must already be normalized in order to achieve the desired result.
+///
+/// I and N must be a scalar or vector whose component type is
+/// floating-point.
+///
+/// eta must be a 16-bit or 32-bit floating-point scalar.
+///
+/// Result type, the type of I, and the type of N must all be the same type.
+
+template <typename T>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
+                                       __detail::is_same<half, T>::value,
+                                   T> refract(T I, T N, T eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+template <typename T>
+const inline __detail::enable_if_t<
+    __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
+    refract(T I, T N, T eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+template <int L>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
+    __detail::HLSL_FIXED_VECTOR<half, L> I,
+    __detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
+  return __detail::refract_vec_impl(I, N, eta);
+}
+
+template <int L>
+const inline __detail::HLSL_FIXED_VECTOR<float, L>
+refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
+        __detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
+  return __detail::refract_vec_impl(I, N, eta);
+}
+
 //===----------------------------------------------------------------------===//
 // smoothstep builtin
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 7131514d53421..2ad2089323cc3 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -69,6 +69,49 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
     TheCall->setType(RetTy);
     break;
   }
+  case SPIRV::BI__builtin_spirv_refract: {
+    if (SemaRef.checkArgCount(TheCall, 3))
+      return true;
+
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+    auto *VTyA = ArgTyA->getAs<VectorType>();
+    if (VTyA == nullptr) {
+      SemaRef.Diag(A.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyA
+          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+
+    ExprResult B = TheCall->getArg(1);
+    QualType ArgTyB = B.get()->getType();
+    auto *VTyB = ArgTyB->getAs<VectorType>();
+    if (VTyB == nullptr) {
+      SemaRef.Diag(B.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyB
+          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+
+    ExprResult C = TheCall->getArg(2);
+    QualType ArgTyC = C.get()->getType();
+    if (!ArgTyC->hasFloatingRepresentation()) {
+      SemaRef.Diag(C.get()->getBeginLoc(),
+                   diag::err_builtin_invalid_arg_type)
+          << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
+          << ArgTyC;
+      return true;
+    }
+
+    QualType RetTy = ArgTyA;
+    TheCall->setType(RetTy);
+    assert(RetTy == ArgTyA);
+    break;
+  }
   case SPIRV::BI__builtin_spirv_reflect: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
@@ -89,7 +132,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
     QualType ArgTyB = B.get()->getType();
     auto *VTyB = ArgTyB->getAs<VectorType>();
     if (VTyB == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
+      SemaRef.Diag(B.get()->getBeginLoc(),
                    diag::err_typecheck_convert_incompatible)
           << ArgTyB
           << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
diff --git a/clang/test/CodeGenHLSL/builtins/refract.hlsl b/clang/test/CodeGenHLSL/builtins/refract.hlsl
new file mode 100644
index 0000000000000..a2e160f17b582
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/refract.hlsl
@@ -0,0 +1,356 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN:   -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPVCHECK
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// CHECK-NEXT:    [[SUB1_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// CHECK-NEXT:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB1_I]]
+// CHECK-NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I]]
+// CHECK-NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// CHECK-NEXT:    br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// CHECK:  if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// CHECK-NEXT:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I]], [[ETA]]
+// CHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL8_I]]
+// CHECK-NEXT:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// CHECK-NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I]], [[MUL9_I]]
+// CHECK-NEXT:    br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// CHECK:  _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// CHECK-NEXT:    ret half [[RETVAL_0_I]]
+//
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] 
+// SPVCHECK-NEXT:  [[ENTRY:.*:]]
+// SPVCHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// SPVCHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// SPVCHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// SPVCHECK-NEXT:    [[MUL_4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// SPVCHECK-NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_4_I]]
+// SPVCHECK-NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// SPVCHECK-NEXT:    br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// SPVCHECK:  if.else.i:                                        ; preds = %entry
+// SPVCHECK-NEXT:    [[MUL_6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// SPVCHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT:    [[MUL_8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_7_I]], [[ETA]]
+// SPVCHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// SPVCHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL_8_I]]
+// SPVCHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// SPVCHECK-NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL_6_I]], [[MUL_9_I]]
+// SPVCHECK-NEXT:    br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// SPVCHECK:  _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// SPVCHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// SPVCHECK-NEXT:    ret half [[RETVAL_0_I]]
+//
+half test_refract_half(half I, half N, half ETA) {
+    return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_
+// CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]], [[CAST_VTRUNC]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> [[N]], <2 x half> [[I]])
+// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK:   if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <2 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <2 x half> [[ETA]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[CAST_VTRUNC]]
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <2 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT10_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.sqrt.v2f16(<2 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK-NEXT:    br label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit
+// CHECK:  _ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <2 x half> [ %sub13.i, %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT:    ret <2 x half> [[RETVAL_0_I]]
+
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_(
+// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// SPVCHECK-NEXT:  [[ENTRY:.*:]]
+// SPVCHECK-NEXT:    [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]] to double
+// SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f64(<2 x half> [[I]], <2 x half> [[N]], double [[CONV_I]])
+// SPVCHECK-NEXT:    ret <2 x half> [[SPV_REFRACT_I]]
+//
+half2 test_refract_half2(half2 I, half2 N, half2 ETA) {
+    return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
+// CHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> [[N]], <3 x half> [[I]])
+// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK:   if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <3 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <3 x half> poison, half [[ETA]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT5_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[ETA]]
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <3 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT10_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <3 x half> @llvm.sqrt.v3f16(<3 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT:   [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    br label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit
+// CHECK:  _ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <3 x half> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT:    ret <3 x half> [[RETVAL_0_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofp...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 29, 2025

@llvm/pr-subscribers-clang-codegen

Author: None (raoanag)

Changes
  • Implement refract using HLSL source in hlsl_intrinsics.h

  • Implement the refract SPIR-V target built-in in clang/include/clang/Basic/BuiltinsSPIRV.td

  • Add sema checks for refract to CheckSPIRVBuiltinFunctionCall in clang/lib/Sema/SemaSPIRV.cpp

  • Add codegen for spv refract to EmitSPIRVBuiltinExpr in CGBuiltin.cpp

  • Add codegen tests to clang/test/CodeGenHLSL/builtins/refract.hlsl

  • Add spv codegen test to clang/test/CodeGenSPIRV/Builtins/refract.c

  • Add sema tests to clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl

  • Add spv sema tests to clang/test/SemaSPIRV/BuiltIns/refract-errors.c

  • Create the int_spv_refract intrinsic in IntrinsicsSPIRV.td

  • In SPIRVInstructionSelector.cpp create the refract lowering and map it to int_spv_refract in SPIRVInstructionSelector::selectIntrinsic.

  • Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll

  • Check for what OpenCL support is needed.

Resolves #99153


Patch is 53.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136026.diff

12 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsSPIRV.td (+6)
  • (modified) clang/lib/CodeGen/TargetBuiltins/SPIR.cpp (+15)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h (+21)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+59)
  • (modified) clang/lib/Sema/SemaSPIRV.cpp (+44-1)
  • (added) clang/test/CodeGenHLSL/builtins/refract.hlsl (+356)
  • (added) clang/test/CodeGenSPIRV/Builtins/refract.c (+34)
  • (added) clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl (+74)
  • (added) clang/test/SemaSPIRV/BuiltIns/refract-errors.c (+28)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+2)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll (+37)
diff --git a/clang/include/clang/Basic/BuiltinsSPIRV.td b/clang/include/clang/Basic/BuiltinsSPIRV.td
index 9f76d672cc7ce..c0f652b4f24e4 100644
--- a/clang/include/clang/Basic/BuiltinsSPIRV.td
+++ b/clang/include/clang/Basic/BuiltinsSPIRV.td
@@ -26,6 +26,12 @@ def SPIRVReflect : Builtin {
   let Prototype = "void(...)";
 }
 
+def SPIRVRefract : Builtin {
+  let Spellings = ["__builtin_spirv_refract"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def SPIRVSmoothStep : Builtin {
   let Spellings = ["__builtin_spirv_smoothstep"];
   let Attributes = [NoThrow, Const, CustomTypeChecking];
diff --git a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
index 92e2c1c6da68f..5fedb9553699f 100644
--- a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
@@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
         /*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
         ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
   }
+  case SPIRV::BI__builtin_spirv_refract: {
+    Value *I = EmitScalarExpr(E->getArg(0));
+    Value *N = EmitScalarExpr(E->getArg(1));
+    Value *eta = EmitScalarExpr(E->getArg(2));
+    assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+           E->getArg(1)->getType()->hasFloatingRepresentation() &&
+           E->getArg(2)->getType()->hasFloatingRepresentation() &&
+           "refract operands must have a float representation");
+    assert(E->getArg(0)->getType()->isVectorType() &&
+           E->getArg(1)->getType()->isVectorType() &&
+           "refract I and N operands must be a vector");
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/I->getType(), Intrinsic::spv_refract,
+        ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
+  }
   case SPIRV::BI__builtin_spirv_smoothstep: {
     Value *Min = EmitScalarExpr(E->getArg(0));
     Value *Max = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 3a8a9b6fa2a45..9a320a78453ac 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -71,6 +71,27 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 #endif
 }
 
+template <typename T> constexpr T refract_impl(T I, T N, T eta) {
+  T k = 1 - eta * eta * (1 - (N * I * N *I));
+  if(k < 0)
+    return 0;
+  else
+    return (eta * I - (eta * N * I + sqrt(k)) * N);
+}
+
+template <typename T, int L>
+constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T eta) {
+#if (__has_builtin(__builtin_spirv_refract))
+  return __builtin_spirv_refract(I, N, eta);
+#else
+  vector<T, L> k = 1 - eta * eta * (1 - dot(N, I) * dot(N, I));
+  if(k < 0)
+    return 0;
+  else
+    return (eta * I - (eta * dot(N, I) + sqrt(k)) * N);
+#endif
+}
+
 template <typename T> constexpr T fmod_impl(T X, T Y) {
 #if !defined(__DIRECTX__)
   return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 35ff80052cf43..bb5b770b4141a 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -386,6 +386,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
   return __detail::reflect_vec_impl(I, N);
 }
 
+//===----------------------------------------------------------------------===//
+// refract builtin
+//===----------------------------------------------------------------------===//
+
+/// \fn T refract(T I, T N, T eta)
+/// \brief Returns a refraction using an entering ray, \a I, a surface
+/// normal, \a N and refraction index \a eta
+/// \param I The entering ray.
+/// \param N The surface normal.
+/// \param eta The refraction index.
+///
+/// The return value is a floating-point vector that represents the refraction
+/// using the refraction index, \a eta, for the direction of the entering ray, \a I,
+/// off a surface with the normal \a N.
+///
+/// This function calculates the refraction vector using the following formulas:
+/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
+/// if k < 0.0 the result is 0.0
+/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
+///
+/// I and N must already be normalized in order to achieve the desired result.
+///
+/// I and N must be a scalar or vector whose component type is
+/// floating-point.
+///
+/// eta must be a 16-bit or 32-bit floating-point scalar.
+///
+/// Result type, the type of I, and the type of N must all be the same type.
+
+template <typename T>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
+                                       __detail::is_same<half, T>::value,
+                                   T> refract(T I, T N, T eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+template <typename T>
+const inline __detail::enable_if_t<
+    __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
+    refract(T I, T N, T eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+template <int L>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
+    __detail::HLSL_FIXED_VECTOR<half, L> I,
+    __detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
+  return __detail::refract_vec_impl(I, N, eta);
+}
+
+template <int L>
+const inline __detail::HLSL_FIXED_VECTOR<float, L>
+refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
+        __detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
+  return __detail::refract_vec_impl(I, N, eta);
+}
+
 //===----------------------------------------------------------------------===//
 // smoothstep builtin
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 7131514d53421..2ad2089323cc3 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -69,6 +69,49 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
     TheCall->setType(RetTy);
     break;
   }
+  case SPIRV::BI__builtin_spirv_refract: {
+    if (SemaRef.checkArgCount(TheCall, 3))
+      return true;
+
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+    auto *VTyA = ArgTyA->getAs<VectorType>();
+    if (VTyA == nullptr) {
+      SemaRef.Diag(A.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyA
+          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+
+    ExprResult B = TheCall->getArg(1);
+    QualType ArgTyB = B.get()->getType();
+    auto *VTyB = ArgTyB->getAs<VectorType>();
+    if (VTyB == nullptr) {
+      SemaRef.Diag(B.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyB
+          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+
+    ExprResult C = TheCall->getArg(2);
+    QualType ArgTyC = C.get()->getType();
+    if (!ArgTyC->hasFloatingRepresentation()) {
+      SemaRef.Diag(C.get()->getBeginLoc(),
+                   diag::err_builtin_invalid_arg_type)
+          << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
+          << ArgTyC;
+      return true;
+    }
+
+    QualType RetTy = ArgTyA;
+    TheCall->setType(RetTy);
+    assert(RetTy == ArgTyA);
+    break;
+  }
   case SPIRV::BI__builtin_spirv_reflect: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
@@ -89,7 +132,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
     QualType ArgTyB = B.get()->getType();
     auto *VTyB = ArgTyB->getAs<VectorType>();
     if (VTyB == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
+      SemaRef.Diag(B.get()->getBeginLoc(),
                    diag::err_typecheck_convert_incompatible)
           << ArgTyB
           << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
diff --git a/clang/test/CodeGenHLSL/builtins/refract.hlsl b/clang/test/CodeGenHLSL/builtins/refract.hlsl
new file mode 100644
index 0000000000000..a2e160f17b582
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/refract.hlsl
@@ -0,0 +1,356 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN:   -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPVCHECK
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// CHECK-NEXT:    [[SUB1_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// CHECK-NEXT:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB1_I]]
+// CHECK-NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I]]
+// CHECK-NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// CHECK-NEXT:    br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// CHECK:  if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// CHECK-NEXT:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I]], [[ETA]]
+// CHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL8_I]]
+// CHECK-NEXT:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// CHECK-NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I]], [[MUL9_I]]
+// CHECK-NEXT:    br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// CHECK:  _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// CHECK-NEXT:    ret half [[RETVAL_0_I]]
+//
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] 
+// SPVCHECK-NEXT:  [[ENTRY:.*:]]
+// SPVCHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// SPVCHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// SPVCHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// SPVCHECK-NEXT:    [[MUL_4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// SPVCHECK-NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_4_I]]
+// SPVCHECK-NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// SPVCHECK-NEXT:    br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// SPVCHECK:  if.else.i:                                        ; preds = %entry
+// SPVCHECK-NEXT:    [[MUL_6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// SPVCHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT:    [[MUL_8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_7_I]], [[ETA]]
+// SPVCHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// SPVCHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL_8_I]]
+// SPVCHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// SPVCHECK-NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL_6_I]], [[MUL_9_I]]
+// SPVCHECK-NEXT:    br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// SPVCHECK:  _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// SPVCHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// SPVCHECK-NEXT:    ret half [[RETVAL_0_I]]
+//
+half test_refract_half(half I, half N, half ETA) {
+    return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_
+// CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]], [[CAST_VTRUNC]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> [[N]], <2 x half> [[I]])
+// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK:   if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <2 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <2 x half> [[ETA]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[CAST_VTRUNC]]
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <2 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT10_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.sqrt.v2f16(<2 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK-NEXT:    br label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit
+// CHECK:  _ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <2 x half> [ %sub13.i, %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT:    ret <2 x half> [[RETVAL_0_I]]
+
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_(
+// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// SPVCHECK-NEXT:  [[ENTRY:.*:]]
+// SPVCHECK-NEXT:    [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]] to double
+// SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f64(<2 x half> [[I]], <2 x half> [[N]], double [[CONV_I]])
+// SPVCHECK-NEXT:    ret <2 x half> [[SPV_REFRACT_I]]
+//
+half2 test_refract_half2(half2 I, half2 N, half2 ETA) {
+    return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
+// CHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> [[N]], <3 x half> [[I]])
+// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK:   if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <3 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <3 x half> poison, half [[ETA]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT5_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[ETA]]
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <3 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT10_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <3 x half> @llvm.sqrt.v3f16(<3 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT:   [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    br label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit
+// CHECK:  _ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <3 x half> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT:    ret <3 x half> [[RETVAL_0_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofp...
[truncated]

@raoanag raoanag changed the title User/raoanag/refract [HLSL][DXIL] Implement refract intrinsic Apr 29, 2025
//
// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]]
// SPVCHECK-NEXT: [[ENTRY:.*:]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the SPIRV and CHECK lines are the same we can combine these by adding a third check that covers the cases where the runlines are the same. There should be examples of this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGenHLSL/builtins/any.hlsl a good example?
It just seems to not have SPVCHECK

@@ -71,6 +71,27 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
#endif
}

template <typename T> constexpr T refract_impl(T I, T N, T eta) {
T k = 1 - eta * eta * (1 - (N * I * N *I));
if(k < 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we can simplify these conditionals to use a select instead.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the latest commit uses select, but the hlsl tests have not been updated, will decide next steps based on offline discuss

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff HEAD~1 HEAD --extensions cpp,c,h -- clang/test/CodeGenSPIRV/Builtins/refract.c clang/test/SemaSPIRV/BuiltIns/refract-errors.c clang/lib/CodeGen/TargetBuiltins/SPIR.cpp clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h clang/lib/Headers/hlsl/hlsl_intrinsics.h clang/lib/Sema/SemaSPIRV.cpp llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
View the diff from clang-format here.
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index e9d376b7f..e076f4ded 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -72,8 +72,8 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 }
 
 template <typename T> constexpr T refract_impl(T I, T N, T eta) {
-  T k = 1 - eta * eta * (1 - (N * I * N *I));
-  if(k < 0)
+  T k = 1 - eta * eta * (1 - (N * I * N * I));
+  if (k < 0)
     return 0;
   else
     return (eta * I - (eta * N * I + sqrt(k)) * N);
@@ -85,7 +85,7 @@ constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T eta) {
   return __builtin_spirv_refract(I, N, eta);
 #else
   vector<T, L> k = 1 - eta * eta * (1 - dot(N, I) * dot(N, I));
-  if(k < 0)
+  if (k < 0)
     return 0;
   else
     return (eta * I - (eta * dot(N, I) + sqrt(k)) * N);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 54fae9f3b..ccfd0b75a 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -445,8 +445,8 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
 /// \param eta The refraction index.
 ///
 /// The return value is a floating-point vector that represents the refraction
-/// using the refraction index, \a eta, for the direction of the entering ray, \a I,
-/// off a surface with the normal \a N.
+/// using the refraction index, \a eta, for the direction of the entering ray,
+/// \a I, off a surface with the normal \a N.
 ///
 /// This function calculates the refraction vector using the following formulas:
 /// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
@@ -473,7 +473,7 @@ const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
 template <typename T>
 const inline __detail::enable_if_t<
     __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
-    refract(T I, T N, T eta) {
+refract(T I, T N, T eta) {
   return __detail::refract_impl(I, N, eta);
 }
 
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index b10c6b767..0f8c8c660 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -117,8 +117,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
     ExprResult C = TheCall->getArg(2);
     QualType ArgTyC = C.get()->getType();
     if (!ArgTyC->hasFloatingRepresentation()) {
-      SemaRef.Diag(C.get()->getBeginLoc(),
-                   diag::err_builtin_invalid_arg_type)
+      SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
           << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
           << ArgTyC;
       return true;

@farzonl
Copy link
Member

farzonl commented Apr 29, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
View the diff from clang-format here.

run git clang-format <git branch or hash before your commits>

Copy link
Member

@farzonl farzonl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests need to be redone. Try not to use any automation tooling to generate your tests. its causing you to miss sublte but important implementation details.

// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
// SPVCHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
// SPVCHECK: [[ENTRY:.*:]]
// SPVCHECK: [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half %{{.*}} to double
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is wrong we should not be casting half to double.

@@ -26,6 +26,12 @@ def SPIRVReflect : Builtin {
let Prototype = "void(...)";
}

def SPIRVRefract : Builtin {
let Spellings = ["__builtin_spirv_refract"];
let Attributes = [NoThrow, Const];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to add CustomTypeChecking.

Suggested change
let Attributes = [NoThrow, Const];
let Attributes = [NoThrow, Const, CustomTypeChecking];

Comment on lines +82 to +87
#if (__has_builtin(__builtin_spirv_refract))
return __builtin_spirv_refract(I, N, Eta);
#else
vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I));
vector<T, L> Result = (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N);
return select<vector<T, L>>(K < 0, vector<T, L>(0), Result);
Copy link
Member

@farzonl farzonl May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our most recent pattern on these has been to do one function when possible instead of a scalar and vector implementation. I beleive this is a case where that might be possible. Can you experiment with the following?

In hlsl_detail.h add

template <typename T>
struct is_vector {
  static const bool value = false;
};

/*NOTE: (don't include this comment) what I am doing here is adding a specialization for vector<T, N>*/
template <typename T, int N>
struct is_vector<vector<T, N>> {
  static const bool value = true;
};

Then line 82-87 changes to

Suggested change
#if (__has_builtin(__builtin_spirv_refract))
return __builtin_spirv_refract(I, N, Eta);
#else
vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I));
vector<T, L> Result = (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N);
return select<vector<T, L>>(K < 0, vector<T, L>(0), Result);
#if (__has_builtin(__builtin_spirv_refract) && is_vector<T>)
return __builtin_spirv_refract(I, N, Eta);
#else
T K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I));
T Result = (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N);
return select(K < 0, 0, Result);

@@ -16,88 +16,90 @@ namespace clang {

SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {}

/// Checks if the first `NumArgsToCheck` arguments of a function call are of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@spall did a bunch of work to get semantic checking to be more consistent. She should review at least this file for this PR.

/// \return `true` if any of the arguments is not a vector type, `false`
/// otherwise.

bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect most cases all the function arguments will be the same type. creae an overload of this function that does not take a NumArgsToCheck and instead just calls this function:

Suggested change
bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall) {
return CheckVectorArgs(theCall, TheCall->getNumArgs());
}
bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {

reflect and distance should use the one that doesn't require you to specify argument size.

@@ -1,4 +1,3 @@
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

undo this change.

#if (__has_builtin(__builtin_spirv_refract))
return __builtin_spirv_refract(I, N, Eta);
#else
vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should really be storing dot(N, I) somewhere so that we aren't repeating the computation 3 times. That should simplify the codegen for the -O0 cases.


float2 test_no_second_arg(float2 p0) {
return __builtin_spirv_refract(p0);
// expected-error@-1 {{too few arguments to function call, expected 3, have 1}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you only need one less than test

@@ -20,7 +20,7 @@ namespace clang {
class SemaSPIRV : public SemaBase {
public:
SemaSPIRV(Sema &S);

bool CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a helper, we don't need to expose it in the class. No one needs to use this outside of spirv and if they did it wouldn't live in the semaSPIRV file it would live somewhere like SemaChecking.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:SPIR-V backend:X86 clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement the refract HLSL Function
4 participants