|
39 | 39 | GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1")
|
40 | 40 | GCP_AUTOML_TEXT_CLS_BUCKET = os.environ.get("GCP_AUTOML_TEXT_CLS_BUCKET", "gs://INVALID BUCKET NAME")
|
41 | 41 |
|
42 |
| -# Example values |
43 |
| -DATASET_ID = "" |
44 |
| - |
45 | 42 | # Example model
|
46 | 43 | MODEL = {
|
47 | 44 | "display_name": "auto_model_1",
|
48 |
| - "dataset_id": DATASET_ID, |
49 | 45 | "text_classification_model_metadata": {},
|
50 | 46 | }
|
51 | 47 |
|
|
55 | 51 | "text_classification_dataset_metadata": {"classification_type": "MULTICLASS"},
|
56 | 52 | }
|
57 | 53 |
|
| 54 | + |
58 | 55 | IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_TEXT_CLS_BUCKET]}}
|
59 | 56 |
|
60 | 57 | extract_object_id = CloudAutoMLHook.extract_object_id
|
|
65 | 62 | start_date=datetime(2021, 1, 1),
|
66 | 63 | catchup=False,
|
67 | 64 | tags=["example"],
|
68 |
| -) as example_dag: |
| 65 | +) as dag: |
69 | 66 | create_dataset_task = AutoMLCreateDatasetOperator(
|
70 | 67 | task_id="create_dataset_task", dataset=DATASET, location=GCP_AUTOML_LOCATION
|
71 | 68 | )
|
72 | 69 |
|
73 | 70 | dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id"))
|
| 71 | + MODEL["dataset_id"] = dataset_id |
74 | 72 |
|
75 | 73 | import_dataset_task = AutoMLImportDataOperator(
|
76 | 74 | task_id="import_dataset_task",
|
77 | 75 | dataset_id=dataset_id,
|
78 | 76 | location=GCP_AUTOML_LOCATION,
|
79 | 77 | input_config=IMPORT_INPUT_CONFIG,
|
80 | 78 | )
|
81 |
| - |
82 | 79 | MODEL["dataset_id"] = dataset_id
|
83 | 80 |
|
84 | 81 | create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION)
|
85 |
| - |
86 | 82 | model_id = cast(str, XComArg(create_model, key="model_id"))
|
87 | 83 |
|
88 | 84 | delete_model_task = AutoMLDeleteModelOperator(
|
|
99 | 95 | project_id=GCP_PROJECT_ID,
|
100 | 96 | )
|
101 | 97 |
|
| 98 | + # TEST BODY |
102 | 99 | import_dataset_task >> create_model
|
| 100 | + # TEST TEARDOWN |
103 | 101 | delete_model_task >> delete_datasets_task
|
104 | 102 |
|
105 | 103 | # Task dependencies created via `XComArgs`:
|
106 | 104 | # create_dataset_task >> import_dataset_task
|
107 | 105 | # create_dataset_task >> create_model
|
108 | 106 | # create_dataset_task >> delete_datasets_task
|
| 107 | + |
| 108 | + from tests.system.utils.watcher import watcher |
| 109 | + |
| 110 | + # This test needs watcher in order to properly mark success/failure |
| 111 | + # when "tearDown" task with trigger rule is part of the DAG |
| 112 | + list(dag.tasks) >> watcher() |
| 113 | + |
| 114 | +from tests.system.utils import get_test_run # noqa: E402 |
| 115 | + |
| 116 | +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) |
| 117 | +test_run = get_test_run(dag) |
0 commit comments