Skip to content

AZ Evaluation #124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ scripts/cooper/training/copy_testset.py
scripts/rizzoli/upsample_data.py
scripts/cooper/training/find_rec_testset.py
synapse-net-models/
scripts/portal/upscale_tomo.py
152 changes: 150 additions & 2 deletions scripts/cooper/ground_truth/az/evaluate_az.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
import argparse

import h5py
import pandas as pd
import numpy as np
from elf.evaluation import dice_score

from torch_em.transform.label import connected_components
from scipy.ndimage import binary_dilation, binary_closing
from tqdm import tqdm
from elf.evaluation import matching
from elf.evaluation.matching import _compute_scores, _compute_tps


def _expand_AZ(az):
Expand Down Expand Up @@ -89,4 +93,148 @@ def main():
print(scores["Dice"].mean(), "+-", scores["Dice"].std())


main()
def get_bounding_box(mask, halo=2):
""" Get bounding box coordinates around a mask with a halo. """
coords = np.argwhere(mask)
if coords.size == 0:
return None # No labels found

min_coords = coords.min(axis=0)
max_coords = coords.max(axis=0)

min_coords = np.maximum(min_coords - halo, 0)
max_coords = np.minimum(max_coords + halo, mask.shape)

slices = tuple(slice(min_c, max_c) for min_c, max_c in zip(min_coords, max_coords))
return slices


def evaluate(labels, seg):
assert labels.shape == seg.shape
stats = matching(seg, labels)
n_true, n_matched, n_pred, scores = _compute_scores(seg, labels, criterion= "iou", ignore_label=None)
tp = _compute_tps(scores, n_matched, threshold= 0.5)
fp = n_pred - tp
fn = n_true - tp
return stats["f1"], stats["precision"], stats["recall"], tp, fp, fn


def evaluate_file(labels_path, seg_path, segment_key, anno_key, mask_key=None, cc=False):
print(f"Evaluating: {os.path.basename(labels_path)}")

ds_name = os.path.basename(os.path.dirname(labels_path))
tomo = os.path.basename(labels_path)

with h5py.File(labels_path) as label_file:
labels = label_file["labels"]
gt = labels[anno_key][:]
if mask_key is not None:
mask = labels[mask_key][:]

with h5py.File(seg_path) as seg_file:
az = seg_file["AZ"][segment_key][:]

if mask_key is not None:
gt[mask == 0] = 0
az[mask == 0] = 0

# Optionally apply connected components
if cc:
gt = connected_components(gt)
az = connected_components(az)

# Optionally crop to bounding box
use_bb = False
if use_bb:
bb_slices = get_bounding_box(gt, halo=2)
gt = gt[bb_slices]
az = az[bb_slices]

f1, precision, recall, tp, fp, fn = evaluate(gt, az)

return pd.DataFrame([[ds_name, tomo, f1, precision, recall, tp, fp, fn]],
columns=["dataset", "tomogram", "f1-score", "precision", "recall", "tp", "fp", "fn"])


def evaluate_folder(labels_path, seg_path, model_name, segment_key, anno_key, mask_key=None, cc=False):
print(f"\nEvaluating folder: {seg_path}")
label_files = os.listdir(labels_path)
seg_files = os.listdir(seg_path)

all_results = []

for seg_file in seg_files:
if seg_file in label_files:
label_fp = os.path.join(labels_path, seg_file)
seg_fp = os.path.join(seg_path, seg_file)

res = evaluate_file(label_fp, seg_fp, segment_key, anno_key, mask_key, cc)
all_results.append(res)

if not all_results:
print("No matched files found for evaluation.")
return

results_df = pd.concat(all_results, ignore_index=True)

# Convert TP, FP, FN to integers
results_df[["tp", "fp", "fn"]] = results_df[["tp", "fp", "fn"]].astype(int)

# Compute folder-level TP/FP/FN and final metrics
total_tp = results_df["tp"].sum()
total_fp = results_df["fp"].sum()
total_fn = results_df["fn"].sum()

precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

folder_name = os.path.basename(seg_path)
print(f"\n Folder-Level Metrics for '{folder_name}':")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}\n")

# Save results
result_dir = "/user/muth9/u12095/synapse-net/scripts/cooper/evaluation_results"
os.makedirs(result_dir, exist_ok=True)

# Per-file results
file_results_path = os.path.join(result_dir, f"evaluation_{model_name}_per_file.csv")
if os.path.exists(file_results_path):
existing_df = pd.read_csv(file_results_path)
results_df = pd.concat([existing_df, results_df], ignore_index=True)
results_df.to_csv(file_results_path, index=False)
print(f"Per-file results saved to {file_results_path}")

