Skip to content

Commit 1f858a7

Browse files
committed
add parser
1 parent 4187603 commit 1f858a7

File tree

9 files changed

+558
-85
lines changed

9 files changed

+558
-85
lines changed

parser/Test0.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import java.sql.Connection;
2+
import java.sql.ResultSet;
3+
import java.sql.SQLException;
4+
import java.sql.Statement;
5+
6+
public class DatabaseConnection {
7+
private Connection connection;
8+
9+
10+
public Connection getConnection() {
11+
return connection;
12+
}
13+
14+
public ResultSet executeQuery(String sqlQuery, String sqlQuery1) throws SQLException {
15+
Statement statement = connection.createStatement();
16+
String a = sqlQuery+sqlQuery1;
17+
ResultSet resultSet = statement.executeQuery(a);
18+
19+
String formattedString = String.format("%s %s", sqlQuery, sqlQuery1);
20+
21+
ResultSet resultSet1 = statement.executeQuery(formattedString);
22+
return resultSet;
23+
}
24+
}
25+
26+
public class DatabaseQuery {
27+
private Connection connection;
28+
29+
public DatabaseQuery(Connection connection) {
30+
this.connection = connection;
31+
}
32+
33+
34+
35+
public static void main(String[] args) {
36+
String jdbcUrl = "jdbc:mysql://localhost:3306/mydatabase";
37+
String username = "yourUsername";
38+
String password = "yourPassword";
39+
40+
try {
41+
DatabaseConnection dbQuery = new DatabaseConnection();
42+
43+
String sqlQuery = "SELECT * FROM employees";
44+
String sqlQuery1 = " limit 1";
45+
ResultSet resultSet = dbQuery.executeQuery(sqlQuery, sqlQuery1);
46+
47+
48+
49+
connection.close();
50+
} catch (SQLException e) {
51+
e.printStackTrace();
52+
}
53+
}
54+
}

parser/parser.go

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package parser
22

33
import (
4+
"strings"
5+
46
"github.com/antlr4-go/antlr/v4"
57

68
javaAntlr "github.com/actiontech/java-sql-extractor/java_antlr"
@@ -12,20 +14,56 @@ type JavaVisitor struct {
1214
ExecSqlExpressions []*Expression
1315

1416
currentNode interface{}
17+
currentFile *FileBlock
1518
currentClass *ClassBlock
1619
currentFunc *FuncBlock
20+
currentExpr *Expression
1721

1822
isStatic bool
1923
}
2024

25+
// 获取变量类型,并判断变量类型是否为import引入的类
26+
func (v *JavaVisitor) getClassFromVar(variableName string) (string, bool) {
27+
var variableType string
28+
if v.currentFunc != nil {
29+
for _, vari := range v.currentFunc.Variables {
30+
if vari.Name == variableName {
31+
variableType = vari.Type
32+
}
33+
}
34+
}
35+
if variableType == "" {
36+
return variableType, false
37+
}
38+
for _, importStr := range v.currentFile.ImportClass {
39+
strSlice := strings.Split(importStr, ".")
40+
if strSlice[len(strSlice)-1] == variableName {
41+
return importStr, true
42+
}
43+
}
44+
return variableType, false
45+
}
46+
47+
func (v *JavaVisitor) getClassFromSingleFile(className string) *ClassBlock {
48+
for _, classBlock := range v.currentFile.ClassBlocks {
49+
if classBlock.Name == className {
50+
return classBlock
51+
}
52+
}
53+
return nil
54+
}
55+
2156
type FileBlock struct {
2257
ClassBlocks []*ClassBlock
58+
ImportClass []string
2359
}
2460

2561
type ClassBlock struct {
2662
Name string
2763
FuncBlocks []*FuncBlock
2864
Variables []*Variable // 静态变量
65+
66+
isDeclare bool // 判断该类是在申明状态还是解析状态
2967
}
3068

3169
func (v *ClassBlock) UpdateVariable(variable *Variable) {
@@ -36,10 +74,24 @@ func (v *ClassBlock) UpdateVariable(variable *Variable) {
3674
}
3775
}
3876

77+
func (c *ClassBlock) getFuncFromClass(funcName string) *FuncBlock {
78+
for _, funcBlock := range c.FuncBlocks {
79+
if funcBlock.Name == funcName {
80+
return funcBlock
81+
}
82+
}
83+
return nil
84+
}
85+
3986
type FuncBlock struct {
4087
Name string
4188
CodeBlock *CodeBlock
4289
Variables []*Variable // 函数中所有的变量
90+
Pointer *javaAntlr.MethodDeclarationContext
91+
isDeclare bool // 判断该方法是在申明状态还是解析状态
92+
Params []*Param // 函数参数列表
93+
94+
CalledExprs []*Expression // 调用该函数的表达式列表
4395

4496
Parent interface{}
4597
}
@@ -52,6 +104,11 @@ func (v *FuncBlock) UpdateVariable(variable *Variable) {
52104
}
53105
}
54106

107+
type Param struct {
108+
Content string
109+
Type string
110+
}
111+
55112
type CodeBlock struct {
56113
Variables []*Variable
57114
CodeBlocks []*CodeBlock
@@ -61,16 +118,26 @@ type CodeBlock struct {
61118

62119
type Variable struct {
63120
Name string
64-
Value *Primary
121+
Value *Expression
65122
Level string
66123
Type string
124+
125+
Pointer *javaAntlr.VariableDeclaratorContext
67126
}
68127

69128
type Expression struct {
70-
SubExpr *Expression
71-
Content *Primary
72-
MethodCall string
129+
Next *Expression // 子节点
130+
Before *Expression // 父节点
131+
Content string
132+
RuleIndex int
73133
Arguments []*Expression
134+
MethodCall *FuncBlock
135+
Depth int // 调用深度;例如a.b.c a的调用深度为1 b的调用深度为2 c的调用深度为3
136+
Symbol string // 表达式后面跟的符号;例如+和.
137+
138+
IsImport bool
139+
ImportName string
140+
LocalClass *ClassBlock // 指向表达式a.c()中a的class
74141

75142
Node interface{}
76143
}

parser/parser_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,15 @@ func TestJavaFiles(t *testing.T) {
6464
}
6565

6666
func TestSingleJavaFile(t *testing.T) {
67-
sqls, err := GetSqlFromJavaFile("/root/javaexample/test/Test6.java")
67+
javaFile := "/root/javaexample/parser/Test0.java"
68+
sqls, err := GetSqlFromJavaFile(javaFile)
6869
if err != nil {
6970
t.Error(err)
7071
}
71-
sqlFileSqls := getSqlsFromSqlFile("/root/javaexample/test/Test6.java" + ".sql")
72+
sqlFileSqls := getSqlsFromSqlFile(javaFile + ".sql")
7273
for i, sql := range sqls {
7374
if sql != sqlFileSqls[i] {
74-
t.Error(fmt.Errorf("sql parser failed, java file: %s", "/root/javaexample/test/Test6.java"))
75+
t.Error(fmt.Errorf("sql parser failed, java file: %s", javaFile))
7576
}
7677
}
7778
}

parser/symbol.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package parser
2+
3+
const STATIC = "static"
4+
5+
6+
const ASSIGN = "="
7+
const DOT = "."
8+
const PLUS = "+"
9+
10+
// 自定义 . 和 = 符号的ruleindex
11+
const TerminalRuleIndex = -1

0 commit comments

Comments
 (0)