Skip to content

Commit fb6c501

Browse files
authored
Add flag apply_gcs_prefix to S3ToGCSOperator (b/245077385) (#31127)
1 parent fdc7a31 commit fb6c501

File tree

3 files changed

+211
-64
lines changed

3 files changed

+211
-64
lines changed

β€Žairflow/providers/google/cloud/transfers/s3_to_gcs.py

Lines changed: 70 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ class S3ToGCSOperator(S3ListOperator):
4545
:param bucket: The S3 bucket where to find the objects. (templated)
4646
:param prefix: Prefix string which filters objects whose name begin with
4747
such prefix. (templated)
48+
:param apply_gcs_prefix: (Optional) Whether to replace source objects' path by given GCS destination path.
49+
If apply_gcs_prefix is False (default), then objects from S3 will be copied to GCS bucket into a given
50+
GSC path and the source path will be place inside. For example,
51+
<s3_bucket><s3_prefix><content> => <gcs_prefix><s3_prefix><content>
52+
53+
If apply_gcs_prefix is True, then objects from S3 will be copied to GCS bucket into a given
54+
GCS path and the source path will be omitted. For example:
55+
<s3_bucket><s3_prefix><content> => <gcs_prefix><content>
56+
4857
:param delimiter: the delimiter marks key hierarchy. (templated)
4958
:param aws_conn_id: The source S3 connection
5059
:param verify: Whether or not to verify SSL certificates for S3 connection.
@@ -106,6 +115,7 @@ def __init__(
106115
*,
107116
bucket,
108117
prefix="",
118+
apply_gcs_prefix=False,
109119
delimiter="",
110120
aws_conn_id="aws_default",
111121
verify=None,
@@ -118,6 +128,7 @@ def __init__(
118128
):
119129

120130
super().__init__(bucket=bucket, prefix=prefix, delimiter=delimiter, aws_conn_id=aws_conn_id, **kwargs)
131+
self.apply_gcs_prefix = apply_gcs_prefix
121132
self.gcp_conn_id = gcp_conn_id
122133
self.dest_gcs = dest_gcs
123134
self.replace = replace
@@ -139,68 +150,74 @@ def _check_inputs(self) -> None:
139150
def execute(self, context: Context):
140151
self._check_inputs()
141152
# use the super method to list all the files in an S3 bucket/key
142-
files = super().execute(context)
153+
s3_objects = super().execute(context)
143154

144155
gcs_hook = GCSHook(
145156
gcp_conn_id=self.gcp_conn_id,
146157
impersonation_chain=self.google_impersonation_chain,
147158
)
148-
149159
if not self.replace:
150-
# if we are not replacing -> list all files in the GCS bucket
151-
# and only keep those files which are present in
152-
# S3 and not in Google Cloud Storage
153-
bucket_name, object_prefix = _parse_gcs_url(self.dest_gcs)
154-
existing_files_prefixed = gcs_hook.list(bucket_name, prefix=object_prefix)
155-
156-
existing_files = []
157-
158-
if existing_files_prefixed:
159-
# Remove the object prefix itself, an empty directory was found
160-
if object_prefix in existing_files_prefixed:
161-
existing_files_prefixed.remove(object_prefix)
162-
163-
# Remove the object prefix from all object string paths
164-
for f in existing_files_prefixed:
165-
if f.startswith(object_prefix):
166-
existing_files.append(f[len(object_prefix) :])
167-
else:
168-
existing_files.append(f)
169-
170-
files = list(set(files) - set(existing_files))
171-
if len(files) > 0:
172-
self.log.info("%s files are going to be synced: %s.", len(files), files)
173-
else:
174-
self.log.info("There are no new files to sync. Have a nice day!")
175-
176-
if files:
160+
s3_objects = self.exclude_existing_objects(s3_objects=s3_objects, gcs_hook=gcs_hook)
161+
162+
if s3_objects:
177163
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
178164

179-
for file in files:
180-
# GCS hook builds its own in-memory file so we have to create
165+
dest_gcs_bucket, dest_gcs_object_prefix = _parse_gcs_url(self.dest_gcs)
166+
for obj in s3_objects:
167+
# GCS hook builds its own in-memory file, so we have to create
181168
# and pass the path
182-
file_object = hook.get_key(file, self.bucket)
183-
with NamedTemporaryFile(mode="wb", delete=True) as f:
184-
file_object.download_fileobj(f)
185-
f.flush()
186-
187-
dest_gcs_bucket, dest_gcs_object_prefix = _parse_gcs_url(self.dest_gcs)
188-
# There will always be a '/' before file because it is
189-
# enforced at instantiation time
190-
dest_gcs_object = dest_gcs_object_prefix + file
191-
192-
# Sync is sequential and the hook already logs too much
193-
# so skip this for now
194-
# self.log.info(
195-
# 'Saving file {0} from S3 bucket {1} in GCS bucket {2}'
196-
# ' as object {3}'.format(file, self.bucket,
197-
# dest_gcs_bucket,
198-
# dest_gcs_object))
199-
200-
gcs_hook.upload(dest_gcs_bucket, dest_gcs_object, f.name, gzip=self.gzip)
201-
202-
self.log.info("All done, uploaded %d files to Google Cloud Storage", len(files))
169+
file_object = hook.get_key(obj, self.bucket)
170+
with NamedTemporaryFile(mode="wb", delete=True) as file:
171+
file_object.download_fileobj(file)
172+
file.flush()
173+
gcs_file = self.s3_to_gcs_object(s3_object=obj)
174+
gcs_hook.upload(dest_gcs_bucket, gcs_file, file.name, gzip=self.gzip)
175+
176+
self.log.info("All done, uploaded %d files to Google Cloud Storage", len(s3_objects))
203177
else:
204178
self.log.info("In sync, no files needed to be uploaded to Google Cloud Storage")
205179

206-
return files
180+
return s3_objects
181+
182+
def exclude_existing_objects(self, s3_objects: list[str], gcs_hook: GCSHook) -> list[str]:
183+
"""Excludes from the list objects that already exist in GCS bucket."""
184+
bucket_name, object_prefix = _parse_gcs_url(self.dest_gcs)
185+
186+
existing_gcs_objects = set(gcs_hook.list(bucket_name, prefix=object_prefix))
187+
188+
s3_paths = set(self.gcs_to_s3_object(gcs_object=gcs_object) for gcs_object in existing_gcs_objects)
189+
s3_objects_reduced = list(set(s3_objects) - s3_paths)
190+
191+
if s3_objects_reduced:
192+
self.log.info("%s files are going to be synced: %s.", len(s3_objects_reduced), s3_objects_reduced)
193+
else:
194+
self.log.info("There are no new files to sync. Have a nice day!")
195+
return s3_objects_reduced
196+
197+
def s3_to_gcs_object(self, s3_object: str) -> str:
198+
"""
199+
Transforms S3 path to GCS path according to the operator's logic.
200+
201+
If apply_gcs_prefix == True then <s3_prefix><content> => <gcs_prefix><content>
202+
If apply_gcs_prefix == False then <s3_prefix><content> => <gcs_prefix><s3_prefix><content>
203+
204+
"""
205+
gcs_bucket, gcs_prefix = _parse_gcs_url(self.dest_gcs)
206+
if self.apply_gcs_prefix:
207+
gcs_object = s3_object.replace(self.prefix, gcs_prefix, 1)
208+
return gcs_object
209+
return gcs_prefix + s3_object
210+
211+
def gcs_to_s3_object(self, gcs_object: str) -> str:
212+
"""
213+
Transforms GCS path to S3 path according to the operator's logic.
214+
215+
If apply_gcs_prefix == True then <gcs_prefix><content> => <s3_prefix><content>
216+
If apply_gcs_prefix == False then <gcs_prefix><s3_prefix><content> => <s3_prefix><content>
217+
218+
"""
219+
gcs_bucket, gcs_prefix = _parse_gcs_url(self.dest_gcs)
220+
s3_object = gcs_object.replace(gcs_prefix, "", 1)
221+
if self.apply_gcs_prefix:
222+
return self.prefix + s3_object
223+
return s3_object

β€Žtests/providers/google/cloud/transfers/test_s3_to_gcs.py

Lines changed: 136 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,39 @@
1919

2020
from unittest import mock
2121

22+
import pytest
23+
2224
from airflow.providers.google.cloud.transfers.s3_to_gcs import S3ToGCSOperator
2325

2426
TASK_ID = "test-s3-gcs-operator"
2527
S3_BUCKET = "test-bucket"
2628
S3_PREFIX = "TEST"
2729
S3_DELIMITER = "/"
28-
GCS_PATH_PREFIX = "gs://gcs-bucket/data/"
29-
MOCK_FILES = ["TEST1.csv", "TEST2.csv", "TEST3.csv"]
30+
GCS_BUCKET = "gcs-bucket"
31+
GCS_BUCKET_URI = "gs://" + GCS_BUCKET
32+
GCS_PREFIX = "data/"
33+
GCS_PATH_PREFIX = GCS_BUCKET_URI + "/" + GCS_PREFIX
34+
MOCK_FILE_1 = "TEST1.csv"
35+
MOCK_FILE_2 = "TEST2.csv"
36+
MOCK_FILE_3 = "TEST3.csv"
37+
MOCK_FILES = [MOCK_FILE_1, MOCK_FILE_2, MOCK_FILE_3]
3038
AWS_CONN_ID = "aws_default"
3139
GCS_CONN_ID = "google_cloud_default"
3240
IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"]
41+
APPLY_GCS_PREFIX = False
42+
PARAMETRIZED_OBJECT_PATHS = (
43+
"apply_gcs_prefix, s3_prefix, s3_object, gcs_destination, gcs_object",
44+
[
45+
(False, "", MOCK_FILE_1, GCS_PATH_PREFIX, GCS_PREFIX + MOCK_FILE_1),
46+
(False, S3_PREFIX, MOCK_FILE_1, GCS_PATH_PREFIX, GCS_PREFIX + S3_PREFIX + MOCK_FILE_1),
47+
(False, "", MOCK_FILE_1, GCS_BUCKET_URI, MOCK_FILE_1),
48+
(False, S3_PREFIX, MOCK_FILE_1, GCS_BUCKET_URI, S3_PREFIX + MOCK_FILE_1),
49+
(True, "", MOCK_FILE_1, GCS_PATH_PREFIX, GCS_PREFIX + MOCK_FILE_1),
50+
(True, S3_PREFIX, MOCK_FILE_1, GCS_PATH_PREFIX, GCS_PREFIX + MOCK_FILE_1),
51+
(True, "", MOCK_FILE_1, GCS_BUCKET_URI, MOCK_FILE_1),
52+
(True, S3_PREFIX, MOCK_FILE_1, GCS_BUCKET_URI, MOCK_FILE_1),
53+
],
54+
)
3355

3456

3557
class TestS3ToGoogleCloudStorageOperator:
@@ -44,6 +66,7 @@ def test_init(self):
4466
gcp_conn_id=GCS_CONN_ID,
4567
dest_gcs=GCS_PATH_PREFIX,
4668
google_impersonation_chain=IMPERSONATION_CHAIN,
69+
apply_gcs_prefix=APPLY_GCS_PREFIX,
4770
)
4871

4972
assert operator.task_id == TASK_ID
@@ -53,6 +76,7 @@ def test_init(self):
5376
assert operator.gcp_conn_id == GCS_CONN_ID
5477
assert operator.dest_gcs == GCS_PATH_PREFIX
5578
assert operator.google_impersonation_chain == IMPERSONATION_CHAIN
79+
assert operator.apply_gcs_prefix == APPLY_GCS_PREFIX
5680

5781
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook")
5882
@mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
@@ -73,12 +97,12 @@ def test_execute(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook):
7397
s3_one_mock_hook.return_value.list_keys.return_value = MOCK_FILES
7498
s3_two_mock_hook.return_value.list_keys.return_value = MOCK_FILES
7599

76-
uploaded_files = operator.execute(None)
100+
uploaded_files = operator.execute(context={})
77101
gcs_mock_hook.return_value.upload.assert_has_calls(
78102
[
79-
mock.call("gcs-bucket", "data/TEST1.csv", mock.ANY, gzip=False),
80-
mock.call("gcs-bucket", "data/TEST3.csv", mock.ANY, gzip=False),
81-
mock.call("gcs-bucket", "data/TEST2.csv", mock.ANY, gzip=False),
103+
mock.call(GCS_BUCKET, GCS_PREFIX + MOCK_FILE_1, mock.ANY, gzip=False),
104+
mock.call(GCS_BUCKET, GCS_PREFIX + MOCK_FILE_2, mock.ANY, gzip=False),
105+
mock.call(GCS_BUCKET, GCS_PREFIX + MOCK_FILE_3, mock.ANY, gzip=False),
82106
],
83107
any_order=True,
84108
)
@@ -112,16 +136,118 @@ def test_execute_with_gzip(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_ho
112136
s3_one_mock_hook.return_value.list_keys.return_value = MOCK_FILES
113137
s3_two_mock_hook.return_value.list_keys.return_value = MOCK_FILES
114138

115-
operator.execute(None)
139+
operator.execute(context={})
116140
gcs_mock_hook.assert_called_once_with(
117141
gcp_conn_id=GCS_CONN_ID,
118142
impersonation_chain=None,
119143
)
120144
gcs_mock_hook.return_value.upload.assert_has_calls(
121145
[
122-
mock.call("gcs-bucket", "data/TEST2.csv", mock.ANY, gzip=True),
123-
mock.call("gcs-bucket", "data/TEST1.csv", mock.ANY, gzip=True),
124-
mock.call("gcs-bucket", "data/TEST3.csv", mock.ANY, gzip=True),
146+
mock.call(GCS_BUCKET, GCS_PREFIX + MOCK_FILE_1, mock.ANY, gzip=True),
147+
mock.call(GCS_BUCKET, GCS_PREFIX + MOCK_FILE_2, mock.ANY, gzip=True),
148+
mock.call(GCS_BUCKET, GCS_PREFIX + MOCK_FILE_3, mock.ANY, gzip=True),
125149
],
126150
any_order=True,
127151
)
152+
153+
@pytest.mark.parametrize(
154+
"source_objects, existing_objects, objects_expected",
155+
[
156+
(MOCK_FILES, [], MOCK_FILES),
157+
(MOCK_FILES, [MOCK_FILE_1], [MOCK_FILE_2, MOCK_FILE_3]),
158+
(MOCK_FILES, [MOCK_FILE_1, MOCK_FILE_2], [MOCK_FILE_3]),
159+
(MOCK_FILES, [MOCK_FILE_3, MOCK_FILE_2], [MOCK_FILE_1]),
160+
(MOCK_FILES, MOCK_FILES, []),
161+
],
162+
)
163+
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook")
164+
def test_exclude_existing_objects(
165+
self, mock_gcs_hook, source_objects, existing_objects, objects_expected
166+
):
167+
operator = S3ToGCSOperator(
168+
task_id=TASK_ID,
169+
bucket=S3_BUCKET,
170+
prefix=S3_PREFIX,
171+
delimiter=S3_DELIMITER,
172+
gcp_conn_id=GCS_CONN_ID,
173+
dest_gcs=GCS_PATH_PREFIX,
174+
gzip=True,
175+
)
176+
mock_gcs_hook.list.return_value = existing_objects
177+
files_reduced = operator.exclude_existing_objects(s3_objects=source_objects, gcs_hook=mock_gcs_hook)
178+
assert set(files_reduced) == set(objects_expected)
179+
180+
@pytest.mark.parametrize(*PARAMETRIZED_OBJECT_PATHS)
181+
def test_s3_to_gcs_object(self, apply_gcs_prefix, s3_prefix, s3_object, gcs_destination, gcs_object):
182+
operator = S3ToGCSOperator(
183+
task_id=TASK_ID,
184+
bucket=S3_BUCKET,
185+
prefix=s3_prefix,
186+
delimiter=S3_DELIMITER,
187+
gcp_conn_id=GCS_CONN_ID,
188+
dest_gcs=gcs_destination,
189+
gzip=True,
190+
apply_gcs_prefix=apply_gcs_prefix,
191+
)
192+
assert operator.s3_to_gcs_object(s3_object=s3_prefix + s3_object) == gcs_object
193+
194+
@pytest.mark.parametrize(*PARAMETRIZED_OBJECT_PATHS)
195+
def test_gcs_to_s3_object(self, apply_gcs_prefix, s3_prefix, s3_object, gcs_destination, gcs_object):
196+
operator = S3ToGCSOperator(
197+
task_id=TASK_ID,
198+
bucket=S3_BUCKET,
199+
prefix=s3_prefix,
200+
delimiter=S3_DELIMITER,
201+
gcp_conn_id=GCS_CONN_ID,
202+
dest_gcs=gcs_destination,
203+
gzip=True,
204+
apply_gcs_prefix=apply_gcs_prefix,
205+
)
206+
assert operator.gcs_to_s3_object(gcs_object=gcs_object) == s3_prefix + s3_object
207+
208+
@pytest.mark.parametrize(*PARAMETRIZED_OBJECT_PATHS)
209+
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook")
210+
@mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
211+
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook")
212+
def test_execute_apply_gcs_prefix(
213+
self,
214+
gcs_mock_hook,
215+
s3_one_mock_hook,
216+
s3_two_mock_hook,
217+
apply_gcs_prefix,
218+
s3_prefix,
219+
s3_object,
220+
gcs_destination,
221+
gcs_object,
222+
):
223+
224+
operator = S3ToGCSOperator(
225+
task_id=TASK_ID,
226+
bucket=S3_BUCKET,
227+
prefix=s3_prefix,
228+
delimiter=S3_DELIMITER,
229+
gcp_conn_id=GCS_CONN_ID,
230+
dest_gcs=gcs_destination,
231+
google_impersonation_chain=IMPERSONATION_CHAIN,
232+
apply_gcs_prefix=apply_gcs_prefix,
233+
)
234+
235+
s3_one_mock_hook.return_value.list_keys.return_value = [s3_prefix + s3_object]
236+
s3_two_mock_hook.return_value.list_keys.return_value = [s3_prefix + s3_object]
237+
238+
uploaded_files = operator.execute(context={})
239+
gcs_mock_hook.return_value.upload.assert_has_calls(
240+
[
241+
mock.call(GCS_BUCKET, gcs_object, mock.ANY, gzip=False),
242+
],
243+
any_order=True,
244+
)
245+
246+
s3_one_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID, verify=None)
247+
s3_two_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID, verify=None)
248+
gcs_mock_hook.assert_called_once_with(
249+
gcp_conn_id=GCS_CONN_ID,
250+
impersonation_chain=IMPERSONATION_CHAIN,
251+
)
252+
253+
assert sorted([s3_prefix + s3_object]) == sorted(uploaded_files)

β€Žtests/system/providers/google/cloud/gcs/example_s3_to_gcs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ def upload_file():
6262
)
6363
# [START howto_transfer_s3togcs_operator]
6464
transfer_to_gcs = S3ToGCSOperator(
65-
task_id="s3_to_gcs_task", bucket=BUCKET_NAME, prefix=PREFIX, dest_gcs=GCS_BUCKET_URL
65+
task_id="s3_to_gcs_task",
66+
bucket=BUCKET_NAME,
67+
prefix=PREFIX,
68+
dest_gcs=GCS_BUCKET_URL,
69+
apply_gcs_prefix=True,
6670
)
6771
# [END howto_transfer_s3togcs_operator]
6872

0 commit comments

Comments
 (0)