Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] refactor data flow and engine library #1054

Merged
merged 26 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor tools
  • Loading branch information
zengyh1900 committed Aug 30, 2022
commit 0914393e1e6dacc85d05f4060e412dbdd8f1bdeb
13 changes: 8 additions & 5 deletions mmedit/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
from .restoration_face_inference import restoration_face_inference
from .restoration_inference import restoration_inference
from .restoration_video_inference import restoration_video_inference
from .test import multi_gpu_test, single_gpu_test
from .video_interpolation_inference import video_interpolation_inference

__all__ = [
'init_model', 'matting_inference', 'inpainting_inference',
'restoration_inference', 'restoration_video_inference',
'restoration_face_inference', 'video_interpolation_inference',
'multi_gpu_test', 'single_gpu_test', 'delete_cfg'
'init_model',
'delete_cfg',
'matting_inference',
'inpainting_inference',
'restoration_inference',
'restoration_video_inference',
'restoration_face_inference',
'video_interpolation_inference',
]
5 changes: 3 additions & 2 deletions mmedit/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
from mmcv.runner import load_checkpoint
from mmengine.config import ConfigDict
from mmengine.runner import load_checkpoint

from mmedit.registry import MODELS
from mmedit.utils import register_all_modules
Expand All @@ -10,7 +11,7 @@ def delete_cfg(cfg, key='init_cfg'):
if key in cfg:
cfg.pop(key)
for _key in cfg.keys():
if isinstance(cfg[_key], mmcv.utils.config.ConfigDict):
if isinstance(cfg[_key], ConfigDict):
delete_cfg(cfg[_key], key)


Expand Down
234 changes: 0 additions & 234 deletions mmedit/apis/test.py

This file was deleted.

2 changes: 1 addition & 1 deletion mmedit/datasets/basic_frames_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os.path as osp
from typing import Callable, List, Optional, Union

from mmengine import BaseDataset
from mmengine.dataset import BaseDataset
from mmengine.fileio import FileClient, list_from_file

from ..registry import DATASETS
Expand Down
2 changes: 1 addition & 1 deletion mmedit/datasets/basic_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from typing import Callable, List, Optional, Tuple, Union

from mmengine import BaseDataset
from mmengine.dataset import BaseDataset
from mmengine.fileio import FileClient, list_from_file

from mmedit.registry import DATASETS
Expand Down
2 changes: 1 addition & 1 deletion mmedit/datasets/transforms/trans_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import mmcv
import numpy as np
import torch
from mmcv.utils import print_log
from mmengine.logging import print_log
from PIL import Image, ImageDraw


Expand Down
6 changes: 3 additions & 3 deletions mmedit/engine/hooks/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import mmcv
import torch
from mmcv.parallel import is_module_wrapper
from mmengine.hooks import Hook
from mmengine.model.wrappers import is_model_wrapper
from mmengine.registry import HOOKS
from mmengine.runner import Runner

Expand Down Expand Up @@ -86,7 +86,7 @@ def after_train_iter(self,
if not self.every_n_iters(runner, self.interval):
return

model = runner.model.module if is_module_wrapper(
model = runner.model.module if is_model_wrapper(
runner.model) else runner.model

for key in self.module_keys:
Expand All @@ -106,7 +106,7 @@ def after_train_iter(self,
ema_net.load_state_dict(states_ema, strict=True)

def before_run(self, runner):
model = runner.model.module if is_module_wrapper(
model = runner.model.module if is_model_wrapper(
runner.model) else runner.model
# sanity check for ema model
for k in self.module_keys:
Expand Down
2 changes: 1 addition & 1 deletion mmedit/engine/hooks/visualization_hook.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence

from mmengine.data import BaseDataElement
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
from mmengine.structures import BaseDataElement


@HOOKS.register_module()
Expand Down
2 changes: 1 addition & 1 deletion mmedit/models/base_models/base_backbone.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.runner import load_checkpoint
from mmengine import MMLogger
from mmengine.runner import load_checkpoint

from mmedit.registry import BACKBONES

Expand Down
3 changes: 1 addition & 2 deletions mmedit/models/base_models/base_mattor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

import torch
import torch.nn.functional as F
from mmcv import ConfigDict
from mmengine.config import Config
from mmengine.config import Config, ConfigDict
from mmengine.model import BaseModel

from mmedit.registry import MODELS
Expand Down
2 changes: 1 addition & 1 deletion mmedit/models/base_models/multi_layer_disc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import load_checkpoint
from mmengine import MMLogger
from mmengine.runner import load_checkpoint

from mmedit.models.layers import LinearModule
from mmedit.registry import COMPONENTS
Expand Down
2 changes: 1 addition & 1 deletion mmedit/models/base_models/patch_disc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import load_checkpoint
from mmengine import MMLogger
from mmengine.runner import load_checkpoint

from mmedit.models.utils import generation_init_weights
from mmedit.registry import COMPONENTS
Expand Down
6 changes: 3 additions & 3 deletions mmedit/models/base_models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
from mmcv.runner import load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine import MMLogger
from mmengine.model.utils import constant_init, kaiming_init
from mmengine.model.weight_init import constant_init, kaiming_init
from mmengine.runner import load_checkpoint
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm


class BasicBlock(nn.Module):
Expand Down
Loading