Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 107 additions & 8 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch
import torch.distributed as dist
from transformers.generation import GenerationConfig

from colossalai.inference.flash_decoding_utils import FDIntermTensors
Expand All @@ -30,8 +30,25 @@
}


class RPC_PARAM(ABC):
"""
NOTE(lry89757) We use rpyc to transport param between client and server.
Rpyc only support the type of `POD` in python as the param, so we should take some smart ways to transport the data like tensor or some sophisticated classes.
Drawing on the logic of `__setstate__`, `__getstate__`, we will let some classes(will be rpc param later) inherit this base class, and rewrite the to_rpc_param and from_rpc_param. We will invoke `to_rpc_param` in client to pass the params and recover the param in server side by `from_rpc_param`.
"""

@abstractmethod
def to_rpc_param(self):
return NotImplementedError

@staticmethod
@abstractmethod
def from_rpc_param():
return NotImplementedError


@dataclass
class InputMetaData:
class InputMetaData(RPC_PARAM):
"""The input info for a single step

Args:
Expand All @@ -48,6 +65,7 @@ class InputMetaData:
dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32.
use_spec_dec (bool): Indicate whether to use speculative decoding.
num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True.
batch_token_ids (List[List[int]], optional): input_token_ids + output_token_ids of current batch. Only used for `repetition_penalty`, `no_repeat_ngram_size` in sampler process.
"""

block_tables: torch.Tensor = None
Expand All @@ -63,6 +81,54 @@ class InputMetaData:
dtype: torch.dtype = torch.float32
use_spec_dec: bool = False
num_tokens_to_verify: int = 0
batch_token_ids: Optional[
List[List[int]]
] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process

def to_rpc_param(self) -> Dict[str, any]:
return {
"block_tables": self.block_tables.tolist(),
"sequence_lengths": self.sequence_lengths.tolist(),
"batch_size": self.batch_size,
"is_prompts": self.is_prompts,
"use_cuda_kernel": self.use_cuda_kernel,
"use_cuda_graph": self.use_cuda_graph,
"kv_seq_len": self.kv_seq_len,
"head_dim": self.head_dim,
"high_precision": self.high_precision,
"dtype": str(self.dtype).split(".")[-1],
"use_spec_dec": self.use_spec_dec,
"num_tokens_to_verify": self.num_tokens_to_verify,
"batch_token_ids": self.batch_token_ids,
}

@staticmethod
def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData":
"""
We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message
"""
from colossalai.accelerator import get_accelerator

dtype = getattr(torch, rpc_dict["dtype"])
return InputMetaData(
block_tables=torch.tensor(
rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device()
),
sequence_lengths=torch.tensor(
rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device()
),
batch_size=rpc_dict["batch_size"],
is_prompts=rpc_dict["is_prompts"],
use_cuda_kernel=rpc_dict["use_cuda_kernel"],
use_cuda_graph=rpc_dict["use_cuda_graph"],
kv_seq_len=rpc_dict["kv_seq_len"],
head_dim=rpc_dict["head_dim"],
high_precision=rpc_dict["high_precision"],
dtype=dtype,
use_spec_dec=rpc_dict["use_spec_dec"],
num_tokens_to_verify=rpc_dict["num_tokens_to_verify"],
batch_token_ids=rpc_dict["batch_token_ids"],
)

