diff --git a/.jenkins/pytorch/macos-lite-interpreter-build-test.sh b/.jenkins/pytorch/macos-lite-interpreter-build-test.sh index 901f4517ddbd5..0e23f13bba1db 100644 --- a/.jenkins/pytorch/macos-lite-interpreter-build-test.sh +++ b/.jenkins/pytorch/macos-lite-interpreter-build-test.sh @@ -28,6 +28,7 @@ if [ "${BUILD_LITE_INTERPRETER}" == 1 ]; then popd || exit "${CPP_BUILD}/caffe2/build/bin/test_lite_interpreter_runtime" + "${CPP_BUILD}/caffe2/build/bin/test_mobile_nnc" # Change the permission manually from 755 to 644 to keep git clean chmod 644 "${HOME}/project/.jenkins/pytorch/macos-lite-interpreter-build-test.sh" diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 9bfbe5a6c1e2a..8bc68615b9834 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -575,6 +575,8 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/export_data.cpp + ${TORCH_SRC_DIR}/csrc/jit/mobile/nnc/context.cpp + ${TORCH_SRC_DIR}/csrc/jit/mobile/nnc/registry.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/optim/sgd.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/sequential.cpp ) @@ -1057,6 +1059,10 @@ endif() ${TORCH_ROOT}/test/cpp/lite_interpreter_runtime ${CMAKE_BINARY_DIR}/test_lite_interpreter_runtime ) + add_subdirectory( + ${TORCH_ROOT}/test/mobile/nnc + ${CMAKE_BINARY_DIR}/test_mobile_nnc + ) else() add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit) add_subdirectory( diff --git a/test/mobile/nnc/CMakeLists.txt b/test/mobile/nnc/CMakeLists.txt new file mode 100644 index 0000000000000..001c7f32d5b8b --- /dev/null +++ b/test/mobile/nnc/CMakeLists.txt @@ -0,0 +1,22 @@ +set(MOBILE_NNC_TEST_ROOT ${TORCH_ROOT}/test/mobile/nnc) + +set(MOBILE_NNC_TEST_SRCS + ${MOBILE_NNC_TEST_ROOT}/test_context.cpp + ${MOBILE_NNC_TEST_ROOT}/test_registry.cpp +) + +add_executable(test_mobile_nnc + ${TORCH_ROOT}/test/cpp/lite_interpreter_runtime/main.cpp + ${MOBILE_NNC_TEST_SRCS}) + +target_link_libraries(test_mobile_nnc PRIVATE torch gtest) +target_include_directories(test_mobile_nnc PRIVATE ${ATen_CPU_INCLUDE}) +target_compile_definitions(test_mobile_nnc PRIVATE USE_GTEST) + +if(INSTALL_TEST) + install(TARGETS test_mobile_nnc DESTINATION bin) + # Install PDB files for MSVC builds + if(MSVC AND BUILD_SHARED_LIBS) + install(FILES $ DESTINATION bin OPTIONAL) + endif() +endif() diff --git a/test/mobile/nnc/test_context.cpp b/test/mobile/nnc/test_context.cpp new file mode 100644 index 0000000000000..c5e30511bce38 --- /dev/null +++ b/test/mobile/nnc/test_context.cpp @@ -0,0 +1,156 @@ +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace mobile { +namespace nnc { + +extern "C" { + +// out = a * n (doing calculation in the `tmp` buffer) +int slow_mul_kernel(void** args) { + const int size = 128; + at::Tensor a = at::from_blob(args[0], {size}, at::kFloat); + at::Tensor out = at::from_blob(args[1], {size}, at::kFloat); + at::Tensor n = at::from_blob(args[2], {1}, at::kInt); + at::Tensor tmp = at::from_blob(args[3], {size}, at::kFloat); + + tmp.zero_(); + for (int i = n.item().toInt(); i > 0; i--) { + tmp.add_(a); + } + out.copy_(tmp); + return 0; +} + +int dummy_kernel(void** /* args */) { + return 0; +} + +} // extern "C" + +REGISTER_NNC_KERNEL("slow_mul", slow_mul_kernel) +REGISTER_NNC_KERNEL("dummy", dummy_kernel) + +InputSpec create_test_input_spec(const std::vector& sizes) { + InputSpec input_spec; + input_spec.sizes_ = sizes; + input_spec.dtype_ = at::kFloat; + return input_spec; +} + +OutputSpec create_test_output_spec(const std::vector& sizes) { + OutputSpec output_spec; + output_spec.sizes_ = sizes; + output_spec.dtype_ = at::kFloat; + return output_spec; +} + +MemoryPlan create_test_memory_plan(const std::vector& buffer_sizes) { + MemoryPlan memory_plan; + memory_plan.buffer_sizes_ = buffer_sizes; + return memory_plan; +} + +TEST(Function, ExecuteSlowMul) { + const int a = 999; + const int n = 100; + const int size = 128; + Function f; + + f.set_nnc_kernel_id("slow_mul"); + f.set_input_specs({create_test_input_spec({size})}); + f.set_output_spec({create_test_output_spec({size})}); + f.set_parameters({at::ones({1}, at::kInt).mul(n)}); + f.set_memory_plan(create_test_memory_plan({sizeof(float) * size})); + + c10::List input({ + at::ones({size}, at::kFloat).mul(a) + }); + auto outputs = f.run(c10::impl::toList(input)); + auto output = ((const c10::IValue&) outputs[0]).toTensor(); + auto expected_output = at::ones({size}, at::kFloat).mul(a * n); + EXPECT_TRUE(output.equal(expected_output)); +} + +TEST(Function, Serialization) { + Function f; + f.set_name("test_function"); + f.set_nnc_kernel_id("test_kernel"); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + f.set_input_specs({create_test_input_spec({1, 3, 224, 224})}); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + f.set_output_spec({create_test_output_spec({1000})}); + f.set_parameters({ + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + at::ones({1, 16, 3, 3}, at::kFloat), + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + at::ones({16, 32, 1, 1}, at::kFloat), + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + at::ones({32, 1, 3, 3}, at::kFloat) + }); + f.set_memory_plan(create_test_memory_plan({ + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + sizeof(float) * 1024, + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + sizeof(float) * 2048, + })); + + auto serialized = f.serialize(); + Function f2(serialized); + EXPECT_EQ(f2.name(), "test_function"); + EXPECT_EQ(f2.nnc_kernel_id(), "test_kernel"); + EXPECT_EQ(f2.input_specs().size(), 1); + EXPECT_EQ(f2.input_specs()[0].sizes_, std::vector({1, 3, 224, 224})); + EXPECT_EQ(f2.input_specs()[0].dtype_, at::kFloat); + + EXPECT_EQ(f2.output_specs().size(), 1); + EXPECT_EQ(f2.output_specs()[0].sizes_, std::vector({1000})); + EXPECT_EQ(f2.output_specs()[0].dtype_, at::kFloat); + + EXPECT_EQ(f2.parameters().size(), 3); + EXPECT_EQ(f2.parameters()[0].sizes(), at::IntArrayRef({1, 16, 3, 3})); + EXPECT_EQ(f2.parameters()[1].sizes(), at::IntArrayRef({16, 32, 1, 1})); + EXPECT_EQ(f2.parameters()[2].sizes(), at::IntArrayRef({32, 1, 3, 3})); + + EXPECT_EQ(f2.memory_plan().buffer_sizes_.size(), 2); + EXPECT_EQ(f2.memory_plan().buffer_sizes_[0], sizeof(float) * 1024); + EXPECT_EQ(f2.memory_plan().buffer_sizes_[1], sizeof(float) * 2048); +} + +TEST(Function, ValidInput) { + const int size = 128; + Function f; + f.set_nnc_kernel_id("dummy"); + f.set_input_specs({create_test_input_spec({size})}); + + c10::List input({ + at::ones({size}, at::kFloat) + }); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + EXPECT_NO_THROW( + f.run(c10::impl::toList(input))); +} + +TEST(Function, InvalidInput) { + const int size = 128; + Function f; + f.set_nnc_kernel_id("dummy"); + f.set_input_specs({create_test_input_spec({size})}); + + c10::List input({ + at::ones({size * 2}, at::kFloat) + }); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + EXPECT_THROW( + f.run(c10::impl::toList(input)), + c10::Error); +} + +} // namespace nnc +} // namespace mobile +} // namespace jit +} // namespace torch diff --git a/test/mobile/nnc/test_registry.cpp b/test/mobile/nnc/test_registry.cpp new file mode 100644 index 0000000000000..e7adeb864da3a --- /dev/null +++ b/test/mobile/nnc/test_registry.cpp @@ -0,0 +1,37 @@ +#include +#include + +namespace torch { +namespace jit { +namespace mobile { +namespace nnc { + +extern "C" { +int generated_asm_kernel_foo(void**) { + return 1; +} + +int generated_asm_kernel_bar(void**) { + return 2; +} +} // extern "C" + +REGISTER_NNC_KERNEL("foo:v1:VERTOKEN", generated_asm_kernel_foo) +REGISTER_NNC_KERNEL("bar:v1:VERTOKEN", generated_asm_kernel_bar) + +TEST(MobileNNCRegistryTest, FindAndRun) { + auto foo_kernel = registry::get_nnc_kernel("foo:v1:VERTOKEN"); + EXPECT_EQ(foo_kernel->execute(nullptr), 1); + + auto bar_kernel = registry::get_nnc_kernel("bar:v1:VERTOKEN"); + EXPECT_EQ(bar_kernel->execute(nullptr), 2); +} + +TEST(MobileNNCRegistryTest, NoKernel) { + EXPECT_EQ(registry::has_nnc_kernel("missing"), false); +} + +} // namespace nnc +} // namespace mobile +} // namespace jit +} // namespace torch diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 078a4c9635d77..d529b8235d1b5 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -397,6 +397,8 @@ libtorch_extra_sources = libtorch_core_jit_sources + [ "torch/csrc/jit/mobile/observer.cpp", "torch/csrc/jit/mobile/optim/sgd.cpp", "torch/csrc/jit/mobile/sequential.cpp", + "torch/csrc/jit/mobile/nnc/context.cpp", + "torch/csrc/jit/mobile/nnc/registry.cpp", "torch/csrc/jit/serialization/onnx.cpp", "torch/csrc/jit/serialization/export.cpp", "torch/csrc/jit/serialization/export_module.cpp", diff --git a/torch/csrc/jit/mobile/nnc/context.cpp b/torch/csrc/jit/mobile/nnc/context.cpp new file mode 100644 index 0000000000000..8b20f0ca919c7 --- /dev/null +++ b/torch/csrc/jit/mobile/nnc/context.cpp @@ -0,0 +1,262 @@ +#include + +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace mobile { +namespace nnc { + +constexpr int64_t kProducedNNCFileFormatVersion = 0x1L; + +namespace { + +c10::IValue Tup(std::vector&& ivalues) { + return c10::ivalue::Tuple::create(ivalues); +} + +} // namespace + +InputSpec::InputSpec(const c10::IValue& value) { + auto dict = value.toGenericDict(); + sizes_ = dict.at("sizes").toIntVector(); + dtype_ = dict.at("dtype").toScalarType(); +} + +c10::IValue InputSpec::serialize() const { + c10::Dict dict( + at::StringType::get(), at::AnyType::get()); + dict.insert("sizes", sizes_); + dict.insert("dtype", dtype_); + return dict; +} + +bool InputSpec::validate(const at::Tensor& input) const { + return input.sizes() == sizes_ && input.scalar_type() == dtype_; +} + +OutputSpec::OutputSpec(const c10::IValue& value) { + auto dict = value.toGenericDict(); + sizes_ = dict.at("sizes").toIntVector(); + dtype_ = dict.at("dtype").toScalarType(); +} + +c10::IValue OutputSpec::serialize() const { + c10::Dict dict( + at::StringType::get(), at::AnyType::get()); + dict.insert("sizes", sizes_); + dict.insert("dtype", dtype_); + return dict; +} + +at::Tensor OutputSpec::allocate() const { + return at::empty( + sizes_, + at::TensorOptions() + .dtype(dtype_) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false)); +} + +MemoryPlan::MemoryPlan(const c10::IValue& value) { + auto dict = value.toGenericDict(); + buffer_sizes_ = dict.at("buffer_sizes").toIntVector(); +} + +c10::IValue MemoryPlan::serialize() const { + c10::Dict dict( + at::StringType::get(), at::AnyType::get()); + dict.insert("buffer_sizes", buffer_sizes_); + return dict; +} + +void MemoryPlan::allocate(ExecutionState* state) const { + auto& allocations = state->preallocations_; + allocations.clear(); + allocations.reserve(buffer_sizes_.size()); + for (int64_t buffer_size : buffer_sizes_) { + at::DataPtr buffer = c10::GetCPUAllocator()->allocate(buffer_size); + allocations.emplace_back(std::move(buffer)); + } +} + +Function::Function(const c10::IValue& value) { + auto dict = value.toGenericDict(); + name_ = c10::QualifiedName(dict.at("name").toStringRef()); + nnc_kernel_id_ = dict.at("nnc_kernel_id").toStringRef(); + parameters_ = dict.at("parameters").toTensorVector(); + + // input_specs_ + for (const auto& input_value : dict.at("input_specs").toTuple()->elements()) { + input_specs_.emplace_back(input_value); + } + + // output_specs_ + for (const auto& output_value : + dict.at("output_specs").toTuple()->elements()) { + output_specs_.emplace_back(output_value); + } + + // memory_plan_ + memory_plan_ = MemoryPlan(dict.at("memory_plan")); +} + +c10::IValue Function::serialize() const { + c10::Dict dict( + at::StringType::get(), at::AnyType::get()); + + dict.insert("name", name_.qualifiedName()); + dict.insert("nnc_kernel_id", nnc_kernel_id_); + // TODO: should serialize parameters with Module instead of with each Method. + // And ideally the parameters should be shared between the compiled model + // and the original model if we can serialize both in the same model file. + dict.insert("parameters", parameters_); + + // input_specs_ + std::vector input_specs; + for (const auto& input_spec : input_specs_) { + input_specs.emplace_back(input_spec.serialize()); + } + dict.insert("input_specs", Tup(std::move(input_specs))); + + // output_specs_ + std::vector output_specs; + for (const auto& output_spec : output_specs_) { + output_specs.emplace_back(output_spec.serialize()); + } + dict.insert("output_specs", Tup(std::move(output_specs))); + + // memory_plan_ + dict.insert("memory_plan", memory_plan_.serialize()); + return dict; +} + +void Function::init_execution_state() const { + if (execution_state_.get() != nullptr) { + return; + } + + ExecutionState state; + memory_plan_.allocate(&state); + + // The arguments vector consists of 4 sections: inputs, outputs, parameters + // and buffers. + auto input_args = input_specs_.size(); + auto output_args = output_specs_.size(); + auto param_args = parameters_.size(); + auto buffer_args = state.preallocations_.size(); + + auto& arguments = state.arguments_; + arguments.reserve(input_args + output_args + param_args + buffer_args); + + // Keep empty slots to fill in inputs/outputs pointers at execution time. + arguments.resize(input_args + output_args); + + // Fill in parameter pointers. + for (const auto& param : parameters_) { + arguments.emplace_back(param.data_ptr()); + } + + // Fill in preallocated buffer pointers. + for (const auto& preallocation : state.preallocations_) { + arguments.emplace_back(preallocation.get()); + } + + execution_state_ = std::make_unique(std::move(state)); +} + +c10::impl::GenericList Function::run( + const c10::impl::GenericList& inputs) const { + TORCH_CHECK( + registry::has_nnc_kernel(nnc_kernel_id_), + "Cannot find NNC kernel: ", + nnc_kernel_id_); + + init_execution_state(); + + std::vector& args = execution_state_->arguments_; + + // Fill in input tensors. + TORCH_CHECK( + input_specs_.size() == inputs.size(), + "Input size doesn't match the spec, expect: ", + input_specs_.size(), + " actual: ", + inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + const c10::IValue& input = inputs[i]; + const auto& input_tensor = input.toTensor(); + TORCH_CHECK( + input_specs_[i].validate(input_tensor), "Invalid input at pos: ", i); + args[i] = input_tensor.data_ptr(); + } + + // Preallocate and fill in output tensors. + c10::List outputs; + outputs.reserve(output_specs_.size()); + for (size_t i = 0; i < output_specs_.size(); ++i) { + at::Tensor output = output_specs_[i].allocate(); + outputs.emplace_back(output); + args[inputs.size() + i] = output.data_ptr(); + } + + // TODO: check consistency, e.g.: code version, input shape and compiled + // shape, etc. + auto kernel = registry::get_nnc_kernel(nnc_kernel_id_); + kernel->execute(args.data()); + + return c10::impl::toList(outputs); +} + +CompilationUnit::CompilationUnit(const c10::IValue& value) { + const auto& root = value.toTuple()->elements(); + const auto& functions = root[1].toTuple()->elements(); + for (const auto& function : functions) { + register_function(std::make_unique(function)); + } +} + +c10::IValue CompilationUnit::serialize() const { + auto functions = + c10::fmap(functions_, [](decltype(functions_)::const_reference func) { + return func.second->serialize(); + }); + return Tup({kProducedNNCFileFormatVersion, Tup(std::move(functions))}); +} + +c10::impl::GenericList CompilationUnit::run( + const c10::QualifiedName& name, + const c10::impl::GenericList& inputs) const { + Function* func = find_function(name); + TORCH_CHECK( + func != nullptr, "Function '", name.qualifiedName(), "' is not defined."); + return func->run(inputs); +} + +void CompilationUnit::register_function(std::unique_ptr fn) { + TORCH_CHECK( + 0 == functions_.count(fn->name()), + "method '", + fn->name().qualifiedName(), + "' already defined."); + const auto& name = fn->name(); + functions_.emplace(name, std::move(fn)); +} + +Function* CompilationUnit::find_function(const c10::QualifiedName& name) const { + auto it = functions_.find(name); + if (it == functions_.end()) { + return nullptr; + } + return it->second.get(); +} + +} // namespace nnc +} // namespace mobile +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/mobile/nnc/context.h b/torch/csrc/jit/mobile/nnc/context.h new file mode 100644 index 0000000000000..a2ec1760eb517 --- /dev/null +++ b/torch/csrc/jit/mobile/nnc/context.h @@ -0,0 +1,207 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace torch { +namespace jit { +namespace mobile { +namespace nnc { + +// Specify the requirements on an input tensor. +// TODO: support input tensor with dynamic shape (PR #54982) +struct TORCH_API InputSpec { + InputSpec() = default; + + // Deserialize the spec from an IValue. + explicit InputSpec(const c10::IValue& value); + + // Serialize the spec into an IValue. + C10_NODISCARD c10::IValue serialize() const; + + // Check whether the input tensor adheres to the spec. + C10_NODISCARD bool validate(const at::Tensor& input) const; + + std::vector sizes_; + c10::ScalarType dtype_{c10::ScalarType::Undefined}; +}; + +// Specify the sizes/dtype/... of output tensor to preallocate the output. +// TODO: support the case where kernel allocates output tensors dynamically. +struct TORCH_API OutputSpec { + OutputSpec() = default; + + // Deserialize the spec from an IValue. + explicit OutputSpec(const c10::IValue& value); + + // Serialize the spec into an IValue. + C10_NODISCARD c10::IValue serialize() const; + + // Allocate an output tensor in accordance with the spec. + C10_NODISCARD at::Tensor allocate() const; + + std::vector sizes_; + c10::ScalarType dtype_{c10::ScalarType::Undefined}; +}; + +// Hold the temporary buffers / states needed during the execution. +struct TORCH_API ExecutionState { + ExecutionState() = default; + ExecutionState(const ExecutionState&) = delete; + ExecutionState(ExecutionState&&) = default; + ExecutionState& operator=(const ExecutionState&) = delete; + ExecutionState& operator=(ExecutionState&&) = default; + + // Preallocated buffers needed by the NNC kernel. + std::vector preallocations_; + + // The NNC kernel expects the following arguments layout: + // input tensor 1 + // ... + // input tensor INPUT_NUM + // output tensor 1 + // ... + // output tensor OUTPUT_NUM + // parameter tensor 1 + // ... + // parameter tensor PARAM_NUM + // temporary buffer 1 + // ... + // temporary buffer BUFFER_NUM + std::vector arguments_; +}; + +// Specify how to allocate temporary buffers at initialization. +struct TORCH_API MemoryPlan { + MemoryPlan() = default; + + explicit MemoryPlan(const c10::IValue& value); + + C10_NODISCARD c10::IValue serialize() const; + + void allocate(ExecutionState* state) const; + + std::vector buffer_sizes_; +}; + +// Represents a compiled NNC function which has a 1-1 correspondence with a +// `Method` (e.g. `forward`). It's similar as torch::jit::mobile::Function. +class TORCH_API Function { + public: + explicit Function() = default; + + // Deserialize from an IValue that is generated by the 'serialize()' method. + explicit Function(const c10::IValue& value); + + // Serialize into an IValue. + c10::IValue serialize() const; + + // Execute the compiled NNC function. + c10::impl::GenericList run(const c10::impl::GenericList& inputs) const; + + // The name of the function as specified in the model code. + c10::QualifiedName name() const { + return name_; + } + + void set_name(const c10::QualifiedName& name) { + name_ = name; + } + + // The unique id of the generated NNC kernel corresponding to the function. + const std::string& nnc_kernel_id() const { + return nnc_kernel_id_; + } + + void set_nnc_kernel_id(const std::string& name) { + nnc_kernel_id_ = name; + } + + // The parameters (e.g. weights / bias tensors) to be passed to the generated + // NNC kernel. + const std::vector& parameters() const { + return parameters_; + } + + void set_parameters(const std::vector& parameters) { + parameters_ = parameters; + } + + const std::vector& input_specs() const { + return input_specs_; + } + + void set_input_specs(const std::vector& input_specs) { + input_specs_ = input_specs; + } + + const std::vector& output_specs() const { + return output_specs_; + } + + void set_output_spec(const std::vector& output_specs) { + output_specs_ = output_specs; + } + + const MemoryPlan& memory_plan() const { + return memory_plan_; + } + + void set_memory_plan(const MemoryPlan& memory_plan) { + memory_plan_ = memory_plan; + } + + private: + void init_execution_state() const; + + c10::QualifiedName name_; + std::string nnc_kernel_id_; + std::vector parameters_; + std::vector input_specs_; + std::vector output_specs_; + MemoryPlan memory_plan_; + mutable std::unique_ptr execution_state_; +}; + +// CompilationUnit consists of a set of compiled NNC functions. It has a 1-1 +// correspondence with a `Module`. +// It's similar as torch::jit::mobile::CompilationUnit. +class TORCH_API CompilationUnit { + public: + CompilationUnit() = default; + CompilationUnit(const CompilationUnit&) = delete; + CompilationUnit(CompilationUnit&&) = default; + CompilationUnit& operator=(const CompilationUnit&) = delete; + CompilationUnit& operator=(CompilationUnit&&) = default; + + // Deserialize from an IValue that is generated by the 'serialize()' method. + explicit CompilationUnit(const c10::IValue& value); + + // Serialize all registered functions into an IValue. The IValue will be save + // into the compiled TorchScript model file ahead-of-time on the host, and + // will be deserialized at runtime on the target device. + C10_NODISCARD c10::IValue serialize() const; + + // Execute a registered function. + C10_NODISCARD c10::impl::GenericList run( + const c10::QualifiedName& function_name, + const c10::impl::GenericList& inputs) const; + + // Register a function to the compilation unit. + void register_function(std::unique_ptr fn); + + private: + C10_NODISCARD Function* find_function(const c10::QualifiedName& qn) const; + + std::unordered_map> functions_; +}; + +} // namespace nnc +} // namespace mobile +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/mobile/nnc/registry.cpp b/torch/csrc/jit/mobile/nnc/registry.cpp new file mode 100644 index 0000000000000..0dc5648463402 --- /dev/null +++ b/torch/csrc/jit/mobile/nnc/registry.cpp @@ -0,0 +1,14 @@ +#include + +namespace torch { +namespace jit { +namespace mobile { +namespace nnc { + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +C10_DEFINE_REGISTRY(NNCKernelRegistry, NNCKernel); + +} // namespace nnc +} // namespace mobile +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/mobile/nnc/registry.h b/torch/csrc/jit/mobile/nnc/registry.h new file mode 100644 index 0000000000000..14c6939d4c4f4 --- /dev/null +++ b/torch/csrc/jit/mobile/nnc/registry.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include + +namespace torch { +namespace jit { +namespace mobile { +namespace nnc { + +using nnc_kernel_function_type = int(void**); + +struct TORCH_API NNCKernel { + virtual ~NNCKernel() = default; + virtual int execute(void** /* args */) = 0; +}; + +C10_DECLARE_REGISTRY(NNCKernelRegistry, NNCKernel); + +#define REGISTER_NNC_KERNEL(id, kernel, ...) \ + extern "C" { \ + nnc_kernel_function_type kernel; \ + } \ + struct NNCKernel_##kernel : public NNCKernel { \ + int execute(void** args) override { \ + return kernel(args); \ + } \ + }; \ + C10_REGISTER_TYPED_CLASS(NNCKernelRegistry, id, NNCKernel_##kernel); + +namespace registry { + +inline bool has_nnc_kernel(const std::string& id) { + return NNCKernelRegistry()->Has(id); +} + +inline std::unique_ptr get_nnc_kernel(const std::string& id) { + return NNCKernelRegistry()->Create(id); +} + +} // namespace registry + +} // namespace nnc +} // namespace mobile +} // namespace jit +} // namespace torch