Skip to content

Commit

Permalink
Merge pull request #1874 from FluxML/cl/mlutils
Browse files Browse the repository at this point in the history
use MLUtils
  • Loading branch information
CarloLucibello authored Feb 18, 2022
2 parents 1f3915d + bdbcaaa commit 13a65be
Show file tree
Hide file tree
Showing 12 changed files with 46 additions and 420 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ been removed in favour of MLDatasets.jl.
* `flatten` is not exported anymore due to clash with Iterators.flatten.
* Remove Juno.jl progress bar support as it is now obsolete.
* `Dropout` gained improved compatibility with Int and Complex arrays and is now twice-differentiable.
* Many utily functions and the `DataLoader` are [now provided by MLUtils.jl](https://github.com/FluxML/Flux.jl/pull/1874).
* The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`.

## v0.12.10
* `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Expand All @@ -25,6 +26,7 @@ Adapt = "3.0"
ArrayInterface = "3.1, 4"
CUDA = "3"
Functors = "0.2.1"
MLUtils = "0.1.4"
MacroTools = "0.5"
NNlib = "0.8.2"
NNlibCUDA = "0.2"
Expand Down
5 changes: 4 additions & 1 deletion docs/src/models/nnlib.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ NNlib.gather
NNlib.gather!
NNlib.scatter
NNlib.scatter!
```

## Miscellaneous

## Utilities
```@docs
NNlib.logsumexp
```
18 changes: 10 additions & 8 deletions docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@ callback functions.

## Working with Data

Utilities for data processing are provided by [MLUtils.jl](https://github.com/JuliaML/MLUtils.jl). Below is a non-exhaustive list.

```@docs
Flux.unsqueeze
Flux.stack
Flux.unstack
Flux.chunk
Flux.frequencies
Flux.batch
Flux.unbatch
Flux.batchseq
MLUtils.unsqueeze
MLUtils.stack
MLUtils.unstack
MLUtils.chunk
MLUtils.group_counts
MLUtils.batch
MLUtils.unbatch
MLUtils.batchseq
Base.rpad(v::AbstractVector, n::Integer, p)
```

Expand Down
4 changes: 4 additions & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ using Statistics, Random, LinearAlgebra
using Zygote, MacroTools, ProgressLogging, Reexport
using MacroTools: @forward
@reexport using NNlib

using MLUtils

using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient

Expand Down Expand Up @@ -50,6 +53,7 @@ include("outputsize.jl")
include("data/Data.jl")
using .Data


include("losses/Losses.jl")
using .Losses # TODO: stop importing Losses in Flux's namespace in v0.12

Expand Down
5 changes: 1 addition & 4 deletions src/data/Data.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
module Data

using Random: shuffle!
using Base: @propagate_inbounds

include("dataloader.jl")
using MLUtils
export DataLoader

end#module
121 changes: 0 additions & 121 deletions src/data/dataloader.jl

This file was deleted.

35 changes: 0 additions & 35 deletions src/data/tree.jl

This file was deleted.

1 change: 1 addition & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ ones32(::Type, dims...) = throw(ArgumentError("Flux.ones32 is always Float32, us
zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32, use Base.zeros to specify the element type"))

# v0.13 deprecations
@deprecate frequencies(xs) group_counts(xs)
2 changes: 1 addition & 1 deletion src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 2}} =
Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 1}} =
OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L)

batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L)
MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatrix(_indices.(xs), L)

Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)

Expand Down
Loading

0 comments on commit 13a65be

Please sign in to comment.