Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add filter response normalization #765

Merged
merged 19 commits into from
Mar 29, 2020

Conversation

AakashKumarNain
Copy link
Member

No description provided.

@AakashKumarNain
Copy link
Member Author

@seanpmorgan @WindQAQ @facaiy I have added the latest contribution for normalization from the Google Brain team. Although the implementation is fully correct IMO but before I go on and write unit-tests for it, can you please take a look at the implementation and let me know if it looks okay. Thank you

Copy link
Member

@Squadrick Squadrick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just a clarification regarding epsilon.

tensorflow_addons/layers/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/normalizations.py Outdated Show resolved Hide resolved
@AakashKumarNain
Copy link
Member Author

@seanpmorgan @WindQAQ @Squadrick @saurabhme I have opened a PR for TLU. Please check #857

@googlebot

This comment has been minimized.

@gabrieldemarmiesse

This comment has been minimized.

@googlebot

This comment has been minimized.

@gabrieldemarmiesse
Copy link
Member

@AakashKumarNain I've merged master into your branch to update it and fixed any formatting/conflicts it might have. If you need to do some more modifications, please do git pull beforehand.

@AakashKumarNain
Copy link
Member Author

@gabrieldemarmiesse this will come with a lot of changes. Will do in the coming weeks. It has been long overdue

@AakashKumarNain
Copy link
Member Author

AakashKumarNain commented Mar 10, 2020

@gabrieldemarmiesse @seanpmorgan the tests work fine in eager mode but are failing in graph mode. Let me know if you need any other info. Thanks

inputs = np.random.rand(28, 28, 1).astype(np.float32)
frn = FilterResponseNormalization(
            beta_initializer="zeros", gamma_initializer="ones"
        )
frn.build((28, 28, 1))
observed = frn(inputs)
expected = self.calculate_frn(inputs, beta=0, gamma=1)
self.assertAllClose(expected, observed) # this is where the test fails

@Squadrick
Copy link
Member

@AakashKumarNain I know the problem. You aren't initializing the variables, you need to add: self.evaluate(tf.compat.v1.global_variables_initializer()) after frn.build(...). Also, break test_random_input() into multiple smaller function for each instance params of the frn layer.

Here's a stripped-down version of test_random_inputs() that runs without any issues:

    def test_random_inputs(self):
        inputs = np.random.rand(28, 28, 1).astype(np.float32)
        frn = FilterResponseNormalization(
            beta_initializer="zeros", gamma_initializer="ones"
        )
        frn.build((28, 28, 1))
        self.evaluate(tf.compat.v1.global_variables_initializer())  # <-- missing
        observed = frn(inputs)
        expected = self.calculate_frn(inputs, beta=0, gamma=1)
        self.assertAllClose(expected, observed[0])

The output log of the successful run;

================================================================================
Target //tensorflow_addons/layers:normalizations_test up-to-date:
  bazel-bin/tensorflow_addons/layers/normalizations_test
INFO: Elapsed time: 33.660s, Critical Path: 33.31s
INFO: 2 processes: 2 local.
INFO: Build completed successfully, 2 total actions
//tensorflow_addons/layers:normalizations_test                           PASSED in 33.3s

INFO: Build completed successfully, 2 total actions

Also, prefer not to use randomized inputs as inputs for better repro. If you want to use random input, at least use a fixed seed for np and tf.

@AakashKumarNain
Copy link
Member Author

@Squadrick yeah 🤦‍♂ . (With session gone, I always forget that part now). And thanks for pointing out the seed part. Will update it. Thanks a lot

@AakashKumarNain
Copy link
Member Author

@seanpmorgan @Squadrick @gabrieldemarmiesse I think we can merge this now. If you think, any changes are required, please let me know. Thanks

Copy link
Member

@Squadrick Squadrick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AakashKumarNain A few changes.

tensorflow_addons/layers/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/normalizations_test.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/normalizations_test.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/__init__.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/normalizations.py Show resolved Hide resolved
Copy link
Member

@Squadrick Squadrick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After going through the paper, I realized that the authors also tested this for FC. Right now, your implementation works only for 4D inputs, and it is mentioned so in the docstring. However, we have no asserts or checks, which will lead to some obscure shape/size errors from Tensorflow instead of a nicer error.

If you decide to scope this PR to only 4D tensors:

  1. Check on input_shape to ensure it is 4D only.
  2. axis should be parameterized, ignore my previous comments. But it should only accept a list or size 2. Remove the check for [1, 2], since [-1, -2] or [2, 3] is perfectly acceptable for when we set channels_first in the preceding conv layer.
  3. Add a TODO and create a new issue as "contributions welcome" for adding support for a generalized N-D input, similar to batch norm. (@saurabhme thoughts on this?)

tensorflow_addons/layers/normalizations.py Show resolved Hide resolved
tensorflow_addons/layers/normalizations.py Show resolved Hide resolved
tensorflow_addons/layers/normalizations.py Show resolved Hide resolved
@AakashKumarNain
Copy link
Member Author

@gabrieldemarmiesse @Squadrick I have made all the changes. Also, I am using only pytest for all tests now (much better). Let me know if any other changes are required

tensorflow_addons/layers/normalizations.py Show resolved Hide resolved
tensorflow_addons/layers/normalizations.py Outdated Show resolved Hide resolved
Copy link
Member

@Squadrick Squadrick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes. A few more changes: some that weren't fixed, and some that I overlooked (sorry).

tensorflow_addons/layers/normalizations.py Show resolved Hide resolved
tensorflow_addons/layers/normalizations.py Outdated Show resolved Hide resolved
tensorflow_addons/layers/normalizations.py Outdated Show resolved Hide resolved
@Squadrick Squadrick self-assigned this Mar 26, 2020
@Squadrick
Copy link
Member

Created #1441 for adding more tensor shape support to this layer.

Copy link
Member

@Squadrick Squadrick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @AakashKumarNain

@Squadrick Squadrick dismissed gabrieldemarmiesse’s stale review March 29, 2020 17:17

The request change will be done in #1441

@Squadrick Squadrick merged commit aa0b8f5 into tensorflow:master Mar 29, 2020
@Squadrick
Copy link
Member

Unblocks #1441

@AakashKumarNain
Copy link
Member Author

Thanks @Squadrick . Will work on #1441 now.

@Squadrick
Copy link
Member

@AakashKumarNain, I think someone else has started working on it already. You can check the comments in #1441.

jrruijli pushed a commit to jrruijli/addons that referenced this pull request Dec 23, 2020
* Add FRN layer (only 4D NHWC tensor support)

Co-authored-by: gabrieldemarmiesse <gabrieldemarmiesse@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants