Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
ConfParser = "88353bc9-fd38-507d-a820-d3b43837d6b9"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
ExceptionUnwrapping = "460bff9d-24e4-43bc-9d9f-a8973cb893f4"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Mocking = "78c3b35d-d492-501b-9361-3d52fe80e533"
Expand Down
20 changes: 16 additions & 4 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import ProtoBuf
using Base.Threads: @spawn
import Dates
import JSON3
using ExceptionUnwrapping: has_wrapped_exception, unwrap_exception_to_root

using Mocking: Mocking, @mock # For unit testing, by mocking API server responses

Expand Down Expand Up @@ -91,12 +92,23 @@ function wait_until_done(ctx::Context, id::AbstractString; start_time_ns = nothi
txn = get_transaction(ctx, id)
return transaction_is_done(txn)
end
t = @spawn get_transaction(ctx, id)
m = @spawn get_transaction_metadata(ctx, id)
p = @spawn get_transaction_problems(ctx, id)
r = @spawn get_transaction_results(ctx, id)

return TransactionResponse(fetch(t), fetch(m), fetch(p), fetch(r))
try
return TransactionResponse(txn, fetch(m), fetch(p), fetch(r))
catch e
# (We use has_wrapped_exception to unwrap the TaskFailedException.)
if has_wrapped_exception(e, HTTPError) &&
unwrap_exception_to_root(e).status_code == 404
# This is an (unfortunately) expected case if the engine crashes during a
# transaction, or the transaction is cancelled. The transaction is marked
# as ABORTED, but it has no results.
return TransactionResponse(txn, nothing, nothing, nothing)
else
rethrow()
end
end
catch
# Always print out the transaction id so that users can still get the txn ID even
# if there's an error during polling (such as an InterruptException).
Expand Down Expand Up @@ -613,7 +625,7 @@ function get_transaction_metadata(ctx::Context, id::AbstractString; kw...)
path = PATH_ASYNC_TRANSACTIONS * "/$id/metadata"
path = _mkurl(ctx, path)
headers = _ensure_proto_accept_header(get(kw, :headers, []))
rsp = request(ctx, "GET", path; kw..., headers)
rsp = @mock request(ctx, "GET", path; kw..., headers)
d = ProtoBuf.ProtoDecoder(IOBuffer(rsp.body));
metadata = ProtoBuf.decode(d, protocol.MetadataInfo)
return metadata
Expand Down
65 changes: 48 additions & 17 deletions test/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ const v2_get_transaction_results_response = HTTP.Response(200, [
"",
], "\r\n"))

const v2_get_transaction_json_completed = """{"id":"a3e3bc91-0a98-50ba-733c-0987e160eb7d","results_format_version":"2.0.1","state":"COMPLETED"}"""
const v2_get_transaction_response_completed() = HTTP.Response(200,
"""
{"transaction": $(v2_get_transaction_json_completed)}
""")

const v2_fastpath_response = HTTP.Response(200, [
"Content-Type" => "Content-Type: multipart/form-data; boundary=8a89e52be8efe57f0b68ea75388314a3",
"Transfer-Encoding" => "chunked",
Expand All @@ -111,7 +117,7 @@ const v2_fastpath_response = HTTP.Response(200, [
"Content-Disposition: form-data; name=\"transaction\"; filename=\"\"",
"Content-Type: application/json",
"",
"""{"id":"a3e3bc91-0a98-50ba-733c-0987e160eb7d","results_format_version":"2.0.1","state":"COMPLETED"}""",
v2_get_transaction_json_completed,
"--8a89e52be8efe57f0b68ea75388314a3",
"Content-Disposition: form-data; name=\"metadata.proto\"; filename=\"\"",
"Content-Type: application/x-protobuf",
Expand Down Expand Up @@ -208,14 +214,16 @@ end
end

struct NetworkError code::Int end
function make_fail_second_time_patch(first_response, fail_code)
make_fail_after_second_time_patch(args...) =
make_fail_after_nth_time_patch(2, args...)
function make_fail_after_nth_time_patch(n, first_response, exception)
request_idx = 0
return (ctx::Context, args...; kw...) -> begin
request_idx += 1
if request_idx == 1
return first_response
if request_idx >= n
throw(exception)
else
throw(NetworkError(fail_code))
return first_response
end
end
end
Expand All @@ -228,26 +236,49 @@ end
@test_throws NetworkError(404) RAI.exec(ctx, "engine", "db", "2+2")
end

# Test for an error thrown _after_ the transaction is created, before it completes.
sync_error_patch = Mocking.Patch(RAI.request,
make_fail_second_time_patch(v2_async_response, 500))
@testset "test that txn ID is logged for txn errors while polling" begin
# Test for an error thrown _after_ the transaction is created, before it completes.
sync_error_patch = Mocking.Patch(RAI.request,
make_fail_after_second_time_patch(v2_async_response, NetworkError(500)))

# See https://discourse.julialang.org/t/how-to-test-the-value-of-a-variable-from-info-log/37380/3
# for an explanation of this logs-testing pattern.
logs, _ = Test.collect_test_logs() do
apply(sync_error_patch) do
@test_throws NetworkError(500) RAI.exec(ctx, "engine", "db", "2+2")
end
end
sym, val = collect(pairs(logs[1].kwargs))[1]
@test sym ≡ :transaction_id
@test val == "1fc9001b-1b88-8685-452e-c01bc6812429"
end

# See https://discourse.julialang.org/t/how-to-test-the-value-of-a-variable-from-info-log/37380/3
# for an explanation of this logs-testing pattern.
logs, _ = Test.collect_test_logs() do
apply(sync_error_patch) do
@test_throws NetworkError(500) RAI.exec(ctx, "engine", "db", "2+2")
@testset "Handle Aborted Txns with no metadata" begin
# Test for the _specific case_ of a 404 from the RelationalAI service, once the txn
# completes.

# Attempt to wait until a txn is done. This will attempt to fetch the metadata &
# results once it's finished.
metadata_404_patch = Mocking.Patch(RAI.request,
make_fail_after_second_time_patch(
# get_transaction() returns a completed Transaction resource
v2_get_transaction_response_completed(),
# So then we attempt to fetch the metadata or results or problems, and error
RAI.HTTPError(404)
)
)

apply(metadata_404_patch) do
RAI.wait_until_done(ctx, "<txn-id>", start_time_ns=0)
end
end
sym, val = collect(pairs(logs[1].kwargs))[1]
@test sym ≡ :transaction_id
@test val == "1fc9001b-1b88-8685-452e-c01bc6812429"

end

@testset "exec with fast-path response only makes one request" begin
# Throw an error if the SDK attempts to make two requests to RAI API:
only_1_request_patch = Mocking.Patch(RAI.request,
make_fail_second_time_patch(v2_fastpath_response, 500))
make_fail_after_second_time_patch(v2_fastpath_response, NetworkError(500)))

ctx = Context("region", "scheme", "host", "2342", nothing, "audience")
apply(only_1_request_patch) do
Expand Down