Skip to content

Commit

Permalink
Add logistic regression example on iris dataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshuaWhittemore committed Feb 28, 2019
1 parent 227bd29 commit bdba607
Show file tree
Hide file tree
Showing 5 changed files with 388 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ We welcome contributions of new models. They should be in a folder with a projec
* [BitString Parity Challenge](other/bitstring-parity)
* [MLP on housing data](other/housing/housing.jl) (low level API)
* [FizzBuzz](other/fizzbuzz/fizzbuzz.jl)
* [Logistic Regression Iris](other/iris/iris.jl)
278 changes: 278 additions & 0 deletions other/iris/Manifest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
[[AbstractTrees]]
deps = ["Markdown", "Test"]
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.2.1"

[[Adapt]]
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "0.4.2"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[BinDeps]]
deps = ["Compat", "Libdl", "SHA", "URIParser"]
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
version = "0.8.10"

[[BinaryProvider]]
deps = ["Libdl", "Pkg", "SHA", "Test"]
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3"

[[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.1"

[[ColorTypes]]
deps = ["FixedPointNumbers", "Random", "Test"]
git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.7.5"

[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.9.5"

[[CommonSubexpressions]]
deps = ["Test"]
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.2.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "49269e311ffe11ac5b334681d212329002a9832a"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.5.1"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.15.0"

[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[DiffResults]]
deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.4"

[[DiffRules]]
deps = ["Random", "Test"]
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.10"

[[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[FixedPointNumbers]]
deps = ["Test"]
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.5.3"

[[Flux]]
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DelimitedFiles", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"]
git-tree-sha1 = "1a683b7d156e0b1bf7909e17a6766df9f32f18e4"
repo-rev = "add-iris-dataset"
repo-url = "https://github.com/joshua-whittemore/Flux.jl"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.7.3+"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.3"

[[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.5.4"

[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[MacroTools]]
deps = ["Compat"]
git-tree-sha1 = "3fd1a3022952128935b449c33552eb65895380c1"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.4.5"

[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[Media]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
version = "0.5.0"

[[Missings]]
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.0"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.4.3"

[[NaNMath]]
deps = ["Compat"]
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.2"

[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.0.2"

[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[Profile]]
deps = ["Printf"]
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[Reexport]]
deps = ["Pkg"]
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "0.2.0"

[[Requires]]
deps = ["Test"]
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "0.5.2"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"

[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[SortingAlgorithms]]
deps = ["DataStructures", "Random", "Test"]
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "0.3.1"

[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.7.2"

[[StaticArrays]]
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.10.3"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[StatsBase]]
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
git-tree-sha1 = "8f68351fc2600bab59e68406b980b13b2100c472"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.28.1"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TranscodingStreams]]
deps = ["Pkg", "Random", "Test"]
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.8.1"

[[URIParser]]
deps = ["Test", "Unicode"]
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0"

[[UUIDs]]
deps = ["Random"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[ZipFile]]
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.0"
3 changes: 3 additions & 0 deletions other/iris/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[deps]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
39 changes: 39 additions & 0 deletions other/iris/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@


# Use Flux to do Logistic Regression on the Iris dataset

This is a very simple model, with a single layer that outputs to softmax.

Logistic regression can basically be thought of as a [single layer neural network](https://sebastianraschka.com/faq/docs/logisticregr-neuralnet.html).

## Data Source

The data source is Fisher's classic dataset, retrieved from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris).

## Usage

`cd` into `model-zoo/other/iris`, start the Julia REPL and instantiate the environment:

```julia

using Pkg; Pkg.activate("."); Pkg.instantiate()

```

Then train and evaluate the model:

```julia
julia> include("iris.jl")

Accuracy: 0.92

Confusion Matrix:

3×3 Array{Int64,2}:
16 0 0
0 15 2
0 2 15

julia>

```
67 changes: 67 additions & 0 deletions other/iris/iris.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@


using Flux
using Flux: crossentropy, normalise, onecold, onehotbatch
using Statistics: mean


labels = Flux.Data.Iris.labels()
features = Flux.Data.Iris.features()


# Subract mean, divide by std dev for normed mean of 0 and std dev of 1.
normed_features = normalise(features, dims=2)


klasses = sort(unique(labels))
onehot_labels = onehotbatch(labels, klasses)


# Split into training and test sets, 2/3 for training, 1/3 for test.
train_indices = [1:3:150 ; 2:3:150]

X_train = normed_features[:, train_indices]
y_train = onehot_labels[:, train_indices]

X_test = normed_features[:, 3:3:150]
y_test = onehot_labels[:, 3:3:150]


# Declare model taking 4 features as inputs and outputting 3 probabiltiies,
# one for each species of iris.
model = Chain(
Dense(4, 3),
softmax
)

loss(x, y) = crossentropy(model(x), y)

# Gradient descent optimiser with learning rate 0.5.
optimiser = Descent(0.5)


# Start Training.
for epoch in 1:100
Flux.train!(loss, params(model), [(X_train, y_train)], optimiser)
end


# Evaluate trained model against test set.
accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))

accuracy_score = accuracy(X_test, y_test)

println("\nAccuracy: $accuracy_score")

# Sanity check.
@assert accuracy_score > 0.8


function confusion_matrix(X, y)
= onehotbatch(onecold(model(X)), 1:3)
y *'
end

println("\nConfusion Matrix:\n")
display(confusion_matrix(X_test, y_test))

0 comments on commit bdba607

Please sign in to comment.