Skip to content

Commit f39607a

Browse files
authored
Merge pull request intel#1 from otcshare/ipex_dlrm
enable automix bf16 interaction, emb fw
2 parents e904bb3 + e4afb04 commit f39607a

File tree

14 files changed

+997
-36
lines changed

14 files changed

+997
-36
lines changed

examples/README.md

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Guide to run auto-mix precision(bf16) models with intel extension for pytorch
2+
3+
## Verified on
4+
5+
| Item | Value |
6+
| -: | :- |
7+
| OS | Ubuntu 18.04 LTS |
8+
| Compiler | gcc 7.5.0 |
9+
| Memory | DDR4 3200MHz, 96GB/socket |
10+
11+
## Environment setting
12+
13+
1. Install anaconda 3.0
14+
```
15+
wget https://repo.continuum.io/archive/Anaconda3-5.0.0-Linux-x86_64.sh -O anaconda3.sh
16+
chmod +x anaconda3.sh
17+
./anaconda3.sh -b -p ~/anaconda3
18+
./anaconda3/bin/conda create -n ipex python=3.7
19+
```
20+
21+
2. Setup anaconda virtual environment for ipex
22+
```
23+
export PATH=~/anaconda3/bin:$PATH
24+
source ./anaconda3/bin/activate ipex
25+
```
26+
27+
3. Install dependencies
28+
```
29+
conda config --append channels intel
30+
conda install ninja pyyaml setuptools cmake cffi typing numpy
31+
conda install mkl intel-openmp mkl-include -c intel --no-update-deps
32+
conda install jemalloc
33+
34+
```
35+
36+
4. Clone source code and build
37+
38+
```
39+
# PyTorch
40+
git clone https://github.com/otcshare/pytorch.git
41+
git checkout tags/v1.7.0 -b v1.7.0
42+
git submodule sync && git submodule update --init --recursive
43+
44+
# extension
45+
git clone https://github.com/otcshare/intel-extension-for-pytorch.git
46+
git checkout dlrm
47+
git submodule update --init --recursive
48+
49+
# prepare patch to PyTorch
50+
cp {path/to/intel-pytorch-extension}/torch_patches/dlrm_fp32.patch {path/to/pytorch}/
51+
cp {path/to/intel-pytorch-extension}/torch_patches/xpu-1.7.patch {path/to/pytorch}/
52+
53+
# build PyTorch
54+
cd {path/to/pytorch}
55+
patch -p1 < xpu-1.7.patch
56+
patch -p1 < dlrm_fp32-1.7.patch
57+
pip install -r requirements.txt
58+
python setup.py install
59+
60+
# build extension
61+
cd {path/to/intel-pytorch-extension}
62+
pip install -r requirements.txt
63+
cd third_party/mkl-dnn
64+
patch -p1 < ../../torch_patches/FIFO.diff
65+
cd ../../
66+
python setup.py install
67+
68+
```
69+
70+
## Prepare DLRM
71+
72+
```
73+
git clone https://github.com/otcshare/dlrm.git
74+
cd dlrm
75+
pip install -r requirements.txt
76+
```
77+
78+
79+
## Run Models
80+
```
81+
export DATASET_PATH={patch/to/dlrm_dataset}
82+
```
83+
84+
1. Inference with vanilla pytorch
85+
```
86+
bash run_inference.sh
87+
```
88+
89+
2. Inference with ipex fp32
90+
```
91+
bash run_inference.sh ipex
92+
```
93+
94+
3. Inference with ipex bf16
95+
```
96+
bash run_inference.sh ipex bf16
97+
```
98+
99+
4. Training with vanilla pytorch
100+
```
101+
bash run_training.sh
102+
```
103+
104+
5. Training with ipex fp32
105+
```
106+
bash run_training.sh ipex
107+
```
108+
109+
6. Training with ipex bf16
110+
```
111+
bash run_training.sh ipex bf16
112+
```

intel_pytorch_extension_py/ops/embeddingbag.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import torch
22
from torch import nn
33
from torch.autograd import Function
4+
import intel_pytorch_extension as ipex
45
import _torch_ipex as core
56

67
# # extension for BF16 fast path only
78

8-
9+
torch_embedding_bag = torch.embedding_bag
910
def embeddingbag(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset):
10-
ret = torch.ops.torch_ipex.embedding_bag(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset)
11+
if weights.device.type in ipex.DEVICE:
12+
ret = torch.ops.torch_ipex.embedding_bag(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset)
13+
else:
14+
ret = torch_embedding_bag(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset)
1115
if len(ret)==1:
1216
ret += [torch.Tensor(), torch.Tensor(), torch.Tensor()]
1317
return ret

