Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,26 @@ def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context: Context) -> k8s
if self.reattach_on_restart:
pod = self.find_pod(pod_request_obj.metadata.namespace, context=context)
if pod:
return pod
# If pod is terminated then delete the pod an create a new as not possible to get xcom
pod_phase = (
pod.status.phase if hasattr(pod, "status") and hasattr(pod.status, "phase") else None
)
if pod_phase and pod_phase not in (PodPhase.SUCCEEDED, PodPhase.FAILED):
return pod

self.log.info(
"Found terminated old matching pod %s with labels %s",
pod.metadata.name,
pod.metadata.labels,
)

# if not required to delete the pod then keep old logic and not automatically create new pod
deleted_pod = self.process_pod_deletion(pod)
if not deleted_pod:
return pod

self.log.info("Deleted pod to handle rerun and create new pod!")

self.log.debug("Starting pod:\n%s", yaml.safe_dump(pod_request_obj.to_dict()))
self.pod_manager.create_pod(pod=pod_request_obj)
return pod_request_obj
Expand Down Expand Up @@ -1067,7 +1086,7 @@ def kill_istio_sidecar(self, pod: V1Pod) -> None:
if self.KILL_ISTIO_PROXY_SUCCESS_MSG not in output_str:
raise AirflowException("Error while deleting istio-proxy sidecar: %s", output_str)

def process_pod_deletion(self, pod: k8s.V1Pod, *, reraise=True):
def process_pod_deletion(self, pod: k8s.V1Pod, *, reraise=True) -> bool:
with _optionally_suppress(reraise=reraise):
if pod is not None:
should_delete_pod = (self.on_finish_action == OnFinishAction.DELETE_POD) or (
Expand All @@ -1080,8 +1099,10 @@ def process_pod_deletion(self, pod: k8s.V1Pod, *, reraise=True):
if should_delete_pod:
self.log.info("Deleting pod: %s", pod.metadata.name)
self.pod_manager.delete_pod(pod)
else:
self.log.info("Skipping deleting pod: %s", pod.metadata.name)
return True
self.log.info("Skipping deleting pod: %s", pod.metadata.name)

return False

def _build_find_pod_label_selector(self, context: Context | None = None, *, exclude_checked=True) -> str:
labels = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,50 @@ def test_omitted_namespace_no_conn_not_in_k8s(self, mock_find, mock_path):
)
mock_find.assert_called_once_with("default", context=context)

@pytest.mark.parametrize(
"pod_phase",
[
PodPhase.SUCCEEDED,
PodPhase.FAILED,
PodPhase.RUNNING,
],
)
@patch(f"{KPO_MODULE}.PodManager.create_pod")
@patch(f"{KPO_MODULE}.KubernetesPodOperator.process_pod_deletion")
@patch(f"{KPO_MODULE}.KubernetesPodOperator.find_pod")
def test_get_or_create_pod_reattach_terminated(
self, mock_find, mock_process_pod_deletion, mock_create_pod, pod_phase
):
"""Check that get_or_create_pod reattaches to existing pod."""
k = KubernetesPodOperator(
image="ubuntu:16.04",
cmds=["bash", "-cx"],
arguments=["echo 10"],
task_id="task",
name="hello",
log_pod_spec_on_failure=False,
)
k.reattach_on_restart = True
context = create_context(k)
mock_pod_request_obj = MagicMock()
mock_pod_request_obj.to_dict.return_value = {"metadata": {"name": "test-pod"}}

mock_found_pod = MagicMock()
mock_found_pod.status.phase = pod_phase
mock_find.return_value = mock_found_pod
result = k.get_or_create_pod(
pod_request_obj=mock_pod_request_obj,
context=context,
)
if pod_phase == PodPhase.RUNNING:
mock_create_pod.assert_not_called()
mock_process_pod_deletion.assert_not_called()
assert result == mock_found_pod
else:
mock_process_pod_deletion.assert_called_once_with(mock_found_pod)
mock_create_pod.assert_called_once_with(pod=mock_pod_request_obj)
assert result == mock_pod_request_obj

def test_xcom_sidecar_container_image_custom(self):
image = "private.repo/alpine:3.13"
with temp_override_attr(PodDefaults.SIDECAR_CONTAINER, "image", image):
Expand Down