Skip to content

Commit b39b266

Browse files
authored
support jiwer 4.0 (#685)
* support jiwer 4.0 * Update cer.py * style * fix xtreme_s
1 parent 5aa3982 commit b39b266

File tree

3 files changed

+146
-86
lines changed

3 files changed

+146
-86
lines changed

metrics/cer/cer.py

Lines changed: 75 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -29,39 +29,41 @@
2929
else:
3030
import importlib.metadata as importlib_metadata
3131

32-
33-
SENTENCE_DELIMITER = ""
34-
35-
36-
if version.parse(importlib_metadata.version("jiwer")) < version.parse("2.3.0"):
37-
38-
class SentencesToListOfCharacters(tr.AbstractTransform):
39-
def __init__(self, sentence_delimiter: str = " "):
40-
self.sentence_delimiter = sentence_delimiter
41-
42-
def process_string(self, s: str):
43-
return list(s)
44-
45-
def process_list(self, inp: List[str]):
46-
chars = []
47-
for sent_idx, sentence in enumerate(inp):
48-
chars.extend(self.process_string(sentence))
49-
if self.sentence_delimiter is not None and self.sentence_delimiter != "" and sent_idx < len(inp) - 1:
50-
chars.append(self.sentence_delimiter)
51-
return chars
52-
53-
cer_transform = tr.Compose(
54-
[tr.RemoveMultipleSpaces(), tr.Strip(), SentencesToListOfCharacters(SENTENCE_DELIMITER)]
55-
)
56-
else:
57-
cer_transform = tr.Compose(
58-
[
59-
tr.RemoveMultipleSpaces(),
60-
tr.Strip(),
61-
tr.ReduceToSingleSentence(SENTENCE_DELIMITER),
62-
tr.ReduceToListOfListOfChars(),
63-
]
64-
)
32+
if hasattr(jiwer, "compute_measures"):
33+
SENTENCE_DELIMITER = ""
34+
if version.parse(importlib_metadata.version("jiwer")) < version.parse("2.3.0"):
35+
36+
class SentencesToListOfCharacters(tr.AbstractTransform):
37+
def __init__(self, sentence_delimiter: str = " "):
38+
self.sentence_delimiter = sentence_delimiter
39+
40+
def process_string(self, s: str):
41+
return list(s)
42+
43+
def process_list(self, inp: List[str]):
44+
chars = []
45+
for sent_idx, sentence in enumerate(inp):
46+
chars.extend(self.process_string(sentence))
47+
if (
48+
self.sentence_delimiter is not None
49+
and self.sentence_delimiter != ""
50+
and sent_idx < len(inp) - 1
51+
):
52+
chars.append(self.sentence_delimiter)
53+
return chars
54+
55+
cer_transform = tr.Compose(
56+
[tr.RemoveMultipleSpaces(), tr.Strip(), SentencesToListOfCharacters(SENTENCE_DELIMITER)]
57+
)
58+
else:
59+
cer_transform = tr.Compose(
60+
[
61+
tr.RemoveMultipleSpaces(),
62+
tr.Strip(),
63+
tr.ReduceToSingleSentence(SENTENCE_DELIMITER),
64+
tr.ReduceToListOfListOfChars(),
65+
]
66+
)
6567

6668

6769
_CITATION = """\
@@ -136,24 +138,43 @@ def _info(self):
136138
)
137139

138140
def _compute(self, predictions, references, concatenate_texts=False):
139-
if concatenate_texts:
140-
return jiwer.compute_measures(
141-
references,
142-
predictions,
143-
truth_transform=cer_transform,
144-
hypothesis_transform=cer_transform,
145-
)["wer"]
146-
147-
incorrect = 0
148-
total = 0
149-
for prediction, reference in zip(predictions, references):
150-
measures = jiwer.compute_measures(
151-
reference,
152-
prediction,
153-
truth_transform=cer_transform,
154-
hypothesis_transform=cer_transform,
155-
)
156-
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
157-
total += measures["substitutions"] + measures["deletions"] + measures["hits"]
158-
159-
return incorrect / total
141+
if hasattr(jiwer, "compute_measures"):
142+
if concatenate_texts:
143+
return jiwer.compute_measures(
144+
references,
145+
predictions,
146+
truth_transform=cer_transform,
147+
hypothesis_transform=cer_transform,
148+
)["wer"]
149+
150+
incorrect = 0
151+
total = 0
152+
for prediction, reference in zip(predictions, references):
153+
measures = jiwer.compute_measures(
154+
reference,
155+
prediction,
156+
truth_transform=cer_transform,
157+
hypothesis_transform=cer_transform,
158+
)
159+
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
160+
total += measures["substitutions"] + measures["deletions"] + measures["hits"]
161+
162+
return incorrect / total
163+
else:
164+
if concatenate_texts:
165+
return jiwer.process_characters(
166+
references,
167+
predictions,
168+
).cer
169+
170+
incorrect = 0
171+
total = 0
172+
for prediction, reference in zip(predictions, references):
173+
measures = jiwer.process_characters(
174+
reference,
175+
prediction,
176+
)
177+
incorrect += measures.substitutions + measures.deletions + measures.insertions
178+
total += measures.substitutions + measures.deletions + measures.hits
179+
180+
return incorrect / total

metrics/wer/wer.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
""" Word Error Ratio (WER) metric. """
1515

1616
import datasets
17-
from jiwer import compute_measures
17+
import jiwer
1818

1919
import evaluate
2020

@@ -94,13 +94,25 @@ def _info(self):
9494
)
9595

9696
def _compute(self, predictions=None, references=None, concatenate_texts=False):
97-
if concatenate_texts:
98-
return compute_measures(references, predictions)["wer"]
97+
if hasattr(jiwer, "compute_measures"):
98+
if concatenate_texts:
99+
return jiwer.compute_measures(references, predictions)["wer"]
100+
else:
101+
incorrect = 0
102+
total = 0
103+
for prediction, reference in zip(predictions, references):
104+
measures = jiwer.compute_measures(reference, prediction)
105+
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
106+
total += measures["substitutions"] + measures["deletions"] + measures["hits"]
107+
return incorrect / total
99108
else:
100-
incorrect = 0
101-
total = 0
102-
for prediction, reference in zip(predictions, references):
103-
measures = compute_measures(reference, prediction)
104-
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
105-
total += measures["substitutions"] + measures["deletions"] + measures["hits"]
106-
return incorrect / total
109+
if concatenate_texts:
110+
return jiwer.process_words(references, predictions).wer
111+
else:
112+
incorrect = 0
113+
total = 0
114+
for prediction, reference in zip(predictions, references):
115+
measures = jiwer.process_words(reference, prediction)
116+
incorrect += measures.substitutions + measures.deletions + measures.insertions
117+
total += measures.substitutions + measures.deletions + measures.hits
118+
return incorrect / total

metrics/xtreme_s/xtreme_s.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,14 @@
9191
SENTENCE_DELIMITER = ""
9292

9393
try:
94-
from jiwer import transforms as tr
94+
import jiwer
9595

9696
_jiwer_available = True
9797
except ImportError:
9898
_jiwer_available = False
9999

100100
if _jiwer_available and version.parse(importlib_metadata.version("jiwer")) < version.parse("2.3.0"):
101+
from jiwer import transforms as tr
101102

102103
class SentencesToListOfCharacters(tr.AbstractTransform):
103104
def __init__(self, sentence_delimiter: str = " "):
@@ -117,7 +118,9 @@ def process_list(self, inp: List[str]):
117118
cer_transform = tr.Compose(
118119
[tr.RemoveMultipleSpaces(), tr.Strip(), SentencesToListOfCharacters(SENTENCE_DELIMITER)]
119120
)
120-
elif _jiwer_available:
121+
elif _jiwer_available and hasattr(jiwer, "compute_measures"):
122+
from jiwer import transforms as tr
123+
121124
cer_transform = tr.Compose(
122125
[
123126
tr.RemoveMultipleSpaces(),
@@ -187,35 +190,59 @@ def bleu(
187190

188191
def wer_and_cer(preds, labels, concatenate_texts, config_name):
189192
try:
190-
from jiwer import compute_measures
193+
import jiwer
191194
except ImportError:
192195
raise ValueError(
193196
f"jiwer has to be installed in order to apply the wer metric for {config_name}."
194197
"You can install it via `pip install jiwer`."
195198
)
196199

197-
if concatenate_texts:
198-
wer = compute_measures(labels, preds)["wer"]
200+
if hasattr(jiwer, "compute_measures"):
201+
if concatenate_texts:
202+
wer = jiwer.compute_measures(labels, preds)["wer"]
199203

200-
cer = compute_measures(labels, preds, truth_transform=cer_transform, hypothesis_transform=cer_transform)["wer"]
201-
return {"wer": wer, "cer": cer}
204+
cer = jiwer.compute_measures(
205+
labels, preds, truth_transform=cer_transform, hypothesis_transform=cer_transform
206+
)["wer"]
207+
return {"wer": wer, "cer": cer}
208+
else:
209+
210+
def compute_score(preds, labels, score_type="wer"):
211+
incorrect = 0
212+
total = 0
213+
for prediction, reference in zip(preds, labels):
214+
if score_type == "wer":
215+
measures = jiwer.compute_measures(reference, prediction)
216+
elif score_type == "cer":
217+
measures = jiwer.compute_measures(
218+
reference, prediction, truth_transform=cer_transform, hypothesis_transform=cer_transform
219+
)
220+
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
221+
total += measures["substitutions"] + measures["deletions"] + measures["hits"]
222+
return incorrect / total
223+
224+
return {"wer": compute_score(preds, labels, "wer"), "cer": compute_score(preds, labels, "cer")}
202225
else:
226+
if concatenate_texts:
227+
wer = jiwer.process_words(labels, preds).wer
228+
229+
cer = jiwer.process_characters(labels, preds).cer
230+
return {"wer": wer, "cer": cer}
231+
else:
203232

204-
def compute_score(preds, labels, score_type="wer"):
205-
incorrect = 0
206-
total = 0
207-
for prediction, reference in zip(preds, labels):
208-
if score_type == "wer":
209-
measures = compute_measures(reference, prediction)
210-
elif score_type == "cer":
211-
measures = compute_measures(
212-
reference, prediction, truth_transform=cer_transform, hypothesis_transform=cer_transform
213-
)
214-
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
215-
total += measures["substitutions"] + measures["deletions"] + measures["hits"]
216-
return incorrect / total
217-
218-
return {"wer": compute_score(preds, labels, "wer"), "cer": compute_score(preds, labels, "cer")}
233+
def compute_score(preds, labels, score_type="wer"):
234+
incorrect = 0
235+
total = 0
236+
for prediction, reference in zip(preds, labels):
237+
if score_type == "wer":
238+
measures = jiwer.process_words(reference, prediction)
239+
elif score_type == "cer":
240+
measures = jiwer.process_characters(reference, prediction)
241+
incorrect += measures.substitutions + measures.deletions + measures.insertions
242+
total += measures.substitutions + measures.deletions + measures.hits
243+
return incorrect / total
244+
245+
return {"wer": compute_score(preds, labels, "wer"), "cer": compute_score(preds, labels, "cer")}
219246

220247

221248
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)

0 commit comments

Comments
 (0)