Skip to content

Commit 6825831

Browse files
committed
Tilde Expansions
1 parent 0f6a1d9 commit 6825831

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

src/tilde_expansions/__init__.py

Whitespace-only changes.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import argparse
2+
from src.deep_impact.models.original import DeepImpact
3+
from datasets import load_dataset
4+
from tqdm import tqdm
5+
6+
def create_collection(original_collection_path, output_collection_path):
7+
with open(original_collection_path, 'r', encoding='utf-8') as f:
8+
original_collection = [line.strip().split('\t') for line in f]
9+
10+
expansions = load_dataset("pxyu/MSMARCO-TILDE-Top200-CSV300k")['train']
11+
already_present = 0
12+
13+
with open(output_collection_path, 'w', encoding='utf-8') as f, tqdm(total=len(original_collection)) as pbar:
14+
for i, (passage, passage_expansions) in enumerate(zip(original_collection, expansions)):
15+
assert passage[0] == passage_expansions['pid']
16+
pre_tokenized_str = DeepImpact.tokenizer.pre_tokenizer.pre_tokenize_str(passage[1])
17+
terms = {x[0] for x in pre_tokenized_str}
18+
string_ = ' [SEP]'
19+
20+
for term in passage_expansions['psg']:
21+
if term not in terms:
22+
string_ += ' ' + term
23+
else:
24+
already_present += 1
25+
26+
f.write(passage[0] + '\t' + passage[1] + string_ + '\n')
27+
28+
pbar.update(1)
29+
pbar.set_description(f"Average duplicate terms per passage: {already_present / (i + 1):.2f}")
30+
31+
32+
if __name__ == '__main__':
33+
parser = argparse.ArgumentParser(description='Create tilde expanded collection.')
34+
parser.add_argument('--original_collection_path', type=str, help='Path to the original collection')
35+
parser.add_argument('--output_collection_path', type=str, help='Path to the output collection')
36+
args = parser.parse_args()
37+
create_collection(args.original_collection_path, args.output_collection_path)

0 commit comments

Comments
 (0)