Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Val Dataset Transform Decoupled #1816

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions contrib/Matting/bg_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
import paddle
from paddleseg.cvlibs import manager, Config
from paddleseg.utils import get_sys_env, logger
from paddleseg.transforms import Compose

from core import predict
import model
from dataset import MattingDataset
from transforms import Compose
from utils import get_image_list, estimate_foreground_ml


Expand Down Expand Up @@ -81,23 +83,14 @@ def main(args):
raise RuntimeError('No configuration file specified.')

cfg = Config(args.cfg)
val_dataset = cfg.val_dataset
if val_dataset is None:
raise RuntimeError(
'The verification dataset is not specified in the configuration file.'
)
elif len(val_dataset) == 0:
raise ValueError(
'The length of val_dataset is 0. Please check if your dataset is valid'
)

msg = '\n---------------Config Information---------------\n'
msg += str(cfg)
msg += '------------------------------------------------'
logger.info(msg)

model = cfg.model
transforms = val_dataset.transforms
transforms = Compose(cfg.val_transforms)

alpha = predict(
model,
Expand Down
12 changes: 2 additions & 10 deletions contrib/Matting/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from core import predict
from model import *
from dataset import MattingDataset
from transforms import Compose
from utils import get_image_list


Expand Down Expand Up @@ -70,23 +71,14 @@ def main(args):
raise RuntimeError('No configuration file specified.')

cfg = Config(args.cfg)
val_dataset = cfg.val_dataset
if val_dataset is None:
raise RuntimeError(
'The verification dataset is not specified in the configuration file.'
)
elif len(val_dataset) == 0:
raise ValueError(
'The length of val_dataset is 0. Please check if your dataset is valid'
)

msg = '\n---------------Config Information---------------\n'
msg += str(cfg)
msg += '------------------------------------------------'
logger.info(msg)

model = cfg.model
transforms = val_dataset.transforms
transforms = Compose(cfg.val_transforms)

image_list, image_dir = get_image_list(args.image_path)
if args.trimap_path is None:
Expand Down
35 changes: 25 additions & 10 deletions paddleseg/cvlibs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,16 +300,19 @@ def model(self) -> paddle.nn.Layer:
raise RuntimeError('No model specified in the configuration file.')
if not 'num_classes' in model_cfg:
num_classes = None
if self.train_dataset_config:
if hasattr(self.train_dataset_class, 'NUM_CLASSES'):
num_classes = self.train_dataset_class.NUM_CLASSES
elif hasattr(self.train_dataset, 'num_classes'):
num_classes = self.train_dataset.num_classes
elif self.val_dataset_config:
if hasattr(self.val_dataset_class, 'NUM_CLASSES'):
num_classes = self.val_dataset_class.NUM_CLASSES
elif hasattr(self.val_dataset, 'num_classes'):
num_classes = self.val_dataset.num_classes
try:
if self.train_dataset_config:
if hasattr(self.train_dataset_class, 'NUM_CLASSES'):
num_classes = self.train_dataset_class.NUM_CLASSES
elif hasattr(self.train_dataset, 'num_classes'):
num_classes = self.train_dataset.num_classes
elif self.val_dataset_config:
if hasattr(self.val_dataset_class, 'NUM_CLASSES'):
num_classes = self.val_dataset_class.NUM_CLASSES
elif hasattr(self.val_dataset, 'num_classes'):
num_classes = self.val_dataset.num_classes
except FileNotFoundError:
michaelowenliu marked this conversation as resolved.
Show resolved Hide resolved
pass

if num_classes is not None:
model_cfg['num_classes'] = num_classes
Expand Down Expand Up @@ -402,3 +405,15 @@ def _is_meta_type(self, item: Any) -> bool:

def __str__(self) -> str:
return yaml.dump(self.dic)

@property
def val_transforms(self) -> list:
"""Get val_transform from val_dataset"""
_val_dataset = self.val_dataset_config
if not _val_dataset:
return []
_transforms = _val_dataset.get('transforms', [])
transforms = []
for i in _transforms:
transforms.append(self._load_object(i))
return transforms
11 changes: 3 additions & 8 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import paddle

from paddleseg.cvlibs import manager, Config
from paddleseg.utils import get_sys_env, logger, config_check, get_image_list
from paddleseg.utils import get_sys_env, logger, get_image_list
from paddleseg.core import predict
from paddleseg.transforms import Compose


def parse_args():
Expand Down Expand Up @@ -141,24 +142,18 @@ def main(args):
raise RuntimeError('No configuration file specified.')

cfg = Config(args.cfg)
val_dataset = cfg.val_dataset
if not val_dataset:
raise RuntimeError(
'The verification dataset is not specified in the configuration file.'
)

msg = '\n---------------Config Information---------------\n'
msg += str(cfg)
msg += '------------------------------------------------'
logger.info(msg)

model = cfg.model
transforms = val_dataset.transforms
transforms = Compose(cfg.val_transforms)
image_list, image_dir = get_image_list(args.image_path)
logger.info('Number of predict images = {}'.format(len(image_list)))

test_config = get_test_config(cfg, args)
config_check(cfg, val_dataset=val_dataset)

predict(
model,
Expand Down