Skip to content

Commit

Permalink
fix syncbn model zoo configs and update docs
Browse files Browse the repository at this point in the history
Reviewed By: rbgirshick

Differential Revision: D20007996

fbshipit-source-id: 555636b8b2fe5df45b677164ec7836748c962063
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Feb 23, 2020
1 parent c14fcae commit 5db4ea3
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 34 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright 2019, Facebook, Inc
Copyright 2019 - present, Facebook, Inc

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
16 changes: 8 additions & 8 deletions MODEL_ZOO.md
Original file line number Diff line number Diff line change
Expand Up @@ -792,19 +792,19 @@ Ablations for normalization methods:
<!-- ROW: mask_rcnn_R_50_FPN_3x_syncbn -->
<tr><td align="left"><a href="configs/Misc/mask_rcnn_R_50_FPN_3x_syncbn.yaml">SyncBN</a></td>
<td align="center">3x</td>
<td align="center">0.464</td>
<td align="center">0.063</td>
<td align="center">5.6</td>
<td align="center">42.0</td>
<td align="center">0.412</td>
<td align="center">0.053</td>
<td align="center">5.5</td>
<td align="center">41.9</td>
<td align="center">37.8</td>
<td align="center">143915318</td>
<td align="center"><a href="https://dl.fbaipublicfiles.com/detectron2/Misc/mask_rcnn_R_50_FPN_3x_syncbn/143915318/model_final_220cfb.pkl">model</a>&nbsp;|&nbsp;<a href="https://dl.fbaipublicfiles.com/detectron2/Misc/mask_rcnn_R_50_FPN_3x_syncbn/143915318/metrics.json">metrics</a></td>
<td align="center">169527823</td>
<td align="center"><a href="https://dl.fbaipublicfiles.com/detectron2/Misc/mask_rcnn_R_50_FPN_3x_syncbn/169527823/model_final_3b3c51.pkl">model</a>&nbsp;|&nbsp;<a href="https://dl.fbaipublicfiles.com/detectron2/Misc/mask_rcnn_R_50_FPN_3x_syncbn/169527823/metrics.json">metrics</a></td>
</tr>
<!-- ROW: mask_rcnn_R_50_FPN_3x_gn -->
<tr><td align="left"><a href="configs/Misc/mask_rcnn_R_50_FPN_3x_gn.yaml">GN</a></td>
<td align="center">3x</td>
<td align="center">0.356</td>
<td align="center">0.077</td>
<td align="center">0.070</td>
<td align="center">7.3</td>
<td align="center">42.6</td>
<td align="center">38.6</td>
Expand All @@ -815,7 +815,7 @@ Ablations for normalization methods:
<tr><td align="left"><a href="configs/Misc/scratch_mask_rcnn_R_50_FPN_3x_gn.yaml">GN (scratch)</a></td>
<td align="center">3x</td>
<td align="center">0.400</td>
<td align="center">0.077</td>
<td align="center">0.070</td>
<td align="center">9.8</td>
<td align="center">39.9</td>
<td align="center">36.6</td>
Expand Down
2 changes: 1 addition & 1 deletion configs/Misc/mask_rcnn_R_50_FPN_3x_syncbn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ MODEL:
RESNETS:
DEPTH: 50
NORM: "SyncBN"
STRIDE_IN_1X1: False
STRIDE_IN_1X1: True
FPN:
NORM: "SyncBN"
ROI_BOX_HEAD:
Expand Down
8 changes: 6 additions & 2 deletions demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,12 @@ def get_parser():
start_time = time.time()
predictions, visualized_output = demo.run_on_image(img)
logger.info(
"{}: detected {} instances in {:.2f}s".format(
path, len(predictions["instances"]), time.time() - start_time
"{}: {} in {:.2f}s".format(
path,
"detected {} instances".format(len(predictions["instances"]))
if "instances" in predictions
else "finished",
time.time() - start_time,
)
)

Expand Down
4 changes: 3 additions & 1 deletion detectron2/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@
_C.MODEL.BACKBONE = CN()

_C.MODEL.BACKBONE.NAME = "build_resnet_backbone"
# Add StopGrad at a specified stage so the bottom layers are frozen
# Freeze the first several stages so they are not trained.
# There are 5 stages in ResNet. The first is a convolution, and the following
# stages are each group of residual blocks.
_C.MODEL.BACKBONE.FREEZE_AT = 2


Expand Down
2 changes: 1 addition & 1 deletion detectron2/data/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def load_coco_json(json_file, image_root, dataset_name=None, extra_annotation_ke
For example, the densepose annotations are loaded in this way.
Returns:
list[dict]: a list of dicts in Detectron2 standard format. (See
list[dict]: a list of dicts in Detectron2 standard dataset dicts format. (See
`Using Custom Datasets </tutorials/datasets.html>`_ )
Notes:
Expand Down
4 changes: 2 additions & 2 deletions detectron2/engine/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def test_and_save_results():

if comm.is_main_process():
# run writers in the end, so that evaluation metrics are written
ret.append(hooks.PeriodicWriter(self.build_writers()))
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
return ret

def build_writers(self):
Expand All @@ -362,7 +362,7 @@ def build_writers(self):
]
"""
# Assume the default print/log frequency.
# Here the default print/log frequency of each writer is used.
return [
# It may not always print what you want to see, since it prints "common" metrics only.
CommonMetricPrinter(self.max_iter),
Expand Down
4 changes: 2 additions & 2 deletions detectron2/model_zoo/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class _ModelZooUrls(object):
"COCO-Detection/faster_rcnn_R_101_DC5_3x.yaml": "138204841/model_final_3e0943.pkl",
"COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml": "137851257/model_final_f6e8b1.pkl",
"COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml": "139173657/model_final_68b088.pkl",
# COCO Detection with Retina-Net
# COCO Detection with RetinaNet
"COCO-Detection/retinanet_R_50_FPN_1x.yaml": "137593951/model_final_b796dc.pkl",
"COCO-Detection/retinanet_R_50_FPN_3x.yaml": "137849486/model_final_4cafe0.pkl",
"COCO-Detection/retinanet_R_101_FPN_3x.yaml": "138363263/model_final_59f53c.pkl",
Expand Down Expand Up @@ -68,7 +68,7 @@ class _ModelZooUrls(object):
"Misc/mask_rcnn_R_50_FPN_3x_dconv_c3-c5.yaml": "144998336/model_final_821d0b.pkl",
"Misc/cascade_mask_rcnn_R_50_FPN_1x.yaml": "138602847/model_final_e9d89b.pkl",
"Misc/cascade_mask_rcnn_R_50_FPN_3x.yaml": "144998488/model_final_480dd8.pkl",
"Misc/mask_rcnn_R_50_FPN_3x_syncbn.yaml": "143915318/model_final_220cfb.pkl",
"Misc/mask_rcnn_R_50_FPN_3x_syncbn.yaml": "169527823/model_final_3b3c51.pkl",
"Misc/mask_rcnn_R_50_FPN_3x_gn.yaml": "138602888/model_final_dc5d9e.pkl",
"Misc/scratch_mask_rcnn_R_50_FPN_3x_gn.yaml": "138602908/model_final_01ca85.pkl",
"Misc/panoptic_fpn_R_101_dconv_cascade_gn_3x.yaml": "139797668/model_final_be35db.pkl",
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@


project = "detectron2"
copyright = "2019, detectron2 contributors"
copyright = "2019-2020, detectron2 contributors"
author = "detectron2 contributors"

# The short X.Y version
Expand Down
44 changes: 33 additions & 11 deletions docs/tutorials/datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,23 @@ DatasetCatalog.register("my_dataset", get_dicts)
```

Here, the snippet associates a dataset "my_dataset" with a function that returns the data.
If you do not modify downstream code (i.e., you use the standard data loader and data mapper),
then the function has to return a list of dicts in detectron2's standard dataset format, described
next. You can also use arbitrary custom data format, as long as the
downstream code (mainly the [custom data loader](data_loading.html)) supports it.
The function can processes data from its original format into either one of the following:
1. Detectron2's standard dataset dict, described below. This will work with many other builtin
features in detectron2, so it's recommended to use it when it's sufficient for your task.
2. Your custom dataset dict. You can also returns arbitrary dicts in your own format,
such as adding extra keys for new tasks.
Then you will need to handle them properly in the downstream as well.
See below for more details.

#### Standard Dataset Dicts

For standard tasks
(instance detection, instance/semantic/panoptic segmentation, keypoint detection),
we use a format similar to COCO's json annotations
as the basic dataset representation.
we load the original dataset into `list[dict]` with a specification similar to COCO's json annotations.
This is our standard representation for a dataset.

The format uses one dict to represent the annotations of
one image. The dict may have the following fields.
Each dict contains information about one image.
The dict may have the following fields.
The fields are often optional, and some functions may be able to
infer certain fields from others if needed, e.g., the data loader
will load the image from "file_name" and load "sem_seg" from "sem_seg_file_name".
Expand Down Expand Up @@ -89,10 +94,10 @@ The following keys are used by Fast R-CNN style training, which is rare today.
+ `proposal_bbox_mode` (int): the format of the precomputed proposal bbox.
It must be a member of
[structures.BoxMode](../modules/structures.html#detectron2.structures.BoxMode).
Default format is `BoxMode.XYXY_ABS`.
Default is `BoxMode.XYXY_ABS`.


If your dataset is already in the COCO format, you can simply register it by
If your dataset is already a json file in COCO format, you can simply register it by
```python
from detectron2.data.datasets import register_coco_instances
register_coco_instances("my_dataset", {}, "json_annotation.json", "path/to/image/dir")
Expand All @@ -102,12 +107,29 @@ which will take care of everything (including metadata) for you.
If your dataset is in COCO format with custom per-instance annotations,
the [load_coco_json](../modules/data.html#detectron2.data.datasets.load_coco_json) function can be used.

#### Custom Dataset Dicts

In the `list[dict]` that your dataset function return, the dictionary can also has arbitrary custom data.
This can be useful when you're doing a new task and needs extra information not supported
by the standard dataset dicts. In this case, you need to make sure the downstream code can handle your data
correctly. Usually this requires writing a new `mapper` for the dataloader (see [Use Custom Dataloaders](data_loading.html))

When designing your custom format, note that all dicts are stored in memory
(sometimes serialized and with multiple copies).
To save memory, each dict is meant to contain small but sufficient information
about each sample, such as file names and annotations.
Loading full samples typically happens in the data loader.

For attributes shared among the entire dataset, use `Metadata` (see below).
To avoid exmemory, do not save such information repeatly for each sample.


### "Metadata" for Datasets

Each dataset is associated with some metadata, accessible through
`MetadataCatalog.get(dataset_name).some_metadata`.
Metadata is a key-value mapping that contains primitive information that helps interpret what's in the dataset, e.g.,
Metadata is a key-value mapping that contains information that's shared among
the entire dataset, and usually is used to interpret what's in the dataset, e.g.,
names of classes, colors of classes, root of files, etc.
This information will be useful for augmentation, evaluation, visualization, logging, etc.
The structure of metadata depends on the what is needed from the corresponding downstream code.
Expand Down
1 change: 1 addition & 0 deletions docs/tutorials/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ corresponds to information about one image.
The dict may contain the following keys:

* "image": `Tensor` in (C, H, W) format. The meaning of channels are defined by `cfg.INPUT.FORMAT`.
Image normalization, if any, will be performed inside the model.
* "instances": an [Instances](../modules/structures.html#detectron2.structures.Instances)
object, with the following fields:
+ "gt_boxes": a [Boxes](../modules/structures.html#detectron2.structures.Boxes) object storing N boxes, one for each instance.
Expand Down
14 changes: 10 additions & 4 deletions docs/tutorials/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ From the previous tutorials, you may now have a custom model and data loader.

You are free to create your own optimizer, and write the training logic: it's
usually easy with PyTorch, and allow researchers to see the entire training
logic more clearly.
logic more clearly and have full control.
One such example is provided in [tools/plain_train_net.py](../../tools/plain_train_net.py).

We also provide a standarized "trainer" abstraction with a
Expand All @@ -13,12 +13,18 @@ that helps simplify the standard types of training.

You can use
[SimpleTrainer().train()](../modules/engine.html#detectron2.engine.SimpleTrainer)
which does single-cost single-optimizer single-data-source training.
Or use [DefaultTrainer().train()](../modules/engine.html#detectron2.engine.defaults.DefaultTrainer)
which includes more standard behavior that one might want to opt in.
which provides minimal abstraction for single-cost single-optimizer single-data-source training.
The builtin `train_net.py` script uses
[DefaultTrainer().train()](../modules/engine.html#detectron2.engine.defaults.DefaultTrainer),
which includes more standard default behavior that one might want to opt in.
This also means that it's less likely to support some non-standard behavior
you might want during research.

To customize the training loops, you can either start
from [tools/plain_train_net.py](../../tools/plain_train_net.py),
or look at the source code of [DefaultTrainer](../../detectron2/engine/defaults.py)
and overwrite some of its behaviors with new parameters or new hooks.


### Logging of Metrics

Expand Down

0 comments on commit 5db4ea3

Please sign in to comment.