Skip to content

"attn_bias is not correctly aligned" on A100 for MPT-30B #795

Closed
@dlopes78

Description

@dlopes78

Hello,

I saw a similar issue to this for MPT30B0-chat on H100, but I see the same error on A100 80Gb. Using vllm 0.1.3. Is there any workaround to fix this currently? It does happen for random prompt, so not straightforward to understand where it's coming from:

96 │   │   │   │   prompt_template = PromptTemplate(input_variables=["text"] │

│ 97 │ │ │ │ answer_chain = LLMChain(llm=self.llm , prompt=prompt_temp │
│ 98 │ │ │ │ │
│ ❱ 99 │ │ │ │ response = answer_chain.run(query) │
│ 100 │ │ │ │
│ 101 │ │ │ else: │
│ 102 │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/chains/base. │
│ py:440 in run │
│ │
│ 437 │ │ if args and not kwargs: │
│ 438 │ │ │ if len(args) != 1: │
│ 439 │ │ │ │ raise ValueError("run supports only one positional argu │
│ ❱ 440 │ │ │ return self(args[0], callbacks=callbacks, tags=tags, metadata │
│ 441 │ │ │ │ _output_key │
│ 442 │ │ │ ] │
│ 443 │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/chains/base. │
│ py:243 in call
│ │
│ 240 │ │ │ ) │
│ 241 │ │ except (KeyboardInterrupt, Exception) as e: │
│ 242 │ │ │ run_manager.on_chain_error(e) │
│ ❱ 243 │ │ │ raise e │
│ 244 │ │ run_manager.on_chain_end(outputs) │
│ 245 │ │ final_outputs: Dict[str, Any] = self.prep_outputs( │
│ 246 │ │ │ inputs, outputs, return_only_outputs │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/chains/base. │
│ py:237 in call
│ │
│ 234 │ │ ) │
│ 235 │ │ try: │
│ 236 │ │ │ outputs = ( │
│ ❱ 237 │ │ │ │ self._call(inputs, run_manager=run_manager) │
│ 238 │ │ │ │ if new_arg_supported │
│ 239 │ │ │ │ else self._call(inputs) │
│ 240 │ │ │ ) │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/chains/llm.p │
│ y:92 in _call │
│ │
│ 89 │ │ inputs: Dict[str, Any], │
│ 90 │ │ run_manager: Optional[CallbackManagerForChainRun] = None, │
│ 91 │ ) -> Dict[str, str]: │
│ ❱ 92 │ │ response = self.generate([inputs], run_manager=run_manager) │
│ 93 │ │ return self.create_outputs(response)[0] │
│ 94 │ │
│ 95 │ def generate( │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/chains/llm.p │
│ y:102 in generate │
│ │
│ 99 │ ) -> LLMResult: │
│ 100 │ │ """Generate LLM result from inputs.""" │
│ 101 │ │ prompts, stop = self.prep_prompts(input_list, run_manager=run_man │
│ ❱ 102 │ │ return self.llm.generate_prompt( │
│ 103 │ │ │ prompts, │
│ 104 │ │ │ stop, │
│ 105 │ │ │ callbacks=run_manager.get_child() if run_manager else None, │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/llms/base.py │
│ :186 in generate_prompt │
│ │
│ 183 │ │ **kwargs: Any, │
│ 184 │ ) -> LLMResult: │
│ 185 │ │ prompt_strings = [p.to_string() for p in prompts] │
│ ❱ 186 │ │ return self.generate(prompt_strings, stop=stop, callbacks=callbac │
│ 187 │ │
│ 188 │ async def agenerate_prompt( │
│ 189 │ │ self, │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/llms/base.py │
│ :279 in generate │
│ │
│ 276 │ │ │ run_managers = callback_manager.on_llm_start( │
│ 277 │ │ │ │ dumpd(self), prompts, invocation_params=params, options=o │
│ 278 │ │ │ ) │
│ ❱ 279 │ │ │ output = self._generate_helper( │
│ 280 │ │ │ │ prompts, stop, run_managers, bool(new_arg_supported), **k │
│ 281 │ │ │ ) │
│ 282 │ │ │ return output │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/llms/base.py │
│ :223 in _generate_helper │
│ │
│ 220 │ │ except (KeyboardInterrupt, Exception) as e: │
│ 221 │ │ │ for run_manager in run_managers: │
│ 222 │ │ │ │ run_manager.on_llm_error(e) │
│ ❱ 223 │ │ │ raise e │
│ 224 │ │ flattened_outputs = output.flatten() │
│ 225 │ │ for manager, flattened_output in zip(run_managers, flattened_outp │
│ 226 │ │ │ manager.on_llm_end(flattened_output) │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/llms/base.py │
│ :210 in _generate_helper │
│ │
│ 207 │ ) -> LLMResult: │
│ 208 │ │ try: │
│ 209 │ │ │ output = ( │
│ ❱ 210 │ │ │ │ self._generate( │
│ 211 │ │ │ │ │ prompts, │
│ 212 │ │ │ │ │ stop=stop, │
│ 213 │ │ │ │ │ # TODO: support multiple run managers │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/langchain/llms/base.py │
│ :604 in _generate │
│ │
│ 601 │ │ │ text = ( │
│ 602 │ │ │ │ self._call(prompt, stop=stop, run_manager=run_manager, ** │
│ 603 │ │ │ │ if new_arg_supported │
│ ❱ 604 │ │ │ │ else self._call(prompt, stop=stop, **kwargs) │
│ 605 │ │ │ ) │
│ 606 │ │ │ generations.append([Generation(text=text)]) │
│ 607 │ │ return LLMResult(generations=generations) │
│ │
│ /root/xxx.py:64 in _call │
│ │
│ 61 │ │ │ │ │ │ │ max_tokens=300, │
│ 62 │ │ │ │ │ │ │ ) │
│ 63 │ │ │
│ ❱ 64 │ │ output = model.generate(prompt, sampling_params) │
│ 65 │ │ │
│ 66 │ │ return output[0].outputs[0].text │
│ 67 │
│ │
│ /root/vllm/vllm/entrypoints/llm.py:130 in generate │
│ │
│ 127 │ │ │ else: │
│ 128 │ │ │ │ token_ids = prompt_token_ids[i] │
│ 129 │ │ │ self._add_request(prompt, sampling_params, token_ids) │
│ ❱ 130 │ │ return self._run_engine(use_tqdm) │
│ 131 │ │
│ 132 │ def _add_request( │
│ 133 │ │ self, │
│ │
│ /root/vllm/vllm/entrypoints/llm.py:150 in _run_engine │
│ │
│ 147 │ │ # Run the engine. │
│ 148 │ │ outputs: List[RequestOutput] = [] │
│ 149 │ │ while self.llm_engine.has_unfinished_requests(): │
│ ❱ 150 │ │ │ step_outputs = self.llm_engine.step() │
│ 151 │ │ │ for output in step_outputs: │
│ 152 │ │ │ │ if output.finished: │
│ 153 │ │ │ │ │ outputs.append(output) │
│ │
│ /root/vllm/vllm/engine/llm_engine.py:313 in step │
│ │
│ 310 │ │ │ ] │
│ 311 │ │ │
│ 312 │ │ # Execute the model. │
│ ❱ 313 │ │ output = self._run_workers( │
│ 314 │ │ │ "execute_model", │
│ 315 │ │ │ seq_group_metadata_list=seq_group_metadata_list, │
│ 316 │ │ │ blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, │
│ │
│ /root/vllm/vllm/engine/llm_engine.py:470 in _run_workers │
│ │
│ 467 │ │ │ else: │
│ 468 │ │ │ │ executor = getattr(worker, method) │
│ 469 │ │ │ │
│ ❱ 470 │ │ │ output = executor(*args, **kwargs) │
│ 471 │ │ │ all_outputs.append(output) │
│ 472 │ │ │
│ 473 │ │ if self.parallel_config.worker_use_ray: │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/torch/utils/_contextli │
│ b.py:115 in decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /root/vllm/vllm/worker/worker.py:293 in execute_model │
│ │
│ 290 │ │ │ seq_group_metadata_list) │
│ 291 │ │ │
│ 292 │ │ # Execute the model. │
│ ❱ 293 │ │ output = self.model( │
│ 294 │ │ │ input_ids=input_tokens, │
│ 295 │ │ │ positions=input_positions, │
│ 296 │ │ │ kv_caches=self.gpu_cache, │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/modul │
│ e.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /root/vllm/vllm/model_executor/models/mpt.py:234 in forward │
│ │
│ 231 │ │ input_metadata: InputMetadata, │
│ 232 │ │ cache_events: Optional[List[torch.cuda.Event]], │
│ 233 │ ) -> Dict[int, SequenceOutputs]: │
│ ❱ 234 │ │ hidden_states = self.transformer(input_ids, positions, kv_caches, │
│ 235 │ │ │ │ │ │ │ │ │ │ input_metadata, cache_events) │
│ 236 │ │ next_tokens = self.sampler(self.lm_head_weight, hidden_states, │
│ 237 │ │ │ │ │ │ │ │ input_metadata) │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/modul │
│ e.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /root/vllm/vllm/model_executor/models/mpt.py:202 in forward │
│ │
│ 199 │ │ │ else: │
│ 200 │ │ │ │ cache_event = cache_events[i] │
│ 201 │ │ │ block = self.blocks[i] │
│ ❱ 202 │ │ │ hidden_states = block( │
│ 203 │ │ │ │ position_ids, │
│ 204 │ │ │ │ hidden_states, │
│ 205 │ │ │ │ kv_caches[i], │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/modul │
│ e.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /root/vllm/vllm/model_executor/models/mpt.py:153 in forward │
│ │
│ 150 │ │ cache_event: Optional[torch.cuda.Event], │
│ 151 │ ) -> torch.Tensor: │
│ 152 │ │ x = self.norm_1(hidden_states) │
│ ❱ 153 │ │ x = self.attn( │
│ 154 │ │ │ position_ids=position_ids, │
│ 155 │ │ │ hidden_states=x, │
│ 156 │ │ │ kv_cache=kv_cache, │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/modul │
│ e.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /root/vllm/vllm/model_executor/models/mpt.py:102 in forward │
│ │
│ 99 │ │ │ q = self.q_ln(q) │
│ 100 │ │ │ k = self.k_ln(k) │
│ 101 │ │ k_cache, v_cache = kv_cache │
│ ❱ 102 │ │ attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata │
│ 103 │ │ │ │ │ │ │ │ cache_event) │
│ 104 │ │ output, _ = self.out_proj(attn_output) │
│ 105 │ │ return output │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/modul │
│ e.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /root/vllm/vllm/model_executor/layers/attention.py:202 in forward │
│ │
│ 199 │ │ │ # Prompt run. │
│ 200 │ │ │ assert input_metadata.num_generation_tokens == 0 │
│ 201 │ │ │ self.set_attn_bias(input_metadata) │
│ ❱ 202 │ │ │ self.multi_query_kv_attention( │
│ 203 │ │ │ │ output[:num_prompt_tokens], │
│ 204 │ │ │ │ query[:num_prompt_tokens], │
│ 205 │ │ │ │ key[:num_prompt_tokens], │
│ │
│ /root/vllm/vllm/model_executor/layers/attention.py:399 in │
│ multi_query_kv_attention │
│ │
│ 396 │ │ start = 0 │
│ 397 │ │ for i, prompt_len in enumerate(input_metadata.prompt_lens): │
│ 398 │ │ │ end = start + prompt_len │
│ ❱ 399 │ │ │ out = xops.memory_efficient_attention_forward( │
│ 400 │ │ │ │ query[None, start:end], │
│ 401 │ │ │ │ key[None, start:end], │
│ 402 │ │ │ │ value[None, start:end], │
│ │
│ /root/miniconda3/envs/py311/lib/python3.11/site-packages/xformers/ops/fmha/in │
│ it
.py:213 in memory_efficient_attention_forward │
│ │
│ 210 │ """ │
│ 211 │ Calculates the forward pass of :attr:xformers.ops.memory_efficient_a │ │ 212 │ """ │ │ ❱ 213 │ return _memory_efficient_attention_forward( │ │ 214 │ │ Inputs( │ │ 215 │ │ │ query=query, key=key, value=value, p=p, attn_bias=attn_bias, │ │ 216 │ │ ), │ │ │ │ /root/miniconda3/envs/py311/lib/python3.11/site-packages/xformers/ops/fmha/__in │ │ it__.py:310 in _memory_efficient_attention_forward │ │ │ │ 307 │ else: │ │ 308 │ │ _ensure_op_supports_or_raise(ValueError, "memory_efficient_attent │ │ 309 │ │ │ ❱ 310 │ out, *_ = op.apply(inp, needs_gradient=False) │ │ 311 │ return out.reshape(output_shape) │ │ 312 │ │ 313 │ │ │ │ /root/miniconda3/envs/py311/lib/python3.11/site-packages/xformers/ops/fmha/cutl │ │ ass.py:175 in apply │ │ │ │ 172 │ │ if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES: │ │ 173 │ │ │ raise NotImplementedError("Unsupported attn_bias type") │ │ 174 │ │ seqstart_k, seqstart_q, max_seqlen_q, _ = _get_seqlen_info(inp) │ │ ❱ 175 │ │ out, lse, rng_seed, rng_offset = cls.OPERATOR( │ │ 176 │ │ │ query=inp.query, │ │ 177 │ │ │ key=inp.key, │ │ 178 │ │ │ value=inp.value, │ │ │ │ /root/miniconda3/envs/py311/lib/python3.11/site-packages/torch/_ops.py:502 in │ │ __call__ │ │ │ │ 499 │ │ # is still callable from JIT │ │ 500 │ │ # We save the function ptr as the op` attribute on │
│ 501 │ │ # OpOverloadPacket to access it here. │
│ ❱ 502 │ │ return self._op(*args, **kwargs or {}) │
│ 503 │ │
│ 504 │ # TODO: use this to make a dir
│ 505 │ def overloads(self): │
╰─────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: attn_bias is not correctly aligned

currently generating with:
model = vllm.LLM(model= "MPT-30B" , trust_remote_code=True)
sampling_params = vllm.SamplingParams(n=1,
temperature=0.2,
top_p=0.9,
top_k = -1,
best_of=1,
use_beam_search=False,
max_tokens=300,
)

    output = model.generate(prompt, sampling_params)

Other libraries version:
xformers==0.0.20
langchain==0.0.232
torch==2.0.1
pytorch-triton==2.1.0+440fd1bf20

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions