Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install pylint
pip install --upgrade aiohttp pycryptodomex apscheduler
pip install --upgrade aiohttp pycryptodomex apscheduler aio_pika
- name: Analysing the code with pylint
run: |
pylint $(git ls-files '*.py')
3 changes: 2 additions & 1 deletion khl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
FriendTypes
)
from .cert import Cert
from .receiver import Receiver, WebhookReceiver, WebsocketReceiver
from .rabbitmq import RabbitMQ, RabbitMQProducer
from .receiver import Receiver, WebhookReceiver, WebsocketReceiver, RabbitMQReceiver
from .requester import HTTPRequester
from .gateway import Gateway, Requestable
from .client import Client
Expand Down
21 changes: 17 additions & 4 deletions khl/bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .. import Cert, HTTPRequester, WebhookReceiver, WebsocketReceiver, Gateway, Client # net related
from .. import MessageTypes, EventTypes, SlowModeTypes, SoftwareTypes # types
from .. import User, Channel, PublicChannel, Guild, Event, Message # concepts
from .. import RabbitMQ, RabbitMQProducer, RabbitMQReceiver # rabbitmq
from ..command import CommandManager
from ..game import Game
from ..task import TaskManager
Expand Down Expand Up @@ -49,7 +50,8 @@ def __init__(self,
out: HTTPRequester = None,
compress: bool = True,
port=5000,
route='/khl-wh'):
route='/khl-wh',
rabbitmq: RabbitMQ = None):
"""
The most common usage: ``Bot(token='xxxxxx')``

Expand All @@ -62,11 +64,14 @@ def __init__(self,
:param compress: used to tune the receiver
:param port: used to tune the WebhookReceiver
:param route: used to tune the WebhookReceiver
:param rabbitmq: used to tune the RabbitMQ Receiver or Producer
"""
if not token and not cert:
raise ValueError('require token or cert')

self._init_client(cert or Cert(token=token), client, gate, out, compress, port, route)
is_rabbitmq_receiver = rabbitmq is not None and not rabbitmq.is_producer
self._init_client(cert or Cert(token=token, is_rabbitmq_receiver=is_rabbitmq_receiver), client, gate, out,
compress, port, route, rabbitmq)
self._register_client_handler()

self.command = CommandManager()
Expand All @@ -78,7 +83,8 @@ def __init__(self,
self._startup_index = []
self._shutdown_index = []

def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPRequester, compress: bool, port, route):
def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPRequester, compress: bool, port, route,
rabbitmq: RabbitMQ):
"""
construct self.client from args.

Expand All @@ -92,6 +98,7 @@ def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPReque
:param compress: used to tune the receiver
:param port: used to tune the WebhookReceiver
:param route: used to tune the WebhookReceiver
:param rabbitmq: used to tune the RabbitMQ Receiver or Producer
:return:
"""
if client:
Expand All @@ -107,10 +114,16 @@ def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPReque
_in = WebsocketReceiver(cert, compress)
elif cert.type == Cert.Types.WEBHOOK:
_in = WebhookReceiver(cert, port=port, route=route, compress=compress)
elif cert.type == Cert.Types.RABBITMQ:
_in = RabbitMQReceiver(rabbitmq, compress)
else:
raise ValueError(f'cert type: {cert.type} not supported')

self.client = Client(Gateway(_out, _in))
rabbitmq_producer = None
if rabbitmq is not None and rabbitmq.is_producer:
rabbitmq_producer = RabbitMQProducer(rabbitmq, compress)

self.client = Client(Gateway(_out, _in), rabbitmq_producer)

def _register_client_handler(self):
# text and kmd -> msg
Expand Down
10 changes: 8 additions & 2 deletions khl/cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ class Types(Enum):
"""
webhook cert
"""