intel_pytorch_extension_py/ops/interaction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def interaction(*args):
88
# So we preserve python custom function while need backward
99
# Since python custom function will meet GIL when run multi-thread in one process
1010
# We will drop python custom function after c++ are supported
11-
if torch.is_grad_enabled():
11+
if torch.is_grad_enabled() and core.get_train():
1212
return InteractionFunc.apply(*args)
1313
return torch.ops.torch_ipex.interaction_forward(args)
1414

tests/cpu/test_emb.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,23 @@
55
import copy
66
from common_utils import TestCase
77

8+
from common_ipex_conf import AutoMixPrecision, AutoDNNL
9+
810
class TestEMB(TestCase):
11+
def test_automix_emb(self):
12+
EE = nn.EmbeddingBag(10, 3, mode='sum', sparse=True)
13+
emb_auto_mix = copy.deepcopy(EE).to(ipex.DEVICE)
14+
emb_auto_mix.weight.requires_grad = False
15+
input = torch.LongTensor([1,2,4,5,4,3,2,9])
16+
offsets = torch.LongTensor([0,1,2,3,4,5,6,7])
17+
res_fp32 = EE(input, offsets)
18+
19+
with AutoDNNL(True), AutoMixPrecision(True):
20+
res_auto_mix = emb_auto_mix(input.to(ipex.DEVICE), offsets.to(ipex.DEVICE))
21+
self.assertEqual(res_auto_mix.dtype, torch.float)
22+
self.assertTrue(ipex.core.is_bf16_dil_tensor(res_auto_mix))
23+
self.assertTrue(torch.allclose(res_fp32,res_auto_mix, rtol=1e-5, atol=1e-2))
24+
925
def test_emb(self):
1026
#E = nn.EmbeddingBag(10, 5, mode="sum", sparse=True)
1127
cpu_emb = nn.EmbeddingBag(10, 3, mode='sum', sparse=True)

tests/cpu/test_interaction.py

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,47 +21,72 @@
2121
IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \
2222
skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf
2323

24+
from common_ipex_conf import AutoMixPrecision, AutoDNNL
25+
2426
class TestInteractionCases(TestCase):
25-
def test_interaction(self):
26-
def interact_fusion(x, ly):
27-
A = [x] + ly
28-
R = ipex.interaction(*A)
29-
return R
27+
def interact_fusion(self, x, ly):
28+
A = [x] + ly
29+
R = ipex.interaction(*A)
30+
return R
3031

31-
def interact_features(x, ly):
32-
(batch_size, d) = x.shape
33-
T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
34-
# Z = pcl_embedding_bag.bdot(T)
35-
Z = torch.bmm(T, torch.transpose(T, 1, 2))
36-
_, ni, nj = Z.shape
37-
offset = 0
38-
li = torch.tensor([i for i in range(ni) for j in range(i + offset)], device=ipex.DEVICE)
39-
lj = torch.tensor([j for i in range(nj) for j in range(i + offset)], device=ipex.DEVICE)
40-
Zflat = Z[:, li, lj]
41-
# concatenate dense features and interactions
42-
R = torch.cat([x] + [Zflat], dim=1)
43-
return R
32+
def interact_features(self, x, ly):
33+
(batch_size, d) = x.shape
34+
T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
35+
# Z = pcl_embedding_bag.bdot(T)
36+
Z = torch.bmm(T, torch.transpose(T, 1, 2))
37+
_, ni, nj = Z.shape
38+
offset = 0
39+
li = torch.tensor([i for i in range(ni) for j in range(i + offset)], device=ipex.DEVICE)
40+
lj = torch.tensor([j for i in range(nj) for j in range(i + offset)], device=ipex.DEVICE)
41+
Zflat = Z[:, li, lj]
42+
# concatenate dense features and interactions
43+
R = torch.cat([x] + [Zflat], dim=1)
44+
return R
4445

45-
dtypes=[torch.float32]
46-
for dtype in dtypes:
46+
def get_input(self, dtype):
4747
x1 = torch.randn([2048, 128], device=ipex.DEVICE).to(dtype).clone().detach().requires_grad_()
48-
x2 = x1.clone().detach().requires_grad_()
48+
x1_clone = x1.clone().detach().requires_grad_()
4949
ly1 = []
50-
ly2 = []
50+
ly1_clone = []
5151
for i in range(0, 26):
5252
V = torch.randn([2048, 128], device=ipex.DEVICE).to(dtype).clone().detach().requires_grad_()
5353
ly1.append(V)
54-
ly2.append(V.clone().detach().requires_grad_())
54+
ly1_clone.append(V.clone().detach().requires_grad_())
55+
return x1, ly1, x1_clone, ly1_clone
5556

