Skip to content

Commit 4495db4

Browse files
authored
Merge pull request #4 from actiontech/issue_442
Issue 442
2 parents 1c4141e + 53012e7 commit 4495db4

File tree

3 files changed

+302
-0
lines changed

3 files changed

+302
-0
lines changed

ast/mapper.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,20 @@ func (m *Mapper) GetStmt(ctx *Context) (string, error) {
6666
}
6767
return strings.TrimSuffix(buff.String(), "\n"), nil
6868
}
69+
70+
func (m *Mapper) GetStmts(ctx *Context, skipErrorQuery bool) ([]string, error) {
71+
var stmts []string
72+
ctx.Sqls = m.SqlNodes
73+
for _, a := range m.QueryNodes {
74+
data, err := a.GetStmt(ctx)
75+
if err == nil {
76+
stmts = append(stmts, data)
77+
continue
78+
}
79+
if skipErrorQuery {
80+
continue
81+
}
82+
return nil, err
83+
}
84+
return stmts, nil
85+
}

parser.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ package parser
22

33
import (
44
"encoding/xml"
5+
"fmt"
56
"io"
67
"strings"
78

89
"github.com/actiontech/mybatis-mapper-2-sql/ast"
910
)
1011

12+
// ParseXML is a parser for parse all query in XML to string.
1113
func ParseXML(data string) (string, error) {
1214
r := strings.NewReader(data)
1315
d := xml.NewDecoder(r)
@@ -25,6 +27,29 @@ func ParseXML(data string) (string, error) {
2527
return stmt, nil
2628
}
2729

30+
// ParseXMLQuery is a parser for parse all query in XML to []string one by one;
31+
// you can set `skipErrorQuery` true to ignore invalid query.
32+
func ParseXMLQuery(data string, skipErrorQuery bool) ([]string, error) {
33+
r := strings.NewReader(data)
34+
d := xml.NewDecoder(r)
35+
n, err := parse(d, nil)
36+
if err != nil {
37+
return nil, err
38+
}
39+
if n == nil {
40+
return nil, nil
41+
}
42+
m, ok := n.(*ast.Mapper)
43+
if !ok {
44+
return nil, fmt.Errorf("the mapper is not found")
45+
}
46+
stmts, err := m.GetStmts(ast.NewContext(), skipErrorQuery)
47+
if err != nil {
48+
return nil, err
49+
}
50+
return stmts, nil
51+
}
52+
2853
func parse(d *xml.Decoder, start *xml.StartElement) (node ast.Node, err error) {
2954
if start != nil {
3055
node, err = scan(start)

parser_test.go

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,3 +572,263 @@ func TestParserSQLRefIdNotFound(t *testing.T) {
572572
t.Errorf("actual error is [%s]", err.Error())
573573
}
574574
}
575+
576+
func testParserQuery(t *testing.T, skipError bool, xmlData string, expect []string) {
577+
actual, err := ParseXMLQuery(xmlData, skipError)
578+
if err != nil {
579+
t.Errorf("parse error: %v", err)
580+
return
581+
}
582+
if len(actual) != len(expect) {
583+
t.Errorf("the length of actual is not the same as the length of expected, actual length is %d, expect is %d",
584+
len(actual), len(expect))
585+
return
586+
}
587+
for i := range actual {
588+
if actual[i] != expect[i] {
589+
t.Errorf("\nexpect[%d]: [%s]\nactual[%d]: [%s]", i, expect, i, actual)
590+
}
591+
}
592+
593+
}
594+
595+
func TestParserQueryFullFile(t *testing.T) {
596+
testParserQuery(t, false,
597+
`
598+
<?xml version="1.0" encoding="UTF-8"?>
599+
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
600+
<mapper namespace="Test">
601+
<sql id="sometable">
602+
fruits
603+
</sql>
604+
<sql id="somewhere">
605+
WHERE
606+
category = #{category}
607+
</sql>
608+
<sql id="someinclude">
609+
FROM
610+
<include refid="${include_target}"/>
611+
<include refid="somewhere"/>
612+
</sql>
613+
<select id="testParameters">
614+
SELECT
615+
name,
616+
category,
617+
price
618+
FROM
619+
fruits
620+
WHERE
621+
category = #{category}
622+
AND price > ${price}
623+
</select>
624+
<select id="testInclude">
625+
SELECT
626+
name,
627+
category,
628+
price
629+
<include refid="someinclude">
630+
<property name="prefix" value="Some"/>
631+
<property name="include_target" value="sometable"/>
632+
</include>
633+
</select>
634+
<select id="testIf">
635+
SELECT
636+
name,
637+
category,
638+
price
639+
FROM
640+
fruits
641+
WHERE
642+
1=1
643+
<if test="category != null and category !=''">
644+
AND category = #{category}
645+
</if>
646+
<if test="price != null and price !=''">
647+
AND price = ${price}
648+
<if test="price >= 400">
649+
AND name = 'Fuji'
650+
</if>
651+
</if>
652+
</select>
653+
<select id="testTrim">
654+
SELECT
655+
name,
656+
category,
657+
price
658+
FROM
659+
fruits
660+
<trim prefix="WHERE" prefixOverrides="AND|OR">
661+
OR category = 'apple'
662+
OR price = 200
663+
</trim>
664+
</select>
665+
<select id="testWhere">
666+
SELECT
667+
name,
668+
category,
669+
price
670+
FROM
671+
fruits
672+
<where>
673+
AND category = 'apple'
674+
<if test="price != null and price !=''">
675+
AND price = ${price}
676+
</if>
677+
</where>
678+
</select>
679+
<update id="testSet">
680+
UPDATE
681+
fruits
682+
<set>
683+
<if test="category != null and category !=''">
684+
category = #{category},
685+
</if>
686+
<if test="price != null and price !=''">
687+
price = ${price},
688+
</if>
689+
</set>
690+
WHERE
691+
name = #{name}
692+
</update>
693+
<select id="testChoose">
694+
SELECT
695+
name,
696+
category,
697+
price
698+
FROM
699+
fruits
700+
<where>
701+
<choose>
702+
<when test="name != null">
703+
AND name = #{name}
704+
</when>
705+
<when test="category == 'banana'">
706+
AND category = #{category}
707+
<if test="price != null and price !=''">
708+
AND price = ${price}
709+
</if>
710+
</when>
711+
<otherwise>
712+
AND category = 'apple'
713+
</otherwise>
714+
</choose>
715+
</where>
716+
</select>
717+
<select id="testForeach">
718+
SELECT
719+
name,
720+
category,
721+
price
722+
FROM
723+
fruits
724+
<where>
725+
category = 'apple' AND
726+
<foreach collection="apples" item="name" open="(" close=")" separator="OR">
727+
<if test="name == 'Jonathan' or name == 'Fuji'">
728+
name = #{name}
729+
</if>
730+
</foreach>
731+
</where>
732+
</select>
733+
<insert id="testInsertMulti">
734+
INSERT INTO
735+
fruits
736+
(
737+
name,
738+
category,
739+
price
740+
)
741+
VALUES
742+
<foreach collection="fruits" item="fruit" separator=",">
743+
(
744+
#{fruit.name},
745+
#{fruit.category},
746+
${fruit.price}
747+
)
748+
</foreach>
749+
</insert>
750+
<select id="testBind">
751+
<bind name="likeName" value="'%' + name + '%'"/>
752+
SELECT
753+
name,
754+
category,
755+
price
756+
FROM
757+
fruits
758+
WHERE
759+
name like #{likeName}
760+
</select>
761+
</mapper>`,
762+
[]string{
763+
"SELECT `name`,`category`,`price` FROM `fruits` WHERE `category`=? AND `price`>?",
764+
"SELECT `name`,`category`,`price` FROM `fruits` WHERE `category`=?",
765+
"SELECT `name`,`category`,`price` FROM `fruits` WHERE 1=1 AND `category`=? AND `price`=? AND `name`=\"Fuji\"",
766+
"SELECT `name`,`category`,`price` FROM `fruits` WHERE `category`=\"apple\" OR `price`=200",
767+
"SELECT `name`,`category`,`price` FROM `fruits` WHERE `category`=\"apple\" AND `price`=?",
768+
"UPDATE `fruits` SET `category`=?, `price`=? WHERE `name`=?",
769+
"SELECT `name`,`category`,`price` FROM `fruits` WHERE `name`=? AND `category`=? AND `price`=? AND `category`=\"apple\"",
770+
"SELECT `name`,`category`,`price` FROM `fruits` WHERE `category`=\"apple\" AND (`name`=? OR `name`=?)",
771+
"INSERT INTO `fruits` (`name`,`category`,`price`) VALUES (?,?,?),(?,?,?)",
772+
"SELECT `name`,`category`,`price` FROM `fruits` WHERE `name` LIKE ?",
773+
})
774+
}
775+
776+
func TestParserQueryHasInvalidQuery(t *testing.T) {
777+
_, err := ParseXMLQuery(
778+
`
779+
<mapper namespace="Test">
780+
<sql id="someinclude">
781+
*
782+
</sql>
783+
<select id="testBind">
784+
<bind name="likeName" value="'%' + name + '%'"/>
785+
SELECT
786+
name,
787+
category,
788+
price
789+
FROM
790+
fruits
791+
WHERE
792+
name like #{likeName}
793+
</select>
794+
<select id="select" resultType="map">
795+
select
796+
<include refid="someinclude2" />
797+
from t
798+
</select>
799+
</mapper>`, false)
800+
if err == nil {
801+
t.Errorf("expect has error, but no error")
802+
}
803+
if err.Error() != "sql someinclude2 is not exist" {
804+
t.Errorf("actual error is [%s]", err.Error())
805+
}
806+
}
807+
808+
func TestParserQueryHasInvalidQueryButSkip(t *testing.T) {
809+
testParserQuery(t, true,
810+
`
811+
<mapper namespace="Test">
812+
<sql id="someinclude">
813+
*
814+
</sql>
815+
<select id="testBind">
816+
<bind name="likeName" value="'%' + name + '%'"/>
817+
SELECT
818+
name,
819+
category,
820+
price
821+
FROM
822+
fruits
823+
WHERE
824+
name like #{likeName}
825+
</select>
826+
<select id="select" resultType="map">
827+
select
828+
<include refid="someinclude2" />
829+
from t
830+
</select>
831+
</mapper>`, []string{
832+
"SELECT `name`,`category`,`price` FROM `fruits` WHERE `name` LIKE ?",
833+
})
834+
}

0 commit comments

Comments
 (0)