Skip to content

Commit

Permalink
Use new API to register custom ops for llama model (#2916)
Browse files Browse the repository at this point in the history
Summary:

Retry of D55713944

Use `EXECUTORCH_LIBRARY` to register custom kernel to ExecuTorch runtime.

Differential Revision: D55856491
  • Loading branch information
larryliu0820 authored and facebook-github-bot committed Apr 9, 2024
1 parent 9c2b0d7 commit a361aac
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 126 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ if(EXECUTORCH_BUILD_PYBIND)
endif()

if(EXECUTORCH_BUILD_CUSTOM)
list(APPEND _dep_libs custom_ops_lib)
list(APPEND _dep_libs custom_ops)
endif()

# compile options for pybind
Expand Down
4 changes: 2 additions & 2 deletions examples/models/llama2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ else()
endif()

if(EXECUTORCH_BUILD_CUSTOM)
target_link_options_shared_lib(custom_ops_lib)
list(APPEND link_libraries custom_ops_lib)
target_link_options_shared_lib(custom_ops)
list(APPEND link_libraries custom_ops)
endif()

# XNNPACK pthreadpool cpuinfo
Expand Down
16 changes: 3 additions & 13 deletions examples/models/llama2/custom_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,6 @@ list(APPEND custom_ops_libs cpuinfo)
list(APPEND custom_ops_libs cpublas)
list(APPEND custom_ops_libs eigen_blas)

# Generate C++ bindings to register kernels into both PyTorch (for AOT) and
# Executorch (for runtime). Here select all ops in optimized.yaml
set(_yaml "${CMAKE_CURRENT_LIST_DIR}/custom_ops.yaml")
gen_selected_ops("${_yaml}" "" "")

generate_bindings_for_kernels(FUNCTIONS_YAML
${CMAKE_CURRENT_SOURCE_DIR}/custom_ops.yaml)
message("Generated files ${gen_command_sources}")

list(TRANSFORM _custom_ops__srcs PREPEND "${EXECUTORCH_ROOT}/")

# TODO: Consider moving xnnpack/threadpool in a separate lib since it's now used
Expand All @@ -70,6 +61,8 @@ if(NOT EXECUTORCH_BUILD_XNNPACK)
"${CMAKE_CURRENT_SOURCE_DIR}/../../../../backends/xnnpack/threadpool/threadpool.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/../../../../backends/xnnpack/threadpool/threadpool_guard.cpp"
)
else()
list(APPEND custom_ops_libs xnnpack_backend)
endif()

add_library(custom_ops ${_custom_ops__srcs})
Expand All @@ -82,7 +75,4 @@ target_link_libraries(custom_ops PUBLIC ${custom_ops_libs})
target_compile_options(custom_ops PUBLIC ${_common_compile_options}
-DET_USE_THREADPOOL)

# Build a library for _custom_ops_srcs
#
# custom_ops_lib: Register optimized ops kernels into Executorch runtime
gen_operators_lib("custom_ops_lib" KERNEL_LIBS custom_ops DEPS executorch)
install(TARGETS custom_ops DESTINATION lib)
Empty file.
14 changes: 0 additions & 14 deletions examples/models/llama2/custom_ops/custom_ops.yaml

This file was deleted.

8 changes: 7 additions & 1 deletion examples/models/llama2/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h>

#include <executorch/kernels/optimized/blas/CPUBlas.h>
#include <executorch/kernels/optimized/vec/functional.h>
Expand All @@ -22,6 +22,7 @@
#include <executorch/backends/xnnpack/threadpool/threadpool.h>
#include <executorch/extension/parallel/thread_parallel.h>
#endif
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>

namespace torch {
namespace executor {
Expand Down Expand Up @@ -843,3 +844,8 @@ Tensor& sdpa_with_kv_cache_out(
} // namespace native
} // namespace executor
} // namespace torch

EXECUTORCH_LIBRARY(
llama,
"sdpa_with_kv_cache.out",
torch::executor::native::sdpa_with_kv_cache_out);
48 changes: 48 additions & 0 deletions examples/models/llama2/custom_ops/op_sdpa.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {

namespace native {

Tensor& sdpa_with_kv_cache_out(
RuntimeContext& ctx,
const Tensor& q_projected,
const Tensor& k_projected,
const Tensor& v_projected,
Tensor& key_cache,
Tensor& value_cache,
const int64_t start_pos,
const int64_t seq_len,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output);

Tensor& flash_attention_kernel_out(
RuntimeContext& ctx,
const Tensor& query,
const Tensor& key,
const Tensor& value,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output);

} // namespace native
} // namespace executor
} // namespace torch
5 changes: 3 additions & 2 deletions examples/models/llama2/custom_ops/op_sdpa_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

#include <limits>

#include <executorch/examples/models/llama2/custom_ops/FunctionHeaderWrapper.h> // Declares the operator
#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h>

