Skip to content

Commit

Permalink
update to the new plugin format
Browse files Browse the repository at this point in the history
  • Loading branch information
andantei committed Feb 25, 2023
1 parent c9f39ff commit f36469d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 20 deletions.
28 changes: 28 additions & 0 deletions bundle.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"plugins": [
{
"providerId": "andantei",
"providerDisplayName": {
"zh": "Andantei行板",
"en": "Andantei"
},
"pluginId": "audioldm-generate",
"pluginDisplayName": {
"zh": "[AI] 文字生成音频 (AudioLDM)",
"en": "[AI] Text-to-Audio (AudioLDM)"
},
"pluginDescription": {
"zh": "用一段提示词生成音频或音乐",
"en": "Generate audio or music from a short prompt text"
},
"version": "1.0.0",
"supportedPlatforms": ["web", "desktop"],
"isInDevelopment": true,
"minRequiredDesktopVersion": "1.8.9",
"options": {
"allowReset": true,
"allowManualApplyAdjust": true
}
}
]
}
4 changes: 3 additions & 1 deletion debug.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from plugin import AudioLDMPlugin
from tuneflow_devkit import Debugger
from pathlib import Path

if __name__ == "__main__":
Debugger(plugin_class=AudioLDMPlugin).start()
Debugger(plugin_class=AudioLDMPlugin, bundle_file_path=str(
Path(__file__).parent.joinpath('bundle.json').absolute())).start()
33 changes: 14 additions & 19 deletions plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,7 @@ def provider_display_name() -> LabelText:
}

@staticmethod
def plugin_display_name() -> LabelText:
return {
"zh": "[AI] 文字生成音频 (AudioLDM)",
"en": "[AI] Text-to-Audio (AudioLDM)"
}

def params(self) -> Dict[str, ParamDescriptor]:
# TODO: Limit prompt length
def params(song: Song, read_apis: ReadAPIs) -> Dict[str, ParamDescriptor]:
return {
"prompt": {
"displayName": {
Expand Down Expand Up @@ -70,7 +63,7 @@ def params(self) -> Dict[str, ParamDescriptor]:
"type": WidgetType.InputNumber.value,
"config": {
"minValue": 0.1,
"maxValue": 5,
"maxValue": 10,
"step": 0.1
}
}
Expand All @@ -85,20 +78,20 @@ def params(self) -> Dict[str, ParamDescriptor]:
"type": WidgetType.InputNumber.value,
"config": {
"minValue": 2.5,
"maxValue": 100,
"maxValue": 10,
"step": 2.5
}
}
}
}

def init(self, song: Song, read_apis: ReadAPIs):
@staticmethod
def run(song: Song, params: Dict[str, Any], read_apis: ReadAPIs):
model_path = str(Path(__file__).parent.joinpath('ckpt/ldm_trimmed.ckpt').absolute())
self.model = build_model(ckpt_path=model_path)

def run(self, song: Song, params: Dict[str, Any], read_apis: ReadAPIs):
model = build_model(ckpt_path=model_path)
# TODO: Support prompt i18n
file_bytes_list = self._text2audio(
file_bytes_list = AudioLDMPlugin._text2audio(
model,
text=params["prompt"],
duration=params["duration"],
guidance_scale=params["guidance_scale"],
Expand All @@ -119,20 +112,22 @@ def run(self, song: Song, params: Dict[str, Any], read_apis: ReadAPIs):
except:
print(traceback.format_exc())

def _text2audio(self, text, duration, guidance_scale, random_seed):
@staticmethod
def _text2audio(model, text, duration, guidance_scale, random_seed):
# print(text, length, guidance_scale)
waveform = text_to_audio(
self.model,
model,
text=text,
seed=random_seed,
duration=duration,
guidance_scale=guidance_scale,
n_candidate_gen_per_text=3,
batchsize=1,
)
return self._save_wave(waveform)
return AudioLDMPlugin._save_wave(waveform)

def _save_wave(self, waveform):
@staticmethod
def _save_wave(waveform):
saved_file_bytes: List[BytesIO] = []
for i in range(waveform.shape[0]):
file_bytes = BytesIO()
Expand Down

0 comments on commit f36469d

Please sign in to comment.