Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ai-rag plugin #11568

Merged
merged 22 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,11 @@ install: runtime
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers
$(ENV_INSTALL) apisix/plugins/ai-proxy/drivers/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers

$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/embeddings
$(ENV_INSTALL) apisix/plugins/ai-rag/embeddings/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/embeddings
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/vector-search
$(ENV_INSTALL) apisix/plugins/ai-rag/vector-search/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/vector-search

$(ENV_INSTALL) bin/apisix $(ENV_INST_BINDIR)/apisix


Expand Down
1 change: 1 addition & 0 deletions apisix/cli/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ local _M = {
"body-transformer",
"ai-prompt-template",
"ai-prompt-decorator",
"ai-rag",
"proxy-mirror",
"proxy-rewrite",
"workflow",
Expand Down
154 changes: 154 additions & 0 deletions apisix/plugins/ai-rag.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local next = next
local require = require
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to the front?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean?

local ngx_req = ngx.req

local http = require("resty.http")
local core = require("apisix.core")

local azure_openai_embeddings = require("apisix.plugins.ai-rag.embeddings.azure_openai").schema
local azure_ai_search_schema = require("apisix.plugins.ai-rag.vector-search.azure_ai_search").schema

local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local HTTP_BAD_REQUEST = ngx.HTTP_BAD_REQUEST

local schema = {
type = "object",
properties = {
type = "object",
embeddings_provider = {
type = "object",
properties = {
azure_openai = azure_openai_embeddings
},
-- ensure only one provider can be configured while implementing support for
-- other providers
required = { "azure_openai" },
},
vector_search_provider = {
type = "object",
properties = {
azure_ai_search = azure_ai_search_schema
},
-- ensure only one provider can be configured while implementing support for
-- other providers
required = { "azure_ai_search" }
},
},
required = { "embeddings_provider", "vector_search_provider" }
}

local request_schema = {
type = "object",
properties = {
ai_rag = {
type = "object",
properties = {
vector_search = {},
embeddings = {},
},
required = { "vector_search", "embeddings" }
}
}
}

local _M = {
version = 0.1,
priority = 1060,
name = "ai-rag",
schema = schema,
}


function _M.check_schema(conf)
return core.schema.check(schema, conf)
end


function _M.access(conf, ctx)
local httpc = http.new()
local body_tab, err = core.request.get_json_request_body_table()
if not body_tab then
return HTTP_BAD_REQUEST, err
end
if not body_tab["ai_rag"] then
core.log.error("request body must have \"ai-rag\" field")
shreemaan-abhishek marked this conversation as resolved.
Show resolved Hide resolved
return HTTP_BAD_REQUEST
end

local embeddings_provider = next(conf.embeddings_provider)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if embeddings_provider have multiple properties?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the schema should ensure only one provider can be configured.

local embeddings_provider_conf = conf.embeddings_provider[embeddings_provider]
local embeddings_driver = require("apisix.plugins.ai-rag.embeddings." .. embeddings_provider)

local vector_search_provider = next(conf.vector_search_provider)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if vector_search_provider have multiple properties?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also ref #11541 (comment) and adjust related?

local vector_search_provider_conf = conf.vector_search_provider[vector_search_provider]
local vector_search_driver = require("apisix.plugins.ai-rag.vector-search." ..
vector_search_provider)

local vs_req_schema = vector_search_driver.request_schema
local emb_req_schema = embeddings_driver.request_schema

request_schema.properties.ai_rag.properties.vector_search = vs_req_schema
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there risks with module variables in concurrent scenarios?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, can you elaborate?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the process again and there is no risk of race condition. Ignore it.

request_schema.properties.ai_rag.properties.embeddings = emb_req_schema

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

construct a request_body_schema according to the service providers

local ok, err = core.schema.check(request_schema, body_tab)
if not ok then
core.log.error("request body fails schema check: ", err)
return HTTP_BAD_REQUEST
end

local embeddings, status, err = embeddings_driver.get_embeddings(embeddings_provider_conf,
body_tab["ai_rag"].embeddings, httpc)
if not embeddings then
core.log.error("could not get embeddings: ", err)
return status, err
end

local search_body = body_tab["ai_rag"].vector_search
search_body.embeddings = embeddings
shreemaan-abhishek marked this conversation as resolved.
Show resolved Hide resolved
local res, status, err = vector_search_driver.search(vector_search_provider_conf,
search_body, httpc)
if not res then
core.log.error("could not get vector_search result: ", err)
return status, err
end

-- remove ai_rag from request body because their purpose is served
-- also, these values will cause failure when proxying requests to LLM.
body_tab["ai_rag"] = nil
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove ai_rag from request body because their purpose is served and also these values will cause failure when proxying requests to LLM.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment the purpose in source code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


if not body_tab.messages then
body_tab.messages = {}
end

local augment = {
role = "user",
content = res
}
core.table.insert_tail(body_tab.messages, augment)

local req_body_json, err = core.json.encode(body_tab)
if not req_body_json then
return HTTP_INTERNAL_SERVER_ERROR, err
end

ngx_req.set_body_data(req_body_json)
end


return _M
87 changes: 87 additions & 0 deletions apisix/plugins/ai-rag/embeddings/azure_openai.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local type = type

local _M = {}

_M.schema = {
type = "object",
properties = {
endpoint = {
type = "string",
},
api_key = {
type = "string",
},
},
required = { "endpoint", "api_key" }
}

function _M.get_embeddings(conf, body, httpc)
local body_tab, err = core.json.encode(body)
if not body_tab then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

local res, err = httpc:request_uri(conf.endpoint, {
method = "POST",
headers = {
["Content-Type"] = "application/json",
["api-key"] = conf.api_key,
},
body = body_tab
})

if not res or not res.body then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

if res.status ~= 200 then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use ngx const var instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return nil, res.status, res.body
end

local res_tab, err = core.json.decode(res.body)
if not res_tab then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

if type(res_tab.data) ~= "table" or core.table.isempty(res_tab.data) then
return nil, HTTP_INTERNAL_SERVER_ERROR, res.body
end

local embeddings, err = core.json.encode(res_tab.data[1].embedding)
if not embeddings then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

return res_tab.data[1].embedding
end


_M.request_schema = {
type = "object",
properties = {
input = {
type = "string"
}
},
required = { "input" }
}

return _M
82 changes: 82 additions & 0 deletions apisix/plugins/ai-rag/vector-search/azure_ai_search.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR

local _M = {}

_M.schema = {
type = "object",
properties = {
endpoint = {
type = "string",
},
api_key = {
type = "string",
},
},
required = {"endpoint", "api_key"}
}


function _M.search(conf, search_body, httpc)
local body = {
vectorQueries = {
{
kind = "vector",
vector = search_body.embeddings,
fields = search_body.fields
}
}
}
local final_body, err = core.json.encode(body)
if not final_body then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

local res, err = httpc:request_uri(conf.endpoint, {
method = "POST",
headers = {
["Content-Type"] = "application/json",
["api-key"] = conf.api_key,
},
body = final_body
})

if not res or not res.body then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

if res.status ~= 200 then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use ngx const var instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return nil, res.status, res.body
end

return res.body
end


_M.request_schema = {
type = "object",
properties = {
fields = {
type = "string"
}
},
required = { "fields" }
}

return _M
1 change: 1 addition & 0 deletions conf/config.yaml.example
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ plugins: # plugin list (sorted by priority)
- body-transformer # priority: 1080
- ai-prompt-template # priority: 1071
- ai-prompt-decorator # priority: 1070
- ai-rag # priority: 1060
- proxy-mirror # priority: 1010
- proxy-rewrite # priority: 1008
- workflow # priority: 1006
Expand Down
3 changes: 2 additions & 1 deletion docs/en/latest/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@
"plugins/degraphql",
"plugins/body-transformer",
"plugins/ai-proxy",
"plugins/attach-consumer-label"
"plugins/attach-consumer-label",
"plugins/ai-rag"
]
},
{
Expand Down
Loading
Loading