Skip to content

Commit 9e03e23

Browse files
committed
feat: add simple RabbitMQ support
1 parent 25d26b4 commit 9e03e23

File tree

7 files changed

+215
-10
lines changed

7 files changed

+215
-10
lines changed

khl/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
FriendTypes
1818
)
1919
from .cert import Cert
20-
from .receiver import Receiver, WebhookReceiver, WebsocketReceiver
20+
from .rabbitmq import RabbitMQ, RabbitMQProducer
21+
from .receiver import Receiver, WebhookReceiver, WebsocketReceiver, RabbitMQReceiver
2122
from .requester import HTTPRequester
2223
from .gateway import Gateway, Requestable
2324
from .client import Client

khl/bot/bot.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .. import Cert, HTTPRequester, WebhookReceiver, WebsocketReceiver, Gateway, Client # net related
1010
from .. import MessageTypes, EventTypes, SlowModeTypes, SoftwareTypes # types
1111
from .. import User, Channel, PublicChannel, Guild, Event, Message # concepts
12+
from .. import RabbitMQ, RabbitMQProducer, RabbitMQReceiver # rabbitmq
1213
from ..command import CommandManager
1314
from ..game import Game
1415
from ..task import TaskManager
@@ -49,7 +50,8 @@ def __init__(self,
4950
out: HTTPRequester = None,
5051
compress: bool = True,
5152
port=5000,
52-
route='/khl-wh'):
53+
route='/khl-wh',
54+
rabbitmq: RabbitMQ = None):
5355
"""
5456
The most common usage: ``Bot(token='xxxxxx')``
5557
@@ -62,11 +64,14 @@ def __init__(self,
6264
:param compress: used to tune the receiver
6365
:param port: used to tune the WebhookReceiver
6466
:param route: used to tune the WebhookReceiver
67+
:param rabbitmq: used to tune the RabbitMQ Receiver or Producer
6568
"""
6669
if not token and not cert:
6770
raise ValueError('require token or cert')
6871

69-
self._init_client(cert or Cert(token=token), client, gate, out, compress, port, route)
72+
is_rabbitmq_receiver = rabbitmq is not None and not rabbitmq.is_producer
73+
self._init_client(cert or Cert(token=token, is_rabbitmq_receiver=is_rabbitmq_receiver), client, gate, out,
74+
compress, port, route, rabbitmq)
7075
self._register_client_handler()
7176

7277
self.command = CommandManager()
@@ -78,7 +83,8 @@ def __init__(self,
7883
self._startup_index = []
7984
self._shutdown_index = []
8085

81-
def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPRequester, compress: bool, port, route):
86+
def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPRequester, compress: bool, port, route,
87+
rabbitmq: RabbitMQ):
8288
"""
8389
construct self.client from args.
8490
@@ -92,6 +98,7 @@ def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPReque
9298
:param compress: used to tune the receiver
9399
:param port: used to tune the WebhookReceiver
94100
:param route: used to tune the WebhookReceiver
101+
:param rabbitmq: used to tune the RabbitMQ Receiver or Producer
95102
:return:
96103
"""
97104
if client:
@@ -107,10 +114,16 @@ def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPReque
107114
_in = WebsocketReceiver(cert, compress)
108115
elif cert.type == Cert.Types.WEBHOOK:
109116
_in = WebhookReceiver(cert, port=port, route=route, compress=compress)
117+
elif cert.type == Cert.Types.RABBITMQ:
118+
_in = RabbitMQReceiver(rabbitmq, compress)
110119
else:
111120
raise ValueError(f'cert type: {cert.type} not supported')
112121

113-
self.client = Client(Gateway(_out, _in))
122+
rabbitmq_producer = None
123+
if rabbitmq is not None and rabbitmq.is_producer:
124+
rabbitmq_producer = RabbitMQProducer(rabbitmq, compress)
125+
126+
self.client = Client(Gateway(_out, _in), rabbitmq_producer)
114127

115128
def _register_client_handler(self):
116129
# text and kmd -> msg

