Skip to content

Commit 08575dd

Browse files
authored
Change BaseOperatorLink interface to take a ti_key, not a datetime (#21798)
1 parent 5befc7f commit 08575dd

File tree

36 files changed

+508
-513
lines changed

36 files changed

+508
-513
lines changed

β€ŽUPDATING.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,16 @@ This setting is also used for the deprecated experimental API, which only uses t
163163

164164
To allow the Airflow UI to use the API, the previous default authorization backend `airflow.api.auth.backend.deny_all` is changed to `airflow.api.auth.backend.session`, and this is automatically added to the list of API authorization backends if a non-default value is set.
165165

166+
### BaseOperatorLink's `get_link` method changed to take a `ti_key` keyword argument
167+
168+
In v2.2 we "deprecated" passing an execution date to XCom.get methods, but there was no other option for operator links as they were only passed an execution_date.
169+
170+
Now in 2.3 as part of Dynamic Task Mapping (AIP-42) we will need to add map_index to the XCom row to support the "reduce" part of the API.
171+
172+
In order to support that cleanly we have changed the interface for BaseOperatorLink to take an TaskInstanceKey as the `ti_key` keyword argument (as execution_date + task is no longer unique for mapped operators).
173+
174+
The existing signature will be detected (by the absence of the `ti_key` argument) and continue to work.
175+
166176
## Airflow 2.2.4
167177

168178
### Smart sensors deprecated

β€Žairflow/api_connexion/endpoints/extra_link_endpoint.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from airflow.api_connexion.types import APIResponse
2525
from airflow.exceptions import TaskNotFound
2626
from airflow.models.dagbag import DagBag
27-
from airflow.models.dagrun import DagRun as DR
2827
from airflow.security import permissions
2928
from airflow.utils.session import NEW_SESSION, provide_session
3029

@@ -45,6 +44,8 @@ def get_extra_links(
4544
session: Session = NEW_SESSION,
4645
) -> APIResponse:
4746
"""Get extra links for task instance"""
47+
from airflow.models.taskinstance import TaskInstance
48+
4849
dagbag: DagBag = current_app.dag_bag
4950
dag: DAG = dagbag.get_dag(dag_id)
5051
if not dag:
@@ -55,14 +56,21 @@ def get_extra_links(
5556
except TaskNotFound:
5657
raise NotFound("Task not found", detail=f'Task with ID = "{task_id}" not found')
5758

58-
execution_date = (
59-
session.query(DR.execution_date).filter(DR.dag_id == dag_id).filter(DR.run_id == dag_run_id).scalar()
59+
ti = (
60+
session.query(TaskInstance)
61+
.filter(
62+
TaskInstance.dag_id == dag_id,
63+
TaskInstance.run_id == dag_run_id,
64+
TaskInstance.task_id == task_id,
65+
)
66+
.one_or_none()
6067
)
61-
if not execution_date:
68+
69+
if not ti:
6270
raise NotFound("DAG Run not found", detail=f'DAG Run with ID = "{dag_run_id}" not found')
6371

6472
all_extra_link_pairs = (
65-
(link_name, task.get_extra_links(execution_date, link_name)) for link_name in task.extra_links
73+
(link_name, task.get_extra_links(ti, link_name)) for link_name in task.extra_links
6674
)
6775
all_extra_links = {
6876
link_name: link_url if link_url else None for link_name, link_url in all_extra_link_pairs

β€Žairflow/models/abstractoperator.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# under the License.
1818

1919
import datetime
20+
import inspect
2021
from typing import TYPE_CHECKING, Any, Callable, Collection, Dict, Iterable, List, Optional, Set, Type, Union
2122

2223
from sqlalchemy.orm import Session
@@ -32,23 +33,26 @@
3233
from airflow.utils.trigger_rule import TriggerRule
3334
from airflow.utils.weight_rule import WeightRule
3435

36+
TaskStateChangeCallback = Callable[[Context], None]
37+
3538
if TYPE_CHECKING:
3639
import jinja2 # Slow import.
3740

3841
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
3942
from airflow.models.dag import DAG
4043
from airflow.models.operator import Operator
44+
from airflow.models.taskinstance import TaskInstance
4145

42-
DEFAULT_OWNER = conf.get("operators", "default_owner")
43-
DEFAULT_POOL_SLOTS = 1
44-
DEFAULT_PRIORITY_WEIGHT = 1
45-
DEFAULT_QUEUE = conf.get("operators", "default_queue")
46-
DEFAULT_RETRIES = conf.getint("core", "default_task_retries", fallback=0)
47-
DEFAULT_RETRY_DELAY = datetime.timedelta(seconds=300)
48-
DEFAULT_WEIGHT_RULE = conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM)
49-
DEFAULT_TRIGGER_RULE = TriggerRule.ALL_SUCCESS
50-
51-
TaskStateChangeCallback = Callable[[Context], None]
46+
DEFAULT_OWNER: str = conf.get("operators", "default_owner")
47+
DEFAULT_POOL_SLOTS: int = 1
48+
DEFAULT_PRIORITY_WEIGHT: int = 1
49+
DEFAULT_QUEUE: str = conf.get("operators", "default_queue")
50+
DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0)
51+
DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(seconds=300)
52+
DEFAULT_WEIGHT_RULE: WeightRule = WeightRule(
53+
conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM)
54+
)
55+
DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS
5256

5357

5458
class AbstractOperator(LoggingMixin, DAGNode):
@@ -239,19 +243,29 @@ def global_operator_extra_link_dict(self) -> Dict[str, Any]:
239243
def extra_links(self) -> List[str]:
240244
return list(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict))
241245

