Skip to content

Commit 80c1ce7

Browse files
Wojciech JanuszekWojciech Januszek
andauthored
Cloud Storage assets & StorageLink update (#23865)
Co-authored-by: Wojciech Januszek <januszek@google.com>
1 parent 048b617 commit 80c1ce7

File tree

5 files changed

+69
-9
lines changed

5 files changed

+69
-9
lines changed

β€Žairflow/providers/google/cloud/operators/dataproc_metastore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ def execute(self, context: "Context"):
711711

712712
DataprocMetastoreLink.persist(context=context, task_instance=self, url=METASTORE_EXPORT_LINK)
713713
uri = self._get_uri_from_destination(MetadataExport.to_dict(metadata_export)["destination_gcs_uri"])
714-
StorageLink.persist(context=context, task_instance=self, uri=uri)
714+
StorageLink.persist(context=context, task_instance=self, uri=uri, project_id=self.project_id)
715715
return MetadataExport.to_dict(metadata_export)
716716

717717
def _get_uri_from_destination(self, destination_uri: str):

β€Žairflow/providers/google/cloud/operators/datastore.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def execute(self, context: 'Context') -> dict:
140140
context=context,
141141
task_instance=self,
142142
uri=f"{self.bucket}/{result['response']['outputUrl'].split('/')[3]}",
143+
project_id=self.project_id or ds_hook.project_id,
143144
)
144145
return result
145146

β€Žairflow/providers/google/cloud/operators/gcs.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from airflow.exceptions import AirflowException
3636
from airflow.models import BaseOperator
3737
from airflow.providers.google.cloud.hooks.gcs import GCSHook
38+
from airflow.providers.google.common.links.storage import FileDetailsLink, StorageLink
3839
from airflow.utils import timezone
3940

4041

@@ -107,6 +108,7 @@ class GCSCreateBucketOperator(BaseOperator):
107108
'impersonation_chain',
108109
)
109110
ui_color = '#f0eee4'
111+
operator_extra_links = (StorageLink(),)
110112

111113
def __init__(
112114
self,
@@ -139,6 +141,12 @@ def execute(self, context: "Context") -> None:
139141
delegate_to=self.delegate_to,
140142
impersonation_chain=self.impersonation_chain,
141143
)
144+
StorageLink.persist(
145+
context=context,
146+
task_instance=self,
147+
uri=self.bucket_name,
148+
project_id=self.project_id or hook.project_id,
149+
)
142150
try:
143151
hook.create_bucket(
144152
bucket_name=self.bucket_name,
@@ -200,6 +208,8 @@ class GCSListObjectsOperator(BaseOperator):
200208

201209
ui_color = '#f0eee4'
202210

211+
operator_extra_links = (StorageLink(),)
212+
203213
def __init__(
204214
self,
205215
*,
@@ -234,6 +244,13 @@ def execute(self, context: "Context") -> list:
234244
self.prefix,
235245
)
236246

247+
StorageLink.persist(
248+
context=context,
249+
task_instance=self,
250+
uri=self.bucket,
251+
project_id=hook.project_id,
252+
)
253+
237254
return hook.list(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter)
238255

239256

@@ -346,6 +363,7 @@ class GCSBucketCreateAclEntryOperator(BaseOperator):
346363
'impersonation_chain',
347364
)
348365
# [END gcs_bucket_create_acl_template_fields]
366+
operator_extra_links = (StorageLink(),)
349367

350368
def __init__(
351369
self,
@@ -371,6 +389,12 @@ def execute(self, context: "Context") -> None:
371389
gcp_conn_id=self.gcp_conn_id,
372390
impersonation_chain=self.impersonation_chain,
373391
)
392+
StorageLink.persist(
393+
context=context,
394+
task_instance=self,
395+
uri=self.bucket,
396+
project_id=hook.project_id,
397+
)
374398
hook.insert_bucket_acl(
375399
bucket_name=self.bucket, entity=self.entity, role=self.role, user_project=self.user_project
376400
)
@@ -418,6 +442,7 @@ class GCSObjectCreateAclEntryOperator(BaseOperator):
418442
'impersonation_chain',
419443
)
420444
# [END gcs_object_create_acl_template_fields]
445+
operator_extra_links = (FileDetailsLink(),)
421446

422447
def __init__(
423448
self,
@@ -447,6 +472,12 @@ def execute(self, context: "Context") -> None:
447472
gcp_conn_id=self.gcp_conn_id,
448473
impersonation_chain=self.impersonation_chain,
449474
)
475+
FileDetailsLink.persist(
476+
context=context,
477+
task_instance=self,
478+
uri=f"{self.bucket}/{self.object_name}",
479+
project_id=hook.project_id,
480+
)
450481
hook.insert_object_acl(
451482
bucket_name=self.bucket,
452483
object_name=self.object_name,
@@ -498,6 +529,7 @@ class GCSFileTransformOperator(BaseOperator):
498529
'transform_script',
499530
'impersonation_chain',
500531
)
532+
operator_extra_links = (FileDetailsLink(),)
501533

