This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e6c56c4  Ensure that dag_id, run_id and execution_date are non-null on 
DagRun (#18804)
e6c56c4 is described below

commit e6c56c4ae475605636f4a1b5ab3884383884a8cf
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Fri Oct 8 01:32:20 2021 +0100

    Ensure that dag_id, run_id and execution_date are non-null on DagRun 
(#18804)
    
    These _should_ be non-nullable, and are always created as such. Without
    this it was possible that someone had manually edited it which caused
    problems with the TaskInstance FK migration not applying correctly.
    
    Co-authored-by: Jed Cunningham 
<[email protected]>
---
 .../7b2661a43ba3_taskinstance_keyed_to_dagrun.py   | 147 +++++++++++++++++----
 airflow/models/dagrun.py                           |   6 +-
 airflow/models/taskinstance.py                     |   6 +-
 airflow/utils/db.py                                |  32 +++++
 tests/api_connexion/schemas/test_dag_run_schema.py |   9 +-
 tests/models/test_dagrun.py                        |   3 +
 6 files changed, 168 insertions(+), 35 deletions(-)

diff --git 
a/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py 
b/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py
index 8c62101..059144e 100644
--- a/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py
+++ b/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py
@@ -41,10 +41,17 @@ branch_labels = None
 depends_on = None
 
 
-def _mssql_datetime():
-    from sqlalchemy.dialects import mssql
+def _datetime_type(dialect_name):
+    if dialect_name == "mssql":
+        from sqlalchemy.dialects import mssql
+
+        return mssql.DATETIME2(precision=6)
+    elif dialect_name == "mysql":
+        from sqlalchemy.dialects import mysql
 
-    return mssql.DATETIME2(precision=6)
+        return mysql.DATETIME(fsp=6)
+
+    return sa.TIMESTAMP(timezone=True)
 
 
 # Just Enough Table to run the conditions for update.
@@ -101,21 +108,30 @@ def upgrade():
     """Apply TaskInstance keyed to DagRun"""
     conn = op.get_bind()
     dialect_name = conn.dialect.name
+    dt_type = _datetime_type(dialect_name)
 
-    run_id_col_type = sa.String(length=ID_LEN, **COLLATION_ARGS)
+    string_id_col_type = sa.String(length=ID_LEN, **COLLATION_ARGS)
 
     if dialect_name == 'sqlite':
         naming_convention = {
             "uq": "%(table_name)s_%(column_0_N_name)s_key",
         }
-        with op.batch_alter_table('dag_run', 
naming_convention=naming_convention, recreate="always"):
-            # The naming_convention force the previously un-named UNIQUE 
constraints to have the right name --
-            # but we still need to enter the context manager to trigger it
-            pass
+        # The naming_convention force the previously un-named UNIQUE 
constraints to have the right name
+        with op.batch_alter_table(
+            'dag_run', naming_convention=naming_convention, recreate="always"
+        ) as batch_op:
+            batch_op.alter_column('dag_id', existing_type=string_id_col_type, 
nullable=False)
+            batch_op.alter_column('run_id', existing_type=string_id_col_type, 
nullable=False)
+            batch_op.alter_column('execution_date', existing_type=dt_type, 
nullable=False)
     elif dialect_name == 'mysql':
         with op.batch_alter_table('dag_run') as batch_op:
-            batch_op.alter_column('dag_id', 
existing_type=sa.String(length=ID_LEN), type_=run_id_col_type)
-            batch_op.alter_column('run_id', 
existing_type=sa.String(length=ID_LEN), type_=run_id_col_type)
+            batch_op.alter_column(
+                'dag_id', existing_type=sa.String(length=ID_LEN), 
type_=string_id_col_type, nullable=False
+            )
+            batch_op.alter_column(
+                'run_id', existing_type=sa.String(length=ID_LEN), 
type_=string_id_col_type, nullable=False
+            )
+            batch_op.alter_column('execution_date', existing_type=dt_type, 
nullable=False)
             batch_op.drop_constraint('dag_id', 'unique')
             batch_op.drop_constraint('dag_id_2', 'unique')
             batch_op.create_unique_constraint(
@@ -124,16 +140,47 @@ def upgrade():
             batch_op.create_unique_constraint('dag_run_dag_id_run_id_key', 
['dag_id', 'run_id'])
     elif dialect_name == 'mssql':
 
-        # _Somehow_ mssql was missing these constraints entirely!
         with op.batch_alter_table('dag_run') as batch_op:
+            batch_op.drop_index('idx_not_null_dag_id_execution_date')
+            batch_op.drop_index('idx_not_null_dag_id_run_id')
+
+            batch_op.drop_index('dag_id_state')
+            batch_op.drop_index('idx_dag_run_dag_id')
+            batch_op.drop_index('idx_dag_run_running_dags')
+            batch_op.drop_index('idx_dag_run_queued_dags')
+
+            batch_op.alter_column('dag_id', existing_type=string_id_col_type, 
nullable=False)
+            batch_op.alter_column('execution_date', existing_type=dt_type, 
nullable=False)
+            batch_op.alter_column('run_id', existing_type=string_id_col_type, 
nullable=False)
+
+            # _Somehow_ mssql was missing these constraints entirely
             batch_op.create_unique_constraint(
                 'dag_run_dag_id_execution_date_key', ['dag_id', 
'execution_date']
             )
             batch_op.create_unique_constraint('dag_run_dag_id_run_id_key', 
['dag_id', 'run_id'])
 
+            batch_op.create_index('dag_id_state', ['dag_id', 'state'], 
unique=False)
+            batch_op.create_index('idx_dag_run_dag_id', ['dag_id'])
+            batch_op.create_index(
+                'idx_dag_run_running_dags',
+                ["state", "dag_id"],
+                mssql_where=sa.text("state='running'"),
+            )
+            batch_op.create_index(
+                'idx_dag_run_queued_dags',
+                ["state", "dag_id"],
+                mssql_where=sa.text("state='queued'"),
+            )
+    else:
+        # Make sure DagRun id columns are non-nullable
+        with op.batch_alter_table('dag_run', schema=None) as batch_op:
+            batch_op.alter_column('dag_id', existing_type=string_id_col_type, 
nullable=False)
+            batch_op.alter_column('execution_date', existing_type=dt_type, 
nullable=False)
+            batch_op.alter_column('run_id', existing_type=string_id_col_type, 
nullable=False)
+
     # First create column nullable
-    op.add_column('task_instance', sa.Column('run_id', type_=run_id_col_type, 
nullable=True))
-    op.add_column('task_reschedule', sa.Column('run_id', 
type_=run_id_col_type, nullable=True))
+    op.add_column('task_instance', sa.Column('run_id', 
type_=string_id_col_type, nullable=True))
+    op.add_column('task_reschedule', sa.Column('run_id', 
type_=string_id_col_type, nullable=True))
 
     # Then update the new column by selecting the right value from DagRun
     update_query = _multi_table_update(dialect_name, task_instance, 
task_instance.c.run_id)
@@ -147,7 +194,9 @@ def upgrade():
     op.execute(update_query)
 
     with op.batch_alter_table('task_reschedule', schema=None) as batch_op:
-        batch_op.alter_column('run_id', existing_type=run_id_col_type, 
existing_nullable=True, nullable=False)
+        batch_op.alter_column(
+            'run_id', existing_type=string_id_col_type, 
existing_nullable=True, nullable=False
+        )
 
         batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', 
'foreignkey')
         if dialect_name == "mysql":
@@ -157,7 +206,14 @@ def upgrade():
 
     with op.batch_alter_table('task_instance', schema=None) as batch_op:
         # Then make it non-nullable
-        batch_op.alter_column('run_id', existing_type=run_id_col_type, 
existing_nullable=True, nullable=False)
+        batch_op.alter_column(
+            'run_id', existing_type=string_id_col_type, 
existing_nullable=True, nullable=False
+        )
+
+        batch_op.alter_column(
+            'dag_id', existing_type=string_id_col_type, 
existing_nullable=True, nullable=False
+        )
+        batch_op.alter_column('execution_date', existing_type=dt_type, 
existing_nullable=True, nullable=False)
 
         # TODO: Is this right for non-postgres?
         if dialect_name == 'mssql':
@@ -212,14 +268,11 @@ def upgrade():
 def downgrade():
     """Unapply TaskInstance keyed to DagRun"""
     dialect_name = op.get_bind().dialect.name
+    dt_type = _datetime_type(dialect_name)
+    string_id_col_type = sa.String(length=ID_LEN, **COLLATION_ARGS)
 
-    if dialect_name == "mssql":
-        col_type = _mssql_datetime()
-    else:
-        col_type = sa.TIMESTAMP(timezone=True)
-
-    op.add_column('task_instance', sa.Column('execution_date', col_type, 
nullable=True))
-    op.add_column('task_reschedule', sa.Column('execution_date', col_type, 
nullable=True))
+    op.add_column('task_instance', sa.Column('execution_date', dt_type, 
nullable=True))
+    op.add_column('task_reschedule', sa.Column('execution_date', dt_type, 
nullable=True))
 
     update_query = _multi_table_update(dialect_name, task_instance, 
task_instance.c.execution_date)
     op.execute(update_query)
@@ -228,9 +281,7 @@ def downgrade():
     op.execute(update_query)
 
     with op.batch_alter_table('task_reschedule', schema=None) as batch_op:
-        batch_op.alter_column(
-            'execution_date', existing_type=col_type, existing_nullable=True, 
nullable=False
-        )
+        batch_op.alter_column('execution_date', existing_type=dt_type, 
existing_nullable=True, nullable=False)
 
         # Can't drop PK index while there is a FK referencing it
         batch_op.drop_constraint('task_reschedule_ti_fkey')
@@ -238,8 +289,9 @@ def downgrade():
         batch_op.drop_index('idx_task_reschedule_dag_task_run')
 
     with op.batch_alter_table('task_instance', schema=None) as batch_op:
+        batch_op.alter_column('execution_date', existing_type=dt_type, 
existing_nullable=True, nullable=False)
         batch_op.alter_column(
-            'execution_date', existing_type=col_type, existing_nullable=True, 
nullable=False
+            'dag_id', existing_type=string_id_col_type, 
existing_nullable=True, nullable=True
         )
 
         batch_op.drop_constraint('task_instance_pkey', type_='primary')
@@ -269,6 +321,49 @@ def downgrade():
             ondelete='CASCADE',
         )
 
+    if dialect_name == "mssql":
+
+        with op.batch_alter_table('dag_run', schema=None) as batch_op:
+            batch_op.drop_constraint('dag_run_dag_id_execution_date_key', 
'unique')
+            batch_op.drop_constraint('dag_run_dag_id_run_id_key', 'unique')
+            batch_op.drop_index('dag_id_state')
+            batch_op.drop_index('idx_dag_run_running_dags')
+            batch_op.drop_index('idx_dag_run_queued_dags')
+
+            batch_op.alter_column('dag_id', existing_type=string_id_col_type, 
nullable=True)
+            batch_op.alter_column('execution_date', existing_type=dt_type, 
nullable=True)
+            batch_op.alter_column('run_id', existing_type=string_id_col_type, 
nullable=True)
+
+            batch_op.create_index('dag_id_state', ['dag_id', 'state'], 
unique=False)
+            batch_op.create_index('idx_dag_run_dag_id', ['dag_id'])
+            batch_op.create_index(
+                'idx_dag_run_running_dags',
+                ["state", "dag_id"],
+                mssql_where=sa.text("state='running'"),
+            )
+            batch_op.create_index(
+                'idx_dag_run_queued_dags',
+                ["state", "dag_id"],
+                mssql_where=sa.text("state='queued'"),
+            )
+        op.execute(
+            """CREATE UNIQUE NONCLUSTERED INDEX 
idx_not_null_dag_id_execution_date
+                    ON dag_run(dag_id,execution_date)
+                    WHERE dag_id IS NOT NULL and execution_date is not null"""
+        )
+        op.execute(
+            """CREATE UNIQUE NONCLUSTERED INDEX idx_not_null_dag_id_run_id
+                     ON dag_run(dag_id,run_id)
+                     WHERE dag_id IS NOT NULL and run_id is not null"""
+        )
+    else:
+        with op.batch_alter_table('dag_run', schema=None) as batch_op:
+            batch_op.drop_index('dag_id_state', table_name='dag_run')
+            batch_op.alter_column('run_id', 
existing_type=sa.VARCHAR(length=250), nullable=True)
+            batch_op.alter_column('execution_date', existing_type=dt_type, 
nullable=True)
+            batch_op.alter_column('dag_id', 
existing_type=sa.VARCHAR(length=250), nullable=True)
+            batch_op.create_index('dag_id_state', 'dag_run', ['dag_id', 
'state'], unique=False)
+
 
 def _multi_table_update(dialect_name, target, column):
     condition = dag_run.c.dag_id == target.c.dag_id
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 3d50b05..2b651c5 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -78,13 +78,13 @@ class DagRun(Base, LoggingMixin):
     __NO_VALUE = object()
 
     id = Column(Integer, primary_key=True)
-    dag_id = Column(String(ID_LEN, **COLLATION_ARGS))
+    dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
     queued_at = Column(UtcDateTime)
-    execution_date = Column(UtcDateTime, default=timezone.utcnow)
+    execution_date = Column(UtcDateTime, default=timezone.utcnow, 
nullable=False)
     start_date = Column(UtcDateTime)
     end_date = Column(UtcDateTime)
     _state = Column('state', String(50), default=State.QUEUED)
-    run_id = Column(String(ID_LEN, **COLLATION_ARGS))
+    run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
     creating_job_id = Column(Integer)
     external_trigger = Column(Boolean, default=True)
     run_type = Column(String(50), nullable=False)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 8453a1e..b178ff0 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -324,9 +324,9 @@ class TaskInstance(Base, LoggingMixin):
 
     __tablename__ = "task_instance"
 
-    task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
-    dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
-    run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
+    task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True, 
nullable=False)
+    dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True, 
nullable=False)
+    run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True, 
nullable=False)
     start_date = Column(UtcDateTime)
     end_date = Column(UtcDateTime)
     duration = Column(Float)
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index db249fe..13dd401 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -697,6 +697,37 @@ def check_conn_type_null(session=None) -> Iterable[str]:
         )
 
 
