Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flashinfer-aot/csrc_aot/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr,
unsigned int top_k_val);

std::vector<torch::Tensor> chain_speculative_sampling(
torch::Tensor chain_speculative_sampling(
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
torch::Tensor target_probs, std::optional<torch::Tensor> maybe_output_accepted_token_num,
std::optional<torch::Tensor> maybe_output_emitted_token_num, bool deterministic);
torch::Tensor target_probs, torch::Tensor output_accepted_token_num,
torch::Tensor output_emitted_token_num, bool deterministic);

void rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps);

Expand Down
16 changes: 10 additions & 6 deletions python/flashinfer/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ def get_quantization_module():


@register_custom_op("flashinfer::packbits", mutates_args=())
def _packbits(x: torch.Tensor, bitorder: str) -> torch.Tensor:
return get_quantization_module().packbits(x, bitorder)


@register_fake_op("flashinfer::packbits")
def _fake_packbits(x: torch.Tensor, bitorder: str) -> torch.Tensor:
return torch.empty((x.size(0) + 7) // 8, dtype=torch.uint8, device=x.device)


def packbits(x: torch.Tensor, bitorder: str = "big") -> torch.Tensor:
r"""Pack the elements of a binary-valued array into bits in a uint8 array.

Expand Down Expand Up @@ -74,12 +83,7 @@ def packbits(x: torch.Tensor, bitorder: str = "big") -> torch.Tensor:
--------
segment_packbits
"""
return get_quantization_module().packbits(x, bitorder)


@register_fake_op("flashinfer::packbits")
def _fake_packbits(x: torch.Tensor, bitorder: str = "big") -> torch.Tensor:
return torch.empty((x.size(0) + 7) // 8, dtype=torch.uint8, device=x.device)
return _packbits(x, bitorder)


def segment_packbits(
Expand Down
4 changes: 2 additions & 2 deletions python/flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def register_custom_op(
schema: Optional[str] = None,
) -> Callable:
if TorchVersion(torch_version) < TorchVersion("2.4"):
return fn
return lambda x: x
return torch.library.custom_op(
name, fn, mutates_args=mutates_args, device_types=device_types, schema=schema
)
Expand All @@ -223,5 +223,5 @@ def register_fake_op(
fn: Optional[Callable] = None,
) -> Callable:
if TorchVersion(torch_version) < TorchVersion("2.4"):
return fn
return lambda x: x
return torch.library.register_fake(name, fn)