diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fc4402194a..0344a282bea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # 更新日志(Changelog) +## v1.2.2 + +### 2024/6/16 + +- 优化在线查询更新速度与修复无更新结果情况(Optimize online query update speed and fix no update result situation) +- 解决个别环境运行更新报错(Solved the problem of running updates in some environments) + ## v1.2.1 ### 2024/6/14 diff --git a/main.py b/main.py index 547ef412f66..50f46b7bc45 100644 --- a/main.py +++ b/main.py @@ -5,21 +5,20 @@ update_file, sort_urls_by_speed_and_resolution, get_total_urls_from_info_list, - use_accessible_url, get_channels_by_subscribe_urls, check_url_by_patterns, get_channels_by_fofa, - async_get_channels_info_list_by_online_search, + get_channels_by_online_search, format_channel_name, resource_path, load_external_config, + get_pbar_remaining, ) import logging from logging.handlers import RotatingFileHandler import os from tqdm import tqdm from tqdm.asyncio import tqdm_asyncio -import threading from time import time config_path = resource_path("user_config.py") @@ -38,29 +37,12 @@ def __init__(self): self.tasks = [] self.channel_items = get_channel_items() self.results = {} - self.channel_queue = asyncio.Queue() self.semaphore = asyncio.Semaphore(10) self.channel_data = {} self.pbar = None self.total = 0 self.start_time = None - def get_pbar_remaining(self): - try: - elapsed = time() - self.start_time - completed_tasks = self.pbar.n - if completed_tasks > 0: - avg_time_per_task = elapsed / completed_tasks - remaining_tasks = self.pbar.total - completed_tasks - remaining_time = self.pbar.format_interval( - avg_time_per_task * remaining_tasks - ) - else: - remaining_time = "未知" - return remaining_time - except Exception as e: - print(f"Error: {e}") - def append_data_to_info_data(self, cate, name, data): for url, date, resolution in data: if url and check_url_by_patterns(url): @@ -96,14 +78,13 @@ async def sort_channel_list(self, cate, name, info_list): f"Sorting, {self.pbar.total - self.pbar.n} urls remaining" ) self.update_progress( - f"正在测速排序, 剩余{self.pbar.total - self.pbar.n}个接口, 预计剩余时间: {self.get_pbar_remaining()}", + f"正在测速排序, 剩余{self.pbar.total - self.pbar.n}个接口, 预计剩余时间: {get_pbar_remaining(self.pbar, self.start_time)}", int((self.pbar.n / self.total) * 100), ) - async def process_channel(self): - async with self.semaphore: - try: - cate, name, old_urls = await self.channel_queue.get() + def process_channel(self): + for cate, channel_obj in self.channel_items.items(): + for name, old_urls in channel_obj.items(): format_name = format_channel_name(name) if config.open_subscribe: self.append_data_to_info_data( @@ -113,31 +94,16 @@ async def process_channel(self): self.append_data_to_info_data( cate, name, self.results["open_multicast"].get(format_name, []) ) - if config.open_online_search and self.results["open_online_search"]: - online_info_list = ( - await async_get_channels_info_list_by_online_search( - self.results["open_online_search"], - format_name, - ) + if config.open_online_search: + self.append_data_to_info_data( + cate, + name, + self.results["open_online_search"].get(format_name, []), ) - if online_info_list: - self.append_data_to_info_data(cate, name, online_info_list) if len(self.channel_data.get(cate, {}).get(name, [])) == 0: self.append_data_to_info_data( cate, name, [(url, None, None) for url in old_urls] ) - except asyncio.exceptions.CancelledError: - print("Update cancelled!") - finally: - self.channel_queue.task_done() - self.pbar.update() - self.pbar.set_description( - f"Processing, {self.total - self.pbar.n} channels remaining" - ) - self.update_progress( - f"正在更新, 剩余{self.total - self.pbar.n}个频道待处理, 预计剩余时间: {self.get_pbar_remaining()}", - int((self.pbar.n / self.total) * 100), - ) def write_channel_to_file(self): self.pbar = tqdm(total=self.total) @@ -155,25 +121,25 @@ def write_channel_to_file(self): f"Writing, {self.pbar.total - self.pbar.n} channels remaining" ) self.update_progress( - f"正在写入结果, 剩余{self.pbar.total - self.pbar.n}个接口, 预计剩余时间: {self.get_pbar_remaining()}", + f"正在写入结果, 剩余{self.pbar.total - self.pbar.n}个接口, 预计剩余时间: {get_pbar_remaining(self.pbar, self.start_time)}", int((self.pbar.n / self.total) * 100), ) - async def visit_page(self): + async def visit_page(self, channel_names=None): task_dict = { "open_subscribe": get_channels_by_subscribe_urls, "open_multicast": get_channels_by_fofa, - "open_online_search": use_accessible_url, + "open_online_search": get_channels_by_online_search, } for config_name, task_func in task_dict.items(): if getattr(config, config_name): task = None - if config_name == "open_subscribe": - task = asyncio.create_task(task_func(self.update_progress)) - elif config_name == "open_multicast": + if config_name == "open_subscribe" or config_name == "open_multicast": task = asyncio.create_task(task_func(self.update_progress)) else: - task = asyncio.create_task(task_func(self.update_progress)) + task = asyncio.create_task( + task_func(channel_names, self.update_progress) + ) if task: self.tasks.append(task) task_results = await tqdm_asyncio.gather(*self.tasks, disable=True) @@ -182,26 +148,18 @@ async def visit_page(self): [name for name in task_dict if getattr(config, name)] ): self.results[config_name] = task_results[i] - for cate, channel_obj in self.channel_items.items(): - for name in channel_obj.keys(): - await self.channel_queue.put((cate, name, channel_obj[name])) async def main(self): try: self.tasks = [] - await self.visit_page() - self.total = self.channel_queue.qsize() - self.tasks = [ - asyncio.create_task(self.process_channel()) for _ in range(self.total) + channel_names = [ + name + for cate, channel_obj in self.channel_items.items() + for name in channel_obj.keys() ] - self.pbar = tqdm_asyncio(total=self.total) - self.pbar.set_description(f"Processing, {self.total} channels remaining") - self.update_progress( - f"正在更新, 共{self.total}个频道", - int((self.pbar.n / self.total) * 100), - ) - self.start_time = time() - await tqdm_asyncio.gather(*self.tasks, disable=True) + self.total = len(channel_names) + await self.visit_page(channel_names) + self.process_channel() if config.open_sort: self.tasks = [ asyncio.create_task(self.sort_channel_list(cate, name, info_list)) @@ -251,8 +209,8 @@ def stop(self): for task in self.tasks: task.cancel() self.tasks = [] - asyncio.get_event_loop().stop() - self.pbar.close() + if self.pbar: + self.pbar.close() if __name__ == "__main__": diff --git a/tkinter_ui.py b/tkinter_ui.py index 4c1a8007629..1f6ad551e17 100644 --- a/tkinter_ui.py +++ b/tkinter_ui.py @@ -23,7 +23,7 @@ class TkinterUI: def __init__(self, root): self.root = root self.root.title("直播源接口更新工具") - self.version = "v1.0.1" + self.version = "v1.0.2" self.update_source = UpdateSource() self.update_running = False self.config_entrys = [ @@ -194,10 +194,7 @@ def on_run_update(self): def run_loop(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - try: - loop.run_until_complete(self.run_update()) - finally: - loop.close() + loop.run_until_complete(self.run_update()) self.thread = threading.Thread(target=run_loop, daemon=True) self.thread.start() diff --git a/utils.py b/utils.py index 29f24682fcf..f616d0872cd 100644 --- a/utils.py +++ b/utils.py @@ -1,7 +1,7 @@ from selenium import webdriver import aiohttp import asyncio -import time +from time import time import re import datetime import os @@ -162,6 +162,24 @@ def get_channel_items(): return channels +def get_pbar_remaining(pbar, start_time): + """ + Get the remaining time of the progress bar + """ + try: + elapsed = time() - start_time + completed_tasks = pbar.n + if completed_tasks > 0: + avg_time_per_task = elapsed / completed_tasks + remaining_tasks = pbar.total - completed_tasks + remaining_time = pbar.format_interval(avg_time_per_task * remaining_tasks) + else: + remaining_time = "未知" + return remaining_time + except Exception as e: + print(f"Error: {e}") + + async def get_channels_by_subscribe_urls(callback): """ Get the channels by subscribe urls @@ -170,6 +188,7 @@ async def get_channels_by_subscribe_urls(callback): pattern = r"^(.*?),(?!#genre#)(.*?)$" subscribe_urls_len = len(config.subscribe_urls) pbar = tqdm_asyncio(total=subscribe_urls_len) + start_time = time() def process_subscribe_channels(subscribe_url): try: @@ -204,9 +223,11 @@ def process_subscribe_channels(subscribe_url): remain = subscribe_urls_len - pbar.n pbar.set_description(f"Processing subscribe, {remain} urls remaining") callback( - f"正在获取订阅源更新, 剩余{remain}个订阅源待获取", + f"正在获取订阅源更新, 剩余{remain}个订阅源待获取, 预计剩余时间: {get_pbar_remaining(pbar, start_time)}", int((pbar.n / subscribe_urls_len) * 100), ) + if config.open_online_search and pbar.n / subscribe_urls_len == 1: + callback("正在获取在线搜索结果, 请耐心等待", 0) pbar.set_description(f"Processing subscribe, {subscribe_urls_len} urls remaining") callback(f"正在获取订阅源更新, 共{subscribe_urls_len}个订阅源", 0) @@ -222,74 +243,95 @@ def process_subscribe_channels(subscribe_url): return channels -def get_channels_info_list_by_online_search(pageUrl, name): +async def get_channels_by_online_search(names, callback): """ - Get the channels info list by online search + Get the channels by online search """ - driver = setup_driver() - wait = WebDriverWait(driver, 10) - info_list = [] - try: - driver.get(pageUrl) - search_box = wait.until( - EC.presence_of_element_located((By.XPATH, '//input[@type="text"]')) - ) - search_box.clear() - search_box.send_keys(name) - submit_button = wait.until( - EC.element_to_be_clickable((By.XPATH, '//input[@type="submit"]')) - ) - driver.execute_script("arguments[0].click();", submit_button) - isFavorite = name in config.favorite_list - pageNum = config.favorite_page_num if isFavorite else config.default_page_num - for page in range(1, pageNum + 1): - try: - if page > 1: - page_link = wait.until( - EC.element_to_be_clickable( - ( - By.XPATH, - f'//a[contains(@href, "={page}") and contains(@href, "{name}")]', + channels = {} + pageUrl = await use_accessible_url(callback) + if not pageUrl: + return channels + start_time = time() + + def process_channel_by_online_search(name): + driver = setup_driver() + wait = WebDriverWait(driver, 10) + info_list = [] + try: + driver.get(pageUrl) + search_box = wait.until( + EC.presence_of_element_located((By.XPATH, '//input[@type="text"]')) + ) + search_box.clear() + search_box.send_keys(name) + submit_button = wait.until( + EC.element_to_be_clickable((By.XPATH, '//input[@type="submit"]')) + ) + driver.execute_script("arguments[0].click();", submit_button) + isFavorite = name in config.favorite_list + pageNum = ( + config.favorite_page_num if isFavorite else config.default_page_num + ) + for page in range(1, pageNum + 1): + try: + if page > 1: + page_link = wait.until( + EC.element_to_be_clickable( + ( + By.XPATH, + f'//a[contains(@href, "={page}") and contains(@href, "{name}")]', + ) ) ) + driver.execute_script("arguments[0].click();", page_link) + source = re.sub( + r"", + "", + driver.page_source, + flags=re.DOTALL, ) - driver.execute_script("arguments[0].click();", page_link) - source = re.sub( - r"", - "", - driver.page_source, - flags=re.DOTALL, - ) - soup = BeautifulSoup(source, "html.parser") - if soup: - results = get_results_from_soup(soup, name) - for result in results: - url, date, resolution = result - if url and check_url_by_patterns(url): - info_list.append((url, date, resolution)) - except Exception as e: - # print(f"Error on page {page}: {e}") - continue - except Exception as e: - # print(f"Error on search: {e}") - pass - finally: - if driver: - driver.quit() - return info_list - - -async def async_get_channels_info_list_by_online_search(pageUrl, name): - """ - Get the channels info list by online search - """ - # with concurrent.futures.ThreadPoolExecutor() as pool: - # loop = asyncio.geto_running_loop() - # info_list = await loop.run_in_executor( - # pool, get_channels_info_list_by_online_search, pageUrl, name - # ) - info_list = get_channels_info_list_by_online_search(pageUrl, name) - return info_list + soup = BeautifulSoup(source, "html.parser") + if soup: + results = get_results_from_soup(soup, name) + for result in results: + url, date, resolution = result + if url and check_url_by_patterns(url): + info_list.append((url, date, resolution)) + except Exception as e: + # print(f"Error on page {page}: {e}") + continue + except Exception as e: + # print(f"Error on search: {e}") + pass + finally: + names_queue.task_done() + pbar.update() + pbar.set_description( + f"Processing online search, {names_len - pbar.n} channels remaining" + ) + callback( + f"正在线上查询更新, 剩余{names_len - pbar.n}个频道待查询, 预计剩余时间: {get_pbar_remaining(pbar, start_time)}", + int((pbar.n / names_len) * 100), + ) + if driver: + driver.quit() + channels[name] = info_list + + names_queue = asyncio.Queue() + for name in names: + await names_queue.put(name) + names_len = names_queue.qsize() + pbar = tqdm_asyncio(total=names_len) + pbar.set_description(f"Processing online search, {names_len} channels remaining") + callback(f"正在线上查询更新, 共{names_len}个频道", 0) + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as pool: + while not names_queue.empty(): + loop = asyncio.get_running_loop() + name = await names_queue.get() + loop.run_in_executor(pool, process_channel_by_online_search, name) + print("Finished processing online search") + pbar.close() + return channels def update_channel_urls_txt(cate, name, urls): @@ -382,13 +424,13 @@ async def get_speed(url, urlTimeout=5): Get the speed of the url """ async with aiohttp.ClientSession() as session: - start = time.time() + start = time() try: async with session.get(url, timeout=urlTimeout) as response: resStatus = response.status except: return float("inf") - end = time.time() + end = time() if resStatus == 200: return int(round((end - start) * 1000)) else: @@ -595,6 +637,7 @@ async def get_channels_by_fofa(callback): fofa_urls = get_fofa_urls_from_region_list() fofa_urls_len = len(fofa_urls) pbar = tqdm_asyncio(total=fofa_urls_len) + start_time = time() fofa_results = {} def process_fofa_channels(fofa_url, pbar, fofa_urls_len, callback): @@ -643,9 +686,11 @@ def process_fofa_channels(fofa_url, pbar, fofa_urls_len, callback): remain = fofa_urls_len - pbar.n pbar.set_description(f"Processing multicast, {remain} regions remaining") callback( - f"正在获取组播源更新, 剩余{remain}个地区待获取", + f"正在获取组播源更新, 剩余{remain}个地区待获取, 预计剩余时间: {get_pbar_remaining(pbar, start_time)}", int((pbar.n / fofa_urls_len) * 100), ) + if config.open_online_search and pbar.n / fofa_urls_len == 1: + callback("正在获取在线搜索结果, 请耐心等待", 0) if driver: driver.quit() diff --git a/version.json b/version.json index fc0ad7c997f..7964167ef5e 100644 --- a/version.json +++ b/version.json @@ -1,3 +1,3 @@ { - "version": "1.2.1" + "version": "1.2.2" } \ No newline at end of file