diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index afa350330c..e06e0cf80a 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -782,8 +782,7 @@ def __update_inputs(next_token_ids): logger.debug(': ' f'batch_size={inputs.seq_length.size(0)} ' f'num_tokens={inputs.input_ids.size(-1)}') - if self.gpu_count == 1: - inputs = inputs.to_device('cuda') + inputs = inputs.to_device('cuda') is_decoding = inputs.is_decoding if all_ids is not None: all_ids = all_ids.cuda() diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 999fa135cc..7df0eeb021 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -135,7 +135,6 @@ def model_forward( stream = stream or torch.cuda.current_stream() with torch.cuda.stream(stream): # forward - inputs = inputs.to_device('cuda') ctx_mgr = model.ctx_mgr context = ctx_mgr.build_context( inputs=inputs, @@ -372,14 +371,26 @@ def _broadcast_config(cache_config): return patched_model, cache_engine, cache_config -def _broadcast_inputs(rank: int, inputs: Any, stream: torch.cuda.Stream): +def _broadcast_inputs(rank: int, inputs: Any, group: dist.group, + stream: torch.cuda.Stream): """get input tensor parallel.""" # broadcast meta info if rank != 0: inputs = [None, None, None] + else: + device_inputs = inputs[0] + meta_inputs = device_inputs.to_device('meta') + inputs[0] = meta_inputs with torch.cuda.stream(stream): - dist.broadcast_object_list(inputs) + dist.broadcast_object_list(inputs, group=group) + if rank == 0: + device_inputs.broadcast() + else: + device_inputs = inputs[0].broadcast() + + inputs[0] = device_inputs + return inputs @@ -392,6 +403,7 @@ def _tp_model_loop( adapters: Dict[str, str], world_size: int, barrier: mp.Barrier, + cpu_group: dist.group, ): """Start model loops for tensor parallel model inference. @@ -417,11 +429,12 @@ def _tp_model_loop( while True: barrier.wait() inputs, swap_in_map, swap_out_map = _broadcast_inputs( - rank, None, stream) + rank, None, cpu_group, stream) cache_swapping(cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) + inputs = inputs.to_device('cuda') model_forward( patched_model, @@ -453,10 +466,13 @@ def _start_tp_process(proc_id: int, try: from lmdeploy.pytorch.check_env import check_env_deeplink check_env_deeplink(device_context.device_type) + timeout = timedelta(days=35600) dist.init_process_group('nccl', rank=rank, world_size=world_size, - timeout=timedelta(days=35600)) + timeout=timeout) + cpu_group = dist.new_group(timeout=timeout, backend='gloo') + kwargs['cpu_group'] = cpu_group dist_ctx = DistContext(rank=rank, world_size=world_size) torch.cuda.set_device(rank) with get_dist_manager().context(dist_ctx), get_device_manager( @@ -626,12 +642,15 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig, rank = 0 try: + timeout = timedelta(days=35600) dist.init_process_group('nccl', rank=rank, world_size=world_size, - timeout=timedelta(days=35600)) + timeout=timeout) + cpu_group = dist.new_group(timeout=timeout, backend='gloo') dist_ctx = DistContext(rank=rank, world_size=world_size) self._dist_ctx = dist_ctx + self._cpu_group = cpu_group except Exception as e: from traceback import print_exc logger.error(f'Rank[{rank}] failed.') @@ -673,7 +692,8 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, self.mp_bar.wait() rank = 0 _broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map], - self.stream) + self._cpu_group, self.stream) + cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) @@ -699,8 +719,6 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_in_map=swap_in_map, swap_out_map=swap_out_map) await asyncio.sleep(0) - while not self.stream.query(): - await asyncio.sleep(0) return output def get_logits(self, hidden_states: torch.Tensor): diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index b5b74e4f02..d10da8557a 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -4,12 +4,21 @@ from typing import Any, Dict, List, Literal import torch +from torch import distributed as dist from lmdeploy.pytorch.backends import get_backend from lmdeploy.pytorch.config import ModelConfig from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +def _broadcast_tensor(value: torch.Tensor, src: int = 0, device: str = 'cuda'): + """broadcast tensor.""" + if value.device.type == 'meta': + value = torch.empty_like(value, device=device) + dist.broadcast(value, src) + return value + + @dataclass class VisionModelInputs: """Vision model inputs.""" @@ -36,10 +45,45 @@ def to_device(self, device: str): elif k == 'input_embeddings': v = [[e.to(device) for e in li] for li in v] elif k == 'input_multimodals': + new_v = [] for mm_datas in v: + new_mm_datas = dict() for modal_type, data in mm_datas.items(): data = [d.to_device(device) for d in data] - mm_datas[modal_type] = data + new_mm_datas[modal_type] = data + new_v.append(new_mm_datas) + v = new_v + out_dict[k] = v + + return VisionModelInputs(**out_dict) + + def broadcast(self): + """broadcast inputs. + + Do `dist.broadcast_object_list(inputs.to_device('meta'))` + before broadcast tensors. + """ + out_dict = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if v is None: + continue + if isinstance(v, torch.Tensor): + v = _broadcast_tensor(v) + elif k == 'input_embedding_ranges': + v = [_broadcast_tensor(e) for e in v] + elif k == 'input_embeddings': + v = [[_broadcast_tensor(e) for e in li] for li in v] + elif k == 'input_multimodals': + new_v = [] + for mm_datas in v: + new_mm_datas = dict() + for modal_type, data in mm_datas.items(): + data = [d.broadcast() for d in data] + new_mm_datas[modal_type] = data + new_v.append(new_mm_datas) + v = new_v out_dict[k] = v return VisionModelInputs(**out_dict) @@ -202,6 +246,24 @@ def to_device(self, device: str): return ModelInputs(**out_dict) + def broadcast(self): + """broadcast inputs. + + Do `dist.broadcast_object_list(inputs.to_device('meta'))` + before broadcast tensors. + """ + out_dict = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if isinstance(v, torch.Tensor): + v = _broadcast_tensor(v) + elif isinstance(v, VisionModelInputs): + v = v.broadcast() + out_dict[k] = v + + return ModelInputs(**out_dict) + @dataclass class StepContext: diff --git a/lmdeploy/pytorch/multimodal/data_type.py b/lmdeploy/pytorch/multimodal/data_type.py index 95ec72d26e..886c7ffbd0 100644 --- a/lmdeploy/pytorch/multimodal/data_type.py +++ b/lmdeploy/pytorch/multimodal/data_type.py @@ -1,8 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from dataclasses import dataclass +from dataclasses import dataclass, fields from typing import Any, Dict, List, Union +import torch from torch import Tensor +from torch import distributed as dist class MultiModalData: @@ -14,6 +16,14 @@ class MultiModalData: NestedTensor = Union[Tensor, List[Tensor]] +def _broadcast_tensor(value: torch.Tensor, src: int = 0, device: str = 'cuda'): + """broadcast tensor.""" + if value.device.type == 'meta': + value = torch.empty_like(value, device=device) + dist.broadcast(value, src) + return value + + @dataclass class MultiModalTensor: data: NestedTensor @@ -28,24 +38,67 @@ def __post_init__(self): def to_device(self, device: str, non_blocking: bool = False): """to device.""" + out_dict = dict() + for f in fields(self): + k = f.name + if k in ('data', 'meta'): + continue + v = getattr(self, k) + out_dict[k] = v + if isinstance(self.data, Tensor): - self.data = self.data.to(device=device, non_blocking=non_blocking) + data = self.data.to(device=device, non_blocking=non_blocking) else: data = [ d.to(device=device, non_blocking=non_blocking) for d in self.data ] - self.data = data + out_dict['data'] = data + new_meta = None if self.meta is not None: + new_meta = dict() for k, v in self.meta.items(): if isinstance(v, Tensor): v = v.to(device=device, non_blocking=non_blocking) - self.meta[k] = v elif hasattr(v, 'to_device'): v = v.to_device(device=device, non_blocking=non_blocking) + new_meta[k] = v + + out_dict['meta'] = new_meta + return MultiModalTensor(**out_dict) + + def broadcast(self): + """broadcast inputs tensors.""" + out_dict = dict() + for f in fields(self): + k = f.name + if k in ('data', 'meta'): + continue + v = getattr(self, k) + out_dict[k] = v + + if isinstance(self.data, Tensor): + data = _broadcast_tensor(self.data) + else: + data = [_broadcast_tensor(d) for d in self.data] + out_dict['data'] = data + + new_meta = None + if self.meta is not None: + new_meta = dict() + for k, v in self.meta.items(): + if isinstance(v, Tensor): + v = _broadcast_tensor(v) + self.meta[k] = v + elif hasattr(v, 'to_device'): + assert hasattr(v, 'broadcast') + v = v.broadcast() self.meta[k] = v - return self + new_meta[k] = v + + out_dict['meta'] = new_meta + return MultiModalTensor(**out_dict) MultiModalInputs = Dict[str, List[MultiModalTensor]]