Skip to content

[DirectX] Scalarize Allocas as part of data scalarization #140165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

farzonl
Copy link
Member

@farzonl farzonl commented May 16, 2025

- DXILDataScalarization should not just be limited to global data
- Add a scalarization for alloca
- Add ReversePostOrderTraversal of functions and iterate over basic
  blocks and run DataScalarizerVisitor.
- fixes llvm#140143
@llvmbot
Copy link
Member

llvmbot commented May 16, 2025

@llvm/pr-subscribers-backend-directx

Author: Farzon Lotfi (farzonl)

Changes
  • DXILDataScalarization should not just be limited to global data
  • Add a scalarization for alloca
  • Add ReversePostOrderTraversal of functions and iterate over basic blocks and run DataScalarizerVisitor.
  • fixes #140143

Full diff: https://github.com/llvm/llvm-project/pull/140165.diff

3 Files Affected:

  • (modified) llvm/lib/Target/DirectX/DXILDataScalarization.cpp (+54-29)
  • (modified) llvm/test/CodeGen/DirectX/scalar-bug-117273.ll (+12-6)
  • (added) llvm/test/CodeGen/DirectX/scalarize-alloca.ll (+10)
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 1f2700ac55647..1209bcdfb2891 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -10,6 +10,7 @@
 #include "DirectX.h"
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstVisitor.h"
@@ -40,9 +41,10 @@ static bool findAndReplaceVectors(Module &M);
 class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
 public:
   DataScalarizerVisitor() : GlobalMap() {}
-  bool visit(Instruction &I);
+  bool visit(Function &F);
   // InstVisitor methods.  They return true if the instruction was scalarized,
   // false if nothing changed.
+  bool visitAllocaInst(AllocaInst &AI);
   bool visitInstruction(Instruction &I) { return false; }
   bool visitSelectInst(SelectInst &SI) { return false; }
   bool visitICmpInst(ICmpInst &ICI) { return false; }
@@ -65,11 +67,17 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
 private:
   GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
   DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
+  static bool isArrayOfVectors(Type *T);
 };
 
-bool DataScalarizerVisitor::visit(Instruction &I) {
-  assert(!GlobalMap.empty());
-  return InstVisitor::visit(I);
+bool DataScalarizerVisitor::visit(Function &F) {
+  bool MadeChange = false;
+  ReversePostOrderTraversal<Function *> RPOT(&F);
+  for (BasicBlock *BB : make_early_inc_range(RPOT)) {
+    for (Instruction &I : make_early_inc_range(*BB))
+      MadeChange |= InstVisitor::visit(I);
+  }
+  return MadeChange;
 }
 
 GlobalVariable *
@@ -83,6 +91,42 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
   return nullptr; // Not found
 }
 
+// Recursively Creates and Array like version of the given vector like type.
+static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
+  if (auto *VecTy = dyn_cast<VectorType>(T))
+    return ArrayType::get(VecTy->getElementType(),
+                          dyn_cast<FixedVectorType>(VecTy)->getNumElements());
+  if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
+    Type *NewElementType =
+        replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
+    return ArrayType::get(NewElementType, ArrayTy->getNumElements());
+  }
+  // If it's not a vector or array, return the original type.
+  return T;
+}
+
+bool DataScalarizerVisitor::isArrayOfVectors(Type *T) {
+  if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
+    return isa<VectorType>(ArrType->getElementType());
+  return false;
+}
+
+bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
+  if (!isArrayOfVectors(AI.getAllocatedType()))
+    return false;
+
+  ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
+  IRBuilder<> Builder(&AI);
+  LLVMContext &Ctx = AI.getContext();
+  Type *NewType = replaceVectorWithArray(ArrType, Ctx);
+  AllocaInst *ArrAlloca =
+      Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize");
+  ArrAlloca->setAlignment(AI.getAlign());
+  AI.replaceAllUsesWith(ArrAlloca);
+  AI.eraseFromParent();
+  return true;
+}
+
 bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
   unsigned NumOperands = LI.getNumOperands();
   for (unsigned I = 0; I < NumOperands; ++I) {
@@ -154,20 +198,6 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
   return true;
 }
 
-// Recursively Creates and Array like version of the given vector like type.
-static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
-  if (auto *VecTy = dyn_cast<VectorType>(T))
-    return ArrayType::get(VecTy->getElementType(),
-                          dyn_cast<FixedVectorType>(VecTy)->getNumElements());
-  if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
-    Type *NewElementType =
-        replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
-    return ArrayType::get(NewElementType, ArrayTy->getNumElements());
-  }
-  // If it's not a vector or array, return the original type.
-  return T;
-}
-
 Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
                                LLVMContext &Ctx) {
   // Handle ConstantAggregateZero (zero-initialized constants)
@@ -253,20 +283,15 @@ static bool findAndReplaceVectors(Module &M) {
       // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
       // type equality. Instead we will use the visitor pattern.
       Impl.GlobalMap[&G] = NewGlobal;
-      for (User *U : make_early_inc_range(G.users())) {
-        if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
-          ConstantExpr *CE = cast<ConstantExpr>(U);
-          for (User *UCE : make_early_inc_range(CE->users())) {
-            if (Instruction *Inst = dyn_cast<Instruction>(UCE))
-              Impl.visit(*Inst);
-          }
-        }
-        if (Instruction *Inst = dyn_cast<Instruction>(U))
-          Impl.visit(*Inst);
-      }
     }
   }
 
