|
17 | 17 | from __future__ import annotations
|
18 | 18 |
|
19 | 19 | import json
|
| 20 | +import logging |
20 | 21 | from unittest import mock
|
21 | 22 |
|
| 23 | +import httplib2 |
22 | 24 | import pytest
|
| 25 | +from googleapiclient.errors import HttpError |
| 26 | +from paramiko.ssh_exception import SSHException |
23 | 27 |
|
| 28 | +from airflow import AirflowException |
24 | 29 | from airflow.models import Connection
|
25 | 30 | from airflow.providers.google.cloud.hooks.compute_ssh import ComputeEngineSSHHook
|
26 | 31 |
|
@@ -99,7 +104,45 @@ def test_get_conn_default_configuration(
|
99 | 104 | ]
|
100 | 105 | )
|
101 | 106 |
|
102 |
| - mock_compute_hook.return_value.set_instance_metadata.assert_not_called() |
| 107 | + @pytest.mark.parametrize( |
| 108 | + "exception_type, error_message", |
| 109 | + [(SSHException, r"Error occurred when establishing SSH connection using Paramiko")], |
| 110 | + ) |
| 111 | + @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook") |
| 112 | + @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook") |
| 113 | + @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko") |
| 114 | + @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient") |
| 115 | + @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineSSHHook._connect_to_instance") |
| 116 | + def test_get_conn_default_configuration_test_exceptions( |
| 117 | + self, |
| 118 | + mock_connect, |
| 119 | + mock_ssh_client, |
| 120 | + mock_paramiko, |
| 121 | + mock_os_login_hook, |
| 122 | + mock_compute_hook, |
| 123 | + exception_type, |
| 124 | + error_message, |
| 125 | + caplog, |
| 126 | + ): |
| 127 | + mock_paramiko.SSHException = Exception |
| 128 | + mock_paramiko.RSAKey.generate.return_value.get_name.return_value = "NAME" |
| 129 | + mock_paramiko.RSAKey.generate.return_value.get_base64.return_value = "AYZ" |
| 130 | + |
| 131 | + mock_compute_hook.return_value.project_id = TEST_PROJECT_ID |
| 132 | + mock_compute_hook.return_value.get_instance_address.return_value = EXTERNAL_IP |
| 133 | + |
| 134 | + mock_os_login_hook.return_value._get_credentials_email.return_value = "test-example@example.org" |
| 135 | + mock_os_login_hook.return_value.import_ssh_public_key.return_value.login_profile.posix_accounts = [ |
| 136 | + mock.MagicMock(username="test-username") |
| 137 | + ] |
| 138 | + |
| 139 | + hook = ComputeEngineSSHHook(instance_name=TEST_INSTANCE_NAME, zone=TEST_ZONE) |
| 140 | + mock_connect.side_effect = [exception_type, mock_ssh_client] |
| 141 | + |
| 142 | + with caplog.at_level(logging.INFO): |
| 143 | + hook.get_conn() |
| 144 | + assert error_message in caplog.text |
| 145 | + assert "Failed establish SSH connection" in caplog.text |
103 | 146 |
|
104 | 147 | @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
|
105 | 148 | @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
|
@@ -159,6 +202,49 @@ def test_get_conn_authorize_using_instance_metadata(
|
159 | 202 |
|
160 | 203 | mock_os_login_hook.return_value.import_ssh_public_key.assert_not_called()
|
161 | 204 |
|
| 205 | + @pytest.mark.parametrize( |
| 206 | + "exception_type, error_message", |
| 207 | + [ |
| 208 | + ( |
| 209 | + HttpError(resp=httplib2.Response({"status": 412}), content=b"Error content"), |
| 210 | + r"Error occurred when trying to update instance metadata", |
| 211 | + ), |
| 212 | + ( |
| 213 | + AirflowException("412 PRECONDITION FAILED"), |
| 214 | + r"Error occurred when trying to update instance metadata", |
| 215 | + ), |
| 216 | + ], |
| 217 | + ) |
| 218 | + @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook") |
| 219 | + @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook") |
| 220 | + @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko") |
| 221 | + @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient") |
| 222 | + def test_get_conn_authorize_using_instance_metadata_test_exception( |
| 223 | + self, |
| 224 | + mock_ssh_client, |
| 225 | + mock_paramiko, |
| 226 | + mock_os_login_hook, |
| 227 | + mock_compute_hook, |
| 228 | + exception_type, |
| 229 | + error_message, |
| 230 | + caplog, |
| 231 | + ): |
| 232 | + mock_paramiko.SSHException = Exception |
| 233 | + mock_paramiko.RSAKey.generate.return_value.get_name.return_value = "NAME" |
| 234 | + mock_paramiko.RSAKey.generate.return_value.get_base64.return_value = "AYZ" |
| 235 | + |
| 236 | + mock_compute_hook.return_value.project_id = TEST_PROJECT_ID |
| 237 | + mock_compute_hook.return_value.get_instance_address.return_value = EXTERNAL_IP |
| 238 | + |
| 239 | + mock_compute_hook.return_value.get_instance_info.return_value = {"metadata": {}} |
| 240 | + mock_compute_hook.return_value.set_instance_metadata.side_effect = [exception_type, None] |
| 241 | + |
| 242 | + hook = ComputeEngineSSHHook(instance_name=TEST_INSTANCE_NAME, zone=TEST_ZONE, use_oslogin=False) |
| 243 | + with caplog.at_level(logging.INFO): |
| 244 | + hook.get_conn() |
| 245 | + assert error_message in caplog.text |
| 246 | + assert "Failed establish SSH connection" in caplog.text |
| 247 | + |
162 | 248 | @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
|
163 | 249 | @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
|
164 | 250 | @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")
|
@@ -274,6 +360,41 @@ def test_get_conn_iap_tunnel(self, mock_ssh_client, mock_paramiko, mock_os_login
|
274 | 360 | f"--zone={TEST_ZONE} --verbosity=warning"
|
275 | 361 | )
|
276 | 362 |
|
| 363 | + @pytest.mark.parametrize( |
| 364 | + "exception_type, error_message", |
| 365 | + [(SSHException, r"Error occurred when establishing SSH connection using Paramiko")], |
| 366 | + ) |
| 367 | + @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook") |
| 368 | + @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook") |
| 369 | + @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko") |
| 370 | + @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient") |
| 371 | + @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineSSHHook._connect_to_instance") |
| 372 | + def test_get_conn_iap_tunnel_test_exception( |
| 373 | + self, |
| 374 | + mock_connect, |
| 375 | + mock_ssh_client, |
| 376 | + mock_paramiko, |
| 377 | + mock_os_login_hook, |
| 378 | + mock_compute_hook, |
| 379 | + exception_type, |
| 380 | + error_message, |
| 381 | + caplog, |
| 382 | + ): |
| 383 | + del mock_os_login_hook |
| 384 | + mock_paramiko.SSHException = Exception |
| 385 | + |
| 386 | + mock_compute_hook.return_value.project_id = TEST_PROJECT_ID |
| 387 | + |
| 388 | + hook = ComputeEngineSSHHook( |
| 389 | + instance_name=TEST_INSTANCE_NAME, zone=TEST_ZONE, use_oslogin=False, use_iap_tunnel=True |
| 390 | + ) |
| 391 | + mock_connect.side_effect = [exception_type, mock_ssh_client] |
| 392 | + |
| 393 | + with caplog.at_level(logging.INFO): |
| 394 | + hook.get_conn() |
| 395 | + assert error_message in caplog.text |
| 396 | + assert "Failed establish SSH connection" in caplog.text |
| 397 | + |
277 | 398 | @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
|
278 | 399 | @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
|
279 | 400 | @mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")
|
|
0 commit comments