16
16
#include " llvm/ExecutionEngine/Orc/LLJIT.h"
17
17
#include " llvm/Support/Error.h"
18
18
19
+ #include " mlir/Dialect/DLTI/DLTI.h"
19
20
#include " mlir/Dialect/Func/IR/FuncOps.h"
20
21
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
22
+ #include " mlir/Interfaces/DataLayoutInterfaces.h"
21
23
#include " mlir/Pass/PassManager.h"
22
24
23
25
namespace mlir ::gc::gpu {
@@ -655,7 +657,8 @@ OclModule::~OclModule() {
655
657
// buffers. The function will call the original function with the context,
656
658
// buffers and the offset/shape/strides, statically created from the
657
659
// memref descriptor.
658
- StringRef createStaticMain (ModuleOp &module , const StringRef &funcName,
660
+ StringRef createStaticMain (OpBuilder &builder, ModuleOp &module ,
661
+ const StringRef &funcName,
659
662
const ArrayRef<Type> argTypes) {
660
663
auto mainFunc = module .lookupSymbol <LLVM::LLVMFuncOp>(funcName);
661
664
if (!mainFunc) {
@@ -670,11 +673,8 @@ StringRef createStaticMain(ModuleOp &module, const StringRef &funcName,
670
673
" ' must have an least 3 arguments." );
671
674
}
672
675
673
- auto ctx = module .getContext ();
674
- ctx->getOrLoadDialect <LLVM::LLVMDialect>();
675
- OpBuilder builder (ctx);
676
676
auto i64Type = builder.getI64Type ();
677
- auto ptrType = LLVM::LLVMPointerType::get (ctx );
677
+ auto ptrType = LLVM::LLVMPointerType::get (builder. getContext () );
678
678
679
679
if (mainArgTypes[nargs - 3 ] != ptrType ||
680
680
mainArgTypes[nargs - 2 ] != ptrType ||
@@ -722,7 +722,7 @@ StringRef createStaticMain(ModuleOp &module, const StringRef &funcName,
722
722
auto loc = mainFunc.getLoc ();
723
723
auto newFuncType = LLVM::LLVMFunctionType::get (
724
724
mainFunc.getNumResults () ? mainFunc->getResult (0 ).getType ()
725
- : LLVM::LLVMVoidType::get (ctx ),
725
+ : LLVM::LLVMVoidType::get (builder. getContext () ),
726
726
{ptrType, ptrType});
727
727
auto newFunc =
728
728
OpBuilder::atBlockEnd (module .getBody ())
@@ -848,17 +848,58 @@ OclModuleBuilder::build(cl_device_id device, cl_context context) {
848
848
849
849
llvm::Expected<std::shared_ptr<const OclModule>>
850
850
OclModuleBuilder::build (const OclRuntime::Ext &ext) {
851
- auto mod = mlirModule.clone ();
852
- PassManager pm{mod.getContext ()};
853
- pipeline (pm);
854
- CHECK (!pm.run (mod).failed (), " GPU pipeline failed!" );
851
+ auto ctx = mlirModule.getContext ();
852
+ ctx->getOrLoadDialect <DLTIDialect>();
853
+ ctx->getOrLoadDialect <LLVM::LLVMDialect>();
854
+ OpBuilder builder (ctx);
855
+ DataLayoutEntryInterface dltiAttrs[6 ];
855
856
856
- auto staticMain = createStaticMain (mod, funcName, argTypes);
857
+ {
858
+ struct DevInfo {
859
+ cl_device_info key;
860
+ const char *attrName;
861
+ };
862
+ DevInfo devInfo[]{
863
+ {CL_DEVICE_MAX_COMPUTE_UNITS, " num_exec_units" },
864
+ {CL_DEVICE_NUM_EUS_PER_SUB_SLICE_INTEL, " num_exec_units_per_slice" },
865
+ {CL_DEVICE_NUM_THREADS_PER_EU_INTEL, " num_threads_per_eu" },
866
+ // Assuming the cache size is equal to the local mem
867
+ {CL_DEVICE_LOCAL_MEM_SIZE, " L1_cache_size_in_bytes" },
868
+ };
857
869
858
- if (printIr) {
859
- mod.dump ();
860
- }
870
+ unsigned i = 0 ;
871
+ for (auto &[key, attrName] : devInfo) {
872
+ int64_t value = 0 ;
873
+ CL_CHECK (
874
+ clGetDeviceInfo (ext.device , key, sizeof (cl_ulong), &value, nullptr ),
875
+ " Failed to get the device property " , attrName);
876
+ gcLogD (" Device property " , attrName, " =" , value);
877
+ dltiAttrs[i++] =
878
+ DataLayoutEntryAttr::get (ctx, builder.getStringAttr (attrName),
879
+ builder.getI64IntegerAttr (value));
880
+ }
861
881
882
+ // There is no a corresponding property in the OpenCL API, using the
883
+ // hardcoded value.
884
+ // TODO: Get the real value.
885
+ dltiAttrs[i] = DataLayoutEntryAttr::get (
886
+ ctx, builder.getStringAttr (" max_vector_op_width" ),
887
+ builder.getI64IntegerAttr (512 ));
888
+ }
889
+
890
+ OclRuntime rt (ext);
891
+ auto expectedQueue = rt.createQueue ();
892
+ CHECKE (expectedQueue, " Failed to create queue!" );
893
+ struct OclQueue {
894
+ cl_command_queue queue;
895
+ ~OclQueue () { clReleaseCommandQueue (queue); }
896
+ } queue{*expectedQueue};
897
+ OclContext oclCtx{rt, queue.queue , false };
898
+
899
+ ModuleOp mod;
900
+ StringRef staticMain;
901
+ std::unique_ptr<ExecutionEngine> eng;
902
+ auto devStr = builder.getStringAttr (" GPU" /* device ID*/ );
862
903
ExecutionEngineOptions opts;
863
904
opts.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive;
864
905
opts.enableObjectDump = enableObjectDump;
@@ -868,18 +909,75 @@ OclModuleBuilder::build(const OclRuntime::Ext &ext) {
868
909
opts.enablePerfNotificationListener = false ;
869
910
#endif
870
911
871
- auto eng = ExecutionEngine::create (mod, opts);
872
- CHECKE (eng, " Failed to create ExecutionEngine!" );
873
- eng->get ()->registerSymbols (OclRuntime::Exports::symbolMap);
912
+ // Build the module and check the kernels workgroup size. If the workgroup
913
+ // size is different, rebuild the module with the new size.
914
+ for (size_t wgSize = 64 ;;) {
915
+ dltiAttrs[sizeof (dltiAttrs) / sizeof (DataLayoutEntryInterface) - 1 ] =
916
+ DataLayoutEntryAttr::get (
917
+ ctx, builder.getStringAttr (" max_work_group_size" ),
918
+ builder.getI64IntegerAttr (static_cast <int64_t >(wgSize)));
919
+ TargetDeviceSpecInterface devSpec =
920
+ TargetDeviceSpecAttr::get (ctx, dltiAttrs);
921
+ auto sysSpec =
922
+ TargetSystemSpecAttr::get (ctx, ArrayRef (std::pair (devStr, devSpec)));
923
+ mod = mlirModule.clone ();
924
+ mod.getOperation ()->setAttr (" #dlti.sys_spec" , sysSpec);
925
+ PassManager pm{ctx};
926
+ pipeline (pm);
927
+ CHECK (!pm.run (mod).failed (), " GPU pipeline failed!" );
928
+ staticMain = createStaticMain (builder, mod, funcName, argTypes);
929
+ auto expectedEng = ExecutionEngine::create (mod, opts);
930
+ CHECKE (expectedEng, " Failed to create ExecutionEngine!" );
931
+ expectedEng->get ()->registerSymbols (OclRuntime::Exports::symbolMap);
932
+
933
+ // Find all kernels and query the workgroup size
934
+ size_t minSize = std::numeric_limits<size_t >::max ();
935
+ mod.walk <>([&](LLVM::LLVMFuncOp func) {
936
+ auto name = func.getName ();
937
+ if (!name.starts_with (" createGcGpuOclKernel_" )) {
938
+ return WalkResult::skip ();
939
+ }
940
+ auto fn = expectedEng.get ()->lookup (name);
941
+ if (!fn) {
942
+ gcLogE (" Function not found: " , name.data ());
943
+ return WalkResult::skip ();
944
+ }
945
+
946
+ Kernel *kernel =
947
+ reinterpret_cast <Kernel *(*)(OclContext *)>(fn.get ())(&oclCtx);
948
+ size_t s = 0 ;
949
+ auto err = clGetKernelWorkGroupInfo (kernel->kernel , ext.device ,
950
+ CL_KERNEL_WORK_GROUP_SIZE,
951
+ sizeof (size_t ), &s, nullptr );
952
+ if (err == CL_SUCCESS) {
953
+ minSize = std::min (minSize, s);
954
+ } else {
955
+ gcLogE (" Failed to get the kernel workgroup size: " , err);
956
+ }
957
+ return WalkResult::skip ();
958
+ });
959
+
960
+ if (minSize == std::numeric_limits<size_t >::max () || minSize == wgSize) {
961
+ eng = std::move (*expectedEng);
962
+ break ;
963
+ }
964
+
965
+ gcLogD (" Changing the workgroup size to " , minSize);
966
+ wgSize = minSize;
967
+ }
968
+
969
+ if (printIr) {
970
+ mod.dump ();
971
+ }
874
972
875
973
OclModule::MainFunc main = {nullptr };
876
974
877
975
if (staticMain.empty ()) {
878
- auto expect = eng. get () ->lookupPacked (funcName);
976
+ auto expect = eng->lookupPacked (funcName);
879
977
CHECKE (expect, " Packed function '" , funcName.begin (), " ' not found!" );
880
978
main.wrappedMain = *expect;
881
979
} else {
882
- auto expect = eng. get () ->lookup (staticMain);
980
+ auto expect = eng->lookup (staticMain);
883
981
CHECKE (expect, " Compiled function '" , staticMain.begin (), " ' not found!" );
884
982
main.staticMain = reinterpret_cast <OclModule::StaticMainFunc>(*expect);
885
983
}
@@ -889,8 +987,7 @@ OclModuleBuilder::build(const OclRuntime::Ext &ext) {
889
987
return it->second ;
890
988
}
891
989
std::shared_ptr<const OclModule> ptr (
892
- new OclModule (OclRuntime (ext), !staticMain.empty (), main, argTypes,
893
- std::move (eng.get ())));
990
+ new OclModule (rt, !staticMain.empty (), main, argTypes, std::move (eng)));
894
991
return cache.emplace (OclDevCtxPair (ext.device , ext.context ), ptr)
895
992
.first ->second ;
896
993
}
0 commit comments