91
91
SENTENCE_DELIMITER = ""
92
92
93
93
try :
94
- from jiwer import transforms as tr
94
+ import jiwer
95
95
96
96
_jiwer_available = True
97
97
except ImportError :
98
98
_jiwer_available = False
99
99
100
100
if _jiwer_available and version .parse (importlib_metadata .version ("jiwer" )) < version .parse ("2.3.0" ):
101
+ from jiwer import transforms as tr
101
102
102
103
class SentencesToListOfCharacters (tr .AbstractTransform ):
103
104
def __init__ (self , sentence_delimiter : str = " " ):
@@ -117,7 +118,9 @@ def process_list(self, inp: List[str]):
117
118
cer_transform = tr .Compose (
118
119
[tr .RemoveMultipleSpaces (), tr .Strip (), SentencesToListOfCharacters (SENTENCE_DELIMITER )]
119
120
)
120
- elif _jiwer_available :
121
+ elif _jiwer_available and hasattr (jiwer , "compute_measures" ):
122
+ from jiwer import transforms as tr
123
+
121
124
cer_transform = tr .Compose (
122
125
[
123
126
tr .RemoveMultipleSpaces (),
@@ -187,35 +190,59 @@ def bleu(
187
190
188
191
def wer_and_cer (preds , labels , concatenate_texts , config_name ):
189
192
try :
190
- from jiwer import compute_measures
193
+ import jiwer
191
194
except ImportError :
192
195
raise ValueError (
193
196
f"jiwer has to be installed in order to apply the wer metric for { config_name } ."
194
197
"You can install it via `pip install jiwer`."
195
198
)
196
199
197
- if concatenate_texts :
198
- wer = compute_measures (labels , preds )["wer" ]
200
+ if hasattr (jiwer , "compute_measures" ):
201
+ if concatenate_texts :
202
+ wer = jiwer .compute_measures (labels , preds )["wer" ]
199
203
200
- cer = compute_measures (labels , preds , truth_transform = cer_transform , hypothesis_transform = cer_transform )["wer" ]
201
- return {"wer" : wer , "cer" : cer }
204
+ cer = jiwer .compute_measures (
205
+ labels , preds , truth_transform = cer_transform , hypothesis_transform = cer_transform
206
+ )["wer" ]
207
+ return {"wer" : wer , "cer" : cer }
208
+ else :
209
+
210
+ def compute_score (preds , labels , score_type = "wer" ):
211
+ incorrect = 0
212
+ total = 0
213
+ for prediction , reference in zip (preds , labels ):
214
+ if score_type == "wer" :
215
+ measures = jiwer .compute_measures (reference , prediction )
216
+ elif score_type == "cer" :
217
+ measures = jiwer .compute_measures (
218
+ reference , prediction , truth_transform = cer_transform , hypothesis_transform = cer_transform
219
+ )
220
+ incorrect += measures ["substitutions" ] + measures ["deletions" ] + measures ["insertions" ]
221
+ total += measures ["substitutions" ] + measures ["deletions" ] + measures ["hits" ]
222
+ return incorrect / total
223
+
224
+ return {"wer" : compute_score (preds , labels , "wer" ), "cer" : compute_score (preds , labels , "cer" )}
202
225
else :
226
+ if concatenate_texts :
227
+ wer = jiwer .process_words (labels , preds ).wer
228
+
229
+ cer = jiwer .process_characters (labels , preds ).cer
230
+ return {"wer" : wer , "cer" : cer }
231
+ else :
203
232
204
- def compute_score (preds , labels , score_type = "wer" ):
205
- incorrect = 0
206
- total = 0
207
- for prediction , reference in zip (preds , labels ):
208
- if score_type == "wer" :
209
- measures = compute_measures (reference , prediction )
210
- elif score_type == "cer" :
211
- measures = compute_measures (
212
- reference , prediction , truth_transform = cer_transform , hypothesis_transform = cer_transform
213
- )
214
- incorrect += measures ["substitutions" ] + measures ["deletions" ] + measures ["insertions" ]
215
- total += measures ["substitutions" ] + measures ["deletions" ] + measures ["hits" ]
216
- return incorrect / total
217
-
218
- return {"wer" : compute_score (preds , labels , "wer" ), "cer" : compute_score (preds , labels , "cer" )}
233
+ def compute_score (preds , labels , score_type = "wer" ):
234
+ incorrect = 0
235
+ total = 0
236
+ for prediction , reference in zip (preds , labels ):
237
+ if score_type == "wer" :
238
+ measures = jiwer .process_words (reference , prediction )
239
+ elif score_type == "cer" :
240
+ measures = jiwer .process_characters (reference , prediction )
241
+ incorrect += measures .substitutions + measures .deletions + measures .insertions
242
+ total += measures .substitutions + measures .deletions + measures .hits
243
+ return incorrect / total
244
+
245
+ return {"wer" : compute_score (preds , labels , "wer" ), "cer" : compute_score (preds , labels , "cer" )}
219
246
220
247
221
248
@evaluate .utils .file_utils .add_start_docstrings (_DESCRIPTION , _KWARGS_DESCRIPTION )
0 commit comments