Skip to content

Commit d9b2dec

Browse files
authored
Merge pull request #2 from ROCmSoftwarePlatform/mlir-rocblas-succeeds
inc files added
2 parents dc0829b + 8563310 commit d9b2dec

File tree

5 files changed

+113
-92
lines changed

5 files changed

+113
-92
lines changed

third_party/hip/hip_stub.cc.inc

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,6 @@ hipError_t hipRuntimeGetVersion(int* runtimeVersion) {
1212
"hipRuntimeGetVersion", runtimeVersion);
1313
}
1414

15-
hipError_t hipGetLastError(void) {
16-
return DynamicCall<decltype(hipGetLastError), &hipGetLastError>(
17-
"hipGetLastError");
18-
}
19-
20-
hipError_t hipPeekAtLastError(void) {
21-
return DynamicCall<decltype(hipPeekAtLastError), &hipPeekAtLastError>(
22-
"hipPeekAtLastError");
23-
}
24-
2515
hipError_t hipDeviceGet(hipDevice_t* device, int ordinal) {
2616
return DynamicCall<decltype(hipDeviceGet), &hipDeviceGet>("hipDeviceGet",
2717
device, ordinal);
@@ -73,6 +63,16 @@ hipError_t hipDeviceGetLimit(size_t* pValue, enum hipLimit_t limit) {
7363
"hipDeviceGetLimit", pValue, limit);
7464
}
7565

