Skip to content

Commit

Permalink
refactor: unify on num instead of nb or n (#2090)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Sep 21, 2023
1 parent 6bf705e commit 25bf259
Show file tree
Hide file tree
Showing 35 changed files with 274 additions and 270 deletions.
4 changes: 2 additions & 2 deletions src/torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class PerceptualEvaluationSpeechQuality(Metric):
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> nb_pesq = PerceptualEvaluationSpeechQuality(8000, 'nb')
>>> nb_pesq(preds, target)
>>> pesq = PerceptualEvaluationSpeechQuality(8000, 'nb')
>>> pesq(preds, target)
tensor(2.2076)
>>> wb_pesq = PerceptualEvaluationSpeechQuality(16000, 'wb')
>>> wb_pesq(preds, target)
Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/classification/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def update(self, preds: Tensor, target: Tensor) -> None:
preds, target = _multilabel_confusion_matrix_format(
preds, target, self.num_labels, threshold=0.0, ignore_index=self.ignore_index, should_threshold=False
)
measure, n_elements = _multilabel_coverage_error_update(preds, target)
measure, num_elements = _multilabel_coverage_error_update(preds, target)
self.measure += measure
self.total += n_elements
self.total += num_elements

def compute(self) -> Tensor:
"""Compute metric."""
Expand Down Expand Up @@ -226,9 +226,9 @@ def update(self, preds: Tensor, target: Tensor) -> None:
preds, target = _multilabel_confusion_matrix_format(
preds, target, self.num_labels, threshold=0.0, ignore_index=self.ignore_index, should_threshold=False
)
measure, n_elements = _multilabel_ranking_average_precision_update(preds, target)
measure, num_elements = _multilabel_ranking_average_precision_update(preds, target)
self.measure += measure
self.total += n_elements
self.total += num_elements

def compute(self) -> Tensor:
"""Compute metric."""
Expand Down Expand Up @@ -348,9 +348,9 @@ def update(self, preds: Tensor, target: Tensor) -> None:
preds, target = _multilabel_confusion_matrix_format(
preds, target, self.num_labels, threshold=0.0, ignore_index=self.ignore_index, should_threshold=False
)
measure, n_elements = _multilabel_ranking_loss_update(preds, target)
measure, num_elements = _multilabel_ranking_loss_update(preds, target)
self.measure += measure
self.total += n_elements
self.total += num_elements

def compute(self) -> Tensor:
"""Compute metric."""
Expand Down
8 changes: 4 additions & 4 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _merge_compute_groups(self) -> None:
``O(number_of_metrics_in_collection ** 2)``, as all metrics need to be compared to all other metrics.
"""
n_groups = len(self._groups)
num_groups = len(self._groups)
while True:
for cg_idx1, cg_members1 in deepcopy(self._groups).items():
for cg_idx2, cg_members2 in deepcopy(self._groups).items():
Expand All @@ -247,13 +247,13 @@ def _merge_compute_groups(self) -> None:
break

# Start over if we merged groups
if len(self._groups) != n_groups:
if len(self._groups) != num_groups:
break

# Stop when we iterate over everything and do not merge any groups
if len(self._groups) == n_groups:
if len(self._groups) == num_groups:
break
n_groups = len(self._groups)
num_groups = len(self._groups)

# Re-index groups
temp = deepcopy(self._groups)
Expand Down
98 changes: 52 additions & 46 deletions src/torchmetrics/detection/_mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,37 +452,43 @@ def _compute_iou(self, idx: int, class_id: int, max_det: int) -> Tensor:
return compute_iou(det, gt, self.iou_type).to(self.device)

def __evaluate_image_gt_no_preds(
self, gt: Tensor, gt_label_mask: Tensor, area_range: Tuple[int, int], nb_iou_thrs: int
self, gt: Tensor, gt_label_mask: Tensor, area_range: Tuple[int, int], num_iou_thrs: int
) -> Dict[str, Any]:
"""Evaluate images with a ground truth but no predictions."""
# GTs
gt = [gt[i] for i in gt_label_mask]
nb_gt = len(gt)
num_gt = len(gt)
areas = compute_area(gt, iou_type=self.iou_type).to(self.device)
ignore_area = (areas < area_range[0]) | (areas > area_range[1])
gt_ignore, _ = torch.sort(ignore_area.to(torch.uint8))
gt_ignore = gt_ignore.to(torch.bool)

# Detections
nb_det = 0
det_ignore = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device)
num_det = 0
det_ignore = torch.zeros((num_iou_thrs, num_det), dtype=torch.bool, device=self.device)

return {
"dtMatches": torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device),
"gtMatches": torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device),
"dtScores": torch.zeros(nb_det, dtype=torch.float32, device=self.device),
"dtMatches": torch.zeros((num_iou_thrs, num_det), dtype=torch.bool, device=self.device),
"gtMatches": torch.zeros((num_iou_thrs, num_gt), dtype=torch.bool, device=self.device),
"dtScores": torch.zeros(num_det, dtype=torch.float32, device=self.device),
"gtIgnore": gt_ignore,
"dtIgnore": det_ignore,
}

def __evaluate_image_preds_no_gt(
self, det: Tensor, idx: int, det_label_mask: Tensor, max_det: int, area_range: Tuple[int, int], nb_iou_thrs: int
self,
det: Tensor,
idx: int,
det_label_mask: Tensor,
max_det: int,
area_range: Tuple[int, int],
num_iou_thrs: int,
) -> Dict[str, Any]:
"""Evaluate images with a prediction but no ground truth."""
# GTs
nb_gt = 0
num_gt = 0

gt_ignore = torch.zeros(nb_gt, dtype=torch.bool, device=self.device)
gt_ignore = torch.zeros(num_gt, dtype=torch.bool, device=self.device)

# Detections

Expand All @@ -494,15 +500,15 @@ def __evaluate_image_preds_no_gt(
det = [det[i] for i in dtind]
if len(det) > max_det:
det = det[:max_det]
nb_det = len(det)
num_det = len(det)
det_areas = compute_area(det, iou_type=self.iou_type).to(self.device)
det_ignore_area = (det_areas < area_range[0]) | (det_areas > area_range[1])
ar = det_ignore_area.reshape((1, nb_det))
det_ignore = torch.repeat_interleave(ar, nb_iou_thrs, 0)
ar = det_ignore_area.reshape((1, num_det))
det_ignore = torch.repeat_interleave(ar, num_iou_thrs, 0)

return {
"dtMatches": torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device),
"gtMatches": torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device),
"dtMatches": torch.zeros((num_iou_thrs, num_det), dtype=torch.bool, device=self.device),
"gtMatches": torch.zeros((num_iou_thrs, num_gt), dtype=torch.bool, device=self.device),
"dtScores": scores_sorted.to(self.device),
"gtIgnore": gt_ignore.to(self.device),
"dtIgnore": det_ignore.to(self.device),
Expand Down Expand Up @@ -535,15 +541,15 @@ def _evaluate_image(
if len(gt_label_mask) == 0 and len(det_label_mask) == 0:
return None

nb_iou_thrs = len(self.iou_thresholds)
num_iou_thrs = len(self.iou_thresholds)

# Some GT but no predictions
if len(gt_label_mask) > 0 and len(det_label_mask) == 0:
return self.__evaluate_image_gt_no_preds(gt, gt_label_mask, area_range, nb_iou_thrs)
return self.__evaluate_image_gt_no_preds(gt, gt_label_mask, area_range, num_iou_thrs)

# Some predictions but no GT
if len(gt_label_mask) == 0 and len(det_label_mask) > 0:
return self.__evaluate_image_preds_no_gt(det, idx, det_label_mask, max_det, area_range, nb_iou_thrs)
return self.__evaluate_image_preds_no_gt(det, idx, det_label_mask, max_det, area_range, num_iou_thrs)

gt = [gt[i] for i in gt_label_mask]
det = [det[i] for i in det_label_mask]
Expand Down Expand Up @@ -574,13 +580,13 @@ def _evaluate_image(
# load computed ious
ious = ious[idx, class_id][:, gtind] if len(ious[idx, class_id]) > 0 else ious[idx, class_id]

nb_iou_thrs = len(self.iou_thresholds)
nb_gt = len(gt)
nb_det = len(det)
gt_matches = torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device)
det_matches = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device)
num_iou_thrs = len(self.iou_thresholds)
num_gt = len(gt)
num_det = len(det)
gt_matches = torch.zeros((num_iou_thrs, num_gt), dtype=torch.bool, device=self.device)
det_matches = torch.zeros((num_iou_thrs, num_det), dtype=torch.bool, device=self.device)
gt_ignore = ignore_area_sorted
det_ignore = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device)
det_ignore = torch.zeros((num_iou_thrs, num_det), dtype=torch.bool, device=self.device)

if torch.numel(ious) > 0:
for idx_iou, t in enumerate(self.iou_thresholds):
Expand All @@ -595,9 +601,9 @@ def _evaluate_image(
# set unmatched detections outside of area range to ignore
det_areas = compute_area(det, iou_type=self.iou_type).to(self.device)
det_ignore_area = (det_areas < area_range[0]) | (det_areas > area_range[1])
ar = det_ignore_area.reshape((1, nb_det))
ar = det_ignore_area.reshape((1, num_det))
det_ignore = torch.logical_or(
det_ignore, torch.logical_and(det_matches == 0, torch.repeat_interleave(ar, nb_iou_thrs, 0))
det_ignore, torch.logical_and(det_matches == 0, torch.repeat_interleave(ar, num_iou_thrs, 0))
)

return {
Expand Down Expand Up @@ -708,15 +714,15 @@ def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResult
for img_id in img_ids
]

nb_iou_thrs = len(self.iou_thresholds)
nb_rec_thrs = len(self.rec_thresholds)
nb_classes = len(class_ids)
nb_bbox_areas = len(self.bbox_area_ranges)
nb_max_det_thrs = len(self.max_detection_thresholds)
nb_imgs = len(img_ids)
precision = -torch.ones((nb_iou_thrs, nb_rec_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs))
recall = -torch.ones((nb_iou_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs))
scores = -torch.ones((nb_iou_thrs, nb_rec_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs))
num_iou_thrs = len(self.iou_thresholds)
num_rec_thrs = len(self.rec_thresholds)
num_classes = len(class_ids)
num_bbox_areas = len(self.bbox_area_ranges)
num_max_det_thrs = len(self.max_detection_thresholds)
num_imgs = len(img_ids)
precision = -torch.ones((num_iou_thrs, num_rec_thrs, num_classes, num_bbox_areas, num_max_det_thrs))
recall = -torch.ones((num_iou_thrs, num_classes, num_bbox_areas, num_max_det_thrs))
scores = -torch.ones((num_iou_thrs, num_rec_thrs, num_classes, num_bbox_areas, num_max_det_thrs))

# move tensors if necessary
rec_thresholds_tensor = torch.tensor(self.rec_thresholds)
Expand All @@ -735,8 +741,8 @@ def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResult
eval_imgs=eval_imgs,
rec_thresholds=rec_thresholds_tensor,
max_det=max_det,
nb_imgs=nb_imgs,
nb_bbox_areas=nb_bbox_areas,
num_imgs=num_imgs,
num_bbox_areas=num_bbox_areas,
)

return precision, recall
Expand Down Expand Up @@ -787,14 +793,14 @@ def __calculate_recall_precision_scores(
eval_imgs: list,
rec_thresholds: Tensor,
max_det: int,
nb_imgs: int,
nb_bbox_areas: int,
num_imgs: int,
num_bbox_areas: int,
) -> Tuple[Tensor, Tensor, Tensor]:
nb_rec_thrs = len(rec_thresholds)
idx_cls_pointer = idx_cls * nb_bbox_areas * nb_imgs
idx_bbox_area_pointer = idx_bbox_area * nb_imgs
num_rec_thrs = len(rec_thresholds)
idx_cls_pointer = idx_cls * num_bbox_areas * num_imgs
idx_bbox_area_pointer = idx_bbox_area * num_imgs
# Load all image evals for current class_id and area_range
img_eval_cls_bbox = [eval_imgs[idx_cls_pointer + idx_bbox_area_pointer + i] for i in range(nb_imgs)]
img_eval_cls_bbox = [eval_imgs[idx_cls_pointer + idx_bbox_area_pointer + i] for i in range(num_imgs)]
img_eval_cls_bbox = [e for e in img_eval_cls_bbox if e is not None]
if not img_eval_cls_bbox:
return recall, precision, scores
Expand Down Expand Up @@ -824,8 +830,8 @@ def __calculate_recall_precision_scores(
nd = len(tp)
rc = tp / npig
pr = tp / (fp + tp + torch.finfo(torch.float64).eps)
prec = torch.zeros((nb_rec_thrs,))
score = torch.zeros((nb_rec_thrs,))
prec = torch.zeros((num_rec_thrs,))
score = torch.zeros((num_rec_thrs,))

recall[idx, idx_cls, idx_bbox_area, idx_max_det_thrs] = rc[-1] if nd else 0

Expand All @@ -837,7 +843,7 @@ def __calculate_recall_precision_scores(
pr += diff

inds = torch.searchsorted(rc, rec_thresholds.to(rc.device), right=False)
num_inds = inds.argmax() if inds.max() >= nd else nb_rec_thrs
num_inds = inds.argmax() if inds.max() >= nd else num_rec_thrs
inds = inds[:num_inds]
prec[:num_inds] = pr[inds]
score[:num_inds] = det_scores_sorted[inds]
Expand Down
20 changes: 10 additions & 10 deletions src/torchmetrics/detection/panoptic_qualities.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ def __init__(
self.allow_unknown_preds_category = allow_unknown_preds_category

# per category intermediate metrics
n_categories = len(things) + len(stuffs)
self.add_state("iou_sum", default=torch.zeros(n_categories, dtype=torch.double), dist_reduce_fx="sum")
self.add_state("true_positives", default=torch.zeros(n_categories, dtype=torch.int), dist_reduce_fx="sum")
self.add_state("false_positives", default=torch.zeros(n_categories, dtype=torch.int), dist_reduce_fx="sum")
self.add_state("false_negatives", default=torch.zeros(n_categories, dtype=torch.int), dist_reduce_fx="sum")
num_categories = len(things) + len(stuffs)
self.add_state("iou_sum", default=torch.zeros(num_categories, dtype=torch.double), dist_reduce_fx="sum")
self.add_state("true_positives", default=torch.zeros(num_categories, dtype=torch.int), dist_reduce_fx="sum")
self.add_state("false_positives", default=torch.zeros(num_categories, dtype=torch.int), dist_reduce_fx="sum")
self.add_state("false_negatives", default=torch.zeros(num_categories, dtype=torch.int), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
r"""Update state with predictions and targets.
Expand Down Expand Up @@ -287,11 +287,11 @@ def __init__(
self.allow_unknown_preds_category = allow_unknown_preds_category

# per category intermediate metrics
n_categories = len(things) + len(stuffs)
self.add_state("iou_sum", default=torch.zeros(n_categories, dtype=torch.double), dist_reduce_fx="sum")
self.add_state("true_positives", default=torch.zeros(n_categories, dtype=torch.int), dist_reduce_fx="sum")
self.add_state("false_positives", default=torch.zeros(n_categories, dtype=torch.int), dist_reduce_fx="sum")
self.add_state("false_negatives", default=torch.zeros(n_categories, dtype=torch.int), dist_reduce_fx="sum")
num_categories = len(things) + len(stuffs)
self.add_state("iou_sum", default=torch.zeros(num_categories, dtype=torch.double), dist_reduce_fx="sum")
self.add_state("true_positives", default=torch.zeros(num_categories, dtype=torch.int), dist_reduce_fx="sum")
self.add_state("false_positives", default=torch.zeros(num_categories, dtype=torch.int), dist_reduce_fx="sum")
self.add_state("false_negatives", default=torch.zeros(num_categories, dtype=torch.int), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
r"""Update state with predictions and targets.
Expand Down
16 changes: 8 additions & 8 deletions src/torchmetrics/functional/audio/srmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def _erb_filterbank(wave: Tensor, coefs: Tensor) -> Tensor:
Tensor: shape [B, N, time]
"""
n_batch, time = wave.shape
wave = wave.to(dtype=coefs.dtype).reshape(n_batch, 1, time) # [B, time]
num_batch, time = wave.shape
wave = wave.to(dtype=coefs.dtype).reshape(num_batch, 1, time) # [B, time]
wave = wave.expand(-1, coefs.shape[0], -1) # [B, N, time]

gain = coefs[:, 9]
Expand Down Expand Up @@ -250,7 +250,7 @@ def speech_reverberation_modulation_energy_ratio(
)
shape = preds.shape
preds = preds.reshape(1, -1) if len(shape) == 1 else preds.reshape(-1, shape[-1])
n_batch, time = preds.shape
num_batch, time = preds.shape
# convert int type to float
if not torch.is_floating_point(preds):
preds = preds.to(torch.float64) / torch.finfo(preds.dtype).max
Expand All @@ -272,7 +272,7 @@ def speech_reverberation_modulation_energy_ratio(
mfs = 400.0
temp = []
preds_np = preds.detach().cpu().numpy()
for b in range(n_batch):
for b in range(num_batch):
gt_env_b = fft_gtgram(preds_np[b], fs, 0.010, 0.0025, n_cochlear_filters, low_freq)
temp.append(torch.tensor(gt_env_b))
gt_env = torch.stack(temp, dim=0).to(device=preds.device)
Expand All @@ -291,7 +291,7 @@ def speech_reverberation_modulation_energy_ratio(
min_cf, max_cf, n=8, fs=mfs, q=2, device=preds.device
)

n_frames = int(1 + (time - w_length) // w_inc)
num_frames = int(1 + (time - w_length) // w_inc)
w = torch.hamming_window(w_length + 1, dtype=torch.float64, device=preds.device)[:-1]
mod_out = lfilter(
gt_env.unsqueeze(-2).expand(-1, -1, mf.shape[0], -1), mf[:, 1, :], mf[:, 0, :], clamp=False, batching=True
Expand All @@ -300,23 +300,23 @@ def speech_reverberation_modulation_energy_ratio(
padding = (0, max(ceil(time / w_inc) * w_inc - time, w_length - time))
mod_out_pad = pad(mod_out, pad=padding, mode="constant", value=0)
mod_out_frame = mod_out_pad.unfold(-1, w_length, w_inc)
energy = ((mod_out_frame[..., :n_frames, :] * w) ** 2).sum(dim=-1) # [B, N_filters, 8, n_frames]
energy = ((mod_out_frame[..., :num_frames, :] * w) ** 2).sum(dim=-1) # [B, N_filters, 8, n_frames]

if norm:
energy = _normalize_energy(energy)

erbs = torch.flipud(_calc_erbs(low_freq, fs, n_cochlear_filters, device=preds.device))

avg_energy = torch.mean(energy, dim=-1)
total_energy = torch.sum(avg_energy.reshape(n_batch, -1), dim=-1)
total_energy = torch.sum(avg_energy.reshape(num_batch, -1), dim=-1)
ac_energy = torch.sum(avg_energy, dim=2)
ac_perc = ac_energy * 100 / total_energy.reshape(-1, 1)
ac_perc_cumsum = ac_perc.flip(-1).cumsum(-1)
k90perc_idx = torch.nonzero((ac_perc_cumsum > 90).cumsum(-1) == 1)[:, 1]
bw = erbs[k90perc_idx]

temp = []
for b in range(n_batch):
for b in range(num_batch):
score = _cal_srmr_score(bw[b], avg_energy[b], cutoffs=cutoffs)
temp.append(score)
score = torch.stack(temp)
Expand Down
Loading

0 comments on commit 25bf259

Please sign in to comment.