Skip to content

Commit da88ed1

Browse files
authored
Fix MyPy errors in Google Cloud (again) (#20469)
Part of #19891 The .py additions are to handle "default_args" passed in examples. Currently some of the obligatory parameters are (correctly) passed as default_args. We have no good mechanism yet to handle it properly for MyPy (it would require to add a custom MyPy plugin to handle it) We have no better way to handle it for now.
1 parent afd84f6 commit da88ed1

File tree

15 files changed

+94
-66
lines changed

15 files changed

+94
-66
lines changed

β€Žairflow/providers/google/cloud/example_dags/example_datacatalog.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from datetime import datetime
2424

2525
from google.cloud.datacatalog_v1beta1 import FieldType, TagField, TagTemplateField
26+
from google.protobuf.field_mask_pb2 import FieldMask
2627

2728
from airflow import models
2829
from airflow.models.baseoperator import chain
@@ -242,7 +243,7 @@
242243
task_id="get_entry_group",
243244
location=LOCATION,
244245
entry_group=ENTRY_GROUP_ID,
245-
read_mask={"paths": ["name", "display_name"]},
246+
read_mask=FieldMask(paths=["name", "display_name"]),
246247
)
247248
# [END howto_operator_gcp_datacatalog_get_entry_group]
248249

β€Žairflow/providers/google/cloud/example_dags/example_dataproc_metastore.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
import datetime
2424
import os
2525

26+
from google.cloud.metastore_v1 import MetadataImport
27+
from google.protobuf.field_mask_pb2 import FieldMask
28+
2629
from airflow import models
2730
from airflow.models.baseoperator import chain
2831
from airflow.providers.google.cloud.operators.dataproc_metastore import (
@@ -66,7 +69,7 @@
6669
"systemtest": "systemtest",
6770
}
6871
}
69-
UPDATE_MASK = {"paths": ["labels"]}
72+
UPDATE_MASK = FieldMask(paths=["labels"])
7073
# [END how_to_cloud_dataproc_metastore_update_service]
7174

7275
# Backup definition
@@ -78,13 +81,15 @@
7881

7982
# Metadata import definition
8083
# [START how_to_cloud_dataproc_metastore_create_metadata_import]
81-
METADATA_IMPORT = {
82-
"name": "test-metadata-import",
83-
"database_dump": {
84-
"gcs_uri": GCS_URI,
85-
"database_type": DB_TYPE,
86-
},
87-
}
84+
METADATA_IMPORT = MetadataImport(
85+
{
86+
"name": "test-metadata-import",
87+
"database_dump": {
88+
"gcs_uri": GCS_URI,
89+
"database_type": DB_TYPE,
90+
},
91+
}
92+
)
8893
# [END how_to_cloud_dataproc_metastore_create_metadata_import]
8994

9095

