Skip to content

Commit

Permalink
convert-hf : remove einops requirement for InternLM2
Browse files Browse the repository at this point in the history
  • Loading branch information
compilade committed May 5, 2024
1 parent 0c38332 commit 98db434
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 20 deletions.
1 change: 0 additions & 1 deletion .devops/nix/package.nix
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ let
# TODO(Green-Sky): find a better way to opt-into the heavy ml python runtime
llama-python-extra = python3.withPackages (
ps: [
ps.einops
ps.numpy
ps.sentencepiece
ps.tiktoken
Expand Down
39 changes: 22 additions & 17 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,16 +1890,18 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"

if re.match(qkv_pattern, name):
from einops import rearrange

bid = re.findall(qkv_pattern, name)[0]
qkv = data_torch
qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
# qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim))
q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
# The model weights of q and k equire additional reshape.
q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
v = rearrange(v, " o g n i -> o (g n i)").T
# q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads)
# k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads)
# v = rearrange(v, " o g n i -> o (g n i)").T
v = v.reshape((v.shape[0], -1)).T
return [
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
Expand Down Expand Up @@ -2238,13 +2240,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
class LazyTorchTensor:
_meta: Tensor
_data: Tensor | None
_args: list[Any]
_func: Callable[[list[Any]], Tensor] | None = None
_args: tuple
_func: Callable[[tuple], Tensor] | None

def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: list[Any] | None = None, func: Callable[[list[Any]], Tensor] | None = None):
def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None):
self._meta = meta
self._data = data
self._args = args if args is not None else []
self._args = args
self._func = func

@staticmethod
Expand All @@ -2266,19 +2268,22 @@ def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], Lazy
def wrapped_fn(*args, **kwargs):
if kwargs is None:
kwargs = {}
args_list = ([self] if use_self else []) + list(args)
args = ((self,) if use_self else ()) + args

meta_args = LazyTorchTensor._recurse_apply(args_list, lambda t: t._meta)
meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta)

return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args_list, func=lambda a: fn(*a, **kwargs))
return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs))
return wrapped_fn

def __getattr__(self, __name: str) -> Any:
meta_attr = getattr(self._meta, __name)
if not callable(meta_attr):
return meta_attr
else:
if callable(meta_attr):
return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
elif isinstance(meta_attr, torch.Tensor):
# for things like self.T
return self._wrap_fn(lambda s: getattr(s, __name))(self)
else:
return meta_attr

_dtype_map: dict[torch.dtype, type] = {
torch.float16: np.float16,
Expand All @@ -2295,7 +2300,7 @@ def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...

@overload
@staticmethod
def to_eager(t: list[Tensor | LazyTorchTensor]) -> list[Tensor]: ...
def to_eager(t: tuple) -> tuple: ...

@staticmethod
def to_eager(t: Any) -> Any:
Expand Down
1 change: 0 additions & 1 deletion requirements/requirements-convert-hf-to-gguf-update.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
-r ./requirements-convert.txt
torch~=2.1.1
einops~=0.7.0
1 change: 0 additions & 1 deletion requirements/requirements-convert-hf-to-gguf.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
-r ./requirements-convert.txt
torch~=2.1.1
einops~=0.7.0

0 comments on commit 98db434

Please sign in to comment.