Skip to content

Commit f6518dd

Browse files
authored
Generalize MLEngineStartTrainingJobOperator to custom images (#13318)
1 parent 6e1a6ff commit f6518dd

File tree

2 files changed

+103
-37
lines changed

2 files changed

+103
-37
lines changed

β€Žairflow/providers/google/cloud/operators/mlengine.py

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,26 +1080,30 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
10801080
:param job_id: A unique templated id for the submitted Google MLEngine
10811081
training job. (templated)
10821082
:type job_id: str
1083-
:param package_uris: A list of package locations for MLEngine training job,
1084-
which should include the main training program + any additional
1085-
dependencies. (templated)
1086-
:type package_uris: List[str]
1087-
:param training_python_module: The Python module name to run within MLEngine
1088-
training job after installing 'package_uris' packages. (templated)
1089-
:type training_python_module: str
1090-
:param training_args: A list of templated command line arguments to pass to
1091-
the MLEngine training program. (templated)
1092-
:type training_args: List[str]
10931083
:param region: The Google Compute Engine region to run the MLEngine training
10941084
job in (templated).
10951085
:type region: str
1086+
:param package_uris: A list of Python package locations for the training
1087+
job, which should include the main training program and any additional
1088+
dependencies. This is mutually exclusive with a custom image specified
1089+
via master_config. (templated)
1090+
:type package_uris: List[str]
1091+
:param training_python_module: The name of the Python module to run within
1092+
the training job after installing the packages. This is mutually
1093+
exclusive with a custom image specified via master_config. (templated)
1094+
:type training_python_module: str
1095+
:param training_args: A list of command-line arguments to pass to the
1096+
training program. (templated)
1097+
:type training_args: List[str]
10961098
:param scale_tier: Resource tier for MLEngine training job. (templated)
10971099
:type scale_tier: str
1098-
:param master_type: Cloud ML Engine machine name.
1099-
Must be set when scale_tier is CUSTOM. (templated)
1100+
:param master_type: The type of virtual machine to use for the master
1101+
worker. It must be set whenever scale_tier is CUSTOM. (templated)
11001102
:type master_type: str
1101-
:param master_config: Cloud ML Engine master config.
1102-
master_type must be set if master_config is provided. (templated)
1103+
:param master_config: The configuration for the master worker. If this is
1104+
provided, master_type must be set as well. If a custom image is
1105+
specified, this is mutually exclusive with package_uris and
1106+
training_python_module. (templated)
11031107
:type master_type: dict
11041108
:param runtime_version: The Google Cloud ML runtime version to use for
11051109
training. (templated)
@@ -1147,10 +1151,10 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
11471151
template_fields = [
11481152
'_project_id',
11491153
'_job_id',
1154+
'_region',
11501155
'_package_uris',
11511156
'_training_python_module',
11521157
'_training_args',
1153-
'_region',
11541158
'_scale_tier',
11551159
'_master_type',
11561160
'_master_config',
@@ -1168,10 +1172,10 @@ def __init__(
11681172
self, # pylint: disable=too-many-arguments
11691173
*,
11701174
job_id: str,
1171-
package_uris: List[str],
1172-
training_python_module: str,
1173-
training_args: List[str],
11741175
region: str,
1176+
package_uris: List[str] = None,
1177+
training_python_module: str = None,
1178+
training_args: List[str] = None,
11751179
scale_tier: Optional[str] = None,
11761180
master_type: Optional[str] = None,
11771181
master_config: Optional[Dict] = None,
@@ -1190,10 +1194,10 @@ def __init__(
11901194
super().__init__(**kwargs)
11911195
self._project_id = project_id
11921196
self._job_id = job_id
1197+
self._region = region
11931198
self._package_uris = package_uris
11941199
self._training_python_module = training_python_module
11951200
self._training_args = training_args
1196-
self._region = region
11971201
self._scale_tier = scale_tier
11981202
self._master_type = master_type
11991203
self._master_config = master_config
@@ -1207,37 +1211,56 @@ def __init__(
12071211
self._labels = labels
12081212
self._impersonation_chain = impersonation_chain
12091213

1214+
custom = self._scale_tier is not None and self._scale_tier.upper() == 'CUSTOM'
1215+
custom_image = (
1216+
custom
1217+
and self._master_config is not None
1218+
and self._master_config.get('imageUri', None) is not None
1219+
)
1220+
12101221
if not self._project_id:
12111222
raise AirflowException('Google Cloud project id is required.')
12121223
if not self._job_id:
12131224
raise AirflowException('An unique job id is required for Google MLEngine training job.')
1214-
if not package_uris:
1215-
raise AirflowException('At least one python package is required for MLEngine Training job.')
1216-
if not training_python_module:
1217-
raise AirflowException(
1218-
'Python module name to run after installing required packages is required.'
1219-
)
12201225
if not self._region:
12211226
raise AirflowException('Google Compute Engine region is required.')
1222-
if self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM" and not self._master_type:
1227+
if custom and not self._master_type:
12231228
raise AirflowException('master_type must be set when scale_tier is CUSTOM')
12241229
if self._master_config and not self._master_type:
12251230
raise AirflowException('master_type must be set when master_config is provided')
1231+
if not (package_uris and training_python_module) and not custom_image:
1232+
raise AirflowException(
1233+
'Either a Python package with a Python module or a custom Docker image should be provided.'
1234+
)
1235+
if (package_uris or training_python_module) and custom_image:
1236+
raise AirflowException(
1237+
'Either a Python package with a Python module or '
1238+
'a custom Docker image should be provided but not both.'
1239+
)
12261240

12271241
def execute(self, context):
12281242
job_id = _normalize_mlengine_job_id(self._job_id)
12291243
training_request = {
12301244
'jobId': job_id,
12311245
'trainingInput': {
12321246
'scaleTier': self._scale_tier,
1233-
'packageUris': self._package_uris,
1234-
'pythonModule': self._training_python_module,
12351247
'region': self._region,
1236-
'args': self._training_args,
12371248
},
12381249
}
1239-
if self._labels:
1240-
training_request['labels'] = self._labels
1250+
if self._package_uris:
1251+
training_request['trainingInput']['packageUris'] = self._package_uris
1252+
1253+
if self._training_python_module:
1254+
training_request['trainingInput']['pythonModule'] = self._training_python_module
1255+
1256+
if self._training_args:
1257+
training_request['trainingInput']['args'] = self._training_args
1258+
1259+
if self._master_type:
1260+
training_request['trainingInput']['masterType'] = self._master_type
1261+
1262+
if self._master_config:
1263+
training_request['trainingInput']['masterConfig'] = self._master_config
12411264

12421265
if self._runtime_version:
12431266
training_request['trainingInput']['runtimeVersion'] = self._runtime_version
@@ -1251,11 +1274,8 @@ def execute(self, context):
12511274
if self._service_account:
12521275
training_request['trainingInput']['serviceAccount'] = self._service_account
12531276

1254-
if self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM":
1255-
training_request['trainingInput']['masterType'] = self._master_type
1256-
1257-
if self._master_config:
1258-
training_request['trainingInput']['masterConfig'] = self._master_config
1277+
if self._labels:
1278+
training_request['labels'] = self._labels
12591279

12601280
if self._mode == 'DRY_RUN':
12611281
self.log.info('In dry_run mode.')

β€Žtests/providers/google/cloud/operators/test_mlengine.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def test_failed_job_error(self, mock_hook):
315315
self.assertEqual('A failure message', str(context.exception))
316316

317317

318-
class TestMLEngineTrainingOperator(unittest.TestCase):
318+
class TestMLEngineStartTrainingJobOperator(unittest.TestCase):
319319
TRAINING_DEFAULT_ARGS = {
320320
'project_id': 'test-project',
321321
'job_id': 'test_training',
@@ -407,6 +407,52 @@ def test_success_create_training_job_with_master_config(self, mock_hook):
407407
project_id='test-project', job=training_input, use_existing_job_fn=ANY
408408
)
409409

410+
@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
411+
def test_success_create_training_job_with_master_image(self, hook):
412+
arguments = {
413+
'project_id': 'test-project',
414+
'job_id': 'test_training',
415+
'region': 'europe-west1',
416+
'scale_tier': 'CUSTOM',
417+
'master_type': 'n1-standard-8',
418+
'master_config': {
419+
'imageUri': 'eu.gcr.io/test-project/test-image:test-version',
420+
},
421+
'task_id': 'test-training',
422+
'start_date': days_ago(1),
423+
}
424+
request = {
425+
'jobId': 'test_training',
426+
'trainingInput': {
427+
'region': 'europe-west1',
428+
'scaleTier': 'CUSTOM',
429+
'masterType': 'n1-standard-8',
430+
'masterConfig': {
431+
'imageUri': 'eu.gcr.io/test-project/test-image:test-version',
432+
},
433+
},
434+
}
435+
436+
response = request.copy()
437+
response['state'] = 'SUCCEEDED'
438+
hook_instance = hook.return_value
439+
hook_instance.create_job.return_value = response
440+
441+
operator = MLEngineStartTrainingJobOperator(**arguments)
442+
operator.execute(MagicMock())
443+
444+
hook.assert_called_once_with(
445+
gcp_conn_id='google_cloud_default',
446+
delegate_to=None,
447+
impersonation_chain=None,
448+
)
449+
self.assertEqual(len(hook_instance.mock_calls), 1)
450+
hook_instance.create_job.assert_called_once_with(
451+
project_id='test-project',
452+
job=request,
453+
use_existing_job_fn=ANY,
454+
)
455+
410456
@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
411457
def test_success_create_training_job_with_optional_args(self, mock_hook):
412458
training_input = copy.deepcopy(self.TRAINING_INPUT)

0 commit comments

Comments
 (0)