-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix cyclegan save error #925
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
examples/generative/cyclegan.py
Outdated
@@ -588,6 +591,8 @@ def discriminator_loss_fn(real, fake): | |||
generator_G=gen_G, generator_F=gen_F, discriminator_X=disc_X, discriminator_Y=disc_Y | |||
) | |||
|
|||
cycle_gan_model.compute_output_shape(input_shape=(None, 256, 256, 3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove this line? What happens if we do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without this line we get this error.
I followed the error message and tried to specify the input shape using the build method and got the same error.
ValueError Traceback (most recent call last)
in ()
42 tf.data.Dataset.zip((train_horses, train_zebras)),
43 epochs=1,
---> 44 callbacks=[plotter, model_checkpoint_callback],
45 )1 frames
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.traceback)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb/usr/local/lib/python3.7/dist-packages/keras/saving/saving_utils.py in raise_model_input_error(model)
92 # If the model is not aSequential
, it is intended to be a subclassed model.
93 raise ValueError(
---> 94 f'Model {model} cannot be saved either because the input shape is not '
95 'available or because the forward pass of the model is not defined.'
96 'To define a forward pass, please overrideModel.call()
. To specify 'ValueError: Model <main.CycleGan object at 0x7f3e527afb10> cannot be saved either because the input shape is not available or because the forward pass of the model is not defined.To define a forward pass, please override
Model.call()
. To specify an input shape, either callbuild(input_shape)
directly, or call the model on actual data usingModel()
,Model.fit()
, orModel.predict()
. If you have a custom training step, please make sure to invoke the forward pass in train step throughModel.__call__
, i.e.model(inputs)
, as opposed tomodel.call()
.
The only way to fix it was to add this line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What version of TF are you using?
You don't need to add this line then. Those who want to save the model can add the line (but this example doesn't save the model and thus doesn't need it) -- and the error message tells them exactly what they need to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The versions used were 2.8.2 and 2.9.2.
I have removed this line.
examples/generative/cyclegan.py
Outdated
@@ -409,6 +407,9 @@ def __init__( | |||
self.lambda_cycle = lambda_cycle | |||
self.lambda_identity = lambda_identity | |||
|
|||
def call(self, inputs): | |||
return self.disc_X(inputs), self.disc_Y(inputs), self.gen_G(inputs), self.gen_F(inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please run black
on the file to fix the code style.
In this problem, the error occurred because the input shape was not given. The code style has been corrected. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Hi all, So please add this line to the code for other users :) |
This is the solution for # 691
This fixes an issue that caused an error at CycleGan's Model Checkpoint.