Skip to content

Commit

Permalink
Add function to load saved model for tflite mlir converter.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 296139934
Change-Id: I1c608c2971d81e5efa38925ee9fe4b80f437726a
  • Loading branch information
renjie-liu authored and tensorflower-gardener committed Feb 20, 2020
1 parent d027ba1 commit f8b2a05
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 6 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/lite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ tf_cc_binary(
"//tensorflow/compiler/mlir:init_mlir",
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/stream_executor/lib",
Expand Down
25 changes: 19 additions & 6 deletions tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/stream_executor/lib/statusor.h"
Expand Down Expand Up @@ -132,12 +133,24 @@ int main(int argc, char **argv) {
llvm::SourceMgr source_mgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);

StatusOr<mlir::OwningModuleRef> module =
tensorflow::LoadFromGraphdefOrMlirSource(
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays,
/*prune_unused_nodes=*/true, &source_mgr, &context);
StatusOr<mlir::OwningModuleRef> module;

// TODO(b/147435528): We need to test the e2e behavior once the graph freezing
// inside mlir is done.
if (import_saved_model || import_saved_model_v1) {
if (input_mlir)
module = tensorflow::errors::InvalidArgument(
"Importing saved model should not have input_mlir set");
module = tensorflow::ImportSavedModel(
import_saved_model, import_saved_model_v1, input_file_name,
saved_model_tags, saved_model_exported_names, &context);
} else {
module = tensorflow::LoadFromGraphdefOrMlirSource(
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
debug_info_file, input_arrays, input_dtypes, input_shapes,
output_arrays,
/*prune_unused_nodes=*/true, &source_mgr, &context);
}

// If errors occur, the library call in the above already logged the error
// message. So we can just return here.
Expand Down
27 changes: 27 additions & 0 deletions tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,33 @@ using llvm::cl::opt;
opt<std::string> input_file_name(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));

// NOLINTNEXTLINE
opt<bool> import_saved_model(
"savedmodel-to-mlir",
llvm::cl::desc("Import a saved model to its MLIR representation"),
llvm::cl::value_desc("dir"));

// NOLINTNEXTLINE
opt<bool> import_saved_model_v1(
"savedmodel-v1-to-mlir",
llvm::cl::desc("Import a saved model V1 to its MLIR representation"),
llvm::cl::value_desc("dir"));

// NOLINTNEXTLINE
opt<std::string> saved_model_tags(
"tf-savedmodel-tags",
llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, "
"separated by ','"),
llvm::cl::init("serve"));

// NOLINTNEXTLINE
opt<std::string> saved_model_exported_names(
"tf-savedmodel-exported-names",
llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty "
"(the default) means export all."),
llvm::cl::init(""));

// NOLINTNEXTLINE
opt<std::string> output_file_name("o", llvm::cl::desc("<output file>"),
llvm::cl::value_desc("filename"),
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,10 @@ extern llvm::cl::opt<bool> inline_functions;
extern llvm::cl::list<std::string> custom_opdefs;
extern llvm::cl::opt<bool> emit_quant_adaptor_ops;
extern llvm::cl::opt<std::string> quant_stats_file_name;

// Import saved model.
extern llvm::cl::opt<bool> import_saved_model;
extern llvm::cl::opt<bool> import_saved_model_v1;
extern llvm::cl::opt<std::string> saved_model_tags;
extern llvm::cl::opt<std::string> saved_model_exported_names;
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_TRANSLATE_CL_H_
37 changes: 37 additions & 0 deletions tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ limitations under the License.

#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"

#include <string>
#include <unordered_set>

#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Parser.h" // TF:llvm-project
Expand Down Expand Up @@ -155,4 +159,37 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
return Status::OK();
}

StatusOr<mlir::OwningModuleRef> ImportSavedModel(
bool import_saved_model, bool import_saved_model_v1,
const std::string& input_filename, const std::string& saved_model_tags,
const std::string& saved_model_exported_names, mlir::MLIRContext* context) {
if (import_saved_model) {
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
std::vector<std::string> exported_names =
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());

auto module = tensorflow::SavedModelToMlirImport(
input_filename, tags, absl::Span<std::string>(exported_names), context);
if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");

return module;
} else if (import_saved_model_v1) {
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');

auto module =
tensorflow::SavedModelV1ToMlirImport(input_filename, tags, context);

if (!module)
return tensorflow::errors::InvalidArgument("fail to open input file");

return module;
} else {
return tensorflow::errors::InvalidArgument(
"Should be either saved model v1 or v2");
}
}

} // namespace tensorflow
6 changes: 6 additions & 0 deletions tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ LoadFromGraphdefOrMlirSource(
absl::string_view output_arrays, bool prune_unused_nodes,
llvm::SourceMgr* source_mgr, mlir::MLIRContext* context);

// Load Saved model (either v1 or v2) into MLIR.
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
bool import_saved_model, bool import_saved_model_v1,
const std::string& input_filename, const std::string& saved_model_tags,
const std::string& saved_model_exported_names, mlir::MLIRContext* context);

// Taking a MLIR module in TF executor dialect and a set of parameters,
// applies a set of passes to convert the module to TF Lite dialect and
// serializes the result to a string. Depending on an attribute in the module
Expand Down

0 comments on commit f8b2a05

Please sign in to comment.