Skip to content

[HIP] Hip adapter multi dev ctx #999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions source/adapters/hip/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,13 @@ ur_context_handle_t_::getOwningURPool(umf_memory_pool_t *UMFPool) {
UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
uint32_t DeviceCount, const ur_device_handle_t *phDevices,
const ur_context_properties_t *, ur_context_handle_t *phContext) {
std::ignore = DeviceCount;
assert(DeviceCount == 1);
ur_result_t RetErr = UR_RESULT_SUCCESS;

std::unique_ptr<ur_context_handle_t_> ContextPtr{nullptr};
try {
// Create a scoped context.
ContextPtr = std::unique_ptr<ur_context_handle_t_>(
new ur_context_handle_t_{*phDevices});
new ur_context_handle_t_{phDevices, DeviceCount});

static std::once_flag InitFlag;
std::call_once(
Expand Down Expand Up @@ -78,9 +76,9 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,

switch (uint32_t{propName}) {
case UR_CONTEXT_INFO_NUM_DEVICES:
return ReturnValue(1);
return ReturnValue(static_cast<uint32_t>(hContext->Devices.size()));
case UR_CONTEXT_INFO_DEVICES:
return ReturnValue(hContext->getDevice());
return ReturnValue(hContext->getDevices());
case UR_CONTEXT_INFO_REFERENCE_COUNT:
return ReturnValue(hContext->getReferenceCount());
case UR_CONTEXT_INFO_ATOMIC_MEMORY_ORDER_CAPABILITIES:
Expand Down Expand Up @@ -124,8 +122,10 @@ urContextRetain(ur_context_handle_t hContext) {

UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
ur_context_handle_t hContext, ur_native_handle_t *phNativeContext) {
// FIXME: this entry point has been deprecated in the SYCL RT and should be
// changed to unsupported once the deprecation period has elapsed
*phNativeContext = reinterpret_cast<ur_native_handle_t>(
hContext->getDevice()->getNativeContext());
hContext->getDevices()[0]->getNativeContext());
return UR_RESULT_SUCCESS;
}

Expand Down
61 changes: 40 additions & 21 deletions source/adapters/hip/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,26 @@ typedef void (*ur_context_extended_deleter_t)(void *UserData);
///
/// One of the main differences between the UR API and the HIP driver API is
/// that the second modifies the state of the threads by assigning
/// `hipCtx_t` objects to threads. `hipCtx_t` objects store data associated
/// \c hipCtx_t objects to threads. \c hipCtx_t objects store data associated
/// with a given device and control access to said device from the user side.
/// UR API context are objects that are passed to functions, and not bound
/// to threads.
/// The ur_context_handle_t_ object doesn't implement this behavior. It only
/// holds the HIP context data. The RAII object \ref ScopedContext implements
/// the active context behavior.
///
/// <b> Primary vs UserDefined context </b>
/// Since the \c ur_context_handle_t can contain multiple devices, and a \c
/// hipCtx_t refers to only a single device, the \c hipCtx_t is more tightly
/// coupled to a \c ur_device_handle_t than a \c ur_context_handle_t. In order
/// to remove some ambiguities about the different semantics of \c
/// \c ur_context_handle_t and native \c hipCtx_t, we access the native \c
/// hipCtx_t solely through the \c ur_device_handle_t class, by using the object
/// \ref ScopedContext, which sets the active device (by setting the active
/// native \c hipCtx_t).
///
/// HIP has two different types of context, the Primary context,
/// which is usable by all threads on a given process for a given device, and
/// the aforementioned custom contexts.
/// The HIP documentation, and performance analysis, suggest using the Primary
/// context whenever possible. The Primary context is also used by the HIP
/// Runtime API. For UR applications to interop with HIP Runtime API, they have
/// to use the primary context - and make that active in the thread. The
/// `ur_context_handle_t_` object can be constructed with a `kind` parameter
/// that allows to construct a Primary or `UserDefined` context, so that
/// the UR object interface is always the same.
/// <b> Primary vs User-defined \c hipCtx_t </b>
///
/// HIP has two different types of \c hipCtx_t, the Primary context, which is
/// usable by all threads on a given process for a given device, and the
/// aforementioned custom \c hipCtx_t s. The HIP documentation, confirmed with
/// performance analysis, suggest using the Primary context whenever possible.
///
/// <b> Destructor callback </b>
///
Expand All @@ -57,6 +57,16 @@ typedef void (*ur_context_extended_deleter_t)(void *UserData);
/// See proposal for details.
/// https://github.com/codeplaysoftware/standards-proposals/blob/master/extended-context-destruction/index.md
///
/// <b> Memory Management for Devices in a Context <\b>
///
/// A \c ur_mem_handle_t is associated with a \c ur_context_handle_t_, which
/// may refer to multiple devices. Therefore the \c ur_mem_handle_t must
/// handle a native allocation for each device in the context. UR is
/// responsible for automatically handling event dependencies for kernels
/// writing to or reading from the same \c ur_mem_handle_t and migrating memory
/// between native allocations for devices in the same \c ur_context_handle_t_
/// if necessary.
///
struct ur_context_handle_t_ {

struct deleter_data {
Expand All @@ -68,15 +78,22 @@ struct ur_context_handle_t_ {

using native_type = hipCtx_t;

ur_device_handle_t DeviceId;
std::vector<ur_device_handle_t> Devices;

std::atomic_uint32_t RefCount;

ur_context_handle_t_(ur_device_handle_t DevId)
: DeviceId{DevId}, RefCount{1} {
urDeviceRetain(DeviceId);
ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
: Devices{Devs, Devs + NumDevices}, RefCount{1} {
for (auto &Dev : Devices) {
urDeviceRetain(Dev);
}
};

~ur_context_handle_t_() { urDeviceRelease(DeviceId); }
~ur_context_handle_t_() {
for (auto &Dev : Devices) {
urDeviceRelease(Dev);
}
}

void invokeExtendedDeleters() {
std::lock_guard<std::mutex> Guard(Mutex);
Expand All @@ -91,7 +108,9 @@ struct ur_context_handle_t_ {
ExtendedDeleters.emplace_back(deleter_data{Function, UserData});
}

ur_device_handle_t getDevice() const noexcept { return DeviceId; }
const std::vector<ur_device_handle_t> &getDevices() const noexcept {
return Devices;
}

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }

Expand Down
11 changes: 8 additions & 3 deletions source/adapters/hip/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ struct ur_device_handle_t_ {
std::atomic_uint32_t RefCount;
ur_platform_handle_t Platform;
hipCtx_t HIPContext;
uint32_t DeviceIndex;

public:
ur_device_handle_t_(native_type HipDevice, hipCtx_t Context,
ur_platform_handle_t Platform)
ur_platform_handle_t Platform, uint32_t DeviceIndex)
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform),
HIPContext(Context) {}
HIPContext(Context), DeviceIndex(DeviceIndex) {}

~ur_device_handle_t_() {
UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDevice));
Expand All @@ -42,7 +43,11 @@ struct ur_device_handle_t_ {

ur_platform_handle_t getPlatform() const noexcept { return Platform; };

hipCtx_t getNativeContext() { return HIPContext; };
hipCtx_t getNativeContext() const noexcept { return HIPContext; };

// Returns the index of the device relative to the other devices in the same
// platform
uint32_t getIndex() const noexcept { return DeviceIndex; };
};

int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute);
Loading