Skip to content

Commit 0c894db

Browse files
Handle multiple connections using exceptions (#32365)
1 parent df74553 commit 0c894db

File tree

5 files changed

+478
-31
lines changed

5 files changed

+478
-31
lines changed

β€Žairflow/providers/google/cloud/hooks/compute_ssh.py

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import random
1920
import shlex
2021
import time
2122
from functools import cached_property
2223
from io import StringIO
2324
from typing import Any
2425

25-
from google.api_core.retry import exponential_sleep_generator
26+
from googleapiclient.errors import HttpError
27+
from paramiko.ssh_exception import SSHException
2628

2729
from airflow import AirflowException
2830
from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook
@@ -82,6 +84,8 @@ class ComputeEngineSSHHook(SSHHook):
8284
keys are managed using instance metadata
8385
:param expire_time: The maximum amount of time in seconds before the private key expires
8486
:param gcp_conn_id: The connection id to use when fetching connection information
87+
:param max_retries: Maximum number of retries the process will try to establish connection to instance.
88+
Could be decreased/increased by user based on the amount of parallel SSH connections to the instance.
8589
"""
8690

8791
conn_name_attr = "gcp_conn_id"
@@ -109,6 +113,7 @@ def __init__(
109113
use_oslogin: bool = True,
110114
expire_time: int = 300,
111115
cmd_timeout: int | ArgNotSet = NOTSET,
116+
max_retries: int = 10,
112117
**kwargs,
113118
) -> None:
114119
if kwargs.get("delegate_to") is not None:
@@ -129,6 +134,7 @@ def __init__(
129134
self.expire_time = expire_time
130135
self.gcp_conn_id = gcp_conn_id
131136
self.cmd_timeout = cmd_timeout
137+
self.max_retries = max_retries
132138
self._conn: Any | None = None
133139

134140
@cached_property
@@ -225,40 +231,59 @@ def get_conn(self) -> paramiko.SSHClient:
225231
hostname = self.hostname
226232

227233
privkey, pubkey = self._generate_ssh_key(self.user)
228-
if self.use_oslogin:
229-
user = self._authorize_os_login(pubkey)
230-
else:
231-
user = self.user
232-
self._authorize_compute_engine_instance_metadata(pubkey)
233-
234-
proxy_command = None
235-
if self.use_iap_tunnel:
236-
proxy_command_args = [
237-
"gcloud",
238-
"compute",
239-
"start-iap-tunnel",
240-
str(self.instance_name),
241-
"22",
242-
"--listen-on-stdin",
243-
f"--project={self.project_id}",
244-
f"--zone={self.zone}",
245-
"--verbosity=warning",
246-
]
247-
proxy_command = " ".join(shlex.quote(arg) for arg in proxy_command_args)
248-
249-
sshclient = self._connect_to_instance(user, hostname, privkey, proxy_command)
234+
235+
max_delay = 10
236+
sshclient = None
237+
for retry in range(self.max_retries + 1):
238+
try:
239+
if self.use_oslogin:
240+
user = self._authorize_os_login(pubkey)
241+
else:
242+
user = self.user
243+
self._authorize_compute_engine_instance_metadata(pubkey)
244+
proxy_command = None
245+
if self.use_iap_tunnel:
246+
proxy_command_args = [
247+
"gcloud",
248+
"compute",
249+
"start-iap-tunnel",
250+
str(self.instance_name),
251+
"22",
252+
"--listen-on-stdin",
253+
f"--project={self.project_id}",
254+
f"--zone={self.zone}",
255+
"--verbosity=warning",
256+
]
257+
proxy_command = " ".join(shlex.quote(arg) for arg in proxy_command_args)
258+
sshclient = self._connect_to_instance(user, hostname, privkey, proxy_command)
259+
break
260+
except (HttpError, AirflowException, SSHException) as exc:
261+
if (isinstance(exc, HttpError) and exc.resp.status == 412) or (
262+
isinstance(exc, AirflowException) and "412 PRECONDITION FAILED" in str(exc)
263+
):
264+
self.log.info("Error occurred when trying to update instance metadata: %s", exc)
265+
elif isinstance(exc, SSHException):
266+
self.log.info("Error occurred when establishing SSH connection using Paramiko: %s", exc)
267+
else:
268+
raise
269+
if retry == self.max_retries:
270+
raise AirflowException("Maximum retries exceeded. Aborting operation.")
271+
delay = random.randint(0, max_delay)
272+
self.log.info(f"Failed establish SSH connection, waiting {delay} seconds to retry...")
273+
time.sleep(delay)
274+
if not sshclient:
275+
raise AirflowException("Unable to establish SSH connection.")
250276
return sshclient
251277

252278
def _connect_to_instance(self, user, hostname, pkey, proxy_command) -> paramiko.SSHClient:
253279
self.log.info("Opening remote connection to host: username=%s, hostname=%s", user, hostname)
254-
max_time_to_wait = 10
255-
for time_to_wait in exponential_sleep_generator(initial=1, maximum=max_time_to_wait):
280+
max_time_to_wait = 5
281+
for time_to_wait in range(max_time_to_wait + 1):
256282
try:
257283
client = _GCloudAuthorizedSSHClient(self._compute_hook)
258284
# Default is RejectPolicy
259285
# No known host checking since we are not storing privatekey
260286
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
261-
262287
client.connect(
263288
hostname=hostname,
264289
username=user,
@@ -268,8 +293,6 @@ def _connect_to_instance(self, user, hostname, pkey, proxy_command) -> paramiko.
268293
)
269294
return client
270295
except paramiko.SSHException:
271-
# exponential_sleep_generator is an infinite generator, so we need to
272-
# check the end condition.
273296
if time_to_wait == max_time_to_wait:
274297
raise
275298
self.log.info("Failed to connect. Waiting %ds to retry", time_to_wait)

β€Žtests/providers/google/cloud/hooks/test_compute_ssh.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@
1717
from __future__ import annotations
1818

1919
import json
20+
import logging
2021
from unittest import mock
2122

23+
import httplib2
2224
import pytest
25+
from googleapiclient.errors import HttpError
26+
from paramiko.ssh_exception import SSHException
2327

28+
from airflow import AirflowException
2429
from airflow.models import Connection
2530
from airflow.providers.google.cloud.hooks.compute_ssh import ComputeEngineSSHHook
2631

@@ -99,7 +104,45 @@ def test_get_conn_default_configuration(
99104
]
100105
)
101106

102-
mock_compute_hook.return_value.set_instance_metadata.assert_not_called()
107+
@pytest.mark.parametrize(
108+
"exception_type, error_message",
109+
[(SSHException, r"Error occurred when establishing SSH connection using Paramiko")],
110+
)
111+
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
112+
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
113+
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")
114+
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient")
115+
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineSSHHook._connect_to_instance")
116+
def test_get_conn_default_configuration_test_exceptions(
117+
self,
118+
mock_connect,
119+
mock_ssh_client,
120+
mock_paramiko,
121+
mock_os_login_hook,
122+
mock_compute_hook,
123+
exception_type,
124+
error_message,
125+
caplog,
126+
):
127+
mock_paramiko.SSHException = Exception
128+
mock_paramiko.RSAKey.generate.return_value.get_name.return_value = "NAME"
129+
mock_paramiko.RSAKey.generate.return_value.get_base64.return_value = "AYZ"
130+
131+
mock_compute_hook.return_value.project_id = TEST_PROJECT_ID
132+
mock_compute_hook.return_value.get_instance_address.return_value = EXTERNAL_IP
133+
134+
mock_os_login_hook.return_value._get_credentials_email.return_value = "test-example@example.org"
135+
mock_os_login_hook.return_value.import_ssh_public_key.return_value.login_profile.posix_accounts = [
136+
mock.MagicMock(username="test-username")
137+
]
138+
139+
hook = ComputeEngineSSHHook(instance_name=TEST_INSTANCE_NAME, zone=TEST_ZONE)
140+
mock_connect.side_effect = [exception_type, mock_ssh_client]
141+
142+
with caplog.at_level(logging.INFO):
143+
hook.get_conn()
144+
assert error_message in caplog.text
145+
assert "Failed establish SSH connection" in caplog.text
103146

104147
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
105148
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
@@ -159,6 +202,49 @@ def test_get_conn_authorize_using_instance_metadata(
159202

160203
mock_os_login_hook.return_value.import_ssh_public_key.assert_not_called()
161204

205+
@pytest.mark.parametrize(
206+
"exception_type, error_message",
207+
[
208+
(
209+
HttpError(resp=httplib2.Response({"status": 412}), content=b"Error content"),
210+
r"Error occurred when trying to update instance metadata",
211+
),
212+
(
213+
AirflowException("412 PRECONDITION FAILED"),
214+
r"Error occurred when trying to update instance metadata",
215+
),
216+
],
217+
)
218+
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
219+
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
220+
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")
221+
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient")
222+
def test_get_conn_authorize_using_instance_metadata_test_exception(
223+
self,
224+
mock_ssh_client,
225+
mock_paramiko,
226+
mock_os_login_hook,
227+
mock_compute_hook,
228+
exception_type,
229+
error_message,
230+
caplog,
231+
):
232+
mock_paramiko.SSHException = Exception
233+
mock_paramiko.RSAKey.generate.return_value.get_name.return_value = "NAME"
234+
mock_paramiko.RSAKey.generate.return_value.get_base64.return_value = "AYZ"
235+
236+
mock_compute_hook.return_value.project_id = TEST_PROJECT_ID
237+
mock_compute_hook.return_value.get_instance_address.return_value = EXTERNAL_IP
238+
239+
mock_compute_hook.return_value.get_instance_info.return_value = {"metadata": {}}
240+
mock_compute_hook.return_value.set_instance_metadata.side_effect = [exception_type, None]
241+
242+
hook = ComputeEngineSSHHook(instance_name=TEST_INSTANCE_NAME, zone=TEST_ZONE, use_oslogin=False)
243+
with caplog.at_level(logging.INFO):
244+
hook.get_conn()
245+
assert error_message in caplog.text
246+
assert "Failed establish SSH connection" in caplog.text
247+
162248
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
163249
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
164250
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")
@@ -274,6 +360,41 @@ def test_get_conn_iap_tunnel(self, mock_ssh_client, mock_paramiko, mock_os_login
274360
f"--zone={TEST_ZONE} --verbosity=warning"
275361
)
276362

363+
@pytest.mark.parametrize(
364+
"exception_type, error_message",
365+
[(SSHException, r"Error occurred when establishing SSH connection using Paramiko")],
366+
)
367+
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
368+
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
369+
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")
370+
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh._GCloudAuthorizedSSHClient")
371+
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineSSHHook._connect_to_instance")
372+
def test_get_conn_iap_tunnel_test_exception(
373+
self,
374+
mock_connect,
375+
mock_ssh_client,
376+
mock_paramiko,
377+
mock_os_login_hook,
378+
mock_compute_hook,
379+
exception_type,
380+
error_message,
381+
caplog,
382+
):
383+
del mock_os_login_hook
384+
mock_paramiko.SSHException = Exception
385+
386+
mock_compute_hook.return_value.project_id = TEST_PROJECT_ID
387+
388+
hook = ComputeEngineSSHHook(
389+
instance_name=TEST_INSTANCE_NAME, zone=TEST_ZONE, use_oslogin=False, use_iap_tunnel=True
390+
)
391+
mock_connect.side_effect = [exception_type, mock_ssh_client]
392+
393+
with caplog.at_level(logging.INFO):
394+
hook.get_conn()
395+
assert error_message in caplog.text
396+
assert "Failed establish SSH connection" in caplog.text
397+
277398
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineHook")
278399
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.OSLoginHook")
279400
@mock.patch("airflow.providers.google.cloud.hooks.compute_ssh.paramiko")

β€Žtests/system/providers/google/cloud/compute/example_compute_ssh.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
schedule_interval="@once",
7575
start_date=datetime(2021, 1, 1),
7676
catchup=False,
77-
tags=["example"],
77+
tags=["example", "compute-ssh"],
7878
) as dag:
7979
# [START howto_operator_gce_insert]
8080
gce_instance_insert = ComputeEngineInsertInstanceOperator(
@@ -95,7 +95,7 @@
9595
project_id=PROJECT_ID,
9696
use_oslogin=False,
9797
use_iap_tunnel=False,
98-
cmd_timeout=100,
98+
cmd_timeout=1,
9999
),
100100
command="echo metadata_without_iap_tunnel1",
101101
)

0 commit comments

Comments
 (0)