diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 40992d1..25adb7b 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -40,6 +40,7 @@ - [eorm:Update 补充 Exec 方法](https://github.com/gotomicro/eorm/pull/98) - [eorm: 支持使用组合定义模型](https://github.com/gotomicro/eorm/pull/99) - [eorm: 增加 Result 抽象](https://github.com/gotomicro/eorm/pull/100) +- [eorm: 实现in查询功能](https://github.com/gotomicro/eorm/pull/102) - [eorm: 补充实现组合定义模型后的测试用例](https://github.com/gotomicro/eorm/pull/104) ### 文档, 代码质量以及文档 diff --git a/builder.go b/builder.go index a96703d..b64b197 100644 --- a/builder.go +++ b/builder.go @@ -179,6 +179,10 @@ func (b *builder) buildExpr(expr Expr) error { if err := b.buildBinaryExpr(binaryExpr(e)); err != nil { return err } + case values: + if err := b.buildIns(e); err != nil { + return err + } case nil: default: return errors.New("unsupported expr") @@ -247,3 +251,18 @@ func (b *builder) buildSubExpr(subExpr Expr) error { } return nil } + +func (b *builder) buildIns(is values) error { + _ = b.buffer.WriteByte('(') + for idx, inVal := range is.data { + if idx > 0 { + _ = b.buffer.WriteByte(',') + } + + b.args = append(b.args, inVal) + _ = b.buffer.WriteByte('?') + + } + _ = b.buffer.WriteByte(')') + return nil +} diff --git a/column.go b/column.go index 0b09fcb..cbe202a 100644 --- a/column.go +++ b/column.go @@ -127,11 +127,11 @@ type columns struct { cs []string } -func (c columns) selected() { +func (columns) selected() { panic("implement me") } -func (c columns) assign() { +func (columns) assign() { panic("implement me") } @@ -141,3 +141,44 @@ func Columns(cs ...string) columns { cs: cs, } } + +// In 方法没有元素传入,会被认为是false,被解释成where false这种形式 +func (c Column) In(data ...any) Predicate { + if len(data) == 0 { + return Predicate{ + op: opFalse, + } + } + + return Predicate{ + left: c, + op: opIn, + right: values{ + data: data, + }, + } +} + +// NotIn 方法没有元素传入,会被认为是false,被解释成where false这种形式 +func (c Column) NotIn(data ...any) Predicate { + if len(data) == 0 { + return Predicate{ + op: opFalse, + } + } + return Predicate{ + left: c, + op: opNotIN, + right: values{ + data: data, + }, + } +} + +type values struct { + data []any +} + +func (values) expr() (string, error) { + panic("implement me") +} diff --git a/predicate.go b/predicate.go index 21ba62e..3c8982a 100644 --- a/predicate.go +++ b/predicate.go @@ -31,9 +31,12 @@ var ( // opMinus = op{symbol:"-", text: "-"} opMulti = op{symbol: "*", text: "*"} // opDiv = op{symbol:"/", text: "/"} - opAnd = op{symbol: "AND", text: " AND "} - opOr = op{symbol: "OR", text: " OR "} - opNot = op{symbol: "NOT", text: "NOT "} + opAnd = op{symbol: "AND", text: " AND "} + opOr = op{symbol: "OR", text: " OR "} + opNot = op{symbol: "NOT", text: "NOT "} + opIn = op{symbol: "IN", text: " IN "} + opNotIN = op{symbol: "NOT IN", text: " NOT IN "} + opFalse = op{symbol: "FALSE", text: "FALSE"} ) // Predicate will be used in Where Or Having diff --git a/select_test.go b/select_test.go index cee72d1..69a3194 100644 --- a/select_test.go +++ b/select_test.go @@ -120,6 +120,47 @@ func TestSelectable(t *testing.T) { builder: NewSelector[TestModel](db).Select(Columns("Id"), Columns("FirstName"), Avg("Age").As("avg_age")).From(&TestModel{}).GroupBy("FirstName").Having(C("Invalid").LT(20)), wantErr: errs.NewInvalidFieldError("Invalid"), }, + { + name: "in", + builder: NewSelector[TestModel](db).Select(Columns("Id")).From(&TestModel{}).Where(C("Id").In(1, 2, 3)), + wantSql: "SELECT `id` FROM `test_model` WHERE `id` IN (?,?,?);", + wantArgs: []interface{}{1, 2, 3}, + }, + { + name: "not in", + builder: NewSelector[TestModel](db).Select(Columns("Id")).From(&TestModel{}).Where(C("Id").NotIn(1, 2, 3)), + wantSql: "SELECT `id` FROM `test_model` WHERE `id` NOT IN (?,?,?);", + wantArgs: []interface{}{1, 2, 3}, + }, + { + // 传入的参数为切片 + name: "slice in", + builder: NewSelector[TestModel](db).Select(Columns("Id")).From(&TestModel{}).Where(C("Id").In([]int{1, 2, 3})), + wantSql: "SELECT `id` FROM `test_model` WHERE `id` IN (?);", + wantArgs: []interface{}{[]int{1, 2, 3}}, + }, + { + // in 后面没有值 + name: "no in", + builder: NewSelector[TestModel](db).Select(Columns("Id")).From(&TestModel{}).Where(C("Id").In()), + wantSql: "SELECT `id` FROM `test_model` WHERE FALSE;", + }, + { + // Notin 后面没有值 + name: "no in", + builder: NewSelector[TestModel](db).Select(Columns("Id")).From(&TestModel{}).Where(C("Id").NotIn()), + wantSql: "SELECT `id` FROM `test_model` WHERE FALSE;", + }, + { + name: "in empty slice", + builder: NewSelector[TestModel](db).Select(Columns("Id")).From(&TestModel{}).Where(C("Id").In([]any{}...)), + wantSql: "SELECT `id` FROM `test_model` WHERE FALSE;", + }, + { + name: "NOT In empty slice", + builder: NewSelector[TestModel](db).Select(Columns("Id")).From(&TestModel{}).Where(C("Id").NotIn([]any{}...)), + wantSql: "SELECT `id` FROM `test_model` WHERE FALSE;", + }, } for _, tc := range testCases {