Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 37 additions & 47 deletions plugins/AITagger/ai_server.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional, Set
import aiohttp
import pydantic
import config
import stashapi.log as log

current_videopipeline = None

# ----------------- AI Server Calling Functions -----------------

async def post_api_async(session, endpoint, payload):
Expand Down Expand Up @@ -38,55 +36,47 @@ async def process_images_async(image_paths, threshold=config.IMAGE_THRESHOLD, re
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session:
return await post_api_async(session, 'process_images/', {"paths": image_paths, "threshold": threshold, "return_confidence": return_confidence})

async def process_video_async(video_path, vr_video=False, frame_interval=config.FRAME_INTERVAL,threshold=config.AI_VIDEO_THRESHOLD, return_confidence=True):
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session:
return await post_api_async(session, 'process_video/', {"path": video_path, "frame_interval": frame_interval, "threshold": threshold, "return_confidence": return_confidence, "vr_video": vr_video})

async def get_image_config_async(threshold=config.IMAGE_THRESHOLD):
async def process_video_async(video_path, vr_video=False, frame_interval=config.FRAME_INTERVAL,threshold=config.AI_VIDEO_THRESHOLD, return_confidence=True, existing_json=None):
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session:
return await get_api_async(session, f'image_pipeline_info/?threshold={threshold}')
return await post_api_async(session, 'process_video/', {"path": video_path, "frame_interval": frame_interval, "threshold": threshold, "return_confidence": return_confidence, "vr_video": vr_video, "existing_json_data": existing_json})

async def get_video_config_async(frame_interval=config.FRAME_INTERVAL, threshold=config.AI_VIDEO_THRESHOLD):
async def find_optimal_marker_settings(existing_json, desired_timespan_data):
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session:
return await get_api_async(session, f'video_pipeline_info/?frame_interval={frame_interval}&threshold={threshold}&return_confidence=True')

return await post_api_async(session, 'optimize_timeframe_settings/', {"existing_json_data": existing_json, "desired_timespan_data": desired_timespan_data})


class VideoResult(pydantic.BaseModel):
result: List[Dict[str, Any]] = pydantic.Field(..., min_items=1)
pipeline_short_name: str
pipeline_version: float
threshold: float
frame_interval: float
return_confidence: bool
result: Dict[str, Any]

class ImageResult(pydantic.BaseModel):
result: List[Dict[str, Any]] = pydantic.Field(..., min_items=1)
pipeline_short_name: str
pipeline_version: float
threshold: float
return_confidence: bool
class TimeFrame(pydantic.BaseModel):
start: float
end: float
totalConfidence: Optional[float]

class ImagePipelineInfo(pydantic.BaseModel):
pipeline_short_name: str
pipeline_version: float
threshold: float
return_confidence: bool
def to_json(self):
return self.model_dump_json(exclude_none=True)

class VideoPipelineInfo(pydantic.BaseModel):
pipeline_short_name: str
pipeline_version: float
threshold: float
frame_interval: float
return_confidence: bool
def __str__(self):
return f"TimeFrame(start={self.start}, end={self.end})"

async def get_current_video_pipeline():
global current_videopipeline
if current_videopipeline is not None:
return current_videopipeline
try:
current_videopipeline = VideoPipelineInfo(**await get_video_config_async())
except aiohttp.ClientConnectionError as e:
log.error(f"Failed to connect to AI server. Is the AI server running at {config.API_BASE_URL}? {e}")
except Exception as e:
log.error(f"Failed to get pipeline info: {e}. Ensure the AI server is running with at least version 1.3.1!")
raise
return current_videopipeline
class VideoTagInfo(pydantic.BaseModel):
video_duration: float
video_tags: Dict[str, Set[str]]
tag_totals: Dict[str, Dict[str, float]]
tag_timespans: Dict[str, Dict[str, List[TimeFrame]]]

@classmethod
def from_json(cls, json_str: str):
log.info(f"json_str: {json_str}")
log.info(f"video_duration: {json_str['video_duration']}, video_tags: {json_str['video_tags']}, tag_totals: {json_str['tag_totals']}, tag_timespans: {json_str['tag_timespans']}")
return cls(video_duration=json_str["video_duration"], video_tags=json_str["video_tags"], tag_totals=json_str["tag_totals"], tag_timespans=json_str["tag_timespans"])

def __str__(self):
return f"VideoTagInfo(video_duration={self.video_duration}, video_tags={self.video_tags}, tag_totals={self.tag_totals}, tag_timespans={self.tag_timespans})"

class ImageResult(pydantic.BaseModel):
result: List[Dict[str, Any]] = pydantic.Field(..., min_items=1)

class OptimizeMarkerSettings(pydantic.BaseModel):
existing_json_data: Any = None
desired_timespan_data: Dict[str, TimeFrame]
Loading
Loading