Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions enhancers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .enhancer_registry import (
ENHANCER_REGISTRY,
get_enhancer,
get_enhancer_class,
get_supported_modalities,
list_registered_enhancers,
register_enhancer,
)

__all__ = [
"ENHANCER_REGISTRY",
"get_enhancer",
"get_enhancer_class",
"get_supported_modalities",
"list_registered_enhancers",
"register_enhancer",
]
32 changes: 32 additions & 0 deletions enhancers/base_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from abc import ABC, abstractmethod

from parsers.base_models import ChunkData


class InformationEnhancer(ABC):
"""信息增强器基类"""
@abstractmethod
async def enhance(self, information: ChunkData) -> ChunkData:
"""增强信息"""
pass

class TableInformationEnhancer(InformationEnhancer):
"""表格信息增强器"""

async def enhance(self, information: ChunkData) -> ChunkData:
"""增强信息"""
return information

class FormulasInformationEnhancer(InformationEnhancer):
"""公式信息增强器"""

async def enhance(self, information: ChunkData) -> ChunkData:
"""增强信息"""
return information

class ImageInformationEnhancer(InformationEnhancer):
"""图片信息增强器"""

async def enhance(self, information: ChunkData) -> ChunkData:
"""增强信息"""
return information
105 changes: 105 additions & 0 deletions enhancers/enhancer_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
解析器注册器模块

提供基于装饰器的解析器自动注册机制,支持多种文件格式的解析器注册和查找。
"""

import logging
from collections.abc import Callable

from enhancers.base_models import InformationEnhancer
from parsers.base_models import ChunkType

logger = logging.getLogger(__name__)

# 全局解析器注册表
ENHANCER_REGISTRY: dict[str, type[InformationEnhancer]] = {}


def register_enhancer(modalities: list[ChunkType]) -> Callable[[type[InformationEnhancer]], type[InformationEnhancer]]:
"""
信息增强器注册装饰器

Args:
modalities: 支持的模态类型列表,如 [ChunkType.TEXT, ChunkType.IMAGE, ChunkType.TABLE]

Returns:
装饰器函数

Example:
@register_enhancer([ChunkType.TEXT, ChunkType.IMAGE, ChunkType.TABLE])
class TextInformationEnhancer(InformationEnhancer):
...
"""
def decorator(cls: type[InformationEnhancer]) -> type[InformationEnhancer]:
# 验证类是否继承自 InformationEnhancer
if not issubclass(cls, InformationEnhancer):
raise TypeError(f"信息增强器类 {cls.__name__} 必须继承自 InformationEnhancer")

# 注册到全局注册表
for modality in modalities:
modality_type = modality.value.lower() # 统一转换为小写
if modality_type in ENHANCER_REGISTRY:
logger.error(f"覆盖已存在的信息增强器: {modality_type} -> {cls.__name__}")
raise ValueError(f"尝试覆盖已存在的信息增强器: {modality_type} -> {cls.__name__}")
ENHANCER_REGISTRY[modality_type] = cls
logger.info(f"注册信息增强器: {modality_type} -> {cls.__name__}")

return cls

return decorator

def get_enhancer(modality: ChunkType) -> InformationEnhancer | None:
"""
根据模态类型获取合适的信息增强器实例

Args:
modality: 模态类型

Returns:
信息增强器实例,如果没有找到则返回 None
"""
modality_type = modality.value.lower()

if modality_type not in ENHANCER_REGISTRY:
logger.warning(f"未找到支持 {modality} 格式的信息增强器")
return None

enhancer_class = ENHANCER_REGISTRY[modality_type]
try:
return enhancer_class()
except Exception as e:
logger.error(f"创建信息增强器实例失败: {enhancer_class.__name__}, 错误: {e}")
return None

def get_supported_modalities() -> list[str]:
"""
获取所有支持的模态类型

Returns:
支持的模态类型列表
"""
return list(ENHANCER_REGISTRY.keys())


def get_enhancer_class(modality: ChunkType) -> type[InformationEnhancer] | None:
"""
根据模态类型获取信息增强器类

