Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions swanlab/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
"SlackCallback",
"LogdirFileWriter",
"BarkCallback",
"TelegramCallback",
]
205 changes: 196 additions & 9 deletions swanlab/plugin/notification.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
class PrintCallback(SwanKitCallback):
"""Basic callback for printing experiment status information."""

def on_init(self, proj_name: str, workspace: str, logdir: str = None, *args, **kwargs):
def on_init(self, proj_name: str, workspace: str, logdir: Optional[str] = None, *args, **kwargs):
"""Called when experiment initialization completes."""
print(f"🚀 My callback: on_init: {proj_name}, {workspace}, {logdir}, {kwargs}")

def on_stop(self, error: str = None, *args, **kwargs):
def on_stop(self, error: Optional[str] = None, *args, **kwargs):
"""Called when experiment stops or encounters error."""
status = f"with error: {error}" if error else "successfully"
print(f"🚀 My callback: Experiment stopped {status}")
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
self.port = port
self.language = language

def _create_email_content(self, error: Optional[str] = None) -> Dict[str, str]:
def _create_email_content(self, error: Optional[str] = None) -> Tuple[str, str]:
"""Generate bilingual email content based on experiment status."""
templates = self.DEFAULT_TEMPLATES[self.language]

Expand Down Expand Up @@ -125,7 +125,7 @@ def send_email(self, subject: str, body: str) -> None:
except smtplib.SMTPException as e:
print(f"❌ Email sending failed: {str(e)}")

def on_init(self, proj_name: str, workspace: str, logdir: str = None, *args, **kwargs):
def on_init(self, proj_name: str, workspace: str, logdir: Optional[str] = None, *args, **kwargs):
self.project = proj_name
self.workspace = workspace

