Skip to content

Commit 9e2342b

Browse files
authored
[Hotfix] Fix bugs in testing continuous batching (#5270)
* fix bug * fix bugs * fix bugs * fix bugs and add padding * add funcs and fix bugs * fix typos * fix bugs * add func
1 parent 5ae9099 commit 9e2342b

File tree

6 files changed

+86
-23
lines changed

6 files changed

+86
-23
lines changed

colossalai/inference/core/request_handler.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def ready_for_prefill(self):
5757
def is_empty(self):
5858
return not self.decoding and not self.prefill
5959

60+
def total_seq_num(self):
61+
return len(self.decoding) + len(self.prefill)
62+
6063

6164
class RequestHandler:
6265
"""
@@ -105,7 +108,13 @@ def schedule(self):
105108
f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence."
106109
)
107110
self.abort_sequence(seq.request_id)
111+
remove_list.append(seq)
108112
break
113+
114+
# stop feeding new sequence into running list to assure
115+
if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num():
116+
break
117+
109118
# Try to allocate cache blocks for the sequence.
110119
if (
111120
self.cache_manager.check_allocation(seq)
@@ -115,7 +124,7 @@ def schedule(self):
115124
# If succeed, add the sequence to running list.
116125
remove_list.append(seq)
117126
self.running_list.append(seq)
118-
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
127+
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len)
119128
for seq in remove_list:
120129
lst.remove(seq)
121130
if self.running_list.ready_for_prefill():
@@ -126,7 +135,13 @@ def schedule(self):
126135

127136
if not self.running_batch.is_empty:
128137
for seq in self.running_batch.sequences_set:
129-
self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
138+
recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
139+
if recycle:
140+
seq.recycle()
141+
self.running_batch.del_seq(seq)
142+
self.running_list.remove(seq)
143+
self.waiting_list[-1].append(seq)
144+
# the recycled sequences are handled with highest priority.
130145

131146
return self.running_batch
132147

colossalai/inference/modeling/layers/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0):
6969
)
7070
padding = seq_len - _cache.size(0)
7171
if padding > 0:
72-
_cache = F.pad(_cache, (0, 0, 0, 0, 0, 1), value=pad_id)
72+
_cache = F.pad(_cache, (0, 0, 0, 0, 0, padding), value=pad_id)
7373
padded_cache.append(_cache)
7474
return torch.stack(padded_cache, dim=0)
7575

colossalai/inference/modeling/models/llama.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,10 @@ def llama_attn_forward(
173173
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
174174
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
175175

176+
kv_seq_len = max(sequence_lengths).item()
177+
176178
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
179+
177180
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
178181

179182
query_states = query_states.transpose(1, 2)

colossalai/inference/struct.py

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ class RequestStatus(enum.Enum):
2929
COMPLETED = enum.auto()
3030
LENGTH_CAPPED = enum.auto()
3131

32+
# recycle status
33+
RECYCLED = enum.auto()
34+
3235
@staticmethod
3336
def is_finished(status: "RequestStatus") -> bool:
3437
return status in [
@@ -119,7 +122,9 @@ def mark_running(self) -> None:
119122
"""
120123
Set status for prefill reqs.
121124
"""
122-
assert self.status == RequestStatus.WAITING, "Sequence is not in WAITTING STATUS"
125+
assert (
126+
self.status == RequestStatus.WAITING or RequestStatus.RECYCLED
127+
), "Sequence is not in WAITTING/RECYCLED STATUS"
123128
self.status = RequestStatus.RUNNING
124129

125130
def mark_finished(self) -> None:
@@ -139,10 +144,10 @@ def recycle(self) -> None:
139144
Recycle a running sequnce to waiitting list
140145
"""
141146
assert (
142-
not self.status.is_finished and not self.status == RequestStatus.ABORTED
147+
not self.check_finish() and not self.status == RequestStatus.ABORTED
143148
), "The running sequence \
144149
is already done but it still in running list"
145-
self.status = RequestStatus.WAITING
150+
self.status = RequestStatus.RECYCLED
146151

147152
def __repr__(self) -> str:
148153
return (
@@ -162,7 +167,7 @@ class BatchInfo:
162167
Information to be passed and used for a batch of sequences.
163168
"""
164169

165-
sequences_set: OrderedSet["Sequence"] = None
170+
sequences_set: OrderedSet[Sequence] = None
166171
is_prompts: bool = True
167172
device: torch.device = None
168173

@@ -207,12 +212,20 @@ def get_block_table_tensor(self) -> None:
207212

208213
def clear_batch(self) -> None:
209214
"""
210-
Clear sequence set and block table.
215+
Clear sequence set and block table if we need to abort this batch.
216+
Prefill: clear sequence set and move them to running batch(external)
217+
Decoding: mark unfinished sequences as aborted.
211218
"""
212-
for seq in self.sequences_set:
213-
if not seq.check_finish():
214-
seq.status = RequestStatus.ABORTED
215-
self.sequences_set.clear()
219+
if self.is_prompts:
220+
self.sequences_set.clear()
221+
222+
else:
223+
for seq in self.sequences_set:
224+
seq.mark_aborted()
225+
if seq.check_finish():
226+
seq.mark_finished()
227+
228+
self.sequences_set.clear()
216229

217230
def fliter_batch(self) -> List["Sequence"]:
218231
"""
@@ -255,6 +268,12 @@ def add_seqs(self, seqs: List["Sequence"]) -> None:
255268
continue
256269
self.sequences_set.add(seq)
257270

271+
def del_seq(self, seq: Sequence) -> Sequence:
272+
"""
273+
Delete sequence in batch
274+
"""
275+
self.sequences_set.discard(seq)
276+
258277
@property
259278
def is_empty(self) -> None:
260279
"""
@@ -297,11 +316,19 @@ def get_batch_inputs(self) -> torch.LongTensor:
297316

298317
for seq in self.sequences_set:
299318
if self.is_prompts:
300-
input_list.append(seq.input_token_id)
319+
if seq.output_len > 0:
320+
print(seq.output_token_id)
321+
seq_data = seq.input_token_id + seq.output_token_id
322+
print(seq_data)
323+
input_list.append(seq.input_token_id + seq.output_token_id)
324+
else:
325+
input_list.append(seq.input_token_id)
301326
else:
302327
input_list.append([seq.output_token_id[-1]])
303328

304-
return torch.tensor(input_list, dtype=torch.long, device=self.device)
329+
max_seq_len = max(len(sub_list) for sub_list in input_list)
330+
331+
return _make_tensor_with_pad(input_list, max_seq_len, 0, dtype=torch.int)
305332

306333
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
307334
"""
@@ -340,12 +367,27 @@ def get_attn_mask(self, padding_id: int) -> torch.Tensor:
340367
for seq in self.sequences_set:
341368
past_values.append(seq.input_token_id + seq.output_token_id)
342369

343-
attn_mask = torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long()
370+
max_seq_len = max(len(sub_list) for sub_list in past_values)
371+
attn_mask = _make_tensor_with_pad(past_values, max_seq_len, 0, dtype=torch.int, device=self.device)
344372

345-
if torch.any(attn_mask == 0):
346-
return attn_mask
347-
else:
348-
return None
373+
return attn_mask.ne(padding_id).long()
349374

350375
def __repr__(self) -> str:
351376
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
377+
378+
379+
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
380+
assert len(x) <= max_len
381+
return x + [pad] * (max_len - len(x))
382+
383+
384+
def _make_tensor_with_pad(
385+
x: Union[List[List[int]], List[int]],
386+
max_len: int,
387+
pad: int,
388+
dtype: torch.dtype,
389+
device: Union[str, torch.device] = "cuda",
390+
pin_memory: bool = False,
391+
):
392+
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
393+
return torch.tensor(padded_x, dtype=dtype, device=device, pin_memory=pin_memory and str(device) == "cpu")

examples/inference/benchmark_llama.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,10 @@ def benchmark_inference(args):
9595

9696
if args.dtype == "fp16":
9797
model = model.half()
98-
elif args.dtype == "bf16":
98+
elif args.dtype == "fp16":
9999
model = model.to(torch.bfloat16)
100100

101-
# mbsz = args.mbsz
102-
mbsz = args.batch_size
101+
mbsz = args.mbsz
103102
if args.mode == "caiinference":
104103
inference_config = InferenceConfig(
105104
dtype=args.dtype,

tests/test_infer/test_config_and_struct.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import colossalai
44
from colossalai.inference.config import InferenceConfig
5-
from colossalai.inference.struct import BatchInfo, Sequence
5+
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
66
from colossalai.testing import rerun_if_address_is_in_use, spawn
77

88

@@ -41,6 +41,10 @@ def check_config_and_inference():
4141
eos_token_id=2,
4242
max_output_len=256,
4343
)
44+
sequence.mark_running()
45+
assert sequence.status == RequestStatus.RUNNING
46+
sequence.recycle()
47+
assert sequence.status == RequestStatus.RECYCLED
4448

4549
assert sequence.sentence_len == 3
4650
assert sequence.input_len == 3

0 commit comments

Comments
 (0)