Skip to content

Commit ce35ee3

Browse files
committed
test: Add unit tests for panic handling and SQL extraction from Java code
1 parent 100f90b commit ce35ee3

File tree

1 file changed

+178
-0
lines changed

1 file changed

+178
-0
lines changed

parser/util_test.go

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
package parser
2+
3+
import (
4+
"os"
5+
"testing"
6+
)
7+
8+
// TestSimplePanicFix 简单测试修复是否有效
9+
func TestSimplePanicFix(t *testing.T) {
10+
// 创建一个会导致panic的Expression
11+
expr := &Expression{
12+
Content: "\"SELECT * FROM users\"",
13+
RuleIndex: 52, // JavaParserRULE_literal
14+
Symbol: "+",
15+
Next: nil, // 这里是nil,修复前会导致panic
16+
}
17+
18+
// 测试修复后的函数不会panic
19+
defer func() {
20+
if r := recover(); r != nil {
21+
t.Errorf("函数仍然panic: %v", r)
22+
}
23+
}()
24+
25+
// 调用修复后的函数
26+
result := getStrValueFromExpression(expr)
27+
28+
// 验证函数正常返回(不panic)
29+
if result == nil {
30+
t.Error("结果不应该为nil")
31+
}
32+
33+
// 验证返回了正确的字符串(因为Next为nil,应该返回原始字符串)
34+
if len(result) == 1 && result[0] != "SELECT * FROM users" {
35+
t.Errorf("期望: SELECT * FROM users, 实际: %s", result[0])
36+
}
37+
}
38+
39+
// TestGetSqlFromJavaFile 测试GetSqlFromJavaFile方法的各种场景
40+
func TestGetSqlFromJavaFile(t *testing.T) {
41+
testCases := []struct {
42+
name string
43+
javaCode string
44+
shouldPanic bool
45+
}{
46+
{
47+
name: "包含null拼接的表达式",
48+
javaCode: `
49+
public class TestClass {
50+
public void testMethod() {
51+
String sql = "SELECT * FROM users" + null;
52+
executeQuery(sql);
53+
}
54+
55+
private void executeQuery(String sql) {
56+
// JDBC调用
57+
}
58+
}`,
59+
shouldPanic: false,
60+
},
61+
{
62+
name: "不完整的表达式_加号后换行",
63+
javaCode: `
64+
public class TestClass {
65+
public void testMethod() {
66+
String sql = "SELECT * FROM users" +
67+
executeQuery(sql);
68+
}
69+
70+
private void executeQuery(String sql) {
71+
// JDBC调用
72+
}
73+
}`,
74+
shouldPanic: false, // 修复后不应该panic
75+
},
76+
{
77+
name: "正常的多行表达式",
78+
javaCode: `
79+
public class TestClass {
80+
public void testMethod() {
81+
String sql = "SELECT * FROM users" +
82+
" WHERE id = 1";
83+
executeQuery(sql);
84+
}
85+
86+
private void executeQuery(String sql) {
87+
// JDBC调用
88+
}
89+
}`,
90+
shouldPanic: false,
91+
},
92+
{
93+
name: "复杂的多行拼接",
94+
javaCode: `
95+
public class TestClass {
96+
public void testMethod() {
97+
String sql = "SELECT " +
98+
"id, name, email " +
99+
"FROM users " +
100+
"WHERE status = 'active'";
101+
executeQuery(sql);
102+
}
103+
104+
private void executeQuery(String sql) {
105+
// JDBC调用
106+
}
107+
}`,
108+
shouldPanic: false,
109+
},
110+
{
111+
name: "语法错误的表达式",
112+
javaCode: `
113+
public class TestClass {
114+
public void testMethod() {
115+
String sql = "SELECT * FROM users" + ; // 语法错误
116+
executeQuery(sql);
117+
}
118+
119+
private void executeQuery(String sql) {
120+
// JDBC调用
121+
}
122+
}`,
123+
shouldPanic: false, // 修复后不应该panic,即使语法错误
124+
},
125+
}
126+
127+
for _, tc := range testCases {
128+
t.Run(tc.name, func(t *testing.T) {
129+
// 创建临时文件
130+
tmpFile, err := os.CreateTemp("", "test_*.java")
131+
if err != nil {
132+
t.Fatalf("创建临时文件失败: %v", err)
133+
}
134+
defer os.Remove(tmpFile.Name())
135+
136+
// 写入Java代码
137+
if _, err := tmpFile.WriteString(tc.javaCode); err != nil {
138+
t.Fatalf("写入Java代码失败: %v", err)
139+
}
140+
tmpFile.Close()
141+
142+
// 测试GetSqlFromJavaFile函数不会panic
143+
defer func() {
144+
if r := recover(); r != nil {
145+
if !tc.shouldPanic {
146+
t.Errorf("测试用例 %s 意外panic: %v", tc.name, r)
147+
}
148+
}
149+
}()
150+
151+
// 调用函数
152+
sqls, err := GetSqlFromJavaFile(tmpFile.Name())
153+
if err != nil {
154+
t.Logf("解析错误(可能正常): %v", err)
155+
}
156+
157+
// 验证函数正常返回(不panic)
158+
if sqls == nil {
159+
t.Error("结果不应该为nil")
160+
}
161+
})
162+
}
163+
}
164+
165+
// BenchmarkPanicFix 性能测试确保修复不影响性能
166+
func BenchmarkPanicFix(b *testing.B) {
167+
expr := &Expression{
168+
Content: "\"SELECT * FROM users\"",
169+
RuleIndex: 52,
170+
Symbol: "+",
171+
Next: nil,
172+
}
173+
174+
b.ResetTimer()
175+
for i := 0; i < b.N; i++ {
176+
getStrValueFromExpression(expr)
177+
}
178+
}

0 commit comments

Comments
 (0)