forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pytorch][nnc] protocol classes to persist the context for compiled f…
…unctions (pytorch#56851) Summary: Pull Request resolved: pytorch#56851 This is part of the changes to enable NNC AOT compilation for mobile. At the end of the ahead-of-time compilation the compiler produces two sets of artifacts: 1. "compiled assembly code" - kernel functions in assembly format optimized for target platforms; 2. "compiled model" - regular TorchScript model that contains serialized parameters (weights/bias/etc) and invokes kernel functions via "handles" (name/version id/input & output specs/etc of the kernel functions). This PR introduces a set of classes to represent kernel functions (a.k.a "handles"), which can be serialized/deserialized into/from the "compiled model" as an IValue. Also introduces APIs to register/look-up "compiled assembly code". ghstack-source-id: 128285802 Test Plan: - unit tests - for FB build environment: buck test //caffe2/test/mobile/nnc:mobile_nnc Reviewed By: kimishpatel, raziel Differential Revision: D27921866 fbshipit-source-id: 4c2a4d8a4d072fc259416ae674b3b494f0ca56f3
- Loading branch information
1 parent
db7b313
commit d82333e
Showing
10 changed files
with
753 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
set(MOBILE_NNC_TEST_ROOT ${TORCH_ROOT}/test/mobile/nnc) | ||
|
||
set(MOBILE_NNC_TEST_SRCS | ||
${MOBILE_NNC_TEST_ROOT}/test_context.cpp | ||
${MOBILE_NNC_TEST_ROOT}/test_registry.cpp | ||
) | ||
|
||
add_executable(test_mobile_nnc | ||
${TORCH_ROOT}/test/cpp/lite_interpreter_runtime/main.cpp | ||
${MOBILE_NNC_TEST_SRCS}) | ||
|
||
target_link_libraries(test_mobile_nnc PRIVATE torch gtest) | ||
target_include_directories(test_mobile_nnc PRIVATE ${ATen_CPU_INCLUDE}) | ||
target_compile_definitions(test_mobile_nnc PRIVATE USE_GTEST) | ||
|
||
if(INSTALL_TEST) | ||
install(TARGETS test_mobile_nnc DESTINATION bin) | ||
# Install PDB files for MSVC builds | ||
if(MSVC AND BUILD_SHARED_LIBS) | ||
install(FILES $<TARGET_PDB_FILE:test_mobile_nnc> DESTINATION bin OPTIONAL) | ||
endif() | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
#include <gtest/gtest.h> | ||
#include <torch/csrc/jit/mobile/nnc/context.h> | ||
#include <torch/csrc/jit/mobile/nnc/registry.h> | ||
#include <ATen/Functions.h> | ||
|
||
namespace torch { | ||
namespace jit { | ||
namespace mobile { | ||
namespace nnc { | ||
|
||
extern "C" { | ||
|
||
// out = a * n (doing calculation in the `tmp` buffer) | ||
int slow_mul_kernel(void** args) { | ||
const int size = 128; | ||
at::Tensor a = at::from_blob(args[0], {size}, at::kFloat); | ||
at::Tensor out = at::from_blob(args[1], {size}, at::kFloat); | ||
at::Tensor n = at::from_blob(args[2], {1}, at::kInt); | ||
at::Tensor tmp = at::from_blob(args[3], {size}, at::kFloat); | ||
|
||
tmp.zero_(); | ||
for (int i = n.item().toInt(); i > 0; i--) { | ||
tmp.add_(a); | ||
} | ||
out.copy_(tmp); | ||
return 0; | ||
} | ||
|
||
int dummy_kernel(void** /* args */) { | ||
return 0; | ||
} | ||
|
||
} // extern "C" | ||
|
||
REGISTER_NNC_KERNEL("slow_mul", slow_mul_kernel) | ||
REGISTER_NNC_KERNEL("dummy", dummy_kernel) | ||
|
||
InputSpec create_test_input_spec(const std::vector<int64_t>& sizes) { | ||
InputSpec input_spec; | ||
input_spec.sizes_ = sizes; | ||
input_spec.dtype_ = at::kFloat; | ||
return input_spec; | ||
} | ||
|
||
OutputSpec create_test_output_spec(const std::vector<int64_t>& sizes) { | ||
OutputSpec output_spec; | ||
output_spec.sizes_ = sizes; | ||
output_spec.dtype_ = at::kFloat; | ||
return output_spec; | ||
} | ||
|
||
MemoryPlan create_test_memory_plan(const std::vector<int64_t>& buffer_sizes) { | ||
MemoryPlan memory_plan; | ||
memory_plan.buffer_sizes_ = buffer_sizes; | ||
return memory_plan; | ||
} | ||
|
||
TEST(Function, ExecuteSlowMul) { | ||
const int a = 999; | ||
const int n = 100; | ||
const int size = 128; | ||
Function f; | ||
|
||
f.set_nnc_kernel_id("slow_mul"); | ||
f.set_input_specs({create_test_input_spec({size})}); | ||
f.set_output_spec({create_test_output_spec({size})}); | ||
f.set_parameters({at::ones({1}, at::kInt).mul(n)}); | ||
f.set_memory_plan(create_test_memory_plan({sizeof(float) * size})); | ||
|
||
c10::List<at::Tensor> input({ | ||
at::ones({size}, at::kFloat).mul(a) | ||
}); | ||
auto outputs = f.run(c10::impl::toList(input)); | ||
auto output = ((const c10::IValue&) outputs[0]).toTensor(); | ||
auto expected_output = at::ones({size}, at::kFloat).mul(a * n); | ||
EXPECT_TRUE(output.equal(expected_output)); | ||
} | ||
|
||
TEST(Function, Serialization) { | ||
Function f; | ||
f.set_name("test_function"); | ||
f.set_nnc_kernel_id("test_kernel"); | ||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) | ||
f.set_input_specs({create_test_input_spec({1, 3, 224, 224})}); | ||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) | ||
f.set_output_spec({create_test_output_spec({1000})}); | ||
f.set_parameters({ | ||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) | ||
at::ones({1, 16, 3, 3}, at::kFloat), | ||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) | ||
at::ones({16, 32, 1, 1}, at::kFloat), | ||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) | ||
at::ones({32, 1, 3, 3}, at::kFloat) | ||
}); | ||
f.set_memory_plan(create_test_memory_plan({ | ||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) | ||
sizeof(float) * 1024, | ||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) | ||
sizeof(float) * 2048, | ||
})); | ||
|
||
auto serialized = f.serialize(); | ||
Function f2(serialized); | ||
EXPECT_EQ(f2.name(), "test_function"); | ||
EXPECT_EQ(f2.nnc_kernel_id(), "test_kernel"); | ||
EXPECT_EQ(f2.input_specs().size(), 1); | ||
EXPECT_EQ(f2.input_specs()[0].sizes_, std::vector<int64_t>({1, 3, 224, 224})); | ||
EXPECT_EQ(f2.input_specs()[0].dtype_, at::kFloat); | ||
|
||
EXPECT_EQ(f2.output_specs().size(), 1); | ||
EXPECT_EQ(f2.output_specs()[0].sizes_, std::vector<int64_t>({1000})); | ||
EXPECT_EQ(f2.output_specs()[0].dtype_, at::kFloat); | ||
|
||
EXPECT_EQ(f2.parameters().size(), 3); | ||
EXPECT_EQ(f2.parameters()[0].sizes(), at::IntArrayRef({1, 16, 3, 3})); | ||
EXPECT_EQ(f2.parameters()[1].sizes(), at::IntArrayRef({16, 32, 1, 1})); | ||
EXPECT_EQ(f2.parameters()[2].sizes(), at::IntArrayRef({32, 1, 3, 3})); | ||
|
||
EXPECT_EQ(f2.memory_plan().buffer_sizes_.size(), 2); | ||
EXPECT_EQ(f2.memory_plan().buffer_sizes_[0], sizeof(float) * 1024); | ||
EXPECT_EQ(f2.memory_plan().buffer_sizes_[1], sizeof(float) * 2048); | ||
} | ||
|
||
TEST(Function, ValidInput) { | ||
const int size = 128; | ||
Function f; | ||
f.set_nnc_kernel_id("dummy"); | ||
f.set_input_specs({create_test_input_spec({size})}); | ||
|
||
c10::List<at::Tensor> input({ | ||
at::ones({size}, at::kFloat) | ||
}); | ||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) | ||
EXPECT_NO_THROW( | ||
f.run(c10::impl::toList(input))); | ||
} | ||
|
||
TEST(Function, InvalidInput) { | ||
const int size = 128; | ||
Function f; | ||
f.set_nnc_kernel_id("dummy"); | ||
f.set_input_specs({create_test_input_spec({size})}); | ||
|
||
c10::List<at::Tensor> input({ | ||
at::ones({size * 2}, at::kFloat) | ||
}); | ||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) | ||
EXPECT_THROW( | ||
f.run(c10::impl::toList(input)), | ||
c10::Error); | ||
} | ||
|
||
} // namespace nnc | ||
} // namespace mobile | ||
} // namespace jit | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#include <gtest/gtest.h> | ||
#include <torch/csrc/jit/mobile/nnc/registry.h> | ||
|
||
namespace torch { | ||
namespace jit { | ||
namespace mobile { | ||
namespace nnc { | ||
|
||
extern "C" { | ||
int generated_asm_kernel_foo(void**) { | ||
return 1; | ||
} | ||
|
||
int generated_asm_kernel_bar(void**) { | ||
return 2; | ||
} | ||
} // extern "C" | ||
|
||
REGISTER_NNC_KERNEL("foo:v1:VERTOKEN", generated_asm_kernel_foo) | ||
REGISTER_NNC_KERNEL("bar:v1:VERTOKEN", generated_asm_kernel_bar) | ||
|
||
TEST(MobileNNCRegistryTest, FindAndRun) { | ||
auto foo_kernel = registry::get_nnc_kernel("foo:v1:VERTOKEN"); | ||
EXPECT_EQ(foo_kernel->execute(nullptr), 1); | ||
|
||
auto bar_kernel = registry::get_nnc_kernel("bar:v1:VERTOKEN"); | ||
EXPECT_EQ(bar_kernel->execute(nullptr), 2); | ||
} | ||
|
||
TEST(MobileNNCRegistryTest, NoKernel) { | ||
EXPECT_EQ(registry::has_nnc_kernel("missing"), false); | ||
} | ||
|
||
} // namespace nnc | ||
} // namespace mobile | ||
} // namespace jit | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.