Skip to content

Commit 4cd87ad

Browse files
manikyabarddarsnacklorenzoh
authored
added TableDataset (#26)
* added TableDataset * Update src/datasets/containers.jl Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> * changed append to push in getobs for TableDataset * updated path type, added direct methods for DataFrames and CSV.File, and changed nobs order * added TableDataset * Update src/datasets/containers.jl Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> * changed append to push in getobs for TableDataset * updated path type, added direct methods for DataFrames and CSV.File, and changed nobs order * Update src/datasets/containers.jl Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> * fixed nobs typo * added tests and made getobs consistent * fixed typo Co-authored-by: lorenzoh <lorenz.ohly@gmail.com> * changed getobs back for DataFrame * Updated getobs * Changed getobs to return NamedTuple * Update test/datasets/containers.jl Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> * Updated tests for TableDataset container. Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> * remove Manifest.toml * fixed path in test * updated csv TableDataset test * removed old csv testcase Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> Co-authored-by: lorenzoh <lorenz.ohly@gmail.com>
1 parent ac16863 commit 4cd87ad

File tree

6 files changed

+112
-3
lines changed

6 files changed

+112
-3
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ version = "0.1.0"
66
[deps]
77
Animations = "27a7e980-b3e6-11e9-2bcd-0b925532e340"
88
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
9+
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
910
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
1011
DLPipelines = "e6530d7c-7faa-4ede-a0d6-9eff9baad396"
1112
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
1213
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
14+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1315
DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9"
1416
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
1517
FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f"

src/datasets/Datasets.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ using MLDataPattern: splitobs
126126
import LearnBase
127127
using Colors
128128
using FixedPointNumbers
129+
using DataFrames
130+
using Tables
131+
using CSV
129132

130133
include("fastaidatasets.jl")
131134

src/datasets/containers.jl

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
2-
31
# FileDataset
42

53
struct FileDataset
@@ -45,5 +43,41 @@ isimagefile(file::File) = isimagefile(file.name)
4543
isimagefile(file::String) = occursin(IMAGEFILE_REGEX, lowercase(file))
4644
const IMAGEFILE_REGEX = r"\.(gif|jpe?g|tiff?|png|webp|bmp)$"
4745

46+
#TableDataset
47+
48+
struct TableDataset{T}
49+
table::T #Should implement Tables.jl interface
50+
TableDataset{T}(table::T) where T = Tables.istable(table) ? new{T}(table) : error("Object doesn't implement Tables.jl interface")
51+
end
52+
53+
TableDataset(table::T) where {T} = TableDataset{T}(table)
54+
TableDataset(path::AbstractPath) = TableDataset(DataFrame(CSV.File(path)))
55+
56+
function LearnBase.getobs(dataset::FastAI.Datasets.TableDataset, idx)
57+
if Tables.rowaccess(dataset.table)
58+
row, _ = Iterators.peel(Iterators.drop(Tables.rows(dataset.table), idx - 1))
59+
return row
60+
elseif Tables.columnaccess(dataset.table)
61+
colnames = Tables.columnnames(dataset.table)
62+
rowvals = [Tables.getcolumn(dataset.table, i)[idx] for i in 1:length(colnames)]
63+
return (; zip(colnames, rowvals)...)
64+
else
65+
error("The Tables.jl implementation used should have either rowaccess or columnaccess.")
66+
end
67+
end
68+
69+
function LearnBase.nobs(dataset::TableDataset)
70+
if Tables.columnaccess(dataset.table)
71+
return length(Tables.getcolumn(dataset.table, 1))
72+
elseif Tables.rowaccess(dataset.table)
73+
return length(Tables.rows(dataset.table)) # length might not be defined, but has to be for this to work.
74+
else
75+
error("The Tables.jl implementation used should have either rowaccess or columnaccess.")
76+
end
77+
end
78+
79+
LearnBase.getobs(dataset::TableDataset{<:DataFrame}, idx) = dataset.table[idx, :]
80+
LearnBase.nobs(dataset::TableDataset{<:DataFrame}) = nrow(dataset.table)
4881

49-
## TODO: TableDataset
82+
LearnBase.getobs(dataset::TableDataset{<:CSV.File}, idx) = dataset.table[idx]
83+
LearnBase.nobs(dataset::TableDataset{<:CSV.File}) = length(dataset.table)

test/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
[deps]
2+
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
23
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
34
DLPipelines = "e6530d7c-7faa-4ede-a0d6-9eff9baad396"
45
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
6+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
57
FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f"
68
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
9+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
10+
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
11+
ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3"
712
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
13+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
814
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
915
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"

test/datasets/containers.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
include("../imports.jl")
2+
3+
@testset ExtendedTestSet "TableDataset" begin
4+
5+
@testset ExtendedTestSet "TableDataset from rowaccess table" begin
6+
Tables.columnaccess(::Type{<:Tables.MatrixTable}) = false
7+
Tables.rowaccess(::Type{<:Tables.MatrixTable}) = true
8+
9+
testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"])
10+
td = TableDataset(testtable)
11+
12+
@test all(getobs(td, 1) .== [1, 4.0, "7"])
13+
@test nobs(td) == 3
14+
end
15+
16+
@testset ExtendedTestSet "TableDataset from columnaccess table" begin
17+
Tables.columnaccess(::Type{<:Tables.MatrixTable}) = true
18+
Tables.rowaccess(::Type{<:Tables.MatrixTable}) = false
19+
20+
testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"])
21+
td = TableDataset(testtable)
22+
23+
@test [data for data in getobs(td, 2)] == [2, 5.0, "8"]
24+
@test nobs(td) == 3
25+
26+
@test getobs(td, 1) isa NamedTuple
27+
end
28+
29+
@testset ExtendedTestSet "TableDataset from DataFrames" begin
30+
testtable = DataFrame(
31+
col1=[1, 2, 3, 4, 5],
32+
col2=["a", "b", "c", "d", "e"],
33+
col3=[10, 20, 30, 40, 50],
34+
col4=["A", "B", "C", "D", "E"],
35+
col5=[100., 200., 300., 400., 500.],
36+
split=["train", "train", "train", "valid", "valid"]
37+
)
38+
td = TableDataset(testtable)
39+
@test td isa TableDataset{<:DataFrame}
40+
41+
@test [data for data in getobs(td, 1)] == [1, "a", 10, "A", 100., "train"]
42+
@test nobs(td) == 5
43+
end
44+
45+
@testset ExtendedTestSet "TableDataset from CSV" begin
46+
open("test.csv", "w") do io
47+
write(io, "col1,col2,col3,col4,col5, split\n1,a,10,A,100.,train")
48+
end
49+
testtable = CSV.File("test.csv")
50+
td = TableDataset(testtable)
51+
@test td isa TableDataset{<:CSV.File}
52+
@test [data for data in getobs(td, 1)] == [1,
53+
"a",
54+
10,
55+
"A",
56+
100.,
57+
"train"]
58+
@test nobs(td) == 1
59+
rm("test.csv")
60+
end
61+
end

test/imports.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ using Flux.Optimise: Optimiser, apply!
1212
using StaticArrays
1313
using Test
1414
using TestSetExtensions
15+
using DataFrames
16+
using Tables
17+
using CSV
1518

1619
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"
1720
include("testdata.jl")

0 commit comments

Comments
 (0)