Skip to content

Commit 87262d6

Browse files
FredericWantiezKDr2yebai
authored
Update args at copy (#141)
* Update mappingx * Revert "Update mappingx" This reverts commit 0355eab. * store and copy args of tapedtask * check type of args * Update src/tapedtask.jl * Update src/tapedtask.jl Co-authored-by: KDr2 <zhuo.dev@gmail.com> Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
1 parent f4941a0 commit 87262d6

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

src/tapedtask.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,22 @@ struct TapedTaskException
33
backtrace::Vector{Any}
44
end
55

6-
struct TapedTask{F}
6+
struct TapedTask{F, AT<:Tuple}
77
task::Task
88
tf::TapedFunction{F}
9+
args::AT
910
produce_ch::Channel{Any}
1011
consume_ch::Channel{Int}
1112
produced_val::Vector{Any}
1213

1314
function TapedTask(
1415
t::Task,
1516
tf::TapedFunction{F},
17+
args::AT,
1618
produce_ch::Channel{Any},
1719
consume_ch::Channel{Int}
18-
) where {F}
19-
new{F}(t, tf, produce_ch, consume_ch, Any[])
20+
) where {F, AT<:Tuple}
21+
new{F, AT}(t, tf, args, produce_ch, consume_ch, Any[])
2022
end
2123
end
2224

@@ -55,7 +57,7 @@ function TapedTask(tf::TapedFunction, args...)
5557
produce_ch = Channel()
5658
consume_ch = Channel{Int}()
5759
task = @task wrap_task(tf, produce_ch, consume_ch, args...)
58-
t = TapedTask(task, tf, produce_ch, consume_ch)
60+
t = TapedTask(task, tf, args, produce_ch, consume_ch)
5961
task.storage === nothing && (task.storage = IdDict())
6062
task.storage[:tapedtask] = t
6163
return t
@@ -159,9 +161,15 @@ Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown()
159161

160162
# copy the task
161163

162-
function Base.copy(t::TapedTask)
164+
function Base.copy(t::TapedTask; args=())
163165
tf = copy(t.tf)
164-
new_t = TapedTask(tf)
166+
task_args = if length(args) > 0
167+
typeof(args) == typeof(t.args) || error("bad arguments")
168+
args
169+
else
170+
tape_copy.(t.args)
171+
end
172+
new_t = TapedTask(tf, task_args...)
165173
storage = t.task.storage::IdDict{Any,Any}
166174
new_t.task.storage = copy(storage)
167175
new_t.task.storage[:tapedtask] = new_t

test/tapedtask.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,4 +184,20 @@
184184
end
185185
end
186186
end
187+
188+
@testset "Issues" begin
189+
@testset "Issue-140, copy unstarted task" begin
190+
function f(x)
191+
for i in 1:3
192+
produce(i + x)
193+
end
194+
end
195+
196+
ttask = TapedTask(f, 3)
197+
ttask2 = copy(ttask)
198+
@test consume(ttask2) == 4
199+
ttask3 = copy(ttask; args=(4,))
200+
@test consume(ttask3) == 5
201+
end
202+
end
187203
end

0 commit comments

Comments
 (0)