Skip to content

added TableDataset #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jun 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
19a77ac
added TableDataset
manikyabard Apr 5, 2021
81369e8
Update src/datasets/containers.jl
manikyabard Apr 6, 2021
8eaa656
changed append to push in getobs for TableDataset
manikyabard Apr 7, 2021
b9c65eb
updated path type, added direct methods for DataFrames and CSV.File, …
manikyabard Apr 21, 2021
8585fc7
added TableDataset
manikyabard Apr 5, 2021
dc4b007
Update src/datasets/containers.jl
manikyabard Apr 6, 2021
d483b6d
changed append to push in getobs for TableDataset
manikyabard Apr 7, 2021
52a456d
updated path type, added direct methods for DataFrames and CSV.File, …
manikyabard Apr 21, 2021
1d52456
Merge branch 'manikyabard/table_container' of https://github.com/mani…
manikyabard May 30, 2021
134f1bc
Update src/datasets/containers.jl
manikyabard May 30, 2021
a2f6261
fixed nobs typo
manikyabard May 30, 2021
d30c7f8
removed redundant line
manikyabard May 30, 2021
81b36dd
added tests and made getobs consistent
manikyabard May 31, 2021
9b36af1
fixed typo
manikyabard May 31, 2021
77d9d2b
changed getobs back for DataFrame
manikyabard Jun 1, 2021
519a3a7
Merge branch 'manikyabard/table_container' of https://github.com/mani…
manikyabard Jun 1, 2021
6f00d25
Updated getobs
manikyabard Jun 5, 2021
2462546
Changed getobs to return NamedTuple
manikyabard Jun 6, 2021
d40c1f8
Update test/datasets/containers.jl
manikyabard Jun 9, 2021
77459fa
Updated tests for TableDataset container.
manikyabard Jun 9, 2021
c0f3fcc
remove Manifest.toml
manikyabard Jun 15, 2021
d92e32b
fixed path in test
manikyabard Jun 15, 2021
6161b0a
updated csv TableDataset test
manikyabard Jun 15, 2021
aa9f9af
removed old csv testcase
manikyabard Jun 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ version = "0.1.0"
[deps]
Animations = "27a7e980-b3e6-11e9-2bcd-0b925532e340"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DLPipelines = "e6530d7c-7faa-4ede-a0d6-9eff9baad396"
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f"
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/Datasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ using MLDataPattern: splitobs
import LearnBase
using Colors
using FixedPointNumbers
using DataFrames
using Tables
using CSV

include("fastaidatasets.jl")

Expand Down
40 changes: 37 additions & 3 deletions src/datasets/containers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


# FileDataset

struct FileDataset
Expand Down Expand Up @@ -45,5 +43,41 @@ isimagefile(file::File) = isimagefile(file.name)
isimagefile(file::String) = occursin(IMAGEFILE_REGEX, lowercase(file))
const IMAGEFILE_REGEX = r"\.(gif|jpe?g|tiff?|png|webp|bmp)$"

#TableDataset

struct TableDataset{T}
table::T #Should implement Tables.jl interface
TableDataset{T}(table::T) where T = Tables.istable(table) ? new{T}(table) : error("Object doesn't implement Tables.jl interface")
end

TableDataset(table::T) where {T} = TableDataset{T}(table)
TableDataset(path::AbstractPath) = TableDataset(DataFrame(CSV.File(path)))

function LearnBase.getobs(dataset::FastAI.Datasets.TableDataset, idx)
if Tables.rowaccess(dataset.table)
row, _ = Iterators.peel(Iterators.drop(Tables.rows(dataset.table), idx - 1))
return row
elseif Tables.columnaccess(dataset.table)
colnames = Tables.columnnames(dataset.table)
rowvals = [Tables.getcolumn(dataset.table, i)[idx] for i in 1:length(colnames)]
return (; zip(colnames, rowvals)...)
else
error("The Tables.jl implementation used should have either rowaccess or columnaccess.")
end
end

function LearnBase.nobs(dataset::TableDataset)
if Tables.columnaccess(dataset.table)
return length(Tables.getcolumn(dataset.table, 1))
elseif Tables.rowaccess(dataset.table)
return length(Tables.rows(dataset.table)) # length might not be defined, but has to be for this to work.
else
error("The Tables.jl implementation used should have either rowaccess or columnaccess.")
end
end

LearnBase.getobs(dataset::TableDataset{<:DataFrame}, idx) = dataset.table[idx, :]
LearnBase.nobs(dataset::TableDataset{<:DataFrame}) = nrow(dataset.table)

## TODO: TableDataset
LearnBase.getobs(dataset::TableDataset{<:CSV.File}, idx) = dataset.table[idx]
LearnBase.nobs(dataset::TableDataset{<:CSV.File}) = length(dataset.table)
6 changes: 6 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DLPipelines = "e6530d7c-7faa-4ede-a0d6-9eff9baad396"
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3"
Comment on lines +10 to +11
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These seem orthogonal, maybe @lorenzoh can double-check?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JLD2, Makie and ShowCases are already in the master branch, not sure why they're shown here. Might just be that the list was sorted on a ]pkg command.

StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"
61 changes: 61 additions & 0 deletions test/datasets/containers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
include("../imports.jl")

@testset ExtendedTestSet "TableDataset" begin

@testset ExtendedTestSet "TableDataset from rowaccess table" begin
Tables.columnaccess(::Type{<:Tables.MatrixTable}) = false
Tables.rowaccess(::Type{<:Tables.MatrixTable}) = true

testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"])
td = TableDataset(testtable)

@test all(getobs(td, 1) .== [1, 4.0, "7"])
@test nobs(td) == 3
end

@testset ExtendedTestSet "TableDataset from columnaccess table" begin
Tables.columnaccess(::Type{<:Tables.MatrixTable}) = true
Tables.rowaccess(::Type{<:Tables.MatrixTable}) = false

testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"])
td = TableDataset(testtable)

@test [data for data in getobs(td, 2)] == [2, 5.0, "8"]
@test nobs(td) == 3

@test getobs(td, 1) isa NamedTuple
end

@testset ExtendedTestSet "TableDataset from DataFrames" begin
testtable = DataFrame(
col1=[1, 2, 3, 4, 5],
col2=["a", "b", "c", "d", "e"],
col3=[10, 20, 30, 40, 50],
col4=["A", "B", "C", "D", "E"],
col5=[100., 200., 300., 400., 500.],
split=["train", "train", "train", "valid", "valid"]
)
td = TableDataset(testtable)
@test td isa TableDataset{<:DataFrame}

@test [data for data in getobs(td, 1)] == [1, "a", 10, "A", 100., "train"]
@test nobs(td) == 5
end

@testset ExtendedTestSet "TableDataset from CSV" begin
open("test.csv", "w") do io
write(io, "col1,col2,col3,col4,col5, split\n1,a,10,A,100.,train")
end
testtable = CSV.File("test.csv")
td = TableDataset(testtable)
@test td isa TableDataset{<:CSV.File}
@test [data for data in getobs(td, 1)] == [1,
"a",
10,
"A",
100.,
"train"]
@test nobs(td) == 1
rm("test.csv")
end
end
3 changes: 3 additions & 0 deletions test/imports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ using Flux.Optimise: Optimiser, apply!
using StaticArrays
using Test
using TestSetExtensions
using DataFrames
using Tables
using CSV

ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"
include("testdata.jl")