Skip to content

Commit

Permalink
More validations (#48)
Browse files Browse the repository at this point in the history
* add heysquad and slue-sqa5 datasets

* multi-ds evaluations

* add spanish and chinese evals

* remove chinese and spanish val sets due to hang

* "Transcribe <|audio|>" to "Transcribe\n<|audio|>"

* _get_messages helper function

* moved contenxt len check to _get_query_prompt
  • Loading branch information
farzadab authored Jul 23, 2024
1 parent 552e7b1 commit c3c8dd1
Show file tree
Hide file tree
Showing 10 changed files with 299 additions and 124 deletions.
291 changes: 214 additions & 77 deletions ultravox/data/datasets.py

Large diffs are not rendered by default.

38 changes: 31 additions & 7 deletions ultravox/data/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_transcribe_dataset():
sample = next(iter(ds))
assert isinstance(sample, datasets.VoiceSample)
assert sample.messages == [
{"role": "user", "content": "Transcribe <|audio|>"},
{"role": "user", "content": "Transcribe\n<|audio|>"},
{"role": "assistant", "content": "0"},
]
assert np.array_equal(sample.audio, np.zeros(256))
Expand All @@ -119,18 +119,18 @@ def test_transcribe_dataset():
def test_num_prompts():
ds = FakeTranscribeDataset(5, datasets.VoiceDatasetArgs(num_prompts=3))
samples = list(ds)
assert samples[0].messages[0]["content"] == "Transcribe <|audio|>"
assert samples[0].messages[0]["content"] == "Transcribe\n<|audio|>"
assert (
samples[1].messages[0]["content"]
== "Repeat exactly what is written here: <|audio|>"
)
assert (
samples[2].messages[0]["content"]
== "Transcribe exactly what is said here <|audio|>"
== "Transcribe exactly what is said here\n<|audio|>"
)
assert (
samples[3].messages[0]["content"]
== "Transcribe exactly what is said here <|audio|>"
== "Transcribe exactly what is said here\n<|audio|>"
)


Expand Down Expand Up @@ -168,14 +168,14 @@ def _create_and_validate_sample(target_dtype: str = "float32"):
# kHz, with an amplitude of 0.1, and the specified dtype.
array = _create_sine_wave(target_dtype=target_dtype)
sample = datasets.VoiceSample.from_prompt_and_raw(
"Transcribe <|audio|>", array, 16000
"Transcribe\n<|audio|>", array, 16000
)
assert sample.sample_rate == 16000
assert sample.audio is not None, "sample.audio should not be None"
assert len(sample.audio) == 16000
assert sample.audio.dtype == np.float32
assert sample.messages == [
{"role": "user", "content": "Transcribe <|audio|>"},
{"role": "user", "content": "Transcribe\n<|audio|>"},
]
# Serialize and deserialize the sample.
json = sample.to_json()
Expand Down Expand Up @@ -208,5 +208,29 @@ def test_create_sample__raises_on_unsupported_dtype():
with pytest.raises(AssertionError):
array = np.ndarray(shape=(16000,), dtype=np.uint8)
sample = datasets.VoiceSample.from_prompt_and_raw(
"Transcribe <|audio|>", array, 16000
"Transcribe\n<|audio|>", array, 16000
)


def test_get_messages():
messages = datasets._get_messages("Yo!", "Hi!")
assert messages == [
{"role": "user", "content": "Yo!"},
{"role": "assistant", "content": "Hi!"},
]

messages = datasets._get_messages(
"Yo!", "Hi!", assistant_last=False, sys_prompt="Be nice!"
)
assert messages == [
{"role": "system", "content": "Be nice!"},
{"role": "assistant", "content": "Yo!"},
{"role": "user", "content": "Hi!"},
]

messages = datasets._get_messages("A", "B", "C")
assert messages == [
{"role": "assistant", "content": "A"},
{"role": "user", "content": "B"},
{"role": "assistant", "content": "C"},
]
6 changes: 3 additions & 3 deletions ultravox/inference/infer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_infer_16kHz(tokenizer, audio_processor):
inference = FakeInference(tokenizer, audio_processor)
array = np.ones(16000, dtype=np.float32)
sample = datasets.VoiceSample.from_prompt_and_raw(
"Transcribe <|audio|>", array, 16000
"Transcribe\n<|audio|>", array, 16000
)
output = inference.infer(sample)
assert output.input_tokens == 20
Expand All @@ -89,7 +89,7 @@ def test_infer_48kHz(tokenizer, audio_processor):
inference = FakeInference(tokenizer, audio_processor)
array = np.ones(48000, dtype=np.float32)
sample = datasets.VoiceSample.from_prompt_and_raw(
"Transcribe <|audio|>", array, 48000
"Transcribe\n<|audio|>", array, 48000
)
output = inference.infer(sample)
assert output.input_tokens == 20
Expand All @@ -112,7 +112,7 @@ def test_infer_16kHz_stream(tokenizer, audio_processor):
inference = FakeInference(tokenizer, audio_processor)
array = np.ones(16000, dtype=np.float32)
sample = datasets.VoiceSample.from_prompt_and_raw(
"Transcribe <|audio|>", array, 16000
"Transcribe\n<|audio|>", array, 16000
)
gen = inference.infer_stream(sample)
text = ""
Expand Down
20 changes: 9 additions & 11 deletions ultravox/model/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,12 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]:
)

