Skip to content

Commit 9fc3a01

Browse files
committed
Routine updates.
1 parent 7eed9f1 commit 9fc3a01

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

validate.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
from typing import Callable
22

33
from cv2 import imread, cvtColor, COLOR_RGB2GRAY
4-
from numpy import ndarray, logical_and, load
4+
from medpy.metric.binary import dc
5+
from numpy import ndarray, load, abs as npabs, max as npmax, sum as npsum
56
from rich.progress import Progress
67

78
from utils import get_items
89

910

1011
def calculate_dcs(a: ndarray, b: ndarray) -> float:
11-
a, b = a.astype(bool), b.astype(bool)
12-
return float(2 * logical_and(a, b).sum() / (a.sum() + b.sum()))
12+
return dc((a / npmax(a)) == 1, (b / npmax(b)) == 1)
1313

1414

1515
def calculate_nsd(a: ndarray, b: ndarray) -> float:
16-
return abs(a - b).sum() / max(a.sum(), b.sum())
16+
a, b = (a / npmax(a)).astype(int), (b / npmax(b)).astype(int)
17+
sum_diff = npsum(npabs(a - b))
18+
max_sum = max(npsum(a), npsum(b))
19+
return sum_diff / max_sum
1720

1821

1922
def evaluate(src: str, val: str, method: Callable[[ndarray, ndarray], float]) -> float:
@@ -25,8 +28,7 @@ def evaluate(src: str, val: str, method: Callable[[ndarray, ndarray], float]) ->
2528
for path in get_items(val):
2629
if not path.endswith(".npy"):
2730
continue
28-
r += method(load(f"{val}/{path}"),
29-
cvtColor(imread(f"{src}/case_{str(i).zfill(4)}.png"), COLOR_RGB2GRAY) / 256)
31+
r += method(load(f"{val}/{path}"), cvtColor(imread(f"{src}/case_{str(i).zfill(4)}.png"), COLOR_RGB2GRAY))
3032
i += 1
3133
progress.update(task, advance=1)
3234
return r / i

0 commit comments

Comments
 (0)