Skip to content

Commit

Permalink
[ascend] add ascend graph mode (#2647)
Browse files Browse the repository at this point in the history
* [pytorch] ascend enable atbgraph

* add paged prefill attention

* refine ascend-update-step-ctx (#26)

refine ascend-update-step-ctx

---------

Co-authored-by: CyCle1024 <chenchiyu@pjlab.org.cn>

* fix: rewrite enable graph for ascend

* fix backend error due to folder refactor

* remove unnecessary comment

* fix rotary_embedding (#27)

---------

Co-authored-by: jinminxi104 <jinminxi104@hotmail.com>
Co-authored-by: tangzhiyi11 <tangzhiyi11@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 25, 2024
1 parent c25520a commit 44a0cd3
Show file tree
Hide file tree
Showing 9 changed files with 315 additions and 29 deletions.
21 changes: 21 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.kernels.dlinfer.activation import silu_and_mul

from ..activation import SiluAndMulBuilder, SiluAndMulImpl


class DlinferSiluAndMulImpl(SiluAndMulImpl):
"""silu + multiple fused implementation."""

def forward(self, x):
"""forward."""
return silu_and_mul(x)


class DlinferSiluAndMulBuilder(SiluAndMulBuilder):
"""silu and mul implementation builder."""

@staticmethod
def build(inplace: bool = False):
"""build."""
return DlinferSiluAndMulImpl()
116 changes: 116 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from importlib import import_module

import torch
import torch.distributed

from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.utils import get_logger

from ...graph_runner import GraphRunner

logger = get_logger('lmdeploy')


class AscendGraphRunner(GraphRunner):
"""ascend graph runner."""

def __init__(self, model: torch.nn.Module, model_config: ModelConfig,
cache_config: CacheConfig, backend_config: BackendConfig,
device: torch.device):
super().__init__(model, model_config, cache_config, backend_config,
device)

self.enable_graph = self.check_enable_graph()
if self.enable_graph:
import dlinfer.graph
dlinfer.graph.config.enable_graph_mode = True
self.patch_kernels_custom_op()
self.patch_kvcache_static_shape()
self.model = torch.compile(self.model,
fullgraph=True,
dynamic=True,
backend='atbgraph')

def check_enable_graph(self):
"""check enable graph."""
# eager_mode
if self.backend_config.eager_mode:
return False
# tp
if torch.distributed.is_initialized():
warnings.warn(
"Graph mode of device_type 'ascend' only supports tp=1 "
'for now, fallback to eager mode', RuntimeWarning)
return False
# model support
self.supported_model = {
'Llama2': 'LlamaConfig',
'InternLM2': 'InternLM2Config',
'Qwen2': 'Qwen2Config',
}
is_model_support = True
model_config_name = str(type(self.model_config.hf_config).__name__)
if model_config_name not in self.supported_model.values():
is_model_support = False
if not is_model_support:
warnings.warn(
"Graph mode of device_type 'ascend' only supports models: "
f"{', '.join(self.supported_model.keys())} when tp=1 for now",
RuntimeWarning)
return True

def patch_kernels_custom_op(self):
from dlinfer.graph.custom_op import register_custom_op
dlinfer_kernels_module = import_module(
'lmdeploy.pytorch.kernels.dlinfer')
dlinfer_backends_module = import_module(
'lmdeploy.pytorch.backends.dlinfer')

# prefill_attention
module_str = 'pagedattention'
paged_attn_module = getattr(dlinfer_kernels_module, module_str)
func_str = 'prefill_attention'
prefill_attn_origin = getattr(paged_attn_module, func_str)
prefill_attn_registered = register_custom_op(
f'lmdeploy::{func_str}', ['attn_output'])(prefill_attn_origin)
setattr(paged_attn_module, func_str, prefill_attn_registered)

# apply_rotary_pos_emb
def apply_rotary_emb_abstract_impl(q, k, cos, sin, q_out, k_out):
result = [q, k]
if q_out is not None:
result[0] = q_out
if k_out is not None:
result[1] = k_out
return tuple(result)

module_str = 'apply_rotary_emb'
apply_rotary_emb_module = getattr(dlinfer_backends_module, module_str)
func_str = 'apply_rotary_pos_emb'
apply_rotary_pos_emb_origin = getattr(apply_rotary_emb_module,
func_str)
apply_rotary_pos_emb_registered = register_custom_op(
f'lmdeploy::{func_str}',
impl_abstract_func=apply_rotary_emb_abstract_impl)(
apply_rotary_pos_emb_origin)
setattr(apply_rotary_emb_module, func_str,
apply_rotary_pos_emb_registered)

def patch_kvcache_static_shape(self):
import torch._dynamo as dynamo
from torch.utils._pytree import tree_map
cache_engine_module = import_module(
'lmdeploy.pytorch.engine.cache_engine')
class_str = 'CacheEngine'
cache_engine_class = getattr(cache_engine_module, class_str)
func_str = 'allocate_gpu_cache'
allocate_gpu_cache_origin = getattr(cache_engine_class, func_str)

def allocate_gpu_cache_mark_static(self):
gpu_cache = allocate_gpu_cache_origin(self)
tree_map(lambda x: dynamo.mark_static(x), gpu_cache)
return gpu_cache

setattr(cache_engine_class, func_str, allocate_gpu_cache_mark_static)
99 changes: 74 additions & 25 deletions lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.utils import get_logger

from ..op_backend import DlinferOpsBackend
Expand All @@ -12,6 +13,9 @@

class AscendOpsBackend(DlinferOpsBackend):
"""ascend layer backend."""
enable_graph = False
half_negative_inf = torch.finfo(torch.float16).min
total_slots = None

@staticmethod
def get_name() -> str:
Expand Down Expand Up @@ -45,21 +49,23 @@ def get_v_block_shape(
@classmethod
def update_step_context(cls, step_context):
"""update step context."""

def get_total_slots():
if cls.total_slots is None:
cls.total_slots = torch.arange(
block_num * block_size,
dtype=torch.long,
device=step_context.block_offsets.device)
cls.total_slots = cls.total_slots.view(block_num, block_size)
return cls.total_slots

kv_start_indices, attention_mask = [], []
block_num, block_size, _ = step_context.kv_caches[0][0].shape
device = step_context.block_offsets.device

is_unpaged_prefill = False
if not step_context.is_decoding:
is_unpaged_prefill = \
all((step_context.q_seqlens ==
step_context.kv_seqlens).tolist())

total_slots = torch.arange(block_num * block_size,
dtype=torch.long,
device=device)
total_slots = total_slots.view(block_num, block_size)

q_seqlens_list = step_context.q_seqlens.tolist()
kv_seqlens_list = step_context.kv_seqlens.tolist()
max_q_seq_len = max(q_seqlens_list)
Expand All @@ -71,9 +77,9 @@ def update_step_context(cls, step_context):

# collect kv start indices.
history_length = kv_seq_len - q_seq_len
slot_tables = total_slots[step_context.block_offsets[i]].flatten()
slot_indices = [p for p in range(history_length, kv_seq_len)]
slots = slot_tables[slot_indices].reshape((-1, 1))
total_slots = get_total_slots()
slot_tables = total_slots[step_context.block_offsets[i]].view(-1)
slots = slot_tables[history_length:kv_seq_len]
kv_start_indices.append(slots)

# collect attention mask of paged_prefill attention stage.
Expand All @@ -83,19 +89,19 @@ def update_step_context(cls, step_context):
torch.ones(q_seq_len,
step_context.block_offsets.shape[1] *
block_size,
dtype=torch.bool).cuda(),
dtype=torch.bool,
device=step_context.block_offsets.device),
diagonal=kv_seq_len - q_seq_len,
))
attention_mask.append(single_attention_mask)

kv_start_indices = torch.cat(kv_start_indices)

if step_context.is_decoding:
# prepare somae params of paged_decode attention stage.
# prepare some params of paged_decode attention stage.
q_start_loc_cpu, q_seqlens_cpu = None, None
kv_seqlens_cpu = step_context.kv_seqlens.cpu()
elif is_unpaged_prefill:
# prepare somae params of unpaged_prefill attention stage.
# prepare some params of unpaged_prefill attention stage.
q_start_loc_cpu, kv_seqlens_cpu = None, None
q_seqlens_cpu = step_context.q_seqlens.cpu()
single_attention_mask = torch.logical_not(
Expand All @@ -106,24 +112,54 @@ def update_step_context(cls, step_context):
))
attention_mask.append(single_attention_mask)
else:
# prepare somae params of paged_prefill attention stage.
# prepare some params of paged_prefill attention stage.
q_start_loc_cpu, q_seqlens_cpu = None, None
kv_seqlens_cpu = step_context.kv_seqlens.repeat_interleave(
step_context.q_seqlens, 0).cpu()
block_offsets_int32 = step_context.block_offsets.to(torch.int32)
step_context.block_offsets = block_offsets_int32.repeat_interleave(
step_context.q_seqlens, 0)
attention_mask = [
torch.cat([mask for mask in attention_mask]).unsqueeze(1)
]
attention_mask = [torch.cat([mask for mask in attention_mask])]

