Skip to content
Open
2 changes: 1 addition & 1 deletion .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install pylint
pip install --upgrade aiohttp pycryptodomex apscheduler
pip install --upgrade aiohttp pycryptodomex apscheduler aio_pika
- name: Analysing the code with pylint
run: |
pylint $(git ls-files '*.py')
16 changes: 11 additions & 5 deletions khl/bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from typing import Dict, Callable, List, Optional, Union, Coroutine, IO

from .. import AsyncRunnable # interfaces
from .. import Cert, HTTPRequester, RateLimiter, WebhookReceiver, WebsocketReceiver, Gateway, Client # net related
from .. import Cert, HTTPRequester, RateLimiter, Gateway, Client # net related
from .. import Receiver, WebhookReceiver, WebsocketReceiver # net related, receivers
from .. import MessageTypes, EventTypes, SlowModeTypes, SoftwareTypes # types
from .. import User, Channel, PublicChannel, Guild, Event, Message # concepts
from ..command import CommandManager
Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(self,
cert: Cert = None,
client: Client = None,
gate: Gateway = None,
receiver: Receiver = None,
out: HTTPRequester = None,
compress: bool = True,
port=5000,
Expand All @@ -59,6 +61,7 @@ def __init__(self,
:param cert: used to build requester and receiver
:param client: the bot relies on
:param gate: the client relies on
:param receiver: custom receiver as the gate's component
:param out: the gate's component
:param compress: used to tune the receiver
:param port: used to tune the WebhookReceiver
Expand All @@ -67,7 +70,7 @@ def __init__(self,
if not token and not cert:
raise ValueError('require token or cert')

self._init_client(cert or Cert(token=token), client, gate, out, compress, port, route, ratelimiter)
self._init_client(cert or Cert(token=token), client, gate, receiver, out, compress, port, route, ratelimiter)
self._register_client_handler()

self.command = CommandManager()
Expand All @@ -79,8 +82,8 @@ def __init__(self,
self._startup_index = []
self._shutdown_index = []

def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPRequester, compress: bool, port, route,
ratelimiter):
def _init_client(self, cert: Cert, client: Client, gate: Gateway, receiver: Receiver, out: HTTPRequester,
compress: bool, port, route, ratelimiter):
"""
construct self.client from args.

Expand All @@ -90,6 +93,7 @@ def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPReque
:param cert: used to build requester and receiver
:param client: the bot relies on
:param gate: the client relies on
:param receiver: custom receiver as the gate's component
:param out: the gate's component
:param compress: used to tune the receiver
:param port: used to tune the WebhookReceiver
Expand All @@ -105,7 +109,9 @@ def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPReque

# client and gate not in args, build them
_out = out if out else HTTPRequester(cert, ratelimiter)
if cert.type == Cert.Types.WEBSOCKET:
if receiver:
_in = receiver
elif cert.type == Cert.Types.WEBSOCKET:
_in = WebsocketReceiver(cert, compress)
elif cert.type == Cert.Types.WEBHOOK:
_in = WebhookReceiver(cert, port=port, route=route, compress=compress)
Expand Down
3 changes: 1 addition & 2 deletions khl/command/parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""parser: component used in command args handling, convert string token to fit the command signature"""
import asyncio
import copy
import inspect
import logging
Expand Down Expand Up @@ -29,7 +28,7 @@ def _get_param_type(param: Union[inspect.Parameter, None]):


def _wrap_one_param_func(func: Callable) -> Callable:
def wrapper(msg: Message, client: Client, token: str):
def wrapper(_msg: Message, _client: Client, token: str):
return func(token)
return wrapper

Expand Down
3 changes: 3 additions & 0 deletions khl/rabbitmq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .rabbitmq import RabbitMQ
from .rabbitmq_receiver import RabbitMQReceiver
from .rabbitmq_producer import RabbitMQProducer, RabbitMQProductionBot
114 changes: 114 additions & 0 deletions khl/rabbitmq/rabbitmq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import hashlib
import json
import logging
import zlib
import aio_pika
from Cryptodome.Cipher import AES
from Cryptodome.Util import Padding
from aio_pika.abc import AbstractRobustConnection, AbstractRobustChannel, AbstractExchange, AbstractRobustQueue
log = logging.getLogger(__name__)


class RabbitMQ:
"""rabbitmq configurate/init/connect/encrypt/decrypt/encode/decode"""
_connection: AbstractRobustConnection = None
_channel: AbstractRobustChannel = None
_queue: AbstractRobustQueue = None
_exchange: AbstractExchange = None
def __init__(self, login: str, password: str, host: str = '127.0.0.1', port: int = 5672, queue: str = 'kook',
qos: int = 10, heartbeat: int = 30, key: str = '', salt: str = '', key_digits: int = 32,
compress: bool = True):
self._host = host
self._port = port
self.queue = queue
self.qos = qos
self._heartbeat = heartbeat
self._login = login
self._password = password
self.compress = compress
# check aes is 128, 192 or 256
if key_digits not in AES.key_size:
raise ValueError(f'rabbitmq key_digits: {key_digits} not in {AES.key_size}')
if key != '':
key_string = key
else:
# if rabbitmq_key is not defined, use sha256 to generate one
# construct a variable complex certain string
# login, password, queue, compress, key_digits are same in both sides
key_string = f'khl.py://{login}:{password}/{queue}?&compress={compress}&key_digits={key_digits}'

if salt != '':
salt_encoded = salt.encode('utf-8')
else:
salt_encoded = b'rabbitmq in khl.py'

# use pbkdf2_hmac to generate key with key_digits
self._aes_key = hashlib.pbkdf2_hmac(
'sha256',
key_string.encode('utf-8'),
salt_encoded,
100000,
dklen=key_digits
)

def decrypt(self, data: bytes) -> bytes:
""" decrypt data
:param data: encrypted byte array
:return: decrypted byte array
"""
decipher = AES.new(self._aes_key, AES.MODE_CBC, iv=data[:16])
data = decipher.decrypt(data[16:])
data = Padding.unpad(data, 16)
return data
def encrypt(self, data: bytes) -> bytes:
""" encrypt data
:param data: byte array
:return: encrypted byte array
"""
data = Padding.pad(data, 16)
cipher = AES.new(self._aes_key, AES.MODE_CBC)
data = cipher.encrypt(data)
return cipher.iv + data
def decode(self, data: bytes) -> dict:
"""decode raw rabbitmq data into plaintext data"""
data = self.decrypt(data)
if self.compress:
data = zlib.decompress(data)
return json.loads(str(data, encoding='utf-8'))
def encode(self, data: dict) -> bytes:
"""encode pkg into rabbitmq data"""
data = json.dumps(data).encode(encoding='utf-8')
if self.compress:
data = zlib.compress(data)
data = self.encrypt(data)
return data
async def get_connection(self) -> AbstractRobustConnection:
"""get rabbitmq connection"""
if self._connection is None:
self._connection = await aio_pika.connect_robust(
host=self._host,
port=self._port,
login=self._login,
password=self._password,
heartbeat=self._heartbeat
)
await self._connection.connect()
return self._connection
async def get_channel(self) -> AbstractRobustChannel:
"""get rabbitmq channel"""
if self._channel is None:
connection = await self.get_connection()
self._channel = await connection.channel()
return self._channel
async def get_queue(self) -> AbstractRobustQueue:
"""get rabbitmq queue"""
if self._queue is None:
channel = await self.get_channel()
self._queue = await channel.declare_queue(self.queue)
return self._queue
async def get_exchange(self) -> AbstractExchange:
"""get rabbitmq default exchange"""
if self._exchange is None:
channel = await self.get_channel()
self._exchange = channel.default_exchange
return self._exchange
93 changes: 93 additions & 0 deletions khl/rabbitmq/rabbitmq_producer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import asyncio
import logging
from typing import Dict
import aio_pika
from aio_pika.abc import AbstractExchange, AbstractRobustConnection
from . import RabbitMQ
from ..cert import Cert
from ..receiver import WebhookReceiver, WebsocketReceiver
from ..interface import AsyncRunnable
log = logging.getLogger(__name__)

class RabbitMQProducer(AsyncRunnable):
"""produce data to rabbitmq"""
_exchange: AbstractExchange
_connection: AbstractRobustConnection
def __init__(self, rabbitmq: RabbitMQ):
super().__init__()
self._rabbitmq = rabbitmq
async def start(self):
self._exchange = await self._rabbitmq.get_exchange()
while True:
await asyncio.sleep(3600) # sleep forever
async def publish(self, data: dict):
"""produce data"""
await self._exchange.publish(aio_pika.Message(body=self._rabbitmq.encode(data)),
routing_key=self._rabbitmq.queue)

class RabbitMQProductionBot(AsyncRunnable):
"""
This bot class is made for rabbitmq data production, to send package unwrapped by Receiver to rabbitmq
"""
def __init__(self,
token: str = '',
*,
cert: Cert = None,
compress: bool = True,
port=5000,
route='/khl-wh',
rabbitmq: RabbitMQ = None):
"""
:param cert: used to build requester and receiver
:param compress: used to tune the receiver
:param port: used to tune the WebhookReceiver
:param route: used to tune the WebhookReceiver
:param rabbitmq: used to tune the RabbitMQ Receiver or Producer
"""
if not token and not cert:
raise ValueError('require token or cert')

if rabbitmq is None:
raise ValueError('require rabbitmq')

cert = cert or Cert(token=token)

if cert.type == Cert.Types.WEBSOCKET:
receiver = WebsocketReceiver(cert, compress)
elif cert.type == Cert.Types.WEBHOOK:
receiver = WebhookReceiver(cert, port=port, route=route, compress=compress)
else:
raise ValueError(f'cert type: {cert.type} not supported')

self._pkg_queue = asyncio.Queue()
self._receiver = receiver
self._producer = RabbitMQProducer(rabbitmq)


async def publish_pkg_to_rabbitmq(self):
"""publish pkg to rabbitmq"""
while True:
pkg: Dict = await self._pkg_queue.get()
log.debug(f'publishing pkg: {pkg}')
try:
await self._producer.publish(pkg)
except Exception as e:
log.exception(e)
self._pkg_queue.task_done()

async def start(self):
"""start the rabbitmq bot"""
self._receiver.pkg_queue = self._pkg_queue # pass the pkg_queue to the receiver
await asyncio.gather(self.publish_pkg_to_rabbitmq(), self._receiver.start(), self._producer.start())

# keep the usage compatible with Bot interface, so ignored pylint warning
# pylint: disable=duplicate-code
def run(self):
"""run the rabbitmq bot in blocking mode"""
if not self.loop:
self.loop = asyncio.get_event_loop()
try:
self.loop.run_until_complete(self.start())
except KeyboardInterrupt:
log.info('see you next time')
# pylint: enable=duplicate-code
32 changes: 32 additions & 0 deletions khl/rabbitmq/rabbitmq_receiver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import asyncio
import logging
from typing import Dict
from ..receiver import Receiver
from .rabbitmq import RabbitMQ
log = logging.getLogger(__name__)


class RabbitMQReceiver(Receiver):
"""receive data in RabbitMQ mode"""
def __init__(self, rabbitmq: RabbitMQ):
super().__init__()
self._rabbitmq = rabbitmq
@property
def type(self) -> str:
return 'rabbitmq'
async def start(self):
queue = await self._rabbitmq.get_queue()
log.info('[ init ] launched')
async with queue.iterator() as queue_iter:
async for message in queue_iter:
async with message.process():
try:
pkg: Dict = self._rabbitmq.decode(message.body)
except Exception as e:
log.exception(e)
continue
if not pkg: # empty pkg
continue
while self.pkg_queue.qsize() >= self._rabbitmq.qos:
await asyncio.sleep(0.001)
await self.pkg_queue.put(pkg)
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ install_requires =
aiohttp
pycryptodomex
apscheduler
aio_pika

[options.packages.find]
where = .
Loading