242-
def get_extra_links(self, dttm: datetime.datetime, link_name: str) -> Optional[Dict[str, Any]]:
246+
def get_extra_links(self, ti: "TaskInstance", link_name: str) -> Optional[str]:
243247
"""For an operator, gets the URLs that the ``extra_links`` entry points to.
244248
249+
:meta private:
250+
245251
:raise ValueError: The error message of a ValueError will be passed on through to
246252
the fronted to show up as a tooltip on the disabled link.
247-
:param dttm: The datetime parsed execution date for the URL being searched for.
253+
:param ti: The TaskInstance for the URL being searched for.
248254
:param link_name: The name of the link we're looking for the URL for. Should be
249255
one of the options specified in ``extra_links``.
250256
"""
251-
if link_name in self.operator_extra_link_dict:
252-
return self.operator_extra_link_dict[link_name].get_link(self, dttm)
253-
elif link_name in self.global_operator_extra_link_dict:
254-
return self.global_operator_extra_link_dict[link_name].get_link(self, dttm)
257+
link: Optional["BaseOperatorLink"] = self.operator_extra_link_dict.get(link_name)
258+
if not link:
259+
link = self.global_operator_extra_link_dict.get(link_name)
260+
if not link:
261+
return None
262+
# Check for old function signature
263+
parameters = inspect.signature(link.get_link).parameters
264+
args = [name for name, p in parameters.items() if p.kind != p.VAR_KEYWORD]
265+
if "ti_key" in args:
266+
return link.get_link(self, ti_key=ti.key)
267+
else:
268+
return link.get_link(self, ti.dag_run.logical_date) # type: ignore[misc]
255269
return None
256270

