Skip to content

Commit 48abf57

Browse files
mai-nakagawapotiuk
authored andcommitted
Add endpoint_id arg to google.cloud.operators.vertex_ai.CreateEndpointOperator
1 parent b45240a commit 48abf57

File tree

4 files changed

+13
-0
lines changed

4 files changed

+13
-0
lines changed

β€Žairflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def create_endpoint(
8181
project_id: str,
8282
region: str,
8383
endpoint: Union[Endpoint, Dict],
84+
endpoint_id: Optional[str] = None,
8485
retry: Union[Retry, _MethodDefault] = DEFAULT,
8586
timeout: Optional[float] = None,
8687
metadata: Sequence[Tuple[str, str]] = (),
@@ -91,6 +92,7 @@ def create_endpoint(
9192
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
9293
:param region: Required. The ID of the Google Cloud region that the service belongs to.
9394
:param endpoint: Required. The Endpoint to create.
95+
:param endpoint_id: The ID of Endpoint. If not provided, Vertex AI will generate a value for this ID.
9496
:param retry: Designation of what errors, if any, should be retried.
9597
:param timeout: The timeout for this request.
9698
:param metadata: Strings which should be sent along with the request as metadata.
@@ -102,6 +104,7 @@ def create_endpoint(
102104
request={
103105
'parent': parent,
104106
'endpoint': endpoint,
107+
'endpoint_id': endpoint_id,
105108
},
106109
retry=retry,
107110
timeout=timeout,

β€Žairflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
region: str,
8282
project_id: str,
8383
endpoint: Union[Endpoint, Dict],
84+
endpoint_id: Optional[str] = None,
8485
retry: Union[Retry, _MethodDefault] = DEFAULT,
8586
timeout: Optional[float] = None,
8687
metadata: Sequence[Tuple[str, str]] = (),
@@ -93,6 +94,7 @@ def __init__(
9394
self.region = region
9495
self.project_id = project_id
9596
self.endpoint = endpoint
97+
self.endpoint_id = endpoint_id
9698
self.retry = retry
9799
self.timeout = timeout
98100
self.metadata = metadata
@@ -112,6 +114,7 @@ def execute(self, context: 'Context'):
112114
project_id=self.project_id,
113115
region=self.region,
114116
endpoint=self.endpoint,
117+
endpoint_id=self.endpoint_id,
115118
retry=self.retry,
116119
timeout=self.timeout,
117120
metadata=self.metadata,

β€Žtests/providers/google/cloud/hooks/vertex_ai/test_endpoint_service.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
TEST_REGION: str = "test-region"
3232
TEST_PROJECT_ID: str = "test-project-id"
3333
TEST_ENDPOINT: dict = {}
34+
TEST_ENDPOINT_ID: str = "test_endpoint_id"
3435
TEST_ENDPOINT_NAME: str = "test_endpoint_name"
3536
TEST_DEPLOYED_MODEL: dict = {}
3637
TEST_DEPLOYED_MODEL_ID: str = "test-deployed-model-id"
@@ -54,12 +55,14 @@ def test_create_endpoint(self, mock_client) -> None:
5455
project_id=TEST_PROJECT_ID,
5556
region=TEST_REGION,
5657
endpoint=TEST_ENDPOINT,
58+
endpoint_id=TEST_ENDPOINT_ID,
5759
)
5860
mock_client.assert_called_once_with(TEST_REGION)
5961
mock_client.return_value.create_endpoint.assert_called_once_with(
6062
request=dict(
6163
parent=mock_client.return_value.common_location_path.return_value,
6264
endpoint=TEST_ENDPOINT,
65+
endpoint_id=TEST_ENDPOINT_ID,
6366
),
6467
metadata=(),
6568
retry=DEFAULT,
@@ -223,12 +226,14 @@ def test_create_endpoint(self, mock_client) -> None:
223226
project_id=TEST_PROJECT_ID,
224227
region=TEST_REGION,
225228
endpoint=TEST_ENDPOINT,
229+
endpoint_id=TEST_ENDPOINT_ID,
226230
)
227231
mock_client.assert_called_once_with(TEST_REGION)
228232
mock_client.return_value.create_endpoint.assert_called_once_with(
229233
request=dict(
230234
parent=mock_client.return_value.common_location_path.return_value,
231235
endpoint=TEST_ENDPOINT,
236+
endpoint_id=TEST_ENDPOINT_ID,
232237
),
233238
metadata=(),
234239
retry=DEFAULT,

β€Žtests/providers/google/cloud/operators/test_vertex_ai.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,6 +1137,7 @@ def test_execute(self, mock_hook, to_dict_mock):
11371137
region=GCP_LOCATION,
11381138
project_id=GCP_PROJECT,
11391139
endpoint=TEST_ENDPOINT,
1140+
endpoint_id=TEST_ENDPOINT_ID,
11401141
retry=RETRY,
11411142
timeout=TIMEOUT,
11421143
metadata=METADATA,
@@ -1149,6 +1150,7 @@ def test_execute(self, mock_hook, to_dict_mock):
11491150
region=GCP_LOCATION,
11501151
project_id=GCP_PROJECT,
11511152
endpoint=TEST_ENDPOINT,
1153+
endpoint_id=TEST_ENDPOINT_ID,
11521154
retry=RETRY,
11531155
timeout=TIMEOUT,
11541156
metadata=METADATA,

0 commit comments

Comments
 (0)