Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/flyte/_bin/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def _run_and_stop():

logger.error(f"Flyte runtime failed for action {name} with run name {run_name}, error: {e}")
err = convert_from_native_to_error(e)
path = await upload_error(err.err, outputs_path, recoverable=err.recoverable)
path = await upload_error(err.err, outputs_path)
logger.error(f"Run {run_name} Action {name} failed with error: {err}. Uploaded error to {path}")
await controller.stop()

Expand Down
22 changes: 19 additions & 3 deletions src/flyte/_internal/runtime/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,15 @@ class Outputs:
@dataclass
class Error:
err: execution_pb2.ExecutionError
recoverable: bool = True

@property
def recoverable(self) -> bool:
return _is_execution_error_recoverable(self.err)


def _is_execution_error_recoverable(err: execution_pb2.ExecutionError) -> bool:
# Producers are expected to stamp RECOVERABLE explicitly; proto3 zero is NON_RECOVERABLE.
return err.recoverability != execution_pb2.ContainerError.NON_RECOVERABLE


# ------------------------------- CONVERT Methods ------------------------------- #
Expand Down Expand Up @@ -297,6 +305,10 @@ def convert_error_to_native(
case execution_pb2.ExecutionError.UNKNOWN:
return flyte.errors.RuntimeUnknownError(code=user_code, message=err.message, worker=err.worker)
case execution_pb2.ExecutionError.USER:
if not _is_execution_error_recoverable(err):
exc = flyte.errors.NonRecoverableError(err.message, code=user_code)
exc.worker = err.worker
return exc
if "OOM" in err.code.upper():
return flyte.errors.OOMError(code=user_code, message=err.message, worker=err.worker)
elif "Interrupted" in err.code:
Expand Down Expand Up @@ -327,8 +339,8 @@ def convert_from_native_to_error(err: BaseException) -> Error:
code=err.code,
message=str(err),
worker=err.worker,
),
recoverable=False,
recoverability=execution_pb2.ContainerError.NON_RECOVERABLE,
)
)
elif isinstance(err, flyte.errors.RuntimeUnknownError):
return Error(
Expand All @@ -337,6 +349,7 @@ def convert_from_native_to_error(err: BaseException) -> Error:
code=err.code,
message=str(err),
worker=err.worker,
recoverability=execution_pb2.ContainerError.RECOVERABLE,
)
)
elif isinstance(err, flyte.errors.RuntimeUserError):
Expand All @@ -346,6 +359,7 @@ def convert_from_native_to_error(err: BaseException) -> Error:
code=err.code,
message=str(err),
worker=err.worker,
recoverability=execution_pb2.ContainerError.RECOVERABLE,
)
)
elif isinstance(err, flyte.errors.RuntimeSystemError):
Expand All @@ -355,6 +369,7 @@ def convert_from_native_to_error(err: BaseException) -> Error:
code=err.code,
message=str(err),
worker=err.worker,
recoverability=execution_pb2.ContainerError.RECOVERABLE,
)
)
else:
Expand All @@ -364,6 +379,7 @@ def convert_from_native_to_error(err: BaseException) -> Error:
code=type(err).__name__,
message=str(err),
worker="UNKNOWN",
recoverability=execution_pb2.ContainerError.RECOVERABLE,
)
)

Expand Down
9 changes: 4 additions & 5 deletions src/flyte/_internal/runtime/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,16 @@ async def upload_outputs(outputs: Outputs, output_path: str, max_bytes: int = -1
logger.debug(f"Uploaded {output_uri} to {output_path}")


async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str, recoverable: bool = True) -> str:
async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str) -> str:
"""
:param err: execution_pb2.ExecutionError
:param output_prefix: The output prefix of the remote uri.
:param recoverable: If False, sets ContainerError.kind to NON_RECOVERABLE so the engine skips retries.
"""
error_document = execution_pb2.ErrorDocument(
error=execution_pb2.ContainerError(
code=err.code,
message=err.message,
kind=execution_pb2.ContainerError.RECOVERABLE
if recoverable
else execution_pb2.ContainerError.NON_RECOVERABLE,
kind=err.recoverability,
origin=err.kind,
)
)
Expand Down Expand Up @@ -176,11 +173,13 @@ async def load_error(path: str) -> execution_pb2.ExecutionError:
message=err.error.message,
kind=err.error.origin,
error_uri=path,
recoverability=err.error.kind,
)

