|  | 
|  | 1 | +import asyncio | 
|  | 2 | +import hashlib | 
|  | 3 | +import json | 
|  | 4 | +import logging | 
|  | 5 | +import zlib | 
|  | 6 | + | 
|  | 7 | +import aio_pika | 
|  | 8 | +from Cryptodome.Cipher import AES | 
|  | 9 | +from Cryptodome.Util import Padding | 
|  | 10 | +from aio_pika.abc import AbstractRobustConnection, AbstractRobustChannel, AbstractExchange, AbstractRobustQueue | 
|  | 11 | + | 
|  | 12 | +from .interface import AsyncRunnable | 
|  | 13 | + | 
|  | 14 | +log = logging.getLogger(__name__) | 
|  | 15 | + | 
|  | 16 | + | 
|  | 17 | +class RabbitMQ: | 
|  | 18 | +    """rabbitmq configurate/init/connect/encrypt/decrypt/encode/decode""" | 
|  | 19 | + | 
|  | 20 | +    _connection: AbstractRobustConnection = None | 
|  | 21 | +    _channel: AbstractRobustChannel = None | 
|  | 22 | +    _queue: AbstractRobustQueue = None | 
|  | 23 | +    _exchange: AbstractExchange = None | 
|  | 24 | + | 
|  | 25 | +    def __init__(self, host: str, port: int, queue: str, qos: int, heartbeat: int, login: str = '', password: str = '', | 
|  | 26 | +                 key: str = '', key_digits: int = 16, is_producer: bool = False): | 
|  | 27 | +        self._host = host | 
|  | 28 | +        self._port = port | 
|  | 29 | +        self.queue = queue | 
|  | 30 | +        self.qos = qos | 
|  | 31 | +        self._heartbeat = heartbeat | 
|  | 32 | +        self._login = login | 
|  | 33 | +        self._password = password | 
|  | 34 | +        self._key = key | 
|  | 35 | +        self._key_digits = key_digits | 
|  | 36 | +        self.is_producer = is_producer | 
|  | 37 | + | 
|  | 38 | +        # check aes is 128, 192 or 256 | 
|  | 39 | +        if key_digits not in AES.key_size: | 
|  | 40 | +            raise ValueError(f'rabbitmq key_digits: {key_digits} not in {AES.key_size}') | 
|  | 41 | +        if key != '': | 
|  | 42 | +            key_encoded = key.encode('utf-8').ljust(key_digits, b'\x00') | 
|  | 43 | +        else: | 
|  | 44 | +            # if rabbitmq_key is not defined, use sha256 to generate one | 
|  | 45 | +            key_encoded = hashlib.sha256(f'{login}:{password}'.encode('utf-8')).digest() | 
|  | 46 | + | 
|  | 47 | +        # make sure key digits is right | 
|  | 48 | +        self._aes_key = key_encoded[:key_digits] | 
|  | 49 | + | 
|  | 50 | +    def decrypt(self, data: bytes) -> bytes: | 
|  | 51 | +        """ decrypt data | 
|  | 52 | +
 | 
|  | 53 | +        :param data: encrypted byte array | 
|  | 54 | +        :return: decrypted byte array | 
|  | 55 | +        """ | 
|  | 56 | +        decipher = AES.new(self._aes_key, AES.MODE_CBC, iv=data[:16]) | 
|  | 57 | +        data = decipher.decrypt(data[16:]) | 
|  | 58 | +        data = Padding.unpad(data, 16) | 
|  | 59 | +        return data | 
|  | 60 | + | 
|  | 61 | +    def encrypt(self, data: bytes) -> bytes: | 
|  | 62 | +        """ encrypt data | 
|  | 63 | +
 | 
|  | 64 | +        :param data: byte array | 
|  | 65 | +        :return: encrypted byte array | 
|  | 66 | +        """ | 
|  | 67 | +        data = Padding.pad(data, 16) | 
|  | 68 | +        cipher = AES.new(self._aes_key, AES.MODE_CBC) | 
|  | 69 | +        data = cipher.encrypt(data) | 
|  | 70 | +        return cipher.iv + data | 
|  | 71 | + | 
|  | 72 | +    def decode(self, data: bytes, compress: bool) -> dict: | 
|  | 73 | +        """decode raw rabbitmq data into plaintext data""" | 
|  | 74 | +        data = self.decrypt(data) | 
|  | 75 | +        if compress: | 
|  | 76 | +            data = zlib.decompress(data) | 
|  | 77 | +        return json.loads(str(data, encoding='utf-8')) | 
|  | 78 | + | 
|  | 79 | +    def encode(self, data: dict, compress: bool) -> bytes: | 
|  | 80 | +        """encode pkg into rabbitmq data""" | 
|  | 81 | +        data = json.dumps(data).encode(encoding='utf-8') | 
|  | 82 | +        if compress: | 
|  | 83 | +            data = zlib.compress(data) | 
|  | 84 | +        data = self.encrypt(data) | 
|  | 85 | +        return data | 
|  | 86 | + | 
|  | 87 | +    async def get_connection(self) -> AbstractRobustConnection: | 
|  | 88 | +        """get rabbitmq connection""" | 
|  | 89 | +        if self._connection is None: | 
|  | 90 | +            self._connection = await aio_pika.connect_robust( | 
|  | 91 | +                host=self._host, | 
|  | 92 | +                port=self._port, | 
|  | 93 | +                login=self._login, | 
|  | 94 | +                password=self._password, | 
|  | 95 | +                heartbeat=self._heartbeat | 
|  | 96 | +            ) | 
|  | 97 | +            await self._connection.connect() | 
|  | 98 | +        return self._connection | 
|  | 99 | + | 
|  | 100 | +    async def get_channel(self) -> AbstractRobustChannel: | 
|  | 101 | +        """get rabbitmq channel""" | 
|  | 102 | +        if self._channel is None: | 
|  | 103 | +            connection = await self.get_connection() | 
|  | 104 | +            self._channel = await connection.channel() | 
|  | 105 | +        return self._channel | 
|  | 106 | + | 
|  | 107 | +    async def get_queue(self) -> AbstractRobustQueue: | 
|  | 108 | +        """get rabbitmq queue""" | 
|  | 109 | +        if self._queue is None: | 
|  | 110 | +            channel = await self.get_channel() | 
|  | 111 | +            self._queue = await channel.declare_queue(self.queue) | 
|  | 112 | +        return self._queue | 
|  | 113 | + | 
|  | 114 | +    async def get_exchange(self) -> AbstractExchange: | 
|  | 115 | +        """get rabbitmq default exchange""" | 
|  | 116 | +        if self._exchange is None: | 
|  | 117 | +            channel = await self.get_channel() | 
|  | 118 | +            self._exchange = channel.default_exchange | 
|  | 119 | +        return self._exchange | 
|  | 120 | + | 
|  | 121 | + | 
|  | 122 | +class RabbitMQProducer(AsyncRunnable): | 
|  | 123 | +    """produce data to rabbitmq""" | 
|  | 124 | + | 
|  | 125 | +    _exchange: AbstractExchange | 
|  | 126 | +    _connection: AbstractRobustConnection | 
|  | 127 | + | 
|  | 128 | +    def __init__(self, rabbitmq: RabbitMQ, compress: bool): | 
|  | 129 | +        super().__init__() | 
|  | 130 | +        self._rabbitmq = rabbitmq | 
|  | 131 | +        self._compress = compress | 
|  | 132 | + | 
|  | 133 | +    async def start(self): | 
|  | 134 | +        self._exchange = await self._rabbitmq.get_exchange() | 
|  | 135 | +        while True: | 
|  | 136 | +            await asyncio.sleep(3600)  # sleep forever | 
|  | 137 | + | 
|  | 138 | +    async def publish(self, data: dict): | 
|  | 139 | +        """produce data""" | 
|  | 140 | +        await self._exchange.publish(aio_pika.Message(body=self._rabbitmq.encode(data, self._compress)), | 
|  | 141 | +                                     routing_key=self._rabbitmq.queue) | 
0 commit comments