Skip to content

Commit

Permalink
Merge pull request #69 from futaringoto/fix/do_formatting
Browse files Browse the repository at this point in the history
フォーマッタを実行した
  • Loading branch information
shun-harutaro authored Aug 13, 2024
2 parents 34a0392 + 9cd06db commit 30cdb76
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 104 deletions.
8 changes: 7 additions & 1 deletion api/main.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
from contextlib import asynccontextmanager
from typing import Union

from fastapi import FastAPI
from contextlib import asynccontextmanager

from routers import raspi
from utils.config import check_env_variables


@asynccontextmanager
async def lifespan(app: FastAPI):
check_env_variables()
yield
print("Shutting down...")


app = FastAPI(lifespan=lifespan)
app.include_router(raspi.router)


@app.get("/")
def read_root():
return {"Hello": "World"}


@app.get("/items/{item_id}")
def read_item(item_id: int, q: Union[str, None] = None):
return {"item_id": item_id, "q": q}
118 changes: 45 additions & 73 deletions api/routers/raspi.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
import os
import tempfile
from datetime import datetime
from typing import Any, Dict

from fastapi import APIRouter, File, HTTPException, UploadFile
from fastapi.responses import FileResponse, JSONResponse
from datetime import datetime
from httpx import HTTPStatusError, RequestError

from services.gpt import generate_text
from services.tts import text2speech
from services.voicevox import audio_query, synthesis
from services.whisper import speech2text
#from services.tts import text2speech
from services.voicevox_api import get_voicevox_audio
from services.whisper import speech2text
from utils.log import upload_json_to_blob
from httpx import RequestError, HTTPStatusError
import tempfile
import os

router = APIRouter()

UPLOAD_DIR = "uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)

@router.post(
"/raspi/",
tags=["raspi"],
summary="一連の動作全て"
)
async def all(
speaker: int = 1,
file: UploadFile = File(...)
) -> JSONResponse:

