Skip to content

Commit

Permalink
DagCode.bulk_write_code is no longer used
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy committed Oct 20, 2024
1 parent b4c73e0 commit 9dc8ec6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 96 deletions.
74 changes: 3 additions & 71 deletions airflow/models/dagcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
from __future__ import annotations

import logging
import os
import struct
import uuid
from datetime import datetime
from typing import TYPE_CHECKING, Collection, Iterable
from typing import TYPE_CHECKING, Collection

from sqlalchemy import BigInteger, Column, ForeignKey, String, Text, delete, select
from sqlalchemy.dialects.mysql import MEDIUMTEXT
Expand All @@ -30,10 +28,10 @@
from sqlalchemy_utils import UUIDType

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.exceptions import AirflowException, DagCodeNotFound
from airflow.exceptions import DagCodeNotFound
from airflow.models.base import Base
from airflow.utils import timezone
from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped
from airflow.utils.file import open_maybe_zipped
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime

Expand Down Expand Up @@ -83,72 +81,6 @@ def write_dag(cls, fileloc: str, session: Session = NEW_SESSION) -> DagCode:
log.debug("DAG file %s written into DagCode table", fileloc)
return dag_code

@provide_session
def sync_to_db(self, session: Session = NEW_SESSION) -> None:
"""
Write code into database.
:param session: ORM Session
"""
self.bulk_sync_to_db([self.fileloc], session)

@classmethod
@provide_session
def bulk_sync_to_db(cls, filelocs: Iterable[str], session: Session = NEW_SESSION) -> None:
"""
Write code in bulk into database.
:param filelocs: file paths of DAGs to sync
:param session: ORM Session
"""
filelocs = set(filelocs)
filelocs_to_hashes = {fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs}
existing_orm_dag_codes = session.scalars(
select(DagCode)
.filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values()))
.with_for_update(of=DagCode)
).all()

if existing_orm_dag_codes:
existing_orm_dag_codes_map = {
orm_dag_code.fileloc: orm_dag_code for orm_dag_code in existing_orm_dag_codes
}
else:
existing_orm_dag_codes_map = {}

existing_orm_dag_codes_by_fileloc_hashes = {orm.fileloc_hash: orm for orm in existing_orm_dag_codes}
existing_orm_filelocs = {orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values()}
if not existing_orm_filelocs.issubset(filelocs):
conflicting_filelocs = existing_orm_filelocs.difference(filelocs)
hashes_to_filelocs = {DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs}
message = ""
for fileloc in conflicting_filelocs:
filename = hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)]
message += (
f"Filename '{filename}' causes a hash collision in the "
f"database with '{fileloc}'. Please rename the file."
)
raise AirflowException(message)

existing_filelocs = {dag_code.fileloc for dag_code in existing_orm_dag_codes}
missing_filelocs = filelocs.difference(existing_filelocs)

for fileloc in missing_filelocs:
orm_dag_code = DagCode(fileloc, cls._get_code_from_file(fileloc))
session.add(orm_dag_code)

for fileloc in existing_filelocs:
current_version = existing_orm_dag_codes_by_fileloc_hashes[filelocs_to_hashes[fileloc]]
file_mod_time = datetime.fromtimestamp(
os.path.getmtime(correct_maybe_zipped(fileloc)), tz=timezone.utc
)

if file_mod_time > current_version.last_updated:
orm_dag_code = existing_orm_dag_codes_map[fileloc]
orm_dag_code.last_updated = file_mod_time
orm_dag_code.source_code = cls._get_code_from_file(orm_dag_code.fileloc)
session.merge(orm_dag_code)

@classmethod
@internal_api_call
@provide_session
Expand Down
26 changes: 1 addition & 25 deletions tests/models/test_dagcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,36 +69,12 @@ def _write_example_dags(self):
SDM.write_dag(dag)
return example_dags

def test_sync_to_db(self):
def test_write_to_db(self):
"""Dg code can be written into database."""
example_dags = self._write_example_dags()

self._compare_example_dags(example_dags)

def test_bulk_sync_to_db(self):
"""Dg code can be bulk written into database."""
example_dags = make_example_dags(example_dags_module)
files = [dag.fileloc for dag in example_dags.values()]
with create_session() as session:
DagCode.bulk_sync_to_db(files, session=session)
session.commit()

self._compare_example_dags(example_dags)

def test_bulk_sync_to_db_half_files(self):
"""Dg code can be bulk written into database."""
example_dags = make_example_dags(example_dags_module)
files = [dag.fileloc for dag in example_dags.values()]
half_files = files[: len(files) // 2]
with create_session() as session:
DagCode.bulk_sync_to_db(half_files, session=session)
session.commit()
with create_session() as session:
DagCode.bulk_sync_to_db(files, session=session)
session.commit()

self._compare_example_dags(example_dags)

@patch.object(DagCode, "dag_fileloc_hash")
def test_detecting_duplicate_key(self, mock_hash):
"""Dag code detects duplicate key."""
Expand Down

0 comments on commit 9dc8ec6

Please sign in to comment.