@@ -100,9 +100,30 @@ def translate(translation_file, i18ns):
100100 translations .append ((original , translation ))
101101 return translations
102102
103+ def frequent_ngrams (corpus , sz , n ):
104+ return collections .Counter (corpus [i :i + sz ] for i in range (len (corpus )- sz )).most_common (n )
105+
106+ def encode_ngrams (translation , ngrams ):
107+ if len (ngrams ) > 32 :
108+ start = 0xe000
109+ else :
110+ start = 0x80
111+ for i , g in enumerate (ngrams ):
112+ translation = translation .replace (g , chr (start + i ))
113+ return translation
114+
115+ def decode_ngrams (compressed , ngrams ):
116+ if len (ngrams ) > 32 :
117+ start , end = 0xe000 , 0xf8ff
118+ else :
119+ start , end = 0x80 , 0x9f
120+ return "" .join (ngrams [ord (c ) - start ] if (start <= ord (c ) <= end ) else c for c in compressed )
121+
103122def compute_huffman_coding (translations , qstrs , compression_filename ):
104123 all_strings = [x [1 ] for x in translations ]
105124 all_strings_concat = "" .join (all_strings )
125+ ngrams = [i [0 ] for i in frequent_ngrams (all_strings_concat , 2 , 32 )]
126+ all_strings_concat = encode_ngrams (all_strings_concat , ngrams )
106127 counts = collections .Counter (all_strings_concat )
107128 cb = huffman .codebook (counts .items ())
108129 values = []
@@ -125,21 +146,31 @@ def compute_huffman_coding(translations, qstrs, compression_filename):
125146 last_l = l
126147 lengths = bytearray ()
127148 print ("// length count" , length_count )
149+ print ("// bigrams" , ngrams )
128150 for i in range (1 , max (length_count ) + 2 ):
129151 lengths .append (length_count .get (i , 0 ))
130152 print ("// values" , values , "lengths" , len (lengths ), lengths )
131- print ("// estimated total memory size" , len (lengths ) + 2 * len (values ) + sum (len (cb [u ]) for u in all_strings_concat ))
153+ ngramdata = [ord (ni ) for i in ngrams for ni in i ]
154+ print ("// estimated total memory size" , len (lengths ) + 2 * len (values ) + 2 * len (ngramdata ) + sum ((len (cb [u ]) + 7 )// 8 for u in all_strings_concat ))
132155 print ("//" , values , lengths )
133156 values_type = "uint16_t" if max (ord (u ) for u in values ) > 255 else "uint8_t"
134157 max_translation_encoded_length = max (len (translation .encode ("utf-8" )) for original ,translation in translations )
135158 with open (compression_filename , "w" ) as f :
136159 f .write ("const uint8_t lengths[] = {{ {} }};\n " .format (", " .join (map (str , lengths ))))
137160 f .write ("const {} values[] = {{ {} }};\n " .format (values_type , ", " .join (str (ord (u )) for u in values )))
138161 f .write ("#define compress_max_length_bits ({})\n " .format (max_translation_encoded_length .bit_length ()))
139- return values , lengths
162+ f .write ("const {} bigrams[] = {{ {} }};\n " .format (values_type , ", " .join (str (u ) for u in ngramdata )))
163+ if len (ngrams ) > 32 :
164+ bigram_start = 0xe000
165+ else :
166+ bigram_start = 0x80
167+ bigram_end = bigram_start + len (ngrams ) - 1 # End is inclusive
168+ f .write ("#define bigram_start {}\n " .format (bigram_start ))
169+ f .write ("#define bigram_end {}\n " .format (bigram_end ))
170+ return values , lengths , ngrams
140171
141172def decompress (encoding_table , encoded , encoded_length_bits ):
142- values , lengths = encoding_table
173+ values , lengths , ngrams = encoding_table
143174 dec = []
144175 this_byte = 0
145176 this_bit = 7
@@ -187,14 +218,16 @@ def decompress(encoding_table, encoded, encoded_length_bits):
187218 searched_length += lengths [bit_length ]
188219
189220 v = values [searched_length + bits - max_code ]
221+ v = decode_ngrams (v , ngrams )
190222 i += len (v .encode ('utf-8' ))
191223 dec .append (v )
192224 return '' .join (dec )
193225
194226def compress (encoding_table , decompressed , encoded_length_bits , len_translation_encoded ):
195227 if not isinstance (decompressed , str ):
196228 raise TypeError ()
197- values , lengths = encoding_table
229+ values , lengths , ngrams = encoding_table
230+ decompressed = encode_ngrams (decompressed , ngrams )
198231 enc = bytearray (len (decompressed ) * 3 )
199232 #print(decompressed)
200233 #print(lengths)
0 commit comments