\n",
" \n",
"\n",
@@ -844,60 +786,75 @@
"text/plain": [
" name map\n",
"0 PL2 Baseline 0.206031\n",
- "1 LTR Baseline 0.144980"
+ "1 LTR Baseline 0.144662"
]
},
- "metadata": {
- "tags": []
- },
- "execution_count": 24
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
}
+ ],
+ "source": [
+ "from sklearn.ensemble import RandomForestRegressor\n",
+ "\n",
+ "BaselineLTR = fbr3f >> pt.ltr.apply_learned_model(RandomForestRegressor(n_estimators=400))\n",
+ "BaselineLTR.fit(train_topics, qrels)\n",
+ "\n",
+ "results = pt.Experiment([PL2, BaselineLTR], test_topics, qrels, [\"map\"], names=[\"PL2 Baseline\", \"LTR Baseline\"])\n",
+ "results"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here, the RandomForest pipeline wasnt very good. LambdaMART is normally a bit better. Lets try that next..."
]
},
{
"cell_type": "markdown",
"metadata": {
- "id": "iGw58PCuumuT",
- "colab_type": "text"
+ "id": "iGw58PCuumuT"
},
"source": [
"## XgBoost Pipeline\n",
"\n",
- "We now demonstrate the use of a LambdaMART implementation from [xgBoost](https://xgboost.readthedocs.io/en/latest/). Again, pyTerrier provides a transformer object, namely `XGBoostLTR_pipeline`, which takes in the constrcutor the actual xgBoost model that you want to train. We took the xgBoost configuration from [their example code](https://github.com/dmlc/xgboost/blob/master/demo/rank/rank.py).\n",
+ "We now demonstrate the use of a LambdaMART implementation from [xgBoost](https://xgboost.readthedocs.io/en/latest/). Again, PyTerrier provides a Transformer object from `pt.ltr.apply_learned_model()`, this time passing `form='ltr'` as kwarg.\n",
"\n",
- "Call the `fit()` method on the full pipeline with the training and validation topics.\n",
+ "This takes in the constrcutor the actual xgBoost model that you want to train. We took the xgBoost configuration from [their example code](https://github.com/dmlc/xgboost/blob/master/demo/rank/rank.py).\n",
"\n",
- "Evaluate the results with the Experiment function by using the test topics"
+ "Call the `fit()` method on the full pipeline with the training *and validation* topics.\n",
+ "\n",
+ "The same pipeline can also be used with [LightGBM](https://github.com/microsoft/LightGBM).\n",
+ "\n",
+ "Evaluate the results with the Experiment function by using the test topics."
]
},
{
"cell_type": "code",
+ "execution_count": 19,
"metadata": {
- "id": "nM0r8EgFuGtQ",
- "colab_type": "code",
- "colab": {}
+ "id": "nM0r8EgFuGtQ"
},
+ "outputs": [],
"source": [
"import xgboost as xgb\n",
- "params = {'objective': 'rank:ndcg', \n",
- " 'learning_rate': 0.1, \n",
- " 'gamma': 1.0, 'min_child_weight': 0.1,\n",
+ "params = {'objective': 'rank:ndcg',\n",
+ " 'learning_rate': 0.1,\n",
+ " 'gamma': 1.0, \n",
+ " 'min_child_weight': 0.1,\n",
" 'max_depth': 6,\n",
- " 'verbose': 2,\n",
- " 'random_state': 42 \n",
+ " 'random_state': 42\n",
" }\n",
"\n",
- "BaseLTR_LM = fbr >> pt.pipelines.XGBoostLTR_pipeline(xgb.sklearn.XGBRanker(**params))\n",
+ "BaseLTR_LM = fbr3f >> pt.ltr.apply_learned_model(xgb.sklearn.XGBRanker(**params), form='ltr')\n",
"BaseLTR_LM.fit(train_topics, qrels, valid_topics, qrels)"
- ],
- "execution_count": 25,
- "outputs": []
+ ]
},
{
"cell_type": "markdown",
"metadata": {
- "id": "HVXoNhzSP-k2",
- "colab_type": "text"
+ "id": "HVXoNhzSP-k2"
},
"source": [
"And evaluate the results."
@@ -905,26 +862,17 @@
},
{
"cell_type": "code",
+ "execution_count": 20,
"metadata": {
- "id": "Dn56DKZMTQ_m",
- "colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
- "height": 111
+ "height": 112
},
- "outputId": "6688d85e-8599-4f11-db18-231abd0d7aee"
+ "id": "Dn56DKZMTQ_m",
+ "outputId": "133260ca-e979-4006-9120-5339682331e0"
},
- "source": [
- "allresultsLM = pt.pipelines.Experiment([PL2, BaseLTR_LM],\n",
- " test_topics, \n",
- " qrels, [\"map\"], \n",
- " names=[\"PL2 Baseline\", \"LambdaMART\"])\n",
- "allresultsLM"
- ],
- "execution_count": 26,
"outputs": [
{
- "output_type": "execute_result",
"data": {
"text/html": [
"
\n",
@@ -958,7 +906,7 @@
"
\n",
" 1 | \n",
" LambdaMART | \n",
- " 0.204391 | \n",
+ " 0.210969 | \n",
"
\n",
" \n",
"\n",
@@ -967,15 +915,51 @@
"text/plain": [
" name map\n",
"0 PL2 Baseline 0.206031\n",
- "1 LambdaMART 0.204391"
+ "1 LambdaMART 0.210969"
]
},
- "metadata": {
- "tags": []
- },
- "execution_count": 26
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
}
+ ],
+ "source": [
+ "allresultsLM = pt.Experiment([PL2, BaseLTR_LM],\n",
+ " test_topics,\n",
+ " qrels, [\"map\"],\n",
+ " names=[\"PL2 Baseline\", \"LambdaMART\"])\n",
+ "allresultsLM"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Excellent, event on this small dataset, adding a few more features and LambdaMART can enhance effectiveness!"
]
}
- ]
-}
\ No newline at end of file
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/pyterrier/__init__.py b/pyterrier/__init__.py
index 98d5ff7d..8198b684 100644
--- a/pyterrier/__init__.py
+++ b/pyterrier/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.10.0"
+__version__ = "0.10.1"
import os
diff --git a/pyterrier/apply_base.py b/pyterrier/apply_base.py
index 0290fdaf..bdd69449 100644
--- a/pyterrier/apply_base.py
+++ b/pyterrier/apply_base.py
@@ -210,11 +210,18 @@ def transform(self, inputRes):
outputRes = push_queries(inputRes.copy(), inplace=True, keep_original=True)
else:
outputRes = inputRes.copy()
- if self.verbose:
- tqdm.pandas(desc="pt.apply.query", unit="d")
- outputRes["query"] = outputRes.progress_apply(fn, axis=1)
- else:
- outputRes["query"] = outputRes.apply(fn, axis=1)
+ try:
+ if self.verbose:
+ tqdm.pandas(desc="pt.apply.query", unit="d")
+ outputRes["query"] = outputRes.progress_apply(fn, axis=1)
+ else:
+ outputRes["query"] = outputRes.apply(fn, axis=1)
+ except ValueError as ve:
+ msg = str(ve)
+ if "Columns must be same length as key" in msg:
+ raise TypeError("Could not coerce return from pt.apply.query function into a list of strings. Check your function returns a string.") from ve
+ else:
+ raise ve
return outputRes
class ApplyGenericTransformer(ApplyTransformerBase):
diff --git a/pyterrier/batchretrieve.py b/pyterrier/batchretrieve.py
index 3c59ca9d..357a3116 100644
--- a/pyterrier/batchretrieve.py
+++ b/pyterrier/batchretrieve.py
@@ -553,7 +553,10 @@ class TextScorer(TextIndexProcessor):
takes(str): configuration - what is needed as input: `"queries"`, or `"docs"`. Default is `"docs"` since v0.8.
returns(str): configuration - what is needed as output: `"queries"`, or `"docs"`. Default is `"docs"`.
body_attr(str): what dataframe input column contains the text of the document. Default is `"body"`.
- wmodel(str): example of configuration passed to BatchRetrieve.
+ wmodel(str): name of the weighting model to use for scoring.
+ background_index(index_like): An optional background index to use for term and collection statistics. If a weighting
+ model such as BM25 or TF_IDF or PL2 is used without setting the background_index, the background statistics
+ will be calculated from the dataframe, which is ususally not the desired behaviour.
Example::
@@ -562,9 +565,21 @@ class TextScorer(TextIndexProcessor):
["q1", "chemical reactions", "d1", "professor protor poured the chemicals"],
["q1", "chemical reactions", "d2", "chemical brothers turned up the beats"],
], columns=["qid", "query", "text"])
- textscorer = pt.TextScorer(takes="docs", body_attr="text", wmodel="TF_IDF")
+ textscorer = pt.TextScorer(takes="docs", body_attr="text", wmodel="Tf")
rtr = textscorer.transform(df)
- #rtr will score each document for the query "chemical reactions" based on the provided document contents
+ #rtr will score each document by term frequency for the query "chemical reactions" based on the provided document contents
+
+ Example::
+
+ df = pd.DataFrame(
+ [
+ ["q1", "chemical reactions", "d1", "professor protor poured the chemicals"],
+ ["q1", "chemical reactions", "d2", "chemical brothers turned up the beats"],
+ ], columns=["qid", "query", "text"])
+ existing_index = pt.IndexFactory.of(...)
+ textscorer = pt.TextScorer(takes="docs", body_attr="text", wmodel="TF_IDF", background_index=existing_index)
+ rtr = textscorer.transform(df)
+ #rtr will score each document by TF_IDF for the query "chemical reactions" based on the provided document contents
"""
def __init__(self, takes="docs", **kwargs):
@@ -606,6 +621,12 @@ def __init__(self, index_location, features, controls=None, properties=None, thr
self.wmodel = kwargs["wmodel"]
if "wmodel" in controls:
self.wmodel = controls["wmodel"]
+
+ # check for terrier-core#246 bug usiung FatFull
+ if self.wmodel is not None:
+ from . import check_version
+ assert check_version(5.9), "Terrier 5.9 is required for this functionality, see https://github.com/terrier-org/terrier-core/pull/246"
+
if threads > 1:
raise ValueError("Multi-threaded retrieval not yet supported by FeaturesBatchRetrieve")
@@ -657,7 +678,7 @@ def transform(self, queries):
Performs the retrieval with multiple features
Args:
- queries: String for a single query, list of queries, or a pandas.Dataframe with columns=['qid', 'query']. For re-ranking,
+ queries: A pandas.Dataframe with columns=['qid', 'query']. For re-ranking,
the DataFrame may also have a 'docid' and or 'docno' column.
Returns:
@@ -846,4 +867,4 @@ def push_fbr_earlier(_br1, _fbr):
global rewrites_setup
rewrites_setup = True
-setup_rewrites()
\ No newline at end of file
+setup_rewrites()
diff --git a/pyterrier/bootstrap.py b/pyterrier/bootstrap.py
index a606ef71..99be0550 100644
--- a/pyterrier/bootstrap.py
+++ b/pyterrier/bootstrap.py
@@ -46,7 +46,7 @@ def _load_into_memory(index, structures=['lexicon', 'direct', 'inverted', 'meta'
},
'inverted' : {
'org.terrier.structures.bit.BitPostingIndex' : {
- 'index.direct.data-source' : 'fileinmem'}
+ 'index.inverted.data-source' : 'fileinmem'}
},
}
if "direct" in structures:
@@ -271,6 +271,60 @@ def _index_add(self, other):
raise ValueError("Cannot document-wise merge indices with and without positions (%r vs %r)" % (blocks_1, blocks_2))
multiindex_cls = autoclass("org.terrier.realtime.multi.MultiIndex")
return multiindex_cls([self, other], blocks_1, fields_1 > 0)
+
+ def _index_corpusiter(self, return_toks=True):
+ def _index_corpusiter_meta(self):
+ meta_inputstream = self.getIndexStructureInputStream("meta")
+ keys = self.getMetaIndex().getKeys()
+ keys_offset = { k: offset for offset, k in enumerate(keys) }
+ while meta_inputstream.hasNext():
+ item = meta_inputstream.next()
+ yield {k : item[keys_offset[k]] for k in keys_offset}
+
+ def _index_corpusiter_direct_pretok(self):
+ import sys
+ MIN_PYTHON = (3, 8)
+ if sys.version_info < MIN_PYTHON:
+ raise NotImplementedError("Sorry, Python 3.8+ is required for this functionality")
+
+ meta_inputstream = self.getIndexStructureInputStream("meta")
+ keys = self.getMetaIndex().getKeys()
+ keys_offset = { k: offset for offset, k in enumerate(keys) }
+ keys_offset = {'docno' : keys_offset['docno'] }
+ direct_inputstream = self.getIndexStructureInputStream("direct")
+ lex = self.getLexicon()
+
+ ip = None
+ while (ip := direct_inputstream.getNextPostings()) is not None: # this is the next() method
+
+ # yield empty toks dicts for empty documents
+ for skipped in range(0, direct_inputstream.getEntriesSkipped()):
+ meta = meta_inputstream.next()
+ rtr = {k : meta[keys_offset[k]] for k in keys_offset}
+ rtr['toks'] = {}
+ yield rtr
+
+ toks = {}
+ while ip.next() != ip.EOL:
+ t, _ = lex[ip.getId()]
+ toks[t] = ip.getFrequency()
+ meta = meta_inputstream.next()
+ rtr = {'toks' : toks}
+ rtr.update({k : meta[keys_offset[k]] for k in keys_offset})
+ yield rtr
+
+ # yield for trailing empty documents
+ for skipped in range(0, direct_inputstream.getEntriesSkipped()):
+ meta = meta_inputstream.next()
+ rtr = {k : meta[keys_offset[k]] for k in keys_offset}
+ rtr['toks'] = {}
+ yield rtr
+
+ if return_toks:
+ if not self.hasIndexStructureInputStream("direct"):
+ raise ValueError("No direct index input stream available, cannot use return_toks=True")
+ return _index_corpusiter_direct_pretok(self)
+ return _index_corpusiter_meta(self)
protocol_map["org.terrier.structures.Index"] = {
# this means that len(index) returns the number of documents in the index
@@ -278,7 +332,10 @@ def _index_add(self, other):
# document-wise composition of indices: adding more documents to an index, by merging two indices with
# different numbers of documents. This implemented by the overloading the `+` Python operator
- '__add__': _index_add
+ '__add__': _index_add,
+
+ # get_corpus_iter returns a yield generator that return {"docno": "d1", "toks" : {'a' : 1}}
+ 'get_corpus_iter' : _index_corpusiter
}
def setup_terrier(file_path, terrier_version=None, helper_version=None, boot_packages=[], force_download=True):
diff --git a/pyterrier/datasets.py b/pyterrier/datasets.py
index 87aa3aeb..247b1825 100644
--- a/pyterrier/datasets.py
+++ b/pyterrier/datasets.py
@@ -644,21 +644,21 @@ def msmarco_document_generate(dataset):
MSMARCO_DOC_FILES = {
"corpus" :
- [("msmarco-docs.trec.gz", "https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docs.trec.gz")],
+ [("msmarco-docs.trec.gz", "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docs.trec.gz")],
"corpus-tsv":
- [("msmarco-docs.tsv.gz", "https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docs.tsv.gz")],
+ [("msmarco-docs.tsv.gz", "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docs.tsv.gz")],
"topics" :
{
- "train" : ("msmarco-doctrain-queries.tsv.gz", "https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-doctrain-queries.tsv.gz", "singleline"),
- "dev" : ("msmarco-docdev-queries.tsv.gz", "https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docdev-queries.tsv.gz", "singleline"),
- "test" : ("msmarco-test2019-queries.tsv.gz", "https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz", "singleline"),
- "test-2020" : ("msmarco-test2020-queries.tsv.gz" , "https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2020-queries.tsv.gz", "singleline"),
- 'leaderboard-2020' : ("docleaderboard-queries.tsv.gz" , "https://msmarco.blob.core.windows.net/msmarcoranking/docleaderboard-queries.tsv.gz", "singleline")
+ "train" : ("msmarco-doctrain-queries.tsv.gz", "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-doctrain-queries.tsv.gz", "singleline"),
+ "dev" : ("msmarco-docdev-queries.tsv.gz", "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-queries.tsv.gz", "singleline"),
+ "test" : ("msmarco-test2019-queries.tsv.gz", "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz", "singleline"),
+ "test-2020" : ("msmarco-test2020-queries.tsv.gz" , "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2020-queries.tsv.gz", "singleline"),
+ 'leaderboard-2020' : ("docleaderboard-queries.tsv.gz" , "https://msmarco.z22.web.core.windows.net/msmarcoranking/docleaderboard-queries.tsv.gz", "singleline")
},
"qrels" :
{
- "train" : ("msmarco-doctrain-qrels.tsv.gz", "https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-doctrain-qrels.tsv.gz"),
- "dev" : ("msmarco-docdev-qrels.tsv.gz", "https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz"),
+ "train" : ("msmarco-doctrain-qrels.tsv.gz", "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-doctrain-qrels.tsv.gz"),
+ "dev" : ("msmarco-docdev-qrels.tsv.gz", "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz"),
"test" : ("2019qrels-docs.txt", "https://trec.nist.gov/data/deep/2019qrels-docs.txt"),
"test-2020" : ("2020qrels-docs.txt", "https://trec.nist.gov/data/deep/2020qrels-docs.txt")
},
@@ -685,18 +685,18 @@ def msmarco_document_generate(dataset):
"dev.small" : ("queries.dev.small.tsv", "collectionandqueries.tar.gz#queries.dev.small.tsv", "singleline"),
"eval" : ("queries.eval.tsv", "queries.tar.gz#queries.eval.tsv", "singleline"),
"eval.small" : ("queries.eval.small.tsv", "collectionandqueries.tar.gz#queries.eval.small.tsv", "singleline"),
- "test-2019" : ("msmarco-test2019-queries.tsv.gz", "https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz", "singleline"),
- "test-2020" : ("msmarco-test2020-queries.tsv.gz", "https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2020-queries.tsv.gz", "singleline")
+ "test-2019" : ("msmarco-test2019-queries.tsv.gz", "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz", "singleline"),
+ "test-2020" : ("msmarco-test2020-queries.tsv.gz", "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2020-queries.tsv.gz", "singleline")
},
"tars" : {
- "queries.tar.gz" : ("queries.tar.gz", "https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz"),
- "collection.tar.gz" : ("collection.tar.gz", "https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz"),
- "collectionandqueries.tar.gz" : ("collectionandqueries.tar.gz", "https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz")
+ "queries.tar.gz" : ("queries.tar.gz", "https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz"),
+ "collection.tar.gz" : ("collection.tar.gz", "https://msmarco.z22.web.core.windows.net/msmarcoranking/collection.tar.gz"),
+ "collectionandqueries.tar.gz" : ("collectionandqueries.tar.gz", "https://msmarco.z22.web.core.windows.net/msmarcoranking/collectionandqueries.tar.gz")
},
"qrels" :
{
- "train" : ("qrels.train.tsv", "https://msmarco.blob.core.windows.net/msmarcoranking/qrels.train.tsv"),
- "dev" : ("qrels.dev.tsv", "https://msmarco.blob.core.windows.net/msmarcoranking/qrels.dev.tsv"),
+ "train" : ("qrels.train.tsv", "https://msmarco.z22.web.core.windows.net/msmarcoranking/qrels.train.tsv"),
+ "dev" : ("qrels.dev.tsv", "https://msmarco.z22.web.core.windows.net/msmarcoranking/qrels.dev.tsv"),
"test-2019" : ("2019qrels-docs.txt", "https://trec.nist.gov/data/deep/2019qrels-pass.txt"),
"test-2020" : ("2020qrels-docs.txt", "https://trec.nist.gov/data/deep/2020qrels-pass.txt"),
"dev.small" : ("qrels.dev.small.tsv", "collectionandqueries.tar.gz#qrels.dev.small.tsv"),
@@ -709,19 +709,19 @@ def msmarco_document_generate(dataset):
MSMARCOv2_DOC_FILES = {
"info_url" : "https://microsoft.github.io/msmarco/TREC-Deep-Learning.html",
"topics" : {
- "train" : ("docv2_train_queries.tsv", "https://msmarco.blob.core.windows.net/msmarcoranking/docv2_train_queries.tsv", "singleline"),
- "dev1" :("docv2_dev_queries.tsv", "https://msmarco.blob.core.windows.net/msmarcoranking/docv2_dev_queries.tsv", "singleline"),
- "dev2" :("docv2_dev2_queries.tsv", "https://msmarco.blob.core.windows.net/msmarcoranking/docv2_dev2_queries.tsv", "singleline"),
- "valid1" : ("msmarco-test2019-queries.tsv.gz" , "https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz", "singleline"),
- "valid2" : ("msmarco-test2020-queries.tsv.gz" , "https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2020-queries.tsv.gz", "singleline"),
- "trec_2021" : ("2021_queries.tsv" , "https://msmarco.blob.core.windows.net/msmarcoranking/2021_queries.tsv", "singleline"),
+ "train" : ("docv2_train_queries.tsv", "https://msmarco.z22.web.core.windows.net/msmarcoranking/docv2_train_queries.tsv", "singleline"),
+ "dev1" :("docv2_dev_queries.tsv", "https://msmarco.z22.web.core.windows.net/msmarcoranking/docv2_dev_queries.tsv", "singleline"),
+ "dev2" :("docv2_dev2_queries.tsv", "https://msmarco.z22.web.core.windows.net/msmarcoranking/docv2_dev2_queries.tsv", "singleline"),
+ "valid1" : ("msmarco-test2019-queries.tsv.gz" , "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz", "singleline"),
+ "valid2" : ("msmarco-test2020-queries.tsv.gz" , "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2020-queries.tsv.gz", "singleline"),
+ "trec_2021" : ("2021_queries.tsv" , "https://msmarco.z22.web.core.windows.net/msmarcoranking/2021_queries.tsv", "singleline"),
},
"qrels" : {
- "train" : ("docv2_train_qrels.tsv", "https://msmarco.blob.core.windows.net/msmarcoranking/docv2_train_qrels.tsv"),
- "dev1" :("docv2_dev_qrels.tsv", "https://msmarco.blob.core.windows.net/msmarcoranking/docv2_dev_qrels.tsv"),
- "dev2" :("docv2_dev2_qrels.tsv", "https://msmarco.blob.core.windows.net/msmarcoranking/docv2_dev2_qrels.tsv"),
- "valid1" : ("docv2_trec2019_qrels.txt.gz" , "https://msmarco.blob.core.windows.net/msmarcoranking/docv2_trec2019_qrels.txt.gz"),
- "valid2" : ("docv2_trec2020_qrels.txt.gz" , "https://msmarco.blob.core.windows.net/msmarcoranking/docv2_trec2020_qrels.txt.gz")
+ "train" : ("docv2_train_qrels.tsv", "https://msmarco.z22.web.core.windows.net/msmarcoranking/docv2_train_qrels.tsv"),
+ "dev1" :("docv2_dev_qrels.tsv", "https://msmarco.z22.web.core.windows.net/msmarcoranking/docv2_dev_qrels.tsv"),
+ "dev2" :("docv2_dev2_qrels.tsv", "https://msmarco.z22.web.core.windows.net/msmarcoranking/docv2_dev2_qrels.tsv"),
+ "valid1" : ("docv2_trec2019_qrels.txt.gz" , "https://msmarco.z22.web.core.windows.net/msmarcoranking/docv2_trec2019_qrels.txt.gz"),
+ "valid2" : ("docv2_trec2020_qrels.txt.gz" , "https://msmarco.z22.web.core.windows.net/msmarcoranking/docv2_trec2020_qrels.txt.gz")
},
"index" : _datarepo_index,
}
@@ -729,15 +729,15 @@ def msmarco_document_generate(dataset):
MSMARCOv2_PASSAGE_FILES = {
"info_url" : "https://microsoft.github.io/msmarco/TREC-Deep-Learning.html",
"topics" : {
- "train" : ("passv2_train_queries.tsv", "https://msmarco.blob.core.windows.net/msmarcoranking/passv2_train_queries.tsv", "singleline"),
- "dev1" : ("passv2_dev_queries.tsv", "https://msmarco.blob.core.windows.net/msmarcoranking/passv2_dev_queries.tsv", "singleline"),
- "dev2" : ("passv2_dev2_queries.tsv", "https://msmarco.blob.core.windows.net/msmarcoranking/passv2_dev2_queries.tsv", "singleline"),
- "trec_2021" : ("2021_queries.tsv" , "https://msmarco.blob.core.windows.net/msmarcoranking/2021_queries.tsv", "singleline"),
+ "train" : ("passv2_train_queries.tsv", "https://msmarco.z22.web.core.windows.net/msmarcoranking/passv2_train_queries.tsv", "singleline"),
+ "dev1" : ("passv2_dev_queries.tsv", "https://msmarco.z22.web.core.windows.net/msmarcoranking/passv2_dev_queries.tsv", "singleline"),
+ "dev2" : ("passv2_dev2_queries.tsv", "https://msmarco.z22.web.core.windows.net/msmarcoranking/passv2_dev2_queries.tsv", "singleline"),
+ "trec_2021" : ("2021_queries.tsv" , "https://msmarco.z22.web.core.windows.net/msmarcoranking/2021_queries.tsv", "singleline"),
},
"qrels" : {
- "train" : ("passv2_train_qrels.tsv" "https://msmarco.blob.core.windows.net/msmarcoranking/passv2_train_qrels.tsv"),
- "dev1" : ("passv2_dev_qrels.tsv", "https://msmarco.blob.core.windows.net/msmarcoranking/passv2_dev_qrels.tsv"),
- "dev2" : ("passv2_dev2_qrels.tsv", "https://msmarco.blob.core.windows.net/msmarcoranking/passv2_dev2_qrels.tsv"),
+ "train" : ("passv2_train_qrels.tsv" "https://msmarco.z22.web.core.windows.net/msmarcoranking/passv2_train_qrels.tsv"),
+ "dev1" : ("passv2_dev_qrels.tsv", "https://msmarco.z22.web.core.windows.net/msmarcoranking/passv2_dev_qrels.tsv"),
+ "dev2" : ("passv2_dev2_qrels.tsv", "https://msmarco.z22.web.core.windows.net/msmarcoranking/passv2_dev2_qrels.tsv"),
},
"index" : _datarepo_index,
}
diff --git a/pyterrier/io.py b/pyterrier/io.py
index 9673b56f..3d220e25 100644
--- a/pyterrier/io.py
+++ b/pyterrier/io.py
@@ -228,12 +228,8 @@ def _parse_line(l):
def _read_results_trec(filename):
results = []
- df = pd.read_csv(filename, sep=r'\s+', names=["qid", "iter", "docno", "rank", "score", "name"])
+ df = pd.read_csv(filename, sep=r'\s+', names=["qid", "iter", "docno", "rank", "score", "name"], dtype={'qid': str, 'docno': str, 'rank': int, 'score': float})
df = df.drop(columns="iter")
- df["qid"] = df["qid"].astype(str)
- df["docno"] = df["docno"].astype(str)
- df["rank"] = df["rank"].astype(int)
- df["score"] = df["score"].astype(float)
return df
def write_results(res, filename, format="trec", append=False, **kwargs):
@@ -294,13 +290,13 @@ def read_topics(filename, format="trec", **kwargs):
Supported Formats:
* "trec" -- an SGML-formatted TREC topics file. Delimited by TOP tags, each having NUM and TITLE tags; DESC and NARR tags are skipped by default. Control using whitelist and blacklist kwargs
- * "trecxml" -- a more modern XML formatted topics file. Delimited by topic tags, each having nunber tags. query, question and narrative tags are parsed by default. Control using tags kwarg.
+ * "trecxml" -- a more modern XML formatted topics file. Delimited by topic tags, each having number tags. query, question and narrative tags are parsed by default. Control using tags kwarg.
* "singeline" -- one query per line, preceeded by a space or colon. Tokenised by default, use tokenise=False kwargs to prevent tokenisation.
"""
if format is None:
format = "trec"
if not format in SUPPORTED_TOPICS_FORMATS:
- raise ValueError("Format %s not known, supported types are %s" % (format, str(SUPPORTED_RESULTS_FORMATS.keys())))
+ raise ValueError("Format %s not known, supported types are %s" % (format, str(SUPPORTED_TOPICS_FORMATS.keys())))
return SUPPORTED_TOPICS_FORMATS[format](filename, **kwargs)
def _read_topics_trec(file_path, doc_tag="TOP", id_tag="NUM", whitelist=["TITLE"], blacklist=["DESC","NARR"]):
@@ -339,7 +335,10 @@ def _read_topics_trecxml(filename, tags=["query", "question", "narrative"], toke
from jnius import autoclass
tokeniser = autoclass("org.terrier.indexing.tokenisation.Tokeniser").getTokeniser()
for child in root.iter('topic'):
- qid = child.attrib["number"]
+ try:
+ qid = child.attrib["number"]
+ except KeyError:
+ qid = child.find("number").text
query = ""
for tag in child:
if tag.tag in tags:
@@ -347,7 +346,7 @@ def _read_topics_trecxml(filename, tags=["query", "question", "narrative"], toke
if tokenise:
query_text = " ".join(tokeniser.getTokens(query_text))
query += " " + query_text
- topics.append((str(qid), query))
+ topics.append((str(qid), query.strip()))
return pd.DataFrame(topics, columns=["qid", "query"])
def _read_topics_singleline(filepath, tokenise=True):
diff --git a/pyterrier/pipelines.py b/pyterrier/pipelines.py
index 345812c8..8188ad08 100644
--- a/pyterrier/pipelines.py
+++ b/pyterrier/pipelines.py
@@ -561,8 +561,11 @@ def _apply_round(measure, value):
for pcol in p_col_names:
pcol_reject = pcol.replace("p-value", "reject")
pcol_corrected = pcol + " corrected"
- reject, corrected, _, _ = statsmodels.stats.multitest.multipletests(df[pcol], alpha=correction_alpha, method=correction)
+ reject, corrected, _, _ = statsmodels.stats.multitest.multipletests(df[pcol].drop(df.index[baseline]), alpha=correction_alpha, method=correction)
insert_pos = df.columns.get_loc(pcol)
+ # add reject/corrected values for the baseline
+ reject = np.insert(reject, baseline, False)
+ corrected = np.insert(corrected, baseline, np.nan)
# add extra columns, put place directly after the p-value column
df.insert(insert_pos+1, pcol_reject, reject)
df.insert(insert_pos+2, pcol_corrected, corrected)
diff --git a/pyterrier/rewrite.py b/pyterrier/rewrite.py
index c2f65e44..46e313b7 100644
--- a/pyterrier/rewrite.py
+++ b/pyterrier/rewrite.py
@@ -205,6 +205,7 @@ def __init__(self, index_like, fb_terms=10, fb_docs=3, qeclass="org.terrier.quer
else:
self.qe = qeclass
self.indexref = _parse_index_like(index_like)
+ self.properties = properties
for k,v in properties.items():
pt.ApplicationSetup.setProperty(k, str(v))
self.applytp = pt.autoclass("org.terrier.querying.ApplyTermPipeline")()
@@ -212,6 +213,34 @@ def __init__(self, index_like, fb_terms=10, fb_docs=3, qeclass="org.terrier.quer
self.fb_docs = fb_docs
self.manager = pt.autoclass("org.terrier.querying.ManagerFactory")._from_(self.indexref)
+ def __reduce__(self):
+ return (
+ self.__class__,
+ (self.indexref,),
+ self.__getstate__()
+ )
+
+ def __getstate__(self):
+ if isinstance(self.qe, str):
+ qe = self.qe
+ else:
+ qe = self.qe.getClass().getName()
+ return {
+ 'fb_terms' : self.fb_terms,
+ 'fb_docs' : self.fb_docs,
+ 'qeclass' : qe,
+ 'properties' : self.properties
+ }
+
+ def __setstate__(self, d):
+ self.fb_terms = d["fb_terms"]
+ self.fb_docs = d["fb_docs"]
+ self.qe = pt.autoclass(d['qeclass'])()
+ self.properties.update(d["properties"])
+ for key,value in d["properties"].items():
+ self.appSetup.setProperty(key, str(value))
+ self.manager = pt.autoclass("org.terrier.querying.ManagerFactory")._from_(self.indexref)
+
def _populate_resultset(self, topics_and_res, qid, index):
docids=None
@@ -387,6 +416,15 @@ def __init__(self, *args, fb_terms=10, fb_docs=3, fb_lambda=0.6, **kwargs):
kwargs["qeclass"] = rm
super().__init__(*args, fb_terms=fb_terms, fb_docs=fb_docs, **kwargs)
+ def __getstate__(self):
+ rtr = super().__getstate__()
+ rtr['fb_lambda'] = self.fb_lambda
+ return rtr
+
+ def __setstate__(self, d):
+ super().__setstate__(d)
+ self.fb_lambda = d["fb_lambda"]
+
def _configure_request(self, rq):
super()._configure_request(rq)
rq.setControl("rm3.lambda", str(self.fb_lambda))
diff --git a/pyterrier/text.py b/pyterrier/text.py
index a3ac6f87..51a2dbce 100644
--- a/pyterrier/text.py
+++ b/pyterrier/text.py
@@ -135,6 +135,13 @@ def scorer(*args, **kwargs) -> Transformer:
This is an alias to pt.TextScorer(). Internally, a Terrier memory index is created, before being
used for scoring.
+ Arguments:
+ body_attr(str): what dataframe input column contains the text of the document. Default is `"body"`.
+ wmodel(str): name of the weighting model to use for scoring.
+ background_index(index_like): An optional background index to use for collection statistics. If a weighting
+ model such as BM25 or TF_IDF or PL2 is used without setting the background_index, the background statistics
+ will be calculated from the dataframe, which is ususally not the desired behaviour.
+
Example::
df = pd.DataFrame(
@@ -149,8 +156,9 @@ def scorer(*args, **kwargs) -> Transformer:
# ["q1", "chemical reactions", "d1", "professor protor poured the chemicals", 0, 1]
# ["q1", "chemical reactions", "d2", "chemical brothers turned up the beats", 0, 1]
- For calculating the scores of documents using any weighting model with the concept of IDF, it may be useful to make use of
- an existing Terrier index for background statistics::
+ For calculating the scores of documents using any weighting model with the concept of IDF, it is strongly advised to make use of
+ an existing Terrier index for background statistics. Without a background index, IDF will be calculated based on the supplied
+ dataframe (for models such as BM25, this can lead to negative scores)::
textscorerTfIdf = pt.text.scorer(body_attr="text", wmodel="TF_IDF", background_index=index)
@@ -512,8 +520,8 @@ def applyPassaging(self, df, labels=True):
newRows.append(newRow)
passageCount+=1
newDF = pd.DataFrame(newRows)
- newDF['query'].fillna('',inplace=True)
- newDF[self.text_attr].fillna('',inplace=True)
- newDF['qid'].fillna('',inplace=True)
+ newDF['query'] = newDF['query'].fillna('')
+ newDF[self.text_attr] = newDF[self.text_attr].fillna('')
+ newDF['qid'] = newDF['qid'].fillna('')
newDF.reset_index(inplace=True,drop=True)
return newDF
diff --git a/pyterrier/transformer.py b/pyterrier/transformer.py
index 38978227..7d1c55e3 100644
--- a/pyterrier/transformer.py
+++ b/pyterrier/transformer.py
@@ -39,7 +39,7 @@ def get_transformer(v, stacklevel=1):
if isinstance(v, pd.DataFrame):
warn('Coercion of a dataframe into a transformer is deprecated; use a pt.Transformer.from_df() instead', stacklevel=stacklevel, category=DeprecationWarning)
return SourceTransformer(v)
- raise ValueError("Passed parameter %s of type %s cannot be coerced into a transformer" % (str(v), type(v)), stacklevel=stacklevel, category=DeprecationWarning)
+ raise ValueError("Passed parameter %s of type %s cannot be coerced into a transformer" % (str(v), type(v)))
rewrite_rules = []
@@ -281,8 +281,10 @@ def __init__(self, *args, **kwargs):
class Indexer(Transformer):
def index(self, iter : Iterable[dict], **kwargs):
"""
- Takes an iterable of dictionaries ("iterdict"), and consumes them. There is no return;
- This method is typically used to implement indexers.
+ Takes an iterable of dictionaries ("iterdict"), and consumes them. The index method may return
+ an instance of the index or retriever. This method is typically used to implement indexers that
+ consume a corpus (or to consume the output of previous pipeline components that have
+ transformer the documents being consumed).
"""
pass
@@ -368,4 +370,4 @@ def __init__(self, rtr, **kwargs):
def transform(self, topics):
rtr = self.rtr.copy()
- return rtr
\ No newline at end of file
+ return rtr
diff --git a/requirements-test.txt b/requirements-test.txt
index a2cc5b7a..9e41c8e0 100644
--- a/requirements-test.txt
+++ b/requirements-test.txt
@@ -4,3 +4,4 @@ fastrank>=0.7.0
torch
lz4
transformers
+scikit-learn
diff --git a/requirements.txt b/requirements.txt
index 55baa50f..6106ceb7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,7 +4,6 @@ wget
tqdm
pyjnius>=1.4.2
matchpy
-scikit-learn
deprecated
chest
scipy
diff --git a/setup.py b/setup.py
index 07913e5c..be9968c1 100644
--- a/setup.py
+++ b/setup.py
@@ -53,6 +53,12 @@ def get_version(rel_path):
author="Craig Macdonald",
author_email='craigm@dcs.gla.ac.uk',
description="Terrier IR platform Python API",
+ project_urls={
+ 'Documentation': 'https://pyterrier.readthedocs.io',
+ 'Changelog': 'https://github.com/terrier-org/pyterrier/releases',
+ 'Issue Tracker': 'https://github.com/terrier-org/pyterrier/issues',
+ 'CI': 'https://github.com/terrier-org/pyterrier/actions',
+ },
long_description=long_description,
long_description_content_type="text/markdown",
package_data={'': ['LICENSE.txt', 'requirements.txt', 'requirements-test.txt']},
@@ -65,5 +71,5 @@ def get_version(rel_path):
"Operating System :: OS Independent",
],
install_requires=requirements,
- python_requires='>=3.7',
+ python_requires='>=3.8',
)
diff --git a/terrier-python-helper/pom.xml b/terrier-python-helper/pom.xml
index c1281371..c829982a 100644
--- a/terrier-python-helper/pom.xml
+++ b/terrier-python-helper/pom.xml
@@ -110,14 +110,14 @@
ch.qos.logback
logback-classic
- 1.2.0
+ 1.2.13
provided
ch.qos.logback
logback-core
- 1.2.9
+ 1.2.13
provided
diff --git a/tests/base.py b/tests/base.py
index 93594a10..501cd234 100644
--- a/tests/base.py
+++ b/tests/base.py
@@ -10,14 +10,20 @@ class BaseTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(BaseTestCase, self).__init__(*args, **kwargs)
terrier_version = os.environ.get("TERRIER_VERSION", None)
- if terrier_version is not None:
- print("Testing with Terrier version " + terrier_version)
terrier_helper_version = os.environ.get("TERRIER_HELPER_VERSION", None)
- if terrier_helper_version is not None:
- print("Testing with Terrier Helper version " + terrier_helper_version)
if not pt.started():
+
+ # display for debugging what is being used
+ if terrier_version is not None:
+ print("Testing with Terrier version " + terrier_version)
+ if terrier_helper_version is not None:
+ print("Testing with Terrier Helper version " + terrier_helper_version)
+
pt.init(version=terrier_version, logging="DEBUG", helper_version=terrier_helper_version)
+ # jvm_opts=['-ea'] can be added here to ensure that all Java assertions are met
self.here = os.path.dirname(os.path.realpath(__file__))
+
+ # check that pt.init() is saving its arguments
assert "version" in pt.init_args
assert pt.init_args["version"] == terrier_version
@@ -42,4 +48,4 @@ def tearDown(self):
except:
pass
-
\ No newline at end of file
+
diff --git a/tests/fixtures/topics.trecxml b/tests/fixtures/topics.trecxml
new file mode 100644
index 00000000..bea592cb
--- /dev/null
+++ b/tests/fixtures/topics.trecxml
@@ -0,0 +1,20 @@
+
+
+ 1
+ lights
+ Description lights
+ Documents are relevant if they describe lights.
+
+
+ 2
+ radiowaves
+ Description radiowaves
+ Documents are relevant if they describe radiowaves.
+
+
+
+ sounds
+ Description sound
+ Documents are relevant if they describe sounds.
+
+
\ No newline at end of file
diff --git a/tests/test_apply.py b/tests/test_apply.py
index 0f6faf8f..ad834eaf 100644
--- a/tests/test_apply.py
+++ b/tests/test_apply.py
@@ -59,6 +59,14 @@ def test_query_apply(self):
rtrDR2 = pt.apply.query(lambda row : row["qid"] )(testDF2)
self.assertEqual(rtrDR2.iloc[0]["query"], "q1")
+ def test_query_apply_error(self):
+ origquery="the bear and the wolf"
+ testDF = pd.DataFrame([["q1", origquery]], columns=["qid", "query"])
+ p = pt.apply.query(lambda q : q) # should thrown an error, as pt.apply.query should return a string, not a row
+ with self.assertRaises(TypeError) as te:
+ p(testDF)
+ self.assertTrue("Could not coerce return from pt.apply.query function into a list of strings" in str(te.exception))
+
def test_by_query_apply(self):
inputDf = pt.new.ranked_documents([[1], [2]], qid=["1", "2"])
def _inc_score(res):
diff --git a/tests/test_experiment.py b/tests/test_experiment.py
index a4f380f9..5498ba30 100644
--- a/tests/test_experiment.py
+++ b/tests/test_experiment.py
@@ -321,7 +321,7 @@ def test_baseline_and_tests(self):
# user-specified TOST
# TOST will omit warnings here, due to low numbers of topics
import statsmodels.stats.weightstats
- fn = lambda X,Y: (0, statsmodels.stats.weightstats.ttost_ind(X, Y, -0.01, 0.01)[0])
+ fn = lambda X,Y: (0, statsmodels.stats.weightstats.ttost_paired(X, Y, -0.01, 0.01)[0])
#This filter doesnt work
with warnings.catch_warnings(record=True) as w:
@@ -363,15 +363,17 @@ def test_baseline_corrected(self):
dataset = pt.get_dataset("vaswani")
res1 = pt.BatchRetrieve(dataset.get_index(), wmodel="BM25")(dataset.get_topics().head(10))
res2 = pt.BatchRetrieve(dataset.get_index(), wmodel="DPH")(dataset.get_topics().head(10))
- for corr in ['hs', 'bonferroni', 'holm-sidak']:
+ baseline = 0
+ for corr in ['hs', 'bonferroni', 'hommel']:
df = pt.Experiment(
[res1, res2],
dataset.get_topics().head(10),
dataset.get_qrels(),
eval_metrics=["map", "ndcg"],
- baseline=0, correction='hs')
+ baseline=baseline, correction=corr)
self.assertTrue("map +" in df.columns)
self.assertTrue("map -" in df.columns)
self.assertTrue("map p-value" in df.columns)
self.assertTrue("map p-value corrected" in df.columns)
self.assertTrue("map reject" in df.columns)
+ self.assertFalse(any(df["map p-value corrected"].drop(df.index[baseline]).isna()))
diff --git a/tests/test_fbr.py b/tests/test_fbr.py
index b25c58ee..d7c49e25 100644
--- a/tests/test_fbr.py
+++ b/tests/test_fbr.py
@@ -137,6 +137,35 @@ def test_fbr(self):
if "matching" in retrBasic.controls:
self.assertNotEqual(retrBasic.controls["matching"], "FatFeaturedScoringMatching,org.terrier.matching.daat.FatFull")
+ def test_fbr_example(self):
+ JIR = pt.autoclass('org.terrier.querying.IndexRef')
+ indexref = JIR.of(self.here + "/fixtures/index/data.properties")
+ index = pt.IndexFactory.of(indexref)
+ # this ranker will make the candidate set of documents for each query
+ BM25 = pt.BatchRetrieve(index, wmodel="BM25")
+
+ # these rankers we will use to re-rank the BM25 results
+ TF_IDF = pt.BatchRetrieve(index, wmodel="Dl")
+ PL2 = pt.BatchRetrieve(index, wmodel="PL2")
+
+ pipe = (BM25 %2) >> (TF_IDF ** PL2)
+ fbr = pt.FeaturesBatchRetrieve(indexref, ["WMODEL:Dl", "WMODEL:PL2"], wmodel="BM25") % 2
+ resultP = pipe.search("chemical")
+ resultF = fbr.search("chemical")
+ pd.set_option('display.max_columns', None)
+
+ self.assertEqual(resultP.iloc[0].docno, resultF.iloc[0].docno)
+ self.assertEqual(resultP.iloc[0].score, resultF.iloc[0].score)
+ self.assertEqual(resultP.iloc[0].features[0], resultF.iloc[0].features[0])
+ self.assertEqual(resultP.iloc[0].features[1], resultF.iloc[0].features[1])
+
+ pipeCompiled = pipe.compile()
+ resultC = pipeCompiled.search("chemical")
+ self.assertEqual(resultP.iloc[0].docno, resultC.iloc[0].docno)
+ self.assertEqual(resultP.iloc[0].score, resultC.iloc[0].score)
+ self.assertEqual(resultP.iloc[0].features[0], resultC.iloc[0].features[0])
+ self.assertEqual(resultP.iloc[0].features[1], resultC.iloc[0].features[1])
+
def test_fbr_empty(self):
JIR = pt.autoclass('org.terrier.querying.IndexRef')
indexref = JIR.of(self.here + "/fixtures/index/data.properties")
diff --git a/tests/test_index_op.py b/tests/test_index_op.py
index 25b0ae3e..a6f27538 100644
--- a/tests/test_index_op.py
+++ b/tests/test_index_op.py
@@ -10,6 +10,98 @@
class TestIndexOp(TempDirTestCase):
+ def test_index_corpus_iter(self):
+ import sys
+ MIN_PYTHON = (3, 8)
+ if sys.version_info < MIN_PYTHON:
+ self.skipTest("Not minimum Python requirements")
+
+ documents = [
+ {'docno' : 'd1', 'text': 'stemming stopwords stopwords'},
+ ]
+ index = pt.IndexFactory.of( pt.IterDictIndexer(tempfile.mkdtemp(), stopwords=None, stemmer=None).index(documents) )
+ self.assertEqual(1, len(index))
+ self.assertEqual(2, index.getCollectionStatistics().getNumberOfUniqueTerms())
+ self.assertEqual(3, index.getCollectionStatistics().getNumberOfTokens())
+
+ # check that get_corpus_iter() contains the correct information
+ iter = index.get_corpus_iter()
+ first_doc = next(iter)
+ self.assertTrue(first_doc is not None)
+ self.assertIn('docno', first_doc)
+ self.assertIn('toks', first_doc)
+ self.assertIn('stemming', first_doc['toks'])
+ self.assertIn('stopwords', first_doc['toks'])
+ self.assertEqual(1, first_doc['toks']['stemming'])
+ self.assertEqual(2, first_doc['toks']['stopwords'])
+ with(self.assertRaises(StopIteration)):
+ next(iter)
+
+ # now check that a static pruning pipe can operate as expected. this example comes from terrier-index-api.rst
+ index_pipe = (
+ # update the toks column for each document, keeping only terms with frequency > 1
+ pt.apply.toks(lambda row: { t : row['toks'][t] for t in row['toks'] if row['toks'][t] > 1 } )
+ >> pt.IterDictIndexer(tempfile.mkdtemp(), pretokenised=True)
+ )
+ new_index_ref = index_pipe.index( index.get_corpus_iter())
+ pruned_index = pt.IndexFactory.of(new_index_ref)
+ self.assertEqual(1, len(pruned_index))
+ self.assertEqual(1, pruned_index.getCollectionStatistics().getNumberOfUniqueTerms())
+ self.assertEqual(2, pruned_index.getCollectionStatistics().getNumberOfTokens())
+
+ def test_index_corpus_iter_empty(self):
+ import sys
+ MIN_PYTHON = (3, 8)
+ if sys.version_info < MIN_PYTHON:
+ self.skipTest("Not minimum Python requirements")
+
+ # compared to test_index_corpus_iter, this tests empty documents are handled correctly.
+ documents = [
+ {'docno' : 'd0', 'text':''},
+ {'docno' : 'd1', 'text':''},
+ {'docno' : 'd2', 'text': 'stemming stopwords stopwords'},
+ {'docno' : 'd3', 'text':''},
+ {'docno' : 'd4', 'text': 'stemming stopwords stopwords'},
+ {'docno' : 'd5', 'text': ''}
+ ]
+ index = pt.IndexFactory.of( pt.IterDictIndexer(tempfile.mkdtemp(), stopwords=None, stemmer=None).index(documents) )
+ self.assertEqual(6, len(index))
+ self.assertEqual(2, index.getCollectionStatistics().getNumberOfUniqueTerms())
+ self.assertEqual(6, index.getCollectionStatistics().getNumberOfTokens())
+
+ iter = index.get_corpus_iter()
+
+ counter = 0
+ for doc in documents:
+ next_doc = next(iter)
+ counter += 1
+ self.assertTrue(next_doc is not None)
+ self.assertIn('docno', next_doc)
+ self.assertIn('toks', next_doc)
+ if doc['text'] == '':
+ self.assertEqual(0, len(next_doc['toks']))
+ else:
+ self.assertIn('stemming', next_doc['toks'])
+ self.assertIn('stopwords', next_doc['toks'])
+ self.assertEqual(1, next_doc['toks']['stemming'])
+ self.assertEqual(2, next_doc['toks']['stopwords'])
+
+ with(self.assertRaises(StopIteration)):
+ next(iter)
+ self.assertEqual(counter, len(documents))
+
+ # now check that a static pruning pipe can operate as expected. this example comes from terrier-index-api.rst
+ index_pipe = (
+ # update the toks column for each document, keeping only terms with frequency > 1
+ pt.apply.toks(lambda row: { t : row['toks'][t] for t in row['toks'] if row['toks'][t] > 1 } )
+ >> pt.IterDictIndexer(tempfile.mkdtemp(), pretokenised=True)
+ )
+ new_index_ref = index_pipe.index( index.get_corpus_iter())
+ pruned_index = pt.IndexFactory.of(new_index_ref)
+ self.assertEqual(6, len(pruned_index))
+ self.assertEqual(1, pruned_index.getCollectionStatistics().getNumberOfUniqueTerms())
+ self.assertEqual(4, pruned_index.getCollectionStatistics().getNumberOfTokens())
+
def test_index_add_write(self):
# inspired by https://github.com/terrier-org/pyterrier/issues/390
documents = [
diff --git a/tests/test_ltr_pipelines.py b/tests/test_ltr_pipelines.py
index 9824d8bf..fdc967c7 100644
--- a/tests/test_ltr_pipelines.py
+++ b/tests/test_ltr_pipelines.py
@@ -39,7 +39,6 @@ def test_xgltr_pipeline(self):
'learning_rate': 0.1,
'gamma': 1.0, 'min_child_weight': 0.1,
'max_depth': 6,
- 'verbose': 2,
'random_state': 42
}
diff --git a/tests/test_pickle.py b/tests/test_pickle.py
index 8d1a7064..769b228d 100644
--- a/tests/test_pickle.py
+++ b/tests/test_pickle.py
@@ -98,6 +98,22 @@ def test_fbr_joblib(self):
self._fix_joblib()
self._fbr(joblib)
+ def test_qe_pickle(self):
+ self._qe(pickle)
+
+ def _qe(self, pickler):
+ vaswani = pt.datasets.get_dataset("vaswani")
+ index = vaswani.get_index()
+ bm25 = pt.BatchRetrieve(index, wmodel='BM25', controls={"c" : 0.75}, num_results=15)
+ br = bm25 >> pt.rewrite.Bo1QueryExpansion(index) >> bm25
+ q = pd.DataFrame([["q1", "chemical"]], columns=["qid", "query"])
+ res1 = br(q)
+ byterep = pickler.dumps(br)
+ br2 = pickler.loads(byterep)
+
+ res2 = br2(q)
+ pd.testing.assert_frame_equal(res1, res2)
+
def _br(self, pickler, wmodel='BM25'):
vaswani = pt.datasets.get_dataset("vaswani")
br = pt.BatchRetrieve(vaswani.get_index(), wmodel=wmodel, controls={"c" : 0.75}, num_results=15)
diff --git a/tests/test_text.py b/tests/test_text.py
index f24dd6db..12e248c0 100644
--- a/tests/test_text.py
+++ b/tests/test_text.py
@@ -41,7 +41,8 @@ def test_scorer_rerank(self):
self.assertEqual(1, dfOut.iloc[0]["rank"])
def test_snippets(self):
- br = pt.BatchRetrieve.from_dataset("vaswani", "terrier_stemmed_text", metadata=["docno", "text"])
+ br = pt.BatchRetrieve.from_dataset("vaswani", "terrier_stemmed") >> pt.text.get_text(pt.get_dataset('irds:vaswani'), "text")
+ #br = pt.BatchRetrieve.from_dataset("vaswani", "terrier_stemmed_text", metadata=["docno", "text"])
psg_scorer = (
pt.text.sliding(text_attr='text', length=25, stride=12, prepend_attr=None)
>> pt.text.scorer(body_attr="text", wmodel='Tf', takes='docs')
diff --git a/tests/test_topicsparsing.py b/tests/test_topicsparsing.py
index cedcebbb..20ecde94 100644
--- a/tests/test_topicsparsing.py
+++ b/tests/test_topicsparsing.py
@@ -1,14 +1,19 @@
-import pyterrier as pt
-import unittest
-from .base import BaseTestCase
import os
+import unittest
+
import pandas as pd
-class TestTopicsParsing(BaseTestCase):
+import pyterrier as pt
+from .base import BaseTestCase
+
+
+class TestTopicsParsing(BaseTestCase):
def testSingleLine(self):
topics = pt.io.read_topics(
- os.path.dirname(os.path.realpath(__file__)) + "/fixtures/singleline.topics", format="singleline")
+ os.path.dirname(os.path.realpath(__file__)) + "/fixtures/singleline.topics",
+ format="singleline",
+ )
self.assertEqual(2, len(topics))
self.assertTrue("qid" in topics.columns)
self.assertTrue("query" in topics.columns)
@@ -19,12 +24,29 @@ def testSingleLine(self):
def test_parse_trec_topics_file_T(self):
input = os.path.dirname(os.path.realpath(__file__)) + "/fixtures/topics.trec"
- exp_result = pd.DataFrame([["1", "light"], ["2", "radiowave"], ["3", "sound"]], columns=['qid', 'query'])
+ exp_result = pd.DataFrame(
+ [["1", "light"], ["2", "radiowave"], ["3", "sound"]],
+ columns=["qid", "query"],
+ )
result = pt.io.read_topics(input)
self.assertTrue(exp_result.equals(result))
def test_parse_trec_topics_file_D(self):
input = os.path.dirname(os.path.realpath(__file__)) + "/fixtures/topics.trec"
- exp_result = pd.DataFrame([["1", "lights"], ["2", "radiowaves"], ["3", "sounds"]], columns=['qid', 'query'])
- result = pt.io.read_topics(input, format="trec", whitelist=["DESC"], blacklist=["TITLE"])
- self.assertTrue(exp_result.equals(result))
\ No newline at end of file
+ exp_result = pd.DataFrame(
+ [["1", "lights"], ["2", "radiowaves"], ["3", "sounds"]],
+ columns=["qid", "query"],
+ )
+ result = pt.io.read_topics(
+ input, format="trec", whitelist=["DESC"], blacklist=["TITLE"]
+ )
+ self.assertTrue(exp_result.equals(result))
+
+ def test_parse_trecxml_topics_file(self):
+ input = os.path.dirname(os.path.realpath(__file__)) + "/fixtures/topics.trecxml"
+ result = pt.io.read_topics(input, format="trecxml", tags=["title"])
+ exp_result = pd.DataFrame(
+ [["1", "lights"], ["2", "radiowaves"], ["3", "sounds"]],
+ columns=["qid", "query"],
+ )
+ self.assertTrue(exp_result.equals(result))