Skip to content

Commit 11c9d29

Browse files
feat: ai-rag plugin (#11568)
1 parent 5eb9f6a commit 11c9d29

File tree

11 files changed

+954
-1
lines changed

11 files changed

+954
-1
lines changed

Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,11 @@ install: runtime
377377
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers
378378
$(ENV_INSTALL) apisix/plugins/ai-proxy/drivers/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers
379379

380+
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/embeddings
381+
$(ENV_INSTALL) apisix/plugins/ai-rag/embeddings/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/embeddings
382+
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/vector-search
383+
$(ENV_INSTALL) apisix/plugins/ai-rag/vector-search/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/vector-search
384+
380385
# ai-content-moderation plugin
381386
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai
382387
$(ENV_INSTALL) apisix/plugins/ai/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai

apisix/cli/config.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ local _M = {
216216
"body-transformer",
217217
"ai-prompt-template",
218218
"ai-prompt-decorator",
219+
"ai-rag",
219220
"ai-content-moderation",
220221
"proxy-mirror",
221222
"proxy-rewrite",

apisix/plugins/ai-rag.lua

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
--
2+
-- Licensed to the Apache Software Foundation (ASF) under one or more
3+
-- contributor license agreements. See the NOTICE file distributed with
4+
-- this work for additional information regarding copyright ownership.
5+
-- The ASF licenses this file to You under the Apache License, Version 2.0
6+
-- (the "License"); you may not use this file except in compliance with
7+
-- the License. You may obtain a copy of the License at
8+
--
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
--
11+
-- Unless required by applicable law or agreed to in writing, software
12+
-- distributed under the License is distributed on an "AS IS" BASIS,
13+
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
-- See the License for the specific language governing permissions and
15+
-- limitations under the License.
16+
--
17+
local next = next
18+
local require = require
19+
local ngx_req = ngx.req
20+
21+
local http = require("resty.http")
22+
local core = require("apisix.core")
23+
24+
local azure_openai_embeddings = require("apisix.plugins.ai-rag.embeddings.azure_openai").schema
25+
local azure_ai_search_schema = require("apisix.plugins.ai-rag.vector-search.azure_ai_search").schema
26+
27+
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
28+
local HTTP_BAD_REQUEST = ngx.HTTP_BAD_REQUEST
29+
30+
local schema = {
31+
type = "object",
32+
properties = {
33+
type = "object",
34+
embeddings_provider = {
35+
type = "object",
36+
properties = {
37+
azure_openai = azure_openai_embeddings
38+
},
39+
-- ensure only one provider can be configured while implementing support for
40+
-- other providers
41+
required = { "azure_openai" },
42+
maxProperties = 1,
43+
},
44+
vector_search_provider = {
45+
type = "object",
46+
properties = {
47+
azure_ai_search = azure_ai_search_schema
48+
},
49+
-- ensure only one provider can be configured while implementing support for
50+
-- other providers
51+
required = { "azure_ai_search" },
52+
maxProperties = 1
53+
},
54+
},
55+
required = { "embeddings_provider", "vector_search_provider" }
56+
}
57+
58+
local request_schema = {
59+
type = "object",
60+
properties = {
61+
ai_rag = {
62+
type = "object",
63+
properties = {
64+
vector_search = {},
65+
embeddings = {},
66+
},
67+
required = { "vector_search", "embeddings" }
68+
}
69+
}
70+
}
71+
72+
local _M = {
73+
version = 0.1,
74+
priority = 1060,
75+
name = "ai-rag",
76+
schema = schema,
77+
}
78+
79+
80+
function _M.check_schema(conf)
81+
return core.schema.check(schema, conf)
82+
end
83+
84+
85+
function _M.access(conf, ctx)
86+
local httpc = http.new()
87+
local body_tab, err = core.request.get_json_request_body_table()
88+
if not body_tab then
89+
return HTTP_BAD_REQUEST, err
90+
end
91+
if not body_tab["ai_rag"] then
92+
core.log.error("request body must have \"ai-rag\" field")
93+
return HTTP_BAD_REQUEST
94+
end
95+
96+
local embeddings_provider = next(conf.embeddings_provider)
97+
local embeddings_provider_conf = conf.embeddings_provider[embeddings_provider]
98+
local embeddings_driver = require("apisix.plugins.ai-rag.embeddings." .. embeddings_provider)
99+
100+
local vector_search_provider = next(conf.vector_search_provider)
101+
local vector_search_provider_conf = conf.vector_search_provider[vector_search_provider]
102+
local vector_search_driver = require("apisix.plugins.ai-rag.vector-search." ..
103+
vector_search_provider)
104+
105+
local vs_req_schema = vector_search_driver.request_schema
106+
local emb_req_schema = embeddings_driver.request_schema
107+
108+
request_schema.properties.ai_rag.properties.vector_search = vs_req_schema
109+
request_schema.properties.ai_rag.properties.embeddings = emb_req_schema
110+
111+
local ok, err = core.schema.check(request_schema, body_tab)
112+
if not ok then
113+
core.log.error("request body fails schema check: ", err)
114+
return HTTP_BAD_REQUEST
115+
end
116+
117+
local embeddings, status, err = embeddings_driver.get_embeddings(embeddings_provider_conf,
118+
body_tab["ai_rag"].embeddings, httpc)
119+
if not embeddings then
120+
core.log.error("could not get embeddings: ", err)
121+
return status, err
122+
end
123+
124+
local search_body = body_tab["ai_rag"].vector_search
125+
search_body.embeddings = embeddings
126+
local res, status, err = vector_search_driver.search(vector_search_provider_conf,
127+
search_body, httpc)
128+
if not res then
129+
core.log.error("could not get vector_search result: ", err)
130+
return status, err
131+
end
132+
133+
-- remove ai_rag from request body because their purpose is served
134+
-- also, these values will cause failure when proxying requests to LLM.
135+
body_tab["ai_rag"] = nil
136+
137+
if not body_tab.messages then
138+
body_tab.messages = {}
139+
end
140+
141+
local augment = {
142+
role = "user",
143+
content = res
144+
}
145+
core.table.insert_tail(body_tab.messages, augment)
146+
147+
local req_body_json, err = core.json.encode(body_tab)
148+
if not req_body_json then
149+
return HTTP_INTERNAL_SERVER_ERROR, err
150+
end
151+
152+
ngx_req.set_body_data(req_body_json)
153+
end
154+
155+
156+
return _M
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
--
2+
-- Licensed to the Apache Software Foundation (ASF) under one or more
3+
-- contributor license agreements. See the NOTICE file distributed with
4+
-- this work for additional information regarding copyright ownership.
5+
-- The ASF licenses this file to You under the Apache License, Version 2.0
6+
-- (the "License"); you may not use this file except in compliance with
7+
-- the License. You may obtain a copy of the License at
8+
--
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
--
11+
-- Unless required by applicable law or agreed to in writing, software
12+
-- distributed under the License is distributed on an "AS IS" BASIS,
13+
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
-- See the License for the specific language governing permissions and
15+
-- limitations under the License.
16+
--
17+
local core = require("apisix.core")
18+
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
19+
local HTTP_OK = ngx.HTTP_OK
20+
local type = type
21+
22+
local _M = {}
23+
24+
_M.schema = {
25+
type = "object",
26+
properties = {
27+
endpoint = {
28+
type = "string",
29+
},
30+
api_key = {
31+
type = "string",
32+
},
33+
},
34+
required = { "endpoint", "api_key" }
35+
}
36+
37+
function _M.get_embeddings(conf, body, httpc)
38+
local body_tab, err = core.json.encode(body)
39+
if not body_tab then
40+
return nil, HTTP_INTERNAL_SERVER_ERROR, err
41+
end
42+
43+
local res, err = httpc:request_uri(conf.endpoint, {
44+
method = "POST",
45+
headers = {
46+
["Content-Type"] = "application/json",
47+
["api-key"] = conf.api_key,
48+
},
49+
body = body_tab
50+
})
51+
52+
if not res or not res.body then
53+
return nil, HTTP_INTERNAL_SERVER_ERROR, err
54+
end
55+
56+
if res.status ~= HTTP_OK then
57+
return nil, res.status, res.body
58+
end
59+
60+
local res_tab, err = core.json.decode(res.body)
61+
if not res_tab then
62+
return nil, HTTP_INTERNAL_SERVER_ERROR, err
63+
end
64+
65+
if type(res_tab.data) ~= "table" or core.table.isempty(res_tab.data) then
66+
return nil, HTTP_INTERNAL_SERVER_ERROR, res.body
67+
end
68+
69+
local embeddings, err = core.json.encode(res_tab.data[1].embedding)
70+
if not embeddings then
71+
return nil, HTTP_INTERNAL_SERVER_ERROR, err
72+
end
73+
74+
return res_tab.data[1].embedding
75+
end
76+
77+
78+
_M.request_schema = {
79+
type = "object",
80+
properties = {
81+
input = {
82+
type = "string"
83+
}
84+
},
85+
required = { "input" }
86+
}
87+
88+
return _M
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
--
2+
-- Licensed to the Apache Software Foundation (ASF) under one or more
3+
-- contributor license agreements. See the NOTICE file distributed with
4+
-- this work for additional information regarding copyright ownership.
5+
-- The ASF licenses this file to You under the Apache License, Version 2.0
6+
-- (the "License"); you may not use this file except in compliance with
7+
-- the License. You may obtain a copy of the License at
8+
--
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
--
11+
-- Unless required by applicable law or agreed to in writing, software
12+
-- distributed under the License is distributed on an "AS IS" BASIS,
13+
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
-- See the License for the specific language governing permissions and
15+
-- limitations under the License.
16+
--
17+
local core = require("apisix.core")
18+
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
19+
local HTTP_OK = ngx.HTTP_OK
20+
21+
local _M = {}
22+
23+
_M.schema = {
24+
type = "object",
25+
properties = {
26+
endpoint = {
27+
type = "string",
28+
},
29+
api_key = {
30+
type = "string",
31+
},
32+
},
33+
required = {"endpoint", "api_key"}
34+
}
35+
36+
37+
function _M.search(conf, search_body, httpc)
38+
local body = {
39+
vectorQueries = {
40+
{
41+
kind = "vector",
42+
vector = search_body.embeddings,
43+
fields = search_body.fields
44+
}
45+
}
46+
}
47+
local final_body, err = core.json.encode(body)
48+
if not final_body then
49+
return nil, HTTP_INTERNAL_SERVER_ERROR, err
50+
end
51+
52+
local res, err = httpc:request_uri(conf.endpoint, {
53+
method = "POST",
54+
headers = {
55+
["Content-Type"] = "application/json",
56+
["api-key"] = conf.api_key,
57+
},
58+
body = final_body
59+
})
60+
61+
if not res or not res.body then
62+
return nil, HTTP_INTERNAL_SERVER_ERROR, err
63+
end
64+
65+
if res.status ~= HTTP_OK then
66+
return nil, res.status, res.body
67+
end
68+
69+
return res.body
70+
end
71+
72+
73+
_M.request_schema = {
74+
type = "object",
75+
properties = {
76+
fields = {
77+
type = "string"
78+
}
79+
},
80+
required = { "fields" }
81+
}
82+
83+
return _M

conf/config.yaml.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ plugins: # plugin list (sorted by priority)
479479
- body-transformer # priority: 1080
480480
- ai-prompt-template # priority: 1071
481481
- ai-prompt-decorator # priority: 1070
482+
- ai-rag # priority: 1060
482483
- ai-content-moderation # priority: 1040 TODO: compare priority with other ai plugins
483484
- proxy-mirror # priority: 1010
484485
- proxy-rewrite # priority: 1008

docs/en/latest/config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@
100100
"plugins/degraphql",
101101
"plugins/body-transformer",
102102
"plugins/ai-proxy",
103-
"plugins/attach-consumer-label"
103+
"plugins/attach-consumer-label",
104+
"plugins/ai-rag"
104105
]
105106
},
106107
{

0 commit comments

Comments
 (0)