Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor tolerance checking approach and tweak yaml configs #2592

Merged
merged 29 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f8a96f5
Tweak yaml files
lintool Sep 7, 2024
360a1d7
Merge branch 'master' of github.com:castorini/anserini into refactoring
lintool Sep 7, 2024
a94e970
Initial addition of tolerance scores in yaml
lintool Sep 10, 2024
32c22de
Merge branch 'master' into refactoring
lintool Sep 11, 2024
9552f0c
sync, regressions for flat-int8
lintool Sep 11, 2024
1ec43ea
flat-int8 onnx
lintool Sep 11, 2024
8269ae5
more tweaks, ms marco.
lintool Sep 11, 2024
4c6b7db
Fixed metrics typo.
lintool Sep 11, 2024
a15106e
tweaks to flat scores for msmarco.
lintool Sep 12, 2024
51d6dd5
initial hnsw cached beir scores.
lintool Sep 12, 2024
530bd03
adjust based on scores
lintool Sep 12, 2024
1ec9b7c
add rest of hnsw scores.
lintool Sep 12, 2024
e2716b2
Tweak tolerance.
lintool Sep 13, 2024
deee1ee
More tolearance tweaks.
lintool Sep 13, 2024
a6578a3
Added hnsw stubs.
lintool Sep 13, 2024
e9c2402
tweaks
lintool Sep 13, 2024
7db41f5
another round of tweaks.
lintool Sep 13, 2024
5996a8e
more tolerance tweaks, removed dead code.
lintool Sep 13, 2024
33ed276
tolerance tweaks
lintool Sep 15, 2024
4f1f1c3
Initial cohere metrics.
lintool Sep 15, 2024
a66c7ab
calibrating cohere
lintool Sep 15, 2024
c1bfce5
cohere tweaks.
lintool Sep 15, 2024
c1e27f5
Merge branch 'master' into refactoring
lintool Sep 15, 2024
69fed51
refresh
lintool Sep 15, 2024
87d17b5
Fixed test.
lintool Sep 15, 2024
4415737
More score tweaks.
lintool Sep 17, 2024
2ccc9cc
another round of tweaks.
lintool Sep 18, 2024
a8ae8f9
more tweaks.
lintool Sep 18, 2024
5ce0101
one final round of tweaks.
lintool Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added hnsw stubs.
  • Loading branch information
lintool committed Sep 13, 2024
commit a6578a31aed89deed95a127b86edcd4d54c79003
253 changes: 130 additions & 123 deletions src/main/python/run_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,58 +212,58 @@ def construct_convert_commands(yaml_data):
# 'flat-cached': beir_flat_cached,
# }

beir_hnsw_int8_onnx = defaultdict(lambda: 0.005)
beir_hnsw_int8_onnx['ArguAna'] = 0.03
beir_hnsw_int8_onnx['BioASQ'] = 0.02
beir_hnsw_int8_onnx['DBPedia'] = 0.007
beir_hnsw_int8_onnx['FiQA-2018'] = 0.007
beir_hnsw_int8_onnx['HotpotQA'] = 0.008
beir_hnsw_int8_onnx['NFCorpus'] = 0.006
beir_hnsw_int8_onnx['Robust04'] = 0.006
beir_hnsw_int8_onnx['Signal-1M'] = 0.04
beir_hnsw_int8_onnx['TREC-NEWS'] = 0.02
beir_hnsw_int8_onnx['Webis-Touche2020'] = 0.01

beir_hnsw_int8_cached = defaultdict(lambda: 0.005)
beir_hnsw_int8_cached['BioASQ'] = 0.02
beir_hnsw_int8_cached['FiQA-2018'] = 0.007
beir_hnsw_int8_cached['HotpotQA'] = 0.007
beir_hnsw_int8_cached['Signal-1M'] = 0.04
beir_hnsw_int8_cached['TREC-NEWS'] = 0.02
beir_hnsw_int8_cached['Webis-Touche2020'] = 0.006