257271
def render_template_fields(

β€Žairflow/models/baseoperator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
import jinja2 # Slow import.
9494

9595
from airflow.models.dag import DAG
96+
from airflow.models.taskinstance import TaskInstanceKey
9697
from airflow.utils.task_group import TaskGroup
9798

9899
ScheduleInterval = Union[str, timedelta, relativedelta]
@@ -1730,11 +1731,14 @@ def name(self) -> str:
17301731
"""
17311732

17321733
@abstractmethod
1733-
def get_link(self, operator: BaseOperator, dttm: datetime) -> str:
1734+
def get_link(self, operator: AbstractOperator, *, ti_key: "TaskInstanceKey") -> str:
17341735
"""
17351736
Link to external system.
17361737
1738+
Note: The old signature of this function was ``(self, operator, dttm: datetime)``. That is still
1739+
supported at runtime but is deprecated.
1740+
17371741
:param operator: airflow operator
1738-
:param dttm: datetime
1742+
:param ti_key: TaskInstance ID to return link for
17391743
:return: link to external system
17401744
"""

β€Žairflow/models/xcom.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@
5050
# run without storing it in the database.
5151
IN_MEMORY_DAGRUN_ID = "__airflow_in_memory_dagrun__"
5252

53+
if TYPE_CHECKING:
54+
from airflow.models.taskinstance import TaskInstanceKey
55+
5356

5457
class BaseXCom(Base, LoggingMixin):
5558
"""Base class for XCom objects."""
@@ -205,11 +208,8 @@ def set(
205208
def get_one(
206209
cls,
207210
*,
208-
run_id: str,
209211
key: Optional[str] = None,
210-
task_id: Optional[str] = None,
211-
dag_id: Optional[str] = None,
212-
include_prior_dates: bool = False,
212+
ti_key: "TaskInstanceKey",
213213
session: Session = NEW_SESSION,
214214
) -> Optional[Any]:
215215
"""Retrieve an XCom value, optionally meeting certain criteria.
@@ -223,20 +223,29 @@ def get_one(
223223
A deprecated form of this function accepts ``execution_date`` instead of
224224
``run_id``. The two arguments are mutually exclusive.
225225
226-
:param run_id: DAG run ID for the task.
226+
:param ti_key: The TaskInstanceKey to look up the XCom for
227227
:param key: A key for the XCom. If provided, only XCom with matching
228228
keys will be returned. Pass *None* (default) to remove the filter.
229-
:param task_id: Only XCom from task with matching ID will be pulled.
230-
Pass *None* (default) to remove the filter.
231-
:param dag_id: Only pull XCom from this DAG. If *None* (default), the
232-
DAG of the calling task is used.
233229
:param include_prior_dates: If *False* (default), only XCom from the
234230
specified DAG run is returned. If *True*, the latest matching XCom is
235231
returned regardless of the run it belongs to.
236232
:param session: Database session. If not given, a new session will be
237233
created for this function.
238234
"""
239235

236+
@overload
237+
@classmethod
238+
def get_one(
239+
cls,
240+
*,
241+
key: Optional[str] = None,
242+
task_id: str,
243+
dag_id: str,
244+
run_id: str,
245+
session: Session = NEW_SESSION,
246+
) -> Optional[Any]:
247+
...
248+
240249
@overload
241250
@classmethod
242251
def get_one(
@@ -256,24 +265,35 @@ def get_one(
256265
cls,
257266
execution_date: Optional[datetime.datetime] = None,
258267
key: Optional[str] = None,
259-
task_id: Optional[Union[str, Iterable[str]]] = None,
260-
dag_id: Optional[Union[str, Iterable[str]]] = None,
268+
task_id: Optional[str] = None,
269+
dag_id: Optional[str] = None,
261270
include_prior_dates: bool = False,
262271
session: Session = NEW_SESSION,
263272
*,
264273
run_id: Optional[str] = None,
274+
ti_key: Optional["TaskInstanceKey"] = None,
265275
) -> Optional[Any]:
266276
""":sphinx-autoapi-skip:"""
267-
if not exactly_one(execution_date is not None, run_id is not None):
268-
raise ValueError("Exactly one of run_id or execution_date must be passed")
269-
270-
if run_id is not None:
277+
if not exactly_one(execution_date is not None, ti_key is not None, run_id is not None):
278+
raise ValueError("Exactly one of ti_key, run_id, or execution_date must be passed")
279+
280+
if ti_key is not None:
281+
query = session.query(cls).filter_by(
282+
dag_id=ti_key.dag_id,
283+
run_id=ti_key.run_id,
284+
task_id=ti_key.task_id,
285+
)
286+
if key:
287+
query = query.filter_by(key=key)
288+
query = query.limit(1)
289+
elif run_id:
271290
query = cls.get_many(
272291
run_id=run_id,
273292
key=key,
274293
task_ids=task_id,
275294
dag_ids=dag_id,
276295
include_prior_dates=include_prior_dates,
296+
limit=1,
277297
session=session,
278298
)
279299
elif execution_date is not None:
@@ -288,6 +308,7 @@ def get_one(
288308
task_ids=task_id,
289309
dag_ids=dag_id,
290310
include_prior_dates=include_prior_dates,
311+
limit=1,
291312
session=session,
292313
)
293314
else:

β€Žairflow/operators/trigger_dagrun.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import datetime
2020
import json
2121
import time
22-
from typing import Dict, List, Optional, Sequence, Union
22+
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union, cast
2323

2424
from airflow.api.common.trigger_dag import trigger_dag
2525
from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists
@@ -35,6 +35,11 @@
3535
XCOM_RUN_ID = "trigger_run_id"
3636

3737

38+
if TYPE_CHECKING:
39+
from airflow.models.abstractoperator import AbstractOperator
40+
from airflow.models.taskinstance import TaskInstanceKey
41+
42+
3843
class TriggerDagRunLink(BaseOperatorLink):
3944
"""
4045
Operator link for TriggerDagRunOperator. It allows users to access
@@ -43,14 +48,16 @@ class TriggerDagRunLink(BaseOperatorLink):
4348

4449
name = 'Triggered DAG'
4550

46-
def get_link(self, operator, dttm):
51+
def get_link(
52+
self,
53+
operator: "AbstractOperator",
54+
*,
55+
ti_key: "TaskInstanceKey",
56+
) -> str:
4757
# Fetch the correct execution date for the triggerED dag which is
4858
# stored in xcom during execution of the triggerING task.
49-
trigger_execution_date_iso = XCom.get_one(
50-
execution_date=dttm, key=XCOM_EXECUTION_DATE_ISO, task_id=operator.task_id, dag_id=operator.dag_id
51-
)
52-
53-
query = {"dag_id": operator.trigger_dag_id, "base_date": trigger_execution_date_iso}
59+
when = XCom.get_one(ti_key=ti_key, key=XCOM_EXECUTION_DATE_ISO)
60+
query = {"dag_id": cast(TriggerDagRunOperator, operator).trigger_dag_id, "base_date": when}
5461
return build_airflow_url_with_query(query)
5562

5663

β€Žairflow/providers/amazon/aws/operators/emr.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from airflow.providers.amazon.aws.hooks.emr import EmrHook
2727

2828
if TYPE_CHECKING:
29+
from airflow.models.taskinstance import TaskInstanceKey
2930
from airflow.utils.context import Context
3031

3132

@@ -230,17 +231,26 @@ class EmrClusterLink(BaseOperatorLink):
230231

231232
name = 'EMR Cluster'
232233

233-
def get_link(self, operator: BaseOperator, dttm: datetime) -> str:
234+
def get_link(
235+
self,
236+
operator,
237+
dttm: Optional[datetime] = None,
238+
ti_key: Optional["TaskInstanceKey"] = None,
239+
) -> str:
234240
"""
235241
Get link to EMR cluster.
236242
237243
:param operator: operator
238244
:param dttm: datetime
239245
:return: url link
240246
"""
241-
flow_id = XCom.get_one(
242-
key="return_value", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm
243-
)
247+
if ti_key:
248+
flow_id = XCom.get_one(key="return_value", ti_key=ti_key)
249+
else:
250+
assert dttm
251+
flow_id = XCom.get_one(
252+
key="return_value", dag_id=operator.dag_id, task_id=operator.task_id, execution_date=dttm
253+
)
244254
return (
245255
f'https://console.aws.amazon.com/elasticmapreduce/home#cluster-details:{flow_id}'
246256
if flow_id

β€Žairflow/providers/dbt/cloud/operators/dbt.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,14 @@ class DbtCloudRunJobOperatorLink(BaseOperatorLink):
3333

3434
name = "Monitor Job Run"
3535

36-
def get_link(self, operator, dttm):
37-
job_run_url = XCom.get_one(
38-
dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm, key="job_run_url"
39-
)
36+
def get_link(self, operator, dttm=None, *, ti_key=None):
37+
if ti_key:
38+
job_run_url = XCom.get_one(key="job_run_url", ti_key=ti_key)
39+
else:
40+
assert dttm
41+
job_run_url = XCom.get_one(
42+
dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm, key="job_run_url"
43+
)
4044

4145
return job_run_url
4246

0 commit comments

Comments
 (0)