Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 50 additions & 9 deletions README.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
.. image:: https://img.shields.io/pypi/pyversions/PyAthena.svg
:target: https://pypi.python.org/pypi/PyAthena/
:target: https://pypi.org/project/PyAthena/

.. image:: https://travis-ci.com/laughingman7743/PyAthena.svg?branch=master
:target: https://travis-ci.com/laughingman7743/PyAthena
Expand All @@ -20,7 +20,7 @@ PyAthena
PyAthena is a Python `DB API 2.0 (PEP 249)`_ compliant client for `Amazon Athena`_.

.. _`DB API 2.0 (PEP 249)`: https://www.python.org/dev/peps/pep-0249/
.. _`Amazon Athena`: http://docs.aws.amazon.com/athena/latest/APIReference/Welcome.html
.. _`Amazon Athena`: https://docs.aws.amazon.com/athena/latest/APIReference/Welcome.html

Requirements
------------
Expand Down Expand Up @@ -136,13 +136,13 @@ Supported SQLAlchemy is 1.0.0 or higher and less than 2.0.0.

The connection string has the following format:

.. code:: python
.. code:: text

awsathena+rest://{aws_access_key_id}:{aws_secret_access_key}@athena.{region_name}.amazonaws.com:443/{schema_name}?s3_staging_dir={s3_staging_dir}&...

If you do not specify ``aws_access_key_id`` and ``aws_secret_access_key`` using instance profile or boto3 configuration file:

.. code:: python
.. code:: text

awsathena+rest://:@athena.{region_name}.amazonaws.com:443/{schema_name}?s3_staging_dir={s3_staging_dir}&...

Expand All @@ -151,7 +151,10 @@ NOTE: ``s3_staging_dir`` requires quote. If ``aws_access_key_id``, ``aws_secret_
Pandas
~~~~~~

Minimal example for Pandas DataFrame:
As DataFrame
^^^^^^^^^^^^

You can use the `pandas.read_sql`_ to handle the query results as a `DataFrame object`_.

.. code:: python

Expand All @@ -165,7 +168,7 @@ Minimal example for Pandas DataFrame:
df = pd.read_sql("SELECT * FROM many_rows", conn)
print(df.head())

As Pandas DataFrame:
The ``pyathena.util`` package also has helper methods.

.. code:: python

Expand All @@ -180,7 +183,45 @@ As Pandas DataFrame:
df = as_pandas(cursor)
print(df.describe())

If you want to use Pandas `DataFrame object`_ directly, you can use `PandasCursor`_.
If you want to use the query results output to S3 directly, you can use `PandasCursor`_.
This cursor fetches query results faster than the default cursor. (See `benchmark results`_.)

.. _`pandas.read_sql`: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_sql.html
.. _`benchmark results`: benchmarks/README.rst

To SQL
^^^^^^

You can use `pandas.DataFrame.to_sql`_ to write records stored in DataFrame to Amazon Athena.
`pandas.DataFrame.to_sql`_ uses `SQLAlchemy`_, so you need to install it.

.. code:: python

import pandas as pd
from urllib.parse import quote_plus
from sqlalchemy import create_engine

conn_str = 'awsathena+rest://:@athena.{region_name}.amazonaws.com:443/'\
'{schema_name}?s3_staging_dir={s3_staging_dir}&s3_dir={s3_dir}&compression=snappy'
engine = create_engine(conn_str.format(
region_name='us-west-2',
schema_name='YOUR_SCHEMA',
s3_staging_dir=quote_plus('s3://YOUR_S3_BUCKET/path/to/'),
s3_dir=quote_plus('s3://YOUR_S3_BUCKET/path/to/')))

df = pd.DataFrame({'a': [1, 2, 3, 4, 5]})
df.to_sql('YOUR_TABLE', engine, schema="YOUR_SCHEMA", index=False, if_exists='replace', method='multi')

The location of the Amazon S3 table is specified by the ``s3_dir`` parameter in the connection string.
If ``s3_dir`` is not specified, ``s3_staging_dir`` parameter will be used. The following rules apply.

.. code:: text

s3://{s3_dir or s3_staging_dir}/{schema}/{table}/

The data format only supports Parquet. The compression format is specified by the ``compression`` parameter in the connection string.

.. _`pandas.DataFrame.to_sql`: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html

AsynchronousCursor
~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -493,8 +534,8 @@ Then you simply specify an instance of this class in the convertes argument when

NOTE: PandasCursor handles the CSV file on memory. Pay attention to the memory capacity.

.. _`DataFrame object`: https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html
.. _`pandas.Timestamp`: https://pandas.pydata.org/pandas-docs/stable/generated/pandas.Timestamp.html
.. _`DataFrame object`: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html
.. _`pandas.Timestamp`: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.Timestamp.html

AsyncPandasCursor
~~~~~~~~~~~~~~~~~
Expand Down
7 changes: 5 additions & 2 deletions pyathena/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging

