Skip to content

Commit

Permalink
Define datetime and StringID column types centrally in migrations
Browse files Browse the repository at this point in the history
We have various flavours of the code all over the place in many
migration files -- which leads to duplication and things not being in
sync.

This pulls them once in to a central location.
  • Loading branch information
ashb committed Nov 8, 2021
1 parent aa15cba commit 64d3f98
Show file tree
Hide file tree
Showing 24 changed files with 238 additions and 264 deletions.
86 changes: 86 additions & 0 deletions airflow/migrations/db_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import sys

import sqlalchemy as sa
from alembic import context
from lazy_object_proxy import Proxy

######################################
# Note about this module:
#
# It loads the specific type dynamically at runtime. For IDE/typing support
# there is an associated db_types.pyi. If you add a new type in here, add a
# simple version in there too.
######################################


def _mssql_use_date_time2():
conn = context.get_bind()
result = conn.execute(
"""SELECT CASE WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion'))
like '8%' THEN '2000' WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion'))
like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion"""
).fetchone()
mssql_version = result[0]
return mssql_version not in ("2000", "2005")


MSSQL_USE_DATE_TIME2 = Proxy(_mssql_use_date_time2)


def _mssql_TIMESTAMP():
from sqlalchemy.dialects import mssql

return mssql.DATETIME2(precision=6) if MSSQL_USE_DATE_TIME2 else mssql.DATETIME


def _mysql_TIMESTAMP():
from sqlalchemy.dialects import mysql

return mysql.TIMESTAMP(fsp=6, timezone=True)


def _sa_TIMESTAMP():
return sa.TIMESTAMP(timezone=True)


def _sa_StringID():
from airflow.models.base import StringID

return StringID


def __getattr__(name):
if name in ["TIMESTAMP", "StringID"]:
dialect = context.get_bind().dialect.name
module = globals()

# Lookup the type based on the dialect specific type, or fallback to the generic type
type_ = module.get(f'_{dialect}_{name}', None) or module.get(f'_sa_{name}')
val = module[name] = type_()
return val

raise AttributeError(f"module {__name__} has no attribute {name}")


if sys.version_info < (3, 7):
from pep562 import Pep562

Pep562(__name__)
28 changes: 28 additions & 0 deletions airflow/migrations/db_types.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

import sqlalchemy as sa

TIMESTAMP = sa.TIMESTAMP
"""Database specific timestamp with timezone"""

StringID = sa.String
"""String column type with correct DB collation applied"""

MSSQL_USE_DATE_TIME2: bool
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from alembic import op
from sqlalchemy.engine.reflection import Inspector

from airflow.models.base import COLLATION_ARGS
from airflow.migrations.db_types import StringID

