Skip to content

Commit 2d854c3

Browse files
authored
Add service_account to Google ML Engine operator (#11619)
1 parent ae06ad0 commit 2d854c3

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

β€Žairflow/providers/google/cloud/operators/mlengine.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,13 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
11151115
:param job_dir: A Google Cloud Storage path in which to store training
11161116
outputs and other data needed for training. (templated)
11171117
:type job_dir: str
1118+
:param service_account: Optional service account to use when running the training application.
1119+
(templated)
1120+
The specified service account must have the `iam.serviceAccounts.actAs` role. The
1121+
Google-managed Cloud ML Engine service account must have the `iam.serviceAccountAdmin` role
1122+
for the specified service account.
1123+
If set to None or missing, the Google-managed Cloud ML Engine service account will be used.
1124+
:type service_account: str
11181125
:param project_id: The Google Cloud project name within which MLEngine training job should run.
11191126
If set to None or missing, the default project_id from the Google Cloud connection is used.
11201127
(templated)
@@ -1156,6 +1163,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
11561163
'_runtime_version',
11571164
'_python_version',
11581165
'_job_dir',
1166+
'_service_account',
11591167
'_impersonation_chain',
11601168
]
11611169

@@ -1176,6 +1184,7 @@ def __init__(
11761184
runtime_version: Optional[str] = None,
11771185
python_version: Optional[str] = None,
11781186
job_dir: Optional[str] = None,
1187+
service_account: Optional[str] = None,
11791188
project_id: Optional[str] = None,
11801189
gcp_conn_id: str = 'google_cloud_default',
11811190
delegate_to: Optional[str] = None,
@@ -1197,6 +1206,7 @@ def __init__(
11971206
self._runtime_version = runtime_version
11981207
self._python_version = python_version
11991208
self._job_dir = job_dir
1209+
self._service_account = service_account
12001210
self._gcp_conn_id = gcp_conn_id
12011211
self._delegate_to = delegate_to
12021212
self._mode = mode
@@ -1244,6 +1254,9 @@ def execute(self, context):
12441254
if self._job_dir:
12451255
training_request['trainingInput']['jobDir'] = self._job_dir
12461256

1257+
if self._service_account:
1258+
training_request['trainingInput']['serviceAccount'] = self._service_account
1259+
12471260
if self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM":
12481261
training_request['trainingInput']['masterType'] = self._master_type
12491262

β€Žtests/providers/google/cloud/operators/test_mlengine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ def test_success_create_training_job_with_optional_args(self, mock_hook):
413413
training_input['trainingInput']['runtimeVersion'] = '1.6'
414414
training_input['trainingInput']['pythonVersion'] = '3.5'
415415
training_input['trainingInput']['jobDir'] = 'gs://some-bucket/jobs/test_training'
416+
training_input['trainingInput']['serviceAccount'] = 'test@serviceaccount.com'
416417

417418
success_response = self.TRAINING_INPUT.copy()
418419
success_response['state'] = 'SUCCEEDED'
@@ -423,6 +424,7 @@ def test_success_create_training_job_with_optional_args(self, mock_hook):
423424
runtime_version='1.6',
424425
python_version='3.5',
425426
job_dir='gs://some-bucket/jobs/test_training',
427+
service_account='test@serviceaccount.com',
426428
**self.TRAINING_DEFAULT_ARGS,
427429
)
428430
training_op.execute(MagicMock())

0 commit comments

Comments
 (0)