Skip to content

Commit

Permalink
Add option to load historic operators in IR when the operator is depr…
Browse files Browse the repository at this point in the history
…ecated (pytorch#71148)

Summary: Pull Request resolved: pytorch#71148

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D33521300

Pulled By: tugsbayasgalan

fbshipit-source-id: a0607dba5e7233590384326537017eb0b18da419
  • Loading branch information
tugsbayasgalan authored and facebook-github-bot committed Jan 12, 2022
1 parent 8f4cec2 commit 7095188
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 6 deletions.
37 changes: 37 additions & 0 deletions test/cpp/jit/test_upgrader_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include <test/cpp/jit/test_utils.h>

#include <vector>

namespace torch {
namespace jit {

Expand Down Expand Up @@ -58,5 +60,40 @@ TEST(UpgraderUtils, FindIfOpIsCurrent) {
test_only_remove_entry("foo");
}

TEST(UpgraderUtils, CanLoadHistoricOp) {
std::vector<UpgraderEntry> dummy_entry = {
{4, "foo__0_3", "foo.bar()"},
{8, "foo__4_7", "foo.foo()"},
};

std::vector<std::string> schemas = {"foo.bar()", "foo.foo()"};

// symbol based look up
test_only_add_entry("old_op_not_exist.first", dummy_entry[0]);
test_only_add_entry("old_op_not_exist.second", dummy_entry[1]);

auto oldSchemas = loadPossibleHistoricOps("old_op_not_exist", 2);
EXPECT_EQ(oldSchemas.size(), 2);
for (const auto& entry : oldSchemas) {
EXPECT_TRUE(
std::find(schemas.begin(), schemas.end(), entry) != schemas.end());
}

auto oldSchemasWithCurrentVersion =
loadPossibleHistoricOps("old_op_not_exist", 9);
EXPECT_EQ(oldSchemasWithCurrentVersion.size(), 0);

test_only_remove_entry("old_op_not_exist.first");
test_only_remove_entry("old_op_not_exist.first");

// it is ok to have old schemas without overload
test_only_add_entry("old_op_not_exist_no_overload", dummy_entry[0]);
auto oldSchemasNoOverload =
loadPossibleHistoricOps("old_op_not_exist_no_overload", 2);
EXPECT_EQ(oldSchemasNoOverload.size(), 1);
EXPECT_EQ(oldSchemasNoOverload[0], "foo.bar()");
test_only_remove_entry("old_op_not_exist_no_overload");
}

} // namespace jit
} // namespace torch
16 changes: 15 additions & 1 deletion torch/csrc/jit/frontend/schema_matching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,8 @@ static Value* emitBuiltinNode(
if (!version.has_value() ||
isOpSymbolCurrent(matched_schema.schema_name, version.value())) {
n->getOperation();
} else {
n->setHistoricSchemaName(matched_schema.schema_name);
}

return packOutputs(graph, n->outputs(), matched_schema.return_field_names);
Expand Down Expand Up @@ -678,6 +680,18 @@ Value* emitBuiltinCall(
schemas.push_back(&op->schema());
}

// we might have seen old historic
// ops that are deprecated
if (variants.empty()) {
auto oldSchemas =
loadPossibleHistoricOps(name.toQualString(), graph_version);
upgrader_schemas.reserve(oldSchemas.size());
for (const auto& old_schema_entry : oldSchemas) {
FunctionSchema old_schema = parseSchema(old_schema_entry);
upgrader_schemas.emplace_back(old_schema);
}
}

// TODO (tugsuu): make sure this is optimized later
for (const auto& schema : upgrader_schemas) {
schemas.push_back(&schema);
Expand Down Expand Up @@ -710,7 +724,7 @@ Value* emitBuiltinCall(

auto matched = matchSchemas(schemas, loc, graph, args, kwargs, self);

if (matched.first < variants.size()) {
if (matched.first < variants.size() + upgrader_schemas.size()) {
return emitBuiltinNode(matched.second, loc, graph, name, graph_version);
} else {
auto& fn = *builtin_functions[matched.first - variants.size()];
Expand Down
14 changes: 14 additions & 0 deletions torch/csrc/jit/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,12 @@ struct TORCH_API Node {
topo_position_t topo_position_ = 0;
// a managing wrapper for Python to allow invalidation
std::shared_ptr<Wrap<Node>> wrap_;
// Stores the full schema name, if the operator is historic
// When the operator is deprecated or the name of the operator
// is changed, we need to rely on this name
// to retrieve old schemas to successfully apply upgraders
// for this operator.
c10::optional<std::string> historic_schema_name_ = c10::nullopt;

protected:
Node(Graph* graph_, NodeKind kind_); // defined after graph
Expand All @@ -362,6 +368,14 @@ struct TORCH_API Node {
return wrap_;
}

const c10::optional<std::string> getHistoricSchemaName() {
return historic_schema_name_;
}

void setHistoricSchemaName(const std::string& name) {
historic_schema_name_ = name;
}

Node*& next() {
return next_in_graph[kNextDirection];
}
Expand Down
24 changes: 24 additions & 0 deletions torch/csrc/jit/operator_upgraders/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,29 @@ bool isOpSymbolCurrent(const std::string& name, size_t current_version) {
return true;
}

std::vector<std::string> loadPossibleHistoricOps(
const std::string& name,
c10::optional<size_t> version) {
std::vector<std::string> possibleSchemas;

if (!version.has_value()) {
return possibleSchemas;
}

for (const auto& entry : get_operator_version_map()) {
auto old_symbol_name = entry.first;
// strip off the overload name, if exist
auto base_name = old_symbol_name.substr(0, old_symbol_name.find('.'));
if (base_name == name) {
auto possibleUpgrader = findUpgrader(entry.second, version.value());
if (possibleUpgrader.has_value()) {
possibleSchemas.push_back(possibleUpgrader.value().old_schema);
}
}
}

return possibleSchemas;
}

} // namespace jit
} // namespace torch
8 changes: 8 additions & 0 deletions torch/csrc/jit/operator_upgraders/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,13 @@ TORCH_API bool isOpSymbolCurrent(
const std::string& name,
size_t current_version);

// Returns the possible old schemas for the operator that
// doesn't exist anymore. This can be true for deprecated
// operators. Since name is always a symbol name, there
// can be multiple schemas for different overloads.
TORCH_API std::vector<std::string> loadPossibleHistoricOps(
const std::string& name,
c10::optional<size_t> version);

} // namespace jit
} // namespace torch
18 changes: 13 additions & 5 deletions torch/csrc/jit/passes/replacement_of_old_operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,28 @@ struct OldOpsReplacerWithUpgraders {
DepthFirstGraphNodeIterator graph_it(graph_);
Node* node = graph_it.next();
while (node) {
if (auto schema = node->maybeSchema()) {
auto schema_name = getFullSchemaName(*schema);
// load the schema name for this op
c10::optional<std::string> schema_name = c10::nullopt;
if (auto op_schema = node->maybeSchema()) {
schema_name = getFullSchemaName(*op_schema);
} else {
schema_name = node->getHistoricSchemaName();
}

if (schema_name.has_value()) {
// this implies there was a version bump because of this operator
auto version_entry = get_operator_version_map().find(schema_name);
auto version_entry =
get_operator_version_map().find(schema_name.value());
if (version_entry != get_operator_version_map().end()) {
const auto& entry = version_entry->second;
auto upgrader_entry =
findUpgrader(version_entry->second, current_version);
if (!upgrader_entry.has_value()) {
if (!isOpSymbolCurrent(schema_name, current_version)) {
if (!isOpSymbolCurrent(schema_name.value(), current_version)) {
TORCH_INTERNAL_ASSERT(
false,
"Upgrader must be present for ",
schema_name,
schema_name.value(),
". The upgrader might have deprecated");
}
node = graph_it.next();
Expand Down

0 comments on commit 7095188

Please sign in to comment.