Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
ServicesClient,
UpdateJobRequest,
)
from google.longrunning import operations_pb2 # type: ignore[attr-defined]
from google.longrunning import operations_pb2

from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from google.cloud.storage_transfer_v1.services.storage_transfer_service.pagers import (
ListTransferJobsAsyncPager,
)
from google.longrunning import operations_pb2 # type: ignore[attr-defined]
from google.longrunning import operations_pb2
from proto import Message

log = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from google.auth.transport import requests as google_requests

# not sure why but mypy complains on missing `container_v1` but it is clearly there and is importable
from google.cloud import exceptions # type: ignore[attr-defined]
from google.cloud import exceptions
from google.cloud.container_v1 import ClusterManagerAsyncClient, ClusterManagerClient
from google.cloud.container_v1.types import Cluster, Operation
from kubernetes import client
Expand Down Expand Up @@ -498,7 +498,7 @@ def __init__(
)

@contextlib.asynccontextmanager
async def get_conn(self) -> async_client.ApiClient: # type: ignore[override]
async def get_conn(self) -> async_client.ApiClient:
kube_client = None
try:
kube_client = await self._load_config()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1098,13 +1098,13 @@ def create_auto_ml_text_training_job(
raise AirflowException("AutoMLTextTrainingJob was not created")

model = self._job.run(
dataset=dataset, # type: ignore[arg-type]
training_fraction_split=training_fraction_split, # type: ignore[call-arg]
validation_fraction_split=validation_fraction_split, # type: ignore[call-arg]
dataset=dataset,
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
training_filter_split=training_filter_split,
validation_filter_split=validation_filter_split,
test_filter_split=test_filter_split, # type: ignore[call-arg]
test_filter_split=test_filter_split,
model_display_name=model_display_name,
model_labels=model_labels,
sync=sync,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _transport(self) -> Transport:
"""Object responsible for sending data to Stackdriver."""
# The Transport object is badly defined (no init) but in the docs client/name as constructor
# arguments are a requirement for any class that derives from Transport class, hence ignore:
return self.transport_type(self._client, self.gcp_log_name) # type: ignore[call-arg]
return self.transport_type(self._client, self.gcp_log_name)

def _get_labels(self, task_instance=None):
if task_instance:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_openlineage_facets_on_complete(self, _):
from airflow.providers.openlineage.sqlparser import SQLParser

if not self.job_id:
self.log.warning("No BigQuery job_id was found by OpenLineage.") # type: ignore[attr-defined]
self.log.warning("No BigQuery job_id was found by OpenLineage.")
return OperatorLineage()

if not self.hook:
Expand All @@ -92,34 +92,34 @@ def get_openlineage_facets_on_complete(self, _):
impersonation_chain=self.impersonation_chain,
)

self.log.debug("Extracting data from bigquery job: `%s`", self.job_id) # type: ignore[attr-defined]
self.log.debug("Extracting data from bigquery job: `%s`", self.job_id)
inputs, outputs = [], []
run_facets: dict[str, RunFacet] = {
"externalQuery": ExternalQueryRunFacet(externalQueryId=self.job_id, source="bigquery")
}
self._client = self.hook.get_client(project_id=self.hook.project_id, location=self.location)
try:
job_properties = self._client.get_job(job_id=self.job_id)._properties # type: ignore
job_properties = self._client.get_job(job_id=self.job_id)._properties

if get_from_nullable_chain(job_properties, ["status", "state"]) != "DONE":
raise ValueError(f"Trying to extract data from running bigquery job: `{self.job_id}`")

run_facets["bigQueryJob"] = self._get_bigquery_job_run_facet(job_properties)

if get_from_nullable_chain(job_properties, ["statistics", "numChildJobs"]):
self.log.debug("Found SCRIPT job. Extracting lineage from child jobs instead.") # type: ignore[attr-defined]
self.log.debug("Found SCRIPT job. Extracting lineage from child jobs instead.")
# SCRIPT job type has no input / output information but spawns child jobs that have one
# https://cloud.google.com/bigquery/docs/information-schema-jobs#multi-statement_query_job
for child_job_id in self._client.list_jobs(parent_job=self.job_id):
child_job_properties = self._client.get_job(job_id=child_job_id)._properties # type: ignore
child_job_properties = self._client.get_job(job_id=child_job_id)._properties
child_inputs, child_outputs = self._get_inputs_and_outputs(child_job_properties)
inputs.extend(child_inputs)
outputs.extend(child_outputs)
else:
inputs, outputs = self._get_inputs_and_outputs(job_properties)

except Exception as e:
self.log.warning("Cannot retrieve job details from BigQuery.Client. %s", e, exc_info=True) # type: ignore[attr-defined]
self.log.warning("Cannot retrieve job details from BigQuery.Client. %s", e, exc_info=True)
exception_msg = traceback.format_exc()
run_facets.update(
{
Expand Down Expand Up @@ -173,7 +173,7 @@ def _deduplicate_outputs(self, outputs: Iterable[OutputDataset | None]) -> list[
if (
single_output.facets
and final_outputs[key].facets
and "columnLineage" in single_output.facets # type: ignore
and "columnLineage" in single_output.facets
and "columnLineage" in final_outputs[key].facets # type: ignore
):
single_output.facets["columnLineage"] = merge_column_lineage_facets(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def __init__(

self.model_id = model_id
self.endpoint_id = endpoint_id
self.operation_params = operation_params # type: ignore
self.operation_params = operation_params
self.instances = instances
self.location = location
self.project_id = project_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.common.sql.operators.sql import ( # type: ignore[attr-defined] # for _parse_boolean
from airflow.providers.common.sql.operators.sql import ( # for _parse_boolean
SQLCheckOperator,
SQLColumnCheckOperator,
SQLIntervalCheckOperator,
Expand Down Expand Up @@ -311,9 +311,7 @@ def _validate_records(self, records) -> None:
if not records:
raise AirflowException(f"The following query returned zero rows: {self.sql}")
if not all(records):
self._raise_exception( # type: ignore[attr-defined]
f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}"
)
self._raise_exception(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}")

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Expand Down Expand Up @@ -430,7 +428,7 @@ def _submit_job(
nowait=True,
)

def execute(self, context: Context) -> None: # type: ignore[override]
def execute(self, context: Context) -> None:
if not self.deferrable:
super().execute(context=context)
else:
Expand Down Expand Up @@ -3041,7 +3039,7 @@ def execute(self, context: Any):

if self.project_id:
job_id_path = convert_job_id(
job_id=self.job_id, # type: ignore[arg-type]
job_id=self.job_id,
project_id=self.project_id,
location=self.location,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def _inject_aws_credentials(self) -> None:

aws_hook = AwsBaseHook(self.aws_conn_id, resource_type="s3")
aws_credentials = aws_hook.get_credentials()
aws_access_key_id = aws_credentials.access_key # type: ignore[attr-defined]
aws_secret_access_key = aws_credentials.secret_key # type: ignore[attr-defined]
aws_access_key_id = aws_credentials.access_key
aws_secret_access_key = aws_credentials.secret_key
self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] = {
ACCESS_KEY_ID: aws_access_key_id,
SECRET_ACCESS_KEY: aws_secret_access_key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1834,7 +1834,7 @@ def execute(self, context: Context):
project_id=project_id,
)

return [DeidentifyTemplate.to_dict(template) for template in templates] # type: ignore[arg-type]
return [DeidentifyTemplate.to_dict(template) for template in templates]


class CloudDLPListDLPJobsOperator(GoogleCloudBaseOperator):
Expand Down Expand Up @@ -1930,7 +1930,7 @@ def execute(self, context: Context):
)

# the DlpJob.to_dict does not have the right type defined as possible to pass in constructor
return [DlpJob.to_dict(job) for job in jobs] # type: ignore[arg-type]
return [DlpJob.to_dict(job) for job in jobs]


class CloudDLPListInfoTypesOperator(GoogleCloudBaseOperator):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -787,10 +787,10 @@ def execute_deferrable(self):
trigger=GKEJobTrigger(
cluster_url=self.cluster_url,
ssl_ca_cert=self.ssl_ca_cert,
job_name=self.job.metadata.name, # type: ignore[union-attr]
job_namespace=self.job.metadata.namespace, # type: ignore[union-attr]
pod_name=self.pod.metadata.name, # type: ignore[union-attr]
pod_namespace=self.pod.metadata.namespace, # type: ignore[union-attr]
job_name=self.job.metadata.name,
job_namespace=self.job.metadata.namespace,
pod_name=self.pod.metadata.name,
pod_namespace=self.pod.metadata.namespace,
base_container_name=self.base_container_name,
gcp_conn_id=self.gcp_conn_id,
poll_interval=self.job_poll_interval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def execute(self, context: Context):
location=self.location,
product_id=self.product_id,
project_id=self.project_id,
update_mask=self.update_mask, # type: ignore
update_mask=self.update_mask,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ class OracleToGCSOperator(BaseSQLToGCSOperator):
ui_color = "#a0e08c"

type_map = {
oracledb.DB_TYPE_BINARY_DOUBLE: "DECIMAL", # type: ignore
oracledb.DB_TYPE_BINARY_FLOAT: "DECIMAL", # type: ignore
oracledb.DB_TYPE_BINARY_INTEGER: "INTEGER", # type: ignore
oracledb.DB_TYPE_BOOLEAN: "BOOLEAN", # type: ignore
oracledb.DB_TYPE_DATE: "TIMESTAMP", # type: ignore
oracledb.DB_TYPE_NUMBER: "NUMERIC", # type: ignore
oracledb.DB_TYPE_TIMESTAMP: "TIMESTAMP", # type: ignore
oracledb.DB_TYPE_TIMESTAMP_LTZ: "TIMESTAMP", # type: ignore
oracledb.DB_TYPE_TIMESTAMP_TZ: "TIMESTAMP", # type: ignore
oracledb.DB_TYPE_BINARY_DOUBLE: "DECIMAL",
oracledb.DB_TYPE_BINARY_FLOAT: "DECIMAL",
oracledb.DB_TYPE_BINARY_INTEGER: "INTEGER",
oracledb.DB_TYPE_BOOLEAN: "BOOLEAN",
oracledb.DB_TYPE_DATE: "TIMESTAMP",
oracledb.DB_TYPE_NUMBER: "NUMERIC",
oracledb.DB_TYPE_TIMESTAMP: "TIMESTAMP",
oracledb.DB_TYPE_TIMESTAMP_LTZ: "TIMESTAMP",
oracledb.DB_TYPE_TIMESTAMP_TZ: "TIMESTAMP",
}

def __init__(self, *, oracle_conn_id="oracle_default", ensure_utc=False, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ async def safe_to_cancel(self) -> bool:
task_state = task_instance.state
return task_state != TaskInstanceState.DEFERRED

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Get current job execution status and yields a TriggerEvent."""
hook = self._get_async_hook()
try:
Expand Down Expand Up @@ -192,9 +192,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
self.location,
self.job_id,
)
await hook.cancel_job( # type: ignore[union-attr]
job_id=self.job_id, project_id=self.project_id, location=self.location
)
await hook.cancel_job(job_id=self.job_id, project_id=self.project_id, location=self.location)
else:
self.log.info(
"Trigger may have shutdown. Skipping to cancel job because the airflow "
Expand Down Expand Up @@ -231,7 +229,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Get current job execution status and yields a TriggerEvent."""
hook = self._get_async_hook()
try:
Expand Down Expand Up @@ -308,7 +306,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Get current job execution status and yields a TriggerEvent with response data."""
hook = self._get_async_hook()
try:
Expand Down Expand Up @@ -433,7 +431,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Get current job execution status and yields a TriggerEvent."""
hook = self._get_async_hook()
try:
Expand Down Expand Up @@ -581,7 +579,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Get current job execution status and yields a TriggerEvent."""
hook = self._get_async_hook()
try:
Expand Down Expand Up @@ -667,7 +665,7 @@ def _get_async_hook(self) -> BigQueryTableAsyncHook:
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Will run until the table exists in the Google Big Query."""
try:
while True:
Expand Down Expand Up @@ -750,7 +748,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Will run until the table exists in the Google Big Query."""
hook = BigQueryAsyncHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
job_id = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Get current build execution status and yields a TriggerEvent."""
hook = self._get_async_hook()
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
from google.longrunning import operations_pb2 # type: ignore[attr-defined]
from google.longrunning import operations_pb2

DEFAULT_BATCH_LOCATION = "us-central1"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Get current data storage transfer jobs and yields a TriggerEvent."""
async_hook: CloudDataTransferServiceAsyncHook = self.get_async_hook()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Get current pipeline status and yields a TriggerEvent."""
hook = self._get_async_hook()
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
try:
while self.end_time > time.time():
cluster = await self.get_async_hook().get_cluster(
region=self.region, # type: ignore[arg-type]
region=self.region,
cluster_name=self.cluster_name,
project_id=self.project_id, # type: ignore[arg-type]
project_id=self.project_id,
metadata=self.metadata,
)
self.log.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
)

@cached_property
def hook(self) -> GKEKubernetesAsyncHook: # type: ignore[override]
def hook(self) -> GKEKubernetesAsyncHook:
return GKEKubernetesAsyncHook(
cluster_url=self._cluster_url,
ssl_ca_cert=self._ssl_ca_cert,
Expand Down Expand Up @@ -200,7 +200,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Get operation status and yields corresponding event."""
hook = self._get_hook()
try:
Expand Down Expand Up @@ -303,7 +303,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Get current job status and yield a TriggerEvent."""
if self.get_logs or self.do_xcom_push:
pod = await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace)
Expand Down
Loading