@@ -16,10 +16,18 @@ def parse_args():
16
16
default = 'cer' ,
17
17
choices = ['cer' , 'wer' ],
18
18
help = "Error rate type. (default: %(default)s)" )
19
+ parser .add_argument (
20
+ '--special_tokens' ,
21
+ type = str ,
22
+ default = '<SPOKEN_NOISE>' ,
23
+ help = "Special tokens in scoring CER, seperated by space. "
24
+ "They shouldn't be splitted and should be treated as one special "
25
+ "character. Example: '<SPOKEN_NOISE> <bos> <eos>' "
26
+ "(default: %(default)s)" )
19
27
parser .add_argument (
20
28
'--ref' , type = str , required = True , help = "The ground truth text." )
21
29
parser .add_argument (
22
- '--hyp' , type = str , required = True , help = "The decoding result." )
30
+ '--hyp' , type = str , required = True , help = "The decoding result text ." )
23
31
args = parser .parse_args ()
24
32
return args
25
33
@@ -31,6 +39,8 @@ def parse_args():
31
39
sum_errors , sum_ref_len = 0.0 , 0
32
40
sent_cnt , not_in_ref_cnt = 0 , 0
33
41
42
+ special_tokens = args .special_tokens .split (" " )
43
+
34
44
with open (args .ref , "r" ) as ref_txt :
35
45
line = ref_txt .readline ()
36
46
while line :
@@ -51,6 +61,8 @@ def parse_args():
51
61
continue
52
62
53
63
if args .error_rate_type == 'cer' :
64
+ for sp_tok in special_tokens :
65
+ sent = sent .replace (sp_tok , '\0 ' )
54
66
errors , ref_len = char_errors (
55
67
ref_dict [key ].decode ("utf8" ),
56
68
sent .decode ("utf8" ),
0 commit comments