File tree Expand file tree Collapse file tree 2 files changed +24
-2
lines changed Expand file tree Collapse file tree 2 files changed +24
-2
lines changed Original file line number Diff line number Diff line change
1
+ import math
1
2
import re
2
- from typing import Optional , TypeVar
3
+ from typing import List , Optional , Sequence , TypeVar
3
4
5
+ import torch
6
+ from torch .utils .data .dataset import Dataset , Subset
4
7
from typing_extensions import TypeGuard
5
8
6
9
T = TypeVar ("T" )
@@ -13,3 +16,22 @@ def exists(val: Optional[T]) -> TypeGuard[T]:
13
16
def camel_to_snake (name : str ) -> str :
14
17
name = re .sub ("(.)([A-Z][a-z]+)" , r"\1_\2" , name )
15
18
return re .sub ("([a-z0-9])([A-Z])" , r"\1_\2" , name ).lower ()
19
+
20
+
21
+ def fractional_random_split (
22
+ dataset : Dataset [T ], fractions : Sequence [int ]
23
+ ) -> List [Subset [T ]]:
24
+ """Fractional split that follows the same convention as random_split"""
25
+ assert sum (fractions ) == 1.0 , "Fractions must sum to 1.0"
26
+
27
+ length = len (dataset ) # type: ignore[arg-type]
28
+ indices = torch .randperm (length )
29
+ splits = []
30
+ cursor = 0
31
+
32
+ for fraction in fractions :
33
+ next_cursor = math .ceil (length * fraction + cursor )
34
+ splits += [Subset (dataset , indices [cursor :next_cursor ])] # type: ignore[arg-type] # noqa
35
+ cursor = next_cursor
36
+
37
+ return splits
Original file line number Diff line number Diff line change 3
3
setup (
4
4
name = "audio-data-pytorch" ,
5
5
packages = find_packages (exclude = []),
6
- version = "0.0.4 " ,
6
+ version = "0.0.5 " ,
7
7
license = "MIT" ,
8
8
description = "Audio Data - PyTorch" ,
9
9
long_description_content_type = "text/markdown" ,
You can’t perform that action at this time.
0 commit comments