Skip to content

Commit 25d463c

Browse files
Deprecate AutoMLTrainModelOperator for NL (#34212)
1 parent cad983d commit 25d463c

File tree

5 files changed

+240
-148
lines changed

5 files changed

+240
-148
lines changed

β€Žairflow/providers/google/cloud/operators/automl.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import annotations
2020

2121
import ast
22+
import warnings
2223
from typing import TYPE_CHECKING, Sequence, Tuple
2324

2425
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
@@ -31,6 +32,7 @@
3132
TableSpec,
3233
)
3334

35+
from airflow.exceptions import AirflowProviderDeprecationWarning
3436
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
3537
from airflow.providers.google.cloud.links.automl import (
3638
AutoMLDatasetLink,
@@ -53,6 +55,10 @@ class AutoMLTrainModelOperator(GoogleCloudBaseOperator):
5355
"""
5456
Creates Google Cloud AutoML model.
5557
58+
AutoMLTrainModelOperator for text prediction is deprecated. Please use
59+
:class:`airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTextTrainingJobOperator`
60+
instead.
61+
5662
.. seealso::
5763
For more information on how to use this operator, take a look at the guide:
5864
:ref:`howto/operator:AutoMLTrainModelOperator`
@@ -102,7 +108,6 @@ def __init__(
102108
**kwargs,
103109
) -> None:
104110
super().__init__(**kwargs)
105-
106111
self.model = model
107112
self.location = location
108113
self.project_id = project_id
@@ -113,6 +118,20 @@ def __init__(
113118
self.impersonation_chain = impersonation_chain
114119

115120
def execute(self, context: Context):
121+
# Output warning if running AutoML Natural Language prediction job
122+
automl_nl_model_keys = [
123+
"text_classification_model_metadata",
124+
"text_extraction_model_metadata",
125+
"text_sentiment_dataset_metadata",
126+
]
127+
if any(key in automl_nl_model_keys for key in self.model):
128+
warnings.warn(
129+
"AutoMLTrainModelOperator for text prediction is deprecated. All the functionality of legacy "
130+
"AutoML Natural Language and new features are available on the Vertex AI platform. "
131+
"Please use `CreateAutoMLTextTrainingJobOperator`",
132+
AirflowProviderDeprecationWarning,
133+
stacklevel=2,
134+
)
116135
hook = CloudAutoMLHook(
117136
gcp_conn_id=self.gcp_conn_id,
118137
impersonation_chain=self.impersonation_chain,

β€Ždocs/apache-airflow-providers-google/operators/cloud/automl.rst

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,16 @@ To create a Google AutoML model you can use
102102
The operator will wait for the operation to complete. Additionally the operator
103103
returns the id of model in :ref:`XCom <concepts:xcom>` under ``model_id`` key.
104104

105+
This Operator is deprecated when running for text prediction and will be removed soon.
106+
All the functionality of legacy AutoML Natural Language and new features are available on the
107+
Vertex AI platform. Please use
108+
:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTextTrainingJobOperator`.
109+
When running Vertex AI Operator for training dat, please ensure that your data is correctly stored in Vertex AI
110+
datasets. To create and import data to the dataset please use
111+
:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator`
112+
and
113+
:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator`
114+
105115
.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py
106116
:language: python
107117
:dedent: 4
@@ -164,7 +174,7 @@ the model must be deployed.
164174
Listing And Deleting Datasets
165175
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
166176

167-
You can get a list of AutoML models using
177+
You can get a list of AutoML datasets using
168178
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLListDatasetOperator`. The operator returns list
169179
of datasets ids in :ref:`XCom <concepts:xcom>` under ``dataset_id_list`` key.
170180

@@ -174,7 +184,7 @@ of datasets ids in :ref:`XCom <concepts:xcom>` under ``dataset_id_list`` key.
174184
:start-after: [START howto_operator_list_dataset]
175185
:end-before: [END howto_operator_list_dataset]
176186

177-
To delete a model you can use :class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator`.
187+
To delete a dataset you can use :class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator`.
178188
The delete operator allows also to pass list or coma separated string of datasets ids to be deleted.
179189

180190
.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py

β€Žtests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py

Lines changed: 69 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -24,47 +24,54 @@
2424
from datetime import datetime
2525
from typing import cast
2626

27+
from google.cloud.aiplatform import schema
28+
from google.protobuf.struct_pb2 import Value
29+
2730
from airflow import models
2831
from airflow.models.xcom_arg import XComArg
2932
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
30-
from airflow.providers.google.cloud.operators.automl import (
31-
AutoMLCreateDatasetOperator,
32-
AutoMLDeleteDatasetOperator,
33-
AutoMLDeleteModelOperator,
34-
AutoMLDeployModelOperator,
35-
AutoMLImportDataOperator,
36-
AutoMLTrainModelOperator,
37-
)
3833
from airflow.providers.google.cloud.operators.gcs import (
3934
GCSCreateBucketOperator,
4035
GCSDeleteBucketOperator,
4136
GCSSynchronizeBucketsOperator,
4237
)
38+
from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
39+
CreateAutoMLTextTrainingJobOperator,
40+
DeleteAutoMLTrainingJobOperator,
41+
)
42+
from airflow.providers.google.cloud.operators.vertex_ai.dataset import (
43+
CreateDatasetOperator,
44+
DeleteDatasetOperator,
45+
ImportDataOperator,
46+
)
4347
from airflow.utils.trigger_rule import TriggerRule
4448

4549
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
46-
DAG_ID = "example_automl_text_cls"
4750
GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
51+
DAG_ID = "example_automl_text_cls"
4852

4953
GCP_AUTOML_LOCATION = "us-central1"
5054
DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
5155
RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
5256

53-
MODEL_NAME = "text_clss_test_model"
54-
MODEL = {
55-
"display_name": MODEL_NAME,
56-
"text_classification_model_metadata": {},
57-
}
57+
TEXT_CLSS_DISPLAY_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-")
58+
AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/classification.csv"
59+
60+
MODEL_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-")
5861

5962
DATASET_NAME = f"ds_clss_{ENV_ID}".replace("-", "_")
6063
DATASET = {
6164
"display_name": DATASET_NAME,
62-
"text_classification_dataset_metadata": {"classification_type": "MULTICLASS"},
65+
"metadata_schema_uri": schema.dataset.metadata.text,
66+
"metadata": Value(string_value="clss-dataset"),
6367
}
6468

65-
AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/text_classification.csv"
66-
IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}}
67-
69+
DATA_CONFIG = [
70+
{
71+
"import_schema_uri": schema.dataset.ioformat.text.single_label_classification,
72+
"gcs_source": {"uris": [AUTOML_DATASET_BUCKET]},
73+
},
74+
]
6875
extract_object_id = CloudAutoMLHook.extract_object_id
6976

7077
# Example DAG for AutoML Natural Language Text Classification
@@ -85,67 +92,77 @@
8592
move_dataset_file = GCSSynchronizeBucketsOperator(
8693
task_id="move_dataset_to_bucket",
8794
source_bucket=RESOURCE_DATA_BUCKET,
88-
source_object="automl/datasets/text",
95+
source_object="vertex-ai/automl/datasets/text",
8996
destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
9097
destination_object="automl",
9198
recursive=True,
9299
)
93100

94-
create_dataset = AutoMLCreateDatasetOperator(
95-
task_id="create_dataset",
101+
create_clss_dataset = CreateDatasetOperator(
102+
task_id="create_clss_dataset",
96103
dataset=DATASET,
97-
location=GCP_AUTOML_LOCATION,
104+
region=GCP_AUTOML_LOCATION,
98105
project_id=GCP_PROJECT_ID,
99106
)
107+
clss_dataset_id = create_clss_dataset.output["dataset_id"]
100108

101-
dataset_id = cast(str, XComArg(create_dataset, key="dataset_id"))
102-
MODEL["dataset_id"] = dataset_id
103-
import_dataset = AutoMLImportDataOperator(
104-
task_id="import_dataset",
105-
dataset_id=dataset_id,
106-
location=GCP_AUTOML_LOCATION,
107-
input_config=IMPORT_INPUT_CONFIG,
109+
import_clss_dataset = ImportDataOperator(
110+
task_id="import_clss_data",
111+
dataset_id=clss_dataset_id,
112+
region=GCP_AUTOML_LOCATION,
113+
project_id=GCP_PROJECT_ID,
114+
import_configs=DATA_CONFIG,
108115
)
109-
MODEL["dataset_id"] = dataset_id
110-
111-
create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION)
112-
model_id = cast(str, XComArg(create_model, key="model_id"))
113116

114-
deploy_model = AutoMLDeployModelOperator(
115-
task_id="deploy_model",
116-
model_id=model_id,
117-
location=GCP_AUTOML_LOCATION,
117+
# [START howto_operator_automl_create_model]
118+
create_clss_training_job = CreateAutoMLTextTrainingJobOperator(
119+
task_id="create_clss_training_job",
120+
display_name=TEXT_CLSS_DISPLAY_NAME,
121+
prediction_type="classification",
122+
multi_label=False,
123+
dataset_id=clss_dataset_id,
124+
model_display_name=MODEL_NAME,
125+
training_fraction_split=0.7,
126+
validation_fraction_split=0.2,
127+
test_fraction_split=0.1,
128+
sync=True,
129+
region=GCP_AUTOML_LOCATION,
118130
project_id=GCP_PROJECT_ID,
119131
)
132+
# [END howto_operator_automl_create_model]
133+
model_id = cast(str, XComArg(create_clss_training_job, key="model_id"))
120134

121-
delete_model = AutoMLDeleteModelOperator(
122-
task_id="delete_model",
123-
model_id=model_id,
124-
location=GCP_AUTOML_LOCATION,
135+
delete_clss_training_job = DeleteAutoMLTrainingJobOperator(
136+
task_id="delete_clss_training_job",
137+
training_pipeline_id=create_clss_training_job.output["training_id"],
138+
region=GCP_AUTOML_LOCATION,
125139
project_id=GCP_PROJECT_ID,
140+
trigger_rule=TriggerRule.ALL_DONE,
126141
)
127142

128-
delete_dataset = AutoMLDeleteDatasetOperator(
129-
task_id="delete_dataset",
130-
dataset_id=dataset_id,
131-
location=GCP_AUTOML_LOCATION,
143+
delete_clss_dataset = DeleteDatasetOperator(
144+
task_id="delete_clss_dataset",
145+
dataset_id=clss_dataset_id,
146+
region=GCP_AUTOML_LOCATION,
132147
project_id=GCP_PROJECT_ID,
148+
trigger_rule=TriggerRule.ALL_DONE,
133149
)
134150

135151
delete_bucket = GCSDeleteBucketOperator(
136-
task_id="delete_bucket", bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE
152+
task_id="delete_bucket",
153+
bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
154+
trigger_rule=TriggerRule.ALL_DONE,
137155
)
138156

139157
(
140158
# TEST SETUP
141-
[create_bucket >> move_dataset_file, create_dataset]
159+
[create_bucket >> move_dataset_file, create_clss_dataset]
142160
# TEST BODY
143-
>> import_dataset
144-
>> create_model
145-
>> deploy_model
161+
>> import_clss_dataset
162+
>> create_clss_training_job
146163
# TEST TEARDOWN
147-
>> delete_model
148-
>> delete_dataset
164+
>> delete_clss_training_job
165+
>> delete_clss_dataset
149166
>> delete_bucket
150167
)
151168

0 commit comments

Comments
 (0)