diff --git a/tensorflow/core/framework/embedding/embedding_var.h b/tensorflow/core/framework/embedding/embedding_var.h index 9d224fb0d6f..1cdbcfcbd89 100644 --- a/tensorflow/core/framework/embedding/embedding_var.h +++ b/tensorflow/core/framework/embedding/embedding_var.h @@ -43,9 +43,6 @@ namespace tensorflow { EventMgr* event_mgr); #endif //GOOGLE_CUDA -namespace { -const char* kInferenceMode = "INFERENCE_MODE"; -} template class GPUHashTable; @@ -632,6 +629,13 @@ class EmbeddingVar : public ResourceBase { storage_->BatchLookupOrCreateKeys(key, item_idxs, n, device); } + void Lookup(const K* key, V* val, V* default_v, + int32 default_v_num, bool is_use_default_value_tensor, + size_t n, const Eigen::GpuDevice& device) { + storage_->BatchLookup(key, val, default_v, default_v_num, + is_use_default_value_tensor, n, device); + } + int32 SlotNum() { return (emb_config_.block_num * (1 + emb_config_.slot_num)); } diff --git a/tensorflow/core/framework/embedding/gpu_hash_map_kv.h b/tensorflow/core/framework/embedding/gpu_hash_map_kv.h index 1b4ca32f689..82edf045f60 100644 --- a/tensorflow/core/framework/embedding/gpu_hash_map_kv.h +++ b/tensorflow/core/framework/embedding/gpu_hash_map_kv.h @@ -16,50 +16,60 @@ limitations under the License. #define TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_GPU_HASH_MAP_KV_H_ #if GOOGLE_CUDA -#include "tensorflow/core/framework/embedding/kv_interface.h" + #include "tensorflow/core/framework/embedding/gpu_hash_table.h" +#include "tensorflow/core/framework/embedding/kv_interface.h" +#include "tensorflow/core/util/env_var.h" namespace tensorflow { + namespace embedding { -template +template class GPUHashMapKV : public KVInterface { public: GPUHashMapKV(const EmbeddingConfig& config, Allocator* alloc) - : config_(config), alloc_(alloc) { - hash_table_ = new GPUHashTable(-1, alloc); + : config_(config), alloc_(alloc), static_hash_table_(nullptr) { + TF_CHECK_OK(ReadBoolFromEnvVar(kInferenceMode, false, &is_inference_)); + if (!is_inference_) { + hash_table_ = new GPUHashTable(-1, alloc); + } } ~GPUHashMapKV() override { - for (int i = 0; i < hash_table_->bank_ptrs.size(); ++i) { - TypedAllocator::Deallocate( - alloc_, hash_table_->bank_ptrs[i], - value_len_ * hash_table_->initial_bank_size); - TypedAllocator::Deallocate( - alloc_, hash_table_->existence_flag_ptrs[i], - hash_table_->initial_bank_size); - } - if (hash_table_->mem_bank_num != 0) { - auto num_elements = hash_table_->mem_bank_num * - (config_.block_num * (1 + config_.slot_num)); + if (is_inference_) { TypedAllocator::Deallocate( - alloc_, hash_table_->d_bank_ptrs, num_elements); - TypedAllocator::Deallocate( - alloc_, hash_table_->d_existence_flag_ptrs, num_elements); + alloc_, static_hash_table_->values_d, + static_hash_table_->capacity_ * static_hash_table_->dimension_); + delete static_hash_table_; + } else { + for (int i = 0; i < hash_table_->bank_ptrs.size(); ++i) { + TypedAllocator::Deallocate(alloc_, hash_table_->bank_ptrs[i], + value_len_ * hash_table_->initial_bank_size); + TypedAllocator::Deallocate(alloc_, hash_table_->existence_flag_ptrs[i], + hash_table_->initial_bank_size); + } + if (hash_table_->mem_bank_num != 0) { + auto num_elements = hash_table_->mem_bank_num * + (config_.block_num * (1 + config_.slot_num)); + TypedAllocator::Deallocate(alloc_, hash_table_->d_bank_ptrs, + num_elements); + TypedAllocator::Deallocate(alloc_, hash_table_->d_existence_flag_ptrs, + num_elements); + } + delete hash_table_; } - delete hash_table_; } TF_DISALLOW_COPY_AND_ASSIGN(GPUHashMapKV); - void SetValueLen(int64 value_len) { - value_len_ = value_len; - } + void SetValueLen(int64 value_len) { value_len_ = value_len; } - Status BatchLookupOrCreateKeys(const K* keys, size_t n, - int32* item_idxs, const Eigen::GpuDevice& device) { + Status BatchLookupOrCreateKeys(const K* keys, size_t n, int32* item_idxs, + const Eigen::GpuDevice& device) { mutex_lock lock(lock_); - int remaining_size = n + *(hash_table_->start_idx) - + int remaining_size = + n + *(hash_table_->start_idx) - hash_table_->mem_bank_num * hash_table_->initial_bank_size; if (remaining_size > 0) { Resize(remaining_size); @@ -71,99 +81,126 @@ class GPUHashMapKV : public KVInterface { } Status BatchLookupOrCreate(const K* keys, V* val, V* default_v, - int32 default_v_num, bool is_use_default_value_tensor, - size_t n, const Eigen::GpuDevice& device) { - int32* item_idxs = TypedAllocator::Allocate(alloc_, n, - AllocationAttributes()); + int32 default_v_num, + bool is_use_default_value_tensor, size_t n, + const Eigen::GpuDevice& device) { + int32* item_idxs = + TypedAllocator::Allocate(alloc_, n, AllocationAttributes()); BatchLookupOrCreateKeys(keys, n, item_idxs, device); functor::KvLookupCreateEmb()( - keys, val, default_v, value_len_, item_idxs, n, - config_.emb_index, default_v_num, is_use_default_value_tensor, - hash_table_->d_bank_ptrs, hash_table_->d_existence_flag_ptrs, + keys, val, default_v, value_len_, item_idxs, n, config_.emb_index, + default_v_num, is_use_default_value_tensor, hash_table_->d_bank_ptrs, + hash_table_->d_existence_flag_ptrs, (config_.block_num * (1 + config_.slot_num)), hash_table_->initial_bank_size, device.stream()); TypedAllocator::Deallocate(alloc_, item_idxs, n); return Status::OK(); } - void GetSnapshot(std::vector* key_list, - std::vector* value_list, - const EmbeddingConfig& emb_config) { + void GetSnapshot(std::vector* key_list, std::vector* value_list, + const EmbeddingConfig& emb_config) { + if (is_inference_) return; // Special case for testing in training mode; auto size = hash_table_->Size(); - if (size > 0) { - int32* item_idxs = TypedAllocator::Allocate( - alloc_, size, AllocationAttributes()); - K* keys_gpu = TypedAllocator::Allocate( - alloc_, size, AllocationAttributes()); - V* values_gpu = TypedAllocator::Allocate( - alloc_, size * value_len_, AllocationAttributes()); - V* values = TypedAllocator::Allocate( - cpu_allocator(), size * value_len_, AllocationAttributes()); - key_list->resize(size); - - auto slot_num = config_.block_num * (1 + config_.slot_num); - functor::KvKeyGetSnapshot()( - keys_gpu, item_idxs, emb_config.emb_index, - emb_config.primary_emb_index, hash_table_->d_existence_flag_ptrs, - hash_table_->mem_bank_num, slot_num, - hash_table_->initial_bank_size, hash_table_, size, NULL); - functor::KvEmbGetSnapshot()( - keys_gpu, values_gpu, -1, value_len_, item_idxs,size, - emb_config.emb_index, hash_table_->d_bank_ptrs, - hash_table_->mem_bank_num, slot_num, - hash_table_->initial_bank_size, NULL); - - for (int64 i = 0; i < size; i++) { - value_list->emplace_back(values + i * value_len_); - } - - cudaMemcpyAsync(const_cast(key_list->data()), - keys_gpu, size * sizeof(K), cudaMemcpyDeviceToHost); - cudaMemcpyAsync(values, values_gpu, size * value_len_ * sizeof(V), - cudaMemcpyDeviceToHost); - EventSynchronize(NULL); - - TypedAllocator::Deallocate(alloc_, item_idxs, size); - TypedAllocator::Deallocate(alloc_, keys_gpu, size); - TypedAllocator::Deallocate(alloc_, values_gpu, size * value_len_); + if (size <= 0) return; + + int32* item_idxs = + TypedAllocator::Allocate(alloc_, size, AllocationAttributes()); + K* keys_gpu = + TypedAllocator::Allocate(alloc_, size, AllocationAttributes()); + V* values_gpu = TypedAllocator::Allocate(alloc_, size * value_len_, + AllocationAttributes()); + V* values = TypedAllocator::Allocate(cpu_allocator(), size * value_len_, + AllocationAttributes()); + key_list->resize(size); + for (int64 i = 0; i < size; i++) { + value_list->emplace_back(values + i * value_len_); } + + auto slot_num = emb_config.block_num * (1 + emb_config.slot_num); + functor::KvKeyGetSnapshot()( + keys_gpu, item_idxs, emb_config.emb_index, emb_config.primary_emb_index, + hash_table_->d_existence_flag_ptrs, hash_table_->mem_bank_num, slot_num, + hash_table_->initial_bank_size, hash_table_, size, NULL); + + functor::KvEmbGetSnapshot()( + keys_gpu, values_gpu, -1, value_len_, item_idxs, size, + emb_config.emb_index, hash_table_->d_bank_ptrs, + hash_table_->mem_bank_num, slot_num, hash_table_->initial_bank_size, + NULL); + + cudaMemcpyAsync(const_cast(key_list->data()), keys_gpu, + size * sizeof(K), cudaMemcpyDeviceToHost); + cudaMemcpyAsync(values, values_gpu, size * value_len_ * sizeof(V), + cudaMemcpyDeviceToHost); + EventSynchronize(NULL); + TypedAllocator::Deallocate(alloc_, item_idxs, size); + TypedAllocator::Deallocate(alloc_, keys_gpu, size); + TypedAllocator::Deallocate(alloc_, values_gpu, size * value_len_); } Status Import(const std::vector& key_import, - const std::vector& value_import, - const Eigen::GpuDevice* device, - const EmbeddingConfig& emb_config) { + const std::vector& value_import, + const Eigen::GpuDevice* device, + const EmbeddingConfig& emb_config) { int n = key_import.size(); auto stream = device->stream(); - if (n > 0) { - int32* item_idxs = TypedAllocator::Allocate( - alloc_, n, AllocationAttributes()); - K* key_gpu = TypedAllocator::Allocate( - alloc_, n, AllocationAttributes()); - cudaMemcpyAsync(key_gpu, key_import.data(), - key_import.size() * sizeof(K), cudaMemcpyHostToDevice, stream); - BatchLookupOrCreateKeys(key_gpu, n, item_idxs, *device); - V* value_gpu = TypedAllocator::Allocate( + + if (is_inference_) { + if (n == 0) { + LOG(INFO) << "Size of keys in EmbeddingVar: " << emb_config.name + << " is 0 while loading in inference mode!"; + return Status::OK(); + } + static_hash_table_ = + new GPUStaticHashTable(n, value_len_, -1, -1, alloc_, stream); + K* keys_d = + TypedAllocator::Allocate(alloc_, n, AllocationAttributes()); + cudaMemcpyAsync(keys_d, key_import.data(), n * sizeof(K), + cudaMemcpyHostToDevice, stream); + static_hash_table_->values_d = TypedAllocator::Allocate( alloc_, value_import.size(), AllocationAttributes()); - cudaMemcpyAsync(value_gpu, value_import.data(), - value_import.size() * sizeof(V), cudaMemcpyHostToDevice, stream); - - functor::KvUpdateEmb()( - key_import.data(), value_gpu, value_len_, item_idxs, n, - emb_config.emb_index, key_import.size(), - hash_table_->d_bank_ptrs, hash_table_->d_existence_flag_ptrs, - (emb_config.block_num * (1 + emb_config.slot_num)), - hash_table_->initial_bank_size, stream); + cudaMemcpyAsync(static_hash_table_->values_d, value_import.data(), + value_import.size() * sizeof(V), cudaMemcpyHostToDevice, + stream); + functor::KvInitStaticMap()( + keys_d, static_hash_table_, n, value_len_, stream); EventSynchronize(stream); - TypedAllocator::Deallocate(alloc_, item_idxs, n); - TypedAllocator::Deallocate(alloc_, value_gpu, value_import.size()); - TypedAllocator::Deallocate(alloc_, key_gpu, n); + + TypedAllocator::Deallocate(alloc_, keys_d, n); + } else { + if (n > 0) { + int32* item_idxs = + TypedAllocator::Allocate(alloc_, n, AllocationAttributes()); + K* key_gpu = + TypedAllocator::Allocate(alloc_, n, AllocationAttributes()); + cudaMemcpyAsync(key_gpu, key_import.data(), + key_import.size() * sizeof(K), cudaMemcpyHostToDevice, + stream); + BatchLookupOrCreateKeys(key_gpu, n, item_idxs, *device); + V* value_gpu = TypedAllocator::Allocate(alloc_, value_import.size(), + AllocationAttributes()); + cudaMemcpyAsync(value_gpu, value_import.data(), + value_import.size() * sizeof(V), cudaMemcpyHostToDevice, + stream); + + functor::KvUpdateEmb()( + key_import.data(), value_gpu, value_len_, item_idxs, n, + emb_config.emb_index, key_import.size(), hash_table_->d_bank_ptrs, + hash_table_->d_existence_flag_ptrs, + (emb_config.block_num * (1 + emb_config.slot_num)), + hash_table_->initial_bank_size, stream); + EventSynchronize(stream); + TypedAllocator::Deallocate(alloc_, item_idxs, n); + TypedAllocator::Deallocate(alloc_, value_gpu, value_import.size()); + TypedAllocator::Deallocate(alloc_, key_gpu, n); + } } + return Status::OK(); } Status BatchLookupOrCreate(const K* keys, size_t n, - ValuePtr** value_ptrs) override { + ValuePtr** value_ptrs) override { return Status::OK(); } @@ -171,25 +208,21 @@ class GPUHashMapKV : public KVInterface { return Status::OK(); } - Status Contains(K key) override { - return Status::OK(); - } + Status Contains(K key) override { return Status::OK(); } Status Insert(K key, const ValuePtr* value_ptr) override { return Status::OK(); } - Status Remove(K key) override { - return Status::OK(); - } + Status Remove(K key) override { return Status::OK(); } Status BatchLookup(const K* keys, size_t size, - ValuePtr** value_ptrs) override { + ValuePtr** value_ptrs) override { return Status::OK(); } Status BatchInsert(const std::vector& keys, - const std::vector*>& value_ptrs) override { + const std::vector*>& value_ptrs) override { return Status::OK(); } @@ -198,46 +231,43 @@ class GPUHashMapKV : public KVInterface { } Status BatchCommit(const std::vector& keys, - const std::vector*>& value_ptrs) override { + const std::vector*>& value_ptrs) override { return Status::OK(); } - int64 Size() const override { - return 0; - } + int64 Size() const override { return 0; } - void SetTotalDims(int total_dims) override { - } + void SetTotalDims(int total_dims) override {} - void FreeValuePtr(ValuePtr* value_ptr) override { - } + void FreeValuePtr(ValuePtr* value_ptr) override {} Status Commit(K key, const ValuePtr* value_ptr) override { return Status::OK(); } Status GetSnapshot(std::vector* key_list, - std::vector*>* value_ptr_list) override { + std::vector*>* value_ptr_list) override { return Status::OK(); } - std::string DebugString() const override { - return std::string(); - } + std::string DebugString() const override { return std::string(); } - Iterator* GetIterator() override { - return nullptr; - } + Iterator* GetIterator() override { return nullptr; } - GPUHashTable* HashTable() override { - return hash_table_; + GPUHashTable* HashTable() override { return hash_table_; } + + Status BatchLookup(const K* keys, V* val, V* default_v, int32 default_v_num, + bool is_use_default_value_tensor, size_t n, + const Eigen::GpuDevice& device) override { + functor::KvLookupKey()( + keys, val, n, value_len_, static_hash_table_, device.stream()); + return Status::OK(); } private: void Resize(int hint) { while (hint > 0) { - for (int i = 0; i < (config_.block_num * - (1 + config_.slot_num)); ++i) { + for (int i = 0; i < (config_.block_num * (1 + config_.slot_num)); ++i) { V* ptr = TypedAllocator::Allocate( alloc_, value_len_ * hash_table_->initial_bank_size, AllocationAttributes()); @@ -251,23 +281,23 @@ class GPUHashMapKV : public KVInterface { ++hash_table_->mem_bank_num; } - auto num_elements = hash_table_->mem_bank_num * ( - config_.block_num * (1 + config_.slot_num)); + auto num_elements = hash_table_->mem_bank_num * + (config_.block_num * (1 + config_.slot_num)); if (hash_table_->d_bank_ptrs) { TypedAllocator::Deallocate(alloc_, hash_table_->d_bank_ptrs, - num_elements); + num_elements); TypedAllocator::Deallocate(alloc_, hash_table_->d_existence_flag_ptrs, - num_elements); + num_elements); } hash_table_->d_bank_ptrs = TypedAllocator::Allocate( alloc_, num_elements, AllocationAttributes()); cudaMemcpy(hash_table_->d_bank_ptrs, hash_table_->bank_ptrs.data(), - num_elements * sizeof(V*), cudaMemcpyHostToDevice); + num_elements * sizeof(V*), cudaMemcpyHostToDevice); hash_table_->d_existence_flag_ptrs = TypedAllocator::Allocate( alloc_, num_elements, AllocationAttributes()); cudaMemcpy(hash_table_->d_existence_flag_ptrs, - hash_table_->existence_flag_ptrs.data(), - num_elements * sizeof(bool*), cudaMemcpyHostToDevice); + hash_table_->existence_flag_ptrs.data(), + num_elements * sizeof(bool*), cudaMemcpyHostToDevice); } void EventSynchronize(const cudaStream_t& stream) { @@ -280,14 +310,16 @@ class GPUHashMapKV : public KVInterface { private: EmbeddingConfig config_; + bool is_inference_; + GPUStaticHashTable* static_hash_table_; GPUHashTable* hash_table_; Allocator* alloc_; int64 value_len_; mutex lock_; }; -} // namespace embedding -} // namespace tensorflow +} // namespace embedding +} // namespace tensorflow -#endif // GOOGLE_CUDA -#endif // TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_GPU_HASH_MAP_KV_H_ +#endif // GOOGLE_CUDA +#endif // TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_GPU_HASH_MAP_KV_H_ diff --git a/tensorflow/core/framework/embedding/gpu_hash_table.cu.cc b/tensorflow/core/framework/embedding/gpu_hash_table.cu.cc index 1a2465bf0b7..b56bd5b7210 100644 --- a/tensorflow/core/framework/embedding/gpu_hash_table.cu.cc +++ b/tensorflow/core/framework/embedding/gpu_hash_table.cu.cc @@ -22,9 +22,8 @@ #include #include -#include - #include "cuco/dynamic_map.cuh" +#include "cuco/static_map.cuh" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/embedding/gpu_hash_table.h" #include "tensorflow/core/framework/register_types.h" @@ -125,15 +124,189 @@ int32 GPUHashTable::Size() { return hash_table->map_.get_size(); } -#define REGISTER_ALL_TYPE(type) \ - template class GPUHashTable; \ - template class GPUHashTable; +template > +class StaticHashTable { + public: + cuco::static_map map_; + + StaticHashTable(size_t initial_capacity, K empty_key_sentinel, + int32 empty_value_sentinel, CUCOAllocator alloc) + : map_(initial_capacity, empty_key_sentinel, empty_value_sentinel, + alloc) {} +}; + +template +GPUStaticHashTable::GPUStaticHashTable(size_t capacity, int dimension, + K empty_key_sentinel, + int32 empty_value_sentinel, + Allocator* alloc, + cudaStream_t stream) { + capacity_ = capacity; + dimension_ = dimension; + // cudaMallocAsync(&values_d, sizeof(V) * dimension * capacity, stream); + // cudaMallocManaged(&values_d, sizeof(V) * dimension * capacity); + + hash_table = new StaticHashTable( + capacity / 0.8 /*load_factor*/, empty_key_sentinel, empty_value_sentinel, + gpu_hash_map_tf_allocator(alloc)); +} + +template +GPUStaticHashTable::~GPUStaticHashTable() { + delete hash_table; + delete default_values; + cudaFree(values_d); +} + +template +std::size_t GPUStaticHashTable::Size() { + return hash_table->map_.get_size(); +} + +#define REGISTER_ALL_TYPE(type) \ + template class GPUHashTable; \ + template class GPUHashTable; \ + template class GPUStaticHashTable; \ + template class GPUStaticHashTable; TF_CALL_REAL_NUMBER_TYPES(REGISTER_ALL_TYPE) #undef REGISTER_ALL_TYPE namespace functor { using atomicT = cuda::atomic; +template , + typename KeyEqual = thrust::equal_to> +__global__ void kv_initialize_static_map(const Key* key_first, int32 num_items, + int32 dimension, + mutableViewT map_mutable_view, + atomicT* num_successes, + Hash hash = Hash{}, + KeyEqual key_equal = KeyEqual{}) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + std::size_t thread_num_successes = 0; + + auto tile = cg::tiled_partition(cg::this_thread_block()); + auto tid = blockDim.x * blockIdx.x + threadIdx.x; + auto key_idx = tid / tile_size; + + while (key_idx < num_items) { + auto key = *(key_first + key_idx); + int32 value = key_idx * dimension; + + auto const insert_pair = cuco::pair_type{key, value}; + if (map_mutable_view.insert(tile, insert_pair, hash, key_equal) && + tile.thread_rank() == 0) { + thread_num_successes++; + } + + key_idx += (gridDim.x * blockDim.x) / tile_size; + } + std::size_t block_num_successes = + BlockReduce(temp_storage).Sum(thread_num_successes); + if (threadIdx.x == 0) { + *num_successes += block_num_successes; + } +} + +template +struct KvInitStaticMap { + void operator()(const Key* keys, GPUStaticHashTable* hash_table, + int32 num_items, int32 dimension, cudaStream_t stream) { + using MutableViewT = typename cuco::static_map< + Key, int32, cuda::thread_scope_device, + gpu_hash_map_tf_allocator>::device_mutable_view; + + auto& map = hash_table->hash_table->map_; + size_t num_to_insert = num_items; + while (num_to_insert > 0) { + static_assert(sizeof(std::size_t) == sizeof(atomicT)); + CUCO_CUDA_TRY( + cudaMemsetAsync(map.get_num_success(), 0, sizeof(atomicT), stream)); + + auto n = std::min((size_t)65535, num_to_insert); + auto const block_size = 128; + auto stride = 1; + auto const tile_size = 4; + auto const grid_size = + (tile_size * n + stride * block_size - 1) / (stride * block_size); + TF_CHECK_OK(GpuLaunchKernel( + kv_initialize_static_map, + thrust::equal_to>, + grid_size, block_size, 0, stream, keys, n, dimension, + map.get_device_mutable_view(), map.get_num_success(), + cuco::detail::MurmurHash3_32{}, thrust::equal_to{})); + + CUCO_CUDA_TRY(cudaStreamSynchronize(stream)); + + std::size_t h_num_successes = + map.get_num_success()->load(cuda::std::memory_order_relaxed); + map.update_size(h_num_successes); + keys += n; + num_to_insert -= n; + } + } +}; + +template , + typename KeyEqual = thrust::equal_to> +__global__ void kv_lookup_key_kernel(const Key* key_first, const V* value_srcs, + V* value_first, size_t num_items, + int32 dimension, ViewT map_views, + Hash hash = Hash{}, + KeyEqual key_equal = KeyEqual{}) { + auto grid = cooperative_groups::this_grid(); + auto block = cooperative_groups::this_thread_block(); + auto tile = cooperative_groups::tiled_partition(block); + + auto tid = blockDim.x * blockIdx.x + threadIdx.x; + auto key_idx = tid / tile_size; // actual thread idx + auto empty_value_sentinel = map_views.get_empty_value_sentinel(); + + while (key_idx < num_items) { + auto key = *(key_first + key_idx); + int32 found_value = empty_value_sentinel; + auto found = map_views.find(tile, key, hash, key_equal); + if (found != map_views.end()) { + found_value = found->second; + } + + if (tile.thread_rank() == 0) { + for (auto id = threadIdx.x; id < dimension; id += blockDim.x) { + value_first[key_idx * dimension + id] = value_srcs[found_value + id]; + } + } + key_idx += (gridDim.x * blockDim.x) / tile_size; + } +} + +template +struct KvLookupKey { + void operator()(const Key* keys, V* vals, int32 num_items, int32 dimension, + GPUStaticHashTable* hash_table, cudaStream_t stream) { + using ViewT = typename cuco::static_map< + Key, int32, cuda::thread_scope_device, + gpu_hash_map_tf_allocator>::device_view; + auto& map = hash_table->hash_table->map_; + + auto const block_size = 128; + auto const stride = 1; + auto const tile_size = 4; + auto const grid_size = (tile_size * num_items + stride * block_size - 1) / + (stride * block_size); + TF_CHECK_OK(GpuLaunchKernel( + kv_lookup_key_kernel, grid_size, + block_size, 0, stream, keys, hash_table->values_d, vals, num_items, + dimension, map.get_device_view(), cuco::detail::MurmurHash3_32{}, + thrust::equal_to{})); + } +}; + template , @@ -220,6 +393,7 @@ struct KvLookupInsertKey { sizeof(atomicT), device_id)); auto n = std::min(capacity_remaining, num_to_insert); + auto const block_size = 128; auto const stride = 1; auto const tile_size = 4; @@ -274,7 +448,8 @@ __global__ void kv_lookup_or_create_emb_kernel( } } for (auto id = threadIdx.x; id < dim; id += blockDim.x) { - val[item_idx * dim + id] = d_banks[slot_offset][offset_in_bank * dim + id]; + val[item_idx * dim + id] = + d_banks[slot_offset][offset_in_bank * dim + id]; } } @@ -434,6 +609,10 @@ struct KvEmbGetSnapshot { } // namespace functor #define REGISTER_ALL_TYPE(type) \ + template struct functor::KvInitStaticMap; \ + template struct functor::KvInitStaticMap; \ + template struct functor::KvLookupKey; \ + template struct functor::KvLookupKey; \ template struct functor::KvLookupInsertKey; \ template struct functor::KvLookupInsertKey; \ template struct functor::KvLookupCreateEmb; \ @@ -449,4 +628,4 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_ALL_TYPE) } // namespace tensorflow -#endif // GOOGLE_CUDA \ No newline at end of file +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/framework/embedding/gpu_hash_table.h b/tensorflow/core/framework/embedding/gpu_hash_table.h index d57970aecd6..076f3e767c7 100644 --- a/tensorflow/core/framework/embedding/gpu_hash_table.h +++ b/tensorflow/core/framework/embedding/gpu_hash_table.h @@ -17,6 +17,7 @@ limitations under the License. #if GOOGLE_CUDA #include + #include "tensorflow/core/framework/typed_allocator.h" #include "tensorflow/core/lib/core/status.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -28,10 +29,32 @@ class gpu_hash_map_tf_allocator; template class DynamicHashTable; +template +class StaticHashTable; + +template +class GPUStaticHashTable { + public: + GPUStaticHashTable(size_t capacity, int dimension, K empty_key_sentinel, + int32 empty_value_sentinel, Allocator* alloc, + cudaStream_t stream); + + ~GPUStaticHashTable(); + + std::size_t Size(); + + StaticHashTable>* hash_table; + V* values_d{nullptr}; + int dimension_; + V* default_values{nullptr}; + int capacity_; +}; + template class GPUHashTable { -public: - GPUHashTable(K empty_key_sentinel, Allocator* alloc, size_t initial_capacity=50000); + public: + GPUHashTable(K empty_key_sentinel, Allocator* alloc, + size_t initial_capacity = 50000); ~GPUHashTable(); @@ -49,83 +72,65 @@ class GPUHashTable { }; namespace functor { + template -struct KvLookupInsertKey { - void operator()(const Key* key_first, - int32* value_first, - int32 num_items, - GPUHashTable* hash_table, - cuda::atomic* start_idx, +struct KvLookupKey { + void operator()(const Key* key_first, V* value_first, int32 num_items, + int32 dimension, GPUStaticHashTable* hash_table, cudaStream_t stream); }; +template +struct KvInitStaticMap { + void operator()(const Key* key_first, GPUStaticHashTable* hash_table, + int32 num_items, int32 dimension, cudaStream_t stream); +}; + +template +struct KvLookupInsertKey { + void operator()( + const Key* key_first, int32* value_first, int32 num_items, + GPUHashTable* hash_table, + cuda::atomic* start_idx, + cudaStream_t stream); +}; + template struct KvLookupCreateEmb { - void operator()(const Key* key_first, - Value* val, - Value* default_v, - int64 dim, - int32* item_idxs, - int32 num_items, - int32 slot_idx, - int32 default_v_num, - bool is_use_default_value_tensor, - Value** d_banks, - bool** d_flags, - int32 slot_num, - int32 bank_size, - cudaStream_t stream); + void operator()(const Key* key_first, Value* val, Value* default_v, int64 dim, + int32* item_idxs, int32 num_items, int32 slot_idx, + int32 default_v_num, bool is_use_default_value_tensor, + Value** d_banks, bool** d_flags, int32 slot_num, + int32 bank_size, cudaStream_t stream); }; template struct KvUpdateEmb { - void operator()(const Key* key_first, - Value* default_v, - int64 dim, - int32* item_idxs, - int32 num_items, - int32 slot_idx, - int32 default_v_num, - Value** d_banks, - bool** d_flags, - int32 slot_num, - int32 bank_size, - cudaStream_t stream); + void operator()(const Key* key_first, Value* default_v, int64 dim, + int32* item_idxs, int32 num_items, int32 slot_idx, + int32 default_v_num, Value** d_banks, bool** d_flags, + int32 slot_num, int32 bank_size, cudaStream_t stream); }; template struct KvKeyGetSnapshot { - void operator()(Key* key_first, - int32* value_first, - int32 slot_idx, - int32 primary_slot_idx, - bool** d_flags, - int32 bank_num, - int32 slot_num, - int32 bank_size, - GPUHashTable* hash_table, - int32 ev_size, + void operator()(Key* key_first, int32* value_first, int32 slot_idx, + int32 primary_slot_idx, bool** d_flags, int32 bank_num, + int32 slot_num, int32 bank_size, + GPUHashTable* hash_table, int32 ev_size, cudaStream_t stream); }; template struct KvEmbGetSnapshot { - void operator()(Key* key, - Value* val, - Key empty_key_sentinel, - int64 dim, - int32* item_idxs, - int32 num_items, - int32 slot_idx, - Value** d_banks, - int32 bank_num, - int32 slot_num, - int32 bank_size, - cudaStream_t stream); + void operator()(Key* key, Value* val, Key empty_key_sentinel, int64 dim, + int32* item_idxs, int32 num_items, int32 slot_idx, + Value** d_banks, int32 bank_num, int32 slot_num, + int32 bank_size, cudaStream_t stream); }; -} // namespace functor -} // namespace tensorflow +} // namespace functor +} // namespace tensorflow -#endif // GOOGLE_CUDA -#endif // TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_GPU_HASH_TABLE_H_ +#endif // GOOGLE_CUDA +#endif // TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_GPU_HASH_TABLE_H_ \ No newline at end of file diff --git a/tensorflow/core/framework/embedding/kv_interface.h b/tensorflow/core/framework/embedding/kv_interface.h index 7e4436a7845..64e0c4685f0 100644 --- a/tensorflow/core/framework/embedding/kv_interface.h +++ b/tensorflow/core/framework/embedding/kv_interface.h @@ -19,6 +19,9 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" namespace tensorflow { +namespace { +const char* kInferenceMode = "INFERENCE_MODE"; +} template class ValuePtr; @@ -107,12 +110,18 @@ class KVInterface { return Status::OK(); } + virtual Status BatchLookup(const K* keys, V* val, V* default_v, + int32 default_v_num, bool is_use_default_value_tensor, + size_t n, const Eigen::GpuDevice& device) { + return Status(error::Code::UNIMPLEMENTED, + "Unimplemented for BatchLookup in KVInterface."); + } + virtual GPUHashTable* HashTable() { return nullptr; } virtual void SetValueLen(int64 value_len) {} - }; } // namespace embedding diff --git a/tensorflow/core/framework/embedding/single_tier_storage.h b/tensorflow/core/framework/embedding/single_tier_storage.h index df1ed5e7bfb..ad9dc4e15b6 100644 --- a/tensorflow/core/framework/embedding/single_tier_storage.h +++ b/tensorflow/core/framework/embedding/single_tier_storage.h @@ -455,6 +455,13 @@ class HbmStorage : public SingleTierStorage { SingleTierStorage::kv_->BatchLookupOrCreateKeys(key, n, item_idxs, device); } + void BatchLookup(const K* key, V* val, V* default_v, + int32 default_v_num, bool is_use_default_value_tensor, + size_t n, const Eigen::GpuDevice& device) override { + SingleTierStorage::kv_->BatchLookup(key, val, default_v, default_v_num, + is_use_default_value_tensor, n, device); + } + int64 GetSnapshot(std::vector* key_list, std::vector* value_list, std::vector* version_list, diff --git a/tensorflow/core/framework/embedding/storage.h b/tensorflow/core/framework/embedding/storage.h index fa87b574f79..cc22bb4712a 100644 --- a/tensorflow/core/framework/embedding/storage.h +++ b/tensorflow/core/framework/embedding/storage.h @@ -114,6 +114,9 @@ class Storage { size_t n, const Eigen::GpuDevice& device) {} virtual void BatchLookupOrCreateKeys(const K* key, int32* item_idxs, size_t n, const Eigen::GpuDevice& device) {} + virtual void BatchLookup(const K* keys, V* val, V* default_v, + int32 default_v_num, bool is_use_default_value_tensor, + size_t n, const Eigen::GpuDevice& device) {} virtual void ImportToHbm(const std::vector& keys, const std::vector& values, const Eigen::GpuDevice* device, const EmbeddingConfig& emb_config) {}; diff --git a/tensorflow/core/kernels/group_embedding/group_embedding_lookup_sparse_forward_ops.cu.cc b/tensorflow/core/kernels/group_embedding/group_embedding_lookup_sparse_forward_ops.cu.cc index 1ad0c352b4d..97f44eb8ae6 100644 --- a/tensorflow/core/kernels/group_embedding/group_embedding_lookup_sparse_forward_ops.cu.cc +++ b/tensorflow/core/kernels/group_embedding/group_embedding_lookup_sparse_forward_ops.cu.cc @@ -52,6 +52,26 @@ class GroupEmbeddingVarLookupOp return default_v + len * (id % total_dim); }; } + bool is_inference; + TF_CHECK_OK(ReadBoolFromEnvVar(kInferenceMode, false, &is_inference)); + if (!is_inference) { + lookup_fn_ = [](EmbeddingVar* ev, const TFKey* key, + TValue* val, TValue* default_v, int32 default_v_num, + bool is_use_default_value_tensor, + size_t n, const Eigen::GpuDevice& device) { + ev->LookupOrCreate(key, val, default_v, default_v_num, + is_use_default_value_tensor, n, device); + }; + } else { + lookup_fn_ = [](EmbeddingVar* ev, const TFKey* key, + TValue* val, TValue* default_v, int32 default_v_num, + bool is_use_default_value_tensor, + size_t n, const Eigen::GpuDevice& device) { + ev->Lookup(key, val, default_v, default_v_num, + is_use_default_value_tensor, n, device); + }; + } + } ~GroupEmbeddingVarLookupOp() { delete[] occupy_flag_; } @@ -67,6 +87,7 @@ class GroupEmbeddingVarLookupOp this->num_lookups_, this->dimension_, this->max_norm_, gpu_allocator); std::vector tensor_list; + tensor_list.reserve(this->num_lookups_); for (int i = 0; i < this->num_lookups_; ++i) { EmbeddingVar* ev = nullptr; @@ -107,12 +128,13 @@ class GroupEmbeddingVarLookupOp auto default_values_matrix = default_values.shaped({default_value_num, dimension}); TValue* default_v_base = &default_values_matrix(0, 0); - ev->LookupOrCreate(key_base, out_base, default_v_base, - default_value_num, is_use_default_value_tensor_, N, - device); + lookup_fn_(ev, key_base, out_base, default_v_base, + default_value_num, is_use_default_value_tensor_, N, + device); + } else { - ev->LookupOrCreate(key_base, out_base, ev->GetDefaultValuePtr(), - ev->GetDefaultValueDim(), true, N, device); + lookup_fn_(ev, key_base, out_base, ev->GetDefaultValuePtr(), + ev->GetDefaultValueDim(), true, N, device); } } else { auto out_flat = @@ -287,6 +309,10 @@ class GroupEmbeddingVarLookupOp private: std::map hash_map_; std::function get_default_v_fn_; + std::function* ev, const TFKey* key, + TValue* val, TValue* default_v, int32 default_v_num, + bool is_use_default_value_tensor, + size_t n, const Eigen::GpuDevice& device)> lookup_fn_; mutable easy_spinrwlock_t mu_ = EASY_SPINRWLOCK_INITIALIZER; bool* occupy_flag_{nullptr}; mutex m_init_occupy_flag_; diff --git a/tensorflow/core/kernels/kv_variable_lookup_ops.cc b/tensorflow/core/kernels/kv_variable_lookup_ops.cc index c5c5cc22c33..6b3139645c0 100644 --- a/tensorflow/core/kernels/kv_variable_lookup_ops.cc +++ b/tensorflow/core/kernels/kv_variable_lookup_ops.cc @@ -783,6 +783,25 @@ class KvResourceGatherGPUOp : public OpKernel { return 1; }; } + bool is_inference; + TF_CHECK_OK(ReadBoolFromEnvVar(kInferenceMode, false, &is_inference)); + if (!is_inference) { + lookup_fn_ = [](EmbeddingVar* ev, const TKey* key, + TValue* val, TValue* default_v, int32 default_v_num, + bool is_use_default_value_tensor, + size_t n, const Eigen::GpuDevice& device) { + ev->LookupOrCreate(key, val, default_v, default_v_num, + is_use_default_value_tensor, n, device); + }; + } else { + lookup_fn_ = [](EmbeddingVar* ev, const TKey* key, + TValue* val, TValue* default_v, int32 default_v_num, + bool is_use_default_value_tensor, + size_t n, const Eigen::GpuDevice& device) { + ev->Lookup(key, val, default_v, default_v_num, + is_use_default_value_tensor, n, device); + }; + } } ~KvResourceGatherGPUOp() { @@ -851,11 +870,11 @@ class KvResourceGatherGPUOp : public OpKernel { auto default_values_matrix = default_values.shaped( {default_value_num, ev->ValueLen()}); TValue* default_v_base = &default_values_matrix(0, 0); - ev->LookupOrCreate(key_base, out_base, default_v_base, + lookup_fn_(ev, key_base, out_base, default_v_base, default_value_num, is_use_default_value_tensor_, indices_size, device); } else { - ev->LookupOrCreate(key_base, out_base, ev->GetDefaultValuePtr(), + lookup_fn_(ev, key_base, out_base, ev->GetDefaultValuePtr(), ev->GetDefaultValueDim(), is_use_default_value_tensor_, indices_size, device); } @@ -967,6 +986,10 @@ class KvResourceGatherGPUOp : public OpKernel { std::function< TValue*(TValue*, TKey, int64, int64, int64)> get_default_v_fn_; std::function get_count_fn_; + std::function* ev, const TKey* key, + TValue* val, TValue* default_v, int32 default_v_num, + bool is_use_default_value_tensor, + size_t n, const Eigen::GpuDevice& device)> lookup_fn_; std::map hash_map_; mutable easy_spinrwlock_t mu_ = EASY_SPINRWLOCK_INITIALIZER; bool* occupy_flag_ = nullptr; diff --git a/tensorflow/python/ops/embedding_variable_ops_gpu_test.py b/tensorflow/python/ops/embedding_variable_ops_gpu_test.py index 445118a3926..26ae99126b9 100644 --- a/tensorflow/python/ops/embedding_variable_ops_gpu_test.py +++ b/tensorflow/python/ops/embedding_variable_ops_gpu_test.py @@ -146,7 +146,7 @@ def testEmbeddingVariableForLookupInt64(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) init = variables.global_variables_initializer() - with self.test_session(force_gpu=True) as sess: + with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) sess.run([init]) @@ -198,7 +198,7 @@ def testEmbeddingVariableForGetShape(self): emb = embedding_ops.embedding_lookup(var, math_ops.cast([0,1,2,5,6,7], dtypes.int64)) shape = var.total_count() init = variables.global_variables_initializer() - with self.test_session(force_gpu=True) as sess: + with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) sess.run([init]) @@ -226,7 +226,7 @@ def testEmbeddingVariableForSparseColumnSharedEmbeddingCol(self): train_op = opt.apply_gradients(g_v) init = variables.global_variables_initializer() - with self.test_session(force_gpu=True) as sess: + with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) sess.run(init) @@ -255,7 +255,7 @@ def testEmbeddingVariableForFeatureFilterFromContribFeatureColumn(self): train_op = opt.apply_gradients(g_v) init = variables.global_variables_initializer() - with self.test_session(force_gpu=True) as sess: + with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) sess.run([init]) @@ -291,7 +291,7 @@ def testEmbeddingVariableForSparseColumnEmbeddingCol(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) init = variables.global_variables_initializer() - with self.test_session(force_gpu=True) as sess: + with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) sess.run(init) @@ -311,7 +311,7 @@ def runTestAdagrad(self, var): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) init = variables.global_variables_initializer() - with self.test_session(force_gpu=True) as sess: + with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) sess.run([init]) @@ -346,7 +346,7 @@ def runTestFtrl(self, var, g): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) init = variables.global_variables_initializer() - with self.test_session(graph=g, force_gpu=True) as sess: + with self.test_session(graph=g) as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) sess.run([init]) @@ -383,7 +383,7 @@ def testEmbeddingVariableForGeneralConstInitializer(self): partitioner=partitioned_variables.fixed_size_partitioner(num_shards=4)) emb = embedding_ops.embedding_lookup(var, math_ops.cast([1,6], dtypes.int64)) init = variables.global_variables_initializer() - with self.test_session(force_gpu=True) as sess: + with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) sess.run([init]) @@ -402,7 +402,7 @@ def testEmbeddingVariableForGeneralRandomInitializer(self): partitioner=partitioned_variables.fixed_size_partitioner(num_shards=4)) emb = embedding_ops.embedding_lookup(var, math_ops.cast([1,6], dtypes.int64)) init = variables.global_variables_initializer() - with self.test_session(force_gpu=True) as sess: + with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) sess.run([init]) @@ -428,7 +428,7 @@ def testEVInitializerWithKeyFetch(self): var_emb = embedding_ops.embedding_lookup(var, math_ops.cast([0,1,2,3,4,5,6,7], dtypes.int64)) emb_emb = embedding_ops.embedding_lookup(emb_var, math_ops.cast([0,1,2,5,6,7,8,9,10], dtypes.int64)) init = variables.global_variables_initializer() - with self.test_session(graph=g, force_gpu=True) as sess: + with self.test_session(graph=g) as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) sess.run([init]) @@ -454,7 +454,7 @@ def runTest(self, var, g): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) init = variables.global_variables_initializer() - with self.test_session(graph=g, force_gpu=True) as sess: + with self.test_session(graph=g) as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) sess.run([init]) diff --git a/third_party/cucollection.patch b/third_party/cucollection.patch index e5ca14fbcd0..fc3d06603b4 100644 --- a/third_party/cucollection.patch +++ b/third_party/cucollection.patch @@ -1,17 +1,17 @@ -From 874376f7e6b597bc288d3b945e706fd83a7033bf Mon Sep 17 00:00:00 2001 -From: Hongxiao Bai -Date: Thu, 20 Jan 2022 21:33:03 +0800 -Subject: [PATCH] cuco_modification_for_deeprec +From b47364f0bf2c1e630c600e4e2e09e54020bac7fa Mon Sep 17 00:00:00 2001 +From: Mesilenceki +Date: Tue, 18 Apr 2023 11:56:47 +0800 +Subject: [PATCH] cuco patch --- - include/cuco/detail/dynamic_map.inl | 47 +++++++- - include/cuco/detail/dynamic_map_kernels.cuh | 71 +++++++++++- - include/cuco/detail/pair.cuh | 14 +++ - include/cuco/detail/static_map.inl | 115 ++++++++++++++++---- - include/cuco/dynamic_map.cuh | 49 ++++++++- - include/cuco/static_map.cuh | 51 ++++++++- + include/cuco/detail/dynamic_map.inl | 47 ++++++- + include/cuco/detail/dynamic_map_kernels.cuh | 71 +++++++++- + include/cuco/detail/pair.cuh | 14 ++ + include/cuco/detail/static_map.inl | 138 ++++++++++++++++---- + include/cuco/dynamic_map.cuh | 49 ++++++- + include/cuco/static_map.cuh | 57 +++++++- include/cuco/traits.hpp | 1 + - 7 files changed, 317 insertions(+), 31 deletions(-) + 7 files changed, 340 insertions(+), 37 deletions(-) diff --git a/include/cuco/detail/dynamic_map.inl b/include/cuco/detail/dynamic_map.inl index 57950ea..78543c5 100644 @@ -190,10 +190,61 @@ index 0d8a85e..4aa8481 100644 namespace detail { diff --git a/include/cuco/detail/static_map.inl b/include/cuco/detail/static_map.inl -index 1719970..febc1fb 100644 +index 1719970..23482f8 100644 --- a/include/cuco/detail/static_map.inl +++ b/include/cuco/detail/static_map.inl -@@ -271,18 +271,18 @@ __device__ bool static_map::device_mutable_view::i +@@ -31,7 +31,10 @@ static_map::static_map(std::size_t capacity, + counter_allocator_{alloc} + { + slots_ = std::allocator_traits::allocate(slot_allocator_, capacity_); +- num_successes_ = std::allocator_traits::allocate(counter_allocator_, 1); ++ // num_successes_ = std::allocator_traits::allocate(counter_allocator_, 1); ++ CUCO_CUDA_TRY(cudaMallocManaged(&num_successes_, sizeof(atomic_ctr_type))); ++ // static_assert(sizeof(std::size_t) == sizeof(atomic_ctr_type)); ++ // CUCO_CUDA_TRY(cudaMemsetAsync(num_successes_, 0, sizeof(atomic_ctr_type), stream)); + + auto constexpr block_size = 256; + auto constexpr stride = 4; +@@ -45,7 +48,8 @@ template ::~static_map() + { + std::allocator_traits::deallocate(slot_allocator_, slots_, capacity_); +- std::allocator_traits::deallocate(counter_allocator_, num_successes_, 1); ++ // std::allocator_traits::deallocate(counter_allocator_, num_successes_, 1); ++ CUCO_ASSERT_CUDA_SUCCESS(cudaFree(num_successes_)); + } + + template +@@ -63,8 +67,12 @@ void static_map::insert( + auto view = get_device_mutable_view(); + + // TODO: memset an atomic variable is unsafe +- static_assert(sizeof(std::size_t) == sizeof(atomic_ctr_type)); +- CUCO_CUDA_TRY(cudaMemsetAsync(num_successes_, 0, sizeof(atomic_ctr_type), stream)); ++ // static_assert(sizeof(std::size_t) == sizeof(atomic_ctr_type)); ++ int device_id; ++ CUCO_CUDA_TRY(cudaGetDevice(&device_id)); ++ CUCO_CUDA_TRY(cudaMemPrefetchAsync(num_successes_, sizeof(atomic_ctr_type), device_id)); ++ ++ // CUCO_CUDA_TRY(cudaMemsetAsync(num_successes_, 0, sizeof(atomic_ctr_type), stream)); + std::size_t h_num_successes; + + detail::insert<<>>( +@@ -101,8 +109,11 @@ void static_map::insert_if(InputIt first, + auto view = get_device_mutable_view(); + + // TODO: memset an atomic variable is unsafe +- static_assert(sizeof(std::size_t) == sizeof(atomic_ctr_type)); +- CUCO_CUDA_TRY(cudaMemsetAsync(num_successes_, 0, sizeof(atomic_ctr_type), stream)); ++ // static_assert(sizeof(std::size_t) == sizeof(atomic_ctr_type)); ++ // CUCO_CUDA_TRY(cudaMemsetAsync(num_successes_, 0, sizeof(atomic_ctr_type), stream)); ++ int device_id; ++ CUCO_CUDA_TRY(cudaGetDevice(&device_id)); ++ CUCO_CUDA_TRY(cudaMemPrefetchAsync(num_successes_, sizeof(atomic_ctr_type), device_id)); + std::size_t h_num_successes; + + detail::insert_if_n<<>>( +@@ -271,18 +282,18 @@ __device__ bool static_map::device_mutable_view::i if (slot_is_empty) { auto const status = [&]() { @@ -222,7 +273,7 @@ index 1719970..febc1fb 100644 }(); // successful insert -@@ -325,18 +325,18 @@ __device__ bool static_map::device_mutable_view::i +@@ -325,18 +336,18 @@ __device__ bool static_map::device_mutable_view::i uint32_t src_lane = __ffs(window_contains_empty) - 1; if (g.thread_rank() == src_lane) { @@ -252,7 +303,7 @@ index 1719970..febc1fb 100644 } uint32_t res_status = g.shfl(static_cast(status), src_lane); -@@ -358,6 +358,43 @@ __device__ bool static_map::device_mutable_view::i +@@ -358,6 +369,43 @@ __device__ bool static_map::device_mutable_view::i } } @@ -296,7 +347,7 @@ index 1719970..febc1fb 100644 template template __device__ typename static_map::device_view::iterator -@@ -482,6 +519,42 @@ static_map::device_view::find(CG g, +@@ -482,6 +530,42 @@ static_map::device_view::find(CG g, } } @@ -340,7 +391,7 @@ index 1719970..febc1fb 100644 template __device__ bool static_map::device_view::contains( diff --git a/include/cuco/dynamic_map.cuh b/include/cuco/dynamic_map.cuh -index 2e57ac6..64f4d3f 100644 +index 2e57ac6..b85759d 100644 --- a/include/cuco/dynamic_map.cuh +++ b/include/cuco/dynamic_map.cuh @@ -96,8 +96,8 @@ class dynamic_map { @@ -420,7 +471,7 @@ index 2e57ac6..64f4d3f 100644 thrust::device_vector submap_views_; ///< vector of device views for each submap thrust::device_vector diff --git a/include/cuco/static_map.cuh b/include/cuco/static_map.cuh -index 321b1f3..cc7601b 100644 +index 321b1f3..fa810e4 100644 --- a/include/cuco/static_map.cuh +++ b/include/cuco/static_map.cuh @@ -123,10 +123,10 @@ class static_map { @@ -502,6 +553,19 @@ index 321b1f3..cc7601b 100644 /** * @brief Indicates whether the key `k` was inserted into the map. * +@@ -1053,6 +1096,12 @@ class static_map { + * @return The number of elements in the map + */ + std::size_t get_size() const noexcept { return size_; } ++ ++ void update_size(std::size_t n) noexcept { size_ += n; } ++ ++ atomic_ctr_type* get_num_success() noexcept { ++ return num_successes_; ++ } + + /** + * @brief Gets the load factor of the hash map. diff --git a/include/cuco/traits.hpp b/include/cuco/traits.hpp index 445a40d..07fe954 100644 --- a/include/cuco/traits.hpp @@ -515,5 +579,4 @@ index 445a40d..07fe954 100644 namespace cuco { -- -2.33.0 - +2.37.1 (Apple Git-137.1) \ No newline at end of file