Skip to content

Commit

Permalink
bugfix: casting int array to int32 for rope input arguments (#697)
Browse files Browse the repository at this point in the history
To avoid the potential bugs when user pass LongTensor to rope APIs.
Also remove some files that are not used in current codebase.

Fixed a bug in AOT mode that `apply_rope_pos_ids_cos_sin_cache` was not
registered in pybind.
  • Loading branch information
yzh119 authored Dec 25, 2024
1 parent 398cd2b commit d158717
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 263 deletions.
192 changes: 0 additions & 192 deletions csrc/dispatch_type_code.h

This file was deleted.

71 changes: 0 additions & 71 deletions csrc/dispatch_utils.h

This file was deleted.

2 changes: 2 additions & 0 deletions csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids");
m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids,
"Apply Llama 3.1 style RoPE with positional ids");
m.def("apply_rope_pos_ids_cos_sin_cache", &apply_rope_pos_ids_cos_sin_cache,
"Apply RoPE with positional ids and cosine/sine cache");

// sampling
m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities");
Expand Down
7 changes: 7 additions & 0 deletions flashinfer/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def _apply_rope(
rope_theta: float,
) -> None:
with q.device as device:
indptr = indptr.int()
offsets = offsets.int()
get_rope_module().apply_rope(
q,
k,
Expand Down Expand Up @@ -104,6 +106,8 @@ def _apply_llama31_rope(
old_context_len: float,
) -> None:
with q.device as device:
indptr = indptr.int()
offsets = offsets.int()
get_rope_module().apply_llama31_rope(
q,
k,
Expand Down Expand Up @@ -154,6 +158,7 @@ def _apply_rope_pos_ids(
rope_theta: float,
) -> None:
with q.device as device:
pos_ids = pos_ids.int()
get_rope_module().apply_rope_pos_ids(
q,
k,
Expand Down Expand Up @@ -197,6 +202,7 @@ def _apply_rope_pos_ids_cos_sin_cache(
interleave: bool,
) -> None:
with q.device as device:
pos_ids = pos_ids.int()
get_rope_module().apply_rope_pos_ids_cos_sin_cache(
q,
k,
Expand Down Expand Up @@ -242,6 +248,7 @@ def _apply_llama31_rope_pos_ids(
old_context_len: float,
) -> None:
with q.device as device:
pos_ids = pos_ids.int()
get_rope_module().apply_llama31_rope_pos_ids(
q,
k,
Expand Down

0 comments on commit d158717

Please sign in to comment.