Skip to content

Commit

Permalink
[jit][edge] Load interface methods to corresponding ClassTypes. (pyto…
Browse files Browse the repository at this point in the history
…rch#65971)

Summary:
Pull Request resolved: pytorch#65971

ghstack-source-id: 141842335

We should be able to load methods into their ClassTypes. Right now mobile runtime only loads data member to ClassTypes but not for methods. To support interface call, we inject methods into ClassTypes when the methods are loaded.

Test Plan: existing tests should all pass.

Reviewed By: qihqi

Differential Revision: D31326146

fbshipit-source-id: fb1dbea619910ef1f8fa26146da3ebab348fe902
  • Loading branch information
zhxchen17 authored and facebook-github-bot committed Oct 29, 2021
1 parent 6259601 commit d6b15bf
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
19 changes: 18 additions & 1 deletion torch/csrc/jit/mobile/import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <torch/csrc/jit/mobile/parse_operators.h>

#include <ATen/core/ivalue.h>
#include <ATen/core/qualified_name.h>
#include <c10/util/Exception.h>
#include <c10/util/ScopeExit.h>
#include <c10/util/irange.h>
#include <caffe2/serialize/inline_container.h>
Expand Down Expand Up @@ -171,6 +173,19 @@ bool isTensorInBytecodeArchive(

namespace {

void tryRegisterMethod(const std::vector<c10::Argument>& args, Function& func) {
if (args.empty() || args[0].name() != "self") {
return;
}

if (auto cls = args[0].type()->castRaw<ClassType>()) {
if (C10_UNLIKELY(cls->findMethod(func.name()))) {
return;
}
cls->addMethod(&func);
}
}

// The deserializer class which loads the bytecode package from bc files.
class BytecodeDeserializer final {
public:
Expand Down Expand Up @@ -227,7 +242,8 @@ void BytecodeDeserializer::parseFunctionSchema(
mobile::Function* function) {
// function schema
if (schemaTable) { // (schema is optional for back compat)
auto parseArgList = [this](c10::ivalue::TupleElements&& argTables) {
auto parseArgList = [this,
function](c10::ivalue::TupleElements&& argTables) {
std::vector<c10::Argument> args;
for (auto&& argTable : std::move(argTables)) {
auto argTableElements =
Expand All @@ -249,6 +265,7 @@ void BytecodeDeserializer::parseFunctionSchema(
c10::nullopt /*N*/,
std::move(default_value));
}
tryRegisterMethod(args, *function);
return args;
};
auto schemaTableElements =
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/mobile/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Method Module::get_method(const std::string& name) const {
}

c10::optional<Method> Module::find_method(const std::string& basename) const {
for (auto& fn : cu_->methods()) {
for (const auto& fn : cu_->methods()) {
if (fn->name() == basename) {
return c10::make_optional<Method>(Method(this, fn.get()));
}
Expand Down

0 comments on commit d6b15bf

Please sign in to comment.