Skip to content

Commit

Permalink
Minor cleanup to rpc.py (#295)
Browse files Browse the repository at this point in the history
* Remove copies and flush print buffers in rpc.py

* format code
  • Loading branch information
dcrankshaw authored and Corey-Zumar committed Sep 21, 2017
1 parent b4ed00f commit f3dc0fa
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions containers/python/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,14 @@ def run(self):
if time_delta_millis >= SOCKET_ACTIVITY_TIMEOUT_MILLIS:
# Terminate the session
print("Connection timed out, reconnecting...")
sys.stdout.flush()
sys.stderr.flush()
connected = False
poller.unregister(socket)
socket.close()
break
else:
self.send_heartbeat(socket)
sys.stdout.flush()
sys.stderr.flush()
continue

# Received a message before the polling timeout
Expand All @@ -230,6 +230,8 @@ def run(self):
if msg_type == MESSAGE_TYPE_HEARTBEAT:
self.event_history.insert(EVENT_HISTORY_RECEIVED_HEARTBEAT)
print("Received heartbeat!")
sys.stdout.flush()
sys.stderr.flush()
heartbeat_type_bytes = socket.recv()
heartbeat_type = struct.unpack("<I",
heartbeat_type_bytes)[0]
Expand Down Expand Up @@ -327,6 +329,8 @@ def send_container_metadata(self, socket):
socket.send_string(str(self.model_input_type))
self.event_history.insert(EVENT_HISTORY_SENT_CONTAINER_METADATA)
print("Sent container metadata!")
sys.stdout.flush()
sys.stderr.flush()

def send_heartbeat(self, socket):
socket.send("", zmq.SNDMORE)
Expand All @@ -340,9 +344,9 @@ class PredictionRequest:
Parameters
----------
msg_id : bytes
The raw message id associated with the RPC
The raw message id associated with the RPC
prediction request message
inputs :
inputs :
One of [[byte]], [[int]], [[float]], [[double]], [string]
"""

Expand Down Expand Up @@ -373,8 +377,9 @@ def __init__(self, msg_id, num_outputs, total_string_length):
self.num_outputs = num_outputs
self.expand_buffer_if_necessary(
total_string_length * MAXIMUM_UTF_8_CHAR_LENGTH_BYTES)
self.memview = memoryview(self.output_buffer)
struct.pack_into("<I", self.output_buffer, 0, num_outputs)
self.memview = memoryview(PredictionResponse.output_buffer)
struct.pack_into("<I", PredictionResponse.output_buffer, 0,
num_outputs)
self.string_content_end_position = BYTES_PER_INT + (
BYTES_PER_INT * num_outputs)
self.current_output_sizes_position = BYTES_PER_INT
Expand All @@ -387,7 +392,7 @@ def add_output(self, output):
"""
output = unicode(output, "utf-8").encode("utf-8")
output_len = len(output)
struct.pack_into("<I", self.output_buffer,
struct.pack_into("<I", PredictionResponse.output_buffer,
self.current_output_sizes_position, output_len)
self.current_output_sizes_position += BYTES_PER_INT
self.memview[self.string_content_end_position:
Expand All @@ -400,12 +405,13 @@ def send(self, socket, event_history):
struct.pack("<I", MESSAGE_TYPE_CONTAINER_CONTENT),
flags=zmq.SNDMORE)
socket.send(self.msg_id, flags=zmq.SNDMORE)
socket.send(self.output_buffer[0:self.string_content_end_position])
socket.send(PredictionResponse.output_buffer[
0:self.string_content_end_position])
event_history.insert(EVENT_HISTORY_SENT_CONTAINER_CONTENT)

def expand_buffer_if_necessary(self, size):
if len(self.output_buffer) < size:
self.output_buffer = bytearray(size * 2)
if len(PredictionResponse.output_buffer) < size:
PredictionResponse.output_buffer = bytearray(size * 2)


class FeedbackRequest():
Expand Down

0 comments on commit f3dc0fa

Please sign in to comment.