Skip to content

Commit 43efde6

Browse files
authored
Fix MyPy Errors for Apache Beam (and Dataflow) provider. (#20301)
1 parent 5712e2b commit 43efde6

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

β€Žairflow/providers/apache/beam/operators/beam.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ class BeamDataflowMixin(metaclass=ABCMeta):
4141
"""
4242

4343
dataflow_hook: Optional[DataflowHook]
44-
dataflow_config: Optional[DataflowConfiguration]
44+
dataflow_config: DataflowConfiguration
45+
gcp_conn_id: str
46+
delegate_to: Optional[str]
4547

4648
def _set_dataflow(
4749
self, pipeline_options: dict, job_name_variable_key: Optional[str] = None
@@ -198,11 +200,17 @@ def __init__(
198200
self.py_system_site_packages = py_system_site_packages
199201
self.gcp_conn_id = gcp_conn_id
200202
self.delegate_to = delegate_to
201-
self.dataflow_config = dataflow_config or {}
202203
self.beam_hook: Optional[BeamHook] = None
203204
self.dataflow_hook: Optional[DataflowHook] = None
204205
self.dataflow_job_id: Optional[str] = None
205206

207+
if dataflow_config is None:
208+
self.dataflow_config = DataflowConfiguration()
209+
elif isinstance(dataflow_config, dict):
210+
self.dataflow_config = DataflowConfiguration(**dataflow_config)
211+
else:
212+
self.dataflow_config = dataflow_config
213+
206214
if self.dataflow_config and self.runner.lower() != BeamRunnerType.DataflowRunner.lower():
207215
self.log.warning(
208216
"dataflow_config is defined but runner is different than DataflowRunner (%s)", self.runner
@@ -216,9 +224,6 @@ def execute(self, context):
216224
is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower()
217225
dataflow_job_name: Optional[str] = None
218226

219-
if isinstance(self.dataflow_config, dict):
220-
self.dataflow_config = DataflowConfiguration(**self.dataflow_config)
221-
222227
if is_dataflow:
223228
dataflow_job_name, pipeline_options, process_line_callback = self._set_dataflow(
224229
pipeline_options=pipeline_options, job_name_variable_key="job_name"
@@ -366,14 +371,20 @@ def __init__(
366371
self.default_pipeline_options = default_pipeline_options or {}
367372
self.pipeline_options = pipeline_options or {}
368373
self.job_class = job_class
369-
self.dataflow_config = dataflow_config or {}
370374
self.gcp_conn_id = gcp_conn_id
371375
self.delegate_to = delegate_to
372376
self.dataflow_job_id = None
373377
self.dataflow_hook: Optional[DataflowHook] = None
374378
self.beam_hook: Optional[BeamHook] = None
375379
self._dataflow_job_name: Optional[str] = None
376380

381+
if dataflow_config is None:
382+
self.dataflow_config = DataflowConfiguration()
383+
elif isinstance(dataflow_config, dict):
384+
self.dataflow_config = DataflowConfiguration(**dataflow_config)
385+
else:
386+
self.dataflow_config = dataflow_config
387+
377388
if self.dataflow_config and self.runner.lower() != BeamRunnerType.DataflowRunner.lower():
378389
self.log.warning(
379390
"dataflow_config is defined but runner is different than DataflowRunner (%s)", self.runner
@@ -387,9 +398,6 @@ def execute(self, context):
387398
is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower()
388399
dataflow_job_name: Optional[str] = None
389400

390-
if isinstance(self.dataflow_config, dict):
391-
self.dataflow_config = DataflowConfiguration(**self.dataflow_config)
392-
393401
if is_dataflow:
394402
dataflow_job_name, pipeline_options, process_line_callback = self._set_dataflow(
395403
pipeline_options=pipeline_options, job_name_variable_key=None

β€Žairflow/providers/google/cloud/hooks/dataflow.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ def __init__(
212212
self._jobs: Optional[List[dict]] = None
213213
self.drain_pipeline = drain_pipeline
214214
self._wait_until_finished = wait_until_finished
215-
self._jobs: Optional[List[dict]] = None
216215

217216
def is_job_running(self) -> bool:
218217
"""
@@ -1064,7 +1063,7 @@ def start_sql_job(
10641063
DeprecationWarning,
10651064
stacklevel=3,
10661065
)
1067-
on_new_job_id_callback(job.get("id"))
1066+
on_new_job_id_callback(job["id"])
10681067

10691068
if on_new_job_callback:
10701069
on_new_job_callback(job)

β€Žairflow/providers/google/cloud/operators/dataflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class DataflowConfiguration:
140140
def __init__(
141141
self,
142142
*,
143-
job_name: Optional[str] = "{{task.task_id}}",
143+
job_name: str = "{{task.task_id}}",
144144
append_job_name: bool = True,
145145
project_id: Optional[str] = None,
146146
location: Optional[str] = DEFAULT_DATAFLOW_LOCATION,

0 commit comments

Comments
 (0)