Skip to content

Commit

Permalink
[Thrust] Use no sync exec policy and caching allocator (apache#16386)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored Jan 11, 2024
1 parent e2e33dd commit 5d4c01e
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 122 deletions.
245 changes: 123 additions & 122 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,48 +21,55 @@
* \file Use external Thrust library call
*/

#include <dlpack/dlpack.h>
#include <thrust/detail/caching_allocator.h>
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/sort.h>
#include <thrust/gather.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>

#include <thrust/sort.h>
#include <tvm/runtime/registry.h>
#include <dlpack/dlpack.h>

#include <algorithm>
#include <vector>
#include <functional>
#include <vector>

#include "../../cuda/cuda_common.h"

namespace tvm {
namespace contrib {

using namespace runtime;

auto get_thrust_exec_policy() {
return thrust::cuda::par_nosync(thrust::detail::single_device_tls_caching_allocator())
.on(GetCUDAStream());
}

// Performs sorting along axis -1 and returns both sorted values and indices.
template<typename DataType, typename IndicesType>
void thrust_sort(DLTensor* input,
DLTensor* out_values,
DLTensor* out_indices,
bool is_ascend,
template <typename DataType, typename IndicesType>
void thrust_sort(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, bool is_ascend,
int n_values) {
thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data));
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data));
thrust::device_ptr<DataType> data_ptr(static_cast<DataType*>(input->data));
thrust::device_ptr<DataType> values_ptr(static_cast<DataType*>(out_values->data));
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType*>(out_indices->data));

auto policy = get_thrust_exec_policy();

size_t size = 1;
for (int i = 0; i < input->ndim; ++i) {
size *= input->shape[i];
}
thrust::copy(data_ptr, data_ptr + size, values_ptr);
thrust::copy(policy, data_ptr, data_ptr + size, values_ptr);

if (size == static_cast<size_t>(input->shape[input->ndim - 1])) {
// A fast path for single segment case
thrust::sequence(indices_ptr, indices_ptr + n_values);
if (is_ascend) {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
thrust::sort_by_key(policy, values_ptr, values_ptr + n_values, indices_ptr);
} else {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr,
thrust::sort_by_key(policy, values_ptr, values_ptr + n_values, indices_ptr,
thrust::greater<DataType>());
}
} else {
Expand All @@ -74,9 +81,9 @@ void thrust_sort(DLTensor* input,

// First, sort values and store the sorted order in argsort_order.
if (is_ascend) {
thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin());
thrust::stable_sort_by_key(policy, values_ptr, values_ptr + size, argsort_order.begin());
} else {
thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin(),
thrust::stable_sort_by_key(policy, values_ptr, values_ptr + size, argsort_order.begin(),
thrust::greater<DataType>());
}

Expand All @@ -85,36 +92,33 @@ void thrust_sort(DLTensor* input,
auto counting_iter = thrust::counting_iterator<int64_t>(0);
auto linear_index_to_sort_axis_index = [n_values] __host__ __device__(int64_t i) {
return i % n_values;
}; // NOLINT(*)
auto init_indices_iter = thrust::make_transform_iterator(counting_iter,
linear_index_to_sort_axis_index);
}; // NOLINT(*)
auto init_indices_iter =
thrust::make_transform_iterator(counting_iter, linear_index_to_sort_axis_index);

// This will reorder indices 0, 1, 2 ... in the sorted order of values_ptr
thrust::gather(argsort_order.begin(), argsort_order.end(), init_indices_iter, indices_ptr);
thrust::gather(policy, argsort_order.begin(), argsort_order.end(), init_indices_iter,
indices_ptr);

thrust::device_vector<int> segment_ids(size);
auto linear_index_to_segment_id = [n_values] __host__ __device__(int64_t i) {
return i / n_values;
}; // NOLINT(*)
}; // NOLINT(*)
// We also reorder segment indices 0, 0, 0, 1, 1, 1 ... in the order of values_ptr
thrust::transform(argsort_order.begin(), argsort_order.end(), segment_ids.begin(),
thrust::transform(policy, argsort_order.begin(), argsort_order.end(), segment_ids.begin(),
linear_index_to_segment_id);

// The second sort key-ed by segment_ids would bring segment_ids back to 0, 0, 0, 1, 1, 1 ...
// values_ptr and indices_ptr will also be sorted in the order of segmend_ids above
// Since sorting has been done in a stable way, relative orderings of values and indices
// in the segment do not change and hence they remain sorted.
auto key_val_zip = thrust::make_zip_iterator(thrust::make_tuple(values_ptr, indices_ptr));
thrust::stable_sort_by_key(segment_ids.begin(), segment_ids.end(), key_val_zip);
thrust::stable_sort_by_key(policy, segment_ids.begin(), segment_ids.end(), key_val_zip);
}
}

