-
Notifications
You must be signed in to change notification settings - Fork 61
[QEff. Finetune]: Added Base dataset class and SFT dataset classes along with its test cases. #647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: ft_experimental
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,3 +4,263 @@ | |
| # SPDX-License-Identifier: BSD-3-Clause | ||
| # | ||
| # ----------------------------------------------------------------------------- | ||
|
|
||
| """ | ||
| Dataset components for the training system. | ||
| """ | ||
|
|
||
| import importlib | ||
| import re | ||
| from typing import Any, Callable, Dict | ||
|
|
||
| from datasets import load_dataset, load_dataset_builder | ||
| from torch.utils.data import Dataset | ||
|
|
||
| from QEfficient.finetune.experimental.core.component_registry import registry | ||
|
|
||
|
|
||
| class BaseDataset(Dataset): | ||
| """Base class for all datasets to ensure consistent interface.""" | ||
|
|
||
| def __init__(self, dataset_name: str, split: str, seed: int = 42, **kwargs): | ||
| self.dataset_name = dataset_name | ||
| self.split = split | ||
| self.seed = seed | ||
| self.kwargs = kwargs | ||
| self._initialize_dataset() | ||
|
|
||
| def _initialize_dataset(self): | ||
| """Subclasses should implement this to load and prepare the dataset.""" | ||
| raise NotImplementedError | ||
|
|
||
| @property | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can skip this and directly use dataset_instance.dataset Bdw, what extra hf_dataset brings? If it brings some more meta data information then lets us keep it else remove it. |
||
| def hf_dataset(self): | ||
| """Return the underlying Hugging Face dataset object.""" | ||
| return self.dataset | ||
|
|
||
| def __len__(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @AbstractMethod same for getitem |
||
| return len(self.dataset) | ||
|
|
||
| def __getitem__(self, idx): | ||
| """Should return a dictionary with 'input_ids', 'attention_mask', and 'labels'.""" | ||
| raise NotImplementedError | ||
|
|
||
| @registry.dataset("sft_dataset") | ||
| class SFTDataset(BaseDataset): | ||
| """ | ||
| A Supervised Fine-Tuning (SFT) dataset class for text data. | ||
| This class handles loading data from Hugging Face datasets or custom JSON files, | ||
| filtering out invalid samples, and applying a prompt/completion templating for SFT tasks. | ||
| Args: | ||
| dataset_name (str): The name of the dataset to load from Hugging Face datasets. | ||
| Ignored if json_file_path is provided. | ||
| split (str): The dataset split to use (e.g., "train", "validation", "test"). | ||
| split_ratio (float): Ratio for train/test split when only one split is available. | ||
| seed (int): Random seed for reproducibility. | ||
| json_file_path (str, optional): Path to a custom JSON file containing the dataset. | ||
| If provided, this takes precedence over dataset_name. | ||
| prompt_template (str): A string template for constructing the prompt. Variables in the | ||
| template should be enclosed in curly braces, e.g., "Answer the question: {question}". | ||
| completion_template (str): A string template for constructing the completion (target). | ||
| Variables should be enclosed in curly braces, e.g., "{answer}". | ||
| Raises: | ||
| RuntimeError: If any variables specified in `prompt_template` or `completion_template` | ||
| are not found as columns in the loaded dataset. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| dataset_name: str, | ||
| split: str, | ||
| split_ratio: float = 0.8, | ||
| seed: int = 42, | ||
| **kwargs, | ||
| ): | ||
| self.split_ratio = split_ratio | ||
| self.json_file_path = kwargs.get("json_file_path", None) | ||
| self.prompt_template = kwargs.get("prompt_template", None) | ||
| self.completion_template = kwargs.get("completion_template", None) | ||
| self.prompt_func_path = kwargs.get("prompt_func", None) | ||
| self.completion_func_path = kwargs.get("completion_func", None) | ||
| self.remove_samples_with_empty_columns = kwargs.get("remove_samples_with_empty_columns", True) | ||
|
|
||
| if (self.prompt_template is None and self.prompt_func_path is None) or ( | ||
| self.prompt_template is not None and self.prompt_func_path is not None | ||
| ): | ||
| raise RuntimeError("Either provide prompt_template or prompt_func in the config.") | ||
| if (self.completion_template is None and self.completion_func_path is None) or ( | ||
| self.completion_template is not None and self.completion_func_path is not None | ||
| ): | ||
| raise RuntimeError("Either provide completion_template or completion_func in the config.") | ||
|
|
||
| # Call parent class __init__ which will call _initialize_dataset | ||
| super().__init__(dataset_name, split, seed, **kwargs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good cleanup in init |
||
|
|
||
| def _initialize_dataset(self): | ||
| """ | ||
| Initialize the dataset from either HuggingFace or a custom JSON file. | ||
| This method loads the dataset, applies splitting if necessary, and prepares | ||
| it for preprocessing with prompt/completion templates. | ||
| """ | ||
| if self.json_file_path: | ||
| # Load dataset from JSON file | ||
| self.dataset = load_dataset("json", data_files=self.json_file_path, split="train") | ||
|
|
||
| # Apply train/test split if needed | ||
| if self.split in ["train", "test"]: | ||
| splitted_dataset = self.dataset.train_test_split( | ||
| test_size=(1 - self.split_ratio), | ||
| seed=self.seed | ||
| ) | ||
| if self.split == "test": | ||
| self.dataset = splitted_dataset["test"] | ||
| else: | ||
| self.dataset = splitted_dataset["train"] | ||
| else: | ||
| # Load dataset from HuggingFace | ||
| db = load_dataset_builder(self.dataset_name) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is good addition over load_dataset. |
||
| available_splits = [] | ||
| if db.info.splits is not None: | ||
| available_splits = list(db.info.splits.keys()) | ||
|
|
||
| if self.split not in available_splits and self.split == "train": | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we simplify this. if self.split not in available_splits: ?? L130 to L135. In reference code, I had added this for some reason but looks difficult to read and interpret the reasoning behind it. |
||
| raise ValueError(f"Split {self.split} is not available for dataset {self.dataset_name}.") | ||
|
|
||
| load_split = self.split | ||
| if self.split not in available_splits: | ||
| load_split = "train" | ||
|
|
||
| # FIXME: Add streaming support for larger datasets. | ||
| self.dataset = load_dataset(self.dataset_name, split=load_split) | ||
| if len(available_splits) == 1: | ||
| splitted_dataset = self.dataset.train_test_split( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. L140 to L147 is same as L115 to L122. |
||
| test_size=(1 - self.split_ratio), | ||
| seed=self.seed | ||
| ) | ||
| if self.split == "test": | ||
| self.dataset = splitted_dataset["test"] | ||
| else: | ||
| self.dataset = splitted_dataset["train"] | ||
|
|
||
| self.dataset_columns = self.dataset.column_names | ||
| self._setup_templates_and_preprocessing() | ||
|
|
||
| def _setup_templates_and_preprocessing(self): | ||
| """ | ||
| Set up prompt/completion templates or functions and apply preprocessing. | ||
| """ | ||
| if self.prompt_template: | ||
| self.prompt_func = None | ||
| # Extract variables from templates and check if they exist in dataset columns | ||
| prompt_variables = re.findall(r"\{(.*?)\}", self.prompt_template) | ||
| for var in prompt_variables: | ||
| if var not in self.dataset_columns: | ||
| raise RuntimeError( | ||
| f"Prompt template variable '{var}' not found in dataset columns: {self.dataset_columns}." | ||
| ) | ||
| else: | ||
| prompt_variables = self.dataset_columns | ||
| self.prompt_func = self.import_func(self.prompt_func_path) | ||
|
|
||
| if self.completion_template: | ||
| self.completion_func = None | ||
| # Extract variables from templates and check if they exist in dataset columns | ||
| completion_variables = re.findall(r"\{(.*?)\}", self.completion_template) | ||
| for var in completion_variables: | ||
| if var not in self.dataset_columns: | ||
| raise RuntimeError( | ||
| f"Completion template variable '{var}' not found in dataset columns: {self.dataset_columns}." | ||
| ) | ||
| else: | ||
| completion_variables = self.dataset_columns | ||
| self.completion_func = self.import_func(self.completion_func_path) | ||
|
|
||
| # Filter out samples with None or empty strings in relevant columns | ||
| # Only filter columns that are actually used in the templates | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make a single line comment |
||
| self.relevant_columns = list(set(prompt_variables + completion_variables)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we pass this to self._filter_empty_or_none_samples rather than making a self variable? |
||
| if self.remove_samples_with_empty_columns: | ||
| self.dataset = self.dataset.filter(self._filter_empty_or_none_samples) | ||
| self.dataset = self.dataset.map(self._preprocess_sample) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we preprocess before hand? We can do it in __ getitem __ function to do it on the fly. Right? |
||
|
|
||
| def import_func(self, func_path: str) -> Callable: | ||
| if ":" not in func_path: | ||
| raise ValueError("func_path must be in the format 'module_file_path:function_name'.") | ||
| module_file_path, function_name = func_path.split(":") | ||
|
|
||
| try: | ||
| module = importlib.import_module(module_file_path) | ||
| except Exception: | ||
| raise RuntimeError(f"Unable to import module : {module_file_path}.") | ||
| if not hasattr(module, function_name): | ||
| raise ValueError(f"Function {function_name} not found in module {module_file_path}.") | ||
| return getattr(module, function_name) | ||
|
|
||
| def _filter_empty_or_none_samples(self, example: Dict[str, Any]) -> bool: | ||
| """ | ||
| Filters out samples where any of the relevant columns are None or contain only whitespace. | ||
| Args: | ||
| example (Dict[str, Any]): A single sample from the dataset. | ||
| Returns: | ||
| bool: True if the sample should be kept, False otherwise. | ||
| """ | ||
| for column in self.relevant_columns: | ||
| value = example.get(column) | ||
| if value is None or (isinstance(value, str) and not value.strip()): | ||
| return False | ||
| return True | ||
|
|
||
| def _preprocess_sample(self, example: Dict[str, Any]) -> Dict[str, str]: | ||
| """ | ||
| Applies the prompt and completion templates to a single example. | ||
| Args: | ||
| example (Dict[str, Any]): A single sample from the dataset. | ||
| Returns: | ||
| Dict[str, str]: A dictionary containing the 'prompt' and 'completion' strings. | ||
| """ | ||
| prompt_text = ( | ||
| self.prompt_func(example) if self.prompt_func is not None else self.prompt_template.format(**example) | ||
| ) | ||
| completion_text = ( | ||
| self.completion_func(example) | ||
| if self.completion_func is not None | ||
| else self.completion_template.format(**example) | ||
| ) | ||
| return { | ||
| "prompt": prompt_text, | ||
| "completion": completion_text, | ||
| } | ||
|
|
||
| def __len__(self) -> int: | ||
| """ | ||
| Returns the number of samples in the dataset. | ||
| Returns: | ||
| int: The total number of samples. | ||
| """ | ||
| return self.dataset.num_rows | ||
|
|
||
| def __getitem__(self, idx: int) -> Dict[str, str]: | ||
| """ | ||
| Retrieves a processed sample from the dataset at the given index. | ||
| Args: | ||
| idx (int): The index of the sample to retrieve. | ||
| Returns: | ||
| Dict[str, str]: A dictionary containing the processed 'prompt' and 'completion' for the sample. | ||
| """ | ||
| # Get the raw example using .select and access the first element | ||
| example = self.dataset.select(indices=[int(idx)])[0] | ||
|
|
||
| # Apply preprocessing (templating) on the fly | ||
| processed_example = self._preprocess_sample(example) | ||
|
|
||
| return processed_example | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from abc import ABC, abstractmethod
class BaseDataset(Dataset, ABC):
Make the BaseDataset class inherit from ABC so that it will become abstract base class.