beir_hnsw_onnx = defaultdict(lambda: 0.003)
beir_hnsw_onnx['ArguAna'] = 0.02
beir_hnsw_onnx['BioASQ'] = 0.01
beir_hnsw_onnx['CQADupStack-wordpress'] = 0.004
beir_hnsw_onnx['DBPedia'] = 0.006
beir_hnsw_onnx['FEVER'] = 0.007
beir_hnsw_onnx['FiQA-2018'] = 0.007
beir_hnsw_onnx['HotpotQA'] = 0.007
beir_hnsw_onnx['Robust04'] = 0.004
beir_hnsw_onnx['Signal-1M'] = 0.05
beir_hnsw_onnx['TREC-NEWS'] = 0.02

beir_hnsw_cached = defaultdict(lambda: 0.003)
beir_hnsw_cached['BioASQ'] = 0.01
beir_hnsw_cached['DBPedia'] = 0.006
beir_hnsw_cached['FEVER'] = 0.008
beir_hnsw_cached['FiQA-2018'] = 0.008
beir_hnsw_cached['HotpotQA'] = 0.007
beir_hnsw_cached['Signal-1M'] = 0.05
beir_hnsw_cached['TREC-NEWS'] = 0.025

beir_hnsw_tolerance = {
'hnsw-int8-onnx': beir_hnsw_int8_onnx,
'hnsw-int8-cached': beir_hnsw_int8_cached,
'hnsw-onnx': beir_hnsw_onnx,
'hnsw-cached': beir_hnsw_cached,
}
# beir_hnsw_int8_onnx = defaultdict(lambda: 0.005)
# beir_hnsw_int8_onnx['ArguAna'] = 0.03
# beir_hnsw_int8_onnx['BioASQ'] = 0.02
# beir_hnsw_int8_onnx['DBPedia'] = 0.007
# beir_hnsw_int8_onnx['FiQA-2018'] = 0.007
# beir_hnsw_int8_onnx['HotpotQA'] = 0.008
# beir_hnsw_int8_onnx['NFCorpus'] = 0.006
# beir_hnsw_int8_onnx['Robust04'] = 0.006
# beir_hnsw_int8_onnx['Signal-1M'] = 0.04
# beir_hnsw_int8_onnx['TREC-NEWS'] = 0.02
# beir_hnsw_int8_onnx['Webis-Touche2020'] = 0.01
#
# beir_hnsw_int8_cached = defaultdict(lambda: 0.005)
# beir_hnsw_int8_cached['BioASQ'] = 0.02
# beir_hnsw_int8_cached['FiQA-2018'] = 0.007
# beir_hnsw_int8_cached['HotpotQA'] = 0.007
# beir_hnsw_int8_cached['Signal-1M'] = 0.04
# beir_hnsw_int8_cached['TREC-NEWS'] = 0.02
# beir_hnsw_int8_cached['Webis-Touche2020'] = 0.006
#
# beir_hnsw_onnx = defaultdict(lambda: 0.003)
# beir_hnsw_onnx['ArguAna'] = 0.02
# beir_hnsw_onnx['BioASQ'] = 0.01
# beir_hnsw_onnx['CQADupStack-wordpress'] = 0.004
# beir_hnsw_onnx['DBPedia'] = 0.006
# beir_hnsw_onnx['FEVER'] = 0.007
# beir_hnsw_onnx['FiQA-2018'] = 0.007
# beir_hnsw_onnx['HotpotQA'] = 0.007
# beir_hnsw_onnx['Robust04'] = 0.004
# beir_hnsw_onnx['Signal-1M'] = 0.05
# beir_hnsw_onnx['TREC-NEWS'] = 0.02
#
# beir_hnsw_cached = defaultdict(lambda: 0.003)
# beir_hnsw_cached['BioASQ'] = 0.01
# beir_hnsw_cached['DBPedia'] = 0.006
# beir_hnsw_cached['FEVER'] = 0.008
# beir_hnsw_cached['FiQA-2018'] = 0.008
# beir_hnsw_cached['HotpotQA'] = 0.007
# beir_hnsw_cached['Signal-1M'] = 0.05
# beir_hnsw_cached['TREC-NEWS'] = 0.025
#
# beir_hnsw_tolerance = {
# 'hnsw-int8-onnx': beir_hnsw_int8_onnx,
# 'hnsw-int8-cached': beir_hnsw_int8_cached,
# 'hnsw-onnx': beir_hnsw_onnx,
# 'hnsw-cached': beir_hnsw_cached,
# }

