diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index cff72f58f69a3..a2caff6d68c24 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,13 +22,13 @@ * \file rocm_device_api.cc * \brief GPU specific API */ -#include - #include #include -#include #include #include +#include +#include + #include "rocm_common.h" namespace tvm { @@ -36,9 +36,7 @@ namespace runtime { class ROCMDeviceAPI final : public DeviceAPI { public: - void SetDevice(TVMContext ctx) final { - ROCM_CALL(hipSetDevice(ctx.device_id)); - } + void SetDevice(TVMContext ctx) final { ROCM_CALL(hipSetDevice(ctx.device_id)); } void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { int value = 0; switch (kind) { @@ -54,35 +52,59 @@ class ROCMDeviceAPI final : public DeviceAPI { break; } case kMaxThreadsPerBlock: { - value = 1024; + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id)); break; } case kWarpSize: { - value = 64; + ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize, ctx.device_id)); + break; + } + case kMaxSharedMemoryPerBlock: { + ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeMaxSharedMemoryPerBlock, + ctx.device_id)); break; } - case kMaxSharedMemoryPerBlock: return; case kComputeVersion: { - hipDeviceProp_t prop; - ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id)); - *rv = prop.gcnArch; + std::ostringstream os; + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id)); + os << value << "."; + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id)); + os << value; + *rv = os.str(); + return; + } + case kDeviceName: + return; + case kMaxClockRate: { + ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate, ctx.device_id)); + break; + } + case kMultiProcessorCount: { + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeMultiprocessorCount, ctx.device_id)); + break; + } + case kMaxThreadDimensions: { + int dims[3]; + ROCM_CALL(hipDeviceGetAttribute(&dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id)); + + std::stringstream ss; + ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]"; + *rv = ss.str(); return; } - case kDeviceName: return; - case kMaxClockRate: return; - case kMultiProcessorCount: return; - case kMaxThreadDimensions: return; } *rv = value; } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, - TVMType type_hint) final { + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) final { ROCM_CALL(hipSetDevice(ctx.device_id)); - CHECK_EQ(256 % alignment, 0U) - << "ROCM space is aligned at 256 bytes"; - void *ret; + CHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes"; + void* ret; ROCM_CALL(hipMalloc(&ret, nbytes)); return ret; } @@ -92,14 +114,8 @@ class ROCMDeviceAPI final : public DeviceAPI { ROCM_CALL(hipFree(ptr)); } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - TVMType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, TVMType type_hint, TVMStreamHandle stream) final { hipStream_t hip_stream = static_cast(stream); from = static_cast(from) + from_offset; @@ -109,9 +125,7 @@ class ROCMDeviceAPI final : public DeviceAPI { if (ctx_from.device_id == ctx_to.device_id) { GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream); } else { - hipMemcpyPeerAsync(to, ctx_to.device_id, - from, ctx_from.device_id, - size, hip_stream); + hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, hip_stream); } } else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) { ROCM_CALL(hipSetDevice(ctx_from.device_id)); @@ -130,8 +144,7 @@ class ROCMDeviceAPI final : public DeviceAPI { } void SetStream(TVMContext ctx, TVMStreamHandle stream) final { - ROCMThreadEntry::ThreadLocal() - ->stream = static_cast(stream); + ROCMThreadEntry::ThreadLocal()->stream = static_cast(stream); } void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final { @@ -143,16 +156,12 @@ class ROCMDeviceAPI final : public DeviceAPI { } static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } private: - static void GPUCopy(const void* from, - void* to, - size_t size, - hipMemcpyKind kind, + static void GPUCopy(const void* from, void* to, size_t size, hipMemcpyKind kind, hipStream_t stream) { if (stream != 0) { ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream)); @@ -164,19 +173,14 @@ class ROCMDeviceAPI final : public DeviceAPI { typedef dmlc::ThreadLocalStore ROCMThreadStore; -ROCMThreadEntry::ROCMThreadEntry() - : pool(kDLROCM, ROCMDeviceAPI::Global()) { -} +ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {} -ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { - return ROCMThreadStore::Get(); -} +ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.rocm") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = ROCMDeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = ROCMDeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); } // namespace runtime } // namespace tvm