Skip to content

Commit 9b92a28

Browse files
committed
Merge branch 'manikyabard/tableblock' into manikyabard/tabularmethods
2 parents 95873cb + 478036c commit 9b92a28

File tree

7 files changed

+102
-3
lines changed

7 files changed

+102
-3
lines changed

src/FastAI.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
module FastAI
22

33

4+
using Base: NamedTuple
5+
using Colors: colormaps_sequential
46
using Reexport
57
@reexport using DLPipelines
68
@reexport using FluxTraining
@@ -47,6 +49,7 @@ include("datablock/checks.jl")
4749
include("datablock/wrappers.jl")
4850

4951
# Encodings
52+
include("encodings/tabularpreprocessing.jl")
5053
include("encodings/onehot.jl")
5154
include("encodings/imagepreprocessing.jl")
5255
include("encodings/projective.jl")
@@ -120,6 +123,8 @@ export
120123
Label,
121124
LabelMulti,
122125
Keypoints,
126+
TableRow,
127+
Continuous,
123128

124129
# encodings
125130
encode,
@@ -131,6 +136,7 @@ export
131136
Only,
132137
Named,
133138
augs_projection, augs_lighting,
139+
TabularTransform,
134140

135141
BlockMethod,
136142
describemethod,

src/datablock/block.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,30 @@ function checkblock(
155155
end
156156

157157
mockblock(block::Keypoints{N}) where N = rand(SVector{N, Float32}, block.sz)
158+
159+
160+
# TableRow
161+
162+
struct TableRow{M, N} <: Block
163+
catcols
164+
contcols
165+
categorydict
166+
end
167+
168+
function TableRow(catcols, contcols, categorydict)
169+
TableRow{length(catcols), length(contcols)}(catcols, contcols, categorydict)
170+
end
171+
172+
function checkblock(block::TableRow, x)
173+
(all(col -> x[col] block.categorydict[col], block.catcols) &&
174+
all(col -> x[col] isa Number, block.contcols) &&
175+
all(col -> haskey(block.categorydict, col), block.catcols))
176+
end
177+
178+
struct Continuous <: Block
179+
n
180+
end
181+
182+
function checkblock(block::Continuous, x)
183+
block.n == length(x)
184+
end

src/datablock/loss.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,8 @@ function blocklossfn(outblock::KeypointTensor{N}, yblock::KeypointTensor{N}) whe
4242
outblock.sz == yblock.sz || error("Sizes of $outblock and $yblock differ!")
4343
return Flux.Losses.mse
4444
end
45+
46+
function blocklossfn(outblock::Continuous, yblock::Continuous)
47+
outblock.n == yblock.n || error("Sizes of $outblock and $yblock differ!")
48+
return Flux.Losses.mse
49+
end

src/datablock/method.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ struct BlockMethod{B, E, O} <: LearningMethod
1010
outputblock::O
1111
end
1212

13-
function BlockMethod(blocks, encodings; outputblock = encodedblock(encodings, blocks[2]))
13+
function BlockMethod(blocks, encodings; outputblock = isnothing(encodedblock(encodings, blocks[2])) ? blocks[2] : encodedblock(encodings, blocks[2]))
1414
return BlockMethod(blocks, encodings, outputblock)
1515
end
1616

@@ -47,12 +47,12 @@ end
4747

4848
function methodmodel(method::BlockMethod, backbone)
4949
xblock = encodedblock(method.encodings, method.blocks[1])
50-
return blockmodel(xblock, method.outputblock, backbone)
50+
return blockmodel(isnothing(xblock) ? method.blocks[1] : xblock, method.outputblock, backbone)
5151
end
5252

5353
function methodlossfn(method::BlockMethod)
5454
yblock = encodedblock(method.encodings, method.blocks[2])
55-
return blocklossfn(method.outputblock, yblock)
55+
return blocklossfn(method.outputblock, isnothing(yblock) ? method.blocks[2] : yblock)
5656
end
5757

5858
# Testing interface

src/datablock/models.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,27 @@ function blockmodel(inblock::ImageTensor{N}, outblock::KeypointTensor{N}, backbo
5959
head = Models.visionhead(outch, prod(outblock.sz)*N, p = 0.)
6060
return Chain(backbone, head)
6161
end
62+
63+
"""
64+
blockmodel(inblock::TableRow{M, N}, outblock::Union{Continuous, OneHotTensor{0}}, backbone=nothing) where {M, N}
65+
66+
Contruct a model for tabular classification or regression. `backbone` should
67+
either be `nothing` or a tuple of categoricalbackbone, continuousbackbone,
68+
and a classifierbackbone, with the first two taking in batches of corresponding
69+
row value matrices.
70+
"""
71+
72+
# function blockmodel(
73+
# inblock::EncodedTableRow{M, N},
74+
# outblock::Union{Continuous, OneHotTensor{0}},
75+
# backbone=nothing;) where {M, N}
76+
# outsz = outblock isa Continuous ? outblock.n : length(outblock.classes)
77+
# if isnothing(backbone)
78+
# TabularModel(inblock.catcols, N, outsz; catdict = inblock.categorydict)
79+
# else
80+
# TabularModel(backbone[1], backbone[2], backbone[3])
81+
# end
82+
# end
83+
84+
85+

src/encodings/onehot.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ function checkblock(block::OneHotTensorMulti{N}, a::AbstractArray{T, M}) where {
2424
return N + 1 == M && last(size(a)) == length(block.classes)
2525
end
2626

27+
2728
"""
2829
OneHot()
2930
OneHot(T, threshold)

src/encodings/tabularpreprocessing.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
struct EncodedTableRow{M, N} <: Block
2+
catcols
3+
contcols
4+
categorydict
5+
end
6+
7+
function EncodedTableRow(catcols, contcols, categorydict)
8+
EncodedTableRow{length(catcols), length(contcols)}(catcols, contcols, categorydict)
9+
end
10+
11+
function checkblock(::EncodedTableRow{M, N}, x) where {M, N}
12+
length(x[1]) == M && length(x[2]) == N
13+
end
14+
15+
struct TabularTransform <: Encoding
16+
tfms
17+
end
18+
19+
function encodedblock(::TabularTransform, block::TableRow)
20+
EncodedTableRow(block.catcols, block.contcols, block.categorydict)
21+
end
22+
23+
function encode(tt::TabularTransform, _, block::TableRow, row)
24+
columns = Tables.columnnames(row)
25+
usedrow = NamedTuple(filter(
26+
x -> x[1] block.catcols || x[1] block.contcols,
27+
collect(zip(columns, row))
28+
))
29+
tfmrow = DataAugmentation.apply(
30+
tt.tfms,
31+
DataAugmentation.TabularItem(usedrow, keys(usedrow))
32+
).data
33+
catvals = map(col -> tfmrow[col], block.catcols)
34+
contvals = map(col -> tfmrow[col], block.contcols)
35+
(catvals, contvals)
36+
end

0 commit comments

Comments
 (0)