Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix passing of keyword args to Dense layers in create_tower #339

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Fix passing of keyword args to Dense layers in create_tower
Current behavior: kwargs are passed to tf.keras.Sequential.add, so they
are not passed on to tf.keras.layers.Dense as intended. For example,
when passing `use_bias=False` to create_tower with the kwarg name
`kernel_regularizer`, it throws an exception:

Traceback (most recent call last):
  File "/Users/brussell/development/ranking/tensorflow_ranking/python/keras/layers_test.py", line 33, in test_create_tower_with_kwargs
    tower = layers.create_tower([3, 2, 1], 1, activation='relu', use_bias=False)
  File "/Users/brussell/development/ranking/tensorflow_ranking/python/keras/layers.py", line 70, in create_tower
    model.add(tf.keras.layers.Dense(units=layer_width), **kwargs)
  File "/usr/local/anaconda3/lib/python3.9/site-packages/tensorflow/python/trackable/base.py", line 205, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "/usr/local/anaconda3/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 61, in error_handler
    return fn(*args, **kwargs)
TypeError: add() got an unexpected keyword argument 'use_bias'
test_create_tower_with_kwargs

Fix: This PR fixes the behavior by shifting the closing paren of
tf.keras.layers.Dense to the correct location.
  • Loading branch information
b4russell committed Dec 8, 2022
commit 1400f70d3449ae1969ec7284fbe512dbddbaf865
6 changes: 3 additions & 3 deletions tensorflow_ranking/python/keras/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def create_tower(hidden_layer_dims: List[int],
dropout: When not `None`, the probability we will drop out a given
coordinate.
name: Name of the Keras layer.
**kwargs: Keyword arguments for every `tf.keras.Dense` layers.
**kwargs: Keyword arguments for every `tf.keras.layers.Dense` layer.

Returns:
A `tf.keras.Sequential` object.
Expand All @@ -67,13 +67,13 @@ def create_tower(hidden_layer_dims: List[int],
if input_batch_norm:
model.add(tf.keras.layers.BatchNormalization(momentum=batch_norm_moment))
for layer_width in hidden_layer_dims:
model.add(tf.keras.layers.Dense(units=layer_width), **kwargs)
model.add(tf.keras.layers.Dense(units=layer_width, **kwargs))
if use_batch_norm:
model.add(tf.keras.layers.BatchNormalization(momentum=batch_norm_moment))
model.add(tf.keras.layers.Activation(activation=activation))
if dropout:
model.add(tf.keras.layers.Dropout(rate=dropout))
model.add(tf.keras.layers.Dense(units=output_units), **kwargs)
model.add(tf.keras.layers.Dense(units=output_units, **kwargs))
return model


Expand Down
4 changes: 4 additions & 0 deletions tensorflow_ranking/python/keras/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def test_create_tower(self):
outputs = tower(inputs)
self.assertAllEqual([2, 3, 1], outputs.get_shape().as_list())

def test_create_tower_with_bias_kwarg(self):
tower = layers.create_tower([3, 2], 1, use_bias=False)
tower_layers_bias = [tower.get_layer(name).use_bias for name in ['dense_1', 'dense_2']]
self.assertAllEqual([False, False], tower_layers_bias)

class FlattenListTest(tf.test.TestCase):

Expand Down