-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adapters: postgres & redshift #259
Changes from 6 commits
3145139
13fe3e5
973c4ad
f4f0b41
d0a6f17
32e20a3
9a46a2e
c612333
7023921
e848e59
7076a7d
b2e31c9
0886bc4
31502f8
f6e0813
4a6750e
5676c34
dcb8278
7a52b80
6251c07
64ee3eb
d2ea805
62f2f68
6023586
ec50446
a5e71f0
9c95a92
3a157fd
114fb91
b959fb8
f5d7be8
2913e81
736dcf9
cd1fe4f
7c0f26b
6cf9684
887fe85
2c0e5ec
6be3d44
6e14eb8
95e4a75
4970d6d
15be495
64e4b67
e30931a
527eaa8
67cc3b4
73cde44
2b3e8dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from dbt.adapters.postgres import PostgresAdapter | ||
|
||
def get_adapter(target): | ||
adapters = { | ||
'postgres': PostgresAdapter, | ||
'redshift': PostgresAdapter, | ||
} | ||
|
||
return adapters[target.target_type] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import copy | ||
import psycopg2 | ||
import re | ||
import time | ||
import yaml | ||
|
||
import dbt.flags as flags | ||
|
||
from dbt.contracts.connection import validate_connection | ||
from dbt.logger import GLOBAL_LOGGER as logger | ||
from dbt.schema import Schema, READ_PERMISSION_DENIED_ERROR | ||
|
||
class PostgresAdapter: | ||
|
||
@classmethod | ||
def acquire_connection(cls, profile): | ||
|
||
# profile requires some marshalling right now because it includes a | ||
# wee bit of global config. | ||
# TODO remove this | ||
credentials = copy.deepcopy(profile) | ||
|
||
credentials.pop('type', None) | ||
credentials.pop('threads', None) | ||
|
||
result = { | ||
'type': 'postgres', | ||
'state': 'init', | ||
'handle': None, | ||
'credentials': credentials | ||
} | ||
|
||
logger.debug('Acquiring postgres connection') | ||
|
||
if flags.STRICT_MODE: | ||
validate_connection(result) | ||
|
||
return cls.open_connection(result) | ||
|
||
@classmethod | ||
def get_connection(cls, profile): | ||
return cls.acquire_connection(profile) | ||
|
||
@staticmethod | ||
def create_table(): | ||
pass | ||
|
||
@staticmethod | ||
def drop_table(): | ||
pass | ||
|
||
@classmethod | ||
def execute_model(cls, project, target, model): | ||
schema_helper = Schema(project, target) | ||
parts = re.split(r'-- (DBT_OPERATION .*)', model.compiled_contents) | ||
profile = project.run_environment() | ||
connection = cls.get_connection(profile) | ||
|
||
if flags.STRICT_MODE: | ||
validate_connection(connection) | ||
|
||
handle = connection['handle'] | ||
|
||
status = 'None' | ||
for i, part in enumerate(parts): | ||
matches = re.match(r'^DBT_OPERATION ({.*})$', part) | ||
if matches is not None: | ||
instruction_string = matches.groups()[0] | ||
instruction = yaml.safe_load(instruction_string) | ||
function = instruction['function'] | ||
kwargs = instruction['args'] | ||
|
||
func_map = { | ||
'expand_column_types_if_needed': \ | ||
lambda kwargs: schema_helper.expand_column_types_if_needed( | ||
**kwargs) | ||
} | ||
|
||
func_map[function](kwargs) | ||
else: | ||
try: | ||
handle, status = cls.add_query_to_transaction( | ||
part, handle) | ||
except psycopg2.ProgrammingError as e: | ||
if "permission denied for" in e.diag.message_primary: | ||
raise RuntimeError(READ_PERMISSION_DENIED_ERROR.format( | ||
model=model.name, | ||
error=str(e).strip(), | ||
user=target.user, | ||
)) | ||
else: | ||
raise | ||
|
||
handle.commit() | ||
return status | ||
|
||
@classmethod | ||
def open_connection(cls, connection): | ||
if connection.get('state') == 'open': | ||
logger.debug('Connection is already open, skipping open.') | ||
return connection | ||
|
||
result = connection.copy() | ||
|
||
try: | ||
handle = psycopg2.connect(cls.get_connection_spec(connection)) | ||
|
||
result['handle'] = handle | ||
result['state'] = 'open' | ||
except psycopg2.Error as e: | ||
logger.debug("Got an error when attempting to open a postgres " | ||
"connection: '{}'" | ||
.format(e)) | ||
|
||
result['handle'] = None | ||
result['state'] = 'fail' | ||
|
||
return result | ||
|
||
@staticmethod | ||
def get_connection_spec(connection): | ||
credentials = connection.get('credentials') | ||
|
||
return ("dbname='{}' user='{}' host='{}' password='{}' port='{}' " | ||
"connect_timeout=10".format( | ||
credentials.get('dbname'), | ||
credentials.get('user'), | ||
credentials.get('host'), | ||
credentials.get('pass'), | ||
credentials.get('port'), | ||
)) | ||
|
||
@staticmethod | ||
def add_query_to_transaction(sql, handle): | ||
cursor = handle.cursor() | ||
|
||
try: | ||
logger.debug("SQL: %s", sql) | ||
pre = time.time() | ||
cursor.execute(sql) | ||
post = time.time() | ||
logger.debug("SQL status: %s in %0.2f seconds", cursor.statusmessage, post-pre) | ||
return handle, cursor.statusmessage | ||
except Exception as e: | ||
handle.rollback() | ||
logger.exception("Error running SQL: %s", sql) | ||
logger.debug("rolling back connection") | ||
raise e | ||
finally: | ||
cursor.close() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from voluptuous import Schema, Required, All, Any, Extra, Range | ||
from voluptuous.error import MultipleInvalid | ||
|
||
from dbt.exceptions import ValidationException | ||
from dbt.logger import GLOBAL_LOGGER as logger | ||
|
||
|
||
connection_contract = Schema({ | ||
Required('type'): Any('postgres', 'redshift'), | ||
Required('state'): Any('init', 'open', 'closed', 'fail'), | ||
Required('handle'): Any(None, object), | ||
Required('credentials'): object, | ||
}) | ||
|
||
postgres_credentials_contract = Schema({ | ||
Required('dbname'): str, | ||
Required('host'): str, | ||
Required('user'): str, | ||
Required('pass'): str, | ||
Required('port'): All(int, Range(min=0, max=65535)), | ||
Required('schema'): str, | ||
}) | ||
|
||
credentials_mapping = { | ||
'postgres': postgres_credentials_contract, | ||
'redshift': postgres_credentials_contract, | ||
} | ||
|
||
def validate_connection(connection): | ||
try: | ||
connection_contract(connection) | ||
|
||
credentials_contract = credentials_mapping.get(connection.get('type')) | ||
credentials_contract(connection.get('credentials')) | ||
except MultipleInvalid as e: | ||
logger.info(e) | ||
raise ValidationException(str(e)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
class ValidationException(Exception): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
STRICT_MODE = False |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
import re | ||
|
||
import dbt.version | ||
import dbt.flags as flags | ||
import dbt.project as project | ||
import dbt.task.run as run_task | ||
import dbt.task.compile as compile_task | ||
|
@@ -37,6 +38,8 @@ def handle(args): | |
|
||
initialize_logger(parsed.debug) | ||
|
||
flags.STRICT_MODE = parsed.strict | ||
|
||
# this needs to happen after args are parsed so we can determine the correct profiles.yml file | ||
if not config.send_anonymous_usage_stats(parsed.profiles_dir): | ||
dbt.tracking.do_not_track() | ||
|
@@ -131,6 +134,7 @@ def parse_args(args): | |
p = argparse.ArgumentParser(prog='dbt: data build tool', formatter_class=argparse.RawTextHelpFormatter) | ||
p.add_argument('--version', action='version', version=dbt.version.get_version_information(), help="Show version information") | ||
p.add_argument('-d', '--debug', action='store_true', help='Display debug logging during dbt execution. Useful for debugging and making bug reports.') | ||
p.add_argument('-S', '--strict', action='store_true', help='Run schema validations at runtime. This will surface bugs in dbt, but may incur a speed penalty.') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
|
||
subs = p.add_subparsers() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
import yaml | ||
from datetime import datetime | ||
|
||
from dbt.adapters.factory import get_adapter | ||
from dbt.logger import GLOBAL_LOGGER as logger | ||
from dbt.compilation import compile_string | ||
from dbt.linker import Linker | ||
|
@@ -94,38 +95,8 @@ def execute_list(self, queries, source): | |
return status | ||
|
||
def execute_contents(self, target, model): | ||
parts = re.split(r'-- (DBT_OPERATION .*)', model.compiled_contents) | ||
handle = None | ||
|
||
status = 'None' | ||
for i, part in enumerate(parts): | ||
matches = re.match(r'^DBT_OPERATION ({.*})$', part) | ||
if matches is not None: | ||
instruction_string = matches.groups()[0] | ||
instruction = yaml.safe_load(instruction_string) | ||
function = instruction['function'] | ||
kwargs = instruction['args'] | ||
|
||
func_map = { | ||
'expand_column_types_if_needed': lambda kwargs: self.schema_helper.expand_column_types_if_needed(**kwargs), | ||
} | ||
|
||
func_map[function](kwargs) | ||
else: | ||
try: | ||
handle, status = self.schema_helper.execute_without_auto_commit(part, handle) | ||
except psycopg2.ProgrammingError as e: | ||
if "permission denied for" in e.diag.message_primary: | ||
raise RuntimeError(dbt.schema.READ_PERMISSION_DENIED_ERROR.format( | ||
model=model.name, | ||
error=str(e).strip(), | ||
user=target.user, | ||
)) | ||
else: | ||
raise | ||
|
||
handle.commit() | ||
return status | ||
return get_adapter(target).execute_model( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is so good |
||
self.project, target, model) | ||
|
||
class ModelRunner(BaseRunner): | ||
run_type = 'run' | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,3 +9,4 @@ paramiko==2.0.1 | |
sshtunnel==0.0.8.2 | ||
snowplow-tracker==0.7.2 | ||
celery==3.1.23 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
voluptuous==0.9.3 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,6 @@ | ||
#!/bin/bash | ||
|
||
. /usr/src/app/test/setup.sh | ||
workon dbt | ||
pip install tox | ||
|
||
cd /usr/src/app | ||
tox -e integration-py27,integration-py35 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/bin/bash | ||
|
||
pip install tox | ||
|
||
cd /usr/src/app | ||
tox -e unit-py27,unit-py35 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import unittest | ||
|
||
import dbt.flags as flags | ||
|
||
from dbt.adapters.postgres import PostgresAdapter | ||
from dbt.exceptions import ValidationException | ||
from dbt.logger import GLOBAL_LOGGER as logger | ||
|
||
|
||
class TestPostgresAdapter(unittest.TestCase): | ||
|
||
def setUp(self): | ||
flags.STRICT_MODE = True | ||
|
||
self.profile = { | ||
'dbname': 'postgres', | ||
'user': 'root', | ||
'host': 'database', | ||
'pass': 'password', | ||
'port': 5432, | ||
'schema': 'public' | ||
} | ||
|
||
def test_acquire_connection_validations(self): | ||
try: | ||
connection = PostgresAdapter.acquire_connection(self.profile) | ||
self.assertEquals(connection.get('type'), 'postgres') | ||
except ValidationException as e: | ||
self.fail('got ValidationException: {}'.format(str(e))) | ||
except BaseException as e: | ||
self.fail('validation failed with unknown exception: {}' | ||
.format(str(e))) | ||
|
||
def test_acquire_connection(self): | ||
connection = PostgresAdapter.acquire_connection(self.profile) | ||
|
||
self.assertEquals(connection.get('state'), 'open') | ||
self.assertNotEquals(connection.get('handle'), None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're probably going to rethink how
DBT_OPERATION
works in the future. I think it's fine to leave this how it is for now, but let's definitely refrain from copying this logic into future adapters!