Skip to content

Commit 2d0b2ae

Browse files
authored
Gracefully handle crashed / cancelled txns (#73)
* Gracefully handle crashed / cancelled txns * Add test and fix the exception handling to unwrap the nested task errors.
1 parent ba4572c commit 2d0b2ae

File tree

3 files changed

+65
-21
lines changed

3 files changed

+65
-21
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
88
ConfParser = "88353bc9-fd38-507d-a820-d3b43837d6b9"
99
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1010
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
11+
ExceptionUnwrapping = "460bff9d-24e4-43bc-9d9f-a8973cb893f4"
1112
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
1213
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
1314
Mocking = "78c3b35d-d492-501b-9361-3d52fe80e533"

src/api.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import ProtoBuf
2323
using Base.Threads: @spawn
2424
import Dates
2525
import JSON3
26+
using ExceptionUnwrapping: has_wrapped_exception, unwrap_exception_to_root
2627

2728
using Mocking: Mocking, @mock # For unit testing, by mocking API server responses
2829

@@ -91,12 +92,23 @@ function wait_until_done(ctx::Context, id::AbstractString; start_time_ns = nothi
9192
txn = get_transaction(ctx, id)
9293
return transaction_is_done(txn)
9394
end
94-
t = @spawn get_transaction(ctx, id)
9595
m = @spawn get_transaction_metadata(ctx, id)
9696
p = @spawn get_transaction_problems(ctx, id)
9797
r = @spawn get_transaction_results(ctx, id)
98-
99-
return TransactionResponse(fetch(t), fetch(m), fetch(p), fetch(r))
98+
try
99+
return TransactionResponse(txn, fetch(m), fetch(p), fetch(r))
100+
catch e
101+
# (We use has_wrapped_exception to unwrap the TaskFailedException.)
102+
if has_wrapped_exception(e, HTTPError) &&
103+
unwrap_exception_to_root(e).status_code == 404
104+
# This is an (unfortunately) expected case if the engine crashes during a
105+
# transaction, or the transaction is cancelled. The transaction is marked
106+
# as ABORTED, but it has no results.
107+
return TransactionResponse(txn, nothing, nothing, nothing)
108+
else
109+
rethrow()
110+
end
111+
end
100112
catch
101113
# Always print out the transaction id so that users can still get the txn ID even
102114
# if there's an error during polling (such as an InterruptException).
@@ -613,7 +625,7 @@ function get_transaction_metadata(ctx::Context, id::AbstractString; kw...)
613625
path = PATH_ASYNC_TRANSACTIONS * "/$id/metadata"
614626
path = _mkurl(ctx, path)
615627
headers = _ensure_proto_accept_header(get(kw, :headers, []))
616-
rsp = request(ctx, "GET", path; kw..., headers)
628+
rsp = @mock request(ctx, "GET", path; kw..., headers)
617629
d = ProtoBuf.ProtoDecoder(IOBuffer(rsp.body));
618630
metadata = ProtoBuf.decode(d, protocol.MetadataInfo)
619631
return metadata

test/api.jl

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ const v2_get_transaction_results_response = HTTP.Response(200, [
100100
"",
101101
], "\r\n"))
102102

103+
const v2_get_transaction_json_completed = """{"id":"a3e3bc91-0a98-50ba-733c-0987e160eb7d","results_format_version":"2.0.1","state":"COMPLETED"}"""
104+
const v2_get_transaction_response_completed() = HTTP.Response(200,
105+
"""
106+
{"transaction": $(v2_get_transaction_json_completed)}
107+
""")
108+
103109
const v2_fastpath_response = HTTP.Response(200, [
104110
"Content-Type" => "Content-Type: multipart/form-data; boundary=8a89e52be8efe57f0b68ea75388314a3",
105111
"Transfer-Encoding" => "chunked",
@@ -111,7 +117,7 @@ const v2_fastpath_response = HTTP.Response(200, [
111117
"Content-Disposition: form-data; name=\"transaction\"; filename=\"\"",
112118
"Content-Type: application/json",
113119
"",
114-
"""{"id":"a3e3bc91-0a98-50ba-733c-0987e160eb7d","results_format_version":"2.0.1","state":"COMPLETED"}""",
120+
v2_get_transaction_json_completed,
115121
"--8a89e52be8efe57f0b68ea75388314a3",
116122
"Content-Disposition: form-data; name=\"metadata.proto\"; filename=\"\"",
117123
"Content-Type: application/x-protobuf",
@@ -208,14 +214,16 @@ end
208214
end
209215

210216
struct NetworkError code::Int end
211-
function make_fail_second_time_patch(first_response, fail_code)
217+
make_fail_after_second_time_patch(args...) =
218+
make_fail_after_nth_time_patch(2, args...)
219+
function make_fail_after_nth_time_patch(n, first_response, exception)
212220
request_idx = 0
213221
return (ctx::Context, args...; kw...) -> begin
214222
request_idx += 1
215-
if request_idx == 1
216-
return first_response
223+
if request_idx >= n
224+
throw(exception)
217225
else
218-
throw(NetworkError(fail_code))
226+
return first_response
219227
end
220228
end
221229
end
@@ -228,26 +236,49 @@ end
228236
@test_throws NetworkError(404) RAI.exec(ctx, "engine", "db", "2+2")
229237
end
230238

231-
# Test for an error thrown _after_ the transaction is created, before it completes.
232-
sync_error_patch = Mocking.Patch(RAI.request,
233-
make_fail_second_time_patch(v2_async_response, 500))
239+
@testset "test that txn ID is logged for txn errors while polling" begin
240+
# Test for an error thrown _after_ the transaction is created, before it completes.
241+
sync_error_patch = Mocking.Patch(RAI.request,
242+
make_fail_after_second_time_patch(v2_async_response, NetworkError(500)))
243+
244+
# See https://discourse.julialang.org/t/how-to-test-the-value-of-a-variable-from-info-log/37380/3
245+
# for an explanation of this logs-testing pattern.
246+
logs, _ = Test.collect_test_logs() do
247+
apply(sync_error_patch) do
248+
@test_throws NetworkError(500) RAI.exec(ctx, "engine", "db", "2+2")
249+
end
250+
end
251+
sym, val = collect(pairs(logs[1].kwargs))[1]
252+
@test sym :transaction_id
253+
@test val == "1fc9001b-1b88-8685-452e-c01bc6812429"
254+
end
234255

235-
# See https://discourse.julialang.org/t/how-to-test-the-value-of-a-variable-from-info-log/37380/3
236-
# for an explanation of this logs-testing pattern.
237-
logs, _ = Test.collect_test_logs() do
238-
apply(sync_error_patch) do
239-
@test_throws NetworkError(500) RAI.exec(ctx, "engine", "db", "2+2")
256+
@testset "Handle Aborted Txns with no metadata" begin
257+
# Test for the _specific case_ of a 404 from the RelationalAI service, once the txn
258+
# completes.
259+
260+
# Attempt to wait until a txn is done. This will attempt to fetch the metadata &
261+
# results once it's finished.
262+
metadata_404_patch = Mocking.Patch(RAI.request,
263+
make_fail_after_second_time_patch(
264+
# get_transaction() returns a completed Transaction resource
265+
v2_get_transaction_response_completed(),
266+
# So then we attempt to fetch the metadata or results or problems, and error
267+
RAI.HTTPError(404)
268+
)
269+
)
270+
271+
apply(metadata_404_patch) do
272+
RAI.wait_until_done(ctx, "<txn-id>", start_time_ns=0)
240273
end
241274
end
242-
sym, val = collect(pairs(logs[1].kwargs))[1]
243-
@test sym :transaction_id
244-
@test val == "1fc9001b-1b88-8685-452e-c01bc6812429"
275+
245276
end
246277

247278
@testset "exec with fast-path response only makes one request" begin
248279
# Throw an error if the SDK attempts to make two requests to RAI API:
249280
only_1_request_patch = Mocking.Patch(RAI.request,
250-
make_fail_second_time_patch(v2_fastpath_response, 500))
281+
make_fail_after_second_time_patch(v2_fastpath_response, NetworkError(500)))
251282

252283
ctx = Context("region", "scheme", "host", "2342", nothing, "audience")
253284
apply(only_1_request_patch) do

0 commit comments

Comments
 (0)