@@ -1115,6 +1115,13 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
1115
1115
:param job_dir: A Google Cloud Storage path in which to store training
1116
1116
outputs and other data needed for training. (templated)
1117
1117
:type job_dir: str
1118
+ :param service_account: Optional service account to use when running the training application.
1119
+ (templated)
1120
+ The specified service account must have the `iam.serviceAccounts.actAs` role. The
1121
+ Google-managed Cloud ML Engine service account must have the `iam.serviceAccountAdmin` role
1122
+ for the specified service account.
1123
+ If set to None or missing, the Google-managed Cloud ML Engine service account will be used.
1124
+ :type service_account: str
1118
1125
:param project_id: The Google Cloud project name within which MLEngine training job should run.
1119
1126
If set to None or missing, the default project_id from the Google Cloud connection is used.
1120
1127
(templated)
@@ -1156,6 +1163,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
1156
1163
'_runtime_version' ,
1157
1164
'_python_version' ,
1158
1165
'_job_dir' ,
1166
+ '_service_account' ,
1159
1167
'_impersonation_chain' ,
1160
1168
]
1161
1169
@@ -1176,6 +1184,7 @@ def __init__(
1176
1184
runtime_version : Optional [str ] = None ,
1177
1185
python_version : Optional [str ] = None ,
1178
1186
job_dir : Optional [str ] = None ,
1187
+ service_account : Optional [str ] = None ,
1179
1188
project_id : Optional [str ] = None ,
1180
1189
gcp_conn_id : str = 'google_cloud_default' ,
1181
1190
delegate_to : Optional [str ] = None ,
@@ -1197,6 +1206,7 @@ def __init__(
1197
1206
self ._runtime_version = runtime_version
1198
1207
self ._python_version = python_version
1199
1208
self ._job_dir = job_dir
1209
+ self ._service_account = service_account
1200
1210
self ._gcp_conn_id = gcp_conn_id
1201
1211
self ._delegate_to = delegate_to
1202
1212
self ._mode = mode
@@ -1244,6 +1254,9 @@ def execute(self, context):
1244
1254
if self ._job_dir :
1245
1255
training_request ['trainingInput' ]['jobDir' ] = self ._job_dir
1246
1256
1257
+ if self ._service_account :
1258
+ training_request ['trainingInput' ]['serviceAccount' ] = self ._service_account
1259
+
1247
1260
if self ._scale_tier is not None and self ._scale_tier .upper () == "CUSTOM" :
1248
1261
training_request ['trainingInput' ]['masterType' ] = self ._master_type
1249
1262
0 commit comments