Skip to content

Commit 755fe52

Browse files
authored
[AIRFLOW-6915] Add AI Platform Console Link for MLEngineStartTrainingJobOperator (#7535)
1 parent 5bddf60 commit 755fe52

File tree

3 files changed

+114
-10
lines changed

3 files changed

+114
-10
lines changed

β€Žairflow/providers/google/cloud/operators/mlengine.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from typing import List, Optional
2525

2626
from airflow.exceptions import AirflowException
27-
from airflow.models import BaseOperator
27+
from airflow.models import BaseOperator, BaseOperatorLink
28+
from airflow.models.taskinstance import TaskInstance
2829
from airflow.providers.google.cloud.hooks.mlengine import MLEngineHook
2930
from airflow.utils.decorators import apply_defaults
3031

@@ -852,6 +853,23 @@ def execute(self, context):
852853
)
853854

854855

856+
class AIPlatformConsoleLink(BaseOperatorLink):
857+
"""
858+
Helper class for constructing AI Platform Console link.
859+
"""
860+
name = "AI Platform Console"
861+
862+
def get_link(self, operator, dttm):
863+
task_instance = TaskInstance(task=operator, execution_date=dttm)
864+
gcp_metadata_dict = task_instance.xcom_pull(task_ids=operator.task_id, key="gcp_metadata")
865+
if not gcp_metadata_dict:
866+
return ''
867+
job_id = gcp_metadata_dict['job_id']
868+
project_id = gcp_metadata_dict['project_id']
869+
console_link = f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}"
870+
return console_link
871+
872+
855873
class MLEngineStartTrainingJobOperator(BaseOperator):
856874
"""
857875
Operator for launching a MLEngine training job.
@@ -915,6 +933,10 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
915933
'_job_dir'
916934
]
917935

936+
operator_extra_links = (
937+
AIPlatformConsoleLink(),
938+
)
939+
918940
@apply_defaults
919941
def __init__(self, # pylint: disable=too-many-arguments
920942
job_id: str,
@@ -1016,6 +1038,12 @@ def check_existing_job(existing_job):
10161038
self.log.error('MLEngine training job failed: %s', str(finished_training_job))
10171039
raise RuntimeError(finished_training_job['errorMessage'])
10181040

1041+
gcp_metadata = {
1042+
"job_id": job_id,
1043+
"project_id": self._project_id,
1044+
}
1045+
context['task_instance'].xcom_push("gcp_metadata", gcp_metadata)
1046+
10191047

10201048
class MLEngineTrainingJobFailureOperator(BaseOperator):
10211049

β€Žairflow/serialization/serialized_objects.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
BUILTIN_OPERATOR_EXTRA_LINKS: List[str] = [
4343
"airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink",
4444
"airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink",
45+
"airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink",
4546
"airflow.providers.qubole.operators.qubole.QDSLink"
4647
]
4748

β€Žtests/providers/google/cloud/operators/test_mlengine.py

Lines changed: 84 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,27 @@
1818
import copy
1919
import datetime
2020
import unittest
21-
from unittest.mock import ANY, patch
21+
from unittest.mock import ANY, MagicMock, patch
2222

2323
import httplib2
2424
from googleapiclient.errors import HttpError
2525

2626
from airflow.exceptions import AirflowException
27+
from airflow.models import TaskInstance
2728
from airflow.models.dag import DAG
2829
from airflow.providers.google.cloud.operators.mlengine import (
29-
MLEngineCreateModelOperator, MLEngineCreateVersionOperator, MLEngineDeleteModelOperator,
30-
MLEngineDeleteVersionOperator, MLEngineGetModelOperator, MLEngineListVersionsOperator,
31-
MLEngineManageModelOperator, MLEngineManageVersionOperator, MLEngineSetDefaultVersionOperator,
32-
MLEngineStartBatchPredictionJobOperator, MLEngineStartTrainingJobOperator,
33-
MLEngineTrainingJobFailureOperator,
30+
AIPlatformConsoleLink, MLEngineCreateModelOperator, MLEngineCreateVersionOperator,
31+
MLEngineDeleteModelOperator, MLEngineDeleteVersionOperator, MLEngineGetModelOperator,
32+
MLEngineListVersionsOperator, MLEngineManageModelOperator, MLEngineManageVersionOperator,
33+
MLEngineSetDefaultVersionOperator, MLEngineStartBatchPredictionJobOperator,
34+
MLEngineStartTrainingJobOperator, MLEngineTrainingJobFailureOperator,
3435
)
36+
from airflow.serialization.serialized_objects import SerializedDAG
37+
from airflow.utils.dates import days_ago
3538

3639
DEFAULT_DATE = datetime.datetime(2017, 6, 6)
3740

41+
TEST_DAG_ID = "test-mlengine-operators"
3842
TEST_PROJECT_ID = "test-project-id"
3943
TEST_MODEL_NAME = "test-model-name"
4044
TEST_VERSION_NAME = "test-version"
@@ -304,7 +308,8 @@ class TestMLEngineTrainingOperator(unittest.TestCase):
304308
'training_args': '--some_arg=\'aaa\'',
305309
'region': 'us-east1',
306310
'scale_tier': 'STANDARD_1',
307-
'task_id': 'test-training'
311+
'task_id': 'test-training',
312+
'start_date': days_ago(1)
308313
}
309314
TRAINING_INPUT = {
310315
'jobId': 'test_training',
@@ -317,6 +322,9 @@ class TestMLEngineTrainingOperator(unittest.TestCase):
317322
}
318323
}
319324

325+
def setUp(self):
326+
self.dag = DAG(TEST_DAG_ID, default_args=self.TRAINING_DEFAULT_ARGS)
327+
320328
@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
321329
def test_success_create_training_job(self, mock_hook):
322330
success_response = self.TRAINING_INPUT.copy()
@@ -326,7 +334,7 @@ def test_success_create_training_job(self, mock_hook):
326334

327335
training_op = MLEngineStartTrainingJobOperator(
328336
**self.TRAINING_DEFAULT_ARGS)
329-
training_op.execute(None)
337+
training_op.execute(MagicMock())
330338

331339
mock_hook.assert_called_once_with(
332340
gcp_conn_id='google_cloud_default', delegate_to=None)
@@ -352,7 +360,7 @@ def test_success_create_training_job_with_optional_args(self, mock_hook):
352360
python_version='3.5',
353361
job_dir='gs://some-bucket/jobs/test_training',
354362
**self.TRAINING_DEFAULT_ARGS)
355-
training_op.execute(None)
363+
training_op.execute(MagicMock())
356364

357365
mock_hook.assert_called_once_with(gcp_conn_id='google_cloud_default', delegate_to=None)
358366
# Make sure only 'create_job' is invoked on hook instance
@@ -404,6 +412,73 @@ def test_failed_job_error(self, mock_hook):
404412
project_id='test-project', job=self.TRAINING_INPUT, use_existing_job_fn=ANY)
405413
self.assertEqual('A failure message', str(context.exception))
406414

415+
@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
416+
def test_console_extra_link(self, mock_hook):
417+
training_op = MLEngineStartTrainingJobOperator(
418+
**self.TRAINING_DEFAULT_ARGS)
419+
420+
ti = TaskInstance(
421+
task=training_op,
422+
execution_date=DEFAULT_DATE,
423+
)
424+
425+
job_id = self.TRAINING_DEFAULT_ARGS['job_id']
426+
project_id = self.TRAINING_DEFAULT_ARGS['project_id']
427+
gcp_metadata = {
428+
"job_id": job_id,
429+
"project_id": project_id,
430+
}
431+
ti.xcom_push(key='gcp_metadata', value=gcp_metadata)
432+
433+
self.assertEqual(
434+
f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}",
435+
training_op.get_extra_links(DEFAULT_DATE, AIPlatformConsoleLink.name),
436+
)
437+
438+
self.assertEqual(
439+
'',
440+
training_op.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name),
441+
)
442+
443+
def test_console_extra_link_serialized_field(self):
444+
with self.dag:
445+
training_op = MLEngineStartTrainingJobOperator(**self.TRAINING_DEFAULT_ARGS)
446+
serialized_dag = SerializedDAG.to_dict(self.dag)
447+
dag = SerializedDAG.from_dict(serialized_dag)
448+
simple_task = dag.task_dict[self.TRAINING_DEFAULT_ARGS['task_id']]
449+
450+
# Check Serialized version of operator link
451+
self.assertEqual(
452+
serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
453+
[{"airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink": {}}]
454+
)
455+
456+
# Check DeSerialized version of operator link
457+
self.assertIsInstance(list(simple_task.operator_extra_links)[0], AIPlatformConsoleLink)
458+
459+
job_id = self.TRAINING_DEFAULT_ARGS['job_id']
460+
project_id = self.TRAINING_DEFAULT_ARGS['project_id']
461+
gcp_metadata = {
462+
"job_id": job_id,
463+
"project_id": project_id,
464+
}
465+
466+
ti = TaskInstance(
467+
task=training_op,
468+
execution_date=DEFAULT_DATE,
469+
)
470+
ti.xcom_push(key='gcp_metadata', value=gcp_metadata)
471+
472+
self.assertEqual(
473+
f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}",
474+
simple_task.get_extra_links(DEFAULT_DATE, AIPlatformConsoleLink.name),
475+
)
476+
477+
self.assertEqual(
478+
'',
479+
simple_task.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name),
480+
)
481+
407482

408483
class TestMLEngineTrainingJobFailureOperator(unittest.TestCase):
409484

0 commit comments

Comments
 (0)