diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 91a34fa3..8e0dc021 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -29,8 +29,8 @@ TRANSCRIBE_PROMPTS = [ # from Gazelle - "Transcribe <|audio|>", - "Transcribe exactly what is said here <|audio|>", + "Transcribe\n<|audio|>", + "Transcribe exactly what is said here\n<|audio|>", "Repeat exactly what is written here: <|audio|>", "Write exactly what was said: <|audio|>", "First listen to the clip. Then, transcribe exactly what is said. <|audio|>", @@ -70,13 +70,16 @@ @dataclasses.dataclass class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq): def __call__(self, features, *args, **kwargs): - audio_features = [f.pop("audio_values") for f in features] + audio_values = [f.pop("audio_values", None) for f in features] batch = super().__call__(features, *args, **kwargs) + # Pad the last dimension of all audio_values to the same length, with 0s on the right. - max_len = max([x.shape[-1] for x in audio_features]) - batch["audio_values"] = torch.stack( - [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_features] - ) + if audio_values and audio_values[0] is not None: + max_len = max([x.shape[-1] for x in audio_values]) + batch["audio_values"] = torch.stack( + [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values] + ) + return batch @@ -194,6 +197,8 @@ class VoiceDatasetArgs: """Whether to include audio in the samples.""" include_context: bool = True """Whether to include additional textual context from the dataset to the prompt.""" + max_context_length: int = 1500 + """Maximum length of context to include in the prompt. Otherwise, skip the sample.""" shuffle: bool = False """Whether to shuffle the dataset.""" shuffle_seed: int = 42 @@ -212,6 +217,30 @@ def __post_init__(self): self.split = DatasetSplit(self.split.lower()) +def _get_messages( + *turns: str, sys_prompt: Optional[str] = None, assistant_last: bool = True +) -> List[Dict[str, str]]: + """ + Convert a list of strings into a list of messages, alternating between user and assistant. + If `sys_prompt` is set, it is prepended as a system message. + If `assistant_last` is True, the assistant's message is the last one. + """ + messages = [] + + if sys_prompt: + messages.append({"role": "system", "content": sys_prompt}) + + roles = ["user", "assistant"] + + # Make sure the last turn is the assistant's iff assistant_last is True. + if (len(turns) + assistant_last) % 2 == 0: + roles = roles[::-1] + + messages += [{"role": roles[i % 2], "content": c} for i, c in enumerate(turns)] + + return messages + + class VoiceDataset(abc.ABC, data.IterableDataset): """ Base class for streaming voice datasets. @@ -273,16 +302,21 @@ def _load_audio_dataset( def __iter__(self): for _, row in enumerate(self._dataset): sample = self._get_sample(row) - if ( - self._args.max_audio_duration_secs is None - or sample.audio.shape[-1] / SAMPLE_RATE - <= self._args.max_audio_duration_secs - ): - yield sample + if sample is not None: + if ( + self._args.max_audio_duration_secs is None + or sample.audio is None + or sample.audio.shape[-1] / SAMPLE_RATE + <= self._args.max_audio_duration_secs + ): + yield sample @abc.abstractmethod - def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample: - pass + def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: + """ + Converts a row from the dataset into a VoiceSample. + Returns None if the sample should be skipped. + """ def _choice(self, prompts: List[str]) -> str: return self._rng.choice(prompts[: self._args.num_prompts]) @@ -302,16 +336,13 @@ def _get_answer_messages( ) -> List[Dict[str, str]]: prompt = self._get_answer_prompt() if self._args.include_audio else question user_content = f"{context}\n\n{prompt}" if context else prompt - return [ - {"role": "user", "content": user_content}, - {"role": "assistant", "content": answer}, - ] + return _get_messages(user_content, answer) def _get_transcribe_messages(self, text: str) -> List[Dict[str, str]]: - return [ - {"role": "user", "content": self._get_transcribe_prompt()}, - {"role": "assistant", "content": text}, - ] + prompt = self._get_transcribe_prompt() + if not self._args.include_audio: + prompt = prompt.replace("<|audio|>", text) + return _get_messages(prompt, text) def _get_audio( self, row: transformers.BatchFeature, column_name: str = "audio" @@ -476,7 +507,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None: ) self._init_dataset(dataset) - def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample: + def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: question = row["question"] answer = "True" if row["answer"] else "False" context = row["passage"] if self._args.include_context else None @@ -492,23 +523,10 @@ def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample: return self._get_transcribe_sample(row, tcol="question") -class BoolQWithExtendedAnswerDataset(BoolQDataset): +class QAVoiceDatasetMixin(VoiceDataset): SEPARATORS = ["\n\n", "\n", "\n----\n"] - BOOLQ_PASSAGE_PROMPTS = [ - "Provide a short explanation, then respond with True/False on the last line", - "Explain briefly, concluding with True/False on a new line." - "Write a quick explanation, and finish with True/False on the last line" - "Summarize in a few words, and end with True/False on a new line." - "Give a brief explanation first, then answer with True/False on the final line", - "Start with a concise explanation, and end with a True/False response on the last line.", - "Explain briefly and follow up with True/False at the end", - "Write a short explanation, then state True/False on a new line.", - "First, offer a brief explanation, and then reply with True/False at the end.", - "Present a concise explanation, ending with a True/False answer on the final line", - "Start with a brief explanation, and then answer with True/False at the end.", - ] - QUERY_PROMPTS = ["Question: ", "Question:\n", "Q: ", "Q:\n", "Query: ", "Query:\n"] - CONTEXT_PROMPTS = [ + QUERY_PREFIX = ["Question: ", "Question:\n", "Q: ", "Q:\n", "Query: ", "Query:\n"] + CONTEXT_PREFIX = [ "Passage: ", "Passage:\n", "Context: ", @@ -516,7 +534,7 @@ class BoolQWithExtendedAnswerDataset(BoolQDataset): "Background: ", "Background:\n", ] - ANSWER_PROMPTS = [ + ANSWER_PREFIX = [ "Answer: ", "A: ", "", @@ -524,57 +542,175 @@ class BoolQWithExtendedAnswerDataset(BoolQDataset): "Result: ", "Conclusion: ", ] + # In most cases there is no extra prompt-suffix needed + PROMPT_SUFFIXES = [""] - def _get_query_prompt(self) -> str: + # TODO: combine `_get_query_prompt` and `_get_answer_messages` into a single method + # and use this mixin for all non-ASR datasets. + def _get_query_prompt(self, question_str: str, context: str) -> Optional[str]: """ - Creates a random prompt for a BoolQ sample with a passage and question. - Example prompt: - <|user|> Passage: {context} + Creates a random prompt for a QA sample with a passage and question. + Example prompt: + Passage: {context} Question: {question} - - Provide a short explanation, then respond with True/False on the last line. - - <|assistant|> {short_explanation} - Answer: {answer} + {optional-prompt-suffix} """ + if len(context) > self._args.max_context_length: + # Skip samples with long context + return None + if self._args.prompt: - return self._args.prompt - prompt = self._choice(self.BOOLQ_PASSAGE_PROMPTS) + prompt = self._args.prompt + else: + prompt = self._choice(self.PROMPT_SUFFIXES) # Separate either with 1 or 2 newlines separator = self._choice(self.SEPARATORS) - query_prompt = self._choice(self.QUERY_PROMPTS) - prompt = f"{query_prompt}{{question}}{separator}{prompt}" + query_prompt = self._choice(self.QUERY_PREFIX) + question = "<|audio|>" if self._args.include_audio else question_str + prompt = f"{query_prompt}{question}{separator}{prompt}" if self._args.include_context: - context_prompt = self._choice(self.CONTEXT_PROMPTS) - prompt = f"{context_prompt}{{context}}{separator}{prompt}" + context_prompt = self._choice(self.CONTEXT_PREFIX) + prompt = f"{context_prompt}{context}{separator}{prompt}" - return prompt + return prompt.strip() - def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample: + +class BoolQWithExtendedAnswerDataset(BoolQDataset, QAVoiceDatasetMixin): + """ + A version of BoolQ that includes the context in the prompt and a longer explanation in the answer. + """ + + PROMPT_SUFFIXES = [ + "Provide a short explanation, then respond with True/False on the last line", + "Explain briefly, concluding with True/False on a new line." + "Write a quick explanation, and finish with True/False on the last line" + "Summarize in a few words, and end with True/False on a new line." + "Give a brief explanation first, then answer with True/False on the final line", + "Start with a concise explanation, and end with a True/False response on the last line.", + "Explain briefly and follow up with True/False at the end", + "Write a short explanation, then state True/False on a new line.", + "First, offer a brief explanation, and then reply with True/False at the end.", + "Present a concise explanation, ending with a True/False answer on the final line", + "Start with a brief explanation, and then answer with True/False at the end.", + ] + + def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: + """ + Example conversation: + <|user|> Passage: {context} + Question: {question} + Provide a short explanation, then respond with True/False on the last line + <|assistant|> {short_explanation} + Answer: {answer} + """ answer = "True" if row["answer"] else "False" - answer_prompt = self._choice(self.ANSWER_PROMPTS) - query_prompt = self._get_query_prompt() - user_content = query_prompt.format( - question="<|audio|>" if self._args.include_audio else row["question"], - context=row["passage"], + answer_prompt = self._choice(self.ANSWER_PREFIX) + user_message = self._get_query_prompt( + question_str=row["question"], context=row["passage"] + ) + if user_message is None: + # Skips samples with long context + return None + + messages = _get_messages( + user_message, f"{row['explanation']}\n{answer_prompt}{answer}" + ) + + return self._make_sample( + messages, self._get_audio(row), audio_transcript=row["question"] + ) + + +class HeySQuADHumanDataset(QAVoiceDatasetMixin): + """ + HeySQuAD is a large-scale Spoken Question Answering (SQA) dataset which includes 76k human-spoken questions, + 97k machine-generated questions, and their corresponding textual answers from the SQuAD QA dataset. + https://arxiv.org/abs/2304.13689 + + This dataset is the human-spoken version of HeySQuAD. + """ + + def __init__(self, args: VoiceDatasetArgs) -> None: + super().__init__(args) + dataset = self._load_audio_dataset( + "yijingwu/HeySQuAD_human", split=args.split.value + ) + self._init_dataset(dataset) + + def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: + """ + Example conversation + <|user|> Context: {context} + Question: {question} + <|assistant|> {answer} + """ + if row["is_impossible"] or not row["answers"]: + # Skip samples with no answer + return None + + prompt = self._get_query_prompt( + question_str=row["question"], context=row["context"] ) - messages = [ - {"role": "user", "content": user_content}, - { - "role": "assistant", - "content": f"{row['explanation']}\n{answer_prompt}{answer}", - }, - ] + if prompt is None: + # Skips samples with long context + return None + messages = _get_messages(prompt, row["answers"][0]["text"]) return self._make_sample( messages, self._get_audio(row), audio_transcript=row["question"] ) +class SlueSQA5Dataset(QAVoiceDatasetMixin): + """ + SLUE-SQA-5 Dataset contains question texts, question audio, answer text, document text, and document audio from these datasets: + * SQuAD1.1 (for questions whose question_id starts with 'squad-') + * Natural Questions (for questions whose question_id starts with 'nq-') + * TriviaQA (for questions whose question_id starts with 'triviaqa-') + The following datasets are supposed to be included, but I haven't found them everywhere: + * WebQuestions (for questions whose question_id starts with 'wq-') + * CuratedTREC (for questions whose question_id starts with 'trec-') + * Spoken Wikipedia + + + Splits: train, validation, test, verified_test + """ + + BASE_AUDIO_COLUMNS = ["question_audio", "document_audio"] + + def __init__(self, args: VoiceDatasetArgs) -> None: + super().__init__(args) + dataset = self._load_audio_dataset( + "asapp/slue-phase-2", "sqa5", split=args.split.value + ) + self._init_dataset(dataset) + + def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: + """ + Example conversation + <|user|> Context: {context} + Question: {question} + <|assistant|> {answer} + """ + prompt = self._get_query_prompt( + question_str=row["raw_question_text"], context=row["raw_document_text"] + ) + if prompt is None: + # Skips samples with long context + return None + + messages = _get_messages(prompt, row["answer_spans"]["answer"][0]) + return self._make_sample( + messages, + self._get_audio(row, "question_audio"), + audio_transcript=row["raw_question_text"], + ) + + class LibriSpeechDataset(VoiceDataset): """ LibriSpeech is a corpus of approximately 1000 hours of 16kHz read @@ -658,10 +794,10 @@ class CommonVoiceDataset(VoiceDataset): NOTE: requires HF login """ - def __init__(self, args: VoiceDatasetArgs) -> None: + def __init__(self, args: VoiceDatasetArgs, lang: str = "en") -> None: super().__init__(args) dataset = self._load_audio_dataset( - "mozilla-foundation/common_voice_16_1", "en", split=args.split.value + "mozilla-foundation/common_voice_16_1", lang, split=args.split.value ) self._init_dataset(dataset) @@ -719,10 +855,8 @@ def _get_sample(self, row) -> VoiceSample: sys_prompt = self._choice(self.SYS_PROMPTS) - messages = [{"role": "system", "content": sys_prompt}] - messages += [ - {"role": roles[i % 2], "content": turn} for i, turn in enumerate(turns) - ] + messages = _get_messages(*turns[:-1], sys_prompt=sys_prompt) + messages[-1]["content"] = row["alt_last_turn"] if self._args.include_audio: messages[-2]["content"] = "<|audio|>" @@ -742,6 +876,8 @@ def create_dataset(name: str, args: VoiceDatasetArgs) -> data.IterableDataset: "boolq": BoolQDataset, "boolq_in": BoolQInputDataset, "boolq_extended": BoolQWithExtendedAnswerDataset, + "heysquad_human": HeySQuADHumanDataset, + "slue_sqa5": SlueSQA5Dataset, "gigaspeech": GigaSpeechDataset, "librispeech": LibriSpeechDataset, "voxpopuli": VoxPopuliDataset, @@ -750,7 +886,8 @@ def create_dataset(name: str, args: VoiceDatasetArgs) -> data.IterableDataset: "soda": SodaDataset, "dummy": LibriSpeechDummyDataset, } - return DATASET_MAP[name](args) + name, *ext = name.split(":") + return DATASET_MAP[name](args, *ext) class InterleaveDataset(data.IterableDataset): diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index bc9110a6..65a28d76 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -108,7 +108,7 @@ def test_transcribe_dataset(): sample = next(iter(ds)) assert isinstance(sample, datasets.VoiceSample) assert sample.messages == [ - {"role": "user", "content": "Transcribe <|audio|>"}, + {"role": "user", "content": "Transcribe\n<|audio|>"}, {"role": "assistant", "content": "0"}, ] assert np.array_equal(sample.audio, np.zeros(256)) @@ -119,18 +119,18 @@ def test_transcribe_dataset(): def test_num_prompts(): ds = FakeTranscribeDataset(5, datasets.VoiceDatasetArgs(num_prompts=3)) samples = list(ds) - assert samples[0].messages[0]["content"] == "Transcribe <|audio|>" + assert samples[0].messages[0]["content"] == "Transcribe\n<|audio|>" assert ( samples[1].messages[0]["content"] == "Repeat exactly what is written here: <|audio|>" ) assert ( samples[2].messages[0]["content"] - == "Transcribe exactly what is said here <|audio|>" + == "Transcribe exactly what is said here\n<|audio|>" ) assert ( samples[3].messages[0]["content"] - == "Transcribe exactly what is said here <|audio|>" + == "Transcribe exactly what is said here\n<|audio|>" ) @@ -168,14 +168,14 @@ def _create_and_validate_sample(target_dtype: str = "float32"): # kHz, with an amplitude of 0.1, and the specified dtype. array = _create_sine_wave(target_dtype=target_dtype) sample = datasets.VoiceSample.from_prompt_and_raw( - "Transcribe <|audio|>", array, 16000 + "Transcribe\n<|audio|>", array, 16000 ) assert sample.sample_rate == 16000 assert sample.audio is not None, "sample.audio should not be None" assert len(sample.audio) == 16000 assert sample.audio.dtype == np.float32 assert sample.messages == [ - {"role": "user", "content": "Transcribe <|audio|>"}, + {"role": "user", "content": "Transcribe\n<|audio|>"}, ] # Serialize and deserialize the sample. json = sample.to_json() @@ -208,5 +208,29 @@ def test_create_sample__raises_on_unsupported_dtype(): with pytest.raises(AssertionError): array = np.ndarray(shape=(16000,), dtype=np.uint8) sample = datasets.VoiceSample.from_prompt_and_raw( - "Transcribe <|audio|>", array, 16000 + "Transcribe\n<|audio|>", array, 16000 ) + + +def test_get_messages(): + messages = datasets._get_messages("Yo!", "Hi!") + assert messages == [ + {"role": "user", "content": "Yo!"}, + {"role": "assistant", "content": "Hi!"}, + ] + + messages = datasets._get_messages( + "Yo!", "Hi!", assistant_last=False, sys_prompt="Be nice!" + ) + assert messages == [ + {"role": "system", "content": "Be nice!"}, + {"role": "assistant", "content": "Yo!"}, + {"role": "user", "content": "Hi!"}, + ] + + messages = datasets._get_messages("A", "B", "C") + assert messages == [ + {"role": "assistant", "content": "A"}, + {"role": "user", "content": "B"}, + {"role": "assistant", "content": "C"}, + ] diff --git a/ultravox/inference/infer_test.py b/ultravox/inference/infer_test.py index 12c84e09..c597f42a 100644 --- a/ultravox/inference/infer_test.py +++ b/ultravox/inference/infer_test.py @@ -66,7 +66,7 @@ def test_infer_16kHz(tokenizer, audio_processor): inference = FakeInference(tokenizer, audio_processor) array = np.ones(16000, dtype=np.float32) sample = datasets.VoiceSample.from_prompt_and_raw( - "Transcribe <|audio|>", array, 16000 + "Transcribe\n<|audio|>", array, 16000 ) output = inference.infer(sample) assert output.input_tokens == 20 @@ -89,7 +89,7 @@ def test_infer_48kHz(tokenizer, audio_processor): inference = FakeInference(tokenizer, audio_processor) array = np.ones(48000, dtype=np.float32) sample = datasets.VoiceSample.from_prompt_and_raw( - "Transcribe <|audio|>", array, 48000 + "Transcribe\n<|audio|>", array, 48000 ) output = inference.infer(sample) assert output.input_tokens == 20 @@ -112,7 +112,7 @@ def test_infer_16kHz_stream(tokenizer, audio_processor): inference = FakeInference(tokenizer, audio_processor) array = np.ones(16000, dtype=np.float32) sample = datasets.VoiceSample.from_prompt_and_raw( - "Transcribe <|audio|>", array, 16000 + "Transcribe\n<|audio|>", array, 16000 ) gen = inference.infer_stream(sample) text = "" diff --git a/ultravox/model/data_processing.py b/ultravox/model/data_processing.py index f2293412..640fe5e0 100644 --- a/ultravox/model/data_processing.py +++ b/ultravox/model/data_processing.py @@ -58,11 +58,12 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]: ) # Extract input_ids, attention_mask, and audio_values from the processed inputs - input_ids = inputs["input_ids"].squeeze(0) - attention_mask = inputs["attention_mask"].squeeze(0) - audio_values = inputs["audio_values"].squeeze(0) - audio_token_start_idx = inputs["audio_token_start_idx"].squeeze(0) - audio_token_len = inputs["audio_token_len"].squeeze(0) + input_ids = inputs["input_ids"].squeeze_(0) + inputs["attention_mask"].squeeze_(0) + if "audio_values" in inputs: + inputs["audio_values"].squeeze_(0) + inputs["audio_token_start_idx"].squeeze_(0) + inputs["audio_token_len"].squeeze_(0) # No need to shift the labels as the model does it internally labels = input_ids.clone() @@ -73,7 +74,7 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]: # One reason is that there's very little randomness in the prompt, so the model would be forced to memorize it. # # Example (-100 is the ignore index): - # Tokens: Transcribe <|audio|> Brown fox jumps over the lazy dog + # Tokens: Transcribe\n<|audio|> Brown fox jumps over the lazy dog # Labels: -100 -100 -100 -100 Brown fox jumps over the lazy dog # # Note: The above might look weird because I'm mixing token IDs and text, but that's just for illustration. @@ -91,10 +92,7 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]: labels[:input_text_len] = -100 return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "audio_values": audio_values, + **inputs, + # input_ids, attention_mask, audio_values, audio_token_start_idx, audio_token_len "labels": labels, - "audio_token_start_idx": audio_token_start_idx, - "audio_token_len": audio_token_len, } diff --git a/ultravox/model/ultravox_processing.py b/ultravox/model/ultravox_processing.py index 5b755fa2..20b95611 100644 --- a/ultravox/model/ultravox_processing.py +++ b/ultravox/model/ultravox_processing.py @@ -152,7 +152,7 @@ def __call__( data["audio_token_start_idx"] = [start_idx] # Replace the audio placeholder with the audio token. - # e.g. "Transcribe <|audio|>" -> "Transcribe " + # e.g. "Transcribe\n<|audio|>" -> "Transcribe " # where the number of is the number of audio frames. text = text.replace( self.audio_placeholder, diff --git a/ultravox/tools/gradio_demo.py b/ultravox/tools/gradio_demo.py index 51e5d138..cca7de3e 100644 --- a/ultravox/tools/gradio_demo.py +++ b/ultravox/tools/gradio_demo.py @@ -16,7 +16,7 @@ class DemoConfig: # runs/llama2_asr_gigaspeech/checkpoint-1000/ # wandb://fixie/ultravox/model-llama2_asr_gigaspeech:v0 model_path: str = "fixie-ai/ultravox" - default_prompt: str = "Transcribe <|audio|>" + default_prompt: str = "Transcribe\n<|audio|>" def main(): @@ -32,7 +32,7 @@ def wrapper(text: str, audio: Tuple[int, np.ndarray]) -> str: gr.Audio(label="Audio", show_download_button=True), ] outputs = [gr.Textbox(label="Output")] - examples = [["Transcribe <|audio|>", "examples/test16.wav"]] + examples = [["Transcribe\n<|audio|>", "examples/test16.wav"]] gr.Interface(fn=wrapper, inputs=inputs, outputs=outputs, examples=examples).launch( share=True diff --git a/ultravox/tools/infer_tool.py b/ultravox/tools/infer_tool.py index 44dc8ed8..8e1dfc05 100644 --- a/ultravox/tools/infer_tool.py +++ b/ultravox/tools/infer_tool.py @@ -22,7 +22,7 @@ # transcription of the audio content and the tool can perfom a WER calculation. # Remember to set the --asr flag when using an ASR input. DEFAULT_PROMPT = "Listen to <|audio|> and respond to it" -DEFAULT_ASR_PROMPT = "Transcribe <|audio|>" +DEFAULT_ASR_PROMPT = "Transcribe\n<|audio|>" @dataclasses.dataclass diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 29fe8dbc..86a20ded 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -14,6 +14,7 @@ @dataclasses.dataclass class TrainConfig: data_sets: List[str] + val_sets: List[str] # language model to use text_model: str # audio encoder model to use diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index 08cc3224..d199c74a 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -2,14 +2,15 @@ text_model: "meta-llama/Meta-Llama-3-8B-Instruct" audio_model: "facebook/wav2vec2-base-960h" data_sets: ["gigaspeech"] +val_sets: ["heysquad_human", "anyinstruct", "soda", "peoplespeech"] repeat_data: True train_on_inputs: False shuffle_data: True max_audio_duration_secs: 16 -val_num_samples: 128 -val_steps: 500 +val_num_samples: 64 +val_steps: 1000 eval_num_samples: 256 eval_max_new_tokens: 32 eval_num_procs: 16 diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 5bdc39d0..0d342a6b 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -1,3 +1,4 @@ +import copy import dataclasses import glob import logging @@ -5,7 +6,7 @@ import re import sys from datetime import datetime -from typing import List, Optional +from typing import Dict, List, Optional import datasets as hf_datasets import safetensors.torch @@ -28,7 +29,7 @@ from ultravox.training import ddp_utils from ultravox.training import evaluation -INPUT_EXAMPLE = {"text": "Transcribe <|audio|>", "audio": b"\x00\x00" * 16000} +INPUT_EXAMPLE = {"text": "Transcribe\n<|audio|>", "audio": b"\x00\x00" * 16000} OUTPUT_EXAMPLE = {"text": "Hello, world!"} @@ -160,11 +161,18 @@ def main() -> None: f"Using dtype and device (world_size): {dtype}, {device} ({world_size})" ) model.to(device=device, dtype=dtype) - # TODO: check if the whole model can now be moved to dtype instead # Prepare dataset, subsetting if needed train_dataset: data.IterableDataset - val_dataset: data.IterableDataset + val_datasets: Dict[str, data.IterableDataset] + # We use multiple validation sets here so that the results are comparable even when training set changes + # To make sure we can compare training and validation loss (e.g. for fine-tuning), we keep a special set + # called "matchtrain" that uses the same data as the training set. + val_sets = dict( + [("matchtrain", args.data_sets)] + + [(x, [x]) for x in args.val_sets] + + [(f"text_{x}", [x]) for x in args.val_sets] + ) if is_master: train_dataset = prepare_dataset( dataset_names=args.data_sets, @@ -182,21 +190,27 @@ def main() -> None: mds_batch_size=args.batch_size, ), ) - val_dataset = prepare_dataset( - dataset_names=args.data_sets, - train_on_inputs=args.train_on_inputs, - repeat_data=args.repeat_data, - processor=processor, - num_samples=args.val_num_samples, - data_args=datasets.VoiceDatasetArgs( - num_prompts=1, - data_dir=args.data_dir, - shuffle=False, - max_audio_duration_secs=16, - use_mds=args.mds, - mds_batch_size=args.batch_size, - ), + val_ds_args = datasets.VoiceDatasetArgs( + num_prompts=1, + data_dir=args.data_dir, + shuffle=False, + max_audio_duration_secs=16, + use_mds=args.mds, + mds_batch_size=args.batch_size, ) + val_ds_args_text = copy.copy(val_ds_args) + val_ds_args_text.include_audio = False + val_datasets = { + k: prepare_dataset( + dataset_names=val_sets[k], + train_on_inputs=args.train_on_inputs, + repeat_data=args.repeat_data, + processor=processor, + num_samples=args.val_num_samples, + data_args=val_ds_args_text if k.startswith("text_") else val_ds_args, + ) + for k in val_sets + } logging.info( f"Loaded {args.data_sets} data sets, sample limit: {args.num_samples} (val sample limit: {args.val_num_samples})" ) @@ -204,7 +218,7 @@ def main() -> None: # When using DDP with split_batches=True, the primary process will distribute the batches to the workers # The point of this is to avoid unnecessary data processing/downloading in the workers. train_dataset = datasets.EmptyDataset() - val_dataset = datasets.EmptyDataset() + val_datasets = {k: datasets.EmptyDataset() for k in val_sets} # Set up the data loader data_collator = datasets.DataCollatorForSeq2SeqWithAudio(tokenizer=text_tokenizer) @@ -214,7 +228,7 @@ def main() -> None: trainer = transformers.Seq2SeqTrainer( model, train_dataset=train_dataset, - eval_dataset=val_dataset, + eval_dataset=val_datasets, data_collator=data_collator, tokenizer=text_tokenizer, args=transformers.Seq2SeqTrainingArguments(