Skip to content

Commit c8de9a5

Browse files
blphamshaniyaclementBrenda Phamshaniyaclement
authored
Adding Support for Google Cloud's Data Pipelines Run Operator (#32846)
--------- Co-authored-by: shaniyaclement <shaniya.clement17@gmail.com> Co-authored-by: Brenda Pham <bloop@google.com> Co-authored-by: Shaniya Clement <93938197+shaniyaclement@users.noreply.github.com>
1 parent 46fa5a2 commit c8de9a5

File tree

6 files changed

+234
-1
lines changed

6 files changed

+234
-1
lines changed

β€Žairflow/providers/google/cloud/hooks/datapipeline.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,37 @@ def create_data_pipeline(
8585
response = request.execute(num_retries=self.num_retries)
8686
return response
8787

88+
@GoogleBaseHook.fallback_to_default_project_id
89+
def run_data_pipeline(
90+
self,
91+
data_pipeline_name: str,
92+
project_id: str,
93+
location: str = DEFAULT_DATAPIPELINE_LOCATION,
94+
) -> None:
95+
"""
96+
Runs a Data Pipelines Instance using the Data Pipelines API.
97+
98+
:param data_pipeline_name: The display name of the pipeline. In example
99+
projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
100+
:param project_id: The ID of the GCP project that owns the job.
101+
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
102+
103+
Returns the created Job in JSON representation.
104+
"""
105+
parent = self.build_parent_name(project_id, location)
106+
service = self.get_conn()
107+
request = (
108+
service.projects()
109+
.locations()
110+
.pipelines()
111+
.run(
112+
name=f"{parent}/pipelines/{data_pipeline_name}",
113+
body={},
114+
)
115+
)
116+
response = request.execute(num_retries=self.num_retries)
117+
return response
118+
88119
@staticmethod
89120
def build_parent_name(project_id: str, location: str):
90121
return f"projects/{project_id}/locations/{location}"

β€Žairflow/providers/google/cloud/operators/datapipeline.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,55 @@ def execute(self, context: Context):
100100
raise AirflowException(self.data_pipeline.get("error").get("message"))
101101

102102
return self.data_pipeline
103+
104+
105+
class RunDataPipelineOperator(GoogleCloudBaseOperator):
106+
"""
107+
Runs a Data Pipelines Instance using the Data Pipelines API.
108+
109+
:param data_pipeline_name: The display name of the pipeline. In example
110+
projects/PROJECT_ID/locations/LOCATION_ID/pipelines/PIPELINE_ID it would be the PIPELINE_ID.
111+
:param project_id: The ID of the GCP project that owns the job.
112+
:param location: The location to direct the Data Pipelines instance to (for example us-central1).
113+
:param gcp_conn_id: The connection ID to connect to the Google Cloud
114+
Platform.
115+
116+
Returns the created Job in JSON representation.
117+
"""
118+
119+
def __init__(
120+
self,
121+
data_pipeline_name: str,
122+
project_id: str | None = None,
123+
location: str = DEFAULT_DATAPIPELINE_LOCATION,
124+
gcp_conn_id: str = "google_cloud_default",
125+
**kwargs,
126+
) -> None:
127+
super().__init__(**kwargs)
128+
129+
self.data_pipeline_name = data_pipeline_name
130+
self.project_id = project_id
131+
self.location = location
132+
self.gcp_conn_id = gcp_conn_id
133+
134+
def execute(self, context: Context):
135+
self.data_pipeline_hook = DataPipelineHook(gcp_conn_id=self.gcp_conn_id)
136+
137+
if self.data_pipeline_name is None:
138+
raise AirflowException("Data Pipeline name not given; cannot run unspecified pipeline.")
139+
if self.project_id is None:
140+
raise AirflowException("Data Pipeline Project ID not given; cannot run pipeline.")
141+
if self.location is None:
142+
raise AirflowException("Data Pipeline location not given; cannot run pipeline.")
143+
144+
self.response = self.data_pipeline_hook.run_data_pipeline(
145+
data_pipeline_name=self.data_pipeline_name,
146+
project_id=self.project_id,
147+
location=self.location,
148+
)
149+
150+
if self.response:
151+
if "error" in self.response:
152+
raise AirflowException(self.response.get("error").get("message"))
153+
154+
return self.response

β€Ždocs/apache-airflow-providers-google/operators/cloud/datapipeline.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,35 @@ Here is an example of how you can create a Data Pipelines instance by running th
5555
:start-after: [START howto_operator_create_data_pipeline]
5656
:end-before: [END howto_operator_create_data_pipeline]
5757

58+
Running a Data Pipeline
59+
^^^^^^^^^^^^^^^^^^^^^^^
60+
61+
To run a Data Pipelines instance, use :class:`~airflow.providers.google.cloud.operators.datapipeline.RunDataPipelineOperator`.
62+
The operator accesses Google Cloud's Data Pipelines API and calls upon the
63+
`run method <https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/projects.locations.pipelines/run>`__
64+
to run the given pipeline.
65+
66+
:class:`~airflow.providers.google.cloud.operators.datapipeline.RunDataPipelineOperator` can take in four parameters:
67+
68+
- ``data_pipeline_name``: the name of the Data Pipelines instance
69+
- ``project_id``: the ID of the GCP project that owns the job
70+
- ``location``: the location of the Data Pipelines instance
71+
- ``gcp_conn_id``: the connection ID to connect to the Google Cloud Platform
72+
73+
Only the Data Pipeline name and Project ID are required parameters, as the Location and GCP Connection ID have default values.
74+
The Project ID and Location will be used to build the parent name, which is where the given Data Pipeline should be located.
75+
76+
You can run a Data Pipelines instance by running the above parameters with RunDataPipelineOperator:
77+
78+
.. exampleinclude:: /../../tests/system/providers/google/cloud/datapipelines/example_datapipeline.py
79+
:language: python
80+
:dedent: 4
81+
:start-after: [START howto_operator_run_data_pipeline]
82+
:end-before: [END howto_operator_run_data_pipeline]
83+
84+
Once called, the RunDataPipelineOperator will return the Google Cloud `Dataflow Job <https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/Job>`__
85+
created by running the given pipeline.
86+
5887
For further information regarding the API usage, see
5988
`Data Pipelines API REST Resource <https://cloud.google.com/dataflow/docs/reference/data-pipelines/rest/v1/projects.locations.pipelines#Pipeline>`__
6089
in the Google Cloud documentation.

β€Žtests/providers/google/cloud/hooks/test_datapipeline.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,26 @@ def test_create_data_pipeline(self, mock_connection):
108108
body=TEST_BODY,
109109
)
110110
assert result == {"name": TEST_PARENT}
111+
112+
@mock.patch("airflow.providers.google.cloud.hooks.datapipeline.DataPipelineHook.get_conn")
113+
def test_run_data_pipeline(self, mock_connection):
114+
"""
115+
Test that run_data_pipeline is called with correct parameters and
116+
calls Google Data Pipelines API
117+
"""
118+
mock_request = (
119+
mock_connection.return_value.projects.return_value.locations.return_value.pipelines.return_value.run
120+
)
121+
mock_request.return_value.execute.return_value = {"job": {"id": TEST_JOB_ID}}
122+
123+
result = self.datapipeline_hook.run_data_pipeline(
124+
data_pipeline_name=TEST_DATA_PIPELINE_NAME,
125+
project_id=TEST_PROJECTID,
126+
location=TEST_LOCATION,
127+
)
128+
129+
mock_request.assert_called_once_with(
130+
name=TEST_NAME,
131+
body={},
132+
)
133+
assert result == {"job": {"id": TEST_JOB_ID}}

