Skip to content

Commit 288b474

Browse files
feat: add fractional random split util
1 parent dc76913 commit 288b474

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

audio_data_pytorch/utils.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import math
12
import re
2-
from typing import Optional, TypeVar
3+
from typing import List, Optional, Sequence, TypeVar
34

5+
import torch
6+
from torch.utils.data.dataset import Dataset, Subset
47
from typing_extensions import TypeGuard
58

69
T = TypeVar("T")
@@ -13,3 +16,22 @@ def exists(val: Optional[T]) -> TypeGuard[T]:
1316
def camel_to_snake(name: str) -> str:
1417
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
1518
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
19+
20+
21+
def fractional_random_split(
22+
dataset: Dataset[T], fractions: Sequence[int]
23+
) -> List[Subset[T]]:
24+
"""Fractional split that follows the same convention as random_split"""
25+
assert sum(fractions) == 1.0, "Fractions must sum to 1.0"
26+
27+
length = len(dataset) # type: ignore[arg-type]
28+
indices = torch.randperm(length)
29+
splits = []
30+
cursor = 0
31+
32+
for fraction in fractions:
33+
next_cursor = math.ceil(length * fraction + cursor)
34+
splits += [Subset(dataset, indices[cursor:next_cursor])] # type: ignore[arg-type] # noqa
35+
cursor = next_cursor
36+
37+
return splits

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-data-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.4",
6+
version="0.0.5",
77
license="MIT",
88
description="Audio Data - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)