diff --git a/src/crate/client/sqlalchemy/tests/dialect_test.py b/src/crate/client/sqlalchemy/tests/dialect_test.py index d744f835..75145768 100644 --- a/src/crate/client/sqlalchemy/tests/dialect_test.py +++ b/src/crate/client/sqlalchemy/tests/dialect_test.py @@ -33,12 +33,7 @@ from sqlalchemy.orm import Session from sqlalchemy.testing import eq_, in_ -fake_cursor = MagicMock(name='fake_cursor') -fake_cursor.description = ( - ('foo', None, None, None, None, None, None), -) FakeCursor = MagicMock(name='FakeCursor', spec=Cursor) -FakeCursor.return_value = fake_cursor @patch('crate.client.connection.Cursor', FakeCursor) @@ -46,9 +41,13 @@ class DialectTest(TestCase): def execute_wrapper(self, query, *args, **kwargs): self.executed_statement = query - return fake_cursor + return self.fake_cursor def setUp(self): + + self.fake_cursor = MagicMock(name='fake_cursor') + FakeCursor.return_value = self.fake_cursor + self.engine = sa.create_engine('crate://') self.executed_statement = None @@ -58,7 +57,7 @@ def setUp(self): self.engine.execute = self.execute_wrapper self.connection.execute = self.execute_wrapper else: - fake_cursor.execute = self.execute_wrapper + self.fake_cursor.execute = self.execute_wrapper self.base = declarative_base(bind=self.engine) @@ -73,31 +72,56 @@ class Character(self.base): self.character = Character self.session = Session() - def test_pks_are_retrieved_depending_on_version_set(self): + def test_pks_are_retrieved_depending_on_version_set_old(self): + """ + Test the old pk retrieval. + """ meta = self.character.metadata - - # test the old pk retrieval insp = inspect(meta.bind) self.engine.dialect.server_version_info = (0, 54, 0) - fake_cursor.rowcount = 1 - fake_cursor.fetchone = MagicMock(return_value=[["id", "id2", "id3"]]) + + # Setup fake cursor. + self.fake_cursor.rowcount = 1 + self.fake_cursor.description = ( + ('foo', None, None, None, None, None, None), + ) + self.fake_cursor.fetchone = MagicMock(return_value=[["id", "id2", "id3"]]) + + # Verify outcome. eq_(insp.get_pk_constraint("characters")['constrained_columns'], {"id", "id2", "id3"}) - fake_cursor.fetchone.assert_called_once_with() + self.fake_cursor.fetchone.assert_called_once_with() in_("information_schema.table_constraints", self.executed_statement) - # test the new pk retrieval + def test_pks_are_retrieved_depending_on_version_set_new(self): + """ + Test the new pk retrieval. + """ + meta = self.character.metadata insp = inspect(meta.bind) self.engine.dialect.server_version_info = (2, 3, 0) - fake_cursor.rowcount = 3 - fake_cursor.fetchall = MagicMock(return_value=[["id"], ["id2"], ["id3"]]) + + # Setup fake cursor. + self.fake_cursor.rowcount = 3 + self.fake_cursor.description = ( + ('foo', None, None, None, None, None, None), + ) + self.fake_cursor.fetchall = MagicMock(return_value=[["id"], ["id2"], ["id3"]]) + + # Verify outcome. eq_(insp.get_pk_constraint("characters")['constrained_columns'], {"id", "id2", "id3"}) - fake_cursor.fetchall.assert_called_once_with() + self.fake_cursor.fetchall.assert_called_once_with() in_("information_schema.key_column_usage", self.executed_statement) def test_get_table_names(self): - fake_cursor.rowcount = 1 - fake_cursor.fetchall = MagicMock(return_value=[["t1"], ["t2"]]) + # Setup fake cursor. + self.fake_cursor.rowcount = 1 + self.fake_cursor.description = ( + ('foo', None, None, None, None, None, None), + ) + self.fake_cursor.fetchall = MagicMock(return_value=[["t1"], ["t2"]]) + + # Verify outcome. insp = inspect(self.character.metadata.bind) self.engine.dialect.server_version_info = (2, 0, 0) eq_(insp.get_table_names(schema="doc"), @@ -117,11 +141,18 @@ def test_get_table_names(self): in_("WHERE schema_name = ? ORDER BY", self.executed_statement) def test_get_view_names(self): - fake_cursor.rowcount = 1 - fake_cursor.fetchall = MagicMock(return_value=[["v1"], ["v2"]]) insp = inspect(self.character.metadata.bind) self.engine.dialect.server_version_info = (2, 0, 0) + + # Setup fake cursor. + self.fake_cursor.rowcount = 1 + self.fake_cursor.description = ( + ('foo', None, None, None, None, None, None), + ) + self.fake_cursor.fetchall = MagicMock(return_value=[["v1"], ["v2"]]) + + # Verify outcome. eq_(insp.get_view_names(schema="doc"), ['v1', 'v2']) eq_(self.executed_statement, "SELECT table_name FROM information_schema.views "