@@ -80,7 +80,19 @@ func (f *FuncBlock) getValueFromCallExpr(argumentIndex int) []string {
8080 values = append (values , recursionSql ... )
8181 }
8282 } else if arg .RuleIndex == javaAntlr .JavaParserRULE_literal {
83- values = append (values , calledExpr .Content )
83+ values = append (values , strings .Trim (arg .Content , "\" " ))
84+ if arg .Symbol == PLUS {
85+ var tmpSlice []string
86+ nextStrs := getStrValueFromExpression (arg .Next )
87+ for _ , str := range nextStrs {
88+ for _ , result := range values {
89+ tmpSlice = append (tmpSlice , result + str )
90+ }
91+ }
92+ if len (tmpSlice ) > 0 {
93+ values = tmpSlice
94+ }
95+ }
8496 }
8597 }
8698 }
@@ -130,16 +142,41 @@ func getVariableValueFromTree(variableName string, currentNode interface{}) []st
130142func GetSqlsFromVisitor (ctx * JavaVisitor ) []string {
131143 sqls := []string {}
132144 for _ , expression := range ctx .ExecSqlExpressions {
133- for _ , arg := range expression .Arguments {
134- // 参数为变量
135- if arg .RuleIndex == javaAntlr .JavaParserRULE_identifier {
136- sqls = append (sqls , getVariableValueFromTree (arg .Content , expression .Node )... )
137- // 参数为字符串
138- } else if arg .RuleIndex == javaAntlr .JavaParserRULE_literal {
139- // anltr为了区分字符串和其他变量,会为字符串的值左右添加双引号,获取sql时需要去除左右的双引号
140- sqls = append (sqls , strings .Trim (arg .Content , "\" " ))
145+ if len (expression .Arguments ) == 0 {
146+ continue
147+ }
148+ arg := expression .Arguments [0 ]
149+ // 参数为变量
150+ if arg .RuleIndex == javaAntlr .JavaParserRULE_identifier {
151+ sqls = append (sqls , getVariableValueFromTree (arg .Content , expression .Node )... )
152+ if arg .Symbol == PLUS {
153+ tmpSlice := []string {}
154+ nextSqls := getVariableValueFromTree (arg .Next .Content , expression .Node )
155+ for _ , str := range nextSqls {
156+ for _ , result := range sqls {
157+ tmpSlice = append (tmpSlice , result + str )
158+ }
159+ }
160+ if len (tmpSlice ) > 0 {
161+ sqls = tmpSlice
162+ }
163+ }
164+ // 参数为字符串
165+ } else if arg .RuleIndex == javaAntlr .JavaParserRULE_literal {
166+ // anltr为了区分字符串和其他变量,会为字符串的值左右添加双引号,获取sql时需要去除左右的双引号
167+ sqls = append (sqls , strings .Trim (arg .Content , "\" " ))
168+ if arg .Symbol == PLUS {
169+ tmpSlice := []string {}
170+ nextSqls := getVariableValueFromTree (arg .Next .Content , expression .Node )
171+ for _ , str := range nextSqls {
172+ for _ , result := range sqls {
173+ tmpSlice = append (tmpSlice , result + str )
174+ }
175+ }
176+ if len (tmpSlice ) > 0 {
177+ sqls = tmpSlice
178+ }
141179 }
142-
143180 }
144181 }
145182
0 commit comments