@@ -29,12 +29,16 @@ limitations under the License.
2929#include " tensorflow/lite/c/common.h"
3030#include " tensorflow/lite/core/api/op_resolver.h"
3131#include " tensorflow/lite/kernels/internal/tensor_ctypes.h"
32+ #include " tensorflow/lite/type_to_tflitetype.h"
3233#include " tensorflow_lite_support/cc/common.h"
3334#include " tensorflow_lite_support/cc/port/status_macros.h"
3435#include " tensorflow_lite_support/cc/port/statusor.h"
3536#include " tensorflow_lite_support/cc/task/core/category.h"
3637#include " tensorflow_lite_support/cc/task/core/task_api_factory.h"
3738#include " tensorflow_lite_support/cc/task/core/task_utils.h"
39+ #include " tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
40+ #include " tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
41+ #include " tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h"
3842#include " tensorflow_lite_support/cc/utils/common_utils.h"
3943
4044namespace tflite {
@@ -52,8 +56,16 @@ using ::tflite::support::StatusOr;
5256using ::tflite::support::task::core::Dequantize;
5357using ::tflite::support::task::core::GetStringAtIndex;
5458using ::tflite::support::task::core::PopulateTensor;
59+ using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit;
60+ using ::tflite::support::text::tokenizer::RegexTokenizer;
61+ using ::tflite::support::text::tokenizer::TokenizerResult;
5562using ::tflite::support::utils::LoadVocabFromBuffer;
5663
64+ namespace {
65+ constexpr int kRegexTokenizerInputTensorIndex = 0 ;
66+ constexpr int kRegexTokenizerProcessUnitIndex = 0 ;
67+ } // namespace
68+
5769const NLClassifierOptions& NLClassifier::GetOptions () const { return options_; }
5870
5971absl::Status NLClassifier::TrySetLabelFromMetadata (
@@ -102,11 +114,59 @@ std::vector<core::Category> NLClassifier::Classify(const std::string& text) {
102114
103115absl::Status NLClassifier::Preprocess (
104116 const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
105- PopulateTensor (
106- input,
107- FindTensorWithNameOrIndex (
108- input_tensors, GetMetadataExtractor ()->GetInputTensorMetadata (),
109- options_.input_tensor_name , options_.input_tensor_index ));
117+ TfLiteTensor* input_tensor = FindTensorWithNameOrIndex (
118+ input_tensors, GetMetadataExtractor ()->GetInputTensorMetadata (),
119+ options_.input_tensor_name , options_.input_tensor_index );
120+ if (input_tensor == nullptr ) {
121+ return CreateStatusWithPayload (
122+ absl::StatusCode::kInvalidArgument ,
123+ " No input tensor found from NLClassifierOptions." ,
124+ TfLiteSupportStatus::kInputTensorNotFoundError );
125+ }
126+
127+ if (HasRegexTokenizerMetadata ()) {
128+ RETURN_IF_ERROR (SetupRegexTokenizer ());
129+
130+ // |<-------sentence_length-------->|
131+ // input_tensor <START>, t1, t2... <PAD>, <PAD>...
132+ // <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's not
133+ // found in tokenizer vocab.
134+ TokenizerResult result = tokenizer_->Tokenize (input);
135+
136+ size_t max_sentence_length = input_tensor->dims ->size == 2
137+ ? input_tensor->dims ->data [1 ]
138+ : input_tensor->dims ->data [0 ];
139+
140+ int unknown_token_id = 0 ;
141+ tokenizer_->LookupId (RegexTokenizer::kUnknown , &unknown_token_id);
142+
143+ int pad_token_id = 0 ;
144+ tokenizer_->LookupId (RegexTokenizer::kPad , &pad_token_id);
145+
146+ std::vector<float > input_tokens (max_sentence_length, pad_token_id);
147+ int start_token_id = 0 ;
148+ size_t input_token_index = 0 ;
149+ if (tokenizer_->LookupId (RegexTokenizer::kStart , &start_token_id)) {
150+ input_tokens[0 ] = start_token_id;
151+ input_token_index = 1 ;
152+ }
153+
154+ for (size_t i = 0 ; (i < result.subwords .size ()) &&
155+ (input_token_index < max_sentence_length);
156+ ++i) {
157+ const std::string& token = result.subwords [i];
158+ int token_id = 0 ;
159+ if (tokenizer_->LookupId (token, &token_id)) {
160+ input_tokens[input_token_index] = token_id;
161+ } else {
162+ input_tokens[input_token_index] = unknown_token_id;
163+ }
164+ }
165+
166+ PopulateTensor (input_tokens, input_tensor);
167+ } else {
168+ PopulateTensor (input, input_tensor);
169+ }
110170 return absl::OkStatus ();
111171}
112172
@@ -172,7 +232,7 @@ absl::Status NLClassifier::Initialize(const NLClassifierOptions& options) {
172232 options.input_tensor_index ),
173233 TfLiteSupportStatus::kInputTensorNotFoundError );
174234 }
175- if (input_tensor->type != kTfLiteString ) {
235+ if (! HasRegexTokenizerMetadata () && input_tensor->type != kTfLiteString ) {
176236 return CreateStatusWithPayload (
177237 StatusCode::kInvalidArgument ,
178238 absl::StrCat (" Type mismatch for input tensor " , input_tensor->name ,
@@ -278,6 +338,38 @@ StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateNLClassifier(
278338 return std::move (nl_classifier);
279339}
280340
341+ bool NLClassifier::HasRegexTokenizerMetadata () {
342+ if (GetMetadataExtractor ()->GetInputTensorMetadata (
343+ kRegexTokenizerInputTensorIndex ) == nullptr ||
344+ GetMetadataExtractor ()
345+ ->GetInputTensorMetadata (kRegexTokenizerInputTensorIndex )
346+ ->process_units () == nullptr ||
347+ GetMetadataExtractor ()
348+ ->GetInputTensorMetadata (kRegexTokenizerInputTensorIndex )
349+ ->process_units ()
350+ ->Get (kRegexTokenizerProcessUnitIndex ) == nullptr ) {
351+ return false ;
352+ }
353+ return GetMetadataExtractor ()
354+ ->GetInputTensorMetadata (kRegexTokenizerInputTensorIndex )
355+ ->process_units ()
356+ ->Get (kRegexTokenizerProcessUnitIndex )
357+ ->options_type () == ProcessUnitOptions_RegexTokenizerOptions;
358+ }
359+
360+ absl::Status NLClassifier::SetupRegexTokenizer () {
361+ ASSIGN_OR_RETURN (
362+ tokenizer_,
363+ CreateTokenizerFromProcessUnit (
364+ GetMetadataExtractor ()
365+ ->GetInputTensorMetadata (kRegexTokenizerInputTensorIndex )
366+ ->process_units ()
367+ ->Get (kRegexTokenizerProcessUnitIndex ),
368+ GetMetadataExtractor ()));
369+
370+ return absl::OkStatus ();
371+ }
372+
281373} // namespace nlclassifier
282374} // namespace text
283375} // namespace task
0 commit comments