Skip to content

Commit db952d2

Browse files
committed
修正 sql error 参数截断逻辑不生效的 bug,并补充相关单元测试。
1 parent 46b6f2a commit db952d2

File tree

2 files changed

+171
-17
lines changed

2 files changed

+171
-17
lines changed

errors.go

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@ var (
2727
ErrExecutingSql = errors.New("dbClient: failed to execute sql")
2828
)
2929

30-
// MaxLengthErrorValue 用于限制错误中 Value 值的长度,超过该大小将会进行截断。
31-
var MaxLengthErrorValue = 64 * 1024
32-
3330
// getExecutingSqlError 用于生成一个带着 SQL 和参数列表的 ErrExecutingSql。
3431
// 错误内容中包含了:原始传入的 SQL,解析后的 SQL,参数列表。
3532
func getExecutingSqlError(err error, rawSql string, fixedSql string, params []any) error {
@@ -56,24 +53,42 @@ func printSqlParams(params []interface{}) string {
5653
switch v := param.(type) {
5754
// 如果是命名参数,打印出 name/value 对。
5855
case sql.NamedArg:
59-
logVal := v.Value
60-
61-
stringVal, ok := v.Value.(fmt.Stringer)
62-
if ok {
63-
logStringValue := stringVal.String()
64-
// string 类型的日志,参考 MaxLengthErrorValue 的值,对输出长度进行截取,以避免 Value 长度过长时候,输出过大的日志。
65-
if len(logStringValue) > MaxLengthErrorValue {
66-
logStringValue = logStringValue[:MaxLengthErrorValue]
67-
}
68-
logVal = logStringValue
69-
}
70-
71-
sb.WriteString(fmt.Sprintf("@%s=%v", v.Name, logVal))
56+
sb.WriteString(fmt.Sprintf("@%s=%v", v.Name, cutLongStringParams(v.Value)))
7257

7358
// 非命名参数,索引按顺序打印。
7459
default:
75-
sb.WriteString(fmt.Sprintf("@p%d=%v", i+1, v))
60+
sb.WriteString(fmt.Sprintf("@p%d=%v", i+1, cutLongStringParams(v)))
7661
}
7762
}
7863
return sb.String()
7964
}
65+
66+
// MaxLengthErrorValue 用于限制错误输出中参数的长度,超过该大小将会进行截断。
67+
// NOTE: 可自行调整该值。
68+
var MaxLengthErrorValue = 64 * 1024
69+
70+
// 本方法用于对参数进行处理,以避免在 error 中输出过大的字符串。
71+
func cutLongStringParams(paramVal any) any {
72+
var paramValString string
73+
switch v := paramVal.(type) {
74+
case string:
75+
paramValString = v
76+
case fmt.Stringer:
77+
paramValString = v.String()
78+
default:
79+
return paramVal
80+
}
81+
82+
if len(paramValString) <= MaxLengthErrorValue {
83+
return paramValString
84+
}
85+
86+
// 超过长度的字符串,截取前 MaxLengthErrorValue 个字符,后面用 ... 填充,并注明参数大小。
87+
var paramStringBuilder strings.Builder
88+
paramStringBuilder.Grow(MaxLengthErrorValue + 24)
89+
paramStringBuilder.WriteString(paramValString[:MaxLengthErrorValue])
90+
paramStringBuilder.WriteString("...")
91+
paramStringBuilder.WriteString(fmt.Sprintf("(length=%d)", len(paramValString)))
92+
93+
return paramStringBuilder.String()
94+
}

