-
Notifications
You must be signed in to change notification settings - Fork 201
Add pytorch style dataloader #463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is "hard-coded" to only two tensors which doesn't fit general case scenarios. We have scenarios when many input and many output tensors. In python this is handled by the dynamic type system and being able to handle tuples etc. I think :) This cannot be mirrored with a good experience in C# in my view. The best generalization is to base this on Sorry, editing old comment, since DataLoader is present here but wanted to say that, since we need to batch we need to be able to iterator over However, we would then also very much like to be able to run data loading threaded, which involves handling reproducibility in the face of global variables, ensuring randomized sampling is reproducible. So shuflling should be fully reproducible and seeded. Hence, random created at ctor. While we are at then, do fischer-yates shuffle and not use I also don't think Dataset should have a GetDataEnumerable(), Dataset should not be iterable, but only be indexable. Since that is how DataLoader can do parallel, shuffled loading. I had been thinking about adding this myself, so very nice to see it being worked on 👍 However, as I see there are some fundamental issues that need to be addressed regarding multi-threading loading and reproducibility that need to be taken into account. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made some review comments too. :) I have more but just some quick remarks.
TODO: Write test about it |
This is such an important addition, and we must make sure to get it right. I would like to see some more extensive code examples on how to use this -- the user experience is essential. |
I'm making a new example for reflection about review |
I'm going to add a few things to the code comments in a review, but I have some questions:
Also, it would be great to demonstrate how to use this by converting the Examples to use this instead of the ad hoc loading I added earlier. |
@dayo05 -- thanks for all the work so far. I'll be taking some time off in December, but I'm looking forward to integrating this when I come back. Thanks, Niklas |
I'm busy now, I'll going to work on two weeks later :( |
@NiklasGustafsson I finished all my job, review please. |
I found shuffler is not correctly working. I'll fix ASAP. |
This is example for classifying fruits360 dataset with imagesharp. Now shuffler is not working fine so loss is big but if I use previous shuffler(which before making shufflegenerator class), It works fine. public class Fruits360: Dataset
{
private List<string> Labels = new();
private List<string> images = new();
public Fruits360(bool isTrain, Device device)
{
var root = "/home/dayo/datasets/fruits-360/" + (isTrain ? "Training" : "Test");
//Labels.AddRange(Directory.GetDirectories(root));
foreach(var x in Directory.GetDirectories(root))
images.AddRange(Directory.GetFiles(x));
Labels.AddRange(Directory.GetDirectories(root).Select(x => x.Split('/')[^1]));
}
public override long Count => images.Count;
public override Dictionary<string, torch.Tensor> GetTensor(long index)
{
var image = Image.Load<Rgb24>(images[(int) index], new JpegDecoder());
using var r = tensor(image.GetPixelMemoryGroup()[0].Span.ToArray().Select(x => x.R / 255.0f).ToList(),
new long[] {1, 100, 100});
using var g = tensor(image.GetPixelMemoryGroup()[0].Span.ToArray().Select(x => x.G / 255.0f).ToList(),
new long[] {1, 100, 100});
using var b = tensor(image.GetPixelMemoryGroup()[0].Span.ToArray().Select(x => x.B / 255.0f).ToList(),
new long[] {1, 100, 100});
return new()
{
{"image", cat(new List<Tensor> {r, g, b}, 0)},
{"label", tensor(Labels.IndexOf(images[(int)index].Split('/')[^2]), ScalarType.Int64)}
};
}
} using var trainDataset = new Fruits360(true, CUDA);
using var testDataset = new Fruits360(false, CUDA);
using var train = new DataLoader(trainDataset, 256, true, CUDA);
using var test = new DataLoader(testDataset, 512, false, CUDA);
var model = new Fruits360Model(CUDA);
var optimizer = optim.Adam(model.parameters(), learningRate: 0.01);
foreach (var epoch in Range(1, 1000))
{
model.Train();
var batchId = 1;
Console.WriteLine($"Epoch{epoch} running");
foreach (var x in train)
{
optimizer.zero_grad();
var prediction = model.forward(x["image"]);
var output = functional.nll_loss(reduction: Reduction.Mean)(prediction, x["label"]);
output.backward();
optimizer.step();
Console.Write($"\rTrain: epoch {epoch} {batchId * 1.0 / train.Count:P2} [{batchId} / {train.Count}] Loss: {output.ToSingle():F9}");
batchId++;
prediction.Dispose();
output.Dispose();
GC.Collect();
}
using (no_grad())
{
model.Eval();
var testLoss = 0.0;
var correct = 0;
var total = 0L;
var idx = 0;
foreach (var x in test)
{
idx++;
Console.Write($"\rTest running: {idx * 1.0 / test.Count:P2}");
var prediction = model.forward(x["image"]);
var output = functional.nll_loss(reduction: Reduction.Sum)(prediction, x["label"]);
testLoss += output.ToSingle();
var pred = prediction.argmax(1);
total += pred.size()[0];
correct += pred.eq(x["label"]).sum().ToInt32();
pred.Dispose();
prediction.Dispose();
output.Dispose();
GC.Collect();
}
Console.WriteLine(
$"\rTest set: Average loss {(testLoss / testDataset.Count):F9} | Accuracy {((double) correct / testDataset.Count):P2}");
}
}
class Fruits360Model : Module
{
private Module layer1 = Sequential(
Conv2d(3, 32, 3),
ReLU(),
MaxPool2d(2, 2));
private Module layer2 = Sequential(
Conv2d(32, 64, 3),
ReLU(),
MaxPool2d(2, 2));
private Module layer3 = Sequential(
Conv2d(64, 64, 3),
ReLU(),
MaxPool2d(2, 2));
private Module fc = Sequential(
Flatten(),
Linear(6400, 1024),
ReLU(),
Dropout(),
Linear(1024, 625),
ReLU(),
Linear(625, 131));
public Fruits360Model(Device? device) : base("fruits360")
{
RegisterComponents();
to(device ?? CPU);
}
public override Tensor forward(Tensor t)
{
t = layer1.forward(t);
t = layer2.forward(t);
t = layer3.forward(t);
t = fc.forward(t);
return LogSoftmax(0).forward(t);
}
} |
I found that this shuffler is depend on seed. So, I'll change to other shuffler and enable to use custom shuffler with implementation of IEnumerable. |
@dayo05 -- A couple of comments:
|
I make fisher yates shuffler as default because I thought that default value is friendly for beginner and previous shuffler is not good for beginner who has less amount of dataset. But, I left previous shuffler and allow to call like this var data = new DataLoader(dataset, batchSize, new FastShuffler(dataset.Count), torch.CUDA); |
https://github.com/dayo05/DataLoaderExample/tree/master |
That's great. Could you take the MNIST example in this repo and convert it to use your API? |
In previous code, dataset is able to dispose twice
@NiklasGustafsson I updated here https://github.com/dayo05/DataLoaderExample/tree/master |
I've been trying to restart the MacOS builds for a couple of days now. I have no idea why there are failing, but it is very early, way before it gets to building TorchSharp. Update: @dayo05 -- the 'main' branch still builds just fine, so there must be something in your PR. I suspect it may have something to do with your changing the SDK version number, but I'm not sure. |
@NiklasGustafsson I fix it |
Okay, I'm going to merge this now. Let's still get some of the examples in this repository using the DataLoader API. |
In pytorch there are DataLoader on torch.utils.data, Torchsharp has DataIterator on TorchSharp.Data.
In pytorch, Creating Dataset is easy.
This is sample code that I found on pytorch official tutorial.
But in DataIterator, I don't know how to use that, It is not similar as pytorch style, I cannot found tutorial about it.
So I make dataloader using like this.
This style is very similar as what pytorch does.
This will make iterating data more easier.