Skip to content

Commit 15273f0

Browse files
authored
Check for same task instead of Equality to detect Duplicate Tasks (#8828)
1 parent f4edd90 commit 15273f0

File tree

17 files changed

+78
-82
lines changed

17 files changed

+78
-82
lines changed

β€Žairflow/example_dags/example_complex.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
)
8282

8383
create_tag_template_field_result2 = BashOperator(
84-
task_id="create_tag_template_field_result", bash_command="echo create_tag_template_field_result"
84+
task_id="create_tag_template_field_result2", bash_command="echo create_tag_template_field_result"
8585
)
8686

8787
# Delete

β€Žairflow/models/baseoperator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from sqlalchemy.orm import Session
3737

3838
from airflow.configuration import conf
39-
from airflow.exceptions import AirflowException, DuplicateTaskIdFound
39+
from airflow.exceptions import AirflowException
4040
from airflow.lineage import apply_lineage, prepare_lineage
4141
from airflow.models.base import Operator
4242
from airflow.models.pool import Pool
@@ -600,9 +600,8 @@ def dag(self, dag: Any):
600600
"The DAG assigned to {} can not be changed.".format(self))
601601
elif self.task_id not in dag.task_dict:
602602
dag.add_task(self)
603-
elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] != self:
604-
raise DuplicateTaskIdFound(
605-
"Task id '{}' has already been added to the DAG".format(self.task_id))
603+
elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] is not self:
604+
dag.add_task(self)
606605

607606
self._dag = dag # pylint: disable=attribute-defined-outside-init
608607

β€Žairflow/models/dag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1337,7 +1337,7 @@ def add_task(self, task):
13371337
elif task.end_date and self.end_date:
13381338
task.end_date = min(task.end_date, self.end_date)
13391339

1340-
if task.task_id in self.task_dict and self.task_dict[task.task_id] != task:
1340+
if task.task_id in self.task_dict and self.task_dict[task.task_id] is not task:
13411341
raise DuplicateTaskIdFound(
13421342
"Task id '{}' has already been added to the DAG".format(task.task_id))
13431343
else:

β€Žairflow/providers/google/cloud/example_dags/example_datacatalog.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@
181181

182182
# [START howto_operator_gcp_datacatalog_create_tag_template_field_result2]
183183
create_tag_template_field_result2 = BashOperator(
184-
task_id="create_tag_template_field_result",
184+
task_id="create_tag_template_field_result2",
185185
bash_command="echo \"{{ task_instance.xcom_pull('create_tag_template_field') }}\"",
186186
)
187187
# [END howto_operator_gcp_datacatalog_create_tag_template_field_result2]

β€Žairflow/providers/google/cloud/example_dags/example_gcs_to_gcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@
100100

101101
# [START howto_operator_gcs_to_gcs_delimiter]
102102
copy_files_with_delimiter = GCSToGCSOperator(
103-
task_id="copy_files_with_wildcard",
103+
task_id="copy_files_with_delimiter",
104104
source_bucket=BUCKET_1_SRC,
105105
source_object="data/",
106106
destination_bucket=BUCKET_1_DST,

β€Žtests/models/test_dag.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -979,34 +979,18 @@ def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self):
979979

980980
self.assertEqual(dag.task_dict, {op1.task_id: op1})
981981

982-
# Also verify that DAGs with duplicate task_ids don't raise errors
983-
with DAG("test_dag_1", start_date=DEFAULT_DATE) as dag1:
984-
op3 = DummyOperator(task_id="t3")
985-
op4 = BashOperator(task_id="t4", bash_command="sleep 1")
986-
op3 >> op4
987-
988-
self.assertEqual(dag1.task_dict, {op3.task_id: op3, op4.task_id: op4})
989-
990982
def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self):
991983
"""Verify tasks with Duplicate task_id raises error"""
992984
with self.assertRaisesRegex(
993985
DuplicateTaskIdFound, "Task id 't1' has already been added to the DAG"
994986
):
995987
dag = DAG("test_dag", start_date=DEFAULT_DATE)
996988
op1 = DummyOperator(task_id="t1", dag=dag)
997-
op2 = BashOperator(task_id="t1", bash_command="sleep 1", dag=dag)
989+
op2 = DummyOperator(task_id="t1", dag=dag)
998990
op1 >> op2
999991

1000992
self.assertEqual(dag.task_dict, {op1.task_id: op1})
1001993

1002-
# Also verify that DAGs with duplicate task_ids don't raise errors
1003-
dag1 = DAG("test_dag_1", start_date=DEFAULT_DATE)
1004-
op3 = DummyOperator(task_id="t3", dag=dag1)
1005-
op4 = DummyOperator(task_id="t4", dag=dag1)
1006-
op3 >> op4
1007-
1008-
self.assertEqual(dag1.task_dict, {op3.task_id: op3, op4.task_id: op4})
1009-
1010994
def test_duplicate_task_ids_for_same_task_is_allowed(self):
1011995
"""Verify that same tasks with Duplicate task_id do not raise error"""
1012996
with DAG("test_dag", start_date=DEFAULT_DATE) as dag:

