Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions airflow/providers/apache/beam/operators/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ def on_kill(self) -> None:
self.dataflow_hook.cancel_job(
job_id=self.dataflow_job_id,
project_id=self.dataflow_config.project_id,
location=self.dataflow_config.location,
)


Expand Down Expand Up @@ -573,6 +574,7 @@ def execute_sync(self, context: Context):
is_running = self.dataflow_hook.is_job_dataflow_running(
name=self.dataflow_config.job_name,
variables=self.pipeline_options,
location=self.dataflow_config.location,
)

if not is_running:
Expand Down Expand Up @@ -656,6 +658,7 @@ def on_kill(self) -> None:
self.dataflow_hook.cancel_job(
job_id=self.dataflow_job_id,
project_id=self.dataflow_config.project_id,
location=self.dataflow_config.location,
)


Expand Down Expand Up @@ -807,6 +810,7 @@ def on_kill(self) -> None:
self.dataflow_hook.cancel_job(
job_id=self.dataflow_job_id,
project_id=self.dataflow_config.project_id,
location=self.dataflow_config.location,
)


Expand Down
19 changes: 14 additions & 5 deletions airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,18 +1124,17 @@ def build_dataflow_job_name(job_name: str, append_job_name: bool = True) -> str:

return safe_job_name

@_fallback_to_location_from_variables
@_fallback_to_project_id_from_variables
@GoogleBaseHook.fallback_to_default_project_id
def is_job_dataflow_running(
self,
name: str,
project_id: str,
location: str = DEFAULT_DATAFLOW_LOCATION,
location: str | None = None,
variables: dict | None = None,
) -> bool:
"""
Check if jos is still running in dataflow.
Check if job is still running in dataflow.

:param name: The name of the job.
:param project_id: Optional, the Google Cloud project ID in which to start a job.
Expand All @@ -1145,11 +1144,21 @@ def is_job_dataflow_running(
"""
if variables:
warnings.warn(
"The variables parameter has been deprecated. You should pass location using "
"the location parameter.",
"The variables parameter has been deprecated. You should pass project_id using "
"the project_id parameter.",
AirflowProviderDeprecationWarning,
stacklevel=4,
)

if location is None:
location = DEFAULT_DATAFLOW_LOCATION
warnings.warn(
"The location argument will be become mandatory in future versions, "
f"currently, it defaults to {DEFAULT_DATAFLOW_LOCATION}, please set the location explicitly.",
AirflowProviderDeprecationWarning,
stacklevel=4,
)

jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
Expand Down
2 changes: 2 additions & 0 deletions airflow/providers/google/cloud/operators/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,13 @@ def set_current_job_id(job_id):
is_running = self.dataflow_hook.is_job_dataflow_running(
name=self.job_name,
variables=pipeline_options,
location=self.location,
)
while is_running and self.check_if_running == CheckJobRunning.WaitForRun:
is_running = self.dataflow_hook.is_job_dataflow_running(
name=self.job_name,
variables=pipeline_options,
location=self.location,
)
if not is_running:
pipeline_options["jobName"] = job_name
Expand Down
20 changes: 15 additions & 5 deletions tests/providers/apache/beam/operators/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,9 @@ def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___):
op.dataflow_job_id = JOB_ID
op.on_kill()

dataflow_cancel_job.assert_called_once_with(job_id=JOB_ID, project_id=op.dataflow_config.project_id)
dataflow_cancel_job.assert_called_once_with(
job_id=JOB_ID, project_id=op.dataflow_config.project_id, location=op.dataflow_config.location
)

@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
Expand Down Expand Up @@ -465,7 +467,9 @@ def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___):
op.dataflow_job_id = JOB_ID
op.on_kill()

dataflow_cancel_job.assert_called_once_with(job_id=JOB_ID, project_id=op.dataflow_config.project_id)
dataflow_cancel_job.assert_called_once_with(
job_id=JOB_ID, project_id=op.dataflow_config.project_id, location=op.dataflow_config.location
)

@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
Expand Down Expand Up @@ -859,7 +863,9 @@ def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___):
op.dataflow_job_id = JOB_ID
op.on_kill()

dataflow_cancel_job.assert_called_once_with(job_id=JOB_ID, project_id=op.dataflow_config.project_id)
dataflow_cancel_job.assert_called_once_with(
job_id=JOB_ID, project_id=op.dataflow_config.project_id, location=op.dataflow_config.location
)

@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
Expand Down Expand Up @@ -989,7 +995,9 @@ def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___):
op.execute(context=mock.MagicMock())
op.dataflow_job_id = JOB_ID
op.on_kill()
dataflow_cancel_job.assert_called_once_with(job_id=JOB_ID, project_id=op.dataflow_config.project_id)
dataflow_cancel_job.assert_called_once_with(
job_id=JOB_ID, project_id=op.dataflow_config.project_id, location=op.dataflow_config.location
)

@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
Expand Down Expand Up @@ -1108,7 +1116,9 @@ def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___):
op.execute(context=mock.MagicMock())
op.dataflow_job_id = JOB_ID
op.on_kill()
dataflow_cancel_job.assert_called_once_with(job_id=JOB_ID, project_id=op.dataflow_config.project_id)
dataflow_cancel_job.assert_called_once_with(
job_id=JOB_ID, project_id=op.dataflow_config.project_id, location=op.dataflow_config.location
)

@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/google/cloud/operators/test_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def test_check_job_running_exec(self, gcs_hook, dataflow_mock, beam_hook_mock):
"output": "gs://test/output",
"labels": {"foo": "bar", "airflow-version": self.expected_airflow_version},
}
dataflow_running.assert_called_once_with(name=JOB_NAME, variables=variables)
dataflow_running.assert_called_once_with(name=JOB_NAME, variables=variables, location=TEST_LOCATION)

@mock.patch(
"airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback"
Expand Down