#flat_model_type_pattern = re.compile(r'(flat-int8-onnx|flat-int8-cached|flat-onnx|flat-cached)$')
hnsw_model_type_pattern = re.compile(r'(hnsw-int8-onnx|hnsw-int8-cached|hnsw-onnx|hnsw-cached)$')

beir_dataset_pattern = re.compile(r'BEIR \(v1.0.0\): (.*)$')
# hnsw_model_type_pattern = re.compile(r'(hnsw-int8-onnx|hnsw-int8-cached|hnsw-onnx|hnsw-cached)$')
#
# beir_dataset_pattern = re.compile(r'BEIR \(v1.0.0\): (.*)$')

# msmarco_v1_flat_int8_onnx = defaultdict(lambda: 0.002)
# msmarco_v1_flat_int8_cached = defaultdict(lambda: 0.002)
Expand Down Expand Up @@ -313,52 +313,52 @@ def construct_convert_commands(yaml_data):
# 'flat-cached': dl20_flat_cached,
# }

msmarco_v1_hnsw_int8_onnx = defaultdict(lambda: 0.01)
msmarco_v1_hnsw_int8_cached = defaultdict(lambda: 0.01)
msmarco_v1_hnsw_onnx = defaultdict(lambda: 0.01)
msmarco_v1_hnsw_onnx['cos-dpr-distil-hnsw-onnx'] = 0.015
msmarco_v1_hnsw_cached = defaultdict(lambda: 0.01)
msmarco_v1_hnsw_cached['cos-dpr-distil-hnsw-cached'] = 0.015

msmarco_v1_hnsw_tolerance = {
'hnsw-int8-onnx': msmarco_v1_hnsw_int8_onnx,
'hnsw-int8-cached': msmarco_v1_hnsw_int8_cached,
'hnsw-onnx': msmarco_v1_hnsw_onnx,
'hnsw-cached': msmarco_v1_hnsw_cached,
}

dl19_hnsw_int8_onnx = defaultdict(lambda: 0.01)
dl19_hnsw_int8_onnx['bge-hnsw-int8-onnx'] = 0.025
dl19_hnsw_int8_onnx['cos-dpr-distil-hnsw-int8-onnx'] = 0.025
dl19_hnsw_int8_cached = defaultdict(lambda: 0.01)
dl19_hnsw_int8_cached['bge-hnsw-int8-cached'] = 0.02
dl19_hnsw_int8_cached['cohere-embed-english-v3.0-hnsw-int8-cached'] = 0.02
dl19_hnsw_int8_cached['cos-dpr-distil-hnsw-int8-cached'] = 0.025
dl19_hnsw_int8_cached['openai-ada2-hnsw-int8-cached'] = 0.015
dl19_hnsw_onnx = defaultdict(lambda: 0.015)
dl19_hnsw_onnx['bge-hnsw-onnx'] = 0.02
dl19_hnsw_cached = defaultdict(lambda: 0.015)
dl19_hnsw_cached['cohere-embed-english-v3.0-hnsw-cached'] = 0.02

dl19_hnsw_tolerance = {
'hnsw-int8-onnx': dl19_hnsw_int8_onnx,
'hnsw-int8-cached': dl19_hnsw_int8_cached,
'hnsw-onnx': dl19_hnsw_onnx,
'hnsw-cached': dl19_hnsw_cached,
}

dl20_hnsw_int8_onnx = defaultdict(lambda: 0.02)
dl20_hnsw_int8_cached = defaultdict(lambda: 0.02)
dl20_hnsw_onnx = defaultdict(lambda: 0.015)
dl20_hnsw_cached = defaultdict(lambda: 0.015)
dl20_hnsw_cached['cohere-embed-english-v3.0-hnsw-cached'] = 0.025

