15
15
# specific language governing permissions and limitations
16
16
# under the License.
17
17
18
- import functools
19
18
import inspect
20
19
import re
21
20
from typing import (
68
67
)
69
68
from airflow .models .pool import Pool
70
69
from airflow .models .xcom_arg import XComArg
71
- from airflow .typing_compat import Protocol
70
+ from airflow .typing_compat import ParamSpec , Protocol
72
71
from airflow .utils import timezone
73
72
from airflow .utils .context import KNOWN_CONTEXT_KEYS , Context
74
73
from airflow .utils .task_group import TaskGroup , TaskGroupContext
@@ -236,13 +235,15 @@ def _hook_apply_defaults(self, *args, **kwargs):
236
235
return args , kwargs
237
236
238
237
239
- Function = TypeVar ("Function" , bound = Callable )
238
+ FParams = ParamSpec ("FParams" )
239
+
240
+ FReturn = TypeVar ("FReturn" )
240
241
241
242
OperatorSubclass = TypeVar ("OperatorSubclass" , bound = "BaseOperator" )
242
243
243
244
244
245
@attr .define (slots = False )
245
- class _TaskDecorator (Generic [Function , OperatorSubclass ]):
246
+ class _TaskDecorator (Generic [FParams , FReturn , OperatorSubclass ]):
246
247
"""
247
248
Helper class for providing dynamic task mapping to decorated functions.
248
249
@@ -251,7 +252,7 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
251
252
:meta private:
252
253
"""
253
254
254
- function : Function = attr .ib ()
255
+ function : Callable [ FParams , FReturn ] = attr .ib ()
255
256
operator_class : Type [OperatorSubclass ]
256
257
multiple_outputs : bool = attr .ib ()
257
258
kwargs : Dict [str , Any ] = attr .ib (factory = dict )
@@ -272,7 +273,7 @@ def __attrs_post_init__(self):
272
273
raise TypeError (f"@{ self .decorator_name } does not support methods" )
273
274
self .kwargs .setdefault ('task_id' , self .function .__name__ )
274
275
275
- def __call__ (self , * args , ** kwargs ) -> XComArg :
276
+ def __call__ (self , * args : "FParams.args" , ** kwargs : "FParams.kwargs" ) -> XComArg :
276
277
op = self .operator_class (
277
278
python_callable = self .function ,
278
279
op_args = args ,
@@ -285,7 +286,7 @@ def __call__(self, *args, **kwargs) -> XComArg:
285
286
return XComArg (op )
286
287
287
288
@property
288
- def __wrapped__ (self ) -> Function :
289
+ def __wrapped__ (self ) -> Callable [ FParams , FReturn ] :
289
290
return self .function
290
291
291
292
@cached_property
@@ -337,9 +338,7 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg:
337
338
# to False to skip the checks on execution.
338
339
return self ._expand (DictOfListsExpandInput (map_kwargs ), strict = False )
339
340
340
- def expand_kwargs (self , kwargs : "XComArg" , * , strict : bool = True ) -> XComArg :
341
- from airflow .models .xcom_arg import XComArg
342
-
341
+ def expand_kwargs (self , kwargs : XComArg , * , strict : bool = True ) -> XComArg :
343
342
if not isinstance (kwargs , XComArg ):
344
343
raise TypeError (f"expected XComArg object, not { type (kwargs ).__name__ } " )
345
344
return self ._expand (ListOfDictsExpandInput (kwargs ), strict = strict )
@@ -420,14 +419,14 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
420
419
)
421
420
return XComArg (operator = operator )
422
421
423
- def partial (self , ** kwargs : Any ) -> "_TaskDecorator[Function , OperatorSubclass]" :
422
+ def partial (self , ** kwargs : Any ) -> "_TaskDecorator[FParams, FReturn , OperatorSubclass]" :
424
423
self ._validate_arg_names ("partial" , kwargs )
425
424
old_kwargs = self .kwargs .get ("op_kwargs" , {})
426
425
prevent_duplicates (old_kwargs , kwargs , fail_reason = "duplicate partial" )
427
426
kwargs .update (old_kwargs )
428
427
return attr .evolve (self , kwargs = {** self .kwargs , "op_kwargs" : kwargs })
429
428
430
- def override (self , ** kwargs : Any ) -> "_TaskDecorator[Function , OperatorSubclass]" :
429
+ def override (self , ** kwargs : Any ) -> "_TaskDecorator[FParams, FReturn , OperatorSubclass]" :
431
430
return attr .evolve (self , kwargs = {** self .kwargs , ** kwargs })
432
431
433
432
@@ -506,7 +505,7 @@ def _render_if_not_already_resolved(key: str, value: Any):
506
505
return {k : _render_if_not_already_resolved (k , v ) for k , v in value .items ()}
507
506
508
507
509
- class Task (Generic [Function ]):
508
+ class Task (Generic [FParams , FReturn ]):
510
509
"""Declaration of a @task-decorated callable for type-checking.
511
510
512
511
An instance of this type inherits the call signature of the decorated
@@ -517,26 +516,32 @@ class Task(Generic[Function]):
517
516
This type is implemented by ``_TaskDecorator`` at runtime.
518
517
"""
519
518
520
- __call__ : Function
519
+ __call__ : Callable [ FParams , XComArg ]
521
520
522
- function : Function
521
+ function : Callable [ FParams , FReturn ]
523
522
524
523
@property
525
- def __wrapped__ (self ) -> Function :
524
+ def __wrapped__ (self ) -> Callable [FParams , FReturn ]:
525
+ ...
526
+
527
+ def partial (self , ** kwargs : Any ) -> "Task[FParams, FReturn]" :
526
528
...
527
529
528
530
def expand (self , ** kwargs : "Mappable" ) -> XComArg :
529
531
...
530
532
531
- def partial (self , ** kwargs : Any ) -> "Task[Function]" :
533
+ def expand_kwargs (self , kwargs : XComArg , * , strict : bool = True ) -> XComArg :
532
534
...
533
535
534
536
535
537
class TaskDecorator (Protocol ):
536
538
"""Type declaration for ``task_decorator_factory`` return type."""
537
539
538
540
@overload
539
- def __call__ (self , python_callable : Function ) -> Task [Function ]:
541
+ def __call__ ( # type: ignore[misc]
542
+ self ,
543
+ python_callable : Callable [FParams , FReturn ],
544
+ ) -> Task [FParams , FReturn ]:
540
545
"""For the "bare decorator" ``@task`` case."""
541
546
542
547
@overload
@@ -545,7 +550,7 @@ def __call__(
545
550
* ,
546
551
multiple_outputs : Optional [bool ] = None ,
547
552
** kwargs : Any ,
548
- ) -> Callable [[Function ] , Task [Function ]]:
553
+ ) -> Callable [[Callable [ FParams , FReturn ]] , Task [FParams , FReturn ]]:
549
554
"""For the decorator factory ``@task()`` case."""
550
555
551
556
@@ -556,16 +561,20 @@ def task_decorator_factory(
556
561
decorated_operator_class : Type [BaseOperator ],
557
562
** kwargs ,
558
563
) -> TaskDecorator :
559
- """
560
- A factory that generates a wrapper that wraps a function into an Airflow operator.
561
- Accepts kwargs for operator kwarg. Can be reused in a single DAG.
564
+ """Generate a wrapper that wraps a function into an Airflow operator.
562
565
563
- :param python_callable: Function to decorate
564
- :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to
565
- multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False.
566
- :param decorated_operator_class: The operator that executes the logic needed to run the python function in
567
- the correct environment
566
+ Can be reused in a single DAG.
568
567
568
+ :param python_callable: Function to decorate.
569
+ :param multiple_outputs: If set to True, the decorated function's return
570
+ value will be unrolled to multiple XCom values. Dict will unroll to XCom
571
+ values with its keys as XCom keys. If set to False (default), only at
572
+ most one XCom value is pushed.
573
+ :param decorated_operator_class: The operator that executes the logic needed
574
+ to run the python function in the correct environment.
575
+
576
+ Other kwargs are directly forwarded to the underlying operator class when
577
+ it's instantiated.
569
578
"""
570
579
if multiple_outputs is None :
571
580
multiple_outputs = cast (bool , attr .NOTHING )
@@ -579,10 +588,13 @@ def task_decorator_factory(
579
588
return cast (TaskDecorator , decorator )
580
589
elif python_callable is not None :
581
590
raise TypeError ('No args allowed while using @task, use kwargs instead' )
582
- decorator_factory = functools .partial (
583
- _TaskDecorator ,
584
- multiple_outputs = multiple_outputs ,
585
- operator_class = decorated_operator_class ,
586
- kwargs = kwargs ,
587
- )
591
+
592
+ def decorator_factory (python_callable ):
593
+ return _TaskDecorator (
594
+ function = python_callable ,
595
+ multiple_outputs = multiple_outputs ,
596
+ operator_class = decorated_operator_class ,
597
+ kwargs = kwargs ,
598
+ )
599
+
588
600
return cast (TaskDecorator , decorator_factory )
0 commit comments