Skip to content

Commit

Permalink
fix: catch possible pickling errors (#118)
Browse files Browse the repository at this point in the history
### Summary of Changes

Catch pickling errors in two other locations where they might occur.

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
  • Loading branch information
lars-reimann and megalinter-bot authored May 29, 2024
1 parent 27cc616 commit 6333b64
Show file tree
Hide file tree
Showing 12 changed files with 466 additions and 78 deletions.
4 changes: 3 additions & 1 deletion src/safeds_runner/cli/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def _get_args() -> argparse.Namespace: # pragma: no cover
return parser.parse_args()


def _add_start_subparser(subparsers: argparse._SubParsersAction) -> None: # pragma: no cover
def _add_start_subparser(
subparsers: argparse._SubParsersAction,
) -> None: # pragma: no cover
parser = subparsers.add_parser(Commands.START, help="start the Safe-DS Runner server")
parser.add_argument("-p", "--port", type=int, default=5000, help="the port to use")
13 changes: 12 additions & 1 deletion src/safeds_runner/memoization/_memoization_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,18 @@ def memoized_function_call(
memoizable_value = _wrap_value_to_shared_memory(computed_value)
if self.max_size is not None:
self.ensure_capacity(_get_size_of_value(memoized_value))
self._map_values[key] = memoizable_value

try:
self._map_values[key] = memoizable_value
# Pickling may raise AttributeError in combination with multiprocessing
except AttributeError as exception:
# Fallback to returning computed value, but inform user about this failure
logging.exception(
"Could not store value for function %s.",
fully_qualified_function_name,
exc_info=exception,
)
return computed_value

self._update_stats_on_miss(
fully_qualified_function_name,
Expand Down
8 changes: 7 additions & 1 deletion src/safeds_runner/memoization/_memoization_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ def update_on_hit(self, access_timestamp: int, lookup_time: int) -> None:
self.access_timestamps.append(access_timestamp)
self.lookup_times.append(lookup_time)

def update_on_miss(self, access_timestamp: int, lookup_time: int, computation_time: int, memory_size: int) -> None:
def update_on_miss(
self,
access_timestamp: int,
lookup_time: int,
computation_time: int,
memory_size: int,
) -> None:
"""
Update the memoization stats on a cache miss.
Expand Down
6 changes: 5 additions & 1 deletion src/safeds_runner/memoization/_memoization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,11 @@ def _create_memoization_key(
A memoization key, which contains the lists converted to tuples
"""
arguments = [*positional_arguments, *keyword_arguments.values()]
return fully_qualified_function_name, _make_hashable(arguments), _make_hashable(hidden_arguments)
return (
fully_qualified_function_name,
_make_hashable(arguments),
_make_hashable(hidden_arguments),
)


def _wrap_value_to_shared_memory(
Expand Down
5 changes: 4 additions & 1 deletion src/safeds_runner/server/_json_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def default(self, o: Any) -> Any:
}
elif isinstance(o, Image):
# Send images together with their format, by default images are encoded only as PNG
return {"format": "png", "bytes": str(base64.encodebytes(o._repr_png_()), "utf-8")}
return {
"format": "png",
"bytes": str(base64.encodebytes(o._repr_png_()), "utf-8"),
}
else:
return json.JSONEncoder.default(self, o)
39 changes: 32 additions & 7 deletions src/safeds_runner/server/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,10 @@ def create_placeholder_value(placeholder_query: QueryMessageData, type_: str, va

message: dict[str, Any] = {"name": placeholder_query.name, "type": type_}
# Start Index >= 0
start_index = max(placeholder_query.window.begin if placeholder_query.window.begin is not None else 0, 0)
start_index = max(
placeholder_query.window.begin if placeholder_query.window.begin is not None else 0,
0,
)
# Length >= 0
length = max(placeholder_query.window.size, 0) if placeholder_query.window.size is not None else None
if isinstance(value, safeds.data.labeled.containers.TabularDataset):
Expand All @@ -244,7 +247,11 @@ def create_placeholder_value(placeholder_query: QueryMessageData, type_: str, va
):
max_index = value.number_of_rows
value = value.slice_rows(start=start_index, length=length)
window_information: dict[str, int] = {"begin": start_index, "size": value.number_of_rows, "max": max_index}
window_information: dict[str, int] = {
"begin": start_index,
"size": value.number_of_rows,
"max": max_index,
}
message["window"] = window_information
message["value"] = value
return message
Expand Down Expand Up @@ -281,7 +288,9 @@ def create_runtime_progress_done() -> str:
return "done"


def parse_validate_message(message: str) -> tuple[Message | None, str | None, str | None]:
def parse_validate_message(
message: str,
) -> tuple[Message | None, str | None, str | None]:
"""
Validate the basic structure of a received message string and return a parsed message object.
Expand All @@ -300,14 +309,30 @@ def parse_validate_message(message: str) -> tuple[Message | None, str | None, st
except json.JSONDecodeError:
return None, f"Invalid message received: {message}", "Invalid Message: not JSON"
if "type" not in message_dict:
return None, f"No message type specified in: {message}", "Invalid Message: no type"
return (
None,
f"No message type specified in: {message}",
"Invalid Message: no type",
)
elif "id" not in message_dict:
return None, f"No message id specified in: {message}", "Invalid Message: no id"
elif "data" not in message_dict:
return None, f"No message data specified in: {message}", "Invalid Message: no data"
return (
None,
f"No message data specified in: {message}",
"Invalid Message: no data",
)
elif not isinstance(message_dict["type"], str):
return None, f"Message type is not a string: {message}", "Invalid Message: invalid type"
return (
None,
f"Message type is not a string: {message}",
"Invalid Message: invalid type",
)
elif not isinstance(message_dict["id"], str):
return None, f"Message id is not a string: {message}", "Invalid Message: invalid id"
return (
None,
f"Message id is not a string: {message}",
"Invalid Message: invalid id",
)
else:
return Message(**message_dict), None, None
19 changes: 17 additions & 2 deletions src/safeds_runner/server/_pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ def _send_message(self, message_type: str, value: dict[Any, Any] | str) -> None:

def _send_exception(self, exception: BaseException) -> None:
backtrace = get_backtrace_info(exception)
self._send_message(message_type_runtime_error, create_runtime_error_description(exception.__str__(), backtrace))
self._send_message(
message_type_runtime_error,
create_runtime_error_description(exception.__str__(), backtrace),
)

def save_placeholder(self, placeholder_name: str, value: Any) -> None:
"""
Expand All @@ -177,7 +180,19 @@ def save_placeholder(self, placeholder_name: str, value: Any) -> None:
and _has_explicit_identity_memory(value)
):
value = ExplicitIdentityWrapper.existing(value)
self._placeholder_map[placeholder_name] = value

try:
self._placeholder_map[placeholder_name] = value
# Pickling may raise AttributeError in combination with multiprocessing
except AttributeError as exception: # pragma: no cover
# Don't crash, but inform user about this failure
logging.exception(
"Could not store value for placeholder %s.",
placeholder_name,
exc_info=exception,
)
return

self._send_message(
message_type_placeholder_type,
create_placeholder_description(placeholder_name, placeholder_type),
Expand Down
8 changes: 6 additions & 2 deletions src/safeds_runner/server/_process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ def _message_queue(self) -> Queue[Message]:

@cached_property
def _message_queue_thread(self) -> threading.Thread:
return threading.Thread(daemon=True, target=self._consume_queue_messages, args=[asyncio.get_event_loop()])
return threading.Thread(
daemon=True,
target=self._consume_queue_messages,
args=[asyncio.get_event_loop()],
)

@cached_property
def _process_pool(self) -> ProcessPoolExecutor:
Expand Down Expand Up @@ -147,7 +151,7 @@ def _warmup_worker() -> None:

from safeds.data.tabular.containers import Table # pragma: no cover

Table({"a": [1]}).get_column("a").plot_histogram() # pragma: no cover
Table({"a": [1]}).get_column("a").plot.histogram() # pragma: no cover


_State: TypeAlias = Literal["initial", "started", "shutdown"]
Expand Down
Loading

0 comments on commit 6333b64

Please sign in to comment.