18
18
import copy
19
19
import datetime
20
20
import unittest
21
- from unittest .mock import ANY , patch
21
+ from unittest .mock import ANY , MagicMock , patch
22
22
23
23
import httplib2
24
24
from googleapiclient .errors import HttpError
25
25
26
26
from airflow .exceptions import AirflowException
27
+ from airflow .models import TaskInstance
27
28
from airflow .models .dag import DAG
28
29
from airflow .providers .google .cloud .operators .mlengine import (
29
- MLEngineCreateModelOperator , MLEngineCreateVersionOperator , MLEngineDeleteModelOperator ,
30
- MLEngineDeleteVersionOperator , MLEngineGetModelOperator , MLEngineListVersionsOperator ,
31
- MLEngineManageModelOperator , MLEngineManageVersionOperator , MLEngineSetDefaultVersionOperator ,
32
- MLEngineStartBatchPredictionJobOperator , MLEngineStartTrainingJobOperator ,
33
- MLEngineTrainingJobFailureOperator ,
30
+ AIPlatformConsoleLink , MLEngineCreateModelOperator , MLEngineCreateVersionOperator ,
31
+ MLEngineDeleteModelOperator , MLEngineDeleteVersionOperator , MLEngineGetModelOperator ,
32
+ MLEngineListVersionsOperator , MLEngineManageModelOperator , MLEngineManageVersionOperator ,
33
+ MLEngineSetDefaultVersionOperator , MLEngineStartBatchPredictionJobOperator ,
34
+ MLEngineStartTrainingJobOperator , MLEngineTrainingJobFailureOperator ,
34
35
)
36
+ from airflow .serialization .serialized_objects import SerializedDAG
37
+ from airflow .utils .dates import days_ago
35
38
36
39
DEFAULT_DATE = datetime .datetime (2017 , 6 , 6 )
37
40
41
+ TEST_DAG_ID = "test-mlengine-operators"
38
42
TEST_PROJECT_ID = "test-project-id"
39
43
TEST_MODEL_NAME = "test-model-name"
40
44
TEST_VERSION_NAME = "test-version"
@@ -304,7 +308,8 @@ class TestMLEngineTrainingOperator(unittest.TestCase):
304
308
'training_args' : '--some_arg=\' aaa\' ' ,
305
309
'region' : 'us-east1' ,
306
310
'scale_tier' : 'STANDARD_1' ,
307
- 'task_id' : 'test-training'
311
+ 'task_id' : 'test-training' ,
312
+ 'start_date' : days_ago (1 )
308
313
}
309
314
TRAINING_INPUT = {
310
315
'jobId' : 'test_training' ,
@@ -317,6 +322,9 @@ class TestMLEngineTrainingOperator(unittest.TestCase):
317
322
}
318
323
}
319
324
325
+ def setUp (self ):
326
+ self .dag = DAG (TEST_DAG_ID , default_args = self .TRAINING_DEFAULT_ARGS )
327
+
320
328
@patch ('airflow.providers.google.cloud.operators.mlengine.MLEngineHook' )
321
329
def test_success_create_training_job (self , mock_hook ):
322
330
success_response = self .TRAINING_INPUT .copy ()
@@ -326,7 +334,7 @@ def test_success_create_training_job(self, mock_hook):
326
334
327
335
training_op = MLEngineStartTrainingJobOperator (
328
336
** self .TRAINING_DEFAULT_ARGS )
329
- training_op .execute (None )
337
+ training_op .execute (MagicMock () )
330
338
331
339
mock_hook .assert_called_once_with (
332
340
gcp_conn_id = 'google_cloud_default' , delegate_to = None )
@@ -352,7 +360,7 @@ def test_success_create_training_job_with_optional_args(self, mock_hook):
352
360
python_version = '3.5' ,
353
361
job_dir = 'gs://some-bucket/jobs/test_training' ,
354
362
** self .TRAINING_DEFAULT_ARGS )
355
- training_op .execute (None )
363
+ training_op .execute (MagicMock () )
356
364
357
365
mock_hook .assert_called_once_with (gcp_conn_id = 'google_cloud_default' , delegate_to = None )
358
366
# Make sure only 'create_job' is invoked on hook instance
@@ -404,6 +412,73 @@ def test_failed_job_error(self, mock_hook):
404
412
project_id = 'test-project' , job = self .TRAINING_INPUT , use_existing_job_fn = ANY )
405
413
self .assertEqual ('A failure message' , str (context .exception ))
406
414
415
+ @patch ('airflow.providers.google.cloud.operators.mlengine.MLEngineHook' )
416
+ def test_console_extra_link (self , mock_hook ):
417
+ training_op = MLEngineStartTrainingJobOperator (
418
+ ** self .TRAINING_DEFAULT_ARGS )
419
+
420
+ ti = TaskInstance (
421
+ task = training_op ,
422
+ execution_date = DEFAULT_DATE ,
423
+ )
424
+
425
+ job_id = self .TRAINING_DEFAULT_ARGS ['job_id' ]
426
+ project_id = self .TRAINING_DEFAULT_ARGS ['project_id' ]
427
+ gcp_metadata = {
428
+ "job_id" : job_id ,
429
+ "project_id" : project_id ,
430
+ }
431
+ ti .xcom_push (key = 'gcp_metadata' , value = gcp_metadata )
432
+
433
+ self .assertEqual (
434
+ f"https://console.cloud.google.com/ai-platform/jobs/{ job_id } ?project={ project_id } " ,
435
+ training_op .get_extra_links (DEFAULT_DATE , AIPlatformConsoleLink .name ),
436
+ )
437
+
438
+ self .assertEqual (
439
+ '' ,
440
+ training_op .get_extra_links (datetime .datetime (2019 , 1 , 1 ), AIPlatformConsoleLink .name ),
441
+ )
442
+
443
+ def test_console_extra_link_serialized_field (self ):
444
+ with self .dag :
445
+ training_op = MLEngineStartTrainingJobOperator (** self .TRAINING_DEFAULT_ARGS )
446
+ serialized_dag = SerializedDAG .to_dict (self .dag )
447
+ dag = SerializedDAG .from_dict (serialized_dag )
448
+ simple_task = dag .task_dict [self .TRAINING_DEFAULT_ARGS ['task_id' ]]
449
+
450
+ # Check Serialized version of operator link
451
+ self .assertEqual (
452
+ serialized_dag ["dag" ]["tasks" ][0 ]["_operator_extra_links" ],
453
+ [{"airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink" : {}}]
454
+ )
455
+
456
+ # Check DeSerialized version of operator link
457
+ self .assertIsInstance (list (simple_task .operator_extra_links )[0 ], AIPlatformConsoleLink )
458
+
459
+ job_id = self .TRAINING_DEFAULT_ARGS ['job_id' ]
460
+ project_id = self .TRAINING_DEFAULT_ARGS ['project_id' ]
461
+ gcp_metadata = {
462
+ "job_id" : job_id ,
463
+ "project_id" : project_id ,
464
+ }
465
+
466
+ ti = TaskInstance (
467
+ task = training_op ,
468
+ execution_date = DEFAULT_DATE ,
469
+ )
470
+ ti .xcom_push (key = 'gcp_metadata' , value = gcp_metadata )
471
+
472
+ self .assertEqual (
473
+ f"https://console.cloud.google.com/ai-platform/jobs/{ job_id } ?project={ project_id } " ,
474
+ simple_task .get_extra_links (DEFAULT_DATE , AIPlatformConsoleLink .name ),
475
+ )
476
+
477
+ self .assertEqual (
478
+ '' ,
479
+ simple_task .get_extra_links (datetime .datetime (2019 , 1 , 1 ), AIPlatformConsoleLink .name ),
480
+ )
481
+
407
482
408
483
class TestMLEngineTrainingJobFailureOperator (unittest .TestCase ):
409
484
0 commit comments