Skip to content

Commit 68558f3

Browse files
authored
Fix mypy errors in models (#58728)
1 parent 8545d3c commit 68558f3

File tree

6 files changed

+10
-10
lines changed

6 files changed

+10
-10
lines changed

airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class TIDeferredStatePayload(StrictBaseModel):
140140
trigger_timeout: timedelta | None = None
141141
next_method: str
142142
"""The name of the method on the operator to call in the worker after the trigger has fired."""
143-
next_kwargs: Annotated[dict[str, Any] | str, Field(default_factory=dict)]
143+
next_kwargs: Annotated[dict[str, Any], Field(default_factory=dict)]
144144
"""
145145
Kwargs to pass to the above method, either a plain dict or an encrypted string.
146146

airflow-core/src/airflow/models/connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,9 @@ def password(cls):
363363
"""Password. The value is decrypted/encrypted when reading/setting the value."""
364364
return synonym("_password", descriptor=property(cls.get_password, cls.set_password))
365365

366-
def get_extra(self) -> str:
366+
def get_extra(self) -> str | None:
367367
"""Return encrypted extra-data."""
368+
extra_val: str | None
368369
if self._extra and self.is_extra_encrypted:
369370
fernet = get_fernet()
370371
if not fernet.is_encrypted:

airflow-core/src/airflow/models/dag.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ def get_run_data_interval(timetable: Timetable, run: DagRun) -> DataInterval:
141141

142142
# Compatibility: runs created before AIP-39 implementation don't have an
143143
# explicit data interval. Try to infer from the logical date.
144+
if TYPE_CHECKING:
145+
assert run.logical_date is not None
144146
return infer_automated_data_interval(timetable, run.logical_date)
145147

146148

@@ -521,14 +523,13 @@ def get_paused_dag_ids(dag_ids: list[str], session: Session = NEW_SESSION) -> se
521523
:param session: ORM Session
522524
:return: Paused Dag_ids
523525
"""
524-
paused_dag_ids = session.execute(
526+
paused_dag_ids = session.scalars(
525527
select(DagModel.dag_id)
526528
.where(DagModel.is_paused == expression.true())
527529
.where(DagModel.dag_id.in_(dag_ids))
528530
)
529531

530-
paused_dag_ids = {paused_dag_id for (paused_dag_id,) in paused_dag_ids}
531-
return paused_dag_ids
532+
return set(paused_dag_ids)
532533

533534
@property
534535
def safe_dag_id(self):

airflow-core/src/airflow/models/serialized_dag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def __init__(self, dag: LazyDeserializedDAG) -> None:
335335

336336
# serve as cache so no need to decompress and load, when accessing data field
337337
# when COMPRESS_SERIALIZED_DAGS is True
338-
self.__data_cache = dag_data
338+
self.__data_cache: dict[Any, Any] | None = dag_data
339339

340340
def __repr__(self) -> str:
341341
return f"<SerializedDag: {self.dag_id}>"

airflow-core/src/airflow/models/taskinstance.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,9 +438,7 @@ class TaskInstance(Base, LoggingMixin):
438438
# The method to call next, and any extra arguments to pass to it.
439439
# Usually used when resuming from DEFERRED.
440440
next_method: Mapped[str | None] = mapped_column(String(1000), nullable=True)
441-
next_kwargs: Mapped[dict | str | None] = mapped_column(
442-
MutableDict.as_mutable(ExtendedJSON), nullable=True
443-
)
441+
next_kwargs: Mapped[dict | None] = mapped_column(MutableDict.as_mutable(ExtendedJSON), nullable=True)
444442

445443
_task_display_property_value: Mapped[str | None] = mapped_column(
446444
"task_display_name", String(2000), nullable=True

task-sdk/src/airflow/sdk/api/datamodels/_generated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ class TIDeferredStatePayload(BaseModel):
184184
trigger_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Trigger Kwargs")] = None
185185
trigger_timeout: Annotated[timedelta | None, Field(title="Trigger Timeout")] = None
186186
next_method: Annotated[str, Field(title="Next Method")]
187-
next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next Kwargs")] = None
187+
next_kwargs: Annotated[dict[str, Any] | None, Field(title="Next Kwargs")] = None
188188
rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None
189189

190190

0 commit comments

Comments
 (0)