# revision identifiers, used by Alembic.
revision = '03afc6b6f902'
Expand Down Expand Up @@ -63,7 +63,7 @@ def upgrade():
op.alter_column(
table_name='ab_view_menu',
column_name='name',
type_=sa.String(length=250, **COLLATION_ARGS),
type_=StringID(length=250),
nullable=False,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import mysql

from airflow.models.base import COLLATION_ARGS
from airflow.migrations.db_types import TIMESTAMP, StringID

# revision identifiers, used by Alembic.
revision = '0a2a5b66e19d'
Expand All @@ -38,43 +37,25 @@
INDEX_NAME = 'idx_' + TABLE_NAME + '_dag_task_date'


# For Microsoft SQL Server, TIMESTAMP is a row-id type,
# having nothing to do with date-time. DateTime() will
# be sufficient.
def mssql_timestamp():
return sa.DateTime()


def mysql_timestamp():
return mysql.TIMESTAMP(fsp=6)


def sa_timestamp():
return sa.TIMESTAMP(timezone=True)


def upgrade():
# See 0e2a74e0fc9f_add_time_zone_awareness
conn = op.get_bind()
if conn.dialect.name == 'mysql':
timestamp = mysql_timestamp
elif conn.dialect.name == 'mssql':
timestamp = mssql_timestamp
else:
timestamp = sa_timestamp
timestamp = TIMESTAMP
if op.get_bind().dialect.name == 'mssql':
# We need to keep this as it was for this old migration on mssql
timestamp = sa.DateTime()

op.create_table(
TABLE_NAME,
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('task_id', sa.String(length=250, **COLLATION_ARGS), nullable=False),
sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False),
sa.Column('task_id', StringID(), nullable=False),
sa.Column('dag_id', StringID(), nullable=False),
# use explicit server_default=None otherwise mysql implies defaults for first timestamp column
sa.Column('execution_date', timestamp(), nullable=False, server_default=None),
sa.Column('execution_date', timestamp, nullable=False, server_default=None),
sa.Column('try_number', sa.Integer(), nullable=False),
sa.Column('start_date', timestamp(), nullable=False),
sa.Column('end_date', timestamp(), nullable=False),
sa.Column('start_date', timestamp, nullable=False),
sa.Column('end_date', timestamp, nullable=False),
sa.Column('duration', sa.Integer(), nullable=False),
sa.Column('reschedule_date', timestamp(), nullable=False),
sa.Column('reschedule_date', timestamp, nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.ForeignKeyConstraint(
['task_id', 'dag_id', 'execution_date'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
"""

from alembic import op
from sqlalchemy import TIMESTAMP, Column
from sqlalchemy.dialects import mssql, mysql
from sqlalchemy import Column

from airflow.migrations.db_types import TIMESTAMP

# Revision identifiers, used by Alembic.
revision = "142555e44c17"
Expand All @@ -35,36 +36,14 @@
depends_on = None


def _use_date_time2(conn):
result = conn.execute(
"""SELECT CASE WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion'))
like '8%' THEN '2000' WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion'))
like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion"""
).fetchone()
mssql_version = result[0]
return mssql_version not in ("2000", "2005")


def _get_timestamp(conn):
dialect_name = conn.dialect.name
if dialect_name == "mysql":
return mysql.TIMESTAMP(fsp=6, timezone=True)
if dialect_name != "mssql":
return TIMESTAMP(timezone=True)
if _use_date_time2(conn):
return mssql.DATETIME2(precision=6)
return mssql.DATETIME


def upgrade():
"""Apply data_interval fields to DagModel and DagRun."""
column_type = _get_timestamp(op.get_bind())
with op.batch_alter_table("dag_run") as batch_op:
batch_op.add_column(Column("data_interval_start", column_type))
batch_op.add_column(Column("data_interval_end", column_type))
batch_op.add_column(Column("data_interval_start", TIMESTAMP))
batch_op.add_column(Column("data_interval_end", TIMESTAMP))
with op.batch_alter_table("dag") as batch_op:
batch_op.add_column(Column("next_dagrun_data_interval_start", column_type))
batch_op.add_column(Column("next_dagrun_data_interval_end", column_type))
batch_op.add_column(Column("next_dagrun_data_interval_start", TIMESTAMP))
batch_op.add_column(Column("next_dagrun_data_interval_end", TIMESTAMP))


def downgrade():
Expand Down
6 changes: 3 additions & 3 deletions airflow/migrations/versions/1b38cef5b76e_add_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import sqlalchemy as sa
from alembic import op

from airflow.models.base import COLLATION_ARGS
from airflow.migrations.db_types import StringID

# revision identifiers, used by Alembic.
revision = '1b38cef5b76e'
Expand All @@ -40,10 +40,10 @@ def upgrade():
op.create_table(
'dag_run',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=True),
sa.Column('dag_id', StringID(), nullable=True),
sa.Column('execution_date', sa.DateTime(), nullable=True),
sa.Column('state', sa.String(length=50), nullable=True),
sa.Column('run_id', sa.String(length=250, **COLLATION_ARGS), nullable=True),
sa.Column('run_id', StringID(), nullable=True),
sa.Column('external_trigger', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('dag_id', 'execution_date'),
Expand Down
10 changes: 6 additions & 4 deletions airflow/migrations/versions/3c20cacc0044_add_dagrun_run_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.ext.declarative import declarative_base

from airflow.models.base import ID_LEN
from airflow.migrations.db_types import StringID
from airflow.utils import timezone
from airflow.utils.sqlalchemy import UtcDateTime
from airflow.utils.state import State
Expand All @@ -55,12 +55,12 @@ class DagRun(Base): # type: ignore
__tablename__ = "dag_run"

id = Column(Integer, primary_key=True)
dag_id = Column(String(ID_LEN))
dag_id = Column(StringID())
execution_date = Column(UtcDateTime, default=timezone.utcnow)
start_date = Column(UtcDateTime, default=timezone.utcnow)
end_date = Column(UtcDateTime)
_state = Column('state', String(50), default=State.RUNNING)
run_id = Column(String(ID_LEN))
run_id = Column(StringID())
external_trigger = Column(Boolean, default=True)
run_type = Column(String(50), nullable=False)
conf = Column(PickleType)
Expand Down Expand Up @@ -96,7 +96,9 @@ def upgrade():

# Make run_type not nullable
with op.batch_alter_table("dag_run") as batch_op:
batch_op.alter_column("run_type", type_=run_type_col_type, nullable=False)
batch_op.alter_column(
"run_type", existing_type=run_type_col_type, type_=run_type_col_type, nullable=False
)


def downgrade():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import sqlalchemy as sa
from alembic import op

from airflow.models.base import COLLATION_ARGS
from airflow.migrations.db_types import StringID

# revision identifiers, used by Alembic.
revision = '64de9cddf6c9'
Expand All @@ -39,8 +39,8 @@ def upgrade():
op.create_table(
'task_fail',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('task_id', sa.String(length=250, **COLLATION_ARGS), nullable=False),
sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False),
sa.Column('task_id', StringID(), nullable=False),
sa.Column('dag_id', StringID(), nullable=False),
sa.Column('execution_date', sa.DateTime(), nullable=False),
sa.Column('start_date', sa.DateTime(), nullable=True),
sa.Column('end_date', sa.DateTime(), nullable=True),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from sqlalchemy import Column, Float, Integer, PickleType, String
from sqlalchemy.ext.declarative import declarative_base

from airflow.models.base import COLLATION_ARGS
from airflow.migrations.db_types import StringID
from airflow.utils.session import create_session
from airflow.utils.sqlalchemy import UtcDateTime

Expand Down Expand Up @@ -60,8 +60,8 @@ class TaskInstance(Base): # type: ignore

__tablename__ = "task_instance"

task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
task_id = Column(StringID(), primary_key=True)
dag_id = Column(StringID(), primary_key=True)
execution_date = Column(UtcDateTime, primary_key=True)
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
Expand Down
4 changes: 2 additions & 2 deletions airflow/migrations/versions/7939bcff74ba_add_dagtags_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import sqlalchemy as sa
from alembic import op

from airflow.models.base import COLLATION_ARGS
from airflow.migrations.db_types import StringID

# revision identifiers, used by Alembic.
revision = '7939bcff74ba'
Expand All @@ -41,7 +41,7 @@ def upgrade():
op.create_table(
'dag_tag',
sa.Column('name', sa.String(length=100), nullable=False),
sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False),
sa.Column('dag_id', StringID(), nullable=False),
sa.ForeignKeyConstraint(
['dag_id'],
['dag.dag_id'],
Expand Down
Loading

0 comments on commit 64d3f98

Please sign in to comment.