1616 along with this program. If not, see <https://www.gnu.org/licenses/>.
1717"""
1818
19- import socket
2019import asyncio
21- import netifaces
2220
2321from enum import Enum , auto
2422
2523from .Enums import *
24+ from .Encoding import *
2625
2726class MsgRxStatus (Enum ):
2827 MSG_RX_MAGIC = auto ()
@@ -33,14 +32,7 @@ class MsgRxStatus(Enum):
3332class MessageReceiver (asyncio .Protocol ):
3433 MSG_MAGIC_IDENTIFIER = b"\xFF ZB\x02 "
3534
36- def __init__ (self , on_connection_made , on_message , on_error ):
37- loop = asyncio .get_running_loop ()
38-
39- self .on_connection_lost = loop .create_future ()
40- self .on_connection_made = on_connection_made
41- self .on_message = on_message
42- self .on_error = on_error
43-
35+ def __init__ (self ):
4436 self .status = MsgRxStatus .MSG_RX_MAGIC
4537 self .buffer = b"\x00 \x00 \x00 \x00 "
4638 self .count = 0
@@ -50,9 +42,6 @@ def __init__(self, on_connection_made, on_message, on_error):
5042 def connection_made (self , transport ):
5143 self .transport = transport
5244
53- if self .on_connection_made != None :
54- self .on_connection_made ()
55-
5645 def bytes_received (self , data ):
5746 for b in data :
5847 b = b .to_bytes (1 , "little" )
@@ -69,17 +58,15 @@ def bytes_received(self, data):
6958 self .buffer += b
7059
7160 if len (self .buffer ) == 4 :
72- self .id = int . from_bytes (self .buffer [0 :2 ], byteorder = 'little' , signed = False )
73- self .size = int . from_bytes (self .buffer [2 :4 ], byteorder = 'little' , signed = False )
61+ self .id = decode_u16 (self .buffer [0 :2 ])
62+ self .size = decode_u16 (self .buffer [2 :4 ])
7463 self .buffer = b""
7564
7665 if self .id == Messages .MSG_ID_EXTENDED :
7766 self .status = MsgRxStatus .MSG_RX_EXTENDED_HEADER
7867 self .id = self .size
7968 elif self .size == 0 :
80- if self .on_message != None :
81- self .on_message (self .id , self .buffer )
82-
69+ self .message_received (self .id , self .buffer )
8370 self .status = MsgRxStatus .MSG_RX_MAGIC
8471 self .buffer = b"\x00 \x00 \x00 \x00 "
8572 else :
@@ -94,30 +81,29 @@ def bytes_received(self, data):
9481 self .buffer += b
9582
9683 if len (self .buffer ) == 4 :
97- self .size = int . from_bytes (self .buffer , byteorder = 'little' , signed = False )
84+ self .size = decode_u32 (self .buffer )
9885 self .buffer = b""
9986
10087 if self .size == 0 :
10188 self .status = MsgRxStatus .MSG_RX_DATA
10289 else :
103- if self .on_message != None :
104- self .on_message (self .id , self .buffer )
105-
90+ self .message_received (self .id , self .buffer )
10691 self .status = MsgRxStatus .MSG_RX_MAGIC
10792 self .buffer = b"\x00 \x00 \x00 \x00 "
10893 else :
10994 self .buffer += b
11095
11196 if len (self .buffer ) == self .size :
112- if self .on_message != None :
113- self .on_message (self .id , self .buffer )
114-
97+ self .message_received (self .id , self .buffer )
11598 self .status = MsgRxStatus .MSG_RX_MAGIC
11699 self .buffer = b"\x00 \x00 \x00 \x00 "
117100
118- def error_received (self , exc ):
119- if self .on_error != None :
120- self .on_error (exc )
101+ def message_received (self , msg_id , msg_payload ):
102+ pass
103+
104+ def error_received (self , err ):
105+ raise ConnectionError (err )
121106
122107 def connection_lost (self , exc ):
123- self .on_connection_lost .set_result (True )
108+ if exc != None :
109+ raise ConnectionError (exc )
0 commit comments