Skip to content

Commit

Permalink
[Embedding] Refactor the restore interface of EmbeddingVariable. (Dee…
Browse files Browse the repository at this point in the history
…pRec-AI#903)

support restore parameters from single or partitioned EmbeddingVariable

Signed-off-by: JunqiHu <silenceki@hotmail.com>
  • Loading branch information
Mesilenceki authored Jul 11, 2023
1 parent 96d66ab commit 56cc51e
Show file tree
Hide file tree
Showing 21 changed files with 1,754 additions and 1,712 deletions.
45 changes: 17 additions & 28 deletions tensorflow/core/framework/embedding/bloom_filter_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ const static std::vector<int64> default_seeds = {

template<typename K, typename V, typename EV>
class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
using FilterPolicy<K, V, EV>::ev_;
using FilterPolicy<K, V, EV>::config_;

public:
BloomFilterPolicy(const EmbeddingConfig& config, EV* ev)
: config_(config), ev_(ev) {
BloomFilterPolicy(const EmbeddingConfig& config, EV* ev) :
FilterPolicy<K, V, EV>(config, ev) {

switch (config_.counter_type){
case DT_UINT64:
VLOG(2) << "The type of bloom counter is uint64";
Expand Down Expand Up @@ -303,16 +307,18 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
}
}

Status Import(RestoreBuffer& restore_buff,
int64 key_num,
int bucket_num,
int64 partition_id,
int64 partition_num,
bool is_filter) override {
Status Restore(int64 key_num, int bucket_num, int64 partition_id,
int64 partition_num, int64 value_len, bool is_filter,
bool to_dram, bool is_incr, RestoreBuffer& restore_buff) override {
K* key_buff = (K*)restore_buff.key_buffer;
V* value_buff = (V*)restore_buff.value_buffer;
int64* version_buff = (int64*)restore_buff.version_buffer;
int64* freq_buff = (int64*)restore_buff.freq_buffer;
if (to_dram) {
LOG(FATAL)<<"BloomFilter dosen't support ImportToDRAM";
return Status::OK();
}

for (auto i = 0; i < key_num; ++i) {
// this can describe by graph(Mod + DynamicPartition),
// but memory waste and slow
Expand All @@ -333,33 +339,19 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
SetBloomFreq(key_buff[i], freq_buff[i]);
}
if (new_freq >= config_.filter_freq){
ev_->CreateKey(key_buff[i], &value_ptr);
ev_->CreateKey(key_buff[i], &value_ptr, to_dram);
if (config_.steps_to_live != 0 || config_.record_version) {
value_ptr->SetStep(version_buff[i]);
}
if (!is_filter){
ev_->LookupOrCreateEmb(value_ptr,
value_buff + i * ev_->ValueLen());
value_buff + i * ev_->ValueLen());
} else {
ev_->LookupOrCreateEmb(value_ptr,
ev_->GetDefaultValue(key_buff[i]));
ev_->GetDefaultValue(key_buff[i]));
}
}
}
if (ev_->IsMultiLevel() && !ev_->IsUseHbm() && config_.is_primary()) {
ev_->UpdateCache(key_buff, key_num, version_buff, freq_buff);
}
return Status::OK();
}

Status ImportToDram(RestoreBuffer& restore_buff,
int64 key_num,
int bucket_num,
int64 partition_id,
int64 partition_num,
bool is_filter,
V* default_values) override {
LOG(FATAL)<<"BloomFilter dosen't support ImportToDRAM";
return Status::OK();
}

Expand Down Expand Up @@ -455,11 +447,8 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
}
}
}

