Skip to content

Commit

Permalink
Merge pull request #175 from Visual-Behavior/merge_dataset_weights
Browse files Browse the repository at this point in the history
Merge dataset weights
  • Loading branch information
thibo73800 authored May 16, 2022
2 parents 997e772 + 4cfda4e commit 6c718f4
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions alodataset/merge_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,35 @@ class MergeDataset(torch.utils.data.Dataset):
List of datasets
transform_fn : function
transformation applied to each sample
weights: List[int] | None
For N datasets, a list of N integer weights.
The samples from a dataset with weight `w` will appear `w` times in the MergeDataset.
"""

def __init__(self, datasets, transform_fn=None):
def __init__(self, datasets, transform_fn=None, weights=None):
self.datasets = datasets
self.weights = self._init_weights(weights)
self.indices = self._init_indices()
self.transform_fn = transform_fn

def _init_weights(self, weights):
n_datasets = len(self.datasets)
if weights is None:
return [1] * n_datasets

if len(weights) != n_datasets:
raise RuntimeError("The number of weights should be equal to the number of datasets.")

if any(type(w) != int for w in weights):
raise RuntimeError("weights should be a list of int.")
return weights

def _init_indices(self):
indices = []
for dset_idx, dset in enumerate(self.datasets):
for idx in range(len(dset)):
indices.append((dset_idx, idx))
for _ in range(self.weights[dset_idx]):
for idx in range(len(dset)):
indices.append((dset_idx, idx))
return indices

def __len__(self):
Expand Down

0 comments on commit 6c718f4

Please sign in to comment.