dl20_hnsw_tolerance = {
'hnsw-int8-onnx': dl20_hnsw_int8_onnx,
'hnsw-int8-cached': dl20_hnsw_int8_cached,
'hnsw-onnx': dl20_hnsw_onnx,
'hnsw-cached': dl20_hnsw_cached,
}
# msmarco_v1_hnsw_int8_onnx = defaultdict(lambda: 0.01)
# msmarco_v1_hnsw_int8_cached = defaultdict(lambda: 0.01)
# msmarco_v1_hnsw_onnx = defaultdict(lambda: 0.01)
# msmarco_v1_hnsw_onnx['cos-dpr-distil-hnsw-onnx'] = 0.015
# msmarco_v1_hnsw_cached = defaultdict(lambda: 0.01)
# msmarco_v1_hnsw_cached['cos-dpr-distil-hnsw-cached'] = 0.015
#
# msmarco_v1_hnsw_tolerance = {
# 'hnsw-int8-onnx': msmarco_v1_hnsw_int8_onnx,
# 'hnsw-int8-cached': msmarco_v1_hnsw_int8_cached,
# 'hnsw-onnx': msmarco_v1_hnsw_onnx,
# 'hnsw-cached': msmarco_v1_hnsw_cached,
# }
#
# dl19_hnsw_int8_onnx = defaultdict(lambda: 0.01)
# dl19_hnsw_int8_onnx['bge-hnsw-int8-onnx'] = 0.025
# dl19_hnsw_int8_onnx['cos-dpr-distil-hnsw-int8-onnx'] = 0.025
# dl19_hnsw_int8_cached = defaultdict(lambda: 0.01)
# dl19_hnsw_int8_cached['bge-hnsw-int8-cached'] = 0.02
# dl19_hnsw_int8_cached['cohere-embed-english-v3.0-hnsw-int8-cached'] = 0.02
# dl19_hnsw_int8_cached['cos-dpr-distil-hnsw-int8-cached'] = 0.025
# dl19_hnsw_int8_cached['openai-ada2-hnsw-int8-cached'] = 0.015
# dl19_hnsw_onnx = defaultdict(lambda: 0.015)
# dl19_hnsw_onnx['bge-hnsw-onnx'] = 0.02
# dl19_hnsw_cached = defaultdict(lambda: 0.015)
# dl19_hnsw_cached['cohere-embed-english-v3.0-hnsw-cached'] = 0.02
#
# dl19_hnsw_tolerance = {
# 'hnsw-int8-onnx': dl19_hnsw_int8_onnx,
# 'hnsw-int8-cached': dl19_hnsw_int8_cached,
# 'hnsw-onnx': dl19_hnsw_onnx,
# 'hnsw-cached': dl19_hnsw_cached,
# }
#
# dl20_hnsw_int8_onnx = defaultdict(lambda: 0.02)
# dl20_hnsw_int8_cached = defaultdict(lambda: 0.02)
# dl20_hnsw_onnx = defaultdict(lambda: 0.015)
# dl20_hnsw_cached = defaultdict(lambda: 0.015)
# dl20_hnsw_cached['cohere-embed-english-v3.0-hnsw-cached'] = 0.025
#
# dl20_hnsw_tolerance = {
# 'hnsw-int8-onnx': dl20_hnsw_int8_onnx,
# 'hnsw-int8-cached': dl20_hnsw_int8_cached,
# 'hnsw-onnx': dl20_hnsw_onnx,
# 'hnsw-cached': dl20_hnsw_cached,
# }


def evaluate_and_verify(yaml_data, dry_run):
Expand Down Expand Up @@ -392,13 +392,18 @@ def evaluate_and_verify(yaml_data, dry_run):
using_hnsw = True if 'type' in model and model['type'] == 'hnsw' else False
using_flat = True if 'type' in model and model['type'] == 'flat' else False

if using_flat:
if 'tolerance' in model:
#print(model['tolerance'])
#print(metric)
tolerance_ok = model['tolerance'][metric['metric']][i]
else:
tolerance_ok = 0
if 'tolerance' in model:
tolerance_ok = model['tolerance'][metric['metric']][i]
else:
tolerance_ok = 0