+def check_run_id_null(session) -> Iterable[str]:
+    import sqlalchemy.schema
+
+    metadata = sqlalchemy.schema.MetaData(session.bind)
+    try:
+        metadata.reflect(only=["dag_run"])
+    except exc.InvalidRequestError:
+        # Table doesn't exist -- empty db
+        return
+
+    dag_run = metadata.tables["dag_run"]
+
+    for colname in ('run_id', 'dag_id', 'execution_date'):
+
+        col = dag_run.columns.get(colname)
+        if col is None:
+            continue
+
+        if not col.nullable:
+            continue
+
+        num = session.query(dag_run).filter(col.is_(None)).count()
+        if num > 0:
+            yield (
+                f'The {dag_run.name} table has {num} row{"s" if num != 1 else 
""} with a NULL value in '
+                f'{col.name!r}. You must manually correct this problem 
(possibly by deleting the problem '
+                'rows).'
+            )
+    session.rollback()
+
+
 def check_task_tables_without_matching_dagruns(session) -> Iterable[str]:
     from itertools import chain
 
@@ -762,6 +793,7 @@ def _check_migration_errors(session=None) -> Iterable[str]:
     for check_fn in (
         check_conn_id_duplicates,
         check_conn_type_null,
+        check_run_id_null,
         check_task_tables_without_matching_dagruns,
     ):
         yield from check_fn(session)