void thrust_sort_common(DLTensor* input,
DLTensor* values_out,
DLTensor* indices_out,
bool is_ascend,
int sort_len,
std::string data_dtype,
void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices_out,
bool is_ascend, int sort_len, std::string data_dtype,
std::string out_dtype) {
if (data_dtype == "float32") {
if (out_dtype == "int32") {
Expand Down Expand Up @@ -152,7 +156,7 @@ void thrust_sort_common(DLTensor* input,
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "int64") {
Expand All @@ -169,8 +173,7 @@ void thrust_sort_common(DLTensor* input,
}
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort").set_body([](TVMArgs args, TVMRetValue* ret) {
ICHECK_GE(args.num_args, 4);
DLTensor* input = args[0];
DLTensor* values_out = args[1];
Expand All @@ -181,97 +184,94 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
auto out_dtype = DLDataType2String(indices_out->dtype);

int n_values = input->shape[input->ndim - 1];
thrust_sort_common(input, values_out, indices_out, is_ascend, n_values,
data_dtype, out_dtype);
thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype);
});

template<typename KeyType, typename ValueType>
void thrust_stable_sort_by_key(DLTensor* keys_in,
DLTensor* values_in,
DLTensor* keys_out,
DLTensor* values_out,
bool for_scatter) {
template <typename KeyType, typename ValueType>
void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* keys_out,
DLTensor* values_out, bool for_scatter) {
const auto size = keys_in->shape[0];
thrust::device_ptr<KeyType> keys_in_ptr(static_cast<KeyType *>(keys_in->data));
thrust::device_ptr<ValueType> values_in_ptr(static_cast<ValueType *>(values_in->data));
thrust::device_ptr<KeyType> keys_out_ptr(static_cast<KeyType *>(keys_out->data));
thrust::device_ptr<ValueType> values_out_ptr(static_cast<ValueType *>(values_out->data));
thrust::device_ptr<KeyType> keys_in_ptr(static_cast<KeyType*>(keys_in->data));
thrust::device_ptr<ValueType> values_in_ptr(static_cast<ValueType*>(values_in->data));
thrust::device_ptr<KeyType> keys_out_ptr(static_cast<KeyType*>(keys_out->data));
thrust::device_ptr<ValueType> values_out_ptr(static_cast<ValueType*>(values_out->data));

auto policy = get_thrust_exec_policy();

if (for_scatter) {
thrust::transform(keys_in_ptr, keys_in_ptr + size, keys_out_ptr, [size] __device__(KeyType k) {
if (k < 0) return k + static_cast<KeyType>(size);
return k;
});
thrust::transform(policy, keys_in_ptr, keys_in_ptr + size, keys_out_ptr,
[size] __device__(KeyType k) {
if (k < 0) return k + static_cast<KeyType>(size);
return k;
});
} else {
thrust::copy(keys_in_ptr, keys_in_ptr + size, keys_out_ptr);
thrust::copy(policy, keys_in_ptr, keys_in_ptr + size, keys_out_ptr);
}
thrust::copy(values_in_ptr, values_in_ptr + size, values_out_ptr);
thrust::copy(policy, values_in_ptr, values_in_ptr + size, values_out_ptr);

thrust::stable_sort_by_key(keys_out_ptr, keys_out_ptr + size, values_out_ptr);
thrust::stable_sort_by_key(policy, keys_out_ptr, keys_out_ptr + size, values_out_ptr);
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ICHECK_GE(args.num_args, 5);
DLTensor* keys_in = args[0];
DLTensor* values_in = args[1];
DLTensor* keys_out = args[2];
DLTensor* values_out = args[3];
bool for_scatter = args[4];

auto key_dtype = DLDataType2String(keys_in->dtype);
auto value_dtype = DLDataType2String(values_in->dtype);

if (key_dtype == "int32") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<int, int64_t>(keys_in, values_in, keys_out, values_out,
.set_body([](TVMArgs args, TVMRetValue* ret) {
ICHECK_GE(args.num_args, 5);
DLTensor* keys_in = args[0];
DLTensor* values_in = args[1];
DLTensor* keys_out = args[2];
DLTensor* values_out = args[3];
bool for_scatter = args[4];

auto key_dtype = DLDataType2String(keys_in->dtype);
auto value_dtype = DLDataType2String(values_in->dtype);

if (key_dtype == "int32") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else if (key_dtype == "int64") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int64_t, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<int64_t, int64_t>(keys_in, values_in, keys_out, values_out,
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<int, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int64_t, float>(keys_in, values_in, keys_out, values_out,
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else if (key_dtype == "float32") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<float, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<float, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<float, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else {
LOG(FATAL) << "Unsupported key dtype: " << key_dtype;
}
});
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else if (key_dtype == "int64") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int64_t, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<int64_t, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int64_t, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else if (key_dtype == "float32") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<float, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<float, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<float, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else {
LOG(FATAL) << "Unsupported key dtype: " << key_dtype;
}
});

template<typename InType, typename OutType>
void thrust_scan(DLTensor* data,
DLTensor* output,
bool exclusive) {
thrust::device_ptr<InType> data_ptr(static_cast<InType *>(data->data));
thrust::device_ptr<OutType> output_ptr(static_cast<OutType *>(output->data));
template <typename InType, typename OutType>
void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive) {
thrust::device_ptr<InType> data_ptr(static_cast<InType*>(data->data));
thrust::device_ptr<OutType> output_ptr(static_cast<OutType*>(output->data));
const auto scan_size = data->shape[data->ndim - 1];

if (scan_size == 0) return;
Expand All @@ -281,19 +281,20 @@ void thrust_scan(DLTensor* data,

const bool need_cast = std::is_same<InType, OutType>::value == false;

auto data_cast_ptr = thrust::make_transform_iterator(data_ptr, [] __host__ __device__(InType v) {
return static_cast<OutType>(v);
}); // NOLINT(*)
auto data_cast_ptr = thrust::make_transform_iterator(
data_ptr, [] __host__ __device__(InType v) { return static_cast<OutType>(v); }); // NOLINT(*)

auto policy = get_thrust_exec_policy();

if (size == static_cast<size_t>(data->shape[data->ndim - 1])) {
if (exclusive && need_cast) {
thrust::exclusive_scan(data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
thrust::exclusive_scan(policy, data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
} else if (exclusive && !need_cast) {
thrust::exclusive_scan(data_ptr, data_ptr + scan_size, output_ptr);
thrust::exclusive_scan(policy, data_ptr, data_ptr + scan_size, output_ptr);
} else if (!exclusive && need_cast) {
thrust::inclusive_scan(data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
thrust::inclusive_scan(policy, data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
} else {
thrust::inclusive_scan(data_ptr, data_ptr + scan_size, output_ptr);
thrust::inclusive_scan(policy, data_ptr, data_ptr + scan_size, output_ptr);
}
} else {
// Use thrust segmented scan to compute scan on the inner most axis
Expand All @@ -305,18 +306,18 @@ void thrust_scan(DLTensor* data,
auto counting_iter = thrust::counting_iterator<size_t>(0);
// Without __host__ annotation, cub crashes
auto linear_index_to_scan_key = [scan_size] __host__ __device__(size_t i) {
return i / scan_size;
}; // NOLINT(*)
return i / scan_size;
}; // NOLINT(*)
auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key);

if (exclusive && need_cast) {
thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_cast_ptr, output_ptr);
thrust::exclusive_scan_by_key(policy, key_iter, key_iter + size, data_cast_ptr, output_ptr);
} else if (exclusive && !need_cast) {
thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr);
thrust::exclusive_scan_by_key(policy, key_iter, key_iter + size, data_ptr, output_ptr);
} else if (!exclusive && need_cast) {
thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_cast_ptr, output_ptr);
thrust::inclusive_scan_by_key(policy, key_iter, key_iter + size, data_cast_ptr, output_ptr);
} else {
thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr);
thrust::inclusive_scan_by_key(policy, key_iter, key_iter + size, data_ptr, output_ptr);
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/cuda/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class CUDAThreadEntry {
// get the threadlocal workspace
static CUDAThreadEntry* ThreadLocal();
};

inline cudaStream_t GetCUDAStream() { return CUDAThreadEntry::ThreadLocal()->stream; }

} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_CUDA_CUDA_COMMON_H_

0 comments on commit 5d4c01e

Please sign in to comment.