Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plain pytorch LLaMA implementation (no fairscale, use as many GPUs as you want) #179

Open
galatolofederico opened this issue Mar 11, 2023 · 8 comments
Labels
feedback-blogpost If the issue or fix has potential for broader announcement and blog post. model-usage issues related to how models are used/loaded

Comments

@galatolofederico
Copy link

Maybe it can be a good idea to also release a llama version without fairscale layers. It is possible to run the 65B version using just 2 A100-SXM-80GB but this code forces you to use 8 GPUs no matter what.

Here is a vanilla pytorch implementation of LLaMA (and a script to convert the weights) https://github.com/galatolofederico/vanilla-llama

@veelion
Copy link

veelion commented Mar 13, 2023

Great work!
But I got an error: NotImplementedError: Cannot copy out of meta tensor; no data!, which comes from accelerate

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ ~/github/vanilla-llama/example.py:18 in <module>                                  │
│                                                                                                  │
│   15 │   a = input('continue?>')                                                                 │
│   16 │                                                                                           │
│   17 │   start_generation = time.time()                                                          │
│ ❱ 18 │   print(llama.generate(["Chat:\nHuman: Hi i am an human\nAI:"], stop_ids=[13]))           │
│   19 │   print(f"Inference took {time.time() - start_generation:.2f} seconds")                   │
│   20 │   while 1:                                                                                │
│   21 │   │   prompt = input('>')                                                                 │
│                                                                                                  │
│ ~/github/vanilla-llama/inference.py:51 in generate                                │
│                                                                                                  │
│   48 │   │   self.generator = LLaMA(self.model, self.tokenizer)                                  │
│   49 │                                                                                           │
│   50 │   def generate(self, texts, temperature=0.8, top_p=0.95, max_length=256, stop_ids=None    │
│ ❱ 51 │   │   results = self.generator.generate(                                                  │
│   52 │   │   │   texts,                                                                          │
│   53 │   │   │   max_gen_len=max_length,                                                         │
│   54 │   │   │   temperature=temperature,                                                        │
│                                                                                                  │

│ ~/github/vanilla-llama/llama/generation.py:73 in generate                         │
│                                                                                                  │
│    70 │   │   start_pos = min_prompt_size                                                        │
│    71 │   │   prev_pos = 0                                                                       │
│    72 │   │   for cur_pos in range(start_pos, total_len):                                        │
│ ❱  73 │   │   │   logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)             │
│    74 │   │   │   if temperature > 0:                                                            │
│    75 │   │   │   │   probs = torch.softmax(logits / temperature, dim=-1)                        │
│    76 │   │   │   │   next_token = sample_top_p(probs, top_p)                                    │
│                                                                                                  │
│ ~/miniconda3/lib/python3.8/site-packages/accelerate/hooks.py:165 in new_forward     │
│                                                                                                  │
│   162 │   │   │   with torch.no_grad():                                                          │
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │
│   164 │   │   else:                                                                              │
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                                          │
│   166 │   │   return module._hf_hook.post_forward(module, output)                                │
│   167 │                                                                                          │
│   168 │   module.forward = new_forward                                                           │
││ ~/miniconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27 in          │
│ decorate_context                                                                                 │
│                                                                                                  │
│    24 │   │   @functools.wraps(func)                                                             │
│    25 │   │   def decorate_context(*args, **kwargs):                                             │
│    26 │   │   │   with self.clone():                                                             │
│ ❱  27 │   │   │   │   return func(*args, **kwargs)                                               │
│    28 │   │   return cast(F, decorate_context)                                                   │
│    29 │                                                                                          │
│    30 │   def _wrap_generator(self, func):                                                       │
│                                                                                                  │
│ ~/github/vanilla-llama/llama/model.py:250 in forward                              │
│                                                                                                  │
│   247 │   │                                                                                      │
│   248 │   │   for layer in self.layers:                                                          │
│   249 │   │   │   h = h.to(layer.parameters().__next__().device)                                 │
│ ❱ 250 │   │   │   h = layer(h, start_pos, freqs_cis, mask)                                       │
│   251 │   │   h = h.to(self.norm.parameters().__next__().device)                                 │
│   252 │   │   h = self.norm(h)                                                                   │
│   253                                                                                            │
│                                                                                                  │
│ ~/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py:1130 in         │
│ _call_impl   
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │                                                                         [65/628]│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1131 │   │   # Do not call functions when jit is used                                          │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ ~/miniconda3/lib/python3.8/site-packages/accelerate/hooks.py:160 in new_forward     │
│                                                                                                  │
│   157 │                                                                                          │
│   158 │   @functools.wraps(old_forward)                                                          │
│   159 │   def new_forward(*args, **kwargs):                                                      │
│ ❱ 160 │   │   args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)                │
│   161 │   │   if module._hf_hook.no_grad:                                                        │
│   162 │   │   │   with torch.no_grad():                                                          │
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                                      │
│                                                                                                  │
│ ~/miniconda3/lib/python3.8/site-packages/accelerate/hooks.py:275 in pre_forward     │
│                                                                                                  │
│   272 │   │   │   ):                                                                             │
│   273 │   │   │   │   set_module_tensor_to_device(module, name, self.execution_device, value=s   │
│   274 │   │                                                                                      │
│ ❱ 275 │   │   return send_to_device(args, self.execution_device), send_to_device(kwargs, self.   │
│   276 │                                                                                          │
│   277 │   def post_forward(self, module, output):                                                │
│   278 │   │   if self.offload:                                                                   │
│                                                                                                  │
│ ~/miniconda3/lib/python3.8/site-packages/accelerate/utils/operations.py:133 in      │
│ send_to_device                                                                                   │
│                                                                                                  │
│   130 │   def _has_to_method(t):                                                                 │
│   131 │   │   return hasattr(t, "to")                                                            │
│   132 │                                                                                          │
│ ❱ 133 │   return recursively_apply(_send_to_device, tensor, device, non_blocking, test_type=_h   │
│   134                                                                                            │
│   135                                                                                            │
│   136 def get_data_structure(data):                                                              │
│                                                                                                  │
│ ~/miniconda3/lib/python3.8/site-packages/accelerate/utils/operations.py:82 in       │
│ recursively_apply                                                                                │
│                                                                                                  │
│    79 │   │   The same data structure as `data` with `func` applied to every object of type `m   │
│    80 │   """                                                                                    │
│    81 │   if isinstance(data, (tuple, list)):                                                    │
│ ❱  82 │   │   return honor_type(                                                                 │
│    83 │   │   │   data,                                                                          │
│    84 │   │   │   (                                                                              │
│    85 │   │   │   │   recursively_apply(                                                         │
│                                                                                                  │
│ ~/miniconda3/lib/python3.8/site-packages/accelerate/utils/operations.py:53 in       │
│ honor_type                                                                                       │
│                                                                                                  │
│    50 │   Cast a generator to the same type as obj (list, tuple, or namedtuple)                  │
│    51 │   """                                                                                    │
│    52 │   try:                                                                                   │
│ ❱  53 │   │   return type(obj)(generator)                                                        │
│    54 │   except TypeError:                                                                      │
│    55 │   │   # Some objects may not be able to instantiate from a generator directly            │
│    56 │   │   return type(obj)(*list(generator))                                                 │
│                                                                                                  │
│ ~/miniconda3/lib/python3.8/site-packages/accelerate/utils/operations.py:85 in       │
│ <genexpr>                                                                                        │
│                                                                                                  │
│    82 │   │   return honor_type(                                                                 │
│    83 │   │   │   data,                                                                          │
│    84 │   │   │   (                                                                              │
│ ❱  85 │   │   │   │   recursively_apply(                                                         │
│    86 │   │   │   │   │   func, o, *args, test_type=test_type, error_on_other_type=error_on_ot   │
│    87 │   │   │   │   )                                                                          │
│    88 │   │   │   │   for o in data                                                              │
│                                                                                                  │
│ ~/miniconda3/lib/python3.8/site-packages/accelerate/utils/operations.py:101 in      │
│ recursively_apply                                                                               
│ recursively_apply                                                                                │
│                                                                                                  │
│    98 │   │   │   }                                                                              │
│    99 │   │   )                                                                                  │
│   100 │   elif test_type(data):                                                                  │
│ ❱ 101 │   │   return func(data, *args, **kwargs)                                                 │
│   102 │   elif error_on_other_type:                                                              │
│   103 │   │   raise TypeError(                                                                   │
│   104 │   │   │   f"Can't apply {func.__name__} on object of type {type(data)}, only of nested   │
│                                                                                                  │
│ ~/miniconda3/lib/python3.8/site-packages/accelerate/utils/operations.py:126 in      │
│ _send_to_device                                                                                  │
│                                                                                                  │
│   123 │                                                                                          │
│   124 │   def _send_to_device(t, device, non_blocking):                                          │
│   125 │   │   try:                                                                               │
│ ❱ 126 │   │   │   return t.to(device, non_blocking=non_blocking)                                 │
│   127 │   │   except TypeError:  # .to() doesn't accept non_blocking as kwarg                    │
│   128 │   │   │   return t.to(device)                                                            │
│   129                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
NotImplementedError: Cannot copy out of meta tensor; no data!                                                                                    │                                                                                                  

@yokie121
Copy link

You can run vanilla-llama on 1, 2, 4, 8 or 100 GPUs.
i want to know how to do it ? thanks

@veelion
Copy link

veelion commented Mar 13, 2023

Run 7B model on 1 GPU (1070, 8GB)

 python example.py --llama-path models/ --model 7B

@randaller
Copy link

Yes, it doesn't loading the weights in load_checkpoint_and_dispatch.

device_map = infer_auto_device_map(model)
print(device_map)

{'':cpu} at all

@galatolofederico
Copy link
Author

Hi @veelion your error is a weird. Try to convert the weights again, it looks like something when wrong in that step. I have tested the conversion script on a cluster so i have never experienced memory problems, try again and let me know if there are any out of memory errors. If that is the case please open an issue on vanilla-llama and i will try to make the conversion script use less RAM.

@yokie121 vanilla-llama uses all the available GPUs by default via accelerate. You just need enough VRAM to fit the model and it will work-

@randaller
Copy link

randaller commented Mar 13, 2023

Hi @galatolofederico I'm experiencing the same error as @veelion. I was able to load the weights, merged by my script that differs a bit, and then device_map showed all ok, but still getting NotImplementedError: Cannot copy out of meta tensor; no data!

We can not open issue in your repo as issues are disabled there :)

@veelion
Copy link

veelion commented Mar 13, 2023

Hi @galatolofederico , I have tried convert the weights many times, but still got the same error as before. I followed @randaller 's method to print device_map, all are on cpu:

{'tok_embeddings': 0, 'layers.0.attention.wq': 0, 'layers.0.attention.wk': 0, 'layers.0.attention.wv': 0,
 'layers.0.attention.wo': 'cpu', 'layers.0.feed_forward': 'cpu', 'layers.0.attention_norm': 'cpu',
 'layers.0.ffn_norm': 'cpu', 'layers.1': 'cpu', 'layers.2': 'cpu', 'layers.3': 'cpu', 'layers.4': 'cpu', 'layers.5': 'cpu',
 'layers.6': 'cpu', 'layers.7': 'cpu', 'layers.8': 'cpu', 'layers.9': 'cpu', 'layers.10': 'cpu', 'layers.11': 'cpu',
 'layers.12': 'cpu', 'layers.13': 'cpu', 'layers.14': 'cpu', 'layers.15': 'cpu', 'layers.16': 'cpu', 'layers.17': 'cpu',
 'layers.18': 'cpu', 'layers.19': 'cpu', 'layers.20': 'cpu', 'layers.21': 'cpu', 'layers.22': 'cpu', 'layers.23': 'cpu',
 'layers.24': 'cpu', 'layers.25': 'cpu', 'layers.26': 'cpu', 'layers.27': 'cpu', 'layers.28': 'cpu', 'layers.29': 'cpu',
 'layers.30': 'cpu', 'layers.31': 'cpu', 'norm': 'cpu', 'output': 'cpu'}

@galatolofederico
Copy link
Author

We can not open issue in your repo as issues are disabled there :)

Sorry i didn't noticed it 🤦 . I have enabled the issues now!

@WuhanMonkey WuhanMonkey added model-usage issues related to how models are used/loaded feedback-blogpost If the issue or fix has potential for broader announcement and blog post. labels Sep 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feedback-blogpost If the issue or fix has potential for broader announcement and blog post. model-usage issues related to how models are used/loaded
Projects
None yet
Development

No branches or pull requests

5 participants