Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/first class pg #174

Merged
merged 1 commit into from
Oct 19, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
make new PostgresTarget, new target type
  • Loading branch information
Drew Banin committed Oct 1, 2016
commit 851dcfde341b75621b69ced85580fbee138b7953
11 changes: 9 additions & 2 deletions dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dbt.source import Source
from dbt.utils import find_model_by_fqn, find_model_by_name, dependency_projects, split_path, This, Var, compiler_error
from dbt.linker import Linker
import dbt.targets
import time
import sqlparse

Expand All @@ -15,6 +16,7 @@ def __init__(self, project, create_template_class):
self.project = project
self.create_template = create_template_class()
self.macro_generator = None
self.target = self.get_target()

def initialize(self):
if not os.path.exists(self.project['target-path']):
Expand All @@ -25,7 +27,7 @@ def initialize(self):

def get_target(self):
target_cfg = self.project.run_environment()
return RedshiftTarget(target_cfg)
return dbt.targets.get_target(target_cfg)

def model_sources(self, this_project, own_project=None):
if own_project is None:
Expand Down Expand Up @@ -154,16 +156,21 @@ def wrapped_do_ref(*args):

def get_context(self, linker, model, models):
context = self.project.context()

# built-ins
context['ref'] = self.__ref(linker, context, model, models)
context['config'] = self.__model_config(model, linker)
context['this'] = This(context['env']['schema'], model.immediate_name, model.name)
context['var'] = Var(model, context=context)

# these get re-interpolated at runtime!
context['run_started_at'] = '{{ run_started_at }}'
context['invocation_id'] = '{{ invocation_id }}'

context['var'] = Var(model, context=context)
# add in context from run target
context.update(self.target.context)

# add in macros (can we cache these somehow?)
for macro_name, macro in self.macro_generator(context):
context[macro_name] = macro

Expand Down
4 changes: 2 additions & 2 deletions dbt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dbt.compilation import Compiler
from dbt.linker import Linker
from dbt.templates import BaseCreateTemplate
from dbt.targets import RedshiftTarget
import dbt.targets
from dbt.source import Source
from dbt.utils import find_model_by_fqn, find_model_by_name, dependency_projects
from dbt.compiled_model import make_compiled_model
Expand Down Expand Up @@ -205,7 +205,7 @@ def __init__(self, project, target_path, graph_type):
self.target_path = target_path
self.graph_type = graph_type

self.target = RedshiftTarget(self.project.run_environment())
self.target = dbt.targets.get_target(self.project.run_environment())

if self.target.should_open_tunnel():
print("Opening ssh tunnel to host {}... ".format(self.target.ssh_host), end="")
Expand Down
4 changes: 2 additions & 2 deletions dbt/schema_tester.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from dbt.targets import RedshiftTarget
import dbt.targets

import psycopg2
import logging
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(self, project):

def get_target(self):
target_cfg = self.project.run_environment()
return RedshiftTarget(target_cfg)
return dbt.targets.get_target(target_cfg)

def execute_query(self, model, sql):
target = self.get_target()
Expand Down
4 changes: 2 additions & 2 deletions dbt/seeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import psycopg2

from dbt.source import Source
from dbt.targets import RedshiftTarget
import dbt.targets

class Seeder:
def __init__(self, project):
self.project = project
run_environment = self.project.run_environment()
self.target = RedshiftTarget(run_environment)
self.target = dbt.targets.get_target(run_environment)

def find_csvs(self):
return Source(self.project).get_csvs(self.project['data-paths'])
Expand Down
44 changes: 41 additions & 3 deletions dbt/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
BAD_THREADS_ERROR = """Invalid value given for "threads" in active run-target.
Value given was {supplied} but it should be an int between {min_val} and {max_val}"""

class RedshiftTarget:
class BaseSQLTarget:
def __init__(self, cfg):
assert cfg['type'] == 'redshift'
self.target_type = cfg['type']
self.host = cfg['host']
self.user = cfg['user']
self.password = cfg['pass']
Expand Down Expand Up @@ -63,7 +63,7 @@ def should_open_tunnel(self):
return False

# make the user explicitly call this function to enable the ssh tunnel
# we don't want it to be automatically opened any time someone makes a RedshiftTarget()
# we don't want it to be automatically opened any time someone makes a new target
def open_tunnel_if_needed(self):
#self.ssh_tunnel = self.__open_tunnel()
pass
Expand Down Expand Up @@ -105,3 +105,41 @@ def get_handle(self):
def rollback(self):
if self.handle is not None:
self.handle.rollback()

@property
def type(self):
return self.target_type

class RedshiftTarget(BaseSQLTarget):
def __init__(self, cfg):
super(RedshiftTarget, self).__init__(cfg)

@property
def context(self):
return {
"sql_now": "getdate()"
}

class PostgresTarget(BaseSQLTarget):
def __init__(self, cfg):
super(PostgresTarget, self).__init__(cfg)

@property
def context(self):
return {
"sql_now": "clock_timestamp()"
}

target_map = {
'postgres': PostgresTarget,
'redshift': RedshiftTarget
}

def get_target(cfg):
target_type = cfg['type']
if target_type in target_map:
klass = target_map[target_type]
return klass(cfg)
else:
valid_csv = ", ".join(["'{}'".format(t) for t in target_map])
raise RuntimeError("Invalid target type provided: '{}'. Must be one of {}".format(target_type, valid_csv))