Skip to content

Commit

Permalink
Generate op table and subgraph invoke functions (#2176)
Browse files Browse the repository at this point in the history
As the next step in the codegen experiment, we want to generate the invoke calls for each layer. This is slightly challenging with the existing sources, as kernels only expose a registration function, not their individual Eval functions. In an effort to keep the code churn to a minimum, this PR introduces an inference only registration structure and function. It includes just two function pointers: invoke and reset. For this CL, we've only introduced it for FullyConnected.

In the code generator, this PR creates a new op_table array in the generated source, with an enum for lookup. It also generates an invoke function for each subgraph, that calls each operator's invoke function.

BUG=295174388
  • Loading branch information
rascani authored Aug 18, 2023
1 parent 389d8f5 commit a0f8970
Show file tree
Hide file tree
Showing 17 changed files with 230 additions and 10 deletions.
14 changes: 14 additions & 0 deletions codegen/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@ package(
licenses = ["notice"],
)

py_library(
name = "graph",
srcs = [
"graph.py",
],
deps = [
"//tensorflow/lite/python:schema_py",
"//tensorflow/lite/tools:visualize",
],
)

py_library(
name = "inference_generator",
srcs = [
Expand All @@ -16,6 +27,7 @@ py_library(
"templates/inference.h.mako",
],
deps = [
":graph",
requirement("mako"),
],
)
Expand All @@ -28,7 +40,9 @@ py_binary(
python_version = "PY3",
srcs_version = "PY3",
deps = [
":graph",
":inference_generator",
"//tensorflow/lite/tools:flatbuffer_utils",
"@absl_py//absl:app",
"@absl_py//absl/flags",
],
Expand Down
7 changes: 6 additions & 1 deletion codegen/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from collections.abc import Sequence

from tflite_micro.codegen import inference_generator
from tflite_micro.codegen import graph
from tflite_micro.tensorflow.lite.tools import flatbuffer_utils

# Usage information:
# Default:
Expand Down Expand Up @@ -51,7 +53,10 @@ def main(argv: Sequence[str]) -> None:
output_name = _OUTPUT_NAME.value or os.path.splitext(
os.path.basename(_MODEL_PATH.value))[0]

inference_generator.generate(output_dir, output_name)
model = flatbuffer_utils.read_model(_MODEL_PATH.value)

inference_generator.generate(output_dir, output_name,
graph.OpCodeTable([model]), graph.Graph(model))


if __name__ == "__main__":
Expand Down
6 changes: 5 additions & 1 deletion codegen/examples/hello_world/hello_world.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@ limitations under the License.
==============================================================================*/

#include "hello_world_model.h"
#include "tensorflow/lite/c/c_api_types.h"

int main(int argc, char** argv) {
hello_world_model::Invoke();
TfLiteStatus status = hello_world_model::Invoke();
if (status != kTfLiteOk) {
return -1;
}

return 0;
}
28 changes: 26 additions & 2 deletions codegen/examples/hello_world/hello_world_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,32 @@ limitations under the License.

#include "hello_world_model.h"

namespace hello_world_model {
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/micro_common.h"

void Invoke() {}
namespace hello_world_model {
namespace {
// TODO(rjascani): We should probably split out the OpTable to a separate file
// once we start generating for multiple models.
enum OpCode { kFullyConnected, kCount };

TFLMInferenceRegistration op_table[OpCode::kCount] = {
tflite::RegisterInference_FULLY_CONNECTED(),
};
} // namespace

TfLiteStatus InvokeSubgraph0() {
TF_LITE_ENSURE_OK(nullptr,
op_table[OpCode::kFullyConnected].invoke(nullptr, nullptr));
TF_LITE_ENSURE_OK(nullptr,
op_table[OpCode::kFullyConnected].invoke(nullptr, nullptr));
TF_LITE_ENSURE_OK(nullptr,
op_table[OpCode::kFullyConnected].invoke(nullptr, nullptr));
return kTfLiteOk;
}

TfLiteStatus Invoke() { return InvokeSubgraph0(); }

} // namespace hello_world_model
4 changes: 3 additions & 1 deletion codegen/examples/hello_world/hello_world_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ limitations under the License.

#pragma once

#include "tensorflow/lite/c/c_api_types.h"

namespace hello_world_model {

void Invoke();
TfLiteStatus Invoke();

} // namespace hello_world_model
94 changes: 94 additions & 0 deletions codegen/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Provides object representation for the model that is conducive to code
generation using templates. """

from typing import List, Sequence

from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb
from tflite_micro.tensorflow.lite.tools import visualize


def _to_pascal_case(s: str) -> str:
return s.title().replace('_', '')


class OpCode(object):

def __init__(self, op_code: schema_fb.OperatorCodeT):
self._op_code: schema_fb.OperatorCodeT = op_code

def name(self) -> str:
if self._op_code.customCode:
return self._op_code.customCode
return visualize.BuiltinCodeToName(self._op_code.builtinCode)

def register_function(self) -> str:
return "tflite::RegisterInference_{}".format(self.name())

def enum_name(self) -> str:
return "k{}".format(_to_pascal_case(self.name()))


class Operator(object):

def __init__(self, model: schema_fb.ModelT, operator: schema_fb.OperatorT):
self._operator: schema_fb.OperatorT = operator
self._op_code: OpCode = OpCode(
model.operatorCodes[self._operator.opcodeIndex])

@property
def op_code(self) -> OpCode:
return self._op_code


class Subgraph(object):

def __init__(self, model: schema_fb.ModelT, subgraph: schema_fb.SubGraphT):
self._subgraph: schema_fb.SubGraphT = subgraph
self._operators: List[Operator] = [
Operator(model, operator) for operator in subgraph.operators
]

@property
def operators(self) -> Sequence[Operator]:
return self._operators


class Graph(object):

def __init__(self, model: schema_fb.ModelT):
self._subgraphs: List[SubGraph] = [
Subgraph(model, subgraph) for subgraph in model.subgraphs
]

@property
def subgraphs(self) -> Sequence[Subgraph]:
return self._subgraphs


class OpCodeTable(object):

def __init__(self, models: Sequence[schema_fb.ModelT]):
op_codes = []
for model in models:
for op_code in model.operatorCodes:
op_codes.append(OpCode(op_code))

self._op_codes: List([OpCode]) = list(set(op_codes))

@property
def op_codes(self) -> Sequence[OpCode]:
return self._op_codes
11 changes: 9 additions & 2 deletions codegen/inference_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from mako import template
from typing import TypedDict

from tflite_micro.codegen import graph

_TEMPLATE_DIR = pathlib.Path(__file__).parent / 'templates'
_HEADER_TEMPLATE = _TEMPLATE_DIR / 'inference.h.mako'
_SOURCE_TEMPLATE = _TEMPLATE_DIR / 'inference.cc.mako'
Expand All @@ -27,6 +29,8 @@
class ModelData(TypedDict):
header_file: str
model_name: str
op_code_table: graph.OpCodeTable
graph: graph.Graph


def _render(output_file: pathlib.Path, template_file: pathlib.Path,
Expand All @@ -45,12 +49,15 @@ def _generate_source(source_path: pathlib.Path, model_data: ModelData) -> None:
_render(source_path, _SOURCE_TEMPLATE, model_data)


def generate(output_dir: str, output_name: str) -> None:
def generate(output_dir: str, output_name: str,
op_code_table: graph.OpCodeTable, graph: graph.Graph) -> None:
""" Generate C/C++ inference code. """
header_file = f"{output_name}.h"
model_data: ModelData = {
'header_file': header_file,
'model_name': output_name
'model_name': output_name,
'op_code_table': op_code_table,
'graph': graph,
}

# Ensure output directory exists
Expand Down
35 changes: 33 additions & 2 deletions codegen/templates/inference.cc.mako
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,39 @@ limitations under the License.

#include "${header_file}"

namespace ${model_name} {
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/micro_common.h"

void Invoke() {}
namespace ${model_name} {
namespace {
// TODO(rjascani): We should probably split out the OpTable to a separate file
// once we start generating for multiple models.
enum OpCode {
% for op_code in op_code_table.op_codes:
${op_code.enum_name()},
% endfor
kCount
};

TFLMInferenceRegistration op_table[OpCode::kCount] = {
% for op_code in op_code_table.op_codes:
${op_code.register_function()}(),
% endfor
};
} // namespace

% for subgraph_idx, subgraph in enumerate(graph.subgraphs):
TfLiteStatus InvokeSubgraph${subgraph_idx}() {
% for operator in subgraph.operators:
TF_LITE_ENSURE_OK(nullptr,
op_table[OpCode::${operator.op_code.enum_name()}].invoke(nullptr, nullptr));
% endfor
return kTfLiteOk;
}
% endfor

TfLiteStatus Invoke() { return InvokeSubgraph0(); }

} // namespace ${model_name}
4 changes: 3 additions & 1 deletion codegen/templates/inference.h.mako
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ limitations under the License.

#pragma once

#include "tensorflow/lite/c/c_api_types.h"

namespace ${model_name} {

void Invoke();
TfLiteStatus Invoke();

} // namespace ${model_name}
4 changes: 4 additions & 0 deletions tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -433,4 +433,8 @@ TFLMRegistration Register_FULLY_CONNECTED_INT16() {
return tflite::micro::RegisterOp(Init, Prepare, EvalInt16);
}

TFLMInferenceRegistration RegisterInference_FULLY_CONNECTED() {
return tflite::micro::RegisterOp(Eval);
}

} // namespace tflite
4 changes: 4 additions & 0 deletions tensorflow/lite/micro/kernels/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,4 +203,8 @@ TFLMRegistration Register_FULLY_CONNECTED() {
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}

TFLMInferenceRegistration RegisterInference_FULLY_CONNECTED() {
return tflite::micro::RegisterOp(Eval);
}

} // namespace tflite
9 changes: 9 additions & 0 deletions tensorflow/lite/micro/kernels/kernel_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ TFLMRegistration RegisterOp(
/*custom_name=*/nullptr};
}

TFLMInferenceRegistration RegisterOp(
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node),
void (*reset)(TfLiteContext* context, void* buffer)) {
return {
/*invoke=*/invoke,
/*reset*/ reset,
};
}

// Returns a mutable tensor for a given input index. is_variable must be checked
// during prepare when the full TfLiteTensor is available.
TfLiteEvalTensor* GetMutableEvalInput(const TfLiteContext* context,
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/lite/micro/kernels/kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ TFLMRegistration RegisterOp(
void (*free)(TfLiteContext* context, void* buffer) = nullptr,
void (*reset)(TfLiteContext* context, void* buffer) = nullptr);

TFLMInferenceRegistration RegisterOp(
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node),
void (*reset)(TfLiteContext* context, void* buffer) = nullptr);

// Prints out n bytes in a int8_t buffer as hex
void PrintNBytes(const int8_t* tensor_data, int n_bytes,
const char* prefix = nullptr);
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/lite/micro/kernels/micro_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ TFLMRegistration Register_VAR_HANDLE();
TFLMRegistration Register_WHILE();
TFLMRegistration Register_ZEROS_LIKE();

// TODO(b/295174388): Add the rest of inference only registration functions.
TFLMInferenceRegistration RegisterInference_FULLY_CONNECTED();

// TODO(b/160234179): Change custom OPs to also return by value.
namespace tflm_signal {
TFLMRegistration* Register_DELAY();
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/lite/micro/kernels/xtensa/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,8 @@ TFLMRegistration Register_FULLY_CONNECTED() {
XtensaPrepareFullyConnected, Eval);
}

TFLMInferenceRegistration RegisterInference_FULLY_CONNECTED() {
return tflite::micro::RegisterOp(Eval);
}

} // namespace tflite
5 changes: 5 additions & 0 deletions tensorflow/lite/micro/micro_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,9 @@ struct TFLMRegistration {
const char* custom_name;
};

struct TFLMInferenceRegistration {
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
void (*reset)(TfLiteContext* context, void* buffer);
};

#endif // THIRD_PARTY_TFLITE_MICRO_TENSORFLOW_LITE_MICRO_MICRO_COMMON_H_
4 changes: 4 additions & 0 deletions third_party/hexagon/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,8 @@ TFLMRegistration Register_FULLY_CONNECTED() {
HexagonFullyConnectedEval);
}

TFLMInferenceRegistration RegisterInference_FULLY_CONNECTED() {
return tflite::micro::RegisterOp(HexagonFullyConnectedEval);
}

} // namespace tflite

0 comments on commit a0f8970

Please sign in to comment.