Skip to content

Commit 1ed0146

Browse files
josh-felluranusjr
andauthored
Add output property to MappedOperator (#25604)
Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
1 parent 05cbba3 commit 1ed0146

File tree

45 files changed

+227
-108
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+227
-108
lines changed

β€Žairflow/example_dags/example_xcom.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"""Example DAG demonstrating the usage of XComs."""
2020
import pendulum
2121

22-
from airflow import DAG
22+
from airflow import DAG, XComArg
2323
from airflow.decorators import task
2424
from airflow.operators.bash import BashOperator
2525

@@ -79,8 +79,8 @@ def pull_value_from_bash_push(ti=None):
7979
bash_pull = BashOperator(
8080
task_id='bash_pull',
8181
bash_command='echo "bash pull demo" && '
82-
f'echo "The xcom pushed manually is {bash_push.output["manually_pushed_value"]}" && '
83-
f'echo "The returned_value xcom is {bash_push.output}" && '
82+
f'echo "The xcom pushed manually is {XComArg(bash_push, key="manually_pushed_value")}" && '
83+
f'echo "The returned_value xcom is {XComArg(bash_push)}" && '
8484
'echo "finished"',
8585
do_xcom_push=False,
8686
)
@@ -90,6 +90,3 @@ def pull_value_from_bash_push(ti=None):
9090
[bash_pull, python_pull_from_bash] << bash_push
9191

9292
puller(push_by_returning()) << push()
93-
94-
# Task dependencies created via `XComArgs`:
95-
# pull << push2

β€Žairflow/models/baseoperator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898

9999
from airflow.models.dag import DAG
100100
from airflow.models.taskinstance import TaskInstanceKey
101+
from airflow.models.xcom_arg import XComArg
101102
from airflow.utils.task_group import TaskGroup
102103

103104
ScheduleInterval = Union[str, timedelta, relativedelta]
@@ -1365,7 +1366,7 @@ def leaves(self) -> List["BaseOperator"]:
13651366
return [self]
13661367

13671368
@property
1368-
def output(self):
1369+
def output(self) -> "XComArg":
13691370
"""Returns reference to XCom pushed by current operator"""
13701371
from airflow.models.xcom_arg import XComArg
13711372

β€Žairflow/models/mappedoperator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,13 @@ def get_dag(self) -> Optional["DAG"]:
530530
"""Implementing Operator."""
531531
return self.dag
532532

533+
@property
534+
def output(self) -> "XComArg":
535+
"""Returns reference to XCom pushed by current operator"""
536+
from airflow.models.xcom_arg import XComArg
537+
538+
return XComArg(operator=self)
539+
533540
def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]:
534541
"""Implementing DAGNode."""
535542
return DagAttributeTypes.OP, self.task_id

β€Žairflow/providers/amazon/aws/example_dags/example_dms.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import json
2424
import os
2525
from datetime import datetime
26+
from typing import cast
2627

2728
import boto3
2829
from sqlalchemy import Column, MetaData, String, Table, create_engine
@@ -256,10 +257,12 @@ def delete_dms_assets():
256257
)
257258
# [END howto_operator_dms_create_task]
258259

260+
task_arn = cast(str, create_task.output)
261+
259262
# [START howto_operator_dms_start_task]
260263
start_task = DmsStartTaskOperator(
261264
task_id='start_task',
262-
replication_task_arn=create_task.output,
265+
replication_task_arn=task_arn,
263266
)
264267
# [END howto_operator_dms_start_task]
265268

@@ -280,30 +283,30 @@ def delete_dms_assets():
280283

281284
await_task_start = DmsTaskBaseSensor(
282285
task_id='await_task_start',
283-
replication_task_arn=create_task.output,
286+
replication_task_arn=task_arn,
284287
target_statuses=['running'],
285288
termination_statuses=['stopped', 'deleting', 'failed'],
286289
)
287290

288291
# [START howto_operator_dms_stop_task]
289292
stop_task = DmsStopTaskOperator(
290293
task_id='stop_task',
291-
replication_task_arn=create_task.output,
294+
replication_task_arn=task_arn,
292295
)
293296
# [END howto_operator_dms_stop_task]
294297

295298
# TaskCompletedSensor actually waits until task reaches the "Stopped" state, so it will work here.
296299
# [START howto_sensor_dms_task_completed]
297300
await_task_stop = DmsTaskCompletedSensor(
298301
task_id='await_task_stop',
299-
replication_task_arn=create_task.output,
302+
replication_task_arn=task_arn,
300303
)
301304
# [END howto_sensor_dms_task_completed]
302305

