|
18 | 18 |
|
19 | 19 | import logging |
20 | 20 | from types import MappingProxyType |
21 | | -from typing import Optional |
| 21 | +from typing import List, Optional, Union |
22 | 22 |
|
23 | 23 | from sparsezoo import Model |
24 | 24 |
|
@@ -75,21 +75,41 @@ def _accuracy(model: Model, metric_name=None) -> float: |
75 | 75 |
|
76 | 76 | if metric_name is not None: |
77 | 77 | for result in validation_results: |
78 | | - if metric_name in result.recorded_units.lower(): |
| 78 | + if _metric_name_matches(metric_name, result.recorded_units.lower()): |
79 | 79 | return result.recorded_value |
80 | 80 | _LOGGER.info(f"metric name {metric_name} not found for model {model}") |
81 | 81 |
|
82 | 82 | # fallback to if any accuracy metric found |
83 | 83 | accuracy_metrics = ["accuracy", "f1", "recall", "map", "top1 accuracy"] |
84 | 84 | for result in validation_results: |
85 | | - if result.recorded_units.lower() in accuracy_metrics: |
| 85 | + if _metric_name_matches(result.recorded_units.lower(), accuracy_metrics): |
86 | 86 | return result.recorded_value |
87 | 87 |
|
88 | 88 | raise ValueError( |
89 | 89 | f"Could not find any accuracy metric {accuracy_metrics} for model {model}" |
90 | 90 | ) |
91 | 91 |
|
92 | 92 |
|
| 93 | +def _metric_name_matches( |
| 94 | + metric_name: str, target_metrics: Union[str, List[str]] |
| 95 | +) -> bool: |
| 96 | + # returns true if metric name is included in the target metrics |
| 97 | + if isinstance(target_metrics, str): |
| 98 | + target_metrics = [target_metrics] |
| 99 | + return any( |
| 100 | + _standardized_str_eq(metric_name, target_metric) |
| 101 | + for target_metric in target_metrics |
| 102 | + ) |
| 103 | + |
| 104 | + |
| 105 | +def _standardized_str_eq(str_1: str, str_2: str) -> bool: |
| 106 | + # strings are equal if lowercase, striped of spaces, -, and _ are equal |
| 107 | + def _standardize(string): |
| 108 | + return string.lower().replace(" ", "").replace("-", "").replace("_", "") |
| 109 | + |
| 110 | + return _standardize(str_1) == _standardize(str_2) |
| 111 | + |
| 112 | + |
93 | 113 | EXTRACTORS = MappingProxyType( |
94 | 114 | { |
95 | 115 | "compression": _size, |
|
0 commit comments