Skip to content

Saving model that accepts text as input fails #512

Closed
@Jmkernes

Description

@Jmkernes

Hi,

I've built an NMT model and am trying to save it, but run into errors. I think the problem has to do with tensorflow not being able to recognize text input. To reproduce the problem, I've created a tiny program below. It loads a SentencePiece from a "tokens.model" file then creates a model that just takes text as input, and spits out the tokenization.

%%capture
!pip install tensorflow_text

import tensorflow as tf
import tensorflow_text as tf_text

proto = tf.io.gfile.GFile('tokens.model', 'rb').read()
tokenizer = tf_text.SentencepieceTokenizer(model=proto, nbest_size=1)

class MyModel(tf.keras.Model):
  def __init__(self, tokenizer, **kwargs):
    super().__init__(**kwargs)
    self.tokenizer = tokenizer
  def call(self, x):
    return self.tokenizer.tokenize(x)

mm = MyModel(tokenizer)
mm('hello world!')
mm.predict('hello world!')

mm.save('temp')

Running mm = MyModel(tokenizer) and mm.predict('hello world!')work perfectly fine. I mention that because when saving the model I encounter the error:

ValueError: Model <main.MyModel object at 0x7f2071b13a20> cannot be saved because the input shapes have not been set. Usually, input shapes are automatically determined from calling .fit() or .predict(). To manually set the shapes, call model.build(input_shape).

I can't figure out how to remedy the issue, as I've clearly built the model and ran it. Trying to setup a build method gives me the error that tensorflow only allows that for float32 inputs.

In my more complicated model, I would generally like the pipeline Input text ---> complicated inner model workings ---> translated text; maybe I'm going about that a totally incorrect way

Below is the full error message


ValueError Traceback (most recent call last)
in ()
----> 1 mm.save('temp')

16 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
2000 # pylint: enable=line-too-long
2001 save.save_model(self, filepath, overwrite, include_optimizer, save_format,
-> 2002 signatures, options, save_traces)
2003
2004 def save_weights(self,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
155 else:
156 saved_model_save.save(model, filepath, overwrite, include_optimizer,
--> 157 signatures, options, save_traces)
158
159

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options, save_traces)
87 with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access
88 with utils.keras_option_scope(save_traces):
---> 89 save_lib.save(model, filepath, signatures, options)
90
91 if not include_optimizer:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py in save(obj, export_dir, signatures, options)
1031
1032 _, exported_graph, object_saver, asset_info = _build_meta_graph(
-> 1033 obj, signatures, options, meta_graph_def)
1034 saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
1035

/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def)
1196
1197 with save_context.save_context(options):
-> 1198 return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
1131 if signatures is None:
1132 signatures = signature_serialization.find_function_to_export(
-> 1133 checkpoint_graph_view)
1134
1135 signatures, wrapped_functions = (

/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/signature_serialization.py in find_function_to_export(saveable_view)
73 # If the user did not specify signatures, check the root object for a function
74 # that can be made into a signature.
---> 75 functions = saveable_view.list_functions(saveable_view.root)
76 signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
77 if signature is not None:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py in list_functions(self, obj, extra_functions)
149 if obj_functions is None:
150 obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
--> 151 self._serialization_cache)
152 self._functions[obj] = obj_functions
153 if extra_functions:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in _list_functions_for_serialization(self, serialization_cache)
2611 self.predict_function = None
2612 functions = super(
-> 2613 Model, self)._list_functions_for_serialization(serialization_cache)
2614 self.train_function = train_function
2615 self.test_function = test_function

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in _list_functions_for_serialization(self, serialization_cache)
3085 def _list_functions_for_serialization(self, serialization_cache):
3086 return (self._trackable_saved_model_saver
-> 3087 .list_functions_for_serialization(serialization_cache))
3088
3089 def getstate(self):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py in list_functions_for_serialization(self, serialization_cache)
92 return {}
93
---> 94 fns = self.functions_to_serialize(serialization_cache)
95
96 # The parent AutoTrackable class saves all user-defined tf.functions, and

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in functions_to_serialize(self, serialization_cache)
77 def functions_to_serialize(self, serialization_cache):
78 return (self._get_serialized_attributes(
---> 79 serialization_cache).functions_to_serialize)
80
81 def _get_serialized_attributes(self, serialization_cache):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
93
94 object_dict, function_dict = self._get_serialized_attributes_internal(
---> 95 serialization_cache)
96
97 serialized_attr.set_and_validate_objects(object_dict)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
49 # cache (i.e. this is the root level object).
50 if len(serialization_cache[constants.KERAS_CACHE_KEY]) == 1:
---> 51 default_signature = save_impl.default_save_signature(self.obj)
52
53 # Other than the default signature function, all other attributes match with

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in default_save_signature(layer)
202 def default_save_signature(layer):
203 original_losses = _reset_layer_losses(layer)
--> 204 fn = saving_utils.trace_model_call(layer)
205 fn.get_concrete_function()
206 _restore_layer_losses(original_losses)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saving_utils.py in trace_model_call(model, input_signature)
121
122 if input_signature is None:
--> 123 raise_model_input_error(model)
124
125 # TODO(mdan): Should the model's call be autographed by default?

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saving_utils.py in raise_model_input_error(model)
96 'set. Usually, input shapes are automatically determined from calling'
97 ' .fit() or .predict(). To manually set the shapes, call '
---> 98 'model.build(input_shape).'.format(model))
99
100

ValueError: Model <main.MyModel object at 0x7f2071b13a20> cannot be saved because the input shapes have not been set. Usually, input shapes are automatically determined from calling .fit() or .predict(). To manually set the shapes, call model.build(input_shape).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions