Skip to content

Commit 91557c6

Browse files
authored
[AIRFLOW-7073] GKEStartPodOperator always use connection credentials (#7738)
1 parent 2a54512 commit 91557c6

File tree

8 files changed

+299
-187
lines changed

8 files changed

+299
-187
lines changed

β€Žairflow/providers/google/cloud/hooks/base.py

Lines changed: 80 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626
import os
2727
import tempfile
2828
from contextlib import contextmanager
29-
from typing import Any, Callable, Dict, Optional, Sequence, TypeVar
29+
from subprocess import check_output
30+
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar
3031

3132
import google.auth
33+
import google.auth.credentials
3234
import google.oauth2.service_account
3335
import google_auth_httplib2
3436
import httplib2
@@ -37,13 +39,15 @@
3739
AlreadyExists, Forbidden, GoogleAPICallError, ResourceExhausted, RetryError, TooManyRequests,
3840
)
3941
from google.api_core.gapic_v1.client_info import ClientInfo
42+
from google.auth import _cloud_sdk
4043
from google.auth.environment_vars import CREDENTIALS
4144
from googleapiclient.errors import HttpError
4245
from googleapiclient.http import set_user_agent
4346

4447
from airflow import version
4548
from airflow.exceptions import AirflowException
4649
from airflow.hooks.base_hook import BaseHook
50+
from airflow.utils.process_utils import patch_environ
4751

4852
log = logging.getLogger(__name__)
4953

@@ -138,7 +142,7 @@ def __init__(self, gcp_conn_id: str = 'google_cloud_default', delegate_to: Optio
138142
self.delegate_to = delegate_to
139143
self.extras = self.get_connection(self.gcp_conn_id).extra_dejson # type: Dict
140144

141-
def _get_credentials_and_project_id(self) -> google.auth.credentials.Credentials:
145+
def _get_credentials_and_project_id(self) -> Tuple[google.auth.credentials.Credentials, Optional[str]]:
142146
"""
143147
Returns the Credentials object for Google API and the associated project_id
144148
"""
@@ -387,28 +391,77 @@ def provide_gcp_credential_file_as_context(self):
387391
It can be used to provide credentials for external programs (e.g. gcloud) that expect authorization
388392
file in ``GOOGLE_APPLICATION_CREDENTIALS`` environment variable.
389393
"""
390-
with tempfile.NamedTemporaryFile(mode='w+t') as conf_file:
391-
key_path = self._get_field('key_path', None) # type: Optional[str] # noqa: E501 # pylint: disable=protected-access
392-
keyfile_dict = self._get_field('keyfile_dict', None) # type: Optional[Dict] # noqa: E501 # pylint: disable=protected-access
393-
current_env_state = os.environ.get(CREDENTIALS)
394-
try:
395-
if key_path:
396-
if key_path.endswith('.p12'):
397-
raise AirflowException(
398-
'Legacy P12 key file are not supported, use a JSON key file.'
399-
)
400-
os.environ[CREDENTIALS] = key_path
401-
elif keyfile_dict:
402-
conf_file.write(keyfile_dict)
403-
conf_file.flush()
404-
os.environ[CREDENTIALS] = conf_file.name
405-
else:
406-
# We will use the default service account credentials.
407-
pass
408-
yield conf_file
409-
finally:
410-
if current_env_state is None:
411-
if CREDENTIALS in os.environ:
412-
del os.environ[CREDENTIALS]
413-
else:
414-
os.environ[CREDENTIALS] = current_env_state
394+
key_path = self._get_field('key_path', None) # type: Optional[str] # noqa: E501 # pylint: disable=protected-access
395+
keyfile_dict = self._get_field('keyfile_dict', None) # type: Optional[Dict] # noqa: E501 # pylint: disable=protected-access
396+
if key_path and keyfile_dict:
397+
raise AirflowException(
398+
"The `keyfile_dict` and `key_path` fields are mutually exclusive. "
399+
"Please provide only one value."
400+
)
401+
elif key_path:
402+
if key_path.endswith('.p12'):
403+
raise AirflowException(
404+
'Legacy P12 key file are not supported, use a JSON key file.'
405+
)
406+
with patch_environ({CREDENTIALS: key_path}):
407+
yield key_path
408+
elif keyfile_dict:
409+
with tempfile.NamedTemporaryFile(mode='w+t') as conf_file:
410+
conf_file.write(keyfile_dict)
411+
conf_file.flush()
412+
with patch_environ({CREDENTIALS: conf_file.name}):
413+
yield conf_file.name
414+
else:
415+
# We will use the default service account credentials.
416+
yield None
417+
418+
@contextmanager
419+
def provide_authorized_gcloud(self):
420+
"""
421+
Provides a separate gcloud configuration with current credentials.
422+
423+
The gcloud allows you to login to GCP only - ``gcloud auth login`` and
424+
for the needs of Application Default Credentials ``gcloud auth application-default login``.
425+
In our case, we want all commands to use only the credentials from ADCm so
426+
we need to configure the credentials in gcloud manually.
427+
"""
428+
credentials_path = _cloud_sdk.get_application_default_credentials_path()
429+
project_id = self.project_id
430+
431+
with self.provide_gcp_credential_file_as_context(), \
432+
tempfile.TemporaryDirectory() as gcloud_config_tmp, \
433+
patch_environ({'CLOUDSDK_CONFIG': gcloud_config_tmp}):
434+
435+
if project_id:
436+
# Don't display stdout/stderr for security reason
437+
check_output([
438+
"gcloud", "config", "set", "core/project", project_id
439+
])
440+
if CREDENTIALS in os.environ:
441+
# This solves most cases when we are logged in using the service key in Airflow.
442+
# Don't display stdout/stderr for security reason
443+
check_output([
444+
"gcloud", "auth", "activate-service-account", f"--key-file={os.environ[CREDENTIALS]}",
445+
])
446+
elif os.path.exists(credentials_path):
447+
# If we are logged in by `gcloud auth application-default` then we need to log in manually.
448+
# This will make the `gcloud auth application-default` and `gcloud auth` credentials equals.
449+
with open(credentials_path) as creds_file:
450+
creds_content = json.loads(creds_file.read())
451+
# Don't display stdout/stderr for security reason
452+
check_output([
453+
"gcloud", "config", "set", "auth/client_id", creds_content["client_id"]
454+
])
455+
# Don't display stdout/stderr for security reason
456+
check_output([
457+
"gcloud", "config", "set", "auth/client_secret", creds_content["client_secret"]
458+
])
459+
# Don't display stdout/stderr for security reason
460+
check_output([
461+
"gcloud",
462+
"auth",
463+
"activate-refresh-token",
464+
creds_content["client_id"],
465+
creds_content["refresh_token"],
466+
])
467+
yield

β€Žairflow/providers/google/cloud/operators/kubernetes_engine.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
"""
2222

2323
import os
24-
import subprocess
2524
import tempfile
2625
from typing import Dict, Optional, Union
2726

@@ -33,6 +32,7 @@
3332
from airflow.providers.google.cloud.hooks.base import CloudBaseHook
3433
from airflow.providers.google.cloud.hooks.kubernetes_engine import GKEHook
3534
from airflow.utils.decorators import apply_defaults
35+
from airflow.utils.process_utils import execute_in_subprocess, patch_environ
3636

3737

3838
class GKEDeleteClusterOperator(BaseOperator):
@@ -254,22 +254,21 @@ def execute(self, context):
254254

255255
# Write config to a temp file and set the environment variable to point to it.
256256
# This is to avoid race conditions of reading/writing a single file
257-
with tempfile.NamedTemporaryFile() as conf_file:
258-
os.environ[KUBE_CONFIG_ENV_VAR] = conf_file.name
259-
260-
with hook.provide_gcp_credential_file_as_context():
261-
# Attempt to get/update credentials
262-
# We call gcloud directly instead of using google-cloud-python api
263-
# because there is no way to write kubernetes config to a file, which is
264-
# required by KubernetesPodOperator.
265-
# The gcloud command looks at the env variable `KUBECONFIG` for where to save
266-
# the kubernetes config file.
267-
subprocess.check_call(
268-
["gcloud", "container", "clusters", "get-credentials",
269-
self.cluster_name,
270-
"--zone", self.location,
271-
"--project", self.project_id])
272-
273-
# Tell `KubernetesPodOperator` where the config file is located
274-
self.config_file = os.environ[KUBE_CONFIG_ENV_VAR]
275-
return super().execute(context)
257+
with tempfile.NamedTemporaryFile() as conf_file,\
258+
patch_environ({KUBE_CONFIG_ENV_VAR: conf_file.name}), \
259+
hook.provide_authorized_gcloud():
260+
# Attempt to get/update credentials
261+
# We call gcloud directly instead of using google-cloud-python api
262+
# because there is no way to write kubernetes config to a file, which is
263+
# required by KubernetesPodOperator.
264+
# The gcloud command looks at the env variable `KUBECONFIG` for where to save
265+
# the kubernetes config file.
266+
execute_in_subprocess(
267+
["gcloud", "container", "clusters", "get-credentials",
268+
self.cluster_name,
269+
"--zone", self.location,
270+
"--project", self.project_id])
271+
272+
# Tell `KubernetesPodOperator` where the config file is located
273+
self.config_file = os.environ[KUBE_CONFIG_ENV_VAR]
274+
return super().execute(context)

β€Žairflow/providers/google/cloud/utils/credentials_provider.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
Google Cloud Platform authentication.
2121
"""
2222
import json
23-
import os
2423
import tempfile
2524
from contextlib import contextmanager
2625
from typing import Dict, Optional, Sequence
@@ -29,6 +28,7 @@
2928
from google.auth.environment_vars import CREDENTIALS
3029

3130
from airflow.exceptions import AirflowException
31+
from airflow.utils.process_utils import patch_environ
3232

3333
AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT = "AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT"
3434

@@ -66,31 +66,6 @@ def build_gcp_conn(
6666
return conn.format(query)
6767

6868

69-
@contextmanager
70-
def temporary_environment_variable(variable_name: str, value: str):
71-
"""
72-
Context manager that set up temporary value for a given environment
73-
variable and the restore initial state.
74-
75-
:param variable_name: Name of the environment variable
76-
:type variable_name: str
77-
:param value: The temporary value
78-
:type value: str
79-
"""
80-
# Save initial value
81-
init_value = os.environ.get(variable_name)
82-
try:
83-
# set temporary value
84-
os.environ[variable_name] = value
85-
yield
86-
finally:
87-
# Restore initial state (remove or restore)
88-
if variable_name in os.environ:
89-
del os.environ[variable_name]
90-
if init_value:
91-
os.environ[variable_name] = init_value
92-
93-
9469
@contextmanager
9570
def provide_gcp_credentials(
9671
key_file_path: Optional[str] = None, key_file_dict: Optional[Dict] = None
@@ -121,7 +96,7 @@ def provide_gcp_credentials(
12196
conf_file.flush()
12297
key_file_path = conf_file.name
12398
if key_file_path:
124-
with temporary_environment_variable(CREDENTIALS, key_file_path):
99+
with patch_environ({CREDENTIALS: key_file_path}):
125100
yield
126101
else:
127102
# We will use the default service account credentials.
@@ -155,7 +130,7 @@ def provide_gcp_connection(
155130
scopes=scopes, key_file_path=key_file_path, project_id=project_id
156131
)
157132

158-
with temporary_environment_variable(AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT, conn):
133+
with patch_environ({AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT: conn}):
159134
yield
160135

161136

β€Žairflow/utils/process_utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
import shlex
2626
import signal
2727
import subprocess
28-
from typing import List
28+
from contextlib import contextmanager
29+
from typing import Dict, List
2930

3031
import psutil
3132

@@ -184,3 +185,26 @@ def kill_child_processes_by_pids(pids_to_kill: List[int], timeout: int = 5) -> N
184185
log.info("Killing child PID: %s", child.pid)
185186
child.kill()
186187
child.wait()
188+
189+
190+
@contextmanager
191+
def patch_environ(new_env_variables: Dict[str, str]):
192+
"""
193+
Sets environment variables in context. After leaving the context, it restores its original state.
194+
195+
:param new_env_variables: Environment variables to set
196+
"""
197+
current_env_state = {
198+
key: os.environ.get(key)
199+
for key in new_env_variables.keys()
200+
}
201+
os.environ.update(new_env_variables)
202+
try: # pylint: disable=too-many-nested-blocks
203+
yield
204+
finally:
205+
for key, old_value in current_env_state.items():
206+
if old_value is None:
207+
if key in os.environ:
208+
del os.environ[key]
209+
else:
210+
os.environ[key] = old_value

0 commit comments

Comments
 (0)