+  for (auto &F : make_early_inc_range(M.functions())) {
+    if (F.isDeclaration())
+      continue;
+    MadeChange |= Impl.visit(F);
+  }
+
   // Remove the old globals after the iteration
   for (auto &[Old, New] : Impl.GlobalMap) {
     Old->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
index 25dc2c36b4e1f..2676abec1d8ae 100644
--- a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
+++ b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
@@ -8,12 +8,18 @@
 define internal void @main() #1 {
 ; CHECK-LABEL: define internal void @main() {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[DOTI0:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), align 16
-; CHECK-NEXT:    [[DOTI1:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), i32 1), align 4
-; CHECK-NEXT:    [[DOTI2:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), i32 2), align 8
-; CHECK-NEXT:    [[DOTI01:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), align 16
-; CHECK-NEXT:    [[DOTI12:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), i32 1), align 4
-; CHECK-NEXT:    [[DOTI23:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), i32 2), align 8
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 1
+; CHECK-NEXT:    [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16
+; CHECK-NEXT:    [[DOTI1:%.*]] = getelementptr float, ptr [[TMP0]], i32 1
+; CHECK-NEXT:    [[DOTI11:%.*]] = load float, ptr [[DOTI1]], align 4
+; CHECK-NEXT:    [[DOTI2:%.*]] = getelementptr float, ptr [[TMP0]], i32 2
+; CHECK-NEXT:    [[DOTI22:%.*]] = load float, ptr [[DOTI2]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 2
+; CHECK-NEXT:    [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16
+; CHECK-NEXT:    [[DOTI14:%.*]] = getelementptr float, ptr [[TMP1]], i32 1
+; CHECK-NEXT:    [[DOTI15:%.*]] = load float, ptr [[DOTI14]], align 4
+; CHECK-NEXT:    [[DOTI26:%.*]] = getelementptr float, ptr [[TMP1]], i32 2
+; CHECK-NEXT:    [[DOTI27:%.*]] = load float, ptr [[DOTI26]], align 8
 ; CHECK-NEXT:    ret void
 ;
 entry:
diff --git a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
new file mode 100644
index 0000000000000..4829f3a31791f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
@@ -0,0 +1,10 @@
+; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=SCHECK
+; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=FCHECK
+
+; CHECK-LABEL: alloca_2d__vec_test
+define void @alloca_2d__vec_test() local_unnamed_addr #2 {
+  ; SCHECK:  alloca [2 x [4 x i32]], align 16
+  ; FCHECK:  alloca [8 x i32], align 16
+  %1 = alloca [2 x <4 x i32>], align 16
+  ret void
+}

@@ -8,12 +8,18 @@
define internal void @main() #1 {
; CHECK-LABEL: define internal void @main() {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), align 16
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was an unexpected test update, Both the before and after both look right.

This behavior seems to have something to do with the fact that we aren't walking uses and doing cast<ConstantExpr>(U); anymore. Instrad walking instructons in basic blocks via for (Instruction &I : make_early_inc_range(*BB)) seems to force the ConstantExpr to instructions so updated these tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little surprising but it seems fine. Do we also have tests where the GEP was a constant expression in the first place to make sure that that's still doing the right thing?

@farzonl farzonl self-assigned this May 16, 2025
@farzonl farzonl moved this to Active in HLSL Support May 16, 2025
@@ -83,6 +91,42 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
return nullptr; // Not found
}

// Recursively Creates and Array like version of the given vector like type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 'creates an array'

@@ -65,11 +67,17 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
private:
GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
static bool isArrayOfVectors(Type *T);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why this is static?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a great reason. I tend to make things static when I want to ensure helpers remain stateless. It doesn't need access to the this pointer and I don't want to allow method overriding. The only thing I thought might be useful was the private visibility modifier.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it need to be a member function of DataScalarizerVisitor at all? I think it'd be clearer to just make this a static freestanding function.

@@ -83,6 +91,42 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
return nullptr; // Not found
}

// Recursively Creates and Array like version of the given vector like type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I know the original comment read like this but I think it may as well be changed now to make more sense.

Suggested change
// Recursively Creates and Array like version of the given vector like type.
// Recursively creates an array-like version of the given vector type.

Copy link
Contributor

@Icohedron Icohedron left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just had a couple of nits.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Active
Development

Successfully merging this pull request may close these issues.

[DirectX] Arrays of vectors remain in alloca and global variables after scalarization
5 participants