Skip to content

Commit b05cb72

Browse files
authored
Merge branch 'main' into jgreer013/async_inference_writes
2 parents 283895a + 09ab5ab commit b05cb72

File tree

5 files changed

+90
-13
lines changed

5 files changed

+90
-13
lines changed

src/oumi/builders/data.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import warnings
3+
from pathlib import Path
34
from typing import Callable, List, Optional, Sequence, TypeVar, Union, cast
45

56
import datasets
@@ -23,6 +24,7 @@
2324
)
2425
from oumi.datasets.trl_dpo_preprocessor import trl_dpo_chat_preprocessor_fn
2526
from oumi.datasets.ultrachat_200k import trl_sft_ultrachat_200k_preprocessor_fn
27+
from oumi.utils.hf_datasets_utils import is_cached_to_disk_hf_dataset
2628
from oumi.utils.logging import logger
2729

2830
DatasetType = TypeVar("DatasetType", datasets.Dataset, datasets.IterableDataset)
@@ -368,11 +370,15 @@ def _load_dataset(
368370
)
369371
return dataset.to_hf()
370372

371-
return datasets.load_dataset(
372-
dataset_params.dataset_name,
373-
name=dataset_params.subset,
374-
split=dataset_params.split,
375-
streaming=stream,
376-
trust_remote_code=dataset_params.trust_remote_code,
377-
**dataset_params.dataset_kwargs,
378-
)
373+
dataset_name_or_path: Path = Path(dataset_params.dataset_name)
374+
if is_cached_to_disk_hf_dataset(dataset_name_or_path):
375+
return datasets.Dataset.load_from_disk(dataset_name_or_path)
376+
else:
377+
return datasets.load_dataset(
378+
dataset_params.dataset_name,
379+
name=dataset_params.subset,
380+
split=dataset_params.split,
381+
streaming=stream,
382+
trust_remote_code=dataset_params.trust_remote_code,
383+
**dataset_params.dataset_kwargs,
384+
)

src/oumi/core/datasets/base_dataset.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import gc
2-
import os
32
from abc import ABC, abstractmethod
3+
from pathlib import Path
44
from typing import Literal, Optional, Union, cast
55

66
import datasets
@@ -9,6 +9,7 @@
99

1010
from oumi.core.tokenizers import BaseTokenizer
1111
from oumi.core.types.turn import Conversation
12+
from oumi.utils.hf_datasets_utils import is_cached_to_disk_hf_dataset
1213
from oumi.utils.logging import logger
1314

1415

@@ -125,11 +126,17 @@ def _load_data(self) -> pd.DataFrame:
125126
Returns:
126127
dict: The loaded dataset.
127128
"""
128-
if os.path.exists(self.dataset_name_or_path):
129-
if self.dataset_name_or_path.endswith(".jsonl"):
129+
dataset_path = Path(self.dataset_name_or_path)
130+
if dataset_path.exists():
131+
if self.dataset_name_or_path.endswith(".jsonl") and dataset_path.is_file():
130132
result = self._load_jsonl_dataset(self.dataset_name_or_path)
131-
elif self.dataset_name_or_path.endswith(".parquet"):
133+
elif (
134+
self.dataset_name_or_path.endswith(".parquet")
135+
and dataset_path.is_file()
136+
):
132137
result = self._load_parquet_dataset(self.dataset_name_or_path)
138+
elif is_cached_to_disk_hf_dataset(self.dataset_name_or_path):
139+
result = self._load_dataset_from_disk(self.dataset_name_or_path)
133140
else:
134141
raise ValueError(
135142
f"File format not supported for {self.dataset_name_or_path}"
@@ -202,6 +209,12 @@ def _load_jsonl_dataset(self, path: str) -> pd.DataFrame:
202209
def _load_parquet_dataset(self, path: str) -> pd.DataFrame:
203210
return pd.read_parquet(path)
204211

212+
def _load_dataset_from_disk(self, path: str) -> pd.DataFrame:
213+
dataset: datasets.Dataset = datasets.Dataset.load_from_disk(path)
214+
result = dataset.to_pandas()
215+
del dataset
216+
return cast(pd.DataFrame, result)
217+
205218

206219
class BaseLMSftDataset(BaseMapDataset, ABC):
207220
"""In-memory dataset for SFT data.

src/oumi/utils/hf_datasets_utils.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from pathlib import Path
2+
from typing import Union
3+
4+
from oumi.utils.logging import logger
5+
6+
7+
def is_cached_to_disk_hf_dataset(dataset_name_or_path: Union[str, Path]) -> bool:
8+
"""Detects whether a dataset was saved using `dataset.save_to_disk()`.
9+
10+
Such datasets should be loaded using `datasets.Daataset.load_from_disk()`
11+
12+
Returns:
13+
Whether the dataset was saved using `dataset.save_to_disk()` method.
14+
"""
15+
if not dataset_name_or_path:
16+
return False
17+
18+
dataset_path: Path = Path(dataset_name_or_path)
19+
20+
if dataset_path.exists() and dataset_path.is_dir():
21+
for file_name in ("dataset_info.json", "state.json"):
22+
file_path: Path = dataset_path / file_name
23+
if not (file_path.exists() and file_path.is_file()):
24+
logger.warning(
25+
f"The dataset {str(dataset_path)} is missing "
26+
f"a required file: {file_name}."
27+
)
28+
return False
29+
return True
30+
31+
return False

tests/utils/test_debugging_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_nvidia_gpu_memory_utilization():
2222
if num_devices > 0:
2323
for device_index in range(0, num_devices):
2424
memory_mib = get_nvidia_gpu_memory_utilization(device_index)
25-
assert memory_mib > 1024 # Must have at least 1 GB
25+
assert memory_mib > 1 # Must have at least 1 MB
2626
assert memory_mib < 1024 * 1024 # No known GPU has 1 TB of VRAM yet.
2727
log_nvidia_gpu_memory_utilization(device_index)
2828

tests/utils/test_hf_datasets_utils.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import tempfile
2+
from pathlib import Path
3+
4+
import datasets
5+
6+
from oumi.utils.hf_datasets_utils import is_cached_to_disk_hf_dataset
7+
8+
9+
def test_is_saved_to_disk_hf_dataset():
10+
with tempfile.TemporaryDirectory() as output_temp_dir:
11+
ds = datasets.Dataset.from_dict(
12+
{"pokemon": ["bulbasaur", "squirtle"], "type": ["grass", "water"]}
13+
)
14+
ds_dir = Path(output_temp_dir) / "toy_dataset"
15+
assert not is_cached_to_disk_hf_dataset(ds_dir)
16+
17+
ds_dir.mkdir(parents=True, exist_ok=True)
18+
assert not is_cached_to_disk_hf_dataset(ds_dir)
19+
20+
ds.save_to_disk(ds_dir, num_shards=2)
21+
assert is_cached_to_disk_hf_dataset(ds_dir)
22+
23+
for filename in ("dataset_info.json", "state.json"):
24+
sub_path: Path = Path(ds_dir) / filename
25+
assert sub_path.exists() and sub_path.is_file()
26+
sub_path.unlink()
27+
assert not is_cached_to_disk_hf_dataset(ds_dir)

0 commit comments

Comments
 (0)