Skip to content

Add pytorch style dataloader #463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Jan 12, 2022
Merged

Add pytorch style dataloader #463

merged 45 commits into from
Jan 12, 2022

Conversation

dayo05
Copy link
Contributor

@dayo05 dayo05 commented Nov 27, 2021

In pytorch there are DataLoader on torch.utils.data, Torchsharp has DataIterator on TorchSharp.Data.

In pytorch, Creating Dataset is easy.

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

This is sample code that I found on pytorch official tutorial.
But in DataIterator, I don't know how to use that, It is not similar as pytorch style, I cannot found tutorial about it.

So I make dataloader using like this.
This style is very similar as what pytorch does.

This will make iterating data more easier.

@nietras
Copy link
Contributor

nietras commented Nov 29, 2021

This is "hard-coded" to only two tensors which doesn't fit general case scenarios. We have scenarios when many input and many output tensors. In python this is handled by the dynamic type system and being able to handle tuples etc. I think :) This cannot be mirrored with a good experience in C# in my view. The best generalization is to base this on Dictionary<string, Tensor> (or Dictionary<object, Tensor>) if generalizing key. Or we need to go down a generic, reflection based path which is problematic in my view.

Sorry, editing old comment, since DataLoader is present here but wanted to say that, since we need to batch we need to be able to iterator over Tensors and produce the same structured output which is why Dictionary<string, Tensor> seems apt.

However, we would then also very much like to be able to run data loading threaded, which involves handling reproducibility in the face of global variables, ensuring randomized sampling is reproducible. So shuflling should be fully reproducible and seeded. Hence, random created at ctor. While we are at then, do fischer-yates shuffle and not use object index, but use statically typed stuff. Hence, it should be GetTensor(int index). Or this could probably be Dictionary<string, Tensor> Get(int index) or even just an indexer.

I also don't think Dataset should have a GetDataEnumerable(), Dataset should not be iterable, but only be indexable. Since that is how DataLoader can do parallel, shuffled loading.

I had been thinking about adding this myself, so very nice to see it being worked on 👍 However, as I see there are some fundamental issues that need to be addressed regarding multi-threading loading and reproducibility that need to be taken into account.

Copy link
Contributor

@nietras nietras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made some review comments too. :) I have more but just some quick remarks.

@dayo05
Copy link
Contributor Author

dayo05 commented Nov 30, 2021

TODO: Write test about it

@NiklasGustafsson
Copy link
Contributor

This is such an important addition, and we must make sure to get it right. I would like to see some more extensive code examples on how to use this -- the user experience is essential.

@dayo05
Copy link
Contributor Author

dayo05 commented Nov 30, 2021

I'm making a new example for reflection about review

@NiklasGustafsson
Copy link
Contributor

NiklasGustafsson commented Dec 1, 2021

