Skip to content

[deeplab] Training deeplab model with ADE20K dataset #3730

Open
@walkerlala

Description

@walkerlala

System information

  • What is the top-level directory of the model you are using: deeplab
  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 16.04
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 1.6.0
  • Bazel version (if compiling from source):
  • CUDA/cuDNN version: 9.0/7.0.4
  • GPU model and memory: 1080Ti * 2 , 10Gb * 2
  • Exact command to reproduce:

Describe the problem

This is a feature request. I am trying to train the deeplab model with the ADE20K dataset (see this presentation). I have finished the data format conversion and "successfully" train the model on a small subset of ADE20K. Below is the modification to file research/deeplab/datasets/segmentation_dataset.py which is used to extract segmentation data.

diff --git a/research/deeplab/datasets/segmentation_dataset.py b/research/deeplab/datasets/segmentation_dataset.py
index a777252..8648fb2 100644
--- a/research/deeplab/datasets/segmentation_dataset.py
+++ b/research/deeplab/datasets/segmentation_dataset.py
@@ -85,10 +85,22 @@ _PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
     ignore_label=255,
 )
 
+_ADE20K_INFORMATION = DatasetDescriptor(
+    splits_to_sizes = {
+        'train': 40,
+        'val': 5,
+    },
+    # TODO temporarily change it to 21 otherwise dimension mismatch
+    num_classes=21,
+    ignore_label=255,
+)
+
 
 _DATASETS_INFORMATION = {
     'cityscapes': _CITYSCAPES_INFORMATION,
     'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
+    'ade20k': _ADE20K_INFORMATION,
 }
 
 # Default file pattern of TFRecord of TensorFlow Example.

The problem is, in the ADE20K dataset there are 150 classes, which is different from that in the VOC or cityspace dataset. That brings problem w.r.t the checkpoint file. Currently there are only pretrained model on the VOC and cityspace dataset. So we have two choices here:

  1. Do not use the checkpoint file. In this case, there is an error:
absl.flags._exceptions.IllegalFlagValueError: flag --tf_initial_checkpoint=None: Flag --tf_initial_checkpoint must be specified.
  1. set num_classes=21 to use those two provided checkpoint files

Are there any alternatives to these?

If anyone have any workable solution for the ADE20K dataset it would be really appreciated.

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions