Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import psycopg2.extras
from deprecated import deprecated
from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor
from sqlalchemy.engine import URL

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.common.sql.hooks.sql import DbApiHook
Expand Down Expand Up @@ -113,6 +114,18 @@ def schema(self):
def schema(self, value):
self.database = value

@property
def sqlalchemy_url(self) -> URL:
conn = self.get_connection(getattr(self, self.conn_name_attr))
return URL.create(
drivername="postgresql",
username=conn.login,
password=conn.password,
host=conn.host,
port=conn.port,
database=self.database or conn.schema,
)

def _get_cursor(self, raw_cursor: str) -> CursorType:
_cursor = raw_cursor.lower()
cursor_types = {
Expand Down Expand Up @@ -186,12 +199,9 @@ def copy_expert(self, sql: str, filename: str) -> None:
def get_uri(self) -> str:
"""Extract the URI from the connection.

:return: the extracted uri.
:return: the extracted URI in Sqlalchemy URI format.
"""
conn = self.get_connection(getattr(self, self.conn_name_attr))
conn.schema = self.database or conn.schema
uri = conn.get_uri().replace("postgres://", "postgresql://")
return uri
return self.sqlalchemy_url.render_as_string(hide_password=False)

def bulk_load(self, table: str, tmp_file: str) -> None:
"""Load a tab-delimited file into a database table."""
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ def test_get_conn(self, mock_connect):

@mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
def test_get_uri(self, mock_connect):
self.connection.extra = json.dumps({"client_encoding": "utf-8"})
self.connection.conn_type = "postgres"
self.connection.port = 5432
self.db_hook.get_conn()
assert mock_connect.call_count == 1
assert self.db_hook.get_uri() == "postgresql://login:password@host/database?client_encoding=utf-8"
assert self.db_hook.get_uri() == "postgresql://login:password@host:5432/database"

@mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
def test_get_conn_cursor(self, mock_connect):
Expand Down