303306
# [START howto_operator_dms_delete_task]
304307
delete_task = DmsDeleteTaskOperator(
305308
task_id='delete_task',
306-
replication_task_arn=create_task.output,
309+
replication_task_arn=task_arn,
307310
trigger_rule='all_done',
308311
)
309312
# [END howto_operator_dms_delete_task]

β€Žairflow/providers/amazon/aws/example_dags/example_ecs.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717

1818
from datetime import datetime
19+
from typing import cast
1920

2021
from airflow import DAG
2122
from airflow.models.baseoperator import chain
@@ -99,18 +100,20 @@
99100
)
100101
# [END howto_operator_ecs_register_task_definition]
101102

103+
registered_task_definition = cast(str, register_task.output)
104+
102105
# [START howto_sensor_ecs_task_definition_state]
103106
await_task_definition = EcsTaskDefinitionStateSensor(
104107
task_id='await_task_definition',
105-
task_definition=register_task.output,
108+
task_definition=registered_task_definition,
106109
)
107110
# [END howto_sensor_ecs_task_definition_state]
108111

109112
# [START howto_operator_ecs_run_task]
110113
run_task = EcsRunTaskOperator(
111114
task_id="run_task",
112115
cluster=EXISTING_CLUSTER_NAME,
113-
task_definition=register_task.output,
116+
task_definition=registered_task_definition,
114117
launch_type="EC2",
115118
overrides={
116119
"containerOverrides": [
@@ -156,7 +159,7 @@
156159
deregister_task = EcsDeregisterTaskDefinitionOperator(
157160
task_id='deregister_task',
158161
trigger_rule=TriggerRule.ALL_DONE,
159-
task_definition=register_task.output,
162+
task_definition=registered_task_definition,
160163
)
161164
# [END howto_operator_ecs_deregister_task_definition]
162165

β€Žairflow/providers/amazon/aws/example_dags/example_emr.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# under the License.
1818
import os
1919
from datetime import datetime
20+
from typing import cast
2021

2122
from airflow import DAG
2223
from airflow.models.baseoperator import chain
@@ -79,39 +80,38 @@
7980
)
8081
# [END howto_operator_emr_create_job_flow]
8182

83+
job_flow_id = cast(str, job_flow_creator.output)
84+
8285
# [START howto_sensor_emr_job_flow]
83-
job_sensor = EmrJobFlowSensor(
84-
task_id='check_job_flow',
85-
job_flow_id=job_flow_creator.output,
86-
)
86+
job_sensor = EmrJobFlowSensor(task_id='check_job_flow', job_flow_id=job_flow_id)
8787
# [END howto_sensor_emr_job_flow]
8888

8989
# [START howto_operator_emr_modify_cluster]
9090
cluster_modifier = EmrModifyClusterOperator(
91-
task_id='modify_cluster', cluster_id=job_flow_creator.output, step_concurrency_level=1
91+
task_id='modify_cluster', cluster_id=job_flow_id, step_concurrency_level=1
9292
)
9393
# [END howto_operator_emr_modify_cluster]
9494

9595
# [START howto_operator_emr_add_steps]
9696
step_adder = EmrAddStepsOperator(
9797
task_id='add_steps',
98-
job_flow_id=job_flow_creator.output,
98+
job_flow_id=job_flow_id,
9999
steps=SPARK_STEPS,
100100
)
101101
# [END howto_operator_emr_add_steps]
102102

103103
# [START howto_sensor_emr_step]
104104
step_checker = EmrStepSensor(
105105
task_id='watch_step',
106-
job_flow_id=job_flow_creator.output,
106+
job_flow_id=job_flow_id,
107107
step_id="{{ task_instance.xcom_pull(task_ids='add_steps', key='return_value')[0] }}",
108108
)
109109
# [END howto_sensor_emr_step]
110110

111111
# [START howto_operator_emr_terminate_job_flow]
112112
cluster_remover = EmrTerminateJobFlowOperator(
113113
task_id='remove_cluster',
114-
job_flow_id=job_flow_creator.output,
114+
job_flow_id=job_flow_id,
115115
)
116116
# [END howto_operator_emr_terminate_job_flow]
117117

β€Žairflow/providers/google/cloud/example_dags/example_automl_nl_text_classification.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
"""
2222
import os
2323
from datetime import datetime
24+
from typing import cast
2425

2526
from airflow import models
27+
from airflow.models.xcom_arg import XComArg
2628
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
2729
from airflow.providers.google.cloud.operators.automl import (
2830
AutoMLCreateDatasetOperator,
@@ -67,7 +69,7 @@
6769
task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION
6870
)
6971

70-
dataset_id = create_dataset_task.output['dataset_id']
72+
dataset_id = cast(str, XComArg(create_dataset_task, key='dataset_id'))
7173

7274
import_dataset_task = AutoMLImportDataOperator(
7375
task_id="import_dataset_task",
@@ -80,7 +82,7 @@
8082

8183
create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION)
8284

83-
model_id = create_model.output['model_id']
85+
model_id = cast(str, XComArg(create_model, key='model_id'))
8486

8587
delete_model_task = AutoMLDeleteModelOperator(
8688
task_id="delete_model_task",

β€Žairflow/providers/google/cloud/example_dags/example_automl_nl_text_sentiment.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
"""
2222
import os
2323
from datetime import datetime
24+
from typing import cast
2425

2526
from airflow import models
27+
from airflow.models.xcom_arg import XComArg
2628
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
2729
from airflow.providers.google.cloud.operators.automl import (
2830
AutoMLCreateDatasetOperator,
@@ -68,7 +70,7 @@
6870
task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION
6971
)
7072

71-
dataset_id = create_dataset_task.output['dataset_id']
73+
dataset_id = cast(str, XComArg(create_dataset_task, key='dataset_id'))
7274

7375
import_dataset_task = AutoMLImportDataOperator(
7476
task_id="import_dataset_task",
@@ -81,7 +83,7 @@
8183

8284
create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION)
8385

84-
model_id = create_model.output['model_id']
86+
model_id = cast(str, XComArg(create_model, key='model_id'))
8587

8688
delete_model_task = AutoMLDeleteModelOperator(
8789
task_id="delete_model_task",

β€Žairflow/providers/google/cloud/example_dags/example_automl_tables.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
import os
2323
from copy import deepcopy
2424
from datetime import datetime
25-
from typing import Dict, List
25+
from typing import Dict, List, cast
2626

2727
from airflow import models
28+
from airflow.models.xcom_arg import XComArg
2829
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
2930
from airflow.providers.google.cloud.operators.automl import (
3031
AutoMLBatchPredictOperator,
@@ -103,7 +104,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str:
103104
project_id=GCP_PROJECT_ID,
104105
)
105106

106-
dataset_id = create_dataset_task.output['dataset_id']
107+
dataset_id = cast(str, XComArg(create_dataset_task, key='dataset_id'))
107108
# [END howto_operator_automl_create_dataset]
108109

109110
MODEL["dataset_id"] = dataset_id
@@ -158,7 +159,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str:
158159
project_id=GCP_PROJECT_ID,
159160
)
160161

161-
model_id = create_model_task.output['model_id']
162+
model_id = cast(str, XComArg(create_model_task, key='model_id'))
162163
# [END howto_operator_automl_create_model]
163164

164165
# [START howto_operator_automl_delete_model]
@@ -209,7 +210,7 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str:
209210
project_id=GCP_PROJECT_ID,
210211
)
211212

212-
dataset_id = create_dataset_task2.output['dataset_id']
213+
dataset_id = cast(str, XComArg(create_dataset_task2, key='dataset_id'))
213214

214215
import_dataset_task = AutoMLImportDataOperator(
215216
task_id="import_dataset_task",

β€Žairflow/providers/google/cloud/example_dags/example_automl_translation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
"""
2222
import os
2323
from datetime import datetime
24+
from typing import cast
2425

2526
from airflow import models
27+
from airflow.models.xcom_arg import XComArg
2628
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
2729
from airflow.providers.google.cloud.operators.automl import (
2830
AutoMLCreateDatasetOperator,
@@ -74,7 +76,7 @@
7476
task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION
7577
)
7678

77-
dataset_id = create_dataset_task.output["dataset_id"]
79+
dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id"))
7880

7981
import_dataset_task = AutoMLImportDataOperator(
8082
task_id="import_dataset_task",
@@ -87,7 +89,7 @@
8789

8890
create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION)
8991

90-
model_id = create_model.output["model_id"]
92+
model_id = cast(str, XComArg(create_model, key="model_id"))
9193

9294
delete_model_task = AutoMLDeleteModelOperator(
9395
task_id="delete_model_task",

0 commit comments

Comments
 (0)