Skip to content

Commit c8af059

Browse files
authored
Improve taskflow type hints with ParamSpec (#25173)
1 parent 5758454 commit c8af059

File tree

8 files changed

+96
-63
lines changed

8 files changed

+96
-63
lines changed

β€Žairflow/decorators/__init__.pyi

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
# necessarily exist at run time. See "Creating Custom @task Decorators"
2121
# documentation for more details.
2222

23-
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union, overload
23+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union, overload
2424

25-
from airflow.decorators.base import Function, Task, TaskDecorator
25+
from airflow.decorators.base import FParams, FReturn, Task, TaskDecorator
2626
from airflow.decorators.branch_python import branch_task
2727
from airflow.decorators.python import python_task
2828
from airflow.decorators.python_virtualenv import virtualenv_task
@@ -68,7 +68,7 @@ class TaskDecoratorCollection:
6868
"""
6969
# [START mixin_for_typing]
7070
@overload
71-
def python(self, python_callable: Function) -> Task[Function]: ...
71+
def python(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
7272
# [END mixin_for_typing]
7373
@overload
7474
def __call__(
@@ -81,7 +81,7 @@ class TaskDecoratorCollection:
8181
) -> TaskDecorator:
8282
"""Aliasing ``python``; signature should match exactly."""
8383
@overload
84-
def __call__(self, python_callable: Function) -> Task[Function]:
84+
def __call__(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]:
8585
"""Aliasing ``python``; signature should match exactly."""
8686
@overload
8787
def virtualenv(
@@ -122,7 +122,7 @@ class TaskDecoratorCollection:
122122
such as transmission a large amount of XCom to TaskAPI.
123123
"""
124124
@overload
125-
def virtualenv(self, python_callable: Function) -> Task[Function]: ...
125+
def virtualenv(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
126126
@overload
127127
def branch(self, *, multiple_outputs: Optional[bool] = None, **kwargs) -> TaskDecorator:
128128
"""Create a decorator to wrap the decorated callable into a BranchPythonOperator.
@@ -134,7 +134,7 @@ class TaskDecoratorCollection:
134134
Dict will unroll to XCom values with keys as XCom keys. Defaults to False.
135135
"""
136136
@overload
137-
def branch(self, python_callable: Function) -> Task[Function]: ...
137+
def branch(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
138138
# [START decorator_signature]
139139
def docker(
140140
self,

β€Žairflow/decorators/base.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import functools
1918
import inspect
2019
import re
2120
from typing import (
@@ -68,7 +67,7 @@
6867
)
6968
from airflow.models.pool import Pool
7069
from airflow.models.xcom_arg import XComArg
71-
from airflow.typing_compat import Protocol
70+
from airflow.typing_compat import ParamSpec, Protocol
7271
from airflow.utils import timezone
7372
from airflow.utils.context import KNOWN_CONTEXT_KEYS, Context
7473
from airflow.utils.task_group import TaskGroup, TaskGroupContext
@@ -236,13 +235,15 @@ def _hook_apply_defaults(self, *args, **kwargs):
236235
return args, kwargs
237236

238237

239-
Function = TypeVar("Function", bound=Callable)
238+
FParams = ParamSpec("FParams")
239+
240+
FReturn = TypeVar("FReturn")
240241

241242
OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator")
242243

243244

244245
@attr.define(slots=False)
245-
class _TaskDecorator(Generic[Function, OperatorSubclass]):
246+
class _TaskDecorator(Generic[FParams, FReturn, OperatorSubclass]):
246247
"""
247248
Helper class for providing dynamic task mapping to decorated functions.
248249
@@ -251,7 +252,7 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
251252
:meta private:
252253
"""
253254

254-
function: Function = attr.ib()
255+
function: Callable[FParams, FReturn] = attr.ib()
255256
operator_class: Type[OperatorSubclass]
256257
multiple_outputs: bool = attr.ib()
257258
kwargs: Dict[str, Any] = attr.ib(factory=dict)
@@ -272,7 +273,7 @@ def __attrs_post_init__(self):
272273
raise TypeError(f"@{self.decorator_name} does not support methods")
273274
self.kwargs.setdefault('task_id', self.function.__name__)
274275

275-
def __call__(self, *args, **kwargs) -> XComArg:
276+
def __call__(self, *args: "FParams.args", **kwargs: "FParams.kwargs") -> XComArg:
276277
op = self.operator_class(
277278
python_callable=self.function,
278279
op_args=args,
@@ -285,7 +286,7 @@ def __call__(self, *args, **kwargs) -> XComArg:
285286
return XComArg(op)
286287

287288
@property
288-
def __wrapped__(self) -> Function:
289+
def __wrapped__(self) -> Callable[FParams, FReturn]:
289290
return self.function
290291

291292
@cached_property
@@ -337,9 +338,7 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg:
337338
# to False to skip the checks on execution.
338339
return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)
339340

340-
def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) -> XComArg:
341-
from airflow.models.xcom_arg import XComArg
342-
341+
def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) -> XComArg:
343342
if not isinstance(kwargs, XComArg):
344343
raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}")
345344
return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
@@ -420,14 +419,14 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
420419
)
421420
return XComArg(operator=operator)
422421

