Skip to content

[Executorch][BE] Rename sdpa_with_kv_cache.py to custom_ops.py #7210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/models/llama/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(

# Note: import this after portable_lib
from executorch.extension.llm.custom_ops import ( # noqa
sdpa_with_kv_cache, # usort: skip
custom_ops, # usort: skip
)
from executorch.kernels import quantized # noqa

Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama/runner/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from executorch.examples.models.llama.runner.generation import LlamaRunner

# Note: import this after portable_lib
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip
from executorch.kernels import quantized # noqa


Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):


def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa
from executorch.extension.llm.custom_ops import custom_ops # noqa

_replace_sdpa_with_custom_op(module)
return module
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llava/test/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip
from executorch.kernels import quantized # noqa # usort: skip

logging.basicConfig(level=logging.INFO)
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llava/test/test_pte.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from PIL import Image

# Custom ops has to be loaded after portable_lib.
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip
from executorch.kernels import quantized # noqa # usort: skip

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
Expand Down
2 changes: 1 addition & 1 deletion extension/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ A sampler class in C++ to sample the logistics given some hyperparameters.
## custom_ops
Contains custom op, such as:
- custom sdpa: implements CPU flash attention and avoids copies by taking the kv cache as one of its arguments.
- _sdpa_with_kv_cache.py_, _op_sdpa_aot.cpp_: custom op definition in PyTorch with C++ registration.
- _custom_ops.py_, _op_sdpa_aot.cpp_: custom op definition in PyTorch with C++ registration.
- _op_sdpa.cpp_: the optimized operator implementation and registration of _sdpa_with_kv_cache.out_.

## runner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from torch.library import impl

# TODO rename this file to custom_ops_meta_registration.py
try:
op = torch.ops.llama.sdpa_with_kv_cache.default
assert op is not None
Expand Down
2 changes: 1 addition & 1 deletion extension/llm/custom_ops/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def define_common_targets():
runtime.python_library(
name = "custom_ops_aot_py",
srcs = [
"sdpa_with_kv_cache.py",
"custom_ops.py",
],
visibility = [
"//executorch/...",
Expand Down
2 changes: 1 addition & 1 deletion extension/llm/custom_ops/test_sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
import torch.nn.functional as F

from .sdpa_with_kv_cache import custom_ops_lib # noqa
from .custom_ops import custom_ops_lib # noqa


def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, attn_mask, start_pos, seq_len):
Expand Down
Loading