Skip to content

Commit 7d54b71

Browse files
NRHelmiNHDaly
andauthored
Models actions from v1 to v2 protocol (#76)
* moving models actions from v1 to v2 * models integration tests * refactoring && cleanup * remove unused _list_models * Update test/integration.jl Co-authored-by: Nathan Daly <nathan.daly@relational.ai> * updates * fix * changelog updates * addressing PR comments Co-authored-by: Nathan Daly <nathan.daly@relational.ai>
1 parent 355d466 commit 7d54b71

File tree

6 files changed

+150
-54
lines changed

6 files changed

+150
-54
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Changelog
22

3+
## main
4+
* Update models actions to use v2 protocol
5+
* Update `load_model` to `load_models`
36
## v0.2.1
47
* Increased `connection_limit` to 4096
58

examples/delete-model.jl renamed to examples/delete-models.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ include("parseargs.jl")
2121
function run(database, engine, model; profile)
2222
cfg = load_config(; profile = profile)
2323
ctx = Context(cfg)
24-
rsp = delete_model(ctx, database, engine, model)
24+
rsp = delete_models(ctx, database, engine, model)
2525
println(rsp)
2626
end
2727

examples/load-model.jl renamed to examples/load-models.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
# Load the given Rel model into the given database.
1616

17-
using RAI: Context, HTTPError, load_config, load_model
17+
using RAI: Context, HTTPError, load_config, load_models
1818

1919
include("parseargs.jl")
2020

@@ -25,7 +25,7 @@ function run(database, engine, fullname; profile)
2525
models = Dict(_sansext(fullname) => read(fullname, String))
2626
cfg = load_config(; profile = profile)
2727
ctx = Context(cfg)
28-
rsp = load_model(ctx, database, engine, models)
28+
rsp = load_models(ctx, database, engine, models)
2929
println(rsp)
3030
end
3131

src/RAI.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ export
4646
list_engines
4747

4848
export
49-
delete_model,
49+
delete_models,
5050
get_model,
5151
list_models,
52-
load_model
52+
load_models
5353

5454
export
5555
create_oauth_client,
@@ -66,7 +66,7 @@ export
6666
update_user
6767

6868
export
69-
delete_model,
69+
delete_models,
7070
get_model,
7171
list_models
7272

src/api.jl

Lines changed: 75 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -234,11 +234,8 @@ function delete_engine(ctx::Context, engine::AbstractString; kw...)
234234
return _delete(ctx, PATH_ENGINE; body = JSON3.write(data), kw...)
235235
end
236236

237-
function delete_model(ctx::Context, database::AbstractString, engine::AbstractString, model::AbstractString; kw...)
238-
tx = Transaction(ctx.region, database, engine, "OPEN"; readonly = false)
239-
actions = [_make_delete_models_action([model])]
240-
return _post(ctx, PATH_TRANSACTION; query = query(tx), body = body(tx, actions...), kw...)
241-
end
237+
# escape rel special string
238+
_escape_string_for_rel(str) = replace(repr(str), '%' => "\\%")
242239

243240
function delete_oauth_client(ctx::Context, id::AbstractString; kw...)
244241
return _delete(ctx, joinpath(PATH_OAUTH_CLIENTS, id); kw...)
@@ -270,15 +267,6 @@ function get_database(ctx::Context, database::AbstractString; kw...)
270267
return rsp[1]
271268
end
272269

273-
# todo: move to rel query
274-
function get_model(ctx::Context, database::AbstractString, engine::AbstractString, name::AbstractString; kw...)
275-
models = _list_models(ctx, database, engine; kw...)
276-
for model in models
277-
model["name"] == name && return model["value"]
278-
end
279-
throw(HTTPError(404))
280-
end
281-
282270
function get_oauth_client(ctx::Context, id::AbstractString; kw...)
283271
return _get(ctx, joinpath(PATH_OAUTH_CLIENTS, id); kw...).client
284272
end
@@ -387,22 +375,6 @@ function _make_actions(actions...)
387375
return result
388376
end
389377

390-
function _make_delete_models_action(models::Vector)
391-
return Dict(
392-
"type" => "ModifyWorkspaceAction",
393-
"delete_source" => models)
394-
end
395-
396-
function _make_load_model_action(name, model)
397-
return Dict(
398-
"type" => "InstallAction",
399-
"sources" => [_make_query_source(name, model)])
400-
end
401-
402-
function _make_list_models_action()
403-
return Dict("type" => "ListSourceAction")
404-
end
405-
406378
function _make_list_edb_action()
407379
return Dict("type" => "ListEdbAction")
408380
end
@@ -695,19 +667,6 @@ function list_edbs(ctx::Context, database::AbstractString, engine::AbstractStrin
695667
return rsp.actions[1].result.rels
696668
end
697669

698-
function _list_models(ctx::Context, database::AbstractString, engine::AbstractString; kw...)
699-
tx = Transaction(ctx.region, database, engine, "OPEN"; readonly = true)
700-
data = body(tx, _make_list_models_action())
701-
rsp = _post(ctx, PATH_TRANSACTION; query = query(tx), body = data, kw...).actions
702-
length(rsp) == 0 && return []
703-
return rsp[1].result.sources
704-
end
705-
706-
function list_models(ctx::Context, database::AbstractString, engine::AbstractString; kw...)
707-
models = _list_models(ctx, database, engine; kw...)
708-
return [model["name"] for model in models]
709-
end
710-
711670
function _gen_literal(value)
712671
return "$value"
713672
end
@@ -777,13 +736,82 @@ function load_json(ctx::Context, database::AbstractString, engine::AbstractStrin
777736
return exec(ctx, database, engine, source; inputs = inputs, readonly = false, kw...)
778737
end
779738

780-
function load_model(ctx::Context, database::AbstractString, engine::AbstractString, models::Dict; kw...)
781-
tx = Transaction(ctx.region, database, engine, "OPEN"; readonly = false)
782-
actions = [_make_load_model_action(name, model) for (name, model) in models]
783-
return _post(ctx, PATH_TRANSACTION; query = query(tx), body = body(tx, actions...), kw...)
739+
function load_models(ctx::Context, database::AbstractString, engine::AbstractString, models::Dict; kw...)
740+
queries = []
741+
queries_inputs = Dict()
742+
rand_uint = rand(UInt64)
743+
744+
index = 0
745+
for (name, value) in models
746+
input_name = string("input_", rand_uint, "_", index)
747+
push!(queries, """
748+
def delete:rel:catalog:model["$name"] = rel:catalog:model["$name"]
749+
def insert:rel:catalog:model["$name"] = $input_name
750+
""")
751+
752+
queries_inputs[input_name] = value
753+
index+=1
754+
end
755+
756+
return exec(ctx, database, engine, join(queries, "\n"); inputs = queries_inputs, readonly = false, kw...)
784757
end
785758

759+
function load_models_async(ctx::Context, database::AbstractString, engine::AbstractString, models::Dict; kw...)
760+
queries = []
761+
queries_inputs = Dict()
762+
rand_uint = rand(UInt64)
763+
764+
index = 0
765+
for (name, value) in models
766+
input_name = string("input_", rand_uint, "_", index)
767+
push!(queries, """
768+
def delete:rel:catalog:model["$name"] = rel:catalog:model["$name"]
769+
def insert:rel:catalog:model["$name"] = $input_name
770+
""")
786771

772+
queries_inputs[input_name] = value
773+
index+=1
774+
end
775+
776+
return exec_async(ctx, database, engine, join(queries, "\n"); inputs = queries_inputs, readonly = false, kw...)
777+
end
778+
779+
function list_models(ctx::Context, database::AbstractString, engine::AbstractString; kw...)
780+
out_name = "model$(rand(UInt64))"
781+
query = """ def output:$out_name[name] = rel:catalog:model(name, _) """
782+
resp = exec(ctx, database, engine, query)
783+
for result in resp.results
784+
if occursin("/:output/:$out_name", result.first)
785+
return [name for name in result.second.v1]
786+
end
787+
end
788+
end
789+
790+
function get_model(ctx::Context, database::AbstractString, engine::AbstractString, name::AbstractString; kw...)
791+
out_name = "model$(rand(UInt64))"
792+
query = """def output:$out_name = rel:catalog:model[$(_escape_string_for_rel(name))]"""
793+
resp = exec(ctx, database, engine, query)
794+
for result in resp.results
795+
if occursin("/:output/:$out_name", result.first)
796+
return first(result.second.v1)
797+
end
798+
end
799+
throw(HTTPError(404))
800+
end
801+
802+
function delete_models(ctx::Context, database::AbstractString, engine::AbstractString, models::Vector{String}; kw...)
803+
queries = ["""
804+
def delete:rel:catalog:model[$(_escape_string_for_rel(model))] = rel:catalog:model[$(_escape_string_for_rel(model))]
805+
""" for model in models]
806+
return exec(ctx, database, engine, join(queries, "\n"); readonly=false, kw...)
807+
end
808+
809+
function delete_models_async(ctx::Context, database::AbstractString, engine::AbstractString, model::AbstractString; kw...)
810+
queries = ["""
811+
def delete:rel:catalog:model[$(_escape_string_for_rel(model))] = rel:catalog:model[$(_escape_string_for_rel(model))
812+
""" for model in models]
813+
return exec_async(ctx, database, engine, join(queries, "\n"); readonly=false, kw...)
814+
end
787815

788816
# --- utils -------------------------
789817
# Patch for older versions of HTTP package that don't support parsing multipart responses:

test/integration.jl

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,72 @@ with_engine(CTX) do engine_name
278278

279279
# -----------------------------------
280280
# models
281-
@testset "models" begin end
281+
@testset "models" begin
282+
models = list_models(CTX, database_name, engine_name)
283+
@test length(models) > 0
284+
285+
models = Dict("test_model" => "def foo = :bar")
286+
resp = load_models(CTX, database_name, engine_name, models)
287+
@test resp.transaction.state == "COMPLETED"
288+
289+
value = get_model(CTX, database_name, engine_name, "test_model")
290+
@test models["test_model"] == value
291+
292+
models = list_models(CTX, database_name, engine_name)
293+
@test "test_model" in models
294+
295+
resp = delete_models(CTX, database_name, engine_name, ["test_model"])
296+
@test resp.transaction.state == "COMPLETED"
297+
@test length(resp.problems) == 0
298+
299+
models = list_models(CTX, database_name, engine_name)
300+
@test !("test_model" in models)
301+
302+
# test escape special rel character
303+
models = Dict("percent" => "def foo = \"98%\"")
304+
resp = load_models(CTX, database_name, engine_name, models)
305+
@test resp.transaction.state == "COMPLETED"
306+
@test length(resp.problems) > 0
307+
resp = delete_models(CTX, database_name, engine_name, ["percent"])
308+
@test resp.transaction.state == "COMPLETED"
309+
@test length(resp.problems) == 0
310+
311+
models = Dict("percent" => "def foo = \"98\\%\"")
312+
resp = load_models(CTX, database_name, engine_name, models)
313+
@test resp.transaction.state == "COMPLETED"
314+
@test length(resp.problems) == 0
315+
value = get_model(CTX, database_name, engine_name, "percent")
316+
@test models["percent"] == value
317+
318+
models = list_models(CTX, database_name, engine_name)
319+
@test "percent" in models
320+
321+
resp = delete_models(CTX, database_name, engine_name, ["percent"])
322+
@test resp.transaction.state == "COMPLETED"
323+
@test length(resp.problems) == 0
324+
325+
models = list_models(CTX, database_name, engine_name)
326+
@test !("percent" in models)
327+
328+
# test escape """
329+
models = Dict("triple_quoted" => "def foo = \"\"\"98\\%\"\"\" ")
330+
resp = load_models(CTX, database_name, engine_name, models)
331+
@test resp.transaction.state == "COMPLETED"
332+
@test length(resp.problems) == 0
333+
value = get_model(CTX, database_name, engine_name, "triple_quoted")
334+
@test models["triple_quoted"] == value
335+
336+
models = list_models(CTX, database_name, engine_name)
337+
@test "triple_quoted" in models
338+
@test length(resp.problems) == 0
339+
340+
resp = delete_models(CTX, database_name, engine_name, ["triple_quoted"])
341+
@test resp.transaction.state == "COMPLETED"
342+
@test length(resp.problems) == 0
343+
344+
models = list_models(CTX, database_name, engine_name)
345+
@test !("triple_quoted" in models)
346+
end
282347
end
283348
end
284349

0 commit comments

Comments
 (0)