Skip to content

Commit

Permalink
[Embedding] Add GPU fused_input_from_feature_column and fuse_embeddin…
Browse files Browse the repository at this point in the history
…g_lookup_sparse. (DeepRec-AI#65)
  • Loading branch information
nvzhou authored Dec 4, 2021
1 parent b52e668 commit 6204d65
Show file tree
Hide file tree
Showing 13 changed files with 2,102 additions and 1 deletion.
2 changes: 2 additions & 0 deletions tensorflow/contrib/layers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ tf_custom_op_py_library(
"python/layers/encoders.py",
"python/layers/feature_column.py",
"python/layers/feature_column_ops.py",
"python/layers/feature_column_fused_ops.py",
"python/layers/initializers.py",
"python/layers/layers.py",
"python/layers/normalization.py",
Expand Down Expand Up @@ -100,6 +101,7 @@ tf_custom_op_py_library(
"//tensorflow/python:embedding_ops",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:fused_embedding_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:layers",
"//tensorflow/python:layers_base",
Expand Down
1 change: 0 additions & 1 deletion tensorflow/contrib/layers/python/layers/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,6 @@ def _embeddings_from_arguments(column,
name=column.name + "weights",
max_norm=args.max_norm)


def _maybe_restore_from_checkpoint(checkpoint_path, variable):
if checkpoint_path is not None:
path, tensor_name = checkpoint_path
Expand Down
139 changes: 139 additions & 0 deletions tensorflow/contrib/layers/python/layers/feature_column_fused_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from tensorflow.python.framework import ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.fused_embedding_ops import fused_embedding_lookup_sparse
from tensorflow.python.framework import dtypes
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.layers.python.layers.feature_column_ops import check_feature_columns
from tensorflow.contrib.layers.python.layers.feature_column_ops import _Transformer
from tensorflow.contrib.layers.python.layers import feature_column as fc


def input_from_feature_columns_fused(columns_to_tensors,
feature_columns,
trainable=True,
scope=None,
cols_to_outs=None):
"""Implementation of `input_from(_sequence)_feature_columns`."""
columns_to_tensors = columns_to_tensors.copy()
check_feature_columns(feature_columns)
if cols_to_outs is not None and not isinstance(cols_to_outs, dict):
raise ValueError('cols_to_outs must be a dict unless None')
with variable_scope.variable_scope(scope,
default_name="input_from_feature_columns_fused",
values=columns_to_tensors.values()):
output_tensors = []
transformer = _Transformer(columns_to_tensors)

for column in sorted(set(feature_columns), key=lambda x: x.key):
with variable_scope.variable_scope(None,
default_name=column.name,
values=columns_to_tensors.values()):
transformed_tensor = transformer.transform(column)
# pylint: disable=protected-access
args = column._deep_embedding_lookup_arguments(
transformed_tensor)
output = embeddings_from_arguments_fused(
column, args, trainable)
output_tensors.append(output)
if cols_to_outs is not None:
cols_to_outs[column] = output_tensors[-1]
return array_ops.concat(output_tensors, 1)


def embeddings_from_arguments_fused(column,
args,
trainable):
# This option is only enabled for scattered_embedding_column.
if args.hash_key:
raise NotImplementedError("not implemented yet for hash_key")

graph = ops.get_default_graph()
partition_num = args.embedding_var_part_num
if partition_num is None:
partitioner = None
else:
partitioner = partitioned_variables.fixed_size_partitioner(partition_num)

# 1. get the embedding_weights
if args.shared_embedding_name is not None:
shared_embedding_collection_name = ("SHARED_EMBEDDING_COLLECTION_" +
args.shared_embedding_name.upper())
shared_embedding_collection = (
graph.get_collection_ref(shared_embedding_collection_name))
shape = [args.vocab_size, args.dimension]
if shared_embedding_collection:
if len(shared_embedding_collection) > 1:
raise ValueError("Collection %s can only contain one "
"(partitioned) variable." %
shared_embedding_collection_name)
else:
embeddings = shared_embedding_collection[0]
if (not args.use_embedding_var and embeddings.get_shape() != shape):
raise ValueError("The embedding variable with name {} already "
"exists, but its shape does not match required "
"embedding shape here. Please make sure to use "
"different shared_embedding_name for different "
"shared embeddings.".format(args.shared_embedding_name))
else:
if args.use_embedding_var:
embeddings = variable_scope.get_embedding_variable_internal(
name=args.shared_embedding_name,
embedding_dim=args.dimension,
key_dtype=dtypes.int64,
initializer=args.initializer,
trainable=(trainable and args.trainable),
collections=None,
partitioner=partitioner,
steps_to_live=args.steps_to_live,
init_data_source=args.init_data_source,
ht_partition_num=args.ht_partition_num,
evconfig=args.evconfig)
graph.add_to_collection(
ops.GraphKeys.EMBEDDING_VARIABLES, embeddings)
else:
embeddings = contrib_variables.model_variable(
name=args.shared_embedding_name,
shape=shape,
dtype=dtypes.float32,
initializer=args.initializer,
trainable=(trainable and args.trainable),
collections=None)
graph.add_to_collection(
shared_embedding_collection_name, embeddings)
else:
if args.use_embedding_var:
embeddings = variable_scope.get_embedding_variable_internal(
name="weights",
embedding_dim=args.dimension,
key_dtype=dtypes.int64,
initializer=args.initializer,
trainable=(trainable and args.trainable),
collections=None,
partitioner=partitioner,
steps_to_live=args.steps_to_live,
init_data_source=args.init_data_source,
ht_partition_num=args.ht_partition_num,
evconfig=args.evconfig)
graph.add_to_collection(
ops.GraphKeys.EMBEDDING_VARIABLES, embeddings)
else:
embeddings = contrib_variables.model_variable(
name="weights",
shape=[args.vocab_size, args.dimension],
dtype=dtypes.float32,
initializer=args.initializer,
trainable=(trainable and args.trainable),
collections=None)

if fc._is_variable(embeddings):
embeddings = [embeddings]
else:
embeddings = embeddings._get_variable_list() # pylint: disable=protected-access
# pylint: disable=protected-access
fc._maybe_restore_from_checkpoint(column._checkpoint_path(), embeddings)

# 2. look up
return fused_embedding_lookup_sparse(embeddings, args.input_tensor,
combiner=args.combiner, max_norm=args.max_norm)
3 changes: 3 additions & 0 deletions tensorflow/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,7 @@ tf_gen_op_libs(
"feature_column_ops",
"function_ops",
"functional_ops",
"fused_embedding_ops",
"hash_ops",
"hash_training_ops",
"fuserecv_ops",
Expand Down Expand Up @@ -1427,6 +1428,7 @@ cc_library(
":feature_column_ops_op_lib",
":function_ops_op_lib",
":functional_ops_op_lib",
":fused_embedding_ops_op_lib",
":fuserecv_ops_op_lib",
":hash_ops_op_lib",
":hash_training_ops_op_lib",
Expand Down Expand Up @@ -1609,6 +1611,7 @@ cc_library(
"//tensorflow/core/kernels:feature_column_ops",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:functional_ops",
"//tensorflow/core/kernels:fused_embedding_ops",
"//tensorflow/core/kernels:grappler",
"//tensorflow/core/kernels:hash_ops",
"//tensorflow/core/kernels:histogram_op",
Expand Down
38 changes: 38 additions & 0 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1945,6 +1945,29 @@ tf_cuda_cc_test(
],
)

tf_cuda_cc_test(
name = "fused_embedding_ops_test",
size = "small",
srcs = ["fused_embedding/fused_embedding_local_ops_test.cc",
"fused_embedding/fused_embedding_ops_test.cc"],
tags = tf_cuda_tests_tags(),
deps = [
":fused_embedding_ops",
":ops_testutil",
":ops_util",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

tf_cc_test(
name = "in_topk_op_test",
size = "small",
Expand Down Expand Up @@ -5259,6 +5282,21 @@ tf_kernel_library(
deps = REQUIRED_DEPS,
)

tf_cuda_library(
name = "fused_embedding_common_cuh",
hdrs = ["fused_embedding/fused_embedding_common.cu.h"],
)

tf_kernel_library(
name = "fused_embedding_ops",
gpu_srcs = [
"fused_embedding/fused_embedding_local_ops_gpu.cu.cc",
"fused_embedding/fused_embedding_ops_gpus.cu.cc"
],
deps = ["//third_party/eigen3"] + DYNAMIC_DEPS +
if_cuda(["@cub_archive//:cub", ":fused_embedding_common_cuh"]),
)

tf_kernel_library(
name = "run_graph_op",
prefix = "run_graph_op",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#ifndef TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_FUSED_EMBEDDING_COMMON_CU_H_
#define TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_FUSED_EMBEDDING_COMMON_CU_H_

#if GOOGLE_CUDA

namespace tensorflow {

namespace {
enum Combiner { Mean, Sum, Sqrtn };

template <Combiner combiner>
__forceinline__ __device__ float Combine(const float in,
const int feature_num);

template <>
__forceinline__ __device__ float Combine<Sqrtn>(const float in,
const int feature_num) {
return in / sqrtf(feature_num);
}

template <>
__forceinline__ __device__ float Combine<Mean>(const float in,
const int feature_num) {
return in / feature_num;
}

template <>
__forceinline__ __device__ float Combine<Sum>(const float in,
const int feature_num) {
return in;
}

template <Combiner combiner>
__forceinline__ __device__ float CombineGrad(const float grad,
const int feature_num);

template <>
__forceinline__ __device__ float CombineGrad<Sqrtn>(const float grad,
const int feature_num) {
return grad / sqrtf(feature_num);
}

template <>
__forceinline__ __device__ float CombineGrad<Mean>(const float grad,
const int feature_num) {
return grad / feature_num;
}

template <>
__forceinline__ __device__ float CombineGrad<Sum>(const float grad,
const int feature_num) {
return grad;
}
} // namespace

} // namespace tensorflow

#endif // GOOGLE_CUDA

#endif // TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_FUSED_EMBEDDING_COMMON_CU_H_
Loading

0 comments on commit 6204d65

Please sign in to comment.