Skip to content

Commit 2cd9a75

Browse files
flamearrowtflite-support-robot
authored andcommitted
add support for the tokenized input in NLClassifier.
PiperOrigin-RevId: 322493255
1 parent 59cfb05 commit 2cd9a75

File tree

7 files changed

+233
-63
lines changed

7 files changed

+233
-63
lines changed

tensorflow_lite_support/cc/task/text/nlclassifier/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ cc_library(
1919
"//tensorflow_lite_support/cc/task/core:category",
2020
"//tensorflow_lite_support/cc/task/core:task_api_factory",
2121
"//tensorflow_lite_support/cc/task/core:task_utils",
22+
"//tensorflow_lite_support/cc/text/tokenizers:regex_tokenizer",
23+
"//tensorflow_lite_support/cc/text/tokenizers:tokenizer",
24+
"//tensorflow_lite_support/cc/text/tokenizers:tokenizer_utils",
2225
"//tensorflow_lite_support/cc/utils:common_utils",
2326
"//tensorflow_lite_support/metadata/cc:metadata_extractor",
2427
"@com_google_absl//absl/algorithm:container",
@@ -27,6 +30,7 @@ cc_library(
2730
"@com_google_absl//absl/strings",
2831
"@flatbuffers",
2932
"@org_tensorflow//tensorflow/lite:string",
33+
"@org_tensorflow//tensorflow/lite:type_to_tflitetype",
3034
"@org_tensorflow//tensorflow/lite/c:common",
3135
"@org_tensorflow//tensorflow/lite/core/api",
3236
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",

tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ namespace nlclassifier {
4747

4848
using ::tflite::support::task::core::FindTensorByName;
4949
using ::tflite::support::task::core::PopulateTensor;
50-
using ::tflite::support::text::tokenizer::CreateTokenizerFromMetadata;
50+
using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit;
5151
using ::tflite::support::text::tokenizer::TokenizerResult;
5252

5353
namespace {
@@ -57,6 +57,7 @@ constexpr char kSegmentIdsTensorName[] = "segment_ids";
5757
constexpr char kScoreTensorName[] = "probability";
5858
constexpr char kClassificationToken[] = "[CLS]";
5959
constexpr char kSeparator[] = "[SEP]";
60+
constexpr int kTokenizerProcessUnitIndex = 0;
6061
} // namespace
6162

6263
absl::Status BertNLClassifier::Preprocess(
@@ -160,8 +161,17 @@ BertNLClassifier::CreateBertNLClassifierWithMetadataFromBinary(
160161

161162
absl::Status BertNLClassifier::InitializeFromMetadata() {
162163
// Set up mandatory tokenizer.
164+
const ProcessUnit* tokenizer_process_unit =
165+
GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex);
166+
if (tokenizer_process_unit == nullptr) {
167+
return CreateStatusWithPayload(
168+
absl::StatusCode::kInvalidArgument,
169+
"No input process unit found from metadata.",
170+
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
171+
}
163172
ASSIGN_OR_RETURN(tokenizer_,
164-
CreateTokenizerFromMetadata(*GetMetadataExtractor()));
173+
CreateTokenizerFromProcessUnit(tokenizer_process_unit,
174+
GetMetadataExtractor()));
165175

166176
// Set up optional label vector.
167177
TrySetLabelFromMetadata(

tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,16 @@ limitations under the License.
2929
#include "tensorflow/lite/c/common.h"
3030
#include "tensorflow/lite/core/api/op_resolver.h"
3131
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
32+
#include "tensorflow/lite/type_to_tflitetype.h"
3233
#include "tensorflow_lite_support/cc/common.h"
3334
#include "tensorflow_lite_support/cc/port/status_macros.h"
3435
#include "tensorflow_lite_support/cc/port/statusor.h"
3536
#include "tensorflow_lite_support/cc/task/core/category.h"
3637
#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
3738
#include "tensorflow_lite_support/cc/task/core/task_utils.h"
39+
#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
40+
#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
41+
#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h"
3842
#include "tensorflow_lite_support/cc/utils/common_utils.h"
3943

4044
namespace tflite {
@@ -52,8 +56,16 @@ using ::tflite::support::StatusOr;
5256
using ::tflite::support::task::core::Dequantize;
5357
using ::tflite::support::task::core::GetStringAtIndex;
5458
using ::tflite::support::task::core::PopulateTensor;
59+
using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit;
60+
using ::tflite::support::text::tokenizer::RegexTokenizer;
61+
using ::tflite::support::text::tokenizer::TokenizerResult;
5562
using ::tflite::support::utils::LoadVocabFromBuffer;
5663

64+
namespace {
65+
constexpr int kRegexTokenizerInputTensorIndex = 0;
66+
constexpr int kRegexTokenizerProcessUnitIndex = 0;
67+
} // namespace
68+
5769
const NLClassifierOptions& NLClassifier::GetOptions() const { return options_; }
5870

5971
absl::Status NLClassifier::TrySetLabelFromMetadata(
@@ -102,11 +114,59 @@ std::vector<core::Category> NLClassifier::Classify(const std::string& text) {
102114

103115
absl::Status NLClassifier::Preprocess(
104116
const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
105-
PopulateTensor(
106-
input,
107-
FindTensorWithNameOrIndex(
108-
input_tensors, GetMetadataExtractor()->GetInputTensorMetadata(),
109-
options_.input_tensor_name, options_.input_tensor_index));
117+
TfLiteTensor* input_tensor = FindTensorWithNameOrIndex(
118+
input_tensors, GetMetadataExtractor()->GetInputTensorMetadata(),
119+
options_.input_tensor_name, options_.input_tensor_index);
120+
if (input_tensor == nullptr) {
121+
return CreateStatusWithPayload(
122+
absl::StatusCode::kInvalidArgument,
123+
"No input tensor found from NLClassifierOptions.",
124+
TfLiteSupportStatus::kInputTensorNotFoundError);
125+
}
126+
127+
if (HasRegexTokenizerMetadata()) {
128+
RETURN_IF_ERROR(SetupRegexTokenizer());
129+
130+
// |<-------sentence_length-------->|
131+
// input_tensor <START>, t1, t2... <PAD>, <PAD>...
132+
// <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's not
133+
// found in tokenizer vocab.
134+
TokenizerResult result = tokenizer_->Tokenize(input);
135+
136+
size_t max_sentence_length = input_tensor->dims->size == 2
137+
? input_tensor->dims->data[1]
138+
: input_tensor->dims->data[0];
139+
140+
int unknown_token_id = 0;
141+
tokenizer_->LookupId(RegexTokenizer::kUnknown, &unknown_token_id);
142+
143+
int pad_token_id = 0;
144+
tokenizer_->LookupId(RegexTokenizer::kPad, &pad_token_id);
145+
146+
std::vector<float> input_tokens(max_sentence_length, pad_token_id);
147+
int start_token_id = 0;
148+
size_t input_token_index = 0;
149+
if (tokenizer_->LookupId(RegexTokenizer::kStart, &start_token_id)) {
150+
input_tokens[0] = start_token_id;
151+
input_token_index = 1;
152+
}
153+
154+
for (size_t i = 0; (i < result.subwords.size()) &&
155+
(input_token_index < max_sentence_length);
156+
++i) {
157+
const std::string& token = result.subwords[i];
158+
int token_id = 0;
159+
if (tokenizer_->LookupId(token, &token_id)) {
160+
input_tokens[input_token_index] = token_id;
161+
} else {
162+
input_tokens[input_token_index] = unknown_token_id;
163+
}
164+
}
165+
166+
PopulateTensor(input_tokens, input_tensor);
167+
} else {
168+
PopulateTensor(input, input_tensor);
169+
}
110170
return absl::OkStatus();
111171
}
112172

@@ -172,7 +232,7 @@ absl::Status NLClassifier::Initialize(const NLClassifierOptions& options) {
172232
options.input_tensor_index),
173233
TfLiteSupportStatus::kInputTensorNotFoundError);
174234
}
175-
if (input_tensor->type != kTfLiteString) {
235+
if (!HasRegexTokenizerMetadata() && input_tensor->type != kTfLiteString) {
176236
return CreateStatusWithPayload(
177237
StatusCode::kInvalidArgument,
178238
absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
@@ -278,6 +338,38 @@ StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateNLClassifier(
278338
return std::move(nl_classifier);
279339
}
280340

341+
bool NLClassifier::HasRegexTokenizerMetadata() {
342+
if (GetMetadataExtractor()->GetInputTensorMetadata(
343+
kRegexTokenizerInputTensorIndex) == nullptr ||
344+
GetMetadataExtractor()
345+
->GetInputTensorMetadata(kRegexTokenizerInputTensorIndex)
346+
->process_units() == nullptr ||
347+
GetMetadataExtractor()
348+
->GetInputTensorMetadata(kRegexTokenizerInputTensorIndex)
349+
->process_units()
350+
->Get(kRegexTokenizerProcessUnitIndex) == nullptr) {
351+
return false;
352+
}
353+
return GetMetadataExtractor()
354+
->GetInputTensorMetadata(kRegexTokenizerInputTensorIndex)
355+
->process_units()
356+
->Get(kRegexTokenizerProcessUnitIndex)
357+
->options_type() == ProcessUnitOptions_RegexTokenizerOptions;
358+
}
359+
360+
absl::Status NLClassifier::SetupRegexTokenizer() {
361+
ASSIGN_OR_RETURN(
362+
tokenizer_,
363+
CreateTokenizerFromProcessUnit(
364+
GetMetadataExtractor()
365+
->GetInputTensorMetadata(kRegexTokenizerInputTensorIndex)
366+
->process_units()
367+
->Get(kRegexTokenizerProcessUnitIndex),
368+
GetMetadataExtractor()));
369+
370+
return absl::OkStatus();
371+
}
372+
281373
} // namespace nlclassifier
282374
} // namespace text
283375
} // namespace task

tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ limitations under the License.
3434
#include "tensorflow_lite_support/cc/port/statusor.h"
3535
#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
3636
#include "tensorflow_lite_support/cc/task/core/category.h"
37+
#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
3738

3839
namespace tflite {
3940
namespace support {
@@ -59,6 +60,8 @@ struct NLClassifierOptions {
5960
// The API expects a TFLite model with the following input/output tensor:
6061
// Input tensor:
6162
// (kTfLiteString) - input of the model, accepts a string.
63+
// or
64+
// (kTfLiteFloat32) - input of the model, accepts a tokenized input of a string
6265
// Output score tensor:
6366
// (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64)
6467
// - output scores for each class, if type is one of the Int types,
@@ -155,10 +158,14 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
155158
}
156159

157160
private:
161+
bool HasRegexTokenizerMetadata();
162+
absl::Status SetupRegexTokenizer();
163+
158164
NLClassifierOptions options_;
159165
// labels vector initialized from output tensor's associated file, if one
160166
// exists.
161167
std::unique_ptr<std::vector<std::string>> labels_vector_;
168+
std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_;
162169
};
163170

164171
} // namespace nlclassifier

tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.cc

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,14 @@ using ::tflite::support::task::core::PopulateTensor;
3434
using ::tflite::support::task::core::PopulateVector;
3535
using ::tflite::support::task::core::ReverseSortIndices;
3636
using ::tflite::support::text::tokenizer::BertTokenizer;
37-
using ::tflite::support::text::tokenizer::CreateTokenizerFromMetadata;
37+
using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit;
3838
using ::tflite::support::text::tokenizer::SentencePieceTokenizer;
3939
using ::tflite::support::text::tokenizer::TokenizerResult;
4040

41+
namespace {
42+
constexpr int kTokenizerProcessUnitIndex = 0;
43+
}
44+
4145
StatusOr<std::unique_ptr<QuestionAnswerer>>
4246
BertQuestionAnswerer::CreateQuestionAnswererWithMetadata(
4347
const std::string& path_to_model_with_metadata) {
@@ -327,8 +331,17 @@ std::string BertQuestionAnswerer::ConvertIndexToString(int start, int end) {
327331
}
328332

329333
absl::Status BertQuestionAnswerer::InitializeFromMetadata() {
334+
const ProcessUnit* tokenizer_process_unit =
335+
GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex);
336+
if (tokenizer_process_unit == nullptr) {
337+
return CreateStatusWithPayload(
338+
absl::StatusCode::kInvalidArgument,
339+
"No input process unit found from metadata.",
340+
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
341+
}
330342
ASSIGN_OR_RETURN(tokenizer_,
331-
CreateTokenizerFromMetadata(*GetMetadataExtractor()));
343+
CreateTokenizerFromProcessUnit(tokenizer_process_unit,
344+
GetMetadataExtractor()));
332345
return absl::OkStatus();
333346
}
334347

0 commit comments

Comments
 (0)