β€Žtests/providers/google/cloud/operators/test_datapipeline.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from airflow.exceptions import AirflowException
2525
from airflow.providers.google.cloud.operators.datapipeline import (
2626
CreateDataPipelineOperator,
27+
RunDataPipelineOperator,
2728
)
2829

2930
TASK_ID = "test-datapipeline-operators"
@@ -136,3 +137,92 @@ def test_response_invalid(self):
136137
}
137138
with pytest.raises(AirflowException):
138139
CreateDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
140+
141+
142+
class TestRunDataPipelineOperator:
143+
@pytest.fixture
144+
def run_operator(self):
145+
"""
146+
Create a RunDataPipelineOperator instance with test data
147+
"""
148+
return RunDataPipelineOperator(
149+
task_id=TASK_ID,
150+
data_pipeline_name=TEST_DATA_PIPELINE_NAME,
151+
project_id=TEST_PROJECTID,
152+
location=TEST_LOCATION,
153+
gcp_conn_id=TEST_GCP_CONN_ID,
154+
)
155+
156+
@mock.patch("airflow.providers.google.cloud.operators.datapipeline.DataPipelineHook")
157+
def test_execute(self, data_pipeline_hook_mock, run_operator):
158+
"""
159+
Test Run Operator execute with correct parameters
160+
"""
161+
run_operator.execute(mock.MagicMock())
162+
data_pipeline_hook_mock.assert_called_once_with(
163+
gcp_conn_id=TEST_GCP_CONN_ID,
164+
)
165+
166+
data_pipeline_hook_mock.return_value.run_data_pipeline.assert_called_once_with(
167+
data_pipeline_name=TEST_DATA_PIPELINE_NAME,
168+
project_id=TEST_PROJECTID,
169+
location=TEST_LOCATION,
170+
)
171+
172+
def test_invalid_data_pipeline_name(self):
173+
"""
174+
Test that AirflowException is raised if Run Operator is not given a data pipeline name.
175+
"""
176+
init_kwargs = {
177+
"task_id": TASK_ID,
178+
"data_pipeline_name": None,
179+
"project_id": TEST_PROJECTID,
180+
"location": TEST_LOCATION,
181+
"gcp_conn_id": TEST_GCP_CONN_ID,
182+
}
183+
with pytest.raises(AirflowException):
184+
RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
185+
186+
def test_invalid_project_id(self):
187+
"""
188+
Test that AirflowException is raised if Run Operator is not given a project ID.
189+
"""
190+
init_kwargs = {
191+
"task_id": TASK_ID,
192+
"data_pipeline_name": TEST_DATA_PIPELINE_NAME,
193+
"project_id": None,
194+
"location": TEST_LOCATION,
195+
"gcp_conn_id": TEST_GCP_CONN_ID,
196+
}
197+
with pytest.raises(AirflowException):
198+
RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
199+
200+
def test_invalid_location(self):
201+
"""
202+
Test that AirflowException is raised if Run Operator is not given a location.
203+
"""
204+
init_kwargs = {
205+
"task_id": TASK_ID,
206+
"data_pipeline_name": TEST_DATA_PIPELINE_NAME,
207+
"project_id": TEST_PROJECTID,
208+
"location": None,
209+
"gcp_conn_id": TEST_GCP_CONN_ID,
210+
}
211+
with pytest.raises(AirflowException):
212+
RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
213+
214+
def test_invalid_response(self):
215+
"""
216+
Test that AirflowException is raised if Run Operator fails execution and returns error.
217+
"""
218+
init_kwargs = {
219+
"task_id": TASK_ID,
220+
"data_pipeline_name": TEST_DATA_PIPELINE_NAME,
221+
"project_id": TEST_PROJECTID,
222+
"location": TEST_LOCATION,
223+
"gcp_conn_id": TEST_GCP_CONN_ID,
224+
}
225+
with pytest.raises(AirflowException):
226+
RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock()).return_value = {
227+
"error": {"message": "example error"}
228+
}

