1
+ import argparse
2
+ from src .deep_impact .models .original import DeepImpact
3
+ from datasets import load_dataset
4
+ from tqdm import tqdm
5
+
6
+ def create_collection (original_collection_path , output_collection_path ):
7
+ with open (original_collection_path , 'r' , encoding = 'utf-8' ) as f :
8
+ original_collection = [line .strip ().split ('\t ' ) for line in f ]
9
+
10
+ expansions = load_dataset ("pxyu/MSMARCO-TILDE-Top200-CSV300k" )['train' ]
11
+ already_present = 0
12
+
13
+ with open (output_collection_path , 'w' , encoding = 'utf-8' ) as f , tqdm (total = len (original_collection )) as pbar :
14
+ for i , (passage , passage_expansions ) in enumerate (zip (original_collection , expansions )):
15
+ assert passage [0 ] == passage_expansions ['pid' ]
16
+ pre_tokenized_str = DeepImpact .tokenizer .pre_tokenizer .pre_tokenize_str (passage [1 ])
17
+ terms = {x [0 ] for x in pre_tokenized_str }
18
+ string_ = ' [SEP]'
19
+
20
+ for term in passage_expansions ['psg' ]:
21
+ if term not in terms :
22
+ string_ += ' ' + term
23
+ else :
24
+ already_present += 1
25
+
26
+ f .write (passage [0 ] + '\t ' + passage [1 ] + string_ + '\n ' )
27
+
28
+ pbar .update (1 )
29
+ pbar .set_description (f"Average duplicate terms per passage: { already_present / (i + 1 ):.2f} " )
30
+
31
+
32
+ if __name__ == '__main__' :
33
+ parser = argparse .ArgumentParser (description = 'Create tilde expanded collection.' )
34
+ parser .add_argument ('--original_collection_path' , type = str , help = 'Path to the original collection' )
35
+ parser .add_argument ('--output_collection_path' , type = str , help = 'Path to the output collection' )
36
+ args = parser .parse_args ()
37
+ create_collection (args .original_collection_path , args .output_collection_path )
0 commit comments