@@ -43,6 +43,8 @@ namespace id = itanium_demangle;
43
43
44
44
#define SLM_BTI 254
45
45
46
+ #define MAX_DIMS 3
47
+
46
48
namespace {
47
49
SmallPtrSet<Type *, 4 > collectGenXVolatileTypes (Module &);
48
50
void generateKernelMetadata (Module &);
@@ -846,145 +848,131 @@ static Instruction *addCastInstIfNeeded(Instruction *OldI, Instruction *NewI) {
846
848
auto CastOpcode = CastInst::getCastOpcode (NewI, false , OITy, false );
847
849
NewI = CastInst::Create (CastOpcode, NewI, OITy,
848
850
NewI->getName () + " .cast.ty" , OldI);
851
+ NewI->setDebugLoc (OldI->getDebugLoc ());
849
852
}
850
853
return NewI;
851
854
}
852
855
853
- static int getIndexForSuffix (StringRef Suff) {
854
- return llvm::StringSwitch<int >(Suff)
855
- .Case (" x" , 0 )
856
- .Case (" y" , 1 )
857
- .Case (" z" , 2 )
858
- .Default (-1 );
856
+ // / Returns the index from the given extract element instruction \p EEI.
857
+ // / It is checked here that the index is either 0, 1, or 2.
858
+ static uint64_t getIndexFromExtract (ExtractElementInst *EEI) {
859
+ Value *IndexV = EEI->getIndexOperand ();
860
+ uint64_t IndexValue = cast<ConstantInt>(IndexV)->getZExtValue ();
861
+ assert (IndexValue < MAX_DIMS &&
862
+ " Extract element index should be either 0, 1, or 2" );
863
+ return IndexValue;
859
864
}
860
865
861
- // Helper function to convert extractelement instruction associated with the
862
- // load from SPIRV builtin global, into the GenX intrinsic that returns vector
863
- // of coordinates. It also generates required extractelement and cast
864
- // instructions. Example:
865
- // %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast
866
- // (<3 x i64> addrspace(1)* @__spirv_BuiltInLocalInvocationId
867
- // to <3 x i64> addrspace(4)*), align 32
868
- // %1 = extractelement <3 x i64> %0, i64 0
869
- //
870
- // =>
871
- //
872
- // %.esimd = call <3 x i32> @llvm.genx.local.id.v3i32()
873
- // %local_id.x = extractelement <3 x i32> %.esimd, i32 0
874
- // %local_id.x.cast.ty = zext i32 %local_id.x to i64
875
- static Instruction *generateVectorGenXForSpirv (ExtractElementInst *EEI,
876
- StringRef Suff,
877
- const std::string &IntrinName,
878
- StringRef ValueName) {
879
- std::string IntrName =
880
- std::string (GenXIntrinsic::getGenXIntrinsicPrefix ()) + IntrinName;
881
- auto ID = GenXIntrinsic::lookupGenXIntrinsicID (IntrName);
882
- LLVMContext &Ctx = EEI->getModule ()->getContext ();
883
- Type *I32Ty = Type::getInt32Ty (Ctx);
884
- Function *NewFDecl = GenXIntrinsic::getGenXDeclaration (
885
- EEI->getModule (), ID, {FixedVectorType::get (I32Ty, 3 )});
886
- Instruction *IntrI =
887
- IntrinsicInst::Create (NewFDecl, {}, EEI->getName () + " .esimd" , EEI);
888
- int ExtractIndex = getIndexForSuffix (Suff);
889
- assert (ExtractIndex != -1 && " Extract index is invalid." );
890
- Twine ExtractName = ValueName + Suff;
891
-
892
- Instruction *ExtrI = ExtractElementInst::Create (
893
- IntrI, ConstantInt::get (I32Ty, ExtractIndex), ExtractName, EEI);
894
- Instruction *CastI = addCastInstIfNeeded (EEI, ExtrI);
895
- if (EEI->getDebugLoc ()) {
896
- IntrI->setDebugLoc (EEI->getDebugLoc ());
897
- ExtrI->setDebugLoc (EEI->getDebugLoc ());
898
- // It's OK if ExtrI and CastI is the same instruction
899
- CastI->setDebugLoc (EEI->getDebugLoc ());
866
+ // / Generates the call of GenX intrinsic \p IntrinName and inserts it
867
+ // / right before the given extract element instruction \p EEI using the result
868
+ // / of vector load. The parameter \p IsVectorCall tells what version of GenX
869
+ // / intrinsic (scalar or vector) to use to lower the load from SPIRV global.
870
+ static Instruction *generateGenXCall (ExtractElementInst *EEI,
871
+ StringRef IntrinName, bool IsVectorCall) {
872
+ uint64_t IndexValue = getIndexFromExtract (EEI);
873
+ std::string Suffix =
874
+ IsVectorCall
875
+ ? " .v3i32"
876
+ : (Twine (" ." ) + Twine (static_cast <char >(' x' + IndexValue))).str ();
877
+ std::string FullIntrinName = (Twine (GenXIntrinsic::getGenXIntrinsicPrefix ()) +
878
+ Twine (IntrinName) + Suffix)
879
+ .str ();
880
+ auto ID = GenXIntrinsic::lookupGenXIntrinsicID (FullIntrinName);
881
+ Type *I32Ty = Type::getInt32Ty (EEI->getModule ()->getContext ());
882
+ Function *NewFDecl =
883
+ IsVectorCall
884
+ ? GenXIntrinsic::getGenXDeclaration (
885
+ EEI->getModule (), ID, FixedVectorType::get (I32Ty, MAX_DIMS))
886
+ : GenXIntrinsic::getGenXDeclaration (EEI->getModule (), ID);
887
+
888
+ std::string ResultName =
889
+ (Twine (EEI->getNameOrAsOperand ()) + " ." + FullIntrinName).str ();
890
+ Instruction *Inst = IntrinsicInst::Create (NewFDecl, {}, ResultName, EEI);
891
+ Inst->setDebugLoc (EEI->getDebugLoc ());
892
+
893
+ if (IsVectorCall) {
894
+ Type *I32Ty = Type::getInt32Ty (EEI->getModule ()->getContext ());
895
+ std::string ExtractName =
896
+ (Twine (Inst->getNameOrAsOperand ()) + " .ext." + Twine (IndexValue)).str ();
897
+ Inst = ExtractElementInst::Create (Inst, ConstantInt::get (I32Ty, IndexValue),
898
+ ExtractName, EEI);
899
+ Inst->setDebugLoc (EEI->getDebugLoc ());
900
900
}
901
- return CastI;
901
+ Inst = addCastInstIfNeeded (EEI, Inst);
902
+ return Inst;
902
903
}
903
904
904
- // Helper function to convert extractelement instruction associated with the
905
- // load from SPIRV builtin global, into the GenX intrinsic. It also generates
906
- // required cast instructions. Example:
907
- // %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64>
908
- // addrspace(1)* @__spirv_BuiltInWorkgroupId to <3 x i64> addrspace(4)*), align
909
- // 32 %1 = extractelement <3 x i64> %0, i64 0
910
- // =>
911
- // %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64>
912
- // addrspace(1)* @__spirv_BuiltInWorkgroupId to <3 x i64> addrspace(4)*), align
913
- // 32 %group.id.x = call i32 @llvm.genx.group.id.x() %group.id.x.cast.ty = zext
914
- // i32 %group.id.x to i64
915
- static Instruction *generateGenXForSpirv (ExtractElementInst *EEI,
916
- StringRef Suff,
917
- const std::string &IntrinName) {
918
- std::string IntrName = std::string (GenXIntrinsic::getGenXIntrinsicPrefix ()) +
919
- IntrinName + Suff.str ();
920
- auto ID = GenXIntrinsic::lookupGenXIntrinsicID (IntrName);
921
- Function *NewFDecl =
922
- GenXIntrinsic::getGenXDeclaration (EEI->getModule (), ID, {});
923
-
924
- Instruction *IntrI =
925
- IntrinsicInst::Create (NewFDecl, {}, IntrinName + Suff.str (), EEI);
926
- Instruction *CastI = addCastInstIfNeeded (EEI, IntrI);
927
- if (EEI->getDebugLoc ()) {
928
- IntrI->setDebugLoc (EEI->getDebugLoc ());
929
- // It's OK if IntrI and CastI is the same instruction
930
- CastI->setDebugLoc (EEI->getDebugLoc ());
905
+ // / Replaces the load \p LI of SPIRV global with corresponding call(s) of GenX
906
+ // / intrinsic(s). The users of \p LI may also be transformed if needed for
907
+ // / def/use type correctness.
908
+ // / The replaced instructions are stored into the given container
909
+ // / \p InstsToErase.
910
+ static void
911
+ translateSpirvGlobalUses (LoadInst *LI, StringRef SpirvGlobalName,
912
+ SmallVectorImpl<Instruction *> &InstsToErase) {
913
+ // TODO: Implement support for the following intrinsics:
914
+ // uint32_t __spirv_BuiltIn NumSubgroups;
915
+ // uint32_t __spirv_BuiltIn SubgroupId;
916
+
917
+ // Translate those loads from _scalar_ SPIRV globals that can be replaced with
918
+ // a const value here.
919
+ // The loads from other scalar SPIRV globals may require insertion of GenX
920
+ // calls before each user, which is done in the loop by users of 'LI' below.
921
+ Value *NewInst = nullptr ;
922
+ if (SpirvGlobalName == " SubgroupLocalInvocationId" ) {
923
+ NewInst = llvm::Constant::getNullValue (LI->getType ());
924
+ } else if (SpirvGlobalName == " SubgroupSize" ||
925
+ SpirvGlobalName == " SubgroupMaxSize" ) {
926
+ NewInst = llvm::Constant::getIntegerValue (LI->getType (),
927
+ llvm::APInt (32 , 1 , true ));
928
+ }
929
+ if (NewInst) {
930
+ LI->replaceAllUsesWith (NewInst);
931
+ InstsToErase.push_back (LI);
932
+ return ;
931
933
}
932
- return CastI;
933
- }
934
934
935
- // This function translates one occurence of SPIRV builtin use into GenX
936
- // intrinsic.
937
- static Value *translateSpirvGlobalUse (ExtractElementInst *EEI,
938
- StringRef SpirvGlobalName) {
939
- Value *IndexV = EEI->getIndexOperand ();
940
- assert (isa<ConstantInt>(IndexV) &&
941
- " Extract element index should be a constant" );
935
+ // Only loads from _vector_ SPIRV globals reach here now. Their users are
936
+ // expected to be ExtractElementInst only, and they are replaced in this loop.
937
+ // When loads from _scalar_ SPIRV globals are handled here as well, the users
938
+ // will not be replaced by new instructions, but the GenX call replacing the
939
+ // original load 'LI' should be inserted before each user.
940
+ for (User *LU : LI->users ()) {
941
+ ExtractElementInst *EEI = cast<ExtractElementInst>(LU);
942
+ NewInst = nullptr ;
943
+
944
+ if (SpirvGlobalName == " WorkgroupSize" ) {
945
+ NewInst = generateGenXCall (EEI, " local.size" , true );
946
+ } else if (SpirvGlobalName == " LocalInvocationId" ) {
947
+ NewInst = generateGenXCall (EEI, " local.id" , true );
948
+ } else if (SpirvGlobalName == " WorkgroupId" ) {
949
+ NewInst = generateGenXCall (EEI, " group.id" , false );
950
+ } else if (SpirvGlobalName == " GlobalInvocationId" ) {
951
+ // GlobalId = LocalId + WorkGroupSize * GroupId
952
+ Instruction *LocalIdI = generateGenXCall (EEI, " local.id" , true );
953
+ Instruction *WGSizeI = generateGenXCall (EEI, " local.size" , true );
954
+ Instruction *GroupIdI = generateGenXCall (EEI, " group.id" , false );
955
+ Instruction *MulI =
956
+ BinaryOperator::CreateMul (WGSizeI, GroupIdI, " mul" , EEI);
957
+ NewInst = BinaryOperator::CreateAdd (LocalIdI, MulI, " add" , EEI);
958
+ } else if (SpirvGlobalName == " GlobalSize" ) {
959
+ // GlobalSize = WorkGroupSize * NumWorkGroups
960
+ Instruction *WGSizeI = generateGenXCall (EEI, " local.size" , true );
961
+ Instruction *NumWGI = generateGenXCall (EEI, " group.count" , true );
962
+ NewInst = BinaryOperator::CreateMul (WGSizeI, NumWGI, " mul" , EEI);
963
+ } else if (SpirvGlobalName == " GlobalOffset" ) {
964
+ // TODO: Support GlobalOffset SPIRV intrinsics
965
+ // Currently all users of load of GlobalOffset are replaced with 0.
966
+ NewInst = llvm::Constant::getNullValue (EEI->getType ());
967
+ } else if (SpirvGlobalName == " NumWorkgroups" ) {
968
+ NewInst = generateGenXCall (EEI, " group.count" , true );
969
+ }
942
970
943
- // Get the suffix based on the index of extractelement instruction
944
- ConstantInt *IndexC = cast<ConstantInt>(IndexV);
945
- std::string Suff;
946
- if (IndexC->equalsInt (0 ))
947
- Suff = ' x' ;
948
- else if (IndexC->equalsInt (1 ))
949
- Suff = ' y' ;
950
- else if (IndexC->equalsInt (2 ))
951
- Suff = ' z' ;
952
- else
953
- assert (false && " Extract element index should be either 0, 1, or 2" );
954
-
955
- // Translate SPIRV into GenX intrinsic.
956
- if (SpirvGlobalName == " WorkgroupSize" ) {
957
- return generateVectorGenXForSpirv (EEI, Suff, " local.size.v3i32" , " wgsize." );
958
- } else if (SpirvGlobalName == " LocalInvocationId" ) {
959
- return generateVectorGenXForSpirv (EEI, Suff, " local.id.v3i32" , " local_id." );
960
- } else if (SpirvGlobalName == " WorkgroupId" ) {
961
- return generateGenXForSpirv (EEI, Suff, " group.id." );
962
- } else if (SpirvGlobalName == " GlobalInvocationId" ) {
963
- // GlobalId = LocalId + WorkGroupSize * GroupId
964
- Instruction *LocalIdI =
965
- generateVectorGenXForSpirv (EEI, Suff, " local.id.v3i32" , " local_id." );
966
- Instruction *WGSizeI =
967
- generateVectorGenXForSpirv (EEI, Suff, " local.size.v3i32" , " wgsize." );
968
- Instruction *GroupIdI = generateGenXForSpirv (EEI, Suff, " group.id." );
969
- Instruction *MulI =
970
- BinaryOperator::CreateMul (WGSizeI, GroupIdI, " mul" , EEI);
971
- return BinaryOperator::CreateAdd (LocalIdI, MulI, " add" , EEI);
972
- } else if (SpirvGlobalName == " GlobalSize" ) {
973
- // GlobalSize = WorkGroupSize * NumWorkGroups
974
- Instruction *WGSizeI =
975
- generateVectorGenXForSpirv (EEI, Suff, " local.size.v3i32" , " wgsize." );
976
- Instruction *NumWGI = generateVectorGenXForSpirv (
977
- EEI, Suff, " group.count.v3i32" , " group_count." );
978
- return BinaryOperator::CreateMul (WGSizeI, NumWGI, " mul" , EEI);
979
- } else if (SpirvGlobalName == " GlobalOffset" ) {
980
- // TODO: Support GlobalOffset SPIRV intrinsics
981
- return llvm::Constant::getNullValue (EEI->getType ());
982
- } else if (SpirvGlobalName == " NumWorkgroups" ) {
983
- return generateVectorGenXForSpirv (EEI, Suff, " group.count.v3i32" ,
984
- " group_count." );
971
+ assert (NewInst && " Load from global SPIRV builtin was not translated" );
972
+ EEI->replaceAllUsesWith (NewInst);
973
+ InstsToErase.push_back (EEI);
985
974
}
986
-
987
- return nullptr ;
975
+ InstsToErase.push_back (LI);
988
976
}
989
977
990
978
static void createESIMDIntrinsicArgs (const ESIMDIntrinDesc &Desc,
@@ -1370,8 +1358,7 @@ SmallPtrSet<Type *, 4> collectGenXVolatileTypes(Module &M) {
1370
1358
1371
1359
} // namespace
1372
1360
1373
- PreservedAnalyses SYCLLowerESIMDPass::run (Module &M,
1374
- ModuleAnalysisManager &) {
1361
+ PreservedAnalyses SYCLLowerESIMDPass::run (Module &M, ModuleAnalysisManager &) {
1375
1362
generateKernelMetadata (M);
1376
1363
SmallPtrSet<Type *, 4 > GVTS = collectGenXVolatileTypes (M);
1377
1364
@@ -1507,23 +1494,11 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,
1507
1494
1508
1495
auto PrefLen = StringRef (SPIRV_INTRIN_PREF).size ();
1509
1496
1510
- // Go through all the uses of the load instruction from SPIRV builtin
1511
- // globals, which are required to be extractelement instructions.
1512
- // Translate each of them.
1513
- for (auto *LU : LI->users ()) {
1514
- auto *EEI = dyn_cast<ExtractElementInst>(LU);
1515
- assert (EEI && " User of load from global SPIRV builtin is not an "
1516
- " extractelement instruction" );
1517
- Value *TranslatedVal = translateSpirvGlobalUse (
1518
- EEI, SpirvGlobal->getName ().drop_front (PrefLen));
1519
- assert (TranslatedVal &&
1520
- " Load from global SPIRV builtin was not translated" );
1521
- EEI->replaceAllUsesWith (TranslatedVal);
1522
- ESIMDToErases.push_back (EEI);
1523
- }
1524
- // After all users of load were translated, we get rid of the load
1525
- // itself.
1526
- ESIMDToErases.push_back (LI);
1497
+ // Translate all uses of the load instruction from SPIRV builtin global.
1498
+ // Replaces the original global load and it is uses and stores the old
1499
+ // instructions to ESIMDToErases.
1500
+ translateSpirvGlobalUses (LI, SpirvGlobal->getName ().drop_front (PrefLen),
1501
+ ESIMDToErases);
1527
1502
}
1528
1503
}
1529
1504
// Now demangle and translate found ESIMD intrinsic calls
0 commit comments