#include <executorch/kernels/test/TestUtil.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
Expand All @@ -28,7 +29,7 @@ exec_aten::Tensor op_scaled_dot_product_attention(
exec_aten::optional<double> scale,
exec_aten::Tensor& out) {
exec_aten::RuntimeContext context{};
return torch::executor::llama::sdpa_outf(
return torch::executor::native::flash_attention_kernel_out(
context, query, key, value, attn_mask, dropout_p, is_causal, scale, out);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#include <limits>

#include <executorch/examples/models/llama2/custom_ops/FunctionHeaderWrapper.h> // Declares the operator
#include <executorch/examples/models/llama2/custom_ops/op_sdpa.h> // Declares the operator
#include <executorch/kernels/test/TestUtil.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
Expand All @@ -32,7 +32,7 @@ exec_aten::Tensor op_sdpa_with_kv_cache(
exec_aten::optional<double> scale,
exec_aten::Tensor& out) {
exec_aten::RuntimeContext context{};
return torch::executor::llama::sdpa_with_kv_cache_outf(
return torch::executor::native::sdpa_with_kv_cache_out(
context,
query,
key,
Expand Down
115 changes: 30 additions & 85 deletions examples/models/llama2/custom_ops/targets.bzl
Original file line number Diff line number Diff line change
@@ -1,49 +1,11 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib")
load("@fbsource//xplat/executorch/kernels/test:util.bzl", "codegen_function_header_wrapper")

def define_tests():
codegen_function_header_wrapper("executorch/examples/models/llama2/custom_ops", "custom_ops")

# In the long run we should really have aten variant available as well
deps = [":function_header_wrapper_custom_ops"]
generated_lib_and_op_deps = [
":custom_ops",
":sdpa",
":custom_ops_headers",
]
runtime.cxx_test(
name = "op_sdpa_test",
srcs = [
"op_sdpa_test.cpp",
],
visibility = ["//executorch/..."],
deps = [
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
"//executorch/kernels/test:test_util",
] + generated_lib_and_op_deps + deps,
)
runtime.cxx_test(
name = "op_sdpa_with_kv_cache_test",
srcs = [
"op_sdpa_with_kv_cache_test.cpp",
],
visibility = ["//executorch/..."],
deps = [
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
"//executorch/kernels/test:test_util",
] + generated_lib_and_op_deps + deps,
)

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.
The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""

runtime.python_library(
name = "llama_custom_ops_aot_lib",
srcs = [
Expand All @@ -58,71 +20,54 @@ def define_common_targets():
],
)

runtime.export_file(
name = "custom_ops.yaml",
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
)

# ~~~ START of custom ops 1 `my_ops::mul3` library definitions ~~~
et_operator_library(
name = "sdpa_op",
ops = [
"llama::sdpa.out",
],
define_static_targets = True,
visibility = [
"//executorch/codegen/...",
"@EXECUTORCH_CLIENTS",
],
)

et_operator_library(
name = "sdpa_with_kv_cache",
ops = [
"llama::sdpa_with_kv_cache.out",
],
define_static_targets = True,
visibility = [
"//executorch/codegen/...",
"@EXECUTORCH_CLIENTS",
],
)

runtime.cxx_library(
name = "sdpa",
name = "custom_ops",
srcs = ["op_sdpa.cpp"],
deps = [
exported_headers = ["op_sdpa.h"],
exported_deps = [
"//executorch/runtime/kernel:kernel_includes",
"//executorch/kernels/portable/cpu:scalar_utils",
"//executorch/kernels/optimized:libblas",
"//executorch/kernels/optimized:libvec",
"//executorch/extension/kernel_util:kernel_util",
"//executorch/extension/parallel:thread_parallel",
"//executorch/backends/xnnpack/threadpool:threadpool",
],
compiler_flags = ["-Wno-missing-prototypes"],
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
visibility = [
"//executorch/...",
"//executorch/examples/models/llama2/custom_ops/...",
"@EXECUTORCH_CLIENTS",
],
# @lint-ignore BUCKLINT link_whole
link_whole = True,
force_static = True,
)

executorch_generated_lib(
name = "custom_ops",
runtime.cxx_test(
name = "op_sdpa_test",
srcs = [
"op_sdpa_test.cpp",
],
visibility = ["//executorch/..."],
deps = [
":sdpa_op",
":sdpa_with_kv_cache",
":sdpa",
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
"//executorch/kernels/test:test_util",
":custom_ops",
],
custom_ops_yaml_target = ":custom_ops.yaml",
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
)

runtime.cxx_test(
name = "op_sdpa_with_kv_cache_test",
srcs = [
"op_sdpa_with_kv_cache_test.cpp",
],
visibility = ["//executorch/..."],
deps = [
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
"//executorch/kernels/test:test_util",
":custom_ops",
],
define_static_targets = True,
)
define_tests()
8 changes: 2 additions & 6 deletions extension/android/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,10 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
add_library(llama_runner STATIC IMPORTED)
set_property(TARGET llama_runner PROPERTY IMPORTED_LOCATION ${LLAMA_RUNNER_PATH})

set(CUSTOM_OPS_LIB_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama2/custom_ops/libcustom_ops_lib.a)
add_library(custom_ops_lib STATIC IMPORTED)
set_property(TARGET custom_ops_lib PROPERTY IMPORTED_LOCATION ${CUSTOM_OPS_LIB_PATH})

set(CUSTOM_OPS_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama2/custom_ops/libcustom_ops.a)
add_library(custom_ops STATIC IMPORTED)
set_property(TARGET custom_ops PROPERTY IMPORTED_LOCATION ${CUSTOM_OPS_PATH})
target_link_options_shared_lib(custom_ops_lib)
target_link_options_shared_lib(custom_ops)

if(TARGET pthreadpool)
set(LLAMA_JNI_SRCS jni/jni_layer_llama.cpp ../../backends/xnnpack/threadpool/cpuinfo_utils.cpp)
Expand All @@ -82,6 +78,6 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
endif()
target_include_directories(executorch_llama_jni PRIVATE ${_common_include_directories})
target_link_libraries(executorch_llama_jni ${link_libraries} llama_runner
custom_ops custom_ops_lib cpublas eigen_blas)
custom_ops cpublas eigen_blas)
target_compile_options(executorch_llama_jni PUBLIC ${_common_compile_options})
endif()

0 comments on commit a361aac

Please sign in to comment.