Skip to content

Commit 4360363

Browse files
committed
[Utils] Add length preparation fn with split_size option
1 parent f8e89a1 commit 4360363

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

fla/ops/utils/index.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4+
from typing import Optional
5+
46
import torch
57
import torch.nn.functional as F
68
import triton
@@ -43,8 +45,36 @@ def prepare_lens_from_mask(mask: torch.BoolTensor) -> torch.LongTensor:
4345

4446

4547
@tensor_cache
46-
def prepare_cu_seqlens_from_mask(mask: torch.BoolTensor, out_dtype: torch.dtype = torch.int32) -> torch.LongTensor:
47-
return F.pad(prepare_lens_from_mask(mask).cumsum(dim=0, dtype=out_dtype), (1, 0))
48+
def prepare_cu_seqlens_from_mask(
49+
mask: torch.BoolTensor,
50+
dtype: Optional[torch.dtype] = torch.int32
51+
) -> torch.LongTensor:
52+
return F.pad(prepare_lens_from_mask(mask).cumsum(dim=0, dtype=dtype), (1, 0))
53+
54+
55+
@tensor_cache
56+
def prepare_split_cu_seqlens(
57+
batch_size: int,
58+
seq_len: int,
59+
split_size: int,
60+
cu_seqlens: Optional[torch.LongTensor] = None,
61+
dtype: Optional[torch.dtype] = torch.int32,
62+
device: Optional[torch.device] = torch.device('cpu')
63+
) -> torch.LongTensor:
64+
if cu_seqlens is None:
65+
total_tokens = batch_size * seq_len
66+
cu_seqlens = list(range(0, total_tokens, seq_len)) + [total_tokens]
67+
else:
68+
cu_seqlens = cu_seqlens.tolist()
69+
return torch.tensor(
70+
[
71+
i
72+
for bos, eos in zip(cu_seqlens[:-1], cu_seqlens[1:])
73+
for i in range(bos, eos, split_size)
74+
] + [cu_seqlens[-1]],
75+
dtype=dtype,
76+
device=device
77+
)
4878

4979

5080
@tensor_cache

0 commit comments

Comments
 (0)