Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

在线查询 支持AI根据描述生成查询语句 #2726

Merged
merged 7 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion common/templates/config.html
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ <h5 style="color: darkgrey"><b>SQL上线</b></h5>
</div>
</div>
<h5 style="color: darkgrey"><b>SQL查询</b></h5>
<h6 style="color:red">注:开启脱敏功能必须要配置goInception信息,用于SQL语法解析</h6>
<h6 style="color:red">注:开启脱敏功能必须要配置goInception信息,用于SQL语法解析;若无OPENAI配置则不开启AI生成SQL语句的功能</h6>
<hr/>
<div class="form-horizontal">
<div class="form-group">
Expand Down Expand Up @@ -399,6 +399,36 @@ <h6 style="color:red">注:开启脱敏功能必须要配置goInception信息
placeholder="管理员/DBA查询结果集限制" />
</div>
</div>
<div class="form-group">
<label for="openai_base_url"
class="col-sm-4 control-label">OPENAI_BASE_URL</label>
<div class="col-sm-5">
<input type="text" class="form-control" id="openai_base_url"
key="openai_base_url"
value="{{ config.openai_base_url }}"
placeholder="openai base url" />
</div>
</div>
<div class="form-group">
<label for="openai_api_key"
class="col-sm-4 control-label">OPENAI_API_KEY</label>
<div class="col-sm-5">
<input type="text" class="form-control" id="openai_api_key"
key="openai_api_key"
value="{{ config.openai_api_key }}"
placeholder="openai api key" />
</div>
</div>
<div class="form-group">
<label for="default_chat_model"
class="col-sm-4 control-label">DEFAULT_CHAT_MODEL</label>
<div class="col-sm-5">
<input type="text" class="form-control" id="default_chat_model"
key="default_chat_model"
value="{{ config.default_chat_model }}"
placeholder="openai default chat model" />
</div>
</div>
<h5 style="color: darkgrey"><b>SQL优化</b></h5>
<hr/>
<div class="form-group">
Expand Down
32 changes: 32 additions & 0 deletions common/utils/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from openai import OpenAI
import logging
from common.config import SysConfig

logger = logging.getLogger("default")


class OpenaiClient:
def __init__(self):
all_config = SysConfig()
self.base_url = all_config.get("openai_base_url", "")
self.api_key = all_config.get("openai_api_key", "")
self.default_chat_model = all_config.get("default_chat_model", "")
LeoQuote marked this conversation as resolved.
Show resolved Hide resolved
self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)

def request_chat_completion(self, messages, **kwargs):
"""chat_completion"""
completion = self.client.chat.completions.create(
model=self.default_chat_model, messages=messages, **kwargs
)
return completion

def generate_sql_by_openai(self, db_type: str, table_schema: str, query_desc: str):
"""根据传入的基本信息生成查询语句"""
tips = f"你是一个熟悉 {db_type} 的工程师, 我会给你一些基本信息和要求, 你会生成一个查询语句给我使用, 不要返回任何注释和序号, 仅返回查询语句"
QSummerY marked this conversation as resolved.
Show resolved Hide resolved
messages = [dict(role="user", content=f"{tips}: {table_schema}\n{query_desc}")]
logger.info(messages)
try:
res = self.request_chat_completion(messages)
return res.choices[0].message.content
except Exception as e:
raise ValueError(f"请求openai生成查询语句失败: {e}")
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@ mozilla-django-oidc==3.0.0
django-auth-dingding==0.0.3
django-cas-ng==4.3.0
cassandra-driver
httpx
OpenAI
76 changes: 76 additions & 0 deletions sql/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from django.http import HttpResponse
from common.config import SysConfig
from common.utils.extend_json_encoder import ExtendJSONEncoder, ExtendJSONEncoderFTime
from common.utils.openai import OpenaiClient
from common.utils.timer import FuncTimer
from sql.query_privileges import query_priv_check
from sql.utils.resource_group import user_instances
Expand Down Expand Up @@ -313,3 +314,78 @@ def kill_query_conn(instance_id, thread_id):
instance = Instance.objects.get(pk=instance_id)
query_engine = get_engine(instance)
query_engine.kill_connection(thread_id)


@permission_required("sql.menu_sqlquery", raise_exception=True)
def generate_sql(request):
"""
利用AI生成查询SQL, 传入数据基本结构和查询描述
:param request:
:return:
"""
db_type = request.POST.get("db_type")
query_desc = request.POST.get("query_desc")
if not db_type or not query_desc:
return HttpResponse(
json.dumps({"status": 1, "msg": "db_type or query_desc不存在", "data": []}),
content_type="application/json",
)

instance_name = request.POST.get("instance_name")
try:
instance = Instance.objects.get(instance_name=instance_name)
except Instance.DoesNotExist:
return HttpResponse(
json.dumps({"status": 1, "msg": "实例不存在", "data": []}),
content_type="application/json",
)
db_name = request.POST.get("db_name")
schema_name = request.POST.get("schema_name")
tb_name = request.POST.get("tb_name")

result = {"status": 0, "msg": "ok", "data": ""}
try:
query_engine = get_engine(instance=instance)
query_result = query_engine.describe_table(
db_name, tb_name, schema_name=schema_name
)
openai_client = OpenaiClient()
# 有些不存在表结构, 例如 redis
if len(query_result.rows) != 0:
result["data"] = openai_client.generate_sql_by_openai(
db_type, query_result.rows[0][-1], query_desc
)
else:
result["data"] = openai_client.generate_sql_by_openai(
db_type, "", query_desc
)
except Exception as msg:
result["status"] = 1
result["msg"] = str(msg)
return HttpResponse(json.dumps(result), content_type="application/json")


def check_openai(request):
"""
校验openai配置是否存在
:param request:
:return:
"""
openai_config = ["openai_base_url", "openai_api_key", "default_chat_model"]
for key in openai_config:
if not (SysConfig().get(key)):
return HttpResponse(
json.dumps(
{
"status": 1,
"msg": f"openai 配置{key}不存在, 不支持此功能",
"data": False,
}
),
content_type="application/json",
)

return HttpResponse(
json.dumps({"status": 0, "msg": "ok", "data": True}),
content_type="application/json",
)
83 changes: 83 additions & 0 deletions sql/templates/sqlquery.html
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ <h4 class="modal-title text-danger">收藏语句</h4>
<option value={{ sql.id }}>{{ sql.alias }}</option>
{% endfor %}
</select>
<input id="generateDesc" class="form-control" style="display: none" placeholder="AI 查询描述" />
<input id="btn-generatesql" type="button" class="btn btn-info" style="display: none" value="生成SQL"/>
QSummerY marked this conversation as resolved.
Show resolved Hide resolved
</div>
<div class="panel-body">
<form id="form-sqlquery" action="/sqlquery/" method="post" class="form-horizontal" role="form">
Expand Down Expand Up @@ -495,6 +497,27 @@ <h4 class="modal-title text-danger">收藏语句</h4>
}
sessionStorage.removeItem('re_query');
}

// 获取sysconfig
function check_openai() {
$.ajax({
type: "get",
url: "/check/openai/",
dataType: "json",
data: false,
complete: function () {
},
success: function (data) {
if (data["data"]) {
$("#generateDesc").show()
$("#btn-generatesql").show()
}
},
error: function (XMLHttpRequest, textStatus, errorThrown) {
alert(errorThrown);
}
});
}
</script>
<!-- 执行结果 -->
<script>
Expand Down Expand Up @@ -624,6 +647,32 @@ <h4 class="modal-title text-danger">收藏语句</h4>
return result;
}

//提交AI生成sql语句请求
$("#btn-generatesql").click(function () {
var check = false
var optgroup = $('#instance_name :selected').parent().attr('label')
var instance_name = $("#instance_name").val()
var db_name = $("#db_name").val()
var tb_name = $("#table_name").val()
var query_desc = $("#generateDesc").val()

if (!instance_name) {
alert("请选择实例!")
} else if (!db_name) {
alert("请选择数据库!")
} else if (optgroup !== 'Redis' && !tb_name){
alert("请选择表结构!")
} else if (!query_desc) {
alert("请输入查询描述!")
} else {
check = true
}
if (check) {
generatesql()
}
}
);

//先做表单验证,验证成功再成功提交查询请求
$("#btn-sqlquery").click(function () {
dosqlquery();
Expand Down Expand Up @@ -1023,6 +1072,37 @@ <h4 class="modal-title text-danger">收藏语句</h4>
});
}

function generatesql() {
var optgroup = $('#instance_name :selected').parent().attr('label');
const data = {
db_type: optgroup,
instance_name: $("#instance_name").val(),
db_name: $("#db_name").val(),
schema_name: $("#schema_name").val(),
tb_name: $("#table_name").val(),
query_desc: $("#generateDesc").val(),
}
//提交请求
$.ajax({
type: "post",
url: "/query/generate_sql/",
dataType: "json",
data: data,
complete: function () {
$('input[type=button]').removeClass('disabled');
$('input[type=button]').prop('disabled', false);
optgroup_control();
},
success: function (data) {
editor.setValue(data["data"]);
editor.clearSelection();
},
error: function (XMLHttpRequest, textStatus, errorThrown) {
alert(errorThrown);
}
});
}

function dosqlquery() {
if (sqlquery_validate()) {
$('input[type=button]').addClass('disabled');
Expand Down Expand Up @@ -1325,6 +1405,9 @@ <h4 class="modal-title text-danger">收藏语句</h4>
} else {
editor.setValue("");
}

// check openai 配置是否存在以支持AI生成查询语句功能
check_openai()

//默认获取查询历史
get_querylog();
Expand Down
48 changes: 48 additions & 0 deletions sql/test_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest


def test_check_openai(admin_client, setup_sys_config):
"""校验openai配置"""
setup_sys_config.set("openai_base_url", "https://platform.openai.com/")
response = admin_client.get("/check/openai/")
assert response.status_code == 200
assert response.json()["data"] == False

setup_sys_config.set("openai_api_key", "sk-test-api-key")
response = admin_client.get("/check/openai/")
assert response.status_code == 200
assert response.json()["data"] == False

setup_sys_config.set("default_chat_model", "gpt-3.5-turbo")
response = admin_client.get("/check/openai/")
assert response.status_code == 200
assert response.json()["data"] == True


@pytest.mark.parametrize(
"data, expected_status",
[
(dict(), 1),
(
dict(
db_type="mysql",
query_desc="获取所有用户名为test的记录",
instance_name="test_instance",
),
1,
),
(
dict(
db_type="mysql",
query_desc="获取所有用户名为test的记录",
instance_name="some_ins",
),
1,
),
],
)
def test_generate_sql(admin_client, db_instance, data, expected_status):
"""测试openai生成sql"""
response = admin_client.post("/query/generate_sql/", data=data)
assert response.status_code == 200
assert response.json()["status"] == expected_status
2 changes: 2 additions & 0 deletions sql/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@
path("query/querylog/", query.querylog),
path("query/querylog_audit/", query.querylog_audit),
path("query/favorite/", query.favorite),
path("query/generate_sql/", query.generate_sql),
path("check/openai/", query.check_openai),
path("query/explain/", sql.sql_optimize.explain),
path("query/applylist/", sql.query_privileges.query_priv_apply_list),
path("query/userprivileges/", sql.query_privileges.user_query_priv),
Expand Down
Loading