Skip to content

Commit 671d97f

Browse files
committed
[API] add default value for freeSeqs.
1 parent 1742991 commit 671d97f

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/xfastertransformer/automodel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ def set_input_cb(
8181
def forward_cb(self):
8282
return self.model.forward_cb()
8383

84-
def free_seqs(self, seq_ids):
84+
def free_seqs(self, seq_ids: Optional[Union[List[int], torch.Tensor]] = None):
85+
if isinstance(seq_ids, list):
86+
seq_ids = torch.tensor(seq_ids, dtype=torch.int64)
87+
8588
return self.model.free_seqs(seq_ids)
8689

8790
@classmethod

0 commit comments

Comments
 (0)