khl/cert.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,12 @@ class Types(Enum):
2929
"""
3030
webhook cert
3131
"""
32-
33-
def __init__(self, *, type: Types = Types.NOTSET, token: str, verify_token: str = '', encrypt_key: str = ''):
32+
RABBITMQ = "rabbitmq"
33+
"""
34+
rabbitmq cert
35+
"""
36+
def __init__(self, *, type: Types = Types.NOTSET, token: str, verify_token: str = '', encrypt_key: str = '',
37+
is_rabbitmq_receiver: bool = False):
3438
"""
3539
all fields from bot config panel
3640
"""
@@ -39,6 +43,8 @@ def __init__(self, *, type: Types = Types.NOTSET, token: str, verify_token: str
3943
else:
4044
if verify_token:
4145
self.type = self.Types.WEBHOOK
46+
elif is_rabbitmq_receiver:
47+
self.type = self.Types.RABBITMQ
4248
else:
4349
self.type = self.Types.WEBSOCKET
4450
self.token = token

khl/client.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ._types import SoftwareTypes, MessageTypes, SlowModeTypes, GameTypes
1717
from .user import User, Friend, FriendRequest
1818
from .util import unpack_id, unpack_value
19+
from .rabbitmq import RabbitMQProducer
1920

2021
log = logging.getLogger(__name__)
2122

@@ -33,13 +34,14 @@ class Client(Requestable, AsyncRunnable):
3334
"""
3435
_handler_map: Dict[MessageTypes, List[TypeHandler]]
3536

36-
def __init__(self, gate: Gateway):
37+
def __init__(self, gate: Gateway, rabbitmq_producer: RabbitMQProducer = None):
3738
self.gate = gate
3839
self.ignore_self_msg = True
3940
self._me = None
4041

4142
self._handler_map = {}
4243
self._pkg_queue = asyncio.Queue()
44+
self._rabbitmq_producer: RabbitMQProducer = rabbitmq_producer
4345

4446
def register(self, type: MessageTypes, handler: TypeHandler):
4547
"""register handler to handle messages of type"""
@@ -61,7 +63,10 @@ async def handle_pkg(self):
6163
log.debug(f'upcoming pkg: {pkg}')
6264

6365
try:
64-
await self._consume_pkg(pkg)
66+
if self._rabbitmq_producer is not None:
67+
await self._rabbitmq_producer.publish(pkg)
68+
else:
69+
await self._consume_pkg(pkg)
6570
except Exception as e:
6671
log.exception(e)
6772

@@ -362,4 +367,7 @@ async def fetch_blocked_friends(self) -> List[Friend]:
362367
return [Friend(_gate_=self.gate, user_id=i['friend_info']['id'], **i) for i in friends]
363368

364369
async def start(self):
365-
await asyncio.gather(self.handle_pkg(), self.gate.run(self._pkg_queue))
370+
if self._rabbitmq_producer is not None:
371+
await asyncio.gather(self.handle_pkg(), self.gate.run(self._pkg_queue), self._rabbitmq_producer.start())
372+
else:
373+
await asyncio.gather(self.handle_pkg(), self.gate.run(self._pkg_queue))

khl/rabbitmq.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import asyncio
2+
import hashlib
3+
import json
4+
import logging
5+
import zlib
6+
7+
import aio_pika
8+
from Cryptodome.Cipher import AES
9+
from Cryptodome.Util import Padding
10+
from aio_pika.abc import AbstractRobustConnection, AbstractRobustChannel, AbstractExchange, AbstractRobustQueue
11+
12+
from .interface import AsyncRunnable
13+
14+
log = logging.getLogger(__name__)
15+
16+
17+
class RabbitMQ:
18+
"""rabbitmq configurate/init/connect/encrypt/decrypt/encode/decode"""
19+
20+
_connection: AbstractRobustConnection = None
21+
_channel: AbstractRobustChannel = None
22+
_queue: AbstractRobustQueue = None
23+
_exchange: AbstractExchange = None
24+
25+
def __init__(self, login: str, password: str, host: str = '127.0.0.1', port: int = 5672, queue: str = 'kook',
26+
qos: int = 10, heartbeat: int = 30, key: str = '', key_digits: int = 16, is_producer: bool = False):
27+
self._host = host
28+
self._port = port
29+
self.queue = queue
30+
self.qos = qos
31+
self._heartbeat = heartbeat
32+
self._login = login
33+
self._password = password
34+
self._key = key
35+
self._key_digits = key_digits
36+
self.is_producer = is_producer
37+
38+
# check aes is 128, 192 or 256
39+
if key_digits not in AES.key_size:
40+
raise ValueError(f'rabbitmq key_digits: {key_digits} not in {AES.key_size}')
41+
if key != '':
42+
key_encoded = key.encode('utf-8').ljust(key_digits, b'\x00')
43+
else:
44+
# if rabbitmq_key is not defined, use sha256 to generate one
45+
key_encoded = hashlib.sha256(f'{login}:{password}'.encode('utf-8')).digest()
46+
47+
# make sure key digits is right
48+
self._aes_key = key_encoded[:key_digits]
49+
50+
def decrypt(self, data: bytes) -> bytes:
51+
""" decrypt data
52+
53+
:param data: encrypted byte array
54+
:return: decrypted byte array
55+
"""
56+
decipher = AES.new(self._aes_key, AES.MODE_CBC, iv=data[:16])
57+
data = decipher.decrypt(data[16:])
58+
data = Padding.unpad(data, 16)
59+
return data
60+
61+
def encrypt(self, data: bytes) -> bytes:
62+
""" encrypt data
63+
64+
:param data: byte array
65+
:return: encrypted byte array
66+
"""
67+
data = Padding.pad(data, 16)
68+
cipher = AES.new(self._aes_key, AES.MODE_CBC)
69+
data = cipher.encrypt(data)
70+
return cipher.iv + data
71+
72+
def decode(self, data: bytes, compress: bool) -> dict:
73+
"""decode raw rabbitmq data into plaintext data"""
74+
data = self.decrypt(data)
75+
if compress:
76+
data = zlib.decompress(data)
77+
return json.loads(str(data, encoding='utf-8'))
78+
79+
def encode(self, data: dict, compress: bool) -> bytes:
80+
"""encode pkg into rabbitmq data"""
81+
data = json.dumps(data).encode(encoding='utf-8')
82+
if compress:
83+
data = zlib.compress(data)
84+
data = self.encrypt(data)
85+
return data
86+
87+
async def get_connection(self) -> AbstractRobustConnection:
88+
"""get rabbitmq connection"""
89+
if self._connection is None:
90+
self._connection = await aio_pika.connect_robust(
91+
host=self._host,
92+
port=self._port,
93+
login=self._login,
94+
password=self._password,
95+
heartbeat=self._heartbeat
96+
)
97+
await self._connection.connect()
98+
return self._connection
99+
100+
async def get_channel(self) -> AbstractRobustChannel:
101+
"""get rabbitmq channel"""
102+
if self._channel is None:
103+
connection = await self.get_connection()
104+
self._channel = await connection.channel()
105+
return self._channel
106+
107+
async def get_queue(self) -> AbstractRobustQueue:
108+
"""get rabbitmq queue"""
109+
if self._queue is None:
110+
channel = await self.get_channel()
111+
self._queue = await channel.declare_queue(self.queue)
112+
return self._queue
113+
114+
async def get_exchange(self) -> AbstractExchange:
115+
"""get rabbitmq default exchange"""
116+
if self._exchange is None:
117+
channel = await self.get_channel()
118+
self._exchange = channel.default_exchange
119+
return self._exchange
120+
121+
122+
class RabbitMQProducer(AsyncRunnable):
123+
"""produce data to rabbitmq"""
124+
125+
_exchange: AbstractExchange
126+
_connection: AbstractRobustConnection
127+
128+
def __init__(self, rabbitmq: RabbitMQ, compress: bool):
129+
super().__init__()
130+
self._rabbitmq = rabbitmq
131+
self._compress = compress
132+
133+
async def start(self):
134+
self._exchange = await self._rabbitmq.get_exchange()
135+
while True:
136+
await asyncio.sleep(3600) # sleep forever
137+
138+
async def publish(self, data: dict):
139+
"""produce data"""
140+
await self._exchange.publish(aio_pika.Message(body=self._rabbitmq.encode(data, self._compress)),
141+
routing_key=self._rabbitmq.queue)

khl/receiver.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from aiohttp import ClientWebSocketResponse, ClientSession, web, WSMessage
99

10+
from .rabbitmq import RabbitMQ
1011
from .cert import Cert
1112
from .interface import AsyncRunnable
1213

@@ -120,6 +121,40 @@ async def _handle_raw(self, raw: WSMessage):
120121
log.exception(e)
121122

122123

124+
class RabbitMQReceiver(Receiver):
125+
"""receive data in RabbitMQ mode"""
126+
127+
def __init__(self, rabbitmq: RabbitMQ, compress: bool):
128+
super().__init__()
129+
self._rabbitmq = rabbitmq
130+
self.compress = compress
131+
132+
@property
133+
def type(self) -> str:
134+
return 'rabbitmq'
135+
136+
async def start(self):
137+
queue = await self._rabbitmq.get_queue()
138+
139+
log.info('[ init ] launched')
140+
141+
async with queue.iterator() as queue_iter:
142+
async for message in queue_iter:
143+
async with message.process():
144+
try:
145+
pkg: Dict = self._rabbitmq.decode(message.body, self.compress)
146+
except Exception as e:
147+
log.exception(e)
148+
continue
149+
150+
if not pkg: # empty pkg
151+
continue
152+
153+
while self.pkg_queue.qsize() >= self._rabbitmq.qos:
154+
await asyncio.sleep(0.001)
155+
await self.pkg_queue.put(pkg)
156+
157+
123158
class WebhookReceiver(Receiver):
124159
"""receive data in webhook mode"""
125160

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ install_requires =
2424
aiohttp
2525
pycryptodomex
2626
apscheduler
27+
aio_pika
2728

2829
[options.packages.find]
2930
where = .

0 commit comments

Comments
 (0)