Skip to content

Move domain-specific functionality to submodules #186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@ jobs:
- windows-latest
arch:
- x64
- x86
exclude:
- os: macOS-latest
arch: x86
- os: windows-latest
arch: x86

steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
15 changes: 13 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## v0.3.0
## v0.4.0 (Unreleased)

### Added

### Changed

- (INTERNAL) domain-specific functionality has moved to submodules `FastAI.Vision` (computer vision) and `FastAI.Tabular` (tabular data). Exports of `FastAI` are not affected.
- (INTERNAL) test suite now runs on InlineTest.jl

### Removed

## v0.3.0 (2021/12/11)

### Added

Expand All @@ -21,7 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- The old visualization API incl. all its `plot*` methods: `plotbatch`, `plotsample`, `plotsamples`, `plotpredictions`


## 0.2.0
## 0.2.0 (2021/09/21)

### Added

Expand Down
4 changes: 2 additions & 2 deletions docs/background/blocksencodings.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ Where do we draw the line between model and data processing? In general, the enc
```julia
FastAI.testencoding(enc, Image{2}())
```
- The default implementations of `encodedblock` and `decodedblock` is to return `nothing` indicating that it doesn't transform the data. This is overwritten for blocks for which `encode` and `decode` are implemented to indicate that the data is transformed. Using `encodedblock(block, data, true)` will replace returned `nothing`s with the unchanged block.
- The default implementations of `encodedblock` and `decodedblock` is to return `nothing` indicating that it doesn't transform the data. This is overwritten for blocks for which `encode` and `decode` are implemented to indicate that the data is transformed. Using `encodedblockfilled(block, data)` will replace returned `nothing`s with the unchanged block.
{cell=main}
```julia
encodedblock(enc, Label(1:10)) === nothing
```
{cell=main}
```julia
encodedblock(enc, Label(1:10), true) == Label(1:10)
encodedblockfilled(enc, Label(1:10)) == Label(1:10)
```
- Encodings can be applied to tuples of blocks. The default behavior is to apply the encoding to each block separately.
{cell=main}
Expand Down
78 changes: 42 additions & 36 deletions src/FastAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,29 @@ module FastAI


using Base: NamedTuple
using Colors: colormaps_sequential
using Reexport
@reexport using DLPipelines
@reexport using FluxTraining
@reexport using DataLoaders
@reexport using Flux

using Animations
using Colors
using DataAugmentation
using DataAugmentation: getbounds, Bounds
import DataAugmentation
import DataAugmentation: getbounds, Bounds

import DLPipelines: methoddataset, methodmodel, methodlossfn, methoddataloaders,
mockmodel, mocksample, predict, predictbatch, mockmodel, encode, encodeinput,
encodetarget, decodeŷ, decodey
using IndirectArrays: IndirectArray
using LearnBase: getobs, nobs
using FilePathsBase
using FixedPointNumbers
using Flux
using Flux.Optimise
import Flux.Optimise: apply!, Optimiser, WeightDecay
using FluxTraining: Learner, handle
using FluxTraining.Events
using JLD2: jldsave, jldopen
using Markdown
import ImageInTerminal
using MLDataPattern
using Parameters
using PrettyTables
using Requires
using StaticArrays
Expand All @@ -41,48 +36,47 @@ import UnicodePlots
using Statistics
using InlineTest

include("learner.jl")

# Data block API
# ## Data block API
include("datablock/block.jl")
include("datablock/encoding.jl")
include("datablock/method.jl")
include("datablock/describe.jl")
include("datablock/checks.jl")
include("datablock/wrappers.jl")

# submodules
include("datasets/Datasets.jl")
@reexport using .Datasets

include("models/Models.jl")
using .Models
# ## Blocks
# ### Wrapper blocks
include("blocks/many.jl")

# Blocks
# ### Other
include("blocks/continuous.jl")
include("blocks/label.jl")

include("blocks/bounded.jl")
# ## Encodings
# ### Wrapper encodings
include("encodings/only.jl")

# Encodings
include("encodings/tabularpreprocessing.jl")
# ### Other
include("encodings/onehot.jl")
include("encodings/imagepreprocessing.jl")
include("encodings/projective.jl")
include("encodings/keypointpreprocessing.jl")


# Training interface
include("datablock/models.jl")
include("datablock/loss.jl")


# Interpretation
include("interpretation/backend.jl")
include("interpretation/text.jl")
include("interpretation/detect.jl")
include("interpretation/method.jl")
include("interpretation/showinterpretable.jl")
include("interpretation/learner.jl")
include("interpretation/detect.jl")

# training

# Training
include("learner.jl")
include("training/paramgroups.jl")
include("training/discriminativelrs.jl")
include("training/utils.jl")
Expand All @@ -94,20 +88,41 @@ include("training/metrics.jl")
include("serialization.jl")



# submodules
include("datasets/Datasets.jl")
@reexport using .Datasets


include("fasterai/methodregistry.jl")
include("fasterai/learningmethods.jl")
include("fasterai/defaults.jl")



# Domain-specific
include("Vision/Vision.jl")
@reexport using .Vision
export Image
export Vision

include("Tabular/Tabular.jl")
@reexport using .Tabular