56-
A = interact_fusion(x1, ly1)
57-
B = interact_features(x2, ly2)
58-
self.assertEqual(A, B)
57+
def test_interaction_fusion(self):
58+
dtypes=[torch.float32, torch.bfloat16]
59+
for dtype in dtypes:
60+
x1, ly1, x1_clone, ly1_clone = self.get_input(dtype)
61+
ipex.core.set_execution_mode(train=True)
62+
A = self.interact_fusion(x1, ly1).to(torch.float32)
63+
B = self.interact_features(x1_clone, ly1_clone).to(torch.float32)
64+
self.assertTrue(torch.allclose(A, B, rtol=1e-4, atol=1e-4))
5965

6066
A.mean().backward()
6167
B.mean().backward()
62-
self.assertEqual(x1.grad, x2.grad)
68+
self.assertTrue(torch.allclose(x1.grad.to(torch.float32), x1_clone.grad.to(torch.float32), rtol=1e-4, atol=1e-4))
6369
for i in range(0, 26):
64-
self.assertEqual(ly1[i].grad, ly2[i].grad)
70+
self.assertTrue(torch.allclose(ly1[i].grad.to(torch.float32), ly1_clone[i].grad.to(torch.float32), rtol=1e-4, atol=1e-4))
71+
72+
def test_automix_fused_interaction(self):
73+
x1, ly1, x1_clone, ly1_clone = self.get_input(torch.float32)
74+
man_bf16_x1 = x1_clone.to(torch.bfloat16)
75+
man_bf16_ly1= [y.to(torch.bfloat16) for y in ly1_clone]
76+
with AutoDNNL(True), AutoMixPrecision(False):
77+
self.assertEqual(man_bf16_x1.dtype, torch.bfloat16)
78+
for i in range(0, 26):
79+
self.assertEqual(man_bf16_ly1[i].dtype, torch.bfloat16)
80+
res_man_bf16 = self.interact_fusion(man_bf16_x1, man_bf16_ly1)
81+
self.assertEqual(res_man_bf16.dtype, torch.bfloat16)
82+
83+
with AutoMixPrecision(True):
84+
res_auto_bf16 = self.interact_fusion(x1, ly1)
85+
self.assertTrue(ipex.core.is_bf16_dil_tensor(x1))
86+
for i in range(0, 26):
87+
self.assertTrue(ipex.core.is_bf16_dil_tensor(ly1[i]))
88+
self.assertTrue(ipex.core.is_bf16_dil_tensor(res_auto_bf16))
89+
self.assertEqual(res_man_bf16.to(torch.float32), res_auto_bf16)
6590

6691
if __name__ == '__main__':
6792
test = unittest.main()

