Skip to content

Commit

Permalink
Replace keep_dims with keepdims in TF calls.
Browse files Browse the repository at this point in the history
TF replaced keep_dims with keepdims a while ago
and now shows a warning when using the old name.
  • Loading branch information
waleedka committed Apr 21, 2018
1 parent 9cea282 commit 4129a27
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
7 changes: 5 additions & 2 deletions mrcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,9 @@ def mrcnn_class_loss_graph(target_class_ids, pred_class_logits,
classes that are in the dataset of the image, and 0
for classes that are not in the dataset.
"""
# During model building, Keras calls this function with
# target_class_ids of type float32. Unclear why. Cast it
# to int to get around it.
target_class_ids = tf.cast(target_class_ids, 'int64')

# Find predictions of classes that are not in the dataset.
Expand Down Expand Up @@ -2137,7 +2140,7 @@ def compile(self, learning_rate, momentum):
if layer.output in self.keras_model.losses:
continue
loss = (
tf.reduce_mean(layer.output, keep_dims=True)
tf.reduce_mean(layer.output, keepdims=True)
* self.config.LOSS_WEIGHTS.get(name, 1.))
self.keras_model.add_loss(loss)

Expand All @@ -2161,7 +2164,7 @@ def compile(self, learning_rate, momentum):
layer = self.keras_model.get_layer(name)
self.keras_model.metrics_names.append(name)
loss = (
tf.reduce_mean(layer.output, keep_dims=True)
tf.reduce_mean(layer.output, keepdims=True)
* self.config.LOSS_WEIGHTS.get(name, 1.))
self.keras_model.metrics_tensors.append(loss)

Expand Down
4 changes: 2 additions & 2 deletions samples/coco/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# Train a new model starting from pre-trained COCO weights
python3 coco.py train --dataset=/path/to/coco/ --model=coco
# Train a new model starting from ImageNet weights
python3 coco.py train --dataset=/path/to/coco/ --model=imagenet
# Train a new model starting from ImageNet weights. Also auto download COCO dataset
python3 coco.py train --dataset=/path/to/coco/ --model=imagenet --download=True
# Continue training a model that you had trained earlier
python3 coco.py train --dataset=/path/to/coco/ --model=/path/to/weights.h5
Expand Down

2 comments on commit 4129a27

@JonathanCMitchell
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks training for tensorflow versions <= 1.4.0

@waleedka
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! Thanks for letting me know. Since TF is at 1.8 now, I'm leaning towards keeping this change and modifying the version check in model.py to make it explicit that this code requires 1.5+.

Please sign in to comment.