forked from DeepRec-AI/DeepRec
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Embedding] Add GPU fused_input_from_feature_column and fuse_embeddin…
…g_lookup_sparse. (DeepRec-AI#65)
- Loading branch information
Showing
13 changed files
with
2,102 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
139 changes: 139 additions & 0 deletions
139
tensorflow/contrib/layers/python/layers/feature_column_fused_ops.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
from tensorflow.python.framework import ops | ||
from tensorflow.python.ops import partitioned_variables | ||
from tensorflow.python.ops import variable_scope | ||
from tensorflow.python.ops import array_ops | ||
from tensorflow.python.ops.fused_embedding_ops import fused_embedding_lookup_sparse | ||
from tensorflow.python.framework import dtypes | ||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables | ||
from tensorflow.contrib.layers.python.layers.feature_column_ops import check_feature_columns | ||
from tensorflow.contrib.layers.python.layers.feature_column_ops import _Transformer | ||
from tensorflow.contrib.layers.python.layers import feature_column as fc | ||
|
||
|
||
def input_from_feature_columns_fused(columns_to_tensors, | ||
feature_columns, | ||
trainable=True, | ||
scope=None, | ||
cols_to_outs=None): | ||
"""Implementation of `input_from(_sequence)_feature_columns`.""" | ||
columns_to_tensors = columns_to_tensors.copy() | ||
check_feature_columns(feature_columns) | ||
if cols_to_outs is not None and not isinstance(cols_to_outs, dict): | ||
raise ValueError('cols_to_outs must be a dict unless None') | ||
with variable_scope.variable_scope(scope, | ||
default_name="input_from_feature_columns_fused", | ||
values=columns_to_tensors.values()): | ||
output_tensors = [] | ||
transformer = _Transformer(columns_to_tensors) | ||
|
||
for column in sorted(set(feature_columns), key=lambda x: x.key): | ||
with variable_scope.variable_scope(None, | ||
default_name=column.name, | ||
values=columns_to_tensors.values()): | ||
transformed_tensor = transformer.transform(column) | ||
# pylint: disable=protected-access | ||
args = column._deep_embedding_lookup_arguments( | ||
transformed_tensor) | ||
output = embeddings_from_arguments_fused( | ||
column, args, trainable) | ||
output_tensors.append(output) | ||
if cols_to_outs is not None: | ||
cols_to_outs[column] = output_tensors[-1] | ||
return array_ops.concat(output_tensors, 1) | ||
|
||
|
||
def embeddings_from_arguments_fused(column, | ||
args, | ||
trainable): | ||
# This option is only enabled for scattered_embedding_column. | ||
if args.hash_key: | ||
raise NotImplementedError("not implemented yet for hash_key") | ||
|
||
graph = ops.get_default_graph() | ||
partition_num = args.embedding_var_part_num | ||
if partition_num is None: | ||
partitioner = None | ||
else: | ||
partitioner = partitioned_variables.fixed_size_partitioner(partition_num) | ||
|
||
# 1. get the embedding_weights | ||
if args.shared_embedding_name is not None: | ||
shared_embedding_collection_name = ("SHARED_EMBEDDING_COLLECTION_" + | ||
args.shared_embedding_name.upper()) | ||
shared_embedding_collection = ( | ||
graph.get_collection_ref(shared_embedding_collection_name)) | ||
shape = [args.vocab_size, args.dimension] | ||
if shared_embedding_collection: | ||
if len(shared_embedding_collection) > 1: | ||
raise ValueError("Collection %s can only contain one " | ||
"(partitioned) variable." % | ||
shared_embedding_collection_name) | ||
else: | ||
embeddings = shared_embedding_collection[0] | ||
if (not args.use_embedding_var and embeddings.get_shape() != shape): | ||
raise ValueError("The embedding variable with name {} already " | ||
"exists, but its shape does not match required " | ||
"embedding shape here. Please make sure to use " | ||
"different shared_embedding_name for different " | ||
"shared embeddings.".format(args.shared_embedding_name)) | ||
else: | ||
if args.use_embedding_var: | ||
embeddings = variable_scope.get_embedding_variable_internal( | ||
name=args.shared_embedding_name, | ||
embedding_dim=args.dimension, | ||
key_dtype=dtypes.int64, | ||
initializer=args.initializer, | ||
trainable=(trainable and args.trainable), | ||
collections=None, | ||
partitioner=partitioner, | ||
steps_to_live=args.steps_to_live, | ||
init_data_source=args.init_data_source, | ||
ht_partition_num=args.ht_partition_num, | ||
evconfig=args.evconfig) | ||
graph.add_to_collection( | ||
ops.GraphKeys.EMBEDDING_VARIABLES, embeddings) | ||
else: | ||
embeddings = contrib_variables.model_variable( | ||
name=args.shared_embedding_name, | ||
shape=shape, | ||
dtype=dtypes.float32, | ||
initializer=args.initializer, | ||
trainable=(trainable and args.trainable), | ||
collections=None) | ||
graph.add_to_collection( | ||
shared_embedding_collection_name, embeddings) | ||
else: | ||
if args.use_embedding_var: | ||
embeddings = variable_scope.get_embedding_variable_internal( | ||
name="weights", | ||
embedding_dim=args.dimension, | ||
key_dtype=dtypes.int64, | ||
initializer=args.initializer, | ||
trainable=(trainable and args.trainable), | ||
collections=None, | ||
partitioner=partitioner, | ||
steps_to_live=args.steps_to_live, | ||
init_data_source=args.init_data_source, | ||
ht_partition_num=args.ht_partition_num, | ||
evconfig=args.evconfig) | ||
graph.add_to_collection( | ||
ops.GraphKeys.EMBEDDING_VARIABLES, embeddings) | ||
else: | ||
embeddings = contrib_variables.model_variable( | ||
name="weights", | ||
shape=[args.vocab_size, args.dimension], | ||
dtype=dtypes.float32, | ||
initializer=args.initializer, | ||
trainable=(trainable and args.trainable), | ||
collections=None) | ||
|
||
if fc._is_variable(embeddings): | ||
embeddings = [embeddings] | ||
else: | ||
embeddings = embeddings._get_variable_list() # pylint: disable=protected-access | ||
# pylint: disable=protected-access | ||
fc._maybe_restore_from_checkpoint(column._checkpoint_path(), embeddings) | ||
|
||
# 2. look up | ||
return fused_embedding_lookup_sparse(embeddings, args.input_tensor, | ||
combiner=args.combiner, max_norm=args.max_norm) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
60 changes: 60 additions & 0 deletions
60
tensorflow/core/kernels/fused_embedding/fused_embedding_common.cu.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#ifndef TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_FUSED_EMBEDDING_COMMON_CU_H_ | ||
#define TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_FUSED_EMBEDDING_COMMON_CU_H_ | ||
|
||
#if GOOGLE_CUDA | ||
|
||
namespace tensorflow { | ||
|
||
namespace { | ||
enum Combiner { Mean, Sum, Sqrtn }; | ||
|
||
template <Combiner combiner> | ||
__forceinline__ __device__ float Combine(const float in, | ||
const int feature_num); | ||
|
||
template <> | ||
__forceinline__ __device__ float Combine<Sqrtn>(const float in, | ||
const int feature_num) { | ||
return in / sqrtf(feature_num); | ||
} | ||
|
||
template <> | ||
__forceinline__ __device__ float Combine<Mean>(const float in, | ||
const int feature_num) { | ||
return in / feature_num; | ||
} | ||
|
||
template <> | ||
__forceinline__ __device__ float Combine<Sum>(const float in, | ||
const int feature_num) { | ||
return in; | ||
} | ||
|
||
template <Combiner combiner> | ||
__forceinline__ __device__ float CombineGrad(const float grad, | ||
const int feature_num); | ||
|
||
template <> | ||
__forceinline__ __device__ float CombineGrad<Sqrtn>(const float grad, | ||
const int feature_num) { | ||
return grad / sqrtf(feature_num); | ||
} | ||
|
||
template <> | ||
__forceinline__ __device__ float CombineGrad<Mean>(const float grad, | ||
const int feature_num) { | ||
return grad / feature_num; | ||
} | ||
|
||
template <> | ||
__forceinline__ __device__ float CombineGrad<Sum>(const float grad, | ||
const int feature_num) { | ||
return grad; | ||
} | ||
} // namespace | ||
|
||
} // namespace tensorflow | ||
|
||
#endif // GOOGLE_CUDA | ||
|
||
#endif // TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_FUSED_EMBEDDING_COMMON_CU_H_ |
Oops, something went wrong.