Skip to content

Loosen class requirements in from_preset #2276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions keras_hub/src/layers/preprocessing/audio_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
PreprocessingLayer,
)
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty
Expand Down Expand Up @@ -89,10 +88,7 @@ class like `keras_hub.models.AudioConverter.from_preset()`, or from a
```
"""
loader = get_preset_loader(preset)
backbone_cls = loader.check_backbone_class()
if cls.backbone_cls != backbone_cls:
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_audio_converter(cls, **kwargs)
return loader.load_audio_converter(cls=cls, kwargs=kwargs)

def save_to_preset(self, preset_dir):
"""Save audio converter to a preset directory.
Expand Down
6 changes: 1 addition & 5 deletions keras_hub/src/layers/preprocessing/image_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
)
from keras_hub.src.utils.keras_utils import standardize_data_format
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty
Expand Down Expand Up @@ -380,10 +379,7 @@ def from_preset(
```
"""
loader = get_preset_loader(preset)
backbone_cls = loader.check_backbone_class()
if cls.backbone_cls != backbone_cls:
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_image_converter(cls, **kwargs)
return loader.load_image_converter(cls=cls, kwargs=kwargs)

def save_to_preset(self, preset_dir):
"""Save image converter to a preset directory.
Expand Down
11 changes: 3 additions & 8 deletions keras_hub/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,9 @@ class like `keras_hub.models.Backbone.from_preset()`, or from
```
"""
loader = get_preset_loader(preset)
backbone_cls = loader.check_backbone_class()
if not issubclass(backbone_cls, cls):
raise ValueError(
f"Saved preset has type `{backbone_cls.__name__}` which is not "
f"a subclass of calling class `{cls.__name__}`. Call "
f"`from_preset` directly on `{backbone_cls.__name__}` instead."
)
return loader.load_backbone(backbone_cls, load_weights, **kwargs)
return loader.load_backbone(
cls=cls, load_weights=load_weights, kwargs=kwargs
)

def save_to_preset(self, preset_dir, max_shard_size=10):
"""Save backbone to a preset directory.
Expand Down
44 changes: 19 additions & 25 deletions keras_hub/src/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
)
from keras_hub.src.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty
Expand Down Expand Up @@ -171,43 +170,38 @@ def from_preset(
)
```
"""
if cls == Preprocessor:
if cls is Preprocessor:
raise ValueError(
"Do not call `Preprocessor.from_preset()` directly. Instead "
"choose a particular task preprocessing class, e.g. "
"`keras_hub.models.TextClassifierPreprocessor.from_preset()`."
)

loader = get_preset_loader(preset)
backbone_cls = loader.check_backbone_class()
# Detect the correct subclass if we need to.
if cls.backbone_cls != backbone_cls:
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_preprocessor(cls, config_file, **kwargs)
return loader.load_preprocessor(
cls=cls, config_file=config_file, kwargs=kwargs
)

@classmethod
def _add_missing_kwargs(cls, loader, kwargs):
"""Fill in required kwargs when loading from preset.

This is a private method hit when loading a preprocessing layer that
was not directly saved in the preset. This method should fill in
all required kwargs required to call the class constructor. For almost,
all preprocessors, the only required args are `tokenizer`,
`image_converter`, and `audio_converter`, but this can be overridden,
e.g. for a preprocessor with multiple tokenizers for different
encoders.
def _from_defaults(cls, loader, kwargs):
"""Load a preprocessor from default values.

This is a private method hit for loading a preprocessing layer that was
not directly saved in the preset. Usually this means loading a
tokenizer, image_converter and/or audio_converter and calling the
constructor. But this can be overridden by subclasses as needed.
"""
defaults = {}
# Allow loading any tokenizer, image_converter or audio_converter config
# we find on disk. We allow mixing a matching tokenizers and
# preprocessing layers (though this is usually not a good idea).
if "tokenizer" not in kwargs and cls.tokenizer_cls:
kwargs["tokenizer"] = loader.load_tokenizer(cls.tokenizer_cls)
defaults["tokenizer"] = loader.load_tokenizer()
if "audio_converter" not in kwargs and cls.audio_converter_cls:
kwargs["audio_converter"] = loader.load_audio_converter(
cls.audio_converter_cls
)
defaults["audio_converter"] = loader.load_audio_converter()
if "image_converter" not in kwargs and cls.image_converter_cls:
kwargs["image_converter"] = loader.load_image_converter(
cls.image_converter_cls
)
return kwargs
defaults["image_converter"] = loader.load_image_converter()
return cls(**{**defaults, **kwargs})

def load_preset_assets(self, preset):
"""Load all static assets needed by the preprocessing layer.
Expand Down
41 changes: 25 additions & 16 deletions keras_hub/src/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.preprocessor import Preprocessor
from keras_hub.src.tokenizers.tokenizer import Tokenizer
from keras_hub.src.utils.keras_utils import print_msg
from keras_hub.src.utils.pipeline_model import PipelineModel
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty
Expand Down Expand Up @@ -175,27 +173,38 @@ def from_preset(
)
```
"""
if cls == Task:
if cls is Task:
raise ValueError(
"Do not call `Task.from_preset()` directly. Instead call a "
"particular task class, e.g. "
"`keras_hub.models.TextClassifier.from_preset()`."
)

loader = get_preset_loader(preset)
backbone_cls = loader.check_backbone_class()
# Detect the correct subclass if we need to.
if (
issubclass(backbone_cls, Backbone)
and cls.backbone_cls != backbone_cls
):
cls = find_subclass(preset, cls, backbone_cls)
# Specifically for classifiers, we never load task weights if
# num_classes is supplied. We handle this in the task base class because
# it is the same logic for classifiers regardless of modality (text,
# images, audio).
load_task_weights = "num_classes" not in kwargs
return loader.load_task(cls, load_weights, load_task_weights, **kwargs)
return loader.load_task(
cls=cls, load_weights=load_weights, kwargs=kwargs
)

@classmethod
def _from_defaults(cls, loader, load_weights, kwargs, backbone_kwargs):
"""Load a task from default values.

This is a private method hit for loading a task layer that was
not directly saved in the preset. Usually this means loading a backbone
and preprocessor and calling the constructor. But this can be overridden
by subclasses as needed.
"""
defaults = {}
if "backbone" not in kwargs:
defaults["backbone"] = loader.load_backbone(
load_weights=load_weights, kwargs=backbone_kwargs
)
if "preprocessor" not in kwargs and cls.preprocessor_cls:
# Only load the "matching" preprocessor class for a task class.
defaults["preprocessor"] = loader.load_preprocessor(
cls=cls.preprocessor_cls
)
return cls(**{**defaults, **kwargs})

def load_task_weights(self, filepath):
"""Load only the tasks specific weights not in the backbone."""
Expand Down
8 changes: 3 additions & 5 deletions keras_hub/src/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from keras_hub.src.utils.preset_utils import ASSET_DIR
from keras_hub.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_file
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import get_preset_saver
Expand Down Expand Up @@ -257,7 +256,6 @@ class like `keras_hub.models.Tokenizer.from_preset()`, or from
```
"""
loader = get_preset_loader(preset)
backbone_cls = loader.check_backbone_class()
if cls.backbone_cls != backbone_cls:
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_tokenizer(cls, config_file, **kwargs)
return loader.load_tokenizer(
cls=cls, config_file=config_file, kwargs=kwargs
)
Loading
Loading