Skip to content

Commit 55cc7f3

Browse files
[Fix] Fix Inference Example, Tests, and Requirements (#5688)
* clean requirements * modify example inference struct * add test ci scripts * mark test_infer as submodule * rm deprecated cls & deps * import of HAS_FLASH_ATTN * prune inference tests to be run * prune triton kernel tests * increment pytest timeout mins * revert import path in openmoe
1 parent f9afe0a commit 55cc7f3

File tree

23 files changed

+46
-328
lines changed

23 files changed

+46
-328
lines changed

.github/workflows/build_on_pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ jobs:
9191
container:
9292
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
9393
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
94-
timeout-minutes: 60
94+
timeout-minutes: 75
9595
defaults:
9696
run:
9797
shell: bash

colossalai/inference/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ import colossalai
8181
from colossalai.inference import InferenceEngine, InferenceConfig
8282
from pprint import pprint
8383

84-
colossalai.launch_from_torch(config={})
84+
colossalai.launch_from_torch()
8585

8686
# Step 1: create a model in "transformers" way
8787
model_path = "lmsys/vicuna-7b-v1.3"

colossalai/inference/spec/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ from colossalai.inference.core.engine import InferenceEngine, GenerationConfig
2323
from colossalai.inference.modeling.models.glide_llama import GlideLlamaForCausalLM, GlideLlamaConfig
2424

2525
# launch colossalai, setup distributed environment
26-
colossalai.launch_from_torch(config={})
26+
colossalai.launch_from_torch()
2727

2828
# main model
2929
model_path_or_name = "REPLACE_TO_VICUNA_7B_PATH_OR_MODEL_CARD"

colossalai/inference/struct.py

Lines changed: 1 addition & 241 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
import enum
22
from dataclasses import dataclass
3-
from typing import Any, List, Tuple, Union
3+
from typing import Any, List
44

5-
import torch
6-
from ordered_set import OrderedSet
7-
8-
from colossalai.inference.flash_decoding_utils import FDIntermTensors
95
from colossalai.logging import get_dist_logger
106

117
logger = get_dist_logger(__name__)
@@ -170,242 +166,6 @@ def __repr__(self) -> str:
170166
)
171167

172168

