Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Refine User Scripts #2849

Merged
merged 15 commits into from
Dec 29, 2022
Prev Previous commit
Next Next commit
Decouple builder from config.py
  • Loading branch information
Bobholamovic committed Dec 19, 2022
commit 3c5fe0ea69fa3a6eb6cfea7d226d7577e70b1a9d
12 changes: 6 additions & 6 deletions paddleseg/cvlibs/_config_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def apply_all_rules(self, cfg):
# Do nothing here as `self.apply_rule()` already handles the exceptions


class _Rule(object):
class Rule(object):
def check_and_correct(self, cfg):
# Be free to add in-place modification here
raise NotImplementedError
Expand All @@ -61,15 +61,15 @@ def apply(self, cfg, allow_update):
self.check_and_correct(cfg)


class DefaultPrimaryRule(_Rule):
class DefaultPrimaryRule(Rule):
def check_and_correct(self, cfg):
assert cfg.dic.get('model', None) is not None, \
'No model specified in the configuration file.'
assert cfg.train_dataset_config or self.val_dataset_config, \
'One of `train_dataset` or `val_dataset should be given, but there are none.'


class DefaultLossRule(_Rule):
class DefaultLossRule(Rule):
def __init__(self, loss_name):
super().__init__()
self.loss_name = loss_name
Expand All @@ -93,7 +93,7 @@ def check_and_correct(self, cfg):
len_types, len_coef))


class DefaultSyncNumClassesRule(_Rule):
class DefaultSyncNumClassesRule(Rule):
def check_and_correct(self, cfg):
num_classes_set = set()

Expand Down Expand Up @@ -129,7 +129,7 @@ def check_and_correct(self, cfg):
cfg.dic['val_dataset']['num_classes'] = num_classes


class DefaultSyncImgChannelsRule(_Rule):
class DefaultSyncImgChannelsRule(Rule):
def check_and_correct(self, cfg):
img_channels_set = set()
model_cfg = cfg.dic['model']
Expand Down Expand Up @@ -168,7 +168,7 @@ def check_and_correct(self, cfg):
cfg.dic['val_dataset']['img_channels'] = img_channels


class DefaultSyncIgnoreIndexRule(_Rule):
class DefaultSyncIgnoreIndexRule(Rule):
def __init__(self, loss_name):
super().__init__()
self.loss_name = loss_name
Expand Down
98 changes: 98 additions & 0 deletions paddleseg/cvlibs/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class ComponentBuilder(object):
"""
This class is responsible for building components. All component classes must be available
in the list of maintained components.

Args:
com_list (list): A list of component classes.
"""

def __init__(self, com_list):
super().__init__()
self.com_list = com_list

def load_component_class(self, com_name):
"""
Load component class, such as model, loss, dataset, etc.
"""
for com in self.com_list:
if com_name in com.components_dict:
return com[com_name]

raise RuntimeError("The specified component ({}) was not found.".format(
com_name))

def create_object(self, cfg):
"""
Create Python object, such as model, loss, dataset, etc.
"""
component = self.load_component_class_from_config(cfg)

params = {}
for key, val in cfg.items():
if self.is_meta_type(val):
params[key] = self.create_object(val)
elif isinstance(val, list):
params[key] = [
self.create_object(item)
if self.is_meta_type(item) else item for item in val
]
else:
params[key] = val

return self._create_object_impl(component, **params)

def _create_object_impl(self, component, *args, **kwargs):
raise NotImplementedError

def load_component_class_from_config(self, cfg):
raise NotImplementedError

@classmethod
def is_meta_type(cls, obj):
raise NotImplementedError


class DefaultComponentBuilder(ComponentBuilder):
def _create_object_impl(self, component, *args, **kwargs):
try:
return component(*args, **kwargs)
except Exception as e:
if hasattr(component, '__name__'):
com_name = component.__name__
else:
com_name = ''
raise RuntimeError(
f"Tried to create a {com_name} object, but the operation has failed. "
"Please double check the arguments used to create the object.\n"
f"The error message is: \n{str(e)}")

def load_component_class_from_config(self, cfg):
# XXX: `cfg` is modified in place. We pop out the 'type' key.
if 'type' not in cfg:
raise RuntimeError(
"It is not possible to create a component object from {}, as 'type' is not specified.".
format(cfg))
component = self.load_component_class(cfg.pop('type'))
return component

@classmethod
def is_meta_type(cls, obj):
# TODO: should we define a protocol (see https://peps.python.org/pep-0544/#defining-a-protocol)
# to make it more pythonic?
return isinstance(obj, dict) and 'type' in obj
96 changes: 33 additions & 63 deletions paddleseg/cvlibs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import paddle

from . import _config_checkers as checker
from . import builder
from paddleseg.cvlibs import manager
from paddleseg.utils import logger, utils
from paddleseg.utils.utils import CachedProperty as cached_property
Expand All @@ -31,7 +32,7 @@


class Config(object):
'''
"""
Training configuration parsing. The only yaml/yml file is supported.

The following hyper-parameters are available in the config file:
Expand Down Expand Up @@ -70,15 +71,16 @@ class Config(object):
# since the model builder uses some properties in dataset.
model = cfg.model
...
'''
"""

def __init__(self,
path: str,
learning_rate: Optional[float]=None,
batch_size: Optional[int]=None,
iters: Optional[int]=None,
opts: Optional[list]=None,
sanity_checker: Optional[checker.ConfigChecker]=None):
sanity_checker: Optional[checker.ConfigChecker]=None,
component_builder: Optional[builder.ComponentBuilder]=None):
assert os.path.exists(path), \
'Config path ({}) does not exist'.format(path)
assert path.endswith('yml') or path.endswith('yaml'), \
Expand All @@ -92,6 +94,13 @@ def __init__(self,
iters=iters,
opts=opts)

# We have to build the component builder before doing any sanity checks
# This is because during a sanity check, some component objects are (possibly)
# required to be constructed.
if component_builder is None:
component_builder = self._build_default_component_builder()
self.builder = component_builder

if sanity_checker is None:
sanity_checker = self._build_default_sanity_checker()
sanity_checker.apply_all_rules(self)
Expand Down Expand Up @@ -223,15 +232,15 @@ def _prepare_loss(self, loss_name):
args = self.dic.get(loss_name, {}).copy()
losses = {'coef': args['coef'], "types": []}
for loss_cfg in args['types']:
losses['types'].append(self.create_object(loss_cfg))
losses['types'].append(self.builder.create_object(loss_cfg))
return losses

#################### model
@cached_property
def model(self) -> paddle.nn.Layer:
model_cfg = self.dic.get('model').copy()
if not self._model:
self._model = self.create_object(model_cfg)
self._model = self.builder.create_object(model_cfg)
return self._model

#################### dataset
Expand All @@ -246,26 +255,26 @@ def val_dataset_config(self) -> Dict:
@cached_property
def train_dataset_class(self) -> Any:
dataset_type = self.train_dataset_config['type']
return load_component_class(dataset_type)
return self.builder.load_component_class(dataset_type)

@cached_property
def val_dataset_class(self) -> Any:
dataset_type = self.val_dataset_config['type']
return load_component_class(dataset_type)
return self.builder.load_component_class(dataset_type)

@cached_property
def train_dataset(self) -> paddle.io.Dataset:
_train_dataset = self.train_dataset_config
if not _train_dataset:
return None
return self.create_object(_train_dataset)
return self.builder.create_object(_train_dataset)

@cached_property
def val_dataset(self) -> paddle.io.Dataset:
_val_dataset = self.val_dataset_config
if not _val_dataset:
return None
return self.create_object(_val_dataset)
return self.builder.create_object(_val_dataset)

@cached_property
def val_transforms(self) -> list:
Expand All @@ -275,8 +284,8 @@ def val_transforms(self) -> list:
return []
_transforms = _val_dataset.get('transforms', [])
transforms = []
for i in _transforms:
transforms.append(self.create_object(i))
for tf in _transforms:
transforms.append(self.builder.create_object(tf))
return transforms

#################### test and export
Expand All @@ -298,10 +307,7 @@ def parse_from_yaml(cls, path: str, *args, **kwargs) -> dict:
return parse_from_yaml(path, *args, **kwargs)

@classmethod
def create_object(cls, cfg: dict, *args, **kwargs) -> Any:
return create_object(cfg, *args, **kwargs)

def _build_default_sanity_checker(self):
def _build_default_sanity_checker(cls):
rules = []
rules.append(checker.DefaultPrimaryRule())
rules.append(checker.DefaultSyncNumClassesRule())
Expand All @@ -315,9 +321,18 @@ def _build_default_sanity_checker(self):

return checker.ConfigChecker(rules, allow_update=True)

@classmethod
def _build_default_component_builder(cls):
com_list = [
manager.MODELS, manager.BACKBONES, manager.DATASETS,
manager.TRANSFORMS, manager.LOSSES
]
component_builder = builder.DefaultComponentBuilder(com_list=com_list)
return component_builder


def merge_config_dicts(dic, base_dic):
'''Merge dic to base_dic and return base_dic.'''
"""Merge dic to base_dic and return base_dic."""
base_dic = base_dic.copy()
dic = dic.copy()

Expand All @@ -335,7 +350,7 @@ def merge_config_dicts(dic, base_dic):


def parse_from_yaml(path: str):
'''Parse a yaml file and build config'''
"""Parse a yaml file and build config"""
with codecs.open(path, 'r', 'utf-8') as file:
dic = yaml.load(file, Loader=yaml.FullLoader)

Expand All @@ -356,7 +371,7 @@ def _update_config_dict(dic: dict,
batch_size: Optional[int]=None,
iters: Optional[int]=None,
opts: Optional[list]=None):
'''Update config'''
"""Update config"""
# TODO: If the items to update are marked as anchors in the yaml file,
# we should synchronize the references.
dic = dic.copy()
Expand Down Expand Up @@ -391,48 +406,3 @@ def _update_config_dict(dic: dict,
tmp_dic[key_list[-1]] = value

return dic


def load_component_class(com_name: str, com_list: Optional[list]=None) -> Any:
'''Load component class, such as model, loss, dataset, etc.'''
if com_list is None:
com_list = [
manager.MODELS, manager.BACKBONES, manager.DATASETS,
manager.TRANSFORMS, manager.LOSSES
]

for com in com_list:
if com_name in com.components_dict:
return com[com_name]

raise RuntimeError('The specified component ({}) was not found.'.format(
com_name))


def create_object(cfg: dict, com_list: Optional[list]=None) -> Any:
'''Create Python object, such as model, loss, dataset, etc.'''
cfg = cfg.copy()
if 'type' not in cfg:
raise RuntimeError('No object information in {}.'.format(cfg))

component = load_component_class(cfg.pop('type'), com_list=com_list)

params = {}
for key, val in cfg.items():
if is_meta_type(val):
params[key] = create_object(val)
elif isinstance(val, list):
params[key] = [
create_object(item) if is_meta_type(item) else item
for item in val
]
else:
params[key] = val

return component(**params)


def is_meta_type(obj):
# TODO: should we define a protocol (see https://peps.python.org/pep-0544/#defining-a-protocol)
# to make it more pythonic?
return isinstance(obj, dict) and 'type' in obj