diff --git a/tests/api_connexion/schemas/test_dag_run_schema.py 
b/tests/api_connexion/schemas/test_dag_run_schema.py
index ba5acae..6f42ec0 100644
--- a/tests/api_connexion/schemas/test_dag_run_schema.py
+++ b/tests/api_connexion/schemas/test_dag_run_schema.py
@@ -51,6 +51,7 @@ class TestDAGRunSchema(TestDAGRunBase):
     @provide_session
     def test_serialize(self, session):
         dagrun_model = DagRun(
+            dag_id="my-dag-run",
             run_id="my-dag-run",
             state='running',
             run_type=DagRunType.MANUAL.value,
@@ -64,7 +65,7 @@ class TestDAGRunSchema(TestDAGRunBase):
         deserialized_dagrun = dagrun_schema.dump(dagrun_model)
 
         assert deserialized_dagrun == {
-            "dag_id": None,
+            "dag_id": "my-dag-run",
             "dag_run_id": "my-dag-run",
             "end_date": None,
             "state": "running",
@@ -128,6 +129,7 @@ class TestDagRunCollection(TestDAGRunBase):
     @provide_session
     def test_serialize(self, session):
         dagrun_model_1 = DagRun(
+            dag_id="my-dag-run",
             run_id="my-dag-run",
             state='running',
             execution_date=timezone.parse(self.default_time),
@@ -136,6 +138,7 @@ class TestDagRunCollection(TestDAGRunBase):
             conf='{"start": "stop"}',
         )
         dagrun_model_2 = DagRun(
+            dag_id="my-dag-run",
             run_id="my-dag-run-2",
             state='running',
             execution_date=timezone.parse(self.second_time),
@@ -150,7 +153,7 @@ class TestDagRunCollection(TestDAGRunBase):
         assert deserialized_dagruns == {
             "dag_runs": [
                 {
-                    "dag_id": None,
+                    "dag_id": "my-dag-run",
                     "dag_run_id": "my-dag-run",
                     "end_date": None,
                     "execution_date": self.default_time,
@@ -161,7 +164,7 @@ class TestDagRunCollection(TestDAGRunBase):
                     "conf": {"start": "stop"},
                 },
                 {
-                    "dag_id": None,
+                    "dag_id": "my-dag-run",
                     "dag_run_id": "my-dag-run-2",
                     "end_date": None,
                     "state": "running",
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 3622603..c4ef287 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -114,6 +114,7 @@ class TestDagRun(unittest.TestCase):
         dag_id1 = "test_dagrun_find_externally_triggered"
         dag_run = models.DagRun(
             dag_id=dag_id1,
+            run_id=dag_id1,
             run_type=DagRunType.MANUAL,
             execution_date=now,
             start_date=now,
@@ -125,6 +126,7 @@ class TestDagRun(unittest.TestCase):
         dag_id2 = "test_dagrun_find_not_externally_triggered"
         dag_run = models.DagRun(
             dag_id=dag_id2,
+            run_id=dag_id2,
             run_type=DagRunType.MANUAL,
             execution_date=now,
             start_date=now,
@@ -532,6 +534,7 @@ class TestDagRun(unittest.TestCase):
         # don't want
         dag_run = models.DagRun(
             dag_id=dag.dag_id,
+            run_id="test_get_task_instance_on_empty_dagrun",
             run_type=DagRunType.MANUAL,
             execution_date=now,
             start_date=now,

Reply via email to