from pyathena.common import BaseCursor, CursorIterator
from pyathena.error import NotSupportedError, OperationalError, ProgrammingError
from pyathena.error import OperationalError, ProgrammingError
from pyathena.model import AthenaQueryExecution
from pyathena.result_set import AthenaResultSet, WithResultSet
from pyathena.util import synchronized
Expand Down Expand Up @@ -58,7 +58,10 @@ def execute(self, operation, parameters=None, work_group=None, s3_staging_dir=No
return self

def executemany(self, operation, seq_of_parameters):
raise NotSupportedError
for parameters in seq_of_parameters:
self.execute(operation, parameters)
# Operations that have result sets are not allowed with executemany.
self._reset_state()

@synchronized
def cancel(self):
Expand Down
7 changes: 5 additions & 2 deletions pyathena/pandas_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pyathena.common import CursorIterator
from pyathena.cursor import BaseCursor
from pyathena.error import NotSupportedError, OperationalError, ProgrammingError
from pyathena.error import OperationalError, ProgrammingError
from pyathena.model import AthenaQueryExecution
from pyathena.result_set import AthenaPandasResultSet, WithResultSet
from pyathena.util import synchronized
Expand Down Expand Up @@ -60,7 +60,10 @@ def execute(self, operation, parameters=None, work_group=None, s3_staging_dir=No
return self

def executemany(self, operation, seq_of_parameters):
raise NotSupportedError
for parameters in seq_of_parameters:
self.execute(operation, parameters)
# Operations that have result sets are not allowed with executemany.
self._reset_state()

@synchronized
def cancel(self):
Expand Down
7 changes: 5 additions & 2 deletions pyathena/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,10 @@ def __init__(self, connection, converter, query_execution, arraysize, retry_conf
self._arraysize = arraysize
self._client = self._connection.session.client(
's3', region_name=self._connection.region_name, **self._connection._client_kwargs)
if self._query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED:
if self._query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED and \
self._query_execution.statement_type == \
AthenaQueryExecution.STATEMENT_TYPE_DML and \
self._query_execution.output_location.endswith('.csv'):
self._df = self._as_pandas()
else:
import pandas as pd
Expand Down Expand Up @@ -455,7 +458,7 @@ def _as_pandas(self):
parse_dates=self.parse_dates,
infer_datetime_format=True)
df = self._trunc_date(df)
else: # Allow empty response so DDL can be used
else: # Allow empty response
df = pd.DataFrame()
return df

Expand Down
115 changes: 110 additions & 5 deletions pyathena/sqlalchemy_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@

import tenacity
from future.utils import raise_from
from sqlalchemy import exc, util
from sqlalchemy.engine import reflection, Engine
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.exc import NoSuchTableError, OperationalError
from sqlalchemy.sql.compiler import (BIND_PARAMS, BIND_PARAMS_ESC,
IdentifierPreparer, SQLCompiler)
IdentifierPreparer, SQLCompiler, DDLCompiler)
from sqlalchemy.sql.sqltypes import (BIGINT, BINARY, BOOLEAN, DATE, DECIMAL, FLOAT,
INTEGER, NULLTYPE, STRINGTYPE, TIMESTAMP)
from tenacity import retry_if_exception, stop_after_attempt, wait_exponential
Expand All @@ -28,17 +29,39 @@ def __contains__(self, item):
return True


class AthenaIdentifierPreparer(IdentifierPreparer):
class AthenaDMLIdentifierPreparer(IdentifierPreparer):
"""PrestoIdentifierPreparer

https://github.com/dropbox/PyHive/blob/master/pyhive/sqlalchemy_presto.py"""
reserved_words = UniversalSet()


class AthenaCompiler(SQLCompiler):
class AthenaDDLIdentifierPreparer(IdentifierPreparer):

def __init__(
self,
dialect,
initial_quote='`',
final_quote=None,
escape_quote='`',
quote_case_sensitive_collations=True,
omit_schema=False
):
super(AthenaDDLIdentifierPreparer, self).__init__(
dialect=dialect,
initial_quote=initial_quote,
final_quote=final_quote,
escape_quote=escape_quote,
quote_case_sensitive_collations=quote_case_sensitive_collations,
omit_schema=omit_schema
)


class AthenaStatementCompiler(SQLCompiler):
"""PrestoCompiler

https://github.com/dropbox/PyHive/blob/master/pyhive/sqlalchemy_presto.py"""

def visit_char_length_func(self, fn, **kw):
return 'length{0}'.format(self.function_argspec(fn, **kw))

Expand Down Expand Up @@ -66,6 +89,85 @@ def do_bindparam(m):
)


class AthenaDDLCompiler(DDLCompiler):

@property
def preparer(self):
return self._preparer

@preparer.setter
def preparer(self, value):
pass

def __init__(
self,
dialect,
statement,
bind=None,
schema_translate_map=None,
compile_kwargs=util.immutabledict()):
self._preparer = AthenaDDLIdentifierPreparer(dialect)
super(AthenaDDLCompiler, self).__init__(
dialect=dialect,
statement=statement,
bind=bind,
schema_translate_map=schema_translate_map,
compile_kwargs=compile_kwargs)

