diff --git a/pyathenajdbc/__init__.py b/pyathenajdbc/__init__.py index e9f7412..64e1d05 100644 --- a/pyathenajdbc/__init__.py +++ b/pyathenajdbc/__init__.py @@ -20,6 +20,7 @@ ATHENA_JAR) ATHENA_DRIVER_CLASS_NAME = 'com.simba.athena.jdbc.Driver' ATHENA_CONNECTION_STRING = 'jdbc:awsathena://AwsRegion={region};' +LOG4J_PROPERTIES = 'log4j.properties' class DBAPITypeObject: @@ -60,9 +61,9 @@ def __eq__(self, other): def connect(s3_staging_dir=None, access_key=None, secret_key=None, region_name=None, schema_name='default', profile_name=None, credential_file=None, jvm_path=None, jvm_options=None, converter=None, formatter=None, - driver_path=None, **kwargs): + driver_path=None, log4j_conf=None, **kwargs): from pyathenajdbc.connection import Connection return Connection(s3_staging_dir, access_key, secret_key, region_name, schema_name, profile_name, credential_file, jvm_path, jvm_options, converter, formatter, - driver_path, **kwargs) + driver_path, log4j_conf, **kwargs) diff --git a/pyathenajdbc/connection.py b/pyathenajdbc/connection.py index f4bf2c1..769e7ea 100644 --- a/pyathenajdbc/connection.py +++ b/pyathenajdbc/connection.py @@ -8,7 +8,8 @@ import jpype from future.utils import iteritems -from pyathenajdbc import (ATHENA_CONNECTION_STRING, ATHENA_DRIVER_CLASS_NAME, ATHENA_JAR) +from pyathenajdbc import (ATHENA_CONNECTION_STRING, ATHENA_DRIVER_CLASS_NAME, + ATHENA_JAR, LOG4J_PROPERTIES) from pyathenajdbc.converter import JDBCTypeConverter from pyathenajdbc.cursor import Cursor from pyathenajdbc.error import NotSupportedError, ProgrammingError @@ -21,11 +22,12 @@ class Connection(object): _ENV_S3_STAGING_DIR = 'AWS_ATHENA_S3_STAGING_DIR' + _BASE_PATH = os.path.dirname(os.path.abspath(__file__)) def __init__(self, s3_staging_dir=None, access_key=None, secret_key=None, region_name=None, schema_name='default', profile_name=None, credential_file=None, jvm_path=None, jvm_options=None, converter=None, formatter=None, - driver_path=None, **driver_kwargs): + driver_path=None, log4j_conf=None, **driver_kwargs): if s3_staging_dir: self.s3_staging_dir = s3_staging_dir else: @@ -63,7 +65,7 @@ def __init__(self, s3_staging_dir=None, access_key=None, secret_key=None, self.region_name = session.get_config_variable('region') assert self.region_name, 'Required argument `region_name` not found.' - self._start_jvm(jvm_path, jvm_options, driver_path) + self._start_jvm(jvm_path, jvm_options, driver_path, log4j_conf) props = self._build_driver_args(**driver_kwargs) jpype.JClass(ATHENA_DRIVER_CLASS_NAME) @@ -75,14 +77,20 @@ def __init__(self, s3_staging_dir=None, access_key=None, secret_key=None, @classmethod @synchronized - def _start_jvm(cls, jvm_path, jvm_options, driver_path): + def _start_jvm(cls, jvm_path, jvm_options, driver_path, log4j_conf): if jvm_path is None: jvm_path = jpype.get_default_jvm_path() if driver_path is None: - driver_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), ATHENA_JAR) + driver_path = os.path.join(cls._BASE_PATH, ATHENA_JAR) + if log4j_conf is None: + log4j_conf = os.path.join(cls._BASE_PATH, LOG4J_PROPERTIES) if not jpype.isJVMStarted(): _logger.debug('JVM path: %s', jvm_path) - args = ['-server', '-Djava.class.path={0}'.format(driver_path)] + args = [ + '-server', + '-Djava.class.path={0}'.format(driver_path), + '-Dlog4j.configuration=file:{0}'.format(log4j_conf) + ] if jvm_options: args.extend(jvm_options) _logger.debug('JVM args: %s', args) diff --git a/pyathenajdbc/log4j.properties b/pyathenajdbc/log4j.properties new file mode 100644 index 0000000..a091230 --- /dev/null +++ b/pyathenajdbc/log4j.properties @@ -0,0 +1,2 @@ +log4j.rootLogger=FATAL, null +log4j.appender.null=com.simba.athena.shaded.apache.log4j.varia.NullAppender diff --git a/setup.py b/setup.py index f3c1c64..2f78d04 100755 --- a/setup.py +++ b/setup.py @@ -105,7 +105,7 @@ def run(self): package_data={ '': ['LICENSE', '*.rst'], 'jdbc': ['*.txt'], - _PACKAGE_NAME.lower(): ['*.jar'], + _PACKAGE_NAME.lower(): ['*.jar', '*.properties'], }, install_requires=[ 'future',