Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Semantic Segmentation] Add first segmentation weights (credit @tanzhenyu) #1059

Merged
merged 5 commits into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"deeplabv3": {
"v0": {
"accelerators": 4,
"args": {},
"contributor": "tanzhenyu",
"epochs_trained": 99,
"script": {
"name": "deeplab_v3.py",
"version": "6a518c900b6533939e80e027d38e741a9d01ff48"
},
"tensorboard_logs": "https://tensorboard.dev/experiment/Wh9RZvNNRMeLjFyqObrsag/",
"validation_accuracy": "0.9141",
"validation_mean_iou": "0.6863"
}
},
"script_authors": {
"deeplab_v3.py": [
"tanzhenyu"
]
}
}
12 changes: 10 additions & 2 deletions keras_cv/models/segmentation/deeplab.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ class DeepLabV3(tf.keras.models.Model):
instance. The supported pre-defined backbone models are:
1. "resnet50_v2", a ResNet50 V2 model
Default to 'resnet50_v2'.
backbone_weights: weights for the backbone model. one of `None` (random
initialization), a pretrained weight file path, or a reference to
pre-trained weights (e.g. 'imagenet/classification') (see available
pre-trained weights in weights.py)
weights: weights for the complete DeepLabV3 model. one of `None` (random
initialization), a pretrained weight file path, or a reference to
pre-trained weights (e.g. 'imagenet/classification') (see available
pre-trained weights in weights.py)
decoder: an optional decoder network for segmentation model, e.g. FPN. The
supported premade decoder is: "fpn". The decoder is called on
the output of the backbone network to up-sample the feature output.
Expand All @@ -56,7 +64,7 @@ def __init__(
classes,
include_rescaling,
backbone,
weights,
backbone_weights=None,
spatial_pyramid_pooling=None,
segmentation_head=None,
**kwargs,
Expand Down Expand Up @@ -90,7 +98,7 @@ def __init__(
include_rescaling=include_rescaling,
include_top=False,
name="resnet50v2",
weights=parse_weights(weights, False, "resnet50v2"),
weights=parse_weights(backbone_weights, False, "resnet50v2"),
pooling=None,
**kwargs,
)
Expand Down
10 changes: 4 additions & 6 deletions keras_cv/models/segmentation/deeplab_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
class DeeplabTest(tf.test.TestCase):
def test_deeplab_model_construction_with_preconfigured_setting(self):
model = segmentation.DeepLabV3(
classes=11, include_rescaling=True, backbone="resnet50_v2", weights=None
classes=11, include_rescaling=True, backbone="resnet50_v2"
)
input_image = tf.random.uniform(shape=[2, 256, 256, 3])
output = model(input_image, training=True)
Expand All @@ -35,7 +35,7 @@ def test_deeplab_model_construction_with_preconfigured_setting(self):
def test_deeplab_model_with_components(self):
backbone = models.ResNet50V2(include_rescaling=True, include_top=False)
model = segmentation.DeepLabV3(
classes=11, include_rescaling=True, backbone=backbone, weights=None
classes=11, include_rescaling=True, backbone=backbone
)

input_image = tf.random.uniform(shape=[2, 256, 256, 3])
Expand All @@ -46,7 +46,7 @@ def test_deeplab_model_with_components(self):
def test_mixed_precision(self):
tf.keras.mixed_precision.set_global_policy("mixed_float16")
model = segmentation.DeepLabV3(
classes=11, include_rescaling=True, backbone="resnet50_v2", weights=None
classes=11, include_rescaling=True, backbone="resnet50_v2"
)
input_image = tf.random.uniform(shape=[2, 256, 256, 3])
output = model(input_image, training=True)
Expand All @@ -61,7 +61,6 @@ def test_invalid_backbone_model(self):
classes=11,
include_rescaling=True,
backbone="resnet_v3",
weights=None,
)
with self.assertRaisesRegex(
ValueError, "Backbone need to be a `tf.keras.layers.Layer`"
Expand All @@ -70,7 +69,6 @@ def test_invalid_backbone_model(self):
classes=11,
include_rescaling=True,
backbone=tf.Module(),
weights=None,
)

@pytest.mark.skipif(
Expand All @@ -81,7 +79,7 @@ def test_invalid_backbone_model(self):
)
def test_model_train(self):
model = segmentation.DeepLabV3(
classes=1, include_rescaling=True, backbone="resnet50_v2", weights=None
classes=1, include_rescaling=True, backbone="resnet50_v2"
)

gcs_data_pattern = "gs://caltech_birds2011_mask/0.1.1/*.tfrecord*"
Expand Down
6 changes: 6 additions & 0 deletions keras_cv/models/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def parse_weights(weights, include_top, model_type):
"imagenet": "imagenet/classification-v0",
"imagenet/classification": "imagenet/classification-v0",
},
"deeplabv3": {
"voc": "voc/segmentation-v0",
},
"densenet121": {
"imagenet": "imagenet/classification-v0",
"imagenet/classification": "imagenet/classification-v0",
Expand Down Expand Up @@ -90,6 +93,9 @@ def parse_weights(weights, include_top, model_type):
"imagenet/classification-v0": "7bc5589f7f7f7ee3878e61ab9323a71682bfb617eb57f530ca8757c742f00c77",
"imagenet/classification-v0-notop": "8dcce43163e4b4a63e74330ba1902e520211db72d895b0b090b6bfe103e7a8a5",
},
"deeplabv3": {
"voc/segmentation-v0": "732042e8b6c9ddba3d51c861f26dc41865187e9f85a0e5d43dfef75a405cca18",
},
"densenet121": {
"imagenet/classification-v0": "13de3d077ad9d9816b9a0acc78215201d9b6e216c7ed8e71d69cc914f8f0775b",
"imagenet/classification-v0-notop": "709afe0321d9f2b2562e562ff9d0dc44cca10ed09e0e2cfba08d783ff4dab6bf",
Expand Down
46 changes: 38 additions & 8 deletions shell/weights/update_training_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,38 @@
tensorboard_results = tensorboard_experiment.get_scalars()

training_epochs = max(tensorboard_results[tensorboard_results.run == "train"].step)
max_validation_accuracy = max(
tensorboard_results[
(tensorboard_results.run == "validation")
& (tensorboard_results.tag == "epoch_categorical_accuracy")
].value
)
max_validation_accuracy = f"{max_validation_accuracy:.4f}"

results_tags = tensorboard_results.tag.unique()

# Validation accuracy won't exist in all logs (e.g for OD tasks).
# We capture the max validation accuracy if it exists, but otherwise omit it.
max_validation_accuracy = None
if (
"epoch_categorical_accuracy" in results_tags
or "epoch_sparse_categorical_accuracy" in results_tags
):
max_validation_accuracy = max(
tensorboard_results[
(tensorboard_results.run == "validation")
& (
(tensorboard_results.tag == "epoch_categorical_accuracy")
| (tensorboard_results.tag == "epoch_sparse_categorical_accuracy")
)
].value
)
max_validation_accuracy = f"{max_validation_accuracy:.4f}"

# Mean IOU won't exist in all logs (e.g for classification tasks).
# We capture the max IOU if it exists, but otherwise omit it.
max_mean_iou = None
if "epoch_mean_io_u" in results_tags:
max_mean_iou = max(
tensorboard_results[
(tensorboard_results.run == "validation")
& (tensorboard_results.tag == "epoch_mean_io_u")
].value
)
max_mean_iou = f"{max_mean_iou:.4f}"

contributor = FLAGS.contributor or input(
"Input your GitHub username (or the username of the contributor, if it's not you)\n"
Expand All @@ -106,14 +131,19 @@

new_results = {
"script": {"name": "/".join(training_script_dirs[2:]), "version": script_version},
"validation_accuracy": max_validation_accuracy,
"epochs_trained": training_epochs,
"tensorboard_logs": f"https://tensorboard.dev/experiment/{tensorboard_experiment_id}/",
"contributor": contributor,
"args": args_dict,
"accelerators": int(accelerators),
}

if max_validation_accuracy is not None:
new_results["validation_accuracy"] = max_validation_accuracy

if max_mean_iou is not None:
new_results["validation_mean_iou"] = max_mean_iou

# Check if the JSON file already exists
results_file = open(training_script_json_path, "r")
results_string = results_file.read()
Expand Down
12 changes: 12 additions & 0 deletions shell/weights/upload_weights.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
if [ "$#" -ne 2 ]; then
echo USAGE: ./process_backbone_weights.sh WEIGHTS_PATH GCS_PATH
exit 1
fi

WEIGHTS=$1
GCS_PATH=$2

echo Checksum: $(shasum -a 256 $WEIGHTS)

gsutil cp $WEIGHTS $GCS_PATH/
gsutil acl ch -u AllUsers:R $GCS_PATH/$WEIGHTS