Skip to content

Commit

Permalink
Live reload for Gradio demo (#89)
Browse files Browse the repository at this point in the history
* Live reload for Gradio demo

Also adds streaming of output tokens.
  • Loading branch information
juberti authored Aug 20, 2024
1 parent b2dc7f1 commit cb488b9
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 168 deletions.
2 changes: 1 addition & 1 deletion Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ mds *FLAGS:
poetry run python -m ultravox.tools.mds_tool {{FLAGS}}

gradio *FLAGS:
poetry run python -m ultravox.tools.gradio_demo {{FLAGS}}
poetry run gradio ultravox/tools/gradio_demo.py {{FLAGS}}

run *FLAGS:
poetry run mcli run -f mcloud.yaml --follow {{FLAGS}}
Expand Down
119 changes: 68 additions & 51 deletions ultravox/inference/infer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import threading
from concurrent import futures
from typing import Dict, List, Optional, Tuple, Union

import librosa
Expand All @@ -13,8 +14,6 @@

SAMPLE_RATE = 16000
MAX_NEW_TOKENS = 1024
# Without this penalty, the model tends to repeat itself.
REPETITION_PENALTY = 1.1


class LocalInference(base.VoiceInference):
Expand Down Expand Up @@ -59,6 +58,49 @@ def _get_sample_with_past(
sample.add_past_messages(self.past_messages)
return sample

def _build_past_messages(
self,
query_messages: List[Dict[str, str]],
audio_token_len: int,
response_content: str,
) -> List[Dict[str, str]]:
messages = copy.copy(query_messages)
if audio_token_len > 0:
user_content = messages[-1]["content"]
if user_content.count("<|audio|>") != 1:
raise ValueError(
f"Expected 1 audio placeholder, found {user_content.count('<|audio|>')}"
)
messages[-1]["content"] = user_content.replace(
"<|audio|>", self.tokenizer.eos_token * audio_token_len
)
messages.append({"role": "assistant", "content": response_content})
return messages

def infer(
self,
sample: datasets.VoiceSample,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> base.VoiceOutput:
extended_sample = self._get_sample_with_past(sample)
inputs = self._dataproc(extended_sample)
input_len = inputs["input_ids"].shape[1]
output = self._generate(
inputs, max_tokens, temperature, past_key_values=self.past_key_values
)
output_tokens = output.sequences[0][input_len:]
output_text = self.tokenizer.decode(output_tokens, skip_special_tokens=True)
output_len = len(output_tokens)
if self.conversation_mode:
audio_token_len = inputs.get("audio_token_len", [0])[0]
past_messages = self._build_past_messages(
extended_sample.messages, audio_token_len, output_text
)
self.update_conversation(past_messages, output.past_key_values)
return base.VoiceOutput(output_text, input_len, output_len)

# Note: infer_batch doesn't support conversation mode or caching yet.
def infer_batch(
self,
samples: List[datasets.VoiceSample],
Expand All @@ -85,70 +127,46 @@ def infer_batch(
output_texts.append(output_text)
return output_texts

def infer(
self,
sample: datasets.VoiceSample,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> base.VoiceOutput:
extended_sample = self._get_sample_with_past(sample)
inputs = self._dataproc(extended_sample)
input_len = inputs["input_ids"].shape[1]
output = self._generate(
inputs, max_tokens, temperature, past_key_values=self.past_key_values
)
output_tokens = output.sequences[0][input_len:]
output_text = self.tokenizer.decode(output_tokens, skip_special_tokens=True)
output_len = len(output_tokens)

if self.conversation_mode:
past_messages = copy.deepcopy(extended_sample.messages)
audio_token_len = (
0 if "audio_token_len" not in inputs else inputs["audio_token_len"][0]
)
if audio_token_len > 0:
user_content = past_messages[-1]["content"]
if user_content.count("<|audio|>") != 1:
raise ValueError(
f"Expected 1 audio placeholder, found {user_content.count('<|audio|>')}"
)
past_messages[-1]["content"] = user_content.replace(
"<|audio|>", self.tokenizer.eos_token * audio_token_len
)
past_messages.append({"role": "assistant", "content": output_text})
self.update_conversation(past_messages, output.past_key_values)

return base.VoiceOutput(output_text, input_len, output_len)

# streaming is not supported in conversation mode yet, to be implemented
def infer_stream(
self,
sample: datasets.VoiceSample,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> base.InferenceGenerator:
inputs = self._dataproc(sample)
extended_sample = self._get_sample_with_past(sample)
inputs = self._dataproc(extended_sample)
input_tokens = inputs["input_ids"].shape[1]
decode_kwargs = {"skip_special_tokens": True}
streamer = transformers.TextIteratorStreamer(
self.tokenizer, skip_prompt=True, decode_kwargs=decode_kwargs
self.tokenizer, skip_prompt=True, skip_special_tokens=True
)

thread_args = (
inputs,
max_tokens,
temperature,
streamer,
def thunk(f: futures.Future):
result = self._generate(
inputs, max_tokens, temperature, streamer, self.past_key_values
)
f.set_result(result)

future: futures.Future[transformers.GenerateDecoderOnlyOutput] = (
futures.Future()
)
thread = threading.Thread(target=self._generate, args=thread_args)
thread = threading.Thread(target=thunk, args=(future,))
thread.start()
output_tokens = 0
output_text = ""
output_token_len = 0
for chunk in streamer:
if chunk:
output_text += chunk
output_token_len += 1
yield base.InferenceChunk(chunk)
output_tokens += 1
yield base.InferenceStats(input_tokens, output_tokens)
thread.join()
output = future.result()
if self.conversation_mode:
audio_token_len = inputs.get("audio_token_len", [0])[0]
past_messages = self._build_past_messages(
extended_sample.messages, audio_token_len, output_text
)
self.update_conversation(past_messages, output.past_key_values)
yield base.InferenceStats(input_tokens, output_token_len)

def _dataproc(self, sample: datasets.VoiceSample):
text_input = self.tokenizer.apply_chat_template(
Expand Down Expand Up @@ -208,7 +226,6 @@ def _generate(
do_sample=do_sample,
max_new_tokens=max_new_tokens or MAX_NEW_TOKENS,
temperature=temperature,
repetition_penalty=REPETITION_PENALTY,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=terminators,
streamer=streamer,
Expand Down
4 changes: 1 addition & 3 deletions ultravox/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ def default_device():
return (
"cuda"
if torch.cuda.is_available()
# until https://github.com/pytorch/pytorch/issues/77764 is resolved
# else "mps" if torch.backends.mps.is_available() else "cpu"
else "cpu"
else "mps" if torch.backends.mps.is_available() else "cpu"
)


Expand Down
Loading

0 comments on commit cb488b9

Please sign in to comment.