Skip to content

Commit

Permalink
Add ability for a mobile::Module to save as flatbuffer (pytorch#67351)
Browse files Browse the repository at this point in the history
Summary:
Included functions:

* save_mobile_module -> saves a mobile::Module to flatbuffer
* load_mobile_module_from_file -> loads a flatbuffer into mobile::Module
* parse_mobile_module -> parses from bytes or deserialized flatbuffer
      Module object

Fixes #{issue number}

Pull Request resolved: pytorch#67351

Reviewed By: iseeyuan

Differential Revision: D32010095

Pulled By: qihqi

fbshipit-source-id: d763b0557780f7c2661b6485105b045e41a5e8f1
  • Loading branch information
qihqi authored and facebook-github-bot committed Dec 2, 2021
1 parent 40fb28e commit 41d35dc
Show file tree
Hide file tree
Showing 29 changed files with 2,307 additions and 19 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,6 @@ pr.diff

# coverage files
*/**/.coverage.*

# generated flatbuffer schema header
torch/csrc/jit/serialization/mobile_bytecode_generated.h
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,6 @@
[submodule "third_party/breakpad"]
path = third_party/breakpad
url = https://github.com/driazati/breakpad.git
[submodule "third_party/flatbuffers"]
path = third_party/flatbuffers
url = https://github.com/google/flatbuffers.git
2 changes: 2 additions & 0 deletions .jenkins/pytorch/win-build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ if [[ $PYLONG_API_CHECK == 0 ]]; then
fi
set -ex

echo 'python %*' > /c/Windows/py.bat

"$SCRIPT_HELPERS_DIR"/build_pytorch.bat

assert_git_not_dirty
Expand Down
38 changes: 38 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ load("//:tools/build_variables.bzl", "torch_cpp_srcs", "libtorch_python_core_sou
load("//tools/rules:cu.bzl", "cu_library")
load("//tools/config:defs.bzl", "if_cuda")
load("//:aten.bzl", "intern_build_aten_ops")
load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")

COMMON_COPTS = [
"-DHAVE_MALLOC_USABLE_SIZE=1",
Expand Down Expand Up @@ -1833,6 +1834,14 @@ genrule(
tools = [':gen_version_header']
)

flatbuffer_cc_library(
name = "mobile_bytecode_header",
srcs = ["torch/csrc/jit/serialization/mobile_bytecode.fbs"],
out_prefix = "torch/csrc/jit/serialization/",
flatc_args=["--gen-mutable", "--scoped-enums",],
)


torch_cuda_headers = glob(["torch/csrc/cuda/*.h"])
cc_library(
name = "torch_headers",
Expand Down Expand Up @@ -1864,6 +1873,7 @@ cc_library(
":aten_headers",
":c10_headers",
":caffe2_headers",
":mobile_bytecode_header",
"@local_config_python//:python_headers",
"@onnx",
],
Expand Down Expand Up @@ -1906,6 +1916,32 @@ cc_library(
alwayslink = True,
)

cc_library(
name = "flatbuffer_loader",
srcs = [
"torch/csrc/jit/mobile/flatbuffer_loader.cpp"
],
hdrs = [
"torch/csrc/jit/mobile/flatbuffer_loader.h"
],
deps = [
":torch"
]
)

cc_library(
name = "flatbuffer_serializer",
srcs = [
"torch/csrc/jit/serialization/flatbuffer_serializer.cpp"
],
hdrs = [
"torch/csrc/jit/serialization/flatbuffer_serializer.h"
],
deps = [
":torch"
]
)

cc_library(
name = "shm",
srcs = glob(["torch/lib/libshm/*.cpp"]),
Expand Down Expand Up @@ -2056,6 +2092,8 @@ cc_test(
],
deps = [
":torch",
":flatbuffer_serializer",
":flatbuffer_loader",
"@com_google_googletest//:gtest_main",
],
)
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,7 @@ if(NOT MSVC)
string(APPEND CMAKE_CXX_FLAGS " -Wno-error=deprecated-declarations")
if(CMAKE_COMPILER_IS_GNUCXX AND NOT (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0.0))
string(APPEND CMAKE_CXX_FLAGS " -Wno-stringop-overflow")
string(APPEND CMAKE_CXX_FLAGS " -Wno-noexcept-type")
endif()
if(CMAKE_COMPILER_IS_GNUCXX)
# Suppress "The ABI for passing parameters with 64-byte alignment has changed in GCC 4.6"
Expand Down
5 changes: 5 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,8 @@ new_empty_repository(
name = "cuda",
build_file = "//third_party:cuda.BUILD",
)

local_repository(
name = "com_github_google_flatbuffers",
path = "third_party/flatbuffers",
)
29 changes: 29 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/model_compatibility.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/module.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/flatbuffer_loader.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_bytecode.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_operators.cpp
Expand Down Expand Up @@ -591,8 +592,10 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/serialization/export.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/export_bytecode.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/export_module.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/flatbuffer_serializer.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp
${TORCH_SRC_DIR}/csrc/jit/api/module_save.cpp
${TORCH_SRC_DIR}/csrc/jit/testing/module_differ.cpp
${TORCH_SRC_DIR}/csrc/utils/byte_order.cpp
)

Expand Down Expand Up @@ -1640,6 +1643,32 @@ if(APPLE AND USE_PYTORCH_METAL)
endif()
endif()


set(schema ${TORCH_SRC_DIR}/csrc/jit/serialization/mobile_bytecode.fbs)
set(generated_include
"${TORCH_ROOT}/build/torch/csrc/jit/serialization/mobile_bytecode_generated.h")
## cann add--reflect-names
add_custom_command(
OUTPUT ${generated_include}
COMMAND bash ${TORCH_ROOT}/scripts/gen_flatbuffer.sh
DEPENDS ${schema}
WORKING_DIRECTORY "${TORCH_ROOT}"
COMMENT "Generating mobile_bytecode_generated.h"
)
add_library(mobile_bytecode_generated_h INTERFACE)
target_sources(
mobile_bytecode_generated_h
INTERFACE ${generated_include}
)
add_dependencies(mobile_bytecode_generated_h flatc ${generated_include})
target_include_directories(
mobile_bytecode_generated_h
INTERFACE ${TORCH_ROOT}/build)

add_dependencies(torch_cpu mobile_bytecode_generated_h)
target_link_libraries(
torch_cpu PRIVATE mobile_bytecode_generated_h flatbuffers)

# Note [Global dependencies]
# Some libraries (e.g. OpenMPI) like to dlopen plugins after they're initialized,
# and they assume that all of their symbols will be available in the global namespace.
Expand Down
3 changes: 3 additions & 0 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1995,3 +1995,6 @@ if(USE_KINETO)
message(STATUS "Configured Kineto")
endif()
endif()

# Include google/FlatBuffers
include(${CMAKE_CURRENT_LIST_DIR}/FlatBuffers.cmake)
2 changes: 2 additions & 0 deletions cmake/FlatBuffers.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON CACHE BOOL "" FORCE)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/flatbuffers ${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-build EXCLUDE_FROM_ALL)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Python dependencies required for development
astunparse
expecttest
flatbuffers
future
numpy
psutil
Expand Down
15 changes: 15 additions & 0 deletions scripts/gen_flatbuffer.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash
ROOT=$(pwd)
FF_LOCATION="$ROOT/third_party/flatbuffers"
cd "$FF_LOCATION" || exit
mkdir build
cd build || exit
py() { command python "$@"; }
cmake ..
cmake --build . --target flatc
mkdir -p "$ROOT/build/torch/csrc/jit/serialization"
./flatc --cpp --gen-mutable --scoped-enums \
-o "$ROOT/build/torch/csrc/jit/serialization" \
-c "$ROOT/torch/csrc/jit/serialization/mobile_bytecode.fbs"
cd "$ROOT" || exit
exit
3 changes: 3 additions & 0 deletions test/cpp/jit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ add_executable(test_jit
${TORCH_ROOT}/test/cpp/common/main.cpp
${JIT_TEST_SRCS}
)
add_dependencies(test_jit flatbuffers)
target_link_libraries(test_jit PRIVATE flatbuffers)


# TODO temporary until we can delete the old gtest polyfills.
target_compile_definitions(test_jit PRIVATE USE_GTEST)
Expand Down
58 changes: 57 additions & 1 deletion test/cpp/jit/test_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/torch.h>

// Tests go in torch::jit
namespace torch {
namespace jit {

mobile::Module load_mobile_module(void* data, size_t) {
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
return initialize_mobile_module(flatbuffer_module);
}

TEST(BackendTest, ToBackend) {
Module m("m");
m.define(R"(
Expand Down Expand Up @@ -141,6 +149,11 @@ TEST(BackendTest, TestCompiler) {
auto mlm = _load_for_mobile(ss);
auto mres = mlm.forward(inputs);
AT_ASSERT(mres.toTensor().equal(ref.toTensor()));

auto buff = save_mobile_module_to_bytes(mlm);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
auto mres2 = mlm2.forward(inputs);
AT_ASSERT(mres2.toTensor().equal(ref.toTensor()));
}

TEST(BackendTest, TestComposite) {
Expand Down Expand Up @@ -183,8 +196,12 @@ TEST(BackendTest, TestComposite) {
c._save_for_mobile(ss);
auto mc = _load_for_mobile(ss);
auto res_mobile = mc.forward(inputs);

AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));

auto buff = save_mobile_module_to_bytes(mc);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
auto mres2 = mlm2.forward(inputs);
AT_ASSERT(mres2.toTensor().equal(res_jit.toTensor()));
}

Module getCompositeModuleWithSameNameSubModules() {
Expand Down Expand Up @@ -241,6 +258,11 @@ TEST(BackendTest, TestCompositeWithSetStates) {
auto mc = _load_for_mobile(ss);
auto res_mobile = mc.forward(inputs);
AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));

auto buff = save_mobile_module_to_bytes(mc);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
auto mres2 = mlm2.forward(inputs);
AT_ASSERT(mres2.toTensor().equal(res_jit.toTensor()));
}

TEST(BackendTest, TestConsistencyOfCompositeWithSetStates) {
Expand All @@ -256,6 +278,11 @@ TEST(BackendTest, TestConsistencyOfCompositeWithSetStates) {
auto mc = _load_for_mobile(ss);
auto res_mobile = mc.forward(inputs);

auto buff = save_mobile_module_to_bytes(mc);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
auto mres2 = mlm2.forward(inputs);
AT_ASSERT(mres2.toTensor().equal(res_mobile.toTensor()));

// check if the methods names are always the same
// by reloading the script module and saving it back as mobile
// The below checks ensure that the names of Methods
Expand Down Expand Up @@ -354,6 +381,13 @@ Traceback of TorchScript (most recent call last):
~~~~~ <--- HERE
)";
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);

/* TODO(add debug info to flatbuffer)
auto buff = save_mobile_module_to_bytes(mlm);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
mlm2.forward(inputs);
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
*/
}

TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithModuleHierarchy) {
Expand Down Expand Up @@ -414,6 +448,12 @@ Traceback of TorchScript (most recent call last):
~~~~~ <--- HERE
)";
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);

/* TODO(add debug info to flatbuffer)
auto buff = save_mobile_module_to_bytes(mlm);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
*/
}

TEST(
Expand Down Expand Up @@ -512,7 +552,13 @@ Traceback of TorchScript (most recent call last):
return x + y
~~~~~ <--- HERE
)";

ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
/* TODO(add debug info to flatbuffer)
auto buff = save_mobile_module_to_bytes(mlm);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
*/
}

TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithLoweredSubModule) {
Expand Down Expand Up @@ -594,6 +640,11 @@ Traceback of TorchScript (most recent call last):
~~~~~ <--- HERE
)";
ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern);
/* TODO(add debug info to flatbuffer)
auto buff = save_mobile_module_to_bytes(c_loaded);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
*/
}

TEST(
Expand Down Expand Up @@ -721,6 +772,11 @@ Traceback of TorchScript (most recent call last):
~~~~~ <--- HERE
)";
ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern);
/* TODO(add debug info to flatbuffer)
auto buff = save_mobile_module_to_bytes(c_loaded);
mobile::Module mlm2 = load_mobile_module(buff.data(), buff.size());
ASSERT_THROWS_WITH_MESSAGE(mlm2.forward(inputs), error_pattern);
*/
}

} // namespace jit
Expand Down
Loading

0 comments on commit 41d35dc

Please sign in to comment.