From e7a10cdf86abfa8d457cb7270387715a4e5f5ecc Mon Sep 17 00:00:00 2001 From: Josh Whittemore Date: Wed, 6 Feb 2019 16:23:01 -0800 Subject: [PATCH 1/3] Add logistic regression example on iris dataset. --- README.md | 2 +- other/iris/Manifest.toml | 278 +++++++++++++++++++++++++++++++++++++++ other/iris/Project.toml | 3 + other/iris/README.md | 39 ++++++ other/iris/iris.jl | 67 ++++++++++ 5 files changed, 388 insertions(+), 1 deletion(-) create mode 100644 other/iris/Manifest.toml create mode 100644 other/iris/Project.toml create mode 100644 other/iris/README.md create mode 100644 other/iris/iris.jl diff --git a/README.md b/README.md index 5e99a46bb..0c042d8c7 100644 --- a/README.md +++ b/README.md @@ -46,4 +46,4 @@ We welcome contributions of new models. They should be in a folder with a projec * [MLP on housing data](other/housing/housing.jl) (low level API) * [FizzBuzz](other/fizzbuzz/fizzbuzz.jl) * [Meta-Learning](other/meta-learning/MetaLearning.jl) - + * [Logistic Regression Iris](other/iris/iris.jl) diff --git a/other/iris/Manifest.toml b/other/iris/Manifest.toml new file mode 100644 index 000000000..dd7a74d82 --- /dev/null +++ b/other/iris/Manifest.toml @@ -0,0 +1,278 @@ +[[AbstractTrees]] +deps = ["Markdown", "Test"] +git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.2.1" + +[[Adapt]] +deps = ["LinearAlgebra", "Test"] +git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "0.4.2" + +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[BinDeps]] +deps = ["Compat", "Libdl", "SHA", "URIParser"] +git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9" +uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee" +version = "0.8.10" + +[[BinaryProvider]] +deps = ["Libdl", "Pkg", "SHA", "Test"] +git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e" +uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" +version = "0.5.3" + +[[CodecZlib]] +deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"] +git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.5.1" + +[[ColorTypes]] +deps = ["FixedPointNumbers", "Random", "Test"] +git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.7.5" + +[[Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"] +git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.9.5" + +[[CommonSubexpressions]] +deps = ["Test"] +git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.2.0" + +[[Compat]] +deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "49269e311ffe11ac5b334681d212329002a9832a" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "1.5.1" + +[[DataStructures]] +deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] +git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.15.0" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[DiffResults]] +deps = ["Compat", "StaticArrays"] +git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "0.0.4" + +[[DiffRules]] +deps = ["Random", "Test"] +git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "0.0.10" + +[[Distributed]] +deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[FixedPointNumbers]] +deps = ["Test"] +git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.5.3" + +[[Flux]] +deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DelimitedFiles", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"] +git-tree-sha1 = "1a683b7d156e0b1bf7909e17a6766df9f32f18e4" +repo-rev = "add-iris-dataset" +repo-url = "https://github.com/joshua-whittemore/Flux.jl" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.7.3+" + +[[ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] +git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.3" + +[[InteractiveUtils]] +deps = ["LinearAlgebra", "Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[Juno]] +deps = ["Base64", "Logging", "Media", "Profile", "Test"] +git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8" +uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" +version = "0.5.4" + +[[LibGit2]] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[MacroTools]] +deps = ["Compat"] +git-tree-sha1 = "3fd1a3022952128935b449c33552eb65895380c1" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.4.5" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[Media]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58" +uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27" +version = "0.5.0" + +[[Missings]] +deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"] +git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "0.4.0" + +[[Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[NNlib]] +deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"] +git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.4.3" + +[[NaNMath]] +deps = ["Compat"] +git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.2" + +[[OrderedCollections]] +deps = ["Random", "Serialization", "Test"] +git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.0.2" + +[[Pkg]] +deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[Profile]] +deps = ["Printf"] +uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[Reexport]] +deps = ["Pkg"] +git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "0.2.0" + +[[Requires]] +deps = ["Test"] +git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "0.5.2" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SortingAlgorithms]] +deps = ["DataStructures", "Random", "Test"] +git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "0.3.1" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[SpecialFunctions]] +deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"] +git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "0.7.2" + +[[StaticArrays]] +deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"] +git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "0.10.3" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[StatsBase]] +deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"] +git-tree-sha1 = "8f68351fc2600bab59e68406b980b13b2100c472" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.28.1" + +[[Test]] +deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[TranscodingStreams]] +deps = ["Pkg", "Random", "Test"] +git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.8.1" + +[[URIParser]] +deps = ["Test", "Unicode"] +git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69" +uuid = "30578b45-9adc-5946-b283-645ec420af67" +version = "0.4.0" + +[[UUIDs]] +deps = ["Random"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[ZipFile]] +deps = ["BinaryProvider", "Libdl", "Printf", "Test"] +git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac" +uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +version = "0.8.0" diff --git a/other/iris/Project.toml b/other/iris/Project.toml new file mode 100644 index 000000000..477b202d8 --- /dev/null +++ b/other/iris/Project.toml @@ -0,0 +1,3 @@ +[deps] +DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/other/iris/README.md b/other/iris/README.md new file mode 100644 index 000000000..c9e5cc96d --- /dev/null +++ b/other/iris/README.md @@ -0,0 +1,39 @@ + + +# Use Flux to do Logistic Regression on the Iris dataset + +This is a very simple model, with a single layer that outputs to softmax. + +Logistic regression can basically be thought of as a [single layer neural network](https://sebastianraschka.com/faq/docs/logisticregr-neuralnet.html). + +## Data Source + +The data source is Fisher's classic dataset, retrieved from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/iris). + +## Usage + +`cd` into `model-zoo/other/iris`, start the Julia REPL and instantiate the environment: + +```julia + +using Pkg; Pkg.activate("."); Pkg.instantiate() + +``` + +Then train and evaluate the model: + +```julia +julia> include("iris.jl") + +Accuracy: 0.92 + +Confusion Matrix: + +3×3 Array{Int64,2}: + 16 0 0 + 0 15 2 + 0 2 15 + +julia> + +``` diff --git a/other/iris/iris.jl b/other/iris/iris.jl new file mode 100644 index 000000000..0d0da16ef --- /dev/null +++ b/other/iris/iris.jl @@ -0,0 +1,67 @@ + + +using Flux +using Flux: crossentropy, normalise, onecold, onehotbatch +using Statistics: mean + + +labels = Flux.Data.Iris.labels() +features = Flux.Data.Iris.features() + + +# Subract mean, divide by std dev for normed mean of 0 and std dev of 1. +normed_features = normalise(features, dims=2) + + +klasses = sort(unique(labels)) +onehot_labels = onehotbatch(labels, klasses) + + +# Split into training and test sets, 2/3 for training, 1/3 for test. +train_indices = [1:3:150 ; 2:3:150] + +X_train = normed_features[:, train_indices] +y_train = onehot_labels[:, train_indices] + +X_test = normed_features[:, 3:3:150] +y_test = onehot_labels[:, 3:3:150] + + +# Declare model taking 4 features as inputs and outputting 3 probabiltiies, +# one for each species of iris. +model = Chain( + Dense(4, 3), + softmax +) + +loss(x, y) = crossentropy(model(x), y) + +# Gradient descent optimiser with learning rate 0.5. +optimiser = Descent(0.5) + + +# Start Training. +for epoch in 1:100 + Flux.train!(loss, params(model), [(X_train, y_train)], optimiser) +end + + +# Evaluate trained model against test set. +accuracy(x, y) = mean(onecold(model(x)) .== onecold(y)) + +accuracy_score = accuracy(X_test, y_test) + +println("\nAccuracy: $accuracy_score") + +# Sanity check. +@assert accuracy_score > 0.8 + + +function confusion_matrix(X, y) + ŷ = onehotbatch(onecold(model(X)), 1:3) + y * ŷ' +end + +println("\nConfusion Matrix:\n") +display(confusion_matrix(X_test, y_test)) + From de0df39a9e267670de8ee3291e77f9d110c486f8 Mon Sep 17 00:00:00 2001 From: Joshua Whittemore Date: Sat, 9 Mar 2019 15:25:54 -0800 Subject: [PATCH 2/3] call `train!` once with an iterator instead of 110 times with a for-loop --- other/iris/Manifest.toml | 36 ++++++++++++++++++------------------ other/iris/Project.toml | 1 - other/iris/README.md | 6 ++++-- other/iris/iris.jl | 8 ++++---- 4 files changed, 26 insertions(+), 25 deletions(-) diff --git a/other/iris/Manifest.toml b/other/iris/Manifest.toml index dd7a74d82..bdf919495 100644 --- a/other/iris/Manifest.toml +++ b/other/iris/Manifest.toml @@ -1,3 +1,5 @@ +# This file is machine-generated - editing it directly is not advised + [[AbstractTrees]] deps = ["Markdown", "Test"] git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b" @@ -27,9 +29,9 @@ version = "0.5.3" [[CodecZlib]] deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"] -git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9" +git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.5.1" +version = "0.5.2" [[ColorTypes]] deps = ["FixedPointNumbers", "Random", "Test"] @@ -51,9 +53,9 @@ version = "0.2.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "49269e311ffe11ac5b334681d212329002a9832a" +git-tree-sha1 = "195a3ffcb8b0762684b6821de18f83a16455c6ea" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "1.5.1" +version = "2.0.0" [[DataStructures]] deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] @@ -82,7 +84,7 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" version = "0.0.10" [[Distributed]] -deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] +deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[FixedPointNumbers]] @@ -92,12 +94,10 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.5.3" [[Flux]] -deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DelimitedFiles", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"] -git-tree-sha1 = "1a683b7d156e0b1bf7909e17a6766df9f32f18e4" -repo-rev = "add-iris-dataset" -repo-url = "https://github.com/joshua-whittemore/Flux.jl" +deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"] +git-tree-sha1 = "28e6dbf663fed71ea607414bc5f2f099d2831c0c" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.7.3+" +version = "0.7.3" [[ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] @@ -106,14 +106,14 @@ uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.10.3" [[InteractiveUtils]] -deps = ["LinearAlgebra", "Markdown"] +deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[Juno]] deps = ["Base64", "Logging", "Media", "Profile", "Test"] -git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8" +git-tree-sha1 = "dc568a3dbc4d0505d252d104bed03710a9a39441" uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" -version = "0.5.4" +version = "0.5.5" [[LibGit2]] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" @@ -244,9 +244,9 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[StatsBase]] deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"] -git-tree-sha1 = "8f68351fc2600bab59e68406b980b13b2100c472" +git-tree-sha1 = "435707791dc85a67d98d671c1c3fcf1b20b00f94" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.28.1" +version = "0.29.0" [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] @@ -254,9 +254,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[TranscodingStreams]] deps = ["Pkg", "Random", "Test"] -git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec" +git-tree-sha1 = "90f845c65c50bc57d6ffc815dbab2a4003ccf75c" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.8.1" +version = "0.9.1" [[URIParser]] deps = ["Test", "Unicode"] @@ -265,7 +265,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67" version = "0.4.0" [[UUIDs]] -deps = ["Random"] +deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[Unicode]] diff --git a/other/iris/Project.toml b/other/iris/Project.toml index 477b202d8..77df42abf 100644 --- a/other/iris/Project.toml +++ b/other/iris/Project.toml @@ -1,3 +1,2 @@ [deps] -DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/other/iris/README.md b/other/iris/README.md index c9e5cc96d..886808772 100644 --- a/other/iris/README.md +++ b/other/iris/README.md @@ -23,15 +23,17 @@ using Pkg; Pkg.activate("."); Pkg.instantiate() Then train and evaluate the model: ```julia + julia> include("iris.jl") +Starting training. -Accuracy: 0.92 +Accuracy: 0.94 Confusion Matrix: 3×3 Array{Int64,2}: 16 0 0 - 0 15 2 + 0 16 1 0 2 15 julia> diff --git a/other/iris/iris.jl b/other/iris/iris.jl index 0d0da16ef..8d4438c0e 100644 --- a/other/iris/iris.jl +++ b/other/iris/iris.jl @@ -40,11 +40,11 @@ loss(x, y) = crossentropy(model(x), y) optimiser = Descent(0.5) -# Start Training. -for epoch in 1:100 - Flux.train!(loss, params(model), [(X_train, y_train)], optimiser) -end +# Create iterator to train model over 110 epochs. +data_iterator = Iterators.repeated((X_train, y_train), 110) +println("Starting training.") +Flux.train!(loss, params(model), data_iterator, optimiser) # Evaluate trained model against test set. accuracy(x, y) = mean(onecold(model(x)) .== onecold(y)) From 694a4f62c364e38fb1836a91ac4673fae566c577 Mon Sep 17 00:00:00 2001 From: Joshua Whittemore Date: Mon, 25 Mar 2019 17:05:12 -0700 Subject: [PATCH 3/3] update manifest for Flux 0.8.1 add julia prompt to REPL excerpt in README.md --- other/iris/Manifest.toml | 28 +++++++++++++++++----------- other/iris/README.md | 2 +- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/other/iris/Manifest.toml b/other/iris/Manifest.toml index bdf919495..27927c7f5 100644 --- a/other/iris/Manifest.toml +++ b/other/iris/Manifest.toml @@ -53,9 +53,9 @@ version = "0.2.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "195a3ffcb8b0762684b6821de18f83a16455c6ea" +git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "2.0.0" +version = "2.1.0" [[DataStructures]] deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] @@ -94,10 +94,10 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.5.3" [[Flux]] -deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"] -git-tree-sha1 = "28e6dbf663fed71ea607414bc5f2f099d2831c0c" +deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DelimitedFiles", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Random", "Reexport", "Requires", "SHA", "Statistics", "StatsBase", "Test", "Tracker", "ZipFile"] +git-tree-sha1 = "0c4473ee0f1109e8eaddea972aa28bac3bd1a99e" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.7.3" +version = "0.8.1" [[ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] @@ -111,9 +111,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[Juno]] deps = ["Base64", "Logging", "Media", "Profile", "Test"] -git-tree-sha1 = "dc568a3dbc4d0505d252d104bed03710a9a39441" +git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175" uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" -version = "0.5.5" +version = "0.7.0" [[LibGit2]] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" @@ -155,9 +155,9 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[NNlib]] deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"] -git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d" +git-tree-sha1 = "9ac5cd21484189339b27840818c4882d1b6df7fd" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.4.3" +version = "0.5.0" [[NaNMath]] deps = ["Compat"] @@ -252,11 +252,17 @@ version = "0.29.0" deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[Tracker]] +deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"] +git-tree-sha1 = "4eeea9f0ef9b8c7d1c5c5b1f8f68cb9b7f45d7df" +uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +version = "0.1.0" + [[TranscodingStreams]] deps = ["Pkg", "Random", "Test"] -git-tree-sha1 = "90f845c65c50bc57d6ffc815dbab2a4003ccf75c" +git-tree-sha1 = "f42956022d8084539f1d7219f632542b0ea686ce" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.9.1" +version = "0.9.3" [[URIParser]] deps = ["Test", "Unicode"] diff --git a/other/iris/README.md b/other/iris/README.md index 886808772..6223e8c5b 100644 --- a/other/iris/README.md +++ b/other/iris/README.md @@ -16,7 +16,7 @@ The data source is Fisher's classic dataset, retrieved from the [UCI Machine Lea ```julia -using Pkg; Pkg.activate("."); Pkg.instantiate() +julia> using Pkg; Pkg.activate("."); Pkg.instantiate() ```