Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Live reload for Gradio demo #89

Merged
merged 16 commits into from
Aug 20, 2024
2 changes: 1 addition & 1 deletion Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ mds *FLAGS:
poetry run python -m ultravox.tools.mds_tool {{FLAGS}}

gradio *FLAGS:
poetry run python -m ultravox.tools.gradio_demo {{FLAGS}}
poetry run gradio ultravox/tools/gradio_demo.py {{FLAGS}}
juberti marked this conversation as resolved.
Show resolved Hide resolved

run *FLAGS:
poetry run mcli run -f mcloud.yaml --follow {{FLAGS}}
Expand Down
215 changes: 111 additions & 104 deletions ultravox/tools/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import simple_parsing

from ultravox.data import datasets
from ultravox.inference import base as infer_base
from ultravox.inference import ultravox_infer

demo_instruction: str = """Enter your prompt here (audio will be inserted at the end or at <|audio|>).
Expand All @@ -28,7 +29,7 @@ class DemoConfig:
temperature: float = 0


def main():
juberti marked this conversation as resolved.
Show resolved Hide resolved
if gr.NO_RELOAD:
args = simple_parsing.parse(config_class=DemoConfig)
inference = ultravox_infer.UltravoxInference(
args.model_path,
Expand All @@ -37,116 +38,122 @@ def main():
conversation_mode=True,
)

def add_text(chatbot: gr.Chatbot, text: str) -> gr.Chatbot:
return chatbot + [(text, None)]

def add_audio(chatbot: gr.Chatbot, audio: str) -> gr.Chatbot:
return chatbot + [((audio,), None)]

def process_turn(
chatbot: gr.Chatbot,
prompt: str,
audio: Optional[str] = None,
max_new_tokens: int = 200,
temperature: float = 0,
):
# We want to keep the prompt (mixed audio/text instruction) as is in voice mode, but set it to "" in anticipation of new prompt in text mode.
prompt_to_return = prompt
if audio:
if "<|audio|>" not in prompt:
prompt += "<|audio|>"
sample = datasets.VoiceSample.from_prompt_and_file(prompt, audio)
else:
sample = datasets.VoiceSample.from_prompt(prompt)
prompt_to_return = ""

if len(sample.messages) != 1:
raise ValueError(
f"Expected exactly 1 message in sample but got {len(sample.messages)}"
)

output = inference.infer(
sample,
max_tokens=max_new_tokens,
temperature=temperature,
)
def add_text(chatbot: gr.Chatbot, text: str) -> gr.Chatbot:
# We set the prompt to "" in anticipation of the next prompt in text mode.
return chatbot + [[text, None]], ""

chatbot = chatbot + [(None, output.text)]
return chatbot, gr.update(value=prompt_to_return)

def process_text(chatbot, prompt, max_new_tokens, temperature):
return process_turn(
chatbot, prompt, max_new_tokens=max_new_tokens, temperature=temperature
)
def add_audio(chatbot: gr.Chatbot, audio: str) -> gr.Chatbot:
# We want to keep the prompt (mixed audio/text instruction) as is in voice mode.
return chatbot + [[(audio,), None]]

juberti marked this conversation as resolved.
Show resolved Hide resolved
def process_audio(chatbot, prompt, audio, max_new_tokens, temperature):
return process_turn(
chatbot,
prompt,
audio=audio,
max_new_tokens=max_new_tokens,
temperature=temperature,
)

def gradio_reset():
inference.update_conversation()
return [], "", None

with gr.Blocks() as demo:
chatbot = gr.Chatbot(scale=10, height=1000)

with gr.Row():
with gr.Column(scale=1):
reset = gr.Button("Reset")
audio = gr.Audio(
label="🎤",
sources=["microphone"],
type="filepath",
visible=True,
)
with gr.Column(scale=8):
prompt = gr.Textbox(
show_label=False,
lines=5,
placeholder=demo_instruction,
value=args.default_prompt,
container=True,
)
with gr.Column(scale=1):
max_new_tokens = gr.Slider(
minimum=50,
maximum=2000,
value=args.max_new_tokens,
step=10,
interactive=True,
label="max_new_tokens",
)
temperature = gr.Slider(
minimum=0,
maximum=5.0,
value=args.temperature,
step=0.1,
interactive=True,
label="temperature",
)

prompt.submit(add_text, [chatbot, prompt], [chatbot], queue=False).then(
process_text,
[chatbot, prompt, max_new_tokens, temperature],
[chatbot, prompt],
queue=False,
)
audio.stop_recording(add_audio, [chatbot, audio], [chatbot], queue=False).then(
process_audio,
[chatbot, prompt, audio, max_new_tokens, temperature],
[chatbot, prompt],
queue=False,
def process_turn(
chatbot: gr.Chatbot,
prompt: str,
audio: Optional[str] = None,
max_new_tokens: int = 200,
temperature: float = 0,
):
if audio:
if "<|audio|>" not in prompt:
prompt += "<|audio|>"
sample = datasets.VoiceSample.from_prompt_and_file(prompt, audio)
else:
# Note that prompt will be "" here, since we cleared it in add_text.
# Instead, we can just get it from the chat history.
sample = datasets.VoiceSample.from_prompt(chatbot[-1][0])

if len(sample.messages) != 1:
raise ValueError(
f"Expected exactly 1 message in sample but got {len(sample.messages)}"
)
reset.click(gradio_reset, [], [chatbot, prompt, audio], queue=False)
demo.load(gradio_reset, [], [chatbot, prompt, audio], queue=False)

demo.launch(share=True)
output = inference.infer_stream(
sample,
max_tokens=max_new_tokens,
temperature=temperature,
)
chatbot += [[None, ""]]
for chunk in output:
if isinstance(chunk, infer_base.InferenceChunk):
chatbot[-1][1] += chunk.text
yield chatbot

juberti marked this conversation as resolved.
Show resolved Hide resolved

def process_text(chatbot, prompt, max_new_tokens, temperature):
yield from process_turn(
chatbot, prompt, max_new_tokens=max_new_tokens, temperature=temperature
)


def process_audio(chatbot, prompt, audio, max_new_tokens, temperature):
yield from process_turn(
chatbot,
prompt,
audio=audio,
max_new_tokens=max_new_tokens,
temperature=temperature,
)


def gradio_reset():
inference.reset_conversation()
juberti marked this conversation as resolved.
Show resolved Hide resolved
return [], "", None


with gr.Blocks() as demo:
chatbot = gr.Chatbot(scale=10, height=1000)

with gr.Row():
with gr.Column(scale=1):
reset = gr.Button("Reset")
audio = gr.Audio(
label="🎤",
sources=["microphone"],
type="filepath",
visible=True,
)
with gr.Column(scale=8):
prompt = gr.Textbox(
show_label=False,
lines=5,
placeholder=demo_instruction,
value=args.default_prompt,
container=True,
)
with gr.Column(scale=1):
max_new_tokens = gr.Slider(
minimum=50,
maximum=2000,
value=args.max_new_tokens,
step=10,
interactive=True,
label="max_new_tokens",
)
temperature = gr.Slider(
minimum=0,
maximum=5.0,
value=args.temperature,
step=0.1,
interactive=True,
label="temperature",
)
prompt.submit(add_text, [chatbot, prompt], [chatbot, prompt], queue=False).then(
juberti marked this conversation as resolved.
Show resolved Hide resolved
process_text,
[chatbot, prompt, max_new_tokens, temperature],
[chatbot],
)
audio.stop_recording(add_audio, [chatbot, audio], [chatbot], queue=False).then(
juberti marked this conversation as resolved.
Show resolved Hide resolved
process_audio,
[chatbot, prompt, audio, max_new_tokens, temperature],
[chatbot],
)
reset.click(gradio_reset, [], [chatbot, prompt, audio], queue=False)
demo.load(gradio_reset, [], [chatbot, prompt, audio], queue=False)


if __name__ == "__main__":
main()
demo.queue()
demo.launch(share=True)
Loading