|
26 | 26 | import os
|
27 | 27 | import tempfile
|
28 | 28 | from contextlib import contextmanager
|
29 |
| -from typing import Any, Callable, Dict, Optional, Sequence, TypeVar |
| 29 | +from subprocess import check_output |
| 30 | +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar |
30 | 31 |
|
31 | 32 | import google.auth
|
| 33 | +import google.auth.credentials |
32 | 34 | import google.oauth2.service_account
|
33 | 35 | import google_auth_httplib2
|
34 | 36 | import httplib2
|
|
37 | 39 | AlreadyExists, Forbidden, GoogleAPICallError, ResourceExhausted, RetryError, TooManyRequests,
|
38 | 40 | )
|
39 | 41 | from google.api_core.gapic_v1.client_info import ClientInfo
|
| 42 | +from google.auth import _cloud_sdk |
40 | 43 | from google.auth.environment_vars import CREDENTIALS
|
41 | 44 | from googleapiclient.errors import HttpError
|
42 | 45 | from googleapiclient.http import set_user_agent
|
43 | 46 |
|
44 | 47 | from airflow import version
|
45 | 48 | from airflow.exceptions import AirflowException
|
46 | 49 | from airflow.hooks.base_hook import BaseHook
|
| 50 | +from airflow.utils.process_utils import patch_environ |
47 | 51 |
|
48 | 52 | log = logging.getLogger(__name__)
|
49 | 53 |
|
@@ -138,7 +142,7 @@ def __init__(self, gcp_conn_id: str = 'google_cloud_default', delegate_to: Optio
|
138 | 142 | self.delegate_to = delegate_to
|
139 | 143 | self.extras = self.get_connection(self.gcp_conn_id).extra_dejson # type: Dict
|
140 | 144 |
|
141 |
| - def _get_credentials_and_project_id(self) -> google.auth.credentials.Credentials: |
| 145 | + def _get_credentials_and_project_id(self) -> Tuple[google.auth.credentials.Credentials, Optional[str]]: |
142 | 146 | """
|
143 | 147 | Returns the Credentials object for Google API and the associated project_id
|
144 | 148 | """
|
@@ -387,28 +391,77 @@ def provide_gcp_credential_file_as_context(self):
|
387 | 391 | It can be used to provide credentials for external programs (e.g. gcloud) that expect authorization
|
388 | 392 | file in ``GOOGLE_APPLICATION_CREDENTIALS`` environment variable.
|
389 | 393 | """
|
390 |
| - with tempfile.NamedTemporaryFile(mode='w+t') as conf_file: |
391 |
| - key_path = self._get_field('key_path', None) # type: Optional[str] # noqa: E501 # pylint: disable=protected-access |
392 |
| - keyfile_dict = self._get_field('keyfile_dict', None) # type: Optional[Dict] # noqa: E501 # pylint: disable=protected-access |
393 |
| - current_env_state = os.environ.get(CREDENTIALS) |
394 |
| - try: |
395 |
| - if key_path: |
396 |
| - if key_path.endswith('.p12'): |
397 |
| - raise AirflowException( |
398 |
| - 'Legacy P12 key file are not supported, use a JSON key file.' |
399 |
| - ) |
400 |
| - os.environ[CREDENTIALS] = key_path |
401 |
| - elif keyfile_dict: |
402 |
| - conf_file.write(keyfile_dict) |
403 |
| - conf_file.flush() |
404 |
| - os.environ[CREDENTIALS] = conf_file.name |
405 |
| - else: |
406 |
| - # We will use the default service account credentials. |
407 |
| - pass |
408 |
| - yield conf_file |
409 |
| - finally: |
410 |
| - if current_env_state is None: |
411 |
| - if CREDENTIALS in os.environ: |
412 |
| - del os.environ[CREDENTIALS] |
413 |
| - else: |
414 |
| - os.environ[CREDENTIALS] = current_env_state |
| 394 | + key_path = self._get_field('key_path', None) # type: Optional[str] # noqa: E501 # pylint: disable=protected-access |
| 395 | + keyfile_dict = self._get_field('keyfile_dict', None) # type: Optional[Dict] # noqa: E501 # pylint: disable=protected-access |
| 396 | + if key_path and keyfile_dict: |
| 397 | + raise AirflowException( |
| 398 | + "The `keyfile_dict` and `key_path` fields are mutually exclusive. " |
| 399 | + "Please provide only one value." |
| 400 | + ) |
| 401 | + elif key_path: |
| 402 | + if key_path.endswith('.p12'): |
| 403 | + raise AirflowException( |
| 404 | + 'Legacy P12 key file are not supported, use a JSON key file.' |
| 405 | + ) |
| 406 | + with patch_environ({CREDENTIALS: key_path}): |
| 407 | + yield key_path |
| 408 | + elif keyfile_dict: |
| 409 | + with tempfile.NamedTemporaryFile(mode='w+t') as conf_file: |
| 410 | + conf_file.write(keyfile_dict) |
| 411 | + conf_file.flush() |
| 412 | + with patch_environ({CREDENTIALS: conf_file.name}): |
| 413 | + yield conf_file.name |
| 414 | + else: |
| 415 | + # We will use the default service account credentials. |
| 416 | + yield None |
| 417 | + |
| 418 | + @contextmanager |
| 419 | + def provide_authorized_gcloud(self): |
| 420 | + """ |
| 421 | + Provides a separate gcloud configuration with current credentials. |
| 422 | +
|
| 423 | + The gcloud allows you to login to GCP only - ``gcloud auth login`` and |
| 424 | + for the needs of Application Default Credentials ``gcloud auth application-default login``. |
| 425 | + In our case, we want all commands to use only the credentials from ADCm so |
| 426 | + we need to configure the credentials in gcloud manually. |
| 427 | + """ |
| 428 | + credentials_path = _cloud_sdk.get_application_default_credentials_path() |
| 429 | + project_id = self.project_id |
| 430 | + |
| 431 | + with self.provide_gcp_credential_file_as_context(), \ |
| 432 | + tempfile.TemporaryDirectory() as gcloud_config_tmp, \ |
| 433 | + patch_environ({'CLOUDSDK_CONFIG': gcloud_config_tmp}): |
| 434 | + |
| 435 | + if project_id: |
| 436 | + # Don't display stdout/stderr for security reason |
| 437 | + check_output([ |
| 438 | + "gcloud", "config", "set", "core/project", project_id |
| 439 | + ]) |
| 440 | + if CREDENTIALS in os.environ: |
| 441 | + # This solves most cases when we are logged in using the service key in Airflow. |
| 442 | + # Don't display stdout/stderr for security reason |
| 443 | + check_output([ |
| 444 | + "gcloud", "auth", "activate-service-account", f"--key-file={os.environ[CREDENTIALS]}", |
| 445 | + ]) |
| 446 | + elif os.path.exists(credentials_path): |
| 447 | + # If we are logged in by `gcloud auth application-default` then we need to log in manually. |
| 448 | + # This will make the `gcloud auth application-default` and `gcloud auth` credentials equals. |
| 449 | + with open(credentials_path) as creds_file: |
| 450 | + creds_content = json.loads(creds_file.read()) |
| 451 | + # Don't display stdout/stderr for security reason |
| 452 | + check_output([ |
| 453 | + "gcloud", "config", "set", "auth/client_id", creds_content["client_id"] |
| 454 | + ]) |
| 455 | + # Don't display stdout/stderr for security reason |
| 456 | + check_output([ |
| 457 | + "gcloud", "config", "set", "auth/client_secret", creds_content["client_secret"] |
| 458 | + ]) |
| 459 | + # Don't display stdout/stderr for security reason |
| 460 | + check_output([ |
| 461 | + "gcloud", |
| 462 | + "auth", |
| 463 | + "activate-refresh-token", |
| 464 | + creds_content["client_id"], |
| 465 | + creds_content["refresh_token"], |
| 466 | + ]) |
| 467 | + yield |
0 commit comments