Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

複数プロセス利用時にユーザー辞書を活用するためのpyopenjtalk別プロセス化 #89

Merged
merged 12 commits into from
Mar 8, 2024
1 change: 1 addition & 0 deletions common/subprocess_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def run_script_with_log(cmd: list[str], ignore_warning=False) -> tuple[bool, str
stdout=SAFE_STDOUT, # type: ignore
stderr=subprocess.PIPE,
text=True,
encoding="utf-8",
)
if result.returncode != 0:
logger.error(f"Error: {' '.join(cmd)}\n{result.stderr}")
Expand Down
12 changes: 8 additions & 4 deletions preprocess_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import click
from tqdm import tqdm

from common.log import logger
from common.stdout_wrapper import SAFE_STDOUT
from config import config
from text.cleaner import clean_text
from common.stdout_wrapper import SAFE_STDOUT
from common.log import logger

preprocess_text_config = config.preprocess_text_config

Expand Down Expand Up @@ -89,7 +89,10 @@ def preprocess(
)
)
except Exception as e:
logger.error(f"An error occurred at line:\n{line.strip()}\n{e}")
logger.error(
f"An error occurred at line:\n{line.strip()}\n{e}",
encoding="utf-8",
)
with open(error_log_path, "a", encoding="utf-8") as error_log:
error_log.write(f"{line.strip()}\n{e}\n\n")
error_count += 1
Expand Down Expand Up @@ -172,8 +175,9 @@ def preprocess(
f"An error occurred in {error_count} lines. Please check {error_log_path} for details."
)
raise Exception(
f"An error occurred in {error_count} lines. Please check {error_log_path} for details."
f"An error occurred in {error_count} lines. Please check `Data/you_model_name/text_error.log` file for details."
)
# 何故か{error_log_path}をraiseすると文字コードエラーが起きるので上のように書いている
else:
logger.info(
"Training set and validation set generation from texts is complete!"
Expand Down
1 change: 0 additions & 1 deletion server_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import yaml

import numpy as np
import pyopenjtalk
import requests
import torch
import uvicorn
Expand Down
4 changes: 3 additions & 1 deletion text/japanese.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import unicodedata
from pathlib import Path

import pyopenjtalk
from . import pyopenjtalk_worker as pyopenjtalk

pyopenjtalk.initialize()
from num2words import num2words
from transformers import AutoTokenizer

Expand Down
129 changes: 129 additions & 0 deletions text/pyopenjtalk_worker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""
Run the pyopenjtalk worker in a separate process
to avoid user dictionary access error
"""

from typing import Optional, Any

from .worker_common import WOKER_PORT
from .worker_client import WorkerClient

from common.log import logger

WORKER_CLIENT: Optional[WorkerClient] = None

# pyopenjtalk interface

# g2p: not used


def run_frontend(text: str) -> list[dict[str, Any]]:
assert WORKER_CLIENT
ret = WORKER_CLIENT.dispatch_pyopenjtalk("run_frontend", text)
assert isinstance(ret, list)
return ret


def make_label(njd_features) -> list[str]:
assert WORKER_CLIENT
ret = WORKER_CLIENT.dispatch_pyopenjtalk("make_label", njd_features)
assert isinstance(ret, list)
return ret


def mecab_dict_index(path: str, out_path: str, dn_mecab: Optional[str] = None):
assert WORKER_CLIENT
WORKER_CLIENT.dispatch_pyopenjtalk("mecab_dict_index", path, out_path, dn_mecab)


def update_global_jtalk_with_user_dict(path: str):
assert WORKER_CLIENT
WORKER_CLIENT.dispatch_pyopenjtalk("update_global_jtalk_with_user_dict", path)


def unset_user_dict():
assert WORKER_CLIENT
WORKER_CLIENT.dispatch_pyopenjtalk("unset_user_dict")


# initialize module when imported


def initialize(port: int = WOKER_PORT):
import time
import socket
import sys
import atexit
import signal

logger.debug("initialize")
global WORKER_CLIENT
if WORKER_CLIENT:
return

client = None
try:
client = WorkerClient(port)
except (socket.timeout, socket.error):
logger.debug("try starting pyopenjtalk worker server")
import os
import subprocess

worker_pkg_path = os.path.relpath(
os.path.dirname(__file__), os.getcwd()
).replace(os.sep, ".")
args = [sys.executable, "-m", worker_pkg_path, "--port", str(port)]
# new session, new process group
if sys.platform.startswith("win"):
cf = subprocess.CREATE_NEW_CONSOLE | subprocess.CREATE_NEW_PROCESS_GROUP # type: ignore
si = subprocess.STARTUPINFO() # type: ignore
si.dwFlags |= subprocess.STARTF_USESHOWWINDOW # type: ignore
si.wShowWindow = subprocess.SW_HIDE # type: ignore
subprocess.Popen(args, creationflags=cf, startupinfo=si)
else:
# align with Windows behavior
# start_new_session is same as specifying setsid in preexec_fn
subprocess.Popen(args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True) # type: ignore

# wait until server listening
count = 0
while True:
try:
client = WorkerClient(port)
break
except socket.error:
time.sleep(1)
count += 1
# 10: max number of retries
if count == 10:
raise TimeoutError("サーバーに接続できませんでした")

WORKER_CLIENT = client
atexit.register(terminate)

# when the process is killed
def signal_handler(signum, frame):
with open("signal_handler.txt", mode="w") as f:

pass
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ここはpassにしていますがログか何かを書き込む予定でしょうか?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

デバッグの痕跡が混入してしまったようです。
申し訳ございません。
削除いたします。

terminate()

signal.signal(signal.SIGTERM, signal_handler)


# top-level declaration
def terminate():
logger.debug("terminate")
global WORKER_CLIENT
if not WORKER_CLIENT:
return

# repare for unexpected errors
try:
if WORKER_CLIENT.status() == 1:
WORKER_CLIENT.quit_server()
except Exception as e:
logger.error(e)

WORKER_CLIENT.close()
WORKER_CLIENT = None
16 changes: 16 additions & 0 deletions text/pyopenjtalk_worker/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import argparse

from .worker_server import WorkerServer
from .worker_common import WOKER_PORT


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=WOKER_PORT)
args = parser.parse_args()
server = WorkerServer()
server.start_server(port=args.port)


if __name__ == "__main__":
main()
55 changes: 55 additions & 0 deletions text/pyopenjtalk_worker/worker_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Any
import socket

from .worker_common import RequestType, receive_data, send_data

from common.log import logger


class WorkerClient:
def __init__(self, port: int) -> None:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# 5: timeout
sock.settimeout(5)
sock.connect((socket.gethostname(), port))
self.sock = sock

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.close()

def close(self):
self.sock.close()

def dispatch_pyopenjtalk(self, func: str, *args, **kwargs):
data = {
"request-type": RequestType.PYOPENJTALK,
"func": func,
"args": args,
"kwargs": kwargs,
}
logger.trace(f"client sends request: {data}")
send_data(self.sock, data)
logger.trace("client sent request successfully")
response = receive_data(self.sock)
logger.trace(f"client received response: {response}")
return response.get("return")

def status(self):
data = {"request-type": RequestType.STATUS}
logger.trace(f"client sends request: {data}")
send_data(self.sock, data)
logger.trace("client sent request successfully")
response = receive_data(self.sock)
logger.trace(f"client received response: {response}")
return response.get("client-count")

def quit_server(self):
data = {"request-type": RequestType.QUIT_SERVER}
logger.trace(f"client sends request: {data}")
send_data(self.sock, data)
logger.trace("client sent request successfully")
response = receive_data(self.sock)
logger.trace(f"client received response: {response}")
44 changes: 44 additions & 0 deletions text/pyopenjtalk_worker/worker_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Any, Optional, Final
from enum import IntEnum, auto
import socket
import json

WOKER_PORT: Final[int] = 7861
HEADER_SIZE: Final[int] = 4


class RequestType(IntEnum):
STATUS = auto()
QUIT_SERVER = auto()
PYOPENJTALK = auto()


class ConnectionClosedException(Exception):
pass


# socket communication


def send_data(sock: socket.socket, data: dict[str, Any]):
json_data = json.dumps(data).encode()
header = len(json_data).to_bytes(HEADER_SIZE, byteorder="big")
sock.sendall(header + json_data)


def _receive_until(sock: socket.socket, size: int):
data = b""
while len(data) < size:
part = sock.recv(size - len(data))
if part == b"":
raise ConnectionClosedException("接続が閉じられました")
data += part

return data


def receive_data(sock: socket.socket) -> dict[str, Any]:
header = _receive_until(sock, HEADER_SIZE)
data_length = int.from_bytes(header, byteorder="big")
body = _receive_until(sock, data_length)
return json.loads(body.decode())
Loading