|
1 | 1 | # -*- coding: utf-8 -*-
|
2 | 2 | # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
3 | 3 |
|
| 4 | +from typing import Optional |
| 5 | + |
4 | 6 | import torch
|
5 | 7 | import torch.nn.functional as F
|
6 | 8 | import triton
|
@@ -43,8 +45,36 @@ def prepare_lens_from_mask(mask: torch.BoolTensor) -> torch.LongTensor:
|
43 | 45 |
|
44 | 46 |
|
45 | 47 | @tensor_cache
|
46 |
| -def prepare_cu_seqlens_from_mask(mask: torch.BoolTensor, out_dtype: torch.dtype = torch.int32) -> torch.LongTensor: |
47 |
| - return F.pad(prepare_lens_from_mask(mask).cumsum(dim=0, dtype=out_dtype), (1, 0)) |
| 48 | +def prepare_cu_seqlens_from_mask( |
| 49 | + mask: torch.BoolTensor, |
| 50 | + dtype: Optional[torch.dtype] = torch.int32 |
| 51 | +) -> torch.LongTensor: |
| 52 | + return F.pad(prepare_lens_from_mask(mask).cumsum(dim=0, dtype=dtype), (1, 0)) |
| 53 | + |
| 54 | + |
| 55 | +@tensor_cache |
| 56 | +def prepare_split_cu_seqlens( |
| 57 | + batch_size: int, |
| 58 | + seq_len: int, |
| 59 | + split_size: int, |
| 60 | + cu_seqlens: Optional[torch.LongTensor] = None, |
| 61 | + dtype: Optional[torch.dtype] = torch.int32, |
| 62 | + device: Optional[torch.device] = torch.device('cpu') |
| 63 | +) -> torch.LongTensor: |
| 64 | + if cu_seqlens is None: |
| 65 | + total_tokens = batch_size * seq_len |
| 66 | + cu_seqlens = list(range(0, total_tokens, seq_len)) + [total_tokens] |
| 67 | + else: |
| 68 | + cu_seqlens = cu_seqlens.tolist() |
| 69 | + return torch.tensor( |
| 70 | + [ |
| 71 | + i |
| 72 | + for bos, eos in zip(cu_seqlens[:-1], cu_seqlens[1:]) |
| 73 | + for i in range(bos, eos, split_size) |
| 74 | + ] + [cu_seqlens[-1]], |
| 75 | + dtype=dtype, |
| 76 | + device=device |
| 77 | + ) |
48 | 78 |
|
49 | 79 |
|
50 | 80 | @tensor_cache
|
|
0 commit comments