diff --git a/pyathenajdbc/sqlalchemy_athena.py b/pyathenajdbc/sqlalchemy_athena.py index eaea1cc..8d321f2 100644 --- a/pyathenajdbc/sqlalchemy_athena.py +++ b/pyathenajdbc/sqlalchemy_athena.py @@ -61,23 +61,23 @@ def do_bindparam(m): _TYPE_MAPPINGS = { - 'BOOLEAN': BOOLEAN, - 'REAL': FLOAT, - 'FLOAT': FLOAT, - 'DOUBLE': FLOAT, - 'TINYINT': INTEGER, - 'SMALLINT': INTEGER, - 'INTEGER': INTEGER, - 'BIGINT': BIGINT, - 'DECIMAL': DECIMAL, - 'CHAR': STRINGTYPE, - 'VARCHAR': STRINGTYPE, - 'ARRAY': STRINGTYPE, - 'ROW': STRINGTYPE, # StructType - 'VARBINARY': BINARY, - 'MAP': STRINGTYPE, - 'DATE': DATE, - 'TIMESTAMP': TIMESTAMP, + 'boolean': BOOLEAN, + 'real': FLOAT, + 'float': FLOAT, + 'double': FLOAT, + 'tinyint': INTEGER, + 'smallint': INTEGER, + 'integer': INTEGER, + 'bigint': BIGINT, + 'decimal': DECIMAL, + 'char': STRINGTYPE, + 'varchar': STRINGTYPE, + 'array': STRINGTYPE, + 'row': STRINGTYPE, # StructType + 'varbinary': BINARY, + 'map': STRINGTYPE, + 'date': DATE, + 'timestamp': TIMESTAMP, } @@ -98,6 +98,8 @@ class AthenaDialect(DefaultDialect): description_encoding = None supports_native_boolean = True + _pattern_column_type = re.compile(r'^([a-zA-Z]+)($|\(.+\)$)') + @classmethod def dbapi(cls): return pyathenajdbc @@ -146,8 +148,6 @@ def has_table(self, connection, table_name, schema=None): @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): - # information_schema.columns fails when filtering with table_schema or table_name, - # if specifying a name that does not exist in table_schema or table_name. schema = schema if schema else connection.connection.schema_name query = """ SELECT @@ -160,20 +160,23 @@ def get_columns(self, connection, table_name, schema=None, **kw): ordinal_position, comment FROM information_schema.columns - """ + WHERE table_schema = '{schema}' + AND table_name = '{table}' + """.format(schema=schema, table=table_name) return [ { 'name': row.column_name, - 'type': _TYPE_MAPPINGS.get(re.sub(r'^([A-Z]+)($|\(.+\)$)', r'\1', - row.data_type.upper()), NULLTYPE), + 'type': _TYPE_MAPPINGS.get(self._get_column_type(row.data_type), NULLTYPE), 'nullable': True if row.is_nullable == 'YES' else False, 'default': row.column_default, 'ordinal_position': row.ordinal_position, 'comment': row.comment, } for row in connection.execute(query).fetchall() - if row.table_schema == schema and row.table_name == table_name ] + def _get_column_type(self, type_): + return self._pattern_column_type.sub(r'\1', type_) + def get_foreign_keys(self, connection, table_name, schema=None, **kw): # Athena has no support for foreign keys. return [] diff --git a/tests/test_sqlalchemy_athena.py b/tests/test_sqlalchemy_athena.py index c48585a..80c6fa9 100644 --- a/tests/test_sqlalchemy_athena.py +++ b/tests/test_sqlalchemy_athena.py @@ -114,6 +114,17 @@ def test_has_table(self, engine, connection): self.assertTrue(Table('one_row', MetaData(bind=engine)).exists()) self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists()) + @with_engine + def test_get_columns(self, engine, connection): + insp = sqlalchemy.inspect(engine) + actual = insp.get_columns(table_name='one_row', schema=SCHEMA)[0] + self.assertEqual(actual['name'], 'number_of_rows') + self.assertTrue(isinstance(actual['type'], INTEGER)) + self.assertTrue(actual['nullable']) + self.assertIsNone(actual['default']) + self.assertEqual(actual['ordinal_position'], 1) + self.assertIsNone(actual['comment']) + @with_engine def test_char_length(self, engine, connection): one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True) @@ -172,6 +183,25 @@ def test_reserved_words(self, engine, connection): self.assertNotIn('`select`', query) self.assertNotIn('`current_timestamp`', query) + @with_engine + def test_get_column_type(self, engine, connection): + dialect = engine.dialect + self.assertEqual(dialect._get_column_type('boolean'), 'boolean') + self.assertEqual(dialect._get_column_type('tinyint'), 'tinyint') + self.assertEqual(dialect._get_column_type('smallint'), 'smallint') + self.assertEqual(dialect._get_column_type('integer'), 'integer') + self.assertEqual(dialect._get_column_type('bigint'), 'bigint') + self.assertEqual(dialect._get_column_type('real'), 'real') + self.assertEqual(dialect._get_column_type('double'), 'double') + self.assertEqual(dialect._get_column_type('varchar'), 'varchar') + self.assertEqual(dialect._get_column_type('timestamp'), 'timestamp') + self.assertEqual(dialect._get_column_type('date'), 'date') + self.assertEqual(dialect._get_column_type('varbinary'), 'varbinary') + self.assertEqual(dialect._get_column_type('array(integer)'), 'array') + self.assertEqual(dialect._get_column_type('map(integer, integer)'), 'map') + self.assertEqual(dialect._get_column_type('row(a integer, b integer)'), 'row') + self.assertEqual(dialect._get_column_type('decimal(10,1)'), 'decimal') + @with_engine def test_contain_percents_character_query(self, engine, connection): query = sqlalchemy.sql.text("""