Open
Description
Describe the bug
The model contains a lots of linears, size like llama-7b, code is below:
import torch
from torch.utils.data.dataset import TensorDataset
from torch.utils.data.dataloader import DataLoader
import intel_extension_for_pytorch as ipex
import time
from torch import nn
class MLP(nn.Module):
'''
Multilayer Perceptron.
'''
def __init__(self):
super().__init__()
modules = []
modules.append(nn.Embedding(32000, 4096))
for i in range(32):
modules.append(nn.Linear(4096, 4096, bias=False))
modules.append(nn.Linear(4096, 4096, bias=False))
modules.append(nn.Linear(4096, 4096, bias=False))
modules.append(nn.Linear(4096, 4096, bias=False))
modules.append(nn.Linear(4096, 11008, bias=False))
modules.append(nn.Linear(11008, 4096, bias=False))
modules.append(nn.Linear(4096, 32000, bias=False))
self.layers = nn.Sequential(*modules)
def forward(self, x):
'''Forward pass'''
return self.layers(x)
model = MLP()
x = torch.randint(1, 32000, (32, 1), dtype=torch.long)
y = torch.ones((32, ), dtype=torch.long)
bs = 1
model = model.half().to('xpu')
ds = TensorDataset(x, y)
dataloader = DataLoader(ds, batch_size=bs)
with torch.inference_mode():
# warmup
for batch_ndx, sample in enumerate(dataloader):
mini_batch = sample[0].to('xpu')
model(mini_batch)
s = time.time()
for batch_ndx, sample in enumerate(dataloader):
mini_batch = sample[0].to('xpu')
model(mini_batch)
print("time cost: " + str(time.time() - s))
The time cost on 6.2 kernel is 1.1821s, but on 5.19 kernel only cost 1.0468s.
Versions
CPU: i9 13900K
GPU: GUNNIR Arc A770
OS: ubuntu 22.04.3
Python: 3.9.18
Dependencies:
accelerate 0.21.0
antlr4-python3-runtime 4.9.3
certifi 2023.7.22
charset-normalizer 3.2.0
einops 0.6.1
filelock 3.12.4
fsspec 2023.9.1
huggingface-hub 0.17.2
idna 3.4
intel-extension-for-pytorch 2.0.110+xpu
Jinja2 3.1.2
MarkupSafe 2.1.3
mkl-include 2023.2.0
mkl-static 2023.2.0
mpmath 1.3.0
networkx 3.1
ninja 1.11.1
numpy 1.26.0
omegaconf 2.3.0
packaging 23.1
pandas 2.1.0
Pillow 10.0.1
pip 23.2.1
protobuf 4.24.3
psutil 5.9.5
py-cpuinfo 9.0.0
python-dateutil 2.8.2
pytz 2023.3.post1
PyYAML 6.0.1
regex 2023.8.8
requests 2.31.0
safetensors 0.3.3
sentencepiece 0.1.99
setuptools 68.0.0
six 1.16.0
sympy 1.12
tabulate 0.9.0
tiktoken 0.5.1
tokenizers 0.13.3
torch 2.0.1a0+cxx11.abi
torchvision 0.15.2a0+cxx11.abi
tqdm 4.66.1
transformers 4.31.0
transformers-stream-generator 0.0.4
typing_extensions 4.8.0
tzdata 2023.3
urllib3 2.0.4
wheel 0.38.4