return execution_pb2.ExecutionError(
code="Unknown",
message=f"Received unloadable error from path {path}",
kind=execution_pb2.ExecutionError.SYSTEM,
error_uri=path,
recoverability=execution_pb2.ContainerError.RECOVERABLE,
)
2 changes: 1 addition & 1 deletion src/flyte/_internal/runtime/taskrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ async def extract_download_run_upload(
)
logger.debug(f"Task {action.name} completed at {t}, with outputs: {outputs}")
if err is not None:
path = await upload_error(err.err, output_path, recoverable=err.recoverable)
path = await upload_error(err.err, output_path)
logger.error(f"Task {task.name} failed with error: {err}. Uploaded error to {path}")
return
if outputs is None:
Expand Down
126 changes: 126 additions & 0 deletions tests/flyte/internal/runtime/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest
import pytest_asyncio
from flyteidl2.core import execution_pb2
from flyteidl2.core.interface_pb2 import TypedInterface, Variable, VariableEntry, VariableMap
from flyteidl2.core.literals_pb2 import (
Literal,
Expand All @@ -27,6 +28,8 @@
from flyteidl2.task import common_pb2 as run_definition_pb2

import flyte._internal.runtime.convert as convert
import flyte.errors
from flyte._internal.runtime import io as runtime_io
from flyte._internal.runtime.convert import Inputs, generate_sub_action_id_and_output_path
from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
from flyte.models import ActionID, NativeInterface, RawDataPath, TaskContext
Expand All @@ -42,6 +45,129 @@
]


def test_convert_from_native_to_error_stamps_recoverability():
recoverable = convert.convert_from_native_to_error(flyte.errors.RuntimeUserError("ValueError", "retry me"))
assert recoverable.recoverable is True
assert recoverable.err.recoverability == execution_pb2.ContainerError.RECOVERABLE

non_recoverable = convert.convert_from_native_to_error(flyte.errors.NonRecoverableError("do not retry"))
assert non_recoverable.recoverable is False
assert non_recoverable.err.recoverability == execution_pb2.ContainerError.NON_RECOVERABLE


def test_convert_error_to_native_uses_non_recoverable_error_for_non_recoverable_user_errors():
err = execution_pb2.ExecutionError(
kind=execution_pb2.ExecutionError.USER,
code="ValueError",
message="bad input",
recoverability=execution_pb2.ContainerError.NON_RECOVERABLE,
)

native = convert.convert_error_to_native(err)

assert isinstance(native, flyte.errors.NonRecoverableError)
assert native.code == "ValueError"
assert not hasattr(native, "recoverable")

round_tripped = convert.convert_from_native_to_error(native)
assert round_tripped.recoverable is False
assert round_tripped.err.recoverability == execution_pb2.ContainerError.NON_RECOVERABLE


def test_convert_error_to_native_preserves_recoverable_signal():
err = execution_pb2.ExecutionError(
kind=execution_pb2.ExecutionError.SYSTEM,
code="TransientSystemError",
message="retry later",
recoverability=execution_pb2.ContainerError.RECOVERABLE,
)

native = convert.convert_error_to_native(err)

assert isinstance(native, flyte.errors.RuntimeSystemError)
assert not hasattr(native, "recoverable")

round_tripped = convert.convert_from_native_to_error(native)
assert round_tripped.recoverable is True
assert round_tripped.err.recoverability == execution_pb2.ContainerError.RECOVERABLE


