Skip to content

Commit 81b4320

Browse files
Elias Ellisonfacebook-github-bot
authored andcommitted
Refactor Type Parser b/w Schemas & IRParser into a type common parser (#17383)
Summary: Creates a new shared type parser to be shared between the IR parser and the Schema Parser. Also adds parsing of CompleteTensorType and DimensionedTensorType, and feature-gates that for the IRParser. Renames the existing type_parser for python annotations, python_type_parser, and names the new one jit_type_parser. Pull Request resolved: #17383 Differential Revision: D14186438 Pulled By: eellison fbshipit-source-id: bbd5e337917d8862c7c6fa0a0006efa101c76afe
1 parent b0c1857 commit 81b4320

File tree

11 files changed

+346
-179
lines changed

11 files changed

+346
-179
lines changed

test/cpp/jit/test_irparser.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,67 @@ graph(%0 : Tensor,
150150
return (%7)
151151
)IR");
152152
}
153+
154+
{
155+
checkRoundtrip(
156+
R"IR(
157+
graph(%0 : Tensor,
158+
%1 : Tensor,
159+
%2 : Tensor):
160+
%3 : int? = prim::Constant()
161+
return (%3)
162+
)IR");
163+
}
164+
165+
{
166+
checkRoundtrip(
167+
R"IR(
168+
graph(%0 : Tensor,
169+
%1 : Tensor,
170+
%2 : Tensor):
171+
%3 : Float(*, *, *) = prim::Constant()
172+
return (%3)
173+
)IR");
174+
}
175+
176+
{
177+
checkRoundtrip(
178+
R"IR(
179+
graph(%0 : Tensor,
180+
%1 : Tensor,
181+
%2 : Tensor):
182+
%3 : Long() = prim::Constant()
183+
return (%3)
184+
)IR");
185+
}
186+
187+
{
188+
checkRoundtrip(
189+
R"IR(
190+
graph(%0 : Tensor,
191+
%1 : Tensor,
192+
%2 : Tensor):
193+
%3 : Double(4, 4, 5) = prim::Constant()
194+
return (%3)
195+
)IR");
196+
}
197+
198+
{
199+
bool error_thrown = false;
200+
try {
201+
checkRoundtrip(
202+
R"IR(
203+
graph(%0 : Tensor,
204+
%1 : Tensor,
205+
%2 : Tensor):
206+
%3 : Double(4!, 4, 5) = prim::Constant()
207+
return (%3)
208+
)IR");
209+
} catch (const std::exception& error) {
210+
error_thrown = true;
211+
}
212+
AT_ASSERT(error_thrown);
213+
}
153214
}
154215
} // namespace jit
155216
} // namespace torch

tools/build_variables.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@
9494
"torch/csrc/jit/script/compiler.cpp",
9595
"torch/csrc/jit/script/edit_distance.cpp",
9696
"torch/csrc/jit/script/final_returns.cpp",
97-
"torch/csrc/jit/script/type_parser.cpp",
97+
"torch/csrc/jit/script/schema_type_parser.cpp",
98+
"torch/csrc/jit/script/script_type_parser.cpp",
9899
"torch/csrc/jit/script/sugared_value.cpp",
99100
"torch/csrc/jit/script/schema_matching.cpp",
100101
"torch/csrc/jit/script/parser.cpp",

