Skip to content

Commit

Permalink
Support multihost connection string (dbcli#978)
Browse files Browse the repository at this point in the history
* Switch to psycopg2 parse_dsn instead of urlparse

* Added wronglink to contributors and updated changelog

* Fix test codestyle

* Support for PGPORT customization in tests

* Support for PGPORT customization in tests

* Refactored PGExecute init and moved short_host generation to object property

* Fix test util codestyle

* Fix local tests run

* Store PGExecute initial params in  _conn_params and added PGExecute.copy method

* Fix codestyle

* Added docstring to PGExecute.copy() method
  • Loading branch information
wronglink authored and j-bennet committed Jan 3, 2019
1 parent a765088 commit f614cef
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 81 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ Contributors:
* Max Rothman
* Daniel Egger
* Ignacio Campabadal
* Mikhail Elovskikh (wronglink)

Creator:
--------
Expand Down
2 changes: 2 additions & 0 deletions changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Features:

* Allows passing the ``-u`` flag to specify a username. (Thanks: `Ignacio Campabadal`_)
* Fix for lag in v2 (#979). (Thanks: `Irina Truong`_)
* Support for multihost connection string that is convenient if you have postgres cluster. (Thanks: `Mikhail Elovskikh`_)

Internal:
---------
Expand Down Expand Up @@ -922,3 +923,4 @@ Improvements:
.. _`Max Rothman`: https://github.com/maxrothman
.. _`Daniel Egger`: https://github.com/DanEEStar
.. _`Ignacio Campabadal`: https://github.com/igncampa
.. _`Mikhail Elovskikh`: https://github.com/wronglink
6 changes: 1 addition & 5 deletions pgcli/completion_refresher.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,7 @@ def _bg_refresh(self, pgexecute, special, callbacks, history=None,
executor = pgexecute
else:
# Create a new pgexecute method to popoulate the completions.
e = pgexecute
executor = PGExecute(
e.dbname, e.user, e.password, e.host, e.port, e.dsn,
**e.extra_args)

executor = pgexecute.copy()
# If callbacks is a single function then push it into a list.
if callable(callbacks):
callbacks = [callbacks]
Expand Down
32 changes: 9 additions & 23 deletions pgcli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,25 +386,13 @@ def connect_dsn(self, dsn):
self.connect(dsn=dsn)

def connect_uri(self, uri):
uri = urlparse(uri)
database = uri.path[1:] # ignore the leading fwd slash

def fixup_possible_percent_encoding(s):
return unquote(str(s)) if s else s

arguments = dict(database=fixup_possible_percent_encoding(database),
host=fixup_possible_percent_encoding(uri.hostname),
user=fixup_possible_percent_encoding(uri.username),
port=fixup_possible_percent_encoding(uri.port),
passwd=fixup_possible_percent_encoding(uri.password))
# Deal with extra params e.g. ?sslmode=verify-ca&sslrootcert=/myrootcert
if uri.query:
arguments = dict(
{k: v for k, (v,) in parse_qs(uri.query).items()},
**arguments)

# unquote str(each URI part (they may be percent encoded)
self.connect(**arguments)
kwargs = psycopg2.extensions.parse_dsn(uri)
remap = {
'dbname': 'database',
'password': 'passwd',
}
kwargs = {remap.get(k, k): v for k, v in kwargs.items()}
self.connect(**kwargs)

def connect(self, database='', host='', user='', port='', passwd='',
dsn='', **kwargs):
Expand Down Expand Up @@ -900,10 +888,8 @@ def get_prompt(self, string):
string = string.replace('\\dsn_alias', self.dsn_alias or '')
string = string.replace('\\t', self.now.strftime('%x %X'))
string = string.replace('\\u', self.pgexecute.user or '(none)')
host = self.pgexecute.host or '(none)'
string = string.replace('\\H', host)
short_host, _, _ = host.partition('.')
string = string.replace('\\h', short_host)
string = string.replace('\\H', self.pgexecute.host or '(none)')
string = string.replace('\\h', self.pgexecute.short_host or '(none)')
string = string.replace('\\d', self.pgexecute.dbname or '(none)')
string = string.replace('\\p', str(
self.pgexecute.port) if self.pgexecute.port is not None else '5432')
Expand Down
105 changes: 58 additions & 47 deletions pgcli/pgexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,16 +185,21 @@ class PGExecute(object):

version_query = "SELECT version();"

def __init__(self, database, user, password, host, port, dsn, **kwargs):
self.dbname = database
self.user = user
self.password = password
self.host = host
self.port = port
self.dsn = dsn
self.extra_args = {k: unicode2utf8(v) for k, v in kwargs.items()}
def __init__(self, database=None, user=None, password=None, host=None,
port=None, dsn=None, **kwargs):
self._conn_params = {}
self.conn = None
self.dbname = None
self.user = None
self.password = None
self.host = None
self.port = None
self.server_version = None
self.connect()
self.connect(database, user, password, host, port, dsn, **kwargs)

def copy(self):
"""Returns a clone of the current executor."""
return self.__class__(**self._conn_params)

def get_server_version(self):
if self.server_version:
Expand All @@ -216,51 +221,48 @@ def get_server_version(self):
def connect(self, database=None, user=None, password=None, host=None,
port=None, dsn=None, **kwargs):

db = (database or self.dbname)
user = (user or self.user)
password = (password or self.password)
host = (host or self.host)
port = (port or self.port)
dsn = (dsn or self.dsn)
kwargs = (kwargs or self.extra_args)
pid = -1
if dsn:
if password:
dsn = "{0} password={1}".format(dsn, password)
conn = psycopg2.connect(dsn=unicode2utf8(dsn))
cursor = conn.cursor()
else:
conn = psycopg2.connect(
database=unicode2utf8(db),
user=unicode2utf8(user),
password=unicode2utf8(password),
host=unicode2utf8(host),
port=unicode2utf8(port),
**kwargs)

cursor = conn.cursor()
conn_params = self._conn_params.copy()

new_params = {
'database': database,
'user': user,
'password': password,
'host': host,
'port': port,
'dsn': dsn,
}
new_params.update(kwargs)
conn_params.update({
k: unicode2utf8(v) for k, v in new_params.items() if v is not None
})

if 'password' in conn_params and 'dsn' in conn_params:
conn_params['dsn'] = "{0} password={1}".format(
conn_params['dsn'], conn_params.pop('password')
)

conn = psycopg2.connect(**conn_params)
cursor = conn.cursor()
conn.set_client_encoding('utf8')
if hasattr(self, 'conn'):

self._conn_params = conn_params
if self.conn:
self.conn.close()
self.conn = conn
self.conn.autocommit = True

if dsn:
# When we connect using a DSN, we don't really know what db,
# user, etc. we connected to. Let's read it.
# Note: moved this after setting autocommit because of #664.
dsn_parameters = conn.get_dsn_parameters()
db = dsn_parameters['dbname']
user = dsn_parameters['user']
host = dsn_parameters['host']
port = dsn_parameters['port']

self.dbname = db
self.user = user
# When we connect using a DSN, we don't really know what db,
# user, etc. we connected to. Let's read it.
# Note: moved this after setting autocommit because of #664.
# TODO: use actual connection info from psycopg2.extensions.Connection.info as psycopg>2.8 is available and required dependency # noqa
dsn_parameters = conn.get_dsn_parameters()

self.dbname = dsn_parameters['dbname']
self.user = dsn_parameters['user']
self.password = password
self.host = host
self.port = port
self.host = dsn_parameters['host']
self.port = dsn_parameters['port']
self.extra_args = kwargs

if not self.host:
self.host = self.get_socket_directory()
Expand All @@ -276,6 +278,15 @@ def connect(self, database=None, user=None, password=None, host=None,
register_json_typecasters(self.conn, self._json_typecaster)
register_hstore_typecaster(self.conn)

@property
def short_host(self):
if ',' in self.host:
host, _, _ = self.host.partition(',')
else:
host = self.host
short_host, _, _ = host.partition('.')
return short_host

def _select_one(self, cur, sql):
"""
Helper method to run a select and retrieve a single field value
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import pytest
from utils import (POSTGRES_HOST, POSTGRES_USER, POSTGRES_PASSWORD, create_db, db_connection,
from utils import (POSTGRES_HOST, POSTGRES_PORT, POSTGRES_USER, POSTGRES_PASSWORD, create_db, db_connection,
drop_tables)
import pgcli.pgexecute

Expand All @@ -25,8 +25,8 @@ def cursor(connection):

@pytest.fixture
def executor(connection):
return pgcli.pgexecute.PGExecute(database='_test_db', user=POSTGRES_USER,
host=POSTGRES_HOST, password=POSTGRES_PASSWORD, port=None, dsn=None)
return pgcli.pgexecute.PGExecute(database='_test_db', user=POSTGRES_USER, host=POSTGRES_HOST,
password=POSTGRES_PASSWORD, port=POSTGRES_PORT, dsn=None)


@pytest.fixture
Expand Down
14 changes: 12 additions & 2 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ def test_quoted_db_uri(tmpdir):
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri('postgres://bar%5E:%5Dfoo@baz.com/testdb%5B')
mock_connect.assert_called_with(database='testdb[',
port=None,
host='baz.com',
user='bar^',
passwd=']foo')
Expand All @@ -281,7 +280,6 @@ def test_ssl_db_uri(tmpdir):
'sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem')
mock_connect.assert_called_with(database='testdb[',
host='baz.com',
port=None,
user='bar^',
passwd=']foo',
sslmode='verify-full',
Expand All @@ -299,3 +297,15 @@ def test_port_db_uri(tmpdir):
user='bar',
passwd='foo',
port='2543')


def test_multihost_db_uri(tmpdir):
with mock.patch.object(PGCli, 'connect') as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri(
'postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb')
mock_connect.assert_called_with(database='testdb',
host='baz1.com,baz2.com,baz3.com',
user='bar',
passwd='foo',
port='2543,2543,2543')
25 changes: 25 additions & 0 deletions tests/test_pgexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ def test_conn(executor):
SELECT 1""")


@dbtest
def test_copy(executor):
executor_copy = executor.copy()
run(executor_copy, '''create table test(a text)''')
run(executor_copy, '''insert into test values('abc')''')
assert run(executor_copy, '''select * from test''', join=True) == dedent("""\
+-----+
| a |
|-----|
| abc |
+-----+
SELECT 1""")



@dbtest
def test_bools_are_treated_as_strings(executor):
run(executor, '''create table test(a boolean)''')
Expand Down Expand Up @@ -400,3 +415,13 @@ def test_nonexistent_view_definition(executor):
result = executor.view_definition('there_is_no_such_view')
with pytest.raises(RuntimeError):
result = executor.view_definition('mvw1')


@dbtest
def test_short_host(executor):
with patch.object(executor, 'host', 'localhost'):
assert executor.short_host == 'localhost'
with patch.object(executor, 'host', 'localhost.example.org'):
assert executor.short_host == 'localhost'
with patch.object(executor, 'host', 'localhost1.example.org,localhost2.example.org'):
assert executor.short_host == 'localhost1'
7 changes: 6 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@

POSTGRES_USER = getenv('PGUSER', 'postgres')
POSTGRES_HOST = getenv('PGHOST', 'localhost')
POSTGRES_PORT = getenv('PGPORT', 5432)
POSTGRES_PASSWORD = getenv('PGPASSWORD', '')


def db_connection(dbname=None):
conn = psycopg2.connect(user=POSTGRES_USER, host=POSTGRES_HOST, database=dbname)
conn = psycopg2.connect(user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
port=POSTGRES_PORT,
database=dbname)
conn.autocommit = True
return conn

Expand Down
2 changes: 2 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ deps = pytest
mock
pgspecial
humanize
psycopg2
commands = py.test
behave tests/features
passenv = PGHOST
PGPORT
PGUSER
PGPASSWORD

0 comments on commit f614cef

Please sign in to comment.