def __init__(self, *, type: Types = Types.NOTSET, token: str, verify_token: str = '', encrypt_key: str = ''):
RABBITMQ = "rabbitmq"
"""
rabbitmq cert
"""
def __init__(self, *, type: Types = Types.NOTSET, token: str, verify_token: str = '', encrypt_key: str = '',
is_rabbitmq_receiver: bool = False):
"""
all fields from bot config panel
"""
Expand All @@ -39,6 +43,8 @@ def __init__(self, *, type: Types = Types.NOTSET, token: str, verify_token: str
else:
if verify_token:
self.type = self.Types.WEBHOOK
elif is_rabbitmq_receiver:
self.type = self.Types.RABBITMQ
else:
self.type = self.Types.WEBSOCKET
self.token = token
Expand Down
14 changes: 11 additions & 3 deletions khl/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ._types import SoftwareTypes, MessageTypes, SlowModeTypes, GameTypes
from .user import User, Friend, FriendRequest
from .util import unpack_id, unpack_value
from .rabbitmq import RabbitMQProducer

log = logging.getLogger(__name__)

Expand All @@ -33,13 +34,14 @@ class Client(Requestable, AsyncRunnable):
"""
_handler_map: Dict[MessageTypes, List[TypeHandler]]

def __init__(self, gate: Gateway):
def __init__(self, gate: Gateway, rabbitmq_producer: RabbitMQProducer = None):
self.gate = gate
self.ignore_self_msg = True
self._me = None

self._handler_map = {}
self._pkg_queue = asyncio.Queue()
self._rabbitmq_producer: RabbitMQProducer = rabbitmq_producer

def register(self, type: MessageTypes, handler: TypeHandler):
"""register handler to handle messages of type"""
Expand All @@ -61,7 +63,10 @@ async def handle_pkg(self):
log.debug(f'upcoming pkg: {pkg}')

try:
await self._consume_pkg(pkg)
if self._rabbitmq_producer is not None:
await self._rabbitmq_producer.publish(pkg)
else:
await self._consume_pkg(pkg)
except Exception as e:
log.exception(e)

Expand Down Expand Up @@ -362,4 +367,7 @@ async def fetch_blocked_friends(self) -> List[Friend]:
return [Friend(_gate_=self.gate, user_id=i['friend_info']['id'], **i) for i in friends]

async def start(self):
await asyncio.gather(self.handle_pkg(), self.gate.run(self._pkg_queue))
if self._rabbitmq_producer is not None:
await asyncio.gather(self.handle_pkg(), self.gate.run(self._pkg_queue), self._rabbitmq_producer.start())
else:
await asyncio.gather(self.handle_pkg(), self.gate.run(self._pkg_queue))
141 changes: 141 additions & 0 deletions khl/rabbitmq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import asyncio
import hashlib
import json
import logging
import zlib

import aio_pika
from Cryptodome.Cipher import AES
from Cryptodome.Util import Padding
from aio_pika.abc import AbstractRobustConnection, AbstractRobustChannel, AbstractExchange, AbstractRobustQueue

from .interface import AsyncRunnable

log = logging.getLogger(__name__)


class RabbitMQ:
"""rabbitmq configurate/init/connect/encrypt/decrypt/encode/decode"""

_connection: AbstractRobustConnection = None
_channel: AbstractRobustChannel = None
_queue: AbstractRobustQueue = None
_exchange: AbstractExchange = None

def __init__(self, login: str, password: str, host: str = '127.0.0.1', port: int = 5672, queue: str = 'kook',
qos: int = 10, heartbeat: int = 30, key: str = '', key_digits: int = 16, is_producer: bool = False):
self._host = host
self._port = port
self.queue = queue
self.qos = qos
self._heartbeat = heartbeat
self._login = login
self._password = password
self._key = key
self._key_digits = key_digits
self.is_producer = is_producer

# check aes is 128, 192 or 256
if key_digits not in AES.key_size:
raise ValueError(f'rabbitmq key_digits: {key_digits} not in {AES.key_size}')
if key != '':
key_encoded = key.encode('utf-8').ljust(key_digits, b'\x00')
else:
# if rabbitmq_key is not defined, use sha256 to generate one
key_encoded = hashlib.sha256(f'{login}:{password}'.encode('utf-8')).digest()

# make sure key digits is right
self._aes_key = key_encoded[:key_digits]

def decrypt(self, data: bytes) -> bytes:
""" decrypt data

:param data: encrypted byte array
:return: decrypted byte array
"""
decipher = AES.new(self._aes_key, AES.MODE_CBC, iv=data[:16])
data = decipher.decrypt(data[16:])
data = Padding.unpad(data, 16)
return data

def encrypt(self, data: bytes) -> bytes:
""" encrypt data

:param data: byte array
:return: encrypted byte array
"""
data = Padding.pad(data, 16)
cipher = AES.new(self._aes_key, AES.MODE_CBC)
data = cipher.encrypt(data)
return cipher.iv + data

def decode(self, data: bytes, compress: bool) -> dict:
"""decode raw rabbitmq data into plaintext data"""
data = self.decrypt(data)
if compress:
data = zlib.decompress(data)
return json.loads(str(data, encoding='utf-8'))

def encode(self, data: dict, compress: bool) -> bytes:
"""encode pkg into rabbitmq data"""
data = json.dumps(data).encode(encoding='utf-8')
if compress:
data = zlib.compress(data)
data = self.encrypt(data)
return data

async def get_connection(self) -> AbstractRobustConnection:
"""get rabbitmq connection"""
if self._connection is None:
self._connection = await aio_pika.connect_robust(
host=self._host,
port=self._port,
login=self._login,
password=self._password,
heartbeat=self._heartbeat
)
await self._connection.connect()
return self._connection

async def get_channel(self) -> AbstractRobustChannel:
"""get rabbitmq channel"""
if self._channel is None:
connection = await self.get_connection()
self._channel = await connection.channel()
return self._channel

async def get_queue(self) -> AbstractRobustQueue:
"""get rabbitmq queue"""
if self._queue is None:
channel = await self.get_channel()
self._queue = await channel.declare_queue(self.queue)
return self._queue

async def get_exchange(self) -> AbstractExchange:
"""get rabbitmq default exchange"""
if self._exchange is None:
channel = await self.get_channel()
self._exchange = channel.default_exchange
return self._exchange


class RabbitMQProducer(AsyncRunnable):
"""produce data to rabbitmq"""

_exchange: AbstractExchange
_connection: AbstractRobustConnection

def __init__(self, rabbitmq: RabbitMQ, compress: bool):
super().__init__()
self._rabbitmq = rabbitmq
self._compress = compress

async def start(self):
self._exchange = await self._rabbitmq.get_exchange()
while True:
await asyncio.sleep(3600) # sleep forever

async def publish(self, data: dict):
"""produce data"""
await self._exchange.publish(aio_pika.Message(body=self._rabbitmq.encode(data, self._compress)),
routing_key=self._rabbitmq.queue)
35 changes: 35 additions & 0 deletions khl/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from aiohttp import ClientWebSocketResponse, ClientSession, web, WSMessage

from .rabbitmq import RabbitMQ
from .cert import Cert
from .interface import AsyncRunnable

Expand Down Expand Up @@ -120,6 +121,40 @@ async def _handle_raw(self, raw: WSMessage):
log.exception(e)


class RabbitMQReceiver(Receiver):
"""receive data in RabbitMQ mode"""

def __init__(self, rabbitmq: RabbitMQ, compress: bool):
super().__init__()
self._rabbitmq = rabbitmq
self.compress = compress

@property
def type(self) -> str:
return 'rabbitmq'

async def start(self):
queue = await self._rabbitmq.get_queue()

log.info('[ init ] launched')

async with queue.iterator() as queue_iter:
async for message in queue_iter:
async with message.process():
try:
pkg: Dict = self._rabbitmq.decode(message.body, self.compress)
except Exception as e:
log.exception(e)
continue

if not pkg: # empty pkg
continue

while self.pkg_queue.qsize() >= self._rabbitmq.qos:
await asyncio.sleep(0.001)
await self.pkg_queue.put(pkg)


class WebhookReceiver(Receiver):
"""receive data in webhook mode"""

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ install_requires =
aiohttp
pycryptodomex
apscheduler
aio_pika

[options.packages.find]
where = .