Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Live reload for Gradio demo #89

Merged
merged 16 commits into from
Aug 20, 2024
2 changes: 1 addition & 1 deletion Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ mds *FLAGS:
poetry run python -m ultravox.tools.mds_tool {{FLAGS}}

gradio *FLAGS:
poetry run python -m ultravox.tools.gradio_demo {{FLAGS}}
poetry run gradio ultravox/tools/gradio_demo.py {{FLAGS}}
juberti marked this conversation as resolved.
Show resolved Hide resolved

run *FLAGS:
poetry run mcli run -f mcloud.yaml --follow {{FLAGS}}
Expand Down
72 changes: 45 additions & 27 deletions ultravox/inference/infer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import queue
import threading
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -53,6 +54,25 @@ def _get_sample_with_past(
sample.add_past_messages(self.past_messages)
return sample

def _build_past_messages(
self,
query_messages: List[Dict[str, str]],
audio_token_len: int,
response_content: str,
) -> List[Dict[str, str]]:
messages = query_messages[:]
juberti marked this conversation as resolved.
Show resolved Hide resolved
if audio_token_len > 0:
user_content = messages[-1]["content"]
if user_content.count("<|audio|>") != 1:
raise ValueError(
f"Expected 1 audio placeholder, found {user_content.count('<|audio|>')}"
)
messages[-1]["content"] = user_content.replace(
"<|audio|>", self.tokenizer.eos_token * audio_token_len
)
messages.append({"role": "assistant", "content": response_content})
return messages

def infer(
self,
sample: datasets.VoiceSample,
Expand All @@ -68,55 +88,53 @@ def infer(
output_tokens = output.sequences[0][input_len:]
output_text = self.tokenizer.decode(output_tokens, skip_special_tokens=True)
juberti marked this conversation as resolved.
Show resolved Hide resolved
output_len = len(output_tokens)

if self.conversation_mode:
past_messages = copy.deepcopy(extended_sample.messages)
audio_token_len = (
0 if "audio_token_len" not in inputs else inputs["audio_token_len"][0]
audio_token_len = inputs.get("audio_token_len", [0])[0]
past_messages = self._build_past_messages(
extended_sample.messages, audio_token_len, output_text
)
if audio_token_len > 0:
user_content = past_messages[-1]["content"]
if user_content.count("<|audio|>") != 1:
raise ValueError(
f"Expected 1 audio placeholder, found {user_content.count('<|audio|>')}"
)
past_messages[-1]["content"] = user_content.replace(
"<|audio|>", self.tokenizer.eos_token * audio_token_len
)
past_messages.append({"role": "assistant", "content": output_text})
self.update_conversation(past_messages, output.past_key_values)

return base.VoiceOutput(output_text, input_len, output_len)

# streaming is not supported in conversation mode yet, to be implemented
def infer_stream(
self,
sample: datasets.VoiceSample,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> base.InferenceGenerator:
inputs = self._dataproc(sample)
extended_sample = self._get_sample_with_past(sample)
inputs = self._dataproc(extended_sample)
input_tokens = inputs["input_ids"].shape[1]
decode_kwargs = {"skip_special_tokens": True}
streamer = transformers.TextIteratorStreamer(
self.tokenizer, skip_prompt=True, decode_kwargs=decode_kwargs
)

thread_args = (
inputs,
max_tokens,
temperature,
streamer,
)
thread = threading.Thread(target=self._generate, args=thread_args)
def thunk(q: queue.Queue):
result = self._generate(
inputs, max_tokens, temperature, streamer, self.past_key_values
)
q.put(result)

result_queue: queue.Queue = queue.Queue()
juberti marked this conversation as resolved.
Show resolved Hide resolved
thread = threading.Thread(target=thunk, args=(result_queue,))
thread.start()
output_tokens = 0
output_text = ""
output_token_len = 0
for chunk in streamer:
if chunk:
output_text += chunk
output_token_len += 1
juberti marked this conversation as resolved.
Show resolved Hide resolved
yield base.InferenceChunk(chunk)
output_tokens += 1
yield base.InferenceStats(input_tokens, output_tokens)
thread.join()
output = result_queue.get()
if self.conversation_mode:
audio_token_len = inputs.get("audio_token_len", [0])[0]
past_messages = self._build_past_messages(
extended_sample.messages, audio_token_len, output_text
)
juberti marked this conversation as resolved.
Show resolved Hide resolved
self.update_conversation(past_messages, output.past_key_values)
yield base.InferenceStats(input_tokens, output_token_len)

def _dataproc(self, sample: datasets.VoiceSample):
text_input = self.tokenizer.apply_chat_template(
Expand Down
4 changes: 1 addition & 3 deletions ultravox/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ def default_device():
return (
"cuda"
if torch.cuda.is_available()
# until https://github.com/pytorch/pytorch/issues/77764 is resolved
# else "mps" if torch.backends.mps.is_available() else "cpu"
else "cpu"
else "mps" if torch.backends.mps.is_available() else "cpu"
)


Expand Down
215 changes: 111 additions & 104 deletions ultravox/tools/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import simple_parsing

from ultravox.data import datasets
from ultravox.inference import base as infer_base
from ultravox.inference import ultravox_infer

demo_instruction: str = """Enter your prompt here (audio will be inserted at the end or at <|audio|>).
Expand All @@ -28,7 +29,7 @@ class DemoConfig:
temperature: float = 0


def main():
juberti marked this conversation as resolved.
Show resolved Hide resolved
if gr.NO_RELOAD:
args = simple_parsing.parse(config_class=DemoConfig)
inference = ultravox_infer.UltravoxInference(
args.model_path,
Expand All @@ -37,116 +38,122 @@ def main():
conversation_mode=True,
)

def add_text(chatbot: gr.Chatbot, text: str) -> gr.Chatbot:
return chatbot + [(text, None)]

def add_audio(chatbot: gr.Chatbot, audio: str) -> gr.Chatbot:
return chatbot + [((audio,), None)]

def process_turn(
chatbot: gr.Chatbot,
prompt: str,
audio: Optional[str] = None,
max_new_tokens: int = 200,
temperature: float = 0,
):
# We want to keep the prompt (mixed audio/text instruction) as is in voice mode, but set it to "" in anticipation of new prompt in text mode.
prompt_to_return = prompt
if audio:
if "<|audio|>" not in prompt:
prompt += "<|audio|>"
sample = datasets.VoiceSample.from_prompt_and_file(prompt, audio)
else:
sample = datasets.VoiceSample.from_prompt(prompt)
prompt_to_return = ""

if len(sample.messages) != 1:
raise ValueError(
f"Expected exactly 1 message in sample but got {len(sample.messages)}"
)

output = inference.infer(
sample,
max_tokens=max_new_tokens,
temperature=temperature,
)
def add_text(chatbot: gr.Chatbot, text: str) -> gr.Chatbot:
# We set the prompt to "" in anticipation of the next prompt in text mode.
return chatbot + [[text, None]], ""

chatbot = chatbot + [(None, output.text)]
return chatbot, gr.update(value=prompt_to_return)

def process_text(chatbot, prompt, max_new_tokens, temperature):
return process_turn(
chatbot, prompt, max_new_tokens=max_new_tokens, temperature=temperature
)
def add_audio(chatbot: gr.Chatbot, audio: str) -> gr.Chatbot:
# We want to keep the prompt (mixed audio/text instruction) as is in voice mode.
return chatbot + [[(audio,), None]]

juberti marked this conversation as resolved.
Show resolved Hide resolved
def process_audio(chatbot, prompt, audio, max_new_tokens, temperature):
return process_turn(
chatbot,
prompt,
audio=audio,
max_new_tokens=max_new_tokens,
temperature=temperature,
)

def gradio_reset():
inference.update_conversation()
return [], "", None

with gr.Blocks() as demo:
chatbot = gr.Chatbot(scale=10, height=1000)

with gr.Row():
with gr.Column(scale=1):
reset = gr.Button("Reset")
audio = gr.Audio(
label="🎤",
sources=["microphone"],
type="filepath",
visible=True,
)
with gr.Column(scale=8):
prompt = gr.Textbox(
show_label=False,
lines=5,
placeholder=demo_instruction,
value=args.default_prompt,
container=True,
)
with gr.Column(scale=1):
max_new_tokens = gr.Slider(
minimum=50,
maximum=2000,
value=args.max_new_tokens,
step=10,
interactive=True,
label="max_new_tokens",
)
temperature = gr.Slider(
minimum=0,
maximum=5.0,
value=args.temperature,
step=0.1,
interactive=True,
label="temperature",
)

prompt.submit(add_text, [chatbot, prompt], [chatbot], queue=False).then(
process_text,
[chatbot, prompt, max_new_tokens, temperature],
[chatbot, prompt],
queue=False,
)
audio.stop_recording(add_audio, [chatbot, audio], [chatbot], queue=False).then(
process_audio,
[chatbot, prompt, audio, max_new_tokens, temperature],
[chatbot, prompt],
queue=False,
def process_turn(
chatbot: gr.Chatbot,
prompt: str,
audio: Optional[str] = None,
max_new_tokens: int = 200,
temperature: float = 0,
):
if audio:
if "<|audio|>" not in prompt:
prompt += "<|audio|>"
sample = datasets.VoiceSample.from_prompt_and_file(prompt, audio)
else:
# Note that prompt will be "" here, since we cleared it in add_text.
# Instead, we can just get it from the chat history.
sample = datasets.VoiceSample.from_prompt(chatbot[-1][0])

if len(sample.messages) != 1:
raise ValueError(
f"Expected exactly 1 message in sample but got {len(sample.messages)}"
)
reset.click(gradio_reset, [], [chatbot, prompt, audio], queue=False)
demo.load(gradio_reset, [], [chatbot, prompt, audio], queue=False)

demo.launch(share=True)
output = inference.infer_stream(
sample,
max_tokens=max_new_tokens,
temperature=temperature,
)
chatbot += [[None, ""]]
for chunk in output:
if isinstance(chunk, infer_base.InferenceChunk):
chatbot[-1][1] += chunk.text
yield chatbot

juberti marked this conversation as resolved.
Show resolved Hide resolved

def process_text(chatbot, prompt, max_new_tokens, temperature):
yield from process_turn(
chatbot, prompt, max_new_tokens=max_new_tokens, temperature=temperature
)


def process_audio(chatbot, prompt, audio, max_new_tokens, temperature):
yield from process_turn(
chatbot,
prompt,
audio=audio,
max_new_tokens=max_new_tokens,
temperature=temperature,
)


def gradio_reset():
inference.update_conversation()
return [], "", None


with gr.Blocks() as demo:
chatbot = gr.Chatbot(scale=10, height=1000)

with gr.Row():
with gr.Column(scale=1):
reset = gr.Button("Reset")
audio = gr.Audio(
label="🎤",
sources=["microphone"],
type="filepath",
visible=True,
)
with gr.Column(scale=8):
prompt = gr.Textbox(
show_label=False,
lines=5,
placeholder=demo_instruction,
value=args.default_prompt,
container=True,
)
with gr.Column(scale=1):
max_new_tokens = gr.Slider(
minimum=50,
maximum=2000,
value=args.max_new_tokens,
step=10,
interactive=True,
label="max_new_tokens",
)
temperature = gr.Slider(
minimum=0,
maximum=5.0,
value=args.temperature,
step=0.1,
interactive=True,
label="temperature",
)
prompt.submit(add_text, [chatbot, prompt], [chatbot, prompt], queue=False).then(
juberti marked this conversation as resolved.
Show resolved Hide resolved
process_text,
[chatbot, prompt, max_new_tokens, temperature],
[chatbot],
)
audio.stop_recording(add_audio, [chatbot, audio], [chatbot], queue=False).then(
juberti marked this conversation as resolved.
Show resolved Hide resolved
process_audio,
[chatbot, prompt, audio, max_new_tokens, temperature],
[chatbot],
)
reset.click(gradio_reset, [], [chatbot, prompt, audio], queue=False)
demo.load(gradio_reset, [], [chatbot, prompt, audio], queue=False)


if __name__ == "__main__":
main()
demo.queue()
demo.launch(share=True)
Loading