Skip to content

Commit e195a98

Browse files
authored
Add type annotations for mlengine_operator_utils (#10297)
Add type annotations, including a few changes to ensure the right types are passed through. Specifically, if region is not given, it must be provided in the DAG's default_args.
1 parent 382c101 commit e195a98

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

β€Žairflow/providers/google/cloud/utils/mlengine_operator_utils.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,31 +24,35 @@
2424
import json
2525
import os
2626
import re
27+
from typing import Callable, Dict, Iterable, List, Optional, Tuple, TypeVar
2728
from urllib.parse import urlsplit
2829

2930
import dill
3031

32+
from airflow import DAG
3133
from airflow.exceptions import AirflowException
3234
from airflow.operators.python import PythonOperator
3335
from airflow.providers.google.cloud.hooks.gcs import GCSHook
3436
from airflow.providers.google.cloud.operators.dataflow import DataflowCreatePythonJobOperator
3537
from airflow.providers.google.cloud.operators.mlengine import MLEngineStartBatchPredictionJobOperator
3638

37-
38-
def create_evaluate_ops(task_prefix, # pylint: disable=too-many-arguments
39-
data_format,
40-
input_paths,
41-
prediction_path,
42-
metric_fn_and_keys,
43-
validate_fn,
44-
batch_prediction_job_id=None,
45-
project_id=None,
46-
region=None,
47-
dataflow_options=None,
48-
model_uri=None,
49-
model_name=None,
50-
version_name=None,
51-
dag=None,
39+
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
40+
41+
42+
def create_evaluate_ops(task_prefix: str, # pylint: disable=too-many-arguments
43+
data_format: str,
44+
input_paths: List[str],
45+
prediction_path: str,
46+
metric_fn_and_keys: Tuple[T, Iterable[str]],
47+
validate_fn: T,
48+
batch_prediction_job_id: Optional[str] = None,
49+
region: Optional[str] = None,
50+
project_id: Optional[str] = None,
51+
dataflow_options: Optional[Dict] = None,
52+
model_uri: Optional[str] = None,
53+
model_name: Optional[str] = None,
54+
version_name: Optional[str] = None,
55+
dag: Optional[DAG] = None,
5256
py_interpreter="python3"):
5357
"""
5458
Creates Operators needed for model evaluation and returns.
@@ -186,6 +190,9 @@ def validate_err_and_count(summary):
186190
:rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator,
187191
PythonOperator)
188192
"""
193+
batch_prediction_job_id = batch_prediction_job_id or ""
194+
dataflow_options = dataflow_options or {}
195+
region = region or ""
189196

190197
# Verify that task_prefix doesn't have any special characters except hyphen
191198
# '-', which is the only allowed non-alphanumeric character by Dataflow.
@@ -203,7 +210,7 @@ def validate_err_and_count(summary):
203210
if dag is not None and dag.default_args is not None:
204211
default_args = dag.default_args
205212
project_id = project_id or default_args.get('project_id')
206-
region = region or default_args.get('region')
213+
region = region or default_args['region']
207214
model_name = model_name or default_args.get('model_name')
208215
version_name = version_name or default_args.get('version_name')
209216
dataflow_options = dataflow_options or \

0 commit comments

Comments
 (0)