forked from DeepRec-AI/DeepRec
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Modelzoo & Docs] Add examples and docs to demonstrate Collective Tra…
…ining. (DeepRec-AI#914) - Add example model to demonstrate usage of Collective Training function. - Add user documentation about Collective Training Interface. Signed-off-by: JunqiHu <silenceki@hotmail.com>
- Loading branch information
1 parent
f7bc901
commit e2037de
Showing
12 changed files
with
1,820 additions
and
193 deletions.
There are no files selected for viewing
136 changes: 136 additions & 0 deletions
136
cibuild/dockerfiles/Dockerfile.devel-py3.8-cu116-ubuntu20.04-hybridbackend
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
FROM alideeprec/deeprec-build:deeprec-dev-gpu-py38-cu116-ubuntu20.04 | ||
|
||
RUN apt-get update && \ | ||
apt-get install -y \ | ||
--allow-unauthenticated \ | ||
--no-install-recommends \ | ||
pkg-config \ | ||
libssl-dev \ | ||
libcurl4-openssl-dev \ | ||
zlib1g-dev \ | ||
libhdf5-dev \ | ||
wget \ | ||
curl \ | ||
inetutils-ping \ | ||
net-tools \ | ||
unzip \ | ||
git \ | ||
vim \ | ||
cmake \ | ||
clang-format-7 \ | ||
openssh-server openssh-client \ | ||
openmpi-bin openmpi-common libopenmpi-dev libgtk2.0-dev && \ | ||
apt-get clean && \ | ||
rm -rf /var/lib/apt/lists/* | ||
|
||
RUN wget -nv -O /opt/openmpi-4.1.1.tar.gz \ | ||
https://www.open-mpi.org/software/ompi/v4.1/downloads/openmpi-4.1.1.tar.gz && \ | ||
cd /opt/ && tar -xvzf ./openmpi-4.1.1.tar.gz && \ | ||
cd openmpi-4.1.1 && ./configure && make && make install | ||
|
||
RUN git clone https://github.com/DeepRec-AI/HybridBackend.git /opt/HybridBackend | ||
|
||
ENV HYBRIDBACKEND_USE_CXX11_ABI=0 \ | ||
HYBRIDBACKEND_WITH_ARROW_HDFS=ON \ | ||
HYBRIDBACKEND_WITH_ARROW_S3=ON \ | ||
TMP=/tmp | ||
|
||
RUN cd /opt/HybridBackend/build/arrow && \ | ||
ARROW_USE_CXX11_ABI=${HYBRIDBACKEND_USE_CXX11_ABI} \ | ||
ARROW_HDFS=${HYBRIDBACKEND_WITH_ARROW_HDFS} \ | ||
ARROW_S3=${HYBRIDBACKEND_WITH_ARROW_S3} \ | ||
./build.sh /opt/arrow | ||
|
||
RUN pip install -U --no-cache-dir \ | ||
Cython \ | ||
nvidia-pyindex \ | ||
pybind11 \ | ||
tqdm && \ | ||
pip install -U --no-cache-dir \ | ||
nvidia-nsys-cli | ||
|
||
ARG TF_REPO=https://github.com/DeepRec-AI/DeepRec.git | ||
ARG TF_TAG=main | ||
|
||
RUN git clone ${TF_REPO} -b ${TF_TAG} /opt/DeepRec | ||
|
||
RUN wget -nv -O /opt/DeepRec/install_bazel.sh \ | ||
http://pythonrun.oss-cn-zhangjiakou.aliyuncs.com/bazel-0.26.1-installer-linux-x86_64.sh && \ | ||
chmod 777 /opt/DeepRec/install_bazel.sh && /opt/DeepRec/install_bazel.sh | ||
|
||
|
||
ENV TF_NEED_CUDA=1 \ | ||
TF_CUDA_PATHS=/usr,/usr/local/cuda \ | ||
TF_CUDA_VERSION=11.6 \ | ||
TF_CUBLAS_VERSION=11 \ | ||
TF_CUDNN_VERSION=8 \ | ||
TF_NCCL_VERSION=2 \ | ||
TF_CUDA_CLANG=0 \ | ||
TF_DOWNLOAD_CLANG=0 \ | ||
TF_NEED_TENSORRT=0 \ | ||
TF_CUDA_COMPUTE_CAPABILITIES="7.0,8.0" \ | ||
TF_ENABLE_XLA=1 \ | ||
TF_NEED_MPI=0 \ | ||
CC_OPT_FLAGS="-march=skylake -Wno-sign-compare" \ | ||
CXX_OPT_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" | ||
|
||
RUN cd /opt/DeepRec && \ | ||
yes "" | bash ./configure || true | ||
|
||
RUN --mount=type=cache,target=/var/cache/bazel.tensorflow \ | ||
cd /opt/DeepRec && \ | ||
bazel build \ | ||
--disk_cache=/var/cache/bazel.tensorflow \ | ||
--config=nogcp \ | ||
--config=cuda \ | ||
--config=xla \ | ||
--verbose_failures \ | ||
--cxxopt="${CXX_OPT_FLAGS}" \ | ||
--host_cxxopt="${CXX_OPT_FLAGS}" \ | ||
--define tensorflow_mkldnn_contraction_kernel=0 \ | ||
//tensorflow/tools/pip_package:build_pip_package | ||
|
||
RUN mkdir -p /src/dist && \ | ||
cd /opt/DeepRec && \ | ||
./bazel-bin/tensorflow/tools/pip_package/build_pip_package \ | ||
/src/dist --gpu --project_name tensorflow | ||
|
||
RUN pip install --no-cache-dir --user \ | ||
/src/dist/tensorflow-*.whl && \ | ||
rm -f /src/dist/tensorflow-*.whl | ||
|
||
RUN mkdir -p \ | ||
$(pip show tensorflow | grep Location | cut -d " " -f 2)/tensorflow_core/include/third_party/gpus/cuda/ && \ | ||
ln -sf /usr/local/cuda/include \ | ||
$(pip show tensorflow | grep Location | cut -d " " -f 2)/tensorflow_core/include/third_party/gpus/cuda/include | ||
|
||
RUN cd /opt/DeepRec/ && \ | ||
cp tensorflow/core/kernels/gpu_device_array* \ | ||
$(pip show tensorflow | grep Location | cut -d " " -f 2)/tensorflow_core/include/tensorflow/core/kernels | ||
|
||
RUN cd /opt/DeepRec && \ | ||
bazel build --disk_cache=/var/cache/bazel.tensorflow \ | ||
-j 16 -c opt --config=opt //tensorflow/tools/pip_package:build_sok && \ | ||
./bazel-bin/tensorflow/tools/pip_package/build_sok | ||
|
||
ENV ARROW_INCLUDE=/opt/arrow/include \ | ||
ARROW_LIB=/opt/arrow/lib \ | ||
ZSTD_LIB=/opt/arrow/lib | ||
|
||
# Configure HybridBackend | ||
ENV HYBRIDBACKEND_WITH_CUDA=ON \ | ||
HYBRIDBACKEND_WITH_NCCL=ON \ | ||
HYBRIDBACKEND_WITH_ARROW_ZEROCOPY=ON \ | ||
HYBRIDBACKEND_WITH_TENSORFLOW_HALF=OFF \ | ||
HYBRIDBACKEND_WITH_TENSORFLOW_DISTRO=99881015 \ | ||
HYBRIDBACKEND_USE_CXX11_ABI=0 \ | ||
HYBRIDBACKEND_USE_RUFF=1 \ | ||
HYBRIDBACKEND_WHEEL_ALIAS=-deeprec-cu116 \ | ||
TF_DISABLE_EV_ALLOCATOR=true | ||
|
||
RUN cd /opt/HybridBackend && make -j32 | ||
|
||
RUN pip install --no-cache-dir --user \ | ||
/opt/HybridBackend/build/wheel/hybridbackend_deeprec*.whl | ||
|
||
RUN rm -rf /opt/DeepRec /opt/HybridBackend && /opt/openmpi-4.1.1.tar.gz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# Collective Training | ||
|
||
## Background | ||
|
||
For sparse recommendation models like DLRM, there are a large number of parameters and heavy GEMM operations. The asynchronous training paradigm of PS makes it difficult to fully utilize the GPUs in the cluster to accelerate the entire training/inference process.We try to place all the parameters on the worker, but the large amount of memory consumed by the parameters(Embedding) cannot be stored on a single GPU, so we need to perform sharding to place on all GPUs.Native Tensorflow did not support model parallel training (MP), and the community has many excellent plug-ins based on Tensorflow, such as HybridBackend (hereinafter referred to as HB), SparseOperationKit (hereinafter referred to as SOK), and so on. DeepRec provides a unified synchronous training interface `CollectiveStrategy` for users to choose and use. Users can use different synchronous training frameworks with very little code. | ||
|
||
## Interface Introduction | ||
|
||
1. Currently the interface supports HB and SOK, users can choose through the environment variable `COLLECTIVE_STRATEGY`. `COLLECTIVE_STRATEGY` can configure hb, sok corresponding to HB and SOK respectively. The difference from normal startup of Tensorflow tasks is that when users use synchronous training, they need to pull up through additional modules, which need to be started in the following way: | ||
|
||
```bash | ||
CUDA_VISIBLE_DEVICES=0,1 COLLECTIVE_STRATEGY=hb python3 -m tensorflow.python.distribute.launch <python script.py> | ||
``` | ||
If the environment variable is not configured with `CUDA_VISIBLE_DEVICES`, the process will pull up the training sub-processes with the number of GPUs in the current environment by default. | ||
|
||
2. In the user script, a `CollectiveStrategy` needs to be initialized to complete the construction of the model. | ||
|
||
```python | ||
class CollectiveStrategy: | ||
def scope(self, *args, **kwargs): | ||
pass | ||
def embedding_scope(self, **kwargs): | ||
pass | ||
def world_size(self): | ||
pass | ||
def rank(self): | ||
pass | ||
def estimator(self): | ||
pass | ||
def export_saved_model(self): | ||
pass | ||
``` | ||
|
||
Following steps below to using synchronous training: | ||
- Mark with strategy.scope() before the entire model definition. | ||
- Use the embedding_scope() flag where model parallelism is required (embedding layer) | ||
- Use export_saved_model when exporting | ||
- (Optional) In addition, the strategy also provides the estimator interface for users to use. | ||
|
||
## Example | ||
|
||
**MonitoredTrainingSession** | ||
|
||
The following example guides users how to construct Graph through tf.train.MonitoredTrainingSession. | ||
|
||
```python | ||
import tensorflow as tf | ||
from tensorflow.python.distribute.group_embedding_collective_strategy import CollectiveStrategy | ||
|
||
#STEP1: initialize a collective strategy | ||
strategy = CollectiveStrategy() | ||
#STEP2: define the data parallel scope | ||
with strategy.scope(), tf.Graph().as_default(): | ||
#STEP3: define the model parallel scope | ||
with strategy.embedding_scope(): | ||
var = tf.get_variable( | ||
'var_1', | ||
shape=(1000, 3), | ||
initializer=tf.ones_initializer(tf.float32), | ||
partitioner=tf.fixed_size_partitioner(num_shards=strategy.world_size()) | ||
) | ||
emb = tf.nn.embedding_lookup( | ||
var, tf.cast([0, 1, 2, 5, 6, 7], tf.int64)) | ||
fun = tf.multiply(emb, 2.0, name='multiply') | ||
loss = tf.reduce_sum(fun, name='reduce_sum') | ||
opt = tf.train.FtrlOptimizer( | ||
0.1, | ||
l1_regularization_strength=2.0, | ||
l2_regularization_strength=0.00001) | ||
g_v = opt.compute_gradients(loss) | ||
train_op = opt.apply_gradients(g_v) | ||
with tf.train.MonitoredTrainingSession('') as sess: | ||
emb_result, loss_result, _ = sess.run([emb, loss, train_op]) | ||
print (emb_result, loss_result) | ||
``` | ||
|
||
**Estimator** | ||
|
||
The following example guides users how to construct Graph through tf.estimator.Estimator. | ||
```python | ||
import tensorflow as tf | ||
import tensorflow_datasets as tfds | ||
from tensorflow.python.distribute.group_embedding_collective_strategy import CollectiveStrategy | ||
|
||
#STEP1: initialize a collective strategy | ||
strategy = CollectiveStrategy() | ||
#STEP2: define the data parallel scope | ||
with strategy.scope(), tf.Graph().as_default(): | ||
def input_fn(): | ||
ratings = tfds.load("movie_lens/100k-ratings", split="train") | ||
ratings = ratings.map( | ||
lambda x: { | ||
"movie_id": tf.strings.to_number(x["movie_id"], tf.int64), | ||
"user_id": tf.strings.to_number(x["user_id"], tf.int64), | ||
"user_rating": x["user_rating"] | ||
}) | ||
shuffled = ratings.shuffle(1_000_000, | ||
seed=2021, | ||
reshuffle_each_iteration=False) | ||
dataset = shuffled.batch(256) | ||
return dataset | ||
|
||
def input_receiver(): | ||
r'''Prediction input receiver. | ||
''' | ||
inputs = { | ||
"movie_id": tf.placeholder(dtype=tf.int64, shape=[None]), | ||
"user_id": tf.placeholder(dtype=tf.int64, shape=[None]), | ||
"user_rating": tf.placeholder(dtype=tf.float32, shape=[None]) | ||
} | ||
return tf.estimator.export.ServingInputReceiver(inputs, inputs) | ||
|
||
def model_fn(features, labels, mode, params): | ||
r'''Model function for estimator. | ||
''' | ||
del params | ||
movie_id = features["movie_id"] | ||
user_id = features["user_id"] | ||
rating = features["user_rating"] | ||
|
||
embedding_columns = [ | ||
tf.feature_column.embedding_column( | ||
tf.feature_column.categorical_column_with_embedding( | ||
"movie_id", dtype=tf.int64), | ||
dimension=16, | ||
initializer=tf.random_uniform_initializer(-1e-3, 1e-3)), | ||
tf.feature_column.embedding_column( | ||
tf.feature_column.categorical_column_with_embedding( | ||
"user_id", dtype=tf.int64), | ||
dimension=16, | ||
initializer=tf.random_uniform_initializer(-1e-3, 1e-3)) | ||
] | ||
#STEP3: define the model parallel scope | ||
with strategy.embedding_scope(): | ||
with tf.variable_scope( | ||
'embedding', | ||
partitioner=tf.fixed_size_partitioner( | ||
strategy.world_size)): | ||
deep_features = [ | ||
tf.feature_column.input_layer(features, [c]) | ||
for c in embedding_columns] | ||
emb = tf.concat(deep_features, axis=-1) | ||
logits = tf.multiply(emb, 2.0, name='multiply') | ||
|
||
if mode == tf.estimator.ModeKeys.TRAIN: | ||
labels = tf.reshape(tf.to_float(labels), shape=[-1, 1]) | ||
loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(labels, logits)) | ||
step = tf.train.get_or_create_global_step() | ||
opt = tf.train.AdagradOptimizer(learning_rate=self._args.lr) | ||
train_op = opt.minimize(loss, global_step=step) | ||
return tf.estimator.EstimatorSpec( | ||
mode=mode, | ||
loss=loss, | ||
train_op=train_op, | ||
training_chief_hooks=[]) | ||
|
||
return None | ||
estimator = strategy.estimator(model_fn=model_fn, | ||
model_dir="./", | ||
config=None) | ||
estimator.train_and_evaluate( | ||
tf.estimator.TrainSpec( | ||
input_fn=input_fn, | ||
max_steps=50), | ||
tf.estimator.EvalSpec( | ||
input_fn=input_fn)) | ||
estimator.export_saved_model("./", input_receiver) | ||
``` | ||
|
||
## Appendix | ||
|
||
- Currently DeepRec provides the corresponding GPU image for users to use (alideeprec/deeprec-release:deeprec2304-gpu-py38-cu116-ubuntu20.04-hybridbackend), users can also refer to [Dockerfile](../../cibuild/dockerfiles/Dockerfile.devel-py3.8-cu116-ubuntu20.04-hybridbackend) | ||
- We also provides more detailed demos about the above two usage methods, see: [ModelZoo](../../modelzoo/features/grouped_embedding) | ||
|
||
- If further optimization is required, there are more fine-tuning parameters for HB and SOK, please refer to: | ||
[SOK](./SOK.md) 和 [HB](https://github.com/DeepRec-AI/HybridBackend) |
Oops, something went wrong.