Skip to content

Commit 516d006

Browse files
committed
Don't pickle proto messages before wrapping as google.protobuf.Any
1 parent bb6ec79 commit 516d006

File tree

1 file changed

+13
-27
lines changed

1 file changed

+13
-27
lines changed

src/dispatch/proto.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(self, req: function_pb.RunRequest):
7878

7979
self._has_input = req.HasField("input")
8080
if self._has_input:
81-
self._input = _pb_any_unpack(req.input)
81+
self._input = _any_unpickle(req.input)
8282
else:
8383
if req.poll_result.coroutine_state:
8484
raise IncompatibleStateError # coroutine_state is deprecated
@@ -141,7 +141,7 @@ def from_input_arguments(cls, function: str, *args, **kwargs):
141141
return Input(
142142
req=function_pb.RunRequest(
143143
function=function,
144-
input=_pb_any_pickle(input),
144+
input=_any_pickle(input),
145145
)
146146
)
147147

@@ -157,7 +157,7 @@ def from_poll_results(
157157
req=function_pb.RunRequest(
158158
function=function,
159159
poll_result=poll_pb.PollResult(
160-
typed_coroutine_state=_pb_any_pickle(coroutine_state),
160+
typed_coroutine_state=_any_pickle(coroutine_state),
161161
results=[result._as_proto() for result in call_results],
162162
error=error._as_proto() if error else None,
163163
),
@@ -241,7 +241,7 @@ def poll(
241241
else None
242242
)
243243
poll = poll_pb.Poll(
244-
typed_coroutine_state=_pb_any_pickle(coroutine_state),
244+
typed_coroutine_state=_any_pickle(coroutine_state),
245245
min_results=min_results,
246246
max_results=max_results,
247247
max_wait=max_wait,
@@ -279,7 +279,7 @@ class Call:
279279
correlation_id: Optional[int] = None
280280

281281
def _as_proto(self) -> call_pb.Call:
282-
input_bytes = _pb_any_pickle(self.input)
282+
input_bytes = _any_pickle(self.input)
283283
return call_pb.Call(
284284
correlation_id=self.correlation_id,
285285
endpoint=self.endpoint,
@@ -301,7 +301,7 @@ def _as_proto(self) -> call_pb.CallResult:
301301
output_any = None
302302
error_proto = None
303303
if self.output is not None:
304-
output_any = _pb_any_pickle(self.output)
304+
output_any = _any_pickle(self.output)
305305
if self.error is not None:
306306
error_proto = self.error._as_proto()
307307

@@ -440,31 +440,17 @@ def _as_proto(self) -> error_pb.Error:
440440
)
441441

442442

443-
def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:
444-
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
445-
p = pickled_pb.Pickled()
446-
any.Unpack(p)
447-
return pickle.loads(p.pickled_value)
448-
449-
elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): # legacy container
450-
b = google.protobuf.wrappers_pb2.BytesValue()
451-
any.Unpack(b)
452-
return pickle.loads(b.value)
453-
454-
elif not any.type_url and not any.value:
455-
return None
456-
457-
raise InvalidArgumentError(f"unsupported pickled value container: {any.type_url}")
458-
459-
460-
def _pb_any_pickle(value: Any) -> google.protobuf.any_pb2.Any:
461-
p = pickled_pb.Pickled(pickled_value=pickle.dumps(value))
443+
def _any_pickle(value: Any) -> google.protobuf.any_pb2.Any:
462444
any = google.protobuf.any_pb2.Any()
463-
any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
445+
if isinstance(value, google.protobuf.message.Message):
446+
any.Pack(value)
447+
else:
448+
p = pickled_pb.Pickled(pickled_value=pickle.dumps(value))
449+
any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
464450
return any
465451

466452

467-
def _pb_any_unpack(any: google.protobuf.any_pb2.Any) -> Any:
453+
def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:
468454
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
469455
p = pickled_pb.Pickled()
470456
any.Unpack(p)

0 commit comments

Comments
 (0)