β€Žtests/models/test_taskinstance.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -373,20 +373,20 @@ def test_ti_updates_with_task(self, session=None):
373373
"""
374374
test that updating the executor_config propogates to the TaskInstance DB
375375
"""
376-
dag = models.DAG(dag_id='test_run_pooling_task')
377-
task = DummyOperator(task_id='test_run_pooling_task_op', dag=dag, owner='airflow',
378-
executor_config={'foo': 'bar'},
379-
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
376+
with models.DAG(dag_id='test_run_pooling_task') as dag:
377+
task = DummyOperator(task_id='test_run_pooling_task_op', owner='airflow',
378+
executor_config={'foo': 'bar'},
379+
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
380380
ti = TI(
381381
task=task, execution_date=timezone.utcnow())
382382

383383
ti.run(session=session)
384384
tis = dag.get_task_instances()
385385
self.assertEqual({'foo': 'bar'}, tis[0].executor_config)
386-
387-
task2 = DummyOperator(task_id='test_run_pooling_task_op', dag=dag, owner='airflow',
388-
executor_config={'bar': 'baz'},
389-
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
386+
with models.DAG(dag_id='test_run_pooling_task') as dag:
387+
task2 = DummyOperator(task_id='test_run_pooling_task_op', owner='airflow',
388+
executor_config={'bar': 'baz'},
389+
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
390390

391391
ti = TI(
392392
task=task2, execution_date=timezone.utcnow())

β€Žtests/providers/amazon/aws/operators/test_s3_to_sftp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def test_s3_to_sftp_operation(self):
133133
def delete_remote_resource(self):
134134
# check the remote file content
135135
remove_file_task = SSHOperator(
136-
task_id="test_check_file",
136+
task_id="test_rm_file",
137137
ssh_hook=self.hook,
138138
command="rm {0}".format(self.sftp_path),
139139
do_xcom_push=True,

β€Žtests/providers/google/cloud/operators/test_mlengine_utils.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -131,16 +131,18 @@ def test_successful_run(self):
131131
self.assertEqual('err=0.9', result)
132132

133133
def test_failures(self):
134-
dag = DAG(
135-
'test_dag',
136-
default_args={
137-
'owner': 'airflow',
138-
'start_date': DEFAULT_DATE,
139-
'end_date': DEFAULT_DATE,
140-
'project_id': 'test-project',
141-
'region': 'us-east1',
142-
},
143-
schedule_interval='@daily')
134+
def create_test_dag(dag_id):
135+
dag = DAG(
136+
dag_id,
137+
default_args={
138+
'owner': 'airflow',
139+
'start_date': DEFAULT_DATE,
140+
'end_date': DEFAULT_DATE,
141+
'project_id': 'test-project',
142+
'region': 'us-east1',
143+
},
144+
schedule_interval='@daily')
145+
return dag
144146

145147
input_with_model = self.INPUT_MISSING_ORIGIN.copy()
146148
other_params_but_models = {
@@ -151,26 +153,30 @@ def test_failures(self):
151153
'prediction_path': input_with_model['outputPath'],
152154
'metric_fn_and_keys': (self.metric_fn, ['err']),
153155
'validate_fn': (lambda x: 'err=%.1f' % x['err']),
154-
'dag': dag,
155156
}
156157

157158
with self.assertRaisesRegex(AirflowException, 'Missing model origin'):
158-
mlengine_operator_utils.create_evaluate_ops(**other_params_but_models)
159+
mlengine_operator_utils.create_evaluate_ops(
160+
dag=create_test_dag('test_dag_1'), **other_params_but_models)
159161

160162
with self.assertRaisesRegex(AirflowException, 'Ambiguous model origin'):
161-
mlengine_operator_utils.create_evaluate_ops(model_uri='abc', model_name='cde',
162-
**other_params_but_models)
163+
mlengine_operator_utils.create_evaluate_ops(
164+
dag=create_test_dag('test_dag_2'), model_uri='abc', model_name='cde',
165+
**other_params_but_models)
163166

164167
with self.assertRaisesRegex(AirflowException, 'Ambiguous model origin'):
165-
mlengine_operator_utils.create_evaluate_ops(model_uri='abc', version_name='vvv',
166-
**other_params_but_models)
168+
mlengine_operator_utils.create_evaluate_ops(
169+
dag=create_test_dag('test_dag_3'), model_uri='abc', version_name='vvv',
170+
**other_params_but_models)
167171

168172
with self.assertRaisesRegex(AirflowException, '`metric_fn` param must be callable'):
169173
params = other_params_but_models.copy()
170174
params['metric_fn_and_keys'] = (None, ['abc'])
171-
mlengine_operator_utils.create_evaluate_ops(model_uri='gs://blah', **params)
175+
mlengine_operator_utils.create_evaluate_ops(
176+
dag=create_test_dag('test_dag_4'), model_uri='gs://blah', **params)
172177

173178
with self.assertRaisesRegex(AirflowException, '`validate_fn` param must be callable'):
174179
params = other_params_but_models.copy()
175180
params['validate_fn'] = None
176-
mlengine_operator_utils.create_evaluate_ops(model_uri='gs://blah', **params)
181+
mlengine_operator_utils.create_evaluate_ops(
182+
dag=create_test_dag('test_dag_5'), model_uri='gs://blah', **params)

β€Žtests/providers/google/cloud/sensors/test_gcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def setUp(self):
205205
self.dag = dag
206206

207207
self.sensor = GCSUploadSessionCompleteSensor(
208-
task_id='sensor',
208+
task_id='sensor_1',
209209
bucket='test-bucket',
210210
prefix='test-prefix/path',
211211
inactivity_period=12,
@@ -227,7 +227,7 @@ def test_files_deleted_between_pokes_throw_error(self):
227227
@mock.patch('airflow.providers.google.cloud.sensors.gcs.get_time', mock_time)
228228
def test_files_deleted_between_pokes_allow_delete(self):
229229
self.sensor = GCSUploadSessionCompleteSensor(
230-
task_id='sensor',
230+
task_id='sensor_2',
231231
bucket='test-bucket',
232232
prefix='test-prefix/path',
233233
inactivity_period=12,

0 commit comments

Comments
 (0)