forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathaot_model_compiler.cc
141 lines (129 loc) · 5.18 KB
/
aot_model_compiler.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include <sstream>
#include <string>
#include <ATen/core/jit_type.h>
#include <c10/core/ScalarType.h>
#include <torch/csrc/jit/backends/backend.h>
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/csrc/jit/backends/backend_preprocess.h>
#include <torch/csrc/jit/mobile/nnc/aot_compiler.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <torch/script.h>
C10_DEFINE_string(model, "", "The torch script model to optimize.");
C10_DEFINE_string(model_name, "", "The name of the model.");
C10_DEFINE_string(model_version, "", "The version of the model.");
C10_DEFINE_string(
input_dims,
"",
"The dimensions of input TensorCPUs using comma separated numbers."
"If multiple inputs needed, use semicolon to separate "
"the dimension of different tensors.");
C10_DEFINE_string(
input_types,
"float",
"The dtype of input TensorCPUs."
"If multiple inputs needed, use semicolon to separate "
"the dtype of different tensors."
"Supported dtypes: float, int64, uint8");
C10_DEFINE_string(
input_memory_formats,
"",
"Input memory format."
"If multiple inputs needed, use semicolon to separate."
"Supported values: contiguous, channels_last");
C10_DEFINE_string(
dynamic_dims,
"",
"Comma separated dimensions of input tensors that can be dynamic");
C10_DEFINE_string(method_name, "forward", "The name of the method.");
C10_DEFINE_string(
output_llvm,
"",
"Name of the output llvm assembly to be saved.");
C10_DEFINE_string(output_model, "", "Name of the output model to be saved.");
namespace {
std::vector<std::string> split(
char separator,
const std::string& string,
bool ignore_empty = true) {
std::vector<std::string> pieces;
std::stringstream ss(string);
std::string item;
while (getline(ss, item, separator)) {
if (!ignore_empty || !item.empty()) {
pieces.push_back(std::move(item));
}
}
return pieces;
}
c10::Dict<c10::IValue, c10::IValue> createCompileSpec() {
c10::Dict<c10::IValue, c10::IValue> compile_spec(
c10::StringType::get(), c10::AnyType::get());
c10::Dict<c10::IValue, c10::IValue> method_spec(
c10::StringType::get(), c10::AnyType::get());
method_spec.insert("sizes", FLAGS_input_dims);
method_spec.insert("types", FLAGS_input_types);
method_spec.insert("memory_formats", FLAGS_input_memory_formats);
method_spec.insert("dynamic_sizes", FLAGS_dynamic_dims);
method_spec.insert("asmfile", FLAGS_output_llvm);
method_spec.insert("model_name", FLAGS_model_name);
method_spec.insert("model_version", FLAGS_model_version);
compile_spec.insert(FLAGS_method_name, method_spec);
return compile_spec;
}
} // namespace
int main(int argc, char** argv) {
c10::SetUsageMessage(
"Run NNC AOT compiler for pytorch model. Example usage:\n"
"build/bin/aot_model_compiler"
" --model=<model file>"
" --model_name=<model name>"
" --model_version=<model version>"
" --input_dims=<input dimensions like '1,3,224,224;2,2'>"
" --input_types=<input dtypes like 'float;float'>"
" --input_memory_formats=<input memory formats like 'channels_last;contiguous'>"
" [--method_name=<method name>]"
" [--output_llvm=<llvm assembly output file path>]"
" [--output_model=<output model file path>]");
if (!c10::ParseCommandLineFlags(&argc, &argv)) {
std::cerr << "Failed to parse command line flags!" << std::endl;
std::cout << c10::UsageMessage() << std::endl;
return 1;
}
CAFFE_ENFORCE(!FLAGS_model.empty(), c10::UsageMessage());
CAFFE_ENFORCE(!FLAGS_model_name.empty(), c10::UsageMessage());
CAFFE_ENFORCE(!FLAGS_model_version.empty(), c10::UsageMessage());
CAFFE_ENFORCE(!FLAGS_input_dims.empty(), c10::UsageMessage());
const auto dims_size = split(';', FLAGS_input_dims).size();
CAFFE_ENFORCE(
dims_size == split(';', FLAGS_input_types).size(),
"Number of input_dims and input_types should be the same");
const auto mem_formats_size = split(';', FLAGS_input_memory_formats).size();
CAFFE_ENFORCE(
mem_formats_size == 0 || mem_formats_size == dims_size,
"Number of input_memory_formats should be 0 (default contiguous) or the same as number of input_dims");
if (FLAGS_output_llvm.empty()) {
FLAGS_output_llvm =
FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.ll";
}
std::string output_model_name = FLAGS_output_model;
if (output_model_name.empty()) {
output_model_name =
FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.pt";
}
auto m = torch::jit::load(FLAGS_model);
m.eval();
auto frozen_m = torch::jit::freeze_module(m.clone());
auto compile_spec = createCompileSpec();
auto any_dict_ty =
c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
auto compiled_module = torch::jit::detail::codegen_backend_module(
"nnc", frozen_m, compile_spec, any_dict_ty);
compiled_module._save_for_mobile(output_model_name);
std::cout << "The compiled model was saved to " << output_model_name
<< std::endl;
return 0;
}