173-
@dataclass
174-
class BatchInfo:
175-
"""
176-
Information to be passed and used for a batch of sequences.
177-
"""
178-
179-
max_batch_size: int
180-
kv_max_split_num: int
181-
num_heads: int
182-
head_dim: int
183-
sequences_set: OrderedSet[Sequence] = None
184-
is_prompts: bool = True
185-
device: torch.device = None
186-
dtype: torch.dtype = None
187-
fd_inter_tensor: FDIntermTensors = None
188-
189-
def __post_init__(self):
190-
if self.device is None:
191-
self.device = torch.cuda.current_device()
192-
if self.sequences_set is None:
193-
self.sequences_set = OrderedSet()
194-
if self.fd_inter_tensor is None:
195-
self.fd_inter_tensor = FDIntermTensors()
196-
197-
def init_fd_tensors(self):
198-
if not self.fd_inter_tensor.is_initialized:
199-
self.fd_inter_tensor.initialize(
200-
max_batch_size=self.max_batch_size,
201-
num_attn_heads=self.num_heads,
202-
kv_max_split_num=self.kv_max_split_num,
203-
head_dim=self.head_dim,
204-
dtype=self.dtype,
205-
device=self.device,
206-
)
207-
208-
def get_block_table_tensor(self) -> None:
209-
tesnor_list = []
210-
block_table = None
211-
212-
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
213-
214-
for seq in self.sequences_set:
215-
block_table = seq.block_table
216-
assert (
217-
block_table is not None
218-
), f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
219-
tesnor_list.append(seq.block_table)
220-
221-
block_table = torch.stack(tesnor_list)
222-
return block_table
223-
224-
def clear_batch(self) -> None:
225-
"""
226-
Clear sequence set and block table if we need to abort this batch.
227-
Prefill: clear sequence set and move them to running batch(external)
228-
Decoding: mark unfinished sequences as aborted.
229-
"""
230-
if self.is_prompts:
231-
self.sequences_set.clear()
232-
else:
233-
for seq in self.sequences_set:
234-
seq.mark_aborted()
235-
if seq.check_finish():
236-
seq.mark_finished()
237-
238-
self.sequences_set.clear()
239-
240-
def fliter_batch(self) -> List["Sequence"]:
241-
"""
242-
Remove completed sentences from a batch.
243-
244-
Returns:
245-
List["Sequence"]: List of finished sequences.
246-
"""
247-
finish_seqs = []
248-
for seq in self.sequences_set:
249-
if seq.check_finish():
250-
finish_seqs.append(seq)
251-
for finish_seq in finish_seqs:
252-
self.sequences_set.discard(finish_seq)
253-
return finish_seqs
254-
255-
def abort_seq(self, seq: "Sequence") -> "Sequence":
256-
"""
257-
Remove sequence from the batch.
258-
"""
259-
if not seq.check_finish():
260-
seq.status = RequestStatus.ABORTED
261-
self.sequences_set.discard(seq)
262-
return seq
263-
264-
def add_seqs(self, seqs: Union[Sequence, List[Sequence]]) -> None:
265-
"""
266-
Add new sequence to batch
267-
268-
Args:
269-
seqs (List["Sequence"]): The list of new sequences.
270-
"""
271-
# covnert single sequence to list
272-
if isinstance(seqs, Sequence):
273-
seqs = [seqs]
274-
275-
for seq in seqs:
276-
if seq in self.sequences_set:
277-
logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.")
278-
continue
279-
self.sequences_set.add(seq)
280-
281-
def del_seq(self, seq: Sequence) -> Sequence:
282-
"""
283-
Delete sequence in batch
284-
"""
285-
self.sequences_set.discard(seq)
286-
287-
@property
288-
def is_empty(self) -> None:
289-
"""
290-
Check whether sequences_set is empty.
291-
"""
292-
return not self.sequences_set
293-
294-
def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None:
295-
"""
296-
Add an output token for each sentence in the batch.
297-
298-
Args:
299-
tokens (List[int]): A batch of tokens
300-
"""
301-
302-
if isinstance(tokens, torch.Tensor):
303-
tokens = tokens.tolist()
304-
305-
assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size."
306-
307-
for seq, token in zip(self.sequences_set, tokens):
308-
if not isinstance(token, list):
309-
if not isinstance(token, int):
310-
raise TypeError(f"The token type must be List[int] or int, but got {type(token)}.")
311-
token = [token]
312-
seq.output_token_id += token
313-
seq.check_finish()
314-
315-
def get_batch_size(self) -> int:
316-
"""
317-
Get batch_size of this batch
318-
"""
319-
return len(self.sequences_set)
320-
321-
def get_batch_inputs(self) -> torch.LongTensor:
322-
"""
323-
Get bacth inputs for forward inference computation.
324-
"""
325-
326-
input_list = []
327-
328-
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
329-
330-
for seq in self.sequences_set:
331-
if self.is_prompts:
332-
if seq.output_len > 0:
333-
input_list.append(seq.input_token_id + seq.output_token_id)
334-
else:
335-
input_list.append(seq.input_token_id)
336-
else:
337-
input_list.append([seq.output_token_id[-1]])
338-
339-
max_seq_len = max(len(sub_list) for sub_list in input_list)
340-
341-
# We assume that all the padding_id in seq are the same at present.
342-
return _make_tensor_with_pad(input_list, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int)
343-
344-
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
345-
"""
346-
Flattening the input tokens.
347-
"""
348-
input_list = []
349-
350-
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
351-
352-
for seq in self.sequences_set:
353-
if self.is_prompts:
354-
input_list.extend(seq.input_token_id)
355-
else:
356-
input_list.append(seq.output_token_id[-1])
357-
358-
return torch.tensor(input_list, dtype=torch.long, device=self.device)
359-
360-
def get_sequence_lengths(self):
361-
"""
362-
Get the input_len of each sentence in this batch.
363-
"""
364-
len_list = []
365-
366-
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
367-
368-
for seq in self.sequences_set:
369-
len_list.append(seq.sentence_len)
370-
371-
return torch.tensor(len_list, dtype=torch.int, device=self.device)
372-
373-
def get_attn_mask(self) -> torch.Tensor:
374-
"""
375-
Generate and return attention mask.
376-
"""
377-
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
378-
379-
past_values = []
380-
# We assume that all the padding_id in seq are the same at present.
381-
padding_id = self.sequences_set[0].pad_token_id
382-
383-
for seq in self.sequences_set:
384-
past_values.append(seq.input_token_id + seq.output_token_id)
385-
386-
max_seq_len = max(len(sub_list) for sub_list in past_values)
387-
attn_mask = _make_tensor_with_pad(
388-
past_values, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int, device=self.device
389-
)
390-
391-
return attn_mask.ne(padding_id).long()
392-
393-
def __repr__(self) -> str:
394-
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
395-
396-
397169
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
398170
assert len(x) <= max_len
399171
return [pad] * (max_len - len(x)) + x
400-
401-
402-
def _make_tensor_with_pad(
403-
x: Union[List[List[int]], List[int]],
404-
max_len: int,
405-
pad: int,
406-
dtype: torch.dtype,
407-
device: Union[str, torch.device] = "cuda",
408-
pin_memory: bool = False,
409-
):
410-
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
411-
return torch.tensor(padded_x, dtype=dtype, device=device, pin_memory=pin_memory and str(device) == "cpu")

examples/inference/benchmark_ops/test_ci.sh

Whitespace-only changes.

examples/inference/benchmark_llama3.py renamed to examples/inference/llama/benchmark_llama3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def benchmark_inference(args):
182182

183183

184184
def inference(rank, world_size, port, args):
185-
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
185+
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
186186
benchmark_inference(args)
187187

188188

examples/inference/llama_generation.py renamed to examples/inference/llama/llama_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def infer(args):
1717
# ==============================
1818
# Launch colossalai, setup distributed environment
1919
# ==============================
20-
colossalai.launch_from_torch(config={})
20+
colossalai.launch_from_torch()
2121
coordinator = DistCoordinator()
2222

2323
# ==============================
@@ -59,7 +59,7 @@ def infer(args):
5959
coordinator.print_on_master(out[0])
6060

6161

62-
# colossalai run --nproc_per_node 1 llama_gen.py -m MODEL_PATH
62+
# colossalai run --nproc_per_node 1 llama_generation.py -m MODEL_PATH
6363
if __name__ == "__main__":
6464
# ==============================
6565
# Parse Arguments
File renamed without changes.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/bin/bash
2+
echo "Skip the test (this test is slow)"
3+
4+
# bash ./run_benchmark.sh

0 commit comments

Comments
 (0)