# if using_flat:
# if 'tolerance' in model:
# #print(model['tolerance'])
# #print(metric)
# tolerance_ok = model['tolerance'][metric['metric']][i]
# else:
# tolerance_ok = 0
# else:
# # Extract model
# match = flat_model_type_pattern.search(model['name'])
Expand All @@ -417,26 +422,28 @@ def evaluate_and_verify(yaml_data, dry_run):
# elif using_flat and 'DL20' in topic_set['name']:
# tolerance_ok = dl20_flat_tolerance[model_type][model['name']]

if using_hnsw:
if 'tolerance' in model:
tolerance_ok = model['tolerance'][metric['metric']][i]
else:
# Extract model
match = hnsw_model_type_pattern.search(model['name'])
model_type = match.group(1)

if 'BEIR' in topic_set['name']:
# Extract BEIR dataset
match = beir_dataset_pattern.search(topic_set['name'])
beir_dataset = match.group(1)

tolerance_ok = beir_hnsw_tolerance[model_type][beir_dataset]
elif 'MS MARCO Passage' in topic_set['name']:
tolerance_ok = msmarco_v1_hnsw_tolerance[model_type][model['name']]
elif 'DL19' in topic_set['name']:
tolerance_ok = dl19_hnsw_tolerance[model_type][model['name']]
elif 'DL20' in topic_set['name']:
tolerance_ok = dl20_hnsw_tolerance[model_type][model['name']]
# if using_hnsw:
# if 'tolerance' in model:
# tolerance_ok = model['tolerance'][metric['metric']][i]
# else:
# tolerance_ok = 0
# else:
# # Extract model
# match = hnsw_model_type_pattern.search(model['name'])
# model_type = match.group(1)
#
# if 'BEIR' in topic_set['name']:
# # Extract BEIR dataset
# match = beir_dataset_pattern.search(topic_set['name'])
# beir_dataset = match.group(1)
#
# tolerance_ok = beir_hnsw_tolerance[model_type][beir_dataset]
# elif 'MS MARCO Passage' in topic_set['name']:
# tolerance_ok = msmarco_v1_hnsw_tolerance[model_type][model['name']]
# elif 'DL19' in topic_set['name']:
# tolerance_ok = dl19_hnsw_tolerance[model_type][model['name']]
# elif 'DL20' in topic_set['name']:
# tolerance_ok = dl20_hnsw_tolerance[model_type][model['name']]

if using_flat or using_hnsw:
result_str = (f'expected: {expected:.4f} actual: {actual:.4f} '
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,12 @@ models:
- 0.6171
R@1000:
- 0.8472
tolerance:
AP@1000:
- 0.005
nDCG@10:
- 0.005
R@100:
- 0.005
R@1000:
- 0.005
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,12 @@ models:
- 0.6171
R@1000:
- 0.8472
tolerance:
AP@1000:
- 0.005
nDCG@10:
- 0.005
R@100:
- 0.005
R@1000:
- 0.005
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,12 @@ models:
- 0.6171
R@1000:
- 0.8472
tolerance:
AP@1000:
- 0.005
nDCG@10:
- 0.005
R@100:
- 0.005
R@1000:
- 0.005
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,12 @@ models:
- 0.6171
R@1000:
- 0.8472
tolerance:
AP@1000:
- 0.005
nDCG@10:
- 0.005
R@100:
- 0.005
R@1000:
- 0.005
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,13 @@ models:
R@100:
- 0.6484
R@1000:
- 0.8630
- 0.8630
tolerance:
AP@1000:
- 0.005
nDCG@10:
- 0.005
R@100:
- 0.005
R@1000:
- 0.005
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,13 @@ models:
R@100:
- 0.6484
R@1000:
- 0.8630
- 0.8630
tolerance:
AP@1000:
- 0.005
nDCG@10:
- 0.005
R@100:
- 0.005
R@1000:
- 0.005
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,12 @@ models:
- 0.6173
R@1000:
- 0.8201
tolerance:
AP@1000:
- 0.005
nDCG@10:
- 0.005
R@100:
- 0.005
R@1000:
- 0.005
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,12 @@ models:
- 0.6173
R@1000:
- 0.8201
tolerance:
AP@1000:
- 0.005
nDCG@10:
- 0.005
R@100:
- 0.005
R@1000:
- 0.005
Loading
Loading