423-
def partial(self, **kwargs: Any) -> "_TaskDecorator[Function, OperatorSubclass]":
422+
def partial(self, **kwargs: Any) -> "_TaskDecorator[FParams, FReturn, OperatorSubclass]":
424423
self._validate_arg_names("partial", kwargs)
425424
old_kwargs = self.kwargs.get("op_kwargs", {})
426425
prevent_duplicates(old_kwargs, kwargs, fail_reason="duplicate partial")
427426
kwargs.update(old_kwargs)
428427
return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs})
429428

430-
def override(self, **kwargs: Any) -> "_TaskDecorator[Function, OperatorSubclass]":
429+
def override(self, **kwargs: Any) -> "_TaskDecorator[FParams, FReturn, OperatorSubclass]":
431430
return attr.evolve(self, kwargs={**self.kwargs, **kwargs})
432431

433432

@@ -506,7 +505,7 @@ def _render_if_not_already_resolved(key: str, value: Any):
506505
return {k: _render_if_not_already_resolved(k, v) for k, v in value.items()}
507506

508507

509-
class Task(Generic[Function]):
508+
class Task(Generic[FParams, FReturn]):
510509
"""Declaration of a @task-decorated callable for type-checking.
511510
512511
An instance of this type inherits the call signature of the decorated
@@ -517,26 +516,32 @@ class Task(Generic[Function]):
517516
This type is implemented by ``_TaskDecorator`` at runtime.
518517
"""
519518

520-
__call__: Function
519+
__call__: Callable[FParams, XComArg]
521520

522-
function: Function
521+
function: Callable[FParams, FReturn]
523522

524523
@property
525-
def __wrapped__(self) -> Function:
524+
def __wrapped__(self) -> Callable[FParams, FReturn]:
525+
...
526+
527+
def partial(self, **kwargs: Any) -> "Task[FParams, FReturn]":
526528
...
527529

528530
def expand(self, **kwargs: "Mappable") -> XComArg:
529531
...
530532

531-
def partial(self, **kwargs: Any) -> "Task[Function]":
533+
def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) -> XComArg:
532534
...
533535

534536

535537
class TaskDecorator(Protocol):
536538
"""Type declaration for ``task_decorator_factory`` return type."""
537539

538540
@overload
539-
def __call__(self, python_callable: Function) -> Task[Function]:
541+
def __call__( # type: ignore[misc]
542+
self,
543+
python_callable: Callable[FParams, FReturn],
544+
) -> Task[FParams, FReturn]:
540545
"""For the "bare decorator" ``@task`` case."""
541546