@router.post("/raspi/", tags=["raspi"], summary="一連の動作全て")
async def all(speaker: int = 1, file: UploadFile = File(...)) -> JSONResponse:
file_location = os.path.join(UPLOAD_DIR, file.filename)
try:
# whisper
Expand All @@ -45,72 +41,63 @@ async def all(
log_data = {
"timestamp": iso_time,
"transcription": transcription.text,
"generated_text": generated_text
"generated_text": generated_text,
}
upload_json_to_blob(log_data)

# voicevox
#query: Dict[str, Any] = await audio_query(generated_text, speaker)
#audio: bytes = await synthesis(query, speaker)
#audio: bytes = text2speech(generated_text)
# query: Dict[str, Any] = await audio_query(generated_text, speaker)
# audio: bytes = await synthesis(query, speaker)
# audio: bytes = text2speech(generated_text)
audio: bytes = await get_voicevox_audio(generated_text, speaker)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
with open(temp_file.name, "wb") as f:
f.write(audio)
except RequestError as e:
raise HTTPException(status_code=500, detail=f"RequestError fetching data: {str(e)}")
raise HTTPException(
status_code=500, detail=f"RequestError fetching data: {str(e)}"
)
except HTTPStatusError as e:
raise HTTPException(status_code=e.response.status_code, detail=f"Error fetching data: {str(e)}")
return FileResponse(
temp_file.name,
media_type="audio/wav",
filename="audio.wav"
)

@router.post(
"/raspi/audio",
tags=["raspi"],
summary="VOICEVOXによる音声合成"
)
raise HTTPException(
status_code=e.response.status_code, detail=f"Error fetching data: {str(e)}"
)
return FileResponse(temp_file.name, media_type="audio/wav", filename="audio.wav")


@router.post("/raspi/audio", tags=["raspi"], summary="VOICEVOXによる音声合成")
async def audio(text: str, speaker: int = 1):
try:
query: Dict[str, Any] = await audio_query(text, speaker)
content: bytes = await synthesis(query, speaker)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
with open(temp_file.name, 'wb') as f:
with open(temp_file.name, "wb") as f:
f.write(content)
except RequestError as e:
raise HTTPException(status_code=500, detail=f"Error fetching data: {str(e)}")
except HTTPStatusError as e:
raise HTTPException(status_code=e.response.status_code, detail=f"Error fetching data: {str(e)}")
return FileResponse(
temp_file.name,
media_type="audio/wav",
filename="audio.wav"
)
raise HTTPException(
status_code=e.response.status_code, detail=f"Error fetching data: {str(e)}"
)
return FileResponse(temp_file.name, media_type="audio/wav", filename="audio.wav")


@router.post("/raspi/tts")
async def tts(text):
try:
content: bytes = text2speech(text)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
with open(temp_file.name, 'wb') as f:
with open(temp_file.name, "wb") as f:
f.write(content)
except RequestError as e:
raise HTTPException(status_code=500, detail=f"Error fetching data: {str(e)}")
except HTTPStatusError as e:
raise HTTPException(status_code=e.response.status_code, detail=f"Error fetching data: {str(e)}")
return FileResponse(
temp_file.name,
media_type="audio/wav",
filename="audio.wav"
)

@router.post(
"/raspi/transcript",
tags=["raspi"],
summary="whisperによる文字起こし"
)
raise HTTPException(
status_code=e.response.status_code, detail=f"Error fetching data: {str(e)}"
)
return FileResponse(temp_file.name, media_type="audio/wav", filename="audio.wav")


@router.post("/raspi/transcript", tags=["raspi"], summary="whisperによる文字起こし")
async def transcript(file: UploadFile = File(...)) -> JSONResponse:
file_location = os.path.join(UPLOAD_DIR, file.filename)
try:
Expand All @@ -120,30 +107,15 @@ async def transcript(file: UploadFile = File(...)) -> JSONResponse:
transcription: str = speech2text(file_location)
os.remove(file_location)

return JSONResponse(
content={"transcript": transcription.text},
status_code=200
)
return JSONResponse(content={"transcript": transcription.text}, status_code=200)
except Exception as e:
return JSONResponse(
content={"error": str(e)},
status_code=500
)
return JSONResponse(content={"error": str(e)}, status_code=500)

@router.post(
"/raspi/gpt",
tags=["raspi"],
summary="chatGPTによる文章生成"
)

@router.post("/raspi/gpt", tags=["raspi"], summary="chatGPTによる文章生成")
async def gpt(text: str) -> JSONResponse:
try:
generated_text: str = generate_text(text)
return JSONResponse(
content={"generatedText": generated_text},
status_code=200
)
return JSONResponse(content={"generatedText": generated_text}, status_code=200)
except Exception as e:
return JSONResponse(
content={"error": str(e)},
status_code=500
)
return JSONResponse(content={"error": str(e)}, status_code=500)
9 changes: 7 additions & 2 deletions api/services/gpt.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from openai import OpenAI

from utils.config import get_openai_api_key

OpenAI.api_key = get_openai_api_key()
client = OpenAI()


def generate_text(input_text: str) -> str:
completion = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "テスト期間中の学生を親の視点から励ますような言葉を生成してください。日本語で40文字程度になるようにお願いします。"},
{"role": "user", "content": input_text}
{
"role": "system",
"content": "テスト期間中の学生を親の視点から励ますような言葉を生成してください。日本語で40文字程度になるようにお願いします。",
},
{"role": "user", "content": input_text},
],
)
return completion.choices[0].message.content
6 changes: 2 additions & 4 deletions api/services/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
OpenAI.api_key = get_openai_api_key()
client = OpenAI()


def text2speech(text: str) -> bytes:
res = client.audio.speech.create(
model="tts-1",
voice="shimmer",
input=text,
response_format="wav"
model="tts-1", voice="shimmer", input=text, response_format="wav"
)
return res.content
13 changes: 7 additions & 6 deletions api/services/voicevox.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
import json
import urllib.parse
from typing import Any, Dict