β€Žtests/system/providers/google/cloud/datapipelines/example_datapipeline.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from airflow import models
2929
from airflow.providers.google.cloud.operators.datapipeline import (
3030
CreateDataPipelineOperator,
31+
RunDataPipelineOperator,
3132
)
3233
from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator
3334
from airflow.providers.google.cloud.transfers.local_to_gcs import LocalFilesystemToGCSOperator
@@ -38,7 +39,7 @@
3839
GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
3940
GCP_LOCATION = os.environ.get("location", "us-central1")
4041

41-
PIPELINE_NAME = "defualt-pipeline-name"
42+
PIPELINE_NAME = os.environ.get("DATA_PIPELINE_NAME", "defualt-pipeline-name")
4243
PIPELINE_TYPE = "PIPELINE_TYPE_BATCH"
4344

4445
BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}"
@@ -117,6 +118,13 @@
117118
# when "teardown" task with trigger rule is part of the DAG
118119
list(dag.tasks) >> watcher()
119120

121+
# [START howto_operator_run_data_pipeline]
122+
run_data_pipeline = RunDataPipelineOperator(
123+
task_id="run_data_pipeline",
124+
data_pipeline_name=PIPELINE_NAME,
125+
project_id=GCP_PROJECT_ID,
126+
)
127+
# [END howto_operator_run_data_pipeline]
120128

121129
from tests.system.utils import get_test_run # noqa: E402
122130

0 commit comments

Comments
 (0)