Skip to content

Commit a3ffbee

Browse files
yathitokisan
andauthored
Fix skipping non-GCS located jars (#22302)
* Fix #21989 indentation. A test is added to confirm job is executed on DataFlow with local jar file. Co-authored-by: Kyaw <kyawtuns@gmail.com>
1 parent 43dfec3 commit a3ffbee

File tree

2 files changed

+85
-22
lines changed

2 files changed

+85
-22
lines changed

β€Žairflow/providers/google/cloud/operators/dataflow.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -410,33 +410,33 @@ def set_current_job_id(job_id):
410410
tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.jar))
411411
self.jar = tmp_gcs_file.name
412412

413-
is_running = False
414-
if self.check_if_running != CheckJobRunning.IgnoreJob:
413+
is_running = False
414+
if self.check_if_running != CheckJobRunning.IgnoreJob:
415+
is_running = self.dataflow_hook.is_job_dataflow_running(
416+
name=self.job_name,
417+
variables=pipeline_options,
418+
)
419+
while is_running and self.check_if_running == CheckJobRunning.WaitForRun:
420+
415421
is_running = self.dataflow_hook.is_job_dataflow_running(
416422
name=self.job_name,
417423
variables=pipeline_options,
418424
)
419-
while is_running and self.check_if_running == CheckJobRunning.WaitForRun:
420-
421-
is_running = self.dataflow_hook.is_job_dataflow_running(
422-
name=self.job_name,
423-
variables=pipeline_options,
424-
)
425-
if not is_running:
426-
pipeline_options["jobName"] = job_name
427-
with self.dataflow_hook.provide_authorized_gcloud():
428-
self.beam_hook.start_java_pipeline(
429-
variables=pipeline_options,
430-
jar=self.jar,
431-
job_class=self.job_class,
432-
process_line_callback=process_line_callback,
433-
)
434-
self.dataflow_hook.wait_for_done(
435-
job_name=job_name,
436-
location=self.location,
437-
job_id=self.job_id,
438-
multiple_jobs=self.multiple_jobs,
425+
if not is_running:
426+
pipeline_options["jobName"] = job_name
427+
with self.dataflow_hook.provide_authorized_gcloud():
428+
self.beam_hook.start_java_pipeline(
429+
variables=pipeline_options,
430+
jar=self.jar,
431+
job_class=self.job_class,
432+
process_line_callback=process_line_callback,
439433
)
434+
self.dataflow_hook.wait_for_done(
435+
job_name=job_name,
436+
location=self.location,
437+
job_id=self.job_id,
438+
multiple_jobs=self.multiple_jobs,
439+
)
440440

441441
return {"job_id": self.job_id}
442442

β€Žtests/providers/google/cloud/operators/test_dataflow.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
PY_FILE = 'gs://my-bucket/my-object.py'
4444
PY_INTERPRETER = 'python3'
4545
JAR_FILE = 'gs://my-bucket/example/test.jar'
46+
LOCAL_JAR_FILE = '/mnt/dev/example/test.jar'
4647
JOB_CLASS = 'com.test.NotMain'
4748
PY_OPTIONS = ['-m']
4849
DEFAULT_OPTIONS_PYTHON = DEFAULT_OPTIONS_JAVA = {
@@ -380,6 +381,68 @@ def set_is_job_dataflow_running_variables(*args, **kwargs):
380381
)
381382

382383

384+
class TestDataflowJavaOperatorWithLocal(unittest.TestCase):
385+
def setUp(self):
386+
self.dataflow = DataflowCreateJavaJobOperator(
387+
task_id=TASK_ID,
388+
jar=LOCAL_JAR_FILE,
389+
job_name=JOB_NAME,
390+
job_class=JOB_CLASS,
391+
dataflow_default_options=DEFAULT_OPTIONS_JAVA,
392+
options=ADDITIONAL_OPTIONS,
393+
poll_sleep=POLL_SLEEP,
394+
location=TEST_LOCATION,
395+
)
396+
self.expected_airflow_version = 'v' + airflow.version.version.replace(".", "-").replace("+", "-")
397+
398+
def test_init(self):
399+
"""Test DataflowTemplateOperator instance is properly initialized."""
400+
assert self.dataflow.jar == LOCAL_JAR_FILE
401+
402+
@mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
403+
@mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
404+
def test_check_job_not_running_exec(self, dataflow_hook_mock, beam_hook_mock):
405+
"""Test DataflowHook is created and the right args are passed to
406+
start_java_workflow with option to check if job is running
407+
"""
408+
is_job_dataflow_running_variables = None
409+
410+
def set_is_job_dataflow_running_variables(*args, **kwargs):
411+
nonlocal is_job_dataflow_running_variables
412+
is_job_dataflow_running_variables = copy.deepcopy(kwargs.get("variables"))
413+
414+
dataflow_running = dataflow_hook_mock.return_value.is_job_dataflow_running
415+
dataflow_running.side_effect = set_is_job_dataflow_running_variables
416+
dataflow_running.return_value = False
417+
start_java_mock = beam_hook_mock.return_value.start_java_pipeline
418+
self.dataflow.check_if_running = True
419+
420+
self.dataflow.execute(None)
421+
expected_variables = {
422+
'project': dataflow_hook_mock.return_value.project_id,
423+
'stagingLocation': 'gs://test/staging',
424+
'jobName': JOB_NAME,
425+
'region': TEST_LOCATION,
426+
'output': 'gs://test/output',
427+
'labels': {'foo': 'bar', 'airflow-version': self.expected_airflow_version},
428+
}
429+
self.assertEqual(expected_variables, is_job_dataflow_running_variables)
430+
job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value
431+
expected_variables["jobName"] = job_name
432+
start_java_mock.assert_called_once_with(
433+
variables=expected_variables,
434+
jar=LOCAL_JAR_FILE,
435+
job_class=JOB_CLASS,
436+
process_line_callback=mock.ANY,
437+
)
438+
dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
439+
job_id=mock.ANY,
440+
job_name=job_name,
441+
location=TEST_LOCATION,
442+
multiple_jobs=False,
443+
)
444+
445+
383446
class TestDataflowTemplateOperator(unittest.TestCase):
384447
def setUp(self):
385448
self.dataflow = DataflowTemplatedJobStartOperator(

0 commit comments

Comments
 (0)