Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rwood-97 committed Aug 30, 2023
1 parent 6ad592b commit 15531a0
Show file tree
Hide file tree
Showing 13 changed files with 410 additions and 513 deletions.
9 changes: 4 additions & 5 deletions experiments/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,7 @@ def prepare_data(self) -> dict:
# Obtain candidates per sentence:
for sentence_id in tqdm(dMentionsPred):
pred_mentions_sent = dMentionsPred[sentence_id]
(
wk_cands,
self.myranker.already_collected_cands,
) = self.myranker.find_candidates(pred_mentions_sent)
wk_cands = self.myranker.find_candidates(pred_mentions_sent)
dCandidates[sentence_id] = wk_cands

# -------------------------------------------
Expand Down Expand Up @@ -466,10 +463,12 @@ def create_mentions_df(self) -> pd.DataFrame:
data=rows,
)

print(f"Saving to {os.path.join(self.data_path,self.dataset,f'{self.myner.model}_{cand_approach}')}")
output_path = (
self.data_path + self.dataset + "/" + self.myner.model + "_" + cand_approach
os.path.join(self.data_path,self.dataset,f"{self.myner.model}_{cand_approach}")
)


# List of columns to merge (i.e. columns where we have indicated
# out data splits), and "article_id", the columns on which we
# will merge the data:
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,8 @@ include = '\.pyi?$'

[tool.isort]
profile = "black"

[tool.pytest.ini_options]
markers = [
"deezy: tests which need a deezy model",
]
3 changes: 0 additions & 3 deletions t_res/geoparser/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,6 @@ def __init__(
self.myranker
)

# Check we've actually loaded the mentions2wikidata dictionary:
assert self.myranker.mentions_to_wikidata["London"] is not None

def run_sentence(
self,
sentence: str,
Expand Down
4 changes: 2 additions & 2 deletions t_res/geoparser/recogniser.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def compute_metrics(p: Tuple[list, list]) -> dict:
training_args = TrainingArguments(
output_dir=self.model_path,
evaluation_strategy="epoch",
logging_dir=self.model_path + "runs/" + self.model,
logging_dir=os.path.join(self.model_path,"runs/",self.model),
learning_rate=self.training_args["learning_rate"],
per_device_train_batch_size=self.training_args["batch_size"],
per_device_eval_batch_size=self.training_args["batch_size"],
Expand All @@ -295,7 +295,7 @@ def compute_metrics(p: Tuple[list, list]) -> dict:
trainer.evaluate()

# Save the model:
trainer.save_model(self.model_path + self.model + ".model")
trainer.save_model(os.path.join(self.model_path,f"{self.model}.model"))

# -------------------------------------------------------------
def create_pipeline(self) -> Pipeline:
Expand Down
82 changes: 25 additions & 57 deletions tests/test_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from pathlib import Path

import pandas as pd
import pytest

large_resources = "/resources/" # path to large resources
small_resources = "./resources/" # path to small resources
processed_path_lwm = "./experiments/outputs/data/lwm/" # path to processed LwM data
processed_path_hipe = "./experiments/outputs/data/hipe/" # path to processed LwM data
current_dir = Path(__file__).parent.resolve()

small_resources = os.path.join(current_dir,"sample_files/resources/") # path to small resources
processed_path_lwm = os.path.join(current_dir,"sample_files/experiments/outputs/data/lwm/") # path to processed LwM data
processed_path_hipe = os.path.join(current_dir,"sample_files/experiments/outputs/data/hipe/") # path to processed LwM data


def test_publication_metadata_exists():
Expand Down Expand Up @@ -53,8 +55,8 @@ def test_original_lwm_data():
train_metadata = pd.read_csv(path_train_metadata, sep="\t")
test_metadata = pd.read_csv(path_test_metadata, sep="\t")
# Assert the size of the metadata files:
assert train_metadata.shape[0] == 343
assert test_metadata.shape[0] == 112
assert train_metadata.shape[0] == 1
assert test_metadata.shape[0] == 1
assert train_metadata.shape[1] == 10
assert test_metadata.shape[1] == 10
# Items in metadata match number of files in directory, for test:
Expand Down Expand Up @@ -98,54 +100,20 @@ def test_lwm_ner_conversion_fine():
dtype={"id": str},
)
# Assert size of the train and dev sets:
assert df_ner_train.shape == (5216, 3)
assert df_ner_dev.shape == (1304, 3)
assert df_ner_train.shape == (141, 3)
assert df_ner_dev.shape == (41, 3)
# Assert number of sentences in train and dev (length of list and set should be the same):
assert (
len(list(df_ner_train["id"]) + list(df_ner_dev["id"]))
== len(set(list(df_ner_train["id"]) + list(df_ner_dev["id"])))
== df_ner_train.shape[0] + df_ner_dev.shape[0]
)
# Assert ID is read as string:
assert type(df_ner_train["id"].iloc[0]) == str
assert isinstance(df_ner_train["id"].iloc[0],str)
# Assert number of unique articles:
train_articles = [x.split("_")[0] for x in list(df_ner_train["id"])]
dev_articles = [x.split("_")[0] for x in list(df_ner_dev["id"])]
assert len(set(train_articles + dev_articles)) == 343


def test_lwm_ner_conversion_coarse():
"""
Test process_lwm_for_ner is not missing articles.
"""
df_ner_train = pd.read_json(
os.path.join(f"{processed_path_lwm}", "ner_coarse_train.json"),
orient="records",
lines=True,
dtype={"id": str},
)
df_ner_dev = pd.read_json(
os.path.join(f"{processed_path_lwm}", "ner_coarse_dev.json"),
orient="records",
lines=True,
dtype={"id": str},
)
# Assert size of the train and dev sets:
assert df_ner_train.shape == (5216, 3)
assert df_ner_dev.shape == (1304, 3)
# Assert number of sentences in train and dev (length of list and set should be the same):
assert (
len(list(df_ner_train["id"]) + list(df_ner_dev["id"]))
== len(set(list(df_ner_train["id"]) + list(df_ner_dev["id"])))
== df_ner_train.shape[0] + df_ner_dev.shape[0]
)
# Assert ID is read as string:
assert type(df_ner_train["id"].iloc[0]) == str
# Assert number of unique articles:
train_articles = [x.split("_")[0] for x in list(df_ner_train["id"])]
dev_articles = [x.split("_")[0] for x in list(df_ner_dev["id"])]
assert len(set(train_articles + dev_articles)) == 343

assert len(set(train_articles + dev_articles)) == 11

def test_lwm_linking_conversion():
"""
Expand All @@ -156,26 +124,26 @@ def test_lwm_linking_conversion():
sep="\t",
)
# Assert size of the dataset (i.e. number of articles):
assert df_linking.shape[0] == 455
assert df_linking.shape[0] == 14
# Assert if place has been filled correctly:
for x in df_linking.place:
assert type(x) == str
assert isinstance(x,str)
assert x != ""
# Assert if place QID has been filled correctly:
for x in df_linking.place_wqid:
assert type(x) == str
assert isinstance(x,str)
assert x != ""
for x in df_linking.annotations:
x = literal_eval(x)
for ann in x:
assert ann["wkdt_qid"] == "NIL" or ann["wkdt_qid"].startswith("Q")
assert df_linking[df_linking["originalsplit"] == "train"].shape[0] == 229
assert df_linking[df_linking["originalsplit"] == "dev"].shape[0] == 114
assert df_linking[df_linking["originalsplit"] == "test"].shape[0] == 112
assert df_linking[df_linking["withouttest"] == "train"].shape[0] == 153
assert df_linking[df_linking["withouttest"] == "dev"].shape[0] == 76
assert df_linking[df_linking["withouttest"] == "test"].shape[0] == 114
assert df_linking[df_linking["withouttest"] == "left_out"].shape[0] == 112
assert df_linking[df_linking["originalsplit"] == "train"].shape[0] == 10
assert df_linking[df_linking["originalsplit"] == "dev"].shape[0] == 2
assert df_linking[df_linking["originalsplit"] == "test"].shape[0] == 2
assert df_linking[df_linking["withouttest"] == "train"].shape[0] == 8
assert df_linking[df_linking["withouttest"] == "dev"].shape[0] == 2
assert df_linking[df_linking["withouttest"] == "test"].shape[0] == 2
assert df_linking[df_linking["withouttest"] == "left_out"].shape[0] == 2
test_withouttest = set(
list(df_linking[df_linking["withouttest"] == "test"].article_id)
)
Expand All @@ -185,7 +153,7 @@ def test_lwm_linking_conversion():
# Test articles of the original split and without test should not overlap:
assert not (test_withouttest & test_originalsplit)


@pytest.mark.skip(reason="Requires HIPE data")
def test_hipe_linking_conversion():
"""
Test process_hipe_for_linking is not missing articles.
Expand All @@ -211,11 +179,11 @@ def test_hipe_linking_conversion():
assert not (test_withouttest & test_originalsplit)
# Assert if place has been filled correctly:
for x in df_linking.place:
assert type(x) == str
assert isinstance(x,str)
assert x != ""
# Assert if place QID has been filled correctly:
for x in df_linking.place_wqid:
assert type(x) == str
assert isinstance(x,str)
assert x != ""
# Do HIPE stats match https://github.com/hipe-eval/HIPE-2022-data/blob/main/notebooks/hipe2022-datasets-stats.ipynb
number_locs = 0
Expand Down
44 changes: 44 additions & 0 deletions tests/test_deezy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
from pathlib import Path

import pytest
from DeezyMatch import candidate_ranker

current_dir = Path(__file__).parent.resolve()

@pytest.mark.deezy(reason="Needs deezy model")
def test_deezy_match_deezy_candidate_ranker(tmp_path):
deezy_parameters = {
# Paths and filenames of DeezyMatch models and data:
"dm_path": os.path.join(current_dir,"sample_files/resources/deezymatch/"),
"dm_cands": "wkdtalts",
"dm_model": "w2v_ocr",
"dm_output": "deezymatch_on_the_fly",
# Ranking measures:
"ranking_metric": "faiss",
"selection_threshold": 50,
"num_candidates": 1,
"verbose": False,
# DeezyMatch training:
"overwrite_training": False,
"do_test": False,
}

dm_path = deezy_parameters["dm_path"]
dm_cands = deezy_parameters["dm_cands"]
dm_model = deezy_parameters["dm_model"]
dm_output = deezy_parameters["dm_output"]

query = ["-", "ST G", "• - , i", "- P", "• FERRIS"]

candidates = candidate_ranker(
candidate_scenario=os.path.join(dm_path, "combined", dm_cands + "_" + dm_model),
query=query,
ranking_metric=deezy_parameters["ranking_metric"],
selection_threshold=deezy_parameters["selection_threshold"],
num_candidates=deezy_parameters["num_candidates"],
search_size=deezy_parameters["num_candidates"],
verbose=deezy_parameters["verbose"],
output_path=os.path.join(tmp_path,dm_output),
)
assert len(candidates) == len(query)
Loading

0 comments on commit 15531a0

Please sign in to comment.