Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HLSL] Implement the smoothstep intrinsic #132288

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/BuiltinsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,9 @@ def SPIRVReflect : Builtin {
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

def SPIRVSmoothStep : Builtin {
let Spellings = ["__builtin_spirv_smoothstep"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
let Prototype = "void(...)";
}
13 changes: 13 additions & 0 deletions clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,19 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
}
case SPIRV::BI__builtin_spirv_smoothstep: {
Value *Min = EmitScalarExpr(E->getArg(0));
Value *Max = EmitScalarExpr(E->getArg(1));
Value *X = EmitScalarExpr(E->getArg(2));
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
E->getArg(1)->getType()->hasFloatingRepresentation() &&
E->getArg(2)->getType()->hasFloatingRepresentation() &&
"SmoothStep operands must have a float representation");
return Builder.CreateIntrinsic(
/*ReturnType=*/Min->getType(), Intrinsic::spv_smoothstep,
ArrayRef<Value *>{Min, Max, X}, /*FMFSource=*/nullptr,
"spv.smoothstep");
}
}
return nullptr;
}
20 changes: 20 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,26 @@ constexpr vector<T, N> fmod_vec_impl(vector<T, N> X, vector<T, N> Y) {
#endif
}

template <typename T> constexpr T smoothstep_impl(T Min, T Max, T X) {
#if (__has_builtin(__builtin_spirv_smoothstep))
return __builtin_spirv_smoothstep(Min, Max, X);
#else
T S = saturate((X - Min) / (Max - Min));
return (3 - 2 * S) * S * S;
#endif
}

template <typename T, int N>
constexpr vector<T, N> smoothstep_vec_impl(vector<T, N> Min, vector<T, N> Max,
vector<T, N> X) {
#if (__has_builtin(__builtin_spirv_smoothstep))
return __builtin_spirv_smoothstep(Min, Max, X);
#else
vector<T, N> S = saturate((X - Min) / (Max - Min));
return (3 - 2 * S) * S * S;
#endif
}

} // namespace __detail
} // namespace hlsl

Expand Down
48 changes: 48 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,5 +322,53 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
__detail::HLSL_FIXED_VECTOR<float, L> N) {
return __detail::reflect_vec_impl(I, N);
}

//===----------------------------------------------------------------------===//
// smoothstep builtin
//===----------------------------------------------------------------------===//

/// \fn T smoothstep(T Min, T Max, T X)
/// \brief Returns a smooth Hermite interpolation between 0 and 1, if \a X is in
/// the range [\a Min, \a Max].
/// \param Min The minimum range of the x parameter.
/// \param Max The maximum range of the x parameter.
/// \param X The specified value to be interpolated.
///
/// The return value is 0.0 if \a X ≤ \a Min and 1.0 if \a X ≥ \a Max. When \a
/// Min < \a X < \a Max, the function performs smooth Hermite interpolation
/// between 0 and 1.

template <typename T>
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
__detail::is_same<half, T>::value,
T> smoothstep(T Min, T Max, T X) {
return __detail::smoothstep_impl(Min, Max, X);
}

template <typename T>
const inline __detail::enable_if_t<
__detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
smoothstep(T Min, T Max, T X) {
return __detail::smoothstep_impl(Min, Max, X);
}

template <int N>
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
const inline __detail::HLSL_FIXED_VECTOR<half, N> smoothstep(
__detail::HLSL_FIXED_VECTOR<half, N> Min,
__detail::HLSL_FIXED_VECTOR<half, N> Max,
__detail::HLSL_FIXED_VECTOR<half, N> X) {
return __detail::smoothstep_vec_impl(Min, Max, X);
}

template <int N>
const inline __detail::HLSL_FIXED_VECTOR<float, N>
smoothstep(__detail::HLSL_FIXED_VECTOR<float, N> Min,
__detail::HLSL_FIXED_VECTOR<float, N> Max,
__detail::HLSL_FIXED_VECTOR<float, N> X) {
return __detail::smoothstep_vec_impl(Min, Max, X);
}

} // namespace hlsl
#endif //_HLSL_HLSL_INTRINSICS_H_
36 changes: 36 additions & 0 deletions clang/lib/Sema/SemaSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,42 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_smoothstep: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;

// check if the all arguments have floating representation
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
ExprResult Arg = TheCall->getArg(i);
QualType ArgTy = Arg.get()->getType();
if (!ArgTy->hasFloatingRepresentation()) {
SemaRef.Diag(Arg.get()->getBeginLoc(),
diag::err_builtin_invalid_arg_type)
<< i + 1 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
<< ArgTy;
return true;
}
}

// check if all arguments are of the same type
ExprResult A = TheCall->getArg(0);
ExprResult B = TheCall->getArg(1);
ExprResult C = TheCall->getArg(2);
if (!(SemaRef.getASTContext().hasSameUnqualifiedType(A.get()->getType(),
B.get()->getType()) &&
SemaRef.getASTContext().hasSameUnqualifiedType(A.get()->getType(),
C.get()->getType()))) {
SemaRef.Diag(TheCall->getBeginLoc(),
diag::err_vec_builtin_incompatible_vector)
<< TheCall->getDirectCallee() << /*useAllTerminology*/ true
<< SourceRange(A.get()->getBeginLoc(), C.get()->getEndLoc());
return true;
}

QualType RetTy = A.get()->getType();
TheCall->setType(RetTy);
break;
}
}
return false;
}
Expand Down
Loading
Loading