diff --git a/Project.toml b/Project.toml index f1545010f6..331d683991 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" diff --git a/src/data/Data.jl b/src/data/Data.jl index ddf0624b4b..ab78f4163c 100644 --- a/src/data/Data.jl +++ b/src/data/Data.jl @@ -1,11 +1,27 @@ module Data import ..Flux +import SHA export CMUDict, cmudict deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...) +function download_and_verify(url, path, hash) + tmppath = tempname() + download(url, tmppath) + hash_download = open(tmppath) do f + bytes2hex(SHA.sha256(f)) + end + if hash_download !== hash + msg = "Hash Mismatch!\n" + msg *= " Expected sha256: $hash\n" + msg *= " Calculated sha256: $hash_download" + error(msg) + end + mv(tmppath, path; force=true) +end + function __init__() mkpath(deps()) end diff --git a/src/data/cmudict.jl b/src/data/cmudict.jl index 926f2342d2..f89ded4fdb 100644 --- a/src/data/cmudict.jl +++ b/src/data/cmudict.jl @@ -2,23 +2,25 @@ module CMUDict export cmudict -using ..Data: deps +using ..Data: deps, download_and_verify const version = "0.7b" const cache_prefix = "https://cache.julialang.org" function load() - suffixes = ["", ".phones", ".symbols"] + suffixes_and_hashes = [("" , "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4"), + (".phones" , "ffb588a5e55684723582c7256e1d2f9fadb130011392d9e59237c76e34c2cfd6"), + (".symbols", "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027")] if isdir(deps("cmudict")) - if all(isfile(deps("cmudict", "cmudict$x")) for x in suffixes) + if all(isfile(deps("cmudict", "cmudict$x")) for (x, _) in suffixes_and_hashes) return end end @info "Downloading CMUDict dataset" mkpath(deps("cmudict")) - for x in suffixes - download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x", - deps("cmudict", "cmudict$x")) + for (x, hash) in suffixes_and_hashes + download_and_verify("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x", + deps("cmudict", "cmudict$x"), hash) end end diff --git a/src/data/fashion-mnist.jl b/src/data/fashion-mnist.jl index e4510b476b..da78b605aa 100644 --- a/src/data/fashion-mnist.jl +++ b/src/data/fashion-mnist.jl @@ -1,19 +1,20 @@ module FashionMNIST using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel +using ..Data: download_and_verify const dir = joinpath(@__DIR__, "../../deps/fashion-mnist") function load() mkpath(dir) cd(dir) do - for file in ["train-images-idx3-ubyte", - "train-labels-idx1-ubyte", - "t10k-images-idx3-ubyte", - "t10k-labels-idx1-ubyte"] + for (file, hash) in [("train-images-idx3-ubyte", "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84"), + ("train-labels-idx1-ubyte", "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845"), + ("t10k-images-idx3-ubyte" , "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073"), + ("t10k-labels-idx1-ubyte" , "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5")] isfile(file) && continue @info "Downloading Fashion-MNIST dataset" - download("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", "$file.gz") + download_and_verify("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", "$file.gz", hash) open(file, "w") do io write(io, gzopen(read, "$file.gz")) end diff --git a/src/data/mnist.jl b/src/data/mnist.jl index 4397618d97..b9c0540a34 100644 --- a/src/data/mnist.jl +++ b/src/data/mnist.jl @@ -1,6 +1,7 @@ module MNIST using CodecZlib, Colors +using ..Data: download_and_verify const Gray = Colors.Gray{Colors.N0f8} @@ -15,13 +16,13 @@ end function load() mkpath(dir) cd(dir) do - for file in ["train-images-idx3-ubyte", - "train-labels-idx1-ubyte", - "t10k-images-idx3-ubyte", - "t10k-labels-idx1-ubyte"] + for (file, hash) in [("train-images-idx3-ubyte", "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"), + ("train-labels-idx1-ubyte", "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"), + ("t10k-images-idx3-ubyte" , "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"), + ("t10k-labels-idx1-ubyte" , "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6")] isfile(file) && continue @info "Downloading MNIST dataset" - download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz") + download_and_verify("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz", hash) open(file, "w") do io write(io, gzopen(read, "$file.gz")) end diff --git a/src/data/sentiment.jl b/src/data/sentiment.jl index 56c9e8ea21..ecb1ab8dc7 100644 --- a/src/data/sentiment.jl +++ b/src/data/sentiment.jl @@ -1,13 +1,13 @@ module Sentiment using ZipFile -using ..Data: deps +using ..Data: deps, download_and_verify function load() isfile(deps("sentiment.zip")) && return @info "Downloading sentiment treebank dataset" - download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip", - deps("sentiment.zip")) + download_and_verify("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip", + deps("sentiment.zip"), "5c613a4f673fc74097d523a2c83f38e0cc462984d847b82c7aaf36b01cbbbfcc") end getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)]