|
1 | 1 | import gc
|
2 |
| -import os |
3 | 2 | from abc import ABC, abstractmethod
|
| 3 | +from pathlib import Path |
4 | 4 | from typing import Literal, Optional, Union, cast
|
5 | 5 |
|
6 | 6 | import datasets
|
|
9 | 9 |
|
10 | 10 | from oumi.core.tokenizers import BaseTokenizer
|
11 | 11 | from oumi.core.types.turn import Conversation
|
| 12 | +from oumi.utils.hf_datasets_utils import is_cached_to_disk_hf_dataset |
12 | 13 | from oumi.utils.logging import logger
|
13 | 14 |
|
14 | 15 |
|
@@ -125,11 +126,17 @@ def _load_data(self) -> pd.DataFrame:
|
125 | 126 | Returns:
|
126 | 127 | dict: The loaded dataset.
|
127 | 128 | """
|
128 |
| - if os.path.exists(self.dataset_name_or_path): |
129 |
| - if self.dataset_name_or_path.endswith(".jsonl"): |
| 129 | + dataset_path = Path(self.dataset_name_or_path) |
| 130 | + if dataset_path.exists(): |
| 131 | + if self.dataset_name_or_path.endswith(".jsonl") and dataset_path.is_file(): |
130 | 132 | result = self._load_jsonl_dataset(self.dataset_name_or_path)
|
131 |
| - elif self.dataset_name_or_path.endswith(".parquet"): |
| 133 | + elif ( |
| 134 | + self.dataset_name_or_path.endswith(".parquet") |
| 135 | + and dataset_path.is_file() |
| 136 | + ): |
132 | 137 | result = self._load_parquet_dataset(self.dataset_name_or_path)
|
| 138 | + elif is_cached_to_disk_hf_dataset(self.dataset_name_or_path): |
| 139 | + result = self._load_dataset_from_disk(self.dataset_name_or_path) |
133 | 140 | else:
|
134 | 141 | raise ValueError(
|
135 | 142 | f"File format not supported for {self.dataset_name_or_path}"
|
@@ -202,6 +209,12 @@ def _load_jsonl_dataset(self, path: str) -> pd.DataFrame:
|
202 | 209 | def _load_parquet_dataset(self, path: str) -> pd.DataFrame:
|
203 | 210 | return pd.read_parquet(path)
|
204 | 211 |
|
| 212 | + def _load_dataset_from_disk(self, path: str) -> pd.DataFrame: |
| 213 | + dataset: datasets.Dataset = datasets.Dataset.load_from_disk(path) |
| 214 | + result = dataset.to_pandas() |
| 215 | + del dataset |
| 216 | + return cast(pd.DataFrame, result) |
| 217 | + |
205 | 218 |
|
206 | 219 | class BaseLMSftDataset(BaseMapDataset, ABC):
|
207 | 220 | """In-memory dataset for SFT data.
|
|
0 commit comments