Skip to content

Commit 3958c4d

Browse files
authored
Fix the arithmetic for wait_until_done(); add start_time kwarg in seconds, instead of ns. (#104)
Correctly account for the time sent from the API (ms since epoch), and add a new parameter, start_time, which is in *seconds* not nanoseconds. Deprecate the old parameter, to be removed in a future release. Also adds an complex mocked test for wait_until_done, which tests that it really is waiting for the expected time, based on the created_on field from the API, which comes back in milliseconds. -------------- The actual src/ changes are pretty straightforward: - change from ns to seconds: the API returns the unix timestamp in ms and julia's `time()` function returns timestamp in seconds (to ~microsecond precision). But then i made the PR more complicated by: - keeping the old interface to make the change backwards compatible, and - adding a complicated mocked unit test to make sure we're testing the behavior here and it's working as we expected ... sorry for the extra complexity. ## Commits * Fix the arithmetic for wait_until_done(); add API in seconds. Correctly account for the time sent from the API (ms since epoch), and add a new parameter, start_time, which is in *seconds* not nanoseconds. Deprecate the old parameter, to be removed in a future release. * Add test for wait_until_done that it counts correctly
1 parent 59d6e20 commit 3958c4d

File tree

4 files changed

+124
-26
lines changed

4 files changed

+124
-26
lines changed

src/api.jl

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ const PATH_ASYNC_TRANSACTIONS = "/transactions"
3535
const PATH_USERS = "/users"
3636
const ARROW_CONTENT_TYPE = "application/vnd.apache.arrow.stream"
3737

38+
const TXN_POLLING_OVERHEAD = 0.10
39+
3840
struct HTTPError <: Exception
3941
status_code::Int
4042
status_text::String
@@ -65,29 +67,49 @@ finished. A transaction has finished once it has reached one of the terminal sta
6567
`COMPLETED` or `ABORTED`. The polling uses a low-overhead exponential backoff in order to
6668
ensure low-latency results without overloading network traffic.
6769
"""
68-
function wait_until_done(ctx::Context, rsp::TransactionResponse; start_time_ns = nothing)
69-
wait_until_done(ctx, rsp.transaction; start_time_ns)
70+
function wait_until_done(ctx::Context, rsp::TransactionResponse;
71+
start_time_ns = nothing, # deprecated
72+
start_time = nothing,
73+
)
74+
wait_until_done(ctx, rsp.transaction; start_time_ns, start_time)
7075
end
71-
function wait_until_done(ctx::Context, txn::JSON3.Object; start_time_ns = nothing)
76+
function wait_until_done(ctx::Context, txn::JSON3.Object;
77+
start_time_ns = nothing, # deprecated
78+
start_time = nothing,
79+
)
80+
if start_time_ns !== nothing
81+
start_time = start_time_ns / 1e9,
82+
@warn "wait_until_done(): start_time_ns= is deprecated; please pass start_time= as a unix timestamp instead."
83+
end
84+
7285
# If the user is calling this manually, read the start time from the transaction object.
73-
if start_time_ns === nothing &&
86+
if start_time === nothing &&
7487
# NOTE: the fast-path txn may not include the created_on key.
7588
haskey(txn, :created_on)
76-
start_time_ns = _transaction_start_time_ns(txn)
89+
start_time = _transaction_start_time(txn)
7790
end
78-
wait_until_done(ctx, transaction_id(txn); start_time_ns)
91+
wait_until_done(ctx, transaction_id(txn); start_time)
7992
end
80-
function _transaction_start_time_ns(txn::JSON3.Object)
81-
return txn[:created_on] ÷ 1_000_000_000
93+
function _transaction_start_time(txn::JSON3.Object)
94+
# The API returns *milliseconds* since the epoch
95+
return txn[:created_on] / 1e3
8296
end
83-
function wait_until_done(ctx::Context, id::AbstractString; start_time_ns = nothing)
97+
function wait_until_done(ctx::Context, id::AbstractString;
98+
start_time_ns = nothing, # deprecated
99+
start_time = nothing,
100+
)
101+
if start_time_ns !== nothing
102+
start_time = start_time_ns / 1e9,
103+
@warn "wait_until_done(): start_time_ns= is deprecated; please pass start_time= as a unix timestamp instead."
104+
end
105+
84106
# If the user is calling this manually, read the start time from the transaction object.
85-
if start_time_ns === nothing
107+
if start_time === nothing
86108
txn = get_transaction(ctx, id)
87-
start_time_ns = _transaction_start_time_ns(txn)
109+
start_time = _transaction_start_time(txn)
88110
end
89111
try
90-
_poll_with_specified_overhead(; overhead_rate = 0.10, start_time_ns) do
112+
_poll_with_specified_overhead(; overhead_rate = TXN_POLLING_OVERHEAD, start_time) do
91113
txn = get_transaction(ctx, id)
92114
return transaction_is_done(txn)
93115
end
@@ -125,14 +147,14 @@ end
125147
function _poll_with_specified_overhead(
126148
f;
127149
overhead_rate, # Add xx% overhead through polling.
128-
start_time_ns = time_ns(), # Optional start time, otherwise defaults to now()
150+
start_time = time(), # Optional start time, otherwise defaults to now()
129151
n = typemax(Int), # Maximum number of polls
130152
max_delay = 120, # 2 min
131153
timeout_secs = Inf, # no timeout by default
132154
throw_on_timeout = false,
133155
)
156+
@debug "start time: $start_time"
134157
@assert overhead_rate >= 0.0
135-
timeout_ns = timeout_secs * 1e9
136158
local iter
137159
for i in 1:n
138160
iter = i
@@ -142,17 +164,19 @@ function _poll_with_specified_overhead(
142164
if done
143165
return nothing
144166
end
145-
current_delay = time_ns() - start_time_ns
146-
if current_delay > timeout_ns
167+
t = @mock(time())
168+
@debug "time: $t"
169+
current_delay_s = t - start_time
170+
if current_delay_s > timeout_secs
147171
break
148172
end
149-
duration = (current_delay * overhead_rate) / 1e9
173+
duration = current_delay_s * overhead_rate
150174
duration = min(duration, max_delay) # clamp the duration as specified.
151-
sleep(duration)
175+
@mock sleep(duration)
152176
end
153177

154178
# We have exhausted the iterator.
155-
current_delay_secs = (time_ns() - start_time_ns) * 1e9
179+
current_delay_secs = time() - start_time
156180
throw_on_timeout && error("Timed out after $iter iterations, $current_delay_secs seconds in `_poll_with_specified_overhead`.")
157181

158182
return nothing
@@ -526,14 +550,14 @@ Dict{String, Any} with 4 entries:
526550
function exec(ctx::Context, database::AbstractString, engine::AbstractString, source; inputs = nothing, readonly = false, kw...)
527551
# Record the initial start time so that we include the time to create the transaction
528552
# in our exponential backoff in `wait_until_done()`.
529-
start_time_ns = time_ns()
553+
start_time = time()
530554
# Create an Async transaction:
531555
transactionResponse = exec_async(ctx, database, engine, source; inputs=inputs, readonly=readonly, kw...)
532556
if transactionResponse.results !== nothing
533557
return transactionResponse
534558
end
535559
# Poll until the transaction is done, and return the results.
536-
return wait_until_done(ctx, transactionResponse; start_time_ns = start_time_ns)
560+
return wait_until_done(ctx, transactionResponse; start_time = start_time)
537561
end
538562

539563
function exec_async(ctx::Context, database::AbstractString, engine::AbstractString, source; inputs = nothing, readonly = false, kw...)

test/api.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import ProtoBuf
1313

1414
Mocking.activate()
1515

16+
include("wait_until_done.jl")
17+
1618
# -----------------------------------
1719
# v2 transactions
1820

@@ -269,7 +271,7 @@ end
269271
)
270272

271273
apply(metadata_404_patch) do
272-
RAI.wait_until_done(ctx, "<txn-id>", start_time_ns=0)
274+
RAI.wait_until_done(ctx, "<txn-id>", start_time=0)
273275
end
274276
end
275277

test/integration.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ function with_engine(f, ctx; existing_engine=nothing)
5757
engine_name = rnd_test_name()
5858
if isnothing(existing_engine)
5959
custom_headers = get(ENV, "CUSTOM_HEADERS", nothing)
60-
start_time_ns = time_ns()
60+
start_time = time()
6161
if isnothing(custom_headers)
6262
create_engine(ctx, engine_name)
6363
else
@@ -66,7 +66,7 @@ function with_engine(f, ctx; existing_engine=nothing)
6666
headers = JSON3.read(custom_headers, Dict{String, String})
6767
create_engine(ctx, engine_name; nothing, headers)
6868
end
69-
_poll_with_specified_overhead(; POLLING_KWARGS..., start_time_ns) do
69+
_poll_with_specified_overhead(; POLLING_KWARGS..., start_time) do
7070
state = get_engine(ctx, engine_name)[:state]
7171
state == "PROVISION_FAILED" && throw("Failed to provision engine $engine_name")
7272
state == "PROVISIONED"
@@ -80,8 +80,8 @@ function with_engine(f, ctx; existing_engine=nothing)
8080
# Engines cannot be deleted if they are still provisioning. We have to at least wait
8181
# until they are ready.
8282
if isnothing(existing_engine)
83-
start_time_ns = time_ns() - 2e9 # assume we started 2 seconds ago
84-
_poll_with_specified_overhead(; POLLING_KWARGS..., start_time_ns) do
83+
start_time = time() - 2 # assume we started 2 seconds ago
84+
_poll_with_specified_overhead(; POLLING_KWARGS..., start_time) do
8585
state = get_engine(ctx, engine_name)[:state]
8686
state == "PROVISION_FAILED" && throw("Failed to provision engine $engine_name")
8787
state == "PROVISIONED"

test/wait_until_done.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
using ExceptionUnwrapping: unwrap_exception_to_root
2+
3+
# This test is _pretty complicated_ since it's trying to test something that depends on
4+
# timing: testing that wait_until_done() polls for the expected amount of time in between
5+
# calls to get_transaction.
6+
# Testing anything to do with timing is always complicated. We tackle it here by mocking
7+
# both sleep() and time(), and injecting fake times, and then making sure that the
8+
# function is computing the correct duration to sleep, based on those times.
9+
@testset "wait_until_done polls correctly" begin
10+
now_ms = round(Int, time() * 1e3)
11+
txn_str = """{
12+
"id": "a3e3bc91-0a98-50ba-733c-0987e160eb7d",
13+
"results_format_version": "2.0.1",
14+
"state": "RUNNING",
15+
"created_on": $(now_ms)
16+
}"""
17+
txn = JSON3.read(txn_str)
18+
19+
ctx = Context("region", "scheme", "host", "2342", nothing, "audience")
20+
21+
start = now_ms / 1e3
22+
# Simulate OVERHEAD of 0.1 + round-trip-time of 0.5
23+
times = [
24+
start + 2, # First call takes 2 seconds then returns async
25+
start + 2.2 + 0.5, # So we slept 0.2 seconds, then get_txn takes 0.5 secs
26+
start + 2.97 + 0.5 # Now we sleep 2.7 * 1.1 ≈ 2.97, then again 0.5 RTT.
27+
]
28+
i = 1
29+
time_patch = @patch function Base.time()
30+
v = times[i]
31+
i += 1
32+
return v
33+
end
34+
# Here, we test that each call to sleep is the correct calculation of current "time"
35+
# minus start time * the overhead.
36+
sleep_patch = @patch function Base.sleep(duration)
37+
@info "Mock sleep for $duration"
38+
@test duration (times[i-1] - start) * RAI.TXN_POLLING_OVERHEAD
39+
end
40+
41+
# This is returned on each get_txn() request.
42+
unfinished_response = HTTP.Response(
43+
200,
44+
["Content-Type" => "application/json"],
45+
body = """{"transaction": $(txn_str)}"""
46+
)
47+
48+
# Stop the test after 3 polls.
49+
ABORT = :ABORT_TEST
50+
51+
request_patch = @patch function RAI.request(ctx::Context, args...; kw...)
52+
if i <= 3
53+
return unfinished_response
54+
else
55+
# Finish the test
56+
throw(ABORT)
57+
end
58+
end
59+
60+
# Call the function with the patches. Assert that it ends with our ABORT exception.
61+
apply([time_patch, sleep_patch, request_patch]) do
62+
try
63+
wait_until_done(ctx, txn)
64+
catch e
65+
@assert unwrap_exception_to_root(e) == ABORT
66+
end
67+
end
68+
69+
# Test that we made it through all the expected polls, so that we know the above
70+
# `@test`s all triggered.
71+
@test i == 4
72+
end

0 commit comments

Comments
 (0)