Skip to content

Conversation

jbischof
Copy link
Contributor

@jbischof jbischof commented Feb 22, 2023

This PR pilots the biggest single step to unifying the KerasCV and KerasNLP APIs. It sits downstream of the pilot for functional subclasses #1401 and will be followed by a PR introducing Task models for classification.

Highlights of this PR:

  • Use from_preset constructor to load weights instead of weights arg
  • Reimplement config-in-code classes (e.g., ResNet50V2Backbone) with from_preset constructor
  • Remove pooling and include_top args to be handled by Task models
  • Add docstring examples to show basic usage
  • Introduce Backbone class to hold generic methods and properties
  • Rename as_backbone to get_feature_extractor
  • Decouple testing per model to improve readability
  • Test weight loading with pytest
  • Introduce conftest.py to allow control of weight RCP testing

API preview

See gist for a preview of the new API.

Quick summary:

We will rely on the docstring and keras.io to communicate preset usage.

Example docstring:

from_preset(*args, **kwargs) method of builtins.type instance
    Instantiate ResNetV2Backbone model from preset architecture and weights.
    Args:
        preset: string. Must be one of "resnet18_v2", "resnet34_v2", "resnet50_v2", "resnet101_v2", "resnet152_v2", "resnet50_v2_imagenet".
            If looking for a preset with pretrained weights, choose one of
            "resnet50_v2_imagenet".
        load_weights: Whether to load pre-trained weights into model.
            Defaults to `None`, which follows whether the preset has
            pretrained weights available.
    
    Examples:
    ```python
    # Load architecture and weights from preset
    model = keras_cv.models.ResNetV2Backbone.from_preset(
        "resnet50_v2_imagenet",
    )
    
    # Load randomly initialized model from preset architecture with weights
    model = keras_cv.models.ResNetV2Backbone.from_preset(
        "resnet50_v2_imagenet",
        load_weights=False,
    ```

Example usage:

# Generic constructor
model = ResNetV2Backbone(
    stackwise_filters=[64, 128, 256, 512],
    stackwise_blocks=[2, 2, 2, 2],
    stackwise_strides=[1, 2, 2, 2],
    include_rescaling=False,
    input_shape=[256, 256, 3],
)

# Load preset architecture without weights
# Can also include overrides
model = ResNetV2Backbone.from_preset(
    "resnet18_v2",
    include_rescaling=False,
)

# Load preset architecture with weights
model = ResNetV2Backbone.from_preset("resnet50_v2_imagenet")

# Use as backbone
model = RetinaNet(
    classes=20,
    bounding_box_format="xywh",
    backbone=ResNetV2Backbone.from_preset(
        "resnet50_v2_imagenet").get_feature_extractor(),

)

keras.applications aliases

We've also reintroduced versions of the keras.applications config-in-code classes now powered by from_preset.

# Load a preset without weights
model = ResNet50V2Backbone(
    include_rescaling=False,
)

# Load a preset with weights
model = ResNet50V2Backbone.from_preset("resnet50_v2_imagenet")

# Use as backbone
model = RetinaNet(
    classes=20,
    bounding_box_format="xywh",
    backbone=ResNet50V2Backbone().get_feature_extractor(),
)

@jbischof
Copy link
Contributor Author

jbischof commented Feb 23, 2023

/gcbrun

Copy link
Contributor

@LukeWood LukeWood left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a first pass --- overall I really like the abstractions around presets/weight loadings.

My only real concern revolves around the readability of ResNet50v2(...) vs `ResNetV2Backbone.from_preset(...)1

times smaller in width and height than the input image.

Args:
min_level: optional int, the lowest level of feature to be included
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This min_level / max_level API is very mysterious. Doubt anyone not already familiar with the implementation will figure out what it means. Can we find better argument names and descriptions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Will file an Issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I really do not like this API. realistically it should be a list - either of level names or layer names.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #1447

@jbischof
Copy link
Contributor Author

/gcbrun

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Mainly took a pass to educate myself, but thoughts on this.

Overall I think this is great at brining the two libraries into a similar style and make it so to an outside observer they look like they are written by the same folks. Which I think is the big win to have here.

I think the biggest thing I noticed was that

model = keras_cv.models.RetinaNet(
    classes=10,
    bounding_box_format="xywh",
    backbone=keras_cv.models.ResNetV2Backbone.from_preset(
        "resnet50_v2",
    ).get_feature_extractor(),
)

feels a bit different than the approach we took on KerasNLP. In KerasNLP our high-level models do some surgery on our backbones often, but that logic stay in the class that need to extract certain features, rather than the backbone having an "export" function. The more similar move to KerasNLP seems like it would be

backbone = keras_cv.models.ResNetV2Backbone.from_preset("resnet50_v2")
model = keras_cv.models.RetinaNet(
    classes=10,
    bounding_box_format="xywh",
    backbone=backbone,
    min_backbone_level=xx,  # None can still be the deafult here.
    max_backbone_level=yy,  # None can still be the deafult here.
)

Or even just

model = keras_cv.models.RetinaNet.from_preset(
    "some_id",
    classes=10,
    bounding_box_format="xywh",
)

I suspect this is more something to think about (or it already has been), then to change on this PR, but what's the big reasons for the discrepancy here?

@fchollet
Copy link
Contributor

I suspect this is more something to think about (or it already has been), then to change on this PR, but what's the big reasons for the discrepancy here?

The approach you describe is cleaner, in fact. I believe the reason for the current discrepancy is that each backbone might require custom handling. If there is no universal way to do feature extraction, then it's more practical for each backbone implementation to specify how to do it for that particular backbone.

But perhaps this is not a correct assumption.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Looking great.

Copy link
Contributor

@LukeWood LukeWood left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm comfortable merging this --- but it might make more sense to wait until the Classifier task model is done. Up to you ---

another note: when we do merge this we should ctrl-f ResNet50V2 in keras-io and any other r

Also: we should probably get a universal agreement

@LukeWood
Copy link
Contributor

(apologies for any typos in my comments - my GitHub UI is bugged and my comments don't display as I type. The box simply remains empty - so no way to edit typos out.)

@jbischof
Copy link
Contributor Author

Thanks @LukeWood am working on a cleaner implementation. Will update Monday.

Copy link
Contributor

@LukeWood LukeWood left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good - I also am glad to see that the Sequential([ResNet50V2Backbone(), layers.GlobalAveragePooling2D()]) API works as expected (per the Simclr training test)

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. Looking good, I think we can ship it.

@jbischof
Copy link
Contributor Author

Thanks @LukeWood, I also replicated our quickstart using the following model:

resnet_classifier = keras.Sequential(
    [
        keras_cv.models.ResNet18V2Backbone(),
        keras.layers.GlobalAveragePooling2D(name="avg_pool"),
        keras.layers.Dense(3, activation="softmax", name="predictions"),
    ],
)

The model accuracy improved from 0.33 -> 0.51 over 10 epochs. See the attached colab for details.

@jbischof
Copy link
Contributor Author

/gcbrun

@LukeWood
Copy link
Contributor

Thanks! Mainly took a pass to educate myself, but thoughts on this.

Overall I think this is great at brining the two libraries into a similar style and make it so to an outside observer they look like they are written by the same folks. Which I think is the big win to have here.

I think the biggest thing I noticed was that

model = keras_cv.models.RetinaNet(
    classes=10,
    bounding_box_format="xywh",
    backbone=keras_cv.models.ResNetV2Backbone.from_preset(
        "resnet50_v2",
    ).get_feature_extractor(),
)

feels a bit different than the approach we took on KerasNLP. In KerasNLP our high-level models do some surgery on our backbones often, but that logic stay in the class that need to extract certain features, rather than the backbone having an "export" function. The more similar move to KerasNLP seems like it would be

backbone = keras_cv.models.ResNetV2Backbone.from_preset("resnet50_v2")
model = keras_cv.models.RetinaNet(
    classes=10,
    bounding_box_format="xywh",
    backbone=backbone,
    min_backbone_level=xx,  # None can still be the deafult here.
    max_backbone_level=yy,  # None can still be the deafult here.
)

Or even just

model = keras_cv.models.RetinaNet.from_preset(
    "some_id",
    classes=10,
    bounding_box_format="xywh",
)

I suspect this is more something to think about (or it already has been), then to change on this PR, but what's the big reasons for the discrepancy here?

I do quite like the idea of having the get_feature_extractor() logic live in the task models -- but I also think that can be a follow up PR.

thoughts?

@jbischof
Copy link
Contributor Author

jbischof commented Mar 1, 2023

/gcbrun

@jbischof jbischof merged commit f3d8582 into keras-team:master Mar 1, 2023
@jbischof jbischof deleted the backbone branch March 1, 2023 00:37
@IMvision12
Copy link
Contributor

@jbischof Are these modifications required for resentv1?

@jbischof
Copy link
Contributor Author

jbischof commented Mar 6, 2023

Yes @IMvision12 but not quite yet! I'm planning some followup PRs to refine the design somewhat and will start filing Issues later this week. Thank you for all your amazing contributions 🚀

ghost pushed a commit to y-vectorfield/keras-cv that referenced this pull request Nov 16, 2023
* Introduce `Backbone` class and presets

* Attach presets

* Restore code we still need

* First passing tests

* Finish backbone tests

* Get preset tests working

* Remove unused marker

* Remove dangling TPU reference

* Add variable input channels test

* Improve preset names

* Better documentation for presets with weights

* Fix import

* Fix broken tests

* Respond to comments

* Respond to comments 2

* Add __init__.py files

* Fix docstring

* format

* Respond to comments

* Export new symbols

* Fix bug in ResNet50V2Backbone

* Respond to more comments

* format

* Inline error message

* Change applications alias to subclass

* Fix broken test

* Remove unneeded overrides

* Fix inheritence structure

* Respond to comments

* format

* format2

* Add docstring examples
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants