Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

from typing import TYPE_CHECKING

from psycopg2.extensions import register_adapter
from psycopg2.extras import Json

from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.transfers.bigquery_to_sql import BigQueryToSqlBaseOperator
from airflow.providers.google.cloud.utils.bigquery_get_data import bigquery_get_data
Expand Down Expand Up @@ -76,6 +79,8 @@ def __init__(
self.replace_index = replace_index

def get_sql_hook(self) -> PostgresHook:
register_adapter(list, Json)
register_adapter(dict, Json)
return PostgresHook(database=self.database, postgres_conn_id=self.postgres_conn_id)

def execute(self, context: Context) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from unittest import mock

import pytest
from psycopg2.extras import Json

from airflow.providers.google.cloud.transfers.bigquery_to_postgres import BigQueryToPostgresOperator

Expand Down Expand Up @@ -85,3 +86,16 @@ def test_init_raises_exception_if_replace_is_true_and_missing_params(
selected_fields=selected_fields,
replace_index=replace_index,
)

@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.register_adapter")
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_postgres.BigQueryHook")
def test_adapters_to_json_registered(self, mock_hook, mock_register_adapter):
BigQueryToPostgresOperator(
task_id=TASK_ID,
dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}",
target_table_name=TEST_DESTINATION_TABLE,
replace=False,
).execute(context=mock.MagicMock())

mock_register_adapter.assert_any_call(list, Json)
mock_register_adapter.assert_any_call(dict, Json)
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import psycopg2
import psycopg2.extensions
import psycopg2.extras
from psycopg2.extras import DictCursor, Json, NamedTupleCursor, RealDictCursor
from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor
from sqlalchemy.engine import URL

from airflow.exceptions import (
Expand Down Expand Up @@ -282,20 +282,19 @@ def _serialize_cell(cell: object, conn: connection | None = None) -> Any:
"""
Serialize a cell.

In order to pass a Python object to the database as query argument you can use the
Json (class psycopg2.extras.Json) adapter.
Psycopg2 adapts all arguments to the ``execute()`` method internally,
hence we return the cell without any conversion.

Reading from the database, json and jsonb values will be automatically converted to Python objects.
See https://www.psycopg.org/docs/extensions.html#sql-adaptation-protocol-objects
for more information.

See https://www.psycopg.org/docs/extras.html#json-adaptation for
more information.
To perform custom type adaptation please use register_adapter function
https://www.psycopg.org/docs/extensions.html#psycopg2.extensions.register_adapter.

:param cell: The cell to insert into the table
:param conn: The database connection
:return: The cell
"""
if isinstance(cell, (dict, list)):
cell = Json(cell)
return cell

def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
Expand Down
19 changes: 0 additions & 19 deletions providers/postgres/tests/unit/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import psycopg2.extras
import pytest
import sqlalchemy
from psycopg2.extras import Json

from airflow.exceptions import AirflowException
from airflow.models import Connection
Expand Down Expand Up @@ -503,24 +502,6 @@ def test_bulk_dump(self, tmp_path):

assert sorted(input_data) == sorted(results)

@pytest.mark.parametrize(
"raw_cell, expected_serialized",
[
("cell content", "cell content"),
(342, 342),
(
{"key1": "value2", "n_key": {"sub_key": "sub_value"}},
{"key1": "value2", "n_key": {"sub_key": "sub_value"}},
),
([1, 2, {"key1": "value2"}, "some data"], [1, 2, {"key1": "value2"}, "some data"]),
],
)
def test_serialize_cell(self, raw_cell, expected_serialized):
if isinstance(raw_cell, Json):
assert expected_serialized == raw_cell.adapted
else:
assert expected_serialized == raw_cell

@pytest.mark.parametrize(
"df_type, expected_type",
[
Expand Down