Skip to content

基于BERT预训练模型使用pythorch训练中文句子转结构化sql

Notifications You must be signed in to change notification settings

WECENG/bert-nl2sql

Repository files navigation

自然语言转结构化SQL

将自然语言转为机器可以理解的SQL语言,旨在拉近用户与结构化数据间的距离,实现人机交互体验升级。实现基本思路:

  1. 定义结构化sql
  2. label训练数据
  3. 采用bert预训练模型
  4. 下游模型构建
  5. pytorch编程
  6. 模型训练
  7. 模型预测

目录结构

.

├── bert-base-chinese/ # pytorch使用的bert模型(哈工大)

├── train-datas/ # 存放数据文件 │ ├── train.jsonal # label好的训练数据 │ └── train_test.jsonal # 测试数据

├── result-model/ # 下游模型训练结果

├── doc/ # 存放文档 ├── src/ # 存放源代码 │ ├── utils.py # 存放通用工具函数 │ ├── model.py # 存放模型代码

│ ├── train.py # 模型训练代码

│ └── predict.py # 模型预测代码 ├── gaggle-bert-nl2sql # 在kaggle环境上运行的notebook ├── requirements.txt # 项目依赖的 Python 包列表 └── README.md # 项目说明文档

定义结构化sql

首先,明确算法输入为自然语言,算法输出为结构化sql。大体如下图所示:

nl2sql

从上图可知,需要明确输出,因此需要定义标准结构的sqlsql关键信息包含表名(table_name)查询列(sel_col)聚合操作(agg)条件(cond),其中条件又包含条件列(cond_col)条件操作符(cond_op)条件值(cond_value)以及连接操作符(conn_op)。本方案sql定义如下:

nl2sql

如果采用json格式输出,则输出结果如下:

nl2sql

上图,观察"agg":[5]会发现该结构存在问题,agg聚会操作应该需要能够表示出两部分信息,分别为聚会函数(AVG、MIN等)和聚合列。但是"agg":[5]只能表示出聚合函数,并无法表示出聚合列。同时观察"sel_col":["fee"]"conds": [["date", 2, "2024"],["companyName", 2, "分公司A"]]出现了列名feedatecompanyName不利于算法训练。因此为解决上述问题,需要对表中所有的列进行编码,以便利用编码后的索引,本方案采用简单的编码方式如下图所示:

nl2sql

已对表中的所有列进行编码,若要使用其列编号,则agg需要显示所有列的聚合函数。解决上述问题后,尚存在一个问题,即如何对条件列表进行编码,该问题要关注的重点有两个,(1)条件的数量如何确定,(2)条件列可重复出现

因为每个question解析成对应的sql对应的cond数量各异,为了方便编码,采用最大值,默认cond数量最大不超过question长度,由此可进行编码。条件列采用亦是采用最大question长度进行编码,因此即使重复出现,编码后亦可体现出来。

假设question最大长度为16,最终输出结果如下所示:

nl2sql

上图可见,agg的长度为6,即tabel column的长度,agg的值表示要进行的聚合操作,cond_*的长度为question的最大长度,cond_cols的值表示条件列,如0表示date这一列,6则表示不存在的列,即这个cond不存在。cond_opscond_vals也是这个编码逻辑。

从最终的输出可以看出,本方法对NL2SQL问题的处理方案采用的是解决多分类问题

label训练数据

在已经确定算法输入输出的前提下,可以进行标签数据收集。标签数据是指在监督学习中用于训练和评估模型的一组已知输出。在监督学习任务中,我们训练模型来学习输入数据和相应输出标签之间的映射关系。这个映射关系用于模型对新的、之前未见过的输入数据进行预测。

结构化SQL可知需要收集的标签数据有两部分:(1)表及其所有需训练的列 (2)问题及对应的结构化SQL

  • 表及其所有需训练的列 需要收集希望被训练及预测的表及其列,并对列进行编码。如下所示:

    nl2sql

  • 问题及对应的结构化SQL 若使用json格式的训练数据,最终效果如下:

    {
    "question": "2024年分公司A的合同总额",
    "table_id": "contract",
    "sql": {
    "sel_col": ["fee"],
    "agg": [[5,"fee"]],
    "limit": 0,
    "orderby": [],
    "asc_desc": 0,
    "conn_op": 1,
    "conds": [
    ["date", 2, "2024"],
    ["companyName", 2, "分公司A"]
    ]
    },
    "keywords": {
    "sel_cols": ["总额"],
    "values": ["2024", "分公司A"]
    }
    }

    结构化SQL可知,上述训练数据还不能直接用于训练,需要进行编码转换,最终:

    {
    "question": "2024年分公司A的合同总额",
    "table_id": "contract",
    "sql": {
    "agg": [6, 6, 6, 6, 6, 5],
    "limit": 0,
    "orderby": [],
    "asc_desc": 0,
    "conn_op": 1,
    "cond_cols": [0, 2, 6, 6, 6, 6
    , 6, 6, 6, 6, 6, 6, 6, 6], #6表示不存在的列
    "cond_ops": [2, 2, 7, 7, 7, 7
    , 7, 7, 7, 7, 7, 7, 7, 7],
    "cond_vals": [1,1,1,1,0,2,2,2,0,0,0,0,0,0,0,0]
    },
    "keywords": {
    "sel_cols": ["总额"],
    "values": ["2024", "分公司A"]
    }
    }

采用bert预训练模型

解决自然语言转结构化SQL问题涉及到自然语言识别(NLP),同时考虑到label数据数据量小的问题,采用已经预训练好的模型bert作为本算法的上游模型,然后本算法模型在bert模型的基础上搭建下游模型,该下游模型主要处理的问题是结构化SQL中所分析出来的多分类问题

bert模型使用

下游模型构建

结构化SQL可知,下游模型需要处理的是多分类问题,主要需要训练的内容是**agg聚合操作、conn_op连接符、cond_op条件操作符、cond_vals条件操作值**。主要存在的问题是查询列、条件数量是不固定的,由于已经对列进行的确定的编码,因此查询列、条件数量的最大值是确定的。因此**aggconn_opcond_colcond_opcond_vals会对所有列进行分类训练**。

将模型拆分成两部分,第一部分训练&预测aggconn_op,第二部分训练&预测conds,整体结构如下图所示:

nl2sql

model A

模型A负责训练aggconn_opcond_ops。因为采用bert作为上游训练模型,bert的输入要求一个tokenization。需要将questiontable column按顺序拼接成一个tokenization作为bert模型的输入。

nl2sql

模型A的架构如下:

nl2sql

model B

模型B负责训练cond_vals,架构如下:

nl2sql

pytorch编程

pytorch框架已对bert模型做了很好的封装。本方案采用pytorch进行算法编程。

模型编程

模型训练

模型预测

About

基于BERT预训练模型使用pythorch训练中文句子转结构化sql

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published