Skip to content

[HLSL] Allow truncation to scalar #104844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 11, 2024
Merged

Conversation

llvm-beanz
Copy link
Collaborator

HLSL allows implicit conversions to truncate vectors to scalar pr-values. These conversions are scored as vector truncations and should warn appropriately.

This change allows forming a truncation cast to a pr-value, but not an l-value. Truncating a vector to a scalar is performed by loading the first element of the vector and disregarding the remaining elements.

Fixes #102964

HLSL allows implicit conversions to truncate vectors to scalar
pr-values. These conversions are scored as vector truncations and
should warn appropriately.

This change allows forming a truncation cast to a pr-value, but not an
l-value. Truncating a vector to a scalar is performed by loading the
first element of the vector and disregarding the remaining elements.

Fixes llvm#102964
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang:codegen IR generation bugs: mangling, exceptions, etc. HLSL HLSL Language Support labels Aug 19, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 19, 2024

@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-clang-codegen

Author: Chris B (llvm-beanz)

Changes

HLSL allows implicit conversions to truncate vectors to scalar pr-values. These conversions are scored as vector truncations and should warn appropriately.

This change allows forming a truncation cast to a pr-value, but not an l-value. Truncating a vector to a scalar is performed by loading the first element of the vector and disregarding the remaining elements.

Fixes #102964


Patch is 22.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/104844.diff

