Skip to content

Commit

Permalink
feat: add input streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Aug 5, 2023
1 parent c1b5560 commit b17b05f
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 5 deletions.
73 changes: 73 additions & 0 deletions elevenlabs/api/tts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from __future__ import annotations

import base64
import json
import os
from typing import Iterator, Optional

import websockets
from websockets.sync.client import connect

from .base import API, api_base_url_v1
from .model import Model
from .voice import Voice
Expand Down Expand Up @@ -40,3 +46,70 @@ def generate_stream(
for chunk in response.iter_content(chunk_size=stream_chunk_size):
if chunk:
yield chunk

@staticmethod
def generate_stream_input(
text: Iterator[str], voice: Voice, model: Model, api_key: Optional[str] = None
) -> Iterator[bytes]:
message = (
"Currently input streaming is only supported for eleven_monolingual_v1"
" model"
)
assert (
model.model_id == "eleven_monolingual_v1"
), f"{message}, got {model.model_id}"

BOS = json.dumps(
dict(
text=" ",
try_trigger_generation=True,
generation_config=dict(
chunk_length_schedule=[50],
model_id=model.model_id,
voice_settings=voice.settings.dict() if voice.settings else None,
),
)
)
EOS = json.dumps(dict(text=""))

with connect(
f"wss://api.elevenlabs.io/v1/text-to-speech/{voice.voice_id}/stream-input",
additional_headers={
"xi-api-key": api_key or os.environ.get("ELEVEN_API_KEY"),
"model_id": model.model_id,
},
) as websocket:
# Send beginning of stream
websocket.send(BOS)

# Stream text chunks and receive audio
text_block = ""
for text_chunk in text:

text_block += text_chunk
if text_block.endswith((".", "!", "?")):
text_block += " "
if not text_block.endswith(" "):
continue

data = dict(text=text_block, try_trigger_generation=True)
text_block = ""
websocket.send(json.dumps(data))
try:
data = json.loads(websocket.recv(1e-4))
if data["audio"]:
yield base64.b64decode(data["audio"]) # type: ignore
except TimeoutError:
pass

# Send end of stream
websocket.send(EOS)

# Receive remaining audio
while True:
try:
data = json.loads(websocket.recv())
if data["audio"]:
yield base64.b64decode(data["audio"]) # type: ignore
except websockets.exceptions.ConnectionClosed:
break
12 changes: 8 additions & 4 deletions elevenlabs/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def is_voice_id(val: str) -> bool:


def generate(
text: str,
text: Union[str, Iterator[str]],
api_key: Optional[str] = None,
voice: Union[str, Voice] = VOICES_CACHE[2], # Bella
model: Union[str, Model] = "eleven_monolingual_v1",
Expand Down Expand Up @@ -123,8 +123,12 @@ def generate(
assert isinstance(model, Model)

if stream:
return TTS.generate_stream(
text, voice, model, stream_chunk_size, api_key=api_key, latency=latency
) # noqa E501
if isinstance(text, str):
return TTS.generate_stream(
text, voice, model, stream_chunk_size, api_key=api_key, latency=latency
) # noqa E501
elif isinstance(text, Iterator):
return TTS.generate_stream_input(text, voice, model, api_key=api_key)
else:
assert isinstance(text, str)
return TTS.generate(text, voice, model, api_key=api_key)
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="elevenlabs",
packages=find_packages(exclude=[]),
version="0.2.21",
version="0.2.22",
description="The official elevenlabs python package.",
long_description_content_type="text/markdown",
author="Elevenlabs",
Expand All @@ -13,6 +13,7 @@
"pydantic>=1.10,<2.0",
"ipython>=7.0",
"requests>=2.20",
"websockets>=11.0",
],
classifiers=[
"Development Status :: 4 - Beta",
Expand Down

0 comments on commit b17b05f

Please sign in to comment.