if cls.enable_graph:
kv_start_indices = kv_start_indices.view(-1).to(torch.int32)
import torch._dynamo as dynamo
if not is_unpaged_prefill:
step_context.block_offsets = step_context.block_offsets.to(
torch.int32)
if not step_context.is_decoding:
step_context.block_offsets = step_context.block_offsets\
.repeat_interleave(step_context.q_seqlens, 0)
dynamo.mark_dynamic(step_context.block_offsets, [0, 1])
kv_seqlens = step_context.kv_seqlens.to(torch.int32)
if not step_context.is_decoding:
if is_unpaged_prefill:
attention_mask = [mask.half() for mask in attention_mask]
else:
attention_mask = [
torch.cat([
mask.half() * cls.half_negative_inf
for mask in attention_mask
]).unsqueeze(1)
]
kv_seqlens = kv_seqlens.repeat_interleave(
step_context.q_seqlens, 0)
else:
if step_context.is_decoding:
kv_seqlens_cpu = step_context.kv_seqlens.cpu()
elif is_unpaged_prefill:
pass
else:
kv_seqlens_cpu = step_context.kv_seqlens.repeat_interleave(
step_context.q_seqlens, 0).cpu()
block_offsets_int32 = step_context.block_offsets.to(
torch.int32)
step_context.block_offsets = block_offsets_int32\
.repeat_interleave(step_context.q_seqlens, 0)
kv_seqlens = kv_seqlens_cpu

