Skip to content

Commit 426a798

Browse files
authored
Imrove support for laatest API in MLEngineStartTrainingJobOperator (#7812)
1 parent cdf1809 commit 426a798

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

β€Žairflow/providers/google/cloud/operators/mlengine.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,8 +1027,16 @@ def execute(self, context):
10271027
# Helper method to check if the existing job's training input is the
10281028
# same as the request we get here.
10291029
def check_existing_job(existing_job):
1030-
return existing_job.get('trainingInput', None) == \
1031-
training_request['trainingInput']
1030+
existing_training_input = existing_job.get('trainingInput', None)
1031+
requested_training_input = training_request['trainingInput']
1032+
if 'scaleTier' not in existing_training_input:
1033+
existing_training_input['scaleTier'] = None
1034+
1035+
existing_training_input['args'] = existing_training_input.get('args', None)
1036+
requested_training_input["args"] = requested_training_input['args'] \
1037+
if requested_training_input["args"] else None
1038+
1039+
return existing_training_input == requested_training_input
10321040

10331041
finished_training_job = hook.create_job(
10341042
project_id=self._project_id, job=training_request, use_existing_job_fn=check_existing_job

0 commit comments

Comments
 (0)