Skip to content

Commit 915f9e4

Browse files
authored
Add GCS Requester Pays bucket support to GCSToS3Operator (#32760)
* Add requester pays bucket support to GCSToS3Operator * Update docstrings * isort * Fix failing unit tests * Fix failing test
1 parent e46929b commit 915f9e4

File tree

7 files changed

+126
-35
lines changed

7 files changed

+126
-35
lines changed

β€Žairflow/providers/amazon/aws/transfers/gcs_to_s3.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ class GCSToS3Operator(BaseOperator):
8080
on the bucket is recreated within path passed in dest_s3_key.
8181
:param match_glob: (Optional) filters objects based on the glob pattern given by the string
8282
(e.g, ``'**/*/.json'``)
83+
:param gcp_user_project: (Optional) The identifier of the Google Cloud project to bill for this request.
84+
Required for Requester Pays buckets.
8385
"""
8486

8587
template_fields: Sequence[str] = (
@@ -88,6 +90,7 @@ class GCSToS3Operator(BaseOperator):
8890
"delimiter",
8991
"dest_s3_key",
9092
"google_impersonation_chain",
93+
"gcp_user_project",
9194
)
9295
ui_color = "#f0eee4"
9396

@@ -107,6 +110,7 @@ def __init__(
107110
s3_acl_policy: str | None = None,
108111
keep_directory_structure: bool = True,
109112
match_glob: str | None = None,
113+
gcp_user_project: str | None = None,
110114
**kwargs,
111115
) -> None:
112116
super().__init__(**kwargs)
@@ -130,10 +134,11 @@ def __init__(
130134
self.s3_acl_policy = s3_acl_policy
131135
self.keep_directory_structure = keep_directory_structure
132136
self.match_glob = match_glob
137+
self.gcp_user_project = gcp_user_project
133138

134139
def execute(self, context: Context) -> list[str]:
135140
# list all files in an Google Cloud Storage bucket
136-
hook = GCSHook(
141+
gcs_hook = GCSHook(
137142
gcp_conn_id=self.gcp_conn_id,
138143
impersonation_chain=self.google_impersonation_chain,
139144
)
@@ -145,8 +150,12 @@ def execute(self, context: Context) -> list[str]:
145150
self.prefix,
146151
)
147152

148-
files = hook.list(
149-
bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter, match_glob=self.match_glob
153+
gcs_files = gcs_hook.list(
154+
bucket_name=self.bucket,
155+
prefix=self.prefix,
156+
delimiter=self.delimiter,
157+
match_glob=self.match_glob,
158+
user_project=self.gcp_user_project,
150159
)
151160

152161
s3_hook = S3Hook(
@@ -173,24 +182,23 @@ def execute(self, context: Context) -> list[str]:
173182
existing_files = existing_files if existing_files is not None else []
174183
# remove the prefix for the existing files to allow the match
175184
existing_files = [file.replace(prefix, "", 1) for file in existing_files]
176-
files = list(set(files) - set(existing_files))
185+
gcs_files = list(set(gcs_files) - set(existing_files))
177186

178-
if files:
179-
180-
for file in files:
181-
with hook.provide_file(object_name=file, bucket_name=self.bucket) as local_tmp_file:
187+
if gcs_files:
188+
for file in gcs_files:
189+
with gcs_hook.provide_file(
190+
object_name=file, bucket_name=self.bucket, user_project=self.gcp_user_project
191+
) as local_tmp_file:
182192
dest_key = os.path.join(self.dest_s3_key, file)
183193
self.log.info("Saving file to %s", dest_key)
184-
185194
s3_hook.load_file(
186195
filename=local_tmp_file.name,
187196
key=dest_key,
188197
replace=self.replace,
189198
acl_policy=self.s3_acl_policy,
190199
)
191-
192-
self.log.info("All done, uploaded %d files to S3", len(files))
200+
self.log.info("All done, uploaded %d files to S3", len(gcs_files))
193201
else:
194202
self.log.info("In sync, no files needed to be uploaded to S3")
195203

196-
return files
204+
return gcs_files

β€Žairflow/providers/google/cloud/hooks/gcs.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ def copy(
197197
destination_object = destination_object or source_object
198198

199199
if source_bucket == destination_bucket and source_object == destination_object:
200-
201200
raise ValueError(
202201
f"Either source/destination bucket or source/destination object must be different, "
203202
f"not both the same: bucket={source_bucket}, object={source_object}"
@@ -282,6 +281,7 @@ def download(
282281
chunk_size: int | None = None,
283282
timeout: int | None = DEFAULT_TIMEOUT,
284283
num_max_attempts: int | None = 1,
284+
user_project: str | None = None,
285285
) -> bytes:
286286
...
287287

@@ -294,6 +294,7 @@ def download(
294294
chunk_size: int | None = None,
295295
timeout: int | None = DEFAULT_TIMEOUT,
296296
num_max_attempts: int | None = 1,
297+
user_project: str | None = None,
297298
) -> str:
298299
...
299300

@@ -305,6 +306,7 @@ def download(
305306
chunk_size: int | None = None,
306307
timeout: int | None = DEFAULT_TIMEOUT,
307308
num_max_attempts: int | None = 1,
309+
user_project: str | None = None,
308310
) -> str | bytes:
309311
"""
310312
Downloads a file from Google Cloud Storage.
@@ -320,6 +322,8 @@ def download(
320322
:param chunk_size: Blob chunk size.
321323
:param timeout: Request timeout in seconds.
322324
:param num_max_attempts: Number of attempts to download the file.
325+
:param user_project: The identifier of the Google Cloud project to bill for the request.
326+
Required for Requester Pays buckets.
323327
"""
324328
# TODO: future improvement check file size before downloading,
325329
# to check for local space availability
@@ -330,7 +334,7 @@ def download(
330334
try:
331335
num_file_attempts += 1
332336
client = self.get_conn()
333-
bucket = client.bucket(bucket_name)
337+
bucket = client.bucket(bucket_name, user_project=user_project)
334338
blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size)
335339

336340
if filename:
@@ -395,6 +399,7 @@ def provide_file(
395399
object_name: str | None = None,
396400
object_url: str | None = None,
397401
dir: str | None = None,
402+
user_project: str | None = None,
398403
) -> Generator[IO[bytes], None, None]:
399404
"""
400405
Downloads the file to a temporary directory and returns a file handle.
@@ -406,13 +411,20 @@ def provide_file(
406411
:param object_name: The object to fetch.
407412
:param object_url: File reference url. Must start with "gs: //"
408413
:param dir: The tmp sub directory to download the file to. (passed to NamedTemporaryFile)
414+
:param user_project: The identifier of the Google Cloud project to bill for the request.
415+
Required for Requester Pays buckets.
409416
:return: File handler
410417
"""
411418
if object_name is None:
412419
raise ValueError("Object name can not be empty")
413420
_, _, file_name = object_name.rpartition("/")
414421
with NamedTemporaryFile(suffix=file_name, dir=dir) as tmp_file:
415-
self.download(bucket_name=bucket_name, object_name=object_name, filename=tmp_file.name)
422+
self.download(
423+
bucket_name=bucket_name,
424+
object_name=object_name,
425+
filename=tmp_file.name,
426+
user_project=user_project,
427+
)
416428
tmp_file.flush()
417429
yield tmp_file
418430

@@ -423,6 +435,7 @@ def provide_file_and_upload(
423435
bucket_name: str = PROVIDE_BUCKET,
424436
object_name: str | None = None,
425437
object_url: str | None = None,
438+
user_project: str | None = None,
426439
) -> Generator[IO[bytes], None, None]:
427440
"""
428441
Creates temporary file, returns a file handle and uploads the files content on close.
@@ -433,6 +446,8 @@ def provide_file_and_upload(
433446
:param bucket_name: The bucket to fetch from.
434447
:param object_name: The object to fetch.
435448
:param object_url: File reference url. Must start with "gs: //"
449+
:param user_project: The identifier of the Google Cloud project to bill for the request.
450+
Required for Requester Pays buckets.
436451
:return: File handler
437452
"""
438453
if object_name is None:
@@ -442,7 +457,12 @@ def provide_file_and_upload(
442457
with NamedTemporaryFile(suffix=file_name) as tmp_file:
443458
yield tmp_file
444459
tmp_file.flush()
445-
self.upload(bucket_name=bucket_name, object_name=object_name, filename=tmp_file.name)
460+
self.upload(
461+
bucket_name=bucket_name,
462+
object_name=object_name,
463+
filename=tmp_file.name,
464+
user_project=user_project,
465+
)
446466

447467
def upload(
448468
self,
@@ -458,6 +478,7 @@ def upload(
458478
num_max_attempts: int = 1,
459479
metadata: dict | None = None,
460480
cache_control: str | None = None,
481+
user_project: str | None = None,
461482
) -> None:
462483
"""
463484
Uploads a local file or file data as string or bytes to Google Cloud Storage.
@@ -474,6 +495,8 @@ def upload(
474495
:param num_max_attempts: Number of attempts to try to upload the file.
475496
:param metadata: The metadata to be uploaded with the file.
476497
:param cache_control: Cache-Control metadata field.
498+
:param user_project: The identifier of the Google Cloud project to bill for the request.
499+
Required for Requester Pays buckets.
477500
"""
478501

479502
def _call_with_retry(f: Callable[[], None]) -> None:
@@ -506,7 +529,7 @@ def _call_with_retry(f: Callable[[], None]) -> None:
506529
continue
507530

508531
client = self.get_conn()
509-
bucket = client.bucket(bucket_name)
532+
bucket = client.bucket(bucket_name, user_project=user_project)
510533
blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size)
511534

512535
if metadata:
@@ -596,7 +619,6 @@ def is_updated_after(self, bucket_name: str, object_name: str, ts: datetime) ->
596619
"""
597620
blob_update_time = self.get_blob_update_time(bucket_name, object_name)
598621
if blob_update_time is not None:
599-
600622
if not ts.tzinfo:
601623
ts = ts.replace(tzinfo=timezone.utc)
602624
self.log.info("Verify object date: %s > %s", blob_update_time, ts)
@@ -618,7 +640,6 @@ def is_updated_between(
618640
"""
619641
blob_update_time = self.get_blob_update_time(bucket_name, object_name)
620642
if blob_update_time is not None:
621-
622643
if not min_ts.tzinfo:
623644
min_ts = min_ts.replace(tzinfo=timezone.utc)
624645
if not max_ts.tzinfo:
@@ -639,7 +660,6 @@ def is_updated_before(self, bucket_name: str, object_name: str, ts: datetime) ->
639660
"""
640661
blob_update_time = self.get_blob_update_time(bucket_name, object_name)
641662
if blob_update_time is not None:
642-
643663
if not ts.tzinfo:
644664
ts = ts.replace(tzinfo=timezone.utc)
645665
self.log.info("Verify object date: %s < %s", blob_update_time, ts)
@@ -681,16 +701,18 @@ def delete(self, bucket_name: str, object_name: str) -> None:
681701

682702
self.log.info("Blob %s deleted.", object_name)
683703

684-
def delete_bucket(self, bucket_name: str, force: bool = False) -> None:
704+
def delete_bucket(self, bucket_name: str, force: bool = False, user_project: str | None = None) -> None:
685705
"""
686706
Delete a bucket object from the Google Cloud Storage.
687707
688708
:param bucket_name: name of the bucket which will be deleted
689709
:param force: false not allow to delete non empty bucket, set force=True
690710
allows to delete non empty bucket
711+
:param user_project: The identifier of the Google Cloud project to bill for the request.
712+
Required for Requester Pays buckets.
691713
"""
692714
client = self.get_conn()
693-
bucket = client.bucket(bucket_name)
715+
bucket = client.bucket(bucket_name, user_project=user_project)
694716

695717
self.log.info("Deleting %s bucket", bucket_name)
696718
try:
@@ -707,6 +729,7 @@ def list(
707729
prefix: str | List[str] | None = None,
708730
delimiter: str | None = None,
709731
match_glob: str | None = None,
732+
user_project: str | None = None,
710733
):
711734
"""
712735
List all objects from the bucket with the given a single prefix or multiple prefixes.
@@ -718,6 +741,8 @@ def list(
718741
:param delimiter: (Deprecated) filters objects based on the delimiter (for e.g '.csv')
719742
:param match_glob: (Optional) filters objects based on the glob pattern given by the string
720743
(e.g, ``'**/*/.json'``).
744+
:param user_project: The identifier of the Google Cloud project to bill for the request.
745+
Required for Requester Pays buckets.
721746
:return: a stream of object names matching the filtering criteria
722747
"""
723748
if delimiter and delimiter != "/":
@@ -739,6 +764,7 @@ def list(
739764
prefix=prefix_item,
740765
delimiter=delimiter,
741766
match_glob=match_glob,
767+
user_project=user_project,
742768
)
743769
)
744770
else:
@@ -750,6 +776,7 @@ def list(
750776
prefix=prefix,
751777
delimiter=delimiter,
752778
match_glob=match_glob,
779+
user_project=user_project,
753780
)
754781
)
755782
return objects
@@ -762,6 +789,7 @@ def _list(
762789
prefix: str | None = None,
763790
delimiter: str | None = None,
764791
match_glob: str | None = None,
792+
user_project: str | None = None,
765793
) -> List:
766794
"""
767795
List all objects from the bucket with the give string prefix in name.
@@ -773,10 +801,12 @@ def _list(
773801
:param delimiter: (Deprecated) filters objects based on the delimiter (for e.g '.csv')
774802
:param match_glob: (Optional) filters objects based on the glob pattern given by the string
775803
(e.g, ``'**/*/.json'``).
804+
:param user_project: The identifier of the Google Cloud project to bill for the request.
805+
Required for Requester Pays buckets.
776806
:return: a stream of object names matching the filtering criteria
777807
"""
778808
client = self.get_conn()
779-
bucket = client.bucket(bucket_name)
809+
bucket = client.bucket(bucket_name, user_project=user_project)
780810

781811
ids = []
782812
page_token = None

β€Žairflow/providers/google/cloud/operators/gcs.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,6 @@ def __init__(
301301
impersonation_chain: str | Sequence[str] | None = None,
302302
**kwargs,
303303
) -> None:
304-
305304
self.bucket_name = bucket_name
306305
self.objects = objects
307306
self.prefix = prefix
@@ -875,12 +874,15 @@ class GCSDeleteBucketOperator(GoogleCloudBaseOperator):
875874
If set as a sequence, the identities from the list must grant
876875
Service Account Token Creator IAM role to the directly preceding identity, with first
877876
account from the list granting this role to the originating account (templated).
877+
:param user_project: (Optional) The identifier of the project to bill for this request.
878+
Required for Requester Pays buckets.
878879
"""
879880

880881
template_fields: Sequence[str] = (
881882
"bucket_name",
882883
"gcp_conn_id",
883884
"impersonation_chain",
885+
"user_project",
884886
)
885887

886888
def __init__(
@@ -890,6 +892,7 @@ def __init__(
890892
force: bool = True,
891893
gcp_conn_id: str = "google_cloud_default",
892894
impersonation_chain: str | Sequence[str] | None = None,
895+
user_project: str | None = None,
893896
**kwargs,
894897
) -> None:
895898
super().__init__(**kwargs)
@@ -898,10 +901,11 @@ def __init__(
898901
self.force: bool = force
899902
self.gcp_conn_id = gcp_conn_id
900903
self.impersonation_chain = impersonation_chain
904+
self.user_project = user_project
901905

902906
def execute(self, context: Context) -> None:
903907
hook = GCSHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
904-
hook.delete_bucket(bucket_name=self.bucket_name, force=self.force)
908+
hook.delete_bucket(bucket_name=self.bucket_name, force=self.force, user_project=self.user_project)
905909

906910

907911
class GCSSynchronizeBucketsOperator(GoogleCloudBaseOperator):

β€Žtests/providers/amazon/aws/transfers/test_gcs_to_s3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@ def test_execute__match_glob(self, mock_hook):
6969

7070
operator.execute(None)
7171
mock_hook.return_value.list.assert_called_once_with(
72-
bucket_name=GCS_BUCKET, delimiter=None, match_glob=f"**/*{DELIMITER}", prefix=PREFIX
72+
bucket_name=GCS_BUCKET,
73+
delimiter=None,
74+
match_glob=f"**/*{DELIMITER}",
75+
prefix=PREFIX,
76+
user_project=None,
7377
)
7478

7579
@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")

0 commit comments

Comments
 (0)