@@ -1080,26 +1080,30 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
1080
1080
:param job_id: A unique templated id for the submitted Google MLEngine
1081
1081
training job. (templated)
1082
1082
:type job_id: str
1083
- :param package_uris: A list of package locations for MLEngine training job,
1084
- which should include the main training program + any additional
1085
- dependencies. (templated)
1086
- :type package_uris: List[str]
1087
- :param training_python_module: The Python module name to run within MLEngine
1088
- training job after installing 'package_uris' packages. (templated)
1089
- :type training_python_module: str
1090
- :param training_args: A list of templated command line arguments to pass to
1091
- the MLEngine training program. (templated)
1092
- :type training_args: List[str]
1093
1083
:param region: The Google Compute Engine region to run the MLEngine training
1094
1084
job in (templated).
1095
1085
:type region: str
1086
+ :param package_uris: A list of Python package locations for the training
1087
+ job, which should include the main training program and any additional
1088
+ dependencies. This is mutually exclusive with a custom image specified
1089
+ via master_config. (templated)
1090
+ :type package_uris: List[str]
1091
+ :param training_python_module: The name of the Python module to run within
1092
+ the training job after installing the packages. This is mutually
1093
+ exclusive with a custom image specified via master_config. (templated)
1094
+ :type training_python_module: str
1095
+ :param training_args: A list of command-line arguments to pass to the
1096
+ training program. (templated)
1097
+ :type training_args: List[str]
1096
1098
:param scale_tier: Resource tier for MLEngine training job. (templated)
1097
1099
:type scale_tier: str
1098
- :param master_type: Cloud ML Engine machine name.
1099
- Must be set when scale_tier is CUSTOM. (templated)
1100
+ :param master_type: The type of virtual machine to use for the master
1101
+ worker. It must be set whenever scale_tier is CUSTOM. (templated)
1100
1102
:type master_type: str
1101
- :param master_config: Cloud ML Engine master config.
1102
- master_type must be set if master_config is provided. (templated)
1103
+ :param master_config: The configuration for the master worker. If this is
1104
+ provided, master_type must be set as well. If a custom image is
1105
+ specified, this is mutually exclusive with package_uris and
1106
+ training_python_module. (templated)
1103
1107
:type master_type: dict
1104
1108
:param runtime_version: The Google Cloud ML runtime version to use for
1105
1109
training. (templated)
@@ -1147,10 +1151,10 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
1147
1151
template_fields = [
1148
1152
'_project_id' ,
1149
1153
'_job_id' ,
1154
+ '_region' ,
1150
1155
'_package_uris' ,
1151
1156
'_training_python_module' ,
1152
1157
'_training_args' ,
1153
- '_region' ,
1154
1158
'_scale_tier' ,
1155
1159
'_master_type' ,
1156
1160
'_master_config' ,
@@ -1168,10 +1172,10 @@ def __init__(
1168
1172
self , # pylint: disable=too-many-arguments
1169
1173
* ,
1170
1174
job_id : str ,
1171
- package_uris : List [str ],
1172
- training_python_module : str ,
1173
- training_args : List [str ],
1174
1175
region : str ,
1176
+ package_uris : List [str ] = None ,
1177
+ training_python_module : str = None ,
1178
+ training_args : List [str ] = None ,
1175
1179
scale_tier : Optional [str ] = None ,
1176
1180
master_type : Optional [str ] = None ,
1177
1181
master_config : Optional [Dict ] = None ,
@@ -1190,10 +1194,10 @@ def __init__(
1190
1194
super ().__init__ (** kwargs )
1191
1195
self ._project_id = project_id
1192
1196
self ._job_id = job_id
1197
+ self ._region = region
1193
1198
self ._package_uris = package_uris
1194
1199
self ._training_python_module = training_python_module
1195
1200
self ._training_args = training_args
1196
- self ._region = region
1197
1201
self ._scale_tier = scale_tier
1198
1202
self ._master_type = master_type
1199
1203
self ._master_config = master_config
@@ -1207,37 +1211,56 @@ def __init__(
1207
1211
self ._labels = labels
1208
1212
self ._impersonation_chain = impersonation_chain
1209
1213
1214
+ custom = self ._scale_tier is not None and self ._scale_tier .upper () == 'CUSTOM'
1215
+ custom_image = (
1216
+ custom
1217
+ and self ._master_config is not None
1218
+ and self ._master_config .get ('imageUri' , None ) is not None
1219
+ )
1220
+
1210
1221
if not self ._project_id :
1211
1222
raise AirflowException ('Google Cloud project id is required.' )
1212
1223
if not self ._job_id :
1213
1224
raise AirflowException ('An unique job id is required for Google MLEngine training job.' )
1214
- if not package_uris :
1215
- raise AirflowException ('At least one python package is required for MLEngine Training job.' )
1216
- if not training_python_module :
1217
- raise AirflowException (
1218
- 'Python module name to run after installing required packages is required.'
1219
- )
1220
1225
if not self ._region :
1221
1226
raise AirflowException ('Google Compute Engine region is required.' )
1222
- if self . _scale_tier is not None and self . _scale_tier . upper () == "CUSTOM" and not self ._master_type :
1227
+ if custom and not self ._master_type :
1223
1228
raise AirflowException ('master_type must be set when scale_tier is CUSTOM' )
1224
1229
if self ._master_config and not self ._master_type :
1225
1230
raise AirflowException ('master_type must be set when master_config is provided' )
1231
+ if not (package_uris and training_python_module ) and not custom_image :
1232
+ raise AirflowException (
1233
+ 'Either a Python package with a Python module or a custom Docker image should be provided.'
1234
+ )
1235
+ if (package_uris or training_python_module ) and custom_image :
1236
+ raise AirflowException (
1237
+ 'Either a Python package with a Python module or '
1238
+ 'a custom Docker image should be provided but not both.'
1239
+ )
1226
1240
1227
1241
def execute (self , context ):
1228
1242
job_id = _normalize_mlengine_job_id (self ._job_id )
1229
1243
training_request = {
1230
1244
'jobId' : job_id ,
1231
1245
'trainingInput' : {
1232
1246
'scaleTier' : self ._scale_tier ,
1233
- 'packageUris' : self ._package_uris ,
1234
- 'pythonModule' : self ._training_python_module ,
1235
1247
'region' : self ._region ,
1236
- 'args' : self ._training_args ,
1237
1248
},
1238
1249
}
1239
- if self ._labels :
1240
- training_request ['labels' ] = self ._labels
1250
+ if self ._package_uris :
1251
+ training_request ['trainingInput' ]['packageUris' ] = self ._package_uris
1252
+
1253
+ if self ._training_python_module :
1254
+ training_request ['trainingInput' ]['pythonModule' ] = self ._training_python_module
1255
+
1256
+ if self ._training_args :
1257
+ training_request ['trainingInput' ]['args' ] = self ._training_args
1258
+
1259
+ if self ._master_type :
1260
+ training_request ['trainingInput' ]['masterType' ] = self ._master_type
1261
+
1262
+ if self ._master_config :
1263
+ training_request ['trainingInput' ]['masterConfig' ] = self ._master_config
1241
1264
1242
1265
if self ._runtime_version :
1243
1266
training_request ['trainingInput' ]['runtimeVersion' ] = self ._runtime_version
@@ -1251,11 +1274,8 @@ def execute(self, context):
1251
1274
if self ._service_account :
1252
1275
training_request ['trainingInput' ]['serviceAccount' ] = self ._service_account
1253
1276
1254
- if self ._scale_tier is not None and self ._scale_tier .upper () == "CUSTOM" :
1255
- training_request ['trainingInput' ]['masterType' ] = self ._master_type
1256
-
1257
- if self ._master_config :
1258
- training_request ['trainingInput' ]['masterConfig' ] = self ._master_config
1277
+ if self ._labels :
1278
+ training_request ['labels' ] = self ._labels
1259
1279
1260
1280
if self ._mode == 'DRY_RUN' :
1261
1281
self .log .info ('In dry_run mode.' )
0 commit comments