Skip to content

Commit

Permalink
feat: add the audio tool (langgenius#10695)
Browse files Browse the repository at this point in the history
  • Loading branch information
hjlarry authored Nov 14, 2024
1 parent b358490 commit 15f341b
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 0 deletions.
3 changes: 3 additions & 0 deletions api/core/tools/provider/builtin/audio/_assets/icon.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions api/core/tools/provider/builtin/audio/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController


class AudioToolProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
pass
11 changes: 11 additions & 0 deletions api/core/tools/provider/builtin/audio/audio.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
identity:
author: hjlarry
name: audio
label:
en_US: Audio
description:
en_US: A tool for tts and asr.
zh_Hans: 一个用于文本转语音和语音转文本的工具。
icon: icon.svg
tags:
- utilities
70 changes: 70 additions & 0 deletions api/core/tools/provider/builtin/audio/tools/asr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import io
from typing import Any

from core.file.enums import FileType
from core.file.file_manager import download
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
from core.tools.tool.builtin_tool import BuiltinTool
from services.model_provider_service import ModelProviderService


class ASRTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
file = tool_parameters.get("audio_file")
if file.type != FileType.AUDIO:
return [self.create_text_message("not a valid audio file")]
audio_binary = io.BytesIO(download(file))
audio_binary.name = "temp.mp3"
provider, model = tool_parameters.get("model").split("#")
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id,
provider=provider,
model_type=ModelType.SPEECH2TEXT,
model=model,
)
text = model_instance.invoke_speech2text(
file=audio_binary,
user=user_id,
)
return [self.create_text_message(text)]

def get_available_models(self) -> list[tuple[str, str]]:
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(
tenant_id=self.runtime.tenant_id, model_type="speech2text"
)
items = []
for provider_model in models:
provider = provider_model.provider
for model in provider_model.models:
items.append((provider, model.model))
return items

def get_runtime_parameters(self) -> list[ToolParameter]:
parameters = []

options = []
for provider, model in self.get_available_models():
option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
options.append(option)

parameters.append(
ToolParameter(
name="model",
label=I18nObject(en_US="Model", zh_Hans="Model"),
human_description=I18nObject(
en_US="All available ASR models",
zh_Hans="所有可用的 ASR 模型",
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=True,
default=options[0].value,
options=options,
)
)
return parameters
22 changes: 22 additions & 0 deletions api/core/tools/provider/builtin/audio/tools/asr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
identity:
name: asr
author: hjlarry
label:
en_US: Speech To Text
description:
human:
en_US: Convert audio file to text.
zh_Hans: 将音频文件转换为文本。
llm: Convert audio file to text.
parameters:
- name: audio_file
type: file
required: true
label:
en_US: Audio File
zh_Hans: 音频文件
human_description:
en_US: The audio file to be converted.
zh_Hans: 要转换的音频文件。
llm_description: The audio file to be converted.
form: llm
90 changes: 90 additions & 0 deletions api/core/tools/provider/builtin/audio/tools/tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import io
from typing import Any

from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
from core.tools.tool.builtin_tool import BuiltinTool
from services.model_provider_service import ModelProviderService


class TTSTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
provider, model = tool_parameters.get("model").split("#")
voice = tool_parameters.get(f"voice#{provider}#{model}")
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id,
provider=provider,
model_type=ModelType.TTS,
model=model,
)
tts = model_instance.invoke_tts(
content_text=tool_parameters.get("text"),
user=user_id,
tenant_id=self.runtime.tenant_id,
voice=voice,
)
buffer = io.BytesIO()
for chunk in tts:
buffer.write(chunk)

wav_bytes = buffer.getvalue()
return [
self.create_text_message("Audio generated successfully"),
self.create_blob_message(
blob=wav_bytes,
meta={"mime_type": "audio/x-wav"},
save_as=self.VariableKey.AUDIO,
),
]

def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
items = []
for provider_model in models:
provider = provider_model.provider
for model in provider_model.models:
voices = model.model_properties.get(ModelPropertyKey.VOICES, [])
items.append((provider, model.model, voices))
return items

def get_runtime_parameters(self) -> list[ToolParameter]:
parameters = []

options = []
for provider, model, voices in self.get_available_models():
option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
options.append(option)
parameters.append(
ToolParameter(
name=f"voice#{provider}#{model}",
label=I18nObject(en_US=f"Voice of {model}({provider})"),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
options=[
ToolParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name")))
for voice in voices
],
)
)

parameters.insert(
0,
ToolParameter(
name="model",
label=I18nObject(en_US="Model", zh_Hans="Model"),
human_description=I18nObject(
en_US="All available TTS models",
zh_Hans="所有可用的 TTS 模型",
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=True,
default=options[0].value,
options=options,
),
)
return parameters
22 changes: 22 additions & 0 deletions api/core/tools/provider/builtin/audio/tools/tts.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
identity:
name: tts
author: hjlarry
label:
en_US: Text To Speech
description:
human:
en_US: Convert text to audio file.
zh_Hans: 将文本转换为音频文件。
llm: Convert text to audio file.
parameters:
- name: text
type: string
required: true
label:
en_US: Text
zh_Hans: 文本
human_description:
en_US: The text to be converted.
zh_Hans: 要转换的文本。
llm_description: The text to be converted.
form: llm

0 comments on commit 15f341b

Please sign in to comment.