Skip to content

Commit afb686c

Browse files
moiseenkovLee-W
andauthored
Implement deferrable mode for GKEStartJobOperator (#38454)
* Implement deferrable mode for GKEStartJobOperator * Specify trigger return type Co-authored-by: Wei Lee <weilee.rx@gmail.com> * Fix f-string Co-authored-by: Wei Lee <weilee.rx@gmail.com> * Refactor trigger event yielding Co-authored-by: Wei Lee <weilee.rx@gmail.com> * Fix typo --------- Co-authored-by: Wei Lee <weilee.rx@gmail.com>
1 parent 14e1b4c commit afb686c

File tree

7 files changed

+365
-23
lines changed

7 files changed

+365
-23
lines changed

β€Žairflow/providers/cncf/kubernetes/triggers/job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
job_name: str,
4646
job_namespace: str,
4747
kubernetes_conn_id: str | None = None,
48-
poll_interval: float = 2,
48+
poll_interval: float = 10.0,
4949
cluster_context: str | None = None,
5050
config_file: str | None = None,
5151
in_cluster: bool | None = None,

β€Žairflow/providers/google/cloud/operators/kubernetes_engine.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from google.cloud.container_v1.types import Cluster
3232
from kubernetes.client import V1JobList
3333
from kubernetes.utils.create_from_yaml import FailToCreateError
34+
from packaging.version import parse as parse_version
3435

3536
from airflow.configuration import conf
3637
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
@@ -55,7 +56,12 @@
5556
KubernetesEngineWorkloadsLink,
5657
)
5758
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
58-
from airflow.providers.google.cloud.triggers.kubernetes_engine import GKEOperationTrigger, GKEStartPodTrigger
59+
from airflow.providers.google.cloud.triggers.kubernetes_engine import (
60+
GKEJobTrigger,
61+
GKEOperationTrigger,
62+
GKEStartPodTrigger,
63+
)
64+
from airflow.providers_manager import ProvidersManager
5965
from airflow.utils.timezone import utcnow
6066

6167
if TYPE_CHECKING:
@@ -834,6 +840,9 @@ class GKEStartJobOperator(KubernetesJobOperator):
834840
Service Account Token Creator IAM role to the directly preceding identity, with first
835841
account from the list granting this role to the originating account (templated).
836842
:param location: The location param is region name.
843+
:param deferrable: Run operator in the deferrable mode.
844+
:param poll_interval: (Deferrable mode only) polling period in seconds to
845+
check for the status of job.
837846
"""
838847

839848
template_fields: Sequence[str] = tuple(
@@ -850,6 +859,8 @@ def __init__(
850859
project_id: str | None = None,
851860
gcp_conn_id: str = "google_cloud_default",
852861
impersonation_chain: str | Sequence[str] | None = None,
862+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
863+
job_poll_interval: float = 10.0,
853864
**kwargs,
854865
) -> None:
855866
super().__init__(**kwargs)
@@ -859,6 +870,8 @@ def __init__(
859870
self.gcp_conn_id = gcp_conn_id
860871
self.impersonation_chain = impersonation_chain
861872
self.use_internal_ip = use_internal_ip
873+
self.deferrable = deferrable
874+
self.job_poll_interval = job_poll_interval
862875

863876
self.job: V1Job | None = None
864877
self._ssl_ca_cert: str | None = None
@@ -900,6 +913,18 @@ def hook(self) -> GKEJobHook:
900913

901914
def execute(self, context: Context):
902915
"""Execute process of creating Job."""
916+
if self.deferrable:
917+
kubernetes_provider = ProvidersManager().providers["apache-airflow-providers-cncf-kubernetes"]
918+
kubernetes_provider_name = kubernetes_provider.data["package-name"]
919+
kubernetes_provider_version = kubernetes_provider.version
920+
min_version = "8.0.1"
921+
if parse_version(kubernetes_provider_version) <= parse_version(min_version):
922+
raise AirflowException(
923+
"You are trying to use `GKEStartJobOperator` in deferrable mode with the provider "
924+
f"package {kubernetes_provider_name}=={kubernetes_provider_version} which doesn't "
925+
f"support this feature. Please upgrade it to version higher than {min_version}."
926+
)
927+
903928
self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails(
904929
cluster_name=self.cluster_name,
905930
project_id=self.project_id,
@@ -909,6 +934,20 @@ def execute(self, context: Context):
909934

910935
return super().execute(context)
911936

937+
def execute_deferrable(self):
938+
self.defer(
939+
trigger=GKEJobTrigger(
940+
cluster_url=self._cluster_url,
941+
ssl_ca_cert=self._ssl_ca_cert,
942+
job_name=self.job.metadata.name, # type: ignore[union-attr]
943+
job_namespace=self.job.metadata.namespace, # type: ignore[union-attr]
944+
gcp_conn_id=self.gcp_conn_id,
945+
poll_interval=self.job_poll_interval,
946+
impersonation_chain=self.impersonation_chain,
947+
),
948+
method_name="execute_complete",
949+
)
950+
912951

913952
class GKEDescribeJobOperator(GoogleCloudBaseOperator):
914953
"""