# Folder-level summary
folder_summary_path = os.path.join(result_dir, f"evaluation_{model_name}_per_folder.csv")
header_needed = not os.path.exists(folder_summary_path)
with open(folder_summary_path, "a") as f:
if header_needed:
f.write("folder,f1-score,precision,recall,tp,fp,fn\n")
f.write(f"{folder_name},{f1_score:.4f},{precision:.4f},{recall:.4f},{total_tp},{total_fp},{total_fn}\n")
print(f"Folder summary appended to {folder_summary_path}")


def main_f1():
parser = argparse.ArgumentParser()
parser.add_argument("-l", "--labels_path", required=True)
parser.add_argument("-seg", "--seg_path", required=True)
parser.add_argument("-n", "--model_name", required=True)
parser.add_argument("-sk", "--segment_key", required=True)
parser.add_argument("-ak", "--anno_key", required=True)
parser.add_argument("-m", "--mask_key")
parser.add_argument("--cc", "--connected_components", dest="cc", action="store_true",
help="Apply connected components to ground truth and segmentation before evaluation")

args = parser.parse_args()

if os.path.isdir(args.seg_path):
evaluate_folder(args.labels_path, args.seg_path, args.model_name, args.segment_key, args.anno_key, args.mask_key, args.cc)
else:
print("Please pass a folder to get folder-level metrics.")

if __name__ == "__main__":
#main() #for dice score
main_f1()
48 changes: 48 additions & 0 deletions scripts/cooper/ground_truth/az/postprocess_az.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import glob
import h5py
import argparse
import sys
sys.path.append('/user/muth9/u12095/synapse-net')
from synapse_net.ground_truth.shape_refinement import edge_filter

def postprocess_file(file_path):
"""Processes a single .h5 file, applying an edge filter and saving the membrane mask."""#
print(f"processing file {file_path}")
with h5py.File(file_path, "a") as f:
raw = f["raw"][:]
print("applying the edge filter ...")
hmap = edge_filter(raw, sigma=1.0, method="sato", per_slice=True, n_threads=8)
membrane_mask = hmap > 0.5
print("saving results ....")
try:
f.create_dataset("labels/membrane_mask", data=membrane_mask, compression="gzip")
except:
print(f"membrane mask aleady saved for {file_path}")
print("Done!")

def postprocess_folder(folder_path):
"""Processes all .h5 files in a given folder recursively."""
files = sorted(glob.glob(os.path.join(folder_path, '**', '*.h5'), recursive=True))
print("Processing files:", files)

for file_path in files:
postprocess_file(file_path)

def main():
#/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects/01_hoi_maus_2020_incomplete
#/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects/04_hoi_stem_examples
#/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects/06_hoi_wt_stem750_fm
#/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects/12_chemical_fix_cryopreparation

parser = argparse.ArgumentParser(description="Postprocess .h5 files by applying edge filtering.")
parser.add_argument("-p", "--data_path", required=True, help="Path to the .h5 file or folder.")
args = parser.parse_args()

if os.path.isdir(args.data_path):
postprocess_folder(args.data_path)
else:
postprocess_file(args.data_path)

if __name__ == "__main__":
main()
76 changes: 62 additions & 14 deletions scripts/cooper/training/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,40 @@

import h5py
import pandas as pd
import numpy as np

from elf.evaluation import matching

from elf.evaluation import matching, symmetric_best_dice_score

def get_bounding_box(mask, halo=2):
""" Get bounding box coordinates around a mask with a halo."""
coords = np.argwhere(mask)
if coords.size == 0:
return None # No labels found

min_coords = coords.min(axis=0)
max_coords = coords.max(axis=0)

min_coords = np.maximum(min_coords - halo, 0)
max_coords = np.minimum(max_coords + halo, mask.shape)

slices = tuple(slice(min_c, max_c) for min_c, max_c in zip(min_coords, max_coords))
return slices

def evaluate(labels, vesicles):
assert labels.shape == vesicles.shape
stats = matching(vesicles, labels)
return [stats["f1"], stats["precision"], stats["recall"]]
sbd = symmetric_best_dice_score(vesicles, labels)
return [stats["f1"], stats["precision"], stats["recall"], sbd]


def summarize_eval(results):
summary = results[["dataset", "f1-score", "precision", "recall"]].groupby("dataset").mean().reset_index("dataset")
total = results[["f1-score", "precision", "recall"]].mean().values.tolist()
summary = results[["dataset", "f1-score", "precision", "recall", "SBD score"]].groupby("dataset").mean().reset_index("dataset")
total = results[["f1-score", "precision", "recall", "SBD score"]].mean().values.tolist()
summary.iloc[-1] = ["all"] + total
table = summary.to_markdown(index=False)
print(table)

