Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor operator links to not create ad hoc TaskInstances #21285

Merged
merged 1 commit into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def get_one(
@classmethod
def get_one(
cls,
execution_date: pendulum.DateTime,
execution_date: datetime.datetime,
key: Optional[str] = None,
task_id: Optional[str] = None,
dag_id: Optional[str] = None,
Expand All @@ -233,7 +233,7 @@ def get_one(
@provide_session
def get_one(
cls,
execution_date: Optional[pendulum.DateTime] = None,
execution_date: Optional[datetime.datetime] = None,
key: Optional[str] = None,
task_id: Optional[Union[str, Iterable[str]]] = None,
dag_id: Optional[Union[str, Iterable[str]]] = None,
Expand Down Expand Up @@ -314,7 +314,7 @@ def get_many(
@classmethod
def get_many(
cls,
execution_date: pendulum.DateTime,
execution_date: datetime.datetime,
key: Optional[str] = None,
task_ids: Union[str, Iterable[str], None] = None,
dag_ids: Union[str, Iterable[str], None] = None,
Expand All @@ -328,7 +328,7 @@ def get_many(
@provide_session
def get_many(
cls,
execution_date: Optional[pendulum.DateTime] = None,
execution_date: Optional[datetime.datetime] = None,
key: Optional[str] = None,
task_ids: Optional[Union[str, Iterable[str]]] = None,
dag_ids: Optional[Union[str, Iterable[str]]] = None,
Expand Down
7 changes: 4 additions & 3 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from uuid import uuid4

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, BaseOperatorLink, TaskInstance
from airflow.models import BaseOperator, BaseOperatorLink, XCom
from airflow.providers.amazon.aws.hooks.emr import EmrHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -238,8 +238,9 @@ def get_link(self, operator: BaseOperator, dttm: datetime) -> str:
:param dttm: datetime
:return: url link
"""
ti = TaskInstance(task=operator, execution_date=dttm)
flow_id = ti.xcom_pull(task_ids=operator.task_id)
flow_id = XCom.get_one(
key="return_value", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm
)
return (
f'https://console.aws.amazon.com/elasticmapreduce/home#cluster-details:{flow_id}'
if flow_id
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, BaseOperatorLink
from airflow.models.taskinstance import TaskInstance
from airflow.models.xcom import XCom
from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob
Expand Down Expand Up @@ -84,8 +83,9 @@ def name(self) -> str:
return f'BigQuery Console #{self.index + 1}'

def get_link(self, operator: BaseOperator, dttm: datetime):
ti = TaskInstance(task=operator, execution_date=dttm)
job_ids = ti.xcom_pull(task_ids=operator.task_id, key='job_id')
job_ids = XCom.get_one(
key='job_id', dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm
)
if not job_ids:
return None
if len(job_ids) < self.index:
Expand Down
13 changes: 7 additions & 6 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@
from google.protobuf.field_mask_pb2 import FieldMask

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, BaseOperatorLink
from airflow.models.taskinstance import TaskInstance
from airflow.models import BaseOperator, BaseOperatorLink, XCom
from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProcJobBuilder
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.utils import timezone
Expand All @@ -59,8 +58,9 @@ class DataprocJobLink(BaseOperatorLink):
name = "Dataproc Job"

def get_link(self, operator, dttm):
ti = TaskInstance(task=operator, execution_date=dttm)
job_conf = ti.xcom_pull(task_ids=operator.task_id, key="job_conf")
job_conf = XCom.get_one(
key="job_conf", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm
)
return (
DATAPROC_JOB_LOG_LINK.format(
job_id=job_conf["job_id"],
Expand All @@ -78,8 +78,9 @@ class DataprocClusterLink(BaseOperatorLink):
name = "Dataproc Cluster"

def get_link(self, operator, dttm):
ti = TaskInstance(task=operator, execution_date=dttm)
cluster_conf = ti.xcom_pull(task_ids=operator.task_id, key="cluster_conf")
cluster_conf = XCom.get_one(
key="cluster_conf", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm
)
return (
DATAPROC_CLUSTER_LINK.format(
cluster_name=cluster_conf["cluster_name"],
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/google/cloud/operators/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, BaseOperatorLink
from airflow.models.taskinstance import TaskInstance
from airflow.models import BaseOperator, BaseOperatorLink, XCom
from airflow.providers.google.cloud.hooks.mlengine import MLEngineHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -980,8 +979,9 @@ class AIPlatformConsoleLink(BaseOperatorLink):
name = "AI Platform Console"

def get_link(self, operator, dttm):
task_instance = TaskInstance(task=operator, execution_date=dttm)
gcp_metadata_dict = task_instance.xcom_pull(task_ids=operator.task_id, key="gcp_metadata")
gcp_metadata_dict = XCom.get_one(
key="gcp_metadata", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm
)
if not gcp_metadata_dict:
return ''
job_id = gcp_metadata_dict['job_id']
Expand Down
10 changes: 7 additions & 3 deletions airflow/providers/microsoft/azure/operators/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence

from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator, BaseOperatorLink, TaskInstance
from airflow.models import BaseOperator, BaseOperatorLink, XCom
from airflow.providers.microsoft.azure.hooks.data_factory import (
AzureDataFactoryHook,
AzureDataFactoryPipelineRunException,
Expand All @@ -35,8 +35,12 @@ class AzureDataFactoryPipelineRunLink(BaseOperatorLink):
name = "Monitor Pipeline Run"

def get_link(self, operator, dttm):
ti = TaskInstance(task=operator, execution_date=dttm)
run_id = ti.xcom_pull(task_ids=operator.task_id, key="run_id")
run_id = XCom.get_one(
key="run_id",
dag_id=operator.dag.dag_id,
task_id=operator.task_id,
execution_date=dttm,
)

conn = BaseHook.get_connection(operator.azure_data_factory_conn_id)
subscription_id = conn.extra_dejson["extra__azure_data_factory__subscriptionId"]
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/qubole/operators/qubole.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from typing import TYPE_CHECKING, Optional, Sequence

from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator, BaseOperatorLink
from airflow.models.taskinstance import TaskInstance
from airflow.models import BaseOperator, BaseOperatorLink, XCom
from airflow.providers.qubole.hooks.qubole import (
COMMAND_ARGS,
HYPHEN_ARGS,
Expand All @@ -48,7 +47,6 @@ def get_link(self, operator: BaseOperator, dttm: datetime) -> str:
:param dttm: datetime
:return: url link
"""
ti = TaskInstance(task=operator, execution_date=dttm)
conn = BaseHook.get_connection(
getattr(operator, "qubole_conn_id", None)
or operator.kwargs['qubole_conn_id'] # type: ignore[attr-defined]
Expand All @@ -57,7 +55,9 @@ def get_link(self, operator: BaseOperator, dttm: datetime) -> str:
host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host)
else:
host = 'https://api.qubole.com/v2/analyze?command_id='
qds_command_id = ti.xcom_pull(task_ids=operator.task_id, key='qbol_cmd_id')
qds_command_id = XCom.get_one(
key='qbol_cmd_id', dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm
)
url = host + str(qds_command_id) if qds_command_id else ''
return url

Expand Down