torch_ipex/csrc/cpu/CustomOPs.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,13 @@ class NewLinearOp : public torch::autograd::Function<NewLinearOp> {
8383
input.sizes(),
8484
grad_output.is_contiguous() ? grad_output
8585
: grad_output.contiguous(),
86-
weight.is_contiguous() ? weight : weight.contiguous());
86+
weight);
8787
std::tie(grad_weight, grad_bias) =
8888
torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_weights(
8989
grad_output.is_contiguous() ? grad_output
9090
: grad_output.contiguous(),
9191
input.is_contiguous() ? input : input.contiguous(),
92-
weight.is_contiguous() ? weight : weight.contiguous(),
92+
weight,
9393
bias.defined());
9494
return {grad_input, grad_weight, grad_bias};
9595
}

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1049,7 +1049,7 @@ at::Tensor AtenIpexCPUDev::dil_linear(
10491049
// reshape first if input dim is greater than 2 and the reshape will cost a memory copy.
10501050
auto self_reshaped = self.dim() > 2 ? dil_reshape(self, {-1, dil_size(self, self.dim() - 1)}) : self;
10511051
const dil::tensor x = dbl::comm::try_gen_dil_tensor(self_reshaped);
1052-
if (!check_train() && check_tensor_own_whole_storage(weight)) {
1052+
if (!(check_auto_mix_bf16_fp32() && check_train()) && check_tensor_own_whole_storage(weight)) {
10531053
dbl::linear::prepack_linear_weights(self_reshaped, x, weight);
10541054
}
10551055
const dil::tensor w = dbl::comm::try_gen_dil_tensor(weight);

torch_ipex/csrc/cpu/ExtendOPs.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <algorithm>
1313
#include <c10/util/Exception.h>
1414
#include <torch/csrc/autograd/function.h>
15+
#include "ShadeDataContext.h"
16+
#include "torch_ipex/csrc/cpu/bf16/Bridge.hpp"
1517

1618
namespace torch_ipex {
1719

@@ -233,7 +235,7 @@ inline at::Tensor _interaction_forward(const std::vector<at::Tensor> &input) {
233235
std::vector<T *> input_data(input.size());
234236
for (int i = 0; i < input.size(); i++) {
235237
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input[i].is_contiguous());
236-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input[i].device().is_xpu());
238+
// TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input[i].device().is_xpu());
237239
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input[i].dim() == 2);
238240
feature_sizes[i] = input[i].sizes()[1];
239241
total_feature_size += input[i].sizes()[1];
@@ -343,8 +345,24 @@ _interaction_backward(const at::Tensor &grad_out,
343345
return output;
344346
}
345347

348+
at::Tensor
349+
interaction_forward_auto_mix_dispatch(const std::vector<at::Tensor> &input) {
350+
for (auto &in : input) {
351+
cpu::dbl::comm::reorder_to_bf16_for_mix_prec(in);
352+
IPEX_CHECK(cpu::ShadeDataContext::isTensorMixPrecisionBF16(in));
353+
}
354+
const std::vector<at::Tensor>& consistent_input = cpu::bf16::gen_consistent_tensorlist(input);
355+
auto &&_ipex_result = _interaction_forward<at::BFloat16>(consistent_input);
356+
return cpu::bf16::gen_mix_prec_tensor(_ipex_result);
357+
}
358+
346359
at::Tensor
347360
AtenIpexTypeExt::interaction_forward(const std::vector<at::Tensor> &input) {
361+
bool auto_mix_bf16 = check_auto_mix_bf16_fp32();
362+
if (auto_mix_bf16){
363+
return interaction_forward_auto_mix_dispatch(input);
364+
}
365+
// preserve the support of origin pytorch bfloat 16 path
348366
if (input[0].scalar_type() == at::kFloat) {
349367
for (auto &in : input) {
350368
cpu::dbl::comm::reorder_to_public(in);
@@ -360,9 +378,29 @@ AtenIpexTypeExt::interaction_forward(const std::vector<at::Tensor> &input) {
360378
}
361379
}
362380

381+
std::vector<at::Tensor>
382+
interaction_backward_auto_mix_dispatch(const at::Tensor &grad_out,
383+
const std::vector<at::Tensor> &input) {
384+
for (auto &in : input) {
385+
cpu::dbl::comm::reorder_to_bf16_for_mix_prec(in);
386+
IPEX_CHECK(cpu::ShadeDataContext::isTensorMixPrecisionBF16(in));
387+
}
388+
cpu::dbl::comm::reorder_to_bf16_for_mix_prec(grad_out);
389+
IPEX_CHECK(cpu::ShadeDataContext::isTensorMixPrecisionBF16(grad_out));
390+
const std::vector<at::Tensor>& consistent_input = cpu::bf16::gen_consistent_tensorlist(input);
391+
const at::Tensor& consistent_grad_out = cpu::bf16::gen_consistent_tensor(grad_out);
392+
auto &&_ipex_result = _interaction_backward<at::BFloat16>(consistent_grad_out, consistent_input);
393+
return cpu::bf16::gen_mix_prec_tensorlist(_ipex_result);
394+
}
395+
363396
std::vector<at::Tensor>
364397
AtenIpexTypeExt::interaction_backward(const at::Tensor &grad_out,
365398
const std::vector<at::Tensor> &input) {
399+
bool auto_mix_bf16 = check_auto_mix_bf16_fp32();
400+
if (auto_mix_bf16){
401+
return interaction_backward_auto_mix_dispatch(grad_out, input);
402+
}
403+
// preserve the support of origin pytorch bfloat 16 path
366404
if (grad_out.scalar_type() == at::kFloat) {
367405
cpu::dbl::comm::reorder_to_public(grad_out);
368406
return _interaction_backward<float>(grad_out, input);

torch_ipex/csrc/cpu/ShadeDataContext.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,26 @@ struct ShadeDataContext {
294294
ShadeDataContext *shade_data_context = (ShadeDataContext*)storage_context;
295295
shade_data_context->packed = value;
296296
}
297+
298+
static inline bool isTensorMixPrecisionBF16(const at::Tensor &tensor) {
299+
auto dil_tensor_type = getDilStorage(tensor).get_data_type();
300+
auto aten_tensor_type = tensor.scalar_type();
301+
if (aten_tensor_type != at::kFloat) {
302+
return false;
303+
}
304+
auto res = (dil_tensor_type == dil::data_type::bf16);
305+
306+
// Check mix_precision
307+
void *raw_context = tensor.storage().data_ptr().get_context();
308+
ShadeDataContext *shade_data_context = (ShadeDataContext*)raw_context;
309+
if (shade_data_context->mix_prec_type == MIX_PREC_TYPE::MIX_BF16_FP32) {
310+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(res);
311+
} else {
312+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!res);
313+
}
314+
315+
return res;
316+
}
297317
};
298318

299319
} // namespace cpu

0 commit comments

Comments
 (0)