Skip to content

Commit

Permalink
[Embedding] Support different layout of ValuePtr in EmbeddingVariable.
Browse files Browse the repository at this point in the history
  • Loading branch information
liutongxuan committed Sep 25, 2021
1 parent 5ddb7c1 commit bccc9ca
Show file tree
Hide file tree
Showing 7 changed files with 471 additions and 200 deletions.
186 changes: 106 additions & 80 deletions tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,34 @@ struct EmbeddingConfig {
int64 slot_num;
std::string name;
int64 steps_to_live;
int64 max_freq;
int64 filter_freq;
int64 max_freq;
float l2_weight_threshold;
LayoutType layout_type;

EmbeddingConfig(int64 emb_index = 0, int64 primary_emb_index = 0,
int64 block_num = 1, int slot_num = 1,
const std::string& name = "", int64 steps_to_live = 0,
int64 filter_freq = 0, int64 max_freq = 999999, float l2_weight_threshold = -1.0):
emb_index(emb_index), primary_emb_index(primary_emb_index),
block_num(block_num), slot_num(slot_num),
name(name), steps_to_live(steps_to_live),
int64 filter_freq = 0, int64 max_freq = 999999,
float l2_weight_threshold = -1.0, const std::string& layout = "normal"):
emb_index(emb_index),
primary_emb_index(primary_emb_index),
block_num(block_num),
slot_num(slot_num),
name(name),
steps_to_live(steps_to_live),
filter_freq(filter_freq),
max_freq(max_freq),
l2_weight_threshold(l2_weight_threshold) {}
l2_weight_threshold(l2_weight_threshold) {
if ("normal" == layout) {
layout_type = LayoutType::NORMAL;
} else if ("light" == layout) {
layout_type = LayoutType::LIGHT;
} else {
LOG(WARNING) << "Unknown layout: " << layout << ", use LayoutType::NORMAL by default.";
layout_type = LayoutType::NORMAL;
}
}

bool is_primary() const {
return emb_index == primary_emb_index;
Expand All @@ -80,12 +94,21 @@ struct EmbeddingConfig {
return filter_freq;
}

LayoutType get_layout_type() {
return layout_type;
}

std::string DebugString() const {
return strings::StrCat("opname: ", name,
" emb_index: ", emb_index,
" primary_emb_index: ", primary_emb_index,
" block_num: ", block_num,
" slot_num: ", slot_num);
" slot_num: ", slot_num,
" layout_type: ", static_cast<int>(layout_type),
" steps_to_live: ", steps_to_live,
" filter_freq: ", filter_freq,
" max_freq: ", max_freq,
" l2_weight_threshold: ", l2_weight_threshold);
}
};