include("interpretation/makie/stub.jl")
function __init__()
@require Makie="ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" begin
using .Makie
include("interpretation/makie/recipes.jl")
include("interpretation/makie/showmakie.jl")
include("interpretation/makie/lrfind.jl")
end
end

module Models
using ..FastAI.Tabular: TabularModel
using ..FastAI.Vision.Models: xresnet18, xresnet50, UNetDynamic
end


export
Expand All @@ -133,11 +148,9 @@ export
predictbatch,

# blocks
Image,
Mask,

Label,
LabelMulti,
Keypoints,
Many,
TableRow,
Continuous,
Expand All @@ -146,10 +159,7 @@ export
encode,
decode,
setup,
ProjectiveTransforms,
ImagePreprocessing,
OneHot,
KeypointPreprocessing,
Only,
Named,
augs_projection, augs_lighting,
Expand Down Expand Up @@ -179,10 +189,6 @@ export

# learning methods
findlearningmethods,
ImageClassificationSingle,
ImageClassificationMulti,
ImageSegmentation,
ImageKeypointRegression,
TabularClassificationSingle,
TabularRegression,

Expand Down
62 changes: 62 additions & 0 deletions src/Tabular/Tabular.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
module Tabular


using ..FastAI
using ..FastAI:
# blocks
Block, WrapperBlock, AbstractBlock, OneHotTensor, OneHotTensorMulti, Label,
LabelMulti, wrapped, Continuous,
# encodings
Encoding, StatefulEncoding, OneHot,
# visualization
ShowText,
# other
FASTAI_METHOD_REGISTRY, registerlearningmethod!

# for tests
using ..FastAI: testencoding

# extending
import ..FastAI:
blockmodel, blockbackbone, blocklossfn, encode, decode, checkblock,
encodedblock, decodedblock, showblock!, mockblock, setup


import DataAugmentation
import DataFrames: DataFrame
import Flux: Embedding, Chain, Dropout, Dense, Parallel
import PrettyTables
import Requires: @require
import ShowCases: ShowCase
import Tables
import Statistics

using InlineTest


# Blocks
include("blocks/tablerow.jl")

# Encodings
include("encodings/tabularpreprocessing.jl")


include("models.jl")
include("learningmethods/classification.jl")
include("learningmethods/regression.jl")
include("recipes.jl")


function __init__()
_registerrecipes()
@require Makie="ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" begin
import .Makie
import .Makie: @recipe, @lift
import .FastAI: ShowMakie
include("makie.jl")
end
end

export TableRow, TabularPreprocessing, TabularClassificationSingle, TabularRegression

end
83 changes: 83 additions & 0 deletions src/Tabular/blocks/tablerow.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@


# TableRow

"""
TableRow{M, N}(catcols, contcols, categorydict) <: Block

`Block` for table rows with M categorical and N continuous columns. `data`
is valid if it satisfies the `AbstractRow` interface in Tables.jl, values
present in indices for categorical and continuous columns are consistent,
and `data` is indexable by the elements of `catcols` and `contcols`.
"""
struct TableRow{M,N,T} <: Block
catcols::NTuple{M}
contcols::NTuple{N}
categorydict::T
end

function TableRow(catcols, contcols, categorydict)
TableRow{length(catcols),length(contcols)}(catcols, contcols, categorydict)
end

function checkblock(block::TableRow, x)
columns = Tables.columnnames(x)
(
all(col -> col ∈ columns, (block.catcols..., block.contcols...)) &&
all(
col ->
haskey(block.categorydict, col) &&
(ismissing(x[col]) || x[col] ∈ block.categorydict[col]),
block.catcols,
) &&
all(col -> ismissing(x[col]) || x[col] isa Number, block.contcols)
)
end

function mockblock(block::TableRow)
cols = (block.catcols..., block.contcols...)
vals = map(cols) do col
col in block.catcols ? rand(block.categorydict[col]) : rand()
end
return NamedTuple(zip(cols, vals))
end

"""
setup(TableRow, data[; catcols, contcols])

Create a `TableRow` block from data container `data::TableDataset`. If the
categorical and continuous columns are not specified manually, try to
guess them from the dataset's column types.
"""
function setup(::Type{TableRow}, data; catcols = nothing, contcols = nothing)
catcols_, contcols_ = getcoltypes(data)
catcols = isnothing(catcols) ? catcols_ : catcols
contcols = isnothing(contcols) ? contcols_ : contcols

return TableRow(
catcols,
contcols,
gettransformdict(data, DataAugmentation.Categorify, catcols),
)
end

function Base.show(io::IO, block::TableRow)
print(io, ShowCase(block, (:catcols, :contcols), show_params = false, new_lines = true))
end


# ## Interpretation

function showblock!(io, ::ShowText, block::TableRow, obs)
rowdata = vcat(
[obs[col] for col in block.catcols],
[obs[col] for col in block.contcols],
)
rownames = [block.catcols..., block.contcols...]
tabledata = hcat(rownames, rowdata)
PrettyTables.pretty_table(
io, tabledata;
alignment=[:r, :l],
highlighters=PrettyTables.Highlighter((obs, i, j) -> (j == 2), bold=true),
noheader=true, tf=PrettyTables.tf_borderless,)
end
Loading