Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,32 @@ class IfExistAction(enum.Enum):
SKIP = "skip"


class _BigQueryHookWithFlexibleProjectId(BigQueryHook):
@property
def project_id(self) -> str:
_, project_id = self.get_credentials_and_project_id()
return project_id or PROVIDE_PROJECT_ID

@project_id.setter
def project_id(self, value: str) -> None:
cached_creds, _ = self.get_credentials_and_project_id()
self._cached_project_id = value or PROVIDE_PROJECT_ID
self._cached_credntials = cached_creds


class _BigQueryDbHookMixin:
def get_db_hook(self: BigQueryCheckOperator) -> BigQueryHook: # type:ignore[misc]
def get_db_hook(self: BigQueryCheckOperator) -> _BigQueryHookWithFlexibleProjectId: # type:ignore[misc]
"""Get BigQuery DB Hook."""
return BigQueryHook(
hook = _BigQueryHookWithFlexibleProjectId(
gcp_conn_id=self.gcp_conn_id,
use_legacy_sql=self.use_legacy_sql,
location=self.location,
impersonation_chain=self.impersonation_chain,
labels=self.labels,
)
if self.project_id:
hook.project_id = self.project_id
return hook


class _BigQueryOperatorsEncryptionConfigurationMixin:
Expand Down Expand Up @@ -190,6 +206,7 @@ class BigQueryCheckOperator(
https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs.
For example, [{ 'name': 'corpus', 'parameterType': { 'type': 'STRING' },
'parameterValue': { 'value': 'romeoandjuliet' } }]. (templated)
:param project_id: Google Cloud Project where the job is running
"""

template_fields: Sequence[str] = (
Expand All @@ -208,6 +225,7 @@ def __init__(
*,
sql: str,
gcp_conn_id: str = "google_cloud_default",
project_id: str = PROVIDE_PROJECT_ID,
use_legacy_sql: bool = True,
location: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
Expand All @@ -228,6 +246,7 @@ def __init__(
self.deferrable = deferrable
self.poll_interval = poll_interval
self.query_params = query_params
self.project_id = project_id

def _submit_job(
self,
Expand All @@ -243,7 +262,7 @@ def _submit_job(

return hook.insert_job(
configuration=configuration,
project_id=hook.project_id,
project_id=self.project_id,
location=self.location,
job_id=job_id,
nowait=True,
Expand All @@ -257,6 +276,8 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
if self.project_id is None:
self.project_id = hook.project_id
job = self._submit_job(hook, job_id="")
context["ti"].xcom_push(key="job_id", value=job.job_id)
if job.running():
Expand All @@ -265,7 +286,7 @@ def execute(self, context: Context):
trigger=BigQueryCheckTrigger(
conn_id=self.gcp_conn_id,
job_id=job.job_id,
project_id=hook.project_id,
project_id=self.project_id,
location=self.location or hook.location,
poll_interval=self.poll_interval,
impersonation_chain=self.impersonation_chain,
Expand Down Expand Up @@ -342,6 +363,7 @@ class BigQueryValueCheckOperator(
:param deferrable: Run operator in the deferrable mode.
:param poll_interval: (Deferrable mode only) polling period in seconds to
check for the status of job.
:param project_id: Google Cloud Project where the job is running
"""

template_fields: Sequence[str] = (
Expand All @@ -363,6 +385,7 @@ def __init__(
tolerance: Any = None,
encryption_configuration: dict | None = None,
gcp_conn_id: str = "google_cloud_default",
project_id: str = PROVIDE_PROJECT_ID,
use_legacy_sql: bool = True,
location: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
Expand All @@ -380,6 +403,7 @@ def __init__(
self.labels = labels
self.deferrable = deferrable
self.poll_interval = poll_interval
self.project_id = project_id

def _submit_job(
self,
Expand All @@ -398,7 +422,7 @@ def _submit_job(

return hook.insert_job(
configuration=configuration,
project_id=hook.project_id,
project_id=self.project_id,
location=self.location,
job_id=job_id,
nowait=True,
Expand All @@ -409,7 +433,8 @@ def execute(self, context: Context) -> None: # type: ignore[override]
super().execute(context=context)
else:
hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)

if self.project_id is None:
self.project_id = hook.project_id
job = self._submit_job(hook, job_id="")
context["ti"].xcom_push(key="job_id", value=job.job_id)
if job.running():
Expand All @@ -418,7 +443,7 @@ def execute(self, context: Context) -> None: # type: ignore[override]
trigger=BigQueryValueCheckTrigger(
conn_id=self.gcp_conn_id,
job_id=job.job_id,
project_id=hook.project_id,
project_id=self.project_id,
location=self.location or hook.location,
sql=self.sql,
pass_value=self.pass_value,
Expand Down Expand Up @@ -575,6 +600,9 @@ def execute(self, context: Context):
hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
self.log.info("Using ratio formula: %s", self.ratio_formula)

if self.project_id is None:
self.project_id = hook.project_id

self.log.info("Executing SQL check: %s", self.sql1)
job_1 = self._submit_job(hook, sql=self.sql1, job_id="")
context["ti"].xcom_push(key="job_id", value=job_1.job_id)
Expand All @@ -587,7 +615,7 @@ def execute(self, context: Context):
conn_id=self.gcp_conn_id,
first_job_id=job_1.job_id,
second_job_id=job_2.job_id,
project_id=hook.project_id,
project_id=self.project_id,
table=self.table,
location=self.location or hook.location,
metrics_thresholds=self.metrics_thresholds,
Expand Down Expand Up @@ -654,6 +682,7 @@ class BigQueryColumnCheckOperator(
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param labels: a dictionary containing labels for the table, passed to BigQuery
:param project_id: Google Cloud Project where the job is running
"""

template_fields: Sequence[str] = tuple(set(SQLColumnCheckOperator.template_fields) | {"gcp_conn_id"})
Expand All @@ -670,6 +699,7 @@ def __init__(
accept_none: bool = True,
encryption_configuration: dict | None = None,
gcp_conn_id: str = "google_cloud_default",
project_id: str = PROVIDE_PROJECT_ID,
use_legacy_sql: bool = True,
location: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
Expand All @@ -695,6 +725,7 @@ def __init__(
self.location = location
self.impersonation_chain = impersonation_chain
self.labels = labels
self.project_id = project_id

def _submit_job(
self,
Expand All @@ -706,7 +737,7 @@ def _submit_job(
self.include_encryption_configuration(configuration, "query")
return hook.insert_job(
configuration=configuration,
project_id=hook.project_id,
project_id=self.project_id,
location=self.location,
job_id=job_id,
nowait=False,
Expand All @@ -715,6 +746,9 @@ def _submit_job(
def execute(self, context=None):
"""Perform checks on the given columns."""
hook = self.get_db_hook()

if self.project_id is None:
self.project_id = hook.project_id
failed_tests = []

job = self._submit_job(hook, job_id="")
Expand Down Expand Up @@ -786,6 +820,7 @@ class BigQueryTableCheckOperator(
account from the list granting this role to the originating account (templated).
:param labels: a dictionary containing labels for the table, passed to BigQuery
:param encryption_configuration: (Optional) Custom encryption configuration (e.g., Cloud KMS keys).
:param project_id: Google Cloud Project where the job is running

.. code-block:: python

Expand All @@ -805,6 +840,7 @@ def __init__(
checks: dict,
partition_clause: str | None = None,
gcp_conn_id: str = "google_cloud_default",
project_id: str = PROVIDE_PROJECT_ID,
use_legacy_sql: bool = True,
location: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
Expand All @@ -819,6 +855,7 @@ def __init__(
self.impersonation_chain = impersonation_chain
self.labels = labels
self.encryption_configuration = encryption_configuration
self.project_id = project_id

def _submit_job(
self,
Expand All @@ -832,7 +869,7 @@ def _submit_job(

return hook.insert_job(
configuration=configuration,
project_id=hook.project_id,
project_id=self.project_id,
location=self.location,
job_id=job_id,
nowait=False,
Expand All @@ -841,6 +878,8 @@ def _submit_job(
def execute(self, context=None):
"""Execute the given checks on the table."""
hook = self.get_db_hook()
if self.project_id is None:
self.project_id = hook.project_id
job = self._submit_job(hook, job_id="")
context["ti"].xcom_push(key="job_id", value=job.job_id)
records = job.result().to_dataframe()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -73,6 +74,7 @@ class BigQueryToBigQueryOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param project_id: Google Cloud Project where the job is running
"""

template_fields: Sequence[str] = (
Expand All @@ -93,6 +95,7 @@ def __init__(
write_disposition: str = "WRITE_EMPTY",
create_disposition: str = "CREATE_IF_NEEDED",
gcp_conn_id: str = "google_cloud_default",
project_id: str = PROVIDE_PROJECT_ID,
labels: dict | None = None,
encryption_configuration: dict | None = None,
location: str | None = None,
Expand All @@ -112,6 +115,7 @@ def __init__(
self.impersonation_chain = impersonation_chain
self.hook: BigQueryHook | None = None
self._job_conf: dict = {}
self.project_id = project_id

def _prepare_job_configuration(self):
self.source_project_dataset_tables = (
Expand All @@ -124,7 +128,7 @@ def _prepare_job_configuration(self):
for source_project_dataset_table in self.source_project_dataset_tables:
source_project, source_dataset, source_table = self.hook.split_tablename(
table_input=source_project_dataset_table,
default_project_id=self.hook.project_id,
default_project_id=self.project_id,
var_name="source_project_dataset_table",
)
source_project_dataset_tables_fixup.append(
Expand All @@ -133,7 +137,7 @@ def _prepare_job_configuration(self):

destination_project, destination_dataset, destination_table = self.hook.split_tablename(
table_input=self.destination_project_dataset_table,
default_project_id=self.hook.project_id,
default_project_id=self.project_id,
)
configuration = {
"copy": {
Expand Down Expand Up @@ -168,12 +172,12 @@ def execute(self, context: Context) -> None:
impersonation_chain=self.impersonation_chain,
)

if not self.hook.project_id:
raise ValueError("The project_id should be set")
if not self.project_id:
self.project_id = self.hook.project_id

configuration = self._prepare_job_configuration()
self._job_conf = self.hook.insert_job(
configuration=configuration, project_id=self.hook.project_id
configuration=configuration, project_id=self.project_id
).to_api_repr()

dest_table_info = self._job_conf["configuration"]["copy"]["destinationTable"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2442,6 +2442,17 @@ def test_bigquery_interval_check_operator_execute_failure(self):
context=None, event={"status": "error", "message": "test failure message"}
)

def test_bigquery_interval_check_operator_project_id(self):
operator = BigQueryIntervalCheckOperator(
task_id="bq_interval_check_operator_project_id",
table="test_table",
metrics_thresholds={"COUNT(*)": 1.5},
location=TEST_DATASET_LOCATION,
project_id=TEST_JOB_PROJECT_ID,
)

assert operator.project_id == TEST_JOB_PROJECT_ID

@pytest.mark.db_test
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_interval_check_operator_async(self, mock_hook, create_task_instance_of_operator):
Expand Down Expand Up @@ -2710,6 +2721,16 @@ def test_bigquery_check_operator_execute_failure(self):
context=None, event={"status": "error", "message": "test failure message"}
)

def test_bigquery_check_operator_project_id(self):
operator = BigQueryCheckOperator(
task_id="bq_check_operator_project_id",
sql="SELECT * FROM any",
location=TEST_DATASET_LOCATION,
project_id=TEST_JOB_PROJECT_ID,
)

assert operator.project_id == TEST_JOB_PROJECT_ID

def test_bigquery_check_op_execute_complete_with_no_records(self):
"""Asserts that exception is raised with correct expected exception message"""

Expand Down Expand Up @@ -2869,6 +2890,17 @@ def test_bigquery_value_check_empty(self):
BigQueryValueCheckOperator(deferrable=True, kwargs={})
assert missing_param.value.args[0] in (expected, expected1)

def test_bigquery_value_check_project_id(self):
operator = BigQueryValueCheckOperator(
task_id="check_value",
sql="SELECT COUNT(*) FROM Any",
pass_value=2,
use_legacy_sql=False,
project_id=TEST_JOB_PROJECT_ID,
)

assert operator.project_id == TEST_JOB_PROJECT_ID

def test_bigquery_value_check_operator_execute_complete_success(self):
"""Tests response message in case of success event"""
operator = BigQueryValueCheckOperator(
Expand Down Expand Up @@ -2954,7 +2986,7 @@ class TestBigQueryColumnCheckOperator:
("leq_to", 0, -1),
],
)
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
@mock.patch("airflow.providers.google.cloud.operators.bigquery._BigQueryHookWithFlexibleProjectId")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryJob")
def test_bigquery_column_check_operator_succeeds(
self, mock_job, mock_hook, check_type, check_value, check_result, create_task_instance_of_operator
Expand Down Expand Up @@ -2986,7 +3018,7 @@ def test_bigquery_column_check_operator_succeeds(
("leq_to", 0, 1),
],
)
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
@mock.patch("airflow.providers.google.cloud.operators.bigquery._BigQueryHookWithFlexibleProjectId")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryJob")
def test_bigquery_column_check_operator_fails(
self, mock_job, mock_hook, check_type, check_value, check_result, create_task_instance_of_operator
Expand Down Expand Up @@ -3017,7 +3049,7 @@ def test_bigquery_column_check_operator_fails(
("less_than", 0, -1),
],
)
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
@mock.patch("airflow.providers.google.cloud.operators.bigquery._BigQueryHookWithFlexibleProjectId")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryJob")
def test_encryption_configuration(self, mock_job, mock_hook, check_type, check_value, check_result):
encryption_configuration = {
Expand Down Expand Up @@ -3058,7 +3090,7 @@ def test_encryption_configuration(self, mock_job, mock_hook, check_type, check_v


class TestBigQueryTableCheckOperator:
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
@mock.patch("airflow.providers.google.cloud.operators.bigquery._BigQueryHookWithFlexibleProjectId")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryJob")
def test_encryption_configuration(self, mock_job, mock_hook):
encryption_configuration = {
Expand Down