diff --git a/flask_restless/search/drivers.py b/flask_restless/search/drivers.py index 3fb32555..302039b7 100644 --- a/flask_restless/search/drivers.py +++ b/flask_restless/search/drivers.py @@ -130,7 +130,7 @@ def search(session, model, filters=None, sort=None, group_by=None, for field_name in group_by: if '.' in field_name: field_name, field_name_in_relation = field_name.split('.') - relation_model = get_related_model(model, field_name) + relation_model = aliased(get_related_model(model, field_name)) field = getattr(relation_model, field_name_in_relation) query = query.join(relation_model) query = query.group_by(field) diff --git a/tests/test_fetching.py b/tests/test_fetching.py index c3d0c25d..c51dcc3b 100644 --- a/tests/test_fetching.py +++ b/tests/test_fetching.py @@ -18,6 +18,7 @@ specification. """ +from itertools import product from operator import itemgetter from unittest2 import skip @@ -50,6 +51,7 @@ def setUp(self): class Person(self.Base): __tablename__ = 'person' id = Column(Integer, primary_key=True) + age = Column(Integer) name = Column(Unicode) class Article(self.Base): @@ -187,13 +189,34 @@ def test_group_by_related(self): article3.author = person2 self.session.add_all([person1, person2, article1, article2, article3]) self.session.commit() - response = self.app.get('/api/article?group=author.name') + query_string = {'group': 'author.name'} + response = self.app.get('/api/article', query_string=query_string) document = loads(response.data) articles = document['data'] author_ids = sorted(article['relationships']['author']['data']['id'] for article in articles) assert ['1', '2'] == author_ids + def test_group_by_mutiple_relationship_attributes(self): + """Tests for grouping results by multiple fields of a related model.""" + names = [u'foo', u'bar'] + ages = [10, 20] + # There are two people with each combination of name and age. + for i, (name, age) in enumerate(product(names, ages), start=1): + person1 = self.Person(id=2 * i - 1, name=name, age=age) + person2 = self.Person(id=2 * i, name=name, age=age) + article1 = self.Article(author=person1) + article2 = self.Article(author=person2) + self.session.add_all([article1, article2, person1, person2]) + self.session.commit() + query_string = {'group': ','.join(['author.name', 'author.age'])} + response = self.app.get('/api/article', query_string=query_string) + document = loads(response.data) + articles = document['data'] + author_ids = sorted(article['relationships']['author']['data']['id'] + for article in articles) + assert ['2', '4', '6', '8'] == author_ids + def test_pagination_links_empty_collection(self): """Tests that pagination links work correctly for an empty collection.