Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump to 1.0.0rc3. #1425

Merged
merged 63 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
9eff9b1
[Improve] add a version option in docs menu (#1162)
zengyh1900 Oct 8, 2022
f6886a1
[Enhance] update dev_scripts for link checking (#1164)
zengyh1900 Oct 8, 2022
944d3a8
[Refactoring] decompose the implementations of different metrics into…
zengyh1900 Oct 9, 2022
5a080ff
[Fix] Fix PPL bug (#1172)
plyfager Oct 9, 2022
0fdd4a8
[Fix] Fix some known bugs. (#1200)
LeoXing1996 Oct 10, 2022
11dcf18
[Fix] Benchmark related bugs (#1236)
plyfager Oct 11, 2022
1ca720d
[Enhancement] Support rerun failed or canceled jobs in `train_benchma…
LeoXing1996 Oct 11, 2022
e25023b
[Fix] Fix bugs in `sr test config`, `realbasicvsr config` and `pconv…
Z-Fran Oct 12, 2022
2547173
[Fix] fix test of Vid4 datasets bug (#1293)
Z-Fran Oct 14, 2022
fcd43a6
[Feature] Support multi-metrics with different sample-model (#1171)
LeoXing1996 Oct 14, 2022
5a38f5b
[Fix] fix GenerateSegmentIndices ut (#1302)
Z-Fran Oct 14, 2022
f33a09d
[Enhancement] Reduce the randomness in unit test of `stylegan3_utils.…
LeoXing1996 Oct 17, 2022
97d74d8
[CI] Fix GitHub windows CI (#1320)
LeoXing1996 Oct 17, 2022
836e180
[Fix] fix basicvsr++ mirror sequence bug (#1304)
Z-Fran Oct 17, 2022
591b666
[Fix] fix sisr-test psnr config (#1319)
Z-Fran Oct 17, 2022
1f71028
[Fix] fix vsr models pytorch2onnx (#1300)
Z-Fran Oct 17, 2022
3721df4
[Bug] Ensure the output type of `GenerateFacialHeatmap` is `np.float3…
LeoXing1996 Oct 17, 2022
013f106
[Bug] Fix sampling behavior of `unpaired_dataset.py` and urls in cyc…
LeoXing1996 Oct 17, 2022
c844581
[README] Fix TTSR's README (#1325)
LeoXing1996 Oct 17, 2022
1dd0529
[CI] Update `paths-ignore` for GitHub CI (#1327)
LeoXing1996 Oct 17, 2022
6498c61
[Bug] Save gt images in PGGAN's `forward` (#1328)
LeoXing1996 Oct 18, 2022
1e741ea
[Bug] Correct RDN number of channels (#1332)
ryanxingql Oct 18, 2022
d1725ae
[Bug] Revise flip transformation in some conditional gan's setting (#…
LeoXing1996 Oct 19, 2022
6261f9e
[Unit Test] Fix unit test of SNR (#1335)
LeoXing1996 Oct 19, 2022
20e5850
[Bug] Revise flavr config (#1336)
LeoXing1996 Oct 20, 2022
b90ac53
[Fix] fix realesrgan ema (#1341)
Z-Fran Oct 21, 2022
7713e5c
[Fix] Fix bugs find during benchmark running (#1348)
plyfager Oct 24, 2022
4cf90a6
[Fix] fix liif test config (#1353)
Z-Fran Oct 25, 2022
4353ab8
[Enhancement] Complete save_best in configs (#1349)
plyfager Oct 28, 2022
eb9fab5
[Config] Revise discriminator's learning rate of TTSR to align with 0…
LeoXing1996 Oct 28, 2022
38f805a
[Fix] fix edsr configs (#1367)
Z-Fran Oct 30, 2022
da10b0f
[Enhancement] Add pixel value clip in visualizer (#1365)
LeoXing1996 Oct 30, 2022
6653aa6
[Bug] Fix randomness in FixedCrop + add L1 loss in Pix2Pix (#1364)
LeoXing1996 Oct 30, 2022
4b55cdd
[Fix] fix realbasicvsr config (#1358)
Z-Fran Oct 30, 2022
a576d02
[Enhancement] Fix PESinGAN-inter-pad setting + add SinGAN Dataset + a…
LeoXing1996 Oct 30, 2022
6d83fb6
[Fix] fix types of exceptions in demos (#1372)
gaoyang07 Oct 31, 2022
3375fff
[Enhancement] Support deterministic training in benchmark (#1356)
LeoXing1996 Oct 31, 2022
9c0768f
[Fix] Avoid cast int and float in GenDataPreprocessor (#1385)
LeoXing1996 Oct 31, 2022
71961a7
[Config] Update metric config in ggan (#1386)
LeoXing1996 Oct 31, 2022
c3cff1d
[Config] Revise batch size in wang-gp's config (#1384)
LeoXing1996 Oct 31, 2022
bf7714e
[Fix]: add type and change default number of preprocess_div2k_dataset…
ruoningYu Oct 31, 2022
00c55b5
[Feature] Support qualitative comparison tools (#1303)
Z-Fran Nov 2, 2022
525767c
[Docs] Revise docs (change PackGenInputs and GenDataSample to mmediti…
LeoXing1996 Nov 2, 2022
424c48c
[Config] Revise Pix2Pix edges2shoes config (#1391)
LeoXing1996 Nov 2, 2022
73ba837
[Bug] fix rdn and srcnn train configs (#1392)
Z-Fran Nov 2, 2022
239bdf1
[Fix] Fix test/val pipeline of pegan configs (#1393)
plyfager Nov 2, 2022
5f2a4c0
[Fix] Modify Readme of S3 (#1398)
plyfager Nov 3, 2022
31fe019
[Fix] Correct fid of ggan (#1397)
plyfager Nov 3, 2022
3575ba9
[Feature] support instance_aware_colorization inference (#1370)
zengyh1900 Nov 3, 2022
ec162c2
Merge branch '1.x' of github.com:open-mmlab/mmediting into dev-1.x
zengyh1900 Nov 4, 2022
76722c1
[Bug] fix cain config (#1404)
Z-Fran Nov 4, 2022
c93fb02
[Fix] Revise config and pretrain model loading in esrgan (#1407)
LeoXing1996 Nov 4, 2022
3a18501
[Fix] Fix lsgan config (#1409)
plyfager Nov 4, 2022
b70695d
[Enhancement] Support `try_import` for `mmdet` (#1408)
LeoXing1996 Nov 4, 2022
3349986
[Enhancement] Set ``real_feat`` to cpu in inception_utils (#1415)
plyfager Nov 7, 2022
b616736
[Enhancement] git ignore slurm generated files (#1416)
plyfager Nov 7, 2022
cb8b6df
[Fix] modify readme and configs of stylegan2&pegan (#1418)
plyfager Nov 8, 2022
ce064f4
[Enhancement] Support try-import for `clip` (#1420)
LeoXing1996 Nov 9, 2022
17e6d13
[Enhancement]: Improve the rendering of Docs-API (#1373)
ruoningYu Nov 9, 2022
807239d
[Fix] Complete requirements (#1419)
plyfager Nov 9, 2022
a32d838
[Doc] Update changelog and README for 1.0.0rc3. (#1421)
LeoXing1996 Nov 10, 2022
454b914
[Bug] Install clip in merge stage test (#1423)
LeoXing1996 Nov 10, 2022
4f8df7a
[Fix] Install clip in windows CI (#1424)
LeoXing1996 Nov 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[Refactoring] decompose the implementations of different metrics into…
… several files (#1161)

* refactor metrics
* add UT for refactored metrics
  • Loading branch information
zengyh1900 authored Oct 9, 2022
commit 944d3a891df3ecc3fe26eb45e048c7129736579d
31 changes: 18 additions & 13 deletions mmedit/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.

from .connectivity_error import ConnectivityError
from .equivariance import Equivariance
from .fid import FrechetInceptionDistance, TransFID
from .gradient_error import GradientError
from .inception_score import InceptionScore, TransIS
from .matting import SAD, ConnectivityError, GradientError, MattingMSE
from .mae import MAE
from .matting_mse import MattingMSE
from .ms_ssim import MultiScaleStructureSimilarity
from .mse import MSE
from .niqe import NIQE, niqe
from .pixel_metrics import MAE, MSE, PSNR, SNR, psnr, snr
from .ppl import PerceptualPathLength
from .precision_and_recall import PrecisionAndRecall
from .psnr import PSNR, psnr
from .sad import SAD
from .snr import SNR, snr
from .ssim import SSIM, ssim
from .swd import SlicedWassersteinDistance

__all__ = [
'ConnectivityError',
'GradientError',
'MAE',
'MattingMSE',
'MSE',
'NIQE',
'niqe',
'PSNR',
'psnr',
'SAD',
'SNR',
'snr',
'SSIM',
'ssim',
'Equivariance',
'MultiScaleStructureSimilarity',
'FrechetInceptionDistance',
'TransFID',
'InceptionScore',
'MultiScaleStructureSimilarity',
'TransIS',
'SAD',
'MattingMSE',
'ConnectivityError',
'GradientError',
'PerceptualPathLength',
'MultiScaleStructureSimilarity',
'PrecisionAndRecall',
'SlicedWassersteinDistance',
'TransFID',
'TransIS',
'NIQE',
'niqe',
'Equivariance',
]
117 changes: 117 additions & 0 deletions mmedit/evaluation/metrics/connectivity_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Evaluation metrics used in Image Matting."""

from typing import List, Sequence

import cv2
import numpy as np
from mmengine.evaluator import BaseMetric

from mmedit.registry import METRICS
from .metrics_utils import _fetch_data_and_check, average


@METRICS.register_module()
class ConnectivityError(BaseMetric):
"""Connectivity error for evaluating alpha matte prediction.

.. note::

Current implementation assume image / alpha / trimap array in numpy
format and with pixel value ranging from 0 to 255.

.. note::

pred_alpha should be masked by trimap before passing
into this metric

Args:
step (float): Step of threshold when computing intersection between
`alpha` and `pred_alpha`. Default to 0.1 .
norm_const (int): Divide the result to reduce its magnitude.
Default to 1000.

Default prefix: ''

Metrics:
- ConnectivityError (float): Connectivity Error
"""

def __init__(
self,
step=0.1,
norm_constant=1000,
**kwargs,
) -> None:
self.step = step
self.norm_constant = norm_constant
super().__init__(**kwargs)

def process(self, data_batch: Sequence[dict],
data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.

Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from
the model.
"""

for data_sample in data_samples:
pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample)

thresh_steps = np.arange(0, 1 + self.step, self.step)
round_down_map = -np.ones_like(gt_alpha)
for i in range(1, len(thresh_steps)):
gt_alpha_thresh = gt_alpha >= thresh_steps[i]
pred_alpha_thresh = pred_alpha >= thresh_steps[i]
intersection = gt_alpha_thresh & pred_alpha_thresh
intersection = intersection.astype(np.uint8)

# connected components
_, output, stats, _ = cv2.connectedComponentsWithStats(
intersection, connectivity=4)
# start from 1 in dim 0 to exclude background
size = stats[1:, -1]

# largest connected component of the intersection
omega = np.zeros_like(gt_alpha)
if len(size) != 0:
max_id = np.argmax(size)
# plus one to include background
omega[output == max_id + 1] = 1

mask = (round_down_map == -1) & (omega == 0)
round_down_map[mask] = thresh_steps[i - 1]
round_down_map[round_down_map == -1] = 1

gt_alpha_diff = gt_alpha - round_down_map
pred_alpha_diff = pred_alpha - round_down_map
# only calculate difference larger than or equal to 0.15
gt_alpha_phi = 1 - gt_alpha_diff * (gt_alpha_diff >= 0.15)
pred_alpha_phi = 1 - pred_alpha_diff * (pred_alpha_diff >= 0.15)

connectivity_error = np.sum(
np.abs(gt_alpha_phi - pred_alpha_phi) * (trimap == 128))

# divide by 1000 to reduce the magnitude of the result
connectivity_error /= self.norm_constant

self.results.append({'conn_err': connectivity_error})

def compute_metrics(self, results: List):
"""Compute the metrics from processed results.

Args:
results (dict): The processed results of each batch.

Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""

conn_err = average(results, 'conn_err')

return {'ConnectivityError': conn_err}
95 changes: 95 additions & 0 deletions mmedit/evaluation/metrics/gradient_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence

import cv2
import numpy as np
from mmengine.evaluator import BaseMetric

from mmedit.registry import METRICS
from ..functional import gauss_gradient
from .metrics_utils import _fetch_data_and_check, average


@METRICS.register_module()
class GradientError(BaseMetric):
"""Gradient error for evaluating alpha matte prediction.

.. note::

Current implementation assume image / alpha / trimap array in numpy
format and with pixel value ranging from 0 to 255.

.. note::

pred_alpha should be masked by trimap before passing
into this metric

Args:
sigma (float): Standard deviation of the gaussian kernel.
Defaults to 1.4 .
norm_const (int): Divide the result to reduce its magnitude.
Defaults to 1000 .

Default prefix: ''

Metrics:
- GradientError (float): Gradient Error
"""

def __init__(
self,
sigma=1.4,
norm_constant=1000,
**kwargs,
) -> None:
self.sigma = sigma
self.norm_constant = norm_constant
super().__init__(**kwargs)

def process(self, data_batch: Sequence[dict],
data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.

Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from
the model.
"""

for data_sample in data_samples:
pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample)

gt_alpha_normed = np.zeros_like(gt_alpha)
pred_alpha_normed = np.zeros_like(pred_alpha)

cv2.normalize(gt_alpha, gt_alpha_normed, 1.0, 0.0, cv2.NORM_MINMAX)
cv2.normalize(pred_alpha, pred_alpha_normed, 1.0, 0.0,
cv2.NORM_MINMAX)

gt_alpha_grad = gauss_gradient(gt_alpha_normed, self.sigma)
pred_alpha_grad = gauss_gradient(pred_alpha_normed, self.sigma)
# this is the sum over n samples
grad_loss = ((gt_alpha_grad - pred_alpha_grad)**2 *
(trimap == 128)).sum()

# divide by 1000 to reduce the magnitude of the result
grad_loss /= self.norm_constant

self.results.append({'grad_err': grad_loss})

def compute_metrics(self, results: List):
"""Compute the metrics from processed results.

Args:
results (dict): The processed results of each batch.

Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""

grad_err = average(results, 'grad_err')

return {'GradientError': grad_err}
60 changes: 60 additions & 0 deletions mmedit/evaluation/metrics/mae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Evaluation metrics based on pixels."""

import numpy as np

from mmedit.registry import METRICS
from .base_sample_wise_metric import BaseSampleWiseMetric


@METRICS.register_module()
class MAE(BaseSampleWiseMetric):
"""Mean Absolute Error metric for image.

mean(abs(a-b))

Args:

gt_key (str): Key of ground-truth. Default: 'gt_img'
pred_key (str): Key of prediction. Default: 'pred_img'
mask_key (str, optional): Key of mask, if mask_key is None, calculate
all regions. Default: None
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Default: None

Metrics:
- MAE (float): Mean of Absolute Error
"""

metric = 'MAE'

def process_image(self, gt, pred, mask):
"""Process an image.

Args:
gt (Tensor | np.ndarray): GT image.
pred (Tensor | np.ndarray): Pred image.
mask (Tensor | np.ndarray): Mask of evaluation.
Returns:
result (np.ndarray): MAE result.
"""

gt = gt / 255.
pred = pred / 255.

diff = gt - pred
diff = abs(diff)

if self.mask_key is not None:
diff *= mask # broadcast for channel dimension
scale = np.prod(diff.shape) / np.prod(mask.shape)
result = diff.sum() / (mask.sum() * scale + 1e-12)
else:
result = diff.mean()

return result
Loading