542547
@overload
@@ -545,7 +550,7 @@ def __call__(
545550
*,
546551
multiple_outputs: Optional[bool] = None,
547552
**kwargs: Any,
548-
) -> Callable[[Function], Task[Function]]:
553+
) -> Callable[[Callable[FParams, FReturn]], Task[FParams, FReturn]]:
549554
"""For the decorator factory ``@task()`` case."""
550555

551556

@@ -556,16 +561,20 @@ def task_decorator_factory(
556561
decorated_operator_class: Type[BaseOperator],
557562
**kwargs,
558563
) -> TaskDecorator:
559-
"""
560-
A factory that generates a wrapper that wraps a function into an Airflow operator.
561-
Accepts kwargs for operator kwarg. Can be reused in a single DAG.
564+
"""Generate a wrapper that wraps a function into an Airflow operator.
562565
563-
:param python_callable: Function to decorate
564-
:param multiple_outputs: If set to True, the decorated function's return value will be unrolled to
565-
multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False.
566-
:param decorated_operator_class: The operator that executes the logic needed to run the python function in
567-
the correct environment
566+
Can be reused in a single DAG.
568567
568+
:param python_callable: Function to decorate.
569+
:param multiple_outputs: If set to True, the decorated function's return
570+
value will be unrolled to multiple XCom values. Dict will unroll to XCom
571+
values with its keys as XCom keys. If set to False (default), only at
572+
most one XCom value is pushed.
573+
:param decorated_operator_class: The operator that executes the logic needed
574+
to run the python function in the correct environment.
575+
576+
Other kwargs are directly forwarded to the underlying operator class when
577+
it's instantiated.
569578
"""
570579
if multiple_outputs is None:
571580
multiple_outputs = cast(bool, attr.NOTHING)
@@ -579,10 +588,13 @@ def task_decorator_factory(
579588
return cast(TaskDecorator, decorator)
580589
elif python_callable is not None:
581590
raise TypeError('No args allowed while using @task, use kwargs instead')
582-
decorator_factory = functools.partial(
583-
_TaskDecorator,
584-
multiple_outputs=multiple_outputs,
585-
operator_class=decorated_operator_class,
586-
kwargs=kwargs,
587-
)
591+
592+
def decorator_factory(python_callable):
593+
return _TaskDecorator(
594+
function=python_callable,
595+
multiple_outputs=multiple_outputs,
596+
operator_class=decorated_operator_class,
597+
kwargs=kwargs,
598+
)
599+
588600
return cast(TaskDecorator, decorator_factory)

β€Žairflow/providers/cncf/kubernetes/operators/kubernetes_pod.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def _get_ti_pod_labels(
328328
if include_try_number:
329329
labels.update(try_number=ti.try_number)
330330
# In the case of sub dags this is just useful
331-
if context['dag'].is_subdag:
331+
if context['dag'].parent_dag:
332332
labels['parent_dag_id'] = context['dag'].parent_dag.dag_id
333333
# Ensure that label is valid for Kube,
334334
# and if not truncate/remove invalid chars and replace with short hash.

β€Žairflow/providers/dbt/cloud/hooks/dbt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest:
8787
class JobRunInfo(TypedDict):
8888
"""Type class for the ``job_run_info`` dictionary."""
8989

90-
account_id: int
90+
account_id: Optional[int]
9191
run_id: int
9292

9393

β€Žairflow/providers/google/cloud/operators/gcs.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
from google.api_core.exceptions import Conflict
3232
from google.cloud.exceptions import GoogleCloudError
33-
from pendulum.datetime import DateTime
3433

3534
from airflow.exceptions import AirflowException
3635
from airflow.models import BaseOperator
@@ -723,22 +722,25 @@ def __init__(
723722
def execute(self, context: "Context") -> List[str]:
724723
# Define intervals and prefixes.
725724
try:
726-
timespan_start = context["data_interval_start"]
727-
timespan_end = context["data_interval_end"]
725+
orig_start = context["data_interval_start"]
726+
orig_end = context["data_interval_end"]
728727
except KeyError:
729-
timespan_start = pendulum.instance(context["execution_date"])
728+
orig_start = pendulum.instance(context["execution_date"])
730729
following_execution_date = context["dag"].following_schedule(context["execution_date"])
731730
if following_execution_date is None:
732-
timespan_end = None
731+
orig_end = None
733732
else:
734-
timespan_end = pendulum.instance(following_execution_date)
735-
736-
if timespan_end is None: # Only possible in Airflow before 2.2.
737-
self.log.warning("No following schedule found, setting timespan end to max %s", timespan_end)
738-
timespan_end = DateTime.max
739-
elif timespan_start >= timespan_end: # Airflow 2.2 sets start == end for non-perodic schedules.
740-
self.log.warning("DAG schedule not periodic, setting timespan end to max %s", timespan_end)
741-
timespan_end = DateTime.max
733+
orig_end = pendulum.instance(following_execution_date)
734+
735+
timespan_start = orig_start
736+
if orig_end is None: # Only possible in Airflow before 2.2.
737+
self.log.warning("No following schedule found, setting timespan end to max %s", orig_end)
738+
timespan_end = pendulum.instance(datetime.datetime.max)
739+
elif orig_start >= orig_end: # Airflow 2.2 sets start == end for non-perodic schedules.
740+
self.log.warning("DAG schedule not periodic, setting timespan end to max %s", orig_end)
741+
timespan_end = pendulum.instance(datetime.datetime.max)
742+
else:
743+
timespan_end = orig_end
742744

743745
timespan_start = timespan_start.in_timezone(timezone.utc)
744746
timespan_end = timespan_end.in_timezone(timezone.utc)

β€Žairflow/providers/qubole/hooks/qubole.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from airflow.utils.state import State
4747

4848
if TYPE_CHECKING:
49+
from airflow.models.taskinstance import TaskInstance
4950
from airflow.utils.context import Context
5051

5152

@@ -139,7 +140,7 @@ def __init__(self, *args, **kwargs) -> None:
139140
self.kwargs = kwargs
140141
self.cls = COMMAND_CLASSES[self.kwargs['command_type']]
141142
self.cmd: Optional[Command] = None
142-
self.task_instance = None
143+
self.task_instance: Optional["TaskInstance"] = None
143144

144145
@staticmethod
145146
def handle_failure_retry(context) -> None:

β€Žairflow/providers/salesforce/operators/bulk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class SalesforceBulkOperator(BaseOperator):
4747
def __init__(
4848
self,
4949
*,
50-
operation: Literal[available_operations],
50+
operation: Literal['insert', 'update', 'upsert', 'delete', 'hard_delete'],
5151
object_name: str,
5252
payload: list,
5353
external_id_field: str = 'Id',

β€Žairflow/typing_compat.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,28 @@
2121
codebase easier.
2222
"""
2323

24-
try:
25-
# Literal, Protocol and TypedDict are only added to typing module starting from
26-
# python 3.8 we can safely remove this shim import after Airflow drops
27-
# support for <3.8
28-
from typing import Literal, Protocol, TypedDict, runtime_checkable # type: ignore
29-
except ImportError:
30-
from typing_extensions import Literal, Protocol, TypedDict, runtime_checkable # type: ignore # noqa
24+
__all__ = [
25+
"Literal",
26+
"ParamSpec",
27+
"Protocol",
28+
"TypedDict",
29+
"runtime_checkable",
30+
]
31+
32+
import sys
33+
34+
if sys.version_info >= (3, 8):
35+
from typing import Protocol, TypedDict, runtime_checkable
36+
else:
37+
from typing_extensions import Protocol, TypedDict, runtime_checkable
38+
39+
# Literal in 3.8 is limited to one single argument, not e.g. "Literal[1, 2]".
40+
if sys.version_info >= (3, 9):
41+
from typing import Literal
42+
else:
43+
from typing_extensions import Literal
44+
45+
if sys.version_info >= (3, 10):
46+
from typing import ParamSpec
47+
else:
48+
from typing_extensions import ParamSpec

0 commit comments

Comments
 (0)