From 44a0cd31ec504d14fdb6f440368f885c1134e93a Mon Sep 17 00:00:00 2001 From: CyCle1024 Date: Fri, 25 Oct 2024 16:32:58 +0800 Subject: [PATCH] [ascend] add ascend graph mode (#2647) * [pytorch] ascend enable atbgraph * add paged prefill attention * refine ascend-update-step-ctx (#26) refine ascend-update-step-ctx --------- Co-authored-by: CyCle1024 * fix: rewrite enable graph for ascend * fix backend error due to folder refactor * remove unnecessary comment * fix rotary_embedding (#27) --------- Co-authored-by: jinminxi104 Co-authored-by: tangzhiyi11 --- .../pytorch/backends/dlinfer/activation.py | 21 ++++ .../backends/dlinfer/ascend/graph_runner.py | 116 ++++++++++++++++++ .../backends/dlinfer/ascend/op_backend.py | 99 +++++++++++---- .../pytorch/backends/dlinfer/op_backend.py | 6 + .../backends/dlinfer/rotary_embedding.py | 84 +++++++++++++ lmdeploy/pytorch/engine/logits_process.py | 1 + .../pytorch/kernels/dlinfer/activation.py | 7 ++ .../kernels/dlinfer/apply_rotary_pos_emb.py | 8 +- .../pytorch/kernels/dlinfer/pagedattention.py | 2 +- 9 files changed, 315 insertions(+), 29 deletions(-) create mode 100644 lmdeploy/pytorch/backends/dlinfer/activation.py create mode 100644 lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py create mode 100644 lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py create mode 100644 lmdeploy/pytorch/kernels/dlinfer/activation.py diff --git a/lmdeploy/pytorch/backends/dlinfer/activation.py b/lmdeploy/pytorch/backends/dlinfer/activation.py new file mode 100644 index 000000000..566fe1162 --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/activation.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.kernels.dlinfer.activation import silu_and_mul + +from ..activation import SiluAndMulBuilder, SiluAndMulImpl + + +class DlinferSiluAndMulImpl(SiluAndMulImpl): + """silu + multiple fused implementation.""" + + def forward(self, x): + """forward.""" + return silu_and_mul(x) + + +class DlinferSiluAndMulBuilder(SiluAndMulBuilder): + """silu and mul implementation builder.""" + + @staticmethod + def build(inplace: bool = False): + """build.""" + return DlinferSiluAndMulImpl() diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py new file mode 100644 index 000000000..3ecc4223b --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from importlib import import_module + +import torch +import torch.distributed + +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig +from lmdeploy.utils import get_logger + +from ...graph_runner import GraphRunner + +logger = get_logger('lmdeploy') + + +class AscendGraphRunner(GraphRunner): + """ascend graph runner.""" + + def __init__(self, model: torch.nn.Module, model_config: ModelConfig, + cache_config: CacheConfig, backend_config: BackendConfig, + device: torch.device): + super().__init__(model, model_config, cache_config, backend_config, + device) + + self.enable_graph = self.check_enable_graph() + if self.enable_graph: + import dlinfer.graph + dlinfer.graph.config.enable_graph_mode = True + self.patch_kernels_custom_op() + self.patch_kvcache_static_shape() + self.model = torch.compile(self.model, + fullgraph=True, + dynamic=True, + backend='atbgraph') + + def check_enable_graph(self): + """check enable graph.""" + # eager_mode + if self.backend_config.eager_mode: + return False + # tp + if torch.distributed.is_initialized(): + warnings.warn( + "Graph mode of device_type 'ascend' only supports tp=1 " + 'for now, fallback to eager mode', RuntimeWarning) + return False + # model support + self.supported_model = { + 'Llama2': 'LlamaConfig', + 'InternLM2': 'InternLM2Config', + 'Qwen2': 'Qwen2Config', + } + is_model_support = True + model_config_name = str(type(self.model_config.hf_config).__name__) + if model_config_name not in self.supported_model.values(): + is_model_support = False + if not is_model_support: + warnings.warn( + "Graph mode of device_type 'ascend' only supports models: " + f"{', '.join(self.supported_model.keys())} when tp=1 for now", + RuntimeWarning) + return True + + def patch_kernels_custom_op(self): + from dlinfer.graph.custom_op import register_custom_op + dlinfer_kernels_module = import_module( + 'lmdeploy.pytorch.kernels.dlinfer') + dlinfer_backends_module = import_module( + 'lmdeploy.pytorch.backends.dlinfer') + + # prefill_attention + module_str = 'pagedattention' + paged_attn_module = getattr(dlinfer_kernels_module, module_str) + func_str = 'prefill_attention' + prefill_attn_origin = getattr(paged_attn_module, func_str) + prefill_attn_registered = register_custom_op( + f'lmdeploy::{func_str}', ['attn_output'])(prefill_attn_origin) + setattr(paged_attn_module, func_str, prefill_attn_registered) + + # apply_rotary_pos_emb + def apply_rotary_emb_abstract_impl(q, k, cos, sin, q_out, k_out): + result = [q, k] + if q_out is not None: + result[0] = q_out + if k_out is not None: + result[1] = k_out + return tuple(result) + + module_str = 'apply_rotary_emb' + apply_rotary_emb_module = getattr(dlinfer_backends_module, module_str) + func_str = 'apply_rotary_pos_emb' + apply_rotary_pos_emb_origin = getattr(apply_rotary_emb_module, + func_str) + apply_rotary_pos_emb_registered = register_custom_op( + f'lmdeploy::{func_str}', + impl_abstract_func=apply_rotary_emb_abstract_impl)( + apply_rotary_pos_emb_origin) + setattr(apply_rotary_emb_module, func_str, + apply_rotary_pos_emb_registered) + + def patch_kvcache_static_shape(self): + import torch._dynamo as dynamo + from torch.utils._pytree import tree_map + cache_engine_module = import_module( + 'lmdeploy.pytorch.engine.cache_engine') + class_str = 'CacheEngine' + cache_engine_class = getattr(cache_engine_module, class_str) + func_str = 'allocate_gpu_cache' + allocate_gpu_cache_origin = getattr(cache_engine_class, func_str) + + def allocate_gpu_cache_mark_static(self): + gpu_cache = allocate_gpu_cache_origin(self) + tree_map(lambda x: dynamo.mark_static(x), gpu_cache) + return gpu_cache + + setattr(cache_engine_class, func_str, allocate_gpu_cache_mark_static) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 065e39b42..79e528836 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -3,6 +3,7 @@ import torch +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig from lmdeploy.utils import get_logger from ..op_backend import DlinferOpsBackend @@ -12,6 +13,9 @@ class AscendOpsBackend(DlinferOpsBackend): """ascend layer backend.""" + enable_graph = False + half_negative_inf = torch.finfo(torch.float16).min + total_slots = None @staticmethod def get_name() -> str: @@ -45,21 +49,23 @@ def get_v_block_shape( @classmethod def update_step_context(cls, step_context): """update step context.""" + + def get_total_slots(): + if cls.total_slots is None: + cls.total_slots = torch.arange( + block_num * block_size, + dtype=torch.long, + device=step_context.block_offsets.device) + cls.total_slots = cls.total_slots.view(block_num, block_size) + return cls.total_slots + kv_start_indices, attention_mask = [], [] block_num, block_size, _ = step_context.kv_caches[0][0].shape - device = step_context.block_offsets.device - is_unpaged_prefill = False if not step_context.is_decoding: is_unpaged_prefill = \ all((step_context.q_seqlens == step_context.kv_seqlens).tolist()) - - total_slots = torch.arange(block_num * block_size, - dtype=torch.long, - device=device) - total_slots = total_slots.view(block_num, block_size) - q_seqlens_list = step_context.q_seqlens.tolist() kv_seqlens_list = step_context.kv_seqlens.tolist() max_q_seq_len = max(q_seqlens_list) @@ -71,9 +77,9 @@ def update_step_context(cls, step_context): # collect kv start indices. history_length = kv_seq_len - q_seq_len - slot_tables = total_slots[step_context.block_offsets[i]].flatten() - slot_indices = [p for p in range(history_length, kv_seq_len)] - slots = slot_tables[slot_indices].reshape((-1, 1)) + total_slots = get_total_slots() + slot_tables = total_slots[step_context.block_offsets[i]].view(-1) + slots = slot_tables[history_length:kv_seq_len] kv_start_indices.append(slots) # collect attention mask of paged_prefill attention stage. @@ -83,7 +89,8 @@ def update_step_context(cls, step_context): torch.ones(q_seq_len, step_context.block_offsets.shape[1] * block_size, - dtype=torch.bool).cuda(), + dtype=torch.bool, + device=step_context.block_offsets.device), diagonal=kv_seq_len - q_seq_len, )) attention_mask.append(single_attention_mask) @@ -91,11 +98,10 @@ def update_step_context(cls, step_context): kv_start_indices = torch.cat(kv_start_indices) if step_context.is_decoding: - # prepare somae params of paged_decode attention stage. + # prepare some params of paged_decode attention stage. q_start_loc_cpu, q_seqlens_cpu = None, None - kv_seqlens_cpu = step_context.kv_seqlens.cpu() elif is_unpaged_prefill: - # prepare somae params of unpaged_prefill attention stage. + # prepare some params of unpaged_prefill attention stage. q_start_loc_cpu, kv_seqlens_cpu = None, None q_seqlens_cpu = step_context.q_seqlens.cpu() single_attention_mask = torch.logical_not( @@ -106,16 +112,46 @@ def update_step_context(cls, step_context): )) attention_mask.append(single_attention_mask) else: - # prepare somae params of paged_prefill attention stage. + # prepare some params of paged_prefill attention stage. q_start_loc_cpu, q_seqlens_cpu = None, None - kv_seqlens_cpu = step_context.kv_seqlens.repeat_interleave( - step_context.q_seqlens, 0).cpu() - block_offsets_int32 = step_context.block_offsets.to(torch.int32) - step_context.block_offsets = block_offsets_int32.repeat_interleave( - step_context.q_seqlens, 0) - attention_mask = [ - torch.cat([mask for mask in attention_mask]).unsqueeze(1) - ] + attention_mask = [torch.cat([mask for mask in attention_mask])] + + if cls.enable_graph: + kv_start_indices = kv_start_indices.view(-1).to(torch.int32) + import torch._dynamo as dynamo + if not is_unpaged_prefill: + step_context.block_offsets = step_context.block_offsets.to( + torch.int32) + if not step_context.is_decoding: + step_context.block_offsets = step_context.block_offsets\ + .repeat_interleave(step_context.q_seqlens, 0) + dynamo.mark_dynamic(step_context.block_offsets, [0, 1]) + kv_seqlens = step_context.kv_seqlens.to(torch.int32) + if not step_context.is_decoding: + if is_unpaged_prefill: + attention_mask = [mask.half() for mask in attention_mask] + else: + attention_mask = [ + torch.cat([ + mask.half() * cls.half_negative_inf + for mask in attention_mask + ]).unsqueeze(1) + ] + kv_seqlens = kv_seqlens.repeat_interleave( + step_context.q_seqlens, 0) + else: + if step_context.is_decoding: + kv_seqlens_cpu = step_context.kv_seqlens.cpu() + elif is_unpaged_prefill: + pass + else: + kv_seqlens_cpu = step_context.kv_seqlens.repeat_interleave( + step_context.q_seqlens, 0).cpu() + block_offsets_int32 = step_context.block_offsets.to( + torch.int32) + step_context.block_offsets = block_offsets_int32\ + .repeat_interleave(step_context.q_seqlens, 0) + kv_seqlens = kv_seqlens_cpu attn_meta_cls = cls.get_attention_metadata_cls() attn_metadata = attn_meta_cls( @@ -123,7 +159,7 @@ def update_step_context(cls, step_context): step_context.block_offsets, q_start_loc=q_start_loc_cpu, q_seqlens=q_seqlens_cpu, - kv_seqlens=kv_seqlens_cpu, + kv_seqlens=kv_seqlens, kv_start_indices=kv_start_indices, block_size=block_size, attention_mask=attention_mask, @@ -134,3 +170,16 @@ def update_step_context(cls, step_context): step_context.attn_metadata = attn_metadata return step_context + + @staticmethod + def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, + cache_config: CacheConfig, + backend_config: BackendConfig, + device: torch.device): + """build graph runner.""" + from .graph_runner import AscendGraphRunner + ascend_graph_runner = AscendGraphRunner(model, model_config, + cache_config, backend_config, + device) + AscendOpsBackend.enable_graph = ascend_graph_runner.enable_graph + return ascend_graph_runner diff --git a/lmdeploy/pytorch/backends/dlinfer/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/op_backend.py index 124633f85..031f51fdc 100644 --- a/lmdeploy/pytorch/backends/dlinfer/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/op_backend.py @@ -28,6 +28,9 @@ def get_layer_impl_builder(cls, layer_type: OpType): elif layer_type == OpType.ApplyRotaryEmb: from .apply_rotary_emb import DlinferApplyRotaryEmbBuilder return DlinferApplyRotaryEmbBuilder + elif layer_type == OpType.SiluAndMul: + from .activation import DlinferSiluAndMulBuilder + return DlinferSiluAndMulBuilder elif layer_type == OpType.RMSNorm: from .norm import DlinferRMSNormBuilder return DlinferRMSNormBuilder @@ -40,6 +43,9 @@ def get_layer_impl_builder(cls, layer_type: OpType): elif layer_type == OpType.LinearW4A16: from .awq_modules import AwqLinearW4A16Builder return AwqLinearW4A16Builder + elif layer_type == OpType.RotaryEmbedding: + from .rotary_embedding import DlinferRotaryEmbeddingBuilder + return DlinferRotaryEmbeddingBuilder else: logger.debug( f'Op {layer_type} fallback to default implementation.') diff --git a/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py b/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py new file mode 100644 index 000000000..e97c9d133 --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn + +from ..default.rotary_embedding import (Llama3RotaryEmbeddingImpl, + LlamaDynamicNTKScalingRotaryEmbedding) +from ..rotary_embedding import (Llama3Parameters, LongRoPEScalingParameters, + RopeType, RotaryEmbeddingBuilder, + RotaryEmbeddingImpl, YarnParameters) + + +class DlinferRotaryEmbeddingImpl(RotaryEmbeddingImpl, nn.Module): + """base rotary embedding.""" + + def __init__(self, + dim: int, + base: int = 10000, + scaling_factor: float = 1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.base = base + inv_freq = 1.0 / (self.base**( + torch.arange(0, self.dim, 2, dtype=torch.int64).float() / + self.dim)).float().cuda() + self.register_buffer('inv_freq', inv_freq, persistent=False) + + def forward(self, x, position_ids): + """forward.""" + # x: [bs, num_attention_heads, seq_len, head_size] + if self.inv_freq.device != x.device: + self.inv_freq = self.inv_freq.to(x.device) + + if self.scaling_factor != 1.0: + position_ids = position_ids.float() / self.scaling_factor + else: + position_ids = position_ids.float() + + inv_freq_expanded = self.inv_freq.view(1, -1, 1) + position_ids_expanded = position_ids.unsqueeze(1) + + # # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance( + device_type, str) and device_type != 'mps' else 'cpu' + inv_freq_expanded = inv_freq_expanded + position_ids_expanded = position_ids_expanded + tmp = torch.bmm(inv_freq_expanded, position_ids_expanded) + freqs = tmp.transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class DlinferRotaryEmbeddingBuilder(RotaryEmbeddingBuilder): + """rotary embedding builder.""" + + @staticmethod + def build( + dim: int, + max_position_embeddings: int = 2048, + base: int = 10000, + scaling_factor: float = 1.0, + yarn_params: YarnParameters = None, + longrope_params: LongRoPEScalingParameters = None, + llama3_params: Llama3Parameters = None, + emb_type: RopeType = RopeType.Default, + ): + """build.""" + if emb_type in (RopeType.Default, RopeType.LinearScaling): + return DlinferRotaryEmbeddingImpl(dim, base, scaling_factor) + elif emb_type == RopeType.DynamicNTKScaling: + return LlamaDynamicNTKScalingRotaryEmbedding( + dim, base, scaling_factor, max_position_embeddings) + elif emb_type == RopeType.Llama3: + return Llama3RotaryEmbeddingImpl(dim, base, scaling_factor, + llama3_params.low_freq_factor, + llama3_params.high_freq_factor, + max_position_embeddings) + else: + raise NotImplementedError( + f'Unsupported embedding type: {emb_type}') diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 2ee2eaced..44eb25a8c 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -336,6 +336,7 @@ def __call__(self, all_ids: torch.LongTensor, guided_input_ids, self.tokenizer) return scores + @torch.inference_mode() def sampling(self, logits: torch.Tensor): """sampling.""" sampling_inputs = self.sampling_inputs diff --git a/lmdeploy/pytorch/kernels/dlinfer/activation.py b/lmdeploy/pytorch/kernels/dlinfer/activation.py new file mode 100644 index 000000000..b862fdfb8 --- /dev/null +++ b/lmdeploy/pytorch/kernels/dlinfer/activation.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dlinfer.ops as ext_ops +from torch import Tensor + + +def silu_and_mul(input_tensor: Tensor, ) -> Tensor: + return ext_ops.silu_and_mul(input_tensor) diff --git a/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py index e67cfda23..0f13f3f38 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py +++ b/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + import dlinfer.ops as ext_ops from torch import Tensor @@ -8,9 +10,9 @@ def apply_rotary_pos_emb( key_states: Tensor, cos: Tensor, sin: Tensor, - q_embed: Tensor = None, - k_embed: Tensor = None, -): + q_embed: Optional[Tensor], + k_embed: Optional[Tensor], +) -> Tuple[Tensor, Tensor]: query_states = query_states.contiguous() key_states = key_states.contiguous() bs = query_states.shape[0] diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index c8fc4e90e..21c72074a 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -19,7 +19,7 @@ def prefill_attention( block_size: int, attn_mask: Sequence[Optional[Tensor]], is_unpaged_prefill: Optional[bool], -): +) -> Tensor: num_q_heads = query_states.shape[1] num_kv_heads = value_states.shape[1]