Expand All @@ -101,7 +124,15 @@ class EmbeddingVar : public ResourceBase {
default_value_(nullptr),
value_len_(0),
alloc_(alloc),
emb_config_(emb_cfg) {}
emb_config_(emb_cfg) {
if (LayoutType::LIGHT == emb_config_.get_layout_type()) {
new_value_ptr_fn = [] (size_t size) { return new LightValuePtr<V>(size); };
} else if (LayoutType::NORMAL == emb_config_.get_layout_type()) {
new_value_ptr_fn = [] (size_t size) { return new NormalValuePtr<V>(size); };
} else {
LOG(FATAL) << name_ << ", Unsupport EmbeddingVariable LayoutType.";
}
}

Status Init(const Tensor& default_tensor) {
if (default_tensor.dims() != 1) {
Expand All @@ -126,27 +157,6 @@ class EmbeddingVar : public ResourceBase {
return is_initialized_;
}

Status LookupOrCreateKeyInternal(K key, ValuePtr<V>** value_ptr, size_t size) {
Status s = kv_->Lookup(key, value_ptr);
if (s.ok()) {
// Found
return s;
} else {
// Not found
*value_ptr = new ValuePtr<V>(size);
s = kv_->Insert(key, *value_ptr);
if (s.ok()) {
// Insert Success
return s;
} else {
// Insert Failed, key already exist
delete *value_ptr;
s = kv_->Lookup(key, value_ptr);
return s;
}
}
}

Status LookupOrCreateKey(K key, ValuePtr<V>** value_ptr, int64 update_version = -1) {
Status s = LookupOrCreateKeyInternal(key, value_ptr, emb_config_.total_num());
TF_CHECK_OK(s);
Expand Down Expand Up @@ -207,20 +217,12 @@ class EmbeddingVar : public ResourceBase {
return emb_config_.filter_freq;
}

void SetMinFreq(int64 min_freq) {
emb_config_.filter_freq = min_freq;
}

void SetL2WeightThreshold(float l2_weight_threshold) {
emb_config_.l2_weight_threshold = l2_weight_threshold;
}

float GetL2WeightThreshold() {
return emb_config_.l2_weight_threshold;
}

std::string DebugString() const {
return kv_->DebugString();
return emb_config_.DebugString();
}

Status Import(RestoreBuffer& restore_buff,
Expand All @@ -239,12 +241,17 @@ class EmbeddingVar : public ResourceBase {
continue;
}
ValuePtr<V>* value_ptr = nullptr;
TF_CHECK_OK(LookupOrCreateKey(key_buff[i], &value_ptr, version_buff[i]));
if (emb_config_.is_primary()){
if (freq_buff[i] <= emb_config_.filter_freq) {
value_ptr->SetFreq(emb_config_.filter_freq);
}else {
value_ptr->SetFreq(freq_buff[i]);
TF_CHECK_OK(LookupOrCreateKey(key_buff[i], &value_ptr));
if (emb_config_.is_primary()) {
if (emb_config_.filter_freq != 0) {
if (freq_buff[i] <= emb_config_.filter_freq) {
value_ptr->SetFreq(emb_config_.filter_freq);
} else {
value_ptr->SetFreq(freq_buff[i]);
}
}
if (emb_config_.steps_to_live != 0) {
value_ptr->SetStep(version_buff[i]);
}
}
LookupOrCreateEmb(value_ptr, emb_config_, value_buff + i * value_len_);
Expand All @@ -265,43 +272,36 @@ class EmbeddingVar : public ResourceBase {
key_list->push_back(key_list_tmp[i]);
if (emb_config_.filter_freq != 0) {
int64 dump_freq = value_ptr_list[i]->GetFreq();
freq_list->push_back(dump_freq);
} else {
freq_list->push_back(0);
freq_list->push_back(dump_freq);
}
if (emb_config_.steps_to_live != 0) {
int64 dump_version = value_ptr_list[i]->GetStep();
version_list->push_back(dump_version);
} else {
version_list->push_back(0);
}
}
}
return key_list->size();
}

Status Shrink(int64 gs) {
if (emb_config_.steps_to_live > 0) {
std::vector<K> key_list;
std::vector<ValuePtr<V>* > value_ptr_list;
kv_->GetSnapshot(&key_list, &value_ptr_list);
std::vector<std::pair<K, ValuePtr<V>* > > to_deleted;
for (int64 i = 0; i < key_list.size(); ++i) {
int64 version = value_ptr_list[i]->GetStep();
if (gs - version > emb_config_.steps_to_live) {
to_deleted.push_back(std::pair<K, ValuePtr<V>*>(key_list[i], value_ptr_list[i]));
}
}
for (const auto it : to_deleted) {
// TODO memory recycle
(it.second)->Destroy(value_len_);
delete it.second;
kv_->Remove(it.first);
}
Status Destroy(int64 value_len) {
std::vector<K> key_list;
std::vector<ValuePtr<V>* > value_ptr_list;
kv_->GetSnapshot(&key_list, &value_ptr_list);
for (auto value_ptr : value_ptr_list) {
value_ptr->Destroy(value_len);
delete value_ptr;
}
return Status::OK();
}

mutex* mu() {
return &mu_;
}

KVInterface<K, V>* kv() {
return kv_;
}

Status Shrink() {
std::vector<K> key_list;
std::vector<ValuePtr<V>* > value_ptr_list;
Expand All @@ -327,29 +327,55 @@ class EmbeddingVar : public ResourceBase {
return Status::OK();
}

Status Destroy(int64 value_len) {
std::vector<K> key_list;
std::vector<ValuePtr<V>* > value_ptr_list;
kv_->GetSnapshot(&key_list, &value_ptr_list);
for (auto value_ptr : value_ptr_list) {
value_ptr->Destroy(value_len);
delete value_ptr;
Status Shrink(int64 gs) {
if (emb_config_.steps_to_live > 0) {
std::vector<K> key_list;
std::vector<ValuePtr<V>* > value_ptr_list;
kv_->GetSnapshot(&key_list, &value_ptr_list);
std::vector<std::pair<K, ValuePtr<V>* > > to_deleted;
for (int64 i = 0; i < key_list.size(); ++i) {
int64 version = value_ptr_list[i]->GetStep();
if (gs - version > emb_config_.steps_to_live) {
to_deleted.push_back(std::pair<K, ValuePtr<V>*>(key_list[i], value_ptr_list[i]));
}
}
for (const auto it : to_deleted) {
// TODO memory recycle
(it.second)->Destroy(value_len_);
delete it.second;
kv_->Remove(it.first);
}
}
return Status::OK();
}

mutex* mu() {
return &mu_;
}

KVInterface<K, V>* kv() {
return kv_;
private:
Status LookupOrCreateKeyInternal(K key, ValuePtr<V>** value_ptr, size_t size) {
Status s = kv_->Lookup(key, value_ptr);
if (s.ok()) {
// Found
return s;
} else {
// Not found
*value_ptr = new_value_ptr_fn(size);
s = kv_->Insert(key, *value_ptr);
if (s.ok()) {
// Insert Success
return s;
} else {
// Insert Failed, key already exist
delete *value_ptr;
s = kv_->Lookup(key, value_ptr);
return s;
}
}
}

private:
std::string name_;
KVInterface<K, V>* kv_;
bool is_initialized_ = false;
std::function<ValuePtr<V>*(size_t)> new_value_ptr_fn;

mutex mu_;

Expand Down
Loading

0 comments on commit bccc9ca

Please sign in to comment.