Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix cyclegan save error #925

Merged
merged 5 commits into from
Jul 5, 2022
Merged

fix cyclegan save error #925

merged 5 commits into from
Jul 5, 2022

Conversation

k2ok3i
Copy link
Contributor

@k2ok3i k2ok3i commented Jun 21, 2022

This is the solution for # 691
This fixes an issue that caused an error at CycleGan's Model Checkpoint.

@google-cla
Copy link

google-cla bot commented Jun 21, 2022

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

@@ -588,6 +591,8 @@ def discriminator_loss_fn(real, fake):
generator_G=gen_G, generator_F=gen_F, discriminator_X=disc_X, discriminator_Y=disc_Y
)

cycle_gan_model.compute_output_shape(input_shape=(None, 256, 256, 3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this line?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this line? What happens if we do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this line we get this error.
I followed the error message and tried to specify the input shape using the build method and got the same error.

ValueError Traceback (most recent call last)
in ()
42 tf.data.Dataset.zip((train_horses, train_zebras)),
43 epochs=1,
---> 44 callbacks=[plotter, model_checkpoint_callback],
45 )

1 frames

/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.traceback)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb

/usr/local/lib/python3.7/dist-packages/keras/saving/saving_utils.py in raise_model_input_error(model)
92 # If the model is not a Sequential, it is intended to be a subclassed model.
93 raise ValueError(
---> 94 f'Model {model} cannot be saved either because the input shape is not '
95 'available or because the forward pass of the model is not defined.'
96 'To define a forward pass, please override Model.call(). To specify '

ValueError: Model <main.CycleGan object at 0x7f3e527afb10> cannot be saved either because the input shape is not available or because the forward pass of the model is not defined.To define a forward pass, please override Model.call(). To specify an input shape, either call build(input_shape) directly, or call the model on actual data using Model(), Model.fit(), or Model.predict(). If you have a custom training step, please make sure to invoke the forward pass in train step through Model.__call__, i.e. model(inputs), as opposed to model.call().

The only way to fix it was to add this line.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What version of TF are you using?

You don't need to add this line then. Those who want to save the model can add the line (but this example doesn't save the model and thus doesn't need it) -- and the error message tells them exactly what they need to do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The versions used were 2.8.2 and 2.9.2.
I have removed this line.

@@ -409,6 +407,9 @@ def __init__(
self.lambda_cycle = lambda_cycle
self.lambda_identity = lambda_identity

def call(self, inputs):
return self.disc_X(inputs), self.disc_Y(inputs), self.gen_G(inputs), self.gen_F(inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please run black on the file to fix the code style.

@k2ok3i
Copy link
Contributor Author

k2ok3i commented Jun 22, 2022

In this problem, the error occurred because the input shape was not given.
Therefore, the input shape was given using the compute_output_shape method.

The code style has been corrected.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@fchollet fchollet merged commit 747d64b into keras-team:master Jul 5, 2022
@emily-stueckelmaier
Copy link

Hi all,
I was trying to run the code without cycle_gan_model.compute_output_shape(input_shape=(None, 256, 256, 3))
Without it I get the same saving error, even if i run it in the latest colab version. But with the line cycle_gan_model.compute_output_shape(input_shape=(None, 256, 256, 3)) the error is fixed

So please add this line to the code for other users :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants