Skip to content

Commit 695ea3c

Browse files
feat: ai-content-moderation plugin (#11541)
1 parent 1cd688b commit 695ea3c

File tree

12 files changed

+1216
-2
lines changed

12 files changed

+1216
-2
lines changed

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,10 @@ install: runtime
377377
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers
378378
$(ENV_INSTALL) apisix/plugins/ai-proxy/drivers/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers
379379

380+
# ai-content-moderation plugin
381+
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai
382+
$(ENV_INSTALL) apisix/plugins/ai/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai
383+
380384
$(ENV_INSTALL) bin/apisix $(ENV_INST_BINDIR)/apisix
381385

382386

apisix-master-0.rockspec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ dependencies = {
8282
"lua-resty-t1k = 1.1.5",
8383
"brotli-ffi = 0.3-1",
8484
"lua-ffi-zlib = 0.6-0",
85-
"api7-lua-resty-aws == 2.0.1-1",
85+
"api7-lua-resty-aws == 2.0.2-1",
8686
}
8787

8888
build = {

apisix/cli/config.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ local _M = {
216216
"body-transformer",
217217
"ai-prompt-template",
218218
"ai-prompt-decorator",
219+
"ai-content-moderation",
219220
"proxy-mirror",
220221
"proxy-rewrite",
221222
"workflow",
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
--
2+
-- Licensed to the Apache Software Foundation (ASF) under one or more
3+
-- contributor license agreements. See the NOTICE file distributed with
4+
-- this work for additional information regarding copyright ownership.
5+
-- The ASF licenses this file to You under the Apache License, Version 2.0
6+
-- (the "License"); you may not use this file except in compliance with
7+
-- the License. You may obtain a copy of the License at
8+
--
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
--
11+
-- Unless required by applicable law or agreed to in writing, software
12+
-- distributed under the License is distributed on an "AS IS" BASIS,
13+
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
-- See the License for the specific language governing permissions and
15+
-- limitations under the License.
16+
--
17+
local core = require("apisix.core")
18+
local aws_instance = require("resty.aws")()
19+
local http = require("resty.http")
20+
local fetch_secrets = require("apisix.secret").fetch_secrets
21+
22+
local next = next
23+
local pairs = pairs
24+
local unpack = unpack
25+
local type = type
26+
local ipairs = ipairs
27+
local require = require
28+
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
29+
local HTTP_BAD_REQUEST = ngx.HTTP_BAD_REQUEST
30+
31+
32+
local aws_comprehend_schema = {
33+
type = "object",
34+
properties = {
35+
access_key_id = { type = "string" },
36+
secret_access_key = { type = "string" },
37+
region = { type = "string" },
38+
endpoint = {
39+
type = "string",
40+
pattern = [[^https?://]]
41+
},
42+
ssl_verify = {
43+
type = "boolean",
44+
default = true
45+
}
46+
},
47+
required = { "access_key_id", "secret_access_key", "region", }
48+
}
49+
50+
local moderation_categories_pattern = "^(PROFANITY|HATE_SPEECH|INSULT|"..
51+
"HARASSMENT_OR_ABUSE|SEXUAL|VIOLENCE_OR_THREAT)$"
52+
local schema = {
53+
type = "object",
54+
properties = {
55+
provider = {
56+
type = "object",
57+
properties = {
58+
aws_comprehend = aws_comprehend_schema
59+
},
60+
maxProperties = 1,
61+
-- ensure only one provider can be configured while implementing support for
62+
-- other providers
63+
required = { "aws_comprehend" }
64+
},
65+
moderation_categories = {
66+
type = "object",
67+
patternProperties = {
68+
[moderation_categories_pattern] = {
69+
type = "number",
70+
minimum = 0,
71+
maximum = 1
72+
}
73+
},
74+
additionalProperties = false
75+
},
76+
moderation_threshold = {
77+
type = "number",
78+
minimum = 0,
79+
maximum = 1,
80+
default = 0.5
81+
},
82+
llm_provider = {
83+
type = "string",
84+
enum = { "openai" },
85+
}
86+
},
87+
required = { "provider", "llm_provider" },
88+
}
89+
90+
91+
local _M = {
92+
version = 0.1,
93+
priority = 1040, -- TODO: might change
94+
name = "ai-content-moderation",
95+
schema = schema,
96+
}
97+
98+
99+
function _M.check_schema(conf)
100+
return core.schema.check(schema, conf)
101+
end
102+
103+
104+
function _M.rewrite(conf, ctx)
105+
conf = fetch_secrets(conf, true, conf, "")
106+
if not conf then
107+
return HTTP_INTERNAL_SERVER_ERROR, "failed to retrieve secrets from conf"
108+
end
109+
110+
local body, err = core.request.get_json_request_body_table()
111+
if not body then
112+
return HTTP_BAD_REQUEST, err
113+
end
114+
115+
local msgs = body.messages
116+
if type(msgs) ~= "table" or #msgs < 1 then
117+
return HTTP_BAD_REQUEST, "messages not found in request body"
118+
end
119+
120+
local provider = conf.provider[next(conf.provider)]
121+
122+
local credentials = aws_instance:Credentials({
123+
accessKeyId = provider.access_key_id,
124+
secretAccessKey = provider.secret_access_key,
125+
sessionToken = provider.session_token,
126+
})
127+
128+
local default_endpoint = "https://comprehend." .. provider.region .. ".amazonaws.com"
129+
local scheme, host, port = unpack(http:parse_uri(provider.endpoint or default_endpoint))
130+
local endpoint = scheme .. "://" .. host
131+
aws_instance.config.endpoint = endpoint
132+
aws_instance.config.ssl_verify = provider.ssl_verify
133+
134+
local comprehend = aws_instance:Comprehend({
135+
credentials = credentials,
136+
endpoint = endpoint,
137+
region = provider.region,
138+
port = port,
139+
})
140+
141+
local ai_module = require("apisix.plugins.ai." .. conf.llm_provider)
142+
local create_request_text_segments = ai_module.create_request_text_segments
143+
144+
local text_segments = create_request_text_segments(msgs)
145+
local res, err = comprehend:detectToxicContent({
146+
LanguageCode = "en",
147+
TextSegments = text_segments,
148+
})
149+
150+
if not res then
151+
core.log.error("failed to send request to ", provider, ": ", err)
152+
return HTTP_INTERNAL_SERVER_ERROR, err
153+
end
154+
155+
local results = res.body and res.body.ResultList
156+
if type(results) ~= "table" or core.table.isempty(results) then
157+
return HTTP_INTERNAL_SERVER_ERROR, "failed to get moderation results from response"
158+
end
159+
160+
for _, result in ipairs(results) do
161+
if conf.moderation_categories then
162+
for _, item in pairs(result.Labels) do
163+
if not conf.moderation_categories[item.Name] then
164+
goto continue
165+
end
166+
if item.Score > conf.moderation_categories[item.Name] then
167+
return HTTP_BAD_REQUEST, "request body exceeds " .. item.Name .. " threshold"
168+
end
169+
::continue::
170+
end
171+
end
172+
173+
if result.Toxicity > conf.moderation_threshold then
174+
return HTTP_BAD_REQUEST, "request body exceeds toxicity threshold"
175+
end
176+
end
177+
end
178+
179+
return _M

apisix/plugins/ai/openai.lua

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
--
2+
-- Licensed to the Apache Software Foundation (ASF) under one or more
3+
-- contributor license agreements. See the NOTICE file distributed with
4+
-- this work for additional information regarding copyright ownership.
5+
-- The ASF licenses this file to You under the Apache License, Version 2.0
6+
-- (the "License"); you may not use this file except in compliance with
7+
-- the License. You may obtain a copy of the License at
8+
--
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
--
11+
-- Unless required by applicable law or agreed to in writing, software
12+
-- distributed under the License is distributed on an "AS IS" BASIS,
13+
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
-- See the License for the specific language governing permissions and
15+
-- limitations under the License.
16+
--
17+
local core = require("apisix.core")
18+
local ipairs = ipairs
19+
20+
local _M = {}
21+
22+
23+
function _M.create_request_text_segments(msgs)
24+
local text_segments = {}
25+
for _, msg in ipairs(msgs) do
26+
core.table.insert_tail(text_segments, {
27+
Text = msg.content
28+
})
29+
end
30+
return text_segments
31+
end
32+
33+
return _M

conf/config.yaml.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ plugins: # plugin list (sorted by priority)
479479
- body-transformer # priority: 1080
480480
- ai-prompt-template # priority: 1071
481481
- ai-prompt-decorator # priority: 1070
482+
- ai-content-moderation # priority: 1040 TODO: compare priority with other ai plugins
482483
- proxy-mirror # priority: 1010
483484
- proxy-rewrite # priority: 1008
484485
- workflow # priority: 1006

docs/en/latest/config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@
8181
"plugins/ext-plugin-post-req",
8282
"plugins/ext-plugin-post-resp",
8383
"plugins/inspect",
84-
"plugins/ocsp-stapling"
84+
"plugins/ocsp-stapling",
85+
"plugins/ai-content-moderation"
8586
]
8687
},
8788
{

0 commit comments

Comments
 (0)