attn_meta_cls = cls.get_attention_metadata_cls()
attn_metadata = attn_meta_cls(
step_context.is_decoding,
step_context.block_offsets,
q_start_loc=q_start_loc_cpu,
q_seqlens=q_seqlens_cpu,
kv_seqlens=kv_seqlens_cpu,
kv_seqlens=kv_seqlens,
kv_start_indices=kv_start_indices,
block_size=block_size,
attention_mask=attention_mask,
Expand All @@ -134,3 +170,16 @@ def update_step_context(cls, step_context):

step_context.attn_metadata = attn_metadata
return step_context

@staticmethod
def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig,
cache_config: CacheConfig,
backend_config: BackendConfig,
device: torch.device):
"""build graph runner."""
from .graph_runner import AscendGraphRunner
ascend_graph_runner = AscendGraphRunner(model, model_config,
cache_config, backend_config,
device)
AscendOpsBackend.enable_graph = ascend_graph_runner.enable_graph
return ascend_graph_runner
6 changes: 6 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def get_layer_impl_builder(cls, layer_type: OpType):
elif layer_type == OpType.ApplyRotaryEmb:
from .apply_rotary_emb import DlinferApplyRotaryEmbBuilder
return DlinferApplyRotaryEmbBuilder
elif layer_type == OpType.SiluAndMul:
from .activation import DlinferSiluAndMulBuilder
return DlinferSiluAndMulBuilder
elif layer_type == OpType.RMSNorm:
from .norm import DlinferRMSNormBuilder
return DlinferRMSNormBuilder
Expand All @@ -40,6 +43,9 @@ def get_layer_impl_builder(cls, layer_type: OpType):
elif layer_type == OpType.LinearW4A16:
from .awq_modules import AwqLinearW4A16Builder
return AwqLinearW4A16Builder
elif layer_type == OpType.RotaryEmbedding:
from .rotary_embedding import DlinferRotaryEmbeddingBuilder
return DlinferRotaryEmbeddingBuilder
else:
logger.debug(
f'Op {layer_type} fallback to default implementation.')
Expand Down
Loading

0 comments on commit 44a0cd3

Please sign in to comment.