Skip to content

Commit

Permalink
Fix mypy typing for airflow/models and their tests
Browse files Browse the repository at this point in the history
The re-ordering of setting attributes in BaseOperator is because
_something_ about that function (throwing the exceptions?) causes mypy
to think that BaseOperator objects could be missing those attributes
  • Loading branch information
ashb committed Dec 14, 2021
1 parent b20e6d3 commit 38c2737
Show file tree
Hide file tree
Showing 12 changed files with 62 additions and 56 deletions.
4 changes: 2 additions & 2 deletions airflow/api/common/experimental/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from airflow.operators.subdag import SubDagOperator
from airflow.utils import timezone
from airflow.utils.session import provide_session
from airflow.utils.state import State
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.types import DagRunType


Expand Down Expand Up @@ -68,7 +68,7 @@ def set_state(
downstream: bool = False,
future: bool = False,
past: bool = False,
state: str = State.SUCCESS,
state: TaskInstanceState = TaskInstanceState.SUCCESS,
commit: bool = False,
session=None,
):
Expand Down
62 changes: 30 additions & 32 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,12 @@ def apply_defaults(self, *args: Any, **kwargs: Any) -> Any:
if len(args) > 0:
raise AirflowException("Use keyword arguments when initializing operators")
dag_args: Dict[str, Any] = {}
dag_params: Dict[str, Any] = {}
dag_params = ParamsDict()

