Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions astrbot/api/event/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
register_llm_tool as llm_tool,
register_on_decorating_result as on_decorating_result,
register_after_message_sent as after_message_sent,
register_on_star_activated as on_star_activated,
register_on_star_deactivated as on_star_deactivated,
)

from astrbot.core.star.filter.event_message_type import (
Expand Down Expand Up @@ -46,4 +48,6 @@
"on_decorating_result",
"after_message_sent",
"on_llm_response",
"on_star_activated",
"on_star_deactivated",
]
4 changes: 4 additions & 0 deletions astrbot/core/star/register/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
register_agent,
register_on_decorating_result,
register_after_message_sent,
register_on_star_activated,
register_on_star_deactivated,
)

__all__ = [
Expand All @@ -32,4 +34,6 @@
"register_agent",
"register_on_decorating_result",
"register_after_message_sent",
"register_on_star_activated",
"register_on_star_deactivated",
]
28 changes: 28 additions & 0 deletions astrbot/core/star/register/star_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,31 @@ def decorator(awaitable):
return awaitable

return decorator


def register_on_star_activated(star_name: str = None, **kwargs):
"""当指定插件被激活时"""

def decorator(awaitable):
handler_md = get_handler_or_create(
awaitable, EventType.OnStarActivatedEvent, **kwargs
)
if star_name:
handler_md.extras_configs["target_star_name"] = star_name
return awaitable

return decorator


def register_on_star_deactivated(star_name: str = None, **kwargs):
"""当指定插件被停用时"""

def decorator(awaitable):
handler_md = get_handler_or_create(
awaitable, EventType.OnStarDeactivatedEvent, **kwargs
)
if star_name:
handler_md.extras_configs["target_star_name"] = star_name
return awaitable

return decorator
2 changes: 2 additions & 0 deletions astrbot/core/star/star.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class StarMetadata:
"""插件版本"""
repo: str | None = None
"""插件仓库地址"""
dependencies: list[str] = field(default_factory=list)
"""插件依赖列表"""

star_cls_type: type[Star] | None = None
"""插件的类对象的类型"""
Expand Down
2 changes: 2 additions & 0 deletions astrbot/core/star/star_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class EventType(enum.Enum):
OnCallingFuncToolEvent = enum.auto() # 调用函数工具
OnAfterMessageSentEvent = enum.auto() # 发送消息后

OnStarActivatedEvent = enum.auto() # 插件启用
OnStarDeactivatedEvent = enum.auto() # 插件禁用

@dataclass
class StarHandlerMetadata:
Expand Down
210 changes: 175 additions & 35 deletions astrbot/core/star/star_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from .star import star_map, star_registry
from .star_handler import star_handlers_registry
from .updator import PluginUpdator
from .star_handler import EventType, StarHandlerMetadata
import networkx as nx

try:
from watchfiles import PythonFilter, awatch
Expand Down Expand Up @@ -144,13 +146,11 @@ def _get_modules(path):
if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists(
os.path.join(path, d, d + ".py")
):
modules.append(
{
"pname": d,
"module": module_str,
"module_path": os.path.join(path, d, module_str),
}
)
modules.append({
"pname": d,
"module": module_str,
"module_path": os.path.join(path, d, module_str),
})
return modules

def _get_plugin_modules(self) -> list[dict]:
Expand Down Expand Up @@ -226,6 +226,7 @@ def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | N
desc=metadata["desc"],
version=metadata["version"],
repo=metadata["repo"] if "repo" in metadata else None,
dependencies=metadata.get("dependencies", []),
)

return metadata
Expand Down Expand Up @@ -321,25 +322,17 @@ async def reload(self, specified_plugin_name=None):
star_handlers_registry.clear()
star_map.clear()
star_registry.clear()
plugin_modules = await self._get_load_order()
result = await self.load(plugin_modules=plugin_modules)
else:
# 只重载指定插件
smd = star_map.get(specified_module_path)
if smd:
try:
await self._terminate_plugin(smd)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
)
if smd.name:
await self._unbind_plugin(smd.name, specified_module_path)

result = await self.load(specified_module_path)
result = await self.batch_reload(
specified_module_path=specified_module_path
)

return result