10 Files Affected:

  • (modified) clang/lib/CodeGen/CGExprScalar.cpp (+11-6)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+1-4)
  • (modified) clang/lib/Sema/SemaExprCXX.cpp (+28-21)
  • (modified) clang/lib/Sema/SemaOverload.cpp (+33-17)
  • (modified) clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl (+24)
  • (modified) clang/test/CodeGenHLSL/builtins/dot.hlsl (-16)
  • (modified) clang/test/CodeGenHLSL/builtins/lerp.hlsl (-21)
  • (modified) clang/test/CodeGenHLSL/builtins/mad.hlsl (-18)
  • (modified) clang/test/SemaHLSL/TruncationOverloadResolution.hlsl (+39-3)
  • (modified) clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzleErrors.hlsl (+5-1)
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 84392745ea6144..2ccdb840579a98 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2692,14 +2692,19 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     return CGF.CGM.createOpenCLIntToSamplerConversion(E, CGF);
 
   case CK_HLSLVectorTruncation: {
-    assert(DestTy->isVectorType() && "Expected dest type to be vector type");
+    assert((DestTy->isVectorType() || DestTy->isBuiltinType()) &&
+           "Destination type must be a vector or builtin type.");
     Value *Vec = Visit(const_cast<Expr *>(E));
-    SmallVector<int, 16> Mask;
-    unsigned NumElts = DestTy->castAs<VectorType>()->getNumElements();
-    for (unsigned I = 0; I != NumElts; ++I)
-      Mask.push_back(I);
+    if (auto *VecTy = DestTy->getAs<VectorType>()) {
+      SmallVector<int, 16> Mask;
+      unsigned NumElts = VecTy->getNumElements();
+      for (unsigned I = 0; I != NumElts; ++I)
+        Mask.push_back(I);
 
-    return Builder.CreateShuffleVector(Vec, Mask, "trunc");
+      return Builder.CreateShuffleVector(Vec, Mask, "trunc");
+    }
+    llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
+    return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
   }
 
   } // end of switch
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 678cdc77f8a71b..be65e149515b9f 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -673,9 +673,6 @@ float dot(float3, float3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
 float dot(float4, float4);
 
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
-double dot(double, double);
-
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
 int dot(int, int);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
@@ -916,7 +913,7 @@ float4 lerp(float4, float4, float4);
 /// \brief Returns the length of the specified floating-point vector.
 /// \param x [in] The vector of floats, or a scalar float.
 ///
-/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + �).
+/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + ...).
 
 _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index 5356bcf172f752..e62e3bcbead3be 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -4301,8 +4301,10 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
 // from type to the elements of the to type without resizing the vector.
 static QualType adjustVectorType(ASTContext &Context, QualType FromTy,
                                  QualType ToType, QualType *ElTy = nullptr) {
-  auto *ToVec = ToType->castAs<VectorType>();
-  QualType ElType = ToVec->getElementType();
+  QualType ElType = ToType;
+  if (auto *ToVec = ToType->getAs<VectorType>())
+    ElType = ToVec->getElementType();
+
   if (ElTy)
     *ElTy = ElType;
   if (!FromTy->isVectorType())
@@ -4463,7 +4465,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
   case ICK_Integral_Conversion: {
     QualType ElTy = ToType;
     QualType StepTy = ToType;
-    if (ToType->isVectorType())
+    if (FromType->isVectorType() || ToType->isVectorType())
       StepTy = adjustVectorType(Context, FromType, ToType, &ElTy);
     if (ElTy->isBooleanType()) {
       assert(FromType->castAs<EnumType>()->getDecl()->isFixed() &&
@@ -4483,7 +4485,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
   case ICK_Floating_Promotion:
   case ICK_Floating_Conversion: {
     QualType StepTy = ToType;
-    if (ToType->isVectorType())
+    if (FromType->isVectorType() || ToType->isVectorType())
       StepTy = adjustVectorType(Context, FromType, ToType);
     From = ImpCastExprToType(From, StepTy, CK_FloatingCast, VK_PRValue,
                              /*BasePath=*/nullptr, CCK)
@@ -4515,7 +4517,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
   case ICK_Floating_Integral: {
     QualType ElTy = ToType;
     QualType StepTy = ToType;
-    if (ToType->isVectorType())
+    if (FromType->isVectorType() || ToType->isVectorType())
       StepTy = adjustVectorType(Context, FromType, ToType, &ElTy);
     if (ElTy->isRealFloatingType())
       From = ImpCastExprToType(From, StepTy, CK_IntegralToFloating, VK_PRValue,
@@ -4656,11 +4658,11 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
     }
     QualType ElTy = FromType;
     QualType StepTy = ToType;
-    if (FromType->isVectorType()) {
-      if (getLangOpts().HLSL)
-        StepTy = adjustVectorType(Context, FromType, ToType);
+    if (FromType->isVectorType())
       ElTy = FromType->castAs<VectorType>()->getElementType();
-    }
+    if (getLangOpts().HLSL &&
+        (FromType->isVectorType() || ToType->isVectorType()))
+      StepTy = adjustVectorType(Context, FromType, ToType);
 
     From = ImpCastExprToType(From, StepTy, ScalarTypeToBooleanCastKind(ElTy),
                              VK_PRValue,
@@ -4815,8 +4817,8 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
     // TODO: Support HLSL matrices.
     assert((!From->getType()->isMatrixType() && !ToType->isMatrixType()) &&
            "Dimension conversion for matrix types is not implemented yet.");
-    assert(ToType->isVectorType() &&
-           "Dimension conversion is only supported for vector types.");
+    assert((ToType->isVectorType() || ToType->isBuiltinType()) &&
+           "Dimension conversion output must be vector or scalar type.");
     switch (SCS.Dimension) {
     case ICK_HLSL_Vector_Splat: {
       // Vector splat from any arithmetic type to a vector.
@@ -4828,18 +4830,23 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
     }
     case ICK_HLSL_Vector_Truncation: {
       // Note: HLSL built-in vectors are ExtVectors. Since this truncates a
-      // vector to a smaller vector, this can only operate on arguments where
-      // the source and destination types are ExtVectors.
-      assert(From->getType()->isExtVectorType() && ToType->isExtVectorType() &&
-             "HLSL vector truncation should only apply to ExtVectors");
+      // vector to a smaller vector or to a scalar, this can only operate on
+      // arguments where the source type is an ExtVector and the destination
+      // type is destination type is either an ExtVectorType or a builtin scalar
+      // type.
       auto *FromVec = From->getType()->castAs<VectorType>();
-      auto *ToVec = ToType->castAs<VectorType>();
       QualType ElType = FromVec->getElementType();
-      QualType TruncTy =
-          Context.getExtVectorType(ElType, ToVec->getNumElements());
-      From = ImpCastExprToType(From, TruncTy, CK_HLSLVectorTruncation,
-                               From->getValueKind())
-                 .get();
+      if (auto *ToVec = ToType->getAs<VectorType>()) {
+        QualType TruncTy =
+            Context.getExtVectorType(ElType, ToVec->getNumElements());
+        From = ImpCastExprToType(From, TruncTy, CK_HLSLVectorTruncation,
+                                 From->getValueKind())
+                   .get();
+      } else {
+        From = ImpCastExprToType(From, ElType, CK_HLSLVectorTruncation,
+                                 From->getValueKind())
+                   .get();
+      }
       break;
     }
     case ICK_Identity:
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index 52f640eb96b73b..f6a097305564a0 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -2032,26 +2032,42 @@ static bool IsVectorConversion(Sema &S, QualType FromType, QualType ToType,
   if (S.Context.hasSameUnqualifiedType(FromType, ToType))
     return false;
 
+  // HLSL allows implicit truncation of vector types.
+  if (S.getLangOpts().HLSL) {
+    auto *ToExtType = ToType->getAs<ExtVectorType>();
+    auto *FromExtType = FromType->getAs<ExtVectorType>();
+
+    // If both arguments are vectors, handle possible vector truncation and
+    // element conversion.
+    if (ToExtType && FromExtType) {
+      unsigned FromElts = FromExtType->getNumElements();
+      unsigned ToElts = ToExtType->getNumElements();
+      if (FromElts < ToElts)
+        return false;
+      if (FromElts == ToElts)
+        ElConv = ICK_Identity;
+      else
+        ElConv = ICK_HLSL_Vector_Truncation;
+
+      QualType FromElTy = FromExtType->getElementType();
+      QualType ToElTy = ToExtType->getElementType();
+      if (S.Context.hasSameUnqualifiedType(FromElTy, ToElTy))
+        return true;
+      return IsVectorElementConversion(S, FromElTy, ToElTy, ICK, From);
+    }
+    if (FromExtType && nullptr == ToExtType) {
+      ElConv = ICK_HLSL_Vector_Truncation;
+      QualType FromElTy = FromExtType->getElementType();
+      if (S.Context.hasSameUnqualifiedType(FromElTy, ToType))
+        return true;
+      return IsVectorElementConversion(S, FromElTy, ToType, ICK, From);
+    }
+    // Fallthrough for the case where ToType is a vector and FromType is not.
+  }
+
   // There are no conversions between extended vector types, only identity.
   if (auto *ToExtType = ToType->getAs<ExtVectorType>()) {
     if (auto *FromExtType = FromType->getAs<ExtVectorType>()) {
-      // HLSL allows implicit truncation of vector types.
-      if (S.getLangOpts().HLSL) {
-        unsigned FromElts = FromExtType->getNumElements();
-        unsigned ToElts = ToExtType->getNumElements();
-        if (FromElts < ToElts)
-          return false;
-        if (FromElts == ToElts)
-          ElConv = ICK_Identity;
-        else
-          ElConv = ICK_HLSL_Vector_Truncation;
-
-        QualType FromElTy = FromExtType->getElementType();
-        QualType ToElTy = ToExtType->getElementType();
-        if (S.Context.hasSameUnqualifiedType(FromElTy, ToElTy))
-          return true;
-        return IsVectorElementConversion(S, FromElTy, ToElTy, ICK, From);
-      }
       // There are no conversions between extended vector types other than the
       // identity conversion.
       return false;
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl
index 5d751be6dae066..6478ea67e32a0d 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl
@@ -117,3 +117,27 @@ void d4_to_b2() {
   vector<double,4> d4 = 9.0;
   vector<bool, 2> b2 = d4;
 }
+
+// CHECK-LABEL: d4_to_d1
+// CHECK: [[d4:%.*]] = alloca <4 x double>
+// CHECK: [[d1:%.*]] = alloca <1 x double>
+// CHECK: store <4 x double> <double 9.000000e+00, double 9.000000e+00, double 9.000000e+00, double 9.000000e+00>, ptr [[d4]]
+// CHECK: [[vecd4:%.*]] = load <4 x double>, ptr [[d4]]
+// CHECK: [[vecd1:%.*]] = shufflevector <4 x double> [[vecd4]], <4 x double> poison, <1 x i32> zeroinitializer
+// CHECK: store <1 x double> [[vecd1]], ptr [[d1:%.*]], align 8
+void d4_to_d1() {
+  vector<double,4> d4 = 9.0;
+  vector<double,1> d1 = d4;
+}
+
+// CHECK-LABEL: d4_to_dScalar
+// CHECK: [[d4:%.*]] = alloca <4 x double>
+// CHECK: [[d:%.*]] = alloca double
+// CHECK: store <4 x double> <double 9.000000e+00, double 9.000000e+00, double 9.000000e+00, double 9.000000e+00>, ptr [[d4]]
+// CHECK: [[vecd4:%.*]] = load <4 x double>, ptr [[d4]]
+// CHECK: [[d4x:%.*]] = extractelement <4 x double> [[vecd4]], i32 0
+// CHECK: store double [[d4x]], ptr [[d]]
+void d4_to_dScalar() {
+  vector<double,4> d4 = 9.0;
+  double d = d4;
+}
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index ae6e45c3f9482a..51ce5a88b302b9 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -143,19 +143,3 @@ float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
 // CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
 // CHECK: ret float %dx.dot
 float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
-
-// CHECK:  %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
-// CHECK: ret float %dx.dot
-float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
-
-// CHECK:  %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
-// CHECK: ret float %dx.dot
-float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
-
-// CHECK:  %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
-// CHECK: ret float %dx.dot
-float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
-
-// CHECK: %dx.dot = fmul double %0, %1
-// CHECK: ret double %dx.dot
-double test_dot_double(double p0, double p1) { return dot(p0, p1); }
diff --git a/clang/test/CodeGenHLSL/builtins/lerp.hlsl b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
index 53ac24dd456930..4c0cbe8a671a9b 100644
--- a/clang/test/CodeGenHLSL/builtins/lerp.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
@@ -65,24 +65,3 @@ float3 test_lerp_float3(float3 p0) { return lerp(p0, p0, p0); }
 // SPIR_CHECK: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}})
 // CHECK: ret <4 x float> %hlsl.lerp
 float4 test_lerp_float4(float4 p0) { return lerp(p0, p0, p0); }
-
-// CHECK: %[[b:.*]] = load <2 x float>, ptr %p1.addr, align 8
-// CHECK: %[[c:.*]] = load <2 x float>, ptr %p1.addr, align 8
-// DXIL_CHECK: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %[[b]], <2 x float> %[[c]])
-// SPIR_CHECK: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %[[b]], <2 x float> %[[c]])
-// CHECK: ret <2 x float> %hlsl.lerp
-float2 test_lerp_float2_splat(float p0, float2 p1) { return lerp(p0, p1, p1); }
-
-// CHECK: %[[b:.*]] = load <3 x float>, ptr %p1.addr, align 16
-// CHECK: %[[c:.*]] = load <3 x float>, ptr %p1.addr, align 16
-// DXIL_CHECK: %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %[[b]], <3 x float> %[[c]])
-// SPIR_CHECK: %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %[[b]], <3 x float> %[[c]])
-// CHECK: ret <3 x float> %hlsl.lerp
-float3 test_lerp_float3_splat(float p0, float3 p1) { return lerp(p0, p1, p1); }
-
-// CHECK: %[[b:.*]] = load <4 x float>, ptr %p1.addr, align 16
-// CHECK: %[[c:.*]] = load <4 x float>, ptr %p1.addr, align 16
-// DXIL_CHECK: %hlsl.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %[[b]], <4 x float> %[[c]])
-// SPIR_CHECK: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %[[b]], <4 x float> %[[c]])
-// CHECK:  ret <4 x float> %hlsl.lerp
-float4 test_lerp_float4_splat(float p0, float4 p1) { return lerp(p0, p1, p1); }
diff --git a/clang/test/CodeGenHLSL/builtins/mad.hlsl b/clang/test/CodeGenHLSL/builtins/mad.hlsl
index 449a793caf93b7..265a2552c80fb4 100644
--- a/clang/test/CodeGenHLSL/builtins/mad.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/mad.hlsl
@@ -263,21 +263,3 @@ uint64_t3 test_mad_uint64_t3(uint64_t3 p0, uint64_t3 p1, uint64_t3 p2) { return
 // SPIR_CHECK: mul nuw <4 x i64>  %{{.*}}, %{{.*}}
 // SPIR_CHECK: add nuw <4 x i64>  %{{.*}}, %{{.*}}
 uint64_t4 test_mad_uint64_t4(uint64_t4 p0, uint64_t4 p1, uint64_t4 p2) { return mad(p0, p1, p2); }
-
-// CHECK: %[[p1:.*]] = load <2 x float>, ptr %p1.addr, align 8
-// CHECK: %[[p2:.*]] = load <2 x float>, ptr %p2.addr, align 8
-// CHECK: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %splat.splat, <2 x float> %[[p1]], <2 x float> %[[p2]])
-// CHECK: ret <2 x float> %hlsl.fmad
-float2 test_mad_float2_splat(float p0, float2 p1, float2 p2) { return mad(p0, p1, p2); }
-
-// CHECK: %[[p1:.*]] = load <3 x float>, ptr %p1.addr, align 16
-// CHECK: %[[p2:.*]] = load <3 x float>, ptr %p2.addr, align 16
-// CHECK: %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %splat.splat, <3 x float> %[[p1]], <3 x float> %[[p2]])
-// CHECK: ret <3 x float> %hlsl.fmad
-float3 test_mad_float3_splat(float p0, float3 p1, float3 p2) { return mad(p0, p1, p2); }
-
-// CHECK: %[[p1:.*]] = load <4 x float>, ptr %p1.addr, align 16
-// CHECK: %[[p2:.*]] = load <4 x float>, ptr %p2.addr, align 16
-// CHECK:  %hlsl.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %splat.splat, <4 x float> %[[p1]], <4 x float> %[[p2]])
-// CHECK:  ret <4 x float> %hlsl.fmad
-float4 test_mad_float4_splat(float p0, float4 p1, float4 p2) { return mad(p0, p1, p2); }
diff --git a/clang/test/SemaHLSL/TruncationOverloadResolution.hlsl b/clang/test/SemaHLSL/TruncationOverloadResolution.hlsl
index f8cfe22372e885..cb00d2e226b7cf 100644
--- a/clang/test/SemaHLSL/TruncationOverloadResolution.hlsl
+++ b/clang/test/SemaHLSL/TruncationOverloadResolution.hlsl
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -fnative-half-type -finclude-default-header -fsyntax-only %s -DERROR=1 -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -fnative-half-type -finclude-default-header -Wconversion -fsyntax-only %s -DERROR=1 -verify
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -fnative-half-type -finclude-default-header -ast-dump %s | FileCheck %s
 
 // Case 1: Prefer exact-match truncation over conversion.
@@ -32,6 +32,42 @@ void Case2(float4 F) {
   Half2Double2(F); // expected-warning{{implicit conversion truncates vector: 'float4' (aka 'vector<float, 4>') to 'vector<double, 2>' (vector of 2 'double' values)}}
 }
 
+// Case 3: Allow truncation down to vector<T,1> or T.
+void Half(half H);
+void Float(float F);
+void Double(double D);
+
+void Half1(half1 H);
+void Float1(float1 F);
+void Double1(double1 D);
+
+void Case3(half3 H, float3 F, double3 D) {
+  Half(H); // expected-warning{{implicit conversion turns vector to scalar: 'half3' (aka 'vector<half, 3>') to 'half'}}
+  Half(F); // expected-warning{{implicit conversion turns vector to scalar: 'float3' (aka 'vector<float, 3>') to 'half'}}
+  Half(D); // expected-warning{{implicit conversion turns vector to scalar: 'double3' (aka 'vector<double, 3>') to 'half'}}
+
+  Float(H); // expected-warning{{implicit conversion turns vector to scalar: 'half3' (aka 'vector<half, 3>') to 'float'}}
+  Float(F); // expected-warning{{implicit conversion turns vector to scalar: 'float3' (aka 'vector<float, 3>') to 'float'}}
+  Float(D); // expected-warning{{implicit conversion turns vector to scalar: 'double3' (aka 'vector<double, 3>') to 'float'}}
+
+  Double(H); // expected-warning{{implicit conversion turns vector to scalar: 'half3' (aka 'vector<half, 3>') to 'double'}}
+  Double(F); // expected-warning{{implicit conversion turns vector to scalar: 'float3' (aka 'vector<float, 3>') to 'double'}}
+  Double(D); // expected-warning{{implicit conversion turns vector to scalar: 'double3' (aka 'vector<double, 3>') to 'double'}}
+
+  Half1(H); // expected-warning{{implicit conversion truncates vector: 'half3' (aka 'vector<half, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}}
+  Half1(F); // expected-warning{{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}} expected-warning{{implicit conversion loses floating-point precision: 'float3' (aka 'vector<float, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}}
+  Half1(D); // expected-warning{{implicit conversion truncates vector: 'double3' (aka 'vector<double, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}} expected-warning{{implicit conversion loses floating-point precision: 'double3' (aka 'vector<double, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}}
+
+  Float1(H); // expected-warning{{implicit conversion truncates vector: 'half3' (aka 'vector<half, 3>') to 'vector<float, 1>' (vector of 1 'float' value)}}
+  Float1(F); // expected-warning{{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'vector<float, 1>' (vector of 1 'float' value)}}
+  Float1(D); // expected-warning{{implicit conversion truncates vector: 'double3' (aka 'vector<double, 3>') to 'vector<float, 1>' (vector of 1 'float' value)}} expected-warning{{implicit conversion loses floating-point precision: 'double3' (aka 'vector<double, 3>') to...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 19, 2024

@llvm/pr-subscribers-backend-x86

Author: Chris B (llvm-beanz)

Changes

HLSL allows implicit conversions to truncate vectors to scalar pr-values. These conversions are scored as vector truncations and should warn appropriately.

This change allows forming a truncation cast to a pr-value, but not an l-value. Truncating a vector to a scalar is performed by loading the first element of the vector and disregarding the remaining elements.

Fixes #102964


Patch is 22.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/104844.diff

10 Files Affected:

  • (modified) clang/lib/CodeGen/CGExprScalar.cpp (+11-6)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+1-4)
  • (modified) clang/lib/Sema/SemaExprCXX.cpp (+28-21)
  • (modified) clang/lib/Sema/SemaOverload.cpp (+33-17)
  • (modified) clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl (+24)
  • (modified) clang/test/CodeGenHLSL/builtins/dot.hlsl (-16)
  • (modified) clang/test/CodeGenHLSL/builtins/lerp.hlsl (-21)
  • (modified) clang/test/CodeGenHLSL/builtins/mad.hlsl (-18)
  • (modified) clang/test/SemaHLSL/TruncationOverloadResolution.hlsl (+39-3)
  • (modified) clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzleErrors.hlsl (+5-1)
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 84392745ea6144..2ccdb840579a98 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2692,14 +2692,19 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     return CGF.CGM.createOpenCLIntToSamplerConversion(E, CGF);
 
   case CK_HLSLVectorTruncation: {
-    assert(DestTy->isVectorType() && "Expected dest type to be vector type");
+    assert((DestTy->isVectorType() || DestTy->isBuiltinType()) &&
+           "Destination type must be a vector or builtin type.");
     Value *Vec = Visit(const_cast<Expr *>(E));
-    SmallVector<int, 16> Mask;
-    unsigned NumElts = DestTy->castAs<VectorType>()->getNumElements();
-    for (unsigned I = 0; I != NumElts; ++I)
-      Mask.push_back(I);
+    if (auto *VecTy = DestTy->getAs<VectorType>()) {
+      SmallVector<int, 16> Mask;
+      unsigned NumElts = VecTy->getNumElements();
+      for (unsigned I = 0; I != NumElts; ++I)
+        Mask.push_back(I);
 
-    return Builder.CreateShuffleVector(Vec, Mask, "trunc");
+      return Builder.CreateShuffleVector(Vec, Mask, "trunc");
+    }
+    llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
+    return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
   }
 
   } // end of switch
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 678cdc77f8a71b..be65e149515b9f 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -673,9 +673,6 @@ float dot(float3, float3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
 float dot(float4, float4);
 
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
-double dot(double, double);
-
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
 int dot(int, int);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
@@ -916,7 +913,7 @@ float4 lerp(float4, float4, float4);
 /// \brief Returns the length of the specified floating-point vector.
 /// \param x [in] The vector of floats, or a scalar float.
 ///
-/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + �).
+/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + ...).
 
 _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index 5356bcf172f752..e62e3bcbead3be 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -4301,8 +4301,10 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
 // from type to the elements of the to type without resizing the vector.
 static QualType adjustVectorType(ASTContext &Context, QualType FromTy,
                                  QualType ToType, QualType *ElTy = nullptr) {
-  auto *ToVec = ToType->castAs<VectorType>();
-  QualType ElType = ToVec->getElementType();
+  QualType ElType = ToType;
+  if (auto *ToVec = ToType->getAs<VectorType>())
+    ElType = ToVec->getElementType();
+
   if (ElTy)
     *ElTy = ElType;
   if (!FromTy->isVectorType())
@@ -4463,7 +4465,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
   case ICK_Integral_Conversion: {
     QualType ElTy = ToType;
     QualType StepTy = ToType;
-    if (ToType->isVectorType())
+    if (FromType->isVectorType() || ToType->isVectorType())
       StepTy = adjustVectorType(Context, FromType, ToType, &ElTy);
     if (ElTy->isBooleanType()) {
       assert(FromType->castAs<EnumType>()->getDecl()->isFixed() &&
@@ -4483,7 +4485,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
   case ICK_Floating_Promotion:
   case ICK_Floating_Conversion: {
     QualType StepTy = ToType;
-    if (ToType->isVectorType())
+    if (FromType->isVectorType() || ToType->isVectorType())
       StepTy = adjustVectorType(Context, FromType, ToType);
     From = ImpCastExprToType(From, StepTy, CK_FloatingCast, VK_PRValue,
                              /*BasePath=*/nullptr, CCK)
@@ -4515,7 +4517,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
   case ICK_Floating_Integral: {
     QualType ElTy = ToType;
     QualType StepTy = ToType;
-    if (ToType->isVectorType())
+    if (FromType->isVectorType() || ToType->isVectorType())
       StepTy = adjustVectorType(Context, FromType, ToType, &ElTy);
     if (ElTy->isRealFloatingType())
       From = ImpCastExprToType(From, StepTy, CK_IntegralToFloating, VK_PRValue,
@@ -4656,11 +4658,11 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
     }
     QualType ElTy = FromType;
     QualType StepTy = ToType;
-    if (FromType->isVectorType()) {
-      if (getLangOpts().HLSL)
-        StepTy = adjustVectorType(Context, FromType, ToType);
+    if (FromType->isVectorType())
       ElTy = FromType->castAs<VectorType>()->getElementType();
-    }
+    if (getLangOpts().HLSL &&
+        (FromType->isVectorType() || ToType->isVectorType()))
+      StepTy = adjustVectorType(Context, FromType, ToType);
 
     From = ImpCastExprToType(From, StepTy, ScalarTypeToBooleanCastKind(ElTy),
                              VK_PRValue,
@@ -4815,8 +4817,8 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
     // TODO: Support HLSL matrices.
     assert((!From->getType()->isMatrixType() && !ToType->isMatrixType()) &&
            "Dimension conversion for matrix types is not implemented yet.");
-    assert(ToType->isVectorType() &&
-           "Dimension conversion is only supported for vector types.");
+    assert((ToType->isVectorType() || ToType->isBuiltinType()) &&
+           "Dimension conversion output must be vector or scalar type.");
     switch (SCS.Dimension) {
     case ICK_HLSL_Vector_Splat: {
       // Vector splat from any arithmetic type to a vector.
@@ -4828,18 +4830,23 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
     }
     case ICK_HLSL_Vector_Truncation: {
       // Note: HLSL built-in vectors are ExtVectors. Since this truncates a
-      // vector to a smaller vector, this can only operate on arguments where
-      // the source and destination types are ExtVectors.
-      assert(From->getType()->isExtVectorType() && ToType->isExtVectorType() &&
-             "HLSL vector truncation should only apply to ExtVectors");
+      // vector to a smaller vector or to a scalar, this can only operate on
+      // arguments where the source type is an ExtVector and the destination
+      // type is destination type is either an ExtVectorType or a builtin scalar
+      // type.
       auto *FromVec = From->getType()->castAs<VectorType>();
-      auto *ToVec = ToType->castAs<VectorType>();
       QualType ElType = FromVec->getElementType();
-      QualType TruncTy =
-          Context.getExtVectorType(ElType, ToVec->getNumElements());
-      From = ImpCastExprToType(From, TruncTy, CK_HLSLVectorTruncation,
-                               From->getValueKind())
-                 .get();
+      if (auto *ToVec = ToType->getAs<VectorType>()) {
+        QualType TruncTy =
+            Context.getExtVectorType(ElType, ToVec->getNumElements());
+        From = ImpCastExprToType(From, TruncTy, CK_HLSLVectorTruncation,
+                                 From->getValueKind())
+                   .get();
+      } else {
+        From = ImpCastExprToType(From, ElType, CK_HLSLVectorTruncation,
+                                 From->getValueKind())
+                   .get();
+      }
       break;
     }
     case ICK_Identity:
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index 52f640eb96b73b..f6a097305564a0 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -2032,26 +2032,42 @@ static bool IsVectorConversion(Sema &S, QualType FromType, QualType ToType,
   if (S.Context.hasSameUnqualifiedType(FromType, ToType))
     return false;
 
+  // HLSL allows implicit truncation of vector types.
+  if (S.getLangOpts().HLSL) {
+    auto *ToExtType = ToType->getAs<ExtVectorType>();
+    auto *FromExtType = FromType->getAs<ExtVectorType>();
+
+    // If both arguments are vectors, handle possible vector truncation and
+    // element conversion.
+    if (ToExtType && FromExtType) {
+      unsigned FromElts = FromExtType->getNumElements();
+      unsigned ToElts = ToExtType->getNumElements();
+      if (FromElts < ToElts)
+        return false;
+      if (FromElts == ToElts)
+        ElConv = ICK_Identity;
+      else
+        ElConv = ICK_HLSL_Vector_Truncation;
+
+      QualType FromElTy = FromExtType->getElementType();
+      QualType ToElTy = ToExtType->getElementType();
+      if (S.Context.hasSameUnqualifiedType(FromElTy, ToElTy))
+        return true;
+      return IsVectorElementConversion(S, FromElTy, ToElTy, ICK, From);
+    }
+    if (FromExtType && nullptr == ToExtType) {
+      ElConv = ICK_HLSL_Vector_Truncation;
+      QualType FromElTy = FromExtType->getElementType();
+      if (S.Context.hasSameUnqualifiedType(FromElTy, ToType))
+        return true;
+      return IsVectorElementConversion(S, FromElTy, ToType, ICK, From);
+    }
+    // Fallthrough for the case where ToType is a vector and FromType is not.
+  }
+
   // There are no conversions between extended vector types, only identity.
   if (auto *ToExtType = ToType->getAs<ExtVectorType>()) {
     if (auto *FromExtType = FromType->getAs<ExtVectorType>()) {
-      // HLSL allows implicit truncation of vector types.
-      if (S.getLangOpts().HLSL) {
-        unsigned FromElts = FromExtType->getNumElements();
-        unsigned ToElts = ToExtType->getNumElements();
-        if (FromElts < ToElts)
-          return false;
-        if (FromElts == ToElts)
-          ElConv = ICK_Identity;
-        else
-          ElConv = ICK_HLSL_Vector_Truncation;
-
-        QualType FromElTy = FromExtType->getElementType();
-        QualType ToElTy = ToExtType->getElementType();
-        if (S.Context.hasSameUnqualifiedType(FromElTy, ToElTy))
-          return true;
-        return IsVectorElementConversion(S, FromElTy, ToElTy, ICK, From);
-      }
       // There are no conversions between extended vector types other than the
       // identity conversion.
       return false;
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl
index 5d751be6dae066..6478ea67e32a0d 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl
@@ -117,3 +117,27 @@ void d4_to_b2() {
   vector<double,4> d4 = 9.0;
   vector<bool, 2> b2 = d4;
 }
+
+// CHECK-LABEL: d4_to_d1
+// CHECK: [[d4:%.*]] = alloca <4 x double>
+// CHECK: [[d1:%.*]] = alloca <1 x double>
+// CHECK: store <4 x double> <double 9.000000e+00, double 9.000000e+00, double 9.000000e+00, double 9.000000e+00>, ptr [[d4]]
+// CHECK: [[vecd4:%.*]] = load <4 x double>, ptr [[d4]]
+// CHECK: [[vecd1:%.*]] = shufflevector <4 x double> [[vecd4]], <4 x double> poison, <1 x i32> zeroinitializer
+// CHECK: store <1 x double> [[vecd1]], ptr [[d1:%.*]], align 8
+void d4_to_d1() {
+  vector<double,4> d4 = 9.0;
+  vector<double,1> d1 = d4;
+}
+
+// CHECK-LABEL: d4_to_dScalar
+// CHECK: [[d4:%.*]] = alloca <4 x double>
+// CHECK: [[d:%.*]] = alloca double
+// CHECK: store <4 x double> <double 9.000000e+00, double 9.000000e+00, double 9.000000e+00, double 9.000000e+00>, ptr [[d4]]
+// CHECK: [[vecd4:%.*]] = load <4 x double>, ptr [[d4]]
+// CHECK: [[d4x:%.*]] = extractelement <4 x double> [[vecd4]], i32 0
+// CHECK: store double [[d4x]], ptr [[d]]
+void d4_to_dScalar() {
+  vector<double,4> d4 = 9.0;
+  double d = d4;
+}
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index ae6e45c3f9482a..51ce5a88b302b9 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -143,19 +143,3 @@ float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
 // CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
 // CHECK: ret float %dx.dot
 float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
-
-// CHECK:  %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
-// CHECK: ret float %dx.dot
-float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
-
-// CHECK:  %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
-// CHECK: ret float %dx.dot
-float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
-
-// CHECK:  %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
-// CHECK: ret float %dx.dot
-float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
-
-// CHECK: %dx.dot = fmul double %0, %1
-// CHECK: ret double %dx.dot
-double test_dot_double(double p0, double p1) { return dot(p0, p1); }
diff --git a/clang/test/CodeGenHLSL/builtins/lerp.hlsl b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
index 53ac24dd456930..4c0cbe8a671a9b 100644
--- a/clang/test/CodeGenHLSL/builtins/lerp.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
@@ -65,24 +65,3 @@ float3 test_lerp_float3(float3 p0) { return lerp(p0, p0, p0); }
 // SPIR_CHECK: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}})
 // CHECK: ret <4 x float> %hlsl.lerp
 float4 test_lerp_float4(float4 p0) { return lerp(p0, p0, p0); }
-
-// CHECK: %[[b:.*]] = load <2 x float>, ptr %p1.addr, align 8
-// CHECK: %[[c:.*]] = load <2 x float>, ptr %p1.addr, align 8
-// DXIL_CHECK: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %[[b]], <2 x float> %[[c]])
-// SPIR_CHECK: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %[[b]], <2 x float> %[[c]])
-// CHECK: ret <2 x float> %hlsl.lerp
-float2 test_lerp_float2_splat(float p0, float2 p1) { return lerp(p0, p1, p1); }
-
-// CHECK: %[[b:.*]] = load <3 x float>, ptr %p1.addr, align 16
-// CHECK: %[[c:.*]] = load <3 x float>, ptr %p1.addr, align 16
-// DXIL_CHECK: %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %[[b]], <3 x float> %[[c]])
-// SPIR_CHECK: %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %[[b]], <3 x float> %[[c]])
-// CHECK: ret <3 x float> %hlsl.lerp
-float3 test_lerp_float3_splat(float p0, float3 p1) { return lerp(p0, p1, p1); }
-
-// CHECK: %[[b:.*]] = load <4 x float>, ptr %p1.addr, align 16
-// CHECK: %[[c:.*]] = load <4 x float>, ptr %p1.addr, align 16
-// DXIL_CHECK: %hlsl.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %[[b]], <4 x float> %[[c]])
-// SPIR_CHECK: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %[[b]], <4 x float> %[[c]])
-// CHECK:  ret <4 x float> %hlsl.lerp
-float4 test_lerp_float4_splat(float p0, float4 p1) { return lerp(p0, p1, p1); }
diff --git a/clang/test/CodeGenHLSL/builtins/mad.hlsl b/clang/test/CodeGenHLSL/builtins/mad.hlsl
index 449a793caf93b7..265a2552c80fb4 100644
--- a/clang/test/CodeGenHLSL/builtins/mad.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/mad.hlsl
@@ -263,21 +263,3 @@ uint64_t3 test_mad_uint64_t3(uint64_t3 p0, uint64_t3 p1, uint64_t3 p2) { return
 // SPIR_CHECK: mul nuw <4 x i64>  %{{.*}}, %{{.*}}
 // SPIR_CHECK: add nuw <4 x i64>  %{{.*}}, %{{.*}}
 uint64_t4 test_mad_uint64_t4(uint64_t4 p0, uint64_t4 p1, uint64_t4 p2) { return mad(p0, p1, p2); }
-
-// CHECK: %[[p1:.*]] = load <2 x float>, ptr %p1.addr, align 8
-// CHECK: %[[p2:.*]] = load <2 x float>, ptr %p2.addr, align 8
-// CHECK: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %splat.splat, <2 x float> %[[p1]], <2 x float> %[[p2]])
-// CHECK: ret <2 x float> %hlsl.fmad
-float2 test_mad_float2_splat(float p0, float2 p1, float2 p2) { return mad(p0, p1, p2); }
-
-// CHECK: %[[p1:.*]] = load <3 x float>, ptr %p1.addr, align 16
-// CHECK: %[[p2:.*]] = load <3 x float>, ptr %p2.addr, align 16
-// CHECK: %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %splat.splat, <3 x float> %[[p1]], <3 x float> %[[p2]])
-// CHECK: ret <3 x float> %hlsl.fmad
-float3 test_mad_float3_splat(float p0, float3 p1, float3 p2) { return mad(p0, p1, p2); }
-
-// CHECK: %[[p1:.*]] = load <4 x float>, ptr %p1.addr, align 16
-// CHECK: %[[p2:.*]] = load <4 x float>, ptr %p2.addr, align 16
-// CHECK:  %hlsl.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %splat.splat, <4 x float> %[[p1]], <4 x float> %[[p2]])
-// CHECK:  ret <4 x float> %hlsl.fmad
-float4 test_mad_float4_splat(float p0, float4 p1, float4 p2) { return mad(p0, p1, p2); }
diff --git a/clang/test/SemaHLSL/TruncationOverloadResolution.hlsl b/clang/test/SemaHLSL/TruncationOverloadResolution.hlsl
index f8cfe22372e885..cb00d2e226b7cf 100644
--- a/clang/test/SemaHLSL/TruncationOverloadResolution.hlsl
+++ b/clang/test/SemaHLSL/TruncationOverloadResolution.hlsl
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -fnative-half-type -finclude-default-header -fsyntax-only %s -DERROR=1 -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -fnative-half-type -finclude-default-header -Wconversion -fsyntax-only %s -DERROR=1 -verify
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -fnative-half-type -finclude-default-header -ast-dump %s | FileCheck %s
 
 // Case 1: Prefer exact-match truncation over conversion.
@@ -32,6 +32,42 @@ void Case2(float4 F) {
   Half2Double2(F); // expected-warning{{implicit conversion truncates vector: 'float4' (aka 'vector<float, 4>') to 'vector<double, 2>' (vector of 2 'double' values)}}
 }
 
+// Case 3: Allow truncation down to vector<T,1> or T.
+void Half(half H);
+void Float(float F);
+void Double(double D);
+
+void Half1(half1 H);
+void Float1(float1 F);
+void Double1(double1 D);
+
+void Case3(half3 H, float3 F, double3 D) {
+  Half(H); // expected-warning{{implicit conversion turns vector to scalar: 'half3' (aka 'vector<half, 3>') to 'half'}}
+  Half(F); // expected-warning{{implicit conversion turns vector to scalar: 'float3' (aka 'vector<float, 3>') to 'half'}}
+  Half(D); // expected-warning{{implicit conversion turns vector to scalar: 'double3' (aka 'vector<double, 3>') to 'half'}}
+
+  Float(H); // expected-warning{{implicit conversion turns vector to scalar: 'half3' (aka 'vector<half, 3>') to 'float'}}
+  Float(F); // expected-warning{{implicit conversion turns vector to scalar: 'float3' (aka 'vector<float, 3>') to 'float'}}
+  Float(D); // expected-warning{{implicit conversion turns vector to scalar: 'double3' (aka 'vector<double, 3>') to 'float'}}
+
+  Double(H); // expected-warning{{implicit conversion turns vector to scalar: 'half3' (aka 'vector<half, 3>') to 'double'}}
+  Double(F); // expected-warning{{implicit conversion turns vector to scalar: 'float3' (aka 'vector<float, 3>') to 'double'}}
+  Double(D); // expected-warning{{implicit conversion turns vector to scalar: 'double3' (aka 'vector<double, 3>') to 'double'}}
+
+  Half1(H); // expected-warning{{implicit conversion truncates vector: 'half3' (aka 'vector<half, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}}
+  Half1(F); // expected-warning{{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}} expected-warning{{implicit conversion loses floating-point precision: 'float3' (aka 'vector<float, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}}
+  Half1(D); // expected-warning{{implicit conversion truncates vector: 'double3' (aka 'vector<double, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}} expected-warning{{implicit conversion loses floating-point precision: 'double3' (aka 'vector<double, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}}
+
+  Float1(H); // expected-warning{{implicit conversion truncates vector: 'half3' (aka 'vector<half, 3>') to 'vector<float, 1>' (vector of 1 'float' value)}}
+  Float1(F); // expected-warning{{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'vector<float, 1>' (vector of 1 'float' value)}}
+  Float1(D); // expected-warning{{implicit conversion truncates vector: 'double3' (aka 'vector<double, 3>') to 'vector<float, 1>' (vector of 1 'float' value)}} expected-warning{{implicit conversion loses floating-point precision: 'double3' (aka 'vector<double, 3>') to...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 19, 2024

@llvm/pr-subscribers-clang

Author: Chris B (llvm-beanz)

Changes

HLSL allows implicit conversions to truncate vectors to scalar pr-values. These conversions are scored as vector truncations and should warn appropriately.

This change allows forming a truncation cast to a pr-value, but not an l-value. Truncating a vector to a scalar is performed by loading the first element of the vector and disregarding the remaining elements.

Fixes #102964


Patch is 22.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/104844.diff

10 Files Affected:

  • (modified) clang/lib/CodeGen/CGExprScalar.cpp (+11-6)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+1-4)
  • (modified) clang/lib/Sema/SemaExprCXX.cpp (+28-21)
  • (modified) clang/lib/Sema/SemaOverload.cpp (+33-17)
  • (modified) clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl (+24)
  • (modified) clang/test/CodeGenHLSL/builtins/dot.hlsl (-16)
  • (modified) clang/test/CodeGenHLSL/builtins/lerp.hlsl (-21)
  • (modified) clang/test/CodeGenHLSL/builtins/mad.hlsl (-18)
  • (modified) clang/test/SemaHLSL/TruncationOverloadResolution.hlsl (+39-3)
  • (modified) clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzleErrors.hlsl (+5-1)
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 84392745ea6144..2ccdb840579a98 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2692,14 +2692,19 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     return CGF.CGM.createOpenCLIntToSamplerConversion(E, CGF);
 
   case CK_HLSLVectorTruncation: {
-    assert(DestTy->isVectorType() && "Expected dest type to be vector type");
+    assert((DestTy->isVectorType() || DestTy->isBuiltinType()) &&
+           "Destination type must be a vector or builtin type.");
     Value *Vec = Visit(const_cast<Expr *>(E));
-    SmallVector<int, 16> Mask;
-    unsigned NumElts = DestTy->castAs<VectorType>()->getNumElements();
-    for (unsigned I = 0; I != NumElts; ++I)
-      Mask.push_back(I);
+    if (auto *VecTy = DestTy->getAs<VectorType>()) {
+      SmallVector<int, 16> Mask;
+      unsigned NumElts = VecTy->getNumElements();
+      for (unsigned I = 0; I != NumElts; ++I)
+        Mask.push_back(I);
 
-    return Builder.CreateShuffleVector(Vec, Mask, "trunc");
+      return Builder.CreateShuffleVector(Vec, Mask, "trunc");
+    }
+    llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
+    return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
   }
 
   } // end of switch
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 678cdc77f8a71b..be65e149515b9f 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -673,9 +673,6 @@ float dot(float3, float3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
 float dot(float4, float4);
 
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
-double dot(double, double);
-
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
 int dot(int, int);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
@@ -916,7 +913,7 @@ float4 lerp(float4, float4, float4);
 /// \brief Returns the length of the specified floating-point vector.
 /// \param x [in] The vector of floats, or a scalar float.
 ///
-/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + �).
+/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + ...).
 
 _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index 5356bcf172f752..e62e3bcbead3be 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -4301,8 +4301,10 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
 // from type to the elements of the to type without resizing the vector.
 static QualType adjustVectorType(ASTContext &Context, QualType FromTy,
                                  QualType ToType, QualType *ElTy = nullptr) {
-  auto *ToVec = ToType->castAs<VectorType>();
-  QualType ElType = ToVec->getElementType();
+  QualType ElType = ToType;
+  if (auto *ToVec = ToType->getAs<VectorType>())
+    ElType = ToVec->getElementType();
+
   if (ElTy)
     *ElTy = ElType;
   if (!FromTy->isVectorType())
@@ -4463,7 +4465,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
   case ICK_Integral_Conversion: {
     QualType ElTy = ToType;
     QualType StepTy = ToType;
-    if (ToType->isVectorType())
+    if (FromType->isVectorType() || ToType->isVectorType())
       StepTy = adjustVectorType(Context, FromType, ToType, &ElTy);
     if (ElTy->isBooleanType()) {
       assert(FromType->castAs<EnumType>()->getDecl()->isFixed() &&
@@ -4483,7 +4485,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
   case ICK_Floating_Promotion:
   case ICK_Floating_Conversion: {
     QualType StepTy = ToType;
-    if (ToType->isVectorType())
+    if (FromType->isVectorType() || ToType->isVectorType())
       StepTy = adjustVectorType(Context, FromType, ToType);
     From = ImpCastExprToType(From, StepTy, CK_FloatingCast, VK_PRValue,
                              /*BasePath=*/nullptr, CCK)
@@ -4515,7 +4517,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
   case ICK_Floating_Integral: {
     QualType ElTy = ToType;
     QualType StepTy = ToType;
-    if (ToType->isVectorType())
+    if (FromType->isVectorType() || ToType->isVectorType())
       StepTy = adjustVectorType(Context, FromType, ToType, &ElTy);
     if (ElTy->isRealFloatingType())
       From = ImpCastExprToType(From, StepTy, CK_IntegralToFloating, VK_PRValue,
@@ -4656,11 +4658,11 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
     }
     QualType ElTy = FromType;
     QualType StepTy = ToType;
-    if (FromType->isVectorType()) {
-      if (getLangOpts().HLSL)
-        StepTy = adjustVectorType(Context, FromType, ToType);
+    if (FromType->isVectorType())
       ElTy = FromType->castAs<VectorType>()->getElementType();
-    }
+    if (getLangOpts().HLSL &&
+        (FromType->isVectorType() || ToType->isVectorType()))
+      StepTy = adjustVectorType(Context, FromType, ToType);
 
     From = ImpCastExprToType(From, StepTy, ScalarTypeToBooleanCastKind(ElTy),
                              VK_PRValue,
@@ -4815,8 +4817,8 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
     // TODO: Support HLSL matrices.
     assert((!From->getType()->isMatrixType() && !ToType->isMatrixType()) &&
            "Dimension conversion for matrix types is not implemented yet.");
-    assert(ToType->isVectorType() &&
-           "Dimension conversion is only supported for vector types.");
+    assert((ToType->isVectorType() || ToType->isBuiltinType()) &&
+           "Dimension conversion output must be vector or scalar type.");
     switch (SCS.Dimension) {
     case ICK_HLSL_Vector_Splat: {
       // Vector splat from any arithmetic type to a vector.
@@ -4828,18 +4830,23 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
     }
     case ICK_HLSL_Vector_Truncation: {
       // Note: HLSL built-in vectors are ExtVectors. Since this truncates a
-      // vector to a smaller vector, this can only operate on arguments where
-      // the source and destination types are ExtVectors.
-      assert(From->getType()->isExtVectorType() && ToType->isExtVectorType() &&
-             "HLSL vector truncation should only apply to ExtVectors");
+      // vector to a smaller vector or to a scalar, this can only operate on
+      // arguments where the source type is an ExtVector and the destination
+      // type is destination type is either an ExtVectorType or a builtin scalar
+      // type.
       auto *FromVec = From->getType()->castAs<VectorType>();
-      auto *ToVec = ToType->castAs<VectorType>();
       QualType ElType = FromVec->getElementType();
-      QualType TruncTy =
-          Context.getExtVectorType(ElType, ToVec->getNumElements());
-      From = ImpCastExprToType(From, TruncTy, CK_HLSLVectorTruncation,
-                               From->getValueKind())
-                 .get();
+      if (auto *ToVec = ToType->getAs<VectorType>()) {
+        QualType TruncTy =
+            Context.getExtVectorType(ElType, ToVec->getNumElements());
+        From = ImpCastExprToType(From, TruncTy, CK_HLSLVectorTruncation,
+                                 From->getValueKind())
+                   .get();
+      } else {
+        From = ImpCastExprToType(From, ElType, CK_HLSLVectorTruncation,
+                                 From->getValueKind())
+                   .get();
+      }
       break;
     }
     case ICK_Identity:
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index 52f640eb96b73b..f6a097305564a0 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -2032,26 +2032,42 @@ static bool IsVectorConversion(Sema &S, QualType FromType, QualType ToType,
   if (S.Context.hasSameUnqualifiedType(FromType, ToType))
     return false;
 
+  // HLSL allows implicit truncation of vector types.
+  if (S.getLangOpts().HLSL) {
+    auto *ToExtType = ToType->getAs<ExtVectorType>();
+    auto *FromExtType = FromType->getAs<ExtVectorType>();
+
+    // If both arguments are vectors, handle possible vector truncation and
+    // element conversion.
+    if (ToExtType && FromExtType) {
+      unsigned FromElts = FromExtType->getNumElements();
+      unsigned ToElts = ToExtType->getNumElements();
+      if (FromElts < ToElts)
+        return false;
+      if (FromElts == ToElts)
+        ElConv = ICK_Identity;
+      else
+        ElConv = ICK_HLSL_Vector_Truncation;
+
+      QualType FromElTy = FromExtType->getElementType();
+      QualType ToElTy = ToExtType->getElementType();
+      if (S.Context.hasSameUnqualifiedType(FromElTy, ToElTy))
+        return true;
+      return IsVectorElementConversion(S, FromElTy, ToElTy, ICK, From);
+    }
+    if (FromExtType && nullptr == ToExtType) {
+      ElConv = ICK_HLSL_Vector_Truncation;
+      QualType FromElTy = FromExtType->getElementType();
+      if (S.Context.hasSameUnqualifiedType(FromElTy, ToType))
+        return true;
+      return IsVectorElementConversion(S, FromElTy, ToType, ICK, From);
+    }
+    // Fallthrough for the case where ToType is a vector and FromType is not.
+  }
+
   // There are no conversions between extended vector types, only identity.
   if (auto *ToExtType = ToType->getAs<ExtVectorType>()) {
     if (auto *FromExtType = FromType->getAs<ExtVectorType>()) {
-      // HLSL allows implicit truncation of vector types.
-      if (S.getLangOpts().HLSL) {
-        unsigned FromElts = FromExtType->getNumElements();
-        unsigned ToElts = ToExtType->getNumElements();
-        if (FromElts < ToElts)
-          return false;
-        if (FromElts == ToElts)
-          ElConv = ICK_Identity;
-        else
-          ElConv = ICK_HLSL_Vector_Truncation;
-
-        QualType FromElTy = FromExtType->getElementType();
-        QualType ToElTy = ToExtType->getElementType();
-        if (S.Context.hasSameUnqualifiedType(FromElTy, ToElTy))
-          return true;
-        return IsVectorElementConversion(S, FromElTy, ToElTy, ICK, From);
-      }
       // There are no conversions between extended vector types other than the
       // identity conversion.
       return false;
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl
index 5d751be6dae066..6478ea67e32a0d 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl
@@ -117,3 +117,27 @@ void d4_to_b2() {
   vector<double,4> d4 = 9.0;
   vector<bool, 2> b2 = d4;
 }
+
+// CHECK-LABEL: d4_to_d1
+// CHECK: [[d4:%.*]] = alloca <4 x double>
+// CHECK: [[d1:%.*]] = alloca <1 x double>
+// CHECK: store <4 x double> <double 9.000000e+00, double 9.000000e+00, double 9.000000e+00, double 9.000000e+00>, ptr [[d4]]
+// CHECK: [[vecd4:%.*]] = load <4 x double>, ptr [[d4]]
+// CHECK: [[vecd1:%.*]] = shufflevector <4 x double> [[vecd4]], <4 x double> poison, <1 x i32> zeroinitializer
+// CHECK: store <1 x double> [[vecd1]], ptr [[d1:%.*]], align 8
+void d4_to_d1() {
+  vector<double,4> d4 = 9.0;
+  vector<double,1> d1 = d4;
+}
+
+// CHECK-LABEL: d4_to_dScalar
+// CHECK: [[d4:%.*]] = alloca <4 x double>
+// CHECK: [[d:%.*]] = alloca double
+// CHECK: store <4 x double> <double 9.000000e+00, double 9.000000e+00, double 9.000000e+00, double 9.000000e+00>, ptr [[d4]]
+// CHECK: [[vecd4:%.*]] = load <4 x double>, ptr [[d4]]
+// CHECK: [[d4x:%.*]] = extractelement <4 x double> [[vecd4]], i32 0
+// CHECK: store double [[d4x]], ptr [[d]]
+void d4_to_dScalar() {
+  vector<double,4> d4 = 9.0;
+  double d = d4;
+}
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index ae6e45c3f9482a..51ce5a88b302b9 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -143,19 +143,3 @@ float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
 // CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
 // CHECK: ret float %dx.dot
 float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
-
-// CHECK:  %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
-// CHECK: ret float %dx.dot
-float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
-
-// CHECK:  %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
-// CHECK: ret float %dx.dot
-float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
-
-// CHECK:  %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
-// CHECK: ret float %dx.dot
-float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
-
-// CHECK: %dx.dot = fmul double %0, %1
-// CHECK: ret double %dx.dot
-double test_dot_double(double p0, double p1) { return dot(p0, p1); }
diff --git a/clang/test/CodeGenHLSL/builtins/lerp.hlsl b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
index 53ac24dd456930..4c0cbe8a671a9b 100644
--- a/clang/test/CodeGenHLSL/builtins/lerp.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
@@ -65,24 +65,3 @@ float3 test_lerp_float3(float3 p0) { return lerp(p0, p0, p0); }
 // SPIR_CHECK: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}})
 // CHECK: ret <4 x float> %hlsl.lerp
 float4 test_lerp_float4(float4 p0) { return lerp(p0, p0, p0); }
-
-// CHECK: %[[b:.*]] = load <2 x float>, ptr %p1.addr, align 8
-// CHECK: %[[c:.*]] = load <2 x float>, ptr %p1.addr, align 8
-// DXIL_CHECK: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %[[b]], <2 x float> %[[c]])
-// SPIR_CHECK: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %[[b]], <2 x float> %[[c]])
-// CHECK: ret <2 x float> %hlsl.lerp
-float2 test_lerp_float2_splat(float p0, float2 p1) { return lerp(p0, p1, p1); }
-
-// CHECK: %[[b:.*]] = load <3 x float>, ptr %p1.addr, align 16
-// CHECK: %[[c:.*]] = load <3 x float>, ptr %p1.addr, align 16
-// DXIL_CHECK: %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %[[b]], <3 x float> %[[c]])
-// SPIR_CHECK: %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %[[b]], <3 x float> %[[c]])
-// CHECK: ret <3 x float> %hlsl.lerp
-float3 test_lerp_float3_splat(float p0, float3 p1) { return lerp(p0, p1, p1); }
-
-// CHECK: %[[b:.*]] = load <4 x float>, ptr %p1.addr, align 16
-// CHECK: %[[c:.*]] = load <4 x float>, ptr %p1.addr, align 16
-// DXIL_CHECK: %hlsl.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %[[b]], <4 x float> %[[c]])
-// SPIR_CHECK: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %[[b]], <4 x float> %[[c]])
-// CHECK:  ret <4 x float> %hlsl.lerp
-float4 test_lerp_float4_splat(float p0, float4 p1) { return lerp(p0, p1, p1); }
diff --git a/clang/test/CodeGenHLSL/builtins/mad.hlsl b/clang/test/CodeGenHLSL/builtins/mad.hlsl
index 449a793caf93b7..265a2552c80fb4 100644
--- a/clang/test/CodeGenHLSL/builtins/mad.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/mad.hlsl
@@ -263,21 +263,3 @@ uint64_t3 test_mad_uint64_t3(uint64_t3 p0, uint64_t3 p1, uint64_t3 p2) { return
 // SPIR_CHECK: mul nuw <4 x i64>  %{{.*}}, %{{.*}}
 // SPIR_CHECK: add nuw <4 x i64>  %{{.*}}, %{{.*}}
 uint64_t4 test_mad_uint64_t4(uint64_t4 p0, uint64_t4 p1, uint64_t4 p2) { return mad(p0, p1, p2); }
-
-// CHECK: %[[p1:.*]] = load <2 x float>, ptr %p1.addr, align 8
-// CHECK: %[[p2:.*]] = load <2 x float>, ptr %p2.addr, align 8
-// CHECK: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %splat.splat, <2 x float> %[[p1]], <2 x float> %[[p2]])
-// CHECK: ret <2 x float> %hlsl.fmad
-float2 test_mad_float2_splat(float p0, float2 p1, float2 p2) { return mad(p0, p1, p2); }
-
-// CHECK: %[[p1:.*]] = load <3 x float>, ptr %p1.addr, align 16
-// CHECK: %[[p2:.*]] = load <3 x float>, ptr %p2.addr, align 16
-// CHECK: %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %splat.splat, <3 x float> %[[p1]], <3 x float> %[[p2]])
-// CHECK: ret <3 x float> %hlsl.fmad
-float3 test_mad_float3_splat(float p0, float3 p1, float3 p2) { return mad(p0, p1, p2); }
-
-// CHECK: %[[p1:.*]] = load <4 x float>, ptr %p1.addr, align 16
-// CHECK: %[[p2:.*]] = load <4 x float>, ptr %p2.addr, align 16
-// CHECK:  %hlsl.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %splat.splat, <4 x float> %[[p1]], <4 x float> %[[p2]])
-// CHECK:  ret <4 x float> %hlsl.fmad
-float4 test_mad_float4_splat(float p0, float4 p1, float4 p2) { return mad(p0, p1, p2); }
diff --git a/clang/test/SemaHLSL/TruncationOverloadResolution.hlsl b/clang/test/SemaHLSL/TruncationOverloadResolution.hlsl
index f8cfe22372e885..cb00d2e226b7cf 100644
--- a/clang/test/SemaHLSL/TruncationOverloadResolution.hlsl
+++ b/clang/test/SemaHLSL/TruncationOverloadResolution.hlsl
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -fnative-half-type -finclude-default-header -fsyntax-only %s -DERROR=1 -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -fnative-half-type -finclude-default-header -Wconversion -fsyntax-only %s -DERROR=1 -verify
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -fnative-half-type -finclude-default-header -ast-dump %s | FileCheck %s
 
 // Case 1: Prefer exact-match truncation over conversion.
@@ -32,6 +32,42 @@ void Case2(float4 F) {
   Half2Double2(F); // expected-warning{{implicit conversion truncates vector: 'float4' (aka 'vector<float, 4>') to 'vector<double, 2>' (vector of 2 'double' values)}}
 }
 
+// Case 3: Allow truncation down to vector<T,1> or T.
+void Half(half H);
+void Float(float F);
+void Double(double D);
+
+void Half1(half1 H);
+void Float1(float1 F);
+void Double1(double1 D);
+
+void Case3(half3 H, float3 F, double3 D) {
+  Half(H); // expected-warning{{implicit conversion turns vector to scalar: 'half3' (aka 'vector<half, 3>') to 'half'}}
+  Half(F); // expected-warning{{implicit conversion turns vector to scalar: 'float3' (aka 'vector<float, 3>') to 'half'}}
+  Half(D); // expected-warning{{implicit conversion turns vector to scalar: 'double3' (aka 'vector<double, 3>') to 'half'}}
+
+  Float(H); // expected-warning{{implicit conversion turns vector to scalar: 'half3' (aka 'vector<half, 3>') to 'float'}}
+  Float(F); // expected-warning{{implicit conversion turns vector to scalar: 'float3' (aka 'vector<float, 3>') to 'float'}}
+  Float(D); // expected-warning{{implicit conversion turns vector to scalar: 'double3' (aka 'vector<double, 3>') to 'float'}}
+
+  Double(H); // expected-warning{{implicit conversion turns vector to scalar: 'half3' (aka 'vector<half, 3>') to 'double'}}
+  Double(F); // expected-warning{{implicit conversion turns vector to scalar: 'float3' (aka 'vector<float, 3>') to 'double'}}
+  Double(D); // expected-warning{{implicit conversion turns vector to scalar: 'double3' (aka 'vector<double, 3>') to 'double'}}
+
+  Half1(H); // expected-warning{{implicit conversion truncates vector: 'half3' (aka 'vector<half, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}}
+  Half1(F); // expected-warning{{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}} expected-warning{{implicit conversion loses floating-point precision: 'float3' (aka 'vector<float, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}}
+  Half1(D); // expected-warning{{implicit conversion truncates vector: 'double3' (aka 'vector<double, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}} expected-warning{{implicit conversion loses floating-point precision: 'double3' (aka 'vector<double, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}}
+
+  Float1(H); // expected-warning{{implicit conversion truncates vector: 'half3' (aka 'vector<half, 3>') to 'vector<float, 1>' (vector of 1 'float' value)}}
+  Float1(F); // expected-warning{{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'vector<float, 1>' (vector of 1 'float' value)}}
+  Float1(D); // expected-warning{{implicit conversion truncates vector: 'double3' (aka 'vector<double, 3>') to 'vector<float, 1>' (vector of 1 'float' value)}} expected-warning{{implicit conversion loses floating-point precision: 'double3' (aka 'vector<double, 3>') to...
[truncated]

Comment on lines 676 to 678
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
double dot(double, double);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surprised to see this here based on the description. Wonder if we could hand hold readers a bit more to explain why this was removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I think this can stay in once this PR lands: #106811

Comment on lines 147 to 172
// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }

// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }

// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }

// CHECK: %dx.dot = fmul double %0, %1
// CHECK: ret double %dx.dot
double test_dot_double(double p0, double p1) { return dot(p0, p1); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these no longer valid cases? So the float p0, float2 p1 for example generates an error now, rather than implicitly truncating p1 to a float?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these are not valid, but also we really shouldn't be testing overload resolution in the intrinsic tests because that just makes the fragile to other changes. The important part of these tests was to verify that the functions codegen to correct code, verifying overload resolution in each of these tests is just excess duplication of test coverage.

Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few style nitpicks. LG!

for (unsigned I = 0; I != NumElts; ++I)
Mask.push_back(I);
if (auto *VecTy = DestTy->getAs<VectorType>()) {
SmallVector<int, 16> Mask;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was already here, but while I see it: better to use SmallVector<int> and let it pick the size here

Comment on lines 4839 to 4849
if (auto *ToVec = ToType->getAs<VectorType>()) {
QualType TruncTy =
Context.getExtVectorType(ElType, ToVec->getNumElements());
From = ImpCastExprToType(From, TruncTy, CK_HLSLVectorTruncation,
From->getValueKind())
.get();
} else {
From = ImpCastExprToType(From, ElType, CK_HLSLVectorTruncation,
From->getValueKind())
.get();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arguably more readable to do something like

Suggested change
if (auto *ToVec = ToType->getAs<VectorType>()) {
QualType TruncTy =
Context.getExtVectorType(ElType, ToVec->getNumElements());
From = ImpCastExprToType(From, TruncTy, CK_HLSLVectorTruncation,
From->getValueKind())
.get();
} else {
From = ImpCastExprToType(From, ElType, CK_HLSLVectorTruncation,
From->getValueKind())
.get();
}
QualType ToType = ElType;
if (auto *ToVec = ToType->getAs<VectorType>())
ToType = Context.getExtVectorType(ElType, ToVec->getNumElements());
From = ImpCastExprToType(From, ToType, CK_HLSLVectorTruncation,
From->getValueKind())
.get();

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the naming on this because ToType is already used, so this caused shadowing, but otherwise I adopted this structure.

return true;
return IsVectorElementConversion(S, FromElTy, ToElTy, ICK, From);
}
if (FromExtType && nullptr == ToExtType) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit, just using ! is how this is usually done.

Suggested change
if (FromExtType && nullptr == ToExtType) {
if (FromExtType && !ToExtType) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yoda conditions are a shockingly hard habit to break...

Copy link
Collaborator

@efriedma-quic efriedma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExprConstant.cpp currently assumes that an CK_HLSLVectorTruncation can't return a scalar type.

This fixes the truncation cast constant evaluation for vectors,
integers and floats.
@llvm-beanz
Copy link
Collaborator Author

ExprConstant.cpp currently assumes that an CK_HLSLVectorTruncation can't return a scalar type.

Thank you for catching this! I've updated the PR and included a test that constant evaluates some vector truncations in static asserts.

@llvm-beanz llvm-beanz merged commit a29afb7 into llvm:main Sep 11, 2024
8 checks passed
mikaelholmen added a commit that referenced this pull request Sep 12, 2024
Last use of the variable was removed in a29afb7
 [HLSL] Allow truncation to scalar (#104844)

gcc warned about this:
  ../../clang/lib/Sema/SemaOverload.cpp:2070:15: warning: unused variable 'FromExtType' [-Wunused-variable]
  2070 |     if (auto *FromExtType = FromType->getAs<ExtVectorType>()) {
       |               ^~~~~~~~~~~
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

[HLSL] Enable truncation to scalar
5 participants