def __repr__(self) -> str:
return (
Expand All @@ -80,7 +146,7 @@ def __repr__(self) -> str:


@dataclass
class InferenceConfig:
class InferenceConfig(RPC_PARAM):
"""The inference configuration.

Args:
Expand Down Expand Up @@ -193,10 +259,6 @@ def _verify_config(self) -> None:
if self.dtype == torch.float32:
self.high_precision = False

# check distributed
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
# check prompt template
if self.prompt_template is None:
return
Expand Down Expand Up @@ -226,6 +288,43 @@ def to_generation_config(self, model_config) -> GenerationConfig:

return GenerationConfig.from_dict(meta_config)

def to_rpc_param(self) -> dict:
kwargs = {
"dtype": str(self.dtype).split(".")[-1],
"max_n_spec_tokens": self.max_n_spec_tokens,
"max_batch_size": self.max_batch_size,
"max_input_len": self.max_input_len,
"max_output_len": self.max_output_len,
"tp_size": self.tp_size,
"pp_size": self.pp_size,
"pad_input": self.pad_input,
"early_stopping": self.early_stopping,
"do_sample": self.do_sample,
"beam_width": self.beam_width,
"kv_cache_dtype": str(self.kv_cache_dtype).split(".")[-1],
}
return kwargs

@staticmethod
def from_rpc_param(rpc_dict: dict) -> "InferenceConfig":
"""
We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message
"""
return InferenceConfig(
dtype=getattr(torch, rpc_dict["dtype"]),
max_n_spec_tokens=rpc_dict["max_n_spec_tokens"],
max_batch_size=rpc_dict["max_batch_size"],
max_input_len=rpc_dict["max_input_len"],
max_output_len=rpc_dict["max_output_len"],
tp_size=rpc_dict["tp_size"],
pp_size=rpc_dict["pp_size"],
pad_input=rpc_dict["pad_input"],
early_stopping=rpc_dict["early_stopping"],
do_sample=rpc_dict["do_sample"],
beam_width=rpc_dict["beam_width"],
kv_cache_dtype=getattr(torch, rpc_dict["kv_cache_dtype"], None),
)

@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
# Get the list of attributes of this dataclass.
Expand Down
17 changes: 14 additions & 3 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence
from colossalai.inference.utils import get_model_size, has_index_file
Expand Down Expand Up @@ -424,7 +425,7 @@ def steps_spec_dec(self) -> List[Sequence]:

# 2. Prefill main model (Verifier) - fill past kv cache for main model
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# append new inputs to the batch, temporarily
batch.append_batch_tokens(next_tokens)
self.request_handler.allocate_batch_spec_dec(batch, 1)
Expand Down Expand Up @@ -472,7 +473,7 @@ def steps_spec_dec(self) -> List[Sequence]:
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)

next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)

# 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
Expand Down Expand Up @@ -689,6 +690,13 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor,
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)

batch_token_ids = None
config_dict = self.generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
batch_token_ids = batch.batch_token_ids

# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
Expand All @@ -708,6 +716,7 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor,
dtype=batch.dtype,
use_spec_dec=batch.use_spec_dec,
num_tokens_to_verify=batch.num_tokens_to_verify,
batch_token_ids=batch_token_ids,
)

return input_ids, output_tensor, input_meta_data
Expand Down Expand Up @@ -738,7 +747,9 @@ def step(self) -> List[str]:
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
next_tokens = search_tokens(
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()

Expand Down
95 changes: 54 additions & 41 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.logging import get_dist_logger

logger = get_dist_logger(__name__)

__all__ = ["RunningList", "RequestHandler"]

Expand Down Expand Up @@ -295,17 +296,6 @@ def _find_sequence(self, request_id: int) -> Sequence:

return None

def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig):
if generation_config.num_beams == 1:
if generation_config.do_sample:
sample_tokens = multinomial_sample(generation_config, probs)
else:
sample_tokens = greedy_sample(generation_config, logprobs)
else:
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty)

return sample_tokens

def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig):
if (
sequence.output_token_id[-1] == generation_config.eos_token_id
Expand All @@ -328,33 +318,6 @@ def check_unfinished_seqs(self) -> bool:
def total_requests_in_batch_bucket(self) -> int:
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size

def search_tokens(self, generation_config: GenerationConfig, logits, cur_batch: BatchBucket):
"""
Sample tokens for finished requests.
"""

# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
# process repetition_penalty, no_repeat_ngram_size
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type], cur_batch)

# do logit processor
if generation_config.do_sample:
# process temperature, top_k, top_p
for type in ["temperature", "top_k", "top_p"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type])

# calculate probs
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

# sample the next tokens
sample_tokens = self._sample(probs, logprobs, generation_config)
return sample_tokens

def append_next_tokens(self, sample_tokens: torch.Tensor):
assert sample_tokens.dim() == 1
n_elements = sample_tokens.size(0)
Expand Down Expand Up @@ -386,3 +349,53 @@ def update(self):
self.done_list.extend(finished_seqs)

return finished_seqs


class RPCRequestHandler(RequestHandler):
"""
RPC Version of request handler
"""

def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
self.inference_config = inference_config
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
self.waiting_list: List[List] = [[], [], []]
self.done_list: List[Sequence] = []
self.dtype = inference_config.dtype
self.max_batch_size = inference_config.max_batch_size

# initialize cache
self._init_cache(model_config)

# initialize batch
torch.cuda.current_device()
kv_max_split_num = (
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
) // inference_config.block_size
head_dim = model_config.hidden_size // model_config.num_attention_heads

# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_bb = BatchBucket(
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
block_size=inference_config.block_size,
kv_max_split_num=kv_max_split_num,
fd_interm_tensor=None,
dtype=self.dtype,
)
self.prefill_bb = BatchBucket(
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
block_size=inference_config.block_size,
kv_max_split_num=kv_max_split_num,
fd_interm_tensor=None,
dtype=self.dtype,
)

def _init_cache(self, model_config):
self.cache_manager = RPCKVCacheManager(self.inference_config, model_config)
Loading