24
24
import json
25
25
import os
26
26
import re
27
+ from typing import Callable , Dict , Iterable , List , Optional , Tuple , TypeVar
27
28
from urllib .parse import urlsplit
28
29
29
30
import dill
30
31
32
+ from airflow import DAG
31
33
from airflow .exceptions import AirflowException
32
34
from airflow .operators .python import PythonOperator
33
35
from airflow .providers .google .cloud .hooks .gcs import GCSHook
34
36
from airflow .providers .google .cloud .operators .dataflow import DataflowCreatePythonJobOperator
35
37
from airflow .providers .google .cloud .operators .mlengine import MLEngineStartBatchPredictionJobOperator
36
38
37
-
38
- def create_evaluate_ops (task_prefix , # pylint: disable=too-many-arguments
39
- data_format ,
40
- input_paths ,
41
- prediction_path ,
42
- metric_fn_and_keys ,
43
- validate_fn ,
44
- batch_prediction_job_id = None ,
45
- project_id = None ,
46
- region = None ,
47
- dataflow_options = None ,
48
- model_uri = None ,
49
- model_name = None ,
50
- version_name = None ,
51
- dag = None ,
39
+ T = TypeVar ("T" , bound = Callable ) # pylint: disable=invalid-name
40
+
41
+
42
+ def create_evaluate_ops (task_prefix : str , # pylint: disable=too-many-arguments
43
+ data_format : str ,
44
+ input_paths : List [str ],
45
+ prediction_path : str ,
46
+ metric_fn_and_keys : Tuple [T , Iterable [str ]],
47
+ validate_fn : T ,
48
+ batch_prediction_job_id : Optional [str ] = None ,
49
+ region : Optional [str ] = None ,
50
+ project_id : Optional [str ] = None ,
51
+ dataflow_options : Optional [Dict ] = None ,
52
+ model_uri : Optional [str ] = None ,
53
+ model_name : Optional [str ] = None ,
54
+ version_name : Optional [str ] = None ,
55
+ dag : Optional [DAG ] = None ,
52
56
py_interpreter = "python3" ):
53
57
"""
54
58
Creates Operators needed for model evaluation and returns.
@@ -186,6 +190,9 @@ def validate_err_and_count(summary):
186
190
:rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator,
187
191
PythonOperator)
188
192
"""
193
+ batch_prediction_job_id = batch_prediction_job_id or ""
194
+ dataflow_options = dataflow_options or {}
195
+ region = region or ""
189
196
190
197
# Verify that task_prefix doesn't have any special characters except hyphen
191
198
# '-', which is the only allowed non-alphanumeric character by Dataflow.
@@ -203,7 +210,7 @@ def validate_err_and_count(summary):
203
210
if dag is not None and dag .default_args is not None :
204
211
default_args = dag .default_args
205
212
project_id = project_id or default_args .get ('project_id' )
206
- region = region or default_args . get ( 'region' )
213
+ region = region or default_args [ 'region' ]
207
214
model_name = model_name or default_args .get ('model_name' )
208
215
version_name = version_name or default_args .get ('version_name' )
209
216
dataflow_options = dataflow_options or \
0 commit comments