Skip to content

Commit d27401a

Browse files
authored
backport basic refactor from PR#100 (TuringLang#103)
* back port basic refactor from PR#100 * remove unused code, and some minor changes * move increase_counter * minor update
1 parent 375e2f8 commit d27401a

File tree

2 files changed

+73
-66
lines changed

2 files changed

+73
-66
lines changed

src/tapedfunction.jl

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,37 @@
1-
mutable struct Instruction{F}
2-
fun::F
3-
input::Tuple
4-
output
5-
tape
6-
end
7-
1+
abstract type AbstractInstruction end
82

93
mutable struct Tape
10-
tape::Vector{Instruction}
4+
tape::Vector{<:AbstractInstruction}
5+
counter::Int
116
owner
127
end
138

14-
Tape() = Tape(Vector{Instruction}(), nothing)
15-
Tape(owner) = Tape(Vector{Instruction}(), owner)
9+
mutable struct Instruction{F} <: AbstractInstruction
10+
fun::F
11+
input::Tuple
12+
output
13+
tape::Tape
14+
end
15+
16+
Tape() = Tape(Vector{AbstractInstruction}(), 1, nothing)
17+
Tape(owner) = Tape(Vector{AbstractInstruction}(), 1, owner)
1618
MacroTools.@forward Tape.tape Base.iterate, Base.length
1719
MacroTools.@forward Tape.tape Base.push!, Base.getindex, Base.lastindex
1820
const NULL_TAPE = Tape()
1921

22+
function setowner!(tape::Tape, owner)
23+
tape.owner = owner
24+
return tape
25+
end
26+
2027
mutable struct Box{T}
2128
val::T
2229
end
2330

2431
val(x) = x
2532
val(x::Box) = x.val
2633
box(x) = Box(x)
27-
any_box(x) = Box{Any}(x)
34+
box(x::Box) = x
2835

2936
gettape(x) = nothing
3037
gettape(x::Instruction) = x.tape
@@ -63,11 +70,21 @@ function (instr::Instruction{F})() where F
6370
instr.output.val = output
6471
end
6572

73+
function increase_counter!(t::Tape)
74+
t.counter > length(t) && return
75+
# instr = t[t.counter]
76+
t.counter += 1
77+
return t
78+
end
79+
6680
function run(tape::Tape, args...)
67-
input = map(box, args)
68-
tape[1].input = input
81+
if length(args) > 0
82+
input = map(box, args)
83+
tape[1].input = input
84+
end
6985
for instruction in tape
7086
instruction()
87+
increase_counter!(tape)
7188
end
7289
end
7390

@@ -77,21 +94,13 @@ function run_and_record!(tape::Tape, f, args...)
7794
box(f(map(val, args)...))
7895
catch e
7996
@warn e
80-
any_box(nothing)
97+
Box{Any}(nothing)
8198
end
8299
ins = Instruction(f, args, output, tape)
83100
push!(tape, ins)
84101
return output
85102
end
86103

87-
function dry_record!(tape::Tape, f, args...)
88-
# We don't know the type of box.val now, so we use Box{Any}
89-
output = any_box(nothing)
90-
ins = Instruction(f, args, output, tape)
91-
push!(tape, ins)
92-
return output
93-
end
94-
95104
function unbox_condition(ir)
96105
for blk in IRTools.blocks(ir)
97106
vars = keys(blk)
@@ -188,27 +197,14 @@ function (tf::TapedFunction)(args...)
188197
tape = IRTools.evalir(ir, tf.func, args...)
189198
tf.ir = ir
190199
tf.tape = tape
191-
tape.owner = tf
200+
setowner!(tape, tf)
192201
return result(tape)
193202
end
194203
# TODO: use cache
195204
run(tf.tape, args...)
196205
return result(tf.tape)
197206
end
198207

199-
function dry_run(tf::TapedFunction)
200-
isempty(tf.tape) || (return tf)
201-
@assert tf.arity >= 0 "TapedFunction need a fixed arity to dry run."
202-
args = fill(nothing, tf.arity)
203-
ir = IRTools.@code_ir tf.func(args...)
204-
ir = intercept(ir; recorder=:dry_record!)
205-
tape = IRTools.evalir(ir, tf.func, args...)
206-
tf.ir = ir
207-
tf.tape = tape
208-
tape.owner = tf
209-
return tf
210-
end
211-
212208
function Base.show(io::IO, tf::TapedFunction)
213209
buf = IOBuffer()
214210
println(buf, "TapedFunction:")

src/tapedtask.jl

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,32 @@
11
struct TapedTaskException
2-
exc
2+
exc::Exception
3+
backtrace
34
end
45

56
struct TapedTask
67
task::Task
78
tf::TapedFunction
8-
counter::Ref{Int}
99
produce_ch::Channel{Any}
1010
consume_ch::Channel{Int}
1111
produced_val::Vector{Any}
1212

1313
function TapedTask(
14-
t::Task, tf::TapedFunction, counter, pch::Channel{Any}, cch::Channel{Int})
15-
new(t, tf, counter, pch, cch, Any[])
14+
t::Task, tf::TapedFunction, pch::Channel{Any}, cch::Channel{Int})
15+
new(t, tf, pch, cch, Any[])
1616
end
1717
end
1818

