Description
Hi,
I've built an NMT model and am trying to save it, but run into errors. I think the problem has to do with tensorflow not being able to recognize text input. To reproduce the problem, I've created a tiny program below. It loads a SentencePiece from a "tokens.model" file then creates a model that just takes text as input, and spits out the tokenization.
%%capture
!pip install tensorflow_text
import tensorflow as tf
import tensorflow_text as tf_text
proto = tf.io.gfile.GFile('tokens.model', 'rb').read()
tokenizer = tf_text.SentencepieceTokenizer(model=proto, nbest_size=1)
class MyModel(tf.keras.Model):
def __init__(self, tokenizer, **kwargs):
super().__init__(**kwargs)
self.tokenizer = tokenizer
def call(self, x):
return self.tokenizer.tokenize(x)
mm = MyModel(tokenizer)
mm('hello world!')
mm.predict('hello world!')
mm.save('temp')
Running mm = MyModel(tokenizer)
and mm.predict('hello world!')
work perfectly fine. I mention that because when saving the model I encounter the error:
ValueError: Model <main.MyModel object at 0x7f2071b13a20> cannot be saved because the input shapes have not been set. Usually, input shapes are automatically determined from calling .fit()
or .predict()
. To manually set the shapes, call model.build(input_shape)
.
I can't figure out how to remedy the issue, as I've clearly built the model and ran it. Trying to setup a build method gives me the error that tensorflow only allows that for float32 inputs.
In my more complicated model, I would generally like the pipeline Input text ---> complicated inner model workings ---> translated text; maybe I'm going about that a totally incorrect way
Below is the full error message
ValueError Traceback (most recent call last)
in ()
----> 1 mm.save('temp')
16 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
2000 # pylint: enable=line-too-long
2001 save.save_model(self, filepath, overwrite, include_optimizer, save_format,
-> 2002 signatures, options, save_traces)
2003
2004 def save_weights(self,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
155 else:
156 saved_model_save.save(model, filepath, overwrite, include_optimizer,
--> 157 signatures, options, save_traces)
158
159
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options, save_traces)
87 with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access
88 with utils.keras_option_scope(save_traces):
---> 89 save_lib.save(model, filepath, signatures, options)
90
91 if not include_optimizer:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py in save(obj, export_dir, signatures, options)
1031
1032 _, exported_graph, object_saver, asset_info = _build_meta_graph(
-> 1033 obj, signatures, options, meta_graph_def)
1034 saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
1035
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def)
1196
1197 with save_context.save_context(options):
-> 1198 return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
1131 if signatures is None:
1132 signatures = signature_serialization.find_function_to_export(
-> 1133 checkpoint_graph_view)
1134
1135 signatures, wrapped_functions = (
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/signature_serialization.py in find_function_to_export(saveable_view)
73 # If the user did not specify signatures, check the root object for a function
74 # that can be made into a signature.
---> 75 functions = saveable_view.list_functions(saveable_view.root)
76 signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
77 if signature is not None:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py in list_functions(self, obj, extra_functions)
149 if obj_functions is None:
150 obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
--> 151 self._serialization_cache)
152 self._functions[obj] = obj_functions
153 if extra_functions:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in _list_functions_for_serialization(self, serialization_cache)
2611 self.predict_function = None
2612 functions = super(
-> 2613 Model, self)._list_functions_for_serialization(serialization_cache)
2614 self.train_function = train_function
2615 self.test_function = test_function
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in _list_functions_for_serialization(self, serialization_cache)
3085 def _list_functions_for_serialization(self, serialization_cache):
3086 return (self._trackable_saved_model_saver
-> 3087 .list_functions_for_serialization(serialization_cache))
3088
3089 def getstate(self):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py in list_functions_for_serialization(self, serialization_cache)
92 return {}
93
---> 94 fns = self.functions_to_serialize(serialization_cache)
95
96 # The parent AutoTrackable class saves all user-defined tf.functions, and
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in functions_to_serialize(self, serialization_cache)
77 def functions_to_serialize(self, serialization_cache):
78 return (self._get_serialized_attributes(
---> 79 serialization_cache).functions_to_serialize)
80
81 def _get_serialized_attributes(self, serialization_cache):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
93
94 object_dict, function_dict = self._get_serialized_attributes_internal(
---> 95 serialization_cache)
96
97 serialized_attr.set_and_validate_objects(object_dict)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
49 # cache (i.e. this is the root level object).
50 if len(serialization_cache[constants.KERAS_CACHE_KEY]) == 1:
---> 51 default_signature = save_impl.default_save_signature(self.obj)
52
53 # Other than the default signature function, all other attributes match with
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in default_save_signature(layer)
202 def default_save_signature(layer):
203 original_losses = _reset_layer_losses(layer)
--> 204 fn = saving_utils.trace_model_call(layer)
205 fn.get_concrete_function()
206 _restore_layer_losses(original_losses)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saving_utils.py in trace_model_call(model, input_signature)
121
122 if input_signature is None:
--> 123 raise_model_input_error(model)
124
125 # TODO(mdan): Should the model's call be autographed by default?
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saving_utils.py in raise_model_input_error(model)
96 'set. Usually, input shapes are automatically determined from calling'
97 ' .fit()
or .predict()
. To manually set the shapes, call '
---> 98 'model.build(input_shape)
.'.format(model))
99
100
ValueError: Model <main.MyModel object at 0x7f2071b13a20> cannot be saved because the input shapes have not been set. Usually, input shapes are automatically determined from calling .fit()
or .predict()
. To manually set the shapes, call model.build(input_shape)
.