import httpx
import urllib.parse

from utils.config import get_voicevox_url

url: str = get_voicevox_url()


async def audio_query(text: str, speaker: int) -> Dict[str, Any]:
params = {"text": text, "speaker": speaker}
async with httpx.AsyncClient() as client:
res = await client.post(
urllib.parse.urljoin(url, "/audio_query"),
params=params
urllib.parse.urljoin(url, "/audio_query"), params=params
)
res.raise_for_status()
query = res.json()
return query


async def synthesis(query: Dict[str, Any], speaker: int) -> bytes:
async with httpx.AsyncClient(timeout=httpx.Timeout(20.0)) as client:
res = await client.post(
urllib.parse.urljoin(url, "/synthesis"),
params={"speaker": speaker},
data=json.dumps(query),
headers={"Content-Type": "application/json"}
headers={"Content-Type": "application/json"},
)
res.raise_for_status()
content: bytes = res.content
return content


7 changes: 3 additions & 4 deletions api/services/voicevox_api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import httpx

from utils.config import get_voicevox_api_key

url: str = "https://deprecatedapis.tts.quest/v2/voicevox/audio/"
API_KEY = get_voicevox_api_key()


async def get_voicevox_audio(text: str, speaker: int) -> bytes:
async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
params = {"text": text, "speaker": speaker, "key": API_KEY}
res = await client.post(
url,
params=params
)
res = await client.post(url, params=params)
res.raise_for_status()
content: bytes = res.content
return content
7 changes: 3 additions & 4 deletions api/services/whisper.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from openai import OpenAI

from utils.config import get_openai_api_key

OpenAI.api_key = get_openai_api_key()
client = OpenAI()


def speech2text(file_location: str) -> str:
audio_file = open(file_location, "rb")
transcription = client.audio.transcriptions.create(
model="whisper-1",
file=audio_file,
response_format="json",
language="ja"
model="whisper-1", file=audio_file, response_format="json", language="ja"
)
return transcription
4 changes: 3 additions & 1 deletion api/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@

from main import app


@pytest.fixture
def client():
return TestClient(app)


@pytest.mark.asyncio
async def test_transcribe_and_respond(client: TestClient):
audio_file_path = "tests/audio1.wav"
with open(audio_file_path, "rb") as audio_file:
files = {'file': ("audio1.wav", audio_file, 'multipart/form-data')}
files = {"file": ("audio1.wav", audio_file, "multipart/form-data")}

response = client.post("/raspi", files=files)

Expand Down
19 changes: 14 additions & 5 deletions api/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,45 @@
import os


def check_env_variables():
is_dev_mode:bool = get_is_dev_mode()
env_vars:list[str] = [
is_dev_mode: bool = get_is_dev_mode()
env_vars: list[str] = [
"OPENAI_API_KEY",
"VOICEVOX_API_KEY",
]
env_vars_prod:list[str] = [
env_vars_prod: list[str] = [
"STORAGE_ACCOUNT_NAME",
"SAS_TOKEN",
]
if not is_dev_mode:
env_vars.extend(env_vars_prod)
missing_vars = [var for var in env_vars if os.getenv(var) is None]
if missing_vars:
raise EnvironmentError(f"Missing environment variables: {', '.join(missing_vars)}")
raise EnvironmentError(
f"Missing environment variables: {', '.join(missing_vars)}"
)


def get_is_dev_mode() -> bool:
is_dev_mode = os.getenv("IS_DEV_MODE")
return int(is_dev_mode)==1
return int(is_dev_mode) == 1


def get_voicevox_url():
return os.getenv("VOICEVOX_URL")


def get_openai_api_key():
return os.getenv("OPENAI_API_KEY")


def get_storage_account_name():
return os.getenv("STORAGE_ACCOUNT_NAME")


def get_sas_token():
return os.getenv("SAS_TOKEN")


def get_voicevox_api_key():
return os.getenv("VOICEVOX_API_KEY")
Loading

0 comments on commit 30cdb76

Please sign in to comment.