Skip to content

Commit

Permalink
[Embedding] Enable EmbeddingVariable HBM unit tests. (DeepRec-AI#430)
Browse files Browse the repository at this point in the history
  • Loading branch information
candyzone authored Sep 9, 2022
1 parent 712419e commit 0f63079
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
5 changes: 4 additions & 1 deletion tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class EmbeddingVar : public ResourceBase {
name_(name),
storage_manager_(storage_manager),
default_value_(nullptr),
default_value_no_permission_(nullptr),
value_len_(0),
alloc_(alloc),
emb_config_(emb_cfg) {
Expand Down Expand Up @@ -475,7 +476,9 @@ class EmbeddingVar : public ResourceBase {
}
}
TypedAllocator::Deallocate(alloc_, default_value_, value_len_);
TypedAllocator::Deallocate(alloc_, default_value_no_permission_, value_len_);
if (default_value_no_permission_) {
TypedAllocator::Deallocate(alloc_, default_value_no_permission_, value_len_);
}
}
TF_DISALLOW_COPY_AND_ASSIGN(EmbeddingVar);
};
Expand Down
5 changes: 0 additions & 5 deletions tensorflow/python/ops/embedding_variable_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2265,7 +2265,6 @@ def testEmbeddingVariableForInference(self):
self.assertAllEqual(np.array([0,3]), s)
del os.environ["INFERENCE_MODE"]

'''
@test_util.run_gpu_only
def testEmbeddingVariableForHBMandDRAM(self):
print("testEmbeddingVariableForHBMandDRAM")
Expand Down Expand Up @@ -2299,19 +2298,15 @@ def runTestAdagrad(self, var, g):
embedding_dim = 128,
initializer=init_ops.ones_initializer(dtypes.float32),
partitioner=partitioned_variables.fixed_size_partitioner(num_shards=1),
#steps_to_live=5,
ev_option = variables.EmbeddingVariableOption(storage_option=variables.StorageOption(storage_type=config_pb2.StorageType.HBM_DRAM)))
var = variable_scope.get_variable("var_2", shape=[1024, 128], initializer=init_ops.ones_initializer(dtypes.float32))

emb1 = runTestAdagrad(self, emb_var, g)
emb2 = runTestAdagrad(self, var, g)
print(emb1)
print(emb2)

for i in range(0, 6):
for j in range(0, 3):
self.assertEqual(emb1[i][j], emb2[i][j])
'''

if __name__ == "__main__":
googletest.main()

0 comments on commit 0f63079

Please sign in to comment.