Description
I benchmarked the Bert model exported from Huggingface TF Bert code and observed that it is much slower than Google's TF2 Bert model(https://github.com/tensorflow/models/tree/master/official/nlp/bert) on TF Serving (GPU device). When batch_size=1, Huggingface Bert model can be more than 200% slower than Google's one. Here are my benchmark results:
latency: (ms/batch) on 1 GPU-V100
Batch Size | google official | HF_TF2 | relative change |
---|---|---|---|
1 | 6.7 | 21.3 | 238.10% |
2 | 9.4 | 24.2 | 171.91% |
4 | 14.4 | 28.1 | 99.29% |
8 | 24.0 | 36.9 | 55.70% |
16 | 46.6 | 58.6 | 26.57% |
64 | 171.5 | 171.4 | 0.35% |
128 | 338.5 | 324.5 | -3.71% |
I used the latest Tensorflow profiling tool (https://www.tensorflow.org/tfx/serving/tensorboard) to compare these two:
Hugging-face Bert-Base FP32 Batch_size = 128
Google Bert-Base FP32 Batch_size = 128
Comparing the event traces of HF's and google's, we see that before each MatMul operation on GPU, there are always some CPU processes running (wasn’t able to get more info about these CPU processes because the Profiling tool doesn’t provide meaningful labels or description for those CPU processes but it seems like they are doing some gather/reduce operations) and the GPU is inactive for a while. If you look at the trace of google official, you can find that all GPU ops are compact and there are no idle time for GPU. In parallel, only one CPU process keeps running and doesn’t cause GPU ops to wait. The short idle time will hurt the performance more when batch_size or model_size is small because short time is spent on GPU calculations and idle time is not ignorable anymore. See time comparison grouped by type below:
Hugging-face Bert-Base FP32 Batch_size = 4
Google Bert-Base FP32 Batch_size = 4
Under small batch_size=4, Hugging-face model has 60% GPU idle time while that of google model is only 25%. This also explains the reason why for small batch_size, there is over 200% slow down and for large batch_size=128, the situation gets better.
The MatMul op is triggered by tf.keras.layers.Dense(), which is widely used in Transformer encoder self-attention, intermediate layer and output layer. In comparison, Google's Bert uses DenseEinsum()(link) to replace all usage of Dense layer. I personally modified Huggingface's TF Bert code base to use DenseEinsum(). And the slow down issue got solved:
latency: (ms/batch) on 1 GPU-V100
Batch Size | google official | HF_TF2 | relative change | After Fixing | relative change |
---|---|---|---|---|---|
1 | 6.7 | 21.3 | 238.10% | 6.60 | 4.76% |
2 | 9.4 | 24.2 | 171.91% | 8.90 | 0.00% |
4 | 14.4 | 28.1 | 99.29% | 13.40 | -4.96% |
8 | 24.0 | 36.9 | 55.70% | 22.10 | -6.75% |
16 | 46.6 | 58.6 | 26.57% | 43.20 | -6.70% |
64 | 171.5 | 171.4 | 0.35% | 158.80 | -7.03% |
128 | 338.5 | 324.5 | -3.71% | 313.10 | -7.09% |
I noticed that multiple issues may be related to this (#6264).
Do you plan to solve this issue? Thanks. @patrickvonplaten