|
24 | 24 | from datetime import datetime
|
25 | 25 | from typing import cast
|
26 | 26 |
|
| 27 | +from google.cloud.aiplatform import schema |
| 28 | +from google.protobuf.struct_pb2 import Value |
| 29 | + |
27 | 30 | from airflow import models
|
28 | 31 | from airflow.models.xcom_arg import XComArg
|
29 | 32 | from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
|
30 |
| -from airflow.providers.google.cloud.operators.automl import ( |
31 |
| - AutoMLCreateDatasetOperator, |
32 |
| - AutoMLDeleteDatasetOperator, |
33 |
| - AutoMLDeleteModelOperator, |
34 |
| - AutoMLDeployModelOperator, |
35 |
| - AutoMLImportDataOperator, |
36 |
| - AutoMLTrainModelOperator, |
37 |
| -) |
38 | 33 | from airflow.providers.google.cloud.operators.gcs import (
|
39 | 34 | GCSCreateBucketOperator,
|
40 | 35 | GCSDeleteBucketOperator,
|
41 | 36 | GCSSynchronizeBucketsOperator,
|
42 | 37 | )
|
| 38 | +from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import ( |
| 39 | + CreateAutoMLTextTrainingJobOperator, |
| 40 | + DeleteAutoMLTrainingJobOperator, |
| 41 | +) |
| 42 | +from airflow.providers.google.cloud.operators.vertex_ai.dataset import ( |
| 43 | + CreateDatasetOperator, |
| 44 | + DeleteDatasetOperator, |
| 45 | + ImportDataOperator, |
| 46 | +) |
43 | 47 | from airflow.utils.trigger_rule import TriggerRule
|
44 | 48 |
|
45 | 49 | ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
|
46 |
| -DAG_ID = "example_automl_text_cls" |
47 | 50 | GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
|
| 51 | +DAG_ID = "example_automl_text_cls" |
48 | 52 |
|
49 | 53 | GCP_AUTOML_LOCATION = "us-central1"
|
50 | 54 | DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
|
51 | 55 | RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
|
52 | 56 |
|
53 |
| -MODEL_NAME = "text_clss_test_model" |
54 |
| -MODEL = { |
55 |
| - "display_name": MODEL_NAME, |
56 |
| - "text_classification_model_metadata": {}, |
57 |
| -} |
| 57 | +TEXT_CLSS_DISPLAY_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-") |
| 58 | +AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/classification.csv" |
| 59 | + |
| 60 | +MODEL_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-") |
58 | 61 |
|
59 | 62 | DATASET_NAME = f"ds_clss_{ENV_ID}".replace("-", "_")
|
60 | 63 | DATASET = {
|
61 | 64 | "display_name": DATASET_NAME,
|
62 |
| - "text_classification_dataset_metadata": {"classification_type": "MULTICLASS"}, |
| 65 | + "metadata_schema_uri": schema.dataset.metadata.text, |
| 66 | + "metadata": Value(string_value="clss-dataset"), |
63 | 67 | }
|
64 | 68 |
|
65 |
| -AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/text_classification.csv" |
66 |
| -IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}} |
67 |
| - |
| 69 | +DATA_CONFIG = [ |
| 70 | + { |
| 71 | + "import_schema_uri": schema.dataset.ioformat.text.single_label_classification, |
| 72 | + "gcs_source": {"uris": [AUTOML_DATASET_BUCKET]}, |
| 73 | + }, |
| 74 | +] |
68 | 75 | extract_object_id = CloudAutoMLHook.extract_object_id
|
69 | 76 |
|
70 | 77 | # Example DAG for AutoML Natural Language Text Classification
|
|
85 | 92 | move_dataset_file = GCSSynchronizeBucketsOperator(
|
86 | 93 | task_id="move_dataset_to_bucket",
|
87 | 94 | source_bucket=RESOURCE_DATA_BUCKET,
|
88 |
| - source_object="automl/datasets/text", |
| 95 | + source_object="vertex-ai/automl/datasets/text", |
89 | 96 | destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
|
90 | 97 | destination_object="automl",
|
91 | 98 | recursive=True,
|
92 | 99 | )
|
93 | 100 |
|
94 |
| - create_dataset = AutoMLCreateDatasetOperator( |
95 |
| - task_id="create_dataset", |
| 101 | + create_clss_dataset = CreateDatasetOperator( |
| 102 | + task_id="create_clss_dataset", |
96 | 103 | dataset=DATASET,
|
97 |
| - location=GCP_AUTOML_LOCATION, |
| 104 | + region=GCP_AUTOML_LOCATION, |
98 | 105 | project_id=GCP_PROJECT_ID,
|
99 | 106 | )
|
| 107 | + clss_dataset_id = create_clss_dataset.output["dataset_id"] |
100 | 108 |
|
101 |
| - dataset_id = cast(str, XComArg(create_dataset, key="dataset_id")) |
102 |
| - MODEL["dataset_id"] = dataset_id |
103 |
| - import_dataset = AutoMLImportDataOperator( |
104 |
| - task_id="import_dataset", |
105 |
| - dataset_id=dataset_id, |
106 |
| - location=GCP_AUTOML_LOCATION, |
107 |
| - input_config=IMPORT_INPUT_CONFIG, |
| 109 | + import_clss_dataset = ImportDataOperator( |
| 110 | + task_id="import_clss_data", |
| 111 | + dataset_id=clss_dataset_id, |
| 112 | + region=GCP_AUTOML_LOCATION, |
| 113 | + project_id=GCP_PROJECT_ID, |
| 114 | + import_configs=DATA_CONFIG, |
108 | 115 | )
|
109 |
| - MODEL["dataset_id"] = dataset_id |
110 |
| - |
111 |
| - create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION) |
112 |
| - model_id = cast(str, XComArg(create_model, key="model_id")) |
113 | 116 |
|
114 |
| - deploy_model = AutoMLDeployModelOperator( |
115 |
| - task_id="deploy_model", |
116 |
| - model_id=model_id, |
117 |
| - location=GCP_AUTOML_LOCATION, |
| 117 | + # [START howto_operator_automl_create_model] |
| 118 | + create_clss_training_job = CreateAutoMLTextTrainingJobOperator( |
| 119 | + task_id="create_clss_training_job", |
| 120 | + display_name=TEXT_CLSS_DISPLAY_NAME, |
| 121 | + prediction_type="classification", |
| 122 | + multi_label=False, |
| 123 | + dataset_id=clss_dataset_id, |
| 124 | + model_display_name=MODEL_NAME, |
| 125 | + training_fraction_split=0.7, |
| 126 | + validation_fraction_split=0.2, |
| 127 | + test_fraction_split=0.1, |
| 128 | + sync=True, |
| 129 | + region=GCP_AUTOML_LOCATION, |
118 | 130 | project_id=GCP_PROJECT_ID,
|
119 | 131 | )
|
| 132 | + # [END howto_operator_automl_create_model] |
| 133 | + model_id = cast(str, XComArg(create_clss_training_job, key="model_id")) |
120 | 134 |
|
121 |
| - delete_model = AutoMLDeleteModelOperator( |
122 |
| - task_id="delete_model", |
123 |
| - model_id=model_id, |
124 |
| - location=GCP_AUTOML_LOCATION, |
| 135 | + delete_clss_training_job = DeleteAutoMLTrainingJobOperator( |
| 136 | + task_id="delete_clss_training_job", |
| 137 | + training_pipeline_id=create_clss_training_job.output["training_id"], |
| 138 | + region=GCP_AUTOML_LOCATION, |
125 | 139 | project_id=GCP_PROJECT_ID,
|
| 140 | + trigger_rule=TriggerRule.ALL_DONE, |
126 | 141 | )
|
127 | 142 |
|
128 |
| - delete_dataset = AutoMLDeleteDatasetOperator( |
129 |
| - task_id="delete_dataset", |
130 |
| - dataset_id=dataset_id, |
131 |
| - location=GCP_AUTOML_LOCATION, |
| 143 | + delete_clss_dataset = DeleteDatasetOperator( |
| 144 | + task_id="delete_clss_dataset", |
| 145 | + dataset_id=clss_dataset_id, |
| 146 | + region=GCP_AUTOML_LOCATION, |
132 | 147 | project_id=GCP_PROJECT_ID,
|
| 148 | + trigger_rule=TriggerRule.ALL_DONE, |
133 | 149 | )
|
134 | 150 |
|
135 | 151 | delete_bucket = GCSDeleteBucketOperator(
|
136 |
| - task_id="delete_bucket", bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE |
| 152 | + task_id="delete_bucket", |
| 153 | + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, |
| 154 | + trigger_rule=TriggerRule.ALL_DONE, |
137 | 155 | )
|
138 | 156 |
|
139 | 157 | (
|
140 | 158 | # TEST SETUP
|
141 |
| - [create_bucket >> move_dataset_file, create_dataset] |
| 159 | + [create_bucket >> move_dataset_file, create_clss_dataset] |
142 | 160 | # TEST BODY
|
143 |
| - >> import_dataset |
144 |
| - >> create_model |
145 |
| - >> deploy_model |
| 161 | + >> import_clss_dataset |
| 162 | + >> create_clss_training_job |
146 | 163 | # TEST TEARDOWN
|
147 |
| - >> delete_model |
148 |
| - >> delete_dataset |
| 164 | + >> delete_clss_training_job |
| 165 | + >> delete_clss_dataset |
149 | 166 | >> delete_bucket
|
150 | 167 | )
|
151 | 168 |
|
|
0 commit comments