Skip to content

Commit

Permalink
Decoupling raft handle from underlying resources (#1111)
Browse files Browse the repository at this point in the history
This implements a design idea a few of us have been kicking around for a little while now to help decouple underlying resources from the raft handle and also allow users to never have to explicitly include headers for resources that are never used (such as cublas, cusolver, cusparse, comms, etc...). 

This effectively breaks the existing raft::handle_t into separate headers for the various resources it contains, providing functions that can be individually included and invoked on a `raft::resources`. This still allows us to write something like a `raft::device_resources` (and also allows us to maintain API compatibility in the meantime by backing the existing `raft::handle_t` with a `raft::resources`. 

One of the major goals of this PR is to also enable a handle to be used outside of just cuda resources and to allow for unused resources to not need to be loaded nor compiled at all into user code downstream. 


Follow-on work after this PR will include:
1. Updating all of RAFT's public functions to accept `raft::resources` and using the individual resource accessors instead of assuming `device_resources` everywhere. 
2. Deprecating the `handle_t` in favor of the more explicit `device_resources`

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Divye Gala (https://github.com/divyegala)
  - Dante Gama Dessavre (https://github.com/dantegd)
  - William Hicks (https://github.com/wphicks)
  - Ben Frederickson (https://github.com/benfred)

URL: #1111
  • Loading branch information
cjnolet authored Jan 10, 2023
1 parent de7d361 commit 2c97abe
Show file tree
Hide file tree
Showing 61 changed files with 2,284 additions and 503 deletions.
3 changes: 2 additions & 1 deletion build.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

# Copyright (c) 2020-2022, NVIDIA CORPORATION.
# Copyright (c) 2020-2023, NVIDIA CORPORATION.

# raft build script

Expand Down Expand Up @@ -153,6 +153,7 @@ function limitTests {
# Remove the full LIMIT_TEST_TARGETS argument from list of args so that it passes validArgs function
ARGS=${ARGS//--limit-tests=$LIMIT_TEST_TARGETS/}
TEST_TARGETS=${LIMIT_TEST_TARGETS}
echo "Limiting tests to $TEST_TARGETS"
fi
fi
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/comms/detail/test.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/core/comms.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@

#pragma once

#include <cuda_runtime.h>
#include <memory>
#include <raft/core/error.hpp>
#include <vector>
Expand Down
241 changes: 241 additions & 0 deletions cpp/include/raft/core/device_resources.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __RAFT_DEVICE_RESOURCES
#define __RAFT_DEVICE_RESOURCES

#pragma once

#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cusolverDn.h>
#include <cusolverSp.h>
#include <cusparse.h>

#include <raft/core/comms.hpp>
#include <rmm/cuda_stream_pool.hpp>
#include <rmm/exec_policy.hpp>

#include <raft/core/resource/comms.hpp>
#include <raft/core/resource/cublas_handle.hpp>
#include <raft/core/resource/cuda_event.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/cuda_stream_pool.hpp>
#include <raft/core/resource/cusolver_dn_handle.hpp>
#include <raft/core/resource/cusolver_sp_handle.hpp>
#include <raft/core/resource/cusparse_handle.hpp>
#include <raft/core/resource/device_id.hpp>
#include <raft/core/resource/device_properties.hpp>
#include <raft/core/resource/sub_comms.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>

namespace raft {

/**
* @brief Main resource container object that stores all necessary resources
* used for calling necessary device functions, cuda kernels and/or libraries
*/
class device_resources : public resources {
public:
// delete copy/move constructors and assignment operators as
// copying and moving underlying resources is unsafe
device_resources(const device_resources&) = delete;
device_resources& operator=(const device_resources&) = delete;
device_resources(device_resources&&) = delete;
device_resources& operator=(device_resources&&) = delete;

/**
* @brief Construct a resources instance with a stream view and stream pool
*
* @param[in] stream_view the default stream (which has the default per-thread stream if
* unspecified)
* @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified)
*/
device_resources(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread,
std::shared_ptr<rmm::cuda_stream_pool> stream_pool = {nullptr})
: resources{}
{
resources::add_resource_factory(std::make_shared<resource::device_id_resource_factory>());
resources::add_resource_factory(
std::make_shared<resource::cuda_stream_resource_factory>(stream_view));
resources::add_resource_factory(
std::make_shared<resource::cuda_stream_pool_resource_factory>(stream_pool));
}

/** Destroys all held-up resources */
virtual ~device_resources() {}

int get_device() const { return resource::get_device_id(*this); }

cublasHandle_t get_cublas_handle() const { return resource::get_cublas_handle(*this); }

cusolverDnHandle_t get_cusolver_dn_handle() const
{
return resource::get_cusolver_dn_handle(*this);
}

cusolverSpHandle_t get_cusolver_sp_handle() const
{
return resource::get_cusolver_sp_handle(*this);
}

cusparseHandle_t get_cusparse_handle() const { return resource::get_cusparse_handle(*this); }

rmm::exec_policy& get_thrust_policy() const { return resource::get_thrust_policy(*this); }

/**
* @brief synchronize a stream on the current container
*/
void sync_stream(rmm::cuda_stream_view stream) const { resource::sync_stream(*this, stream); }

/**
* @brief synchronize main stream on the current container
*/
void sync_stream() const { resource::sync_stream(*this); }

/**
* @brief returns main stream on the current container
*/
rmm::cuda_stream_view get_stream() const { return resource::get_cuda_stream(*this); }

/**
* @brief returns whether stream pool was initialized on the current container
*/

bool is_stream_pool_initialized() const { return resource::is_stream_pool_initialized(*this); }

/**
* @brief returns stream pool on the current container
*/
const rmm::cuda_stream_pool& get_stream_pool() const
{
return resource::get_cuda_stream_pool(*this);
}

std::size_t get_stream_pool_size() const { return resource::get_stream_pool_size(*this); }

/**
* @brief return stream from pool
*/
rmm::cuda_stream_view get_stream_from_stream_pool() const
{
return resource::get_stream_from_stream_pool(*this);
}

/**
* @brief return stream from pool at index
*/
rmm::cuda_stream_view get_stream_from_stream_pool(std::size_t stream_idx) const
{
return resource::get_stream_from_stream_pool(*this, stream_idx);
}

/**
* @brief return stream from pool if size > 0, else main stream on current container
*/
rmm::cuda_stream_view get_next_usable_stream() const
{
return resource::get_next_usable_stream(*this);
}

/**
* @brief return stream from pool at index if size > 0, else main stream on current container
*
* @param[in] stream_idx the required index of the stream in the stream pool if available
*/
rmm::cuda_stream_view get_next_usable_stream(std::size_t stream_idx) const
{
return resource::get_next_usable_stream(*this, stream_idx);
}

/**
* @brief synchronize the stream pool on the current container
*/
void sync_stream_pool() const { return resource::sync_stream_pool(*this); }

/**
* @brief synchronize subset of stream pool
*
* @param[in] stream_indices the indices of the streams in the stream pool to synchronize
*/
void sync_stream_pool(const std::vector<std::size_t> stream_indices) const
{
return resource::sync_stream_pool(*this, stream_indices);
}

/**
* @brief ask stream pool to wait on last event in main stream
*/
void wait_stream_pool_on_stream() const { return resource::wait_stream_pool_on_stream(*this); }

void set_comms(std::shared_ptr<comms::comms_t> communicator)
{
resource::set_comms(*this, communicator);
}

const comms::comms_t& get_comms() const { return resource::get_comms(*this); }

void set_subcomm(std::string key, std::shared_ptr<comms::comms_t> subcomm)
{
resource::set_subcomm(*this, key, subcomm);
}

const comms::comms_t& get_subcomm(std::string key) const
{
return resource::get_subcomm(*this, key);
}

bool comms_initialized() const { return resource::comms_initialized(*this); }

const cudaDeviceProp& get_device_properties() const
{
return resource::get_device_properties(*this);
}
}; // class device_resources

/**
* @brief RAII approach to synchronizing across all streams in the current container
*/
class stream_syncer {
public:
explicit stream_syncer(const device_resources& handle) : handle_(handle)
{
handle_.sync_stream();
}
~stream_syncer()
{
handle_.wait_stream_pool_on_stream();
handle_.sync_stream_pool();
}

stream_syncer(const stream_syncer& other) = delete;
stream_syncer& operator=(const stream_syncer& other) = delete;

private:
const device_resources& handle_;
}; // class stream_syncer

} // namespace raft

#endif
Loading

0 comments on commit 2c97abe

Please sign in to comment.