def visit_create_table(self, create):
table = create.element
preparer = self.preparer

text = '\nCREATE EXTERNAL '
text += 'TABLE ' + preparer.format_table(table) + ' '
text += '('

separator = '\n'
for create_column in create.columns:
column = create_column.element
try:
processed = self.process(create_column)
if processed is not None:
text += separator
separator = ", \n"
text += "\t" + processed
except exc.CompileError as ce:
util.raise_from_cause(
exc.CompileError(
util.u("(in table '{0}', column '{1}'): {2}").format(
table.description, column.name, ce.args[0])
)
)

const = self.create_table_constraints(
table,
_include_foreign_key_constraints=create.include_foreign_key_constraints,
)
if const:
text += separator + "\t" + const

text += "\n)\n%s\n\n" % self.post_create_table(table)
return text

def post_create_table(self, table):
raw_connection = table.bind.raw_connection()
# TODO Supports orc, avro, json, csv or tsv format
text = 'STORED AS PARQUET\n'

location = raw_connection._kwargs['s3_dir'] if 's3_dir' in raw_connection._kwargs \
else raw_connection.s3_staging_dir
if not location:
raise exc.CompileError('`s3_dir` or `s3_staging_dir` parameter is required'
' in the connection string.')
text += "LOCATION '{0}{1}/{2}/'\n".format(location, table.schema, table.name)

compression = raw_connection._kwargs.get('compression')
if compression:
text += "TBLPROPERTIES ('parquet.compress'='{0}')\n".format(compression.upper())

return text


_TYPE_MAPPINGS = {
'boolean': BOOLEAN,
'real': FLOAT,
Expand All @@ -91,18 +193,21 @@ class AthenaDialect(DefaultDialect):

name = 'awsathena'
driver = 'rest'
preparer = AthenaIdentifierPreparer
statement_compiler = AthenaCompiler
preparer = AthenaDMLIdentifierPreparer
statement_compiler = AthenaStatementCompiler
ddl_compiler = AthenaDDLCompiler
default_paramstyle = pyathena.paramstyle
supports_alter = False
supports_pk_autoincrement = False
supports_default_values = False
supports_empty_insert = False
supports_multivalues_insert = True
supports_unicode_statements = True
supports_unicode_binds = True
returns_unicode_strings = True
description_encoding = None
supports_native_boolean = True
postfetch_lastrowid = False

_pattern_data_catlog_exception = re.compile(
r'(((Database|Namespace)\ (?P<schema>.+))|(Table\ (?P<table>.+)))\ not\ found\.')
Expand Down
11 changes: 10 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import contextlib
import os
import uuid

import pytest

Expand Down Expand Up @@ -50,6 +51,12 @@ def _create_table(cursor):
ENV.s3_staging_dir, S3_PREFIX, 'integer_na_values')
location_boolean_na_values = '{0}{1}/{2}/'.format(
ENV.s3_staging_dir, S3_PREFIX, 'boolean_na_values')
location_execute_many = '{0}{1}/{2}/'.format(
ENV.s3_staging_dir, S3_PREFIX, 'execute_many_{0}'.format(
str(uuid.uuid4()).replace('-', '')))
location_execute_many_pandas = '{0}{1}/{2}/'.format(
ENV.s3_staging_dir, S3_PREFIX, 'execute_many_pandas_{0}'.format(
str(uuid.uuid4()).replace('-', '')))
for q in read_query(
os.path.join(BASE_PATH, 'sql', 'create_table.sql')):
cursor.execute(q.format(schema=SCHEMA,
Expand All @@ -58,4 +65,6 @@ def _create_table(cursor):
location_one_row_complex=location_one_row_complex,
location_partition_table=location_partition_table,
location_integer_na_values=location_integer_na_values,
location_boolean_na_values=location_boolean_na_values))
location_boolean_na_values=location_boolean_na_values,
location_execute_many=location_execute_many,
location_execute_many_pandas=location_execute_many_pandas))
14 changes: 14 additions & 0 deletions tests/sql/create_table.sql
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,17 @@ CREATE EXTERNAL TABLE IF NOT EXISTS {schema}.boolean_na_values (
)
ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\n' STORED AS TEXTFILE
LOCATION '{location_boolean_na_values}';

DROP TABLE IF EXISTS {schema}.execute_many;
CREATE EXTERNAL TABLE IF NOT EXISTS {schema}.execute_many (
a INT
)
ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\n' STORED AS TEXTFILE
LOCATION '{location_execute_many}';

DROP TABLE IF EXISTS {schema}.execute_many_pandas;
CREATE EXTERNAL TABLE IF NOT EXISTS {schema}.execute_many_pandas (
a INT
)
ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\n' STORED AS TEXTFILE
LOCATION '{location_execute_many_pandas}';
Loading