Skip to content

Commit 9442435

Browse files
authored
Rename Vertex AI AutoML operators fields' names to comply with templated fields validation (#38049)
1 parent 777a216 commit 9442435

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

β€Ž.pre-commit-config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,6 @@ repos:
337337
^airflow\/providers\/google\/cloud\/operators\/mlengine.py$|
338338
^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service.py$|
339339
^airflow\/providers\/apache\/spark\/operators\/spark_submit.py\.py$|
340-
^airflow\/providers\/google\/cloud\/operators\/vertex_ai\/auto_ml\.py$|
341340
^airflow\/providers\/apache\/spark\/operators\/spark_submit\.py$|
342341
^airflow\/providers\/databricks\/operators\/databricks_sql\.py$|
343342
)$

β€Žairflow/providers/google/cloud/operators/vertex_ai/auto_ml.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222

2323
from typing import TYPE_CHECKING, Sequence
2424

25+
from deprecated import deprecated
2526
from google.api_core.exceptions import NotFound
2627
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
2728
from google.cloud.aiplatform import datasets
2829
from google.cloud.aiplatform.models import Model
2930
from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
3031

32+
from airflow.exceptions import AirflowProviderDeprecationWarning
3133
from airflow.providers.google.cloud.hooks.vertex_ai.auto_ml import AutoMLHook
3234
from airflow.providers.google.cloud.links.vertex_ai import (
3335
VertexAIModelLink,
@@ -607,7 +609,7 @@ class DeleteAutoMLTrainingJobOperator(GoogleCloudBaseOperator):
607609
AutoMLTabularTrainingJob, AutoMLTextTrainingJob, or AutoMLVideoTrainingJob.
608610
"""
609611

610-
template_fields = ("training_pipeline", "region", "project_id", "impersonation_chain")
612+
template_fields = ("training_pipeline_id", "region", "project_id", "impersonation_chain")
611613

612614
def __init__(
613615
self,
@@ -623,7 +625,7 @@ def __init__(
623625
**kwargs,
624626
) -> None:
625627
super().__init__(**kwargs)
626-
self.training_pipeline = training_pipeline_id
628+
self.training_pipeline_id = training_pipeline_id
627629
self.region = region
628630
self.project_id = project_id
629631
self.retry = retry
@@ -632,6 +634,16 @@ def __init__(
632634
self.gcp_conn_id = gcp_conn_id
633635
self.impersonation_chain = impersonation_chain
634636

637+
@property
638+
@deprecated(
639+
reason="`training_pipeline` is deprecated and will be removed in the future. "
640+
"Please use `training_pipeline_id` instead.",
641+
category=AirflowProviderDeprecationWarning,
642+
)
643+
def training_pipeline(self):
644+
"""Alias for ``training_pipeline_id``, used for compatibility (deprecated)."""
645+
return self.training_pipeline_id
646+
635647
def execute(self, context: Context):
636648
hook = AutoMLHook(
637649
gcp_conn_id=self.gcp_conn_id,

β€Žtests/providers/google/cloud/operators/test_vertex_ai.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,30 @@ def test_execute(self, mock_hook):
10281028
metadata=METADATA,
10291029
)
10301030

1031+
@pytest.mark.db_test
1032+
def test_templating(self, create_task_instance_of_operator):
1033+
ti = create_task_instance_of_operator(
1034+
DeleteAutoMLTrainingJobOperator,
1035+
# Templated fields
1036+
training_pipeline_id="{{ 'training-pipeline-id' }}",
1037+
region="{{ 'region' }}",
1038+
project_id="{{ 'project-id' }}",
1039+
impersonation_chain="{{ 'impersonation-chain' }}",
1040+
# Other parameters
1041+
dag_id="test_template_body_templating_dag",
1042+
task_id="test_template_body_templating_task",
1043+
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
1044+
)
1045+
ti.render_templates()
1046+
task: DeleteAutoMLTrainingJobOperator = ti.task
1047+
assert task.training_pipeline_id == "training-pipeline-id"
1048+
assert task.region == "region"
1049+
assert task.project_id == "project-id"
1050+
assert task.impersonation_chain == "impersonation-chain"
1051+
1052+
with pytest.warns(AirflowProviderDeprecationWarning):
1053+
assert task.training_pipeline == "training-pipeline-id"
1054+
10311055

10321056
class TestVertexAIListAutoMLTrainingJobOperator:
10331057
@mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))

0 commit comments

Comments
 (0)