502534
def __init__(
503535
self,
@@ -549,6 +581,12 @@ def execute(self, context: "Context") -> None:
549581
self.log.info("Transformation succeeded. Output temporarily located at %s", destination_file.name)
550582

551583
self.log.info("Uploading file to %s as %s", self.destination_bucket, self.destination_object)
584+
FileDetailsLink.persist(
585+
context=context,
586+
task_instance=self,
587+
uri=f"{self.destination_bucket}/{self.destination_object}",
588+
project_id=hook.project_id,
589+
)
552590
hook.upload(
553591
bucket_name=self.destination_bucket,
554592
object_name=self.destination_object,
@@ -628,6 +666,7 @@ class GCSTimeSpanFileTransformOperator(BaseOperator):
628666
'source_impersonation_chain',
629667
'destination_impersonation_chain',
630668
)
669+
operator_extra_links = (StorageLink(),)
631670

632671
@staticmethod
633672
def interpolate_prefix(prefix: str, dt: datetime.datetime) -> Optional[str]:
@@ -718,6 +757,12 @@ def execute(self, context: "Context") -> List[str]:
718757
gcp_conn_id=self.destination_gcp_conn_id,
719758
impersonation_chain=self.destination_impersonation_chain,
720759
)
760+
StorageLink.persist(
761+
context=context,
762+
task_instance=self,
763+
uri=self.destination_bucket,
764+
project_id=destination_hook.project_id,
765+
)
721766

722767
# Fetch list of files.
723768
blobs_to_transform = source_hook.list_by_timespan(
@@ -904,6 +949,7 @@ class GCSSynchronizeBucketsOperator(BaseOperator):
904949
'delegate_to',
905950
'impersonation_chain',
906951
)
952+
operator_extra_links = (StorageLink(),)
907953

908954
def __init__(
909955
self,
@@ -938,6 +984,12 @@ def execute(self, context: "Context") -> None:
938984
delegate_to=self.delegate_to,
939985
impersonation_chain=self.impersonation_chain,
940986
)
987+
StorageLink.persist(
988+
context=context,
989+
task_instance=self,
990+
uri=self._get_uri(self.destination_bucket, self.destination_object),
991+
project_id=hook.project_id,
992+
)
941993
hook.sync(
942994
source_bucket=self.source_bucket,
943995
destination_bucket=self.destination_bucket,
@@ -947,3 +999,8 @@ def execute(self, context: "Context") -> None:
947999
delete_extra_files=self.delete_extra_files,
9481000
allow_overwrite=self.allow_overwrite,
9491001
)
1002+
1003+
def _get_uri(self, gcs_bucket: str, gcs_object: Optional[str]) -> str:
1004+
if gcs_object and gcs_object[-1] == "/":
1005+
gcs_object = gcs_object[:-1]
1006+
return f"{gcs_bucket}/{gcs_object}" if gcs_object else gcs_bucket

β€Žairflow/providers/google/common/links/storage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ class StorageLink(BaseGoogleLink):
3636
format_str = GCS_STORAGE_LINK
3737

3838
@staticmethod
39-
def persist(context: "Context", task_instance, uri: str):
39+
def persist(context: "Context", task_instance, uri: str, project_id: Optional[str]):
4040
task_instance.xcom_push(
4141
context=context,
4242
key=StorageLink.key,
43-
value={"uri": uri, "project_id": task_instance.project_id},
43+
value={"uri": uri, "project_id": project_id},
4444
)
4545

4646

β€Žtests/providers/google/cloud/operators/test_gcs.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_execute(self, mock_hook):
5757
project_id=TEST_PROJECT,
5858
)
5959

60-
operator.execute(None)
60+
operator.execute(context=mock.MagicMock())
6161
mock_hook.return_value.create_bucket.assert_called_once_with(
6262
bucket_name=TEST_BUCKET,
6363
storage_class="MULTI_REGIONAL",
@@ -78,7 +78,7 @@ def test_bucket_create_acl(self, mock_hook):
7878
user_project="test-user-project",
7979
task_id="id",
8080
)
81-
operator.execute(None)
81+
operator.execute(context=mock.MagicMock())
8282
mock_hook.return_value.insert_bucket_acl.assert_called_once_with(
8383
bucket_name="test-bucket",
8484
entity="test-entity",
@@ -97,7 +97,7 @@ def test_object_create_acl(self, mock_hook):
9797
user_project="test-user-project",
9898
task_id="id",
9999
)
100-
operator.execute(None)
100+
operator.execute(context=mock.MagicMock())
101101
mock_hook.return_value.insert_object_acl.assert_called_once_with(
102102
bucket_name="test-bucket",
103103
object_name="test-object",
@@ -148,7 +148,7 @@ def test_execute(self, mock_hook):
148148
task_id=TASK_ID, bucket=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER
149149
)
150150

151-
files = operator.execute(None)
151+
files = operator.execute(context=mock.MagicMock())
152152
mock_hook.return_value.list.assert_called_once_with(
153153
bucket_name=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER
154154
)
@@ -197,7 +197,7 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempfile):
197197
destination_bucket=destination_bucket,
198198
transform_script=transform_script,
199199
)
200-
op.execute(None)
200+
op.execute(context=mock.MagicMock())
201201

202202
mock_hook.return_value.download.assert_called_once_with(
203203
bucket_name=source_bucket, object_name=source_object, filename=source
@@ -273,9 +273,11 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempdir):
273273
timespan_end = timespan_start + timedelta(hours=1)
274274
mock_dag = mock.Mock()
275275
mock_dag.following_schedule = lambda x: x + timedelta(hours=1)
276+
mock_ti = mock.Mock()
276277
context = dict(
277278
execution_date=timespan_start,
278279
dag=mock_dag,
280+
ti=mock_ti,
279281
)
280282

281283
mock_tempdir.return_value.__enter__.side_effect = [source, destination]
@@ -397,7 +399,7 @@ def test_execute(self, mock_hook):
397399
delegate_to="DELEGATE_TO",
398400
impersonation_chain=IMPERSONATION_CHAIN,
399401
)
400-
task.execute({})
402+
task.execute(context=mock.MagicMock())
401403
mock_hook.assert_called_once_with(
402404
gcp_conn_id='GCP_CONN_ID',
403405
delegate_to='DELEGATE_TO',

0 commit comments

Comments
 (0)