errors_test.go

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
package sqlmer
2+
3+
import (
4+
"database/sql"
5+
"errors"
6+
"strings"
7+
"testing"
8+
)
9+
10+
func TestGetExecutingSqlError(t *testing.T) {
11+
tests := []struct {
12+
name string
13+
err error
14+
rawSql string
15+
fixedSql string
16+
params []any
17+
wantErr string
18+
}{
19+
{
20+
name: "Basic Error Test",
21+
err: errors.New("test error"),
22+
rawSql: "SELECT * FROM users WHERE id = @p1",
23+
fixedSql: "SELECT * FROM users WHERE id = ?",
24+
params: []any{1},
25+
wantErr: "dbClient: failed to execute sql\nraw error: test error\nsql:\ninput sql=SELECT * FROM users WHERE id = @p1\nexecuting sql=SELECT * FROM users WHERE id = ?\nparams:\n@p1=1",
26+
},
27+
{
28+
name: "Named Parameter Test",
29+
err: errors.New("test error"),
30+
rawSql: "SELECT * FROM users WHERE name = @name",
31+
fixedSql: "SELECT * FROM users WHERE name = ?",
32+
params: []any{sql.Named("name", "test")},
33+
wantErr: "dbClient: failed to execute sql\nraw error: test error\nsql:\ninput sql=SELECT * FROM users WHERE name = @name\nexecuting sql=SELECT * FROM users WHERE name = ?\nparams:\n@name=test",
34+
},
35+
}
36+
37+
for _, tt := range tests {
38+
t.Run(tt.name, func(t *testing.T) {
39+
gotErr := getExecutingSqlError(tt.err, tt.rawSql, tt.fixedSql, tt.params)
40+
if gotErr.Error() != tt.wantErr {
41+
t.Errorf("getExecutingSqlError() error = %v, want %v", gotErr, tt.wantErr)
42+
}
43+
})
44+
}
45+
}
46+
47+
func TestGetSqlError(t *testing.T) {
48+
tests := []struct {
49+
name string
50+
err error
51+
rawSql string
52+
params []any
53+
wantErr string
54+
}{
55+
{
56+
name: "Basic Error Test",
57+
err: ErrExpectedSizeWrong,
58+
rawSql: "UPDATE users SET name = @p1",
59+
params: []any{"test"},
60+
wantErr: "dbClient: effected rows was wrong\nsql:\ninput sql=UPDATE users SET name = @p1\nparams:\n@p1=test",
61+
},
62+
{
63+
name: "Named Parameter Test",
64+
err: ErrParseParamFailed,
65+
rawSql: "INSERT INTO users (name) VALUES (@name)",
66+
params: []any{sql.Named("name", "test")},
67+
wantErr: "dbClient: failed to parse named params\nsql:\ninput sql=INSERT INTO users (name) VALUES (@name)\nparams:\n@name=test",
68+
},
69+
}
70+
71+
for _, tt := range tests {
72+
t.Run(tt.name, func(t *testing.T) {
73+
gotErr := getSqlError(tt.err, tt.rawSql, tt.params)
74+
if gotErr.Error() != tt.wantErr {
75+
t.Errorf("getSqlError() error = %v, want %v", gotErr, tt.wantErr)
76+
}
77+
})
78+
}
79+
}
80+
81+
func TestCutLongStringParams(t *testing.T) {
82+
originalMaxLength := MaxLengthErrorValue
83+
defer func() { MaxLengthErrorValue = originalMaxLength }()
84+
85+
// 为了便于测试,将最大长度设置为较小的值
86+
MaxLengthErrorValue = 10
87+
88+
tests := []struct {
89+
name string
90+
paramVal any
91+
want any
92+
}{
93+
{
94+
name: "Short String",
95+
paramVal: "test",
96+
want: "test",
97+
},
98+
{
99+
name: "Long String",
100+
paramVal: "this is a very long string",
101+
want: "this is a ...(length=24)",
102+
},
103+
{
104+
name: "Non-String Type",
105+
paramVal: 123,
106+
want: 123,
107+
},
108+
{
109+
name: "Stringer Interface",
110+
paramVal: testStringer{"this is a very long string"},
111+
want: "this is a ...(length=24)",
112+
},
113+
}
114+
115+
for _, tt := range tests {
116+
t.Run(tt.name, func(t *testing.T) {
117+
got := cutLongStringParams(tt.paramVal)
118+
if got != tt.want {
119+
// 对于字符串类型的结果,检查是否包含预期的内容
120+
if s, ok := got.(string); ok {
121+
if !strings.Contains(s, "...(length=") {
122+
t.Errorf("cutLongStringParams() = %v, want %v", got, tt.want)
123+
}
124+
} else {
125+
t.Errorf("cutLongStringParams() = %v, want %v", got, tt.want)
126+
}
127+
}
128+
})
129+
}
130+
}
131+
132+
// 用于测试 Stringer 接口的辅助类型
133+
type testStringer struct {
134+
value string
135+
}
136+
137+
func (ts testStringer) String() string {
138+
return ts.value
139+
}

0 commit comments

Comments
 (0)