Skip to content

Commit

Permalink
Update RetinaNet to use better LR schedule and PredictionDecoder conf…
Browse files Browse the repository at this point in the history
…iguration (keras-team#1211)

* RetinaNet

* RetinaNet update guide

* RetinaNet update guide

* Update guides

* Update guides

* Update guides

* Update guides

* Update guides

* README update
  • Loading branch information
LukeWood authored Jan 26, 2023
1 parent 3037412 commit 3080d7f
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 181 deletions.
Binary file modified guides/img/retina_net_overview/retina_net_overview_13_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified guides/img/retina_net_overview/retina_net_overview_28_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified guides/img/retina_net_overview/retina_net_overview_30_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified guides/img/retina_net_overview/retina_net_overview_8_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 16 additions & 7 deletions guides/ipynb/keras_cv/retina_net_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"BATCH_SIZE = 16\n",
"EPOCHS = int(os.getenv(\"EPOCHS\", \"1\"))\n",
"# To fully train a RetinaNet, comment out this line.\n",
"# EPOCHS = 50\n",
"# EPOCHS = 100\n",
"CHECKPOINT_PATH = os.getenv(\"CHECKPOINT_PATH\", \"checkpoint/\")\n",
"INFERENCE_CHECKPOINT_PATH = os.getenv(\"INFERENCE_CHECKPOINT_PATH\", CHECKPOINT_PATH)\n",
"\n",
Expand Down Expand Up @@ -441,7 +441,6 @@
"source": [
"callbacks = [\n",
" keras.callbacks.TensorBoard(log_dir=\"logs\"),\n",
" keras.callbacks.ReduceLROnPlateau(patience=5),\n",
" keras.callbacks.ModelCheckpoint(CHECKPOINT_PATH, save_weights_only=True),\n",
"]\n",
""
Expand Down Expand Up @@ -470,7 +469,15 @@
"outputs": [],
"source": [
"# including a global_clipnorm is extremely important in object detection tasks\n",
"optimizer = tf.optimizers.SGD(global_clipnorm=10.0)\n",
"base_lr = 0.01\n",
"lr_decay = tf.keras.optimizers.schedules.PiecewiseConstantDecay(\n",
" boundaries=[12000 * 16, 16000 * 16],\n",
" values=[base_lr, 0.1 * base_lr, 0.01 * base_lr],\n",
")\n",
"\n",
"optimizer = tf.keras.optimizers.SGD(\n",
" learning_rate=lr_decay, momentum=0.9, global_clipnorm=10.0\n",
")\n",
"model.compile(\n",
" classification_loss=\"focal\",\n",
" box_loss=\"smoothl1\",\n",
Expand Down Expand Up @@ -529,7 +536,7 @@
"\n",
"\n",
"def visualize_detections(model, bounding_box_format):\n",
" images, y_true = next(iter(train_ds.take(1)))\n",
" images, y_true = next(iter(eval_ds.take(1)))\n",
" y_pred = model.predict(images)\n",
" y_pred = bounding_box.to_ragged(y_pred)\n",
" visualization.plot_bounding_box_gallery(\n",
Expand Down Expand Up @@ -573,8 +580,10 @@
"prediction_decoder = keras_cv.layers.MultiClassNonMaxSuppression(\n",
" bounding_box_format=\"xywh\",\n",
" from_logits=True,\n",
" iou_threshold=0.75,\n",
" confidence_threshold=0.85,\n",
" # Decrease the required threshold to make predictions get pruned out\n",
" iou_threshold=0.35,\n",
" # Tune confidence threshold for predictions to pass NMS\n",
" confidence_threshold=0.95,\n",
")\n",
"model.prediction_decoder = prediction_decoder\n",
"visualize_detections(model, bounding_box_format=\"xywh\")"
Expand All @@ -596,7 +605,7 @@
"Some follow up exercises for the reader:\n",
"\n",
"- add additional augmentation techniques to improve model performance\n",
"- grid search `confidence_threshold` and `iou_threshold` on `NmsPredictionDecoder` to\n",
"- grid search `confidence_threshold` and `iou_threshold` on `MultiClassNonMaxSuppression` to\n",
" achieve an optimal Mean Average Precision\n",
"- tune the hyperparameters and data augmentation used to produce high quality results\n",
"- train an object detection model on another dataset"
Expand Down
23 changes: 16 additions & 7 deletions guides/keras_cv/retina_net_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
BATCH_SIZE = 16
EPOCHS = int(os.getenv("EPOCHS", "1"))
# To fully train a RetinaNet, comment out this line.
# EPOCHS = 50
# EPOCHS = 100
CHECKPOINT_PATH = os.getenv("CHECKPOINT_PATH", "checkpoint/")
INFERENCE_CHECKPOINT_PATH = os.getenv("INFERENCE_CHECKPOINT_PATH", CHECKPOINT_PATH)

Expand Down Expand Up @@ -284,7 +284,6 @@ def dict_to_tuple(inputs):

callbacks = [
keras.callbacks.TensorBoard(log_dir="logs"),
keras.callbacks.ReduceLROnPlateau(patience=5),
keras.callbacks.ModelCheckpoint(CHECKPOINT_PATH, save_weights_only=True),
]

Expand All @@ -299,7 +298,15 @@ def dict_to_tuple(inputs):
"""

# including a global_clipnorm is extremely important in object detection tasks
optimizer = tf.optimizers.SGD(global_clipnorm=10.0)
base_lr = 0.01
lr_decay = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
boundaries=[12000 * 16, 16000 * 16],
values=[base_lr, 0.1 * base_lr, 0.01 * base_lr],
)

optimizer = tf.keras.optimizers.SGD(
learning_rate=lr_decay, momentum=0.9, global_clipnorm=10.0
)
model.compile(
classification_loss="focal",
box_loss="smoothl1",
Expand Down Expand Up @@ -330,7 +337,7 @@ def dict_to_tuple(inputs):


def visualize_detections(model, bounding_box_format):
images, y_true = next(iter(train_ds.take(1)))
images, y_true = next(iter(eval_ds.take(1)))
y_pred = model.predict(images)
y_pred = bounding_box.to_ragged(y_pred)
visualization.plot_bounding_box_gallery(
Expand Down Expand Up @@ -360,8 +367,10 @@ def visualize_detections(model, bounding_box_format):
prediction_decoder = keras_cv.layers.MultiClassNonMaxSuppression(
bounding_box_format="xywh",
from_logits=True,
iou_threshold=0.75,
confidence_threshold=0.85,
# Decrease the required threshold to make predictions get pruned out
iou_threshold=0.35,
# Tune confidence threshold for predictions to pass NMS
confidence_threshold=0.95,
)
model.prediction_decoder = prediction_decoder
visualize_detections(model, bounding_box_format="xywh")
Expand All @@ -376,7 +385,7 @@ def visualize_detections(model, bounding_box_format):
Some follow up exercises for the reader:
- add additional augmentation techniques to improve model performance
- grid search `confidence_threshold` and `iou_threshold` on `NmsPredictionDecoder` to
- grid search `confidence_threshold` and `iou_threshold` on `MultiClassNonMaxSuppression` to
achieve an optimal Mean Average Precision
- tune the hyperparameters and data augmentation used to produce high quality results
- train an object detection model on another dataset
Expand Down
Loading

0 comments on commit 3080d7f

Please sign in to comment.