# Extract input_ids, attention_mask, and audio_values from the processed inputs
input_ids = inputs["input_ids"].squeeze(0)
attention_mask = inputs["attention_mask"].squeeze(0)
audio_values = inputs["audio_values"].squeeze(0)
audio_token_start_idx = inputs["audio_token_start_idx"].squeeze(0)
audio_token_len = inputs["audio_token_len"].squeeze(0)
input_ids = inputs["input_ids"].squeeze_(0)
inputs["attention_mask"].squeeze_(0)
if "audio_values" in inputs:
inputs["audio_values"].squeeze_(0)
inputs["audio_token_start_idx"].squeeze_(0)
inputs["audio_token_len"].squeeze_(0)

# No need to shift the labels as the model does it internally
labels = input_ids.clone()
Expand All @@ -73,7 +74,7 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]:
# One reason is that there's very little randomness in the prompt, so the model would be forced to memorize it.
#
# Example (-100 is the ignore index):
# Tokens: <user> Transcribe <|audio|> </s> <assistant> Brown fox jumps over the lazy dog </s>
# Tokens: <user> Transcribe\n<|audio|> </s> <assistant> Brown fox jumps over the lazy dog </s>
# Labels: -100 -100 -100 -100 <assistant> Brown fox jumps over the lazy dog </s>
#
# Note: The above might look weird because I'm mixing token IDs and text, but that's just for illustration.
Expand All @@ -91,10 +92,7 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]:
labels[:input_text_len] = -100

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"audio_values": audio_values,
**inputs,
# input_ids, attention_mask, audio_values, audio_token_start_idx, audio_token_len
"labels": labels,
"audio_token_start_idx": audio_token_start_idx,
"audio_token_len": audio_token_len,
}
2 changes: 1 addition & 1 deletion ultravox/model/ultravox_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __call__(
data["audio_token_start_idx"] = [start_idx]

# Replace the audio placeholder with the audio token.
# e.g. "Transcribe <|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
# e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
# where the number of </s> is the number of audio frames.
text = text.replace(
self.audio_placeholder,
Expand Down
4 changes: 2 additions & 2 deletions ultravox/tools/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class DemoConfig:
# runs/llama2_asr_gigaspeech/checkpoint-1000/
# wandb://fixie/ultravox/model-llama2_asr_gigaspeech:v0
model_path: str = "fixie-ai/ultravox"
default_prompt: str = "Transcribe <|audio|>"
default_prompt: str = "Transcribe\n<|audio|>"


def main():
Expand All @@ -32,7 +32,7 @@ def wrapper(text: str, audio: Tuple[int, np.ndarray]) -> str:
gr.Audio(label="Audio", show_download_button=True),
]
outputs = [gr.Textbox(label="Output")]
examples = [["Transcribe <|audio|>", "examples/test16.wav"]]
examples = [["Transcribe\n<|audio|>", "examples/test16.wav"]]

gr.Interface(fn=wrapper, inputs=inputs, outputs=outputs, examples=examples).launch(
share=True
Expand Down
2 changes: 1 addition & 1 deletion ultravox/tools/infer_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# transcription of the audio content and the tool can perfom a WER calculation.
# Remember to set the --asr flag when using an ASR input.
DEFAULT_PROMPT = "Listen to <|audio|> and respond to it"
DEFAULT_ASR_PROMPT = "Transcribe <|audio|>"
DEFAULT_ASR_PROMPT = "Transcribe\n<|audio|>"


@dataclasses.dataclass
Expand Down
1 change: 1 addition & 0 deletions ultravox/training/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
@dataclasses.dataclass
class TrainConfig:
data_sets: List[str]
val_sets: List[str]
# language model to use
text_model: str
# audio encoder model to use
Expand Down
5 changes: 3 additions & 2 deletions ultravox/training/configs/meta_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ text_model: "meta-llama/Meta-Llama-3-8B-Instruct"
audio_model: "facebook/wav2vec2-base-960h"

data_sets: ["gigaspeech"]
val_sets: ["heysquad_human", "anyinstruct", "soda", "peoplespeech"]
repeat_data: True

train_on_inputs: False
shuffle_data: True
max_audio_duration_secs: 16

val_num_samples: 128
val_steps: 500
val_num_samples: 64
val_steps: 1000
eval_num_samples: 256
eval_max_new_tokens: 32
eval_num_procs: 16
Expand Down
54 changes: 34 additions & 20 deletions ultravox/training/train.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import copy
import dataclasses
import glob
import logging
import os
import re
import sys
from datetime import datetime
from typing import List, Optional
from typing import Dict, List, Optional

import datasets as hf_datasets
import safetensors.torch
Expand All @@ -28,7 +29,7 @@
from ultravox.training import ddp_utils
from ultravox.training import evaluation

INPUT_EXAMPLE = {"text": "Transcribe <|audio|>", "audio": b"\x00\x00" * 16000}
INPUT_EXAMPLE = {"text": "Transcribe\n<|audio|>", "audio": b"\x00\x00" * 16000}
OUTPUT_EXAMPLE = {"text": "Hello, world!"}


Expand Down Expand Up @@ -160,11 +161,18 @@ def main() -> None:
f"Using dtype and device (world_size): {dtype}, {device} ({world_size})"
)
model.to(device=device, dtype=dtype)
# TODO: check if the whole model can now be moved to dtype instead

# Prepare dataset, subsetting if needed
train_dataset: data.IterableDataset
val_dataset: data.IterableDataset
val_datasets: Dict[str, data.IterableDataset]
# We use multiple validation sets here so that the results are comparable even when training set changes
# To make sure we can compare training and validation loss (e.g. for fine-tuning), we keep a special set
# called "matchtrain" that uses the same data as the training set.
val_sets = dict(
[("matchtrain", args.data_sets)]
+ [(x, [x]) for x in args.val_sets]
+ [(f"text_{x}", [x]) for x in args.val_sets]
)
if is_master:
train_dataset = prepare_dataset(
dataset_names=args.data_sets,
Expand All @@ -182,29 +190,35 @@ def main() -> None:
mds_batch_size=args.batch_size,
),
)
val_dataset = prepare_dataset(
dataset_names=args.data_sets,
train_on_inputs=args.train_on_inputs,
repeat_data=args.repeat_data,
processor=processor,
num_samples=args.val_num_samples,
data_args=datasets.VoiceDatasetArgs(
num_prompts=1,
data_dir=args.data_dir,
shuffle=False,
max_audio_duration_secs=16,
use_mds=args.mds,
mds_batch_size=args.batch_size,
),
val_ds_args = datasets.VoiceDatasetArgs(
num_prompts=1,
data_dir=args.data_dir,
shuffle=False,
max_audio_duration_secs=16,
use_mds=args.mds,
mds_batch_size=args.batch_size,
)
val_ds_args_text = copy.copy(val_ds_args)
val_ds_args_text.include_audio = False
val_datasets = {
k: prepare_dataset(
dataset_names=val_sets[k],
train_on_inputs=args.train_on_inputs,
repeat_data=args.repeat_data,
processor=processor,
num_samples=args.val_num_samples,
data_args=val_ds_args_text if k.startswith("text_") else val_ds_args,
)
for k in val_sets
}
logging.info(
f"Loaded {args.data_sets} data sets, sample limit: {args.num_samples} (val sample limit: {args.val_num_samples})"
)
else:
# When using DDP with split_batches=True, the primary process will distribute the batches to the workers
# The point of this is to avoid unnecessary data processing/downloading in the workers.
train_dataset = datasets.EmptyDataset()
val_dataset = datasets.EmptyDataset()
val_datasets = {k: datasets.EmptyDataset() for k in val_sets}

# Set up the data loader
data_collator = datasets.DataCollatorForSeq2SeqWithAudio(tokenizer=text_tokenizer)
Expand All @@ -214,7 +228,7 @@ def main() -> None:
trainer = transformers.Seq2SeqTrainer(
model,
train_dataset=train_dataset,
eval_dataset=val_dataset,
eval_dataset=val_datasets,
data_collator=data_collator,
tokenizer=text_tokenizer,
args=transformers.Seq2SeqTrainingArguments(
Expand Down

0 comments on commit c3c8dd1

Please sign in to comment.