async def load(self, specified_module_path=None, specified_dir_name=None):
async def load(self, plugin_modules=None):
"""载入插件。
当 specified_module_path 或者 specified_dir_name 不为 None 时,只载入指定的插件。

Expand All @@ -356,10 +349,11 @@ async def load(self, specified_module_path=None, specified_dir_name=None):
inactivated_llm_tools = await sp.global_get("inactivated_llm_tools", [])
alter_cmd = await sp.global_get("alter_cmd", {})

plugin_modules = self._get_plugin_modules()
if plugin_modules is None:
return False, "未找到任何插件模块"

logger.info(
f"正在按顺序加载插件: {[plugin_module['pname'] for plugin_module in plugin_modules]}"
)
fail_rec = ""

# 导入插件模块,并尝试实例化插件类
Expand All @@ -375,12 +369,6 @@ async def load(self, specified_module_path=None, specified_dir_name=None):
path = "data.plugins." if not reserved else "packages."
path += root_dir_name + "." + module_str

# 检查是否需要载入指定的插件
if specified_module_path and path != specified_module_path:
continue
if specified_dir_name and root_dir_name != specified_dir_name:
continue

logger.info(f"正在载入插件 {root_dir_name} ...")

# 尝试导入模块
Expand Down Expand Up @@ -451,6 +439,9 @@ async def load(self, specified_module_path=None, specified_dir_name=None):
metadata.star_cls = metadata.star_cls_type(
context=self.context
)
await self._trigger_star_lifecycle_event(
EventType.OnStarActivatedEvent, metadata
)
else:
logger.info(f"插件 {metadata.name} 已被禁用。")

Expand Down Expand Up @@ -622,7 +613,8 @@ async def install_plugin(self, repo_url: str, proxy=""):
plugin_path = await self.updator.install(repo_url, proxy)
# reload the plugin
dir_name = os.path.basename(plugin_path)
await self.load(specified_dir_name=dir_name)
plugin_modules = await self._get_load_order(specified_dir_name=dir_name)
await self.batch_reload(plugin_modules=plugin_modules)

# Get the plugin metadata to return repo info
plugin = self.context.get_registered_star(dir_name)
Expand Down Expand Up @@ -778,8 +770,7 @@ async def turn_off_plugin(self, plugin_name: str):

plugin.activated = False

@staticmethod
async def _terminate_plugin(star_metadata: StarMetadata):
async def _terminate_plugin(self, star_metadata: StarMetadata):
"""终止插件,调用插件的 terminate() 和 __del__() 方法"""
logger.info(f"正在终止插件 {star_metadata.name} ...")

Expand All @@ -788,14 +779,18 @@ async def _terminate_plugin(star_metadata: StarMetadata):
logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。")
return

await self._trigger_star_lifecycle_event(
EventType.OnStarDeactivatedEvent, star_metadata
)

if star_metadata.star_cls is None:
return

if '__del__' in star_metadata.star_cls_type.__dict__:
if "__del__" in star_metadata.star_cls_type.__dict__:
asyncio.get_event_loop().run_in_executor(
None, star_metadata.star_cls.__del__
)
elif 'terminate' in star_metadata.star_cls_type.__dict__:
elif "terminate" in star_metadata.star_cls_type.__dict__:
await star_metadata.star_cls.terminate()

async def turn_on_plugin(self, plugin_name: str):
Expand Down Expand Up @@ -832,7 +827,8 @@ async def install_plugin_from_file(self, zip_file_path: str):
except BaseException as e:
logger.warning(f"删除插件压缩包失败: {str(e)}")
# await self.reload()
await self.load(specified_dir_name=dir_name)
plugin_modules = await self._get_load_order(specified_dir_name=dir_name)
await self.batch_reload(plugin_modules=plugin_modules)

# Get the plugin metadata to return repo info
plugin = self.context.get_registered_star(dir_name)
Expand Down Expand Up @@ -865,3 +861,147 @@ async def install_plugin_from_file(self, zip_file_path: str):
}

return plugin_info

async def _trigger_star_lifecycle_event(
self, event_type: EventType, star_metadata: StarMetadata
):
"""
内部辅助函数,用于触发插件(Star)相关的生命周期事件。
Args:
event_type: 要触发的事件类型 (EventType.OnStarActivatedEvent 或 EventType.OnStarDeactivatedEvent)。
star_metadata: 触发事件的插件的 StarMetadata 对象。
"""
handlers_to_run: list[StarHandlerMetadata] = []
# 获取所有监听该事件类型的 handlers
handlers = star_handlers_registry.get_handlers_by_event_type(event_type)

for handler in handlers:
# 检查这个 handler 是否监听了特定的插件名
target_star_name = handler.extras_configs.get("target_star_name")
if target_star_name and target_star_name == star_metadata.name:
# 如果指定了目标插件名,则只在匹配时添加
handlers_to_run.append(handler)

for handler in handlers_to_run:
try:
# 调用插件的钩子函数,并传入 StarMetadata 对象
logger.info(
f"hook({event_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} (目标插件: {star_metadata.name})"
)
await handler.handler(star_metadata) # 传递参数
except Exception:
logger.error(
f"执行插件 {handler.handler_name} 的 {event_type.name} 钩子时出错: {traceback.format_exc()}"
)

def _get_plugin_dir_path(self, root_dir_name: str, is_reserved: bool) -> str:
"""根据插件的根目录名和是否为保留插件,返回插件的完整文件路径。"""
return (
os.path.join(self.plugin_store_path, root_dir_name)
if not is_reserved
else os.path.join(self.reserved_plugin_path, root_dir_name)
)

def _build_module_path(self, plugin_module_info: dict) -> str:
"""根据插件模块信息构建完整的模块路径。"""
reserved = plugin_module_info.get("reserved", False)
path_prefix = "packages." if reserved else "data.plugins."
return (
f"{path_prefix}{plugin_module_info['pname']}.{plugin_module_info['module']}"
)

async def _get_load_order(
self, specified_dir_name: str = None, specified_module_path: str = None
):
star_graph = self._build_star_graph()
if star_graph is None:
return None
try:
if specified_dir_name:
for node in star_graph:
if (
star_graph.nodes[node]["data"].get("pname")
== specified_dir_name
):
dependent_nodes = nx.descendants(star_graph, node)
sub_graph = star_graph.subgraph(dependent_nodes.union({node}))
load_order = list(nx.topological_sort(sub_graph))
return [star_graph.nodes[node]["data"] for node in load_order]
elif specified_module_path:
for node in star_graph:
if specified_module_path == self._build_module_path(
star_graph.nodes[node].get("data")
):
dependent_nodes = nx.descendants(star_graph, node)
sub_graph = star_graph.subgraph(dependent_nodes.union({node}))
load_order = list(nx.topological_sort(sub_graph))
return [star_graph.nodes[node]["data"] for node in load_order]
else:
sorted_nodes = list(nx.topological_sort(star_graph))

reserved_plugins = [
star_graph.nodes[node]["data"]
for node in sorted_nodes
if star_graph.nodes[node]["data"].get("reserved", False)
]
non_reserved_plugins = [
star_graph.nodes[node]["data"]
for node in sorted_nodes
if not star_graph.nodes[node]["data"].get("reserved", False)
]

return reserved_plugins + non_reserved_plugins

except nx.NetworkXUnfeasible:
logger.error("出现循环依赖,无法确定加载顺序,按自然顺序加载")
return [star_graph.nodes[node]["data"] for node in star_graph]

def _build_star_graph(self):
plugin_modules = self._get_plugin_modules()
if plugin_modules is None:
return None
G = nx.DiGraph()
for plugin_module in plugin_modules:
root_dir_name = plugin_module["pname"]
is_reserved = plugin_module.get("reserved", False)
plugin_dir_path = self._get_plugin_dir_path(root_dir_name, is_reserved)
G.add_node(root_dir_name, data=plugin_module)
try:
metadata = self._load_plugin_metadata(plugin_dir_path)
if metadata:
for dep_name in metadata.dependencies:
G.add_edge(root_dir_name, dep_name)
except Exception:
pass
# 过滤不存在的依赖(出边没有data, 就删除指向的节点)
nodes_to_remove = []
for node_name in list(G.nodes()):
for neighbor in list(G.neighbors(node_name)):
if G.nodes[neighbor].get("data") is None:
nodes_to_remove.append(neighbor)
logger.warning(
f"插件 {node_name} 声明依赖 {neighbor}, 但该插件未被发现,跳过加载。"
)
for node in nodes_to_remove:
G.remove_node(node)
return G

async def batch_reload(self, specified_module_path=None, plugin_modules=None):
if not plugin_modules:
plugin_modules = await self._get_load_order(
specified_module_path=specified_module_path
)
for plugin_module in plugin_modules:
specified_module_path = self._build_module_path(plugin_module)
smd = star_map.get(specified_module_path)
if smd:
try:
await self._terminate_plugin(smd)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
)
await self._unbind_plugin(smd.name, specified_module_path)

return await self.load(plugin_modules=plugin_modules)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dependencies = [
"watchfiles>=1.0.5",
"websockets>=15.0.1",
"wechatpy>=1.8.18",
"networkx>=3.4.2",
]

[project.scripts]
Expand Down
Loading