forked from fishaudio/Bert-VITS2
-
Notifications
You must be signed in to change notification settings - Fork 91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
複数プロセス利用時にユーザー辞書を活用するためのpyopenjtalk別プロセス化 #89
Merged
Merged
Changes from 10 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
beb6a64
clean
litagin02 8b17d4d
Try to fix colab encoding error
litagin02 46935c6
Try to fix encoding error
litagin02 bc10582
Delete pyopenjtalk import
kale4eat 1eb8fb4
add openjtalk worker pkg
kale4eat 85f5b9b
replace pyopenjtalk import with worker
kale4eat 98ff976
modify logging
kale4eat 2ab025d
Run in a separate process group to avoid receiving signals by Ctrl + C
kale4eat 45c6bde
In Windows, create new console and hide it
kale4eat ed90af1
Enhanced server error handling
kale4eat 419d0e5
Delete debugging traces
kale4eat 2994873
add no client timeout
kale4eat File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,6 @@ | |
import yaml | ||
|
||
import numpy as np | ||
import pyopenjtalk | ||
import requests | ||
import torch | ||
import uvicorn | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
""" | ||
Run the pyopenjtalk worker in a separate process | ||
to avoid user dictionary access error | ||
""" | ||
|
||
from typing import Optional, Any | ||
|
||
from .worker_common import WOKER_PORT | ||
from .worker_client import WorkerClient | ||
|
||
from common.log import logger | ||
|
||
WORKER_CLIENT: Optional[WorkerClient] = None | ||
|
||
# pyopenjtalk interface | ||
|
||
# g2p: not used | ||
|
||
|
||
def run_frontend(text: str) -> list[dict[str, Any]]: | ||
assert WORKER_CLIENT | ||
ret = WORKER_CLIENT.dispatch_pyopenjtalk("run_frontend", text) | ||
assert isinstance(ret, list) | ||
return ret | ||
|
||
|
||
def make_label(njd_features) -> list[str]: | ||
assert WORKER_CLIENT | ||
ret = WORKER_CLIENT.dispatch_pyopenjtalk("make_label", njd_features) | ||
assert isinstance(ret, list) | ||
return ret | ||
|
||
|
||
def mecab_dict_index(path: str, out_path: str, dn_mecab: Optional[str] = None): | ||
assert WORKER_CLIENT | ||
WORKER_CLIENT.dispatch_pyopenjtalk("mecab_dict_index", path, out_path, dn_mecab) | ||
|
||
|
||
def update_global_jtalk_with_user_dict(path: str): | ||
assert WORKER_CLIENT | ||
WORKER_CLIENT.dispatch_pyopenjtalk("update_global_jtalk_with_user_dict", path) | ||
|
||
|
||
def unset_user_dict(): | ||
assert WORKER_CLIENT | ||
WORKER_CLIENT.dispatch_pyopenjtalk("unset_user_dict") | ||
|
||
|
||
# initialize module when imported | ||
|
||
|
||
def initialize(port: int = WOKER_PORT): | ||
import time | ||
import socket | ||
import sys | ||
import atexit | ||
import signal | ||
|
||
logger.debug("initialize") | ||
global WORKER_CLIENT | ||
if WORKER_CLIENT: | ||
return | ||
|
||
client = None | ||
try: | ||
client = WorkerClient(port) | ||
except (socket.timeout, socket.error): | ||
logger.debug("try starting pyopenjtalk worker server") | ||
import os | ||
import subprocess | ||
|
||
worker_pkg_path = os.path.relpath( | ||
os.path.dirname(__file__), os.getcwd() | ||
).replace(os.sep, ".") | ||
args = [sys.executable, "-m", worker_pkg_path, "--port", str(port)] | ||
# new session, new process group | ||
if sys.platform.startswith("win"): | ||
cf = subprocess.CREATE_NEW_CONSOLE | subprocess.CREATE_NEW_PROCESS_GROUP # type: ignore | ||
si = subprocess.STARTUPINFO() # type: ignore | ||
si.dwFlags |= subprocess.STARTF_USESHOWWINDOW # type: ignore | ||
si.wShowWindow = subprocess.SW_HIDE # type: ignore | ||
subprocess.Popen(args, creationflags=cf, startupinfo=si) | ||
else: | ||
# align with Windows behavior | ||
# start_new_session is same as specifying setsid in preexec_fn | ||
subprocess.Popen(args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True) # type: ignore | ||
|
||
# wait until server listening | ||
count = 0 | ||
while True: | ||
try: | ||
client = WorkerClient(port) | ||
break | ||
except socket.error: | ||
time.sleep(1) | ||
count += 1 | ||
# 10: max number of retries | ||
if count == 10: | ||
raise TimeoutError("サーバーに接続できませんでした") | ||
|
||
WORKER_CLIENT = client | ||
atexit.register(terminate) | ||
|
||
# when the process is killed | ||
def signal_handler(signum, frame): | ||
with open("signal_handler.txt", mode="w") as f: | ||
|
||
pass | ||
terminate() | ||
|
||
signal.signal(signal.SIGTERM, signal_handler) | ||
|
||
|
||
# top-level declaration | ||
def terminate(): | ||
logger.debug("terminate") | ||
global WORKER_CLIENT | ||
if not WORKER_CLIENT: | ||
return | ||
|
||
# repare for unexpected errors | ||
try: | ||
if WORKER_CLIENT.status() == 1: | ||
WORKER_CLIENT.quit_server() | ||
except Exception as e: | ||
logger.error(e) | ||
|
||
WORKER_CLIENT.close() | ||
WORKER_CLIENT = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import argparse | ||
|
||
from .worker_server import WorkerServer | ||
from .worker_common import WOKER_PORT | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--port", type=int, default=WOKER_PORT) | ||
args = parser.parse_args() | ||
server = WorkerServer() | ||
server.start_server(port=args.port) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from typing import Any | ||
import socket | ||
|
||
from .worker_common import RequestType, receive_data, send_data | ||
|
||
from common.log import logger | ||
|
||
|
||
class WorkerClient: | ||
def __init__(self, port: int) -> None: | ||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | ||
# 5: timeout | ||
sock.settimeout(5) | ||
sock.connect((socket.gethostname(), port)) | ||
self.sock = sock | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_value, traceback): | ||
self.close() | ||
|
||
def close(self): | ||
self.sock.close() | ||
|
||
def dispatch_pyopenjtalk(self, func: str, *args, **kwargs): | ||
data = { | ||
"request-type": RequestType.PYOPENJTALK, | ||
"func": func, | ||
"args": args, | ||
"kwargs": kwargs, | ||
} | ||
logger.trace(f"client sends request: {data}") | ||
send_data(self.sock, data) | ||
logger.trace("client sent request successfully") | ||
response = receive_data(self.sock) | ||
logger.trace(f"client received response: {response}") | ||
return response.get("return") | ||
|
||
def status(self): | ||
data = {"request-type": RequestType.STATUS} | ||
logger.trace(f"client sends request: {data}") | ||
send_data(self.sock, data) | ||
logger.trace("client sent request successfully") | ||
response = receive_data(self.sock) | ||
logger.trace(f"client received response: {response}") | ||
return response.get("client-count") | ||
|
||
def quit_server(self): | ||
data = {"request-type": RequestType.QUIT_SERVER} | ||
logger.trace(f"client sends request: {data}") | ||
send_data(self.sock, data) | ||
logger.trace("client sent request successfully") | ||
response = receive_data(self.sock) | ||
logger.trace(f"client received response: {response}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from typing import Any, Optional, Final | ||
from enum import IntEnum, auto | ||
import socket | ||
import json | ||
|
||
WOKER_PORT: Final[int] = 7861 | ||
HEADER_SIZE: Final[int] = 4 | ||
|
||
|
||
class RequestType(IntEnum): | ||
STATUS = auto() | ||
QUIT_SERVER = auto() | ||
PYOPENJTALK = auto() | ||
|
||
|
||
class ConnectionClosedException(Exception): | ||
pass | ||
|
||
|
||
# socket communication | ||
|
||
|
||
def send_data(sock: socket.socket, data: dict[str, Any]): | ||
json_data = json.dumps(data).encode() | ||
header = len(json_data).to_bytes(HEADER_SIZE, byteorder="big") | ||
sock.sendall(header + json_data) | ||
|
||
|
||
def _receive_until(sock: socket.socket, size: int): | ||
data = b"" | ||
while len(data) < size: | ||
part = sock.recv(size - len(data)) | ||
if part == b"": | ||
raise ConnectionClosedException("接続が閉じられました") | ||
data += part | ||
|
||
return data | ||
|
||
|
||
def receive_data(sock: socket.socket) -> dict[str, Any]: | ||
header = _receive_until(sock, HEADER_SIZE) | ||
data_length = int.from_bytes(header, byteorder="big") | ||
body = _receive_until(sock, data_length) | ||
return json.loads(body.decode()) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ここはpassにしていますがログか何かを書き込む予定でしょうか?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
デバッグの痕跡が混入してしまったようです。
申し訳ございません。
削除いたします。