Skip to content

Commit 993899d

Browse files
committed
add unit test
1 parent f68cce5 commit 993899d

File tree

11 files changed

+296
-15
lines changed

11 files changed

+296
-15
lines changed

parser/Test0.java

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ public ResultSet executeQuery(String sqlQuery, String sqlQuery1) throws SQLExcep
1515
Statement statement = connection.createStatement();
1616
String a = sqlQuery+sqlQuery1;
1717
ResultSet resultSet = statement.executeQuery(a);
18-
19-
String formattedString = String.format("%s %s", sqlQuery, sqlQuery1);
20-
21-
ResultSet resultSet1 = statement.executeQuery(formattedString);
2218
return resultSet;
2319
}
2420
}
@@ -46,6 +42,36 @@ public static void main(String[] args) {
4642

4743

4844

45+
connection.close();
46+
} catch (SQLException e) {
47+
e.printStackTrace();
48+
}
49+
}
50+
}
51+
52+
public class DatabaseQuery1 {
53+
private Connection connection;
54+
55+
public DatabaseQuery1(Connection connection) {
56+
this.connection = connection;
57+
}
58+
59+
60+
61+
public static void main(String[] args) {
62+
String jdbcUrl = "jdbc:mysql://localhost:3306/mydatabase";
63+
String username = "yourUsername";
64+
String password = "yourPassword";
65+
66+
try {
67+
DatabaseConnection dbQuery = new DatabaseConnection();
68+
69+
String sqlQuery = "SELECT * FROM users";
70+
String sqlQuery1 = " limit 1";
71+
ResultSet resultSet = dbQuery.executeQuery(sqlQuery, sqlQuery1);
72+
73+
74+
4975
connection.close();
5076
} catch (SQLException e) {
5177
e.printStackTrace();

parser/parser_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func TestJavaFiles(t *testing.T) {
6464
}
6565

6666
func TestSingleJavaFile(t *testing.T) {
67-
javaFile := "/root/javaexample/test/Test11.java"
67+
javaFile := "/root/javaexample/parser/Test0.java"
6868
sqls, err := GetSqlFromJavaFile(javaFile)
6969
if err != nil {
7070
t.Error(err)

parser/util.go

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,19 @@ func (f *FuncBlock) getValueFromCallExpr(argumentIndex int) []string {
8080
values = append(values, recursionSql...)
8181
}
8282
} else if arg.RuleIndex == javaAntlr.JavaParserRULE_literal {
83-
values = append(values, calledExpr.Content)
83+
values = append(values, strings.Trim(arg.Content, "\""))
84+
if arg.Symbol == PLUS {
85+
var tmpSlice []string
86+
nextStrs := getStrValueFromExpression(arg.Next)
87+
for _, str := range nextStrs {
88+
for _, result := range values {
89+
tmpSlice = append(tmpSlice, result+str)
90+
}
91+
}
92+
if len(tmpSlice) > 0 {
93+
values = tmpSlice
94+
}
95+
}
8496
}
8597
}
8698
}
@@ -130,16 +142,41 @@ func getVariableValueFromTree(variableName string, currentNode interface{}) []st
130142
func GetSqlsFromVisitor(ctx *JavaVisitor) []string {
131143
sqls := []string{}
132144
for _, expression := range ctx.ExecSqlExpressions {
133-
for _, arg := range expression.Arguments {
134-
// 参数为变量
135-
if arg.RuleIndex == javaAntlr.JavaParserRULE_identifier {
136-
sqls = append(sqls, getVariableValueFromTree(arg.Content, expression.Node)...)
137-
// 参数为字符串
138-
} else if arg.RuleIndex == javaAntlr.JavaParserRULE_literal {
139-
// anltr为了区分字符串和其他变量,会为字符串的值左右添加双引号,获取sql时需要去除左右的双引号
140-
sqls = append(sqls, strings.Trim(arg.Content, "\""))
145+
if len(expression.Arguments) == 0 {
146+
continue
147+
}
148+
arg := expression.Arguments[0]
149+
// 参数为变量
150+
if arg.RuleIndex == javaAntlr.JavaParserRULE_identifier {
151+
sqls = append(sqls, getVariableValueFromTree(arg.Content, expression.Node)...)
152+
if arg.Symbol == PLUS {
153+
tmpSlice := []string{}
154+
nextSqls := getVariableValueFromTree(arg.Next.Content, expression.Node)
155+
for _, str := range nextSqls {
156+
for _, result := range sqls {
157+
tmpSlice = append(tmpSlice, result+str)
158+
}
159+
}
160+
if len(tmpSlice) > 0 {
161+
sqls = tmpSlice
162+
}
163+
}
164+
// 参数为字符串
165+
} else if arg.RuleIndex == javaAntlr.JavaParserRULE_literal {
166+
// anltr为了区分字符串和其他变量,会为字符串的值左右添加双引号,获取sql时需要去除左右的双引号
167+
sqls = append(sqls, strings.Trim(arg.Content, "\""))
168+
if arg.Symbol == PLUS {
169+
tmpSlice := []string{}
170+
nextSqls := getVariableValueFromTree(arg.Next.Content, expression.Node)
171+
for _, str := range nextSqls {
172+
for _, result := range sqls {
173+
tmpSlice = append(tmpSlice, result+str)
174+
}
175+
}
176+
if len(tmpSlice) > 0 {
177+
sqls = tmpSlice
178+
}
141179
}
142-
143180
}
144181
}
145182

test/Test12.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import java.sql.Connection;
2+
import java.sql.DriverManager;
3+
import java.sql.ResultSet;
4+
import java.sql.SQLException;
5+
import java.sql.Statement;
6+
7+
public class DatabaseQuery {
8+
private Connection connection;
9+
10+
public DatabaseQuery(String jdbcUrl, String username, String password) throws SQLException {
11+
// 构造函数中创建数据库连接
12+
connection = DriverManager.getConnection(jdbcUrl, username, password);
13+
}
14+
15+
public ResultSet exec(String sqlQuery1) throws SQLException {
16+
// 创建 Statement 对象
17+
Statement statement = connection.createStatement();
18+
19+
String sqlQurey = sqlQuery1 + " WHERE a = 1;";
20+
21+
// 执行 SQL 查询
22+
ResultSet resultSet = statement.executeQuery(sqlQurey);
23+
24+
return resultSet;
25+
}
26+
27+
public void closeConnection() throws SQLException {
28+
// 关闭数据库连接
29+
if (connection != null) {
30+
connection.close();
31+
}
32+
}
33+
34+
public static void main(String[] args) {
35+
String jdbcUrl = "jdbc:mysql://localhost:3306/mydatabase";
36+
String username = "yourUsername";
37+
String password = "yourPassword";
38+
39+
try {
40+
DatabaseQuery dbQuery = new DatabaseQuery(jdbcUrl, username, password);
41+
42+
// 执行 SQL 查询
43+
String sqlQuery = "SELECT * FROM employees";
44+
ResultSet resultSet = dbQuery.exec(sqlQuery);
45+
46+
// 处理查询结果
47+
while (resultSet.next()) {
48+
int employeeId = resultSet.getInt("employee_id");
49+
String firstName = resultSet.getString("first_name");
50+
String lastName = resultSet.getString("last_name");
51+
52+
System.out.println("Employee ID: " + employeeId);
53+
System.out.println("First Name: " + firstName);
54+
System.out.println("Last Name: " + lastName);
55+
}
56+
57+
// 关闭数据库连接
58+
dbQuery.closeConnection();
59+
} catch (SQLException e) {
60+
e.printStackTrace();
61+
}
62+
}
63+
}

test/Test12.java.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
SELECT * FROM employees WHERE a = 1;

test/Test13.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import java.sql.Connection;
2+
import java.sql.DriverManager;
3+
import java.sql.ResultSet;
4+
import java.sql.SQLException;
5+
import java.sql.Statement;
6+
7+
public class DatabaseQuery {
8+
private Connection connection;
9+
10+
public DatabaseQuery(String jdbcUrl, String username, String password) throws SQLException {
11+
// 构造函数中创建数据库连接
12+
connection = DriverManager.getConnection(jdbcUrl, username, password);
13+
}
14+
15+
public ResultSet exec() throws SQLException {
16+
// 创建 Statement 对象
17+
Statement statement = connection.createStatement();
18+
19+
String sqlQurey = "select * from users";
20+
String sqlQuery1 = " where a = 1;";
21+
22+
// 执行 SQL 查询
23+
ResultSet resultSet = statement.executeQuery(sqlQurey+sqlQuery1);
24+
25+
return resultSet;
26+
}
27+
28+
public void closeConnection() throws SQLException {
29+
// 关闭数据库连接
30+
if (connection != null) {
31+
connection.close();
32+
}
33+
}
34+
35+
public static void main(String[] args) {
36+
String jdbcUrl = "jdbc:mysql://localhost:3306/mydatabase";
37+
String username = "yourUsername";
38+
String password = "yourPassword";
39+
40+
try {
41+
DatabaseQuery dbQuery = new DatabaseQuery(jdbcUrl, username, password);
42+
43+
// 执行 SQL 查询
44+
ResultSet resultSet = dbQuery.exec();
45+
46+
// 处理查询结果
47+
while (resultSet.next()) {
48+
int employeeId = resultSet.getInt("employee_id");
49+
String firstName = resultSet.getString("first_name");
50+
String lastName = resultSet.getString("last_name");
51+
52+
System.out.println("Employee ID: " + employeeId);
53+
System.out.println("First Name: " + firstName);
54+
System.out.println("Last Name: " + lastName);
55+
}
56+
57+
// 关闭数据库连接
58+
dbQuery.closeConnection();
59+
} catch (SQLException e) {
60+
e.printStackTrace();
61+
}
62+
}
63+
}

test/Test13.java.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
select * from users where a = 1;

test/Test14.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import java.sql.Connection;
2+
import java.sql.DriverManager;
3+
import java.sql.ResultSet;
4+
import java.sql.SQLException;
5+
import java.sql.Statement;
6+
7+
public class DatabaseQuery {
8+
private Connection connection;
9+
10+
public DatabaseQuery(String jdbcUrl, String username, String password) throws SQLException {
11+
// 构造函数中创建数据库连接
12+
connection = DriverManager.getConnection(jdbcUrl, username, password);
13+
}
14+
15+
public ResultSet exec(String sqlQurey) throws SQLException {
16+
// 创建 Statement 对象
17+
Statement statement = connection.createStatement();
18+
19+
String sqlQuery1 = " where a = 1;";
20+
21+
// 执行 SQL 查询
22+
ResultSet resultSet = statement.executeQuery(sqlQurey+sqlQuery1);
23+
24+
return resultSet;
25+
}
26+
27+
public static void main(String[] args) {
28+
String jdbcUrl = "jdbc:mysql://localhost:3306/mydatabase";
29+
String username = "yourUsername";
30+
String password = "yourPassword";
31+
32+
try {
33+
DatabaseQuery dbQuery = new DatabaseQuery(jdbcUrl, username, password);
34+
35+
// 执行 SQL 查询
36+
ResultSet resultSet = dbQuery.exec("select * from users");
37+
38+
String queryStr = "select * from students";
39+
ResultSet resultSet1 = dbQuery.exec(queryStr);
40+
41+
} catch (SQLException e) {
42+
e.printStackTrace();
43+
}
44+
}
45+
}

test/Test14.java.sql

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
select * from users where a = 1;
2+
select * from students where a = 1;

test/Test15.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import java.sql.Connection;
2+
import java.sql.DriverManager;
3+
import java.sql.ResultSet;
4+
import java.sql.SQLException;
5+
import java.sql.Statement;
6+
7+
public class DatabaseQuery {
8+
private Connection connection;
9+
10+
public DatabaseQuery(String jdbcUrl, String username, String password) throws SQLException {
11+
// 构造函数中创建数据库连接
12+
connection = DriverManager.getConnection(jdbcUrl, username, password);
13+
}
14+
15+
public ResultSet exec(String sqlQurey) throws SQLException {
16+
// 创建 Statement 对象
17+
Statement statement = connection.createStatement();
18+
19+
// 执行 SQL 查询
20+
ResultSet resultSet = statement.executeQuery(sqlQurey);
21+
22+
return resultSet;
23+
}
24+
25+
public static void main(String[] args) {
26+
String jdbcUrl = "jdbc:mysql://localhost:3306/mydatabase";
27+
String username = "yourUsername";
28+
String password = "yourPassword";
29+
30+
try {
31+
DatabaseQuery dbQuery = new DatabaseQuery(jdbcUrl, username, password);
32+
33+
// 执行 SQL 查询
34+
String whereStr = " where 1=1;";
35+
36+
ResultSet resultSet = dbQuery.exec("select * from users"+whereStr);
37+
38+
} catch (SQLException e) {
39+
e.printStackTrace();
40+
}
41+
}
42+
}

0 commit comments

Comments
 (0)