Skip to content

Revert "concat dataset" #1358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,6 @@

Releases, starting with 9/2/2021, are listed with the most recent release at the top.

# NuGet Version 0.102.7

__Breaking Changes__:

A new interface `IDataset<out T>` has been added. (Now `Dataset<T>` implements `IDataset<T>`; `Dataset` implements both `IDataset<Dictionary<string, Tensor>>` and `IDataset<IReadOnlyDictionary<string, Tensor>>`; `IterableDataset` implements `IDataset<IList<string, Tensor>>` and `IDataset<IEnumerable<string, Tensor>>`.)<br/>
`torch.utils.data.ConcatDataset` has been added.<br/>

__API Changes__:

The parameter of `DataLoader`s has been relaxed to `IDataset`.<br/>
The parameter of `DataLoader`s' collate functions has been relaxed to `IReadOnlyList`.<br/>

# NuGet Version 0.102.6

__Breaking Changes__:
Expand Down
108 changes: 0 additions & 108 deletions src/TorchSharp/ConcatDataset.cs

This file was deleted.

52 changes: 22 additions & 30 deletions src/TorchSharp/DataLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ public static partial class utils
{
public static partial class data
{

public static Modules.DataLoader DataLoader(
IDataset<IReadOnlyDictionary<string, torch.Tensor>> dataset,
Dataset dataset,
int batchSize, IEnumerable<long> shuffler,
Device device = null,
int num_worker = 1, bool drop_last = false,
Expand All @@ -33,7 +34,7 @@ public static Modules.DataLoader DataLoader(
}

public static Modules.DataLoader DataLoader(
IDataset<IReadOnlyDictionary<string, torch.Tensor>> dataset,
Dataset dataset,
int batchSize, bool shuffle = false,
Device device = null, int? seed = null,
int num_worker = 1, bool drop_last = false,
Expand All @@ -48,7 +49,7 @@ public static Modules.DataLoader DataLoader(
}

public static Modules.IterableDataLoader DataLoader(
IDataset<IEnumerable<Tensor>> dataset,
IterableDataset dataset,
int batchSize, IEnumerable<long> shuffler,
Device device = null,
int num_worker = 1, bool drop_last = false,
Expand All @@ -63,7 +64,7 @@ public static Modules.IterableDataLoader DataLoader(
}

public static Modules.IterableDataLoader DataLoader(
IDataset<IEnumerable<Tensor>> dataset,
IterableDataset dataset,
int batchSize, bool shuffle = false,
Device device = null, int? seed = null,
int num_worker = 1, bool drop_last = false,
Expand All @@ -89,8 +90,7 @@ namespace Modules
/// Data loader. Combines a dataset and a sampler, and provides an enumerator over the given dataset.
/// </summary>
/// <remarks>This class is used for map-style data sets</remarks>
public class DataLoader : DataLoader<IReadOnlyDictionary<string, torch.Tensor>,
Dictionary<string, torch.Tensor>>
public class DataLoader : DataLoader<Dictionary<string, torch.Tensor>, Dictionary<string, torch.Tensor>>
{
/// <summary>
/// Pytorch style dataloader
Expand All @@ -111,7 +111,7 @@ public class DataLoader : DataLoader<IReadOnlyDictionary<string, torch.Tensor>,
/// Indicates whether to dispose the dataset when being disposed.
/// </param>
public DataLoader(
IDataset<IReadOnlyDictionary<string, torch.Tensor>> dataset,
Dataset dataset,
int batchSize, IEnumerable<long> shuffler,
Device device = null,
int num_worker = 1, bool drop_last = false,
Expand Down Expand Up @@ -144,7 +144,7 @@ public DataLoader(
/// Indicates whether to dispose the dataset when being disposed.
/// </param>
public DataLoader(
IDataset<IReadOnlyDictionary<string, torch.Tensor>> dataset,
Dataset dataset,
int batchSize, bool shuffle = false,
Device device = null, int? seed = null,
int num_worker = 1, bool drop_last = false,
Expand All @@ -157,8 +157,7 @@ public DataLoader(
{
}

private static Dictionary<string, torch.Tensor> Collate(
IEnumerable<IReadOnlyDictionary<string, torch.Tensor>> dic, torch.Device device)
private static Dictionary<string, torch.Tensor> Collate(IEnumerable<Dictionary<string, torch.Tensor>> dic, torch.Device device)
{
using (torch.NewDisposeScope()) {
Dictionary<string, torch.Tensor> batch = new();
Expand All @@ -177,8 +176,7 @@ public DataLoader(
/// Data loader. Combines a dataset and a sampler, and provides an enumerator over the given dataset.
/// </summary>
/// <remarks>This class is used for list-style data sets</remarks>
public class IterableDataLoader :
DataLoader<IEnumerable<torch.Tensor>, IList<torch.Tensor>>
public class IterableDataLoader : DataLoader<IList<torch.Tensor>, IList<torch.Tensor>>
{
/// <summary>
/// Pytorch style dataloader
Expand All @@ -199,7 +197,7 @@ public class IterableDataLoader :
/// Indicates whether to dispose the dataset when being disposed.
/// </param>
public IterableDataLoader(
IDataset<IEnumerable<Tensor>> dataset,
IterableDataset dataset,
int batchSize, IEnumerable<long> shuffler,
Device device = null,
int num_worker = 1, bool drop_last = false,
Expand Down Expand Up @@ -232,7 +230,7 @@ public IterableDataLoader(
/// Indicates whether to dispose the dataset when being disposed.
/// </param>
public IterableDataLoader(
IDataset<IEnumerable<Tensor>> dataset,
IterableDataset dataset,
int batchSize, bool shuffle = false,
Device device = null, int? seed = null,
int num_worker = 1, bool drop_last = false,
Expand All @@ -245,18 +243,12 @@ public IterableDataLoader(
{
}

private static IList<torch.Tensor> Collate(
IReadOnlyList<IEnumerable<torch.Tensor>> dic, torch.Device device)
private static IList<torch.Tensor> Collate(IEnumerable<IList<torch.Tensor>> dic, torch.Device device)
{
var dicCopy = new List<torch.Tensor[]>();
foreach (var e in dic) {
dicCopy.Add(e.ToArray());
}

using (torch.NewDisposeScope()) {
List<torch.Tensor> batch = new();
for (var x = 0; x < dicCopy[0].Length; x++) {
var t = cat(dicCopy.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
for (var x = 0; x < dic.First().Count; x++) {
var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
if (t.device_type != device.type || t.device_index != device.index)
t = t.to(device);
batch.Add(t.MoveToOuterDisposeScope());
Expand All @@ -272,12 +264,12 @@ public IterableDataLoader(
/// </summary>
public class DataLoader<T, S> : IEnumerable<S>, IDisposable
{
public IDataset<T> dataset { get; }
public Dataset<T> dataset { get; }
public int batch_size { get; }
public bool drop_last { get; }
public IEnumerable<long> sampler { get; }
public int num_workers { get; }
public Func<IReadOnlyList<T>, Device, S> collate_fn { get; }
public Func<IEnumerable<T>, Device, S> collate_fn { get; }

public Device Device { get; }
public bool DisposeBatch { get; }
Expand All @@ -303,9 +295,9 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
/// Indicates whether to dispose the dataset when being disposed.
/// </param>
public DataLoader(
IDataset<T> dataset,
Dataset<T> dataset,
int batchSize,
Func<IReadOnlyList<T>, torch.Device, S> collate_fn,
Func<IEnumerable<T>, torch.Device, S> collate_fn,
IEnumerable<long> shuffler,
Device? device = null,
int num_worker = 1,
Expand Down Expand Up @@ -345,9 +337,9 @@ public DataLoader(
/// Indicates whether to dispose the dataset when being disposed.
/// </param>
public DataLoader(
IDataset<T> dataset,
Dataset<T> dataset,
int batchSize,
Func<IReadOnlyList<T>, torch.Device, S> collate_fn,
Func<IEnumerable<T>, torch.Device, S> collate_fn,
bool shuffle = false,
Device? device = null,
int? seed = null,
Expand Down Expand Up @@ -440,7 +432,7 @@ public bool MoveNext()
.WithDegreeOfParallelism(loader.num_workers)
.ForAll((i) => {
using var getTensorScope = torch.NewDisposeScope();
tensors[i] = loader.dataset[indices[i]];
tensors[i] = loader.dataset.GetTensor(indices[i]);
getTensorDisposables[i] = getTensorScope.DetachAllAndDispose();
});

Expand Down
44 changes: 3 additions & 41 deletions src/TorchSharp/Dataset.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;

namespace TorchSharp
{
Expand All @@ -14,29 +13,21 @@ public static partial class data
/// <summary>
/// Map-style data set
/// </summary>
public abstract class Dataset : Dataset<Dictionary<string, Tensor>>,
IDataset<IReadOnlyDictionary<string, Tensor>>
public abstract class Dataset : Dataset<Dictionary<string, torch.Tensor>>
{
// Due to covariation, it should naturally be IDataset<IReadOnlyDictionary<string, Tensor>>.
// However FSharp.Examples will break down without this.
IReadOnlyDictionary<string, Tensor> IDataset<IReadOnlyDictionary<string, Tensor>>.this[long index] => this[index];
}

/// <summary>
/// Iterable-style data sets
/// </summary>
public abstract class IterableDataset : Dataset<IList<Tensor>>,
IDataset<IEnumerable<Tensor>>
public abstract class IterableDataset : Dataset<IList<Tensor>>
{
// Due to covariation, it should naturally be IDataset<IEnumerable<Tensor>>.
// However FSharp.Examples will break down without this.
IEnumerable<Tensor> IDataset<IEnumerable<Tensor>>.this[long index] => this[index];
}

/// <summary>
/// The base nterface for all Datasets.
/// </summary>
public abstract class Dataset<T> : IDataset<T>, IDisposable
public abstract class Dataset<T> : IDisposable
{
public void Dispose()
{
Expand All @@ -49,12 +40,6 @@ public void Dispose()
/// </summary>
public abstract long Count { get; }

[IndexerName("DatasetItems")]
public T this[long index] => this.GetTensor(index);

// GetTensor is kept for compatibility.
// Perhaps we should remove that and make the indexer abstract later.

/// <summary>
/// Get tensor according to index
/// </summary>
Expand All @@ -64,31 +49,8 @@ public void Dispose()

protected virtual void Dispose(bool disposing)
{
IDataset<Dictionary<string, string>> a = null;
IDataset<IReadOnlyDictionary<string, string>> b = a;
}
}

/// <summary>
/// The base interface for all Datasets.
/// </summary>
public interface IDataset<out T> : IDisposable
{
/// <summary>
/// Size of dataset
/// </summary>
long Count { get; }

/// <summary>
/// Get tensor according to index
/// </summary>
/// <param name="index">Index for tensor</param>
/// <returns>Tensors of index. DataLoader will catenate these tensors into batches.</returns>
[IndexerName("DatasetItems")]
T this[long index] { get; }

// TODO: support System.Index
}
}
}
}
Expand Down
Loading