diff --git a/tensorflow_ranking/python/keras/layers.py b/tensorflow_ranking/python/keras/layers.py index 854c73d..8aa41dd 100644 --- a/tensorflow_ranking/python/keras/layers.py +++ b/tensorflow_ranking/python/keras/layers.py @@ -57,7 +57,7 @@ def create_tower(hidden_layer_dims: List[int], dropout: When not `None`, the probability we will drop out a given coordinate. name: Name of the Keras layer. - **kwargs: Keyword arguments for every `tf.keras.Dense` layers. + **kwargs: Keyword arguments for every `tf.keras.layers.Dense` layer. Returns: A `tf.keras.Sequential` object. @@ -67,13 +67,13 @@ def create_tower(hidden_layer_dims: List[int], if input_batch_norm: model.add(tf.keras.layers.BatchNormalization(momentum=batch_norm_moment)) for layer_width in hidden_layer_dims: - model.add(tf.keras.layers.Dense(units=layer_width), **kwargs) + model.add(tf.keras.layers.Dense(units=layer_width, **kwargs)) if use_batch_norm: model.add(tf.keras.layers.BatchNormalization(momentum=batch_norm_moment)) model.add(tf.keras.layers.Activation(activation=activation)) if dropout: model.add(tf.keras.layers.Dropout(rate=dropout)) - model.add(tf.keras.layers.Dense(units=output_units), **kwargs) + model.add(tf.keras.layers.Dense(units=output_units, **kwargs)) return model diff --git a/tensorflow_ranking/python/keras/layers_test.py b/tensorflow_ranking/python/keras/layers_test.py index d66ba14..9383f89 100644 --- a/tensorflow_ranking/python/keras/layers_test.py +++ b/tensorflow_ranking/python/keras/layers_test.py @@ -28,6 +28,10 @@ def test_create_tower(self): outputs = tower(inputs) self.assertAllEqual([2, 3, 1], outputs.get_shape().as_list()) + def test_create_tower_with_bias_kwarg(self): + tower = layers.create_tower([3, 2], 1, use_bias=False) + tower_layers_bias = [tower.get_layer(name).use_bias for name in ['dense_1', 'dense_2']] + self.assertAllEqual([False, False], tower_layers_bias) class FlattenListTest(tf.test.TestCase):