@@ -78,7 +78,7 @@ def __init__(self, req: function_pb.RunRequest):
78
78
79
79
self ._has_input = req .HasField ("input" )
80
80
if self ._has_input :
81
- self ._input = _pb_any_unpack (req .input )
81
+ self ._input = _any_unpickle (req .input )
82
82
else :
83
83
if req .poll_result .coroutine_state :
84
84
raise IncompatibleStateError # coroutine_state is deprecated
@@ -141,7 +141,7 @@ def from_input_arguments(cls, function: str, *args, **kwargs):
141
141
return Input (
142
142
req = function_pb .RunRequest (
143
143
function = function ,
144
- input = _pb_any_pickle (input ),
144
+ input = _any_pickle (input ),
145
145
)
146
146
)
147
147
@@ -157,7 +157,7 @@ def from_poll_results(
157
157
req = function_pb .RunRequest (
158
158
function = function ,
159
159
poll_result = poll_pb .PollResult (
160
- typed_coroutine_state = _pb_any_pickle (coroutine_state ),
160
+ typed_coroutine_state = _any_pickle (coroutine_state ),
161
161
results = [result ._as_proto () for result in call_results ],
162
162
error = error ._as_proto () if error else None ,
163
163
),
@@ -241,7 +241,7 @@ def poll(
241
241
else None
242
242
)
243
243
poll = poll_pb .Poll (
244
- typed_coroutine_state = _pb_any_pickle (coroutine_state ),
244
+ typed_coroutine_state = _any_pickle (coroutine_state ),
245
245
min_results = min_results ,
246
246
max_results = max_results ,
247
247
max_wait = max_wait ,
@@ -279,7 +279,7 @@ class Call:
279
279
correlation_id : Optional [int ] = None
280
280
281
281
def _as_proto (self ) -> call_pb .Call :
282
- input_bytes = _pb_any_pickle (self .input )
282
+ input_bytes = _any_pickle (self .input )
283
283
return call_pb .Call (
284
284
correlation_id = self .correlation_id ,
285
285
endpoint = self .endpoint ,
@@ -301,7 +301,7 @@ def _as_proto(self) -> call_pb.CallResult:
301
301
output_any = None
302
302
error_proto = None
303
303
if self .output is not None :
304
- output_any = _pb_any_pickle (self .output )
304
+ output_any = _any_pickle (self .output )
305
305
if self .error is not None :
306
306
error_proto = self .error ._as_proto ()
307
307
@@ -440,31 +440,17 @@ def _as_proto(self) -> error_pb.Error:
440
440
)
441
441
442
442
443
- def _any_unpickle (any : google .protobuf .any_pb2 .Any ) -> Any :
444
- if any .Is (pickled_pb .Pickled .DESCRIPTOR ):
445
- p = pickled_pb .Pickled ()
446
- any .Unpack (p )
447
- return pickle .loads (p .pickled_value )
448
-
449
- elif any .Is (google .protobuf .wrappers_pb2 .BytesValue .DESCRIPTOR ): # legacy container
450
- b = google .protobuf .wrappers_pb2 .BytesValue ()
451
- any .Unpack (b )
452
- return pickle .loads (b .value )
453
-
454
- elif not any .type_url and not any .value :
455
- return None
456
-
457
- raise InvalidArgumentError (f"unsupported pickled value container: { any .type_url } " )
458
-
459
-
460
- def _pb_any_pickle (value : Any ) -> google .protobuf .any_pb2 .Any :
461
- p = pickled_pb .Pickled (pickled_value = pickle .dumps (value ))
443
+ def _any_pickle (value : Any ) -> google .protobuf .any_pb2 .Any :
462
444
any = google .protobuf .any_pb2 .Any ()
463
- any .Pack (p , type_url_prefix = "buf.build/stealthrocket/dispatch-proto/" )
445
+ if isinstance (value , google .protobuf .message .Message ):
446
+ any .Pack (value )
447
+ else :
448
+ p = pickled_pb .Pickled (pickled_value = pickle .dumps (value ))
449
+ any .Pack (p , type_url_prefix = "buf.build/stealthrocket/dispatch-proto/" )
464
450
return any
465
451
466
452
467
- def _pb_any_unpack (any : google .protobuf .any_pb2 .Any ) -> Any :
453
+ def _any_unpickle (any : google .protobuf .any_pb2 .Any ) -> Any :
468
454
if any .Is (pickled_pb .Pickled .DESCRIPTOR ):
469
455
p = pickled_pb .Pickled ()
470
456
any .Unpack (p )
0 commit comments