From b7ebd849eb74a6e7a1b4929fdc5a1a2c991b93d1 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Fri, 2 Jun 2023 00:27:32 +0200 Subject: [PATCH] feat: add listable items, fix voices to return voices object --- .gitignore | 3 +- elevenlabs/api/base.py | 19 +++++- elevenlabs/api/history.py | 12 ++-- elevenlabs/api/model.py | 12 ++-- elevenlabs/api/voice.py | 12 ++-- elevenlabs/simple.py | 122 +++++++++++++++++++------------------- 6 files changed, 96 insertions(+), 84 deletions(-) diff --git a/.gitignore b/.gitignore index f5d5a25..f9eab5f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__ .mypy_cache +.pytest_cache .DS_Store -TODO.md +NOTES.md diff --git a/elevenlabs/api/base.py b/elevenlabs/api/base.py index 5cb6df9..fd13135 100644 --- a/elevenlabs/api/base.py +++ b/elevenlabs/api/base.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Optional, Sequence import requests # type: ignore from pydantic import BaseModel @@ -12,7 +12,7 @@ UnauthenticatedRateLimitError, ) -api_base_url_v1 = "https://api.elevenlabs.io/v1" +api_base_url_v1 = os.environ.get("ELEVEN_BASE_URL", "https://api.elevenlabs.io/v1") class API(BaseModel): @@ -67,3 +67,18 @@ def post(url: str, *args, **kwargs): @staticmethod def delete(url: str, *args, **kwargs): return API.request(url, method="delete", *args, **kwargs) # type: ignore + + +class Listable: + @property + def items(self) -> Sequence: + raise NotImplementedError + + def __getitem__(self, idx: int): + return self.items[idx] + + def __iter__(self): + return iter(self.items) + + def __len__(self) -> int: + return len(self.items) diff --git a/elevenlabs/api/history.py b/elevenlabs/api/history.py index 3e5385c..2744c68 100644 --- a/elevenlabs/api/history.py +++ b/elevenlabs/api/history.py @@ -5,7 +5,7 @@ from pydantic import root_validator -from .base import API, api_base_url_v1 +from .base import API, Listable, api_base_url_v1 from .voice import VoiceSettings @@ -59,7 +59,7 @@ def audio(self) -> bytes: return self._audio -class History(API): +class History(Listable, API): history: List[HistoryItem] @classmethod @@ -68,8 +68,6 @@ def from_api(cls) -> History: response = API.get(url).json() return cls(**response) - def __getitem__(self, idx: int) -> HistoryItem: - return self.history[idx] - - def __iter__(self): - return iter(self.history) + @property + def items(self): + return self.history diff --git a/elevenlabs/api/model.py b/elevenlabs/api/model.py index 6dea5c0..f138397 100644 --- a/elevenlabs/api/model.py +++ b/elevenlabs/api/model.py @@ -2,7 +2,7 @@ from typing import List, Optional -from .base import API, api_base_url_v1 +from .base import API, Listable, api_base_url_v1 class Model(API): @@ -12,7 +12,7 @@ class Model(API): description: Optional[str] -class Models(API): +class Models(Listable, API): models: List[Model] @classmethod @@ -21,8 +21,6 @@ def from_api(cls) -> Models: response = cls.get(url).json() return cls(models=response) - def __getitem__(self, idx: int) -> Model: - return self.models[idx] - - def __iter__(self): - return iter(self.models) + @property + def items(self): + return self.models diff --git a/elevenlabs/api/voice.py b/elevenlabs/api/voice.py index 5ab4766..5111d96 100644 --- a/elevenlabs/api/voice.py +++ b/elevenlabs/api/voice.py @@ -6,7 +6,7 @@ from pydantic import Field, root_validator, validator -from .base import API, api_base_url_v1 +from .base import API, Listable, api_base_url_v1 from .error import APIError @@ -138,7 +138,7 @@ def delete(self): API.delete(f"{api_base_url_v1}/voices/{self.voice_id}") -class Voices(API): +class Voices(Listable, API): voices: List[Voice] @classmethod @@ -150,8 +150,6 @@ def from_api(cls, api_key: Optional[str] = None): def add_clone(self, voice_clone: VoiceClone) -> Voice: pass - def __getitem__(self, idx: int) -> Voice: - return self.voices[idx] - - def __iter__(self): - return iter(self.voices) + @property + def items(self): + return self.voices diff --git a/elevenlabs/simple.py b/elevenlabs/simple.py index 210eba2..8b8a9b9 100644 --- a/elevenlabs/simple.py +++ b/elevenlabs/simple.py @@ -1,6 +1,6 @@ import os import re -from typing import Iterator, List, Optional, Union +from typing import Iterator, Optional, Union from .api import TTS, Model, Voice, VoiceClone, Voices, VoiceSettings @@ -14,65 +14,67 @@ def get_api_key() -> Optional[str]: # Save default voices to avoid querying the API for unathorized users -VOICES_CACHE = [ - Voice( - voice_id="21m00Tcm4TlvDq8ikWAM", - name="Rachel", - category="premade", - settings=VoiceSettings(stability=0.75, similarity_boost=0.75), - ), - Voice( - voice_id="AZnzlk1XvdvUeBnXmlld", - name="Domi", - category="premade", - settings=VoiceSettings(stability=0.1, similarity_boost=0.75), - ), - Voice( - voice_id="EXAVITQu4vr4xnSDxMaL", - name="Bella", - category="premade", - settings=VoiceSettings(stability=0.245, similarity_boost=0.75), - ), - Voice( - voice_id="ErXwobaYiN019PkySvjV", - name="Antoni", - category="premade", - settings=VoiceSettings(stability=0.195, similarity_boost=0.75), - ), - Voice( - voice_id="MF3mGyEYCl7XYWbV9V6O", - name="Elli", - category="premade", - settings=VoiceSettings(stability=0.755, similarity_boost=0.75), - ), - Voice( - voice_id="TxGEqnHWrfWFTfGW9XjX", - name="Josh", - category="premade", - settings=VoiceSettings(stability=0.15, similarity_boost=0.51), - ), - Voice( - voice_id="VR6AewLTigWG4xSOukaG", - name="Arnold", - category="premade", - settings=VoiceSettings(stability=0.15, similarity_boost=0.75), - ), - Voice( - voice_id="pNInz6obpgDQGcFmaJgB", - name="Adam", - category="premade", - settings=VoiceSettings(stability=0.2, similarity_boost=0.75), - ), - Voice( - voice_id="yoZ06aMxZJJ28mfd3POQ", - name="Sam", - category="premade", - settings=VoiceSettings(stability=0.25, similarity_boost=0.75), - ), -] - - -def voices() -> List[Voice]: +VOICES_CACHE = Voices( + voices=[ + Voice( + voice_id="21m00Tcm4TlvDq8ikWAM", + name="Rachel", + category="premade", + settings=VoiceSettings(stability=0.75, similarity_boost=0.75), + ), + Voice( + voice_id="AZnzlk1XvdvUeBnXmlld", + name="Domi", + category="premade", + settings=VoiceSettings(stability=0.1, similarity_boost=0.75), + ), + Voice( + voice_id="EXAVITQu4vr4xnSDxMaL", + name="Bella", + category="premade", + settings=VoiceSettings(stability=0.245, similarity_boost=0.75), + ), + Voice( + voice_id="ErXwobaYiN019PkySvjV", + name="Antoni", + category="premade", + settings=VoiceSettings(stability=0.195, similarity_boost=0.75), + ), + Voice( + voice_id="MF3mGyEYCl7XYWbV9V6O", + name="Elli", + category="premade", + settings=VoiceSettings(stability=0.755, similarity_boost=0.75), + ), + Voice( + voice_id="TxGEqnHWrfWFTfGW9XjX", + name="Josh", + category="premade", + settings=VoiceSettings(stability=0.15, similarity_boost=0.51), + ), + Voice( + voice_id="VR6AewLTigWG4xSOukaG", + name="Arnold", + category="premade", + settings=VoiceSettings(stability=0.15, similarity_boost=0.75), + ), + Voice( + voice_id="pNInz6obpgDQGcFmaJgB", + name="Adam", + category="premade", + settings=VoiceSettings(stability=0.2, similarity_boost=0.75), + ), + Voice( + voice_id="yoZ06aMxZJJ28mfd3POQ", + name="Sam", + category="premade", + settings=VoiceSettings(stability=0.25, similarity_boost=0.75), + ), + ] +) + + +def voices() -> Voices: """Lists all voices in the API, if authenticated for the current user""" api_key = get_api_key() global VOICES_CACHE