Args:
modality: 模态类型

Returns:
信息增强器类,如果没有找到则返回 None
"""
return ENHANCER_REGISTRY.get(modality.value.lower())


def list_registered_enhancers() -> dict[str, str]:
"""
列出所有已注册的信息增强器

Returns:
模态类型到信息增强器类名的映射字典
"""
return {modality: cls.__name__ for modality, cls in ENHANCER_REGISTRY.items()}
62 changes: 0 additions & 62 deletions enhancers/information_enhancer.py

This file was deleted.

4 changes: 3 additions & 1 deletion parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Parsers package

from .base_models import DocumentData, DocumentParser
from .base_models import ChunkData, ChunkType, DocumentData, DocumentParser
from .parser_registry import (
PARSER_REGISTRY,
get_parser,
Expand All @@ -12,6 +12,8 @@
__all__ = [
'DocumentData',
'DocumentParser',
'ChunkData',
'ChunkType',
'PARSER_REGISTRY',
'register_parser',
'get_parser',
Expand Down
38 changes: 22 additions & 16 deletions worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@

from sanic import Sanic

from enhancers.information_enhancer import InformationEnhancerFactory
from parsers import get_parser, load_all_parsers
from parsers.base_models import ChunkData
from enhancers import get_enhancer
from parsers import ChunkData, ChunkType, get_parser, load_all_parsers


async def worker(app: Sanic) -> dict[str, Any]:
# 使用工厂获取合适的解析器
load_all_parsers()
enhancer_factory = InformationEnhancerFactory()
redis = app.ctx.redis
while True:
task = await redis.get_task()
Expand All @@ -25,21 +23,29 @@ async def worker(app: Sanic) -> dict[str, Any]:
parse_result = await parser.parse(file_path)
if not parse_result.success:
continue
chunk_list = parse_result.texts + parse_result.tables + parse_result.images + parse_result.formulas
# 控制并发数量,防止访问量过大导致失败
SEMAPHORE_LIMIT = 10 # 可根据实际情况调整
SEMAPHORE_LIMIT = 10
semaphore = asyncio.Semaphore(SEMAPHORE_LIMIT)

async def enhance_with_semaphore(chunk: ChunkData, semaphore: asyncio.Semaphore) -> ChunkData:
async with semaphore:
return await enhancer_factory.enhance_information(chunk)

# 并发增强每个信息
enhanced_chunk_list = await asyncio.gather(
*(enhance_with_semaphore(chunk, semaphore) for chunk in chunk_list)
)
parse_result.texts = enhanced_chunk_list[:len(parse_result.texts)]
parse_result.tables = enhanced_chunk_list[len(parse_result.texts):len(parse_result.texts) + len(parse_result.tables)]
parse_result.images = enhanced_chunk_list[len(parse_result.texts) + len(parse_result.tables):len(parse_result.texts) + len(parse_result.tables) + len(parse_result.images)]
parse_result.formulas = enhanced_chunk_list[len(parse_result.texts) + len(parse_result.tables) + len(parse_result.images):]
enhancer = get_enhancer(ChunkType(chunk.type))
if not enhancer:
return chunk
return await enhancer.enhance(chunk)

text_tasks = [enhance_with_semaphore(chunk, semaphore) for chunk in parse_result.texts]
table_tasks = [enhance_with_semaphore(chunk, semaphore) for chunk in parse_result.tables]
image_tasks = [enhance_with_semaphore(chunk, semaphore) for chunk in parse_result.images]
formula_tasks = [enhance_with_semaphore(chunk, semaphore) for chunk in parse_result.formulas]

text_chunk_list = await asyncio.gather(*text_tasks)
table_chunk_list = await asyncio.gather(*table_tasks)
image_chunk_list = await asyncio.gather(*image_tasks)
formula_chunk_list = await asyncio.gather(*formula_tasks)

parse_result.texts = text_chunk_list
parse_result.tables = table_chunk_list
parse_result.images = image_chunk_list
parse_result.formulas = formula_chunk_list
return parse_result.model_dump(mode="json")
Loading