dag: Optional[DAG] = kwargs.get('dag') or DagContext.get_current_dag()
if dag:
dag_args = copy.copy(dag.default_args) or {}
dag_params = copy.deepcopy(dag.params.dump())
dag_args = copy.copy(dag.default_args) or dag_args
dag_params = copy.deepcopy(dag.params) or dag_params
task_group = TaskGroupContext.get_current_task_group(dag)
if task_group:
dag_args.update(task_group.default_args)
Expand Down Expand Up @@ -567,6 +567,13 @@ def __init__(
self.email = email
self.email_on_retry = email_on_retry
self.email_on_failure = email_on_failure
self.execution_timeout = execution_timeout
self.on_execute_callback = on_execute_callback
self.on_failure_callback = on_failure_callback
self.on_success_callback = on_success_callback
self.on_retry_callback = on_retry_callback
self._pre_execute_hook = pre_execute
self._post_execute_hook = post_execute

self.start_date = start_date
if start_date and not isinstance(start_date, datetime):
Expand All @@ -578,6 +585,25 @@ def __init__(
if end_date:
self.end_date = timezone.convert_to_utc(end_date)

if retries is not None and not isinstance(retries, int):
try:
parsed_retries = int(retries)
except (TypeError, ValueError):
raise AirflowException(f"'retries' type must be int, not {type(retries).__name__}")
self.log.warning("Implicitly converting 'retries' for %s from %r to int", self, retries)
retries = parsed_retries

self.executor_config = executor_config or {}
self.run_as_user = run_as_user
self.retries = retries
self.queue = queue
self.pool = Pool.DEFAULT_POOL_NAME if pool is None else pool
self.pool_slots = pool_slots
if self.pool_slots < 1:
dag_str = f" in dag {dag.dag_id}" if dag else ""
raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1")
self.sla = sla

if trigger_rule == "dummy":
warnings.warn(
"dummy Trigger Rule is deprecated. Please use `TriggerRule.ALWAYS`.",
Expand Down Expand Up @@ -607,30 +633,6 @@ def __init__(
if wait_for_downstream:
self.depends_on_past = True

if retries is not None and not isinstance(retries, int):
try:
parsed_retries = int(retries)
except (TypeError, ValueError):
raise AirflowException(f"'retries' type must be int, not {type(retries).__name__}")
self.log.warning("Implicitly converting 'retries' for %s from %r to int", self, retries)
retries = parsed_retries

self.retries = retries
self.queue = queue
self.pool = Pool.DEFAULT_POOL_NAME if pool is None else pool
self.pool_slots = pool_slots
if self.pool_slots < 1:
dag_str = f" in dag {dag.dag_id}" if dag else ""
raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot be less than 1")
self.sla = sla
self.execution_timeout = execution_timeout
self.on_execute_callback = on_execute_callback
self.on_failure_callback = on_failure_callback
self.on_success_callback = on_success_callback
self.on_retry_callback = on_retry_callback
self._pre_execute_hook = pre_execute
self._post_execute_hook = post_execute

if isinstance(retry_delay, timedelta):
self.retry_delay = retry_delay
else:
Expand Down Expand Up @@ -661,7 +663,6 @@ def __init__(
)
self.weight_rule = weight_rule
self.resources: Optional[Resources] = Resources(**resources) if resources else None
self.run_as_user = run_as_user
if task_concurrency and not max_active_tis_per_dag:
# TODO: Remove in Airflow 3.0
warnings.warn(
Expand All @@ -671,7 +672,6 @@ def __init__(
)
max_active_tis_per_dag = task_concurrency
self.max_active_tis_per_dag = max_active_tis_per_dag
self.executor_config = executor_config or {}
self.do_xcom_push = do_xcom_push

self.doc_md = doc_md
Expand Down Expand Up @@ -1043,9 +1043,7 @@ def __setstate__(self, state):
self._log = logging.getLogger("airflow.task.operators")

def render_template_fields(
self,
context: Context,
jinja_env: Optional[jinja2.Environment] = None,
self, context: Context, jinja_env: Optional[jinja2.Environment] = None
) -> None:
"""
Template all attributes listed in template_fields. Note this operation is irreversible.
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def get_uri(self) -> str:

if self.extra:
try:
query = urlencode(self.extra_dejson)
query: Optional[str] = urlencode(self.extra_dejson)
except TypeError:
query = None
if query and self.extra_dejson == dict(parse_qsl(query, keep_blank_values=True)):
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,7 +1638,7 @@ def set_task_instance_state(
future: bool = False,
past: bool = False,
commit: bool = True,
session: Session = NEW_SESSION,
session=NEW_SESSION,
) -> List[TaskInstance]:
"""
Set the state of a TaskInstance to the given state, and clear its downstream tasks that are
Expand All @@ -1649,7 +1649,7 @@ def set_task_instance_state(
:param execution_date: execution_date of the TaskInstance
:type execution_date: datetime
:param state: State to set the TaskInstance to
:type state: State
:type state: TaskInstanceState
:param upstream: Include all upstream tasks of the given task_id
:type upstream: bool
:param downstream: Include all downstream tasks of the given task_id
Expand Down
7 changes: 5 additions & 2 deletions airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
# under the License.
import copy
from typing import Any, Dict, ItemsView, MutableMapping, Optional, ValuesView
from typing import TYPE_CHECKING, Any, Dict, ItemsView, MutableMapping, Optional, ValuesView

import jsonschema
from jsonschema import FormatChecker
from jsonschema.exceptions import ValidationError

from airflow.exceptions import AirflowException
from airflow.utils.context import Context
from airflow.utils.types import NOTSET, ArgNotSet

if TYPE_CHECKING:
from airflow.utils.context import Context


class Param:
"""
Expand Down Expand Up @@ -235,7 +238,7 @@ def __init__(self, current_dag, name: str, default: Optional[Any] = None):
self._name = name
self._default = default

def resolve(self, context: Context) -> Any:
def resolve(self, context: "Context") -> Any:
"""Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
default = self._default
if not self._default:
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1951,7 +1951,7 @@ def get_prev_ds_nodash() -> Optional[str]:
'yesterday_ds': get_yesterday_ds(),
'yesterday_ds_nodash': get_yesterday_ds_nodash(),
}
return Context(context)
return Context(context) # type: ignore

@provide_session
def get_rendered_template_fields(self, session=NEW_SESSION):
Expand Down
3 changes: 2 additions & 1 deletion airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,9 @@ def clear(
execution_date: Optional[pendulum.DateTime] = None,
dag_id: Optional[str] = None,
task_id: Optional[str] = None,
run_id: Optional[str] = None,
session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
) -> None:
""":sphinx-autoapi-skip:"""
# Given the historic order of this function (execution_date was first argument) to add a new optional
Expand Down
9 changes: 5 additions & 4 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union

from airflow.exceptions import AirflowException
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskmixin import TaskMixin
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.utils.context import Context
from airflow.utils.edgemodifier import EdgeModifier

if TYPE_CHECKING:
from airflow.utils.context import Context


class XComArg(TaskMixin):
"""
Expand Down Expand Up @@ -129,7 +130,7 @@ def set_downstream(
"""Proxy to underlying operator set_downstream method. Required by TaskMixin."""
self.operator.set_downstream(task_or_task_list, edge_modifier)

def resolve(self, context: Context) -> Any:
def resolve(self, context: "Context") -> Any:
"""
Pull XCom value for the existing arg. This method is run during ``op.execute()``
in respectable context.
Expand Down
9 changes: 5 additions & 4 deletions airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
# under the License.

from collections import Counter

from sqlalchemy.orm import Session
from typing import TYPE_CHECKING

from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule as TR

if TYPE_CHECKING:
from airflow.settings import SASession as Session


class TriggerRuleDep(BaseTIDep):
"""
Expand Down Expand Up @@ -92,8 +94,7 @@ def _evaluate_trigger_rule(
upstream_failed,
done,
flag_upstream_failed,
*,
session: Session = NEW_SESSION,
session: "Session" = NEW_SESSION,
):
"""
Yields a dependency status that indicate whether the given task instance's trigger
Expand Down
9 changes: 5 additions & 4 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Any,
Container,
Dict,
ItemsView,
Iterator,
List,
MutableMapping,
Expand Down Expand Up @@ -184,11 +185,11 @@ def __ne__(self, other: Any) -> bool:
def keys(self) -> AbstractSet[str]:
return self._context.keys()

def items(self) -> AbstractSet[Tuple[str, Any]]:
return self._context.items()
def items(self):
return ItemsView(self._context)

def values(self) -> ValuesView[Any]:
return self._context.values()
def values(self):
return ValuesView(self._context)

def copy_only(self, keys: Container[str]) -> "Context":
new = type(self)({k: v for k, v in self._context.items() if k in keys})
Expand Down
3 changes: 2 additions & 1 deletion airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# undefined attribute errors from Mypy. Hopefully there will be a mechanism to
# declare "these are defined, but don't error if others are accessed" someday.

from typing import Any, Optional
from typing import Any, Optional, Union

from pendulum import DateTime

Expand Down Expand Up @@ -58,6 +58,7 @@ class Context(TypedDict, total=False):
ds: str
ds_nodash: str
execution_date: DateTime
exception: Union[Exception, str, None]
inlets: list
logical_date: DateTime
macros: Any
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from airflow.utils import timezone
from airflow.utils.file import list_py_file_paths
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State
from airflow.utils.state import DagRunState, State
from airflow.utils.timezone import datetime as datetime_tz
from airflow.utils.types import DagRunType
from airflow.utils.weight_rule import WeightRule
Expand Down Expand Up @@ -1553,7 +1553,7 @@ def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]):
session = settings.Session() # type: ignore
dagrun_1 = dag.create_dagrun(
run_type=DagRunType.BACKFILL_JOB,
state=State.RUNNING,
state=DagRunState.RUNNING,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
)
Expand Down

0 comments on commit 38c2737

Please sign in to comment.