Skip to content

Commit

Permalink
fix import ipex problem (#9323)
Browse files Browse the repository at this point in the history
* fix import ipex problem

* fix style
  • Loading branch information
yangw1234 committed Nov 1, 2023
1 parent 9f3d467 commit e1bc18f
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions python/llm/src/bigdl/llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,23 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
KV_CACHE_ALLOC_BLOCK_LENGTH = 256


def get_ipex_version():
_ipex_version = None

if importlib.util.find_spec("intel_extension_for_pytorch") is not None:
import intel_extension_for_pytorch as ipex
return ipex.__version__
else:
return None

def get_ipex_version():

global _ipex_version
if _ipex_version is not None:
return _ipex_version

ipex_version = get_ipex_version()
import intel_extension_for_pytorch as ipex
_ipex_version = ipex.__version__
return _ipex_version


def llama_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
if ipex_version == "2.0.110+xpu":
if get_ipex_version() == "2.0.110+xpu":
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
else:
Expand Down

0 comments on commit e1bc18f

Please sign in to comment.