private:
void* bloom_counter_;
EmbeddingConfig config_;
EV* ev_;
std::vector<int64> seeds_;
};
} // tensorflow
Expand Down
88 changes: 14 additions & 74 deletions tensorflow/core/framework/embedding/counter_filter_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ namespace tensorflow {

template<typename K, typename V, typename EV>
class CounterFilterPolicy : public FilterPolicy<K, V, EV> {
using FilterPolicy<K, V, EV>::ev_;
using FilterPolicy<K, V, EV>::config_;
using FilterPolicy<K, V, EV>::LookupOrCreateEmbInternal;

public:
CounterFilterPolicy(const EmbeddingConfig& config, EV* ev)
: config_(config), ev_(ev){
}
CounterFilterPolicy(const EmbeddingConfig& config, EV* ev) :
FilterPolicy<K, V, EV>(config, ev) {}

Status Lookup(K key, V* val, const V* default_value_ptr,
const V* default_value_no_permission) override {
Expand Down Expand Up @@ -115,60 +118,13 @@ class CounterFilterPolicy : public FilterPolicy<K, V, EV> {
return value_ptr->GetFreq();
}

Status Import(RestoreBuffer& restore_buff,
int64 key_num,
int bucket_num,
int64 partition_id,
int64 partition_num,
bool is_filter) override {
K* key_buff = (K*)restore_buff.key_buffer;
V* value_buff = (V*)restore_buff.value_buffer;
int64* version_buff = (int64*)restore_buff.version_buffer;
int64* freq_buff = (int64*)restore_buff.freq_buffer;
for (auto i = 0; i < key_num; ++i) {
// this can describe by graph(Mod + DynamicPartition),
// but memory waste and slow
if (*(key_buff + i) % bucket_num % partition_num != partition_id) {
LOG(INFO) << "skip EV key:" << *(key_buff + i);
continue;
}
ValuePtr<V>* value_ptr = nullptr;
ev_->CreateKey(key_buff[i], &value_ptr);
if (!is_filter) {
if (freq_buff[i] >= config_.filter_freq) {
value_ptr->SetFreq(freq_buff[i]);
} else {
value_ptr->SetFreq(config_.filter_freq);
}
} else {
value_ptr->SetFreq(freq_buff[i]);
}
if (config_.steps_to_live != 0 || config_.record_version) {
value_ptr->SetStep(version_buff[i]);
}
if (value_ptr->GetFreq() >= config_.filter_freq) {
if (!is_filter) {
ev_->LookupOrCreateEmb(value_ptr,
value_buff + i * ev_->ValueLen());
} else {
ev_->LookupOrCreateEmb(value_ptr,
ev_->GetDefaultValue(key_buff[i]));
}
}
}
if (ev_->IsMultiLevel() && !ev_->IsUseHbm() && config_.is_primary()) {
ev_->UpdateCache(key_buff, key_num, version_buff, freq_buff);
}
return Status::OK();
bool is_admit(K key, ValuePtr<V>* value_ptr) override {
return (GetFreq(key, value_ptr) >= config_.filter_freq);
}

Status ImportToDram(RestoreBuffer& restore_buff,
int64 key_num,
int bucket_num,
int64 partition_id,
int64 partition_num,
bool is_filter,
V* default_values) override {
Status Restore(int64 key_num, int bucket_num, int64 partition_id,
int64 partition_num, int64 value_len, bool is_filter,
bool to_dram, bool is_incr, RestoreBuffer& restore_buff) override {
K* key_buff = (K*)restore_buff.key_buffer;
V* value_buff = (V*)restore_buff.value_buffer;
int64* version_buff = (int64*)restore_buff.version_buffer;
Expand All @@ -181,7 +137,7 @@ class CounterFilterPolicy : public FilterPolicy<K, V, EV> {
continue;
}
ValuePtr<V>* value_ptr = nullptr;
ev_->CreateKeyOnDram(key_buff[i], &value_ptr);
ev_->CreateKey(key_buff[i], &value_ptr, to_dram);
if (!is_filter) {
if (freq_buff[i] >= config_.filter_freq) {
value_ptr->SetFreq(freq_buff[i]);
Expand All @@ -195,28 +151,12 @@ class CounterFilterPolicy : public FilterPolicy<K, V, EV> {
value_ptr->SetStep(version_buff[i]);
}
if (value_ptr->GetFreq() >= config_.filter_freq) {
if (!is_filter) {
ev_->LookupOrCreateEmb(value_ptr,
value_buff + i * ev_->ValueLen(), ev_allocator());
} else {
ev_->LookupOrCreateEmb(value_ptr,
default_values +
(key_buff[i] % config_.default_value_dim)
* ev_->ValueLen(),
ev_allocator());
}
LookupOrCreateEmbInternal(is_filter, to_dram, i, value_len,
value_ptr, value_buff, key_buff);
}
}
return Status::OK();
}

bool is_admit(K key, ValuePtr<V>* value_ptr) override {
return (GetFreq(key, value_ptr) >= config_.filter_freq);
}

private:
EmbeddingConfig config_;
EV* ev_;
};

} // tensorflow
Expand Down
9 changes: 1 addition & 8 deletions tensorflow/core/framework/embedding/dram_leveldb_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,9 @@ class DramLevelDBStore : public MultiTierStorage<K, V> {
}

void Insert(K key, ValuePtr<V>** value_ptr,
size_t alloc_len) override {
size_t alloc_len, bool to_dram = false) override {
dram_->Insert(key, value_ptr, alloc_len);
}

Status GetOrCreate(K key, ValuePtr<V>** value_ptr,
size_t size, CopyBackFlag &need_copyback) override {
LOG(FATAL)<<"GetOrCreate(K key, ValuePtr<V>** value_ptr, "
Expand Down Expand Up @@ -112,12 +111,6 @@ class DramLevelDBStore : public MultiTierStorage<K, V> {
return false;
}

bool IsUsePersistentStorage() override {
/*The return value is set to false temporarily,
because the corresponding interface is not implemented.*/
return false;
}

void iterator_mutex_lock() override {
leveldb_->get_mutex()->lock();
}
Expand Down
9 changes: 1 addition & 8 deletions tensorflow/core/framework/embedding/dram_pmem_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,9 @@ class DramPmemStorage : public MultiTierStorage<K, V> {
}

void Insert(K key, ValuePtr<V>** value_ptr,
size_t alloc_len) override {
size_t alloc_len, bool to_dram = false) override {
dram_->Insert(key, value_ptr, alloc_len);
}

Status GetOrCreate(K key, ValuePtr<V>** value_ptr,
size_t size, CopyBackFlag &need_copyback) override {
LOG(FATAL)<<"GetOrCreate(K key, ValuePtr<V>** value_ptr, "
Expand All @@ -95,12 +94,6 @@ class DramPmemStorage : public MultiTierStorage<K, V> {
return false;
}

bool IsUsePersistentStorage() override {
/*The return value is set to false temporarily,
because the corresponding interface is not implemented.*/
return false;
}

Status GetOrCreate(K key, ValuePtr<V>** value_ptr,
size_t size) override {
Status s = dram_->Get(key, value_ptr);
Expand Down
36 changes: 18 additions & 18 deletions tensorflow/core/framework/embedding/dram_ssd_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class DramSsdHashStorage : public MultiTierStorage<K, V> {
}

void Insert(K key, ValuePtr<V>** value_ptr,
size_t alloc_len) override {
size_t alloc_len, bool to_dram = false) override {
dram_->Insert(key, value_ptr, alloc_len);
}

Expand Down Expand Up @@ -210,27 +210,27 @@ class DramSsdHashStorage : public MultiTierStorage<K, V> {
return key_list->size() + ssd_rec_desc->key_list.size();
}

void RestoreSsdHashmap(
K* key_list, int64* key_file_id_list,
int64* key_offset_list, int64 num_of_keys,
int64* file_list, int64* invalid_record_count_list,
int64* record_count_list, int64 num_of_files,
const std::string& ssd_emb_file_name) override {
Status RestoreSSD(int64 emb_index, int64 emb_slot_num, int64 value_len,
const std::string& ssd_emb_file_name, EmbeddingVar<K, V>* ev,
RestoreSSDBuffer<K>& restore_buff) override {
int64 alloc_len = Storage<K, V>::ComputeAllocLen(value_len);
std::map<int64, int64> file_id_map;
for (int64 i = 0; i < num_of_files; i++) {
file_id_map[file_list[i]] = i;
for (int64 i = 0; i < restore_buff.num_of_files; i++) {
file_id_map[restore_buff.file_list_buf[i]] = i;
}

ssd_hash_->CopyEmbFilesFromCkpt(
file_list, invalid_record_count_list,
record_count_list, num_of_files,
ssd_emb_file_name);

ssd_hash_->Import(key_list, key_file_id_list,
key_offset_list, num_of_keys,
file_id_map);
ssd_hash_->CopyEmbFilesFromCkpt(restore_buff.file_list_buf,
restore_buff.invalid_record_count_list_buf,
restore_buff.record_count_list_buf,
restore_buff.num_of_files,
ssd_emb_file_name);

ssd_hash_->Import(restore_buff.key_list_buf,
restore_buff.key_file_id_list_buf,
restore_buff.key_offset_list_buf,
restore_buff.num_of_keys,
file_id_map);
}

Status Eviction(K* evict_ids, int64 evict_size) override {
ValuePtr<V>* value_ptr = nullptr;
for (int64 i = 0; i < evict_size; ++i) {
Expand Down
Loading

0 comments on commit 56cc51e

Please sign in to comment.