66+
hipError_t hipGetLastError(void) {
67+
return DynamicCall<decltype(hipGetLastError), &hipGetLastError>(
68+
"hipGetLastError");
69+
}
70+
71+
hipError_t hipPeekAtLastError(void) {
72+
return DynamicCall<decltype(hipPeekAtLastError), &hipPeekAtLastError>(
73+
"hipPeekAtLastError");
74+
}
75+
7676
hipError_t hipStreamCreateWithFlags(hipStream_t* stream, unsigned int flags) {
7777
return DynamicCall<decltype(hipStreamCreateWithFlags),
7878
&hipStreamCreateWithFlags>("hipStreamCreateWithFlags",
@@ -209,6 +209,12 @@ hipError_t hipMemcpy(void* dst, const void* src, size_t sizeBytes,
209209
sizeBytes, kind);
210210
}
211211

212+
hipError_t hipModuleGetGlobal(hipDeviceptr_t* dptr, size_t* bytes,
213+
hipModule_t hmod, const char* name) {
214+
return DynamicCall<decltype(hipModuleGetGlobal), &hipModuleGetGlobal>(
215+
"hipModuleGetGlobal", dptr, bytes, hmod, name);
216+
}
217+
212218
hipError_t hipMemcpyAsync(void* dst, const void* src, size_t sizeBytes,
213219
hipMemcpyKind kind, hipStream_t stream __dparm(0)) {
214220
return DynamicCall<decltype(hipMemcpyAsync), &hipMemcpyAsync>(
@@ -376,13 +382,6 @@ hipError_t hipModuleGetFunction(hipFunction_t* function, hipModule_t module,
376382
"hipModuleGetFunction", function, module, kname);
377383
}
378384

379-
hipError_t hipModuleGetGlobal(void** ptr, size_t* bytes, hipModule_t module,
380-
const char* kname) {
381-
return DynamicCall<decltype(hipModuleGetGlobal), &hipModuleGetGlobal>(
382-
"hipModuleGetGlobal", ptr, bytes, module, kname);
383-
}
384-
385-
386385
hipError_t hipFuncGetAttributes(struct hipFuncAttributes* attr,
387386
const void* func) {
388387
return DynamicCall<decltype(hipFuncGetAttributes), &hipFuncGetAttributes>(

third_party/hip/hip_stub.h.inc

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,22 @@ enum hipError_t {
6262
hipErrorPeerAccessAlreadyEnabled = 704,
6363
hipErrorPeerAccessNotEnabled = 705,
6464
hipErrorSetOnActiveProcess = 708,
65+
hipErrorContextIsDestroyed = 709,
6566
hipErrorAssert = 710,
6667
hipErrorHostMemoryAlreadyRegistered = 712,
6768
hipErrorHostMemoryNotRegistered = 713,
6869
hipErrorLaunchFailure = 719,
6970
hipErrorCooperativeLaunchTooLarge = 720,
7071
hipErrorNotSupported = 801,
72+
hipErrorStreamCaptureUnsupported = 900,
73+
hipErrorStreamCaptureInvalidated = 901,
74+
hipErrorStreamCaptureMerge = 902,
75+
hipErrorStreamCaptureUnmatched = 903,
76+
hipErrorStreamCaptureUnjoined = 904,
77+
hipErrorStreamCaptureIsolation = 905,
78+
hipErrorStreamCaptureImplicit = 906,
79+
hipErrorCapturedEvent = 907,
80+
hipErrorStreamCaptureWrongThread = 908,
7181
hipErrorUnknown = 999,
7282
hipErrorRuntimeMemory = 1052,
7383
hipErrorRuntimeOther = 1053,
@@ -153,6 +163,7 @@ enum hipFunction_attribute {
153163
};
154164

155165
enum hipLimit_t {
166+
hipLimitPrintfFifoSize = 0x01,
156167
hipLimitMallocHeapSize = 0x02,
157168
};
158169

@@ -196,10 +207,6 @@ hipError_t hipDriverGetVersion(int* driverVersion);
196207

197208
hipError_t hipRuntimeGetVersion(int* runtimeVersion);
198209

199-
hipError_t hipGetLastError(void);
200-
201-
hipError_t hipPeekAtLastError(void);
202-
203210
hipError_t hipDeviceGet(hipDevice_t* device, int ordinal);
204211

205212
hipError_t hipDeviceGetName(char* name, int len, hipDevice_t device);
@@ -221,6 +228,10 @@ hipError_t hipGetDeviceProperties(hipDeviceProp_t* prop, int deviceId);
221228

222229
hipError_t hipDeviceGetLimit(size_t* pValue, enum hipLimit_t limit);
223230

231+
hipError_t hipGetLastError(void);
232+
233+
hipError_t hipPeekAtLastError(void);
234+
224235
hipError_t hipStreamCreateWithFlags(hipStream_t* stream, unsigned int flags);
225236

226237
hipError_t hipStreamCreateWithPriority(hipStream_t* stream, unsigned int flags,
@@ -278,6 +289,9 @@ hipError_t hipHostFree(void* ptr);
278289
hipError_t hipMemcpy(void* dst, const void* src, size_t sizeBytes,
279290
hipMemcpyKind kind);
280291

292+
hipError_t hipModuleGetGlobal(hipDeviceptr_t* dptr, size_t* bytes,
293+
hipModule_t hmod, const char* name);
294+
281295
hipError_t hipMemcpyAsync(void* dst, const void* src, size_t sizeBytes,
282296
hipMemcpyKind kind, hipStream_t stream __dparm(0));
283297

@@ -353,9 +367,6 @@ hipError_t hipModuleUnload(hipModule_t module);
353367
hipError_t hipModuleGetFunction(hipFunction_t* function, hipModule_t module,
354368
const char* kname);
355369

356-
hipError_t hipModuleGetGlobal(void** ptr, size_t* bytes, hipModule_t module,
357-
const char* kname);
358-
359370
hipError_t hipFuncGetAttributes(struct hipFuncAttributes* attr,
360371
const void* func);
361372

@@ -385,3 +396,12 @@ hipError_t hipOccupancyMaxPotentialBlockSize(int* gridSize, int* blockSize,
385396
const void* f,
386397
size_t dynSharedMemPerBlk,
387398
int blockSizeLimit);
399+
400+
enum hipDataType {
401+
HIP_R_16F = 2,
402+
HIP_R_32F = 0,
403+
HIP_R_64F = 1,
404+
HIP_C_16F = 6,
405+
HIP_C_32F = 4,
406+
HIP_C_64F = 5,
407+
};

third_party/hip/miopen_stub.h.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ typedef enum {
4646
miopenInt8 = 3,
4747
miopenInt8x4 = 4,
4848
miopenBFloat16 = 5,
49+
miopenDouble = 6,
4950
} miopenDataType_t;
5051

5152
typedef enum {

third_party/hip/rocblas_stub.cc.inc

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -34,48 +34,6 @@ ROCBLAS_EXPORT rocblas_status rocblas_get_pointer_mode(
3434
handle, pointer_mode);
3535
}
3636

37-
ROCBLAS_EXPORT rocblas_status rocblas_gemm_ex(
38-
rocblas_handle handle, rocblas_operation transA, rocblas_operation transB,
39-
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
40-
const void* a, rocblas_datatype a_type, rocblas_int lda, const void* b,
41-
rocblas_datatype b_type, rocblas_int ldb, const void* beta, const void* c,
42-
rocblas_datatype c_type, rocblas_int ldc, void* d, rocblas_datatype d_type,
43-
rocblas_int ldd, rocblas_datatype compute_type, rocblas_gemm_algo algo,
44-
int32_t solution_index, uint32_t flags) {
45-
return DynamicCall<decltype(rocblas_gemm_ex), &rocblas_gemm_ex>(
46-
"rocblas_gemm_ex", handle, transA, transB, m, n, k, alpha, a, a_type, lda,
47-
b, b_type, ldb, beta, c, c_type, ldc, d, d_type, ldd, compute_type, algo,
48-
solution_index, flags);
49-
}
50-
51-
ROCBLAS_EXPORT rocblas_status rocblas_gemm_strided_batched_ex(
52-
rocblas_handle handle, rocblas_operation transA, rocblas_operation transB,
53-
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
54-
const void* a, rocblas_datatype a_type, rocblas_int lda,
55-
rocblas_stride stride_a, const void* b, rocblas_datatype b_type,
56-
rocblas_int ldb, rocblas_stride stride_b, const void* beta, const void* c,
57-
rocblas_datatype c_type, rocblas_int ldc, rocblas_stride stride_c, void* d,
58-
rocblas_datatype d_type, rocblas_int ldd, rocblas_stride stride_d,
59-
rocblas_int batch_count, rocblas_datatype compute_type,
60-
rocblas_gemm_algo algo, int32_t solution_index, uint32_t flags) {
61-
return DynamicCall<decltype(rocblas_gemm_strided_batched_ex),
62-
&rocblas_gemm_strided_batched_ex>(
63-
"rocblas_gemm_strided_batched_ex", handle, transA, transB, m, n, k, alpha,
64-
a, a_type, lda, stride_a, b, b_type, ldb, stride_b, beta, c, c_type, ldc,
65-
stride_c, d, d_type, ldd, stride_d, batch_count, compute_type, algo,
66-
solution_index, flags);
67-
}
68-
69-
ROCBLAS_EXPORT rocblas_status rocblas_axpy_ex(
70-
rocblas_handle handle, rocblas_int n, const void* alpha,
71-
rocblas_datatype alpha_type, const void* x, rocblas_datatype x_type,
72-
rocblas_int incx, void* y, rocblas_datatype y_type, rocblas_int incy,
73-
rocblas_datatype execution_type) {
74-
return DynamicCall<decltype(rocblas_axpy_ex), &rocblas_axpy_ex>(
75-
"rocblas_axpy_ex", handle, n, alpha, alpha_type, x, x_type, incx, y,
76-
y_type, incy, execution_type);
77-
}
78-
7937
ROCBLAS_EXPORT rocblas_status rocblas_strsm_batched(
8038
rocblas_handle handle, rocblas_side side, rocblas_fill uplo,
8139
rocblas_operation transA, rocblas_diagonal diag, rocblas_int m,
@@ -120,3 +78,45 @@ ROCBLAS_EXPORT rocblas_status rocblas_ztrsm_batched(
12078
"rocblas_ztrsm_batched", handle, side, uplo, transA, diag, m, n, alpha, A,
12179
lda, B, ldb, batch_count);
12280
}
81+
82+
ROCBLAS_EXPORT rocblas_status rocblas_gemm_ex(
83+
rocblas_handle handle, rocblas_operation transA, rocblas_operation transB,
84+
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
85+
const void* a, rocblas_datatype a_type, rocblas_int lda, const void* b,
86+
rocblas_datatype b_type, rocblas_int ldb, const void* beta, const void* c,
87+
rocblas_datatype c_type, rocblas_int ldc, void* d, rocblas_datatype d_type,
88+
rocblas_int ldd, rocblas_datatype compute_type, rocblas_gemm_algo algo,
89+
int32_t solution_index, uint32_t flags) {
90+
return DynamicCall<decltype(rocblas_gemm_ex), &rocblas_gemm_ex>(
91+
"rocblas_gemm_ex", handle, transA, transB, m, n, k, alpha, a, a_type, lda,
92+
b, b_type, ldb, beta, c, c_type, ldc, d, d_type, ldd, compute_type, algo,
93+
solution_index, flags);
94+
}
95+
96+
ROCBLAS_EXPORT rocblas_status rocblas_gemm_strided_batched_ex(
97+
rocblas_handle handle, rocblas_operation transA, rocblas_operation transB,
98+
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
99+
const void* a, rocblas_datatype a_type, rocblas_int lda,
100+
rocblas_stride stride_a, const void* b, rocblas_datatype b_type,
101+
rocblas_int ldb, rocblas_stride stride_b, const void* beta, const void* c,
102+
rocblas_datatype c_type, rocblas_int ldc, rocblas_stride stride_c, void* d,
103+
rocblas_datatype d_type, rocblas_int ldd, rocblas_stride stride_d,
104+
rocblas_int batch_count, rocblas_datatype compute_type,
105+
rocblas_gemm_algo algo, int32_t solution_index, uint32_t flags) {
106+
return DynamicCall<decltype(rocblas_gemm_strided_batched_ex),
107+
&rocblas_gemm_strided_batched_ex>(
108+
"rocblas_gemm_strided_batched_ex", handle, transA, transB, m, n, k, alpha,
109+
a, a_type, lda, stride_a, b, b_type, ldb, stride_b, beta, c, c_type, ldc,
110+
stride_c, d, d_type, ldd, stride_d, batch_count, compute_type, algo,
111+
solution_index, flags);
112+
}
113+
114+
ROCBLAS_EXPORT rocblas_status rocblas_axpy_ex(
115+
rocblas_handle handle, rocblas_int n, const void* alpha,
116+
rocblas_datatype alpha_type, const void* x, rocblas_datatype x_type,
117+
rocblas_int incx, void* y, rocblas_datatype y_type, rocblas_int incy,
118+
rocblas_datatype execution_type) {
119+
return DynamicCall<decltype(rocblas_axpy_ex), &rocblas_axpy_ex>(
120+
"rocblas_axpy_ex", handle, n, alpha, alpha_type, x, x_type, incx, y,
121+
y_type, incy, execution_type);
122+
}

third_party/hip/rocblas_stub.h.inc

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ typedef enum rocblas_gemm_algo_ {
7171
typedef enum rocblas_gemm_flags_ {
7272
rocblas_gemm_flags_none = 0x0,
7373
rocblas_gemm_flags_pack_int8x4 = 0x1,
74+
rocblas_gemm_flags_use_cu_efficiency = 0x2,
7475
} rocblas_gemm_flags;
7576

7677
ROCBLAS_EXPORT rocblas_status rocblas_create_handle(rocblas_handle* handle);
@@ -89,32 +90,6 @@ ROCBLAS_EXPORT rocblas_status rocblas_set_pointer_mode(
8990
ROCBLAS_EXPORT rocblas_status rocblas_get_pointer_mode(
9091
rocblas_handle handle, rocblas_pointer_mode* pointer_mode);
9192

92-
ROCBLAS_EXPORT rocblas_status rocblas_gemm_ex(
93-
rocblas_handle handle, rocblas_operation transA, rocblas_operation transB,
94-
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
95-
const void* a, rocblas_datatype a_type, rocblas_int lda, const void* b,
96-
rocblas_datatype b_type, rocblas_int ldb, const void* beta, const void* c,
97-
rocblas_datatype c_type, rocblas_int ldc, void* d, rocblas_datatype d_type,
98-
rocblas_int ldd, rocblas_datatype compute_type, rocblas_gemm_algo algo,
99-
int32_t solution_index, uint32_t flags);
100-
101-
ROCBLAS_EXPORT rocblas_status rocblas_gemm_strided_batched_ex(
102-
rocblas_handle handle, rocblas_operation transA, rocblas_operation transB,
103-
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
104-
const void* a, rocblas_datatype a_type, rocblas_int lda,
105-
rocblas_stride stride_a, const void* b, rocblas_datatype b_type,
106-
rocblas_int ldb, rocblas_stride stride_b, const void* beta, const void* c,
107-
rocblas_datatype c_type, rocblas_int ldc, rocblas_stride stride_c, void* d,
108-
rocblas_datatype d_type, rocblas_int ldd, rocblas_stride stride_d,
109-
rocblas_int batch_count, rocblas_datatype compute_type,
110-
rocblas_gemm_algo algo, int32_t solution_index, uint32_t flags);
111-
112-
ROCBLAS_EXPORT rocblas_status rocblas_axpy_ex(
113-
rocblas_handle handle, rocblas_int n, const void* alpha,
114-
rocblas_datatype alpha_type, const void* x, rocblas_datatype x_type,
115-
rocblas_int incx, void* y, rocblas_datatype y_type, rocblas_int incy,
116-
rocblas_datatype execution_type);
117-
11893
ROCBLAS_EXPORT rocblas_status rocblas_strsm_batched(
11994
rocblas_handle handle, rocblas_side side, rocblas_fill uplo,
12095
rocblas_operation transA, rocblas_diagonal diag, rocblas_int m,
@@ -142,3 +117,29 @@ ROCBLAS_EXPORT rocblas_status rocblas_ztrsm_batched(
142117
const rocblas_double_complex* const A[], rocblas_int lda,
143118
rocblas_double_complex* const B[], rocblas_int ldb,
144119
rocblas_int batch_count);
120+
121+
ROCBLAS_EXPORT rocblas_status rocblas_gemm_ex(
122+
rocblas_handle handle, rocblas_operation transA, rocblas_operation transB,
123+
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
124+
const void* a, rocblas_datatype a_type, rocblas_int lda, const void* b,
125+
rocblas_datatype b_type, rocblas_int ldb, const void* beta, const void* c,
126+
rocblas_datatype c_type, rocblas_int ldc, void* d, rocblas_datatype d_type,
127+
rocblas_int ldd, rocblas_datatype compute_type, rocblas_gemm_algo algo,
128+
int32_t solution_index, uint32_t flags);
129+
130+
ROCBLAS_EXPORT rocblas_status rocblas_gemm_strided_batched_ex(
131+
rocblas_handle handle, rocblas_operation transA, rocblas_operation transB,
132+
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
133+
const void* a, rocblas_datatype a_type, rocblas_int lda,
134+
rocblas_stride stride_a, const void* b, rocblas_datatype b_type,
135+
rocblas_int ldb, rocblas_stride stride_b, const void* beta, const void* c,
136+
rocblas_datatype c_type, rocblas_int ldc, rocblas_stride stride_c, void* d,
137+
rocblas_datatype d_type, rocblas_int ldd, rocblas_stride stride_d,
138+
rocblas_int batch_count, rocblas_datatype compute_type,
139+
rocblas_gemm_algo algo, int32_t solution_index, uint32_t flags);
140+
141+
ROCBLAS_EXPORT rocblas_status rocblas_axpy_ex(
142+
rocblas_handle handle, rocblas_int n, const void* alpha,
143+
rocblas_datatype alpha_type, const void* x, rocblas_datatype x_type,
144+
rocblas_int incx, void* y, rocblas_datatype y_type, rocblas_int incy,
145+
rocblas_datatype execution_type);

0 commit comments

Comments
 (0)