Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -441,17 +441,17 @@ def execute_on_dataflow(self, context: Context):
"""Execute the Apache Beam Pipeline on Dataflow runner."""
if not self.dataflow_hook:
self.dataflow_hook = self.__set_dataflow_hook()
with self.dataflow_hook.provide_authorized_gcloud():
self.beam_hook.start_python_pipeline(
variables=self.snake_case_pipeline_options,
py_file=self.py_file,
py_options=self.py_options,
py_interpreter=self.py_interpreter,
py_requirements=self.py_requirements,
py_system_site_packages=self.py_system_site_packages,
process_line_callback=self.process_line_callback,
is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback,
)

self.beam_hook.start_python_pipeline(
variables=self.snake_case_pipeline_options,
py_file=self.py_file,
py_options=self.py_options,
py_interpreter=self.py_interpreter,
py_requirements=self.py_requirements,
py_system_site_packages=self.py_system_site_packages,
process_line_callback=self.process_line_callback,
is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback,
)

location = self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION
DataflowJobLink.persist(context=context, region=location)
Expand Down Expand Up @@ -623,14 +623,13 @@ def execute_on_dataflow(self, context: Context):

if not is_running:
self.pipeline_options["jobName"] = self.dataflow_job_name
with self.dataflow_hook.provide_authorized_gcloud():
self.beam_hook.start_java_pipeline(
variables=self.pipeline_options,
jar=self.jar,
job_class=self.job_class,
process_line_callback=self.process_line_callback,
is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback,
)
self.beam_hook.start_java_pipeline(
variables=self.pipeline_options,
jar=self.jar,
job_class=self.job_class,
process_line_callback=self.process_line_callback,
is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback,
)
if self.dataflow_job_name and self.dataflow_config.location:
DataflowJobLink.persist(context=context)
if self.deferrable:
Expand Down Expand Up @@ -790,12 +789,11 @@ def execute(self, context: Context):
go_artifact.download_from_gcs(gcs_hook=gcs_hook, tmp_dir=tmp_dir)

if is_dataflow and self.dataflow_hook:
with self.dataflow_hook.provide_authorized_gcloud():
go_artifact.start_pipeline(
beam_hook=self.beam_hook,
variables=snake_case_pipeline_options,
process_line_callback=process_line_callback,
)
go_artifact.start_pipeline(
beam_hook=self.beam_hook,
variables=snake_case_pipeline_options,
process_line_callback=process_line_callback,
)
DataflowJobLink.persist(context=context)
if dataflow_job_name and self.dataflow_config.location:
self.dataflow_hook.wait_for_done(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def test_exec_dataflow_runner(
process_line_callback=mock.ANY,
is_dataflow_job_id_exist_callback=mock.ANY,
)
dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with()

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
Expand Down Expand Up @@ -760,7 +759,6 @@ def test_exec_dataflow_runner_with_go_file(
multiple_jobs=False,
project_id=dataflow_config.project_id,
)
dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with()

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
@mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
Expand Down Expand Up @@ -789,8 +787,6 @@ def gcs_download_side_effect(bucket_name: str, object_name: str, filename: str)
gcs_download_method.side_effect = gcs_download_side_effect

mock_dataflow_hook.build_dataflow_job_name.return_value = "test-job"

provide_authorized_gcloud_method = mock_dataflow_hook.return_value.provide_authorized_gcloud
start_go_pipeline_method = mock_beam_hook.return_value.start_go_pipeline_with_binary
wait_for_done_method = mock_dataflow_hook.return_value.wait_for_done

Expand Down Expand Up @@ -835,7 +831,6 @@ def gcs_download_side_effect(bucket_name: str, object_name: str, filename: str)
cancel_timeout=dataflow_config.cancel_timeout,
wait_until_finished=dataflow_config.wait_until_finished,
)
provide_authorized_gcloud_method.assert_called_once_with()
start_go_pipeline_method.assert_called_once_with(
variables=expected_options,
launcher_binary=expected_launcher_binary,
Expand Down Expand Up @@ -971,7 +966,6 @@ def test_exec_dataflow_runner(self, gcs_hook_mock, dataflow_hook_mock, beam_hook
wait_until_finished=dataflow_config.wait_until_finished,
)
beam_hook_mock.return_value.start_python_pipeline.assert_called_once()
dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with()

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
Expand Down Expand Up @@ -1100,7 +1094,6 @@ def test_exec_dataflow_runner(self, gcs_hook_mock, dataflow_hook_mock, beam_hook
wait_until_finished=dataflow_config.wait_until_finished,
)
beam_hook_mock.return_value.start_python_pipeline.assert_not_called()
dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with()

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
Expand Down