Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KEP-2401: Kubeflow LLM Trainer V2 #2410

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
doc: add data preprocess section in design chapter.
Signed-off-by: Electronic-Waste <2690692950@qq.com>
  • Loading branch information
Electronic-Waste committed Feb 1, 2025
commit c2f7307c71a2a4a399d62984a6d67d1bddeecf86
39 changes: 37 additions & 2 deletions docs/proposals/2401-llm-trainer-v2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ TrainingClient().train(
trainer=Trainer(
fine_tuning_config=FineTuningConfig(
backend="huggingface",
launch="torchrun",
dataset_class="InstructionDataset",
dataset_class="Instruction",
peft_config=LoraConfig(r=4),
sharding_config=FsdpConfig(...),
kwargs={},
Expand Down Expand Up @@ -149,6 +148,42 @@ def fine_tune(model_name, dataset, backend, **kwargs):

```

### Data Preprocess

Different datasets have vastly different keys and usage. For example, instruction datasets (e.g. [tatsu-lab/alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca)) always include keys like `instruction`, `input`, `output` and `text`. However, question answering datasets (e.g. [openai/gsm8k](https://huggingface.co/datasets/openai/gsm8k)) contain columns like `question` and `answer`. It’s impossible to implement a unified dataset class suitable for every datasets on HuggingFace. **Different types of tasks need different implementations** so that they can preprocess data in a specific way.

Based on the reasons above, we decide to **provide multiple built-in dataset classes** for data processing. They can be used directly by specifying the `dataset_class` parameter in Python SDK (e.g. `dataset_class="instruction"`). Meanwhile, we also **allow users to define customized dataset classes with specified methods implemented** and pass it to the `dataset_class` parameter in Python SDK (e.g. `dataset_class=CustomDatasetClass`).

```python
from torch.utils.data import Dataset

DATASET_REGISTRY = {}

def register_dataset(name):
def decorator(cls):
DATASET_REGISTRY[name] = cls
return cls
return decorator

# Abstract Dataset Class
class InitMethod(ABC):
@abstractmethod
def __init__(self, dataset_config, tokenizer, partition="train"):
raise NotImplementedError()

@register_dataset("instruction")
class InstructionDataset(Dataset, InitMethod):
def __init__(self, dataset_config, tokenizer, partition="train"):
# Some code here

def __len__(self):
# Some code here

def __getitem__(self, index):
# Some code here

```

## Implementation History

- 2025-01-31: Create KEP-2401 doc
Expand Down