Skip to content

Commit

Permalink
Fix api (#321)
Browse files Browse the repository at this point in the history
  • Loading branch information
AnyaCoder authored Jul 4, 2024
1 parent cee42e7 commit 61b6609
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 21 deletions.
12 changes: 3 additions & 9 deletions tools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,23 +200,17 @@ def inference(req: InvokeRequest):
lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)

if lab_path and wav_path:
with open(wav_path, "rb") as wav_file:
audio_bytes = wav_file.read()
with open(lab_path, "r", encoding="utf-8") as lab_file:
ref_text = lab_file.read()
req.reference_audio = base64.b64encode(audio_bytes).decode("utf-8")
req.reference_audio = wav_path
req.reference_text = ref_text
logger.info("ref_path: " + str(wav_path))
logger.info("ref_text: " + ref_text)

# Parse reference audio aka prompt
prompt_tokens = encode_reference(
decoder_model=decoder_model,
reference_audio=(
io.BytesIO(base64.b64decode(req.reference_audio))
if req.reference_audio is not None
else None
),
reference_audio=req.reference_audio,
enable_reference_audio=req.reference_audio is not None,
)

Expand Down Expand Up @@ -423,7 +417,7 @@ def parse_args():
text="Hello world.",
reference_text=None,
reference_audio=None,
max_new_tokens=1024,
max_new_tokens=0,
top_p=0.7,
repetition_penalty=1.2,
temperature=0.7,
Expand Down
13 changes: 1 addition & 12 deletions tools/post_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,6 @@
import requests


def wav_to_base64(file_path):
if not file_path:
return None
with open(file_path, "rb") as wav_file:
wav_content = wav_file.read()
base64_encoded = base64.b64encode(wav_content)
return base64_encoded.decode("utf-8")


def play_audio(audio_content, format, channels, rate):
p = pyaudio.PyAudio()
stream = p.open(format=format, channels=channels, rate=rate, output=True)
Expand Down Expand Up @@ -88,12 +79,10 @@ def play_audio(audio_content, format, channels, rate):

args = parser.parse_args()

base64_audio = wav_to_base64(args.reference_audio)

data = {
"text": args.text,
"reference_text": args.reference_text,
"reference_audio": base64_audio,
"reference_audio": args.reference_audio,
"max_new_tokens": args.max_new_tokens,
"chunk_length": args.chunk_length,
"top_p": args.top_p,
Expand Down

0 comments on commit 61b6609

Please sign in to comment.