β€Žairflow/providers/google/cloud/triggers/kubernetes_engine.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,18 @@
2727
from airflow.exceptions import AirflowProviderDeprecationWarning
2828
from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
2929
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
30-
from airflow.providers.google.cloud.hooks.kubernetes_engine import GKEAsyncHook, GKEPodAsyncHook
30+
from airflow.providers.google.cloud.hooks.kubernetes_engine import (
31+
GKEAsyncHook,
32+
GKEKubernetesAsyncHook,
33+
GKEPodAsyncHook,
34+
)
3135
from airflow.triggers.base import BaseTrigger, TriggerEvent
3236

3337
if TYPE_CHECKING:
3438
from datetime import datetime
3539

40+
from kubernetes_asyncio.client import V1Job
41+
3642

3743
class GKEStartPodTrigger(KubernetesPodTrigger):
3844
"""
@@ -237,3 +243,67 @@ def _get_hook(self) -> GKEAsyncHook:
237243
impersonation_chain=self.impersonation_chain,
238244
)
239245
return self._hook
246+
247+
248+
class GKEJobTrigger(BaseTrigger):
249+
"""GKEJobTrigger run on the trigger worker to check the state of Job."""
250+
251+
def __init__(
252+
self,
253+
cluster_url: str,
254+
ssl_ca_cert: str,
255+
job_name: str,
256+
job_namespace: str,
257+
gcp_conn_id: str = "google_cloud_default",
258+
poll_interval: float = 2,
259+
impersonation_chain: str | Sequence[str] | None = None,
260+
) -> None:
261+
super().__init__()
262+
self.cluster_url = cluster_url
263+
self.ssl_ca_cert = ssl_ca_cert
264+
self.job_name = job_name
265+
self.job_namespace = job_namespace
266+
self.gcp_conn_id = gcp_conn_id
267+
self.poll_interval = poll_interval
268+
self.impersonation_chain = impersonation_chain
269+
270+
def serialize(self) -> tuple[str, dict[str, Any]]:
271+
"""Serialize KubernetesCreateJobTrigger arguments and classpath."""
272+
return (
273+
"airflow.providers.google.cloud.triggers.kubernetes_engine.GKEJobTrigger",
274+
{
275+
"cluster_url": self.cluster_url,
276+
"ssl_ca_cert": self.ssl_ca_cert,
277+
"job_name": self.job_name,
278+
"job_namespace": self.job_namespace,
279+
"gcp_conn_id": self.gcp_conn_id,
280+
"poll_interval": self.poll_interval,
281+
"impersonation_chain": self.impersonation_chain,
282+
},
283+
)
284+
285+
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
286+
"""Get current job status and yield a TriggerEvent."""
287+
job: V1Job = await self.hook.wait_until_job_complete(name=self.job_name, namespace=self.job_namespace)
288+
job_dict = job.to_dict()
289+
error_message = self.hook.is_job_failed(job=job)
290+
status = "error" if error_message else "success"
291+
message = f"Job failed with error: {error_message}" if error_message else "Job completed successfully"
292+
yield TriggerEvent(
293+
{
294+
"name": job.metadata.name,
295+
"namespace": job.metadata.namespace,
296+
"status": status,
297+
"message": message,
298+
"job": job_dict,
299+
}
300+
)
301+
302+
@cached_property
303+
def hook(self) -> GKEKubernetesAsyncHook:
304+
return GKEKubernetesAsyncHook(
305+
cluster_url=self.cluster_url,
306+
ssl_ca_cert=self.ssl_ca_cert,
307+
gcp_conn_id=self.gcp_conn_id,
308+
impersonation_chain=self.impersonation_chain,
309+
)

β€Ždocs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,15 @@ All Kubernetes parameters (except ``config_file``) are also valid for the ``GKES
213213
:start-after: [START howto_operator_gke_start_job]
214214
:end-before: [END howto_operator_gke_start_job]
215215

216+
``GKEStartJobOperator`` also supports deferrable mode. Note that it makes sense only if the ``wait_until_job_complete``
217+
parameter is set ``True``.
218+
219+
.. exampleinclude:: /../../tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
220+
:language: python
221+
:dedent: 4
222+
:start-after: [START howto_operator_gke_start_job_def]
223+
:end-before: [END howto_operator_gke_start_job_def]
224+
216225
For run Job on a GKE cluster with Kueue enabled use ``GKEStartKueueJobOperator``.
217226

218227
.. exampleinclude:: /../../tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_kueue.py

β€Žtests/providers/google/cloud/operators/test_kubernetes_engine.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
JOB_NAME = "test-job"
7272
NAMESPACE = ("default",)
7373
IMAGE = "bash"
74+
JOB_POLL_INTERVAL = 20.0
7475

7576
GCLOUD_COMMAND = "gcloud container clusters get-credentials {} --zone {} --project {}"
7677
KUBE_ENV_VAR = "KUBECONFIG"
@@ -708,6 +709,81 @@ def test_execute(self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock)
708709
self.gke_op.execute(context=mock.MagicMock())
709710
fetch_cluster_info_mock.assert_called_once()
710711

712+
@mock.patch(KUB_JOB_OPERATOR_EXEC)
713+
@mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
714+
@mock.patch(GKE_HOOK_PATH)
715+
@mock.patch(f"{GKE_HOOK_MODULE_PATH}.ProvidersManager")
716+
def test_execute_in_deferrable_mode(
717+
self, mock_providers_manager, mock_hook, fetch_cluster_info_mock, exec_mock
718+
):
719+
kubernetes_package_name = "apache-airflow-providers-cncf-kubernetes"
720+
mock_providers_manager.return_value.providers = {
721+
kubernetes_package_name: mock.MagicMock(
722+
data={
723+
"package-name": kubernetes_package_name,
724+
},
725+
version="8.0.2",
726+
)
727+
}
728+
self.gke_op.deferrable = True
729+
730+
fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
731+
self.gke_op.execute(context=mock.MagicMock())
732+
fetch_cluster_info_mock.assert_called_once()
733+
734+
@mock.patch(f"{GKE_HOOK_MODULE_PATH}.ProvidersManager")
735+
def test_execute_in_deferrable_mode_exception(self, mock_providers_manager):
736+
kubernetes_package_name = "apache-airflow-providers-cncf-kubernetes"
737+
mock_providers_manager.return_value.providers = {
738+
kubernetes_package_name: mock.MagicMock(
739+
data={
740+
"package-name": kubernetes_package_name,
741+
},
742+
version="8.0.1",
743+
)
744+
}
745+
self.gke_op.deferrable = True
746+
with pytest.raises(AirflowException):
747+
self.gke_op.execute({})
748+
749+
@mock.patch(f"{GKE_HOOK_MODULE_PATH}.GKEJobTrigger")
750+
def test_execute_deferrable(self, mock_trigger):
751+
mock_trigger_instance = mock_trigger.return_value
752+
753+
op = GKEStartJobOperator(
754+
project_id=TEST_GCP_PROJECT_ID,
755+
location=PROJECT_LOCATION,
756+
cluster_name=CLUSTER_NAME,
757+
task_id=PROJECT_TASK_ID,
758+
name=TASK_NAME,
759+
namespace=NAMESPACE,
760+
image=IMAGE,
761+
job_poll_interval=JOB_POLL_INTERVAL,
762+
)
763+
op._ssl_ca_cert = SSL_CA_CERT
764+
op._cluster_url = CLUSTER_URL
765+
766+
with mock.patch.object(op, "job") as mock_job:
767+
mock_metadata = mock_job.metadata
768+
mock_metadata.name = TASK_NAME
769+
mock_metadata.namespace = NAMESPACE
770+
with mock.patch.object(op, "defer") as mock_defer:
771+
op.execute_deferrable()
772+
773+
mock_trigger.assert_called_once_with(
774+
cluster_url=CLUSTER_URL,
775+
ssl_ca_cert=SSL_CA_CERT,
776+
job_name=TASK_NAME,
777+
job_namespace=NAMESPACE,
778+
gcp_conn_id="google_cloud_default",
779+
poll_interval=JOB_POLL_INTERVAL,
780+
impersonation_chain=None,
781+
)
782+
mock_defer.assert_called_once_with(
783+
trigger=mock_trigger_instance,
784+
method_name="execute_complete",
785+
)
786+
711787
def test_config_file_throws_error(self):
712788
with pytest.raises(AirflowException):
713789
GKEStartJobOperator(

0 commit comments

Comments
 (0)