Skip to content

Commit

Permalink
Revert D33284352: [jit][edge] Do not reuse mobile type parser for all…
Browse files Browse the repository at this point in the history
… unpicklers.

Test Plan: revert-hammer

Differential Revision:
D33284352 (pytorch@0a921ba)

Original commit changeset: 997c4f110b36

Original Phabricator Diff: D33284352 (pytorch@0a921ba)

fbshipit-source-id: af316727442a64f1ae40d53d7a9d26ec550d634e
  • Loading branch information
zhxchen17 authored and facebook-github-bot committed Jan 8, 2022
1 parent f626bef commit 9762aa0
Show file tree
Hide file tree
Showing 20 changed files with 52 additions and 79 deletions.
6 changes: 5 additions & 1 deletion test/cpp/jit/test_mobile_type_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
#include <test/cpp/jit/test_utils.h>

#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/mobile/type_parser.h>

namespace c10 {
TypePtr parseType(const std::string& pythonStr);
std::vector<TypePtr> parseType(std::vector<std::string>& pythonStr);
} // namespace c10

namespace torch {
namespace jit {
Expand Down
8 changes: 5 additions & 3 deletions test/mobile/test_upgrader_bytecode_table_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
* cd ~/pytorch && python torch/csrc/jit/mobile/upgrader_mobile.cpp
*/

#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/mobile/upgrader_mobile.h>

#include <ATen/core/ivalue.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/mobile/type_parser.h>

namespace c10 {
TypePtr parseType(const std::string& pythonStr);
} // namespace c10

namespace torch {
namespace jit {
Expand Down
8 changes: 5 additions & 3 deletions tools/codegen/operator_versions/gen_mobile_upgraders.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ class ByteCode(Enum):
* cd ~/pytorch && python torch/csrc/jit/mobile/upgrader_mobile.cpp
*/
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
#include <ATen/core/ivalue.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/mobile/type_parser.h>
namespace c10 {
TypePtr parseType(const std::string& pythonStr);
} // namespace c10
namespace torch {
namespace jit {
Expand Down
9 changes: 5 additions & 4 deletions torch/csrc/jit/frontend/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ struct Tree;
using TreeRef = c10::intrusive_ptr<Tree>;
using TreeList = at::SmallVector<TreeRef, 4>;

static const TreeList empty_trees = {};

struct Tree : c10::intrusive_ptr_target {
Tree(int kind_) : kind_(kind_) {}
int kind() const {
Expand All @@ -44,7 +46,6 @@ struct Tree : c10::intrusive_ptr_target {
throw std::runtime_error("stringValue can only be called on TK_STRING");
}
virtual const TreeList& trees() const {
static const TreeList empty_trees = {};
return empty_trees;
}
const TreeRef& tree(size_t i) const {
Expand Down Expand Up @@ -148,11 +149,11 @@ struct Compound : public Tree {
return false;
}
TreeRef map(const std::function<TreeRef(TreeRef)>& fn) override {
TreeList ret;
TreeList trees_;
for (auto& t : trees()) {
ret.push_back(fn(t));
trees_.push_back(fn(t));
}
return Compound::create(kind(), range(), std::move(ret));
return Compound::create(kind(), range(), std::move(trees_));
}

const SourceRange& range() const override {
Expand Down
12 changes: 4 additions & 8 deletions torch/csrc/jit/mobile/debug_info.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include <torch/csrc/jit/frontend/source_range.h>
#include <torch/csrc/jit/mobile/debug_info.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
#include <torch/csrc/jit/serialization/source_range_serialization.h>

Expand Down Expand Up @@ -123,13 +122,10 @@ MobileDebugTable::MobileDebugTable(
size_t debug_size{0};
std::tie(debug_data, debug_size) = reader->getRecord(record_name);
auto ivalues =
std::move(*jit::unpickle(
reinterpret_cast<const char*>(debug_data.get()),
debug_size,
nullptr,
{},
c10::parseType)
.toTuple())
std::move(
*jit::unpickle(
reinterpret_cast<const char*>(debug_data.get()), debug_size)
.toTuple())
.elements();
SourceRangeDeserializer deserializer;
for (auto& val : ivalues) {
Expand Down
9 changes: 6 additions & 3 deletions torch/csrc/jit/mobile/import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
Expand Down Expand Up @@ -79,6 +78,11 @@
// - Argument::{known_length_,kwarg_only_}
// - FunctionSchema::{overload_name_, is_vararg_, is_varret_}

namespace c10 {
// std::string serializeType(const Type &t);
TypePtr parseType(const std::string& pythonStr);
} // namespace c10

namespace torch {
namespace jit {
using caffe2::serialize::IStreamAdapter;
Expand Down Expand Up @@ -498,8 +502,7 @@ c10::IValue BytecodeDeserializer::readArchive(
type_resolver,
obj_loader,
device_,
*reader_.get(),
nullptr);
*reader_.get());
return ivalues;
}

Expand Down
10 changes: 6 additions & 4 deletions torch/csrc/jit/mobile/import_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/unpickler.h>
#include <torch/custom_class.h>
Expand All @@ -15,6 +14,11 @@
#include <string>
#include <vector>

namespace c10 {
// std::string serializeType(const Type &t);
TypePtr parseType(const std::string& pythonStr);
} // namespace c10

namespace torch {
namespace jit {
using caffe2::serialize::IStreamAdapter;
Expand Down Expand Up @@ -147,9 +151,7 @@ c10::IValue BytecodeDeserializer::readArchive(
std::move(obj_loader),
std::move(read_record),
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(device),
false,
nullptr);
std::move(device));
return unpickler.parse_ivalue();
}

Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/jit/mobile/model_compatibility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ c10::IValue readArchive(
type_resolver,
obj_loader,
device,
stream_reader,
nullptr);
stream_reader);
return ivalues;
}

Expand Down
2 changes: 0 additions & 2 deletions torch/csrc/jit/mobile/type_parser.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#include <torch/csrc/jit/mobile/type_parser.h>

#include <ATen/core/jit_type.h>
#include <c10/util/string_view.h>
#include <torch/csrc/jit/frontend/parser_constants.h>
Expand Down
2 changes: 0 additions & 2 deletions torch/csrc/jit/mobile/type_parser.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#pragma once

#include <ATen/core/dynamic_type.h>
#include <ATen/core/jit_type.h>

Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/runtime/register_ops_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,8 @@ void listAdd(Stack& stack) {
}

void listInplaceAdd(Stack& stack) {
c10::List<IValue> b = pop(stack).to<c10::List<IValue>>();
c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
c10::List<IValue> b = pop(stack).to<List<IValue>>();
c10::List<IValue> a = pop(stack).to<List<IValue>>();
a.append(std::move(b));
push(stack, std::move(a));
}
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/runtime/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
TORCH_SELECTIVE_SCHEMA(
"aten::index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"),
[](Stack& stack) {
auto indices = pop(stack).to<c10::List<c10::optional<at::Tensor>>>();
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
auto self = pop(stack).toTensor();
auto result = at::index(self, indices);
push(stack, std::move(result));
Expand All @@ -986,7 +986,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
auto unsafe = pop(stack).toBool();
auto accumulate = pop(stack).toBool();
auto values = pop(stack).toTensor();
auto indices = pop(stack).to<c10::List<c10::optional<at::Tensor>>>();
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
auto self = pop(stack).toTensor();
auto result =
at::_index_put_impl_(self, indices, values, accumulate, unsafe);
Expand All @@ -999,7 +999,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
[](Stack& stack) {
auto accumulate = pop(stack).toBool();
auto values = pop(stack).toTensor();
auto indices = pop(stack).to<c10::List<c10::optional<at::Tensor>>>();
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
auto self = pop(stack).toTensor();
auto result = at::index_put_(self, indices, values, accumulate);
push(stack, std::move(result));
Expand All @@ -1011,7 +1011,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
[](Stack& stack) {
auto accumulate = pop(stack).toBool();
auto values = pop(stack).toTensor();
auto indices = pop(stack).to<c10::List<c10::optional<at::Tensor>>>();
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
auto self = pop(stack).toTensor();
auto result = at::index_put_(self, indices, values, accumulate);
push(stack, std::move(result));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
#include <torch/csrc/jit/serialization/pickle.h>

Expand Down Expand Up @@ -215,12 +214,7 @@ ska::flat_hash_map<int64_t, DebugInfoTuple> CallStackDebugInfoUnpickler::
size_t size,
const ska::flat_hash_map<int64_t, SourceRange>& source_range_map,
const std::shared_ptr<CompilationUnit>& cu) {
auto ival = jit::unpickle(
reinterpret_cast<const char*>(data.get()),
size,
nullptr,
{},
c10::parseType);
auto ival = jit::unpickle(reinterpret_cast<const char*>(data.get()), size);
ska::flat_hash_map<int64_t, DebugInfoTuple> callstack_ptrs;
auto ivalues = std::move(*std::move(ival).toTuple()).elements();
for (auto& val : ivalues) {
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/serialization/import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) {
obj_loader,
device_,
*reader_.get(),
nullptr,
storage_context_);
}

Expand Down
2 changes: 0 additions & 2 deletions torch/csrc/jit/serialization/import_read.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ IValue readArchiveAndTensors(
c10::optional<ObjLoader> obj_loader,
c10::optional<at::Device> device,
caffe2::serialize::PyTorchStreamReader& stream_reader,
c10::TypePtr (*type_parser)(const std::string&),
std::shared_ptr<DeserializationStorageContext> storage_context) {
std::string picklename = pickle_prefix + archive_name + ".pkl";
at::DataPtr pickle_ptr;
Expand Down Expand Up @@ -48,7 +47,6 @@ IValue readArchiveAndTensors(
std::move(read_record),
device,
false,
type_parser,
storage_context);
unpickler.set_version(stream_reader.version());
return unpickler.parse_ivalue();
Expand Down
2 changes: 0 additions & 2 deletions torch/csrc/jit/serialization/import_read.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ TORCH_API IValue readArchiveAndTensors(
c10::optional<ObjLoader> obj_loader,
c10::optional<at::Device> device,
caffe2::serialize::PyTorchStreamReader& stream_reader,
c10::TypePtr (*type_parser)(const std::string&) =
Unpickler::defaultTypeParser,
std::shared_ptr<DeserializationStorageContext> storage_context = nullptr);

bool check_zip_file(
Expand Down
11 changes: 4 additions & 7 deletions torch/csrc/jit/serialization/pickle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,17 @@ IValue pickle_load(const std::vector<char>& data) {
IValue unpickle(
std::function<size_t(char*, size_t)> reader,
TypeResolver type_resolver,
c10::ArrayRef<at::Tensor> tensor_table,
c10::TypePtr (*type_parser)(const std::string&)) {
c10::ArrayRef<at::Tensor> tensor_table) {
Unpickler unpickler(
std::move(reader), std::move(type_resolver), tensor_table, type_parser);
std::move(reader), std::move(type_resolver), tensor_table);
return unpickler.parse_ivalue();
}

IValue unpickle(
const char* data,
size_t size,
TypeResolver type_resolver,
c10::ArrayRef<at::Tensor> tensor_table,
c10::TypePtr (*type_parser)(const std::string&)) {
c10::ArrayRef<at::Tensor> tensor_table) {
size_t bytes_read = 0;
return unpickle(
[&](char* buffer, size_t len) -> size_t {
Expand All @@ -147,8 +145,7 @@ IValue unpickle(
return len;
},
std::move(type_resolver),
tensor_table,
type_parser);
tensor_table);
}

} // namespace jit
Expand Down
8 changes: 2 additions & 6 deletions torch/csrc/jit/serialization/pickle.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ TORCH_API IValue pickle_load(const std::vector<char>& data);
TORCH_API IValue unpickle(
std::function<size_t(char*, size_t)> reader,
TypeResolver type_resolver,
c10::ArrayRef<at::Tensor> tensor_table,
c10::TypePtr (*type_parser)(const std::string&) =
Unpickler::defaultTypeParser);
c10::ArrayRef<at::Tensor> tensor_table);

/// Decode a chunk of memory containing pickled data into its `torch::IValue`s.
///
Expand All @@ -83,9 +81,7 @@ TORCH_API IValue unpickle(
const char* data,
size_t size,
TypeResolver type_resolver = nullptr,
c10::ArrayRef<at::Tensor> tensor_table = {},
c10::TypePtr (*type_parser)(const std::string&) =
Unpickler::defaultTypeParser);
c10::ArrayRef<at::Tensor> tensor_table = {});

} // namespace jit
} // namespace torch
2 changes: 1 addition & 1 deletion torch/csrc/jit/serialization/unpickler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ void Unpickler::readGlobal(
if (type_resolver_ == nullptr) {
// If we haven't injected a custom way of retrieving types from
// names, use a barebones type parser.
type = type_parser_(type_str);
type = c10::parseType(type_str);
} else {
type = type_resolver_(type_str).type_;
}
Expand Down
Loading

0 comments on commit 9762aa0

Please sign in to comment.