Skip to content

Commit

Permalink
Merge pull request #303 from seanmacavaney/field_named_query
Browse files Browse the repository at this point in the history
irds fix error when a query field is named "query"
  • Loading branch information
cmacdonald authored May 5, 2022
2 parents d4cd7ef + 584c923 commit 943764f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pyterrier/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,11 @@ def get_topics(self, variant=None, tokenise_query=True):
df.rename(columns={"query_id": "qid"}, inplace=True) # pyterrier uses "qid"

if variant is not None:
# Some datasets have a query field called "query". We need to remove it or
# we'll end up with multiple "query" columns, which will cause problems
# because many components are written assuming no columns have the same name.
if variant != 'query' and 'query' in df.columns:
df.drop(['query'], 1, inplace=True)
df.rename(columns={variant: "query"}, inplace=True) # user specified which version of the query they want
df.drop(df.columns.difference(['qid','query']), 1, inplace=True)
elif len(qcls._fields) == 2:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_irds_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_results(self):
self.assertEqual('who is robert gray', results.iloc[0].query)
# ensure it's terrier-tokenised (orig text is "tracheids are part of _____.")
self.assertEqual('tracheids are part of', results[results.qid=='1124210'].iloc[0].query)

def test_nonexistant(self):
# Should raise an error when you request an irds: dataset that doesn't exist
with self.assertRaises(KeyError):
Expand All @@ -79,5 +80,30 @@ def test_nonexistant(self):
with self.assertRaises(KeyError):
dataset = pt.datasets.get_dataset('bla-bla-bla')

def test_variants(self):
dataset = pt.get_dataset('irds:clueweb09/catb/trec-web-2009')

with self.subTest('all fields'):
topics = dataset.get_topics()
self.assertEqual(['qid', 'query', 'description', 'type', 'subtopics'], list(topics.columns))

with self.subTest('specific field'):
topics = dataset.get_topics('description')
self.assertEqual(['qid', 'query'], list(topics.columns)) # description mapped to query
self.assertEqual(topics.iloc[0]['query'], 'find information on president barack obama s family history including genealogy national origins places and dates of birth etc')

with self.subTest('specific field'):
topics = dataset.get_topics('description', tokenise_query=False)
self.assertEqual(['qid', 'query'], list(topics.columns)) # description mapped to query
self.assertEqual(topics.iloc[0]['query'], "Find information on President Barack Obama's family\n history, including genealogy, national origins, places and dates of\n birth, etc.\n ")

with self.subTest('field named query'):
topics = dataset.get_topics('query')
self.assertEqual(['qid', 'query'], list(topics.columns))
self.assertEqual(topics.iloc[0]['query'], 'obama family tree')

with self.assertRaises(AssertionError):
dataset.get_topics('field_that_does_not_exist')

if __name__ == '__main__':
unittest.main()

0 comments on commit 943764f

Please sign in to comment.