diff --git a/src/Flux.jl b/src/Flux.jl index 6288dba715..c335488d4f 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -36,8 +36,6 @@ include("layers/normalisation.jl") include("data/Data.jl") -include("jit/JIT.jl") - @require CuArrays include("cuda/cuda.jl") end # module diff --git a/src/jit/JIT.jl b/src/jit/JIT.jl deleted file mode 100644 index 06a7da6b6e..0000000000 --- a/src/jit/JIT.jl +++ /dev/null @@ -1,9 +0,0 @@ -module JIT - -using MacroTools - -include("shapes.jl") -include("trace.jl") -include("lib.jl") - -end diff --git a/src/jit/lib.jl b/src/jit/lib.jl deleted file mode 100644 index 5301e57957..0000000000 --- a/src/jit/lib.jl +++ /dev/null @@ -1,40 +0,0 @@ -# Primitive definitions - -shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T = - Shape{T}(size(A,1)) - -shape(::typeof(*), A::MatShape{T}, B::MatShape{T}) where T = - Shape{T}(size(A,1),size(B,2)) - -inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) = - A_mul_B!(C, A, B) - -shape(::typeof(broadcast), f, xs...) = - Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...) - -inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...) - -shape(::typeof(reshape), x::Shape{T}, i...) where T = - Shape{T}(Base._reshape_uncolon(x, i)) - -inplace!(::typeof(reshape), y, x, i...) = copy!(y, x) - -# NNlib - -using NNlib -using ..Tracker: _conv, _maxpool - -shape(::typeof(softmax), x) = x -inplace!(::typeof(softmax), y, x) = NNlib.softmax!(y, x) - -shape(::typeof(_conv), x::Shape{T}, w::Shape{T}, stride, pad) where T = - Shape{T}(NNlib.cdims(size(x), size(w), pad, stride)) - -inplace!(::typeof(_conv), y, x, w, stride, pad) = - NNlib.conv!(y, x, w, stride = stride, pad = pad) - -shape(::typeof(_maxpool), x::Shape{T}, k, pad) where T = - Shape{T}(NNlib.pdims(size(x), k, pad, k)) - -inplace!(::typeof(_maxpool), y, x, k, pad) = - NNlib.maxpool!(y, x, k, pad = pad) diff --git a/src/jit/shapes.jl b/src/jit/shapes.jl deleted file mode 100644 index 3985c18511..0000000000 --- a/src/jit/shapes.jl +++ /dev/null @@ -1,56 +0,0 @@ -using ..Tracker: TrackedArray - -struct Shape{T,N} - dims::NTuple{N,Int} -end - -VecShape{T} = Shape{T,1} -MatShape{T} = Shape{T,2} - -Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims) -Shape{T}(dims::NTuple{N,Integer}) where {T,N} = Shape{T,N}(dims) - -Base.size(s::Shape) = s.dims -Base.size(s::Shape, n) = s.dims[n] -Base.ndims(s::Shape{T,N}) where {T,N} = N -Base.length(s::Shape) = prod(s.dims) -Base.eltype(s::Shape{T}) where T = T - -Base.sizeof(s::Shape{T}) where T = sizeof(T)*prod(size(s)) - -function Base.show(io::IO, s::Shape{T}) where T - print(io, "Shape{$T}(") - join(io, s.dims, ", ") - print(io, ")") -end - -shape(x) = x -shape(x::Shape) = x -shape(x::Tuple) = shape.(x) -shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...) -shape(x::TrackedArray) = shape(x.data) - -bytes(s::Shape) = sizeof(s) -bytes(x::Tuple) = sum(bytes.(x)) - -# Recover structure from byte buffers -# Make sure to hold on to the parent buffer for the lifetime of the data. - -function restructure(sh::Shape{T}, buf::Vector{UInt8}) where T - buf = unsafe_wrap(Array, pointer(buf), sizeof(sh)) - reshape(reinterpret(T, buf), size(sh)) -end - -# Execution with caches - -mutable struct Cached{F,A} - f::F - buffer::A -end - -function (c::Cached)(args...) - sh = shape(c.f, shape(args)...) - bytes(sh) > length(c.buffer) && (c.buffer = similar(c.buffer, bytes(sh))) - y = restructure(sh, c.buffer) - inplace!(c.f, y, args...) -end diff --git a/src/jit/trace.jl b/src/jit/trace.jl deleted file mode 100644 index 8266096fad..0000000000 --- a/src/jit/trace.jl +++ /dev/null @@ -1,75 +0,0 @@ -# This is hacky; we'll eventually reuse Cassette for better tracing. - -using ..Tracker, DataFlow -using ..Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf -using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax, - inputnode, constant - -vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...) -vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...) - -graph(x::Tracked, inputs...; cache = ObjectIdDict()) = - vcall(x.f.func, map(x -> graph(x, inputs...; cache = cache), x.f.args)...) - -function graph(x, inputs...; cache = ObjectIdDict()) - haskey(cache, x) && return cache[x] - i = findfirst(y -> x === y, inputs) - cache[x] = - i > 0 ? inputnode(i) : - istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) : - constant(x) -end - -function trace(f, args...) - inputs = param.(args) - graph(f(inputs...), inputs...) -end - -# Graph manipulation - -function liftparams(v) - ps = [] - v = prewalk(DataFlow.bumpinputs(v)) do v - isconstant(v) && istracked(v.value.value) || return v - push!(ps, v.value.value) - DataFlow.vcall(getindex, inputnode(1), length(ps)) - end - return v, ps -end - -function cacheall(v, buf = () -> UInt8[]) - prewalk(v) do v - iscall(v) && isconstant(v[1]) || return v - f = v[1].value.value - return vertex(Call(), constant(Cached(f, buf())), v[2:end]...) - end -end - -code(v, n) = syntax(vertex(Lambda(n, v))) - -struct Compiled{F,T<:Tuple} - model - func::F - params::T -end - -# TODO when we support derivatives -# (c::Compiled)(args...) = -# Tracker.track(Tracker.Call(c, args...), -# c.func(Tracker.data.(c.params), args...)) - -(c::Compiled)(args...) = c.func(Tracker.data.(c.params), Tracker.data.(args)...) - -Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")") - -function compile(f, args...) - v = trace(f, args...) - v, ps = liftparams(cacheall(v, () -> similar(args[1], UInt8, 1))) # no empty arrays on GPU - Compiled(f, eval(code(v, length(args)+1)), (ps...,)) -end - -function source(f, args...) - v = trace(f, args...) - v, ps = liftparams(v) - code(v, length(args)+1) |> prettify -end diff --git a/test/jit.jl b/test/jit.jl deleted file mode 100644 index 09efb02331..0000000000 --- a/test/jit.jl +++ /dev/null @@ -1,12 +0,0 @@ -using Flux, Base.Test -using Flux.JIT: compile - -@testset "JIT" begin - -m = Dense(10, 5) -f = compile(m, rand(10)) -x = rand(10) - -@test Tracker.data(m(x)) == f(x) - -end diff --git a/test/runtests.jl b/test/runtests.jl index f447c3b5b2..47f7e9e5be 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,7 +10,6 @@ include("layers/normalisation.jl") include("layers/stateless.jl") include("optimise.jl") include("data.jl") -include("jit.jl") if Base.find_in_path("CuArrays") ≠ nothing include("cuda/cuda.jl")