Skip to content

Commit 1dc9023

Browse files
bors[bot]Dsantra92
andauthored
Merge #1256
1256: Updated onehot.jl r=CarloLucibello a=Dsantra92 @janEbert @CarloLucibello Updated doctstrings and doctests for onehot.jl and onehot.md Co-authored-by: Dsantra92 <santradibbo@gmail.com> Co-authored-by: Deeptendu Santra <santradibbo@gmail.com>
2 parents e3e5e54 + 319fefb commit 1dc9023

File tree

2 files changed

+22
-21
lines changed

2 files changed

+22
-21
lines changed

docs/src/data/onehot.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
It's common to encode categorical variables (like `true`, `false` or `cat`, `dog`) in "one-of-k" or ["one-hot"](https://en.wikipedia.org/wiki/One-hot) form. Flux provides the `onehot` function to make this easy.
44

5-
```
5+
```jldoctest onehot
66
julia> using Flux: onehot, onecold
77
88
julia> onehot(:b, [:a, :b, :c])
@@ -20,7 +20,7 @@ julia> onehot(:c, [:a, :b, :c])
2020

2121
The inverse is `onecold` (which can take a general probability distribution, as well as just booleans).
2222

23-
```julia
23+
```jldoctest onehot
2424
julia> onecold(ans, [:a, :b, :c])
2525
:c
2626
@@ -40,20 +40,20 @@ Flux.onecold
4040

4141
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `onecold` treats matrices as batches.
4242

43-
```julia
43+
```jldoctest onehot
4444
julia> using Flux: onehotbatch
4545
4646
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
47-
3×3 Flux.OneHotMatrix:
48-
false true false
49-
true false true
50-
false false false
51-
52-
julia> onecold(ans, [:a, :b, :c])
53-
3-element Array{Symbol,1}:
54-
:b
55-
:a
56-
:b
47+
3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
48+
0 1 0
49+
1 0 1
50+
0 0 0
51+
52+
julia> onecold(ans, [:a, :b, :c])
53+
3-element Array{Symbol,1}:
54+
:b
55+
:a
56+
:b
5757
```
5858

5959
Note that these operations returned `OneHotVector` and `OneHotMatrix` rather than `Array`s. `OneHotVector`s behave like normal vectors but avoid any unnecessary cost compared to using an integer index directly. For example, multiplying a matrix with a one-hot vector simply slices out the relevant row of the matrix under the hood.

src/onehot.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,11 @@ cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.d
5151
"""
5252
onehot(l, labels[, unk])
5353
54-
Create a `OneHotVector` with its `l`-th element `true` based on the
55-
possible set of `labels`.
56-
If `unk` is given, return `onehot(unk, labels)` if the input label `l` is not found
57-
in `labels`; otherwise, it will raise an error.
54+
Return a `OneHotVector` where only first occourence of `l` in `labels` is `1` and
55+
all other elements are `0`.
56+
57+
If `l` is not found in labels and `unk` is present, the function returns
58+
`onehot(unk, labels)`; otherwise the function raises an error.
5859
5960
# Examples
6061
```jldoctest
@@ -86,10 +87,10 @@ end
8687
"""
8788
onehotbatch(ls, labels[, unk...])
8889
89-
Create a `OneHotMatrix` with a batch of labels based on the
90-
possible set of `labels`.
91-
If `unk` is given, return [`onehot(unk, labels)`](@ref) if one of the input
92-
labels `ls` is not found in `labels`; otherwise it will error.
90+
Return a `OneHotMatrix` where `k`th column of the matrix is `onehot(ls[k], labels)`.
91+
92+
If one of the input labels `ls` is not found in `labels` and `unk` is given,
93+
return [`onehot(unk, labels)`](@ref) ; otherwise the function will raise an error.
9394
9495
# Examples
9596
```jldoctest

0 commit comments

Comments
 (0)