Skip to content

Commit

Permalink
fix supported model list of ascend graph mode (#2669)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinminxi104 authored Oct 28, 2024
1 parent a41a2a2 commit f5189ce
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion docs/en/get_started/ascend/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ For more information about running the Docker client on Ascend devices, please r
## Offline batch inference

> \[!TIP\]
> Graph mode has been supported on Atlas 800T A2. Currently, InternLM2-7B/LLaMa2-7B/Qwen2-7B are tested on graph mode.
> Graph mode has been supported on Atlas 800T A2. Currently, LLaMa3-8B/LLaMa2-7B/Qwen2-7B are tested on graph mode.
> Users can set `eager_mode=False` to enable graph mode, or, set `eager_mode=True` to disable graph mode.
> (Please source `/usr/local/Ascend/nnal/atb/set_env.sh` before enabling graph mode)
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/get_started/ascend/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ docker run -e ASCEND_VISIBLE_DEVICES=0 --rm --name lmdeploy -t lmdeploy-aarch64-
## 离线批处理

> \[!TIP\]
> 图模式已经支持了Atlas 800T A2。目前,单卡下的InternLM2-7B/LLaMa2-7B/Qwen2-7B已经通过测试。用户可以设定`eager_mode=False`来开启图模式,或者设定`eager_mode=True`来关闭图模式。(启动图模式需要事先source `/usr/local/Ascend/nnal/atb/set_env.sh`)
> 图模式已经支持了Atlas 800T A2。目前,单卡下的LLaMa3-8B/LLaMa2-7B/Qwen2-7B已经通过测试。用户可以设定`eager_mode=False`来开启图模式,或者设定`eager_mode=True`来关闭图模式。(启动图模式需要事先source `/usr/local/Ascend/nnal/atb/set_env.sh`)
### LLM 推理

Expand Down
30 changes: 15 additions & 15 deletions lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig,
super().__init__(model, model_config, cache_config, backend_config,
device)

self.supported_model = ['Llama3-8B', 'Llama2-7B', 'Qwen2-7B']
self.enable_graph = self.check_enable_graph()
if self.enable_graph:
import dlinfer.graph
Expand All @@ -44,21 +45,20 @@ def check_enable_graph(self):
"Graph mode of device_type 'ascend' only supports tp=1 "
'for now, fallback to eager mode', RuntimeWarning)
return False
# model support
self.supported_model = {
'Llama2': 'LlamaConfig',
'InternLM2': 'InternLM2Config',
'Qwen2': 'Qwen2Config',
}
is_model_support = True
model_config_name = str(type(self.model_config.hf_config).__name__)
if model_config_name not in self.supported_model.values():
is_model_support = False
if not is_model_support:
warnings.warn(
"Graph mode of device_type 'ascend' only supports models: "
f"{', '.join(self.supported_model.keys())} when tp=1 for now",
RuntimeWarning)

warnings.warn(
'\n\n'
'**********************************************************\n'
' The following models were tested in graph mode of\n'
" device_type 'ascend' when tp=1:\n"
f" {', '.join(self.supported_model)}\n"
' Other LLaMa-like models may work in graph mode, please\n'
' check the result yourself!\n'
' If graph mode does not work correctly with your model,\n'
' please use eager mode instead.\n'
'**********************************************************\n\n',
RuntimeWarning)

return True

def patch_kernels_custom_op(self):
Expand Down

0 comments on commit f5189ce

Please sign in to comment.