@pytest.mark.asyncio
async def test_child_error_upload_download_round_trips_to_parent_native_error(tmp_path):
child_exc = flyte.errors.NonRecoverableError("invalid customer id", code="InvalidCustomerID")

child_err = convert.convert_from_native_to_error(child_exc)
error_uri = await runtime_io.upload_error(child_err.err, str(tmp_path))
downloaded_err = await runtime_io.load_error(error_uri)
parent_exc = convert.convert_error_to_native(downloaded_err)

assert error_uri == runtime_io.error_path(str(tmp_path))
assert downloaded_err.kind == child_err.err.kind
assert downloaded_err.code == child_exc.code
assert downloaded_err.message == str(child_exc)
assert downloaded_err.recoverability == execution_pb2.ContainerError.NON_RECOVERABLE
assert isinstance(parent_exc, flyte.errors.NonRecoverableError)
assert parent_exc.code == child_exc.code
assert str(parent_exc) == str(child_exc)


@pytest.mark.asyncio
async def test_child_runtime_unknown_error_upload_download_round_trips_to_parent_native_error(tmp_path):
child_exc = flyte.errors.RuntimeUnknownError("MysteryFailure", "something unexpected happened")

child_err = convert.convert_from_native_to_error(child_exc)
error_uri = await runtime_io.upload_error(child_err.err, str(tmp_path))
downloaded_err = await runtime_io.load_error(error_uri)
parent_exc = convert.convert_error_to_native(downloaded_err)

assert error_uri == runtime_io.error_path(str(tmp_path))
assert downloaded_err.kind == child_err.err.kind
assert downloaded_err.code == child_exc.code
assert downloaded_err.message == str(child_exc)
assert downloaded_err.recoverability == execution_pb2.ContainerError.RECOVERABLE
assert isinstance(parent_exc, flyte.errors.RuntimeUnknownError)
assert parent_exc.code == child_exc.code
assert str(parent_exc) == str(child_exc)


@pytest.mark.asyncio
async def test_child_runtime_user_error_upload_download_round_trips_to_parent_native_error(tmp_path):
child_exc = flyte.errors.OOMError("OOMKilled", "container was killed after using too much memory")

child_err = convert.convert_from_native_to_error(child_exc)
error_uri = await runtime_io.upload_error(child_err.err, str(tmp_path))
downloaded_err = await runtime_io.load_error(error_uri)
parent_exc = convert.convert_error_to_native(downloaded_err)

assert error_uri == runtime_io.error_path(str(tmp_path))
assert downloaded_err.kind == child_err.err.kind
assert downloaded_err.code == child_exc.code
assert downloaded_err.message == str(child_exc)
assert downloaded_err.recoverability == execution_pb2.ContainerError.RECOVERABLE
assert isinstance(parent_exc, flyte.errors.OOMError)
assert parent_exc.code == child_exc.code
assert str(parent_exc) == str(child_exc)


@pytest.mark.asyncio
async def test_child_runtime_system_error_upload_download_round_trips_to_parent_native_error(tmp_path):
child_exc = flyte.errors.RuntimeSystemError("StorageUnavailable", "metadata store is temporarily unavailable")

child_err = convert.convert_from_native_to_error(child_exc)
error_uri = await runtime_io.upload_error(child_err.err, str(tmp_path))
downloaded_err = await runtime_io.load_error(error_uri)
parent_exc = convert.convert_error_to_native(downloaded_err)

assert error_uri == runtime_io.error_path(str(tmp_path))
assert downloaded_err.kind == child_err.err.kind
assert downloaded_err.code == child_exc.code
assert downloaded_err.message == str(child_exc)
assert downloaded_err.recoverability == execution_pb2.ContainerError.RECOVERABLE
assert isinstance(parent_exc, flyte.errors.RuntimeSystemError)
assert parent_exc.code == child_exc.code
assert str(parent_exc) == str(child_exc)


@pytest_asyncio.fixture(params=test_cases)
async def generate_inputs(request) -> Tuple[Inputs, str]:
if request.param[0] is None:
Expand Down
Loading