Skip to content

Commit

Permalink
batched iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
Miraj98 committed Dec 20, 2022
1 parent dfc9895 commit a512ad2
Showing 1 changed file with 57 additions and 5 deletions.
62 changes: 57 additions & 5 deletions src/dataloader.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{fs::File, os::unix::prelude::FileExt};
use std::{fs::File, os::unix::prelude::FileExt, ops::Range};
use tensor_rs::{
dim::{Dimension, Ix3},
dim::Dimension,
DataElement, Tensor, TensorView,
};

Expand All @@ -22,6 +22,40 @@ where
pub fn new(args: (Tensor<S, E>, Tensor<S, E>, Option<Tensor<S, E>>)) -> Self {
Dataset { training_set: args.0, labels: args.1, test_set: args.2 }
}

pub fn iter(&self) -> DatasetIterator<'_, S, E> {
DatasetIterator::new(self)
}

pub fn batch_iter(&self, batch_size: usize) -> DatasetBatchIterator<'_, S, E> {
let nbatches = self.training_set.dim()[0] / batch_size;
DatasetBatchIterator { dataset: self, nbatches, batch_size, batch_idx: 0 }
}
}

pub struct DatasetBatchIterator<'a, S, E>
where
S: Dimension,
E: DataElement
{
dataset: &'a Dataset<S, E>,
nbatches: usize,
batch_size: usize,
batch_idx: usize,
}

impl<'a, S, E> Iterator for DatasetBatchIterator<'a, S, E>
where
S: Dimension,
E: DataElement
{
type Item = DatasetIterator<'a, S, E>;
fn next(&mut self) -> Option<Self::Item> {
if self.batch_idx >= self.nbatches - 1 { return None }
let val = DatasetIterator::new(self.dataset).range((self.batch_size * self.batch_idx)..(self.batch_size * self.batch_idx + self.batch_size));
self.batch_idx += 1;
return Some(val);
}
}

pub struct DatasetIterator<'a, S, E>
Expand All @@ -30,7 +64,25 @@ where
E: DataElement
{
dataset: &'a Dataset<S, E>,
index: usize
index: usize,
range: Range<usize>,
}

impl<'a, S, E> DatasetIterator<'a, S, E>
where
S: Dimension,
E: DataElement
{
pub fn new(dataset: &'a Dataset<S, E>) -> Self {
Self { dataset, index: 0, range: Range { start: 0, end: dataset.training_set.dim()[0] } }
}

pub fn range(mut self, range: Range<usize>) -> Self {
assert!(range.start < range.end);
assert!(range.end < self.dataset.training_set.dim()[0]);
self.range = range;
self
}
}

impl<'a, S, E> Iterator for DatasetIterator<'a, S, E>
Expand All @@ -40,7 +92,7 @@ where
{
type Item = (TensorView<'a, S::Smaller, E>, TensorView<'a, S::Smaller, E>);
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.dataset.training_set.dim()[0] { return None; }
if self.index >= self.range.end - 1 { return None; }
let val = (self.dataset.training_set.outer_dim(self.index), self.dataset.labels.outer_dim(self.index));
self.index += 1;
return Some(val);
Expand All @@ -55,7 +107,7 @@ where
type IntoIter = DatasetIterator<'a, S, E>;
type Item = (TensorView<'a, S::Smaller, E>, TensorView<'a, S::Smaller, E>);
fn into_iter(self) -> Self::IntoIter {
DatasetIterator { dataset: self, index: 0 }
DatasetIterator::new(self)
}
}

Expand Down

0 comments on commit a512ad2

Please sign in to comment.