Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixtral-8x7B-Instruct-v0.1-GPTQ weight loading error #2202

Closed
jbohnslav opened this issue Dec 19, 2023 · 5 comments · Fixed by #2208
Closed

Mixtral-8x7B-Instruct-v0.1-GPTQ weight loading error #2202

jbohnslav opened this issue Dec 19, 2023 · 5 comments · Fixed by #2208
Labels
bug Something isn't working

Comments

@jbohnslav
Copy link

Command:

python3 -m vllm.entrypoints.openai.api_server     --model /models/Mixtral-8x7B-Instruct-v0.1-GPTQ -tp 2 --dtype float16 

Result:

INFO 12-19 00:45:14 api_server.py:727] args: Namespace(host=None, port=8000, allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], served_model_name=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, model='/models/Mixtral-8x7B-Instruct-v0.1-GPTQ', tokenizer=None, revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', dtype='float16', max_model_len=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=2, max_parallel_loading_workers=None, block_size=16, seed=0, swap_space=4, gpu_memory_utilization=0.9, max_num_batched_tokens=None, max_num_seqs=256, max_paddings=256, disable_log_stats=False, quantization=None, enforce_eager=False, max_context_len_to_capture=8192, engine_use_ray=False, disable_log_requests=False, max_log_len=None)       
WARNING 12-19 00:45:14 config.py:463] Casting torch.bfloat16 to torch.float16.                                                                               
WARNING 12-19 00:45:14 config.py:175] gptq quantization is not fully optimized yet. The speed can be slower than non-quantized models.                                                                                     
WARNING 12-19 00:45:14 config.py:187] gptq does not support CUDA graph yet. Disabling CUDA graph.                     
2023-12-19 00:45:16,379 INFO worker.py:1673 -- Started a local Ray instance.    
INFO 12-19 00:45:17 llm_engine.py:73] Initializing an LLM engine with config: model='/models/Mixtral-8x7B-Instruct-v0.1-GPTQ', tokenizer='/models/Mixtral-8x7B-Instruct-v0.1-GPTQ', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=2, quantization=gptq, enforce_eager=True, seed=0)                         
(RayWorkerVllm pid=2792) /usr/local/lib/python3.10/dist-packages/torch/nn/init.py:412: UserWarning: Initializing zero-element tensors is a no-op                                                                           
(RayWorkerVllm pid=2792)   warnings.warn("Initializing zero-element tensors is a no-op")                                                                                                                                   Traceback (most recent call last):                                                                                                                                                                                           
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main                                                                                                                                                        
return _run_code(code, main_globals, None,                                                                                                                                                                               
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code                                                                                                                                                                   
exec(code, run_globals)                                                                                                                                                                                                  
File "/workspace/vllm/entrypoints/openai/api_server.py", line 737, in <module>                                                                                                                                               
engine = AsyncLLMEngine.from_engine_args(engine_args)                                                                                                                                                                    
File "/workspace/vllm/engine/async_llm_engine.py", line 496, in from_engine_args                                                                                                                                             
engine = cls(parallel_config.worker_use_ray,                                                                                                                                                                             
File "/workspace/vllm/engine/async_llm_engine.py", line 269, in __init__                                                                                                                                                     
self.engine = self._init_engine(*args, **kwargs)                                                                                                                                                                         
File "/workspace/vllm/engine/async_llm_engine.py", line 314, in _init_engine                                                                                                                                                 
return engine_class(*args, **kwargs)                                                                                                                                                                                     
File "/workspace/vllm/engine/llm_engine.py", line 108, in __init__                                                                                                                                                           
self._init_workers_ray(placement_group)                                                                                                                                                                                  
File "/workspace/vllm/engine/llm_engine.py", line 195, in _init_workers_ray                                                                                                                                                  
self._run_workers(                                                                                                                                                                                                       
File "/workspace/vllm/engine/llm_engine.py", line 763, in _run_workers                                                                                                                                                       
self._run_workers_in_batch(workers, method, *args, **kwargs))                                                                                                                                                            
File "/workspace/vllm/engine/llm_engine.py", line 740, in _run_workers_in_batch                                                                                                                                              
all_outputs = ray.get(all_outputs)                                                                                                                                                                                       
File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 24, in auto_init_wrapper                                                                                                                 
return fn(*args, **kwargs)                                                                                                                                                                                               
File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper                                                                                                                        
return func(*args, **kwargs)                                                                                                                                                                                             
File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2563, in get                                                                                                                                     
raise value.as_instanceof_cause()                                                                                                                                                                                      
ray.exceptions.RayTaskError(KeyError): ray::RayWorkerVllm.execute_method() (pid=2793, ip=172.17.0.2, 
actor_id=650e8e31e738a77f21d4661701000000, repr=<vllm.engine.ray_utils.RayWorkerVllm object at 0x7f22cd099a50>)         File 
"/workspace/vllm/engine/ray_utils.py", line 31, in execute_method                                                                                                                                                       
return executor(*args, **kwargs)                                                                                                                                                                                         
File "/workspace/vllm/worker/worker.py", line 79, in load_model                                                                                                                                                              
self.model_runner.load_model()                                                                                                                                                                                           
File "/workspace/vllm/worker/model_runner.py", line 57, in load_model                                                                                                                                                        
self.model = get_model(self.model_config)                                                                                                                                                                                
File "/workspace/vllm/model_executor/model_loader.py", line 72, in get_model                                                                                                                                                 
model.load_weights(model_config.model, model_config.download_dir,                                                                                                                                                        
File "/workspace/vllm/model_executor/models/mixtral.py", line 430, in load_weights                                                                                                                                           
param = params_dict[name]                                                                                                                                                                                              
KeyError: 'model.layers.0.block_sparse_moe.experts.0.w1.g_idx'  

I think there's a mismatch between the keys of all the HF quantized models and the default model uploaded by Mistral. Maybe we just need a lookup table? Or do the instruct versions have different params?

@Sirri69
Copy link

Sirri69 commented Dec 19, 2023

Yep, I'm facing the same issue.
Sys config:

2x A30 -- On Runpod
16 vCPUs
62GB RAM

Command:

python -m vllm.entrypoints.api_server \
    --model TheBloke/dolphin-2.5-mixtral-8x7b-GPTQ \
    --tensor-parallel-size 2 --port 8000 --dtype float16
WARNING 12-19 18:20:29 config.py:467] Casting torch.bfloat16 to torch.float16.
WARNING 12-19 18:20:29 config.py:179] gptq quantization is not fully optimized yet. The speed can be slower than non-quantized models.
WARNING 12-19 18:20:29 config.py:191] gptq does not support CUDA graph yet. Disabling CUDA graph.
2023-12-19 18:20:30,880 WARNING utils.py:581 -- Detecting docker specified CPUs. In previous versions of Ray, CPU detection in containers was incorrect. Please ensure that Ray has enough CPUs allocated. As a temporary workaround to revert to the prior behavior, set RAY_USE_MULTIPROCESSING_CPU_COUNT=1 as an env var before starting Ray. Set the env var: RAY_DISABLE_DOCKER_CPU_WARNING=1 to mute this warning.
2023-12-19 18:20:30,880 WARNING utils.py:593 -- Ray currently does not support initializing Ray with fractional cpus. Your num_cpus will be truncated from 13.6 to 13.
2023-12-19 18:20:31,953 INFO worker.py:1673 -- Started a local Ray instance.
INFO 12-19 18:20:32 llm_engine.py:73] Initializing an LLM engine with config: model='TheBloke/dolphin-2.5-mixtral-8x7b-GPTQ', tokenizer='TheBloke/dolphin-2.5-mixtral-8x7b-GPTQ', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=2, quantization=gptq, enforce_eager=True, seed=0)
tokenizer_config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.70k/1.70k [00:00<00:00, 9.19MB/s]
tokenizer.model: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 493k/493k [00:00<00:00, 47.5MB/s]
added_tokens.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 51.0/51.0 [00:00<00:00, 231kB/s]
special_tokens_map.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 443/443 [00:00<00:00, 1.56MB/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
(RayWorkerVllm pid=2228) /usr/local/lib/python3.10/dist-packages/torch/nn/init.py:412: UserWarning: Initializing zero-element tensors is a no-op
(RayWorkerVllm pid=2228)   warnings.warn("Initializing zero-element tensors is a no-op")
model.safetensors: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23.8G/23.8G [01:48<00:00, 219MB/s]

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/api_server.py", line 80, in <module>
    engine = AsyncLLMEngine.from_engine_args(engine_args)
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 496, in from_engine_args
    engine = cls(parallel_config.worker_use_ray,
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 269, in _init_
    self.engine = self._init_engine(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 314, in _init_engine
    return engine_class(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 108, in _init_
    self._init_workers_ray(placement_group)
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 195, in _init_workers_ray
    self._run_workers(
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 755, in _run_workers
    self._run_workers_in_batch(workers, method, *args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 732, in _run_workers_in_batch
    all_outputs = ray.get(all_outputs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 24, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2563, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(KeyError): ray::RayWorkerVllm.execute_method() (pid=2228, ip=192.168.16.2, actor_id=68d3f2408f6899611388f36101000000, repr=<vllm.engine.ray_utils.RayWorkerVllm object at 0x7ef674220790>)
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/ray_utils.py", line 31, in execute_method
    return executor(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 79, in load_model
    self.model_runner.load_model()
  File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 57, in load_model
    self.model = get_model(self.model_config)
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader.py", line 72, in get_model
    model.load_weights(model_config.model, model_config.download_dir,
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/mixtral.py", line 430, in load_weights
    param = params_dict[name]
KeyError: 'model.layers.0.block_sparse_moe.experts.4.w1.g_idx'

@Sirri69
Copy link

Sirri69 commented Dec 19, 2023

None of the params in params_dict (mixtral.py:406)

params_dict = dict(self.named_parameters())

matches the weights loaded from the safetensors file from hf_model_weights_iterator in weight_utils.py:191

Seems like a silly mistake, a goofy fault, indeed. Pleaze fix ASAP.

@WoosukKwon
Copy link
Collaborator

@jbohnslav @Sirri69 Thanks for reporting the bug! I only tested the Mixtral GPTQ model with a single GPU. I reproduced the bug and fixed it in #2208

@jbohnslav
Copy link
Author

Thanks so much @WoosukKwon! I'll check it out now.

@jbohnslav
Copy link
Author

Confirm fixed!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants