diff --git a/examples/run_pipeline_deezy_reldisamb+wmtops.ipynb b/examples/run_pipeline_deezy_reldisamb+wmtops.ipynb index 97aee962..6df6aaa7 100644 --- a/examples/run_pipeline_deezy_reldisamb+wmtops.ipynb +++ b/examples/run_pipeline_deezy_reldisamb+wmtops.ipynb @@ -72,7 +72,7 @@ "metadata": {}, "outputs": [], "source": [ - "with sqlite3.connect(\"../resources/rel_db/embedding_database.db\") as conn:\n", + "with sqlite3.connect(\"../resources/rel_db/embeddings_database.db\") as conn:\n", " cursor = conn.cursor()\n", " mylinker = linking.Linker(\n", " method=\"reldisamb\",\n", diff --git a/examples/run_pipeline_deezy_reldisamb+wpubl+wmtops.ipynb b/examples/run_pipeline_deezy_reldisamb+wpubl+wmtops.ipynb index bd4adc99..f2f4cff2 100644 --- a/examples/run_pipeline_deezy_reldisamb+wpubl+wmtops.ipynb +++ b/examples/run_pipeline_deezy_reldisamb+wpubl+wmtops.ipynb @@ -72,7 +72,7 @@ "metadata": {}, "outputs": [], "source": [ - "with sqlite3.connect(\"../resources/rel_db/embedding_database.db\") as conn:\n", + "with sqlite3.connect(\"../resources/rel_db/embeddings_database.db\") as conn:\n", " cursor = conn.cursor()\n", " mylinker = linking.Linker(\n", " method=\"reldisamb\",\n", diff --git a/examples/run_pipeline_deezy_reldisamb+wpubl.ipynb b/examples/run_pipeline_deezy_reldisamb+wpubl.ipynb index ccbd53d3..ad475493 100644 --- a/examples/run_pipeline_deezy_reldisamb+wpubl.ipynb +++ b/examples/run_pipeline_deezy_reldisamb+wpubl.ipynb @@ -72,7 +72,7 @@ "metadata": {}, "outputs": [], "source": [ - "with sqlite3.connect(\"../resources/rel_db/embedding_database.db\") as conn:\n", + "with sqlite3.connect(\"../resources/rel_db/embeddings_database.db\") as conn:\n", " cursor = conn.cursor()\n", " mylinker = linking.Linker(\n", " method=\"reldisamb\",\n", diff --git a/examples/run_pipeline_deezy_reldisamb.ipynb b/examples/run_pipeline_deezy_reldisamb.ipynb index 85102e3a..4d7bf262 100644 --- a/examples/run_pipeline_deezy_reldisamb.ipynb +++ b/examples/run_pipeline_deezy_reldisamb.ipynb @@ -72,7 +72,7 @@ "metadata": {}, "outputs": [], "source": [ - "with sqlite3.connect(\"../resources/rel_db/embedding_database.db\") as conn:\n", + "with sqlite3.connect(\"../resources/rel_db/embeddings_database.db\") as conn:\n", " cursor = conn.cursor()\n", " mylinker = linking.Linker(\n", " method=\"reldisamb\",\n", diff --git a/experiments/toponym_resolution.py b/experiments/toponym_resolution.py index 747489e9..8a3864f8 100644 --- a/experiments/toponym_resolution.py +++ b/experiments/toponym_resolution.py @@ -119,7 +119,7 @@ # -------------------------------------- # Instantiate the linker: - with sqlite3.connect("../resources/rel_db/embedding_database.db") as conn: + with sqlite3.connect("../resources/rel_db/embeddings_database.db") as conn: cursor = conn.cursor() mylinker = linking.Linker( method=top_res_method, diff --git a/tests/test_disambiguation.py b/tests/test_disambiguation.py index db85e44c..00841784 100644 --- a/tests/test_disambiguation.py +++ b/tests/test_disambiguation.py @@ -20,7 +20,7 @@ def test_embeddings(): """ # Test 1: Check glove embeddings mentions = ["in", "apple"] - with sqlite3.connect("resources/rel_db/embedding_database.db") as conn: + with sqlite3.connect("resources/rel_db/embeddings_database.db") as conn: cursor = conn.cursor() embs = rel_utils.get_db_emb(cursor, mentions, "snd") assert len(mentions) == len(embs) @@ -47,7 +47,9 @@ def test_embeddings(): def test_prepare_initial_data(): - df = pd.read_csv("experiments/outputs/data/lwm/linking_df_split.tsv", sep="\t").iloc[:1] + df = pd.read_csv( + "experiments/outputs/data/lwm/linking_df_split.tsv", sep="\t" + ).iloc[:1] parsed_doc = rel_utils.prepare_initial_data(df, context_len=100) assert parsed_doc["4939308_1"][0]["mention"] == "STALYBRIDGE" assert parsed_doc["4939308_1"][0]["gold"][0] == "Q1398653" @@ -106,7 +108,7 @@ def test_train(): "do_test": False, }, ) - with sqlite3.connect("resources/rel_db/embedding_database.db") as conn: + with sqlite3.connect("resources/rel_db/embeddings_database.db") as conn: cursor = conn.cursor() mylinker = linking.Linker( @@ -148,7 +150,10 @@ def test_train(): # candidates to the training set): mylinker.rel_params["ed_model"] = mylinker.train_load_model(myranker) - assert type(mylinker.rel_params["ed_model"]) == entity_disambiguation.EntityDisambiguation + assert ( + type(mylinker.rel_params["ed_model"]) + == entity_disambiguation.EntityDisambiguation + ) # assert expected performance on test set assert mylinker.rel_params["ed_model"].best_performance["f1"] == 0.6583541147132169 @@ -206,7 +211,7 @@ def test_load_eval_model(): }, ) - with sqlite3.connect("resources/rel_db/embedding_database.db") as conn: + with sqlite3.connect("resources/rel_db/embeddings_database.db") as conn: cursor = conn.cursor() mylinker = linking.Linker( @@ -249,7 +254,10 @@ def test_load_eval_model(): # candidates to the training set): mylinker.rel_params["ed_model"] = mylinker.train_load_model(myranker) - assert type(mylinker.rel_params["ed_model"]) == entity_disambiguation.EntityDisambiguation + assert ( + type(mylinker.rel_params["ed_model"]) + == entity_disambiguation.EntityDisambiguation + ) def test_predict(): @@ -303,7 +311,7 @@ def test_predict(): "do_test": False, }, ) - with sqlite3.connect("resources/rel_db/embedding_database.db") as conn: + with sqlite3.connect("resources/rel_db/embeddings_database.db") as conn: cursor = conn.cursor() mylinker = linking.Linker( diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 1d278d80..b31bebb9 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -147,7 +147,7 @@ def test_deezy_rel_wpubl_wmtops(): }, ) - with sqlite3.connect("resources/rel_db/embedding_database.db") as conn: + with sqlite3.connect("resources/rel_db/embeddings_database.db") as conn: cursor = conn.cursor() mylinker = linking.Linker( method="reldisamb", @@ -238,7 +238,7 @@ def test_perfect_rel_wpubl_wmtops(): }, ) - with sqlite3.connect("resources/rel_db/embedding_database.db") as conn: + with sqlite3.connect("resources/rel_db/embeddings_database.db") as conn: cursor = conn.cursor() mylinker = linking.Linker( method="reldisamb", diff --git a/utils/REL/entity_disambiguation.py b/utils/REL/entity_disambiguation.py index 39f211a1..d04e0602 100644 --- a/utils/REL/entity_disambiguation.py +++ b/utils/REL/entity_disambiguation.py @@ -47,6 +47,18 @@ def __init__(self, db_embs, user_config, reset_embeddings=False): assert ( test is not None ), "DB embeddings in wrong folder..? Test embedding not found.." + test = rel_utils.get_db_emb(self.db_embs, ["#ENTITY/UNK#"], "entity")[0] + assert ( + test is not None + ), "DB embeddings in wrong folder..? Test embedding not found.." + test = rel_utils.get_db_emb(self.db_embs, ["#WORD/UNK#"], "word")[0] + assert ( + test is not None + ), "DB embeddings in wrong folder..? Test embedding not found.." + test = rel_utils.get_db_emb(self.db_embs, ["#SND/UNK#"], "snd")[0] + assert ( + test is not None + ), "DB embeddings in wrong folder..? Test embedding not found.." # Initialise embedding dictionary: self.__load_embeddings()