diff --git a/AUTHORS b/AUTHORS index 7e076e4de..e29b5813f 100644 --- a/AUTHORS +++ b/AUTHORS @@ -87,6 +87,7 @@ Contributors: * Max Rothman * Daniel Egger * Ignacio Campabadal + * Mikhail Elovskikh (wronglink) Creator: -------- diff --git a/changelog.rst b/changelog.rst index 48ee08733..491e7768a 100644 --- a/changelog.rst +++ b/changelog.rst @@ -11,6 +11,7 @@ Features: * Allows passing the ``-u`` flag to specify a username. (Thanks: `Ignacio Campabadal`_) * Fix for lag in v2 (#979). (Thanks: `Irina Truong`_) +* Support for multihost connection string that is convenient if you have postgres cluster. (Thanks: `Mikhail Elovskikh`_) Internal: --------- @@ -922,3 +923,4 @@ Improvements: .. _`Max Rothman`: https://github.com/maxrothman .. _`Daniel Egger`: https://github.com/DanEEStar .. _`Ignacio Campabadal`: https://github.com/igncampa +.. _`Mikhail Elovskikh`: https://github.com/wronglink \ No newline at end of file diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py index 2f6908e13..0f33357fd 100644 --- a/pgcli/completion_refresher.py +++ b/pgcli/completion_refresher.py @@ -54,11 +54,7 @@ def _bg_refresh(self, pgexecute, special, callbacks, history=None, executor = pgexecute else: # Create a new pgexecute method to popoulate the completions. - e = pgexecute - executor = PGExecute( - e.dbname, e.user, e.password, e.host, e.port, e.dsn, - **e.extra_args) - + executor = pgexecute.copy() # If callbacks is a single function then push it into a list. if callable(callbacks): callbacks = [callbacks] diff --git a/pgcli/main.py b/pgcli/main.py index 1d941ebef..2e5cffde8 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -386,25 +386,13 @@ def connect_dsn(self, dsn): self.connect(dsn=dsn) def connect_uri(self, uri): - uri = urlparse(uri) - database = uri.path[1:] # ignore the leading fwd slash - - def fixup_possible_percent_encoding(s): - return unquote(str(s)) if s else s - - arguments = dict(database=fixup_possible_percent_encoding(database), - host=fixup_possible_percent_encoding(uri.hostname), - user=fixup_possible_percent_encoding(uri.username), - port=fixup_possible_percent_encoding(uri.port), - passwd=fixup_possible_percent_encoding(uri.password)) - # Deal with extra params e.g. ?sslmode=verify-ca&sslrootcert=/myrootcert - if uri.query: - arguments = dict( - {k: v for k, (v,) in parse_qs(uri.query).items()}, - **arguments) - - # unquote str(each URI part (they may be percent encoded) - self.connect(**arguments) + kwargs = psycopg2.extensions.parse_dsn(uri) + remap = { + 'dbname': 'database', + 'password': 'passwd', + } + kwargs = {remap.get(k, k): v for k, v in kwargs.items()} + self.connect(**kwargs) def connect(self, database='', host='', user='', port='', passwd='', dsn='', **kwargs): @@ -900,10 +888,8 @@ def get_prompt(self, string): string = string.replace('\\dsn_alias', self.dsn_alias or '') string = string.replace('\\t', self.now.strftime('%x %X')) string = string.replace('\\u', self.pgexecute.user or '(none)') - host = self.pgexecute.host or '(none)' - string = string.replace('\\H', host) - short_host, _, _ = host.partition('.') - string = string.replace('\\h', short_host) + string = string.replace('\\H', self.pgexecute.host or '(none)') + string = string.replace('\\h', self.pgexecute.short_host or '(none)') string = string.replace('\\d', self.pgexecute.dbname or '(none)') string = string.replace('\\p', str( self.pgexecute.port) if self.pgexecute.port is not None else '5432') diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index 2e9a441f9..a64fff893 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -185,16 +185,21 @@ class PGExecute(object): version_query = "SELECT version();" - def __init__(self, database, user, password, host, port, dsn, **kwargs): - self.dbname = database - self.user = user - self.password = password - self.host = host - self.port = port - self.dsn = dsn - self.extra_args = {k: unicode2utf8(v) for k, v in kwargs.items()} + def __init__(self, database=None, user=None, password=None, host=None, + port=None, dsn=None, **kwargs): + self._conn_params = {} + self.conn = None + self.dbname = None + self.user = None + self.password = None + self.host = None + self.port = None self.server_version = None - self.connect() + self.connect(database, user, password, host, port, dsn, **kwargs) + + def copy(self): + """Returns a clone of the current executor.""" + return self.__class__(**self._conn_params) def get_server_version(self): if self.server_version: @@ -216,51 +221,48 @@ def get_server_version(self): def connect(self, database=None, user=None, password=None, host=None, port=None, dsn=None, **kwargs): - db = (database or self.dbname) - user = (user or self.user) - password = (password or self.password) - host = (host or self.host) - port = (port or self.port) - dsn = (dsn or self.dsn) - kwargs = (kwargs or self.extra_args) - pid = -1 - if dsn: - if password: - dsn = "{0} password={1}".format(dsn, password) - conn = psycopg2.connect(dsn=unicode2utf8(dsn)) - cursor = conn.cursor() - else: - conn = psycopg2.connect( - database=unicode2utf8(db), - user=unicode2utf8(user), - password=unicode2utf8(password), - host=unicode2utf8(host), - port=unicode2utf8(port), - **kwargs) - - cursor = conn.cursor() + conn_params = self._conn_params.copy() + + new_params = { + 'database': database, + 'user': user, + 'password': password, + 'host': host, + 'port': port, + 'dsn': dsn, + } + new_params.update(kwargs) + conn_params.update({ + k: unicode2utf8(v) for k, v in new_params.items() if v is not None + }) + + if 'password' in conn_params and 'dsn' in conn_params: + conn_params['dsn'] = "{0} password={1}".format( + conn_params['dsn'], conn_params.pop('password') + ) + conn = psycopg2.connect(**conn_params) + cursor = conn.cursor() conn.set_client_encoding('utf8') - if hasattr(self, 'conn'): + + self._conn_params = conn_params + if self.conn: self.conn.close() self.conn = conn self.conn.autocommit = True - if dsn: - # When we connect using a DSN, we don't really know what db, - # user, etc. we connected to. Let's read it. - # Note: moved this after setting autocommit because of #664. - dsn_parameters = conn.get_dsn_parameters() - db = dsn_parameters['dbname'] - user = dsn_parameters['user'] - host = dsn_parameters['host'] - port = dsn_parameters['port'] - - self.dbname = db - self.user = user + # When we connect using a DSN, we don't really know what db, + # user, etc. we connected to. Let's read it. + # Note: moved this after setting autocommit because of #664. + # TODO: use actual connection info from psycopg2.extensions.Connection.info as psycopg>2.8 is available and required dependency # noqa + dsn_parameters = conn.get_dsn_parameters() + + self.dbname = dsn_parameters['dbname'] + self.user = dsn_parameters['user'] self.password = password - self.host = host - self.port = port + self.host = dsn_parameters['host'] + self.port = dsn_parameters['port'] + self.extra_args = kwargs if not self.host: self.host = self.get_socket_directory() @@ -276,6 +278,15 @@ def connect(self, database=None, user=None, password=None, host=None, register_json_typecasters(self.conn, self._json_typecaster) register_hstore_typecaster(self.conn) + @property + def short_host(self): + if ',' in self.host: + host, _, _ = self.host.partition(',') + else: + host = self.host + short_host, _, _ = host.partition('.') + return short_host + def _select_one(self, cur, sql): """ Helper method to run a select and retrieve a single field value diff --git a/tests/conftest.py b/tests/conftest.py index 9554623a8..26a408160 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import os import pytest -from utils import (POSTGRES_HOST, POSTGRES_USER, POSTGRES_PASSWORD, create_db, db_connection, +from utils import (POSTGRES_HOST, POSTGRES_PORT, POSTGRES_USER, POSTGRES_PASSWORD, create_db, db_connection, drop_tables) import pgcli.pgexecute @@ -25,8 +25,8 @@ def cursor(connection): @pytest.fixture def executor(connection): - return pgcli.pgexecute.PGExecute(database='_test_db', user=POSTGRES_USER, - host=POSTGRES_HOST, password=POSTGRES_PASSWORD, port=None, dsn=None) + return pgcli.pgexecute.PGExecute(database='_test_db', user=POSTGRES_USER, host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, port=POSTGRES_PORT, dsn=None) @pytest.fixture diff --git a/tests/test_main.py b/tests/test_main.py index fc4de7d90..b1c403562 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -267,7 +267,6 @@ def test_quoted_db_uri(tmpdir): cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) cli.connect_uri('postgres://bar%5E:%5Dfoo@baz.com/testdb%5B') mock_connect.assert_called_with(database='testdb[', - port=None, host='baz.com', user='bar^', passwd=']foo') @@ -281,7 +280,6 @@ def test_ssl_db_uri(tmpdir): 'sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem') mock_connect.assert_called_with(database='testdb[', host='baz.com', - port=None, user='bar^', passwd=']foo', sslmode='verify-full', @@ -299,3 +297,15 @@ def test_port_db_uri(tmpdir): user='bar', passwd='foo', port='2543') + + +def test_multihost_db_uri(tmpdir): + with mock.patch.object(PGCli, 'connect') as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri( + 'postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb') + mock_connect.assert_called_with(database='testdb', + host='baz1.com,baz2.com,baz3.com', + user='bar', + passwd='foo', + port='2543,2543,2543') diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py index a23405475..397889502 100644 --- a/tests/test_pgexecute.py +++ b/tests/test_pgexecute.py @@ -33,6 +33,21 @@ def test_conn(executor): SELECT 1""") +@dbtest +def test_copy(executor): + executor_copy = executor.copy() + run(executor_copy, '''create table test(a text)''') + run(executor_copy, '''insert into test values('abc')''') + assert run(executor_copy, '''select * from test''', join=True) == dedent("""\ + +-----+ + | a | + |-----| + | abc | + +-----+ + SELECT 1""") + + + @dbtest def test_bools_are_treated_as_strings(executor): run(executor, '''create table test(a boolean)''') @@ -400,3 +415,13 @@ def test_nonexistent_view_definition(executor): result = executor.view_definition('there_is_no_such_view') with pytest.raises(RuntimeError): result = executor.view_definition('mvw1') + + +@dbtest +def test_short_host(executor): + with patch.object(executor, 'host', 'localhost'): + assert executor.short_host == 'localhost' + with patch.object(executor, 'host', 'localhost.example.org'): + assert executor.short_host == 'localhost' + with patch.object(executor, 'host', 'localhost1.example.org,localhost2.example.org'): + assert executor.short_host == 'localhost1' diff --git a/tests/utils.py b/tests/utils.py index 31fd12a3f..62b066027 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,11 +7,16 @@ POSTGRES_USER = getenv('PGUSER', 'postgres') POSTGRES_HOST = getenv('PGHOST', 'localhost') +POSTGRES_PORT = getenv('PGPORT', 5432) POSTGRES_PASSWORD = getenv('PGPASSWORD', '') def db_connection(dbname=None): - conn = psycopg2.connect(user=POSTGRES_USER, host=POSTGRES_HOST, database=dbname) + conn = psycopg2.connect(user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + port=POSTGRES_PORT, + database=dbname) conn.autocommit = True return conn diff --git a/tox.ini b/tox.ini index c71af93cd..712db9cbc 100644 --- a/tox.ini +++ b/tox.ini @@ -5,8 +5,10 @@ deps = pytest mock pgspecial humanize + psycopg2 commands = py.test behave tests/features passenv = PGHOST + PGPORT PGUSER PGPASSWORD