From c6601aaeed6b1e54914e8fd16b0cb1bc3605d92d Mon Sep 17 00:00:00 2001 From: lanvent Date: Tue, 25 Apr 2023 01:04:03 +0800 Subject: [PATCH] fix: ensure get access_token thread-safe --- channel/wechatcom/README.md | 2 +- channel/wechatcom/wechatcomapp_channel.py | 6 +++--- channel/wechatcom/wechatcomapp_client.py | 21 +++++++++++++++++++++ channel/wechatcom/wechatcomapp_message.py | 21 ++++----------------- 4 files changed, 29 insertions(+), 21 deletions(-) create mode 100644 channel/wechatcom/wechatcomapp_client.py diff --git a/channel/wechatcom/README.md b/channel/wechatcom/README.md index 5eea6882e..1728678b8 100644 --- a/channel/wechatcom/README.md +++ b/channel/wechatcom/README.md @@ -54,4 +54,4 @@ AIGC开放社区中已经部署了多个可免费使用的Bot,扫描下方的二维码会自动邀请你来体验。 - \ No newline at end of file + \ No newline at end of file diff --git a/channel/wechatcom/wechatcomapp_channel.py b/channel/wechatcom/wechatcomapp_channel.py index 6959b9e3a..bd51c5fff 100644 --- a/channel/wechatcom/wechatcomapp_channel.py +++ b/channel/wechatcom/wechatcomapp_channel.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # -*- coding=utf-8 -*- import io import os @@ -6,7 +5,7 @@ import requests import web -from wechatpy.enterprise import WeChatClient, create_reply, parse_message +from wechatpy.enterprise import create_reply, parse_message from wechatpy.enterprise.crypto import WeChatCrypto from wechatpy.enterprise.exceptions import InvalidCorpIdException from wechatpy.exceptions import InvalidSignatureException, WeChatClientException @@ -14,6 +13,7 @@ from bridge.context import Context from bridge.reply import Reply, ReplyType from channel.chat_channel import ChatChannel +from channel.wechatcom.wechatcomapp_client import WechatComAppClient from channel.wechatcom.wechatcomapp_message import WechatComAppMessage from common.log import logger from common.singleton import singleton @@ -38,7 +38,7 @@ def __init__(self): "[wechatcom] init: corp_id: {}, secret: {}, agent_id: {}, token: {}, aes_key: {}".format(self.corp_id, self.secret, self.agent_id, self.token, self.aes_key) ) self.crypto = WeChatCrypto(self.token, self.aes_key, self.corp_id) - self.client = WeChatClient(self.corp_id, self.secret) # todo: 这里可能有线程安全问题 + self.client = WechatComAppClient(self.corp_id, self.secret) def startup(self): # start message listener diff --git a/channel/wechatcom/wechatcomapp_client.py b/channel/wechatcom/wechatcomapp_client.py new file mode 100644 index 000000000..c0feb7a18 --- /dev/null +++ b/channel/wechatcom/wechatcomapp_client.py @@ -0,0 +1,21 @@ +import threading +import time + +from wechatpy.enterprise import WeChatClient + + +class WechatComAppClient(WeChatClient): + def __init__(self, corp_id, secret, access_token=None, session=None, timeout=None, auto_retry=True): + super(WechatComAppClient, self).__init__(corp_id, secret, access_token, session, timeout, auto_retry) + self.fetch_access_token_lock = threading.Lock() + + def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token + with self.fetch_access_token_lock: + access_token = self.session.get(self.access_token_key) + if access_token: + if not self.expires_at: + return access_token + timestamp = time.time() + if self.expires_at - timestamp > 60: + return access_token + return super().fetch_access_token() diff --git a/channel/wechatcom/wechatcomapp_message.py b/channel/wechatcom/wechatcomapp_message.py index f441a68a5..a70f7556e 100644 --- a/channel/wechatcom/wechatcomapp_message.py +++ b/channel/wechatcom/wechatcomapp_message.py @@ -1,14 +1,9 @@ -import re - -import requests from wechatpy.enterprise import WeChatClient from bridge.context import ContextType from channel.chat_message import ChatMessage from common.log import logger from common.tmp_dir import TmpDir -from lib import itchat -from lib.itchat.content import * class WechatComAppMessage(ChatMessage): @@ -23,9 +18,7 @@ def __init__(self, msg, client: WeChatClient, is_group=False): self.content = msg.content elif msg.type == "voice": self.ctype = ContextType.VOICE - self.content = ( - TmpDir().path() + msg.media_id + "." + msg.format - ) # content直接存临时目录路径 + self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径 def download_voice(): # 如果响应状态码是200,则将响应内容写入本地文件 @@ -34,9 +27,7 @@ def download_voice(): with open(self.content, "wb") as f: f.write(response.content) else: - logger.info( - f"[wechatcom] Failed to download voice file, {response.content}" - ) + logger.info(f"[wechatcom] Failed to download voice file, {response.content}") self._prepare_fn = download_voice elif msg.type == "image": @@ -50,15 +41,11 @@ def download_image(): with open(self.content, "wb") as f: f.write(response.content) else: - logger.info( - f"[wechatcom] Failed to download image file, {response.content}" - ) + logger.info(f"[wechatcom] Failed to download image file, {response.content}") self._prepare_fn = download_image else: - raise NotImplementedError( - "Unsupported message type: Type:{} ".format(msg.type) - ) + raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type)) self.from_user_id = msg.source self.to_user_id = msg.target