β€Žairflow/providers/google/cloud/hooks/cloud_memorystore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,7 @@ def update_parameters(
939939
parameters: Union[Dict, cloud_memcache.MemcacheParameters],
940940
project_id: str,
941941
location: str,
942-
instance_id: Optional[str] = None,
942+
instance_id: str,
943943
retry: Optional[Retry] = None,
944944
timeout: Optional[float] = None,
945945
metadata: Sequence[Tuple[str, str]] = (),

β€Žairflow/providers/google/cloud/hooks/cloud_sql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,15 +505,15 @@ def _download_sql_proxy_if_needed(self) -> None:
505505
if "follow_redirects" in signature(httpx.get).parameters.keys():
506506
response = httpx.get(download_url, follow_redirects=True)
507507
else:
508-
response = httpx.get(download_url, allow_redirects=True)
508+
response = httpx.get(download_url, allow_redirects=True) # type: ignore[call-arg]
509509
# Downloading to .tmp file first to avoid case where partially downloaded
510510
# binary is used by parallel operator which uses the same fixed binary path
511511
with open(proxy_path_tmp, 'wb') as file:
512512
file.write(response.content)
513513
if response.status_code != 200:
514514
raise AirflowException(
515515
"The cloud-sql-proxy could not be downloaded. "
516-
f"Status code = {response.status_code}. Reason = {response.reason}"
516+
f"Status code = {response.status_code}. Reason = {response.reason_phrase}"
517517
)
518518

519519
self.log.info("Moving sql_proxy binary from %s to %s", proxy_path_tmp, self.sql_proxy_path)

β€Žairflow/providers/google/cloud/hooks/compute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def get_instance_address(
406406
if use_internal_ip:
407407
return instance_info["networkInterfaces"][0].get("networkIP")
408408

409-
access_config = instance_info.get("networkInterfaces")[0].get("accessConfigs")
409+
access_config = instance_info["networkInterfaces"][0].get("accessConfigs")
410410
if access_config:
411411
return access_config[0].get("natIP")
412412
raise AirflowException("The target instance does not have external IP")

β€Žairflow/providers/google/cloud/hooks/dataflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,8 @@ def cancel(self) -> None:
503503
timeout_error_message = (
504504
f"Canceling jobs failed due to timeout ({self._cancel_timeout}s): {', '.join(job_ids)}"
505505
)
506-
with timeout(seconds=self._cancel_timeout, error_message=timeout_error_message):
506+
tm = timeout(seconds=self._cancel_timeout, error_message=timeout_error_message)
507+
with tm:
507508
self._wait_for_states({DataflowJobStatus.JOB_STATE_CANCELLED})
508509
else:
509510
self.log.info("No jobs to cancel")

β€Žairflow/providers/google/cloud/operators/datacatalog.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -375,9 +375,12 @@ def execute(self, context: dict):
375375
)
376376
except AlreadyExists:
377377
self.log.info("Tag already exists. Skipping create operation.")
378+
project_id = self.project_id or hook.project_id
379+
if project_id is None:
380+
raise RuntimeError("The project id must be set here")
378381
if self.template_id:
379382
template_name = DataCatalogClient.tag_template_path(
380-
self.project_id or hook.project_id, self.location, self.template_id
383+
project_id, self.location, self.template_id
381384
)
382385
else:
383386
if isinstance(self.tag, Tag):
@@ -390,7 +393,7 @@ def execute(self, context: dict):
390393
entry_group=self.entry_group,
391394
template_name=template_name,
392395
entry=self.entry,
393-
project_id=self.project_id,
396+
project_id=project_id,
394397
retry=self.retry,
395398
timeout=self.timeout,
396399
metadata=self.metadata,
@@ -1265,7 +1268,7 @@ def __init__(
12651268
*,
12661269
location: str,
12671270
entry_group: str,
1268-
read_mask: Union[Dict, FieldMask],
1271+
read_mask: FieldMask,
12691272
project_id: Optional[str] = None,
12701273
retry: Optional[Retry] = None,
12711274
timeout: Optional[float] = None,

β€Žairflow/providers/google/cloud/operators/dataprep.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@
1616
# specific language governing permissions and limitations
1717
# under the License.
1818
"""This module contains a Google Dataprep operator."""
19+
from typing import TYPE_CHECKING
1920

2021
from airflow.models import BaseOperator
2122
from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook
2223

24+
if TYPE_CHECKING:
25+
from airflow.utils.context import Context
26+
2327

2428
class DataprepGetJobsForJobGroupOperator(BaseOperator):
2529
"""
@@ -121,7 +125,7 @@ def __init__(self, *, dataprep_conn_id: str = "dataprep_default", body_request:
121125
self.body_request = body_request
122126
self.dataprep_conn_id = dataprep_conn_id
123127

124-
def execute(self, context: None) -> dict:
128+
def execute(self, context: "Context") -> dict:
125129
self.log.info("Creating a job...")
126130
hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id)
127131
response = hook.run_job_group(body_request=self.body_request)

β€Žairflow/providers/google/cloud/operators/gcs.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
import warnings
2323
from pathlib import Path
2424
from tempfile import NamedTemporaryFile, TemporaryDirectory
25-
from typing import Dict, Iterable, List, Optional, Sequence, Union
25+
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Union
26+
27+
if TYPE_CHECKING:
28+
from airflow.utils.context import Context
2629

2730
from google.api_core.exceptions import Conflict
2831
from google.cloud.exceptions import GoogleCloudError
@@ -152,7 +155,7 @@ def __init__(
152155
self.delegate_to = delegate_to
153156
self.impersonation_chain = impersonation_chain
154157

155-
def execute(self, context) -> None:
158+
def execute(self, context: "Context") -> None:
156159
hook = GCSHook(
157160
gcp_conn_id=self.gcp_conn_id,
158161
delegate_to=self.delegate_to,
@@ -258,7 +261,7 @@ def __init__(
258261
self.delegate_to = delegate_to
259262
self.impersonation_chain = impersonation_chain
260263

261-
def execute(self, context) -> list:
264+
def execute(self, context: "Context") -> list:
262265

263266
hook = GCSHook(
264267
gcp_conn_id=self.gcp_conn_id,
@@ -320,7 +323,7 @@ def __init__(
320323
self,
321324
*,
322325
bucket_name: str,
323-
objects: Optional[Iterable[str]] = None,
326+
objects: Optional[List[str]] = None,
324327
prefix: Optional[str] = None,
325328
gcp_conn_id: str = 'google_cloud_default',
326329
google_cloud_storage_conn_id: Optional[str] = None,
@@ -350,7 +353,7 @@ def __init__(
350353

351354
super().__init__(**kwargs)
352355

353-
def execute(self, context):
356+
def execute(self, context: "Context") -> None:
354357
hook = GCSHook(
355358
gcp_conn_id=self.gcp_conn_id,
356359
delegate_to=self.delegate_to,
@@ -443,7 +446,7 @@ def __init__(
443446
self.gcp_conn_id = gcp_conn_id
444447
self.impersonation_chain = impersonation_chain
445448

446-
def execute(self, context) -> None:
449+
def execute(self, context: "Context") -> None:
447450
hook = GCSHook(
448451
gcp_conn_id=self.gcp_conn_id,
449452
impersonation_chain=self.impersonation_chain,
@@ -541,7 +544,7 @@ def __init__(
541544
self.gcp_conn_id = gcp_conn_id
542545
self.impersonation_chain = impersonation_chain
543546

544-
def execute(self, context) -> None:
547+
def execute(self, context: "Context") -> None:
545548
hook = GCSHook(
546549
gcp_conn_id=self.gcp_conn_id,
547550
impersonation_chain=self.impersonation_chain,
@@ -620,7 +623,7 @@ def __init__(
620623
self.output_encoding = sys.getdefaultencoding()
621624
self.impersonation_chain = impersonation_chain
622625

623-
def execute(self, context: dict) -> None:
626+
def execute(self, context: "Context") -> None:
624627
hook = GCSHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
625628

626629
with NamedTemporaryFile() as source_file, NamedTemporaryFile() as destination_file:
@@ -742,7 +745,7 @@ class GCSTimeSpanFileTransformOperator(BaseOperator):
742745
)
743746

744747
@staticmethod
745-
def interpolate_prefix(prefix: str, dt: datetime.datetime) -> Optional[datetime.datetime]:
748+
def interpolate_prefix(prefix: str, dt: datetime.datetime) -> Optional[str]:
746749
"""Interpolate prefix with datetime.
747750
748751
:param prefix: The prefix to interpolate
@@ -792,7 +795,7 @@ def __init__(
792795
self.upload_continue_on_fail = upload_continue_on_fail
793796
self.upload_num_attempts = upload_num_attempts
794797

795-
def execute(self, context: dict) -> None:
798+
def execute(self, context: "Context") -> List[str]:
796799
# Define intervals and prefixes.
797800
try:
798801
timespan_start = context["data_interval_start"]
@@ -838,12 +841,12 @@ def execute(self, context: dict) -> None:
838841
)
839842

840843
with TemporaryDirectory() as temp_input_dir, TemporaryDirectory() as temp_output_dir:
841-
temp_input_dir = Path(temp_input_dir)
842-
temp_output_dir = Path(temp_output_dir)
844+
temp_input_dir_path = Path(temp_input_dir)
845+
temp_output_dir_path = Path(temp_output_dir)
843846

844847
# TODO: download in parallel.
845848
for blob_to_transform in blobs_to_transform:
846-
destination_file = temp_input_dir / blob_to_transform
849+
destination_file = temp_input_dir_path / blob_to_transform
847850
destination_file.parent.mkdir(parents=True, exist_ok=True)
848851
try:
849852
source_hook.download(
@@ -861,8 +864,8 @@ def execute(self, context: dict) -> None:
861864
self.log.info("Starting the transformation")
862865
cmd = [self.transform_script] if isinstance(self.transform_script, str) else self.transform_script
863866
cmd += [
864-
str(temp_input_dir),
865-
str(temp_output_dir),
867+
str(temp_input_dir_path),
868+
str(temp_output_dir_path),
866869
timespan_start.replace(microsecond=0).isoformat(),
867870
timespan_end.replace(microsecond=0).isoformat(),
868871
]
@@ -878,16 +881,16 @@ def execute(self, context: dict) -> None:
878881
if process.returncode:
879882
raise AirflowException(f"Transform script failed: {process.returncode}")
880883

881-
self.log.info("Transformation succeeded. Output temporarily located at %s", temp_output_dir)
884+
self.log.info("Transformation succeeded. Output temporarily located at %s", temp_output_dir_path)
882885

883886
files_uploaded = []
884887

885888
# TODO: upload in parallel.
886-
for upload_file in temp_output_dir.glob("**/*"):
889+
for upload_file in temp_output_dir_path.glob("**/*"):
887890
if upload_file.is_dir():
888891
continue
889892

890-
upload_file_name = str(upload_file.relative_to(temp_output_dir))
893+
upload_file_name = str(upload_file.relative_to(temp_output_dir_path))
891894

892895
if self.destination_prefix is not None:
893896
upload_file_name = f"{destination_prefix_interp}/{upload_file_name}"
@@ -959,7 +962,7 @@ def __init__(
959962
self.gcp_conn_id = gcp_conn_id
960963
self.impersonation_chain = impersonation_chain
961964

962-
def execute(self, context) -> None:
965+
def execute(self, context: "Context") -> None:
963966
hook = GCSHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
964967
hook.delete_bucket(bucket_name=self.bucket_name, force=self.force)
965968

@@ -1056,7 +1059,7 @@ def __init__(
10561059
self.delegate_to = delegate_to
10571060
self.impersonation_chain = impersonation_chain
10581061

1059-
def execute(self, context) -> None:
1062+
def execute(self, context: "Context") -> None:
10601063
hook = GCSHook(
10611064
gcp_conn_id=self.gcp_conn_id,
10621065
delegate_to=self.delegate_to,

β€Žairflow/providers/google/cloud/operators/mlengine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,9 +1164,9 @@ def __init__(
11641164
*,
11651165
job_id: str,
11661166
region: str,
1167-
package_uris: List[str] = None,
1168-
training_python_module: str = None,
1169-
training_args: List[str] = None,
1167+
package_uris: Optional[List[str]] = None,
1168+
training_python_module: Optional[str] = None,
1169+
training_args: Optional[List[str]] = None,
11701170
scale_tier: Optional[str] = None,
11711171
master_type: Optional[str] = None,
11721172
master_config: Optional[Dict] = None,

0 commit comments

Comments
 (0)