This doc introduces how to use SparseOperationKit(SOK) lookup together with Embedding Variable in Deeprec. Users can now leverage both the highly efficient lookup operation provided in SOK and the flexible functionality with the DeepRec Embedding Variable
. This doc includes 3 parts:
- The API of SOK embedding lookup.
- A demo that shows how to use SOK embedding lookup together with
Embedding Variable
, with just a few lines change. - A guide about how to build and test this new feature.
def lookup_sparse(params,
indices,
combiners)
params
: list of variables. Each variable should be created bytf.get_embedding_variable
indices
: list of tf.SparseTensor. The indices/keys to lookup.combiners
: list of string. The combiner type. Can bemean
orsum
import time
import numpy as np
import tensorflow as tf
import horovod.tensorflow as hvd
from sparse_operation_kit import experiment as sok
# 1. init horovod and sok
hvd.init()
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU")
sok.init()
var = tf.get_embedding_variable("var_0",
embedding_dim=3,
initializer=tf.ones_initializer(tf.float32))
# 2. set target_gpu for Embedding Variable
var.target_gpu=-1
indices = tf.SparseTensor(
indices=tf.convert_to_tensor([[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]], dtype=tf.int64),
values=tf.convert_to_tensor([1, 1, 3, 4, 5], dtype=tf.int64),
dense_shape=[2, 3]
)
# 3. Use sok lookup_sparse
emb = sok.lookup_sparse([var], [indices], combiners=['sum'])
fun = tf.multiply(emb, 2.0, name='multiply')
loss = tf.reduce_sum(fun, name='reduce_sum')
opt = tf.train.AdagradOptimizer(0.1)
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
loss = hvd.allreduce(loss, op=hvd.Sum)
init = tf.global_variables_initializer()
sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
with tf.Session(config=sess_config) as sess:
sess.run([init])
print(sess.run([emb, train_op, loss]))
print(sess.run([emb, train_op, loss]))
print(sess.run([emb, train_op, loss]))
You can also try the demo.py by using horovodrun -np ${NUM_GPU} -H localhost:${NUM_GPU} python3 demo.py
.
There are 2 additional steps to use SOK embedding lookup with Embedding Variable
compared with using tf.nn.lookup
.
- Initialize the horovod and SOK at the beginning of your script. SOK is using horovod as its communication backend. And SOK also needs explicit initialization to allocate necessary buffer/communicate necessary data before training.
- Specify
target_gpu
forEmbedding Variable
.target_gpu
is used to specify which GPU you want to place this variable on. There are 2 options supported right now:
- -1: this means you want to distribute current variable into all GPUs
- ${GPU_ID}: an integer that belongs to 0~${NUM_GPU} - 1, so you put this embedding variable to the specified GPU.
- Clone Deeprec
git clone https://github.com/alibaba/DeepRec.git /DeepRec
- Build DeepRec from source code. You can follow the instruction in DeepRec Build. You should configure GPUEV support.
- Build SOK
bazel --output_base /tmp build -j 16 -c opt --config=opt //tensorflow/tools/pip_package:build_sok && ./bazel-bin/tensorflow/tools/pip_package/build_sok
- Run utest.
cd /DeepRec/addons/sparse_operation_kit/python && horovodrun -np ${NUM_GPU} -H localhost:${NUM_GPU} python3 embedding_var_lookup_utest.py
- Download Kaggle Display Advertising Challenge Dataset (Criteo Dataset) from https://storage.googleapis.com/dataset-uploader/criteo-kaggle/large_version/train.csv
- Install Horovod using
HOROVOD_NCCL_LINK=SHARED HOROVOD_GPU_OPERATIONS=NCCL pip install --no-cache-dir horovod
- Run benchmark.
# 8 GPU
horovodrun -np 8 python benchmark_sok.py --batch_size 65536