def evaluate_file(labels_path, vesicles_path, model_name, segment_key, anno_key):
def evaluate_file(labels_path, vesicles_path, model_name, segment_key, anno_key, mask_key = None):
print(f"Evaluate labels {labels_path} and vesicles {vesicles_path}")

ds_name = os.path.basename(os.path.dirname(labels_path))
Expand All @@ -33,16 +48,31 @@ def evaluate_file(labels_path, vesicles_path, model_name, segment_key, anno_key)
#vesicles = labels["vesicles"]
gt = labels[anno_key][:]

if mask_key is not None:
mask = labels[mask_key][:]

with h5py.File(vesicles_path) as seg_file:
segmentation = seg_file["vesicles"]
vesicles = segmentation[segment_key][:]


#evaluate the match of ground truth and vesicles
if mask_key is not None:
gt[mask == 0] = 0
vesicles[mask == 0] = 0

bb= False
if bb:
# Get bounding box and crop
bb_slices = get_bounding_box(gt, halo=2)
gt = gt[bb_slices]
vesicles = vesicles[bb_slices]
else:
print("not using bb")

#evaluate the match of ground truth and vesicles
scores = evaluate(gt, vesicles)

#store results
result_folder ="/user/muth9/u12095/synaptic-reconstruction/scripts/cooper/evaluation_results"
result_folder ="/user/muth9/u12095/synapse-net/scripts/cooper/evaluation_results"
os.makedirs(result_folder, exist_ok=True)
result_path=os.path.join(result_folder, f"evaluation_{model_name}.csv")
print("Evaluation results are saved to:", result_path)
Expand All @@ -53,7 +83,7 @@ def evaluate_file(labels_path, vesicles_path, model_name, segment_key, anno_key)
results = None

res = pd.DataFrame(
[[ds_name, tomo] + scores], columns=["dataset", "tomogram", "f1-score", "precision", "recall"]
[[ds_name, tomo] + scores], columns=["dataset", "tomogram", "f1-score", "precision", "recall", "SBD score"]
)
if results is None:
results = res
Expand All @@ -65,7 +95,7 @@ def evaluate_file(labels_path, vesicles_path, model_name, segment_key, anno_key)
summarize_eval(results)


def evaluate_folder(labels_path, vesicles_path, model_name, segment_key, anno_key):
def evaluate_folder(labels_path, vesicles_path, model_name, segment_key, anno_key, mask_key = None):
print(f"Evaluating folder {vesicles_path}")
print(f"Using labels stored in {labels_path}")

Expand All @@ -75,9 +105,25 @@ def evaluate_folder(labels_path, vesicles_path, model_name, segment_key, anno_ke
for vesicle_file in vesicles_files:
if vesicle_file in label_files:

evaluate_file(os.path.join(labels_path, vesicle_file), os.path.join(vesicles_path, vesicle_file), model_name, segment_key, anno_key)
evaluate_file(os.path.join(labels_path, vesicle_file), os.path.join(vesicles_path, vesicle_file), model_name, segment_key, anno_key, mask_key)

def evaluate_folder_onlyGTAnnotations(labels_path, vesicles_path, model_name, segment_key, anno_key, mask_key=None):
print(f"Evaluating folder {vesicles_path}")
print(f"Using labels stored in {labels_path}")

label_files = set(os.listdir(labels_path))
vesicles_files = os.listdir(vesicles_path)

for vesicle_file in vesicles_files:
if vesicle_file.endswith('_processed.h5'):
# Remove '_processed' to get the corresponding label file name
label_file = vesicle_file.replace('_processed', '')
if label_file in label_files:
evaluate_file(
os.path.join(labels_path, label_file),
os.path.join(vesicles_path, vesicle_file),
model_name, segment_key, anno_key, mask_key
)

def main():

Expand All @@ -87,13 +133,15 @@ def main():
parser.add_argument("-n", "--model_name", required=True)
parser.add_argument("-sk", "--segment_key", required=True)
parser.add_argument("-ak", "--anno_key", required=True)
parser.add_argument("-m", "--mask_key")
args = parser.parse_args()

vesicles_path = args.vesicles_path
if os.path.isdir(vesicles_path):
evaluate_folder(args.labels_path, vesicles_path, args.model_name, args.segment_key, args.anno_key)
evaluate_folder(args.labels_path, vesicles_path, args.model_name, args.segment_key, args.anno_key, args.mask_key)
#evaluate_folder_onlyGTAnnotations(args.labels_path, vesicles_path, args.model_name, args.segment_key, args.anno_key, args.mask_key)
else:
evaluate_file(args.labels_path, vesicles_path, args.model_name, args.segment_key, args.anno_key)
evaluate_file(args.labels_path, vesicles_path, args.model_name, args.segment_key, args.anno_key, args.mask_key)



Expand Down
Loading