diff --git a/codegen/BUILD b/codegen/BUILD index 5473361484f..4596d18436d 100644 --- a/codegen/BUILD +++ b/codegen/BUILD @@ -6,6 +6,17 @@ package( licenses = ["notice"], ) +py_library( + name = "graph", + srcs = [ + "graph.py", + ], + deps = [ + "//tensorflow/lite/python:schema_py", + "//tensorflow/lite/tools:visualize", + ], +) + py_library( name = "inference_generator", srcs = [ @@ -16,6 +27,7 @@ py_library( "templates/inference.h.mako", ], deps = [ + ":graph", requirement("mako"), ], ) @@ -28,7 +40,9 @@ py_binary( python_version = "PY3", srcs_version = "PY3", deps = [ + ":graph", ":inference_generator", + "//tensorflow/lite/tools:flatbuffer_utils", "@absl_py//absl:app", "@absl_py//absl/flags", ], diff --git a/codegen/code_generator.py b/codegen/code_generator.py index da246018ed1..95d05952322 100644 --- a/codegen/code_generator.py +++ b/codegen/code_generator.py @@ -21,6 +21,8 @@ from collections.abc import Sequence from tflite_micro.codegen import inference_generator +from tflite_micro.codegen import graph +from tflite_micro.tensorflow.lite.tools import flatbuffer_utils # Usage information: # Default: @@ -51,7 +53,10 @@ def main(argv: Sequence[str]) -> None: output_name = _OUTPUT_NAME.value or os.path.splitext( os.path.basename(_MODEL_PATH.value))[0] - inference_generator.generate(output_dir, output_name) + model = flatbuffer_utils.read_model(_MODEL_PATH.value) + + inference_generator.generate(output_dir, output_name, + graph.OpCodeTable([model]), graph.Graph(model)) if __name__ == "__main__": diff --git a/codegen/examples/hello_world/hello_world.cc b/codegen/examples/hello_world/hello_world.cc index e597c90fbd2..81b650f0842 100644 --- a/codegen/examples/hello_world/hello_world.cc +++ b/codegen/examples/hello_world/hello_world.cc @@ -14,9 +14,13 @@ limitations under the License. ==============================================================================*/ #include "hello_world_model.h" +#include "tensorflow/lite/c/c_api_types.h" int main(int argc, char** argv) { - hello_world_model::Invoke(); + TfLiteStatus status = hello_world_model::Invoke(); + if (status != kTfLiteOk) { + return -1; + } return 0; } diff --git a/codegen/examples/hello_world/hello_world_model.cc b/codegen/examples/hello_world/hello_world_model.cc index 50f37695748..c92c52d63f3 100644 --- a/codegen/examples/hello_world/hello_world_model.cc +++ b/codegen/examples/hello_world/hello_world_model.cc @@ -17,8 +17,32 @@ limitations under the License. #include "hello_world_model.h" -namespace hello_world_model { +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/micro_ops.h" +#include "tensorflow/lite/micro/micro_common.h" -void Invoke() {} +namespace hello_world_model { +namespace { +// TODO(rjascani): We should probably split out the OpTable to a separate file +// once we start generating for multiple models. +enum OpCode { kFullyConnected, kCount }; + +TFLMInferenceRegistration op_table[OpCode::kCount] = { + tflite::RegisterInference_FULLY_CONNECTED(), +}; +} // namespace + +TfLiteStatus InvokeSubgraph0() { + TF_LITE_ENSURE_OK(nullptr, + op_table[OpCode::kFullyConnected].invoke(nullptr, nullptr)); + TF_LITE_ENSURE_OK(nullptr, + op_table[OpCode::kFullyConnected].invoke(nullptr, nullptr)); + TF_LITE_ENSURE_OK(nullptr, + op_table[OpCode::kFullyConnected].invoke(nullptr, nullptr)); + return kTfLiteOk; +} + +TfLiteStatus Invoke() { return InvokeSubgraph0(); } } // namespace hello_world_model diff --git a/codegen/examples/hello_world/hello_world_model.h b/codegen/examples/hello_world/hello_world_model.h index 241666e002e..cd54cbd92cc 100644 --- a/codegen/examples/hello_world/hello_world_model.h +++ b/codegen/examples/hello_world/hello_world_model.h @@ -17,8 +17,10 @@ limitations under the License. #pragma once +#include "tensorflow/lite/c/c_api_types.h" + namespace hello_world_model { -void Invoke(); +TfLiteStatus Invoke(); } // namespace hello_world_model diff --git a/codegen/graph.py b/codegen/graph.py new file mode 100644 index 00000000000..1afac73abba --- /dev/null +++ b/codegen/graph.py @@ -0,0 +1,94 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" Provides object representation for the model that is conducive to code + generation using templates. """ + +from typing import List, Sequence + +from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb +from tflite_micro.tensorflow.lite.tools import visualize + + +def _to_pascal_case(s: str) -> str: + return s.title().replace('_', '') + + +class OpCode(object): + + def __init__(self, op_code: schema_fb.OperatorCodeT): + self._op_code: schema_fb.OperatorCodeT = op_code + + def name(self) -> str: + if self._op_code.customCode: + return self._op_code.customCode + return visualize.BuiltinCodeToName(self._op_code.builtinCode) + + def register_function(self) -> str: + return "tflite::RegisterInference_{}".format(self.name()) + + def enum_name(self) -> str: + return "k{}".format(_to_pascal_case(self.name())) + + +class Operator(object): + + def __init__(self, model: schema_fb.ModelT, operator: schema_fb.OperatorT): + self._operator: schema_fb.OperatorT = operator + self._op_code: OpCode = OpCode( + model.operatorCodes[self._operator.opcodeIndex]) + + @property + def op_code(self) -> OpCode: + return self._op_code + + +class Subgraph(object): + + def __init__(self, model: schema_fb.ModelT, subgraph: schema_fb.SubGraphT): + self._subgraph: schema_fb.SubGraphT = subgraph + self._operators: List[Operator] = [ + Operator(model, operator) for operator in subgraph.operators + ] + + @property + def operators(self) -> Sequence[Operator]: + return self._operators + + +class Graph(object): + + def __init__(self, model: schema_fb.ModelT): + self._subgraphs: List[SubGraph] = [ + Subgraph(model, subgraph) for subgraph in model.subgraphs + ] + + @property + def subgraphs(self) -> Sequence[Subgraph]: + return self._subgraphs + + +class OpCodeTable(object): + + def __init__(self, models: Sequence[schema_fb.ModelT]): + op_codes = [] + for model in models: + for op_code in model.operatorCodes: + op_codes.append(OpCode(op_code)) + + self._op_codes: List([OpCode]) = list(set(op_codes)) + + @property + def op_codes(self) -> Sequence[OpCode]: + return self._op_codes diff --git a/codegen/inference_generator.py b/codegen/inference_generator.py index 12597afd2ec..fe351f36550 100644 --- a/codegen/inference_generator.py +++ b/codegen/inference_generator.py @@ -19,6 +19,8 @@ from mako import template from typing import TypedDict +from tflite_micro.codegen import graph + _TEMPLATE_DIR = pathlib.Path(__file__).parent / 'templates' _HEADER_TEMPLATE = _TEMPLATE_DIR / 'inference.h.mako' _SOURCE_TEMPLATE = _TEMPLATE_DIR / 'inference.cc.mako' @@ -27,6 +29,8 @@ class ModelData(TypedDict): header_file: str model_name: str + op_code_table: graph.OpCodeTable + graph: graph.Graph def _render(output_file: pathlib.Path, template_file: pathlib.Path, @@ -45,12 +49,15 @@ def _generate_source(source_path: pathlib.Path, model_data: ModelData) -> None: _render(source_path, _SOURCE_TEMPLATE, model_data) -def generate(output_dir: str, output_name: str) -> None: +def generate(output_dir: str, output_name: str, + op_code_table: graph.OpCodeTable, graph: graph.Graph) -> None: """ Generate C/C++ inference code. """ header_file = f"{output_name}.h" model_data: ModelData = { 'header_file': header_file, - 'model_name': output_name + 'model_name': output_name, + 'op_code_table': op_code_table, + 'graph': graph, } # Ensure output directory exists diff --git a/codegen/templates/inference.cc.mako b/codegen/templates/inference.cc.mako index dc4e1f7c091..0acbc1d1792 100644 --- a/codegen/templates/inference.cc.mako +++ b/codegen/templates/inference.cc.mako @@ -17,8 +17,39 @@ limitations under the License. #include "${header_file}" -namespace ${model_name} { +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/micro_ops.h" +#include "tensorflow/lite/micro/micro_common.h" -void Invoke() {} +namespace ${model_name} { +namespace { +// TODO(rjascani): We should probably split out the OpTable to a separate file +// once we start generating for multiple models. +enum OpCode { +% for op_code in op_code_table.op_codes: + ${op_code.enum_name()}, +% endfor + kCount +}; + +TFLMInferenceRegistration op_table[OpCode::kCount] = { +% for op_code in op_code_table.op_codes: + ${op_code.register_function()}(), +% endfor +}; +} // namespace + +% for subgraph_idx, subgraph in enumerate(graph.subgraphs): +TfLiteStatus InvokeSubgraph${subgraph_idx}() { +% for operator in subgraph.operators: + TF_LITE_ENSURE_OK(nullptr, + op_table[OpCode::${operator.op_code.enum_name()}].invoke(nullptr, nullptr)); +% endfor + return kTfLiteOk; +} +% endfor + +TfLiteStatus Invoke() { return InvokeSubgraph0(); } } // namespace ${model_name} diff --git a/codegen/templates/inference.h.mako b/codegen/templates/inference.h.mako index 29f5584181f..ba4104f500d 100644 --- a/codegen/templates/inference.h.mako +++ b/codegen/templates/inference.h.mako @@ -17,8 +17,10 @@ limitations under the License. #pragma once +#include "tensorflow/lite/c/c_api_types.h" + namespace ${model_name} { -void Invoke(); +TfLiteStatus Invoke(); } // namespace ${model_name} diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc index a7ab8f12ab3..c4cad9023ef 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc @@ -433,4 +433,8 @@ TFLMRegistration Register_FULLY_CONNECTED_INT16() { return tflite::micro::RegisterOp(Init, Prepare, EvalInt16); } +TFLMInferenceRegistration RegisterInference_FULLY_CONNECTED() { + return tflite::micro::RegisterOp(Eval); +} + } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index f732b2935a0..54576faf6ea 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -203,4 +203,8 @@ TFLMRegistration Register_FULLY_CONNECTED() { return tflite::micro::RegisterOp(Init, Prepare, Eval); } +TFLMInferenceRegistration RegisterInference_FULLY_CONNECTED() { + return tflite::micro::RegisterOp(Eval); +} + } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/kernel_util.cc b/tensorflow/lite/micro/kernels/kernel_util.cc index ffffa084d51..12adf3614b6 100644 --- a/tensorflow/lite/micro/kernels/kernel_util.cc +++ b/tensorflow/lite/micro/kernels/kernel_util.cc @@ -53,6 +53,15 @@ TFLMRegistration RegisterOp( /*custom_name=*/nullptr}; } +TFLMInferenceRegistration RegisterOp( + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node), + void (*reset)(TfLiteContext* context, void* buffer)) { + return { + /*invoke=*/invoke, + /*reset*/ reset, + }; +} + // Returns a mutable tensor for a given input index. is_variable must be checked // during prepare when the full TfLiteTensor is available. TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context, diff --git a/tensorflow/lite/micro/kernels/kernel_util.h b/tensorflow/lite/micro/kernels/kernel_util.h index 080a0b3f361..f14c927133d 100644 --- a/tensorflow/lite/micro/kernels/kernel_util.h +++ b/tensorflow/lite/micro/kernels/kernel_util.h @@ -35,6 +35,10 @@ TFLMRegistration RegisterOp( void (*free)(TfLiteContext* context, void* buffer) = nullptr, void (*reset)(TfLiteContext* context, void* buffer) = nullptr); +TFLMInferenceRegistration RegisterOp( + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node), + void (*reset)(TfLiteContext* context, void* buffer) = nullptr); + // Prints out n bytes in a int8_t buffer as hex void PrintNBytes(const int8_t* tensor_data, int n_bytes, const char* prefix = nullptr); diff --git a/tensorflow/lite/micro/kernels/micro_ops.h b/tensorflow/lite/micro/kernels/micro_ops.h index 5bffa09b9c6..c723092f104 100644 --- a/tensorflow/lite/micro/kernels/micro_ops.h +++ b/tensorflow/lite/micro/kernels/micro_ops.h @@ -133,6 +133,9 @@ TFLMRegistration Register_VAR_HANDLE(); TFLMRegistration Register_WHILE(); TFLMRegistration Register_ZEROS_LIKE(); +// TODO(b/295174388): Add the rest of inference only registration functions. +TFLMInferenceRegistration RegisterInference_FULLY_CONNECTED(); + // TODO(b/160234179): Change custom OPs to also return by value. namespace tflm_signal { TFLMRegistration* Register_DELAY(); diff --git a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc index 1395fc39645..df5458001b7 100644 --- a/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/xtensa/fully_connected.cc @@ -125,4 +125,8 @@ TFLMRegistration Register_FULLY_CONNECTED() { XtensaPrepareFullyConnected, Eval); } +TFLMInferenceRegistration RegisterInference_FULLY_CONNECTED() { + return tflite::micro::RegisterOp(Eval); +} + } // namespace tflite diff --git a/tensorflow/lite/micro/micro_common.h b/tensorflow/lite/micro/micro_common.h index dc0bc0843a6..9ab427f5add 100644 --- a/tensorflow/lite/micro/micro_common.h +++ b/tensorflow/lite/micro/micro_common.h @@ -30,4 +30,9 @@ struct TFLMRegistration { const char* custom_name; }; +struct TFLMInferenceRegistration { + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); + void (*reset)(TfLiteContext* context, void* buffer); +}; + #endif // THIRD_PARTY_TFLITE_MICRO_TENSORFLOW_LITE_MICRO_MICRO_COMMON_H_ diff --git a/third_party/hexagon/fully_connected.cc b/third_party/hexagon/fully_connected.cc index c27c238003c..99ee1f3c09d 100644 --- a/third_party/hexagon/fully_connected.cc +++ b/third_party/hexagon/fully_connected.cc @@ -129,4 +129,8 @@ TFLMRegistration Register_FULLY_CONNECTED() { HexagonFullyConnectedEval); } +TFLMInferenceRegistration RegisterInference_FULLY_CONNECTED() { + return tflite::micro::RegisterOp(HexagonFullyConnectedEval); +} + } // namespace tflite