I'm going to add a few things to the code comments in a review, but I have some questions:

  1. Should we allow data loaders to accept custom shufflers? That is, if I think (I'd probably be wrong) that I have a better shuffle algorithm, shouldn't I be allowed to use that, instead?
  2. There's no way to specify the device on which the 'Current' tensors end up.
  3. Why does 'Current' return a dictionary? Is it input vs. labels?
  4. How do I split a data set into train, test, and validation data subsets?
  5. How do we add data augmentation support?
  6. When the dataset is not shuffled, we're still creating fresh batches every time through the data set. It would be great to think of some way to keep data in memory between epochs, as long as it fits, of course.

Also, it would be great to demonstrate how to use this by converting the Examples to use this instead of the ad hoc loading I added earlier.

@NiklasGustafsson
Copy link
Contributor

@dayo05 -- thanks for all the work so far. I'll be taking some time off in December, but I'm looking forward to integrating this when I come back.

Thanks,

Niklas

@dayo05
Copy link
Contributor Author

dayo05 commented Dec 10, 2021

I'm busy now, I'll going to work on two weeks later :(

@dayo05
Copy link
Contributor Author

dayo05 commented Dec 25, 2021

@NiklasGustafsson I finished all my job, review please.

@dayo05
Copy link
Contributor Author

dayo05 commented Dec 25, 2021

I found shuffler is not correctly working. I'll fix ASAP.

@dayo05
Copy link
Contributor Author

dayo05 commented Dec 25, 2021

This is example for classifying fruits360 dataset with imagesharp. Now shuffler is not working fine so loss is big but if I use previous shuffler(which before making shufflegenerator class), It works fine.

public class Fruits360: Dataset
{
    private List<string> Labels = new();
    private List<string> images = new();
    public Fruits360(bool isTrain, Device device)
    {
        var root = "/home/dayo/datasets/fruits-360/" + (isTrain ? "Training" : "Test");
        //Labels.AddRange(Directory.GetDirectories(root));
        foreach(var x in Directory.GetDirectories(root))
            images.AddRange(Directory.GetFiles(x));
        Labels.AddRange(Directory.GetDirectories(root).Select(x => x.Split('/')[^1]));
    }

    public override long Count => images.Count;

    public override Dictionary<string, torch.Tensor> GetTensor(long index)
    {
        var image = Image.Load<Rgb24>(images[(int) index], new JpegDecoder());
        using var r = tensor(image.GetPixelMemoryGroup()[0].Span.ToArray().Select(x => x.R / 255.0f).ToList(),
            new long[] {1, 100, 100});
        using var g = tensor(image.GetPixelMemoryGroup()[0].Span.ToArray().Select(x => x.G / 255.0f).ToList(),
            new long[] {1, 100, 100});
        using var b = tensor(image.GetPixelMemoryGroup()[0].Span.ToArray().Select(x => x.B / 255.0f).ToList(),
            new long[] {1, 100, 100});
        return new()
        {
            {"image", cat(new List<Tensor> {r, g, b}, 0)},
            {"label", tensor(Labels.IndexOf(images[(int)index].Split('/')[^2]), ScalarType.Int64)}
        };
    }
}
using var trainDataset = new Fruits360(true, CUDA);
using var testDataset = new Fruits360(false, CUDA);
using var train = new DataLoader(trainDataset, 256, true, CUDA);
using var test = new DataLoader(testDataset, 512, false, CUDA);

var model = new Fruits360Model(CUDA);
var optimizer = optim.Adam(model.parameters(), learningRate: 0.01);

foreach (var epoch in Range(1, 1000))
{
    model.Train();
    var batchId = 1;
    Console.WriteLine($"Epoch{epoch} running");
    foreach (var x in train)
    {
        optimizer.zero_grad();

        var prediction = model.forward(x["image"]);
        var output = functional.nll_loss(reduction: Reduction.Mean)(prediction, x["label"]);
        
        output.backward();
        optimizer.step();
        
        Console.Write($"\rTrain: epoch {epoch} {batchId * 1.0 / train.Count:P2} [{batchId} / {train.Count}] Loss: {output.ToSingle():F9}");
        batchId++;
        
        prediction.Dispose();
        output.Dispose();
        GC.Collect();
    }

    using (no_grad())
    {
        model.Eval();
        
        var testLoss = 0.0;
        var correct = 0;
        var total = 0L;
        var idx = 0;
        foreach (var x in test)
        {
            idx++;
            Console.Write($"\rTest running: {idx * 1.0 / test.Count:P2}");
            var prediction = model.forward(x["image"]);
            var output = functional.nll_loss(reduction: Reduction.Sum)(prediction, x["label"]);
            testLoss += output.ToSingle();

            var pred = prediction.argmax(1);
            total += pred.size()[0];
            correct += pred.eq(x["label"]).sum().ToInt32();
            pred.Dispose();
            prediction.Dispose();
            output.Dispose();
            GC.Collect();
        }
        Console.WriteLine(
            $"\rTest set: Average loss {(testLoss / testDataset.Count):F9} | Accuracy {((double) correct / testDataset.Count):P2}");
    }
}


class Fruits360Model : Module
{
    private Module layer1 = Sequential(
        Conv2d(3, 32, 3),
        ReLU(),
        MaxPool2d(2, 2));

    private Module layer2 = Sequential(
        Conv2d(32, 64, 3),
        ReLU(),
        MaxPool2d(2, 2));

    private Module layer3 = Sequential(
        Conv2d(64, 64, 3),
        ReLU(),
        MaxPool2d(2, 2));

    private Module fc = Sequential(
        Flatten(),
        Linear(6400, 1024),
        ReLU(),
        Dropout(),
        Linear(1024, 625),
        ReLU(),
        Linear(625, 131));
    public Fruits360Model(Device? device) : base("fruits360")
    {
        RegisterComponents();
        to(device ?? CPU);
    }

    public override Tensor forward(Tensor t)
    {
        t = layer1.forward(t);
        t = layer2.forward(t);
        t = layer3.forward(t);
        t = fc.forward(t);
        return LogSoftmax(0).forward(t);
    }
}

@dayo05
Copy link
Contributor Author

dayo05 commented Dec 27, 2021

I found that this shuffler is depend on seed. So, I'll change to other shuffler and enable to use custom shuffler with implementation of IEnumerable.

@NiklasGustafsson
Copy link
Contributor

@dayo05 -- A couple of comments:

  1. This looks really good now. I will be happy to approve this later.
  2. There seems to be some problem with the Azure pipelines that are doing the builds for MacOS. We'll have to wait for that to go away before merging.
  3. Before we merge, I would like to see at least one of the examples in this repo modified to use this API instead of the version I put together for temporary purposes.

@dayo05
Copy link
Contributor Author

dayo05 commented Jan 7, 2022

I make fisher yates shuffler as default because I thought that default value is friendly for beginner and previous shuffler is not good for beginner who has less amount of dataset. But, I left previous shuffler and allow to call like this

var data = new DataLoader(dataset, batchSize, new FastShuffler(dataset.Count), torch.CUDA);

@dayo05
Copy link
Contributor Author

dayo05 commented Jan 7, 2022

https://github.com/dayo05/DataLoaderExample/tree/master
I write example usage for this commit
See Fruits360.cs, Program.cs

@NiklasGustafsson
Copy link
Contributor

https://github.com/dayo05/DataLoaderExample/tree/master I write example usage for this commit See Fruits360.cs, Program.cs

That's great. Could you take the MNIST example in this repo and convert it to use your API?

@dayo05
Copy link
Contributor Author

dayo05 commented Jan 10, 2022

@NiklasGustafsson I updated here https://github.com/dayo05/DataLoaderExample/tree/master
Exist on MNIST.cs, MNISTReader.cs
I edited from torchsharp main repo source

@NiklasGustafsson
Copy link
Contributor

NiklasGustafsson commented Jan 11, 2022

I've been trying to restart the MacOS builds for a couple of days now. I have no idea why there are failing, but it is very early, way before it gets to building TorchSharp.

Update: @dayo05 -- the 'main' branch still builds just fine, so there must be something in your PR. I suspect it may have something to do with your changing the SDK version number, but I'm not sure.

@dayo05
Copy link
Contributor Author

dayo05 commented Jan 12, 2022

@NiklasGustafsson I fix it

@NiklasGustafsson
Copy link
Contributor

Okay, I'm going to merge this now. Let's still get some of the examples in this repository using the DataLoader API.

@NiklasGustafsson NiklasGustafsson merged commit 25d6ad2 into dotnet:main Jan 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants