@@ -41,6 +41,7 @@ namespace occa {
41
41
device::device (const occa::properties &properties_) :
42
42
occa::modeDevice_t (properties_) {
43
43
44
+ hipDeviceProp_t props;
44
45
if (!properties.has (" wrapped" )) {
45
46
OCCA_ERROR (" [HIP] device not given a [device_id] integer" ,
46
47
properties.has (" device_id" ) &&
@@ -53,6 +54,9 @@ namespace occa {
53
54
54
55
OCCA_HIP_ERROR (" Device: Creating Context" ,
55
56
hipCtxCreate (&hipContext, 0 , hipDevice));
57
+
58
+ OCCA_HIP_ERROR (" Getting device properties" ,
59
+ hipGetDeviceProperties (&props, deviceID));
56
60
}
57
61
58
62
p2pEnabled = false ;
@@ -82,6 +86,7 @@ namespace occa {
82
86
83
87
archMajorVersion = properties.get (" hip/arch/major" , archMajorVersion);
84
88
archMinorVersion = properties.get (" hip/arch/minor" , archMinorVersion);
89
+ properties[" kernel/target" ] = toString (props.gcnArch );
85
90
86
91
properties[" kernel/verbose" ] = properties.get (" verbose" , false );
87
92
}
@@ -339,6 +344,17 @@ namespace occa {
339
344
lock);
340
345
}
341
346
347
+ void device::setArchCompilerFlags (occa::properties &kernelProps) {
348
+ if (kernelProps.get <std::string>(" compiler_flags" ).find (" -t gfx" ) == std::string::npos) {
349
+ std::stringstream ss;
350
+ std::string arch = kernelProps[" target" ];
351
+ if (arch.size ()) {
352
+ ss << " -t gfx" << arch << ' ' ;
353
+ kernelProps[" compiler_flags" ] += ss.str ();
354
+ }
355
+ }
356
+ }
357
+
342
358
void device::compileKernel (const std::string &hashDir,
343
359
const std::string &kernelName,
344
360
occa::properties &kernelProps,
@@ -350,6 +366,8 @@ namespace occa {
350
366
std::string binaryFilename = hashDir + kc::binaryFile;
351
367
const std::string ptxBinaryFilename = hashDir + " ptx_binary.o" ;
352
368
369
+ setArchCompilerFlags (kernelProps);
370
+
353
371
std::stringstream command;
354
372
355
373
// ---[ Compiling Command ]--------
@@ -358,7 +376,7 @@ namespace occa {
358
376
<< " --genco "
359
377
<< " " << sourceFilename
360
378
<< " -o " << binaryFilename
361
- << ' ' << kernelProps[" compilerFlags " ]
379
+ << ' ' << kernelProps[" compiler_flags " ]
362
380
#if (OCCA_OS == OCCA_WINDOWS_OS)
363
381
<< " -D OCCA_OS=OCCA_WINDOWS_OS -D _MSC_VER=1800"
364
382
#endif
0 commit comments