Expand All @@ -142,7 +142,7 @@ def before_init_experiment(
self.exp_name = exp_name
self.description = description

def on_stop(self, error: str = None, *args, **kwargs):
def on_stop(self, error: Optional[str] = None, *args, **kwargs):
"""Trigger email notification when experiment stops."""
print("📧 Preparing email notification...")
subject, body = self._create_email_content(error)
Expand Down Expand Up @@ -206,7 +206,7 @@ def send_msg(self, content: str) -> None:
"""发送消息的具体实现"""
pass

def on_init(self, proj_name: str, workspace: str, logdir: str = None, *args, **kwargs):
def on_init(self, proj_name: str, workspace: str, logdir: Optional[str] = None, *args, **kwargs):
self.project = proj_name
self.workspace = workspace

Expand All @@ -223,7 +223,7 @@ def before_init_experiment(
self.exp_name = exp_name
self.description = description

def on_stop(self, error: str = None, *args, **kwargs):
def on_stop(self, error: Optional[str] = None, *args, **kwargs):
print(f"🤖 Preparing {self.__class__.__name__} notification...")
content = self._create_content(error)
self.send_msg(content)
Expand Down Expand Up @@ -563,7 +563,7 @@ def send_notification(self, data: dict):
print("✅ Bark sending successfully")


def on_init(self, proj_name: str, workspace: str, public: bool = None, logdir: str = None, *args, **kwargs):
def on_init(self, proj_name: str, workspace: str, public: Optional[bool] = None, logdir: Optional[str] = None, *args, **kwargs):
self.project = proj_name
self.workspace = workspace

Expand All @@ -580,6 +580,193 @@ def before_init_experiment(
self.exp_name = exp_name
self.description = description

def on_stop(self, error: str = None, *args, **kwargs):
def on_stop(self, error: Optional[str] = None, *args, **kwargs):
content = self._create_notification_message(error)
self.send_notification(content)


class TelegramBot:
"""
Telegram Bot notification helper.
docs: https://core.telegram.org/bots/api#sendmessage
"""

def __init__(self, bot_token: str, chat_id: str):
"""
Initialize Telegram Bot.

:param bot_token: Telegram Bot API token (get from @BotFather)
:param chat_id: Target chat ID (can be user ID, group ID, or channel username)
"""
self.bot_token = bot_token
self.chat_id = chat_id
self.api_base = f"https://api.telegram.org/bot{bot_token}"

def send_message(self, text: str, parse_mode: str = "HTML") -> Dict[str, Any]:
"""
Send a message via Telegram Bot API.

:param text: Message text
:param parse_mode: Parse mode for formatting (HTML or Markdown)
:return: API response
"""
url = f"{self.api_base}/sendMessage"
data = {
"chat_id": self.chat_id,
"text": text,
"parse_mode": parse_mode,
}
resp = requests.post(url, json=data)
resp.raise_for_status()
return resp.json()


class TelegramCallback(SwanKitCallback):
"""
Telegram notification callback with bilingual support.
Send notifications to Telegram chat when experiment starts/stops.

Usage:
1. Create a bot via @BotFather and get the bot token
2. Get your chat ID (send /start to @userinfobot)
3. Initialize the callback:

```python
from swanlab.plugin import TelegramCallback

telegram = TelegramCallback(
bot_token="YOUR_BOT_TOKEN",
chat_id="YOUR_CHAT_ID",
language="zh" # or "en"
)
swanlab.init(callbacks=[telegram])
```
"""

DEFAULT_TEMPLATES = {
"en": {
"title": "🧪 <b>SwanLab Notification</b>\n\n",
"msg_start": "🚀 Experiment started\n",
"msg_success": "✅ Experiment completed successfully\n",
"msg_error": "❌ Experiment failed: {error}\n",
"link_text": (
"<b>Project:</b> {project}\n"
"<b>Workspace:</b> {workspace}\n"
"<b>Name:</b> {exp_name}\n"
"<b>Description:</b> {description}\n"
"<b>Link:</b> {link}"
),
"offline_text": "📴 Running in offline mode",
},
"zh": {
"title": "🧪 <b>SwanLab 消息通知</b>\n\n",
"msg_start": "🚀 实验已开始\n",
"msg_success": "✅ 实验已成功完成\n",
"msg_error": "❌ 实验遇到错误: {error}\n",
"link_text": (
"<b>项目:</b> {project}\n"
"<b>工作区:</b> {workspace}\n"
"<b>实验名:</b> {exp_name}\n"
"<b>描述:</b> {description}\n"
"<b>链接:</b> {link}"
),
"offline_text": "📴 离线模式运行中",
},
}

def __init__(
self,
bot_token: str,
chat_id: str,
language: str = "zh",
notify_on_start: bool = False,
):
"""
Initialize Telegram callback configuration.

:param bot_token: Telegram Bot API token (get from @BotFather)
:param chat_id: Target chat ID (user ID, group ID, or @channel_username)
:param language: Notification language (en/zh)
:param notify_on_start: Whether to send notification when experiment starts
"""
self.bot = TelegramBot(bot_token, chat_id)
self.language = language
self.notify_on_start = notify_on_start

def _create_content(self, event: str = "stop", error: Optional[str] = None) -> str:
"""
Create notification content based on event type.

:param event: Event type ("start" or "stop")
:param error: Error message if experiment failed
:return: Formatted message text
"""
templates = self.DEFAULT_TEMPLATES[self.language]
content = templates["title"]

if event == "start":
content += templates["msg_start"]
elif error:
content += templates["msg_error"].format(error=error)
else:
content += templates["msg_success"]

exp_link = swanlab.get_url()
if exp_link:
content += templates["link_text"].format(
project=self.project,
workspace=self.workspace,
exp_name=self.exp_name,
description=self.description or "N/A",
link=exp_link,
)
else:
content += templates["offline_text"]

return content

def send_msg(self, content: str) -> None:
"""Send message via Telegram Bot."""
try:
result = self.bot.send_message(content)
if result.get("ok"):
print("✅ Telegram message sent successfully")
else:
print(f"❌ Telegram sending failed: {result.get('description')}")
except requests.RequestException as e:
print(f"❌ Telegram sending failed: {str(e)}")

def on_init(self, proj_name: str, workspace: str, public: Optional[bool] = None, logdir: Optional[str] = None, *args, **kwargs):
"""Called when experiment initialization completes."""
self.project = proj_name
self.workspace = workspace

def before_init_experiment(
self,
run_id: str,
exp_name: str,
description: str,
colors: Tuple[str, str],
*args,
**kwargs,
):
"""Called before experiment initialization."""
self.run_id = run_id
self.exp_name = exp_name
self.description = description

def on_run(self, *args, **kwargs):
"""Called when experiment starts running."""
if self.notify_on_start:
print("📱 Preparing Telegram start notification...")
content = self._create_content(event="start")
self.send_msg(content)

def on_stop(self, error: Optional[str] = None, *args, **kwargs):
"""Called when experiment stops."""
print("📱 Preparing Telegram notification...")
content = self._create_content(event="stop", error=error)
self.send_msg(content)

def __str__(self) -> str:
return "TelegramCallback"
47 changes: 47 additions & 0 deletions test/plugin/notification_telegram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
@author: yugangcao
@file: notification_telegram.py
@time: 2025/12/05
@description: 使用Telegram通知实验运行情况
"""

import argparse
import random
import time

import swanlab
from swanlab.plugin import TelegramCallback

parser = argparse.ArgumentParser(description="测试Telegram通知功能")
parser.add_argument('--api-key', type=str, default=None, help="SwanLab的API Key")
parser.add_argument('--host', type=str, default=None, help="SwanLab的服务器地址")
parser.add_argument("--bot-token", type=str, required=True, help="Telegram Bot Token (从@BotFather获取)")
parser.add_argument("--chat-id", type=str, required=True, help="Telegram Chat ID (从@userinfobot获取)")
parser.add_argument("--language", type=str, default="zh", choices=["zh", "en"], help="通知语言")
parser.add_argument("--notify-on-start", action="store_true", help="实验开始时也发送通知")
args = parser.parse_args()

if args.api_key:
swanlab.login(api_key=args.api_key, host=args.host)
# 集成TelegramCallback
swanlab.register_callbacks([TelegramCallback(
bot_token=args.bot_token,
chat_id=args.chat_id,
language=args.language,
notify_on_start=args.notify_on_start,
)])

epochs = 50
lr = 0.01
offset = random.random() / 5
# 初始化
swanlab.init(description="测试Telegram通知功能", mode="cloud")
swanlab.config.epochs = epochs
swanlab.config.learning_rate = lr
# 模拟训练
for epoch in range(2, swanlab.config.epochs):
acc = 1 - 2**-epoch - random.random() / epoch - offset
loss = 2**-epoch + random.random() / epoch + offset
print(f"epoch={epoch}, accuracy={acc}, loss={loss}")
swanlab.log({"t/accuracy": acc, "loss": loss})
time.sleep(0.5)
Loading