torch/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ set(TORCH_SRCS
174174
${TORCH_SRC_DIR}/csrc/jit/testing/file_check.cpp
175175
${TORCH_SRC_DIR}/csrc/jit/script/final_returns.cpp
176176
${TORCH_SRC_DIR}/csrc/jit/script/schema_matching.cpp
177-
${TORCH_SRC_DIR}/csrc/jit/script/type_parser.cpp
177+
${TORCH_SRC_DIR}/csrc/jit/script/schema_type_parser.cpp
178+
${TORCH_SRC_DIR}/csrc/jit/script/script_type_parser.cpp
178179
${TORCH_SRC_DIR}/csrc/jit/script/sugared_value.cpp
179180
${TORCH_SRC_DIR}/csrc/jit/script/parser.cpp
180181
${TORCH_SRC_DIR}/csrc/jit/script/builtin_functions.cpp

torch/csrc/jit/irparser.cpp

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <torch/csrc/jit/ir.h>
33
#include <torch/csrc/jit/script/lexer.h>
44
#include <torch/csrc/jit/script/parse_string_literal.h>
5+
#include <torch/csrc/jit/script/schema_type_parser.h>
56

67
#include <string>
78
#include <vector>
@@ -16,7 +17,9 @@ struct ParsedLiteral;
1617
class IRParser {
1718
friend void parseIR(const std::string& str, torch::jit::Graph* graph);
1819
IRParser(const std::string& str, torch::jit::Graph* graph)
19-
: L(str), g(graph) {}
20+
: L(str),
21+
g(graph),
22+
type_parser(L, /*parse_complete_tensor_types*/ true) {}
2023

2124
std::string parseVar();
2225
VarWithType parseVarWithType();
@@ -48,6 +51,7 @@ class IRParser {
4851
torch::jit::script::Lexer L;
4952
torch::jit::Graph* g = nullptr;
5053
std::unordered_map<std::string, Value*> vmap;
54+
SchemaTypeParser type_parser;
5155
};
5256

5357
struct ParsedLiteral {
@@ -66,31 +70,14 @@ struct ParsedLiteral {
6670
struct VarWithType {
6771
VarWithType() = default;
6872
std::string name;
69-
std::string type;
73+
TypePtr type;
7074
};
7175

7276
void parseIR(const std::string& str, torch::jit::Graph* graph) {
7377
torch::jit::script::IRParser p(str, graph);
7478
p.parse();
7579
}
7680

77-
TypePtr parseType(const std::string& s) {
78-
if (s == "Tensor") {
79-
return TensorType::get();
80-
}
81-
if (s == "int") {
82-
return IntType::get();
83-
}
84-
if (s == "float") {
85-
return FloatType::get();
86-
}
87-
if (s == "string") {
88-
return StringType::get();
89-
}
90-
// TODO: Support other types.
91-
AT_ASSERTM(false, "Type not supported by parser:", s);
92-
}
93-
9481
VarWithType IRParser::parseVarWithType() {
9582
L.expect('%');
9683
VarWithType r;
@@ -99,9 +86,11 @@ VarWithType IRParser::parseVarWithType() {
9986
} else {
10087
r.name = L.expect(TK_NUMBER).text();
10188
}
102-
r.type = "Tensor";
89+
r.type = TensorType::get();
10390
if (L.nextIf(':')) {
104-
r.type = L.expect(TK_IDENT).text();
91+
auto type_alias = type_parser.parseType();
92+
AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled");
93+
r.type = type_alias.first;
10594
}
10695
return r;
10796
}
@@ -268,7 +257,7 @@ void IRParser::parseBlockInputs(Block* b) {
268257
// If the name isn't valid, don't use it
269258
std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
270259
vmap[v.name] = b->addInput(uniq_name);
271-
vmap[v.name]->setType(parseType(v.type));
260+
vmap[v.name]->setType(v.type);
272261
});
273262
}
274263

@@ -345,7 +334,7 @@ void IRParser::parseOperator(Block* b) {
345334
int idx = 0;
346335
for (const VarWithType& v : outs) {
347336
vmap[v.name] = n->outputs()[idx++];
348-
vmap[v.name]->setType(parseType(v.type));
337+
vmap[v.name]->setType(v.type);
349338
}
350339

351340
// Insert the new node into block B.
@@ -364,7 +353,7 @@ void IRParser::parseGraphInputs() {
364353
// If the name isn't valid, don't use it
365354
std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
366355
vmap[v.name] = g->addInput(uniq_name);
367-
vmap[v.name]->setType(parseType(v.type));
356+
vmap[v.name]->setType(v.type);
368357
});
369358
}
370359

@@ -433,7 +422,6 @@ void IRParser::parseList(
433422
L.expect(end);
434423
}
435424
}
436-
437425
} // namespace script
438426
} // namespace jit
439427
} // namespace torch

0 commit comments

Comments
 (0)