@@ -61,10 +61,11 @@ torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
6161torch::Tensor top_k_mask_logits (torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr,
6262 unsigned int top_k_val);
6363
64- torch::Tensor chain_speculative_sampling (
65- torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
66- torch::Tensor target_probs, torch::Tensor output_accepted_token_num,
67- torch::Tensor output_emitted_token_num, bool deterministic);
64+ torch::Tensor chain_speculative_sampling (torch::Tensor draft_probs, torch::Tensor draft_token_ids,
65+ torch::Tensor uniform_samples, torch::Tensor target_probs,
66+ torch::Tensor output_accepted_token_num,
67+ torch::Tensor output_emitted_token_num,
68+ bool deterministic);
6869
6970void rmsnorm (torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps);
7071
@@ -82,24 +83,30 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
8283
8384void gelu_and_mul (torch::Tensor& out, torch::Tensor& input);
8485
85- void apply_rope_inplace (torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
86- torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta);
87-
88- void apply_llama31_rope_inplace (torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
89- torch::Tensor offsets, bool interleave, float rope_scale,
90- float rope_theta, float low_freq_factor, float high_freq_factor,
91- float old_context_length);
92-
93- std::vector<torch::Tensor> apply_rope (torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
86+ std::vector<torch::Tensor> apply_rope (torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
87+ torch::Tensor k_rope, torch::Tensor indptr,
9488 torch::Tensor offsets, bool interleave, float rope_scale,
9589 float rope_theta);
9690
9791std::vector<torch::Tensor> apply_llama31_rope (torch::Tensor q, torch::Tensor k,
92+ torch::Tensor q_rope, torch::Tensor k_rope,
9893 torch::Tensor indptr, torch::Tensor offsets,
9994 bool interleave, float rope_scale, float rope_theta,
10095 float low_freq_factor, float high_freq_factor,
10196 float old_context_length);
10297
98+ std::vector<torch::Tensor> apply_rope_pos_ids (torch::Tensor q, torch::Tensor k,
99+ torch::Tensor q_rope, torch::Tensor k_rope,
100+ torch::Tensor pos_ids, bool interleave,
101+ float rope_scale, float rope_theta);
102+
103+ std::vector<torch::Tensor> apply_llama31_rope_pos_ids (torch::Tensor q, torch::Tensor k,
104+ torch::Tensor q_rope, torch::Tensor k_rope,
105+ torch::Tensor pos_ids, bool interleave,
106+ float rope_scale, float rope_theta,
107+ float low_freq_factor, float high_freq_factor,
108+ float old_context_length);
109+
103110torch::Tensor packbits (torch::Tensor x, const std::string& bitorder);
104111
105112torch::Tensor segment_packbits (torch::Tensor x, torch::Tensor input_indptr,
@@ -141,11 +148,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
141148 m.def (" silu_and_mul" , &silu_and_mul, " Fused SiLU and Mul" );
142149 m.def (" gelu_tanh_and_mul" , &gelu_tanh_and_mul, " Fused GeLU Tanh and Mul" );
143150 m.def (" gelu_and_mul" , &gelu_and_mul, " Fused GeLU and Mul" );
144- m.def (" apply_rope_inplace" , &apply_rope_inplace, " Apply RoPE in-place" );
145- m.def (" apply_llama31_rope_inplace" , &apply_llama31_rope_inplace,
146- " Apply Llama 3.1 style RoPE in-place" );
147151 m.def (" apply_rope" , &apply_rope, " Apply RoPE" );
148152 m.def (" apply_llama31_rope" , &apply_llama31_rope, " Apply Llama 3.1 style RoPE" );
153+ m.def (" apply_rope_pos_ids" , &apply_rope_pos_ids, " Apply RoPE with positional ids" );
154+ m.def (" apply_llama31_rope_pos_ids" , &apply_llama31_rope_pos_ids,
155+ " Apply Llama 3.1 style RoPE with positional ids" );
149156 m.def (" packbits" , &packbits, " GPU packbits operator" );
150157 m.def (" segment_packbits" , &segment_packbits, " GPU segment packbits operator" );
151158 m.def (" cutlass_segment_gemm" , &CutlassSegmentGEMM, " Cutlass Segment GEMM operator" );
0 commit comments