Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support finding dependencies of a variable #63

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/IRTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ let exports = :[
definitions, usages, dominators, domtree, domorder, domorder!, renumber,
merge_returns!, expand!, prune!, ssa!, inlineable!, log!, pis!, func, evalir,
Simple, Loop, Multiple, reloop, stackify, functional, cond, WorkQueue,
Graph, liveness, interference, colouring, inline,
Graph, liveness, interference, colouring, inline, dependencies,
# Reflection, Dynamo
Meta, Lambda, meta, dynamo, transform, refresh, recurse!, self,
varargs!, slots!,
Expand Down
91 changes: 91 additions & 0 deletions src/passes/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,97 @@ function usages(b::Block)
return uses
end

usages(st::Statement) = usages(st.expr)
usages(ex) = Set{Variable}()

function usages(ex::Expr)
uses = Set{Variable}()
for x in ex.args
x isa Variable && push!(uses, x)
end
return uses
end

function block_changes_deps(deps, ir, b)
for (v, st) in b
if haskey(deps, v)
(usages(st) ⊆ deps[v]) || return true
else
return true
end
end

brs = branches(b)
for br in brs
if br.block > 0
next_block = block(ir, br.block)
if !isempty(br.args)
for (x, y) in zip(arguments(next_block), br.args)
haskey(deps, x) && (y in deps[x]) && return true
end
end
end
end
return false
end

function update_deps!(deps, v, direct)
set = get!(deps, v, Set{Variable}())
union!(set, setdiff(direct, (v, )))

for x in direct
if (v != x) && haskey(deps, x) && !(deps[x] ⊆ set)
update_deps!(deps, v, deps[x])
end
end
return deps
end

"""
dependencies(ir::IR)

Return the list of direct dependencies for each variable.
"""
function dependencies(ir::IR)
worklist = [block(ir, 1)]
deps = Dict()
while !isempty(worklist)
b = pop!(worklist)
for (v, st) in b
update_deps!(deps, v, usages(st))
end

brs = branches(b)
jump_next_block = true
for br in brs
if br.condition === nothing
jump_next_block = false
end

if br.block > 0 # reachable
next_block = block(ir, br.block)
if !isempty(br.args) # pass arguments
for (x, y) in zip(arguments(next_block), br.args)
y isa Variable && update_deps!(deps, x, (y, ))
end
end

if block_changes_deps(deps, ir, next_block)
push!(worklist, next_block)
end
end
end

if jump_next_block
next_block = block(ir, b.id + 1)
if block_changes_deps(deps, ir, next_block)
push!(worklist, next_block)
end
end
end
return deps
end

function usecounts(ir::IR)
counts = Dict{Variable,Int}()
prewalk(ir) do x
Expand Down
44 changes: 43 additions & 1 deletion test/analysis.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using IRTools, Test
using IRTools: CFG, dominators, domtree
using IRTools: CFG, dominators, domtree, dependencies, var

relu(x) = (y = x > 0 ? x : 0)
ir = @code_ir relu(1)
Expand All @@ -10,3 +10,45 @@ ir = @code_ir relu(1)
@test domtree(CFG(ir)) == (1 => [2 => [], 3 => [], 4 => []])

@test domtree(CFG(ir)', entry = 4) == (4 => [1 => [], 2 => [], 3 => []])

function f(x)
x = sin(x)
y = cos(x)

if x > 1
x = cos(x) + 1
else
x = y + 1
end
return x
end

ir = @code_ir f(1.0)

deps = dependencies(ir)

@test deps[var(9)] == Set(var.([2, 8, 7, 3, 4, 6]))
@test deps[var(8)] == Set(var.([3, 2, 4]))
@test deps[var(7)] == Set(var.([6, 3, 2]))
@test deps[var(6)] == Set(var.([3, 2]))
@test deps[var(5)] == Set(var.([3, 2]))
@test deps[var(4)] == Set(var.([3, 2]))
@test deps[var(3)] == Set([var(2)])

function pow(x, n)
r = 1
while n > 0
n -= 1
r *= x
end
return r
end

ir = @code_ir pow(1.0, 2)
deps = dependencies(ir)

@test deps[var(8)] == Set(var.([5, 2]))
@test deps[var(7)] == Set(var.([4, 3]))
@test deps[var(6)] == Set(var.([4, 3]))
@test deps[var(5)] == Set(var.([2, 8]))
@test deps[var(4)] == Set(var.([3, 7]))