1919
function TapedTask(tf::TapedFunction, args...)
2020
tf.owner != nothing && error("TapedFunction is owned to another task.")
21-
# dry_run(tf)
2221
isempty(tf.tape) && tf(args...)
23-
counter = Ref{Int}(1)
2422
produce_ch = Channel()
2523
consume_ch = Channel{Int}()
2624
task = @task try
27-
step_in(tf, counter, args)
25+
step_in(tf.tape, args)
2826
catch e
29-
put!(produce_ch, TapedTaskException(e))
30-
# @error "TapedTask Error: " exception=(e, catch_backtrace())
27+
bt = catch_backtrace()
28+
put!(produce_ch, TapedTaskException(e, bt))
29+
# @error "TapedTask Error: " exception=(e, bt)
3130
rethrow()
3231
finally
3332
@static if VERSION >= v"1.4"
@@ -40,7 +39,7 @@ function TapedTask(tf::TapedFunction, args...)
4039
close(produce_ch)
4140
close(consume_ch)
4241
end
43-
t = TapedTask(task, tf, counter, produce_ch, consume_ch)
42+
t = TapedTask(task, tf, produce_ch, consume_ch)
4443
task.storage === nothing && (task.storage = IdDict())
4544
task.storage[:tapedtask] = t
4645
tf.owner = t
@@ -53,25 +52,31 @@ TapedTask(f, args...) = TapedTask(TapedFunction(f, arity=length(args)), args...)
5352
TapedTask(t::TapedTask, args...) = TapedTask(func(t), args...)
5453
func(t::TapedTask) = t.tf.func
5554

56-
function step_in(tf::TapedFunction, counter::Ref{Int}, args)
57-
len = length(tf.tape)
58-
if(counter[] <= 1 && length(args) > 0)
55+
56+
function step_in(t::Tape, args)
57+
len = length(t)
58+
if(t.counter <= 1 && length(args) > 0)
5959
input = map(box, args)
60-
tf.tape[1].input = input
60+
t[1].input = input
6161
end
62-
while counter[] <= len
63-
tf.tape[counter[]]()
62+
while t.counter <= len
63+
t[t.counter]()
6464
# produce and wait after an instruction is done
65-
ttask = tf.owner
65+
ttask = t.owner.owner
6666
if length(ttask.produced_val) > 0
6767
val = pop!(ttask.produced_val)
6868
put!(ttask.produce_ch, val)
6969
take!(ttask.consume_ch) # wait for next consumer
7070
end
71-
counter[] += 1
71+
increase_counter!(t)
7272
end
7373
end
7474

75+
function next_step!(t::TapedTask)
76+
increase_counter!(t.tf.tape)
77+
return t
78+
end
79+
7580
#=
7681
# ** Approach (A) to implement `produce`:
7782
# Make`produce` a standalone instturction. This approach does NOT
@@ -186,18 +191,21 @@ function copy_box(old_box::Box{T}, roster::Dict{UInt64, Any}) where T
186191
end
187192
copy_box(o, roster::Dict{UInt64, Any}) = o
188193

189-
function Base.copy(t::Tape)
194+
function Base.copy(x::Instruction, on_tape::Tape, roster::Dict{UInt64, Any})
195+
input = map(x.input) do ob
196+
copy_box(ob, roster)
197+
end
198+
output = copy_box(x.output, roster)
199+
Instruction(x.fun, input, output, on_tape)
200+
end
201+
202+
function Base.copy(t::Tape, roster::Dict{UInt64, Any})
190203
old_data = t.tape
191-
new_data = Vector{Instruction}()
192-
new_tape = Tape(new_data, t.owner)
204+
new_data = Vector{AbstractInstruction}()
205+
new_tape = Tape(new_data, t.counter, t.owner)
193206

194-
roster = Dict{UInt64, Any}()
195207
for x in old_data
196-
input = map(x.input) do ob
197-
copy_box(ob, roster)
198-
end
199-
output = copy_box(x.output, roster)
200-
new_ins = Instruction(x.fun, input, output, new_tape)
208+
new_ins = copy(x, new_tape, roster)
201209
push!(new_data, new_ins)
202210
end
203211

@@ -207,8 +215,9 @@ end
207215
function Base.copy(tf::TapedFunction)
208216
new_tf = TapedFunction(tf.func; arity=tf.arity)
209217
new_tf.ir = tf.ir
210-
new_tape = copy(tf.tape)
211-
new_tape.owner = new_tf
218+
roster = Dict{UInt64, Any}()
219+
new_tape = copy(tf.tape, roster)
220+
setowner!(new_tape, new_tf)
212221
new_tf.tape = new_tape
213222
return new_tf
214223
end
@@ -217,6 +226,8 @@ function Base.copy(t::TapedTask)
217226
# t.counter[] <= 1 && error("Can't copy a TapedTask which is not running.")
218227
tf = copy(t.tf)
219228
new_t = TapedTask(tf)
220-
new_t.counter[] = t.counter[] + 1
229+
new_t.task.storage = copy(t.task.storage)
230+
new_t.task.storage[:tapedtask] = new_t
231+
next_step!(new_t)
221232
return new_t
222233
end

0 commit comments

Comments
 (0)