Skip to content

Commit

Permalink
[Embedding] Initialize the layout of ValuePtr in InitializeOp and Imp…
Browse files Browse the repository at this point in the history
…ortOp. (#405)
  • Loading branch information
lixy9474 authored Aug 29, 2022
1 parent ba3da62 commit cec4417
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 17 deletions.
31 changes: 28 additions & 3 deletions tensorflow/core/kernels/kv_variable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ class InitializeKvVariableOp : public OpKernel {
&false_positive_probability_));
OP_REQUIRES_OK(c, c->GetAttr("l2_weight_threshold",
&l2_weight_threshold_));
OP_REQUIRES_OK(c, c->GetAttr("layout", &layout_));
OP_REQUIRES_OK(c, c->GetAttr("default_value_dim", &default_value_dim_));
OP_REQUIRES_OK(c, c->GetAttr("slot_num", &slot_num_));
OP_REQUIRES_OK(c, c->GetAttr("record_freq", &record_freq_));
Expand All @@ -156,6 +155,20 @@ class InitializeKvVariableOp : public OpKernel {
filter_freq_ = 0;
}

if ((filter_freq_ != 0 && max_element_size_ == 0)
|| steps_to_live_ != 0 || record_freq_
|| record_version_ || storage_type > 5) {
if (block_num_ > 1 || (filter_freq_ != 0 && storage_type <= 5)) {
layout_ = "normal";
} else {
layout_ = "normal_contiguous";
}
} else {
layout_ = "light";
}

CHECK(block_num_ == 1 || layout_ != "normal_contiguous");

if (steps_to_live_ == kEmbeddingVarUseDB ||
steps_to_live_ == kInitializableEmbeddingVarUseDB) {
LOG(INFO) << "hashmap use db";
Expand Down Expand Up @@ -199,7 +212,6 @@ class InitializeKvVariableOp : public OpKernel {
std::string opname = handle_self.name();

EmbeddingVar<TKey, TValue>* ev = nullptr;
CHECK(block_num_ == 1 || layout_ != "normal_contiguous");

if (handle_self.name() == handle_primary.name() &&
handle_self.container() == handle_primary.container()) {
Expand Down Expand Up @@ -642,7 +654,6 @@ class KvResourceImportV2Op: public AsyncOpKernel {
&false_positive_probability_));
OP_REQUIRES_OK(c, c->GetAttr("l2_weight_threshold",
&l2_weight_threshold_));
OP_REQUIRES_OK(c, c->GetAttr("layout", &layout_));
OP_REQUIRES_OK(c, c->GetAttr("max_freq", &max_freq_));
OP_REQUIRES_OK(c, c->GetAttr("default_value_dim",
&default_value_dim_));
Expand All @@ -656,6 +667,20 @@ class KvResourceImportV2Op: public AsyncOpKernel {
OP_REQUIRES_OK(c, c->GetAttr("record_freq", &record_freq_));
OP_REQUIRES_OK(c, c->GetAttr("record_version", &record_version_));

if ((filter_freq_ != 0 && max_element_size_ == 0)
|| steps_to_live_ != -1 || record_freq_
|| record_version_ || storage_type > 5) {
if (block_num_ > 1 || (filter_freq_ != 0 && storage_type <= 5)) {
layout_ = "normal";
} else {
layout_ = "normal_contiguous";
}
} else {
layout_ = "light";
}

CHECK(block_num_ == 1 || layout_ != "normal_contiguous");

TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_EV_ASYNC_RESTORE", true,
&ev_async_restore_));
}
Expand Down
14 changes: 1 addition & 13 deletions tensorflow/python/ops/kv_variable_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,18 +293,6 @@ def _init_from_args(self,
self._storage_path = evconfig.storage_path
self._storage_size = evconfig.storage_size
self._default_value_dim = evconfig.default_value_dim
if (isinstance(evconfig.filter_strategy, variables.CounterFilter) and self._filter_freq != 0) or \
self._steps_to_live not in [0, None] or self._record_version or \
self._storage_type in multi_level_list or self._record_freq:
if self._block_num not in [1, None] and self._storage_type in multi_level_list:
raise ValueError("Dynamic-dimension Embedding and Multi-level EV can't be enabled together")
if self._block_num not in [1, None] or \
(self._filter_freq != 0 and self._storage_type not in multi_level_list):
self._layout = "normal"
else:
self._layout = "normal_contiguous"
else:
self._layout = "light"

if self._primary is None:
self._is_primary = True
Expand Down Expand Up @@ -409,7 +397,7 @@ def _init_from_args(self,
false_positive_probability = self._false_positive_probability,
counter_type = self._counter_type,
max_freq = 99999,
layout = self._layout,
layout = "",
storage_type = self._storage_type,
storage_path = self._storage_path,
storage_size = self._storage_size,
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/training/saving/saveable_object_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def restore(self, restored_tensors, unused_restored_shapes):
max_element_size = self.var._max_element_size,
false_positive_probability = self.var._false_positive_probability,
counter_type = self.var._counter_type,
layout = self.var._layout,
layout = "",
storage_type=self.var._storage_type,
storage_path=self.var._storage_path,
storage_size=self.var._storage_size,
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/python/training/slot_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype, slot_con
primary=primary._primary,
slot_num=slot_config.slot_num,
storage_type=primary.storage_type,
storage_path=primary._storage_path,
storage_size=primary._storage_size,
l2_weight_threshold=primary._l2_weight_threshold,
filter_strategy=filter_strategy)
)
Expand Down

0 comments on commit cec4417

Please sign in to comment.