From defcda95831a3069a9a216aea8e1d0e608c4f259 Mon Sep 17 00:00:00 2001 From: doug-martin Date: Wed, 21 Aug 2019 00:16:24 -0500 Subject: [PATCH] Refactor SQL Generation Logic * Created new `sqlgen` module to encapsulate sql generation * Broke SQLDialect inti new SQL generators for each statement type. * Test refactor * Moved to a test case pattern to allow for quickly adding new test cases. --- HISTORY.md | 9 +- codecov.yml | 3 + delete_dataset.go | 5 +- delete_dataset_test.go | 633 ++-- exp/exp.go | 12 + exp/exp_map.go | 2 +- insert_dataset.go | 7 +- insert_dataset_test.go | 1019 ++---- internal/sb/sql_builder.go | 8 - select_dataset.go | 12 +- select_dataset_test.go | 2888 ++++++----------- sql_dialect.go | 1149 +------ sql_dialect_test.go | 2615 +-------------- sqlgen/base_test.go | 40 + sqlgen/common_sql_generator.go | 99 + sqlgen/common_sql_generator_test.go | 341 ++ sqlgen/delete_sql_generator.go | 73 + sqlgen/delete_sql_generator_test.go | 233 ++ sqlgen/expression_sql_generator.go | 568 ++++ sqlgen/expression_sql_generator_test.go | 1198 +++++++ sqlgen/insert_sql_generator.go | 206 ++ sqlgen/insert_sql_generator_test.go | 455 +++ sqlgen/mocks/DeleteSQLGenerator.go | 31 + sqlgen/mocks/InsertSQLGenerator.go | 31 + sqlgen/mocks/SelectSQLGenerator.go | 31 + sqlgen/mocks/TruncateSQLGenerator.go | 31 + sqlgen/mocks/UpdateSQLGenerator.go | 31 + sqlgen/select_sql_generator.go | 206 ++ sqlgen/select_sql_generator_test.go | 412 +++ .../sql_dialect_options.go | 2 +- sqlgen/truncate_sql_generator.go | 68 + sqlgen/truncate_sql_generator_test.go | 120 + sqlgen/update_sql_generator.go | 117 + sqlgen/update_sql_generator_test.go | 241 ++ truncate_dataset_test.go | 200 +- update_dataset.go | 5 +- update_dataset_test.go | 1003 ++---- 37 files changed, 6466 insertions(+), 7638 deletions(-) create mode 100644 codecov.yml create mode 100644 sqlgen/base_test.go create mode 100644 sqlgen/common_sql_generator.go create mode 100644 sqlgen/common_sql_generator_test.go create mode 100644 sqlgen/delete_sql_generator.go create mode 100644 sqlgen/delete_sql_generator_test.go create mode 100644 sqlgen/expression_sql_generator.go create mode 100644 sqlgen/expression_sql_generator_test.go create mode 100644 sqlgen/insert_sql_generator.go create mode 100644 sqlgen/insert_sql_generator_test.go create mode 100644 sqlgen/mocks/DeleteSQLGenerator.go create mode 100644 sqlgen/mocks/InsertSQLGenerator.go create mode 100644 sqlgen/mocks/SelectSQLGenerator.go create mode 100644 sqlgen/mocks/TruncateSQLGenerator.go create mode 100644 sqlgen/mocks/UpdateSQLGenerator.go create mode 100644 sqlgen/select_sql_generator.go create mode 100644 sqlgen/select_sql_generator_test.go rename sql_dialect_options.go => sqlgen/sql_dialect_options.go (99%) create mode 100644 sqlgen/truncate_sql_generator.go create mode 100644 sqlgen/truncate_sql_generator_test.go create mode 100644 sqlgen/update_sql_generator.go create mode 100644 sqlgen/update_sql_generator_test.go diff --git a/HISTORY.md b/HISTORY.md index 54b8c57d..fd9721a4 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,6 +1,13 @@ +## 8.4.0 + +* Created new `sqlgen` module to encapsulate sql generation + * Broke SQLDialect inti new SQL generators for each statement type. +* Test refactor + * Moved to a test case pattern to allow for quickly adding new test cases. + ## v8.3.2 -* [FIXED] Data race during query factory initialisation [#133](https://github.com/doug-martin/goqu/issues/133) and [#136](https://github.com/doug-martin/goqu/issues/136) - [@o1egl](https://github.com/o1egl) +* [FIXED] Data race during query factory initialization [#133](https://github.com/doug-martin/goqu/issues/133) and [#136](https://github.com/doug-martin/goqu/issues/136) - [@o1egl](https://github.com/o1egl) ## 8.3.1 diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 00000000..a3d9cd50 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,3 @@ +ignore: + - "**/mocks/**" # glob accepted + - "mocks/**" # glob accepted \ No newline at end of file diff --git a/delete_dataset.go b/delete_dataset.go index 6cb32e17..0c28a778 100644 --- a/delete_dataset.go +++ b/delete_dataset.go @@ -3,9 +3,12 @@ package goqu import ( "github.com/doug-martin/goqu/v8/exec" "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/errors" "github.com/doug-martin/goqu/v8/internal/sb" ) +var errBadFromArgument = errors.New("unsupported DeleteDataset#From argument, a string or identifier expression is required") + type DeleteDataset struct { dialect SQLDialect clauses exp.DeleteClauses @@ -108,7 +111,7 @@ func (dd *DeleteDataset) From(table interface{}) *DeleteDataset { case string: return dd.copy(dd.clauses.SetFrom(exp.ParseIdentifier(t))) default: - panic("unsupported table type, a string or identifier expression is required") + panic(errBadFromArgument) } } diff --git a/delete_dataset_test.go b/delete_dataset_test.go index 0ba8c29b..bedf7d42 100644 --- a/delete_dataset_test.go +++ b/delete_dataset_test.go @@ -13,8 +13,20 @@ import ( "github.com/stretchr/testify/suite" ) -type deleteDatasetSuite struct { - suite.Suite +type ( + deleteTestCase struct { + ds *DeleteDataset + clauses exp.DeleteClauses + } + deleteDatasetSuite struct { + suite.Suite + } +) + +func (dds *deleteDatasetSuite) assertCases(cases ...deleteTestCase) { + for _, s := range cases { + dds.Equal(s.clauses, s.ds.GetClauses()) + } } func (dds *deleteDatasetSuite) SetupSuite() { @@ -62,24 +74,6 @@ func (dds *deleteDatasetSuite) TestPrepared() { dds.True(preparedDs.Where(Ex{"a": 1}).IsPrepared()) } -func (dds *deleteDatasetSuite) TestPrepared_ToSQL() { - ds1 := Delete("items") - dsql, args, err := ds1.Prepared(true).ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal(`DELETE FROM "items"`, dsql) - - dsql, args, err = ds1.Where(I("id").Eq(1)).Prepared(true).ToSQL() - dds.NoError(err) - dds.Equal([]interface{}{int64(1)}, args) - dds.Equal(`DELETE FROM "items" WHERE ("id" = ?)`, dsql) - - dsql, args, err = ds1.Returning("id").Where(I("id").Eq(1)).Prepared(true).ToSQL() - dds.NoError(err) - dds.Equal([]interface{}{int64(1)}, args) - dds.Equal(`DELETE FROM "items" WHERE ("id" = ?) RETURNING "id"`, dsql) -} - func (dds *deleteDatasetSuite) TestGetClauses() { ds := Delete("test") ce := exp.NewDeleteClauses().SetFrom(I("test")) @@ -88,415 +82,272 @@ func (dds *deleteDatasetSuite) TestGetClauses() { func (dds *deleteDatasetSuite) TestWith() { from := From("cte") - ds := Delete("test") - dsc := ds.GetClauses() - ec := dsc.CommonTablesAppend(exp.NewCommonTableExpression(false, "test-cte", from)) - dds.Equal(ec, ds.With("test-cte", from).GetClauses()) - dds.Equal(dsc, ds.GetClauses()) + bd := Delete("items") + dds.assertCases( + deleteTestCase{ + ds: bd.With("test-cte", from), + clauses: exp.NewDeleteClauses().SetFrom(C("items")). + CommonTablesAppend(exp.NewCommonTableExpression(false, "test-cte", from)), + }, + deleteTestCase{ + ds: bd, + clauses: exp.NewDeleteClauses().SetFrom(C("items")), + }, + ) } func (dds *deleteDatasetSuite) TestWithRecursive() { from := From("cte") - ds := Delete("test") - dsc := ds.GetClauses() - ec := dsc.CommonTablesAppend(exp.NewCommonTableExpression(true, "test-cte", from)) - dds.Equal(ec, ds.WithRecursive("test-cte", from).GetClauses()) - dds.Equal(dsc, ds.GetClauses()) -} - -func (dds *deleteDatasetSuite) TestFrom() { - ds := Delete("test") - dsc := ds.GetClauses() - ec := dsc.SetFrom(T("t")) - dds.Equal(ec, ds.From(T("t")).GetClauses()) - dds.Equal(dsc, ds.GetClauses()) -} - -func (dds *deleteDatasetSuite) TestFrom_ToSQL() { - ds1 := Delete("test") - - deleteSQL, _, err := ds1.ToSQL() - dds.NoError(err) - dds.Equal(`DELETE FROM "test"`, deleteSQL) - - ds2 := ds1.From("test2") - deleteSQL, _, err = ds2.ToSQL() - dds.NoError(err) - dds.Equal(`DELETE FROM "test2"`, deleteSQL) - - // original should not change - deleteSQL, _, err = ds1.ToSQL() - dds.NoError(err) - dds.Equal(`DELETE FROM "test"`, deleteSQL) - -} - -func (dds *deleteDatasetSuite) TestWhere() { - ds := Delete("test") - dsc := ds.GetClauses() - w := Ex{ - "a": 1, - } - ec := dsc.WhereAppend(w) - dds.Equal(ec, ds.Where(w).GetClauses()) - dds.Equal(dsc, ds.GetClauses()) -} - -func (dds *deleteDatasetSuite) TestWhere_ToSQL() { - ds1 := Delete("test") - - b := ds1.Where( - C("a").Eq(true), - C("a").Neq(true), - C("a").Eq(false), - C("a").Neq(false), - ) - deleteSQL, args, err := b.ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal( - `DELETE FROM "test" WHERE (("a" IS TRUE) AND ("a" IS NOT TRUE) AND ("a" IS FALSE) AND ("a" IS NOT FALSE))`, - deleteSQL, - ) - - deleteSQL, args, err = b.Prepared(true).ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal( - `DELETE FROM "test" WHERE (("a" IS TRUE) AND ("a" IS NOT TRUE) AND ("a" IS FALSE) AND ("a" IS NOT FALSE))`, - deleteSQL, - ) - - b = ds1.Where( - C("a").Eq("a"), - C("b").Neq("b"), - C("c").Gt("c"), - C("d").Gte("d"), - C("e").Lt("e"), - C("f").Lte("f"), - ) - deleteSQL, args, err = b.ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal( - `DELETE FROM "test" `+ - `WHERE (("a" = 'a') AND ("b" != 'b') AND ("c" > 'c') AND ("d" >= 'd') AND ("e" < 'e') AND ("f" <= 'f'))`, - deleteSQL, - ) - - deleteSQL, args, err = b.Prepared(true).ToSQL() - dds.NoError(err) - dds.Equal([]interface{}{"a", "b", "c", "d", "e", "f"}, args) - dds.Equal( - `DELETE FROM "test" `+ - `WHERE (("a" = ?) AND ("b" != ?) AND ("c" > ?) AND ("d" >= ?) AND ("e" < ?) AND ("f" <= ?))`, - deleteSQL, - ) - - b = ds1.Where( - C("a").Eq(From("test2").Select("id")), - ) - deleteSQL, args, err = b.ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal( - `DELETE FROM "test" WHERE ("a" IN (SELECT "id" FROM "test2"))`, - deleteSQL, - ) - - deleteSQL, args, err = b.Prepared(true).ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal( - `DELETE FROM "test" WHERE ("a" IN (SELECT "id" FROM "test2"))`, - deleteSQL, - ) - - b = ds1.Where(Ex{ - "a": "a", - "b": Op{"neq": "b"}, - "c": Op{"gt": "c"}, - "d": Op{"gte": "d"}, - "e": Op{"lt": "e"}, - "f": Op{"lte": "f"}, - }) - deleteSQL, args, err = b.ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal(`DELETE FROM "test" `+ - `WHERE (("a" = 'a') AND ("b" != 'b') AND ("c" > 'c') AND ("d" >= 'd') AND ("e" < 'e') AND ("f" <= 'f'))`, - deleteSQL, + bd := Delete("items") + dds.assertCases( + deleteTestCase{ + ds: bd.WithRecursive("test-cte", from), + clauses: exp.NewDeleteClauses().SetFrom(C("items")). + CommonTablesAppend(exp.NewCommonTableExpression(true, "test-cte", from)), + }, + deleteTestCase{ + ds: bd, + clauses: exp.NewDeleteClauses().SetFrom(C("items")), + }, ) +} - deleteSQL, args, err = b.Prepared(true).ToSQL() - dds.NoError(err) - dds.Equal([]interface{}{"a", "b", "c", "d", "e", "f"}, args) - dds.Equal( - `DELETE FROM "test" `+ - `WHERE (("a" = ?) AND ("b" != ?) AND ("c" > ?) AND ("d" >= ?) AND ("e" < ?) AND ("f" <= ?))`, - deleteSQL, +func (dds *deleteDatasetSuite) TestFrom_withIdentifier() { + bd := Delete("items") + dds.assertCases( + deleteTestCase{ + ds: bd.From("items2"), + clauses: exp.NewDeleteClauses().SetFrom(C("items2")), + }, + deleteTestCase{ + ds: bd.From(C("items2")), + clauses: exp.NewDeleteClauses().SetFrom(C("items2")), + }, + deleteTestCase{ + ds: bd.From(T("items2")), + clauses: exp.NewDeleteClauses().SetFrom(T("items2")), + }, + deleteTestCase{ + ds: bd.From("schema.table"), + clauses: exp.NewDeleteClauses().SetFrom(I("schema.table")), + }, + deleteTestCase{ + ds: bd, + clauses: exp.NewDeleteClauses().SetFrom(C("items")), + }, ) - b = ds1.Where(Ex{ - "a": From("test2").Select("id"), + dds.PanicsWithValue(errBadFromArgument, func() { + Delete("test").From(true) }) - deleteSQL, args, err = b.ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal(`DELETE FROM "test" WHERE ("a" IN (SELECT "id" FROM "test2"))`, deleteSQL) - deleteSQL, args, err = b.Prepared(true).ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal(`DELETE FROM "test" WHERE ("a" IN (SELECT "id" FROM "test2"))`, deleteSQL) } -func (dds *deleteDatasetSuite) TestWhere_chainToSQL() { - ds1 := Delete("test").Where( - C("x").Eq(0), - C("y").Eq(1), - ) - - ds2 := ds1.Where( - C("z").Eq(2), - ) - - a := ds2.Where( - C("a").Eq("A"), - ) - b := ds2.Where( - C("b").Eq("B"), - ) - deleteSQL, _, err := a.ToSQL() - dds.NoError(err) - dds.Equal( - `DELETE FROM "test" WHERE (("x" = 0) AND ("y" = 1) AND ("z" = 2) AND ("a" = 'A'))`, - deleteSQL, - ) - deleteSQL, _, err = b.ToSQL() - dds.NoError(err) - dds.Equal( - `DELETE FROM "test" WHERE (("x" = 0) AND ("y" = 1) AND ("z" = 2) AND ("b" = 'B'))`, - deleteSQL, +func (dds *deleteDatasetSuite) TestWhere() { + bd := Delete("items") + dds.assertCases( + deleteTestCase{ + ds: bd.Where(Ex{"a": 1}), + clauses: exp.NewDeleteClauses(). + SetFrom(C("items")). + WhereAppend(Ex{"a": 1}), + }, + deleteTestCase{ + ds: bd.Where(Ex{"a": 1}).Where(C("b").Eq("c")), + clauses: exp.NewDeleteClauses(). + SetFrom(C("items")). + WhereAppend(Ex{"a": 1}). + WhereAppend(C("b").Eq("c")), + }, + deleteTestCase{ + ds: bd, + clauses: exp.NewDeleteClauses().SetFrom(C("items")), + }, ) } -func (dds *deleteDatasetSuite) TestWhere_emptyToSQL() { - ds1 := Delete("test") - - b := ds1.Where() - deleteSQL, _, err := b.ToSQL() - dds.NoError(err) - dds.Equal(`DELETE FROM "test"`, deleteSQL) -} - func (dds *deleteDatasetSuite) TestClearWhere() { - w := Ex{ - "a": 1, - } - ds := Delete("test").Where(w) - dsc := ds.GetClauses() - ec := dsc.ClearWhere() - dds.Equal(ec, ds.ClearWhere().GetClauses()) - dds.Equal(dsc, ds.GetClauses()) -} - -func (dds *deleteDatasetSuite) TestClearWhere_ToSQL() { - ds1 := Delete("test") - - b := ds1.Where( - C("a").Eq(1), - ).ClearWhere() - deleteSQL, _, err := b.ToSQL() - dds.NoError(err) - dds.Equal(`DELETE FROM "test"`, deleteSQL) + bd := Delete("items").Where(Ex{"a": 1}) + dds.assertCases( + deleteTestCase{ + ds: bd.ClearWhere(), + clauses: exp.NewDeleteClauses(). + SetFrom(C("items")), + }, + deleteTestCase{ + ds: bd, + clauses: exp.NewDeleteClauses(). + SetFrom(C("items")). + WhereAppend(Ex{"a": 1}), + }, + ) } func (dds *deleteDatasetSuite) TestOrder() { - ds := Delete("test") - dsc := ds.GetClauses() - o := C("a").Desc() - ec := dsc.SetOrder(o) - dds.Equal(ec, ds.Order(o).GetClauses()) - dds.Equal(dsc, ds.GetClauses()) -} -func (dds *deleteDatasetSuite) TestOrder_ToSQL() { - - ds1 := Delete("test").WithDialect("order-on-delete") - - b := ds1.Order(C("a").Asc(), L(`("a" + "b" > 2)`).Asc()) - deleteSQL, args, err := b.ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal(`DELETE FROM "test" ORDER BY "a" ASC, ("a" + "b" > 2) ASC`, deleteSQL) - - deleteSQL, args, err = b.Prepared(true).ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal(`DELETE FROM "test" ORDER BY "a" ASC, ("a" + "b" > 2) ASC`, deleteSQL) + bd := Delete("items") + dds.assertCases( + deleteTestCase{ + ds: bd.Order(C("a").Asc()), + clauses: exp.NewDeleteClauses(). + SetFrom(C("items")). + SetOrder(C("a").Asc()), + }, + deleteTestCase{ + ds: bd.Order(C("a").Asc()).Order(C("b").Desc()), + clauses: exp.NewDeleteClauses(). + SetFrom(C("items")). + SetOrder(C("b").Desc()), + }, + deleteTestCase{ + ds: bd.Order(C("a").Asc(), C("b").Desc()), + clauses: exp.NewDeleteClauses(). + SetFrom(C("items")). + SetOrder(C("a").Asc(), C("b").Desc()), + }, + deleteTestCase{ + ds: bd, + clauses: exp.NewDeleteClauses().SetFrom(C("items")), + }, + ) } func (dds *deleteDatasetSuite) TestOrderAppend() { - ds := Delete("test").Order(C("a").Desc()) - dsc := ds.GetClauses() - o := C("b").Desc() - ec := dsc.OrderAppend(o) - dds.Equal(ec, ds.OrderAppend(o).GetClauses()) - dds.Equal(dsc, ds.GetClauses()) + bd := Delete("items").Order(C("a").Asc()) + dds.assertCases( + deleteTestCase{ + ds: bd.OrderAppend(C("b").Desc()), + clauses: exp.NewDeleteClauses(). + SetFrom(C("items")). + SetOrder(C("a").Asc(), C("b").Desc()), + }, + deleteTestCase{ + ds: bd, + clauses: exp.NewDeleteClauses(). + SetFrom(C("items")). + SetOrder(C("a").Asc()), + }, + ) } -func (dds *deleteDatasetSuite) TestOrderAppend_ToSQL() { - ds := Delete("test").WithDialect("order-on-delete") - b := ds.Order(C("a").Asc().NullsFirst()).OrderAppend(C("b").Desc().NullsLast()) - deleteSQL, _, err := b.ToSQL() - dds.NoError(err) - dds.Equal(`DELETE FROM "test" ORDER BY "a" ASC NULLS FIRST, "b" DESC NULLS LAST`, deleteSQL) - - b = ds.OrderAppend(C("a").Asc().NullsFirst()).OrderAppend(C("b").Desc().NullsLast()) - deleteSQL, _, err = b.ToSQL() - dds.NoError(err) - dds.Equal(`DELETE FROM "test" ORDER BY "a" ASC NULLS FIRST, "b" DESC NULLS LAST`, deleteSQL) - +func (dds *deleteDatasetSuite) TestOrderPrepend() { + bd := Delete("items").Order(C("a").Asc()) + dds.assertCases( + deleteTestCase{ + ds: bd.OrderPrepend(C("b").Desc()), + clauses: exp.NewDeleteClauses(). + SetFrom(C("items")). + SetOrder(C("b").Desc(), C("a").Asc()), + }, + deleteTestCase{ + ds: bd, + clauses: exp.NewDeleteClauses(). + SetFrom(C("items")). + SetOrder(C("a").Asc()), + }, + ) } func (dds *deleteDatasetSuite) TestClearOrder() { - ds := Delete("test").Order(C("a").Desc()) - dsc := ds.GetClauses() - ec := dsc.ClearOrder() - dds.Equal(ec, ds.ClearOrder().GetClauses()) - dds.Equal(dsc, ds.GetClauses()) -} - -func (dds *deleteDatasetSuite) TestClearOrder_ToSQL() { - ds := Delete("test").WithDialect("order-on-delete") - b := ds.Order(C("a").Asc().NullsFirst()).ClearOrder() - deleteSQL, _, err := b.ToSQL() - dds.NoError(err) - dds.Equal(`DELETE FROM "test"`, deleteSQL) + bd := Delete("items").Order(C("a").Asc()) + dds.assertCases( + deleteTestCase{ + ds: bd.ClearOrder(), + clauses: exp.NewDeleteClauses().SetFrom(C("items")), + }, + deleteTestCase{ + ds: bd, + clauses: exp.NewDeleteClauses(). + SetFrom(C("items")). + SetOrder(C("a").Asc()), + }, + ) } func (dds *deleteDatasetSuite) TestLimit() { - ds := Delete("test") - dsc := ds.GetClauses() - ec := dsc.SetLimit(uint(1)) - dds.Equal(ec, ds.Limit(1).GetClauses()) - dds.Equal(dsc, ds.Limit(0).GetClauses()) - dds.Equal(dsc, ds.GetClauses()) -} - -func (dds *deleteDatasetSuite) TestLimit_ToSQL() { - ds1 := Delete("test").WithDialect("limit-on-delete") - - b := ds1.Where(C("a").Gt(1)).Limit(10) - deleteSQL, args, err := b.ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal(`DELETE FROM "test" WHERE ("a" > 1) LIMIT 10`, deleteSQL) - - deleteSQL, args, err = b.Prepared(true).ToSQL() - dds.NoError(err) - dds.Equal([]interface{}{int64(1), int64(10)}, args) - dds.Equal(`DELETE FROM "test" WHERE ("a" > ?) LIMIT ?`, deleteSQL) - - b = ds1.Where(C("a").Gt(1)).Limit(0) - deleteSQL, args, err = b.ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal(`DELETE FROM "test" WHERE ("a" > 1)`, deleteSQL) - - deleteSQL, args, err = b.Prepared(true).ToSQL() - dds.NoError(err) - dds.Equal([]interface{}{int64(1)}, args) - dds.Equal(`DELETE FROM "test" WHERE ("a" > ?)`, deleteSQL) + bd := Delete("test") + dds.assertCases( + deleteTestCase{ + ds: bd.Limit(10), + clauses: exp.NewDeleteClauses(). + SetFrom(C("test")). + SetLimit(uint(10)), + }, + deleteTestCase{ + ds: bd.Limit(0), + clauses: exp.NewDeleteClauses().SetFrom(C("test")), + }, + deleteTestCase{ + ds: bd.Limit(10).Limit(2), + clauses: exp.NewDeleteClauses(). + SetFrom(C("test")). + SetLimit(uint(2)), + }, + deleteTestCase{ + ds: bd.Limit(10).Limit(0), + clauses: exp.NewDeleteClauses().SetFrom(C("test")), + }, + deleteTestCase{ + ds: bd, + clauses: exp.NewDeleteClauses().SetFrom(C("test")), + }, + ) } func (dds *deleteDatasetSuite) TestLimitAll() { - ds := Delete("test") - dsc := ds.GetClauses() - ec := dsc.SetLimit(L("ALL")) - dds.Equal(ec, ds.LimitAll().GetClauses()) - dds.Equal(dsc, ds.GetClauses()) -} - -func (dds *deleteDatasetSuite) TestLimitAll_ToSQL() { - ds1 := Delete("test").WithDialect("limit-on-delete") - - b := ds1.Where(C("a").Gt(1)).LimitAll() - - deleteSQL, args, err := b.ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal(`DELETE FROM "test" WHERE ("a" > 1) LIMIT ALL`, deleteSQL) - - deleteSQL, args, err = b.Prepared(true).ToSQL() - dds.NoError(err) - dds.Equal([]interface{}{int64(1)}, args) - dds.Equal(`DELETE FROM "test" WHERE ("a" > ?) LIMIT ALL`, deleteSQL) - - b = ds1.Where(C("a").Gt(1)).Limit(0).LimitAll() - deleteSQL, _, err = b.ToSQL() - dds.NoError(err) - dds.Equal(`DELETE FROM "test" WHERE ("a" > 1) LIMIT ALL`, deleteSQL) - - deleteSQL, args, err = b.Prepared(true).ToSQL() - dds.NoError(err) - dds.Equal([]interface{}{int64(1)}, args) - dds.Equal(`DELETE FROM "test" WHERE ("a" > ?) LIMIT ALL`, deleteSQL) + bd := Delete("test") + dds.assertCases( + deleteTestCase{ + ds: bd.LimitAll(), + clauses: exp.NewDeleteClauses(). + SetFrom(C("test")). + SetLimit(L("ALL")), + }, + deleteTestCase{ + ds: bd.Limit(10).LimitAll(), + clauses: exp.NewDeleteClauses(). + SetFrom(C("test")). + SetLimit(L("ALL")), + }, + deleteTestCase{ + ds: bd, + clauses: exp.NewDeleteClauses().SetFrom(C("test")), + }, + ) } func (dds *deleteDatasetSuite) TestClearLimit() { - ds := Delete("test").Limit(1) - dsc := ds.GetClauses() - ec := dsc.ClearLimit() - dds.Equal(ec, ds.ClearLimit().GetClauses()) - dds.Equal(dsc, ds.GetClauses()) -} - -func (dds *deleteDatasetSuite) TestClearLimit_ToSQL() { - ds1 := Delete("test") - - b := ds1.Where(C("a").Gt(1)).LimitAll().ClearLimit() - deleteSQL, args, err := b.ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal(`DELETE FROM "test" WHERE ("a" > 1)`, deleteSQL) - - deleteSQL, args, err = b.Prepared(true).ToSQL() - dds.NoError(err) - dds.Equal([]interface{}{int64(1)}, args) - dds.Equal(`DELETE FROM "test" WHERE ("a" > ?)`, deleteSQL) - - b = ds1.Where(C("a").Gt(1)).Limit(10).ClearLimit() - deleteSQL, args, err = b.ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal(`DELETE FROM "test" WHERE ("a" > 1)`, deleteSQL) - deleteSQL, args, err = b.Prepared(true).ToSQL() - dds.NoError(err) - dds.Equal([]interface{}{int64(1)}, args) - dds.Equal(`DELETE FROM "test" WHERE ("a" > ?)`, deleteSQL) + bd := Delete("test").Limit(10) + dds.assertCases( + deleteTestCase{ + ds: bd.ClearLimit(), + clauses: exp.NewDeleteClauses().SetFrom(C("test")), + }, + deleteTestCase{ + ds: bd, + clauses: exp.NewDeleteClauses().SetFrom(C("test")).SetLimit(uint(10)), + }, + ) } func (dds *deleteDatasetSuite) TestReturning() { - ds := Delete("test") - dsc := ds.GetClauses() - ec := dsc.SetReturning(exp.NewColumnListExpression(C("a"))) - dds.Equal(ec, ds.Returning("a").GetClauses()) - dds.Equal(dsc, ds.GetClauses()) -} - -func (dds *deleteDatasetSuite) TestReturning_ToSQL() { - ds := Delete("test") - b := ds.Returning("a") - - deleteSQL, args, err := b.ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal(`DELETE FROM "test" RETURNING "a"`, deleteSQL) - - deleteSQL, args, err = b.Prepared(true).ToSQL() - dds.NoError(err) - dds.Empty(args) - dds.Equal(`DELETE FROM "test" RETURNING "a"`, deleteSQL) + bd := Delete("items") + dds.assertCases( + deleteTestCase{ + ds: bd.Returning("a"), + clauses: exp.NewDeleteClauses(). + SetFrom(C("items")). + SetReturning(exp.NewColumnListExpression("a")), + }, + deleteTestCase{ + ds: bd.Returning("a").Returning("b"), + clauses: exp.NewDeleteClauses(). + SetFrom(C("items")). + SetReturning(exp.NewColumnListExpression("b")), + }, + deleteTestCase{ + ds: bd, + clauses: exp.NewDeleteClauses().SetFrom(C("items")), + }, + ) } func (dds *deleteDatasetSuite) TestToSQL() { diff --git a/exp/exp.go b/exp/exp.go index 71eaa127..b08fe730 100644 --- a/exp/exp.go +++ b/exp/exp.go @@ -442,6 +442,8 @@ const ( RegexpILikeOp // !~*, NOT REGEXP RegexpNotILikeOp + + betweenStr = "between" ) var ( @@ -519,6 +521,16 @@ func (bo BooleanOperation) String() string { return fmt.Sprintf("%d", bo) } +func (ro RangeOperation) String() string { + switch ro { + case BetweenOp: + return betweenStr + case NotBetweenOp: + return "not between" + } + return fmt.Sprintf("%d", ro) +} + func (jt JoinType) String() string { switch jt { case InnerJoinType: diff --git a/exp/exp_map.go b/exp/exp_map.go index e99fdfaf..bfd78245 100644 --- a/exp/exp_map.go +++ b/exp/exp_map.go @@ -138,7 +138,7 @@ func createExpressionFromOp(lhs IdentifierExpression, opKey string, op Op) (exp exp = lhs.ILike(op[opKey]) case NotILikeOp.String(): exp = lhs.NotILike(op[opKey]) - case "between": + case betweenStr: rangeVal, ok := op[opKey].(RangeVal) if ok { exp = lhs.Between(rangeVal) diff --git a/insert_dataset.go b/insert_dataset.go index 0f39711e..1147cb7d 100644 --- a/insert_dataset.go +++ b/insert_dataset.go @@ -3,6 +3,7 @@ package goqu import ( "github.com/doug-martin/goqu/v8/exec" "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/errors" "github.com/doug-martin/goqu/v8/internal/sb" ) @@ -13,6 +14,8 @@ type InsertDataset struct { queryFactory exec.QueryFactory } +var errUnsupportedIntoType = errors.New("unsupported table type, a string or identifier expression is required") + // used internally by database to create a database with a specific adapter func newInsertDataset(d string, queryFactory exec.QueryFactory) *InsertDataset { return &InsertDataset{ @@ -117,7 +120,7 @@ func (id *InsertDataset) Into(into interface{}) *InsertDataset { case string: return id.copy(id.clauses.SetInto(exp.ParseIdentifier(t))) default: - panic("unsupported table type, a string or identifier expression is required") + panic(errUnsupportedIntoType) } } @@ -128,7 +131,7 @@ func (id *InsertDataset) Cols(cols ...interface{}) *InsertDataset { // Clears the Columns to insert into func (id *InsertDataset) ClearCols() *InsertDataset { - return id.copy(id.clauses.SetCols(exp.NewColumnListExpression(exp.Star()))) + return id.copy(id.clauses.SetCols(nil)) } // Adds columns to the current list of columns clause. See examples diff --git a/insert_dataset_test.go b/insert_dataset_test.go index 8c082360..e9ccdd42 100644 --- a/insert_dataset_test.go +++ b/insert_dataset_test.go @@ -1,7 +1,6 @@ package goqu import ( - "database/sql" "testing" "time" @@ -15,18 +14,20 @@ import ( "github.com/stretchr/testify/suite" ) -type insertDatasetSuite struct { - suite.Suite -} - -func (ids *insertDatasetSuite) SetupSuite() { - noReturn := DefaultDialectOptions() - noReturn.SupportsReturn = false - RegisterDialect("no-return", noReturn) -} +type ( + insertTestCase struct { + ds *InsertDataset + clauses exp.InsertClauses + } + insertDatasetSuite struct { + suite.Suite + } +) -func (ids *insertDatasetSuite) TearDownSuite() { - DeregisterDialect("no-return") +func (ids *insertDatasetSuite) assertCases(cases ...insertTestCase) { + for _, s := range cases { + ids.Equal(s.clauses, s.ds.GetClauses()) + } } func (ids *insertDatasetSuite) TestClone() { @@ -71,557 +72,309 @@ func (ids *insertDatasetSuite) TestGetClauses() { } func (ids *insertDatasetSuite) TestWith() { - from := Insert("cte") - ds := Insert("test") - dsc := ds.GetClauses() - ec := dsc.CommonTablesAppend(exp.NewCommonTableExpression(false, "test-cte", from)) - ids.Equal(ec, ds.With("test-cte", from).GetClauses()) - ids.Equal(dsc, ds.GetClauses()) + from := From("cte") + bd := Insert("items") + ids.assertCases( + insertTestCase{ + ds: bd.With("test-cte", from), + clauses: exp.NewInsertClauses(). + SetInto(C("items")). + CommonTablesAppend(exp.NewCommonTableExpression(false, "test-cte", from)), + }, + insertTestCase{ + ds: bd, + clauses: exp.NewInsertClauses().SetInto(C("items")), + }, + ) } func (ids *insertDatasetSuite) TestWithRecursive() { - from := Insert("cte") - ds := Insert("test") - dsc := ds.GetClauses() - ec := dsc.CommonTablesAppend(exp.NewCommonTableExpression(true, "test-cte", from)) - ids.Equal(ec, ds.WithRecursive("test-cte", from).GetClauses()) - ids.Equal(dsc, ds.GetClauses()) -} - -func (ids *insertDatasetSuite) TestRows_ToSQLWithNullTimeField() { - type item struct { - CreatedAt *time.Time `db:"created_at"` - } - ds := Insert("items").Rows(item{CreatedAt: nil}) - insertSQL, args, err := ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("created_at") VALUES (NULL)`, insertSQL) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("created_at") VALUES (NULL)`, insertSQL) + from := From("cte") + bd := Insert("items") + ids.assertCases( + insertTestCase{ + ds: bd.WithRecursive("test-cte", from), + clauses: exp.NewInsertClauses(). + SetInto(C("items")). + CommonTablesAppend(exp.NewCommonTableExpression(true, "test-cte", from)), + }, + insertTestCase{ + ds: bd, + clauses: exp.NewInsertClauses().SetInto(C("items")), + }, + ) } -func (ids *insertDatasetSuite) TestRows_ToSQLWithInvalidValue() { - ds := Insert("test").Rows(true) - _, _, err := ds.ToSQL() - ids.EqualError(err, "goqu: unsupported insert must be map, goqu.Record, or struct type got: bool") +func (ids *insertDatasetSuite) TestInto() { + bd := Insert("items") + ids.assertCases( + insertTestCase{ + ds: bd.Into("items2"), + clauses: exp.NewInsertClauses().SetInto(C("items2")), + }, + insertTestCase{ + ds: bd.Into(L("items2")), + clauses: exp.NewInsertClauses().SetInto(L("items2")), + }, + insertTestCase{ + ds: bd, + clauses: exp.NewInsertClauses().SetInto(C("items")), + }, + ) - _, _, err = ds.Prepared(true).ToSQL() - ids.EqualError(err, "goqu: unsupported insert must be map, goqu.Record, or struct type got: bool") + ids.PanicsWithValue(errUnsupportedIntoType, func() { + bd.Into(true) + }) } -func (ids *insertDatasetSuite) TestRows_ToSQLWithStructs() { - type item struct { - Address string `db:"address"` - Name string `db:"name"` - Created time.Time `db:"created"` - } - ds := Insert("items") - created, _ := time.Parse("2006-01-02", "2015-01-01") - ds1 := ds.Rows(item{Name: "Test", Address: "111 Test Addr", Created: created}) - - insertSQL, args, err := ds1.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal( - `INSERT INTO "items" ("address", "created", "name") VALUES ('111 Test Addr', '`+created.Format(time.RFC3339Nano)+`', 'Test')`, - insertSQL, - ) // #nosec - - insertSQL, args, err = ds1.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{"111 Test Addr", created, "Test"}, args) - ids.Equal(`INSERT INTO "items" ("address", "created", "name") VALUES (?, ?, ?)`, insertSQL) - - ds2 := ds1.Rows( - item{Address: "111 Test Addr", Name: "Test1", Created: created}, - item{Address: "211 Test Addr", Name: "Test2", Created: created}, - item{Address: "311 Test Addr", Name: "Test3", Created: created}, - item{Address: "411 Test Addr", Name: "Test4", Created: created}, +func (ids *insertDatasetSuite) TestCols() { + bd := Insert("items") + ids.assertCases( + insertTestCase{ + ds: bd.Cols("a", "b"), + clauses: exp.NewInsertClauses(). + SetInto(C("items")). + SetCols(exp.NewColumnListExpression("a", "b")), + }, + insertTestCase{ + ds: bd.Cols("a", "b").Cols("c", "d"), + clauses: exp.NewInsertClauses(). + SetInto(C("items")). + SetCols(exp.NewColumnListExpression("c", "d")), + }, + insertTestCase{ + ds: bd, + clauses: exp.NewInsertClauses().SetInto(C("items")), + }, ) +} - insertSQL, args, err = ds2.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal( - `INSERT INTO "items" ("address", "created", "name") VALUES `+ - `('111 Test Addr', '`+created.Format(time.RFC3339Nano)+`', 'Test1'), `+ - `('211 Test Addr', '`+created.Format(time.RFC3339Nano)+`', 'Test2'), `+ - `('311 Test Addr', '`+created.Format(time.RFC3339Nano)+`', 'Test3'), `+ - `('411 Test Addr', '`+created.Format(time.RFC3339Nano)+`', 'Test4')`, - insertSQL, +func (ids *insertDatasetSuite) TestClearCols() { + bd := Insert("items").Cols("a", "b") + ids.assertCases( + insertTestCase{ + ds: bd.ClearCols(), + clauses: exp.NewInsertClauses().SetInto(C("items")), + }, + insertTestCase{ + ds: bd, + clauses: exp.NewInsertClauses().SetInto(C("items")).SetCols(exp.NewColumnListExpression("a", "b")), + }, ) +} - insertSQL, args, err = ds2.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{ - "111 Test Addr", created, "Test1", - "211 Test Addr", created, "Test2", - "311 Test Addr", created, "Test3", - "411 Test Addr", created, "Test4", - }, args) - ids.Equal( - `INSERT INTO "items" ("address", "created", "name") VALUES (?, ?, ?), (?, ?, ?), (?, ?, ?), (?, ?, ?)`, - insertSQL, +func (ids *insertDatasetSuite) TestColsAppend() { + bd := Insert("items").Cols("a") + ids.assertCases( + insertTestCase{ + ds: bd.ColsAppend("b"), + clauses: exp.NewInsertClauses().SetInto(C("items")).SetCols(exp.NewColumnListExpression("a", "b")), + }, + insertTestCase{ + ds: bd, + clauses: exp.NewInsertClauses().SetInto(C("items")).SetCols(exp.NewColumnListExpression("a")), + }, ) } -func (ids *insertDatasetSuite) TestRows_ToSQLWithEmbeddedStruct() { - type Phone struct { - Primary string `db:"primary_phone"` - Home string `db:"home_phone"` - } - type item struct { - Phone - Address string `db:"address"` - Name string `db:"name"` - } +func (ids *insertDatasetSuite) TestFromQuery() { bd := Insert("items") - ds := bd.Rows(item{ - Name: "Test", - Address: "111 Test Addr", - Phone: Phone{ - Home: "123123", - Primary: "456456", + ids.assertCases( + insertTestCase{ + ds: bd.FromQuery(From("other_items").Where(C("b").Gt(10))), + clauses: exp.NewInsertClauses(). + SetInto(C("items")). + SetFrom(From("other_items").Where(C("b").Gt(10))), + }, + insertTestCase{ + ds: bd.FromQuery(From("other_items").Where(C("b").Gt(10))).Cols("a", "b"), + clauses: exp.NewInsertClauses(). + SetInto(C("items")). + SetCols(exp.NewColumnListExpression("a", "b")). + SetFrom(From("other_items").Where(C("b").Gt(10))), + }, + insertTestCase{ + ds: bd, + clauses: exp.NewInsertClauses().SetInto(C("items")), }, - }) - - insertSQL, args, err := ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "home_phone", "name", "primary_phone") VALUES `+ - `('111 Test Addr', '123123', 'Test', '456456')`, insertSQL) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{"111 Test Addr", "123123", "Test", "456456"}, args) - ids.Equal( - `INSERT INTO "items" ("address", "home_phone", "name", "primary_phone") VALUES (?, ?, ?, ?)`, - insertSQL, - ) - - ds = bd.Rows( - item{Address: "111 Test Addr", Name: "Test1", Phone: Phone{Home: "123123", Primary: "456456"}}, - item{Address: "211 Test Addr", Name: "Test2", Phone: Phone{Home: "123123", Primary: "456456"}}, - item{Address: "311 Test Addr", Name: "Test3", Phone: Phone{Home: "123123", Primary: "456456"}}, - item{Address: "411 Test Addr", Name: "Test4", Phone: Phone{Home: "123123", Primary: "456456"}}, ) - insertSQL, args, err = ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "home_phone", "name", "primary_phone") VALUES `+ - `('111 Test Addr', '123123', 'Test1', '456456'), `+ - `('211 Test Addr', '123123', 'Test2', '456456'), `+ - `('311 Test Addr', '123123', 'Test3', '456456'), `+ - `('411 Test Addr', '123123', 'Test4', '456456')`, insertSQL) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{ - "111 Test Addr", "123123", "Test1", "456456", - "211 Test Addr", "123123", "Test2", "456456", - "311 Test Addr", "123123", "Test3", "456456", - "411 Test Addr", "123123", "Test4", "456456", - }, args) - ids.Equal(`INSERT INTO "items" ("address", "home_phone", "name", "primary_phone") VALUES `+ - `(?, ?, ?, ?), `+ - `(?, ?, ?, ?), `+ - `(?, ?, ?, ?), `+ - `(?, ?, ?, ?)`, insertSQL) } -func (ids *insertDatasetSuite) TestRows_ToSQLWithEmbeddedStructPtr() { - type Phone struct { - Primary string `db:"primary_phone"` - Home string `db:"home_phone"` +func (ids *insertDatasetSuite) TestVals() { + + val1 := []interface{}{ + "a", "b", } - type item struct { - *Phone - Address string `db:"address"` - Name string `db:"name"` + val2 := []interface{}{ + "c", "d", } + bd := Insert("items") - ds := bd.Rows(item{ - Name: "Test", - Address: "111 Test Addr", - Phone: &Phone{ - Home: "123123", - Primary: "456456", + ids.assertCases( + insertTestCase{ + ds: bd.Vals(val1), + clauses: exp.NewInsertClauses(). + SetInto(C("items")). + SetVals([][]interface{}{val1}), + }, + insertTestCase{ + ds: bd.Vals(val1, val2), + clauses: exp.NewInsertClauses(). + SetInto(C("items")). + SetVals([][]interface{}{val1, val2}), + }, + insertTestCase{ + ds: bd.Vals(val1).Vals(val2), + clauses: exp.NewInsertClauses(). + SetInto(C("items")). + SetVals([][]interface{}{val1, val2}), + }, + insertTestCase{ + ds: bd, + clauses: exp.NewInsertClauses().SetInto(C("items")), }, - }) - - insertSQL, args, err := ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "home_phone", "name", "primary_phone") VALUES `+ - `('111 Test Addr', '123123', 'Test', '456456')`, insertSQL) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{"111 Test Addr", "123123", "Test", "456456"}, args) - ids.Equal( - insertSQL, - `INSERT INTO "items" ("address", "home_phone", "name", "primary_phone") VALUES (?, ?, ?, ?)`, ) +} - ds = bd.Rows( - item{Address: "111 Test Addr", Name: "Test1", Phone: &Phone{Home: "123123", Primary: "456456"}}, - item{Address: "211 Test Addr", Name: "Test2", Phone: &Phone{Home: "123123", Primary: "456456"}}, - item{Address: "311 Test Addr", Name: "Test3", Phone: &Phone{Home: "123123", Primary: "456456"}}, - item{Address: "411 Test Addr", Name: "Test4", Phone: &Phone{Home: "123123", Primary: "456456"}}, +func (ids *insertDatasetSuite) TestClearVals() { + val := []interface{}{ + "a", "b", + } + bd := Insert("items").Vals(val) + ids.assertCases( + insertTestCase{ + ds: bd.ClearVals(), + clauses: exp.NewInsertClauses().SetInto(C("items")), + }, + insertTestCase{ + ds: bd, + clauses: exp.NewInsertClauses().SetInto(C("items")).SetVals([][]interface{}{val}), + }, ) - insertSQL, args, err = ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "home_phone", "name", "primary_phone") VALUES `+ - `('111 Test Addr', '123123', 'Test1', '456456'), `+ - `('211 Test Addr', '123123', 'Test2', '456456'), `+ - `('311 Test Addr', '123123', 'Test3', '456456'), `+ - `('411 Test Addr', '123123', 'Test4', '456456')`, insertSQL) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{ - "111 Test Addr", "123123", "Test1", "456456", - "211 Test Addr", "123123", "Test2", "456456", - "311 Test Addr", "123123", "Test3", "456456", - "411 Test Addr", "123123", "Test4", "456456", - }, args) - ids.Equal(`INSERT INTO "items" ("address", "home_phone", "name", "primary_phone") VALUES `+ - `(?, ?, ?, ?), `+ - `(?, ?, ?, ?), `+ - `(?, ?, ?, ?), `+ - `(?, ?, ?, ?)`, insertSQL) } -func (ids *insertDatasetSuite) TestRows_ToSQLWithValuer() { +func (ids *insertDatasetSuite) TestRows() { type item struct { - Address string `db:"address"` - Name string `db:"name"` - Valuer sql.NullInt64 `db:"valuer"` + CreatedAt *time.Time `db:"created_at"` } - + n := time.Now() + r := item{CreatedAt: nil} + r2 := item{CreatedAt: &n} bd := Insert("items") - ds := bd.Rows(item{Name: "Test", Address: "111 Test Addr", Valuer: sql.NullInt64{Int64: 10, Valid: true}}) - insertSQL, args, err := ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name", "valuer") VALUES ('111 Test Addr', 'Test', 10)`, insertSQL) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{"111 Test Addr", "Test", int64(10)}, args) - ids.Equal(`INSERT INTO "items" ("address", "name", "valuer") VALUES (?, ?, ?)`, insertSQL) - - ds = bd.Rows( - item{Address: "111 Test Addr", Name: "Test1", Valuer: sql.NullInt64{Int64: 10, Valid: true}}, - item{Address: "211 Test Addr", Name: "Test2", Valuer: sql.NullInt64{Int64: 20, Valid: true}}, - item{Address: "311 Test Addr", Name: "Test3", Valuer: sql.NullInt64{Int64: 30, Valid: true}}, - item{Address: "411 Test Addr", Name: "Test4", Valuer: sql.NullInt64{Int64: 40, Valid: true}}, + ids.assertCases( + insertTestCase{ + ds: bd.Rows(r), + clauses: exp.NewInsertClauses().SetInto(C("items")).SetRows([]interface{}{r}), + }, + insertTestCase{ + ds: bd.Rows(r).Rows(r2), + clauses: exp.NewInsertClauses().SetInto(C("items")).SetRows([]interface{}{r2}), + }, + insertTestCase{ + ds: bd, + clauses: exp.NewInsertClauses().SetInto(C("items")), + }, ) - insertSQL, args, err = ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name", "valuer") VALUES `+ - `('111 Test Addr', 'Test1', 10), `+ - `('211 Test Addr', 'Test2', 20), `+ - `('311 Test Addr', 'Test3', 30), `+ - `('411 Test Addr', 'Test4', 40)`, insertSQL) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{ - "111 Test Addr", "Test1", int64(10), - "211 Test Addr", "Test2", int64(20), - "311 Test Addr", "Test3", int64(30), - "411 Test Addr", "Test4", int64(40), - }, args) - ids.Equal(`INSERT INTO "items" ("address", "name", "valuer") VALUES `+ - `(?, ?, ?), `+ - `(?, ?, ?), `+ - `(?, ?, ?), `+ - `(?, ?, ?)`, insertSQL) } -func (ids *insertDatasetSuite) TestRows_ToSQLWithValuerNull() { +func (ids *insertDatasetSuite) TestClearRows() { type item struct { - Address string `db:"address"` - Name string `db:"name"` - Valuer sql.NullInt64 `db:"valuer"` + CreatedAt *time.Time `db:"created_at"` } - - bd := Insert("items") - ds := bd.Rows(item{Name: "Test", Address: "111 Test Addr"}) - insertSQL, args, err := ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name", "valuer") VALUES ('111 Test Addr', 'Test', NULL)`, insertSQL) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{"111 Test Addr", "Test"}, args) - ids.Equal(`INSERT INTO "items" ("address", "name", "valuer") VALUES (?, ?, NULL)`, insertSQL) - - ds = bd.Rows( - item{Address: "111 Test Addr", Name: "Test1"}, - item{Address: "211 Test Addr", Name: "Test2"}, - item{Address: "311 Test Addr", Name: "Test3"}, - item{Address: "411 Test Addr", Name: "Test4"}, - ) - insertSQL, args, err = ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name", "valuer") VALUES `+ - `('111 Test Addr', 'Test1', NULL), `+ - `('211 Test Addr', 'Test2', NULL), `+ - `('311 Test Addr', 'Test3', NULL), `+ - `('411 Test Addr', 'Test4', NULL)`, - insertSQL, + r := item{CreatedAt: nil} + bd := Insert("items").Rows(r) + ids.assertCases( + insertTestCase{ + ds: bd.ClearRows(), + clauses: exp.NewInsertClauses().SetInto(C("items")), + }, + insertTestCase{ + ds: bd, + clauses: exp.NewInsertClauses().SetInto(C("items")).SetRows([]interface{}{r}), + }, ) - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{ - "111 Test Addr", "Test1", - "211 Test Addr", "Test2", - "311 Test Addr", "Test3", - "411 Test Addr", "Test4", - }, args) - ids.Equal(`INSERT INTO "items" ("address", "name", "valuer") VALUES `+ - `(?, ?, NULL), `+ - `(?, ?, NULL), `+ - `(?, ?, NULL), `+ - `(?, ?, NULL)`, - insertSQL, - ) } -func (ids *insertDatasetSuite) TestRows_ToSQLWithMaps() { - ds := Insert("items") +func (ids *insertDatasetSuite) TestOnConflict() { + du := DoUpdate("other_items", Record{"a": 1}) - ds1 := ds.Rows(map[string]interface{}{"name": "Test", "address": "111 Test Addr"}) - insertSQL, args, err := ds1.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test')`, insertSQL) - - insertSQL, args, err = ds1.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{"111 Test Addr", "Test"}, args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (?, ?)`, insertSQL) - - ds1 = ds.Rows( - map[string]interface{}{"address": "111 Test Addr", "name": "Test1"}, - map[string]interface{}{"address": "211 Test Addr", "name": "Test2"}, - map[string]interface{}{"address": "311 Test Addr", "name": "Test3"}, - map[string]interface{}{"address": "411 Test Addr", "name": "Test4"}, - ) - insertSQL, _, err = ds1.ToSQL() - ids.NoError(err) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES `+ - `('111 Test Addr', 'Test1'), `+ - `('211 Test Addr', 'Test2'), `+ - `('311 Test Addr', 'Test3'), `+ - `('411 Test Addr', 'Test4')`, - insertSQL, + bd := Insert("items") + ids.assertCases( + insertTestCase{ + ds: bd.OnConflict(nil), + clauses: exp.NewInsertClauses().SetInto(C("items")), + }, + insertTestCase{ + ds: bd.OnConflict(DoNothing()), + clauses: exp.NewInsertClauses().SetInto(C("items")).SetOnConflict(DoNothing()), + }, + insertTestCase{ + ds: bd.OnConflict(du), + clauses: exp.NewInsertClauses().SetInto(C("items")).SetOnConflict(du), + }, + insertTestCase{ + ds: bd, + clauses: exp.NewInsertClauses().SetInto(C("items")), + }, ) - - insertSQL, args, err = ds1.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{ - "111 Test Addr", "Test1", - "211 Test Addr", "Test2", - "311 Test Addr", "Test3", - "411 Test Addr", "Test4", - }, args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (?, ?), (?, ?), (?, ?), (?, ?)`, insertSQL) } -func (ids *insertDatasetSuite) TestRows_ToSQLWithSQLBuilder() { - ds := Insert("items") +func (ids *insertDatasetSuite) TestClearOnConflict() { + du := DoUpdate("other_items", Record{"a": 1}) - ds1 := ds.Rows(From("other_items").Where(C("b").Gt(10))) - - insertSQL, args, err := ds1.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" SELECT * FROM "other_items" WHERE ("b" > 10)`, insertSQL) - - insertSQL, args, err = ds1.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{int64(10)}, args) - ids.Equal(`INSERT INTO "items" SELECT * FROM "other_items" WHERE ("b" > ?)`, insertSQL) -} - -func (ids *insertDatasetSuite) TestRows_ToSQLWithMapsWithDifferentLengths() { - ds1 := Insert("items").Rows( - map[string]interface{}{"address": "111 Test Addr", "name": "Test1"}, - map[string]interface{}{"address": "211 Test Addr"}, - map[string]interface{}{"address": "311 Test Addr", "name": "Test3"}, - map[string]interface{}{"address": "411 Test Addr", "name": "Test4"}, - ) - _, _, err := ds1.ToSQL() - ids.EqualError(err, "goqu: rows with different value length expected 2 got 1") - _, _, err = ds1.Prepared(true).ToSQL() - ids.EqualError(err, "goqu: rows with different value length expected 2 got 1") -} - -func (ids *insertDatasetSuite) TestRows_ToSQLWitDifferentKeys() { - ds := Insert("items").Rows( - map[string]interface{}{"address": "111 Test Addr", "name": "test"}, - map[string]interface{}{"phoneNumber": 10, "address": "111 Test Addr"}, + bd := Insert("items").OnConflict(du) + ids.assertCases( + insertTestCase{ + ds: bd.ClearOnConflict(), + clauses: exp.NewInsertClauses().SetInto(C("items")), + }, + insertTestCase{ + ds: bd, + clauses: exp.NewInsertClauses().SetInto(C("items")).SetOnConflict(du), + }, ) - _, _, err := ds.ToSQL() - ids.EqualError(err, `goqu: rows with different keys expected ["address","name"] got ["address","phoneNumber"]`) - - _, _, err = ds.Prepared(true).ToSQL() - ids.EqualError(err, `goqu: rows with different keys expected ["address","name"] got ["address","phoneNumber"]`) } -func (ids *insertDatasetSuite) TestRows_ToSQLDifferentTypes() { - type item struct { - Address string `db:"address"` - Name string `db:"name"` - } - type item2 struct { - Address string `db:"address"` - Name string `db:"name"` - } +func (ids *insertDatasetSuite) TestReturning() { bd := Insert("items") - ds := bd.Rows( - item{Address: "111 Test Addr", Name: "Test1"}, - item2{Address: "211 Test Addr", Name: "Test2"}, - item{Address: "311 Test Addr", Name: "Test3"}, - item2{Address: "411 Test Addr", Name: "Test4"}, - ) - _, _, err := ds.ToSQL() - ids.EqualError(err, "goqu: rows must be all the same type expected goqu.item got goqu.item2") - _, _, err = ds.Prepared(true).ToSQL() - ids.EqualError(err, "goqu: rows must be all the same type expected goqu.item got goqu.item2") - - ds = bd.Rows( - item{Address: "111 Test Addr", Name: "Test1"}, - map[string]interface{}{"address": "211 Test Addr", "name": "Test2"}, - item{Address: "311 Test Addr", Name: "Test3"}, - map[string]interface{}{"address": "411 Test Addr", "name": "Test4"}, - ) - _, _, err = ds.ToSQL() - ids.EqualError(err, "goqu: rows must be all the same type expected goqu.item got map[string]interface {}") - - _, _, err = ds.Prepared(true).ToSQL() - ids.EqualError(err, "goqu: rows must be all the same type expected goqu.item got map[string]interface {}") -} - -func (ids *insertDatasetSuite) TestRows_ToSQLWithGoquSkipInsertTagSQL() { - type item struct { - ID uint32 `db:"id" goqu:"skipinsert"` - Address string `db:"address"` - Name string `db:"name"` - } - ds := Insert("items") - - ds1 := ds.Rows(item{Name: "Test", Address: "111 Test Addr"}) - - insertSQL, args, err := ds1.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test')`, insertSQL) - - insertSQL, args, err = ds1.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{"111 Test Addr", "Test"}, args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (?, ?)`, insertSQL) - - ds1 = ds.Rows( - item{Name: "Test1", Address: "111 Test Addr"}, - item{Name: "Test2", Address: "211 Test Addr"}, - item{Name: "Test3", Address: "311 Test Addr"}, - item{Name: "Test4", Address: "411 Test Addr"}, - ) - - insertSQL, args, err = ds1.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES `+ - `('111 Test Addr', 'Test1'), `+ - `('211 Test Addr', 'Test2'), `+ - `('311 Test Addr', 'Test3'), `+ - `('411 Test Addr', 'Test4')`, - insertSQL, + ids.assertCases( + insertTestCase{ + ds: bd.Returning("a"), + clauses: exp.NewInsertClauses(). + SetInto(C("items")). + SetReturning(exp.NewColumnListExpression("a")), + }, + insertTestCase{ + ds: bd.Returning("a").Returning("b"), + clauses: exp.NewInsertClauses(). + SetInto(C("items")). + SetReturning(exp.NewColumnListExpression("b")), + }, + insertTestCase{ + ds: bd, + clauses: exp.NewInsertClauses().SetInto(C("items")), + }, ) - - insertSQL, args, err = ds1.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{ - "111 Test Addr", "Test1", - "211 Test Addr", "Test2", - "311 Test Addr", "Test3", - "411 Test Addr", "Test4", - }, args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (?, ?), (?, ?), (?, ?), (?, ?)`, insertSQL) -} - -func (ids *insertDatasetSuite) TestRows_ToSQLWithGoquDefaultIfEmptyTag() { - type item struct { - ID uint32 `db:"id" goqu:"skipinsert"` - Address string `db:"address" goqu:"defaultifempty"` - Name string `db:"name" goqu:"defaultifempty"` - Bool bool `db:"bool" goqu:"skipinsert,defaultifempty"` - } - ds := Insert("items") - - ds1 := ds.Rows(item{Name: "Test", Address: "111 Test Addr"}) - - insertSQL, args, err := ds1.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test')`, insertSQL) - - insertSQL, args, err = ds1.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{"111 Test Addr", "Test"}, args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (?, ?)`, insertSQL) - - ds1 = ds.Rows(item{}) - - insertSQL, args, err = ds1.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (DEFAULT, DEFAULT)`, insertSQL) - - insertSQL, args, err = ds1.Prepared(true).ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (DEFAULT, DEFAULT)`, insertSQL) } -func (ids *insertDatasetSuite) TestRows_ToSQLWithDefaultValues() { - ds := Insert("items") - ds1 := ds.Rows() - - insertSQL, args, err := ds1.ToSQL() +func (ids *insertDatasetSuite) TestExecutor() { + mDb, _, err := sqlmock.New() ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" DEFAULT VALUES`, insertSQL) - insertSQL, args, err = ds1.Prepared(true).ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" DEFAULT VALUES`, insertSQL) + ds := newInsertDataset("mock", exec.NewQueryFactory(mDb)). + Into("items"). + Rows(Record{"address": "111 Test Addr", "name": "Test1"}) - ds1 = ds.Rows(map[string]interface{}{"name": Default(), "address": Default()}) - insertSQL, args, err = ds1.ToSQL() + isql, args, err := ds.Executor().ToSQL() ids.NoError(err) ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (DEFAULT, DEFAULT)`, insertSQL) + ids.Equal(`INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test1')`, isql) - insertSQL, _, err = ds1.ToSQL() + isql, args, err = ds.Prepared(true).Executor().ToSQL() ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (DEFAULT, DEFAULT)`, insertSQL) + ids.Equal([]interface{}{"111 Test Addr", "Test1"}, args) + ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (?, ?)`, isql) } func (ids *insertDatasetSuite) TestToSQL() { @@ -637,14 +390,17 @@ func (ids *insertDatasetSuite) TestToSQL() { md.AssertExpectations(ids.T()) } -func (ids *insertDatasetSuite) TestToSQL_WithNoInto() { - ds1 := newInsertDataset("test", nil).Rows(map[string]interface{}{ - "address": "111 Test Addr", "name": "Test1", - }) - _, _, err := ds1.ToSQL() - ids.EqualError(err, "goqu: no source found when generating insert sql") - _, _, err = ds1.Prepared(true).ToSQL() - ids.EqualError(err, "goqu: no source found when generating insert sql") +func (ids *insertDatasetSuite) TestToSQL_Prepared() { + md := new(mocks.SQLDialect) + ds := Insert("test").SetDialect(md).Prepared(true) + c := ds.GetClauses() + sqlB := sb.NewSQLBuilder(true) + md.On("ToInsertSQL", sqlB, c).Return(nil).Once() + insertSQL, args, err := ds.ToSQL() + ids.Empty(insertSQL) + ids.Empty(args) + ids.Nil(err) + md.AssertExpectations(ids.T()) } func (ids *insertDatasetSuite) TestToSQL_ReturnedError() { @@ -664,247 +420,6 @@ func (ids *insertDatasetSuite) TestToSQL_ReturnedError() { md.AssertExpectations(ids.T()) } -func (ids *insertDatasetSuite) TestFromQuery_ToSQL() { - bd := Insert("items") - - ds := bd.FromQuery(From("other_items").Where(C("b").Gt(10))) - - insertSQL, args, err := ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" SELECT * FROM "other_items" WHERE ("b" > 10)`, insertSQL) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{int64(10)}, args) - ids.Equal(`INSERT INTO "items" SELECT * FROM "other_items" WHERE ("b" > ?)`, insertSQL) -} - -func (ids *insertDatasetSuite) TestFromQuery_ToSQLWithCols() { - bd := Insert("items") - - ds := bd.Cols("a", "b").FromQuery(From("other_items").Select("c", "d").Where(C("b").Gt(10))) - - insertSQL, args, err := ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("a", "b") SELECT "c", "d" FROM "other_items" WHERE ("b" > 10)`, insertSQL) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{int64(10)}, args) - ids.Equal(`INSERT INTO "items" ("a", "b") SELECT "c", "d" FROM "other_items" WHERE ("b" > ?)`, insertSQL) -} - -func (ids *insertDatasetSuite) TestOnConflict__ToSQLNilConflictExpression() { - type item struct { - Address string `db:"address"` - Name string `db:"name"` - } - ds := Insert("items").Rows(item{Name: "Test", Address: "111 Test Addr"}).OnConflict(nil) - insertSQL, args, err := ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test')`, insertSQL) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{"111 Test Addr", "Test"}, args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (?, ?)`, insertSQL) -} - -func (ids *insertDatasetSuite) TestOnConflict__ToSQLDoUpdate() { - type item struct { - Address string `db:"address"` - Name string `db:"name"` - } - i := item{Name: "Test", Address: "111 Test Addr"} - ds := Insert("items").Rows(i).OnConflict( - DoUpdate("name", Record{"address": L("excluded.address")}), - ) - insertSQL, args, err := ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES `+ - `('111 Test Addr', 'Test') `+ - `ON CONFLICT (name) `+ - `DO UPDATE `+ - `SET "address"=excluded.address`, insertSQL) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{"111 Test Addr", "Test"}, args) - ids.Equal( - `INSERT INTO "items" ("address", "name") VALUES (?, ?) ON CONFLICT (name) DO UPDATE SET "address"=excluded.address`, - insertSQL, - ) -} - -func (ids *insertDatasetSuite) TestOnConflict__ToSQLDoUpdateWhere() { - type item struct { - Address string `db:"address"` - Name string `db:"name"` - } - i := item{Name: "Test", Address: "111 Test Addr"} - ds := Insert("items").Rows(i).OnConflict( - DoUpdate("name", Record{"address": L("excluded.address")}). - Where(C("name").Eq("Test")), - ) - - insertSQL, args, err := ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES `+ - `('111 Test Addr', 'Test') `+ - `ON CONFLICT (name) `+ - `DO UPDATE `+ - `SET "address"=excluded.address WHERE ("name" = 'Test')`, insertSQL) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{"111 Test Addr", "Test", "Test"}, args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES `+ - `(?, ?) `+ - `ON CONFLICT (name) `+ - `DO UPDATE `+ - `SET "address"=excluded.address WHERE ("name" = ?)`, insertSQL) -} - -func (ids *insertDatasetSuite) TestOnConflict__ToSQLWithDatasetDoUpdateWhere() { - fromDs := From("ds2") - ds := Insert("items"). - FromQuery(fromDs). - OnConflict( - DoUpdate("name", Record{"address": L("excluded.address")}).Where(C("name").Eq("Test")), - ) - - insertSQL, args, err := ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" `+ - `SELECT * FROM "ds2" `+ - `ON CONFLICT (name) `+ - `DO UPDATE `+ - `SET "address"=excluded.address WHERE ("name" = 'Test')`, insertSQL) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{"Test"}, args) - ids.Equal(`INSERT INTO "items" `+ - `SELECT * FROM "ds2" `+ - `ON CONFLICT (name) `+ - `DO UPDATE `+ - `SET "address"=excluded.address WHERE ("name" = ?)`, insertSQL) -} - -func (ids *insertDatasetSuite) TestOnConflict_ToSQLDoNothing() { - type item struct { - Address string `db:"address"` - Name string `db:"name"` - } - ds := Insert("items").Rows(item{Name: "Test", Address: "111 Test Addr"}).OnConflict(DoNothing()) - insertSQL, args, err := ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES `+ - `('111 Test Addr', 'Test') `+ - `ON CONFLICT DO NOTHING`, insertSQL) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{"111 Test Addr", "Test"}, args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (?, ?) ON CONFLICT DO NOTHING`, insertSQL) -} - -func (ids *insertDatasetSuite) TestReturning() { - ds := Insert("test") - dsc := ds.GetClauses() - ec := dsc.SetReturning(exp.NewColumnListExpression(C("a"))) - ids.Equal(ec, ds.Returning("a").GetClauses()) - ids.Equal(dsc, ds.GetClauses()) -} - -func (ids *insertDatasetSuite) TestReturning_ToSQL() { - type item struct { - Address string `db:"address"` - Name string `db:"name"` - } - bd := Insert("items").Returning("id") - - ds := bd.FromQuery(From("other_items").Where(C("b").Gt(10))) - - insertSQL, args, err := ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" SELECT * FROM "other_items" WHERE ("b" > 10) RETURNING "id"`, insertSQL) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{int64(10)}, args) - ids.Equal(`INSERT INTO "items" SELECT * FROM "other_items" WHERE ("b" > ?) RETURNING "id"`, insertSQL) - - ds = bd.Rows(map[string]interface{}{"name": "Test", "address": "111 Test Addr"}) - - insertSQL, args, err = ds.ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal( - `INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test') RETURNING "id"`, - insertSQL, - ) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{"111 Test Addr", "Test"}, args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (?, ?) RETURNING "id"`, insertSQL) - - ds = bd.Rows(item{Name: "Test", Address: "111 Test Addr"}) - - insertSQL, _, err = ds.ToSQL() - ids.NoError(err) - ids.Equal( - `INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test') RETURNING "id"`, - insertSQL, - ) - - insertSQL, args, err = ds.Prepared(true).ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{"111 Test Addr", "Test"}, args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (?, ?) RETURNING "id"`, insertSQL) -} - -func (ids *insertDatasetSuite) TestReturning_ToSQLReturnNotSupported() { - ds1 := New("no-return", nil).Insert("items") - type item struct { - Address string `db:"address"` - Name string `db:"name"` - } - _, _, err := ds1.Returning("id").Rows(item{Name: "Test", Address: "111 Test Addr"}).ToSQL() - ids.EqualError(err, "goqu: dialect does not support RETURNING clause [dialect=no-return]") - - _, _, err = ds1.Returning("id").Rows(From("test2")).ToSQL() - ids.EqualError(err, "goqu: dialect does not support RETURNING clause [dialect=no-return]") -} - -func (ids *insertDatasetSuite) TestExecutor() { - mDb, _, err := sqlmock.New() - ids.NoError(err) - - ds := newInsertDataset("mock", exec.NewQueryFactory(mDb)). - Into("items"). - Rows(Record{"address": "111 Test Addr", "name": "Test1"}) - - isql, args, err := ds.Executor().ToSQL() - ids.NoError(err) - ids.Empty(args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test1')`, isql) - - isql, args, err = ds.Prepared(true).Executor().ToSQL() - ids.NoError(err) - ids.Equal([]interface{}{"111 Test Addr", "Test1"}, args) - ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (?, ?)`, isql) -} - func TestInsertDataset(t *testing.T) { suite.Run(t, new(insertDatasetSuite)) } diff --git a/internal/sb/sql_builder.go b/internal/sb/sql_builder.go index 8dbf660d..643b0083 100644 --- a/internal/sb/sql_builder.go +++ b/internal/sb/sql_builder.go @@ -9,7 +9,6 @@ type ( SQLBuilder interface { Error() error SetError(err error) SQLBuilder - Clear() SQLBuilder WriteArg(i ...interface{}) SQLBuilder Write(p []byte) SQLBuilder WriteStrings(ss ...string) SQLBuilder @@ -49,13 +48,6 @@ func (b *sqlBuilder) SetError(err error) SQLBuilder { return b } -func (b *sqlBuilder) Clear() SQLBuilder { - b.buf.Truncate(0) - b.args = make([]interface{}, 0) - b.err = nil - return b -} - func (b *sqlBuilder) Write(bs []byte) SQLBuilder { if b.err == nil { b.buf.Write(bs) diff --git a/select_dataset.go b/select_dataset.go index 0b3afb22..2852af42 100644 --- a/select_dataset.go +++ b/select_dataset.go @@ -112,7 +112,7 @@ func (sd *SelectDataset) Update() *UpdateDataset { if sd.clauses.Where() != nil { c = c.WhereAppend(sd.clauses.Where()) } - if c.HasLimit() { + if sd.clauses.HasLimit() { c = c.SetLimit(sd.clauses.Limit()) } if sd.clauses.HasOrder() { @@ -225,6 +225,10 @@ func (sd *SelectDataset) Select(selects ...interface{}) *SelectDataset { // See examples // Deprecated: Use Distinct() instead. func (sd *SelectDataset) SelectDistinct(selects ...interface{}) *SelectDataset { + if len(selects) == 0 { + cleared := sd.ClearSelect() + return cleared.copy(cleared.clauses.SetDistinct(nil)) + } return sd.copy(sd.clauses.SetSelect(exp.NewColumnListExpression(selects...)).SetDistinct(exp.NewColumnListExpression())) } @@ -533,6 +537,9 @@ func (sd *SelectDataset) ScanStructs(i interface{}) error { // // i: A pointer to a slice of structs func (sd *SelectDataset) ScanStructsContext(ctx context.Context, i interface{}) error { + if sd.queryFactory == nil { + return errQueryFactoryNotFoundError + } ds := sd if sd.GetClauses().IsDefaultSelect() { ds = sd.Select(i) @@ -596,6 +603,9 @@ func (sd *SelectDataset) ScanVal(i interface{}) (bool, error) { // // i: A pointer to a primitive value func (sd *SelectDataset) ScanValContext(ctx context.Context, i interface{}) (bool, error) { + if sd.queryFactory == nil { + return false, errQueryFactoryNotFoundError + } return sd.Limit(1).Executor().ScanValContext(ctx, i) } diff --git a/select_dataset_test.go b/select_dataset_test.go index 92436345..f6ae286b 100644 --- a/select_dataset_test.go +++ b/select_dataset_test.go @@ -13,2086 +13,1004 @@ import ( "github.com/stretchr/testify/suite" ) -type dsTestActionItem struct { - Address string `db:"address"` - Name string `db:"name"` -} - -type selectDatasetSuite struct { - suite.Suite -} - -func (sds *selectDatasetSuite) TestClone() { - ds := From("test") - sds.Equal(ds, ds.Clone()) -} - -func (sds *selectDatasetSuite) TestExpression() { - ds := From("test") - sds.Equal(ds, ds.Expression()) -} - -func (sds *selectDatasetSuite) TestDialect() { - ds := From("test") - sds.NotNil(ds.Dialect()) -} - -func (sds *selectDatasetSuite) TestWithDialect() { - ds := From("test") - md := new(mocks.SQLDialect) - ds = ds.SetDialect(md) - - dialect := GetDialect("default") - dialectDs := ds.WithDialect("default") - sds.Equal(md, ds.Dialect()) - sds.Equal(dialect, dialectDs.Dialect()) -} - -func (sds *selectDatasetSuite) TestPrepared() { - ds := From("test") - preparedDs := ds.Prepared(true) - sds.True(preparedDs.IsPrepared()) - sds.False(ds.IsPrepared()) - // should apply the prepared to any datasets created from the root - sds.True(preparedDs.Where(Ex{"a": 1}).IsPrepared()) -} - -func (sds *selectDatasetSuite) TestGetClauses() { - ds := From("test") - ce := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression(I("test"))) - sds.Equal(ce, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestWith() { - from := From("cte") - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.CommonTablesAppend(exp.NewCommonTableExpression(false, "test-cte", from)) - sds.Equal(ec, ds.With("test-cte", from).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestWithRecursive() { - from := From("cte") - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.CommonTablesAppend(exp.NewCommonTableExpression(true, "test-cte", from)) - sds.Equal(ec, ds.WithRecursive("test-cte", from).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestSelect() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.SetSelect(exp.NewColumnListExpression(C("a"))) - sds.Equal(ec, ds.Select(C("a")).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestSelect_ToSQL() { - ds1 := From("test") - - selectSQL, _, err := ds1.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.Select().ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.Select("id").ToSQL() - sds.NoError(err) - sds.Equal(`SELECT "id" FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.Select("id", "name").ToSQL() - sds.NoError(err) - sds.Equal(`SELECT "id", "name" FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.Select(L("COUNT(*)").As("count")).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT COUNT(*) AS "count" FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.Select(C("id").As("other_id"), L("COUNT(*)").As("count")).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT "id" AS "other_id", COUNT(*) AS "count" FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.From().Select(ds1.From("test_1").Select("id")).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT (SELECT "id" FROM "test_1")`, selectSQL) - - selectSQL, _, err = ds1.From().Select(ds1.From("test_1").Select("id").As("test_id")).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT (SELECT "id" FROM "test_1") AS "test_id"`, selectSQL) - - selectSQL, _, err = ds1.From(). - Select( - DISTINCT("a").As("distinct"), - COUNT("a").As("count"), - L("CASE WHEN ? THEN ? ELSE ? END", MIN("a").Eq(10), true, false), - L("CASE WHEN ? THEN ? ELSE ? END", AVG("a").Neq(10), true, false), - L("CASE WHEN ? THEN ? ELSE ? END", FIRST("a").Gt(10), true, false), - L("CASE WHEN ? THEN ? ELSE ? END", FIRST("a").Gte(10), true, false), - L("CASE WHEN ? THEN ? ELSE ? END", LAST("a").Lt(10), true, false), - L("CASE WHEN ? THEN ? ELSE ? END", LAST("a").Lte(10), true, false), - SUM("a").As("sum"), - COALESCE(C("a"), "a").As("colaseced"), - ).ToSQL() - sds.NoError(err) - sds.Equal( - `SELECT `+ - `DISTINCT("a") AS "distinct", `+ - `COUNT("a") AS "count", `+ - `CASE WHEN (MIN("a") = 10) THEN TRUE ELSE FALSE END, `+ - `CASE WHEN (AVG("a") != 10) THEN TRUE ELSE FALSE END, `+ - `CASE WHEN (FIRST("a") > 10) THEN TRUE ELSE FALSE END, `+ - `CASE WHEN (FIRST("a") >= 10) THEN TRUE ELSE FALSE END,`+ - ` CASE WHEN (LAST("a") < 10) THEN TRUE ELSE FALSE END, `+ - `CASE WHEN (LAST("a") <= 10) THEN TRUE ELSE FALSE END, `+ - `SUM("a") AS "sum", `+ - `COALESCE("a", 'a') AS "colaseced"`, - selectSQL, - ) - - type MyStruct struct { - Name string - Address string `db:"address"` - EmailAddress string `db:"email_address"` - FakeCol string `db:"-"` - } - selectSQL, _, err = ds1.Select(&MyStruct{}).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT "address", "email_address", "name" FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.Select(MyStruct{}).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT "address", "email_address", "name" FROM "test"`, selectSQL) - - type myStruct2 struct { - MyStruct - Zipcode string `db:"zipcode"` +type ( + selectTestCase struct { + ds *SelectDataset + clauses exp.SelectClauses } - - selectSQL, _, err = ds1.Select(&myStruct2{}).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT "address", "email_address", "name", "zipcode" FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.Select(myStruct2{}).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT "address", "email_address", "name", "zipcode" FROM "test"`, selectSQL) - - var myStructs []MyStruct - selectSQL, _, err = ds1.Select(&myStructs).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT "address", "email_address", "name" FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.Select(myStructs).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT "address", "email_address", "name" FROM "test"`, selectSQL) - // should not change original - selectSQL, _, err = ds1.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test"`, selectSQL) -} - -func (sds *selectDatasetSuite) TestSelectDistinct() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.SetSelect(exp.NewColumnListExpression(C("a"))).SetDistinct(exp.NewColumnListExpression()) - sds.Equal(ec, ds.SelectDistinct(C("a")).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestSelectDistinct_ToSQL() { - ds1 := From("test") - - selectSQL, _, err := ds1.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.SelectDistinct("id").ToSQL() - sds.NoError(err) - sds.Equal(`SELECT DISTINCT "id" FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.SelectDistinct("id", "name").ToSQL() - sds.NoError(err) - sds.Equal(`SELECT DISTINCT "id", "name" FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.SelectDistinct(L("COUNT(*)").As("count")).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT DISTINCT COUNT(*) AS "count" FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.SelectDistinct(C("id").As("other_id"), L("COUNT(*)").As("count")).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT DISTINCT "id" AS "other_id", COUNT(*) AS "count" FROM "test"`, selectSQL) - - type MyStruct struct { - Name string - Address string `db:"address"` - EmailAddress string `db:"email_address"` - FakeCol string `db:"-"` + dsTestActionItem struct { + Address string `db:"address"` + Name string `db:"name"` } - selectSQL, _, err = ds1.SelectDistinct(&MyStruct{}).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT DISTINCT "address", "email_address", "name" FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.SelectDistinct(MyStruct{}).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT DISTINCT "address", "email_address", "name" FROM "test"`, selectSQL) - - type myStruct2 struct { - MyStruct - Zipcode string `db:"zipcode"` + selectDatasetSuite struct { + suite.Suite } +) - selectSQL, _, err = ds1.SelectDistinct(&myStruct2{}).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT DISTINCT "address", "email_address", "name", "zipcode" FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.SelectDistinct(myStruct2{}).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT DISTINCT "address", "email_address", "name", "zipcode" FROM "test"`, selectSQL) - - var myStructs []MyStruct - selectSQL, _, err = ds1.SelectDistinct(&myStructs).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT DISTINCT "address", "email_address", "name" FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.SelectDistinct(myStructs).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT DISTINCT "address", "email_address", "name" FROM "test"`, selectSQL) - // should not change original - selectSQL, _, err = ds1.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test"`, selectSQL) - // should not change original - selectSQL, _, err = ds1.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test"`, selectSQL) -} - -func (sds *selectDatasetSuite) TestDistinct() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.SetDistinct(exp.NewColumnListExpression()) - ecs := dsc.SetSelect(exp.NewColumnListExpression("a", "b")).SetDistinct(exp.NewColumnListExpression()) - ecsd := dsc.SetSelect(exp.NewColumnListExpression("a", "b")).SetDistinct(exp.NewColumnListExpression("c")) - sds.Equal(ec, ds.Distinct().GetClauses()) - sds.Equal(ecs, ds.Select("a", "b").Distinct().GetClauses()) - sds.Equal(ecsd, ds.Select("a", "b").Distinct("c").GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestDistinct_ToSQL() { - ds1 := From("test") - - selectSQL, _, err := ds1.Distinct().ToSQL() - sds.NoError(err) - sds.Equal(`SELECT DISTINCT * FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.Distinct("id").ToSQL() - sds.NoError(err) - sds.Equal(`SELECT DISTINCT ON ("id") * FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.Distinct("id").Select("name").ToSQL() - sds.NoError(err) - sds.Equal(`SELECT DISTINCT ON ("id") "name" FROM "test"`, selectSQL) - - selectSQL, _, err = ds1.Select(L("COUNT(*)").As("count")).Distinct(COALESCE(C("b"), "empty")).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT DISTINCT ON (COALESCE("b", 'empty')) COUNT(*) AS "count" FROM "test"`, selectSQL) - - // should not change original - selectSQL, _, err = ds1.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test"`, selectSQL) - // should not change original - selectSQL, _, err = ds1.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test"`, selectSQL) -} - -func (sds *selectDatasetSuite) TestClearSelect() { - ds := From("test").Select(C("a")) - dsc := ds.GetClauses() - ec := dsc.SetSelect(exp.NewColumnListExpression(Star())) - sds.Equal(ec, ds.ClearSelect().GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestClearSelect_ToSQL() { - ds1 := From("test") - - selectSQL, _, err := ds1.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test"`, selectSQL) - - b := ds1.Select("a").ClearSelect() - selectSQL, _, err = b.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test"`, selectSQL) -} - -func (sds *selectDatasetSuite) TestSelectAppend(selects ...interface{}) { - ds := From("test").Select(C("a")) - dsc := ds.GetClauses() - ec := dsc.SelectAppend(exp.NewColumnListExpression(C("b"))) - sds.Equal(ec, ds.SelectAppend(C("b")).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestSelectAppend_ToSQL() { - ds1 := From("test") - - selectSQL, _, err := ds1.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test"`, selectSQL) - - b := ds1.Select("a").SelectAppend("b").SelectAppend("c") - selectSQL, _, err = b.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT "a", "b", "c" FROM "test"`, selectSQL) -} - -func (sds *selectDatasetSuite) TestFrom() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.SetFrom(exp.NewColumnListExpression(T("t"))) - sds.Equal(ec, ds.From(T("t")).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestFrom_ToSQL() { - ds1 := From("test") - - selectSQL, _, err := ds1.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test"`, selectSQL) - - ds2 := ds1.From("test2") - selectSQL, _, err = ds2.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test2"`, selectSQL) - - ds2 = ds1.From("test2", "test3") - selectSQL, _, err = ds2.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test2", "test3"`, selectSQL) - - ds2 = ds1.From(T("test2").As("test_2"), "test3") - selectSQL, _, err = ds2.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test2" AS "test_2", "test3"`, selectSQL) - - ds2 = ds1.From(ds1.From("test2"), "test3") - selectSQL, _, err = ds2.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM (SELECT * FROM "test2") AS "t1", "test3"`, selectSQL) - - ds2 = ds1.From(ds1.From("test2").As("test_2"), "test3") - selectSQL, _, err = ds2.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM (SELECT * FROM "test2") AS "test_2", "test3"`, selectSQL) - // should not change original - selectSQL, _, err = ds1.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test"`, selectSQL) -} - -func (sds *selectDatasetSuite) TestFromSelf() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.SetFrom(exp.NewColumnListExpression(ds.As("t1"))) - sds.Equal(ec, ds.FromSelf().GetClauses()) - - ec2 := dsc.SetFrom(exp.NewColumnListExpression(ds.As("test"))) - sds.Equal(ec2, ds.As("test").FromSelf().GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestCompoundFromSelf() { - ds := From("test") - dsc := ds.GetClauses() - sds.Equal(dsc, ds.CompoundFromSelf().GetClauses()) - - ds2 := ds.Limit(1) - dsc2 := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression(ds2.As("t1"))) - sds.Equal(dsc2, ds2.CompoundFromSelf().GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestJoin() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.JoinsAppend( - exp.NewConditionedJoinExpression(exp.InnerJoinType, T("foo"), On(C("a").IsNull())), - ) - sds.Equal(ec, ds.Join(T("foo"), On(C("a").IsNull())).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestJoin_ToSQL() { - ds1 := From("items") - - b := ds1.Join(T("players").As("p"), On(Ex{"p.id": I("items.playerId")})) - selectSQL, args, err := b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM "items" INNER JOIN "players" AS "p" ON ("p"."id" = "items"."playerId")`, - selectSQL, - ) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM "items" INNER JOIN "players" AS "p" ON ("p"."id" = "items"."playerId")`, - selectSQL, - ) - - b = ds1.Join(ds1.From("players").As("p"), On(Ex{"p.id": I("items.playerId")})) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM "items" INNER JOIN (SELECT * FROM "players") AS "p" ON ("p"."id" = "items"."playerId")`, - selectSQL, - ) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM "items" INNER JOIN (SELECT * FROM "players") AS "p" ON ("p"."id" = "items"."playerId")`, - selectSQL, - ) - - b = ds1.Join(S("v1").Table("test"), On(Ex{"v1.test.id": I("items.playerId")})) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM "items" INNER JOIN "v1"."test" ON ("v1"."test"."id" = "items"."playerId")`, - selectSQL, - ) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM "items" INNER JOIN "v1"."test" ON ("v1"."test"."id" = "items"."playerId")`, - selectSQL, - ) - - b = ds1.Join(T("test"), Using(C("name"), C("common_id"))) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "items" INNER JOIN "test" USING ("name", "common_id")`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "items" INNER JOIN "test" USING ("name", "common_id")`, selectSQL) - - b = ds1.Join(T("test"), Using("name", "common_id")) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "items" INNER JOIN "test" USING ("name", "common_id")`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "items" INNER JOIN "test" USING ("name", "common_id")`, selectSQL) - - b = ds1.Join( - T("categories"), - On( - I("categories.categoryId").Eq(I("items.id")), - I("categories.categoryId").In(1, 2, 3), - ), - ) - - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM "items" `+ - `INNER JOIN "categories" ON (`+ - `("categories"."categoryId" = "items"."id") AND ("categories"."categoryId" IN (1, 2, 3))`+ - `)`, - selectSQL, - ) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1), int64(2), int64(3)}, args) - sds.Equal( - `SELECT * FROM "items" `+ - `INNER JOIN "categories" ON (`+ - `("categories"."categoryId" = "items"."id") AND ("categories"."categoryId" IN (?, ?, ?))`+ - `)`, - selectSQL, - ) -} - -func (sds *selectDatasetSuite) TestInnerJoin() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.JoinsAppend( - exp.NewConditionedJoinExpression(exp.InnerJoinType, T("foo"), On(C("a").IsNull())), - ) - sds.Equal(ec, ds.InnerJoin(T("foo"), On(C("a").IsNull())).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestInnerJoin_ToSQL() { - ds1 := From("items") - selectSQL, _, err := ds1. - InnerJoin(T("b"), On(Ex{"b.itemsId": I("items.id")})). - LeftOuterJoin(T("c"), On(Ex{"c.b_id": I("b.id")})). - ToSQL() - sds.NoError(err) - sds.Equal( - `SELECT * FROM "items" `+ - `INNER JOIN "b" ON ("b"."itemsId" = "items"."id") `+ - `LEFT OUTER JOIN "c" ON ("c"."b_id" = "b"."id")`, - selectSQL, - ) - - selectSQL, _, err = ds1. - InnerJoin(T("b"), On(Ex{"b.itemsId": I("items.id")})). - LeftOuterJoin(T("c"), On(Ex{"c.b_id": I("b.id")})). - ToSQL() - sds.NoError(err) - sds.Equal( - `SELECT * FROM "items" `+ - `INNER JOIN "b" ON ("b"."itemsId" = "items"."id") `+ - `LEFT OUTER JOIN "c" ON ("c"."b_id" = "b"."id")`, - selectSQL, - ) - - selectSQL, _, err = ds1.InnerJoin( - T("categories"), - On(Ex{"categories.categoryId": I("items.id")}), - ).ToSQL() - sds.NoError(err) - sds.Equal( - `SELECT * FROM "items" INNER JOIN "categories" ON ("categories"."categoryId" = "items"."id")`, - selectSQL, - ) -} - -func (sds *selectDatasetSuite) TestFullOuterJoin() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.JoinsAppend( - exp.NewConditionedJoinExpression(exp.FullOuterJoinType, T("foo"), On(C("a").IsNull())), - ) - sds.Equal(ec, ds.FullOuterJoin(T("foo"), On(C("a").IsNull())).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestFullOuterJoin_ToSQL() { - ds1 := From("items") - selectSQL, _, err := ds1. - FullOuterJoin(T("categories"), On(Ex{"categories.categoryId": I("items.id")})). - Order(C("stamp").Asc()).ToSQL() - sds.NoError(err) - sds.Equal( - `SELECT * FROM "items" `+ - `FULL OUTER JOIN "categories" ON ("categories"."categoryId" = "items"."id") ORDER BY "stamp" ASC`, - selectSQL, - ) - - selectSQL, _, err = ds1.FullOuterJoin( - T("categories"), - On(Ex{"categories.categoryId": I("items.id")}), - ).ToSQL() - sds.NoError(err) - sds.Equal( - `SELECT * FROM "items" FULL OUTER JOIN "categories" ON ("categories"."categoryId" = "items"."id")`, - selectSQL, - ) -} - -func (sds *selectDatasetSuite) TestRightOuterJoin() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.JoinsAppend( - exp.NewConditionedJoinExpression(exp.RightOuterJoinType, T("foo"), On(C("a").IsNull())), - ) - sds.Equal(ec, ds.RightOuterJoin(T("foo"), On(C("a").IsNull())).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestRightOuterJoin_ToSQL() { - ds1 := From("items") - selectSQL, _, err := ds1.RightOuterJoin( - T("categories"), - On(Ex{"categories.categoryId": I("items.id")}), - ).ToSQL() - sds.NoError(err) - sds.Equal( - `SELECT * FROM "items" RIGHT OUTER JOIN "categories" ON ("categories"."categoryId" = "items"."id")`, - selectSQL, - ) -} - -func (sds *selectDatasetSuite) TestLeftOuterJoin() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.JoinsAppend( - exp.NewConditionedJoinExpression(exp.LeftOuterJoinType, T("foo"), On(C("a").IsNull())), - ) - sds.Equal(ec, ds.LeftOuterJoin(T("foo"), On(C("a").IsNull())).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestLeftOuterJoin_ToSQL() { - ds1 := From("items") - - selectSQL, _, err := ds1.LeftOuterJoin(T("categories"), On(Ex{ - "categories.categoryId": I("items.id"), - })).ToSQL() - sds.NoError(err) - sds.Equal( - `SELECT * FROM "items" LEFT OUTER JOIN "categories" ON ("categories"."categoryId" = "items"."id")`, - selectSQL, - ) - - selectSQL, _, err = ds1. - LeftOuterJoin( - T("categories"), - On( - I("categories.categoryId").Eq(I("items.id")), - I("categories.categoryId").In(1, 2, 3)), - ).ToSQL() - sds.NoError(err) - sds.Equal( - `SELECT * FROM "items" `+ - `LEFT OUTER JOIN "categories" `+ - `ON (("categories"."categoryId" = "items"."id") AND ("categories"."categoryId" IN (1, 2, 3)))`, - selectSQL, - ) - -} - -func (sds *selectDatasetSuite) TestFullJoin() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.JoinsAppend( - exp.NewConditionedJoinExpression(exp.FullJoinType, T("foo"), On(C("a").IsNull())), - ) - sds.Equal(ec, ds.FullJoin(T("foo"), On(C("a").IsNull())).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestFullJoin_ToSQL() { - ds1 := From("items") - selectSQL, _, err := ds1.FullJoin( - T("categories"), - On(Ex{"categories.categoryId": I("items.id")}), - ).ToSQL() - sds.NoError(err) - sds.Equal( - `SELECT * FROM "items" FULL JOIN "categories" ON ("categories"."categoryId" = "items"."id")`, - selectSQL, - ) -} - -func (sds *selectDatasetSuite) TestRightJoin() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.JoinsAppend( - exp.NewConditionedJoinExpression(exp.RightJoinType, T("foo"), On(C("a").IsNull())), - ) - sds.Equal(ec, ds.RightJoin(T("foo"), On(C("a").IsNull())).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestRightJoin_ToSQL() { - ds1 := From("items") - selectSQL, _, err := ds1.RightJoin( - T("categories"), - On(Ex{"categories.categoryId": I("items.id")}), - ).ToSQL() - sds.NoError(err) - sds.Equal( - `SELECT * FROM "items" RIGHT JOIN "categories" ON ("categories"."categoryId" = "items"."id")`, - selectSQL, - ) -} - -func (sds *selectDatasetSuite) TestLeftJoin() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.JoinsAppend( - exp.NewConditionedJoinExpression(exp.LeftJoinType, T("foo"), On(C("a").IsNull())), - ) - sds.Equal(ec, ds.LeftJoin(T("foo"), On(C("a").IsNull())).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestLeftJoin_ToSQL() { - ds1 := From("items") - selectSQL, _, err := ds1.LeftJoin( - T("categories"), - On(Ex{"categories.categoryId": I("items.id")}), - ).ToSQL() - sds.NoError(err) - sds.Equal( - `SELECT * FROM "items" LEFT JOIN "categories" ON ("categories"."categoryId" = "items"."id")`, - selectSQL, - ) -} - -func (sds *selectDatasetSuite) TestNaturalJoin() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.JoinsAppend( - exp.NewUnConditionedJoinExpression(exp.NaturalJoinType, T("foo")), - ) - sds.Equal(ec, ds.NaturalJoin(T("foo")).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestNaturalJoin_ToSQL() { - ds1 := From("items") - selectSQL, _, err := ds1.NaturalJoin(T("categories")).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "items" NATURAL JOIN "categories"`, selectSQL) -} - -func (sds *selectDatasetSuite) TestNaturalLeftJoin() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.JoinsAppend( - exp.NewUnConditionedJoinExpression(exp.NaturalLeftJoinType, T("foo")), - ) - sds.Equal(ec, ds.NaturalLeftJoin(T("foo")).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestNaturalLeftJoin_ToSQL() { - ds1 := From("items") - selectSQL, _, err := ds1.NaturalLeftJoin(T("categories")).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "items" NATURAL LEFT JOIN "categories"`, selectSQL) - -} - -func (sds *selectDatasetSuite) TestNaturalRightJoin_ToSQL() { - ds1 := From("items") - selectSQL, _, err := ds1.NaturalRightJoin(T("categories")).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "items" NATURAL RIGHT JOIN "categories"`, selectSQL) -} - -func (sds *selectDatasetSuite) TestNaturalRightJoin() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.JoinsAppend( - exp.NewUnConditionedJoinExpression(exp.NaturalRightJoinType, T("foo")), - ) - sds.Equal(ec, ds.NaturalRightJoin(T("foo")).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} -func (sds *selectDatasetSuite) TestNaturalFullJoin() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.JoinsAppend( - exp.NewUnConditionedJoinExpression(exp.NaturalFullJoinType, T("foo")), - ) - sds.Equal(ec, ds.NaturalFullJoin(T("foo")).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestNaturalFullJoin_ToSQL() { - ds1 := From("items") - selectSQL, _, err := ds1.NaturalFullJoin(T("categories")).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "items" NATURAL FULL JOIN "categories"`, selectSQL) -} - -func (sds *selectDatasetSuite) TestCrossJoin() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.JoinsAppend( - exp.NewUnConditionedJoinExpression(exp.CrossJoinType, T("foo")), - ) - sds.Equal(ec, ds.CrossJoin(T("foo")).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestCrossJoin_ToSQL() { - selectSQL, _, err := From("items").CrossJoin(T("categories")).ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "items" CROSS JOIN "categories"`, selectSQL) -} - -func (sds *selectDatasetSuite) TestWhere() { - ds := From("test") - dsc := ds.GetClauses() - w := Ex{ - "a": 1, - } - ec := dsc.WhereAppend(w) - sds.Equal(ec, ds.Where(w).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestWhere_ToSQL() { - ds1 := From("test") - - b := ds1.Where( - C("a").Eq(true), - C("a").Neq(true), - C("a").Eq(false), - C("a").Neq(false), - ) - selectSQL, args, err := b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM "test" WHERE (("a" IS TRUE) AND ("a" IS NOT TRUE) AND ("a" IS FALSE) AND ("a" IS NOT FALSE))`, - selectSQL, - ) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM "test" WHERE (("a" IS TRUE) AND ("a" IS NOT TRUE) AND ("a" IS FALSE) AND ("a" IS NOT FALSE))`, - selectSQL, - ) - - b = ds1.Where( - C("a").Eq("a"), - C("b").Neq("b"), - C("c").Gt("c"), - C("d").Gte("d"), - C("e").Lt("e"), - C("f").Lte("f"), - ) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM "test" `+ - `WHERE (("a" = 'a') AND ("b" != 'b') AND ("c" > 'c') AND ("d" >= 'd') AND ("e" < 'e') AND ("f" <= 'f'))`, - selectSQL, - ) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{"a", "b", "c", "d", "e", "f"}, args) - sds.Equal( - `SELECT * FROM "test" `+ - `WHERE (("a" = ?) AND ("b" != ?) AND ("c" > ?) AND ("d" >= ?) AND ("e" < ?) AND ("f" <= ?))`, - selectSQL, - ) - - b = ds1.Where( - C("a").Eq(From("test2").Select("id")), - ) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" IN (SELECT "id" FROM "test2"))`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" IN (SELECT "id" FROM "test2"))`, selectSQL) - - b = ds1.Where(Ex{ - "a": "a", - "b": Op{"neq": "b"}, - "c": Op{"gt": "c"}, - "d": Op{"gte": "d"}, - "e": Op{"lt": "e"}, - "f": Op{"lte": "f"}, - }) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM "test" `+ - `WHERE (("a" = 'a') AND ("b" != 'b') AND ("c" > 'c') AND ("d" >= 'd') AND ("e" < 'e') AND ("f" <= 'f'))`, - selectSQL, - ) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{"a", "b", "c", "d", "e", "f"}, args) - sds.Equal( - `SELECT * FROM "test" `+ - `WHERE (("a" = ?) AND ("b" != ?) AND ("c" > ?) AND ("d" >= ?) AND ("e" < ?) AND ("f" <= ?))`, - selectSQL, - ) - - b = ds1.Where(Ex{ - "a": From("test2").Select("id"), - }) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" IN (SELECT "id" FROM "test2"))`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" IN (SELECT "id" FROM "test2"))`, selectSQL) -} - -func (sds *selectDatasetSuite) TestWhere_ToSQLEmpty() { - ds1 := From("test") - - b := ds1.Where() - selectSQL, _, err := b.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test"`, selectSQL) -} - -func (sds *selectDatasetSuite) TestWhere_ToSQLWithChain() { - ds1 := From("test").Where( - C("x").Eq(0), - C("y").Eq(1), - ) - - ds2 := ds1.Where( - C("z").Eq(2), - ) - - a := ds2.Where( - C("a").Eq("A"), - ) - b := ds2.Where( - C("b").Eq("B"), - ) - selectSQL, _, err := a.ToSQL() - sds.NoError(err) - sds.Equal( - `SELECT * FROM "test" WHERE (("x" = 0) AND ("y" = 1) AND ("z" = 2) AND ("a" = 'A'))`, - selectSQL, - ) - selectSQL, _, err = b.ToSQL() - sds.NoError(err) - sds.Equal( - `SELECT * FROM "test" WHERE (("x" = 0) AND ("y" = 1) AND ("z" = 2) AND ("b" = 'B'))`, - selectSQL, - ) -} - -func (sds *selectDatasetSuite) TestClearWhere() { - w := Ex{ - "a": 1, - } - ds := From("test").Where(w) - dsc := ds.GetClauses() - ec := dsc.ClearWhere() - sds.Equal(ec, ds.ClearWhere().GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestClearWhere_ToSQL() { - ds1 := From("test") - - b := ds1.Where( - C("a").Eq(1), - ).ClearWhere() - selectSQL, _, err := b.ToSQL() - sds.NoError(err) - sds.Equal(`SELECT * FROM "test"`, selectSQL) -} - -func (sds *selectDatasetSuite) TestForUpdate() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.SetLock(exp.NewLock(exp.ForUpdate, NoWait)) - sds.Equal(ec, ds.ForUpdate(NoWait).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestForUpdate_ToSQL() { - ds1 := From("test") - - b := ds1.Where(C("a").Gt(1)).ForUpdate(Wait) - selectSQL, args, err := b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) FOR UPDATE `, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) FOR UPDATE `, selectSQL) - - b = ds1.Where(C("a").Gt(1)).ForUpdate(NoWait) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) FOR UPDATE NOWAIT`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) FOR UPDATE NOWAIT`, selectSQL) - - b = ds1.Where(C("a").Gt(1)).ForUpdate(SkipLocked) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) FOR UPDATE SKIP LOCKED`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) FOR UPDATE SKIP LOCKED`, selectSQL) -} - -func (sds *selectDatasetSuite) TestForNoKeyUpdate() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.SetLock(exp.NewLock(exp.ForNoKeyUpdate, NoWait)) - sds.Equal(ec, ds.ForNoKeyUpdate(NoWait).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestForNoKeyUpdate_ToSQL() { - ds1 := From("test") - - b := ds1.Where(C("a").Gt(1)).ForNoKeyUpdate(Wait) - selectSQL, args, err := b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) FOR NO KEY UPDATE `, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) FOR NO KEY UPDATE `, selectSQL) - - b = ds1.Where(C("a").Gt(1)).ForNoKeyUpdate(NoWait) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) FOR NO KEY UPDATE NOWAIT`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) FOR NO KEY UPDATE NOWAIT`, selectSQL) - - b = ds1.Where(C("a").Gt(1)).ForNoKeyUpdate(SkipLocked) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) FOR NO KEY UPDATE SKIP LOCKED`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) FOR NO KEY UPDATE SKIP LOCKED`, selectSQL) -} - -func (sds *selectDatasetSuite) TestForKeyShare() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.SetLock(exp.NewLock(exp.ForKeyShare, NoWait)) - sds.Equal(ec, ds.ForKeyShare(NoWait).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestForKeyShare_ToSQL() { - ds1 := From("test") - - b := ds1.Where(C("a").Gt(1)).ForKeyShare(Wait) - selectSQL, args, err := b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) FOR KEY SHARE `, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) FOR KEY SHARE `, selectSQL) - - b = ds1.Where(C("a").Gt(1)).ForKeyShare(NoWait) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) FOR KEY SHARE NOWAIT`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) FOR KEY SHARE NOWAIT`, selectSQL) - - b = ds1.Where(C("a").Gt(1)).ForKeyShare(SkipLocked) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) FOR KEY SHARE SKIP LOCKED`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) FOR KEY SHARE SKIP LOCKED`, selectSQL) -} - -func (sds *selectDatasetSuite) TestForShare() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.SetLock(exp.NewLock(exp.ForShare, NoWait)) - sds.Equal(ec, ds.ForShare(NoWait).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestForShare_ToSQL() { - ds1 := From("test") - - b := ds1.Where(C("a").Gt(1)).ForShare(Wait) - selectSQL, args, err := b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) FOR SHARE `, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) FOR SHARE `, selectSQL) - - b = ds1.Where(C("a").Gt(1)).ForShare(NoWait) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) FOR SHARE NOWAIT`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) FOR SHARE NOWAIT`, selectSQL) - - b = ds1.Where(C("a").Gt(1)).ForShare(SkipLocked) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) FOR SHARE SKIP LOCKED`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) FOR SHARE SKIP LOCKED`, selectSQL) -} - -func (sds *selectDatasetSuite) TestGroupBy() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.SetGroupBy(exp.NewColumnListExpression(C("a"))) - sds.Equal(ec, ds.GroupBy("a").GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestGroupBy_ToSQL() { - ds1 := From("test") - - b := ds1.Where(C("a").Gt(1)).GroupBy("created") - selectSQL, args, err := b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) GROUP BY "created"`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) GROUP BY "created"`, selectSQL) - - b = ds1.Where(C("a").Gt(1)).GroupBy(L("created::DATE")) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) GROUP BY created::DATE`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) GROUP BY created::DATE`, selectSQL) - - b = ds1.Where(C("a").Gt(1)).GroupBy("name", L("created::DATE")) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) GROUP BY "name", created::DATE`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) GROUP BY "name", created::DATE`, selectSQL) -} - -func (sds *selectDatasetSuite) TestHaving() { - ds := From("test") - dsc := ds.GetClauses() - h := C("a").Gt(1) - ec := dsc.HavingAppend(h) - sds.Equal(ec, ds.Having(h).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestHaving_ToSQL() { - ds1 := From("test") - - b := ds1.Having(Ex{"a": Op{"gt": 1}}).GroupBy("created") - selectSQL, args, err := b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" GROUP BY "created" HAVING ("a" > 1)`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" GROUP BY "created" HAVING ("a" > ?)`, selectSQL) - - b = ds1.Where(Ex{"b": true}).Having(Ex{"a": Op{"gt": 1}}).GroupBy("created") - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("b" IS TRUE) GROUP BY "created" HAVING ("a" > 1)`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("b" IS TRUE) GROUP BY "created" HAVING ("a" > ?)`, selectSQL) - - b = ds1.Having(Ex{"a": Op{"gt": 1}}) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" HAVING ("a" > 1)`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" HAVING ("a" > ?)`, selectSQL) - - b = ds1.Having(Ex{"a": Op{"gt": 1}}).Having(Ex{"b": 2}) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" HAVING (("a" > 1) AND ("b" = 2))`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1), int64(2)}, args) - sds.Equal(`SELECT * FROM "test" HAVING (("a" > ?) AND ("b" = ?))`, selectSQL) - - b = ds1.GroupBy("name").Having(SUM("amount").Gt(0)) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" GROUP BY "name" HAVING (SUM("amount") > 0)`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(0)}, args) - sds.Equal(`SELECT * FROM "test" GROUP BY "name" HAVING (SUM("amount") > ?)`, selectSQL) -} - -func (sds *selectDatasetSuite) TestOrder() { - ds := From("test") - dsc := ds.GetClauses() - o := C("a").Desc() - ec := dsc.SetOrder(o) - sds.Equal(ec, ds.Order(o).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestOrder_ToSQL() { - - ds1 := From("test") - - b := ds1.Order(C("a").Asc(), L(`("a" + "b" > 2)`).Asc()) - selectSQL, args, err := b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" ORDER BY "a" ASC, ("a" + "b" > 2) ASC`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" ORDER BY "a" ASC, ("a" + "b" > 2) ASC`, selectSQL) -} - -func (sds *selectDatasetSuite) TestOrderAppend() { - ds := From("test").Order(C("a").Desc()) - dsc := ds.GetClauses() - o := C("b").Desc() - ec := dsc.OrderAppend(o) - sds.Equal(ec, ds.OrderAppend(o).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestOrderAppend_ToSQL() { - b := From("test").Order(C("a").Asc().NullsFirst()).OrderAppend(C("b").Desc().NullsLast()) - selectSQL, args, err := b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" ORDER BY "a" ASC NULLS FIRST, "b" DESC NULLS LAST`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" ORDER BY "a" ASC NULLS FIRST, "b" DESC NULLS LAST`, selectSQL) - - b = From("test").OrderAppend(C("a").Asc().NullsFirst()).OrderAppend(C("b").Desc().NullsLast()) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" ORDER BY "a" ASC NULLS FIRST, "b" DESC NULLS LAST`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" ORDER BY "a" ASC NULLS FIRST, "b" DESC NULLS LAST`, selectSQL) - -} - -func (sds *selectDatasetSuite) TestClearOrder() { - ds := From("test").Order(C("a").Desc()) - dsc := ds.GetClauses() - ec := dsc.ClearOrder() - sds.Equal(ec, ds.ClearOrder().GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestClearOrder_ToSQL() { - b := From("test").Order(C("a").Asc().NullsFirst()).ClearOrder() - selectSQL, args, err := b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test"`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test"`, selectSQL) -} - -func (sds *selectDatasetSuite) TestLimit() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.SetLimit(uint(1)) - sds.Equal(ec, ds.Limit(1).GetClauses()) - sds.Equal(dsc, ds.Limit(0).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestLimit_ToSQL() { - ds1 := From("test") - - b := ds1.Where(C("a").Gt(1)).Limit(10) - selectSQL, args, err := b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) LIMIT 10`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1), int64(10)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) LIMIT ?`, selectSQL) - - b = ds1.Where(C("a").Gt(1)).Limit(0) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1)`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?)`, selectSQL) -} - -func (sds *selectDatasetSuite) TestLimitAll() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.SetLimit(L("ALL")) - sds.Equal(ec, ds.LimitAll().GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestLimitAll_ToSQL() { - ds1 := From("test") - - b := ds1.Where(C("a").Gt(1)).LimitAll() - selectSQL, args, err := b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) LIMIT ALL`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) LIMIT ALL`, selectSQL) - - b = ds1.Where(C("a").Gt(1)).Limit(0).LimitAll() - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) LIMIT ALL`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) LIMIT ALL`, selectSQL) -} - -func (sds *selectDatasetSuite) TestClearLimit() { - ds := From("test").Limit(1) - dsc := ds.GetClauses() - ec := dsc.ClearLimit() - sds.Equal(ec, ds.ClearLimit().GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestClearLimit_ToSQL() { - ds1 := From("test") - - b := ds1.Where(C("a").Gt(1)).LimitAll().ClearLimit() - selectSQL, args, err := b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1)`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?)`, selectSQL) - - b = ds1.Where(C("a").Gt(1)).Limit(10).ClearLimit() - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1)`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?)`, selectSQL) -} - -func (sds *selectDatasetSuite) TestOffset() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.SetOffset(1) - sds.Equal(ec, ds.Offset(1).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestOffset_ToSQL() { - ds1 := From("test") - - b := ds1.Where(C("a").Gt(1)).Offset(10) - selectSQL, args, err := b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1) OFFSET 10`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1), int64(10)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?) OFFSET ?`, selectSQL) - - b = ds1.Where(C("a").Gt(1)).Offset(0) - selectSQL, args, err = b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1)`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?)`, selectSQL) -} - -func (sds *selectDatasetSuite) TestClearOffset() { - ds := From("test").Offset(1) - dsc := ds.GetClauses() - ec := dsc.ClearOffset() - sds.Equal(ec, ds.ClearOffset().GetClauses()) - sds.Equal(dsc, ds.GetClauses()) -} - -func (sds *selectDatasetSuite) TestClearOffset_ToSQL() { - ds1 := From("test") - - b := ds1.Where(C("a").Gt(1)).Offset(10).ClearOffset() - selectSQL, args, err := b.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > 1)`, selectSQL) - - selectSQL, args, err = b.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1)}, args) - sds.Equal(`SELECT * FROM "test" WHERE ("a" > ?)`, selectSQL) +func (sds *selectDatasetSuite) assertCases(cases ...selectTestCase) { + for _, s := range cases { + sds.Equal(s.clauses, s.ds.GetClauses()) + } } -func (sds *selectDatasetSuite) TestUnion() { - uds := From("union_test") +func (sds *selectDatasetSuite) TestClone() { ds := From("test") - dsc := ds.GetClauses() - ec := dsc.CompoundsAppend(exp.NewCompoundExpression(exp.UnionCompoundType, uds)) - sds.Equal(ec, ds.Union(uds).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) + sds.Equal(ds, ds.Clone()) } -func (sds *selectDatasetSuite) TestUnion_ToSQL() { - a := From("invoice").Select("id", "amount").Where(C("amount").Gt(1000)) - b := From("invoice").Select("id", "amount").Where(C("amount").Lt(10)) - - ds := a.Union(b) - selectSQL, args, err := ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) `+ - `UNION (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10))`, - selectSQL, - ) - - ds = a.Limit(1).Union(b) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM (`+ - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) LIMIT 1) AS "t1" `+ - `UNION (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10))`, - selectSQL, - ) +func (sds *selectDatasetSuite) TestExpression() { + ds := From("test") + sds.Equal(ds, ds.Expression()) +} - ds = a.Order(C("id").Asc()).Union(b) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM `+ - `(SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) ORDER BY "id" ASC) AS "t1" `+ - `UNION (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10))`, - selectSQL, - ) +func (sds *selectDatasetSuite) TestDialect() { + ds := From("test") + sds.NotNil(ds.Dialect()) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1000), int64(10)}, args) - sds.Equal( - `SELECT * FROM `+ - `(SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) ORDER BY "id" ASC) AS "t1" `+ - `UNION (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?))`, - selectSQL, - ) +func (sds *selectDatasetSuite) TestWithDialect() { + ds := From("test") + md := new(mocks.SQLDialect) + ds = ds.SetDialect(md) - ds = a.Union(b.Limit(1)) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) `+ - `UNION (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10) LIMIT 1) AS "t1")`, - selectSQL, - ) + dialect := GetDialect("default") + dialectDs := ds.WithDialect("default") + sds.Equal(md, ds.Dialect()) + sds.Equal(dialect, dialectDs.Dialect()) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1000), int64(10), int64(1)}, args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) `+ - `UNION (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?) LIMIT ?) AS "t1")`, - selectSQL, - ) +func (sds *selectDatasetSuite) TestPrepared() { + ds := From("test") + preparedDs := ds.Prepared(true) + sds.True(preparedDs.IsPrepared()) + sds.False(ds.IsPrepared()) + // should apply the prepared to any datasets created from the root + sds.True(preparedDs.Where(Ex{"a": 1}).IsPrepared()) +} - ds = a.Union(b.Order(C("id").Desc())) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) `+ - `UNION (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10) ORDER BY "id" DESC) AS "t1")`, - selectSQL, - ) +func (sds *selectDatasetSuite) TestGetClauses() { + ds := From("test") + ce := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression(I("test"))) + sds.Equal(ce, ds.GetClauses()) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1000), int64(10)}, args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) `+ - `UNION (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?) ORDER BY "id" DESC) AS "t1")`, - selectSQL, - ) +func (sds *selectDatasetSuite) TestUpdate() { + where := Ex{"a": 1} + from := From("cte") + limit := uint(1) + order := []exp.OrderedExpression{C("a").Asc(), C("b").Desc()} + ds := From("test"). + With("test-cte", from). + Where(where). + Limit(limit). + Order(order...) + ec := exp.NewUpdateClauses(). + SetTable(C("test")). + CommonTablesAppend(exp.NewCommonTableExpression(false, "test-cte", from)). + WhereAppend(ds.clauses.Where()). + SetLimit(limit). + SetOrder(order...) + sds.Equal(ec, ds.Update().GetClauses()) +} + +func (sds *selectDatasetSuite) TestInsert() { + where := Ex{"a": 1} + from := From("cte") + limit := uint(1) + order := []exp.OrderedExpression{C("a").Asc(), C("b").Desc()} + ds := From("test"). + With("test-cte", from). + Where(where). + Limit(limit). + Order(order...) + ec := exp.NewInsertClauses(). + SetInto(C("test")). + CommonTablesAppend(exp.NewCommonTableExpression(false, "test-cte", from)) + sds.Equal(ec, ds.Insert().GetClauses()) +} + +func (sds *selectDatasetSuite) TestDelete() { + where := Ex{"a": 1} + from := From("cte") + limit := uint(1) + order := []exp.OrderedExpression{C("a").Asc(), C("b").Desc()} + ds := From("test"). + With("test-cte", from). + Where(where). + Limit(limit). + Order(order...) + ec := exp.NewDeleteClauses(). + SetFrom(C("test")). + CommonTablesAppend(exp.NewCommonTableExpression(false, "test-cte", from)). + WhereAppend(ds.clauses.Where()). + SetLimit(limit). + SetOrder(order...) + sds.Equal(ec, ds.Delete().GetClauses()) +} + +func (sds *selectDatasetSuite) TestTruncate() { + where := Ex{"a": 1} + from := From("cte") + limit := uint(1) + order := []exp.OrderedExpression{C("a").Asc(), C("b").Desc()} + ds := From("test"). + With("test-cte", from). + Where(where). + Limit(limit). + Order(order...) + ec := exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")) + sds.Equal(ec, ds.Truncate().GetClauses()) +} - ds = a.Limit(1).Union(b.Order(C("id").Desc())) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM (`+ - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) LIMIT 1) AS "t1" `+ - `UNION (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10) ORDER BY "id" DESC) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestWith() { + from := From("cte") + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.With("test-cte", from), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + CommonTablesAppend(exp.NewCommonTableExpression(false, "test-cte", from)), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1000), int64(1), int64(10)}, args) - sds.Equal( - `SELECT * FROM (`+ - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) LIMIT ?) AS "t1" `+ - `UNION (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?) ORDER BY "id" DESC) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestWithRecursive() { + from := From("cte") + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.WithRecursive("test-cte", from), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + CommonTablesAppend(exp.NewCommonTableExpression(true, "test-cte", from)), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - ds = a.Union(b).Union(b.Where(C("id").Lt(50))) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) `+ - `UNION (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10)) `+ - `UNION (SELECT "id", "amount" FROM "invoice" WHERE (("amount" < 10) AND ("id" < 50)))`, - selectSQL, +func (sds *selectDatasetSuite) TestSelect() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.Select("a", "b"), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetSelect(exp.NewColumnListExpression("a", "b")), + }, + selectTestCase{ + ds: bd.Select("a").Select("b"), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetSelect(exp.NewColumnListExpression("b")), + }, + selectTestCase{ + ds: bd.Select("a").Select(), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1000), int64(10), int64(10), int64(50)}, args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) `+ - `UNION (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?)) `+ - `UNION (SELECT "id", "amount" FROM "invoice" WHERE (("amount" < ?) AND ("id" < ?)))`, - selectSQL, +func (sds *selectDatasetSuite) TestSelectDistinct() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.SelectDistinct("a", "b"), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetSelect(exp.NewColumnListExpression("a", "b")). + SetDistinct(exp.NewColumnListExpression()), + }, + selectTestCase{ + ds: bd.SelectDistinct("a").SelectDistinct("b"), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetSelect(exp.NewColumnListExpression("b")). + SetDistinct(exp.NewColumnListExpression()), + }, + selectTestCase{ + ds: bd.Select("a").SelectDistinct("b"), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetSelect(exp.NewColumnListExpression("b")). + SetDistinct(exp.NewColumnListExpression()), + }, + selectTestCase{ + ds: bd.Select("a").SelectDistinct(), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetSelect(exp.NewColumnListExpression(Star())). + SetDistinct(nil), + }, + selectTestCase{ + ds: bd.SelectDistinct("a").SelectDistinct(), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetSelect(exp.NewColumnListExpression(Star())). + SetDistinct(nil), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) - } -func (sds *selectDatasetSuite) TestUnionAll() { - uds := From("union_test") - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.CompoundsAppend(exp.NewCompoundExpression(exp.UnionAllCompoundType, uds)) - sds.Equal(ec, ds.UnionAll(uds).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) +func (sds *selectDatasetSuite) TestClearSelect() { + bd := From("test").Select("a") + sds.assertCases( + selectTestCase{ + ds: bd.ClearSelect(), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetSelect(exp.NewColumnListExpression("a")), + }, + ) +} + +func (sds *selectDatasetSuite) TestSelectAppend() { + bd := From("test").Select("a") + sds.assertCases( + selectTestCase{ + ds: bd.SelectAppend("b"), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetSelect(exp.NewColumnListExpression("a", "b")), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetSelect(exp.NewColumnListExpression("a")), + }, + ) } -func (sds *selectDatasetSuite) TestUnionAll_ToSQL() { - a := From("invoice").Select("id", "amount").Where(C("amount").Gt(1000)) - b := From("invoice").Select("id", "amount").Where(C("amount").Lt(10)) - - ds := a.UnionAll(b) - selectSQL, args, err := ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) `+ - `UNION ALL (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10))`, - selectSQL, +func (sds *selectDatasetSuite) TestDistinct() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.Distinct("a", "b"), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetDistinct(exp.NewColumnListExpression("a", "b")), + }, + selectTestCase{ + ds: bd.Distinct("a").Distinct("b"), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetDistinct(exp.NewColumnListExpression("b")), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - ds = a.Limit(1).UnionAll(b) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM (`+ - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) LIMIT 1) AS "t1" `+ - `UNION ALL (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10))`, - selectSQL, +func (sds *selectDatasetSuite) TestFrom() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.From(T("test2")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression(T("test2"))), + }, + selectTestCase{ + ds: bd.From(From("test")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression(From("test").As("t1"))), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - ds = a.Order(C("id").Asc()).UnionAll(b) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM `+ - `(SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) ORDER BY "id" ASC) AS "t1" `+ - `UNION ALL (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10))`, - selectSQL, +func (sds *selectDatasetSuite) TestFromSelf() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.FromSelf(), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression(bd.As("t1"))), + }, + selectTestCase{ + ds: bd.As("alias").FromSelf(), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression(bd.As("alias"))), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1000), int64(10)}, args) - sds.Equal( - `SELECT * FROM `+ - `(SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) ORDER BY "id" ASC) AS "t1" `+ - `UNION ALL (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?))`, - selectSQL, +func (sds *selectDatasetSuite) TestCompoundFromSelf() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.CompoundFromSelf(), + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, + selectTestCase{ + ds: bd.Limit(10).CompoundFromSelf(), + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression(bd.Limit(10).As("t1"))), + }, + selectTestCase{ + ds: bd.Order(C("a").Asc()).CompoundFromSelf(), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression(bd.Order(C("a").Asc()).As("t1"))), + }, + selectTestCase{ + ds: bd.As("alias").FromSelf(), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression(bd.As("alias"))), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - ds = a.UnionAll(b.Limit(1)) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) `+ - `UNION ALL (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10) LIMIT 1) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestJoin() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.Join(T("foo"), On(C("a").IsNull())), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + JoinsAppend( + exp.NewConditionedJoinExpression(exp.InnerJoinType, T("foo"), On(C("a").IsNull())), + ), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1000), int64(10), int64(1)}, args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) `+ - `UNION ALL (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?) LIMIT ?) AS "t1")`, - selectSQL, - ) +func (sds *selectDatasetSuite) TestInnerJoin() { - ds = a.UnionAll(b.Order(C("id").Desc())) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) `+ - `UNION ALL (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10) ORDER BY "id" DESC) AS "t1")`, - selectSQL, + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.InnerJoin(T("foo"), On(C("a").IsNull())), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + JoinsAppend( + exp.NewConditionedJoinExpression(exp.InnerJoinType, T("foo"), On(C("a").IsNull())), + ), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1000), int64(10)}, args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) `+ - `UNION ALL (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?) ORDER BY "id" DESC) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestFullOuterJoin() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.FullOuterJoin(T("foo"), On(C("a").IsNull())), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + JoinsAppend( + exp.NewConditionedJoinExpression(exp.FullOuterJoinType, T("foo"), On(C("a").IsNull())), + ), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - ds = a.Limit(1).UnionAll(b.Order(C("id").Desc())) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM (`+ - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) LIMIT 1) AS "t1" `+ - `UNION ALL (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10) ORDER BY "id" DESC) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestRightOuterJoin() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.RightOuterJoin(T("foo"), On(C("a").IsNull())), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + JoinsAppend( + exp.NewConditionedJoinExpression(exp.RightOuterJoinType, T("foo"), On(C("a").IsNull())), + ), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1000), int64(1), int64(10)}, args) - sds.Equal( - `SELECT * FROM (`+ - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) LIMIT ?) AS "t1" `+ - `UNION ALL (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?) ORDER BY "id" DESC) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestLeftOuterJoin() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.LeftOuterJoin(T("foo"), On(C("a").IsNull())), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + JoinsAppend( + exp.NewConditionedJoinExpression(exp.LeftOuterJoinType, T("foo"), On(C("a").IsNull())), + ), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - ds = a.UnionAll(b).UnionAll(b.Where(C("id").Lt(50))) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) `+ - `UNION ALL (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10)) `+ - `UNION ALL (SELECT "id", "amount" FROM "invoice" WHERE (("amount" < 10) AND ("id" < 50)))`, - selectSQL, +func (sds *selectDatasetSuite) TestFullJoin() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.FullJoin(T("foo"), On(C("a").IsNull())), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + JoinsAppend( + exp.NewConditionedJoinExpression(exp.FullJoinType, T("foo"), On(C("a").IsNull())), + ), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1000), int64(10), int64(10), int64(50)}, args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) `+ - `UNION ALL (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?)) `+ - `UNION ALL (SELECT "id", "amount" FROM "invoice" WHERE (("amount" < ?) AND ("id" < ?)))`, - selectSQL, +func (sds *selectDatasetSuite) TestRightJoin() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.RightJoin(T("foo"), On(C("a").IsNull())), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + JoinsAppend( + exp.NewConditionedJoinExpression(exp.RightJoinType, T("foo"), On(C("a").IsNull())), + ), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) - } -func (sds *selectDatasetSuite) TestIntersect() { - uds := From("union_test") - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.CompoundsAppend(exp.NewCompoundExpression(exp.IntersectCompoundType, uds)) - sds.Equal(ec, ds.Intersect(uds).GetClauses()) - sds.Equal(dsc, ds.GetClauses()) +func (sds *selectDatasetSuite) TestLeftJoin() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.LeftJoin(T("foo"), On(C("a").IsNull())), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + JoinsAppend( + exp.NewConditionedJoinExpression(exp.LeftJoinType, T("foo"), On(C("a").IsNull())), + ), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, + ) } -func (sds *selectDatasetSuite) TestIntersect_ToSQL() { - a := From("invoice").Select("id", "amount").Where(C("amount").Gt(1000)) - b := From("invoice").Select("id", "amount").Where(C("amount").Lt(10)) - - ds := a.Intersect(b) - selectSQL, args, err := ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) `+ - `INTERSECT (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10))`, - selectSQL, +func (sds *selectDatasetSuite) TestNaturalJoin() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.NaturalJoin(T("foo")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + JoinsAppend( + exp.NewUnConditionedJoinExpression(exp.NaturalJoinType, T("foo")), + ), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - ds = a.Limit(1).Intersect(b) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM (`+ - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) LIMIT 1) AS "t1" `+ - `INTERSECT (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10))`, - selectSQL, +func (sds *selectDatasetSuite) TestNaturalLeftJoin() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.NaturalLeftJoin(T("foo")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + JoinsAppend( + exp.NewUnConditionedJoinExpression(exp.NaturalLeftJoinType, T("foo")), + ), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - ds = a.Order(C("id").Asc()).Intersect(b) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM `+ - `(SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) ORDER BY "id" ASC) AS "t1" `+ - `INTERSECT (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10))`, - selectSQL, +func (sds *selectDatasetSuite) TestNaturalRightJoin() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.NaturalRightJoin(T("foo")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + JoinsAppend( + exp.NewUnConditionedJoinExpression(exp.NaturalRightJoinType, T("foo")), + ), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1000), int64(10)}, args) - sds.Equal( - `SELECT * FROM `+ - `(SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) ORDER BY "id" ASC) AS "t1" `+ - `INTERSECT (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?))`, - selectSQL, +func (sds *selectDatasetSuite) TestNaturalFullJoin() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.NaturalFullJoin(T("foo")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + JoinsAppend( + exp.NewUnConditionedJoinExpression(exp.NaturalFullJoinType, T("foo")), + ), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - ds = a.Intersect(b.Limit(1)) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) `+ - `INTERSECT (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10) LIMIT 1) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestCrossJoin() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.CrossJoin(T("foo")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + JoinsAppend( + exp.NewUnConditionedJoinExpression(exp.CrossJoinType, T("foo")), + ), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1000), int64(10), int64(1)}, args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) `+ - `INTERSECT (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?) LIMIT ?) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestWhere() { + w := Ex{"a": 1} + w2 := Ex{"b": "c"} + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.Where(w), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + WhereAppend(w), + }, + selectTestCase{ + ds: bd.Where(w).Where(w2), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + WhereAppend(w).WhereAppend(w2), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - ds = a.Intersect(b.Order(C("id").Desc())) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) `+ - `INTERSECT (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10) ORDER BY "id" DESC) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestClearWhere() { + w := Ex{"a": 1} + bd := From("test").Where(w) + sds.assertCases( + selectTestCase{ + ds: bd.ClearWhere(), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")).WhereAppend(w), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1000), int64(10)}, args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) `+ - `INTERSECT (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?) ORDER BY "id" DESC) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestForUpdate() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.ForUpdate(NoWait), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForUpdate, NoWait)), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - ds = a.Limit(1).Intersect(b.Order(C("id").Desc())) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM (`+ - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) LIMIT 1) AS "t1" `+ - `INTERSECT (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10) ORDER BY "id" DESC) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestForNoKeyUpdate() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.ForNoKeyUpdate(NoWait), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForNoKeyUpdate, NoWait)), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1000), int64(1), int64(10)}, args) - sds.Equal( - `SELECT * FROM (`+ - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) LIMIT ?) AS "t1" `+ - `INTERSECT (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?) ORDER BY "id" DESC) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestForKeyShare() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.ForKeyShare(NoWait), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForKeyShare, NoWait)), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - ds = a.Intersect(b).Intersect(b.Where(C("id").Lt(50))) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) `+ - `INTERSECT (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10)) `+ - `INTERSECT (SELECT "id", "amount" FROM "invoice" WHERE (("amount" < 10) AND ("id" < 50)))`, - selectSQL, +func (sds *selectDatasetSuite) TestForShare() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.ForShare(NoWait), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForShare, NoWait)), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal(args, []interface{}{int64(1000), int64(10), int64(10), int64(50)}) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) `+ - `INTERSECT (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?)) `+ - `INTERSECT (SELECT "id", "amount" FROM "invoice" WHERE (("amount" < ?) AND ("id" < ?)))`, - selectSQL, +func (sds *selectDatasetSuite) TestGroupBy() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.GroupBy("a"), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetGroupBy(exp.NewColumnListExpression("a")), + }, + selectTestCase{ + ds: bd.GroupBy("a").GroupBy("b"), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetGroupBy(exp.NewColumnListExpression("b")), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) } -func (sds *selectDatasetSuite) TestIntersectAll() { - uds := From("union_test") +func (sds *selectDatasetSuite) TestHaving() { ds := From("test") dsc := ds.GetClauses() - ec := dsc.CompoundsAppend(exp.NewCompoundExpression(exp.IntersectAllCompoundType, uds)) - sds.Equal(ec, ds.IntersectAll(uds).GetClauses()) + h := C("a").Gt(1) + ec := dsc.HavingAppend(h) + sds.Equal(ec, ds.Having(h).GetClauses()) sds.Equal(dsc, ds.GetClauses()) -} -func (sds *selectDatasetSuite) TestIntersectAll_ToSQL() { - a := From("invoice").Select("id", "amount").Where(C("amount").Gt(1000)) - b := From("invoice").Select("id", "amount").Where(C("amount").Lt(10)) + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.Having(C("a").Gt(1)), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + HavingAppend(C("a").Gt(1)), + }, + selectTestCase{ + ds: bd.Having(C("a").Gt(1)).Having(Ex{"b": "c"}), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + HavingAppend(C("a").Gt(1)).HavingAppend(Ex{"b": "c"}), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, + ) +} - ds := a.IntersectAll(b) - selectSQL, args, err := ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) `+ - `INTERSECT ALL (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10))`, - selectSQL, +func (sds *selectDatasetSuite) TestOrder() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.Order(C("a").Asc()), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetOrder(C("a").Asc()), + }, + selectTestCase{ + ds: bd.Order(C("a").Asc()).Order(C("b").Asc()), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetOrder(C("b").Asc()), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - ds = a.Limit(1).IntersectAll(b) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM (`+ - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) LIMIT 1) AS "t1" `+ - `INTERSECT ALL (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10))`, - selectSQL, +func (sds *selectDatasetSuite) TestOrderAppend() { + bd := From("test").Order(C("a").Asc()) + sds.assertCases( + selectTestCase{ + ds: bd.OrderAppend(C("b").Asc()), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetOrder(C("a").Asc(), C("b").Asc()), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")). + SetOrder(C("a").Asc()), + }, + ) +} + +func (sds *selectDatasetSuite) TestOrderPrepend() { + bd := From("test").Order(C("a").Asc()) + sds.assertCases( + selectTestCase{ + ds: bd.OrderPrepend(C("b").Asc()), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetOrder(C("b").Asc(), C("a").Asc()), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")). + SetOrder(C("a").Asc()), + }, ) +} - ds = a.Order(C("id").Asc()).IntersectAll(b) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM `+ - `(SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) ORDER BY "id" ASC) AS "t1" `+ - `INTERSECT ALL (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10))`, - selectSQL, +func (sds *selectDatasetSuite) TestClearOrder() { + bd := From("test").Order(C("a").Asc()) + sds.assertCases( + selectTestCase{ + ds: bd.ClearOrder(), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")). + SetOrder(C("a").Asc()), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1000), int64(10)}, args) - sds.Equal( - `SELECT * FROM `+ - `(SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) ORDER BY "id" ASC) AS "t1" `+ - `INTERSECT ALL (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?))`, - selectSQL, +func (sds *selectDatasetSuite) TestLimit() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.Limit(10), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLimit(uint(10)), + }, + selectTestCase{ + ds: bd.Limit(0), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")), + }, + selectTestCase{ + ds: bd.Limit(10).Limit(2), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLimit(uint(2)), + }, + selectTestCase{ + ds: bd.Limit(10).Limit(0), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - ds = a.IntersectAll(b.Limit(1)) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) `+ - `INTERSECT ALL (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10) LIMIT 1) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestLimitAll() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.LimitAll(), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLimit(L("ALL")), + }, + selectTestCase{ + ds: bd.Limit(10).LimitAll(), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLimit(L("ALL")), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal([]interface{}{int64(1000), int64(10), int64(1)}, args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) `+ - `INTERSECT ALL (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?) LIMIT ?) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestClearLimit() { + bd := From("test").Limit(10) + sds.assertCases( + selectTestCase{ + ds: bd.ClearLimit(), + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLimit(uint(10)), + }, ) +} - ds = a.IntersectAll(b.Order(C("id").Desc())) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) `+ - `INTERSECT ALL (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10) ORDER BY "id" DESC) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestOffset() { + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.Offset(10), + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")).SetOffset(10), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal(args, []interface{}{int64(1000), int64(10)}) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) `+ - `INTERSECT ALL (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?) ORDER BY "id" DESC) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestClearOffset() { + bd := From("test").Offset(10) + sds.assertCases( + selectTestCase{ + ds: bd.ClearOffset(), + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")).SetOffset(10), + }, ) +} - ds = a.Limit(1).IntersectAll(b.Order(C("id").Desc())) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT * FROM (`+ - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) LIMIT 1) AS "t1" `+ - `INTERSECT ALL (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10) ORDER BY "id" DESC) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestUnion() { + uds := From("union_test") + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.Union(uds), + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")). + CompoundsAppend(exp.NewCompoundExpression(exp.UnionCompoundType, uds)), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal(args, []interface{}{int64(1000), int64(1), int64(10)}) - sds.Equal( - `SELECT * FROM (`+ - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) LIMIT ?) AS "t1" `+ - `INTERSECT ALL (SELECT * FROM (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?) ORDER BY "id" DESC) AS "t1")`, - selectSQL, +func (sds *selectDatasetSuite) TestUnionAll() { + uds := From("union_test") + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.UnionAll(uds), + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")). + CompoundsAppend(exp.NewCompoundExpression(exp.UnionAllCompoundType, uds)), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - ds = a.IntersectAll(b).IntersectAll(b.Where(C("id").Lt(50))) - selectSQL, args, err = ds.ToSQL() - sds.NoError(err) - sds.Empty(args) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > 1000) `+ - `INTERSECT ALL (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < 10)) `+ - `INTERSECT ALL (SELECT "id", "amount" FROM "invoice" WHERE (("amount" < 10) AND ("id" < 50)))`, - selectSQL, +func (sds *selectDatasetSuite) TestIntersect() { + uds := From("union_test") + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.Intersect(uds), + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")). + CompoundsAppend(exp.NewCompoundExpression(exp.IntersectCompoundType, uds)), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) +} - selectSQL, args, err = ds.Prepared(true).ToSQL() - sds.NoError(err) - sds.Equal(args, []interface{}{int64(1000), int64(10), int64(10), int64(50)}) - sds.Equal( - `SELECT "id", "amount" FROM "invoice" WHERE ("amount" > ?) `+ - `INTERSECT ALL (SELECT "id", "amount" FROM "invoice" WHERE ("amount" < ?)) `+ - `INTERSECT ALL (SELECT "id", "amount" FROM "invoice" WHERE (("amount" < ?) AND ("id" < ?)))`, - selectSQL, +func (sds *selectDatasetSuite) TestIntersectAll() { + uds := From("union_test") + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.IntersectAll(uds), + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")). + CompoundsAppend(exp.NewCompoundExpression(exp.IntersectAllCompoundType, uds)), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, ) } func (sds *selectDatasetSuite) TestAs() { - ds := From("test") - dsc := ds.GetClauses() - ec := dsc.SetAlias(T("a")) - sds.Equal(ec, ds.As("a").GetClauses()) - sds.Equal(dsc, ds.GetClauses()) + bd := From("test") + sds.assertCases( + selectTestCase{ + ds: bd.As("t"), + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")). + SetAlias(T("t")), + }, + selectTestCase{ + ds: bd, + clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), + }, + ) } func (sds *selectDatasetSuite) TestToSQL() { @@ -2108,6 +1026,19 @@ func (sds *selectDatasetSuite) TestToSQL() { md.AssertExpectations(sds.T()) } +func (sds *selectDatasetSuite) TestToSQL_prepared() { + md := new(mocks.SQLDialect) + ds := From("test").Prepared(true).SetDialect(md) + c := ds.GetClauses() + sqlB := sb.NewSQLBuilder(true) + md.On("ToSelectSQL", sqlB, c).Return(nil).Once() + sql, args, err := ds.ToSQL() + sds.Empty(sql) + sds.Empty(args) + sds.Nil(err) + md.AssertExpectations(sds.T()) +} + func (sds *selectDatasetSuite) TestToSQL_ReturnedError() { md := new(mocks.SQLDialect) ds := From("test").SetDialect(md) @@ -2176,6 +1107,9 @@ func (sds *selectDatasetSuite) TestScanStructs() { "goqu: type must be a pointer to a slice when scanning into structs") sds.EqualError(ds.From("items").Select("test").ScanStructs(&items), `goqu: unable to find corresponding field to column "test" returned by query`) + + ds = newDataset("mock", nil) + sds.Equal(errQueryFactoryNotFoundError, ds.From("items").ScanStructs(items)) } func (sds *selectDatasetSuite) TestScanStructs_WithPreparedStatements() { @@ -2255,6 +1189,10 @@ func (sds *selectDatasetSuite) TestScanStruct() { sds.EqualError(err, "goqu: type must be a pointer to a struct when scanning into a struct") _, err = ds.From("items").Select("test").ScanStruct(&item) sds.EqualError(err, `goqu: unable to find corresponding field to column "test" returned by query`) + + ds = newDataset("mock", nil) + _, err = ds.From("items").ScanStruct(item) + sds.Equal(errQueryFactoryNotFoundError, err) } func (sds *selectDatasetSuite) TestScanStruct_WithPreparedStatements() { @@ -2311,6 +1249,10 @@ func (sds *selectDatasetSuite) TestScanVals() { "goqu: type must be a pointer to a slice when scanning into vals") sds.EqualError(ds.From("items").ScanVals(dsTestActionItem{}), "goqu: type must be a pointer to a slice when scanning into vals") + + ds = newDataset("mock", nil) + err = ds.From("items").ScanVals(&ids) + sds.Equal(errQueryFactoryNotFoundError, err) } func (sds *selectDatasetSuite) TestScanVals_WithPreparedStatment() { @@ -2359,6 +1301,10 @@ func (sds *selectDatasetSuite) TestScanVal() { found, err = ds.From("items").ScanVal(10) sds.False(found) sds.EqualError(err, "goqu: type must be a pointer when scanning into val") + + ds = newDataset("mock", nil) + _, err = ds.From("items").ScanVal(&id) + sds.Equal(errQueryFactoryNotFoundError, err) } func (sds *selectDatasetSuite) TestScanVal_WithPreparedStatement() { diff --git a/sql_dialect.go b/sql_dialect.go index 50b9cc45..1f28c319 100644 --- a/sql_dialect.go +++ b/sql_dialect.go @@ -1,20 +1,15 @@ package goqu import ( - "database/sql/driver" - "reflect" - "strconv" "strings" - "time" - "unicode/utf8" "github.com/doug-martin/goqu/v8/exp" - "github.com/doug-martin/goqu/v8/internal/errors" "github.com/doug-martin/goqu/v8/internal/sb" - "github.com/doug-martin/goqu/v8/internal/util" + "github.com/doug-martin/goqu/v8/sqlgen" ) type ( + SQLDialectOptions = sqlgen.SQLDialectOptions // An adapter interface to be used by a Dataset to generate SQL for a specific dialect. // See DefaultAdapter for a concrete implementation and examples. SQLDialect interface { @@ -31,74 +26,19 @@ type ( sqlDialect struct { dialect string dialectOptions *SQLDialectOptions + selectGen sqlgen.SelectSQLGenerator + updateGen sqlgen.UpdateSQLGenerator + insertGen sqlgen.InsertSQLGenerator + deleteGen sqlgen.DeleteSQLGenerator + truncateGen sqlgen.TruncateSQLGenerator } ) var ( - replacementRune = '?' - dialects = make(map[string]SQLDialect) - TrueLiteral = exp.NewLiteralExpression("TRUE") - FalseLiteral = exp.NewLiteralExpression("FALSE") - - errNoUpdatedValuesProvided = errors.New("no update values provided") - errConflictUpdateValuesRequired = errors.New("values are required for on conflict update expression") - errNoSourceForUpdate = errors.New("no source found when generating update sql") - errNoSourceForInsert = errors.New("no source found when generating insert sql") - errNoSourceForDelete = errors.New("no source found when generating delete sql") - errNoSourceForTruncate = errors.New("no source found when generating truncate sql") - errNoSetValuesForUpdate = errors.New("no set values found when generating UPDATE sql") - errEmptyIdentifier = errors.New(`a empty identifier was encountered, please specify a "schema", "table" or "column"`) + dialects = make(map[string]SQLDialect) + DefaultDialectOptions = sqlgen.DefaultDialectOptions ) -func errNotSupportedFragment(sqlType string, f SQLFragmentType) error { - return errors.New("unsupported %s SQL fragment %s", sqlType, f) -} - -func errNotSupportedJoinType(j exp.JoinExpression) error { - return errors.New("dialect does not support %v", j.JoinType()) -} - -func errJoinConditionRequired(j exp.JoinExpression) error { - return errors.New("join condition required for conditioned join %v", j.JoinType()) -} - -func errMisMatchedRowLength(expectedL, actualL int) error { - return errors.New("rows with different value length expected %d got %d", expectedL, actualL) -} - -func errUnsupportedExpressionType(e exp.Expression) error { - return errors.New("unsupported expression type %T", e) -} - -func errUnsupportedIdentifierExpression(t interface{}) error { - return errors.New("unexpected col type must be string or LiteralExpression received %T", t) -} - -func errUnsupportedBooleanExpressionOperator(op exp.BooleanOperation) error { - return errors.New("boolean operator '%+v' not supported", op) -} - -func errUnsupportedRangeExpressionOperator(op exp.RangeOperation) error { - return errors.New("range operator %+v not supported", op) -} - -func errCTENotSupported(dialect string) error { - return errors.New("dialect does not support CTE WITH clause [dialect=%s]", dialect) -} -func errRecursiveCTENotSupported(dialect string) error { - return errors.New("dialect does not support CTE WITH RECURSIVE clause [dialect=%s]", dialect) -} -func errUpsertWithWhereNotSupported(dialect string) error { - return errors.New("dialect does not support upsert with where clause [dialect=%s]", dialect) -} -func errReturnNotSupported(dialect string) error { - return errors.New("dialect does not support RETURNING clause [dialect=%s]", dialect) -} - -func errDistinctOnNotSupported(dialect string) error { - return errors.New("dialect does not support DISTINCT ON clause [dialect=%s]", dialect) -} - func init() { RegisterDialect("default", DefaultDialectOptions()) } @@ -121,1078 +61,37 @@ func GetDialect(name string) SQLDialect { } func newDialect(dialect string, do *SQLDialectOptions) SQLDialect { - return &sqlDialect{dialect: dialect, dialectOptions: do} + return &sqlDialect{ + dialect: dialect, + dialectOptions: do, + selectGen: sqlgen.NewSelectSQLGenerator(dialect, do), + updateGen: sqlgen.NewUpdateSQLGenerator(dialect, do), + insertGen: sqlgen.NewInsertSQLGenerator(dialect, do), + deleteGen: sqlgen.NewDeleteSQLGenerator(dialect, do), + truncateGen: sqlgen.NewTruncateSQLGenerator(dialect, do), + } } func (d *sqlDialect) Dialect() string { return d.dialect } -func (d *sqlDialect) SupportsReturn() bool { - return d.dialectOptions.SupportsReturn -} - -func (d *sqlDialect) SupportsOrderByOnUpdate() bool { - return d.dialectOptions.SupportsOrderByOnUpdate -} - -func (d *sqlDialect) SupportsLimitOnUpdate() bool { - return d.dialectOptions.SupportsLimitOnUpdate -} - -func (d *sqlDialect) SupportsOrderByOnDelete() bool { - return d.dialectOptions.SupportsOrderByOnDelete -} -func (d *sqlDialect) SupportsLimitOnDelete() bool { - return d.dialectOptions.SupportsLimitOnDelete -} func (d *sqlDialect) ToSelectSQL(b sb.SQLBuilder, clauses exp.SelectClauses) { - for _, f := range d.dialectOptions.SelectSQLOrder { - if b.Error() != nil { - return - } - switch f { - case CommonTableSQLFragment: - d.CommonTablesSQL(b, clauses.CommonTables()) - case SelectSQLFragment: - d.SelectSQL(b, clauses) - case FromSQLFragment: - d.FromSQL(b, clauses.From()) - case JoinSQLFragment: - d.JoinSQL(b, clauses.Joins()) - case WhereSQLFragment: - d.WhereSQL(b, clauses.Where()) - case GroupBySQLFragment: - d.GroupBySQL(b, clauses.GroupBy()) - case HavingSQLFragment: - d.HavingSQL(b, clauses.Having()) - case CompoundsSQLFragment: - d.CompoundsSQL(b, clauses.Compounds()) - case OrderSQLFragment: - d.OrderSQL(b, clauses.Order()) - case LimitSQLFragment: - d.LimitSQL(b, clauses.Limit()) - case OffsetSQLFragment: - d.OffsetSQL(b, clauses.Offset()) - case ForSQLFragment: - d.ForSQL(b, clauses.Lock()) - default: - b.SetError(errNotSupportedFragment("SELECT", f)) - } - } + d.selectGen.Generate(b, clauses) } func (d *sqlDialect) ToUpdateSQL(b sb.SQLBuilder, clauses exp.UpdateClauses) { - if !clauses.HasTable() { - b.SetError(errNoSourceForUpdate) - return - } - if !clauses.HasSetValues() { - b.SetError(errNoSetValuesForUpdate) - return - } - if !d.dialectOptions.SupportsMultipleUpdateTables && clauses.HasFrom() { - b.SetError(errors.New("%s dialect does not support multiple tables in UPDATE", d.dialect)) - } - updates, err := exp.NewUpdateExpressions(clauses.SetValues()) - if err != nil { - b.SetError(err) - return - } - for _, f := range d.dialectOptions.UpdateSQLOrder { - if b.Error() != nil { - return - } - switch f { - case CommonTableSQLFragment: - d.CommonTablesSQL(b, clauses.CommonTables()) - case UpdateBeginSQLFragment: - d.UpdateBeginSQL(b) - case SourcesSQLFragment: - d.updateTableSQL(b, clauses) - case UpdateSQLFragment: - d.UpdateExpressionsSQL(b, updates...) - case UpdateFromSQLFragment: - d.updateFromSQL(b, clauses.From()) - case WhereSQLFragment: - d.WhereSQL(b, clauses.Where()) - case OrderSQLFragment: - if d.dialectOptions.SupportsOrderByOnUpdate { - d.OrderSQL(b, clauses.Order()) - } - case LimitSQLFragment: - if d.dialectOptions.SupportsLimitOnUpdate { - d.LimitSQL(b, clauses.Limit()) - } - case ReturningSQLFragment: - d.ReturningSQL(b, clauses.Returning()) - default: - b.SetError(errNotSupportedFragment("UPDATE", f)) - } - } + d.updateGen.Generate(b, clauses) } -func (d *sqlDialect) ToInsertSQL( - b sb.SQLBuilder, - clauses exp.InsertClauses, -) { - if !clauses.HasInto() { - b.SetError(errNoSourceForInsert) - return - } - for _, f := range d.dialectOptions.InsertSQLOrder { - if b.Error() != nil { - return - } - switch f { - case CommonTableSQLFragment: - d.CommonTablesSQL(b, clauses.CommonTables()) - case InsertBeingSQLFragment: - d.InsertBeginSQL(b, clauses.OnConflict()) - case IntoSQLFragment: - b.WriteRunes(d.dialectOptions.SpaceRune) - d.Literal(b, clauses.Into()) - case InsertSQLFragment: - d.InsertSQL(b, clauses) - case ReturningSQLFragment: - d.ReturningSQL(b, clauses.Returning()) - default: - b.SetError(errNotSupportedFragment("INSERT", f)) - } - } - +func (d *sqlDialect) ToInsertSQL(b sb.SQLBuilder, clauses exp.InsertClauses) { + d.insertGen.Generate(b, clauses) } func (d *sqlDialect) ToDeleteSQL(b sb.SQLBuilder, clauses exp.DeleteClauses) { - if !clauses.HasFrom() { - b.SetError(errNoSourceForDelete) - return - } - for _, f := range d.dialectOptions.DeleteSQLOrder { - if b.Error() != nil { - return - } - switch f { - case CommonTableSQLFragment: - d.CommonTablesSQL(b, clauses.CommonTables()) - case DeleteBeginSQLFragment: - d.DeleteBeginSQL(b) - case FromSQLFragment: - d.FromSQL(b, exp.NewColumnListExpression(clauses.From())) - case WhereSQLFragment: - d.WhereSQL(b, clauses.Where()) - case OrderSQLFragment: - if d.dialectOptions.SupportsOrderByOnDelete { - d.OrderSQL(b, clauses.Order()) - } - case LimitSQLFragment: - if d.dialectOptions.SupportsLimitOnDelete { - d.LimitSQL(b, clauses.Limit()) - } - case ReturningSQLFragment: - d.ReturningSQL(b, clauses.Returning()) - default: - b.SetError(errNotSupportedFragment("DELETE", f)) - } - } + d.deleteGen.Generate(b, clauses) } func (d *sqlDialect) ToTruncateSQL(b sb.SQLBuilder, clauses exp.TruncateClauses) { - if !clauses.HasTable() { - b.SetError(errNoSourceForTruncate) - return - } - for _, f := range d.dialectOptions.TruncateSQLOrder { - if b.Error() != nil { - return - } - switch f { - case TruncateSQLFragment: - d.TruncateSQL(b, clauses.Table(), clauses.Options()) - default: - b.SetError(errNotSupportedFragment("TRUNCATE", f)) - } - } -} - -// Adds the correct fragment to being an UPDATE statement -func (d *sqlDialect) UpdateBeginSQL(b sb.SQLBuilder) { - b.Write(d.dialectOptions.UpdateClause) -} - -// Adds the correct fragment to being an INSERT statement -func (d *sqlDialect) InsertBeginSQL(b sb.SQLBuilder, o exp.ConflictExpression) { - if d.dialectOptions.SupportsInsertIgnoreSyntax && o != nil { - b.Write(d.dialectOptions.InsertIgnoreClause) - } else { - b.Write(d.dialectOptions.InsertClause) - } -} - -// Adds the correct fragment to being an DELETE statement -func (d *sqlDialect) DeleteBeginSQL(b sb.SQLBuilder) { - b.Write(d.dialectOptions.DeleteClause) -} - -// Generates a TRUNCATE statement -func (d *sqlDialect) TruncateSQL(b sb.SQLBuilder, from exp.ColumnListExpression, opts exp.TruncateOptions) { - b.Write(d.dialectOptions.TruncateClause) - d.SourcesSQL(b, from) - if opts.Identity != d.dialectOptions.EmptyString { - b.WriteRunes(d.dialectOptions.SpaceRune). - WriteStrings(strings.ToUpper(opts.Identity)). - Write(d.dialectOptions.IdentityFragment) - } - if opts.Cascade { - b.Write(d.dialectOptions.CascadeFragment) - } else if opts.Restrict { - b.Write(d.dialectOptions.RestrictFragment) - } -} - -// Adds the columns list to an insert statement -func (d *sqlDialect) InsertSQL(b sb.SQLBuilder, ic exp.InsertClauses) { - switch { - case ic.HasRows(): - ie, err := exp.NewInsertExpression(ic.Rows()...) - if err != nil { - b.SetError(err) - return - } - d.InsertExpressionSQL(b, ie) - case ic.HasCols() && ic.HasVals(): - d.insertColumnsSQL(b, ic.Cols()) - d.insertValuesSQL(b, ic.Vals()) - case ic.HasCols() && ic.HasFrom(): - d.insertColumnsSQL(b, ic.Cols()) - d.insertFromSQL(b, ic.From()) - case ic.HasFrom(): - d.insertFromSQL(b, ic.From()) - default: - d.defaultValuesSQL(b) - - } - - d.onConflictSQL(b, ic.OnConflict()) -} - -func (d *sqlDialect) InsertExpressionSQL(b sb.SQLBuilder, ie exp.InsertExpression) { - switch { - case ie.IsInsertFrom(): - d.insertFromSQL(b, ie.From()) - case ie.IsEmpty(): - d.defaultValuesSQL(b) - default: - d.insertColumnsSQL(b, ie.Cols()) - d.insertValuesSQL(b, ie.Vals()) - } -} - -// Adds column setters in an update SET clause -func (d *sqlDialect) UpdateExpressionsSQL(b sb.SQLBuilder, updates ...exp.UpdateExpression) { - b.Write(d.dialectOptions.SetFragment) - d.updateValuesSQL(b, updates...) - -} - -// Adds the SELECT clause and columns to a sql statement -func (d *sqlDialect) SelectSQL(b sb.SQLBuilder, clauses exp.SelectClauses) { - b.Write(d.dialectOptions.SelectClause). - WriteRunes(d.dialectOptions.SpaceRune) - dc := clauses.Distinct() - if dc != nil { - b.Write(d.dialectOptions.DistinctFragment) - if !dc.IsEmpty() { - if d.dialectOptions.SupportsDistinctOn { - b.Write(d.dialectOptions.OnFragment).WriteRunes(d.dialectOptions.LeftParenRune) - d.Literal(b, dc) - b.WriteRunes(d.dialectOptions.RightParenRune, d.dialectOptions.SpaceRune) - } else { - b.SetError(errDistinctOnNotSupported(d.dialect)) - return - } - } else { - b.WriteRunes(d.dialectOptions.SpaceRune) - } - } - cols := clauses.Select() - if clauses.IsDefaultSelect() || len(cols.Columns()) == 0 { - b.WriteRunes(d.dialectOptions.StarRune) - } else { - d.Literal(b, cols) - } -} - -func (d *sqlDialect) ReturningSQL(b sb.SQLBuilder, returns exp.ColumnListExpression) { - if returns != nil && len(returns.Columns()) > 0 { - if d.SupportsReturn() { - b.Write(d.dialectOptions.ReturningFragment) - d.Literal(b, returns) - } else { - b.SetError(errReturnNotSupported(d.dialect)) - } - } - -} - -// Adds the FROM clause and tables to an sql statement -func (d *sqlDialect) FromSQL(b sb.SQLBuilder, from exp.ColumnListExpression) { - if from != nil && !from.IsEmpty() { - b.Write(d.dialectOptions.FromFragment) - d.SourcesSQL(b, from) - } -} - -// Adds the generates the SQL for a column list -func (d *sqlDialect) SourcesSQL(b sb.SQLBuilder, from exp.ColumnListExpression) { - b.WriteRunes(d.dialectOptions.SpaceRune) - d.Literal(b, from) -} - -// Generates the JOIN clauses for an SQL statement -func (d *sqlDialect) JoinSQL(b sb.SQLBuilder, joins exp.JoinExpressions) { - if len(joins) > 0 { - for _, j := range joins { - joinType, ok := d.dialectOptions.JoinTypeLookup[j.JoinType()] - if !ok { - b.SetError(errNotSupportedJoinType(j)) - return - } - b.Write(joinType) - d.Literal(b, j.Table()) - if t, ok := j.(exp.ConditionedJoinExpression); ok { - if t.IsConditionEmpty() { - b.SetError(errJoinConditionRequired(j)) - return - } - d.joinConditionSQL(b, t.Condition()) - } - } - } -} - -// Generates the WHERE clause for an SQL statement -func (d *sqlDialect) WhereSQL(b sb.SQLBuilder, where exp.ExpressionList) { - if where != nil && !where.IsEmpty() { - b.Write(d.dialectOptions.WhereFragment) - d.Literal(b, where) - } -} - -// Generates the GROUP BY clause for an SQL statement -func (d *sqlDialect) GroupBySQL(b sb.SQLBuilder, groupBy exp.ColumnListExpression) { - if groupBy != nil && len(groupBy.Columns()) > 0 { - b.Write(d.dialectOptions.GroupByFragment) - d.Literal(b, groupBy) - } -} - -// Generates the HAVING clause for an SQL statement -func (d *sqlDialect) HavingSQL(b sb.SQLBuilder, having exp.ExpressionList) { - if having != nil && len(having.Expressions()) > 0 { - b.Write(d.dialectOptions.HavingFragment) - d.Literal(b, having) - } -} - -// Generates the ORDER BY clause for an SQL statement -func (d *sqlDialect) OrderSQL(b sb.SQLBuilder, order exp.ColumnListExpression) { - if order != nil && len(order.Columns()) > 0 { - b.Write(d.dialectOptions.OrderByFragment) - d.Literal(b, order) - } -} - -// Generates the LIMIT clause for an SQL statement -func (d *sqlDialect) LimitSQL(b sb.SQLBuilder, limit interface{}) { - if limit != nil { - b.Write(d.dialectOptions.LimitFragment) - d.Literal(b, limit) - } -} - -// Generates the OFFSET clause for an SQL statement -func (d *sqlDialect) OffsetSQL(b sb.SQLBuilder, offset uint) { - if offset > 0 { - b.Write(d.dialectOptions.OffsetFragment) - d.Literal(b, offset) - } -} - -// Generates the sql for the WITH clauses for common table expressions (CTE) -func (d *sqlDialect) CommonTablesSQL(b sb.SQLBuilder, ctes []exp.CommonTableExpression) { - if l := len(ctes); l > 0 { - if !d.dialectOptions.SupportsWithCTE { - b.SetError(errCTENotSupported(d.dialect)) - return - } - b.Write(d.dialectOptions.WithFragment) - anyRecursive := false - for _, cte := range ctes { - anyRecursive = anyRecursive || cte.IsRecursive() - } - if anyRecursive { - if !d.dialectOptions.SupportsWithCTERecursive { - b.SetError(errRecursiveCTENotSupported(d.dialect)) - return - } - b.Write(d.dialectOptions.RecursiveFragment) - } - for i, cte := range ctes { - d.Literal(b, cte) - if i < l-1 { - b.WriteRunes(d.dialectOptions.CommaRune, d.dialectOptions.SpaceRune) - } - } - b.WriteRunes(d.dialectOptions.SpaceRune) - } -} - -// Generates the compound sql clause for an SQL statement (e.g. UNION, INTERSECT) -func (d *sqlDialect) CompoundsSQL(b sb.SQLBuilder, compounds []exp.CompoundExpression) { - for _, compound := range compounds { - d.Literal(b, compound) - } -} - -// Generates the FOR (aka "locking") clause for an SQL statement -func (d *sqlDialect) ForSQL(b sb.SQLBuilder, lockingClause exp.Lock) { - if lockingClause == nil { - return - } - switch lockingClause.Strength() { - case exp.ForNolock: - return - case exp.ForUpdate: - b.Write(d.dialectOptions.ForUpdateFragment) - case exp.ForNoKeyUpdate: - b.Write(d.dialectOptions.ForNoKeyUpdateFragment) - case exp.ForShare: - b.Write(d.dialectOptions.ForShareFragment) - case exp.ForKeyShare: - b.Write(d.dialectOptions.ForKeyShareFragment) - } - // the WAIT case is the default in Postgres, and is what you get if you don't specify NOWAIT or - // SKIP LOCKED. There's no special syntax for it in PG, so we don't do anything for it here - switch lockingClause.WaitOption() { - case exp.NoWait: - b.Write(d.dialectOptions.NowaitFragment) - case exp.SkipLocked: - b.Write(d.dialectOptions.SkipLockedFragment) - } -} - -func (d *sqlDialect) Literal(b sb.SQLBuilder, val interface{}) { - if b.Error() != nil { - return - } - if val == nil { - d.literalNil(b) - return - } - switch v := val.(type) { - case exp.Expression: - d.expressionSQL(b, v) - case int: - d.literalInt(b, int64(v)) - case int32: - d.literalInt(b, int64(v)) - case int64: - d.literalInt(b, v) - case float32: - d.literalFloat(b, float64(v)) - case float64: - d.literalFloat(b, v) - case string: - d.literalString(b, v) - case []byte: - d.literalBytes(b, v) - case bool: - d.literalBool(b, v) - case time.Time: - d.literalTime(b, v) - case *time.Time: - if v == nil { - d.literalNil(b) - return - } - d.literalTime(b, *v) - case driver.Valuer: - dVal, err := v.Value() - if err != nil { - b.SetError(errors.New(err.Error())) - return - } - d.Literal(b, dVal) - default: - d.reflectSQL(b, val) - } -} - -// Adds the DefaultValuesFragment to an SQL statement -func (d *sqlDialect) defaultValuesSQL(b sb.SQLBuilder) { - b.Write(d.dialectOptions.DefaultValuesFragment) -} - -func (d *sqlDialect) insertFromSQL(b sb.SQLBuilder, ae exp.AppendableExpression) { - b.WriteRunes(d.dialectOptions.SpaceRune) - ae.AppendSQL(b) -} - -// Adds the columns list to an insert statement -func (d *sqlDialect) insertColumnsSQL(b sb.SQLBuilder, cols exp.ColumnListExpression) { - b.WriteRunes(d.dialectOptions.SpaceRune, d.dialectOptions.LeftParenRune) - d.Literal(b, cols) - b.WriteRunes(d.dialectOptions.RightParenRune) -} - -// Adds the values clause to an SQL statement -func (d *sqlDialect) insertValuesSQL(b sb.SQLBuilder, values [][]interface{}) { - b.Write(d.dialectOptions.ValuesFragment) - rowLen := len(values[0]) - valueLen := len(values) - for i, row := range values { - if len(row) != rowLen { - b.SetError(errMisMatchedRowLength(rowLen, len(row))) - return - } - d.Literal(b, row) - if i < valueLen-1 { - b.WriteRunes(d.dialectOptions.CommaRune, d.dialectOptions.SpaceRune) - } - } -} - -// Adds the DefaultValuesFragment to an SQL statement -func (d *sqlDialect) onConflictSQL(b sb.SQLBuilder, o exp.ConflictExpression) { - if o == nil { - return - } - b.Write(d.dialectOptions.ConflictFragment) - switch t := o.(type) { - case exp.ConflictUpdateExpression: - target := t.TargetColumn() - if d.dialectOptions.SupportsConflictTarget && target != "" { - wrapParens := !strings.HasPrefix(strings.ToLower(target), "on constraint") - - b.WriteRunes(d.dialectOptions.SpaceRune) - if wrapParens { - b.WriteRunes(d.dialectOptions.LeftParenRune). - WriteStrings(target). - WriteRunes(d.dialectOptions.RightParenRune) - } else { - b.Write([]byte(target)) - } - } - d.onConflictDoUpdateSQL(b, t) - default: - b.Write(d.dialectOptions.ConflictDoNothingFragment) - } -} - -func (d *sqlDialect) updateTableSQL(b sb.SQLBuilder, uc exp.UpdateClauses) { - b.WriteRunes(d.dialectOptions.SpaceRune) - d.Literal(b, uc.Table()) - if uc.HasFrom() { - if !d.dialectOptions.UseFromClauseForMultipleUpdateTables { - b.WriteRunes(d.dialectOptions.CommaRune) - d.Literal(b, uc.From()) - } - } -} - -// Adds column setters in an update SET clause -func (d *sqlDialect) updateValuesSQL(b sb.SQLBuilder, updates ...exp.UpdateExpression) { - if len(updates) == 0 { - b.SetError(errNoUpdatedValuesProvided) - return - } - updateLen := len(updates) - for i, update := range updates { - d.Literal(b, update) - if i < updateLen-1 { - b.WriteRunes(d.dialectOptions.CommaRune) - } - } -} - -func (d *sqlDialect) updateFromSQL(b sb.SQLBuilder, ce exp.ColumnListExpression) { - if ce == nil || ce.IsEmpty() { - return - } - if d.dialectOptions.UseFromClauseForMultipleUpdateTables { - d.FromSQL(b, ce) - } -} - -func (d *sqlDialect) onConflictDoUpdateSQL(b sb.SQLBuilder, o exp.ConflictUpdateExpression) { - b.Write(d.dialectOptions.ConflictDoUpdateFragment) - update := o.Update() - if update == nil { - b.SetError(errConflictUpdateValuesRequired) - return - } - ue, err := exp.NewUpdateExpressions(update) - if err != nil { - b.SetError(err) - return - } - d.updateValuesSQL(b, ue...) - if b.Error() == nil && o.WhereClause() != nil { - if !d.dialectOptions.SupportsConflictUpdateWhere { - b.SetError(errUpsertWithWhereNotSupported(d.dialect)) - return - } - d.WhereSQL(b, o.WhereClause()) - } -} - -func (d *sqlDialect) joinConditionSQL(b sb.SQLBuilder, jc exp.JoinCondition) { - switch t := jc.(type) { - case exp.JoinOnCondition: - d.joinOnConditionSQL(b, t) - case exp.JoinUsingCondition: - d.joinUsingConditionSQL(b, t) - } -} - -func (d *sqlDialect) joinUsingConditionSQL(b sb.SQLBuilder, jc exp.JoinUsingCondition) { - b.Write(d.dialectOptions.UsingFragment). - WriteRunes(d.dialectOptions.LeftParenRune) - d.Literal(b, jc.Using()) - b.WriteRunes(d.dialectOptions.RightParenRune) -} - -func (d *sqlDialect) joinOnConditionSQL(b sb.SQLBuilder, jc exp.JoinOnCondition) { - b.Write(d.dialectOptions.OnFragment) - d.Literal(b, jc.On()) -} - -func (d *sqlDialect) reflectSQL(b sb.SQLBuilder, val interface{}) { - v := reflect.Indirect(reflect.ValueOf(val)) - valKind := v.Kind() - switch { - case util.IsInvalid(valKind): - d.literalNil(b) - case util.IsSlice(valKind): - if bs, ok := val.([]byte); ok { - d.Literal(b, bs) - return - } - d.sliceValueSQL(b, v) - case util.IsInt(valKind): - d.Literal(b, v.Int()) - case util.IsUint(valKind): - d.Literal(b, int64(v.Uint())) - case util.IsFloat(valKind): - d.Literal(b, v.Float()) - case util.IsString(valKind): - d.Literal(b, v.String()) - case util.IsBool(valKind): - d.Literal(b, v.Bool()) - default: - b.SetError(errors.NewEncodeError(val)) - } -} - -func (d *sqlDialect) expressionSQL(b sb.SQLBuilder, expression exp.Expression) { - switch e := expression.(type) { - case exp.ColumnListExpression: - d.columnListSQL(b, e) - case exp.ExpressionList: - d.expressionListSQL(b, e) - case exp.LiteralExpression: - d.literalExpressionSQL(b, e) - case exp.IdentifierExpression: - d.quoteIdentifier(b, e) - case exp.AliasedExpression: - d.aliasedExpressionSQL(b, e) - case exp.BooleanExpression: - d.booleanExpressionSQL(b, e) - case exp.RangeExpression: - d.rangeExpressionSQL(b, e) - case exp.OrderedExpression: - d.orderedExpressionSQL(b, e) - case exp.UpdateExpression: - d.updateExpressionSQL(b, e) - case exp.SQLFunctionExpression: - d.sqlFunctionExpressionSQL(b, e) - case exp.CastExpression: - d.castExpressionSQL(b, e) - case exp.AppendableExpression: - d.appendableExpressionSQL(b, e) - case exp.CommonTableExpression: - d.commonTableExpressionSQL(b, e) - case exp.CompoundExpression: - d.compoundExpressionSQL(b, e) - case exp.Ex: - d.expressionMapSQL(b, e) - case exp.ExOr: - d.expressionOrMapSQL(b, e) - default: - b.SetError(errUnsupportedExpressionType(e)) - } -} - -// Generates a placeholder (e.g. ?, $1) -func (d *sqlDialect) placeHolderSQL(b sb.SQLBuilder, i interface{}) { - b.WriteRunes(d.dialectOptions.PlaceHolderRune) - if d.dialectOptions.IncludePlaceholderNum { - b.WriteStrings(strconv.FormatInt(int64(b.CurrentArgPosition()), 10)) - } - b.WriteArg(i) -} - -// Generates creates the sql for a sub select on a Dataset -func (d *sqlDialect) appendableExpressionSQL(b sb.SQLBuilder, a exp.AppendableExpression) { - b.WriteRunes(d.dialectOptions.LeftParenRune) - a.AppendSQL(b) - b.WriteRunes(d.dialectOptions.RightParenRune) - c := a.GetClauses() - if c != nil { - alias := c.Alias() - if alias != nil { - b.Write(d.dialectOptions.AsFragment) - d.Literal(b, alias) - } - } -} - -// Quotes an identifier (e.g. "col", "table"."col" -func (d *sqlDialect) quoteIdentifier(b sb.SQLBuilder, ident exp.IdentifierExpression) { - if ident.IsEmpty() { - b.SetError(errEmptyIdentifier) - return - } - schema, table, col := ident.GetSchema(), ident.GetTable(), ident.GetCol() - if schema != d.dialectOptions.EmptyString { - b.WriteRunes(d.dialectOptions.QuoteRune). - WriteStrings(schema). - WriteRunes(d.dialectOptions.QuoteRune) - } - if table != d.dialectOptions.EmptyString { - if schema != d.dialectOptions.EmptyString { - b.WriteRunes(d.dialectOptions.PeriodRune) - } - b.WriteRunes(d.dialectOptions.QuoteRune). - WriteStrings(table). - WriteRunes(d.dialectOptions.QuoteRune) - } - switch t := col.(type) { - case nil: - case string: - if col != d.dialectOptions.EmptyString { - if table != d.dialectOptions.EmptyString || schema != d.dialectOptions.EmptyString { - b.WriteRunes(d.dialectOptions.PeriodRune) - } - b.WriteRunes(d.dialectOptions.QuoteRune). - WriteStrings(t). - WriteRunes(d.dialectOptions.QuoteRune) - } - case exp.LiteralExpression: - if table != d.dialectOptions.EmptyString || schema != d.dialectOptions.EmptyString { - b.WriteRunes(d.dialectOptions.PeriodRune) - } - d.Literal(b, t) - default: - b.SetError(errUnsupportedIdentifierExpression(col)) - } -} - -// Generates SQL NULL value -func (d *sqlDialect) literalNil(b sb.SQLBuilder) { - b.Write(d.dialectOptions.Null) -} - -// Generates SQL bool literal, (e.g. TRUE, FALSE, mysql 1, 0, sqlite3 1, 0) -func (d *sqlDialect) literalBool(b sb.SQLBuilder, bl bool) { - if b.IsPrepared() { - d.placeHolderSQL(b, bl) - return - } - if bl { - b.Write(d.dialectOptions.True) - } else { - b.Write(d.dialectOptions.False) - } -} - -// Generates SQL for a time.Time value -func (d *sqlDialect) literalTime(b sb.SQLBuilder, t time.Time) { - if b.IsPrepared() { - d.placeHolderSQL(b, t) - return - } - d.Literal(b, t.Format(d.dialectOptions.TimeFormat)) -} - -// Generates SQL for a Float Value -func (d *sqlDialect) literalFloat(b sb.SQLBuilder, f float64) { - if b.IsPrepared() { - d.placeHolderSQL(b, f) - return - } - b.WriteStrings(strconv.FormatFloat(f, 'f', -1, 64)) -} - -// Generates SQL for an int value -func (d *sqlDialect) literalInt(b sb.SQLBuilder, i int64) { - if b.IsPrepared() { - d.placeHolderSQL(b, i) - return - } - b.WriteStrings(strconv.FormatInt(i, 10)) -} - -// Generates SQL for a string -func (d *sqlDialect) literalString(b sb.SQLBuilder, s string) { - if b.IsPrepared() { - d.placeHolderSQL(b, s) - return - } - b.WriteRunes(d.dialectOptions.StringQuote) - for _, char := range s { - if e, ok := d.dialectOptions.EscapedRunes[char]; ok { - b.Write(e) - } else { - b.WriteRunes(char) - } - } - - b.WriteRunes(d.dialectOptions.StringQuote) -} - -// Generates SQL for a slice of bytes -func (d *sqlDialect) literalBytes(b sb.SQLBuilder, bs []byte) { - if b.IsPrepared() { - d.placeHolderSQL(b, bs) - return - } - b.WriteRunes(d.dialectOptions.StringQuote) - i := 0 - for len(bs) > 0 { - char, l := utf8.DecodeRune(bs) - if e, ok := d.dialectOptions.EscapedRunes[char]; ok { - b.Write(e) - } else { - b.WriteRunes(char) - } - i++ - bs = bs[l:] - } - b.WriteRunes(d.dialectOptions.StringQuote) -} - -// Generates SQL for a slice of values (e.g. []int64{1,2,3,4} -> (1,2,3,4) -func (d *sqlDialect) sliceValueSQL(b sb.SQLBuilder, slice reflect.Value) { - b.WriteRunes(d.dialectOptions.LeftParenRune) - for i, l := 0, slice.Len(); i < l; i++ { - d.Literal(b, slice.Index(i).Interface()) - if i < l-1 { - b.WriteRunes(d.dialectOptions.CommaRune, d.dialectOptions.SpaceRune) - } - } - b.WriteRunes(d.dialectOptions.RightParenRune) -} - -// Generates SQL for an AliasedExpression (e.g. I("a").As("b") -> "a" AS "b") -func (d *sqlDialect) aliasedExpressionSQL(b sb.SQLBuilder, aliased exp.AliasedExpression) { - d.Literal(b, aliased.Aliased()) - b.Write(d.dialectOptions.AsFragment) - d.Literal(b, aliased.GetAs()) -} - -// Generates SQL for a BooleanExpresion (e.g. I("a").Eq(2) -> "a" = 2) -func (d *sqlDialect) booleanExpressionSQL(b sb.SQLBuilder, operator exp.BooleanExpression) { - b.WriteRunes(d.dialectOptions.LeftParenRune) - d.Literal(b, operator.LHS()) - b.WriteRunes(d.dialectOptions.SpaceRune) - operatorOp := operator.Op() - if val, ok := d.dialectOptions.BooleanOperatorLookup[operatorOp]; ok { - b.Write(val) - } else { - b.SetError(errUnsupportedBooleanExpressionOperator(operatorOp)) - return - } - rhs := operator.RHS() - if (operatorOp == exp.IsOp || operatorOp == exp.IsNotOp) && d.dialectOptions.UseLiteralIsBools { - if rhs == true { - rhs = TrueLiteral - } else if rhs == false { - rhs = FalseLiteral - } - } - b.WriteRunes(d.dialectOptions.SpaceRune) - d.Literal(b, rhs) - b.WriteRunes(d.dialectOptions.RightParenRune) -} - -// Generates SQL for a RangeExpresion (e.g. I("a").Between(RangeVal{Start:2,End:5}) -> "a" BETWEEN 2 AND 5) -func (d *sqlDialect) rangeExpressionSQL(b sb.SQLBuilder, operator exp.RangeExpression) { - b.WriteRunes(d.dialectOptions.LeftParenRune) - d.Literal(b, operator.LHS()) - b.WriteRunes(d.dialectOptions.SpaceRune) - operatorOp := operator.Op() - if val, ok := d.dialectOptions.RangeOperatorLookup[operatorOp]; ok { - b.Write(val) - } else { - b.SetError(errUnsupportedRangeExpressionOperator(operatorOp)) - return - } - rhs := operator.RHS() - b.WriteRunes(d.dialectOptions.SpaceRune) - d.Literal(b, rhs.Start()) - b.Write(d.dialectOptions.AndFragment) - d.Literal(b, rhs.End()) - b.WriteRunes(d.dialectOptions.RightParenRune) -} - -// Generates SQL for an OrderedExpression (e.g. I("a").Asc() -> "a" ASC) -func (d *sqlDialect) orderedExpressionSQL(b sb.SQLBuilder, order exp.OrderedExpression) { - d.Literal(b, order.SortExpression()) - if order.IsAsc() { - b.Write(d.dialectOptions.AscFragment) - } else { - b.Write(d.dialectOptions.DescFragment) - } - switch order.NullSortType() { - case exp.NullsFirstSortType: - b.Write(d.dialectOptions.NullsFirstFragment) - case exp.NullsLastSortType: - b.Write(d.dialectOptions.NullsLastFragment) - } -} - -// Generates SQL for an ExpressionList (e.g. And(I("a").Eq("a"), I("b").Eq("b")) -> (("a" = 'a') AND ("b" = 'b'))) -func (d *sqlDialect) expressionListSQL(b sb.SQLBuilder, expressionList exp.ExpressionList) { - if expressionList.IsEmpty() { - return - } - var op []byte - if expressionList.Type() == exp.AndType { - op = d.dialectOptions.AndFragment - } else { - op = d.dialectOptions.OrFragment - } - exps := expressionList.Expressions() - expLen := len(exps) - 1 - needsAppending := expLen > 0 - if needsAppending { - b.WriteRunes(d.dialectOptions.LeftParenRune) - } else { - d.Literal(b, exps[0]) - return - } - for i, e := range exps { - d.Literal(b, e) - if i < expLen { - b.Write(op) - } - } - b.WriteRunes(d.dialectOptions.RightParenRune) -} - -// Generates SQL for a ColumnListExpression -func (d *sqlDialect) columnListSQL(b sb.SQLBuilder, columnList exp.ColumnListExpression) { - cols := columnList.Columns() - colLen := len(cols) - for i, col := range cols { - d.Literal(b, col) - if i < colLen-1 { - b.WriteRunes(d.dialectOptions.CommaRune, d.dialectOptions.SpaceRune) - } - } -} - -// Generates SQL for an UpdateEpxresion -func (d *sqlDialect) updateExpressionSQL(b sb.SQLBuilder, update exp.UpdateExpression) { - d.Literal(b, update.Col()) - b.WriteRunes(d.dialectOptions.SetOperatorRune) - d.Literal(b, update.Val()) -} - -// Generates SQL for a LiteralExpression -// L("a + b") -> a + b -// L("a = ?", 1) -> a = 1 -func (d *sqlDialect) literalExpressionSQL(b sb.SQLBuilder, literal exp.LiteralExpression) { - lit := literal.Literal() - args := literal.Args() - argsLen := len(args) - if argsLen > 0 { - currIndex := 0 - for _, char := range lit { - if char == replacementRune && currIndex < argsLen { - d.Literal(b, args[currIndex]) - currIndex++ - } else { - b.WriteRunes(char) - } - } - } else { - b.WriteStrings(lit) - } -} - -// Generates SQL for a SQLFunctionExpression -// COUNT(I("a")) -> COUNT("a") -func (d *sqlDialect) sqlFunctionExpressionSQL(b sb.SQLBuilder, sqlFunc exp.SQLFunctionExpression) { - b.WriteStrings(sqlFunc.Name()) - d.Literal(b, sqlFunc.Args()) -} - -// Generates SQL for a CastExpression -// I("a").Cast("NUMERIC") -> CAST("a" AS NUMERIC) -func (d *sqlDialect) castExpressionSQL(b sb.SQLBuilder, cast exp.CastExpression) { - b.Write(d.dialectOptions.CastFragment).WriteRunes(d.dialectOptions.LeftParenRune) - d.Literal(b, cast.Casted()) - b.Write(d.dialectOptions.AsFragment) - d.Literal(b, cast.Type()) - b.WriteRunes(d.dialectOptions.RightParenRune) -} - -// Generates SQL for a CommonTableExpression -func (d *sqlDialect) commonTableExpressionSQL(b sb.SQLBuilder, cte exp.CommonTableExpression) { - d.Literal(b, cte.Name()) - b.Write(d.dialectOptions.AsFragment) - d.Literal(b, cte.SubQuery()) -} - -// Generates SQL for a CompoundExpression -func (d *sqlDialect) compoundExpressionSQL(b sb.SQLBuilder, compound exp.CompoundExpression) { - switch compound.Type() { - case exp.UnionCompoundType: - b.Write(d.dialectOptions.UnionFragment) - case exp.UnionAllCompoundType: - b.Write(d.dialectOptions.UnionAllFragment) - case exp.IntersectCompoundType: - b.Write(d.dialectOptions.IntersectFragment) - case exp.IntersectAllCompoundType: - b.Write(d.dialectOptions.IntersectAllFragment) - } - if d.dialectOptions.WrapCompoundsInParens { - b.WriteRunes(d.dialectOptions.LeftParenRune) - compound.RHS().AppendSQL(b) - b.WriteRunes(d.dialectOptions.RightParenRune) - } else { - compound.RHS().AppendSQL(b) - } - -} - -func (d *sqlDialect) expressionMapSQL(b sb.SQLBuilder, ex exp.Ex) { - expressionList, err := ex.ToExpressions() - if err != nil { - b.SetError(err) - return - } - d.Literal(b, expressionList) -} - -func (d *sqlDialect) expressionOrMapSQL(b sb.SQLBuilder, ex exp.ExOr) { - expressionList, err := ex.ToExpressions() - if err != nil { - b.SetError(err) - return - } - d.Literal(b, expressionList) + d.truncateGen.Generate(b, clauses) } diff --git a/sql_dialect_test.go b/sql_dialect_test.go index 1b86fcaa..4ad2c06e 100644 --- a/sql_dialect_test.go +++ b/sql_dialect_test.go @@ -1,2623 +1,92 @@ package goqu import ( - "database/sql/driver" - "fmt" - "regexp" "testing" - "time" "github.com/doug-martin/goqu/v8/exp" - "github.com/doug-martin/goqu/v8/internal/errors" "github.com/doug-martin/goqu/v8/internal/sb" + "github.com/doug-martin/goqu/v8/sqlgen/mocks" "github.com/stretchr/testify/suite" ) -var emptyArgs = make([]interface{}, 0) - -type testAppendableExpression struct { - exp.AppendableExpression - sql string - args []interface{} - err error - clauses exp.SelectClauses -} - -func newTestAppendableExpression(sql string, args []interface{}, err error, clauses exp.SelectClauses) exp.AppendableExpression { - if clauses == nil { - clauses = exp.NewSelectClauses() - } - return &testAppendableExpression{sql: sql, args: args, err: err, clauses: clauses} -} - -func (tae *testAppendableExpression) Expression() exp.Expression { - return tae -} - -func (tae *testAppendableExpression) GetClauses() exp.SelectClauses { - return tae.clauses -} - -func (tae *testAppendableExpression) Clone() exp.Expression { - return tae -} - -func (tae *testAppendableExpression) AppendSQL(b sb.SQLBuilder) { - if tae.err != nil { - b.SetError(tae.err) - return - } - b.WriteStrings(tae.sql) - if len(tae.args) > 0 { - b.WriteArg(tae.args...) - } -} - type dialectTestSuite struct { suite.Suite } -func (dts *dialectTestSuite) assertNotPreparedSQL(b sb.SQLBuilder, expectedSQL string) { - actualSQL, actualArgs, err := b.ToSQL() - dts.NoError(err) - dts.Equal(expectedSQL, actualSQL) - dts.Empty(actualArgs) -} - -func (dts *dialectTestSuite) assertPreparedSQL( - b sb.SQLBuilder, - expectedSQL string, - expectedArgs []interface{}, -) { - actualSQL, actualArgs, err := b.ToSQL() - dts.NoError(err) - dts.Equal(expectedSQL, actualSQL) - dts.Equal(expectedArgs, actualArgs) -} - -func (dts *dialectTestSuite) assertErrorSQL(b sb.SQLBuilder, errMsg string) { - actualSQL, actualArgs, err := b.ToSQL() - dts.EqualError(err, errMsg) - dts.Empty(actualSQL) - dts.Empty(actualArgs) -} - -func (dts *dialectTestSuite) TestSupportsReturn() { - opts := DefaultDialectOptions() - opts.SupportsReturn = true - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - opts2.SupportsReturn = false - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - dts.True(d.SupportsReturn()) - dts.False(d2.SupportsReturn()) -} - -func (dts *dialectTestSuite) TestSupportsOrderByOnUpdate() { - opts := DefaultDialectOptions() - opts.SupportsOrderByOnUpdate = true - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - opts2.SupportsOrderByOnUpdate = false - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - dts.True(d.SupportsOrderByOnUpdate()) - dts.False(d2.SupportsOrderByOnUpdate()) -} - -func (dts *dialectTestSuite) TestSupportsLimitOnUpdate() { - opts := DefaultDialectOptions() - opts.SupportsLimitOnUpdate = true - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - opts2.SupportsLimitOnUpdate = false - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - dts.True(d.SupportsLimitOnUpdate()) - dts.False(d2.SupportsLimitOnUpdate()) -} - -func (dts *dialectTestSuite) TestSupportsOrderByOnDelete() { - opts := DefaultDialectOptions() - opts.SupportsOrderByOnDelete = true - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - opts2.SupportsOrderByOnDelete = false - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - dts.True(d.SupportsOrderByOnDelete()) - dts.False(d2.SupportsOrderByOnDelete()) -} - -func (dts *dialectTestSuite) TestSupportsLimitOnDelete() { - opts := DefaultDialectOptions() - opts.SupportsLimitOnDelete = true - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - opts2.SupportsLimitOnDelete = false - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - dts.True(d.SupportsLimitOnDelete()) - dts.False(d2.SupportsLimitOnDelete()) -} - -func (dts *dialectTestSuite) TestToTruncateSQL() { - opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - opts2.TruncateClause = []byte("truncate") - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - tables := exp.NewColumnListExpression("a") - tc := exp.NewTruncateClauses().SetTable(tables) - b := sb.NewSQLBuilder(false) - - d.ToTruncateSQL(b, tc) - dts.assertNotPreparedSQL(b, `TRUNCATE "a"`) - - d2.ToTruncateSQL(b.Clear(), tc) - dts.assertNotPreparedSQL(b, `truncate "a"`) - - b = sb.NewSQLBuilder(true) - d.ToTruncateSQL(b, tc) - dts.assertPreparedSQL(b, `TRUNCATE "a"`, emptyArgs) - - d2.ToTruncateSQL(b.Clear(), tc) - dts.assertPreparedSQL(b, `truncate "a"`, emptyArgs) -} - -func (dts *dialectTestSuite) TestToTruncateSQL_UnsupportedFragment() { - opts := DefaultDialectOptions() - opts.TruncateSQLOrder = []SQLFragmentType{UpdateBeginSQLFragment} - d := sqlDialect{dialect: "test", dialectOptions: opts} - - b := sb.NewSQLBuilder(true) - d.ToTruncateSQL(b, exp.NewTruncateClauses().SetTable(exp.NewColumnListExpression("a"))) - dts.assertErrorSQL(b, `goqu: unsupported TRUNCATE SQL fragment UpdateBeginSQLFragment`) -} - -func (dts *dialectTestSuite) TestToTruncateSQL_WithErroredBuilder() { - opts := DefaultDialectOptions() - opts.TruncateSQLOrder = []SQLFragmentType{UpdateBeginSQLFragment} - d := sqlDialect{dialect: "test", dialectOptions: opts} - - b := sb.NewSQLBuilder(true).SetError(errors.New("expected error")) - d.ToTruncateSQL(b, exp.NewTruncateClauses().SetTable(exp.NewColumnListExpression("a"))) - dts.assertErrorSQL(b, `goqu: expected error`) -} - -func (dts *dialectTestSuite) TestToTruncateSQL_withoutTable() { - opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} - b := sb.NewSQLBuilder(false) - - d.ToTruncateSQL(b, exp.NewTruncateClauses()) - dts.assertErrorSQL(b, "goqu: no source found when generating truncate sql") -} - -func (dts *dialectTestSuite) TestToTruncateSQL_WithCascade() { - opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - opts2.TruncateClause = []byte("truncate") - opts2.CascadeFragment = []byte(" cascade") - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - tables := exp.NewColumnListExpression("a") - tc := exp.NewTruncateClauses().SetTable(tables) - b := sb.NewSQLBuilder(false) - - d.ToTruncateSQL(b.Clear(), tc.SetOptions(exp.TruncateOptions{Cascade: true})) - dts.assertNotPreparedSQL(b, `TRUNCATE "a" CASCADE`) - - d2.ToTruncateSQL(b.Clear(), tc.SetOptions(exp.TruncateOptions{Cascade: true})) - dts.assertNotPreparedSQL(b, `truncate "a" cascade`) - - b = sb.NewSQLBuilder(true) - - d.ToTruncateSQL(b.Clear(), tc.SetOptions(exp.TruncateOptions{Cascade: true})) - dts.assertPreparedSQL(b, `TRUNCATE "a" CASCADE`, emptyArgs) - - d2.ToTruncateSQL(b.Clear(), tc.SetOptions(exp.TruncateOptions{Cascade: true})) - dts.assertPreparedSQL(b, `truncate "a" cascade`, emptyArgs) -} - -func (dts *dialectTestSuite) TestToTruncateSQL_WithRestrict() { - opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - opts2.TruncateClause = []byte("truncate") - opts2.RestrictFragment = []byte(" restrict") - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - tables := exp.NewColumnListExpression("a") - tc := exp.NewTruncateClauses().SetTable(tables) - b := sb.NewSQLBuilder(false) - - d.ToTruncateSQL(b.Clear(), tc.SetOptions(exp.TruncateOptions{Restrict: true})) - dts.assertNotPreparedSQL(b, `TRUNCATE "a" RESTRICT`) - - d2.ToTruncateSQL(b.Clear(), tc.SetOptions(exp.TruncateOptions{Restrict: true})) - dts.assertNotPreparedSQL(b, `truncate "a" restrict`) - - b = sb.NewSQLBuilder(true) - - d.ToTruncateSQL(b.Clear(), tc.SetOptions(exp.TruncateOptions{Restrict: true})) - dts.assertPreparedSQL(b, `TRUNCATE "a" RESTRICT`, emptyArgs) - - d2.ToTruncateSQL(b.Clear(), tc.SetOptions(exp.TruncateOptions{Restrict: true})) - dts.assertPreparedSQL(b, `truncate "a" restrict`, emptyArgs) -} - -func (dts *dialectTestSuite) TestToTruncateSQL_WithRestart() { +func (dts *dialectTestSuite) TestDialect() { opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - opts2.TruncateClause = []byte("truncate") - opts2.IdentityFragment = []byte(" identity") - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - tables := exp.NewColumnListExpression("a") - tc := exp.NewTruncateClauses().SetTable(tables) - b := sb.NewSQLBuilder(false) - - d.ToTruncateSQL(b.Clear(), tc.SetOptions(exp.TruncateOptions{Identity: "restart"})) - dts.assertNotPreparedSQL(b, `TRUNCATE "a" RESTART IDENTITY`) + sm := new(mocks.SelectSQLGenerator) + d := sqlDialect{dialect: "test", dialectOptions: opts, selectGen: sm} - d2.ToTruncateSQL(b.Clear(), tc.SetOptions(exp.TruncateOptions{Identity: "restart"})) - dts.assertNotPreparedSQL(b, `truncate "a" RESTART identity`) - - b = sb.NewSQLBuilder(true) - - d.ToTruncateSQL(b.Clear(), tc.SetOptions(exp.TruncateOptions{Identity: "restart"})) - dts.assertPreparedSQL(b, `TRUNCATE "a" RESTART IDENTITY`, emptyArgs) - - d2.ToTruncateSQL(b.Clear(), tc.SetOptions(exp.TruncateOptions{Identity: "restart"})) - dts.assertPreparedSQL(b, `truncate "a" RESTART identity`, emptyArgs) + dts.Equal("test", d.Dialect()) } -func (dts *dialectTestSuite) TestToInsertSQL_UnsupportedFragment() { +func (dts *dialectTestSuite) TestToSelectSQL() { opts := DefaultDialectOptions() - opts.InsertSQLOrder = []SQLFragmentType{UpdateBeginSQLFragment} - d := sqlDialect{dialect: "test", dialectOptions: opts} + sm := new(mocks.SelectSQLGenerator) + d := sqlDialect{dialect: "test", dialectOptions: opts, selectGen: sm} b := sb.NewSQLBuilder(true) - ic := exp.NewInsertClauses(). - SetInto(exp.NewIdentifierExpression("", "test", "")) - d.ToInsertSQL(b, ic) - dts.assertErrorSQL(b, `goqu: unsupported INSERT SQL fragment UpdateBeginSQLFragment`) -} - -func (dts *dialectTestSuite) TestToInsertSQL_empty() { - opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - opts2.DefaultValuesFragment = []byte(" default values") - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - ic := exp.NewInsertClauses(). - SetInto(exp.NewIdentifierExpression("", "test", "")) - - b := sb.NewSQLBuilder(false) - d.ToInsertSQL(b, ic) - dts.assertNotPreparedSQL(b, `INSERT INTO "test" DEFAULT VALUES`) - - d2.ToInsertSQL(b.Clear(), ic) - dts.assertNotPreparedSQL(b, `INSERT INTO "test" default values`) -} - -func (dts *dialectTestSuite) TestToInsertSQL_nilValues() { - opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - ic := exp.NewInsertClauses(). - SetInto(exp.NewIdentifierExpression("", "test", "")). - SetCols(exp.NewColumnListExpression("a")). - SetVals([][]interface{}{ - {nil}, - }) - - b := sb.NewSQLBuilder(false) - d.ToInsertSQL(b, ic) - dts.assertNotPreparedSQL(b, `INSERT INTO "test" ("a") VALUES (NULL)`) - - d2.ToInsertSQL(b.Clear(), ic) - dts.assertNotPreparedSQL(b, `INSERT INTO "test" ("a") VALUES (NULL)`) -} - -func (dts *dialectTestSuite) TestToInsertSQL_colsAndVals() { - opts := DefaultDialectOptions() - opts.LeftParenRune = '{' - opts.RightParenRune = '}' - opts.ValuesFragment = []byte(" values ") - opts.LeftParenRune = '{' - opts.RightParenRune = '}' - opts.CommaRune = ';' - opts.PlaceHolderRune = '#' - d := sqlDialect{dialect: "test", dialectOptions: opts} - - ic := exp.NewInsertClauses(). - SetInto(exp.NewIdentifierExpression("", "test", "")). - SetCols(exp.NewColumnListExpression("a", "b")). - SetVals([][]interface{}{ - {"a1", "b1"}, - {"a2", "b2"}, - {"a3", "b3"}, - }) - - bic := ic.SetCols(exp.NewColumnListExpression("a", "b")). - SetVals([][]interface{}{ - {"a1"}, - {"a2", "b2"}, - {"a3", "b3"}, - }) - - b := sb.NewSQLBuilder(false) - d.ToInsertSQL(b, ic) - dts.assertNotPreparedSQL(b, `INSERT INTO "test" {"a"; "b"} values {'a1'; 'b1'}; {'a2'; 'b2'}; {'a3'; 'b3'}`) - - b = sb.NewSQLBuilder(true) - d.ToInsertSQL(b, ic) - dts.assertPreparedSQL(b, `INSERT INTO "test" {"a"; "b"} values {#; #}; {#; #}; {#; #}`, []interface{}{ - "a1", "b1", "a2", "b2", "a3", "b3", - }) - - d.ToInsertSQL(b.Clear(), bic) - dts.assertErrorSQL(b, "goqu: rows with different value length expected 1 got 2") -} - -func (dts *dialectTestSuite) TestToInsertSQL_withNoInto() { - opts := DefaultDialectOptions() - opts.LeftParenRune = '{' - opts.RightParenRune = '}' - opts.ValuesFragment = []byte(" values ") - opts.LeftParenRune = '{' - opts.RightParenRune = '}' - opts.CommaRune = ';' - opts.PlaceHolderRune = '#' - d := sqlDialect{dialect: "test", dialectOptions: opts} - - ic := exp.NewInsertClauses(). - SetCols(exp.NewColumnListExpression("a", "b")). - SetVals([][]interface{}{ - {"a1", "b1"}, - {"a2", "b2"}, - {"a3", "b3"}, - }) - - b := sb.NewSQLBuilder(false) - d.ToInsertSQL(b.Clear(), ic) - dts.assertErrorSQL(b, "goqu: no source found when generating insert sql") -} + sc := exp.NewSelectClauses() + sm.On("Generate", b, sc).Return(nil).Once() -func (dts *dialectTestSuite) TestToInsertSQL_withRows() { - opts := DefaultDialectOptions() - opts.LeftParenRune = '{' - opts.RightParenRune = '}' - opts.ValuesFragment = []byte(" values ") - opts.LeftParenRune = '{' - opts.RightParenRune = '}' - opts.CommaRune = ';' - opts.PlaceHolderRune = '#' - d := sqlDialect{dialect: "test", dialectOptions: opts} - - ic := exp.NewInsertClauses(). - SetInto(exp.NewIdentifierExpression("", "test", "")). - SetRows([]interface{}{ - exp.Record{"a": "a1", "b": "b1"}, - exp.Record{"a": "a2", "b": "b2"}, - exp.Record{"a": "a3", "b": "b3"}, - }) - - bic := ic. - SetRows([]interface{}{ - exp.Record{"a": "a1"}, - exp.Record{"a": "a2", "b": "b2"}, - exp.Record{"a": "a3", "b": "b3"}, - }) - - b := sb.NewSQLBuilder(false) - d.ToInsertSQL(b, ic) - dts.assertNotPreparedSQL(b, `INSERT INTO "test" {"a"; "b"} values {'a1'; 'b1'}; {'a2'; 'b2'}; {'a3'; 'b3'}`) - - b = sb.NewSQLBuilder(true) - d.ToInsertSQL(b, ic) - dts.assertPreparedSQL(b, `INSERT INTO "test" {"a"; "b"} values {#; #}; {#; #}; {#; #}`, []interface{}{ - "a1", "b1", "a2", "b2", "a3", "b3", - }) - - d.ToInsertSQL(b.Clear(), bic) - dts.assertErrorSQL(b, "goqu: rows with different value length expected 1 got 2") -} - -func (dts *dialectTestSuite) TestToInsertSQL_withRowsAppendableExpression() { - opts := DefaultDialectOptions() - opts.LeftParenRune = '{' - opts.RightParenRune = '}' - opts.ValuesFragment = []byte(" values ") - opts.LeftParenRune = '{' - opts.RightParenRune = '}' - opts.CommaRune = ';' - opts.PlaceHolderRune = '#' - d := sqlDialect{dialect: "test", dialectOptions: opts} - - ic := exp.NewInsertClauses(). - SetInto(exp.NewIdentifierExpression("", "test", "")). - SetRows([]interface{}{newTestAppendableExpression(`select * from "other"`, emptyArgs, nil, nil)}) - - b := sb.NewSQLBuilder(false) - d.ToInsertSQL(b, ic) - dts.assertNotPreparedSQL(b, `INSERT INTO "test" select * from "other"`) - - b = sb.NewSQLBuilder(true) - d.ToInsertSQL(b, ic) - dts.assertPreparedSQL(b, `INSERT INTO "test" select * from "other"`, emptyArgs) -} - -func (dts *dialectTestSuite) TestToInsertSQL_withFrom() { - opts := DefaultDialectOptions() - opts.LeftParenRune = '{' - opts.RightParenRune = '}' - opts.ValuesFragment = []byte(" values ") - opts.LeftParenRune = '{' - opts.RightParenRune = '}' - opts.CommaRune = ';' - opts.PlaceHolderRune = '#' - d := sqlDialect{dialect: "test", dialectOptions: opts} - - ic := exp.NewInsertClauses(). - SetInto(exp.NewIdentifierExpression("", "test", "")). - SetFrom(newTestAppendableExpression(`select c, d from test where a = 'b'`, nil, nil, nil)) - - b := sb.NewSQLBuilder(false) - d.ToInsertSQL(b.Clear(), ic) - dts.assertNotPreparedSQL(b, `INSERT INTO "test" select c, d from test where a = 'b'`) - - b = sb.NewSQLBuilder(true) - d.ToInsertSQL(b.Clear(), ic) - dts.assertPreparedSQL(b, `INSERT INTO "test" select c, d from test where a = 'b'`, emptyArgs) - - ic = ic.SetCols(exp.NewColumnListExpression("a", "b")) - - b = sb.NewSQLBuilder(false) - d.ToInsertSQL(b.Clear(), ic) - dts.assertNotPreparedSQL(b, `INSERT INTO "test" {"a"; "b"} select c, d from test where a = 'b'`) - - b = sb.NewSQLBuilder(true) - d.ToInsertSQL(b.Clear(), ic) - dts.assertPreparedSQL(b, `INSERT INTO "test" {"a"; "b"} select c, d from test where a = 'b'`, emptyArgs) -} - -func (dts *dialectTestSuite) TestToInsertSQL_onConflict() { - opts := DefaultDialectOptions() - // make sure the fragments are used - opts.ConflictFragment = []byte(" on conflict") - opts.ConflictDoNothingFragment = []byte(" do nothing") - opts.ConflictDoUpdateFragment = []byte(" do update set ") - d := sqlDialect{dialect: "test", dialectOptions: opts} - - icnoc := exp.NewInsertClauses(). - SetInto(exp.NewIdentifierExpression("", "test", "")). - SetCols(exp.NewColumnListExpression("a")). - SetVals([][]interface{}{ - {"a1"}, - {"a2"}, - {"a3"}, - }) - - icdn := icnoc.SetOnConflict(DoNothing()) - icdu := icnoc.SetOnConflict(DoUpdate("test", exp.Record{"a": "b"})) - icdoc := icnoc.SetOnConflict(DoUpdate("on constraint test", exp.Record{"a": "b"})) - icduw := icnoc.SetOnConflict( - exp.NewDoUpdateConflictExpression("test", exp.Record{"a": "b"}).Where(exp.Ex{"foo": true}), - ) - - b := sb.NewSQLBuilder(false) - d.ToInsertSQL(b, icnoc) - dts.assertNotPreparedSQL(b, `INSERT INTO "test" ("a") VALUES ('a1'), ('a2'), ('a3')`) - - d.ToInsertSQL(b.Clear(), icdn) - dts.assertNotPreparedSQL(b, `INSERT INTO "test" ("a") VALUES ('a1'), ('a2'), ('a3') on conflict do nothing`) - - d.ToInsertSQL(b.Clear(), icdu) - dts.assertNotPreparedSQL( - b, - `INSERT INTO "test" ("a") VALUES ('a1'), ('a2'), ('a3') on conflict (test) do update set "a"='b'`, - ) - - d.ToInsertSQL(b.Clear(), icdoc) - dts.assertNotPreparedSQL( - b, - `INSERT INTO "test" ("a") VALUES ('a1'), ('a2'), ('a3') on conflict on constraint test do update set "a"='b'`, - ) - - d.ToInsertSQL(b.Clear(), icduw) - dts.assertNotPreparedSQL(b, - `INSERT INTO "test" ("a") VALUES ('a1'), ('a2'), ('a3') on conflict (test) do update set "a"='b' WHERE ("foo" IS TRUE)`, - ) - - b = sb.NewSQLBuilder(true) - d.ToInsertSQL(b, icdn) - dts.assertPreparedSQL(b, `INSERT INTO "test" ("a") VALUES (?), (?), (?) on conflict do nothing`, []interface{}{ - "a1", "a2", "a3", - }) - - d.ToInsertSQL(b.Clear(), icdu) - dts.assertPreparedSQL( - b, - `INSERT INTO "test" ("a") VALUES (?), (?), (?) on conflict (test) do update set "a"=?`, - []interface{}{"a1", "a2", "a3", "b"}, - ) - - d.ToInsertSQL(b.Clear(), icduw) - dts.assertPreparedSQL( - b, - `INSERT INTO "test" ("a") VALUES (?), (?), (?) on conflict (test) do update set "a"=? WHERE ("foo" IS TRUE)`, - []interface{}{"a1", "a2", "a3", "b"}, - ) -} - -func (dts *dialectTestSuite) TestToInsertSQL_withSupportsInsertIgnoreSyntax() { - opts := DefaultDialectOptions() - // make sure the fragments are used - opts.SupportsInsertIgnoreSyntax = true - opts.InsertIgnoreClause = []byte("insert ignore into") - d := sqlDialect{dialect: "test", dialectOptions: opts} - - icnoc := exp.NewInsertClauses(). - SetInto(exp.NewIdentifierExpression("", "test", "")). - SetCols(exp.NewColumnListExpression("a")). - SetVals([][]interface{}{ - {"a1"}, - {"a2"}, - {"a3"}, - }) - - icdn := icnoc.SetOnConflict(DoNothing()) - icdu := icnoc.SetOnConflict(DoUpdate("test", exp.Record{"a": "b"})) - icdoc := icnoc.SetOnConflict(DoUpdate("on constraint test", exp.Record{"a": "b"})) - icduw := icnoc.SetOnConflict( - exp.NewDoUpdateConflictExpression("test", exp.Record{"a": "b"}).Where(exp.Ex{"foo": true}), - ) - - b := sb.NewSQLBuilder(false) - - d.ToInsertSQL(b.Clear(), icdu) - dts.assertNotPreparedSQL( - b, - `insert ignore into "test" ("a") VALUES ('a1'), ('a2'), ('a3') ON CONFLICT (test) DO UPDATE SET "a"='b'`, - ) - - d.ToInsertSQL(b.Clear(), icdoc) - dts.assertNotPreparedSQL( - b, - `insert ignore into "test" ("a") VALUES ('a1'), ('a2'), ('a3') ON CONFLICT on constraint test DO UPDATE SET "a"='b'`, - ) - - d.ToInsertSQL(b.Clear(), icduw) - dts.assertNotPreparedSQL(b, - `insert ignore into "test" ("a") VALUES ('a1'), ('a2'), ('a3') ON CONFLICT (test) DO UPDATE SET "a"='b' WHERE ("foo" IS TRUE)`, - ) - - b = sb.NewSQLBuilder(true) - d.ToInsertSQL(b, icdn) - dts.assertPreparedSQL(b, `insert ignore into "test" ("a") VALUES (?), (?), (?) ON CONFLICT DO NOTHING`, []interface{}{ - "a1", "a2", "a3", - }) - - d.ToInsertSQL(b.Clear(), icdu) - dts.assertPreparedSQL( - b, - `insert ignore into "test" ("a") VALUES (?), (?), (?) ON CONFLICT (test) DO UPDATE SET "a"=?`, - []interface{}{"a1", "a2", "a3", "b"}, - ) - - d.ToInsertSQL(b.Clear(), icduw) - dts.assertPreparedSQL( - b, - `insert ignore into "test" ("a") VALUES (?), (?), (?) ON CONFLICT (test) DO UPDATE SET "a"=? WHERE ("foo" IS TRUE)`, - []interface{}{"a1", "a2", "a3", "b"}, - ) + d.ToSelectSQL(b, sc) + sm.AssertExpectations(dts.T()) } -func (dts *dialectTestSuite) TestToInsertSQL_withCommonTables() { +func (dts *dialectTestSuite) TestToUpdateSQL() { opts := DefaultDialectOptions() - opts.WithFragment = []byte("with ") - opts.RecursiveFragment = []byte("recursive ") - d := sqlDialect{dialect: "test", dialectOptions: opts} - tse := newTestAppendableExpression("select * from foo", emptyArgs, nil, nil) - cte1 := exp.NewCommonTableExpression(false, "test_cte", tse) - cte2 := exp.NewCommonTableExpression(true, "test_cte", tse) - - ic := exp.NewInsertClauses(). - SetInto(exp.NewIdentifierExpression("", "test_cte", "")) - - b := sb.NewSQLBuilder(false) - - d.ToInsertSQL(b.Clear(), ic.CommonTablesAppend(cte1)) - dts.assertNotPreparedSQL(b, `with test_cte AS (select * from foo) INSERT INTO "test_cte" DEFAULT VALUES`) - - d.ToInsertSQL(b.Clear(), ic.CommonTablesAppend(cte2)) - dts.assertNotPreparedSQL(b, `with recursive test_cte AS (select * from foo) INSERT INTO "test_cte" DEFAULT VALUES`) - - d.ToInsertSQL(b.Clear(), ic.CommonTablesAppend(cte1).CommonTablesAppend(cte2)) - dts.assertNotPreparedSQL( - b, - `with recursive test_cte AS (select * from foo), test_cte AS (select * from foo) INSERT INTO "test_cte" DEFAULT VALUES`, - ) + um := new(mocks.UpdateSQLGenerator) + d := sqlDialect{dialect: "test", dialectOptions: opts, updateGen: um} - opts = DefaultDialectOptions() - opts.SupportsWithCTE = false - d = sqlDialect{dialect: "test", dialectOptions: opts} - - d.ToInsertSQL(b.Clear(), ic.CommonTablesAppend(cte1)) - dts.assertErrorSQL(b, "goqu: dialect does not support CTE WITH clause [dialect=test]") - - opts = DefaultDialectOptions() - opts.SupportsWithCTERecursive = false - d = sqlDialect{dialect: "test", dialectOptions: opts} - - d.ToInsertSQL(b.Clear(), ic.CommonTablesAppend(cte2)) - dts.assertErrorSQL(b, "goqu: dialect does not support CTE WITH RECURSIVE clause [dialect=test]") - - d.ToInsertSQL(b.Clear(), ic.CommonTablesAppend(cte1)) - dts.assertNotPreparedSQL(b, `WITH test_cte AS (select * from foo) INSERT INTO "test_cte" DEFAULT VALUES`) - -} - -func (dts *dialectTestSuite) TestToUpdateSQL_unsupportedFragment() { - opts := DefaultDialectOptions() - opts.UpdateSQLOrder = []SQLFragmentType{InsertBeingSQLFragment} - d := sqlDialect{dialect: "test", dialectOptions: opts} - uc := exp.NewUpdateClauses(). - SetTable(exp.NewIdentifierExpression("", "test", "")). - SetSetValues(exp.Record{"a": "b", "b": "c"}) b := sb.NewSQLBuilder(true) - - d.ToUpdateSQL(b, uc) - dts.assertErrorSQL(b, `goqu: unsupported UPDATE SQL fragment InsertBeingSQLFragment`) -} - -func (dts *dialectTestSuite) TestToUpdateSQL_empty() { - opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} uc := exp.NewUpdateClauses() + um.On("Generate", b, uc).Return(nil).Once() - b := sb.NewSQLBuilder(false) - d.ToUpdateSQL(b, uc) - dts.Equal(errNoSourceForUpdate, b.Error()) - -} - -func (dts *dialectTestSuite) TestToUpdateSQL_withBadUpdateValues() { - opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} - uc := exp.NewUpdateClauses(). - SetTable(exp.NewIdentifierExpression("", "test", "")). - SetSetValues(true) - - b := sb.NewSQLBuilder(false) - d.ToUpdateSQL(b, uc) - dts.EqualError(b.Error(), "goqu: unsupported update interface type bool") - -} - -func (dts *dialectTestSuite) TestToUpdateSQL_noSetValues() { - opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} - uc := exp.NewUpdateClauses().SetTable(exp.NewIdentifierExpression("", "test", "")) - - b := sb.NewSQLBuilder(false) - d.ToUpdateSQL(b, uc) - dts.Equal(errNoSetValuesForUpdate, b.Error()) -} - -func (dts *dialectTestSuite) TestToUpdateSQL_withFrom() { - opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} - uc := exp.NewUpdateClauses(). - SetTable(exp.NewIdentifierExpression("", "test", "")). - SetSetValues(exp.Record{"foo": "bar"}). - SetFrom(exp.NewColumnListExpression("other_test")) - - b := sb.NewSQLBuilder(false) d.ToUpdateSQL(b, uc) - dts.NoError(b.Error()) - dts.assertNotPreparedSQL(b, `UPDATE "test" SET "foo"='bar' FROM "other_test"`) - - opts = DefaultDialectOptions() - opts.UseFromClauseForMultipleUpdateTables = false - d = sqlDialect{dialect: "test", dialectOptions: opts} - d.ToUpdateSQL(b.Clear(), uc) - dts.NoError(b.Error()) - dts.assertNotPreparedSQL(b, `UPDATE "test","other_test" SET "foo"='bar'`) - - opts = DefaultDialectOptions() - opts.SupportsMultipleUpdateTables = false - d = sqlDialect{dialect: "test", dialectOptions: opts} - d.ToUpdateSQL(b.Clear(), uc) - dts.EqualError(b.Error(), "goqu: test dialect does not support multiple tables in UPDATE") - - opts = DefaultDialectOptions() - opts.SupportsMultipleUpdateTables = false - d = sqlDialect{dialect: "test", dialectOptions: opts} - d.ToUpdateSQL(b.Clear(), uc.SetFrom(nil)) - dts.NoError(b.Error()) - dts.assertNotPreparedSQL(b, `UPDATE "test" SET "foo"='bar'`) - -} - -func (dts *dialectTestSuite) TestToInsertSQL_withReturning() { - opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} - - ic := exp.NewInsertClauses(). - SetInto(exp.NewIdentifierExpression("", "test", "")). - SetCols(exp.NewColumnListExpression("a", "b")). - SetVals([][]interface{}{ - {"a1", "b1"}, - {"a2", "b2"}, - {"a3", "b3"}, - }) - b := sb.NewSQLBuilder(false) - d.ToInsertSQL(b, ic.SetReturning(exp.NewColumnListExpression("a", "b"))) - dts.assertNotPreparedSQL(b, `INSERT INTO "test" ("a", "b") VALUES ('a1', 'b1'), ('a2', 'b2'), ('a3', 'b3') RETURNING "a", "b"`) - - b = sb.NewSQLBuilder(true) - d.ToInsertSQL(b, ic.SetReturning(exp.NewColumnListExpression("a", "b"))) - dts.assertPreparedSQL(b, `INSERT INTO "test" ("a", "b") VALUES (?, ?), (?, ?), (?, ?) RETURNING "a", "b"`, []interface{}{ - "a1", "b1", "a2", "b2", "a3", "b3", - }) -} + um.AssertExpectations(dts.T()) -func (dts *dialectTestSuite) TestToUpdateSQL_withUpdateExpression() { - - opts := DefaultDialectOptions() - // make sure the fragments are used - opts.SetFragment = []byte(" set ") - d := sqlDialect{dialect: "test", dialectOptions: opts} - uc := exp.NewUpdateClauses(). - SetTable(exp.NewIdentifierExpression("", "test", "")) - - b := sb.NewSQLBuilder(false) - d.ToUpdateSQL(b, uc.SetSetValues(exp.Record{"a": "b", "b": "c"})) - dts.assertNotPreparedSQL(b, `UPDATE "test" set "a"='b',"b"='c'`) - - b = sb.NewSQLBuilder(true) - d.ToUpdateSQL(b, uc.SetSetValues(exp.Record{"a": "b", "b": "c"})) - dts.assertPreparedSQL(b, `UPDATE "test" set "a"=?,"b"=?`, []interface{}{"b", "c"}) - - b = sb.NewSQLBuilder(true) - d.ToUpdateSQL(b, uc.SetSetValues(exp.Record{})) - dts.assertErrorSQL(b, errNoUpdatedValuesProvided.Error()) } -func (dts *dialectTestSuite) TestToUpdateSQL_withOrder() { +func (dts *dialectTestSuite) TestToInsertSQL() { opts := DefaultDialectOptions() - opts.SupportsOrderByOnUpdate = true - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - opts2.SupportsOrderByOnUpdate = false - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - uc := exp.NewUpdateClauses(). - SetTable(exp.NewIdentifierExpression("", "test", "")). - SetSetValues(exp.Record{"a": "b", "b": "c"}). - SetOrder(exp.NewIdentifierExpression("", "", "c").Desc()) - - b := sb.NewSQLBuilder(false) - d.ToUpdateSQL(b.Clear(), uc) - dts.assertNotPreparedSQL(b, `UPDATE "test" SET "a"='b',"b"='c' ORDER BY "c" DESC`) + im := new(mocks.InsertSQLGenerator) + d := sqlDialect{dialect: "test", dialectOptions: opts, insertGen: im} - d2.ToUpdateSQL(b.Clear(), uc) - dts.assertNotPreparedSQL(b, `UPDATE "test" SET "a"='b',"b"='c'`) - - b = sb.NewSQLBuilder(true) - d.ToUpdateSQL(b.Clear(), uc) - dts.assertPreparedSQL(b, `UPDATE "test" SET "a"=?,"b"=? ORDER BY "c" DESC`, []interface{}{"b", "c"}) + b := sb.NewSQLBuilder(true) + ic := exp.NewInsertClauses() + im.On("Generate", b, ic).Return(nil).Once() - d2.ToUpdateSQL(b.Clear(), uc) - dts.assertPreparedSQL(b, `UPDATE "test" SET "a"=?,"b"=?`, []interface{}{"b", "c"}) + d.ToInsertSQL(b, ic) + im.AssertExpectations(dts.T()) } -func (dts *dialectTestSuite) TestToUpdateSQL_withLimit() { +func (dts *dialectTestSuite) TestToDeleteSQL() { opts := DefaultDialectOptions() - opts.SupportsLimitOnUpdate = true - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - opts2.SupportsLimitOnUpdate = false - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - uc := exp.NewUpdateClauses(). - SetTable(exp.NewIdentifierExpression("", "test", "")). - SetSetValues(exp.Record{"a": "b", "b": "c"}). - SetLimit(10) - - b := sb.NewSQLBuilder(false) - d.ToUpdateSQL(b.Clear(), uc) - dts.assertNotPreparedSQL(b, `UPDATE "test" SET "a"='b',"b"='c' LIMIT 10`) - - d2.ToUpdateSQL(b.Clear(), uc) - dts.assertNotPreparedSQL(b, `UPDATE "test" SET "a"='b',"b"='c'`) + dm := new(mocks.DeleteSQLGenerator) + d := sqlDialect{dialect: "test", dialectOptions: opts, deleteGen: dm} - b = sb.NewSQLBuilder(true) - d.ToUpdateSQL(b.Clear(), uc) - dts.assertPreparedSQL(b, `UPDATE "test" SET "a"=?,"b"=? LIMIT ?`, []interface{}{"b", "c", int64(10)}) + b := sb.NewSQLBuilder(true) + dc := exp.NewDeleteClauses() + dm.On("Generate", b, dc).Return(nil).Once() - d2.ToUpdateSQL(b.Clear(), uc) - dts.assertPreparedSQL(b, `UPDATE "test" SET "a"=?,"b"=?`, []interface{}{"b", "c"}) + d.ToDeleteSQL(b, dc) + dm.AssertExpectations(dts.T()) } -func (dts *dialectTestSuite) TestToUpdateSQL_withCommonTables() { +func (dts *dialectTestSuite) TestToTruncateSQL() { opts := DefaultDialectOptions() - opts.WithFragment = []byte("with ") - opts.RecursiveFragment = []byte("recursive ") - d := sqlDialect{dialect: "test", dialectOptions: opts} - tse := newTestAppendableExpression("select * from foo", emptyArgs, nil, nil) - cte1 := exp.NewCommonTableExpression(false, "test_cte", tse) - cte2 := exp.NewCommonTableExpression(true, "test_cte", tse) - - uc := exp.NewUpdateClauses(). - SetTable(exp.NewIdentifierExpression("", "test_cte", "")). - SetSetValues(exp.Record{"a": "b", "b": "c"}) - - b := sb.NewSQLBuilder(false) - - d.ToUpdateSQL(b.Clear(), uc.CommonTablesAppend(cte1)) - dts.assertNotPreparedSQL(b, `with test_cte AS (select * from foo) UPDATE "test_cte" SET "a"='b',"b"='c'`) - - d.ToUpdateSQL(b.Clear(), uc.CommonTablesAppend(cte2)) - dts.assertNotPreparedSQL(b, `with recursive test_cte AS (select * from foo) UPDATE "test_cte" SET "a"='b',"b"='c'`) + tm := new(mocks.TruncateSQLGenerator) + d := sqlDialect{dialect: "test", dialectOptions: opts, truncateGen: tm} - d.ToUpdateSQL(b.Clear(), uc.CommonTablesAppend(cte1).CommonTablesAppend(cte2)) - dts.assertNotPreparedSQL( - b, - `with recursive test_cte AS (select * from foo), test_cte AS (select * from foo) UPDATE "test_cte" SET "a"='b',"b"='c'`, - ) - - opts = DefaultDialectOptions() - opts.SupportsWithCTE = false - d = sqlDialect{dialect: "test", dialectOptions: opts} - - d.ToUpdateSQL(b.Clear(), uc.CommonTablesAppend(cte1)) - dts.assertErrorSQL(b, "goqu: dialect does not support CTE WITH clause [dialect=test]") - - opts = DefaultDialectOptions() - opts.SupportsWithCTERecursive = false - d = sqlDialect{dialect: "test", dialectOptions: opts} - - d.ToUpdateSQL(b.Clear(), uc.CommonTablesAppend(cte2)) - dts.assertErrorSQL(b, "goqu: dialect does not support CTE WITH RECURSIVE clause [dialect=test]") - - d.ToUpdateSQL(b.Clear(), uc.CommonTablesAppend(cte1)) - dts.assertNotPreparedSQL(b, `WITH test_cte AS (select * from foo) UPDATE "test_cte" SET "a"='b',"b"='c'`) + b := sb.NewSQLBuilder(true) + tc := exp.NewTruncateClauses() + tm.On("Generate", b, tc).Return(nil).Once() + d.ToTruncateSQL(b, tc) + tm.AssertExpectations(dts.T()) } -func (dts *dialectTestSuite) TestToDeleteSQL() { - opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - opts2.DeleteClause = []byte("delete") - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - dc := exp.NewDeleteClauses().SetFrom(exp.NewIdentifierExpression("", "test", "")) - b := sb.NewSQLBuilder(false) - d.ToDeleteSQL(b, dc) - dts.assertNotPreparedSQL(b, `DELETE FROM "test"`) - - d2.ToDeleteSQL(b.Clear(), dc) - dts.assertNotPreparedSQL(b, `delete FROM "test"`) - - b = sb.NewSQLBuilder(true) - d.ToDeleteSQL(b, dc) - dts.assertNotPreparedSQL(b, `DELETE FROM "test"`) - - d2.ToDeleteSQL(b.Clear(), dc) - dts.assertNotPreparedSQL(b, `delete FROM "test"`) -} - -func (dts *dialectTestSuite) TestToUpdateSQL_withUnsupportedFragment() { - opts := DefaultDialectOptions() - opts.DeleteSQLOrder = []SQLFragmentType{InsertBeingSQLFragment} - d := sqlDialect{dialect: "test", dialectOptions: opts} - dc := exp.NewDeleteClauses().SetFrom(exp.NewIdentifierExpression("", "test", "")) - b := sb.NewSQLBuilder(true) - - d.ToDeleteSQL(b, dc) - dts.assertErrorSQL(b, `goqu: unsupported DELETE SQL fragment InsertBeingSQLFragment`) -} - -func (dts *dialectTestSuite) TestToDeleteSQL_noFrom() { - opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} - - dc := exp.NewDeleteClauses() - b := sb.NewSQLBuilder(false) - d.ToDeleteSQL(b, dc) - dts.assertErrorSQL(b, errNoSourceForDelete.Error()) - - b = sb.NewSQLBuilder(true) - d.ToDeleteSQL(b, dc) - dts.assertErrorSQL(b, errNoSourceForDelete.Error()) -} - -func (dts *dialectTestSuite) TestToDeleteSQL_withErroredBuilder() { - opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} - - dc := exp.NewDeleteClauses().SetFrom(exp.NewIdentifierExpression("", "test", "")) - b := sb.NewSQLBuilder(false).SetError(errors.New("expected error")) - d.ToDeleteSQL(b, dc) - dts.assertErrorSQL(b, "goqu: expected error") - - b = sb.NewSQLBuilder(true).SetError(errors.New("expected error")) - d.ToDeleteSQL(b, dc) - dts.assertErrorSQL(b, "goqu: expected error") -} - -func (dts *dialectTestSuite) TestToDeleteSQL_withWhere() { - opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} - - dc := exp.NewDeleteClauses(). - SetFrom(exp.NewIdentifierExpression("", "test", "")). - WhereAppend(exp.NewLiteralExpression(`"a"=?`, 1)) - b := sb.NewSQLBuilder(false) - d.ToDeleteSQL(b, dc) - dts.assertNotPreparedSQL(b, `DELETE FROM "test" WHERE "a"=1`) - - b = sb.NewSQLBuilder(true) - d.ToDeleteSQL(b, dc) - dts.assertPreparedSQL(b, `DELETE FROM "test" WHERE "a"=?`, []interface{}{ - int64(1), - }) -} - -func (dts *dialectTestSuite) TestToDeleteSQL_withOrder() { - opts := DefaultDialectOptions() - opts.SupportsOrderByOnDelete = true - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - opts2.SupportsOrderByOnDelete = false - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - dc := exp.NewDeleteClauses(). - SetFrom(exp.NewIdentifierExpression("", "test", "")). - SetOrder(exp.NewIdentifierExpression("", "", "c").Desc()) - b := sb.NewSQLBuilder(false) - d.ToDeleteSQL(b.Clear(), dc) - dts.assertNotPreparedSQL(b, `DELETE FROM "test" ORDER BY "c" DESC`) - - d2.ToDeleteSQL(b.Clear(), dc) - dts.assertNotPreparedSQL(b, `DELETE FROM "test"`) - - b = sb.NewSQLBuilder(true) - d.ToDeleteSQL(b.Clear(), dc) - dts.assertPreparedSQL(b, `DELETE FROM "test" ORDER BY "c" DESC`, emptyArgs) - - d2.ToDeleteSQL(b.Clear(), dc) - dts.assertPreparedSQL(b, `DELETE FROM "test"`, emptyArgs) -} - -func (dts *dialectTestSuite) TestToDeleteSQL_withLimit() { - opts := DefaultDialectOptions() - opts.SupportsLimitOnDelete = true - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - opts2.SupportsLimitOnDelete = false - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - dc := exp.NewDeleteClauses(). - SetFrom(exp.NewIdentifierExpression("", "test", "")). - SetLimit(1) - b := sb.NewSQLBuilder(false) - d.ToDeleteSQL(b.Clear(), dc) - dts.assertNotPreparedSQL(b, `DELETE FROM "test" LIMIT 1`) - - d2.ToDeleteSQL(b.Clear(), dc) - dts.assertNotPreparedSQL(b, `DELETE FROM "test"`) - - b = sb.NewSQLBuilder(true) - d.ToDeleteSQL(b.Clear(), dc) - dts.assertPreparedSQL(b, `DELETE FROM "test" LIMIT ?`, []interface{}{int64(1)}) - - d2.ToDeleteSQL(b.Clear(), dc) - dts.assertPreparedSQL(b, `DELETE FROM "test"`, emptyArgs) -} - -func (dts *dialectTestSuite) TestToDeleteSQL_withReturning() { - opts := DefaultDialectOptions() - opts.SupportsReturn = true - d := sqlDialect{dialect: "test", dialectOptions: opts} - - opts2 := DefaultDialectOptions() - opts2.SupportsReturn = false - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - dc := exp.NewDeleteClauses(). - SetFrom(exp.NewIdentifierExpression("", "test", "")). - SetReturning(exp.NewColumnListExpression("a", "b")) - b := sb.NewSQLBuilder(false) - d.ToDeleteSQL(b.Clear(), dc) - dts.assertNotPreparedSQL(b, `DELETE FROM "test" RETURNING "a", "b"`) - - d2.ToDeleteSQL(b.Clear(), dc) - dts.assertErrorSQL(b, `goqu: dialect does not support RETURNING clause [dialect=test]`) - - b = sb.NewSQLBuilder(true) - d.ToDeleteSQL(b.Clear(), dc) - dts.assertPreparedSQL(b, `DELETE FROM "test" RETURNING "a", "b"`, emptyArgs) - - d2.ToDeleteSQL(b.Clear(), dc) - dts.assertErrorSQL(b, `goqu: dialect does not support RETURNING clause [dialect=test]`) -} - -func (dts *dialectTestSuite) TestToDeleteSQL_withCommonTables() { - opts := DefaultDialectOptions() - opts.WithFragment = []byte("with ") - opts.RecursiveFragment = []byte("recursive ") - d := sqlDialect{dialect: "test", dialectOptions: opts} - tse := newTestAppendableExpression("select * from foo", emptyArgs, nil, nil) - cte1 := exp.NewCommonTableExpression(false, "test_cte", tse) - cte2 := exp.NewCommonTableExpression(true, "test_cte", tse) - - dc := exp.NewDeleteClauses(). - SetFrom(exp.NewIdentifierExpression("", "test_cte", "")) - - b := sb.NewSQLBuilder(false) - - d.ToDeleteSQL(b.Clear(), dc.CommonTablesAppend(cte1)) - dts.assertNotPreparedSQL(b, `with test_cte AS (select * from foo) DELETE FROM "test_cte"`) - - d.ToDeleteSQL(b.Clear(), dc.CommonTablesAppend(cte2)) - dts.assertNotPreparedSQL(b, `with recursive test_cte AS (select * from foo) DELETE FROM "test_cte"`) - - d.ToDeleteSQL(b.Clear(), dc.CommonTablesAppend(cte1).CommonTablesAppend(cte2)) - dts.assertNotPreparedSQL( - b, - `with recursive test_cte AS (select * from foo), test_cte AS (select * from foo) DELETE FROM "test_cte"`, - ) - - opts = DefaultDialectOptions() - opts.SupportsWithCTE = false - d = sqlDialect{dialect: "test", dialectOptions: opts} - - d.ToDeleteSQL(b.Clear(), dc.CommonTablesAppend(cte1)) - dts.assertErrorSQL(b, "goqu: dialect does not support CTE WITH clause [dialect=test]") - - opts = DefaultDialectOptions() - opts.SupportsWithCTERecursive = false - d = sqlDialect{dialect: "test", dialectOptions: opts} - - d.ToDeleteSQL(b.Clear(), dc.CommonTablesAppend(cte2)) - dts.assertErrorSQL(b, "goqu: dialect does not support CTE WITH RECURSIVE clause [dialect=test]") - - d.ToDeleteSQL(b.Clear(), dc.CommonTablesAppend(cte1)) - dts.assertNotPreparedSQL(b, `WITH test_cte AS (select * from foo) DELETE FROM "test_cte"`) - -} - -func (dts *dialectTestSuite) TestToSelectSQL() { - opts := DefaultDialectOptions() - // make sure the fragments are used - opts.SelectClause = []byte("select") - opts.StarRune = '#' - d := sqlDialect{dialect: "test", dialectOptions: opts} - sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) - scWithCols := sc.SetSelect(exp.NewColumnListExpression("a", "b")) - b := sb.NewSQLBuilder(false) - - d.ToSelectSQL(b, sc) - dts.assertNotPreparedSQL(b, `select # FROM "test"`) - - d.ToSelectSQL(b.Clear(), scWithCols) - dts.assertNotPreparedSQL(b, `select "a", "b" FROM "test"`) - - b = sb.NewSQLBuilder(true) - d.ToSelectSQL(b, sc) - dts.assertPreparedSQL(b, `select # FROM "test"`, emptyArgs) - - d.ToSelectSQL(b.Clear(), scWithCols) - dts.assertPreparedSQL(b, `select "a", "b" FROM "test"`, emptyArgs) -} - -func (dts *dialectTestSuite) TestToSelectSQL_UnsupportedFragment() { - opts := DefaultDialectOptions() - opts.SelectSQLOrder = []SQLFragmentType{InsertBeingSQLFragment} - d := sqlDialect{dialect: "test", dialectOptions: opts} - - b := sb.NewSQLBuilder(true) - c := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) - d.ToSelectSQL(b, c) - dts.assertErrorSQL(b, `goqu: unsupported SELECT SQL fragment InsertBeingSQLFragment`) -} - -func (dts *dialectTestSuite) TestToSelectSQL_WithErroredBuilder() { - opts := DefaultDialectOptions() - opts.SelectSQLOrder = []SQLFragmentType{InsertBeingSQLFragment} - d := sqlDialect{dialect: "test", dialectOptions: opts} - - b := sb.NewSQLBuilder(true).SetError(errors.New("test error")) - c := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) - d.ToSelectSQL(b, c) - dts.assertErrorSQL(b, `goqu: test error`) -} - -func (dts *dialectTestSuite) TestToSelectSQL_withDistinct() { - opts := DefaultDialectOptions() - // make sure the fragments are used - opts.SelectClause = []byte("select") - opts.StarRune = '#' - opts.DistinctFragment = []byte("distinct") - opts.OnFragment = []byte(" on ") - opts.SupportsDistinctOn = true - d := sqlDialect{dialect: "test", dialectOptions: opts} - - sc := exp.NewSelectClauses().SetDistinct(exp.NewColumnListExpression()) - scDistinctOn := sc.SetDistinct(exp.NewColumnListExpression("a", "b")) - b := sb.NewSQLBuilder(false) - d.SelectSQL(b, sc) - dts.assertNotPreparedSQL(b, `select distinct #`) - - d.SelectSQL(b.Clear(), scDistinctOn) - dts.assertNotPreparedSQL(b, `select distinct on ("a", "b") #`) - - b = sb.NewSQLBuilder(true) - d.SelectSQL(b.Clear(), sc) - dts.assertPreparedSQL(b, `select distinct #`, emptyArgs) - - d.SelectSQL(b.Clear(), scDistinctOn) - dts.assertPreparedSQL(b, `select distinct on ("a", "b") #`, emptyArgs) - - opts = DefaultDialectOptions() - opts.OnFragment = []byte(" on ") - opts.SupportsDistinctOn = false - d = sqlDialect{dialect: "test", dialectOptions: opts} - - b = sb.NewSQLBuilder(false) - d.SelectSQL(b, sc) - dts.assertNotPreparedSQL(b, `SELECT DISTINCT *`) - - d.SelectSQL(b.Clear(), scDistinctOn) - dts.assertErrorSQL(b, "goqu: dialect does not support DISTINCT ON clause [dialect=test]") - - b = sb.NewSQLBuilder(true) - d.SelectSQL(b.Clear(), sc) - dts.assertPreparedSQL(b, `SELECT DISTINCT *`, emptyArgs) - - d.SelectSQL(b.Clear(), scDistinctOn) - dts.assertErrorSQL(b, "goqu: dialect does not support DISTINCT ON clause [dialect=test]") -} - -func (dts *dialectTestSuite) TestToSelectSQL_withFromSQL() { - opts := DefaultDialectOptions() - // make sure the fragments are used - opts.FromFragment = []byte(" from") - d := sqlDialect{dialect: "test", dialectOptions: opts} - sc := exp.NewSelectClauses(). - SetFrom(exp.NewColumnListExpression("a", "b")) - b := sb.NewSQLBuilder(false) - d.ToSelectSQL(b.Clear(), sc) - dts.assertNotPreparedSQL(b, `SELECT * from "a", "b"`) - - b = sb.NewSQLBuilder(true) - d.ToSelectSQL(b.Clear(), sc) - dts.assertPreparedSQL(b, `SELECT * from "a", "b"`, emptyArgs) - - sc = exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression()) - b = sb.NewSQLBuilder(false) - d.ToSelectSQL(b.Clear(), sc) - dts.assertNotPreparedSQL(b, `SELECT *`) - - b = sb.NewSQLBuilder(true) - d.ToSelectSQL(b.Clear(), sc) - dts.assertPreparedSQL(b, `SELECT *`, emptyArgs) -} - -func (dts *dialectTestSuite) TestToSelectSQL_withJoin() { - opts := DefaultDialectOptions() - d := sqlDialect{dialect: "test", dialectOptions: opts} - sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) - ti := exp.NewIdentifierExpression("", "test2", "") - uj := exp.NewUnConditionedJoinExpression(exp.NaturalJoinType, ti) - cjo := exp.NewConditionedJoinExpression(exp.LeftJoinType, ti, exp.NewJoinOnCondition(exp.Ex{"a": "foo"})) - cju := exp.NewConditionedJoinExpression(exp.LeftJoinType, ti, exp.NewJoinUsingCondition("a")) - - b := sb.NewSQLBuilder(false) - d.ToSelectSQL(b.Clear(), sc.JoinsAppend(uj)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" NATURAL JOIN "test2"`) - - d.ToSelectSQL(b.Clear(), sc.JoinsAppend(cjo)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" LEFT JOIN "test2" ON ("a" = 'foo')`) - - d.ToSelectSQL(b.Clear(), sc.JoinsAppend(cju)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" LEFT JOIN "test2" USING ("a")`) - - d.ToSelectSQL(b.Clear(), sc.JoinsAppend(uj).JoinsAppend(cjo).JoinsAppend(cju)) - dts.assertNotPreparedSQL( - b, - `SELECT * FROM "test" NATURAL JOIN "test2" LEFT JOIN "test2" ON ("a" = 'foo') LEFT JOIN "test2" USING ("a")`, - ) - - b = sb.NewSQLBuilder(true) - d.ToSelectSQL(b.Clear(), sc.JoinsAppend(uj)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" NATURAL JOIN "test2"`, emptyArgs) - - d.ToSelectSQL(b.Clear(), sc.JoinsAppend(cjo)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" LEFT JOIN "test2" ON ("a" = ?)`, []interface{}{"foo"}) - - d.ToSelectSQL(b.Clear(), sc.JoinsAppend(cju)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" LEFT JOIN "test2" USING ("a")`, emptyArgs) - - d.ToSelectSQL(b.Clear(), sc.JoinsAppend(uj).JoinsAppend(cjo).JoinsAppend(cju)) - dts.assertPreparedSQL( - b, - `SELECT * FROM "test" NATURAL JOIN "test2" LEFT JOIN "test2" ON ("a" = ?) LEFT JOIN "test2" USING ("a")`, - []interface{}{"foo"}, - ) - - opts2 := DefaultDialectOptions() - // override fragements to make sure dialect is used - opts2.UsingFragment = []byte(" using ") - opts2.OnFragment = []byte(" on ") - opts2.JoinTypeLookup = map[exp.JoinType][]byte{ - exp.LeftJoinType: []byte(" left join "), - exp.NaturalJoinType: []byte(" natural join "), - } - d2 := sqlDialect{dialect: "test", dialectOptions: opts2} - - b = sb.NewSQLBuilder(false) - d2.ToSelectSQL(b.Clear(), sc.JoinsAppend(uj)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" natural join "test2"`) - - d2.ToSelectSQL(b.Clear(), sc.JoinsAppend(cjo)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" left join "test2" on ("a" = 'foo')`) - - d2.ToSelectSQL(b.Clear(), sc.JoinsAppend(cju)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" left join "test2" using ("a")`) - - d2.ToSelectSQL(b.Clear(), sc.JoinsAppend(uj).JoinsAppend(cjo).JoinsAppend(cju)) - dts.assertNotPreparedSQL( - b, - `SELECT * FROM "test" natural join "test2" left join "test2" on ("a" = 'foo') left join "test2" using ("a")`, - ) - - rj := exp.NewConditionedJoinExpression(exp.RightJoinType, ti, exp.NewJoinUsingCondition(exp.NewIdentifierExpression("", "", "a"))) - d2.ToSelectSQL(b.Clear(), sc.JoinsAppend(rj)) - dts.assertErrorSQL(b, "goqu: dialect does not support RightJoinType") - - badJoin := exp.NewConditionedJoinExpression(exp.LeftJoinType, ti, exp.NewJoinUsingCondition()) - d2.ToSelectSQL(b.Clear(), sc.JoinsAppend(badJoin)) - dts.assertErrorSQL(b, "goqu: join condition required for conditioned join LeftJoinType") -} - -func (dts *dialectTestSuite) TestToSelectSQL_withWhere() { - opts := DefaultDialectOptions() - opts.WhereFragment = []byte(" where ") - d := sqlDialect{dialect: "test", dialectOptions: opts} - - sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) - w := exp.Ex{"a": "b"} - w2 := exp.Ex{"b": "c"} - - b := sb.NewSQLBuilder(false) - d.ToSelectSQL(b.Clear(), sc.WhereAppend(w)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" where ("a" = 'b')`) - - d.ToSelectSQL(b.Clear(), sc.WhereAppend(w, w2)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" where (("a" = 'b') AND ("b" = 'c'))`) - - b = sb.NewSQLBuilder(true) - d.ToSelectSQL(b.Clear(), sc.WhereAppend(w)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" where ("a" = ?)`, []interface{}{"b"}) - - d.ToSelectSQL(b.Clear(), sc.WhereAppend(w, w2)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" where (("a" = ?) AND ("b" = ?))`, []interface{}{"b", "c"}) -} - -func (dts *dialectTestSuite) TestToSelectSQL_withGroupBy() { - opts := DefaultDialectOptions() - opts.GroupByFragment = []byte(" group by ") - d := sqlDialect{dialect: "test", dialectOptions: opts} - sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) - c1 := exp.NewIdentifierExpression("", "", "a") - c2 := exp.NewIdentifierExpression("", "", "b") - - b := sb.NewSQLBuilder(false) - d.ToSelectSQL(b.Clear(), sc.SetGroupBy(exp.NewColumnListExpression(c1))) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" group by "a"`) - - d.ToSelectSQL(b.Clear(), sc.SetGroupBy(exp.NewColumnListExpression(c1, c2))) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" group by "a", "b"`) - - b = sb.NewSQLBuilder(true) - d.ToSelectSQL(b.Clear(), sc.SetGroupBy(exp.NewColumnListExpression(c1))) - dts.assertPreparedSQL(b, `SELECT * FROM "test" group by "a"`, emptyArgs) - - d.ToSelectSQL(b.Clear(), sc.SetGroupBy(exp.NewColumnListExpression(c1, c2))) - dts.assertPreparedSQL(b, `SELECT * FROM "test" group by "a", "b"`, emptyArgs) -} - -func (dts *dialectTestSuite) TestToSelectSQL_withHaving() { - opts := DefaultDialectOptions() - opts.HavingFragment = []byte(" having ") - d := sqlDialect{dialect: "test", dialectOptions: opts} - sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) - w := exp.Ex{"a": "b"} - w2 := exp.Ex{"b": "c"} - - b := sb.NewSQLBuilder(false) - d.ToSelectSQL(b.Clear(), sc.HavingAppend(w)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" having ("a" = 'b')`) - - d.ToSelectSQL(b.Clear(), sc.HavingAppend(w, w2)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" having (("a" = 'b') AND ("b" = 'c'))`) - - b = sb.NewSQLBuilder(true) - d.ToSelectSQL(b.Clear(), sc.HavingAppend(w)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" having ("a" = ?)`, []interface{}{"b"}) - - d.ToSelectSQL(b.Clear(), sc.HavingAppend(w, w2)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" having (("a" = ?) AND ("b" = ?))`, []interface{}{"b", "c"}) -} - -func (dts *dialectTestSuite) TestToSelectSQL_withOrder() { - opts := DefaultDialectOptions() - // override fragments to ensure they are used - opts.OrderByFragment = []byte(" order by ") - opts.AscFragment = []byte(" asc") - opts.DescFragment = []byte(" desc") - opts.NullsFirstFragment = []byte(" nulls first") - opts.NullsLastFragment = []byte(" nulls last") - d := sqlDialect{dialect: "test", dialectOptions: opts} - - sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) - oa := exp.NewIdentifierExpression("", "", "a").Asc() - oanf := exp.NewIdentifierExpression("", "", "a").Asc().NullsFirst() - oanl := exp.NewIdentifierExpression("", "", "a").Asc().NullsLast() - od := exp.NewIdentifierExpression("", "", "a").Desc() - odnf := exp.NewIdentifierExpression("", "", "a").Desc().NullsFirst() - odnl := exp.NewIdentifierExpression("", "", "a").Desc().NullsLast() - - b := sb.NewSQLBuilder(false) - d.ToSelectSQL(b.Clear(), sc.SetOrder(oa)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" order by "a" asc`) - - d.ToSelectSQL(b.Clear(), sc.SetOrder(oanf)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" order by "a" asc nulls first`) - - d.ToSelectSQL(b.Clear(), sc.SetOrder(oanl)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" order by "a" asc nulls last`) - - d.ToSelectSQL(b.Clear(), sc.SetOrder(od)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" order by "a" desc`) - - d.ToSelectSQL(b.Clear(), sc.SetOrder(odnf)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" order by "a" desc nulls first`) - - d.ToSelectSQL(b.Clear(), sc.SetOrder(odnl)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" order by "a" desc nulls last`) - - d.ToSelectSQL(b.Clear(), sc.SetOrder(oa, od)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" order by "a" asc, "a" desc`) - - b = sb.NewSQLBuilder(true) - d.ToSelectSQL(b.Clear(), sc.SetOrder(oa)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" order by "a" asc`, emptyArgs) - - d.ToSelectSQL(b.Clear(), sc.SetOrder(oanf)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" order by "a" asc nulls first`, emptyArgs) - - d.ToSelectSQL(b.Clear(), sc.SetOrder(oanl)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" order by "a" asc nulls last`, emptyArgs) - - d.ToSelectSQL(b.Clear(), sc.SetOrder(od)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" order by "a" desc`, emptyArgs) - - d.ToSelectSQL(b.Clear(), sc.SetOrder(odnf)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" order by "a" desc nulls first`, emptyArgs) - - d.ToSelectSQL(b.Clear(), sc.SetOrder(odnl)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" order by "a" desc nulls last`, emptyArgs) - - d.ToSelectSQL(b.Clear(), sc.SetOrder(oa, od)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" order by "a" asc, "a" desc`, emptyArgs) - -} - -func (dts *dialectTestSuite) TestToSelectSQL_withLimit() { - opts := DefaultDialectOptions() - opts.LimitFragment = []byte(" limit ") - d := sqlDialect{dialect: "test", dialectOptions: opts} - - sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) - b := sb.NewSQLBuilder(false) - d.ToSelectSQL(b.Clear(), sc.SetLimit(10)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" limit 10`) - - d.ToSelectSQL(b.Clear(), sc.SetLimit(0)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" limit 0`) - - d.ToSelectSQL(b.Clear(), sc.SetLimit(exp.NewLiteralExpression("ALL"))) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" limit ALL`) - - b = sb.NewSQLBuilder(true) - d.ToSelectSQL(b.Clear(), sc.SetLimit(10)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" limit ?`, []interface{}{int64(10)}) - - d.ToSelectSQL(b.Clear(), sc.SetLimit(0)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" limit ?`, []interface{}{int64(0)}) - - d.ToSelectSQL(b.Clear(), sc.SetLimit(exp.NewLiteralExpression("ALL"))) - dts.assertPreparedSQL(b, `SELECT * FROM "test" limit ALL`, emptyArgs) -} - -func (dts *dialectTestSuite) TestToSelectSQL_withOffset() { - opts := DefaultDialectOptions() - opts.OffsetFragment = []byte(" offset ") - d := sqlDialect{dialect: "test", dialectOptions: opts} - sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) - - b := sb.NewSQLBuilder(false) - d.ToSelectSQL(b.Clear(), sc.SetOffset(10)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" offset 10`) - - d.ToSelectSQL(b.Clear(), sc.SetOffset(0)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test"`) - - b = sb.NewSQLBuilder(true) - d.ToSelectSQL(b.Clear(), sc.SetOffset(10)) - dts.assertPreparedSQL(b, `SELECT * FROM "test" offset ?`, []interface{}{int64(10)}) - - d.ToSelectSQL(b.Clear(), sc.SetOffset(0)) - dts.assertPreparedSQL(b, `SELECT * FROM "test"`, emptyArgs) -} - -func (dts *dialectTestSuite) TestToSelectSQL_withCommonTables() { - opts := DefaultDialectOptions() - opts.WithFragment = []byte("with ") - opts.RecursiveFragment = []byte("recursive ") - d := sqlDialect{dialect: "test", dialectOptions: opts} - tse := newTestAppendableExpression("select * from foo", emptyArgs, nil, nil) - cte1 := exp.NewCommonTableExpression(false, "test_cte", tse) - cte2 := exp.NewCommonTableExpression(true, "test_cte", tse) - - sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test_cte")) - - b := sb.NewSQLBuilder(false) - - d.ToSelectSQL(b.Clear(), sc.CommonTablesAppend(cte1)) - dts.assertNotPreparedSQL(b, `with test_cte AS (select * from foo) SELECT * FROM "test_cte"`) - - d.ToSelectSQL(b.Clear(), sc.CommonTablesAppend(cte2)) - dts.assertNotPreparedSQL(b, `with recursive test_cte AS (select * from foo) SELECT * FROM "test_cte"`) - - d.ToSelectSQL(b.Clear(), sc.CommonTablesAppend(cte1).CommonTablesAppend(cte2)) - dts.assertNotPreparedSQL( - b, - `with recursive test_cte AS (select * from foo), test_cte AS (select * from foo) SELECT * FROM "test_cte"`, - ) - - opts = DefaultDialectOptions() - opts.SupportsWithCTE = false - d = sqlDialect{dialect: "test", dialectOptions: opts} - - d.ToSelectSQL(b.Clear(), sc.CommonTablesAppend(cte1)) - dts.assertErrorSQL(b, "goqu: dialect does not support CTE WITH clause [dialect=test]") - - opts = DefaultDialectOptions() - opts.SupportsWithCTERecursive = false - d = sqlDialect{dialect: "test", dialectOptions: opts} - - d.ToSelectSQL(b.Clear(), sc.CommonTablesAppend(cte2)) - dts.assertErrorSQL(b, "goqu: dialect does not support CTE WITH RECURSIVE clause [dialect=test]") - - d.ToSelectSQL(b.Clear(), sc.CommonTablesAppend(cte1)) - dts.assertNotPreparedSQL(b, `WITH test_cte AS (select * from foo) SELECT * FROM "test_cte"`) - -} - -func (dts *dialectTestSuite) TestToSelectSQL_withCompounds() { - opts := DefaultDialectOptions() - opts.UnionFragment = []byte(" union ") - opts.UnionAllFragment = []byte(" union all ") - opts.IntersectFragment = []byte(" intersect ") - opts.IntersectAllFragment = []byte(" intersect all ") - d := sqlDialect{dialect: "test", dialectOptions: opts} - - tse := newTestAppendableExpression("select * from foo", emptyArgs, nil, nil) - - sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) - b := sb.NewSQLBuilder(false) - - u := exp.NewCompoundExpression(exp.UnionCompoundType, tse) - d.ToSelectSQL(b.Clear(), sc.CompoundsAppend(u)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" union (select * from foo)`) - - ua := exp.NewCompoundExpression(exp.UnionAllCompoundType, tse) - d.ToSelectSQL(b.Clear(), sc.CompoundsAppend(ua)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" union all (select * from foo)`) - - i := exp.NewCompoundExpression(exp.IntersectCompoundType, tse) - d.ToSelectSQL(b.Clear(), sc.CompoundsAppend(i)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" intersect (select * from foo)`) - - ia := exp.NewCompoundExpression(exp.IntersectAllCompoundType, tse) - d.ToSelectSQL(b.Clear(), sc.CompoundsAppend(ia)) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" intersect all (select * from foo)`) - - d.ToSelectSQL(b.Clear(), sc.CompoundsAppend(u).CompoundsAppend(ua).CompoundsAppend(i).CompoundsAppend(ia)) - dts.assertNotPreparedSQL( - b, - `SELECT * FROM "test"`+ - ` union (select * from foo)`+ - ` union all (select * from foo)`+ - ` intersect (select * from foo)`+ - ` intersect all (select * from foo)`, - ) - -} - -func (dts *dialectTestSuite) TestToSelectSQL_withFor() { - opts := DefaultDialectOptions() - opts.ForUpdateFragment = []byte(" for update ") - opts.ForNoKeyUpdateFragment = []byte(" for no key update ") - opts.ForShareFragment = []byte(" for share ") - opts.ForKeyShareFragment = []byte(" for key share ") - opts.NowaitFragment = []byte("nowait") - opts.SkipLockedFragment = []byte("skip locked") - d := sqlDialect{dialect: "test", dialectOptions: opts} - - sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) - b := sb.NewSQLBuilder(false) - - d.ToSelectSQL(b.Clear(), sc.SetLock(exp.NewLock(exp.ForNolock, exp.Wait))) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test"`) - - d.ToSelectSQL(b.Clear(), sc.SetLock(exp.NewLock(exp.ForShare, exp.Wait))) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" for share `) - - d.ToSelectSQL(b.Clear(), sc.SetLock(exp.NewLock(exp.ForShare, exp.NoWait))) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" for share nowait`) - - d.ToSelectSQL(b.Clear(), sc.SetLock(exp.NewLock(exp.ForShare, exp.SkipLocked))) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" for share skip locked`) - - d.ToSelectSQL(b.Clear(), sc.SetLock(exp.NewLock(exp.ForKeyShare, exp.Wait))) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" for key share `) - - d.ToSelectSQL(b.Clear(), sc.SetLock(exp.NewLock(exp.ForKeyShare, exp.NoWait))) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" for key share nowait`) - - d.ToSelectSQL(b.Clear(), sc.SetLock(exp.NewLock(exp.ForKeyShare, exp.SkipLocked))) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" for key share skip locked`) - - d.ToSelectSQL(b.Clear(), sc.SetLock(exp.NewLock(exp.ForUpdate, exp.Wait))) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" for update `) - - d.ToSelectSQL(b.Clear(), sc.SetLock(exp.NewLock(exp.ForUpdate, exp.NoWait))) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" for update nowait`) - - d.ToSelectSQL(b.Clear(), sc.SetLock(exp.NewLock(exp.ForUpdate, exp.SkipLocked))) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" for update skip locked`) - - d.ToSelectSQL(b.Clear(), sc.SetLock(exp.NewLock(exp.ForNoKeyUpdate, exp.Wait))) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" for no key update `) - - d.ToSelectSQL(b.Clear(), sc.SetLock(exp.NewLock(exp.ForNoKeyUpdate, exp.NoWait))) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" for no key update nowait`) - - d.ToSelectSQL(b.Clear(), sc.SetLock(exp.NewLock(exp.ForNoKeyUpdate, exp.SkipLocked))) - dts.assertNotPreparedSQL(b, `SELECT * FROM "test" for no key update skip locked`) -} - -func (dts *dialectTestSuite) TestLiteral_FloatTypes() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - var float float64 - - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), float32(10.01)) - dts.assertNotPreparedSQL(b, "10.010000228881836") - - d.Literal(b.Clear(), float64(10.01)) - dts.assertNotPreparedSQL(b, "10.01") - - d.Literal(b.Clear(), &float) - dts.assertNotPreparedSQL(b, "0") - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), float32(10.01)) - dts.assertPreparedSQL(b, "?", []interface{}{float64(float32(10.01))}) - - d.Literal(b.Clear(), float64(10.01)) - dts.assertPreparedSQL(b, "?", []interface{}{float64(10.01)}) - - d.Literal(b.Clear(), &float) - dts.assertPreparedSQL(b, "?", []interface{}{float}) -} - -func (dts *dialectTestSuite) TestLiteral_IntTypes() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - var i int64 - b := sb.NewSQLBuilder(false) - ints := []interface{}{ - int(10), - int16(10), - int32(10), - int64(10), - uint(10), - uint16(10), - uint32(10), - uint64(10), - } - for _, i := range ints { - d.Literal(b.Clear(), i) - dts.assertNotPreparedSQL(b, "10") - } - d.Literal(b.Clear(), &i) - dts.assertNotPreparedSQL(b, "0") - - b = sb.NewSQLBuilder(true) - for _, i := range ints { - d.Literal(b.Clear(), i) - dts.assertPreparedSQL(b, "?", []interface{}{int64(10)}) - } - d.Literal(b.Clear(), &i) - dts.assertPreparedSQL(b, "?", []interface{}{i}) -} - -func (dts *dialectTestSuite) TestLiteral_StringTypes() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - var str string - - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), "Hello") - dts.assertNotPreparedSQL(b, "'Hello'") - - // should escape single quotes - d.Literal(b.Clear(), "Hello'") - dts.assertNotPreparedSQL(b, "'Hello'''") - - d.Literal(b.Clear(), &str) - dts.assertNotPreparedSQL(b, "''") - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), "Hello") - dts.assertPreparedSQL(b, "?", []interface{}{"Hello"}) - - // should escape single quotes - d.Literal(b.Clear(), "Hello'") - dts.assertPreparedSQL(b, "?", []interface{}{"Hello'"}) - - d.Literal(b.Clear(), &str) - dts.assertPreparedSQL(b, "?", []interface{}{str}) -} - -func (dts *dialectTestSuite) TestLiteral_BytesTypes() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), []byte("Hello")) - dts.assertNotPreparedSQL(b, "'Hello'") - - // should escape single quotes - d.Literal(b.Clear(), []byte("Hello'")) - dts.assertNotPreparedSQL(b, "'Hello'''") - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), []byte("Hello")) - dts.assertPreparedSQL(b, "?", []interface{}{[]byte("Hello")}) - - // should escape single quotes - d.Literal(b.Clear(), []byte("Hello'")) - dts.assertPreparedSQL(b, "?", []interface{}{[]byte("Hello'")}) -} - -func (dts *dialectTestSuite) TestLiteral_BoolTypes() { - var bl bool - - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), true) - dts.assertNotPreparedSQL(b, "TRUE") - - d.Literal(b.Clear(), false) - dts.assertNotPreparedSQL(b, "FALSE") - - d.Literal(b.Clear(), &bl) - dts.assertNotPreparedSQL(b, "FALSE") - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), true) - dts.assertPreparedSQL(b, "?", []interface{}{true}) - - d.Literal(b.Clear(), false) - dts.assertPreparedSQL(b, "?", []interface{}{false}) - - d.Literal(b.Clear(), &bl) - dts.assertPreparedSQL(b, "?", []interface{}{bl}) -} - -func (dts *dialectTestSuite) TestLiteral_TimeTypes() { - d := sqlDialect{dialect: "default", dialectOptions: DefaultDialectOptions()} - var nt *time.Time - asiaShanghai, err := time.LoadLocation("Asia/Shanghai") - dts.Require().NoError(err) - testDatas := []time.Time{ - time.Now().UTC(), - time.Now().In(asiaShanghai), - } - - for _, n := range testDatas { - var now = n - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), now) - dts.assertNotPreparedSQL(b, "'"+now.Format(time.RFC3339Nano)+"'") - - d.Literal(b.Clear(), &now) - dts.assertNotPreparedSQL(b, "'"+now.Format(time.RFC3339Nano)+"'") - - d.Literal(b.Clear(), nt) - dts.assertNotPreparedSQL(b, "NULL") - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), now) - dts.assertPreparedSQL(b, "?", []interface{}{now}) - - d.Literal(b.Clear(), &now) - dts.assertPreparedSQL(b, "?", []interface{}{now}) - - d.Literal(b.Clear(), nt) - dts.assertPreparedSQL(b, "NULL", emptyArgs) - } -} - -func (dts *dialectTestSuite) TestLiteral_NilTypes() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), nil) - dts.assertNotPreparedSQL(b, "NULL") - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), nil) - dts.assertPreparedSQL(b, "NULL", []interface{}{}) -} - -type datasetValuerType int64 - -func (j datasetValuerType) Value() (driver.Value, error) { - return []byte(fmt.Sprintf("Hello World %d", j)), nil -} - -func (dts *dialectTestSuite) TestLiteral_Valuer() { - b := sb.NewSQLBuilder(false) - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - - d.Literal(b.Clear(), datasetValuerType(10)) - dts.assertNotPreparedSQL(b, "'Hello World 10'") - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), datasetValuerType(10)) - dts.assertPreparedSQL(b, "?", []interface{}{[]byte("Hello World 10")}) -} - -func (dts *dialectTestSuite) TestLiteral_Slice() { - b := sb.NewSQLBuilder(false) - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - d.Literal(b.Clear(), []string{"a", "b", "c"}) - dts.assertNotPreparedSQL(b, `('a', 'b', 'c')`) - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), []string{"a", "b", "c"}) - dts.assertPreparedSQL(b, `(?, ?, ?)`, []interface{}{"a", "b", "c"}) -} - -type unknownExpression struct { -} - -func (ue unknownExpression) Expression() exp.Expression { - return ue -} -func (ue unknownExpression) Clone() exp.Expression { - return ue -} -func (dts *dialectTestSuite) TestLiteralUnsupportedExpression() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), unknownExpression{}) - dts.assertErrorSQL(b, "goqu: unsupported expression type goqu.unknownExpression") -} - -func (dts *dialectTestSuite) TestLiteral_AppendableExpression() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - ti := exp.NewIdentifierExpression("", "b", "") - a := newTestAppendableExpression(`select * from "a"`, []interface{}{}, nil, nil) - aliasedA := newTestAppendableExpression(`select * from "a"`, []interface{}{}, nil, exp.NewSelectClauses().SetAlias(ti)) - argsA := newTestAppendableExpression(`select * from "a" where x=?`, []interface{}{true}, nil, exp.NewSelectClauses().SetAlias(ti)) - ae := newTestAppendableExpression(`select * from "a"`, emptyArgs, errors.New("expected error"), nil) - - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), a) - dts.assertNotPreparedSQL(b, `(select * from "a")`) - - d.Literal(b.Clear(), aliasedA) - dts.assertNotPreparedSQL(b, `(select * from "a") AS "b"`) - - d.Literal(b.Clear(), ae) - dts.assertErrorSQL(b, "goqu: expected error") - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), a) - dts.assertPreparedSQL(b, `(select * from "a")`, emptyArgs) - - d.Literal(b.Clear(), aliasedA) - dts.assertPreparedSQL(b, `(select * from "a") AS "b"`, emptyArgs) - - d.Literal(b.Clear(), argsA) - dts.assertPreparedSQL(b, `(select * from "a" where x=?) AS "b"`, []interface{}{true}) -} - -func (dts *dialectTestSuite) TestLiteral_ColumnList() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), exp.NewColumnListExpression("a", exp.NewLiteralExpression("true"))) - dts.assertNotPreparedSQL(b, `"a", true`) - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), exp.NewColumnListExpression("a", exp.NewLiteralExpression("true"))) - dts.assertPreparedSQL(b, `"a", true`, emptyArgs) -} - -func (dts *dialectTestSuite) TestLiteral_ExpressionList() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), exp.NewExpressionList( - exp.AndType, - exp.NewIdentifierExpression("", "", "a").Eq("b"), - exp.NewIdentifierExpression("", "", "c").Neq(1), - )) - dts.assertNotPreparedSQL(b, `(("a" = 'b') AND ("c" != 1))`) - - d.Literal(b.Clear(), exp.NewExpressionList( - exp.OrType, - exp.NewIdentifierExpression("", "", "a").Eq("b"), - exp.NewIdentifierExpression("", "", "c").Neq(1), - )) - dts.assertNotPreparedSQL(b, `(("a" = 'b') OR ("c" != 1))`) - - d.Literal(b.Clear(), exp.NewExpressionList(exp.OrType, - exp.NewIdentifierExpression("", "", "a").Eq("b"), - exp.NewExpressionList(exp.AndType, - exp.NewIdentifierExpression("", "", "c").Neq(1), - exp.NewIdentifierExpression("", "", "d").Eq(exp.NewLiteralExpression("NOW()")), - ), - )) - dts.assertNotPreparedSQL(b, `(("a" = 'b') OR (("c" != 1) AND ("d" = NOW())))`) - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), exp.NewExpressionList( - exp.AndType, - exp.NewIdentifierExpression("", "", "a").Eq("b"), - exp.NewIdentifierExpression("", "", "c").Neq(1), - )) - dts.assertPreparedSQL(b, `(("a" = ?) AND ("c" != ?))`, []interface{}{"b", int64(1)}) - - d.Literal(b.Clear(), exp.NewExpressionList( - exp.OrType, - exp.NewIdentifierExpression("", "", "a").Eq("b"), - exp.NewIdentifierExpression("", "", "c").Neq(1)), - ) - dts.assertPreparedSQL(b, `(("a" = ?) OR ("c" != ?))`, []interface{}{"b", int64(1)}) - - d.Literal(b.Clear(), exp.NewExpressionList( - exp.OrType, - exp.NewIdentifierExpression("", "", "a").Eq("b"), - exp.NewExpressionList( - exp.AndType, - exp.NewIdentifierExpression("", "", "c").Neq(1), - exp.NewIdentifierExpression("", "", "d").Eq(exp.NewLiteralExpression("NOW()")), - ), - )) - dts.assertPreparedSQL(b, `(("a" = ?) OR (("c" != ?) AND ("d" = NOW())))`, []interface{}{"b", int64(1)}) -} - -func (dts *dialectTestSuite) TestLiteral_LiteralExpression() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), exp.NewLiteralExpression(`"b"::DATE = '2010-09-02'`)) - dts.assertNotPreparedSQL(b, `"b"::DATE = '2010-09-02'`) - - d.Literal(b.Clear(), exp.NewLiteralExpression( - `"b" = ? or "c" = ? or d IN ?`, - "a", 1, []int{1, 2, 3, 4}), - ) - dts.assertNotPreparedSQL(b, `"b" = 'a' or "c" = 1 or d IN (1, 2, 3, 4)`) - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), exp.NewLiteralExpression(`"b"::DATE = '2010-09-02'`)) - dts.assertPreparedSQL(b, `"b"::DATE = '2010-09-02'`, emptyArgs) - - d.Literal(b.Clear(), exp.NewLiteralExpression( - `"b" = ? or "c" = ? or d IN ?`, - "a", 1, []int{1, 2, 3, 4}, - )) - dts.assertPreparedSQL(b, `"b" = ? or "c" = ? or d IN (?, ?, ?, ?)`, []interface{}{ - "a", - int64(1), - int64(1), - int64(2), - int64(3), - int64(4), - }) -} - -func (dts *dialectTestSuite) TestLiteral_AliasedExpression() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").As("b")) - dts.assertNotPreparedSQL(b, `"a" AS "b"`) - - d.Literal(b.Clear(), exp.NewLiteralExpression("count(*)").As("count")) - dts.assertNotPreparedSQL(b, `count(*) AS "count"`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a"). - As(exp.NewIdentifierExpression("", "", "b"))) - dts.assertNotPreparedSQL(b, `"a" AS "b"`) - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").As("b")) - dts.assertPreparedSQL(b, `"a" AS "b"`, emptyArgs) - - d.Literal(b.Clear(), exp.NewLiteralExpression("count(*)").As("count")) - dts.assertPreparedSQL(b, `count(*) AS "count"`, emptyArgs) -} - -func (dts *dialectTestSuite) TestLiteral_BooleanExpression() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - - ae := newTestAppendableExpression(`SELECT "id" FROM "test2"`, emptyArgs, nil, nil) - ident := exp.NewIdentifierExpression("", "", "a") - b := sb.NewSQLBuilder(false) - - d.Literal(b.Clear(), ident.Eq(1)) - dts.assertNotPreparedSQL(b, `("a" = 1)`) - - d.Literal(b.Clear(), ident.Eq(true)) - dts.assertNotPreparedSQL(b, `("a" IS TRUE)`) - - d.Literal(b.Clear(), ident.Eq(false)) - dts.assertNotPreparedSQL(b, `("a" IS FALSE)`) - - d.Literal(b.Clear(), ident.Eq(nil)) - dts.assertNotPreparedSQL(b, `("a" IS NULL)`) - - d.Literal(b.Clear(), ident.Eq([]int64{1, 2, 3})) - dts.assertNotPreparedSQL(b, `("a" IN (1, 2, 3))`) - - d.Literal(b.Clear(), ident.Eq(ae)) - dts.assertNotPreparedSQL(b, `("a" IN (SELECT "id" FROM "test2"))`) - - d.Literal(b.Clear(), ident.Neq(1)) - dts.assertNotPreparedSQL(b, `("a" != 1)`) - - d.Literal(b.Clear(), ident.Neq(true)) - dts.assertNotPreparedSQL(b, `("a" IS NOT TRUE)`) - - d.Literal(b.Clear(), ident.Neq(false)) - dts.assertNotPreparedSQL(b, `("a" IS NOT FALSE)`) - - d.Literal(b.Clear(), ident.Neq(nil)) - dts.assertNotPreparedSQL(b, `("a" IS NOT NULL)`) - - d.Literal(b.Clear(), ident.Neq([]int64{1, 2, 3})) - dts.assertNotPreparedSQL(b, `("a" NOT IN (1, 2, 3))`) - - d.Literal(b.Clear(), ident.Neq(ae)) - dts.assertNotPreparedSQL(b, `("a" NOT IN (SELECT "id" FROM "test2"))`) - - d.Literal(b.Clear(), ident.Is(nil)) - dts.assertNotPreparedSQL(b, `("a" IS NULL)`) - - d.Literal(b.Clear(), ident.Is(false)) - dts.assertNotPreparedSQL(b, `("a" IS FALSE)`) - - d.Literal(b.Clear(), ident.Is(true)) - dts.assertNotPreparedSQL(b, `("a" IS TRUE)`) - - d.Literal(b.Clear(), ident.IsNot(nil)) - dts.assertNotPreparedSQL(b, `("a" IS NOT NULL)`) - - d.Literal(b.Clear(), ident.IsNot(false)) - dts.assertNotPreparedSQL(b, `("a" IS NOT FALSE)`) - - d.Literal(b.Clear(), ident.IsNot(true)) - dts.assertNotPreparedSQL(b, `("a" IS NOT TRUE)`) - - d.Literal(b.Clear(), ident.Gt(1)) - dts.assertNotPreparedSQL(b, `("a" > 1)`) - - d.Literal(b.Clear(), ident.Gte(1)) - dts.assertNotPreparedSQL(b, `("a" >= 1)`) - - d.Literal(b.Clear(), ident.Lt(1)) - dts.assertNotPreparedSQL(b, `("a" < 1)`) - - d.Literal(b.Clear(), ident.Lte(1)) - dts.assertNotPreparedSQL(b, `("a" <= 1)`) - - d.Literal(b.Clear(), ident.In([]int{1, 2, 3})) - dts.assertNotPreparedSQL(b, `("a" IN (1, 2, 3))`) - - d.Literal(b.Clear(), ident.NotIn([]int{1, 2, 3})) - dts.assertNotPreparedSQL(b, `("a" NOT IN (1, 2, 3))`) - - d.Literal(b.Clear(), ident.Like("a%")) - dts.assertNotPreparedSQL(b, `("a" LIKE 'a%')`) - - d.Literal(b.Clear(), ident. - Like(regexp.MustCompile("(a|b)"))) - dts.assertNotPreparedSQL(b, `("a" ~ '(a|b)')`) - - d.Literal(b.Clear(), ident.NotLike("a%")) - dts.assertNotPreparedSQL(b, `("a" NOT LIKE 'a%')`) - - d.Literal(b.Clear(), ident. - NotLike(regexp.MustCompile("(a|b)"))) - dts.assertNotPreparedSQL(b, `("a" !~ '(a|b)')`) - - d.Literal(b.Clear(), ident.ILike("a%")) - dts.assertNotPreparedSQL(b, `("a" ILIKE 'a%')`) - - d.Literal(b.Clear(), ident. - ILike(regexp.MustCompile("(a|b)"))) - dts.assertNotPreparedSQL(b, `("a" ~* '(a|b)')`) - - d.Literal(b.Clear(), ident.NotILike("a%")) - dts.assertNotPreparedSQL(b, `("a" NOT ILIKE 'a%')`) - - d.Literal(b.Clear(), ident. - NotILike(regexp.MustCompile("(a|b)"))) - dts.assertNotPreparedSQL(b, `("a" !~* '(a|b)')`) - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), ident.Eq(1)) - dts.assertPreparedSQL(b, `("a" = ?)`, []interface{}{int64(1)}) - - d.Literal(b.Clear(), ident.Eq(true)) - dts.assertPreparedSQL(b, `("a" IS TRUE)`, []interface{}{}) - - d.Literal(b.Clear(), ident.Eq(false)) - dts.assertPreparedSQL(b, `("a" IS FALSE)`, emptyArgs) - - d.Literal(b.Clear(), ident.Eq(nil)) - dts.assertPreparedSQL(b, `("a" IS NULL)`, emptyArgs) - - d.Literal(b.Clear(), ident.Eq([]int64{1, 2, 3})) - dts.assertPreparedSQL(b, `("a" IN (?, ?, ?))`, []interface{}{int64(1), int64(2), int64(3)}) - - d.Literal(b.Clear(), ident.Neq(1)) - dts.assertPreparedSQL(b, `("a" != ?)`, []interface{}{int64(1)}) - - d.Literal(b.Clear(), ident.Neq(true)) - dts.assertPreparedSQL(b, `("a" IS NOT TRUE)`, emptyArgs) - - d.Literal(b.Clear(), ident.Neq(false)) - dts.assertPreparedSQL(b, `("a" IS NOT FALSE)`, emptyArgs) - - d.Literal(b.Clear(), ident.Neq(nil)) - dts.assertPreparedSQL(b, `("a" IS NOT NULL)`, emptyArgs) - - d.Literal(b.Clear(), ident.Neq([]int64{1, 2, 3})) - dts.assertPreparedSQL(b, `("a" NOT IN (?, ?, ?))`, []interface{}{int64(1), int64(2), int64(3)}) - - d.Literal(b.Clear(), ident.Is(nil)) - dts.assertPreparedSQL(b, `("a" IS NULL)`, emptyArgs) - - d.Literal(b.Clear(), ident.Is(false)) - dts.assertPreparedSQL(b, `("a" IS FALSE)`, emptyArgs) - - d.Literal(b.Clear(), ident.Is(true)) - dts.assertPreparedSQL(b, `("a" IS TRUE)`, emptyArgs) - - d.Literal(b.Clear(), ident.IsNot(nil)) - dts.assertPreparedSQL(b, `("a" IS NOT NULL)`, emptyArgs) - - d.Literal(b.Clear(), ident.IsNot(false)) - dts.assertPreparedSQL(b, `("a" IS NOT FALSE)`, emptyArgs) - - d.Literal(b.Clear(), ident.IsNot(true)) - dts.assertPreparedSQL(b, `("a" IS NOT TRUE)`, emptyArgs) - - d.Literal(b.Clear(), ident.Gt(1)) - dts.assertPreparedSQL(b, `("a" > ?)`, []interface{}{int64(1)}) - - d.Literal(b.Clear(), ident.Gte(1)) - dts.assertPreparedSQL(b, `("a" >= ?)`, []interface{}{int64(1)}) - - d.Literal(b.Clear(), ident.Lt(1)) - dts.assertPreparedSQL(b, `("a" < ?)`, []interface{}{int64(1)}) - - d.Literal(b.Clear(), ident.Lte(1)) - dts.assertPreparedSQL(b, `("a" <= ?)`, []interface{}{int64(1)}) - - d.Literal(b.Clear(), ident.In([]int{1, 2, 3})) - dts.assertPreparedSQL(b, `("a" IN (?, ?, ?))`, []interface{}{int64(1), int64(2), int64(3)}) - - d.Literal(b.Clear(), ident.NotIn([]int{1, 2, 3})) - dts.assertPreparedSQL(b, `("a" NOT IN (?, ?, ?))`, []interface{}{int64(1), int64(2), int64(3)}) - - d.Literal(b.Clear(), ident.Like("a%")) - dts.assertPreparedSQL(b, `("a" LIKE ?)`, []interface{}{"a%"}) - - d.Literal(b.Clear(), ident. - Like(regexp.MustCompile("(a|b)"))) - dts.assertPreparedSQL(b, `("a" ~ ?)`, []interface{}{"(a|b)"}) - - d.Literal(b.Clear(), ident.NotLike("a%")) - dts.assertPreparedSQL(b, `("a" NOT LIKE ?)`, []interface{}{"a%"}) - - d.Literal(b.Clear(), ident. - NotLike(regexp.MustCompile("(a|b)"))) - dts.assertPreparedSQL(b, `("a" !~ ?)`, []interface{}{"(a|b)"}) - - d.Literal(b.Clear(), ident.ILike("a%")) - dts.assertPreparedSQL(b, `("a" ILIKE ?)`, []interface{}{"a%"}) - - d.Literal(b.Clear(), ident. - ILike(regexp.MustCompile("(a|b)"))) - dts.assertPreparedSQL(b, `("a" ~* ?)`, []interface{}{"(a|b)"}) - - d.Literal(b.Clear(), ident.NotILike("a%")) - dts.assertPreparedSQL(b, `("a" NOT ILIKE ?)`, []interface{}{"a%"}) - - d.Literal(b.Clear(), ident. - NotILike(regexp.MustCompile("(a|b)"))) - dts.assertPreparedSQL(b, `("a" !~* ?)`, []interface{}{"(a|b)"}) - - // test unsupported op - opts := DefaultDialectOptions() - opts.BooleanOperatorLookup = map[exp.BooleanOperation][]byte{} - d = sqlDialect{dialect: "test", dialectOptions: opts} - b = sb.NewSQLBuilder(false) - d.Literal(b.Clear(), ident.Eq(1)) - dts.assertErrorSQL(b, "goqu: boolean operator 'eq' not supported") - d.Literal(b.Clear(), ident.Neq(1)) - dts.assertErrorSQL(b, "goqu: boolean operator 'neq' not supported") - d.Literal(b.Clear(), ident.Is(true)) - dts.assertErrorSQL(b, "goqu: boolean operator 'is' not supported") - d.Literal(b.Clear(), ident.IsNot(true)) - dts.assertErrorSQL(b, "goqu: boolean operator 'isnot' not supported") - d.Literal(b.Clear(), ident.Gt(1)) - dts.assertErrorSQL(b, "goqu: boolean operator 'gt' not supported") - d.Literal(b.Clear(), ident.Gte(1)) - dts.assertErrorSQL(b, "goqu: boolean operator 'gte' not supported") - d.Literal(b.Clear(), ident.Lt(1)) - dts.assertErrorSQL(b, "goqu: boolean operator 'lt' not supported") - d.Literal(b.Clear(), ident.Lte(1)) - dts.assertErrorSQL(b, "goqu: boolean operator 'lte' not supported") - d.Literal(b.Clear(), ident.In(1, 2, 3)) - dts.assertErrorSQL(b, "goqu: boolean operator 'in' not supported") - d.Literal(b.Clear(), ident.NotIn(1, 2, 3)) - dts.assertErrorSQL(b, "goqu: boolean operator 'notin' not supported") - d.Literal(b.Clear(), ident.Like("a%")) - dts.assertErrorSQL(b, "goqu: boolean operator 'like' not supported") - d.Literal(b.Clear(), ident.NotLike("a%")) - dts.assertErrorSQL(b, "goqu: boolean operator 'notlike' not supported") - d.Literal(b.Clear(), ident.ILike("a%")) - dts.assertErrorSQL(b, "goqu: boolean operator 'ilike' not supported") - d.Literal(b.Clear(), ident.NotILike("a%")) - dts.assertErrorSQL(b, "goqu: boolean operator 'notilike' not supported") - d.Literal(b.Clear(), ident.Like(regexp.MustCompile("(a|b)"))) - dts.assertErrorSQL(b, "goqu: boolean operator 'regexp like' not supported") - d.Literal(b.Clear(), ident.NotLike(regexp.MustCompile("(a|b)"))) - dts.assertErrorSQL(b, "goqu: boolean operator 'regexp notlike' not supported") - d.Literal(b.Clear(), ident.ILike(regexp.MustCompile("(a|b)"))) - dts.assertErrorSQL(b, "goqu: boolean operator 'regexp ilike' not supported") - d.Literal(b.Clear(), ident.NotILike(regexp.MustCompile("(a|b)"))) - dts.assertErrorSQL(b, "goqu: boolean operator 'regexp notilike' not supported") -} - -func (dts *dialectTestSuite) TestLiteral_RangeExpression() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a"). - Between(exp.NewRangeVal(1, 2))) - dts.assertNotPreparedSQL(b, `("a" BETWEEN 1 AND 2)`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a"). - NotBetween(exp.NewRangeVal(1, 2))) - dts.assertNotPreparedSQL(b, `("a" NOT BETWEEN 1 AND 2)`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a"). - Between(exp.NewRangeVal("aaa", "zzz"))) - dts.assertNotPreparedSQL(b, `("a" BETWEEN 'aaa' AND 'zzz')`) - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a"). - Between(exp.NewRangeVal(1, 2))) - dts.assertPreparedSQL(b, `("a" BETWEEN ? AND ?)`, []interface{}{int64(1), int64(2)}) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a"). - NotBetween(exp.NewRangeVal(1, 2))) - dts.assertPreparedSQL(b, `("a" NOT BETWEEN ? AND ?)`, []interface{}{int64(1), int64(2)}) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a"). - Between(exp.NewRangeVal("aaa", "zzz"))) - dts.assertPreparedSQL(b, `("a" BETWEEN ? AND ?)`, []interface{}{"aaa", "zzz"}) -} - -func (dts *dialectTestSuite) TestLiteral_OrderedExpression() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").Asc()) - dts.assertNotPreparedSQL(b, `"a" ASC`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").Desc()) - dts.assertNotPreparedSQL(b, `"a" DESC`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").Asc().NullsLast()) - dts.assertNotPreparedSQL(b, `"a" ASC NULLS LAST`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").Desc().NullsLast()) - dts.assertNotPreparedSQL(b, `"a" DESC NULLS LAST`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").Asc().NullsFirst()) - dts.assertNotPreparedSQL(b, `"a" ASC NULLS FIRST`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").Desc().NullsFirst()) - dts.assertNotPreparedSQL(b, `"a" DESC NULLS FIRST`) - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").Asc()) - dts.assertPreparedSQL(b, `"a" ASC`, emptyArgs) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").Desc()) - dts.assertPreparedSQL(b, `"a" DESC`, emptyArgs) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").Asc().NullsLast()) - dts.assertPreparedSQL(b, `"a" ASC NULLS LAST`, emptyArgs) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").Desc().NullsLast()) - dts.assertPreparedSQL(b, `"a" DESC NULLS LAST`, emptyArgs) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").Asc().NullsFirst()) - dts.assertPreparedSQL(b, `"a" ASC NULLS FIRST`, emptyArgs) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").Desc().NullsFirst()) - dts.assertPreparedSQL(b, `"a" DESC NULLS FIRST`, emptyArgs) -} - -func (dts *dialectTestSuite) TestLiteral_UpdateExpression() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").Set(1)) - dts.assertNotPreparedSQL(b, `"a"=1`) - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").Set(1)) - dts.assertPreparedSQL(b, `"a"=?`, []interface{}{int64(1)}) -} - -func (dts *dialectTestSuite) TestLiteral_SQLFunctionExpression() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), exp.NewSQLFunctionExpression("MIN", exp.NewIdentifierExpression("", "", "a"))) - dts.assertNotPreparedSQL(b, `MIN("a")`) - - d.Literal(b.Clear(), exp.NewSQLFunctionExpression("COALESCE", exp.NewIdentifierExpression("", "", "a"), "a")) - dts.assertNotPreparedSQL(b, `COALESCE("a", 'a')`) - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), exp.NewSQLFunctionExpression("MIN", exp.NewIdentifierExpression("", "", "a"))) - dts.assertNotPreparedSQL(b, `MIN("a")`) - - d.Literal(b.Clear(), exp.NewSQLFunctionExpression("COALESCE", exp.NewIdentifierExpression("", "", "a"), "a")) - dts.assertPreparedSQL(b, `COALESCE("a", ?)`, []interface{}{"a"}) - -} - -func (dts *dialectTestSuite) TestLiteral_CastExpression() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").Cast("DATE")) - dts.assertNotPreparedSQL(b, `CAST("a" AS DATE)`) - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "a").Cast("DATE")) - dts.assertPreparedSQL(b, `CAST("a" AS DATE)`, emptyArgs) -} - -func (dts *dialectTestSuite) TestLiteral_CommonTableExpression() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - ae := newTestAppendableExpression(`SELECT * FROM "b"`, emptyArgs, nil, nil) - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), exp.NewCommonTableExpression(false, "a", ae)) - dts.assertNotPreparedSQL(b, `a AS (SELECT * FROM "b")`) - - d.Literal(b.Clear(), exp.NewCommonTableExpression(false, "a(x,y)", ae)) - dts.assertNotPreparedSQL(b, `a(x,y) AS (SELECT * FROM "b")`) - - d.Literal(b.Clear(), exp.NewCommonTableExpression(true, "a", ae)) - dts.assertNotPreparedSQL(b, `a AS (SELECT * FROM "b")`) - - d.Literal(b.Clear(), exp.NewCommonTableExpression(true, "a(x,y)", ae)) - dts.assertNotPreparedSQL(b, `a(x,y) AS (SELECT * FROM "b")`) - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), exp.NewCommonTableExpression(false, "a", ae)) - dts.assertPreparedSQL(b, `a AS (SELECT * FROM "b")`, emptyArgs) - - d.Literal(b.Clear(), exp.NewCommonTableExpression(false, "a(x,y)", ae)) - dts.assertPreparedSQL(b, `a(x,y) AS (SELECT * FROM "b")`, emptyArgs) - - d.Literal(b.Clear(), exp.NewCommonTableExpression(true, "a", ae)) - dts.assertPreparedSQL(b, `a AS (SELECT * FROM "b")`, emptyArgs) - - d.Literal(b.Clear(), exp.NewCommonTableExpression(true, "a(x,y)", ae)) - dts.assertPreparedSQL(b, `a(x,y) AS (SELECT * FROM "b")`, emptyArgs) -} - -func (dts *dialectTestSuite) TestLiteral_CompoundExpression() { - ae := newTestAppendableExpression(`SELECT * FROM "b"`, emptyArgs, nil, nil) - - b := sb.NewSQLBuilder(false) - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - - d.Literal(b.Clear(), exp.NewCompoundExpression(exp.UnionCompoundType, ae)) - dts.assertNotPreparedSQL(b, ` UNION (SELECT * FROM "b")`) - - d.Literal(b.Clear(), exp.NewCompoundExpression(exp.UnionAllCompoundType, ae)) - dts.assertNotPreparedSQL(b, ` UNION ALL (SELECT * FROM "b")`) - - d.Literal(b.Clear(), exp.NewCompoundExpression(exp.IntersectCompoundType, ae)) - dts.assertNotPreparedSQL(b, ` INTERSECT (SELECT * FROM "b")`) - - d.Literal(b.Clear(), exp.NewCompoundExpression(exp.IntersectAllCompoundType, ae)) - dts.assertNotPreparedSQL(b, ` INTERSECT ALL (SELECT * FROM "b")`) - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), exp.NewCompoundExpression(exp.UnionCompoundType, ae)) - dts.assertNotPreparedSQL(b, ` UNION (SELECT * FROM "b")`) - - d.Literal(b.Clear(), exp.NewCompoundExpression(exp.UnionAllCompoundType, ae)) - dts.assertNotPreparedSQL(b, ` UNION ALL (SELECT * FROM "b")`) - - d.Literal(b.Clear(), exp.NewCompoundExpression(exp.IntersectCompoundType, ae)) - dts.assertNotPreparedSQL(b, ` INTERSECT (SELECT * FROM "b")`) - - d.Literal(b.Clear(), exp.NewCompoundExpression(exp.IntersectAllCompoundType, ae)) - dts.assertNotPreparedSQL(b, ` INTERSECT ALL (SELECT * FROM "b")`) -} - -func (dts *dialectTestSuite) TestLiteral_IdentifierExpression() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "")) - dts.assertErrorSQL(b, `goqu: a empty identifier was encountered, please specify a "schema", "table" or "column"`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", nil)) - dts.assertErrorSQL(b, `goqu: a empty identifier was encountered, please specify a "schema", "table" or "column"`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", false)) - dts.assertErrorSQL(b, `goqu: unexpected col type must be string or LiteralExpression received bool`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "col")) - dts.assertNotPreparedSQL(b, `"col"`) - - d.Literal(b.Clear(), exp.ParseIdentifier("table.col")) - dts.assertNotPreparedSQL(b, `"table"."col"`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "col").Table("table")) - dts.assertNotPreparedSQL(b, `"table"."col"`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "table", "col")) - dts.assertNotPreparedSQL(b, `"table"."col"`) - - d.Literal(b.Clear(), exp.ParseIdentifier("a.b.c")) - dts.assertNotPreparedSQL(b, `"a"."b"."c"`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("schema", "table", "col")) - dts.assertNotPreparedSQL(b, `"schema"."table"."col"`) - - d.Literal(b.Clear(), exp.ParseIdentifier("schema.table.*")) - dts.assertNotPreparedSQL(b, `"schema"."table".*`) - - d.Literal(b.Clear(), exp.ParseIdentifier("table.*")) - dts.assertNotPreparedSQL(b, `"table".*`) - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "col")) - dts.assertNotPreparedSQL(b, `"col"`) - - d.Literal(b.Clear(), exp.ParseIdentifier("table.col")) - dts.assertNotPreparedSQL(b, `"table"."col"`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "", "col").Table("table")) - dts.assertNotPreparedSQL(b, `"table"."col"`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("", "table", "col")) - dts.assertNotPreparedSQL(b, `"table"."col"`) - - d.Literal(b.Clear(), exp.ParseIdentifier("a.b.c")) - dts.assertNotPreparedSQL(b, `"a"."b"."c"`) - - d.Literal(b.Clear(), exp.NewIdentifierExpression("schema", "table", "col")) - dts.assertNotPreparedSQL(b, `"schema"."table"."col"`) - - d.Literal(b.Clear(), exp.ParseIdentifier("schema.table.*")) - dts.assertNotPreparedSQL(b, `"schema"."table".*`) - - d.Literal(b.Clear(), exp.ParseIdentifier("table.*")) - dts.assertNotPreparedSQL(b, `"table".*`) -} - -func (dts *dialectTestSuite) TestLiteral_ExpressionMap() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), exp.Ex{"a": 1}) - dts.assertNotPreparedSQL(b, `("a" = 1)`) - - d.Literal(b.Clear(), exp.Ex{}) - dts.assertNotPreparedSQL(b, ``) - - d.Literal(b.Clear(), exp.Ex{"a": true}) - dts.assertNotPreparedSQL(b, `("a" IS TRUE)`) - - d.Literal(b.Clear(), exp.Ex{"a": false}) - dts.assertNotPreparedSQL(b, `("a" IS FALSE)`) - - d.Literal(b.Clear(), exp.Ex{"a": nil}) - dts.assertNotPreparedSQL(b, `("a" IS NULL)`) - - d.Literal(b.Clear(), exp.Ex{"a": []string{"a", "b", "c"}}) - dts.assertNotPreparedSQL(b, `("a" IN ('a', 'b', 'c'))`) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"neq": 1}}) - dts.assertNotPreparedSQL(b, `("a" != 1)`) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"isnot": true}}) - dts.assertNotPreparedSQL(b, `("a" IS NOT TRUE)`) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"gt": 1}}) - dts.assertNotPreparedSQL(b, `("a" > 1)`) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"gte": 1}}) - dts.assertNotPreparedSQL(b, `("a" >= 1)`) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"lt": 1}}) - dts.assertNotPreparedSQL(b, `("a" < 1)`) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"lte": 1}}) - dts.assertNotPreparedSQL(b, `("a" <= 1)`) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"like": "a%"}}) - dts.assertNotPreparedSQL(b, `("a" LIKE 'a%')`) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"notLike": "a%"}}) - dts.assertNotPreparedSQL(b, `("a" NOT LIKE 'a%')`) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"notLike": "a%"}}) - dts.assertNotPreparedSQL(b, `("a" NOT LIKE 'a%')`) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"in": []string{"a", "b", "c"}}}) - dts.assertNotPreparedSQL(b, `("a" IN ('a', 'b', 'c'))`) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"notIn": []string{"a", "b", "c"}}}) - dts.assertNotPreparedSQL(b, `("a" NOT IN ('a', 'b', 'c'))`) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"is": nil, "eq": 10}}) - dts.assertNotPreparedSQL(b, `(("a" = 10) OR ("a" IS NULL))`) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"between": exp.NewRangeVal(1, 10)}}) - dts.assertNotPreparedSQL(b, `("a" BETWEEN 1 AND 10)`) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"notbetween": exp.NewRangeVal(1, 10)}}) - dts.assertNotPreparedSQL(b, `("a" NOT BETWEEN 1 AND 10)`) - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), exp.Ex{"a": 1}) - dts.assertPreparedSQL(b, `("a" = ?)`, []interface{}{int64(1)}) - - d.Literal(b.Clear(), exp.Ex{"a": true}) - dts.assertPreparedSQL(b, `("a" IS TRUE)`, emptyArgs) - - d.Literal(b.Clear(), exp.Ex{"a": false}) - dts.assertPreparedSQL(b, `("a" IS FALSE)`, emptyArgs) - - d.Literal(b.Clear(), exp.Ex{"a": nil}) - dts.assertPreparedSQL(b, `("a" IS NULL)`, emptyArgs) - - d.Literal(b.Clear(), exp.Ex{"a": []string{"a", "b", "c"}}) - dts.assertPreparedSQL(b, `("a" IN (?, ?, ?))`, []interface{}{"a", "b", "c"}) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"neq": 1}}) - dts.assertPreparedSQL(b, `("a" != ?)`, []interface{}{int64(1)}) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"isnot": true}}) - dts.assertPreparedSQL(b, `("a" IS NOT TRUE)`, emptyArgs) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"gt": 1}}) - dts.assertPreparedSQL(b, `("a" > ?)`, []interface{}{int64(1)}) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"gte": 1}}) - dts.assertPreparedSQL(b, `("a" >= ?)`, []interface{}{int64(1)}) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"lt": 1}}) - dts.assertPreparedSQL(b, `("a" < ?)`, []interface{}{int64(1)}) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"lte": 1}}) - dts.assertPreparedSQL(b, `("a" <= ?)`, []interface{}{int64(1)}) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"like": "a%"}}) - dts.assertPreparedSQL(b, `("a" LIKE ?)`, []interface{}{"a%"}) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"notLike": "a%"}}) - dts.assertPreparedSQL(b, `("a" NOT LIKE ?)`, []interface{}{"a%"}) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"in": []string{"a", "b", "c"}}}) - dts.assertPreparedSQL(b, `("a" IN (?, ?, ?))`, []interface{}{"a", "b", "c"}) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"notIn": []string{"a", "b", "c"}}}) - dts.assertPreparedSQL(b, `("a" NOT IN (?, ?, ?))`, []interface{}{"a", "b", "c"}) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"is": nil, "eq": 10}}) - dts.assertPreparedSQL(b, `(("a" = ?) OR ("a" IS NULL))`, []interface{}{int64(10)}) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"between": exp.NewRangeVal(1, 10)}}) - dts.assertPreparedSQL(b, `("a" BETWEEN ? AND ?)`, []interface{}{int64(1), int64(10)}) - - d.Literal(b.Clear(), exp.Ex{"a": exp.Op{"notbetween": exp.NewRangeVal(1, 10)}}) - dts.assertPreparedSQL(b, `("a" NOT BETWEEN ? AND ?)`, []interface{}{int64(1), int64(10)}) -} - -func (dts *dialectTestSuite) TestLiteral_ExpressionOrMap() { - d := sqlDialect{dialect: "test", dialectOptions: DefaultDialectOptions()} - - b := sb.NewSQLBuilder(false) - d.Literal(b.Clear(), exp.ExOr{"a": 1, "b": true}) - dts.assertNotPreparedSQL(b, `(("a" = 1) OR ("b" IS TRUE))`) - - d.Literal(b.Clear(), exp.ExOr{"a": 1, "b": []string{"a", "b", "c"}}) - dts.assertNotPreparedSQL(b, `(("a" = 1) OR ("b" IN ('a', 'b', 'c')))`) - - b = sb.NewSQLBuilder(true) - d.Literal(b.Clear(), exp.ExOr{"a": 1, "b": true}) - dts.assertPreparedSQL(b, `(("a" = ?) OR ("b" IS TRUE))`, []interface{}{int64(1)}) - - d.Literal(b.Clear(), exp.ExOr{"a": 1, "b": []string{"a", "b", "c"}}) - dts.assertPreparedSQL(b, `(("a" = ?) OR ("b" IN (?, ?, ?)))`, []interface{}{int64(1), "a", "b", "c"}) - -} -func TestDialectSuite(t *testing.T) { +func TestSQLDialect(t *testing.T) { suite.Run(t, new(dialectTestSuite)) } diff --git a/sqlgen/base_test.go b/sqlgen/base_test.go new file mode 100644 index 00000000..223122a6 --- /dev/null +++ b/sqlgen/base_test.go @@ -0,0 +1,40 @@ +package sqlgen + +import ( + "github.com/doug-martin/goqu/v8/internal/sb" + "github.com/stretchr/testify/suite" +) + +type baseSQLGeneratorSuite struct { + suite.Suite +} + +func (bsgs *baseSQLGeneratorSuite) assertNotPreparedSQL(b sb.SQLBuilder, expectedSQL string) { + actualSQL, actualArgs, err := b.ToSQL() + bsgs.NoError(err) + bsgs.Equal(expectedSQL, actualSQL) + bsgs.Empty(actualArgs) +} + +func (bsgs *baseSQLGeneratorSuite) assertPreparedSQL( + b sb.SQLBuilder, + expectedSQL string, + expectedArgs []interface{}, +) { + actualSQL, actualArgs, err := b.ToSQL() + bsgs.NoError(err) + bsgs.Equal(expectedSQL, actualSQL) + if len(actualArgs) == 0 { + bsgs.Empty(expectedArgs) + } else { + bsgs.Equal(expectedArgs, actualArgs) + } + +} + +func (bsgs *baseSQLGeneratorSuite) assertErrorSQL(b sb.SQLBuilder, errMsg string) { + actualSQL, actualArgs, err := b.ToSQL() + bsgs.EqualError(err, errMsg) + bsgs.Empty(actualSQL) + bsgs.Empty(actualArgs) +} diff --git a/sqlgen/common_sql_generator.go b/sqlgen/common_sql_generator.go new file mode 100644 index 00000000..ce60c66c --- /dev/null +++ b/sqlgen/common_sql_generator.go @@ -0,0 +1,99 @@ +package sqlgen + +import ( + "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/errors" + "github.com/doug-martin/goqu/v8/internal/sb" +) + +var ( + errNoUpdatedValuesProvided = errors.New("no update values provided") +) + +func errCTENotSupported(dialect string) error { + return errors.New("dialect does not support CTE WITH clause [dialect=%s]", dialect) +} +func errRecursiveCTENotSupported(dialect string) error { + return errors.New("dialect does not support CTE WITH RECURSIVE clause [dialect=%s]", dialect) +} + +func errReturnNotSupported(dialect string) error { + return errors.New("dialect does not support RETURNING clause [dialect=%s]", dialect) +} + +func errNotSupportedFragment(sqlType string, f SQLFragmentType) error { + return errors.New("unsupported %s SQL fragment %s", sqlType, f) +} + +type commonSQLGenerator struct { + dialect string + esg ExpressionSQLGenerator + dialectOptions *SQLDialectOptions +} + +func newCommonSQLGenerator(dialect string, do *SQLDialectOptions) *commonSQLGenerator { + return &commonSQLGenerator{dialect: dialect, esg: NewExpressionSQLGenerator(dialect, do), dialectOptions: do} +} + +func (csg *commonSQLGenerator) ReturningSQL(b sb.SQLBuilder, returns exp.ColumnListExpression) { + if returns != nil && len(returns.Columns()) > 0 { + if csg.dialectOptions.SupportsReturn { + b.Write(csg.dialectOptions.ReturningFragment) + csg.esg.Generate(b, returns) + } else { + b.SetError(errReturnNotSupported(csg.dialect)) + } + } +} + +// Adds the FROM clause and tables to an sql statement +func (csg *commonSQLGenerator) FromSQL(b sb.SQLBuilder, from exp.ColumnListExpression) { + if from != nil && !from.IsEmpty() { + b.Write(csg.dialectOptions.FromFragment) + csg.SourcesSQL(b, from) + } +} + +// Adds the generates the SQL for a column list +func (csg *commonSQLGenerator) SourcesSQL(b sb.SQLBuilder, from exp.ColumnListExpression) { + b.WriteRunes(csg.dialectOptions.SpaceRune) + csg.esg.Generate(b, from) +} + +// Generates the WHERE clause for an SQL statement +func (csg *commonSQLGenerator) WhereSQL(b sb.SQLBuilder, where exp.ExpressionList) { + if where != nil && !where.IsEmpty() { + b.Write(csg.dialectOptions.WhereFragment) + csg.esg.Generate(b, where) + } +} + +// Generates the ORDER BY clause for an SQL statement +func (csg *commonSQLGenerator) OrderSQL(b sb.SQLBuilder, order exp.ColumnListExpression) { + if order != nil && len(order.Columns()) > 0 { + b.Write(csg.dialectOptions.OrderByFragment) + csg.esg.Generate(b, order) + } +} + +// Generates the LIMIT clause for an SQL statement +func (csg *commonSQLGenerator) LimitSQL(b sb.SQLBuilder, limit interface{}) { + if limit != nil { + b.Write(csg.dialectOptions.LimitFragment) + csg.esg.Generate(b, limit) + } +} + +func (csg *commonSQLGenerator) UpdateExpressionSQL(b sb.SQLBuilder, updates ...exp.UpdateExpression) { + if len(updates) == 0 { + b.SetError(errNoUpdatedValuesProvided) + return + } + updateLen := len(updates) + for i, update := range updates { + csg.esg.Generate(b, update) + if i < updateLen-1 { + b.WriteRunes(csg.dialectOptions.CommaRune) + } + } +} diff --git a/sqlgen/common_sql_generator_test.go b/sqlgen/common_sql_generator_test.go new file mode 100644 index 00000000..c4161207 --- /dev/null +++ b/sqlgen/common_sql_generator_test.go @@ -0,0 +1,341 @@ +package sqlgen + +import ( + "testing" + + "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/sb" + "github.com/stretchr/testify/suite" +) + +type ( + commonSQLTestCase struct { + gen func(builder sb.SQLBuilder) + sql string + isPrepared bool + err string + args []interface{} + } + commonSQLGeneratorSuite struct { + baseSQLGeneratorSuite + } +) + +func (csgs *commonSQLGeneratorSuite) assertCases(testCases ...commonSQLTestCase) { + for _, tc := range testCases { + b := sb.NewSQLBuilder(tc.isPrepared) + tc.gen(b) + switch { + case len(tc.err) > 0: + csgs.assertErrorSQL(b, tc.err) + case tc.isPrepared: + csgs.assertPreparedSQL(b, tc.sql, tc.args) + default: + csgs.assertNotPreparedSQL(b, tc.sql) + } + } +} + +func (csgs *commonSQLGeneratorSuite) TestReturningSQL() { + + returningGen := func(csgs *commonSQLGenerator) func(sb.SQLBuilder) { + return func(sb sb.SQLBuilder) { + csgs.ReturningSQL(sb, exp.NewColumnListExpression("a", "b")) + } + } + + returningNoColsGen := func(csgs *commonSQLGenerator) func(sb.SQLBuilder) { + return func(sb sb.SQLBuilder) { + csgs.ReturningSQL(sb, exp.NewColumnListExpression()) + } + } + + returningNilExpGen := func(csgs *commonSQLGenerator) func(sb.SQLBuilder) { + return func(sb sb.SQLBuilder) { + csgs.ReturningSQL(sb, nil) + } + } + + opts := DefaultDialectOptions() + opts.SupportsReturn = true + csgs1 := newCommonSQLGenerator("test", opts) + + opts2 := DefaultDialectOptions() + opts2.SupportsReturn = false + csgs2 := newCommonSQLGenerator("test", opts2) + + csgs.assertCases( + commonSQLTestCase{gen: returningGen(csgs1), sql: ` RETURNING "a", "b"`}, + commonSQLTestCase{gen: returningGen(csgs1), sql: ` RETURNING "a", "b"`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: returningNoColsGen(csgs1), sql: ``}, + commonSQLTestCase{gen: returningNoColsGen(csgs1), sql: ``, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: returningNilExpGen(csgs1), sql: ``}, + commonSQLTestCase{gen: returningNilExpGen(csgs1), sql: ``, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: returningGen(csgs2), err: `goqu: dialect does not support RETURNING clause [dialect=test]`}, + commonSQLTestCase{gen: returningGen(csgs2), err: `goqu: dialect does not support RETURNING clause [dialect=test]`}, + ) +} + +func (csgs *commonSQLGeneratorSuite) TestFromSQL() { + fromGen := func(csgs *commonSQLGenerator) func(sb.SQLBuilder) { + return func(sb sb.SQLBuilder) { + csgs.FromSQL(sb, exp.NewColumnListExpression("a", "b")) + } + } + + fromNoColsGen := func(csgs *commonSQLGenerator) func(sb.SQLBuilder) { + return func(sb sb.SQLBuilder) { + csgs.FromSQL(sb, exp.NewColumnListExpression()) + } + } + + fromNilExpGen := func(csgs *commonSQLGenerator) func(sb.SQLBuilder) { + return func(sb sb.SQLBuilder) { + csgs.FromSQL(sb, nil) + } + } + + csg := newCommonSQLGenerator("test", DefaultDialectOptions()) + + opts := DefaultDialectOptions() + opts.FromFragment = []byte(" from") + csgFromFrag := newCommonSQLGenerator("test", opts) + + csgs.assertCases( + commonSQLTestCase{gen: fromGen(csg), sql: ` FROM "a", "b"`}, + commonSQLTestCase{gen: fromGen(csg), sql: ` FROM "a", "b"`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: fromNoColsGen(csg), sql: ``}, + commonSQLTestCase{gen: fromNoColsGen(csg), sql: ``, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: fromNilExpGen(csg), sql: ``}, + commonSQLTestCase{gen: fromNilExpGen(csg), sql: ``, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: fromGen(csgFromFrag), sql: ` from "a", "b"`}, + commonSQLTestCase{gen: fromGen(csgFromFrag), sql: ` from "a", "b"`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: fromNoColsGen(csgFromFrag), sql: ``}, + commonSQLTestCase{gen: fromNoColsGen(csgFromFrag), sql: ``, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: fromNilExpGen(csgFromFrag), sql: ``}, + commonSQLTestCase{gen: fromNilExpGen(csgFromFrag), sql: ``, isPrepared: true, args: emptyArgs}, + ) +} + +func (csgs *commonSQLGeneratorSuite) TestWhereSQL() { + + whereAndGen := func(csgs *commonSQLGenerator, exps ...exp.Expression) func(sb.SQLBuilder) { + return func(sb sb.SQLBuilder) { + csgs.WhereSQL(sb, exp.NewExpressionList(exp.AndType, exps...)) + } + } + + whereOrGen := func(csgs *commonSQLGenerator, exps ...exp.Expression) func(sb.SQLBuilder) { + return func(sb sb.SQLBuilder) { + csgs.WhereSQL(sb, exp.NewExpressionList(exp.OrType, exps...)) + } + } + + csg := newCommonSQLGenerator("test", DefaultDialectOptions()) + + opts := DefaultDialectOptions() + opts.WhereFragment = []byte(" where ") + csgWhereFrag := newCommonSQLGenerator("test", opts) + + w := exp.Ex{"a": "b"} + w2 := exp.Ex{"b": "c"} + + csgs.assertCases( + commonSQLTestCase{gen: whereAndGen(csg), sql: ``}, + commonSQLTestCase{gen: whereAndGen(csg), sql: ``, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: whereAndGen(csg, w), sql: ` WHERE ("a" = 'b')`}, + commonSQLTestCase{gen: whereAndGen(csg, w), sql: ` WHERE ("a" = ?)`, isPrepared: true, args: []interface{}{"b"}}, + + commonSQLTestCase{gen: whereAndGen(csg, w, w2), sql: ` WHERE (("a" = 'b') AND ("b" = 'c'))`}, + commonSQLTestCase{gen: whereAndGen(csg, w, w2), sql: ` WHERE (("a" = ?) AND ("b" = ?))`, isPrepared: true, args: []interface{}{"b", "c"}}, + + commonSQLTestCase{gen: whereOrGen(csg), sql: ``}, + commonSQLTestCase{gen: whereOrGen(csg), sql: ``, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: whereOrGen(csg, w), sql: ` WHERE ("a" = 'b')`}, + commonSQLTestCase{gen: whereOrGen(csg, w), sql: ` WHERE ("a" = ?)`, isPrepared: true, args: []interface{}{"b"}}, + + commonSQLTestCase{gen: whereOrGen(csg, w, w2), sql: ` WHERE (("a" = 'b') OR ("b" = 'c'))`}, + commonSQLTestCase{gen: whereOrGen(csg, w, w2), sql: ` WHERE (("a" = ?) OR ("b" = ?))`, isPrepared: true, args: []interface{}{"b", "c"}}, + + commonSQLTestCase{gen: whereAndGen(csgWhereFrag), sql: ``}, + commonSQLTestCase{gen: whereAndGen(csgWhereFrag), sql: ``, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: whereAndGen(csgWhereFrag, w), sql: ` where ("a" = 'b')`}, + commonSQLTestCase{gen: whereAndGen(csgWhereFrag, w), sql: ` where ("a" = ?)`, isPrepared: true, args: []interface{}{"b"}}, + + commonSQLTestCase{gen: whereAndGen(csgWhereFrag, w, w2), sql: ` where (("a" = 'b') AND ("b" = 'c'))`}, + commonSQLTestCase{ + gen: whereAndGen(csgWhereFrag, w, w2), + sql: ` where (("a" = ?) AND ("b" = ?))`, + isPrepared: true, + args: []interface{}{"b", "c"}, + }, + + commonSQLTestCase{gen: whereOrGen(csgWhereFrag), sql: ``}, + commonSQLTestCase{gen: whereOrGen(csgWhereFrag), sql: ``, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: whereOrGen(csgWhereFrag, w), sql: ` where ("a" = 'b')`}, + commonSQLTestCase{gen: whereOrGen(csgWhereFrag, w), sql: ` where ("a" = ?)`, isPrepared: true, args: []interface{}{"b"}}, + + commonSQLTestCase{gen: whereOrGen(csgWhereFrag, w, w2), sql: ` where (("a" = 'b') OR ("b" = 'c'))`}, + commonSQLTestCase{ + gen: whereOrGen(csgWhereFrag, w, w2), + sql: ` where (("a" = ?) OR ("b" = ?))`, + isPrepared: true, + args: []interface{}{"b", "c"}, + }, + ) +} + +func (csgs *commonSQLGeneratorSuite) TestOrderSQL() { + + orderGen := func(csgs *commonSQLGenerator, o ...exp.OrderedExpression) func(sb.SQLBuilder) { + return func(sb sb.SQLBuilder) { + csgs.OrderSQL(sb, exp.NewOrderedColumnList(o...)) + } + } + + csg := newCommonSQLGenerator("test", DefaultDialectOptions()) + + opts := DefaultDialectOptions() + // override fragments to ensure they are used + opts.OrderByFragment = []byte(" order by ") + opts.AscFragment = []byte(" asc") + opts.DescFragment = []byte(" desc") + opts.NullsFirstFragment = []byte(" nulls first") + opts.NullsLastFragment = []byte(" nulls last") + csgCustom := newCommonSQLGenerator("test", opts) + + ident := exp.NewIdentifierExpression("", "", "a") + oa := ident.Asc() + oanf := ident.Asc().NullsFirst() + oanl := ident.Asc().NullsLast() + + od := ident.Desc() + odnf := ident.Desc().NullsFirst() + odnl := ident.Desc().NullsLast() + + csgs.assertCases( + commonSQLTestCase{gen: orderGen(csg), sql: ``}, + commonSQLTestCase{gen: orderGen(csg), sql: ``, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: orderGen(csg, oa), sql: ` ORDER BY "a" ASC`}, + commonSQLTestCase{gen: orderGen(csg, oa), sql: ` ORDER BY "a" ASC`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: orderGen(csg, oanf), sql: ` ORDER BY "a" ASC NULLS FIRST`}, + commonSQLTestCase{gen: orderGen(csg, oanf), sql: ` ORDER BY "a" ASC NULLS FIRST`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: orderGen(csg, oanl), sql: ` ORDER BY "a" ASC NULLS LAST`}, + commonSQLTestCase{gen: orderGen(csg, oanl), sql: ` ORDER BY "a" ASC NULLS LAST`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: orderGen(csg, od), sql: ` ORDER BY "a" DESC`}, + commonSQLTestCase{gen: orderGen(csg, od), sql: ` ORDER BY "a" DESC`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: orderGen(csg, odnf), sql: ` ORDER BY "a" DESC NULLS FIRST`}, + commonSQLTestCase{gen: orderGen(csg, odnf), sql: ` ORDER BY "a" DESC NULLS FIRST`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: orderGen(csg, odnl), sql: ` ORDER BY "a" DESC NULLS LAST`}, + commonSQLTestCase{gen: orderGen(csg, odnl), sql: ` ORDER BY "a" DESC NULLS LAST`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: orderGen(csg, oa, od), sql: ` ORDER BY "a" ASC, "a" DESC`}, + commonSQLTestCase{gen: orderGen(csg, oa, od), sql: ` ORDER BY "a" ASC, "a" DESC`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: orderGen(csgCustom), sql: ``}, + commonSQLTestCase{gen: orderGen(csgCustom), sql: ``, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: orderGen(csgCustom, oa), sql: ` order by "a" asc`}, + commonSQLTestCase{gen: orderGen(csgCustom, oa), sql: ` order by "a" asc`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: orderGen(csgCustom, oanf), sql: ` order by "a" asc nulls first`}, + commonSQLTestCase{gen: orderGen(csgCustom, oanf), sql: ` order by "a" asc nulls first`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: orderGen(csgCustom, oanl), sql: ` order by "a" asc nulls last`}, + commonSQLTestCase{gen: orderGen(csgCustom, oanl), sql: ` order by "a" asc nulls last`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: orderGen(csgCustom, od), sql: ` order by "a" desc`}, + commonSQLTestCase{gen: orderGen(csgCustom, od), sql: ` order by "a" desc`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: orderGen(csgCustom, odnf), sql: ` order by "a" desc nulls first`}, + commonSQLTestCase{gen: orderGen(csgCustom, odnf), sql: ` order by "a" desc nulls first`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: orderGen(csgCustom, odnl), sql: ` order by "a" desc nulls last`}, + commonSQLTestCase{gen: orderGen(csgCustom, odnl), sql: ` order by "a" desc nulls last`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: orderGen(csgCustom, oa, od), sql: ` order by "a" asc, "a" desc`}, + commonSQLTestCase{gen: orderGen(csgCustom, oa, od), sql: ` order by "a" asc, "a" desc`, isPrepared: true, args: emptyArgs}, + ) +} + +func (csgs *commonSQLGeneratorSuite) TestLimitSQL() { + limitGen := func(csgs *commonSQLGenerator, l interface{}) func(sb.SQLBuilder) { + return func(sb sb.SQLBuilder) { + csgs.LimitSQL(sb, l) + } + } + + csg := newCommonSQLGenerator("test", DefaultDialectOptions()) + + opts := DefaultDialectOptions() + opts.LimitFragment = []byte(" limit ") + csgCustom := newCommonSQLGenerator("test", opts) + + l := int64(10) + la := exp.NewLiteralExpression("ALL") + + csgs.assertCases( + commonSQLTestCase{gen: limitGen(csg, nil), sql: ``}, + commonSQLTestCase{gen: limitGen(csg, nil), sql: ``, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: limitGen(csg, l), sql: ` LIMIT 10`}, + commonSQLTestCase{gen: limitGen(csg, l), sql: ` LIMIT ?`, isPrepared: true, args: []interface{}{l}}, + + commonSQLTestCase{gen: limitGen(csg, la), sql: ` LIMIT ALL`}, + commonSQLTestCase{gen: limitGen(csg, la), sql: ` LIMIT ALL`, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: limitGen(csgCustom, nil), sql: ``}, + commonSQLTestCase{gen: limitGen(csgCustom, nil), sql: ``, isPrepared: true, args: emptyArgs}, + + commonSQLTestCase{gen: limitGen(csgCustom, l), sql: ` limit 10`}, + commonSQLTestCase{gen: limitGen(csgCustom, l), sql: ` limit ?`, isPrepared: true, args: []interface{}{l}}, + + commonSQLTestCase{gen: limitGen(csgCustom, la), sql: ` limit ALL`}, + commonSQLTestCase{gen: limitGen(csgCustom, la), sql: ` limit ALL`, isPrepared: true, args: emptyArgs}, + ) +} + +func (csgs *commonSQLGeneratorSuite) TestUpdateExpressionSQL() { + updateGen := func(csgs *commonSQLGenerator, ues ...exp.UpdateExpression) func(sb.SQLBuilder) { + return func(sb sb.SQLBuilder) { + csgs.UpdateExpressionSQL(sb, ues...) + } + } + + csg := newCommonSQLGenerator("test", DefaultDialectOptions()) + ue := exp.NewIdentifierExpression("", "", "col").Set("a") + ue2 := exp.NewIdentifierExpression("", "", "col2").Set("b") + + csgs.assertCases( + commonSQLTestCase{gen: updateGen(csg), err: errNoUpdatedValuesProvided.Error()}, + commonSQLTestCase{gen: updateGen(csg), err: errNoUpdatedValuesProvided.Error()}, + + commonSQLTestCase{gen: updateGen(csg, ue), sql: `"col"='a'`}, + commonSQLTestCase{gen: updateGen(csg, ue), sql: `"col"=?`, isPrepared: true, args: []interface{}{"a"}}, + + commonSQLTestCase{gen: updateGen(csg, ue, ue2), sql: `"col"='a',"col2"='b'`}, + commonSQLTestCase{gen: updateGen(csg, ue, ue2), sql: `"col"=?,"col2"=?`, isPrepared: true, args: []interface{}{"a", "b"}}, + ) +} + +func TestCommonSQLGenerator(t *testing.T) { + suite.Run(t, new(commonSQLGeneratorSuite)) +} diff --git a/sqlgen/delete_sql_generator.go b/sqlgen/delete_sql_generator.go new file mode 100644 index 00000000..67477514 --- /dev/null +++ b/sqlgen/delete_sql_generator.go @@ -0,0 +1,73 @@ +package sqlgen + +import ( + "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/errors" + "github.com/doug-martin/goqu/v8/internal/sb" +) + +type ( + // An adapter interface to be used by a Dataset to generate SQL for a specific dialect. + // See DefaultAdapter for a concrete implementation and examples. + DeleteSQLGenerator interface { + Dialect() string + Generate(b sb.SQLBuilder, clauses exp.DeleteClauses) + } + // The default adapter. This class should be used when building a new adapter. When creating a new adapter you can + // either override methods, or more typically update default values. + // See (github.com/doug-martin/goqu/adapters/postgres) + deleteSQLGenerator struct { + *commonSQLGenerator + } +) + +var ( + errNoSourceForDelete = errors.New("no source found when generating delete sql") +) + +func NewDeleteSQLGenerator(dialect string, do *SQLDialectOptions) DeleteSQLGenerator { + return &deleteSQLGenerator{newCommonSQLGenerator(dialect, do)} +} + +func (dsg *deleteSQLGenerator) Dialect() string { + return dsg.dialect +} + +func (dsg *deleteSQLGenerator) Generate(b sb.SQLBuilder, clauses exp.DeleteClauses) { + if !clauses.HasFrom() { + b.SetError(errNoSourceForDelete) + return + } + for _, f := range dsg.dialectOptions.DeleteSQLOrder { + if b.Error() != nil { + return + } + switch f { + case CommonTableSQLFragment: + dsg.esg.Generate(b, clauses.CommonTables()) + case DeleteBeginSQLFragment: + dsg.DeleteBeginSQL(b) + case FromSQLFragment: + dsg.FromSQL(b, exp.NewColumnListExpression(clauses.From())) + case WhereSQLFragment: + dsg.WhereSQL(b, clauses.Where()) + case OrderSQLFragment: + if dsg.dialectOptions.SupportsOrderByOnDelete { + dsg.OrderSQL(b, clauses.Order()) + } + case LimitSQLFragment: + if dsg.dialectOptions.SupportsLimitOnDelete { + dsg.LimitSQL(b, clauses.Limit()) + } + case ReturningSQLFragment: + dsg.ReturningSQL(b, clauses.Returning()) + default: + b.SetError(errNotSupportedFragment("DELETE", f)) + } + } +} + +// Adds the correct fragment to being an DELETE statement +func (dsg *deleteSQLGenerator) DeleteBeginSQL(b sb.SQLBuilder) { + b.Write(dsg.dialectOptions.DeleteClause) +} diff --git a/sqlgen/delete_sql_generator_test.go b/sqlgen/delete_sql_generator_test.go new file mode 100644 index 00000000..1bda1b96 --- /dev/null +++ b/sqlgen/delete_sql_generator_test.go @@ -0,0 +1,233 @@ +package sqlgen + +import ( + "testing" + + "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/errors" + "github.com/doug-martin/goqu/v8/internal/sb" + "github.com/stretchr/testify/suite" +) + +type ( + deleteTestCase struct { + clause exp.DeleteClauses + sql string + isPrepared bool + args []interface{} + err string + } + deleteSQLGeneratorSuite struct { + baseSQLGeneratorSuite + } +) + +func (dsgs *deleteSQLGeneratorSuite) assertCases(dsg DeleteSQLGenerator, testCases ...deleteTestCase) { + for _, tc := range testCases { + b := sb.NewSQLBuilder(tc.isPrepared) + dsg.Generate(b, tc.clause) + switch { + case len(tc.err) > 0: + dsgs.assertErrorSQL(b, tc.err) + case tc.isPrepared: + dsgs.assertPreparedSQL(b, tc.sql, tc.args) + default: + dsgs.assertNotPreparedSQL(b, tc.sql) + } + } +} + +func (dsgs *deleteSQLGeneratorSuite) TestDialect() { + opts := DefaultDialectOptions() + d := NewDeleteSQLGenerator("test", opts) + dsgs.Equal("test", d.Dialect()) + + opts2 := DefaultDialectOptions() + d2 := NewDeleteSQLGenerator("test2", opts2) + dsgs.Equal("test2", d2.Dialect()) +} + +func (dsgs *deleteSQLGeneratorSuite) TestGenerate() { + + dc := exp.NewDeleteClauses(). + SetFrom(exp.NewIdentifierExpression("", "test", "")) + + dsgs.assertCases( + NewDeleteSQLGenerator("test", DefaultDialectOptions()), + deleteTestCase{clause: dc, sql: `DELETE FROM "test"`}, + deleteTestCase{clause: dc, sql: `DELETE FROM "test"`, isPrepared: true}, + ) + + opts2 := DefaultDialectOptions() + opts2.DeleteClause = []byte("delete") + + dsgs.assertCases( + NewDeleteSQLGenerator("test", opts2), + deleteTestCase{clause: dc, sql: `delete FROM "test"`}, + deleteTestCase{clause: dc, sql: `delete FROM "test"`, isPrepared: true}, + ) +} + +func (dsgs *deleteSQLGeneratorSuite) TestGenerate_withUnsupportedFragment() { + opts := DefaultDialectOptions() + opts.DeleteSQLOrder = []SQLFragmentType{InsertBeingSQLFragment} + dc := exp.NewDeleteClauses(). + SetFrom(exp.NewIdentifierExpression("", "test", "")) + + dsgs.assertCases( + NewDeleteSQLGenerator("test", opts), + deleteTestCase{clause: dc, err: `goqu: unsupported DELETE SQL fragment InsertBeingSQLFragment`}, + deleteTestCase{clause: dc, err: `goqu: unsupported DELETE SQL fragment InsertBeingSQLFragment`, isPrepared: true}, + ) +} + +func (dsgs *deleteSQLGeneratorSuite) TestGenerate_noFrom() { + dc := exp.NewDeleteClauses() + dsgs.assertCases( + NewDeleteSQLGenerator("test", DefaultDialectOptions()), + deleteTestCase{clause: dc, err: errNoSourceForDelete.Error()}, + deleteTestCase{clause: dc, err: errNoSourceForDelete.Error(), isPrepared: true}, + ) +} + +func (dsgs *deleteSQLGeneratorSuite) TestGenerate_withErroredBuilder() { + opts := DefaultDialectOptions() + d := NewDeleteSQLGenerator("test", opts) + + dc := exp.NewDeleteClauses().SetFrom(exp.NewIdentifierExpression("", "test", "")) + b := sb.NewSQLBuilder(false).SetError(errors.New("expected error")) + d.Generate(b, dc) + dsgs.assertErrorSQL(b, "goqu: expected error") + + b = sb.NewSQLBuilder(true).SetError(errors.New("expected error")) + d.Generate(b, dc) + dsgs.assertErrorSQL(b, "goqu: expected error") +} + +func (dsgs *deleteSQLGeneratorSuite) TestGenerate_withCommonTables() { + opts := DefaultDialectOptions() + opts.WithFragment = []byte("with ") + opts.RecursiveFragment = []byte("recursive ") + + tse := newTestAppendableExpression("select * from foo", emptyArgs, nil, nil) + + dc := exp.NewDeleteClauses().SetFrom(exp.NewIdentifierExpression("", "test_cte", "")) + dcCte1 := dc.CommonTablesAppend(exp.NewCommonTableExpression(false, "test_cte", tse)) + dcCte2 := dc.CommonTablesAppend(exp.NewCommonTableExpression(true, "test_cte", tse)) + + dsgs.assertCases( + NewDeleteSQLGenerator("test", opts), + deleteTestCase{clause: dcCte1, sql: `with test_cte AS (select * from foo) DELETE FROM "test_cte"`}, + deleteTestCase{clause: dcCte1, sql: `with test_cte AS (select * from foo) DELETE FROM "test_cte"`, isPrepared: true}, + + deleteTestCase{clause: dcCte2, sql: `with recursive test_cte AS (select * from foo) DELETE FROM "test_cte"`}, + deleteTestCase{clause: dcCte2, sql: `with recursive test_cte AS (select * from foo) DELETE FROM "test_cte"`, isPrepared: true}, + ) + + opts.SupportsWithCTE = false + expectedErr := errCTENotSupported("test") + dsgs.assertCases( + NewDeleteSQLGenerator("test", opts), + deleteTestCase{clause: dcCte1, err: expectedErr.Error()}, + deleteTestCase{clause: dcCte1, err: expectedErr.Error(), isPrepared: true}, + + deleteTestCase{clause: dcCte2, err: expectedErr.Error()}, + deleteTestCase{clause: dcCte2, err: expectedErr.Error(), isPrepared: true}, + ) + + opts.SupportsWithCTE = true + opts.SupportsWithCTERecursive = false + expectedErr = errRecursiveCTENotSupported("test") + dsgs.assertCases( + NewDeleteSQLGenerator("test", opts), + deleteTestCase{clause: dcCte1, sql: `with test_cte AS (select * from foo) DELETE FROM "test_cte"`}, + deleteTestCase{clause: dcCte1, sql: `with test_cte AS (select * from foo) DELETE FROM "test_cte"`, isPrepared: true}, + + deleteTestCase{clause: dcCte2, err: expectedErr.Error()}, + deleteTestCase{clause: dcCte2, err: expectedErr.Error(), isPrepared: true}, + ) +} + +func (dsgs *deleteSQLGeneratorSuite) TestGenerate_withWhere() { + dc := exp.NewDeleteClauses(). + SetFrom(exp.NewIdentifierExpression("", "test", "")). + WhereAppend(exp.NewLiteralExpression(`"a"=?`, 1)) + dsgs.assertCases( + NewDeleteSQLGenerator("test", DefaultDialectOptions()), + deleteTestCase{clause: dc, sql: `DELETE FROM "test" WHERE "a"=1`}, + deleteTestCase{clause: dc, sql: `DELETE FROM "test" WHERE "a"=?`, isPrepared: true, args: []interface{}{ + int64(1), + }}, + ) +} + +func (dsgs *deleteSQLGeneratorSuite) TestGenerate_withOrder() { + opts := DefaultDialectOptions() + opts.SupportsOrderByOnDelete = true + + dc := exp.NewDeleteClauses(). + SetFrom(exp.NewIdentifierExpression("", "test", "")). + SetOrder(exp.NewIdentifierExpression("", "", "c").Desc()) + + dsgs.assertCases( + NewDeleteSQLGenerator("test", opts), + deleteTestCase{clause: dc, sql: `DELETE FROM "test" ORDER BY "c" DESC`}, + deleteTestCase{clause: dc, sql: `DELETE FROM "test" ORDER BY "c" DESC`, isPrepared: true}, + ) + + opts.SupportsOrderByOnDelete = false + dsgs.assertCases( + NewDeleteSQLGenerator("test", opts), + deleteTestCase{clause: dc, sql: `DELETE FROM "test"`}, + deleteTestCase{clause: dc, sql: `DELETE FROM "test"`, isPrepared: true}, + ) +} + +func (dsgs *deleteSQLGeneratorSuite) TestGenerate_withLimit() { + opts := DefaultDialectOptions() + opts.SupportsLimitOnDelete = true + + dc := exp.NewDeleteClauses(). + SetFrom(exp.NewIdentifierExpression("", "test", "")). + SetLimit(1) + + dsgs.assertCases( + NewDeleteSQLGenerator("test", opts), + deleteTestCase{clause: dc, sql: `DELETE FROM "test" LIMIT 1`}, + deleteTestCase{clause: dc, sql: `DELETE FROM "test" LIMIT ?`, isPrepared: true, args: []interface{}{int64(1)}}, + ) + + opts.SupportsLimitOnDelete = false + dsgs.assertCases( + NewDeleteSQLGenerator("test", opts), + deleteTestCase{clause: dc, sql: `DELETE FROM "test"`}, + deleteTestCase{clause: dc, sql: `DELETE FROM "test"`, isPrepared: true}, + ) +} + +func (dsgs *deleteSQLGeneratorSuite) TestGenerate_withReturning() { + opts := DefaultDialectOptions() + opts.SupportsReturn = true + + dc := exp.NewDeleteClauses(). + SetFrom(exp.NewIdentifierExpression("", "test", "")). + SetReturning(exp.NewColumnListExpression("a", "b")) + + dsgs.assertCases( + NewDeleteSQLGenerator("test", opts), + deleteTestCase{clause: dc, sql: `DELETE FROM "test" RETURNING "a", "b"`}, + deleteTestCase{clause: dc, sql: `DELETE FROM "test" RETURNING "a", "b"`, isPrepared: true}, + ) + + opts.SupportsReturn = false + expectedErr := `goqu: dialect does not support RETURNING clause [dialect=test]` + dsgs.assertCases( + NewDeleteSQLGenerator("test", opts), + deleteTestCase{clause: dc, err: expectedErr}, + deleteTestCase{clause: dc, err: expectedErr, isPrepared: true}, + ) +} + +func TestDeleteSQLGenerator(t *testing.T) { + suite.Run(t, new(deleteSQLGeneratorSuite)) +} diff --git a/sqlgen/expression_sql_generator.go b/sqlgen/expression_sql_generator.go new file mode 100644 index 00000000..f7afd2b6 --- /dev/null +++ b/sqlgen/expression_sql_generator.go @@ -0,0 +1,568 @@ +package sqlgen + +import ( + "database/sql/driver" + "reflect" + "strconv" + "time" + "unicode/utf8" + + "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/errors" + "github.com/doug-martin/goqu/v8/internal/sb" + "github.com/doug-martin/goqu/v8/internal/util" +) + +type ( + // An adapter interface to be used by a Dataset to generate SQL for a specific dialect. + // See DefaultAdapter for a concrete implementation and examples. + ExpressionSQLGenerator interface { + Dialect() string + Generate(b sb.SQLBuilder, val interface{}) + } + // The default adapter. This class should be used when building a new adapter. When creating a new adapter you can + // either override methods, or more typically update default values. + // See (github.com/doug-martin/goqu/adapters/postgres) + expressionSQLGenerator struct { + dialect string + dialectOptions *SQLDialectOptions + } +) + +var ( + replacementRune = '?' + TrueLiteral = exp.NewLiteralExpression("TRUE") + FalseLiteral = exp.NewLiteralExpression("FALSE") + + errEmptyIdentifier = errors.New(`a empty identifier was encountered, please specify a "schema", "table" or "column"`) +) + +func errUnsupportedExpressionType(e exp.Expression) error { + return errors.New("unsupported expression type %T", e) +} + +func errUnsupportedIdentifierExpression(t interface{}) error { + return errors.New("unexpected col type must be string or LiteralExpression received %T", t) +} + +func errUnsupportedBooleanExpressionOperator(op exp.BooleanOperation) error { + return errors.New("boolean operator '%+v' not supported", op) +} + +func errUnsupportedRangeExpressionOperator(op exp.RangeOperation) error { + return errors.New("range operator %+v not supported", op) +} + +func NewExpressionSQLGenerator(dialect string, do *SQLDialectOptions) ExpressionSQLGenerator { + return &expressionSQLGenerator{dialect: dialect, dialectOptions: do} +} + +func (esg *expressionSQLGenerator) Dialect() string { + return esg.dialect +} + +func (esg *expressionSQLGenerator) Generate(b sb.SQLBuilder, val interface{}) { + if b.Error() != nil { + return + } + if val == nil { + esg.literalNil(b) + return + } + switch v := val.(type) { + case exp.Expression: + esg.expressionSQL(b, v) + case int: + esg.literalInt(b, int64(v)) + case int32: + esg.literalInt(b, int64(v)) + case int64: + esg.literalInt(b, v) + case float32: + esg.literalFloat(b, float64(v)) + case float64: + esg.literalFloat(b, v) + case string: + esg.literalString(b, v) + case bool: + esg.literalBool(b, v) + case time.Time: + esg.literalTime(b, v) + case *time.Time: + if v == nil { + esg.literalNil(b) + return + } + esg.literalTime(b, *v) + case driver.Valuer: + dVal, err := v.Value() + if err != nil { + b.SetError(err) + return + } + esg.Generate(b, dVal) + default: + esg.reflectSQL(b, val) + } +} + +func (esg *expressionSQLGenerator) reflectSQL(b sb.SQLBuilder, val interface{}) { + v := reflect.Indirect(reflect.ValueOf(val)) + valKind := v.Kind() + switch { + case util.IsInvalid(valKind): + esg.literalNil(b) + case util.IsSlice(valKind): + switch t := val.(type) { + case []byte: + esg.literalBytes(b, t) + case []exp.CommonTableExpression: + esg.commonTablesSliceSQL(b, t) + default: + esg.sliceValueSQL(b, v) + } + case util.IsInt(valKind): + esg.Generate(b, v.Int()) + case util.IsUint(valKind): + esg.Generate(b, int64(v.Uint())) + case util.IsFloat(valKind): + esg.Generate(b, v.Float()) + case util.IsString(valKind): + esg.Generate(b, v.String()) + case util.IsBool(valKind): + esg.Generate(b, v.Bool()) + default: + b.SetError(errors.NewEncodeError(val)) + } +} + +func (esg *expressionSQLGenerator) expressionSQL(b sb.SQLBuilder, expression exp.Expression) { + switch e := expression.(type) { + case exp.ColumnListExpression: + esg.columnListSQL(b, e) + case exp.ExpressionList: + esg.expressionListSQL(b, e) + case exp.LiteralExpression: + esg.literalExpressionSQL(b, e) + case exp.IdentifierExpression: + esg.identifierExpressionSQL(b, e) + case exp.AliasedExpression: + esg.aliasedExpressionSQL(b, e) + case exp.BooleanExpression: + esg.booleanExpressionSQL(b, e) + case exp.RangeExpression: + esg.rangeExpressionSQL(b, e) + case exp.OrderedExpression: + esg.orderedExpressionSQL(b, e) + case exp.UpdateExpression: + esg.updateExpressionSQL(b, e) + case exp.SQLFunctionExpression: + esg.sqlFunctionExpressionSQL(b, e) + case exp.CastExpression: + esg.castExpressionSQL(b, e) + case exp.AppendableExpression: + esg.appendableExpressionSQL(b, e) + case exp.CommonTableExpression: + esg.commonTableExpressionSQL(b, e) + case exp.CompoundExpression: + esg.compoundExpressionSQL(b, e) + case exp.Ex: + esg.expressionMapSQL(b, e) + case exp.ExOr: + esg.expressionOrMapSQL(b, e) + default: + b.SetError(errUnsupportedExpressionType(e)) + } +} + +// Generates a placeholder (e.g. ?, $1) +func (esg *expressionSQLGenerator) placeHolderSQL(b sb.SQLBuilder, i interface{}) { + b.WriteRunes(esg.dialectOptions.PlaceHolderRune) + if esg.dialectOptions.IncludePlaceholderNum { + b.WriteStrings(strconv.FormatInt(int64(b.CurrentArgPosition()), 10)) + } + b.WriteArg(i) +} + +// Generates creates the sql for a sub select on a Dataset +func (esg *expressionSQLGenerator) appendableExpressionSQL(b sb.SQLBuilder, a exp.AppendableExpression) { + b.WriteRunes(esg.dialectOptions.LeftParenRune) + a.AppendSQL(b) + b.WriteRunes(esg.dialectOptions.RightParenRune) + c := a.GetClauses() + if c != nil { + alias := c.Alias() + if alias != nil { + b.Write(esg.dialectOptions.AsFragment) + esg.Generate(b, alias) + } + } +} + +// Quotes an identifier (e.g. "col", "table"."col" +func (esg *expressionSQLGenerator) identifierExpressionSQL(b sb.SQLBuilder, ident exp.IdentifierExpression) { + if ident.IsEmpty() { + b.SetError(errEmptyIdentifier) + return + } + schema, table, col := ident.GetSchema(), ident.GetTable(), ident.GetCol() + if schema != esg.dialectOptions.EmptyString { + b.WriteRunes(esg.dialectOptions.QuoteRune). + WriteStrings(schema). + WriteRunes(esg.dialectOptions.QuoteRune) + } + if table != esg.dialectOptions.EmptyString { + if schema != esg.dialectOptions.EmptyString { + b.WriteRunes(esg.dialectOptions.PeriodRune) + } + b.WriteRunes(esg.dialectOptions.QuoteRune). + WriteStrings(table). + WriteRunes(esg.dialectOptions.QuoteRune) + } + switch t := col.(type) { + case nil: + case string: + if col != esg.dialectOptions.EmptyString { + if table != esg.dialectOptions.EmptyString || schema != esg.dialectOptions.EmptyString { + b.WriteRunes(esg.dialectOptions.PeriodRune) + } + b.WriteRunes(esg.dialectOptions.QuoteRune). + WriteStrings(t). + WriteRunes(esg.dialectOptions.QuoteRune) + } + case exp.LiteralExpression: + if table != esg.dialectOptions.EmptyString || schema != esg.dialectOptions.EmptyString { + b.WriteRunes(esg.dialectOptions.PeriodRune) + } + esg.Generate(b, t) + default: + b.SetError(errUnsupportedIdentifierExpression(col)) + } +} + +// Generates SQL NULL value +func (esg *expressionSQLGenerator) literalNil(b sb.SQLBuilder) { + b.Write(esg.dialectOptions.Null) +} + +// Generates SQL bool literal, (e.g. TRUE, FALSE, mysql 1, 0, sqlite3 1, 0) +func (esg *expressionSQLGenerator) literalBool(b sb.SQLBuilder, bl bool) { + if b.IsPrepared() { + esg.placeHolderSQL(b, bl) + return + } + if bl { + b.Write(esg.dialectOptions.True) + } else { + b.Write(esg.dialectOptions.False) + } +} + +// Generates SQL for a time.Time value +func (esg *expressionSQLGenerator) literalTime(b sb.SQLBuilder, t time.Time) { + if b.IsPrepared() { + esg.placeHolderSQL(b, t) + return + } + esg.Generate(b, t.Format(esg.dialectOptions.TimeFormat)) +} + +// Generates SQL for a Float Value +func (esg *expressionSQLGenerator) literalFloat(b sb.SQLBuilder, f float64) { + if b.IsPrepared() { + esg.placeHolderSQL(b, f) + return + } + b.WriteStrings(strconv.FormatFloat(f, 'f', -1, 64)) +} + +// Generates SQL for an int value +func (esg *expressionSQLGenerator) literalInt(b sb.SQLBuilder, i int64) { + if b.IsPrepared() { + esg.placeHolderSQL(b, i) + return + } + b.WriteStrings(strconv.FormatInt(i, 10)) +} + +// Generates SQL for a string +func (esg *expressionSQLGenerator) literalString(b sb.SQLBuilder, s string) { + if b.IsPrepared() { + esg.placeHolderSQL(b, s) + return + } + b.WriteRunes(esg.dialectOptions.StringQuote) + for _, char := range s { + if e, ok := esg.dialectOptions.EscapedRunes[char]; ok { + b.Write(e) + } else { + b.WriteRunes(char) + } + } + + b.WriteRunes(esg.dialectOptions.StringQuote) +} + +// Generates SQL for a slice of bytes +func (esg *expressionSQLGenerator) literalBytes(b sb.SQLBuilder, bs []byte) { + if b.IsPrepared() { + esg.placeHolderSQL(b, bs) + return + } + b.WriteRunes(esg.dialectOptions.StringQuote) + i := 0 + for len(bs) > 0 { + char, l := utf8.DecodeRune(bs) + if e, ok := esg.dialectOptions.EscapedRunes[char]; ok { + b.Write(e) + } else { + b.WriteRunes(char) + } + i++ + bs = bs[l:] + } + b.WriteRunes(esg.dialectOptions.StringQuote) +} + +// Generates SQL for a slice of values (e.g. []int64{1,2,3,4} -> (1,2,3,4) +func (esg *expressionSQLGenerator) sliceValueSQL(b sb.SQLBuilder, slice reflect.Value) { + b.WriteRunes(esg.dialectOptions.LeftParenRune) + for i, l := 0, slice.Len(); i < l; i++ { + esg.Generate(b, slice.Index(i).Interface()) + if i < l-1 { + b.WriteRunes(esg.dialectOptions.CommaRune, esg.dialectOptions.SpaceRune) + } + } + b.WriteRunes(esg.dialectOptions.RightParenRune) +} + +// Generates SQL for an AliasedExpression (e.g. I("a").As("b") -> "a" AS "b") +func (esg *expressionSQLGenerator) aliasedExpressionSQL(b sb.SQLBuilder, aliased exp.AliasedExpression) { + esg.Generate(b, aliased.Aliased()) + b.Write(esg.dialectOptions.AsFragment) + esg.Generate(b, aliased.GetAs()) +} + +// Generates SQL for a BooleanExpresion (e.g. I("a").Eq(2) -> "a" = 2) +func (esg *expressionSQLGenerator) booleanExpressionSQL(b sb.SQLBuilder, operator exp.BooleanExpression) { + b.WriteRunes(esg.dialectOptions.LeftParenRune) + esg.Generate(b, operator.LHS()) + b.WriteRunes(esg.dialectOptions.SpaceRune) + operatorOp := operator.Op() + if val, ok := esg.dialectOptions.BooleanOperatorLookup[operatorOp]; ok { + b.Write(val) + } else { + b.SetError(errUnsupportedBooleanExpressionOperator(operatorOp)) + return + } + rhs := operator.RHS() + if (operatorOp == exp.IsOp || operatorOp == exp.IsNotOp) && esg.dialectOptions.UseLiteralIsBools { + if rhs == true { + rhs = TrueLiteral + } else if rhs == false { + rhs = FalseLiteral + } + } + b.WriteRunes(esg.dialectOptions.SpaceRune) + esg.Generate(b, rhs) + b.WriteRunes(esg.dialectOptions.RightParenRune) +} + +// Generates SQL for a RangeExpresion (e.g. I("a").Between(RangeVal{Start:2,End:5}) -> "a" BETWEEN 2 AND 5) +func (esg *expressionSQLGenerator) rangeExpressionSQL(b sb.SQLBuilder, operator exp.RangeExpression) { + b.WriteRunes(esg.dialectOptions.LeftParenRune) + esg.Generate(b, operator.LHS()) + b.WriteRunes(esg.dialectOptions.SpaceRune) + operatorOp := operator.Op() + if val, ok := esg.dialectOptions.RangeOperatorLookup[operatorOp]; ok { + b.Write(val) + } else { + b.SetError(errUnsupportedRangeExpressionOperator(operatorOp)) + return + } + rhs := operator.RHS() + b.WriteRunes(esg.dialectOptions.SpaceRune) + esg.Generate(b, rhs.Start()) + b.Write(esg.dialectOptions.AndFragment) + esg.Generate(b, rhs.End()) + b.WriteRunes(esg.dialectOptions.RightParenRune) +} + +// Generates SQL for an OrderedExpression (e.g. I("a").Asc() -> "a" ASC) +func (esg *expressionSQLGenerator) orderedExpressionSQL(b sb.SQLBuilder, order exp.OrderedExpression) { + esg.Generate(b, order.SortExpression()) + if order.IsAsc() { + b.Write(esg.dialectOptions.AscFragment) + } else { + b.Write(esg.dialectOptions.DescFragment) + } + switch order.NullSortType() { + case exp.NullsFirstSortType: + b.Write(esg.dialectOptions.NullsFirstFragment) + case exp.NullsLastSortType: + b.Write(esg.dialectOptions.NullsLastFragment) + } +} + +// Generates SQL for an ExpressionList (e.g. And(I("a").Eq("a"), I("b").Eq("b")) -> (("a" = 'a') AND ("b" = 'b'))) +func (esg *expressionSQLGenerator) expressionListSQL(b sb.SQLBuilder, expressionList exp.ExpressionList) { + if expressionList.IsEmpty() { + return + } + var op []byte + if expressionList.Type() == exp.AndType { + op = esg.dialectOptions.AndFragment + } else { + op = esg.dialectOptions.OrFragment + } + exps := expressionList.Expressions() + expLen := len(exps) - 1 + needsAppending := expLen > 0 + if needsAppending { + b.WriteRunes(esg.dialectOptions.LeftParenRune) + } else { + esg.Generate(b, exps[0]) + return + } + for i, e := range exps { + esg.Generate(b, e) + if i < expLen { + b.Write(op) + } + } + b.WriteRunes(esg.dialectOptions.RightParenRune) +} + +// Generates SQL for a ColumnListExpression +func (esg *expressionSQLGenerator) columnListSQL(b sb.SQLBuilder, columnList exp.ColumnListExpression) { + cols := columnList.Columns() + colLen := len(cols) + for i, col := range cols { + esg.Generate(b, col) + if i < colLen-1 { + b.WriteRunes(esg.dialectOptions.CommaRune, esg.dialectOptions.SpaceRune) + } + } +} + +// Generates SQL for an UpdateEpxresion +func (esg *expressionSQLGenerator) updateExpressionSQL(b sb.SQLBuilder, update exp.UpdateExpression) { + esg.Generate(b, update.Col()) + b.WriteRunes(esg.dialectOptions.SetOperatorRune) + esg.Generate(b, update.Val()) +} + +// Generates SQL for a LiteralExpression +// L("a + b") -> a + b +// L("a = ?", 1) -> a = 1 +func (esg *expressionSQLGenerator) literalExpressionSQL(b sb.SQLBuilder, literal exp.LiteralExpression) { + l := literal.Literal() + args := literal.Args() + argsLen := len(args) + if argsLen > 0 { + currIndex := 0 + for _, char := range l { + if char == replacementRune && currIndex < argsLen { + esg.Generate(b, args[currIndex]) + currIndex++ + } else { + b.WriteRunes(char) + } + } + } else { + b.WriteStrings(l) + } +} + +// Generates SQL for a SQLFunctionExpression +// COUNT(I("a")) -> COUNT("a") +func (esg *expressionSQLGenerator) sqlFunctionExpressionSQL(b sb.SQLBuilder, sqlFunc exp.SQLFunctionExpression) { + b.WriteStrings(sqlFunc.Name()) + esg.Generate(b, sqlFunc.Args()) +} + +// Generates SQL for a CastExpression +// I("a").Cast("NUMERIC") -> CAST("a" AS NUMERIC) +func (esg *expressionSQLGenerator) castExpressionSQL(b sb.SQLBuilder, cast exp.CastExpression) { + b.Write(esg.dialectOptions.CastFragment).WriteRunes(esg.dialectOptions.LeftParenRune) + esg.Generate(b, cast.Casted()) + b.Write(esg.dialectOptions.AsFragment) + esg.Generate(b, cast.Type()) + b.WriteRunes(esg.dialectOptions.RightParenRune) +} + +// Generates the sql for the WITH clauses for common table expressions (CTE) +func (esg *expressionSQLGenerator) commonTablesSliceSQL(b sb.SQLBuilder, ctes []exp.CommonTableExpression) { + if l := len(ctes); l > 0 { + if !esg.dialectOptions.SupportsWithCTE { + b.SetError(errCTENotSupported(esg.dialect)) + return + } + b.Write(esg.dialectOptions.WithFragment) + anyRecursive := false + for _, cte := range ctes { + anyRecursive = anyRecursive || cte.IsRecursive() + } + if anyRecursive { + if !esg.dialectOptions.SupportsWithCTERecursive { + b.SetError(errRecursiveCTENotSupported(esg.dialect)) + return + } + b.Write(esg.dialectOptions.RecursiveFragment) + } + for i, cte := range ctes { + esg.Generate(b, cte) + if i < l-1 { + b.WriteRunes(esg.dialectOptions.CommaRune, esg.dialectOptions.SpaceRune) + } + } + b.WriteRunes(esg.dialectOptions.SpaceRune) + } +} + +// Generates SQL for a CommonTableExpression +func (esg *expressionSQLGenerator) commonTableExpressionSQL(b sb.SQLBuilder, cte exp.CommonTableExpression) { + esg.Generate(b, cte.Name()) + b.Write(esg.dialectOptions.AsFragment) + esg.Generate(b, cte.SubQuery()) +} + +// Generates SQL for a CompoundExpression +func (esg *expressionSQLGenerator) compoundExpressionSQL(b sb.SQLBuilder, compound exp.CompoundExpression) { + switch compound.Type() { + case exp.UnionCompoundType: + b.Write(esg.dialectOptions.UnionFragment) + case exp.UnionAllCompoundType: + b.Write(esg.dialectOptions.UnionAllFragment) + case exp.IntersectCompoundType: + b.Write(esg.dialectOptions.IntersectFragment) + case exp.IntersectAllCompoundType: + b.Write(esg.dialectOptions.IntersectAllFragment) + } + if esg.dialectOptions.WrapCompoundsInParens { + b.WriteRunes(esg.dialectOptions.LeftParenRune) + compound.RHS().AppendSQL(b) + b.WriteRunes(esg.dialectOptions.RightParenRune) + } else { + compound.RHS().AppendSQL(b) + } + +} + +func (esg *expressionSQLGenerator) expressionMapSQL(b sb.SQLBuilder, ex exp.Ex) { + expressionList, err := ex.ToExpressions() + if err != nil { + b.SetError(err) + return + } + esg.Generate(b, expressionList) +} + +func (esg *expressionSQLGenerator) expressionOrMapSQL(b sb.SQLBuilder, ex exp.ExOr) { + expressionList, err := ex.ToExpressions() + if err != nil { + b.SetError(err) + return + } + esg.Generate(b, expressionList) +} diff --git a/sqlgen/expression_sql_generator_test.go b/sqlgen/expression_sql_generator_test.go new file mode 100644 index 00000000..5b28e07c --- /dev/null +++ b/sqlgen/expression_sql_generator_test.go @@ -0,0 +1,1198 @@ +package sqlgen + +import ( + "database/sql/driver" + "fmt" + "regexp" + "testing" + "time" + + "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/errors" + "github.com/doug-martin/goqu/v8/internal/sb" + "github.com/stretchr/testify/suite" +) + +var emptyArgs = make([]interface{}, 0) + +type testAppendableExpression struct { + exp.AppendableExpression + sql string + args []interface{} + err error + clauses exp.SelectClauses +} + +func newTestAppendableExpression(sql string, args []interface{}, err error, clauses exp.SelectClauses) exp.AppendableExpression { + if clauses == nil { + clauses = exp.NewSelectClauses() + } + return &testAppendableExpression{sql: sql, args: args, err: err, clauses: clauses} +} + +func (tae *testAppendableExpression) Expression() exp.Expression { + return tae +} + +func (tae *testAppendableExpression) GetClauses() exp.SelectClauses { + return tae.clauses +} + +func (tae *testAppendableExpression) Clone() exp.Expression { + return tae +} + +func (tae *testAppendableExpression) AppendSQL(b sb.SQLBuilder) { + if tae.err != nil { + b.SetError(tae.err) + return + } + b.WriteStrings(tae.sql) + if len(tae.args) > 0 { + b.WriteArg(tae.args...) + } +} + +type ( + expressionTestCase struct { + val interface{} + sql string + err string + isPrepared bool + args []interface{} + } + expressionSQLGeneratorSuite struct { + suite.Suite + } +) + +func (esgs *expressionSQLGeneratorSuite) assertCases(esg ExpressionSQLGenerator, cases ...expressionTestCase) { + for i, c := range cases { + b := sb.NewSQLBuilder(c.isPrepared) + esg.Generate(b, c.val) + actualSQL, actualArgs, err := b.ToSQL() + if c.err == "" { + esgs.NoError(err, "test case %d failed", i) + } else { + esgs.EqualError(err, c.err, "test case %d failed", i) + } + esgs.Equal(c.sql, actualSQL, "test case %d failed", i) + if c.isPrepared && c.args != nil || len(c.args) > 0 { + esgs.Equal(c.args, actualArgs, "test case %d failed", i) + } else { + esgs.Empty(actualArgs, "test case %d failed", i) + } + } +} + +func (esgs *expressionSQLGeneratorSuite) TestDialect() { + esg := NewExpressionSQLGenerator("test", DefaultDialectOptions()) + esgs.Equal("test", esg.Dialect()) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_ErroredBuilder() { + esg := NewExpressionSQLGenerator("test", DefaultDialectOptions()) + expectedErr := errors.New("test error") + b := sb.NewSQLBuilder(false).SetError(expectedErr) + esg.Generate(b, 1) + sql, args, err := b.ToSQL() + esgs.Equal(expectedErr, err) + esgs.Empty(sql) + esgs.Empty(args) + + b = sb.NewSQLBuilder(true).SetError(err) + esg.Generate(b, true) + sql, args, err = b.ToSQL() + esgs.Equal(expectedErr, err) + esgs.Empty(sql) + esgs.Empty(args) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_Invalid() { + var b *bool + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: b, sql: "NULL"}, + expressionTestCase{val: b, sql: "NULL", isPrepared: true}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_UnsupportedType() { + type strct struct { + } + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: strct{}, err: "goqu_encode_error: Unable to encode value {}"}, + expressionTestCase{val: strct{}, err: "goqu_encode_error: Unable to encode value {}", isPrepared: true}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_IncludePlaceholderNum() { + opts := DefaultDialectOptions() + opts.IncludePlaceholderNum = true + opts.PlaceHolderRune = '$' + ex := exp.Ex{ + "a": 1, + "b": true, + "c": false, + "d": []string{"a", "b", "c"}, + } + esgs.assertCases( + NewExpressionSQLGenerator("test", opts), + expressionTestCase{ + val: ex, + sql: `(("a" = 1) AND ("b" IS TRUE) AND ("c" IS FALSE) AND ("d" IN ('a', 'b', 'c')))`, + }, + expressionTestCase{ + val: ex, + sql: `(("a" = $1) AND ("b" IS TRUE) AND ("c" IS FALSE) AND ("d" IN ($2, $3, $4)))`, + isPrepared: true, + args: []interface{}{int64(1), "a", "b", "c"}, + }, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_FloatTypes() { + var float float64 + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: float32(10.01), sql: "10.010000228881836"}, + expressionTestCase{val: float32(10.01), sql: "?", isPrepared: true, args: []interface{}{float64(float32(10.01))}}, + + expressionTestCase{val: float64(10.01), sql: "10.01"}, + expressionTestCase{val: float64(10.01), sql: "?", isPrepared: true, args: []interface{}{float64(10.01)}}, + + expressionTestCase{val: &float, sql: "0"}, + expressionTestCase{val: &float, sql: "?", isPrepared: true, args: []interface{}{float}}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_IntTypes() { + var i int64 + ints := []interface{}{ + int(10), + int16(10), + int32(10), + int64(10), + uint(10), + uint16(10), + uint32(10), + uint64(10), + } + for _, i := range ints { + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: i, sql: "10"}, + expressionTestCase{val: i, sql: "?", isPrepared: true, args: []interface{}{int64(10)}}, + ) + } + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: &i, sql: "0"}, + expressionTestCase{val: &i, sql: "?", isPrepared: true, args: []interface{}{i}}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_StringTypes() { + var str string + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: "Hello", sql: "'Hello'"}, + expressionTestCase{val: "Hello", sql: "?", isPrepared: true, args: []interface{}{"Hello"}}, + + expressionTestCase{val: "Hello'", sql: "'Hello'''"}, + expressionTestCase{val: "Hello'", sql: "?", isPrepared: true, args: []interface{}{"Hello'"}}, + + expressionTestCase{val: &str, sql: "''"}, + expressionTestCase{val: &str, sql: "?", isPrepared: true, args: []interface{}{str}}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_BytesTypes() { + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: []byte("Hello"), sql: "'Hello'"}, + expressionTestCase{val: []byte("Hello"), sql: "?", isPrepared: true, args: []interface{}{[]byte("Hello")}}, + + expressionTestCase{val: []byte("Hello'"), sql: "'Hello'''"}, + expressionTestCase{val: []byte("Hello'"), sql: "?", isPrepared: true, args: []interface{}{[]byte("Hello'")}}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_BoolTypes() { + var bl bool + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: true, sql: "TRUE"}, + expressionTestCase{val: true, sql: "?", isPrepared: true, args: []interface{}{true}}, + + expressionTestCase{val: false, sql: "FALSE"}, + expressionTestCase{val: false, sql: "?", isPrepared: true, args: []interface{}{false}}, + + expressionTestCase{val: &bl, sql: "FALSE"}, + expressionTestCase{val: &bl, sql: "?", isPrepared: true, args: []interface{}{bl}}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_TimeTypes() { + var nt *time.Time + + asiaShanghai, err := time.LoadLocation("Asia/Shanghai") + esgs.Require().NoError(err) + testDatas := []time.Time{ + time.Now().UTC(), + time.Now().In(asiaShanghai), + } + + for _, n := range testDatas { + now := n + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: now, sql: "'" + now.Format(time.RFC3339Nano) + "'"}, + expressionTestCase{val: now, sql: "?", isPrepared: true, args: []interface{}{now}}, + + expressionTestCase{val: &now, sql: "'" + now.Format(time.RFC3339Nano) + "'"}, + expressionTestCase{val: &now, sql: "?", isPrepared: true, args: []interface{}{now}}, + ) + } + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: nt, sql: "NULL"}, + expressionTestCase{val: nt, sql: "NULL", isPrepared: true}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_NilTypes() { + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: nil, sql: "NULL"}, + expressionTestCase{val: nil, sql: "NULL", isPrepared: true}, + ) +} + +type datasetValuerType struct { + int int64 + err error +} + +func (j datasetValuerType) Value() (driver.Value, error) { + if j.err != nil { + return nil, j.err + } + return []byte(fmt.Sprintf("Hello World %d", j.int)), nil +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_Valuer() { + err := errors.New("valuer error") + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: datasetValuerType{int: 10}, sql: "'Hello World 10'"}, + expressionTestCase{ + val: datasetValuerType{int: 10}, sql: "?", isPrepared: true, args: []interface{}{[]byte("Hello World 10")}, + }, + + expressionTestCase{val: datasetValuerType{err: err}, err: "goqu: valuer error"}, + expressionTestCase{ + val: datasetValuerType{err: err}, isPrepared: true, err: "goqu: valuer error", + }, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_Slice() { + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: []string{"a", "b", "c"}, sql: `('a', 'b', 'c')`}, + expressionTestCase{ + val: []string{"a", "b", "c"}, sql: "(?, ?, ?)", isPrepared: true, args: []interface{}{"a", "b", "c"}, + }, + + expressionTestCase{val: []byte{'a', 'b', 'c'}, sql: `'abc'`}, + expressionTestCase{ + val: []byte{'a', 'b', 'c'}, sql: "?", isPrepared: true, args: []interface{}{[]byte{'a', 'b', 'c'}}, + }, + ) +} + +type unknownExpression struct { +} + +func (ue unknownExpression) Expression() exp.Expression { + return ue +} +func (ue unknownExpression) Clone() exp.Expression { + return ue +} +func (esgs *expressionSQLGeneratorSuite) TestGenerateUnsupportedExpression() { + errMsg := "goqu: unsupported expression type sqlgen.unknownExpression" + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: unknownExpression{}, err: errMsg}, + expressionTestCase{ + val: unknownExpression{}, isPrepared: true, err: errMsg}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_AppendableExpression() { + ti := exp.NewIdentifierExpression("", "b", "") + a := newTestAppendableExpression(`select * from "a"`, []interface{}{}, nil, nil) + aliasedA := newTestAppendableExpression(`select * from "a"`, []interface{}{}, nil, exp.NewSelectClauses().SetAlias(ti)) + argsA := newTestAppendableExpression(`select * from "a" where x=?`, []interface{}{true}, nil, exp.NewSelectClauses().SetAlias(ti)) + ae := newTestAppendableExpression(`select * from "a"`, emptyArgs, errors.New("expected error"), nil) + + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: a, sql: `(select * from "a")`}, + expressionTestCase{val: a, sql: `(select * from "a")`, isPrepared: true}, + + expressionTestCase{val: aliasedA, sql: `(select * from "a") AS "b"`}, + expressionTestCase{val: aliasedA, sql: `(select * from "a") AS "b"`, isPrepared: true}, + + expressionTestCase{val: ae, err: "goqu: expected error"}, + expressionTestCase{val: ae, err: "goqu: expected error", isPrepared: true}, + + expressionTestCase{val: argsA, sql: `(select * from "a" where x=?) AS "b"`, args: []interface{}{true}}, + expressionTestCase{val: argsA, sql: `(select * from "a" where x=?) AS "b"`, isPrepared: true, args: []interface{}{true}}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_ColumnList() { + cl := exp.NewColumnListExpression("a", exp.NewLiteralExpression("true")) + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: cl, sql: `"a", true`}, + expressionTestCase{val: cl, sql: `"a", true`, isPrepared: true}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_ExpressionList() { + + andEl := exp.NewExpressionList( + exp.AndType, + exp.NewIdentifierExpression("", "", "a").Eq("b"), + exp.NewIdentifierExpression("", "", "c").Neq(1), + ) + + orEl := exp.NewExpressionList( + exp.OrType, + exp.NewIdentifierExpression("", "", "a").Eq("b"), + exp.NewIdentifierExpression("", "", "c").Neq(1), + ) + + andOrEl := exp.NewExpressionList(exp.OrType, + exp.NewIdentifierExpression("", "", "a").Eq("b"), + exp.NewExpressionList(exp.AndType, + exp.NewIdentifierExpression("", "", "c").Neq(1), + exp.NewIdentifierExpression("", "", "d").Eq(exp.NewLiteralExpression("NOW()")), + ), + ) + + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: andEl, sql: `(("a" = 'b') AND ("c" != 1))`}, + expressionTestCase{ + val: andEl, sql: `(("a" = ?) AND ("c" != ?))`, isPrepared: true, args: []interface{}{"b", int64(1)}, + }, + + expressionTestCase{val: orEl, sql: `(("a" = 'b') OR ("c" != 1))`}, + expressionTestCase{ + val: orEl, sql: `(("a" = ?) OR ("c" != ?))`, isPrepared: true, args: []interface{}{"b", int64(1)}, + }, + + expressionTestCase{val: andOrEl, sql: `(("a" = 'b') OR (("c" != 1) AND ("d" = NOW())))`}, + expressionTestCase{ + val: andOrEl, + sql: `(("a" = ?) OR (("c" != ?) AND ("d" = NOW())))`, + isPrepared: true, + args: []interface{}{"b", int64(1)}, + }, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_LiteralExpression() { + noArgsL := exp.NewLiteralExpression(`"b"::DATE = '2010-09-02'`) + argsL := exp.NewLiteralExpression(`"b" = ? or "c" = ? or d IN ?`, "a", 1, []int{1, 2, 3, 4}) + + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: noArgsL, sql: `"b"::DATE = '2010-09-02'`}, + expressionTestCase{val: noArgsL, sql: `"b"::DATE = '2010-09-02'`, isPrepared: true}, + + expressionTestCase{val: argsL, sql: `"b" = 'a' or "c" = 1 or d IN (1, 2, 3, 4)`}, + expressionTestCase{ + val: argsL, + sql: `"b" = ? or "c" = ? or d IN (?, ?, ?, ?)`, + isPrepared: true, + args: []interface{}{ + "a", + int64(1), + int64(1), + int64(2), + int64(3), + int64(4), + }, + }, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_AliasedExpression() { + aliasedI := exp.NewIdentifierExpression("", "", "a").As("b") + aliasedWithII := exp.NewIdentifierExpression("", "", "a"). + As(exp.NewIdentifierExpression("", "", "b")) + aliasedL := exp.NewLiteralExpression("count(*)").As("count") + + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: aliasedI, sql: `"a" AS "b"`}, + expressionTestCase{val: aliasedI, sql: `"a" AS "b"`, isPrepared: true}, + + expressionTestCase{val: aliasedWithII, sql: `"a" AS "b"`}, + expressionTestCase{val: aliasedWithII, sql: `"a" AS "b"`, isPrepared: true}, + + expressionTestCase{val: aliasedL, sql: `count(*) AS "count"`}, + expressionTestCase{val: aliasedL, sql: `count(*) AS "count"`, isPrepared: true}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_BooleanExpression() { + ae := newTestAppendableExpression(`SELECT "id" FROM "test2"`, emptyArgs, nil, nil) + re := regexp.MustCompile("(a|b)") + ident := exp.NewIdentifierExpression("", "", "a") + + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: ident.Eq(1), sql: `("a" = 1)`}, + expressionTestCase{val: ident.Eq(1), sql: `("a" = ?)`, isPrepared: true, args: []interface{}{int64(1)}}, + + expressionTestCase{val: ident.Eq(true), sql: `("a" IS TRUE)`}, + expressionTestCase{val: ident.Eq(true), sql: `("a" IS TRUE)`, isPrepared: true}, + + expressionTestCase{val: ident.Eq(false), sql: `("a" IS FALSE)`}, + expressionTestCase{val: ident.Eq(false), sql: `("a" IS FALSE)`, isPrepared: true}, + + expressionTestCase{val: ident.Eq(nil), sql: `("a" IS NULL)`}, + expressionTestCase{val: ident.Eq(nil), sql: `("a" IS NULL)`, isPrepared: true}, + + expressionTestCase{val: ident.Eq([]int64{1, 2, 3}), sql: `("a" IN (1, 2, 3))`}, + expressionTestCase{val: ident.Eq([]int64{1, 2, 3}), sql: `("a" IN (?, ?, ?))`, isPrepared: true, args: []interface{}{ + int64(1), int64(2), int64(3), + }}, + + expressionTestCase{val: ident.Eq(ae), sql: `("a" IN (SELECT "id" FROM "test2"))`}, + expressionTestCase{val: ident.Eq(ae), sql: `("a" IN (SELECT "id" FROM "test2"))`, isPrepared: true}, + + expressionTestCase{val: ident.Neq(1), sql: `("a" != 1)`}, + expressionTestCase{val: ident.Neq(1), sql: `("a" != ?)`, isPrepared: true, args: []interface{}{int64(1)}}, + + expressionTestCase{val: ident.Neq(true), sql: `("a" IS NOT TRUE)`}, + expressionTestCase{val: ident.Neq(true), sql: `("a" IS NOT TRUE)`, isPrepared: true}, + + expressionTestCase{val: ident.Neq(false), sql: `("a" IS NOT FALSE)`}, + expressionTestCase{val: ident.Neq(false), sql: `("a" IS NOT FALSE)`, isPrepared: true}, + + expressionTestCase{val: ident.Neq(nil), sql: `("a" IS NOT NULL)`}, + expressionTestCase{val: ident.Neq(nil), sql: `("a" IS NOT NULL)`, isPrepared: true}, + + expressionTestCase{val: ident.Neq([]int64{1, 2, 3}), sql: `("a" NOT IN (1, 2, 3))`}, + expressionTestCase{val: ident.Neq([]int64{1, 2, 3}), sql: `("a" NOT IN (?, ?, ?))`, isPrepared: true, args: []interface{}{ + int64(1), int64(2), int64(3), + }}, + + expressionTestCase{val: ident.Neq(ae), sql: `("a" NOT IN (SELECT "id" FROM "test2"))`}, + expressionTestCase{val: ident.Neq(ae), sql: `("a" NOT IN (SELECT "id" FROM "test2"))`, isPrepared: true}, + + expressionTestCase{val: ident.Is(true), sql: `("a" IS TRUE)`}, + expressionTestCase{val: ident.Is(true), sql: `("a" IS TRUE)`, isPrepared: true}, + + expressionTestCase{val: ident.Is(false), sql: `("a" IS FALSE)`}, + expressionTestCase{val: ident.Is(false), sql: `("a" IS FALSE)`, isPrepared: true}, + + expressionTestCase{val: ident.Is(nil), sql: `("a" IS NULL)`}, + expressionTestCase{val: ident.Is(nil), sql: `("a" IS NULL)`, isPrepared: true}, + + expressionTestCase{val: ident.IsNot(true), sql: `("a" IS NOT TRUE)`}, + expressionTestCase{val: ident.IsNot(true), sql: `("a" IS NOT TRUE)`, isPrepared: true}, + + expressionTestCase{val: ident.IsNot(false), sql: `("a" IS NOT FALSE)`}, + expressionTestCase{val: ident.IsNot(false), sql: `("a" IS NOT FALSE)`, isPrepared: true}, + + expressionTestCase{val: ident.IsNot(nil), sql: `("a" IS NOT NULL)`}, + expressionTestCase{val: ident.IsNot(nil), sql: `("a" IS NOT NULL)`, isPrepared: true}, + + expressionTestCase{val: ident.Gt(1), sql: `("a" > 1)`}, + expressionTestCase{val: ident.Gt(1), sql: `("a" > ?)`, isPrepared: true, args: []interface{}{int64(1)}}, + + expressionTestCase{val: ident.Gte(1), sql: `("a" >= 1)`}, + expressionTestCase{val: ident.Gte(1), sql: `("a" >= ?)`, isPrepared: true, args: []interface{}{int64(1)}}, + + expressionTestCase{val: ident.Lt(1), sql: `("a" < 1)`}, + expressionTestCase{val: ident.Lt(1), sql: `("a" < ?)`, isPrepared: true, args: []interface{}{int64(1)}}, + + expressionTestCase{val: ident.Lte(1), sql: `("a" <= 1)`}, + expressionTestCase{val: ident.Lte(1), sql: `("a" <= ?)`, isPrepared: true, args: []interface{}{int64(1)}}, + + expressionTestCase{val: ident.In([]int64{1, 2, 3}), sql: `("a" IN (1, 2, 3))`}, + expressionTestCase{val: ident.In([]int64{1, 2, 3}), sql: `("a" IN (?, ?, ?))`, isPrepared: true, args: []interface{}{ + int64(1), int64(2), int64(3), + }}, + + expressionTestCase{val: ident.In(ae), sql: `("a" IN ((SELECT "id" FROM "test2")))`}, + expressionTestCase{val: ident.In(ae), sql: `("a" IN ((SELECT "id" FROM "test2")))`, isPrepared: true}, + + expressionTestCase{val: ident.NotIn([]int64{1, 2, 3}), sql: `("a" NOT IN (1, 2, 3))`}, + expressionTestCase{val: ident.NotIn([]int64{1, 2, 3}), sql: `("a" NOT IN (?, ?, ?))`, isPrepared: true, args: []interface{}{ + int64(1), int64(2), int64(3), + }}, + + expressionTestCase{val: ident.NotIn(ae), sql: `("a" NOT IN ((SELECT "id" FROM "test2")))`}, + expressionTestCase{val: ident.NotIn(ae), sql: `("a" NOT IN ((SELECT "id" FROM "test2")))`, isPrepared: true}, + + expressionTestCase{val: ident.Like("a%"), sql: `("a" LIKE 'a%')`}, + expressionTestCase{val: ident.Like("a%"), sql: `("a" LIKE ?)`, isPrepared: true, args: []interface{}{"a%"}}, + + expressionTestCase{val: ident.Like(re), sql: `("a" ~ '(a|b)')`}, + expressionTestCase{val: ident.Like(re), sql: `("a" ~ ?)`, isPrepared: true, args: []interface{}{"(a|b)"}}, + + expressionTestCase{val: ident.ILike("a%"), sql: `("a" ILIKE 'a%')`}, + expressionTestCase{val: ident.ILike("a%"), sql: `("a" ILIKE ?)`, isPrepared: true, args: []interface{}{"a%"}}, + + expressionTestCase{val: ident.ILike(re), sql: `("a" ~* '(a|b)')`}, + expressionTestCase{val: ident.ILike(re), sql: `("a" ~* ?)`, isPrepared: true, args: []interface{}{"(a|b)"}}, + + expressionTestCase{val: ident.NotLike("a%"), sql: `("a" NOT LIKE 'a%')`}, + expressionTestCase{val: ident.NotLike("a%"), sql: `("a" NOT LIKE ?)`, isPrepared: true, args: []interface{}{"a%"}}, + + expressionTestCase{val: ident.NotLike(re), sql: `("a" !~ '(a|b)')`}, + expressionTestCase{val: ident.NotLike(re), sql: `("a" !~ ?)`, isPrepared: true, args: []interface{}{"(a|b)"}}, + + expressionTestCase{val: ident.NotILike("a%"), sql: `("a" NOT ILIKE 'a%')`}, + expressionTestCase{val: ident.NotILike("a%"), sql: `("a" NOT ILIKE ?)`, isPrepared: true, args: []interface{}{"a%"}}, + + expressionTestCase{val: ident.NotILike(re), sql: `("a" !~* '(a|b)')`}, + expressionTestCase{val: ident.NotILike(re), sql: `("a" !~* ?)`, isPrepared: true, args: []interface{}{"(a|b)"}}, + ) + + opts := DefaultDialectOptions() + opts.BooleanOperatorLookup = map[exp.BooleanOperation][]byte{} + esgs.assertCases( + NewExpressionSQLGenerator("test", opts), + expressionTestCase{val: ident.Eq(1), err: "goqu: boolean operator 'eq' not supported"}, + expressionTestCase{val: ident.Neq(1), err: "goqu: boolean operator 'neq' not supported"}, + expressionTestCase{val: ident.Is(true), err: "goqu: boolean operator 'is' not supported"}, + expressionTestCase{val: ident.IsNot(true), err: "goqu: boolean operator 'isnot' not supported"}, + expressionTestCase{val: ident.Gt(1), err: "goqu: boolean operator 'gt' not supported"}, + expressionTestCase{val: ident.Gte(1), err: "goqu: boolean operator 'gte' not supported"}, + expressionTestCase{val: ident.Lt(1), err: "goqu: boolean operator 'lt' not supported"}, + expressionTestCase{val: ident.Lte(1), err: "goqu: boolean operator 'lte' not supported"}, + expressionTestCase{val: ident.In([]int64{1, 2, 3}), err: "goqu: boolean operator 'in' not supported"}, + expressionTestCase{val: ident.NotIn([]int64{1, 2, 3}), err: "goqu: boolean operator 'notin' not supported"}, + expressionTestCase{val: ident.Like("a%"), err: "goqu: boolean operator 'like' not supported"}, + expressionTestCase{val: ident.Like(re), err: "goqu: boolean operator 'regexp like' not supported"}, + expressionTestCase{val: ident.ILike("a%"), err: "goqu: boolean operator 'ilike' not supported"}, + expressionTestCase{val: ident.ILike(re), err: "goqu: boolean operator 'regexp ilike' not supported"}, + expressionTestCase{val: ident.NotLike("a%"), err: "goqu: boolean operator 'notlike' not supported"}, + expressionTestCase{val: ident.NotLike(re), err: "goqu: boolean operator 'regexp notlike' not supported"}, + expressionTestCase{val: ident.NotILike("a%"), err: "goqu: boolean operator 'notilike' not supported"}, + expressionTestCase{val: ident.NotILike(re), err: "goqu: boolean operator 'regexp notilike' not supported"}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_RangeExpression() { + betweenNum := exp.NewIdentifierExpression("", "", "a"). + Between(exp.NewRangeVal(1, 2)) + notBetweenNum := exp.NewIdentifierExpression("", "", "a"). + NotBetween(exp.NewRangeVal(1, 2)) + + betweenStr := exp.NewIdentifierExpression("", "", "a"). + Between(exp.NewRangeVal("aaa", "zzz")) + notBetweenStr := exp.NewIdentifierExpression("", "", "a"). + NotBetween(exp.NewRangeVal("aaa", "zzz")) + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: betweenNum, sql: `("a" BETWEEN 1 AND 2)`}, + expressionTestCase{val: betweenNum, sql: `("a" BETWEEN ? AND ?)`, isPrepared: true, args: []interface{}{ + int64(1), + int64(2), + }}, + + expressionTestCase{val: notBetweenNum, sql: `("a" NOT BETWEEN 1 AND 2)`}, + expressionTestCase{val: notBetweenNum, sql: `("a" NOT BETWEEN ? AND ?)`, isPrepared: true, args: []interface{}{ + int64(1), + int64(2), + }}, + + expressionTestCase{val: betweenStr, sql: `("a" BETWEEN 'aaa' AND 'zzz')`}, + expressionTestCase{val: betweenStr, sql: `("a" BETWEEN ? AND ?)`, isPrepared: true, args: []interface{}{ + "aaa", + "zzz", + }}, + + expressionTestCase{val: notBetweenStr, sql: `("a" NOT BETWEEN 'aaa' AND 'zzz')`}, + expressionTestCase{val: notBetweenStr, sql: `("a" NOT BETWEEN ? AND ?)`, isPrepared: true, args: []interface{}{ + "aaa", + "zzz", + }}, + ) + + opts := DefaultDialectOptions() + opts.RangeOperatorLookup = map[exp.RangeOperation][]byte{} + esgs.assertCases( + NewExpressionSQLGenerator("test", opts), + expressionTestCase{val: betweenNum, err: "goqu: range operator between not supported"}, + expressionTestCase{val: betweenNum, err: "goqu: range operator between not supported"}, + + expressionTestCase{val: notBetweenNum, err: "goqu: range operator not between not supported"}, + expressionTestCase{val: notBetweenNum, err: "goqu: range operator not between not supported"}, + + expressionTestCase{val: betweenStr, err: "goqu: range operator between not supported"}, + expressionTestCase{val: betweenStr, err: "goqu: range operator between not supported"}, + + expressionTestCase{val: notBetweenStr, err: "goqu: range operator not between not supported"}, + expressionTestCase{val: notBetweenStr, err: "goqu: range operator not between not supported"}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_OrderedExpression() { + + asc := exp.NewIdentifierExpression("", "", "a").Asc() + ascNf := exp.NewIdentifierExpression("", "", "a").Asc().NullsFirst() + ascNl := exp.NewIdentifierExpression("", "", "a").Asc().NullsLast() + + desc := exp.NewIdentifierExpression("", "", "a").Desc() + descNf := exp.NewIdentifierExpression("", "", "a").Desc().NullsFirst() + descNl := exp.NewIdentifierExpression("", "", "a").Desc().NullsLast() + + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: asc, sql: `"a" ASC`}, + expressionTestCase{val: asc, sql: `"a" ASC`, isPrepared: true}, + + expressionTestCase{val: ascNf, sql: `"a" ASC NULLS FIRST`}, + expressionTestCase{val: ascNf, sql: `"a" ASC NULLS FIRST`, isPrepared: true}, + + expressionTestCase{val: ascNl, sql: `"a" ASC NULLS LAST`}, + expressionTestCase{val: ascNl, sql: `"a" ASC NULLS LAST`, isPrepared: true}, + + expressionTestCase{val: desc, sql: `"a" DESC`}, + expressionTestCase{val: desc, sql: `"a" DESC`, isPrepared: true}, + + expressionTestCase{val: descNf, sql: `"a" DESC NULLS FIRST`}, + expressionTestCase{val: descNf, sql: `"a" DESC NULLS FIRST`, isPrepared: true}, + + expressionTestCase{val: descNl, sql: `"a" DESC NULLS LAST`}, + expressionTestCase{val: descNl, sql: `"a" DESC NULLS LAST`, isPrepared: true}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_UpdateExpression() { + ue := exp.NewIdentifierExpression("", "", "a").Set(1) + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: ue, sql: `"a"=1`}, + expressionTestCase{val: ue, sql: `"a"=?`, isPrepared: true, args: []interface{}{int64(1)}}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_SQLFunctionExpression() { + + min := exp.NewSQLFunctionExpression("MIN", exp.NewIdentifierExpression("", "", "a")) + coalesce := exp.NewSQLFunctionExpression("COALESCE", exp.NewIdentifierExpression("", "", "a"), "a") + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: min, sql: `MIN("a")`}, + expressionTestCase{val: min, sql: `MIN("a")`, isPrepared: true}, + + expressionTestCase{val: coalesce, sql: `COALESCE("a", 'a')`}, + expressionTestCase{val: coalesce, sql: `COALESCE("a", ?)`, isPrepared: true, args: []interface{}{"a"}}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_CastExpression() { + cast := exp.NewIdentifierExpression("", "", "a").Cast("DATE") + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: cast, sql: `CAST("a" AS DATE)`}, + expressionTestCase{val: cast, sql: `CAST("a" AS DATE)`, isPrepared: true}, + ) +} + +// Generates the sql for the WITH clauses for common table expressions (CTE) +func (esgs *expressionSQLGeneratorSuite) TestGenerate_CommonTableExpressionSlice() { + ae := newTestAppendableExpression(`SELECT * FROM "b"`, emptyArgs, nil, nil) + + cteNoArgs := []exp.CommonTableExpression{ + exp.NewCommonTableExpression(false, "a", ae), + } + cteArgs := []exp.CommonTableExpression{ + exp.NewCommonTableExpression(false, "a(x,y)", ae), + } + + cteRecursiveNoArgs := []exp.CommonTableExpression{ + exp.NewCommonTableExpression(true, "a", ae), + } + cteRecursiveArgs := []exp.CommonTableExpression{ + exp.NewCommonTableExpression(true, "a(x,y)", ae), + } + + allCtes := []exp.CommonTableExpression{ + exp.NewCommonTableExpression(false, "a", ae), + exp.NewCommonTableExpression(false, "a(x,y)", ae), + } + + allRecursiveCtes := []exp.CommonTableExpression{ + exp.NewCommonTableExpression(true, "a", ae), + exp.NewCommonTableExpression(true, "a(x,y)", ae), + } + + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: cteNoArgs, sql: `WITH a AS (SELECT * FROM "b") `}, + expressionTestCase{val: cteNoArgs, sql: `WITH a AS (SELECT * FROM "b") `, isPrepared: true}, + + expressionTestCase{val: cteArgs, sql: `WITH a(x,y) AS (SELECT * FROM "b") `}, + expressionTestCase{val: cteArgs, sql: `WITH a(x,y) AS (SELECT * FROM "b") `, isPrepared: true}, + + expressionTestCase{val: cteRecursiveNoArgs, sql: `WITH RECURSIVE a AS (SELECT * FROM "b") `}, + expressionTestCase{val: cteRecursiveNoArgs, sql: `WITH RECURSIVE a AS (SELECT * FROM "b") `, isPrepared: true}, + + expressionTestCase{val: cteRecursiveArgs, sql: `WITH RECURSIVE a(x,y) AS (SELECT * FROM "b") `}, + expressionTestCase{val: cteRecursiveArgs, sql: `WITH RECURSIVE a(x,y) AS (SELECT * FROM "b") `, isPrepared: true}, + + expressionTestCase{val: allCtes, sql: `WITH a AS (SELECT * FROM "b"), a(x,y) AS (SELECT * FROM "b") `}, + expressionTestCase{val: allCtes, sql: `WITH a AS (SELECT * FROM "b"), a(x,y) AS (SELECT * FROM "b") `, isPrepared: true}, + + expressionTestCase{val: allRecursiveCtes, sql: `WITH RECURSIVE a AS (SELECT * FROM "b"), a(x,y) AS (SELECT * FROM "b") `}, + expressionTestCase{ + val: allRecursiveCtes, + sql: `WITH RECURSIVE a AS (SELECT * FROM "b"), a(x,y) AS (SELECT * FROM "b") `, + isPrepared: true, + }, + ) + opts := DefaultDialectOptions() + opts.SupportsWithCTE = false + esgs.assertCases( + NewExpressionSQLGenerator("test", opts), + expressionTestCase{val: cteNoArgs, err: "goqu: dialect does not support CTE WITH clause [dialect=test]"}, + expressionTestCase{val: cteNoArgs, err: "goqu: dialect does not support CTE WITH clause [dialect=test]", isPrepared: true}, + + expressionTestCase{val: cteArgs, err: "goqu: dialect does not support CTE WITH clause [dialect=test]"}, + expressionTestCase{val: cteArgs, err: "goqu: dialect does not support CTE WITH clause [dialect=test]", isPrepared: true}, + + expressionTestCase{val: cteRecursiveNoArgs, err: "goqu: dialect does not support CTE WITH clause [dialect=test]"}, + expressionTestCase{val: cteRecursiveNoArgs, err: "goqu: dialect does not support CTE WITH clause [dialect=test]", isPrepared: true}, + + expressionTestCase{val: cteRecursiveArgs, err: "goqu: dialect does not support CTE WITH clause [dialect=test]"}, + expressionTestCase{val: cteRecursiveArgs, err: "goqu: dialect does not support CTE WITH clause [dialect=test]", isPrepared: true}, + ) + opts = DefaultDialectOptions() + opts.SupportsWithCTERecursive = false + esgs.assertCases( + NewExpressionSQLGenerator("test", opts), + expressionTestCase{val: cteNoArgs, sql: `WITH a AS (SELECT * FROM "b") `}, + expressionTestCase{val: cteNoArgs, sql: `WITH a AS (SELECT * FROM "b") `, isPrepared: true}, + + expressionTestCase{val: cteArgs, sql: `WITH a(x,y) AS (SELECT * FROM "b") `}, + expressionTestCase{val: cteArgs, sql: `WITH a(x,y) AS (SELECT * FROM "b") `, isPrepared: true}, + + expressionTestCase{ + val: cteRecursiveNoArgs, + err: "goqu: dialect does not support CTE WITH RECURSIVE clause [dialect=test]", + }, + expressionTestCase{ + val: cteRecursiveNoArgs, + err: "goqu: dialect does not support CTE WITH RECURSIVE clause [dialect=test]", + isPrepared: true, + }, + + expressionTestCase{ + val: cteRecursiveArgs, + err: "goqu: dialect does not support CTE WITH RECURSIVE clause [dialect=test]", + }, + expressionTestCase{ + val: cteRecursiveArgs, + err: "goqu: dialect does not support CTE WITH RECURSIVE clause [dialect=test]", + isPrepared: true, + }, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_CommonTableExpression() { + ae := newTestAppendableExpression(`SELECT * FROM "b"`, emptyArgs, nil, nil) + + cteNoArgs := exp.NewCommonTableExpression(false, "a", ae) + cteArgs := exp.NewCommonTableExpression(false, "a(x,y)", ae) + + cteRecursiveNoArgs := exp.NewCommonTableExpression(true, "a", ae) + cteRecursiveArgs := exp.NewCommonTableExpression(true, "a(x,y)", ae) + + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: cteNoArgs, sql: `a AS (SELECT * FROM "b")`}, + expressionTestCase{val: cteNoArgs, sql: `a AS (SELECT * FROM "b")`, isPrepared: true}, + + expressionTestCase{val: cteArgs, sql: `a(x,y) AS (SELECT * FROM "b")`}, + expressionTestCase{val: cteArgs, sql: `a(x,y) AS (SELECT * FROM "b")`, isPrepared: true}, + + expressionTestCase{val: cteRecursiveNoArgs, sql: `a AS (SELECT * FROM "b")`}, + expressionTestCase{val: cteRecursiveNoArgs, sql: `a AS (SELECT * FROM "b")`, isPrepared: true}, + + expressionTestCase{val: cteRecursiveArgs, sql: `a(x,y) AS (SELECT * FROM "b")`}, + expressionTestCase{val: cteRecursiveArgs, sql: `a(x,y) AS (SELECT * FROM "b")`, isPrepared: true}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_CompoundExpression() { + ae := newTestAppendableExpression(`SELECT * FROM "b"`, emptyArgs, nil, nil) + + u := exp.NewCompoundExpression(exp.UnionCompoundType, ae) + ua := exp.NewCompoundExpression(exp.UnionAllCompoundType, ae) + + i := exp.NewCompoundExpression(exp.IntersectCompoundType, ae) + ia := exp.NewCompoundExpression(exp.IntersectAllCompoundType, ae) + + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: u, sql: ` UNION (SELECT * FROM "b")`}, + expressionTestCase{val: u, sql: ` UNION (SELECT * FROM "b")`, isPrepared: true}, + + expressionTestCase{val: ua, sql: ` UNION ALL (SELECT * FROM "b")`}, + expressionTestCase{val: ua, sql: ` UNION ALL (SELECT * FROM "b")`, isPrepared: true}, + + expressionTestCase{val: i, sql: ` INTERSECT (SELECT * FROM "b")`}, + expressionTestCase{val: i, sql: ` INTERSECT (SELECT * FROM "b")`, isPrepared: true}, + + expressionTestCase{val: ia, sql: ` INTERSECT ALL (SELECT * FROM "b")`}, + expressionTestCase{val: ia, sql: ` INTERSECT ALL (SELECT * FROM "b")`, isPrepared: true}, + ) + + opts := DefaultDialectOptions() + opts.WrapCompoundsInParens = false + esgs.assertCases( + NewExpressionSQLGenerator("test", opts), + expressionTestCase{val: u, sql: ` UNION SELECT * FROM "b"`}, + expressionTestCase{val: u, sql: ` UNION SELECT * FROM "b"`, isPrepared: true}, + + expressionTestCase{val: ua, sql: ` UNION ALL SELECT * FROM "b"`}, + expressionTestCase{val: ua, sql: ` UNION ALL SELECT * FROM "b"`, isPrepared: true}, + + expressionTestCase{val: i, sql: ` INTERSECT SELECT * FROM "b"`}, + expressionTestCase{val: i, sql: ` INTERSECT SELECT * FROM "b"`, isPrepared: true}, + + expressionTestCase{val: ia, sql: ` INTERSECT ALL SELECT * FROM "b"`}, + expressionTestCase{val: ia, sql: ` INTERSECT ALL SELECT * FROM "b"`, isPrepared: true}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_IdentifierExpression() { + col := exp.NewIdentifierExpression("", "", "col") + colStar := exp.NewIdentifierExpression("", "", "*") + table := exp.NewIdentifierExpression("", "table", "") + schema := exp.NewIdentifierExpression("schema", "", "") + tableCol := exp.NewIdentifierExpression("", "table", "col") + schemaTableCol := exp.NewIdentifierExpression("schema", "table", "col") + + parsedCol := exp.ParseIdentifier("col") + parsedTableCol := exp.ParseIdentifier("table.col") + parsedSchemaTableCol := exp.ParseIdentifier("schema.table.col") + + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{ + val: exp.NewIdentifierExpression("", "", ""), + err: `goqu: a empty identifier was encountered, please specify a "schema", "table" or "column"`, + }, + expressionTestCase{ + val: exp.NewIdentifierExpression("", "", nil), + err: `goqu: a empty identifier was encountered, please specify a "schema", "table" or "column"`, + }, + expressionTestCase{ + val: exp.NewIdentifierExpression("", "", false), + err: `goqu: unexpected col type must be string or LiteralExpression received bool`, + }, + + expressionTestCase{val: col, sql: `"col"`}, + expressionTestCase{val: col, sql: `"col"`, isPrepared: true}, + + expressionTestCase{val: col.Table("table"), sql: `"table"."col"`}, + expressionTestCase{val: col.Table("table"), sql: `"table"."col"`, isPrepared: true}, + + expressionTestCase{val: col.Table("table").Schema("schema"), sql: `"schema"."table"."col"`}, + expressionTestCase{val: col.Table("table").Schema("schema"), sql: `"schema"."table"."col"`, isPrepared: true}, + + expressionTestCase{val: colStar, sql: `*`}, + expressionTestCase{val: colStar, sql: `*`, isPrepared: true}, + + expressionTestCase{val: colStar.Table("table"), sql: `"table".*`}, + expressionTestCase{val: colStar.Table("table"), sql: `"table".*`, isPrepared: true}, + + expressionTestCase{val: colStar.Table("table").Schema("schema"), sql: `"schema"."table".*`}, + expressionTestCase{val: colStar.Table("table").Schema("schema"), sql: `"schema"."table".*`, isPrepared: true}, + + expressionTestCase{val: table, sql: `"table"`}, + expressionTestCase{val: table, sql: `"table"`, isPrepared: true}, + + expressionTestCase{val: table.Col("col"), sql: `"table"."col"`}, + expressionTestCase{val: table.Col("col"), sql: `"table"."col"`, isPrepared: true}, + + expressionTestCase{val: table.Col(nil), sql: `"table"`}, + expressionTestCase{val: table.Col(nil), sql: `"table"`, isPrepared: true}, + + expressionTestCase{val: table.Col("*"), sql: `"table".*`}, + expressionTestCase{val: table.Col("*"), sql: `"table".*`, isPrepared: true}, + + expressionTestCase{val: table.Schema("schema").Col("col"), sql: `"schema"."table"."col"`}, + expressionTestCase{val: table.Schema("schema").Col("col"), sql: `"schema"."table"."col"`, isPrepared: true}, + + expressionTestCase{val: schema, sql: `"schema"`}, + expressionTestCase{val: schema, sql: `"schema"`, isPrepared: true}, + + expressionTestCase{val: schema.Table("table"), sql: `"schema"."table"`}, + expressionTestCase{val: schema.Table("table"), sql: `"schema"."table"`, isPrepared: true}, + + expressionTestCase{val: schema.Table("table").Col("col"), sql: `"schema"."table"."col"`}, + expressionTestCase{val: schema.Table("table").Col("col"), sql: `"schema"."table"."col"`, isPrepared: true}, + + expressionTestCase{val: schema.Table("table").Col(nil), sql: `"schema"."table"`}, + expressionTestCase{val: schema.Table("table").Col(nil), sql: `"schema"."table"`, isPrepared: true}, + + expressionTestCase{val: schema.Table("table").Col("*"), sql: `"schema"."table".*`}, + expressionTestCase{val: schema.Table("table").Col("*"), sql: `"schema"."table".*`, isPrepared: true}, + + expressionTestCase{val: tableCol, sql: `"table"."col"`}, + expressionTestCase{val: tableCol, sql: `"table"."col"`, isPrepared: true}, + + expressionTestCase{val: schemaTableCol, sql: `"schema"."table"."col"`}, + expressionTestCase{val: schemaTableCol, sql: `"schema"."table"."col"`, isPrepared: true}, + + expressionTestCase{val: parsedCol, sql: `"col"`}, + expressionTestCase{val: parsedCol, sql: `"col"`, isPrepared: true}, + + expressionTestCase{val: parsedTableCol, sql: `"table"."col"`}, + expressionTestCase{val: parsedTableCol, sql: `"table"."col"`, isPrepared: true}, + + expressionTestCase{val: parsedSchemaTableCol, sql: `"schema"."table"."col"`}, + expressionTestCase{val: parsedSchemaTableCol, sql: `"schema"."table"."col"`, isPrepared: true}, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_ExpressionMap() { + re := regexp.MustCompile("(a|b)") + esgs.assertCases( + NewExpressionSQLGenerator("test", DefaultDialectOptions()), + expressionTestCase{val: exp.Ex{}}, + expressionTestCase{val: exp.Ex{}, isPrepared: true}, + + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"badOp": true}}, + err: "goqu: unsupported expression type map[badOp:%!s(bool=true)]", + }, + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"badOp": true}}, + isPrepared: true, + err: "goqu: unsupported expression type map[badOp:%!s(bool=true)]", + }, + + expressionTestCase{val: exp.Ex{"a": 1}, sql: `("a" = 1)`}, + expressionTestCase{val: exp.Ex{"a": 1}, sql: `("a" = ?)`, isPrepared: true, args: []interface{}{int64(1)}}, + + expressionTestCase{val: exp.Ex{"a": true}, sql: `("a" IS TRUE)`}, + expressionTestCase{val: exp.Ex{"a": true}, sql: `("a" IS TRUE)`, isPrepared: true}, + + expressionTestCase{val: exp.Ex{"a": false}, sql: `("a" IS FALSE)`}, + expressionTestCase{val: exp.Ex{"a": false}, sql: `("a" IS FALSE)`, isPrepared: true}, + + expressionTestCase{val: exp.Ex{"a": nil}, sql: `("a" IS NULL)`}, + expressionTestCase{val: exp.Ex{"a": nil}, sql: `("a" IS NULL)`, isPrepared: true}, + + expressionTestCase{val: exp.Ex{"a": []string{"a", "b", "c"}}, sql: `("a" IN ('a', 'b', 'c'))`}, + expressionTestCase{ + val: exp.Ex{"a": []string{"a", "b", "c"}}, + sql: `("a" IN (?, ?, ?))`, + isPrepared: true, + args: []interface{}{"a", "b", "c"}, + }, + + expressionTestCase{val: exp.Ex{"a": exp.Op{"neq": 1}}, sql: `("a" != 1)`}, + expressionTestCase{val: exp.Ex{"a": exp.Op{"neq": 1}}, sql: `("a" != ?)`, isPrepared: true, args: []interface{}{ + int64(1), + }}, + + expressionTestCase{val: exp.Ex{"a": exp.Op{"isnot": true}}, sql: `("a" IS NOT TRUE)`}, + expressionTestCase{val: exp.Ex{"a": exp.Op{"isnot": true}}, sql: `("a" IS NOT TRUE)`, isPrepared: true}, + + expressionTestCase{val: exp.Ex{"a": exp.Op{"gt": 1}}, sql: `("a" > 1)`}, + expressionTestCase{val: exp.Ex{"a": exp.Op{"gt": 1}}, sql: `("a" > ?)`, isPrepared: true, args: []interface{}{ + int64(1), + }}, + + expressionTestCase{val: exp.Ex{"a": exp.Op{"gte": 1}}, sql: `("a" >= 1)`}, + expressionTestCase{val: exp.Ex{"a": exp.Op{"gte": 1}}, sql: `("a" >= ?)`, isPrepared: true, args: []interface{}{ + int64(1), + }}, + + expressionTestCase{val: exp.Ex{"a": exp.Op{"lt": 1}}, sql: `("a" < 1)`}, + expressionTestCase{val: exp.Ex{"a": exp.Op{"lt": 1}}, sql: `("a" < ?)`, isPrepared: true, args: []interface{}{ + int64(1), + }}, + + expressionTestCase{val: exp.Ex{"a": exp.Op{"lte": 1}}, sql: `("a" <= 1)`}, + expressionTestCase{val: exp.Ex{"a": exp.Op{"lte": 1}}, sql: `("a" <= ?)`, isPrepared: true, args: []interface{}{ + int64(1), + }}, + + expressionTestCase{val: exp.Ex{"a": exp.Op{"like": "a%"}}, sql: `("a" LIKE 'a%')`}, + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"like": "a%"}}, + sql: `("a" LIKE ?)`, + isPrepared: true, + args: []interface{}{"a%"}, + }, + + expressionTestCase{val: exp.Ex{"a": exp.Op{"like": re}}, sql: `("a" ~ '(a|b)')`}, + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"like": re}}, + sql: `("a" ~ ?)`, + isPrepared: true, + args: []interface{}{"(a|b)"}, + }, + + expressionTestCase{val: exp.Ex{"a": exp.Op{"notLike": "a%"}}, sql: `("a" NOT LIKE 'a%')`}, + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"notLike": "a%"}}, + sql: `("a" NOT LIKE ?)`, + isPrepared: true, + args: []interface{}{"a%"}, + }, + + expressionTestCase{val: exp.Ex{"a": exp.Op{"notLike": re}}, sql: `("a" !~ '(a|b)')`}, + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"notLike": re}}, + sql: `("a" !~ ?)`, + isPrepared: true, + args: []interface{}{"(a|b)"}, + }, + + expressionTestCase{val: exp.Ex{"a": exp.Op{"iLike": "a%"}}, sql: `("a" ILIKE 'a%')`}, + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"iLike": "a%"}}, + sql: `("a" ILIKE ?)`, + isPrepared: true, + args: []interface{}{"a%"}, + }, + + expressionTestCase{val: exp.Ex{"a": exp.Op{"iLike": re}}, sql: `("a" ~* '(a|b)')`}, + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"iLike": re}}, + sql: `("a" ~* ?)`, + isPrepared: true, + args: []interface{}{"(a|b)"}, + }, + + expressionTestCase{val: exp.Ex{"a": exp.Op{"notILike": "a%"}}, sql: `("a" NOT ILIKE 'a%')`}, + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"notILike": "a%"}}, + sql: `("a" NOT ILIKE ?)`, + isPrepared: true, + args: []interface{}{"a%"}, + }, + + expressionTestCase{val: exp.Ex{"a": exp.Op{"notILike": re}}, sql: `("a" !~* '(a|b)')`}, + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"notILike": re}}, + sql: `("a" !~* ?)`, + isPrepared: true, + args: []interface{}{"(a|b)"}, + }, + + expressionTestCase{val: exp.Ex{"a": exp.Op{"in": []string{"a", "b", "c"}}}, sql: `("a" IN ('a', 'b', 'c'))`}, + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"in": []string{"a", "b", "c"}}}, + sql: `("a" IN (?, ?, ?))`, + isPrepared: true, + args: []interface{}{"a", "b", "c"}, + }, + + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"notIn": []string{"a", "b", "c"}}}, + sql: `("a" NOT IN ('a', 'b', 'c'))`, + }, + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"notIn": []string{"a", "b", "c"}}}, + sql: `("a" NOT IN (?, ?, ?))`, + isPrepared: true, + args: []interface{}{"a", "b", "c"}, + }, + + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"between": exp.NewRangeVal("aaa", "zzz")}}, + sql: `("a" BETWEEN 'aaa' AND 'zzz')`, + }, + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"between": exp.NewRangeVal("aaa", "zzz")}}, + sql: `("a" BETWEEN ? AND ?)`, + isPrepared: true, + args: []interface{}{"aaa", "zzz"}, + }, + + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"notBetween": exp.NewRangeVal("aaa", "zzz")}}, + sql: `("a" NOT BETWEEN 'aaa' AND 'zzz')`, + }, + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"notBetween": exp.NewRangeVal("aaa", "zzz")}}, + sql: `("a" NOT BETWEEN ? AND ?)`, + isPrepared: true, + args: []interface{}{"aaa", "zzz"}, + }, + + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"is": nil, "eq": 10}}, + sql: `(("a" = 10) OR ("a" IS NULL))`, + }, + expressionTestCase{ + val: exp.Ex{"a": exp.Op{"is": nil, "eq": 10}}, + sql: `(("a" = ?) OR ("a" IS NULL))`, + isPrepared: true, + args: []interface{}{int64(10)}, + }, + ) +} + +func (esgs *expressionSQLGeneratorSuite) TestGenerate_ExpressionOrMap() { + esgs.assertCases( + NewExpressionSQLGenerator("default", DefaultDialectOptions()), + expressionTestCase{val: exp.ExOr{}}, + expressionTestCase{val: exp.ExOr{}, isPrepared: true}, + + expressionTestCase{ + val: exp.ExOr{"a": exp.Op{"badOp": true}}, + err: "goqu: unsupported expression type map[badOp:%!s(bool=true)]", + }, + expressionTestCase{ + val: exp.ExOr{"a": exp.Op{"badOp": true}}, + isPrepared: true, + err: "goqu: unsupported expression type map[badOp:%!s(bool=true)]", + }, + + expressionTestCase{val: exp.ExOr{"a": 1, "b": true}, sql: `(("a" = 1) OR ("b" IS TRUE))`}, + expressionTestCase{ + val: exp.ExOr{"a": 1, "b": true}, + sql: `(("a" = ?) OR ("b" IS TRUE))`, + isPrepared: true, + args: []interface{}{int64(1)}, + }, + + expressionTestCase{ + val: exp.ExOr{"a": 1, "b": []string{"a", "b", "c"}}, + sql: `(("a" = 1) OR ("b" IN ('a', 'b', 'c')))`, + }, + expressionTestCase{ + val: exp.ExOr{"a": 1, "b": []string{"a", "b", "c"}}, + sql: `(("a" = ?) OR ("b" IN (?, ?, ?)))`, + isPrepared: true, + args: []interface{}{int64(1), "a", "b", "c"}, + }, + ) +} +func TestExpressionSQLGenerator(t *testing.T) { + suite.Run(t, new(expressionSQLGeneratorSuite)) +} diff --git a/sqlgen/insert_sql_generator.go b/sqlgen/insert_sql_generator.go new file mode 100644 index 00000000..9aec0f44 --- /dev/null +++ b/sqlgen/insert_sql_generator.go @@ -0,0 +1,206 @@ +package sqlgen + +import ( + "strings" + + "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/errors" + "github.com/doug-martin/goqu/v8/internal/sb" +) + +type ( + // An adapter interface to be used by a Dataset to generate SQL for a specific dialect. + // See DefaultAdapter for a concrete implementation and examples. + InsertSQLGenerator interface { + Dialect() string + Generate(b sb.SQLBuilder, clauses exp.InsertClauses) + } + // The default adapter. This class should be used when building a new adapter. When creating a new adapter you can + // either override methods, or more typically update default values. + // See (github.com/doug-martin/goqu/adapters/postgres) + insertSQLGenerator struct { + *commonSQLGenerator + } +) + +var ( + errConflictUpdateValuesRequired = errors.New("values are required for on conflict update expression") + errNoSourceForInsert = errors.New("no source found when generating insert sql") +) + +func errMisMatchedRowLength(expectedL, actualL int) error { + return errors.New("rows with different value length expected %d got %d", expectedL, actualL) +} + +func errUpsertWithWhereNotSupported(dialect string) error { + return errors.New("dialect does not support upsert with where clause [dialect=%s]", dialect) +} + +func NewInsertSQLGenerator(dialect string, do *SQLDialectOptions) InsertSQLGenerator { + return &insertSQLGenerator{newCommonSQLGenerator(dialect, do)} +} + +func (isg *insertSQLGenerator) Dialect() string { + return isg.dialect +} + +func (isg *insertSQLGenerator) Generate( + b sb.SQLBuilder, + clauses exp.InsertClauses, +) { + if !clauses.HasInto() { + b.SetError(errNoSourceForInsert) + return + } + for _, f := range isg.dialectOptions.InsertSQLOrder { + if b.Error() != nil { + return + } + switch f { + case CommonTableSQLFragment: + isg.esg.Generate(b, clauses.CommonTables()) + case InsertBeingSQLFragment: + isg.InsertBeginSQL(b, clauses.OnConflict()) + case IntoSQLFragment: + b.WriteRunes(isg.dialectOptions.SpaceRune) + isg.esg.Generate(b, clauses.Into()) + case InsertSQLFragment: + isg.InsertSQL(b, clauses) + case ReturningSQLFragment: + isg.ReturningSQL(b, clauses.Returning()) + default: + b.SetError(errNotSupportedFragment("INSERT", f)) + } + } + +} + +// Adds the correct fragment to being an INSERT statement +func (isg *insertSQLGenerator) InsertBeginSQL(b sb.SQLBuilder, o exp.ConflictExpression) { + if isg.dialectOptions.SupportsInsertIgnoreSyntax && o != nil { + b.Write(isg.dialectOptions.InsertIgnoreClause) + } else { + b.Write(isg.dialectOptions.InsertClause) + } +} + +// Adds the columns list to an insert statement +func (isg *insertSQLGenerator) InsertSQL(b sb.SQLBuilder, ic exp.InsertClauses) { + switch { + case ic.HasRows(): + ie, err := exp.NewInsertExpression(ic.Rows()...) + if err != nil { + b.SetError(err) + return + } + isg.InsertExpressionSQL(b, ie) + case ic.HasCols() && ic.HasVals(): + isg.insertColumnsSQL(b, ic.Cols()) + isg.insertValuesSQL(b, ic.Vals()) + case ic.HasCols() && ic.HasFrom(): + isg.insertColumnsSQL(b, ic.Cols()) + isg.insertFromSQL(b, ic.From()) + case ic.HasFrom(): + isg.insertFromSQL(b, ic.From()) + default: + isg.defaultValuesSQL(b) + + } + + isg.onConflictSQL(b, ic.OnConflict()) +} + +func (isg *insertSQLGenerator) InsertExpressionSQL(b sb.SQLBuilder, ie exp.InsertExpression) { + switch { + case ie.IsInsertFrom(): + isg.insertFromSQL(b, ie.From()) + case ie.IsEmpty(): + isg.defaultValuesSQL(b) + default: + isg.insertColumnsSQL(b, ie.Cols()) + isg.insertValuesSQL(b, ie.Vals()) + } +} + +// Adds the DefaultValuesFragment to an SQL statement +func (isg *insertSQLGenerator) defaultValuesSQL(b sb.SQLBuilder) { + b.Write(isg.dialectOptions.DefaultValuesFragment) +} + +func (isg *insertSQLGenerator) insertFromSQL(b sb.SQLBuilder, ae exp.AppendableExpression) { + b.WriteRunes(isg.dialectOptions.SpaceRune) + ae.AppendSQL(b) +} + +// Adds the columns list to an insert statement +func (isg *insertSQLGenerator) insertColumnsSQL(b sb.SQLBuilder, cols exp.ColumnListExpression) { + b.WriteRunes(isg.dialectOptions.SpaceRune, isg.dialectOptions.LeftParenRune) + isg.esg.Generate(b, cols) + b.WriteRunes(isg.dialectOptions.RightParenRune) +} + +// Adds the values clause to an SQL statement +func (isg *insertSQLGenerator) insertValuesSQL(b sb.SQLBuilder, values [][]interface{}) { + b.Write(isg.dialectOptions.ValuesFragment) + rowLen := len(values[0]) + valueLen := len(values) + for i, row := range values { + if len(row) != rowLen { + b.SetError(errMisMatchedRowLength(rowLen, len(row))) + return + } + isg.esg.Generate(b, row) + if i < valueLen-1 { + b.WriteRunes(isg.dialectOptions.CommaRune, isg.dialectOptions.SpaceRune) + } + } +} + +// Adds the DefaultValuesFragment to an SQL statement +func (isg *insertSQLGenerator) onConflictSQL(b sb.SQLBuilder, o exp.ConflictExpression) { + if o == nil { + return + } + b.Write(isg.dialectOptions.ConflictFragment) + switch t := o.(type) { + case exp.ConflictUpdateExpression: + target := t.TargetColumn() + if isg.dialectOptions.SupportsConflictTarget && target != "" { + wrapParens := !strings.HasPrefix(strings.ToLower(target), "on constraint") + + b.WriteRunes(isg.dialectOptions.SpaceRune) + if wrapParens { + b.WriteRunes(isg.dialectOptions.LeftParenRune). + WriteStrings(target). + WriteRunes(isg.dialectOptions.RightParenRune) + } else { + b.Write([]byte(target)) + } + } + isg.onConflictDoUpdateSQL(b, t) + default: + b.Write(isg.dialectOptions.ConflictDoNothingFragment) + } +} + +func (isg *insertSQLGenerator) onConflictDoUpdateSQL(b sb.SQLBuilder, o exp.ConflictUpdateExpression) { + b.Write(isg.dialectOptions.ConflictDoUpdateFragment) + update := o.Update() + if update == nil { + b.SetError(errConflictUpdateValuesRequired) + return + } + ue, err := exp.NewUpdateExpressions(update) + if err != nil { + b.SetError(err) + return + } + isg.UpdateExpressionSQL(b, ue...) + if b.Error() == nil && o.WhereClause() != nil { + if !isg.dialectOptions.SupportsConflictUpdateWhere { + b.SetError(errUpsertWithWhereNotSupported(isg.dialect)) + return + } + isg.WhereSQL(b, o.WhereClause()) + } +} diff --git a/sqlgen/insert_sql_generator_test.go b/sqlgen/insert_sql_generator_test.go new file mode 100644 index 00000000..7ca423ec --- /dev/null +++ b/sqlgen/insert_sql_generator_test.go @@ -0,0 +1,455 @@ +package sqlgen + +import ( + "testing" + + "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/sb" + "github.com/stretchr/testify/suite" +) + +type ( + insertTestCase struct { + clause exp.InsertClauses + sql string + isPrepared bool + args []interface{} + err string + } + insertSQLGeneratorSuite struct { + baseSQLGeneratorSuite + } +) + +func (igs *insertSQLGeneratorSuite) assertCases(isg InsertSQLGenerator, testCases ...insertTestCase) { + for _, tc := range testCases { + b := sb.NewSQLBuilder(tc.isPrepared) + isg.Generate(b, tc.clause) + switch { + case len(tc.err) > 0: + igs.assertErrorSQL(b, tc.err) + case tc.isPrepared: + igs.assertPreparedSQL(b, tc.sql, tc.args) + default: + igs.assertNotPreparedSQL(b, tc.sql) + } + } +} + +func (igs *insertSQLGeneratorSuite) TestDialect() { + opts := DefaultDialectOptions() + d := NewInsertSQLGenerator("test", opts) + igs.Equal("test", d.Dialect()) + + opts2 := DefaultDialectOptions() + d2 := NewInsertSQLGenerator("test2", opts2) + igs.Equal("test2", d2.Dialect()) +} + +func (igs *insertSQLGeneratorSuite) TestGenerate_UnsupportedFragment() { + opts := DefaultDialectOptions() + opts.InsertSQLOrder = []SQLFragmentType{UpdateBeginSQLFragment} + d := NewInsertSQLGenerator("test", opts) + + b := sb.NewSQLBuilder(true) + ic := exp.NewInsertClauses(). + SetInto(exp.NewIdentifierExpression("", "test", "")) + d.Generate(b, ic) + igs.assertErrorSQL(b, `goqu: unsupported INSERT SQL fragment UpdateBeginSQLFragment`) +} + +func (igs *insertSQLGeneratorSuite) TestGenerate_empty() { + ic := exp.NewInsertClauses(). + SetInto(exp.NewIdentifierExpression("", "test", "")) + + igs.assertCases( + NewInsertSQLGenerator("test", DefaultDialectOptions()), + insertTestCase{clause: ic, sql: `INSERT INTO "test" DEFAULT VALUES`}, + insertTestCase{clause: ic, sql: `INSERT INTO "test" DEFAULT VALUES`, isPrepared: true}, + ) + + opts2 := DefaultDialectOptions() + opts2.DefaultValuesFragment = []byte(" default values") + + igs.assertCases( + NewInsertSQLGenerator("test", opts2), + insertTestCase{clause: ic, sql: `INSERT INTO "test" default values`}, + insertTestCase{clause: ic, sql: `INSERT INTO "test" default values`, isPrepared: true}, + ) +} + +func (igs *insertSQLGeneratorSuite) TestGenerate_nilValues() { + ic := exp.NewInsertClauses(). + SetInto(exp.NewIdentifierExpression("", "test", "")). + SetCols(exp.NewColumnListExpression("a")). + SetVals([][]interface{}{ + {nil}, + }) + + igs.assertCases( + NewInsertSQLGenerator("test", DefaultDialectOptions()), + insertTestCase{clause: ic, sql: `INSERT INTO "test" ("a") VALUES (NULL)`}, + insertTestCase{clause: ic, sql: `INSERT INTO "test" ("a") VALUES (NULL)`, isPrepared: true}, + ) +} + +func (igs *insertSQLGeneratorSuite) TestGenerate_colsAndVals() { + opts := DefaultDialectOptions() + opts.LeftParenRune = '{' + opts.RightParenRune = '}' + opts.ValuesFragment = []byte(" values ") + opts.LeftParenRune = '{' + opts.RightParenRune = '}' + opts.CommaRune = ';' + opts.PlaceHolderRune = '#' + + ic := exp.NewInsertClauses(). + SetInto(exp.NewIdentifierExpression("", "test", "")). + SetCols(exp.NewColumnListExpression("a", "b")). + SetVals([][]interface{}{ + {"a1", "b1"}, + {"a2", "b2"}, + {"a3", "b3"}, + }) + + bic := ic.SetCols(exp.NewColumnListExpression("a", "b")). + SetVals([][]interface{}{ + {"a1"}, + {"a2", "b2"}, + {"a3", "b3"}, + }) + + igs.assertCases( + NewInsertSQLGenerator("test", opts), + insertTestCase{clause: ic, sql: `INSERT INTO "test" {"a"; "b"} values {'a1'; 'b1'}; {'a2'; 'b2'}; {'a3'; 'b3'}`}, + insertTestCase{clause: ic, sql: `INSERT INTO "test" {"a"; "b"} values {#; #}; {#; #}; {#; #}`, isPrepared: true, args: []interface{}{ + "a1", "b1", "a2", "b2", "a3", "b3", + }}, + + insertTestCase{clause: bic, err: `goqu: rows with different value length expected 1 got 2`}, + insertTestCase{clause: bic, err: `goqu: rows with different value length expected 1 got 2`, isPrepared: true}, + ) +} + +func (igs *insertSQLGeneratorSuite) TestGenerate_withNoInto() { + opts := DefaultDialectOptions() + opts.LeftParenRune = '{' + opts.RightParenRune = '}' + opts.ValuesFragment = []byte(" values ") + opts.LeftParenRune = '{' + opts.RightParenRune = '}' + opts.CommaRune = ';' + opts.PlaceHolderRune = '#' + + ic := exp.NewInsertClauses(). + SetCols(exp.NewColumnListExpression("a", "b")). + SetVals([][]interface{}{ + {"a1", "b1"}, + {"a2", "b2"}, + {"a3", "b3"}, + }) + expectedErr := "goqu: no source found when generating insert sql" + igs.assertCases( + NewInsertSQLGenerator("test", opts), + insertTestCase{clause: ic, err: expectedErr}, + insertTestCase{clause: ic, err: expectedErr, isPrepared: true}, + ) +} +func (igs *insertSQLGeneratorSuite) TestGenerate_withRows() { + opts := DefaultDialectOptions() + opts.LeftParenRune = '{' + opts.RightParenRune = '}' + opts.ValuesFragment = []byte(" values ") + opts.LeftParenRune = '{' + opts.RightParenRune = '}' + opts.CommaRune = ';' + opts.PlaceHolderRune = '#' + + ic := exp.NewInsertClauses(). + SetInto(exp.NewIdentifierExpression("", "test", "")). + SetRows([]interface{}{ + exp.Record{"a": "a1", "b": "b1"}, + exp.Record{"a": "a2", "b": "b2"}, + exp.Record{"a": "a3", "b": "b3"}, + }) + + bic := ic.SetRows([]interface{}{ + exp.Record{"a": "a1"}, + exp.Record{"a": "a2", "b": "b2"}, + exp.Record{"a": "a3", "b": "b3"}, + }) + + igs.assertCases( + NewInsertSQLGenerator("test", opts), + insertTestCase{clause: ic, sql: `INSERT INTO "test" {"a"; "b"} values {'a1'; 'b1'}; {'a2'; 'b2'}; {'a3'; 'b3'}`}, + insertTestCase{clause: ic, sql: `INSERT INTO "test" {"a"; "b"} values {#; #}; {#; #}; {#; #}`, isPrepared: true, args: []interface{}{ + "a1", "b1", "a2", "b2", "a3", "b3", + }}, + + insertTestCase{clause: bic, err: `goqu: rows with different value length expected 1 got 2`}, + insertTestCase{clause: bic, err: `goqu: rows with different value length expected 1 got 2`, isPrepared: true}, + ) +} + +func (igs *insertSQLGeneratorSuite) TestGenerate_withEmptyRows() { + ic := exp.NewInsertClauses(). + SetInto(exp.NewIdentifierExpression("", "test", "")). + SetRows([]interface{}{exp.Record{}}) + + igs.assertCases( + NewInsertSQLGenerator("test", DefaultDialectOptions()), + insertTestCase{clause: ic, sql: `INSERT INTO "test" DEFAULT VALUES`}, + insertTestCase{clause: ic, sql: `INSERT INTO "test" DEFAULT VALUES`, isPrepared: true}, + ) + + opts2 := DefaultDialectOptions() + opts2.DefaultValuesFragment = []byte(" default values") + + igs.assertCases( + NewInsertSQLGenerator("test", opts2), + insertTestCase{clause: ic, sql: `INSERT INTO "test" default values`}, + insertTestCase{clause: ic, sql: `INSERT INTO "test" default values`, isPrepared: true}, + ) +} + +func (igs *insertSQLGeneratorSuite) TestGenerate_withRowsAppendableExpression() { + ic := exp.NewInsertClauses(). + SetInto(exp.NewIdentifierExpression("", "test", "")). + SetRows([]interface{}{newTestAppendableExpression(`select * from "other"`, emptyArgs, nil, nil)}) + + igs.assertCases( + NewInsertSQLGenerator("test", DefaultDialectOptions()), + insertTestCase{clause: ic, sql: `INSERT INTO "test" select * from "other"`}, + insertTestCase{clause: ic, sql: `INSERT INTO "test" select * from "other"`, isPrepared: true}, + ) +} + +func (igs *insertSQLGeneratorSuite) TestGenerate_withFrom() { + ic := exp.NewInsertClauses(). + SetInto(exp.NewIdentifierExpression("", "test", "")). + SetFrom(newTestAppendableExpression(`select c, d from test where a = 'b'`, nil, nil, nil)) + + icCols := ic.SetCols(exp.NewColumnListExpression("a", "b")) + igs.assertCases( + NewInsertSQLGenerator("test", DefaultDialectOptions()), + insertTestCase{clause: ic, sql: `INSERT INTO "test" select c, d from test where a = 'b'`}, + insertTestCase{clause: ic, sql: `INSERT INTO "test" select c, d from test where a = 'b'`, isPrepared: true}, + + insertTestCase{clause: icCols, sql: `INSERT INTO "test" ("a", "b") select c, d from test where a = 'b'`}, + insertTestCase{clause: icCols, sql: `INSERT INTO "test" ("a", "b") select c, d from test where a = 'b'`, isPrepared: true}, + ) +} + +func (igs *insertSQLGeneratorSuite) TestGenerate_onConflict() { + opts := DefaultDialectOptions() + // make sure the fragments are used + opts.ConflictFragment = []byte(" on conflict") + opts.ConflictDoNothingFragment = []byte(" do nothing") + opts.ConflictDoUpdateFragment = []byte(" do update set ") + + ic := exp.NewInsertClauses(). + SetInto(exp.NewIdentifierExpression("", "test", "")). + SetCols(exp.NewColumnListExpression("a")). + SetVals([][]interface{}{ + {"a1"}, + }) + icDn := ic.SetOnConflict(exp.NewDoNothingConflictExpression()) + icDu := ic.SetOnConflict(exp.NewDoUpdateConflictExpression("test", exp.Record{"a": "b"})) + icDoc := ic.SetOnConflict(exp.NewDoUpdateConflictExpression("on constraint test", exp.Record{"a": "b"})) + icDuw := ic.SetOnConflict( + exp.NewDoUpdateConflictExpression("test", exp.Record{"a": "b"}).Where(exp.Ex{"foo": true}), + ) + + icDuNil := ic.SetOnConflict(exp.NewDoUpdateConflictExpression("test", nil)) + icDuBad := ic.SetOnConflict(exp.NewDoUpdateConflictExpression("test", true)) + + igs.assertCases( + NewInsertSQLGenerator("test", opts), + insertTestCase{clause: icDn, sql: `INSERT INTO "test" ("a") VALUES ('a1') on conflict do nothing`}, + insertTestCase{ + clause: icDn, + sql: `INSERT INTO "test" ("a") VALUES (?) on conflict do nothing`, + isPrepared: true, + args: []interface{}{"a1"}, + }, + + insertTestCase{clause: icDu, sql: `INSERT INTO "test" ("a") VALUES ('a1') on conflict (test) do update set "a"='b'`}, + insertTestCase{ + clause: icDu, + sql: `INSERT INTO "test" ("a") VALUES (?) on conflict (test) do update set "a"=?`, + isPrepared: true, + args: []interface{}{"a1", "b"}, + }, + + insertTestCase{clause: icDoc, sql: `INSERT INTO "test" ("a") VALUES ('a1') on conflict on constraint test do update set "a"='b'`}, + insertTestCase{ + clause: icDoc, + sql: `INSERT INTO "test" ("a") VALUES (?) on conflict on constraint test do update set "a"=?`, + isPrepared: true, + args: []interface{}{"a1", "b"}, + }, + + insertTestCase{ + clause: icDuw, + sql: `INSERT INTO "test" ("a") VALUES ('a1') on conflict (test) do update set "a"='b' WHERE ("foo" IS TRUE)`, + }, + insertTestCase{ + clause: icDuw, + sql: `INSERT INTO "test" ("a") VALUES (?) on conflict (test) do update set "a"=? WHERE ("foo" IS TRUE)`, + isPrepared: true, + args: []interface{}{"a1", "b"}, + }, + + insertTestCase{clause: icDuNil, err: errConflictUpdateValuesRequired.Error()}, + insertTestCase{clause: icDuNil, err: errConflictUpdateValuesRequired.Error(), isPrepared: true}, + + insertTestCase{clause: icDuBad, err: "goqu: unsupported update interface type bool"}, + insertTestCase{clause: icDuBad, err: "goqu: unsupported update interface type bool", isPrepared: true}, + ) + opts.SupportsInsertIgnoreSyntax = true + opts.InsertIgnoreClause = []byte("insert ignore into") + igs.assertCases( + NewInsertSQLGenerator("test", opts), + insertTestCase{clause: icDn, sql: `insert ignore into "test" ("a") VALUES ('a1') on conflict do nothing`}, + insertTestCase{ + clause: icDn, + sql: `insert ignore into "test" ("a") VALUES (?) on conflict do nothing`, + isPrepared: true, + args: []interface{}{"a1"}, + }, + + insertTestCase{clause: icDu, + sql: `insert ignore into "test" ("a") VALUES ('a1') on conflict (test) do update set "a"='b'`, + }, + insertTestCase{ + clause: icDu, + sql: `insert ignore into "test" ("a") VALUES (?) on conflict (test) do update set "a"=?`, + isPrepared: true, + args: []interface{}{"a1", "b"}, + }, + + insertTestCase{ + clause: icDoc, + sql: `insert ignore into "test" ("a") VALUES ('a1') on conflict on constraint test do update set "a"='b'`, + }, + insertTestCase{ + clause: icDoc, + sql: `insert ignore into "test" ("a") VALUES (?) on conflict on constraint test do update set "a"=?`, + isPrepared: true, + args: []interface{}{"a1", "b"}, + }, + + insertTestCase{ + clause: icDuw, + sql: `insert ignore into "test" ("a") VALUES ('a1') on conflict (test) do update set "a"='b' WHERE ("foo" IS TRUE)`, + }, + insertTestCase{ + clause: icDuw, + sql: `insert ignore into "test" ("a") VALUES (?) on conflict (test) do update set "a"=? WHERE ("foo" IS TRUE)`, + isPrepared: true, + args: []interface{}{"a1", "b"}, + }, + + insertTestCase{clause: icDuNil, err: errConflictUpdateValuesRequired.Error()}, + insertTestCase{clause: icDuNil, err: errConflictUpdateValuesRequired.Error(), isPrepared: true}, + + insertTestCase{clause: icDuBad, err: "goqu: unsupported update interface type bool"}, + insertTestCase{clause: icDuBad, err: "goqu: unsupported update interface type bool", isPrepared: true}, + ) + + opts.SupportsConflictUpdateWhere = false + expectedErr := "goqu: dialect does not support upsert with where clause [dialect=test]" + igs.assertCases( + NewInsertSQLGenerator("test", opts), + insertTestCase{clause: icDuw, err: expectedErr}, + insertTestCase{clause: icDuw, err: expectedErr, isPrepared: true}, + ) + +} + +func (igs *insertSQLGeneratorSuite) TestGenerate_withCommonTables() { + opts := DefaultDialectOptions() + opts.WithFragment = []byte("with ") + opts.RecursiveFragment = []byte("recursive ") + + tse := newTestAppendableExpression("select * from foo", emptyArgs, nil, nil) + + ic := exp.NewInsertClauses().SetInto(exp.NewIdentifierExpression("", "test_cte", "")) + icCte1 := ic.CommonTablesAppend(exp.NewCommonTableExpression(false, "test_cte", tse)) + icCte2 := ic.CommonTablesAppend(exp.NewCommonTableExpression(true, "test_cte", tse)) + + igs.assertCases( + NewInsertSQLGenerator("test", opts), + insertTestCase{ + clause: icCte1, + sql: `with test_cte AS (select * from foo) INSERT INTO "test_cte" DEFAULT VALUES`, + }, + insertTestCase{ + clause: icCte1, + sql: `with test_cte AS (select * from foo) INSERT INTO "test_cte" DEFAULT VALUES`, + isPrepared: true, + }, + + insertTestCase{ + clause: icCte2, + sql: `with recursive test_cte AS (select * from foo) INSERT INTO "test_cte" DEFAULT VALUES`, + }, + insertTestCase{ + clause: icCte2, + sql: `with recursive test_cte AS (select * from foo) INSERT INTO "test_cte" DEFAULT VALUES`, + isPrepared: true}, + ) + + opts.SupportsWithCTE = false + expectedErr := "goqu: dialect does not support CTE WITH clause [dialect=test]" + igs.assertCases( + NewInsertSQLGenerator("test", opts), + insertTestCase{clause: icCte1, err: expectedErr}, + insertTestCase{clause: icCte1, err: expectedErr, isPrepared: true}, + + insertTestCase{clause: icCte2, err: expectedErr}, + insertTestCase{clause: icCte2, err: expectedErr, isPrepared: true}, + ) + + opts.SupportsWithCTE = true + opts.SupportsWithCTERecursive = false + expectedErr = "goqu: dialect does not support CTE WITH RECURSIVE clause [dialect=test]" + igs.assertCases( + NewInsertSQLGenerator("test", opts), + insertTestCase{ + clause: icCte1, + sql: `with test_cte AS (select * from foo) INSERT INTO "test_cte" DEFAULT VALUES`, + }, + insertTestCase{ + clause: icCte1, + sql: `with test_cte AS (select * from foo) INSERT INTO "test_cte" DEFAULT VALUES`, + isPrepared: true, + }, + + insertTestCase{clause: icCte2, err: expectedErr}, + insertTestCase{clause: icCte2, err: expectedErr, isPrepared: true}, + ) + +} + +func (igs *insertSQLGeneratorSuite) TestGenerate_withReturning() { + ic := exp.NewInsertClauses(). + SetInto(exp.NewIdentifierExpression("", "test", "")). + SetCols(exp.NewColumnListExpression("a", "b")). + SetVals([][]interface{}{ + {"a1", "b1"}, + }). + SetReturning(exp.NewColumnListExpression("a", "b")) + + igs.assertCases( + NewInsertSQLGenerator("test", DefaultDialectOptions()), + insertTestCase{clause: ic, sql: `INSERT INTO "test" ("a", "b") VALUES ('a1', 'b1') RETURNING "a", "b"`}, + insertTestCase{clause: ic, sql: `INSERT INTO "test" ("a", "b") VALUES (?, ?) RETURNING "a", "b"`, isPrepared: true, args: []interface{}{ + "a1", "b1", + }}, + ) +} + +func TestInsertSQLGenerator(t *testing.T) { + suite.Run(t, new(insertSQLGeneratorSuite)) +} diff --git a/sqlgen/mocks/DeleteSQLGenerator.go b/sqlgen/mocks/DeleteSQLGenerator.go new file mode 100644 index 00000000..a0694f5a --- /dev/null +++ b/sqlgen/mocks/DeleteSQLGenerator.go @@ -0,0 +1,31 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import exp "github.com/doug-martin/goqu/v8/exp" +import mock "github.com/stretchr/testify/mock" +import sb "github.com/doug-martin/goqu/v8/internal/sb" + +// DeleteSQLGenerator is an autogenerated mock type for the DeleteSQLGenerator type +type DeleteSQLGenerator struct { + mock.Mock +} + +// Dialect provides a mock function with given fields: +func (_m *DeleteSQLGenerator) Dialect() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// Generate provides a mock function with given fields: b, clauses +func (_m *DeleteSQLGenerator) Generate(b sb.SQLBuilder, clauses exp.DeleteClauses) { + _m.Called(b, clauses) +} diff --git a/sqlgen/mocks/InsertSQLGenerator.go b/sqlgen/mocks/InsertSQLGenerator.go new file mode 100644 index 00000000..04e7a9bc --- /dev/null +++ b/sqlgen/mocks/InsertSQLGenerator.go @@ -0,0 +1,31 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import exp "github.com/doug-martin/goqu/v8/exp" +import mock "github.com/stretchr/testify/mock" +import sb "github.com/doug-martin/goqu/v8/internal/sb" + +// InsertSQLGenerator is an autogenerated mock type for the InsertSQLGenerator type +type InsertSQLGenerator struct { + mock.Mock +} + +// Dialect provides a mock function with given fields: +func (_m *InsertSQLGenerator) Dialect() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// Generate provides a mock function with given fields: b, clauses +func (_m *InsertSQLGenerator) Generate(b sb.SQLBuilder, clauses exp.InsertClauses) { + _m.Called(b, clauses) +} diff --git a/sqlgen/mocks/SelectSQLGenerator.go b/sqlgen/mocks/SelectSQLGenerator.go new file mode 100644 index 00000000..bdf1a155 --- /dev/null +++ b/sqlgen/mocks/SelectSQLGenerator.go @@ -0,0 +1,31 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import exp "github.com/doug-martin/goqu/v8/exp" +import mock "github.com/stretchr/testify/mock" +import sb "github.com/doug-martin/goqu/v8/internal/sb" + +// SelectSQLGenerator is an autogenerated mock type for the SelectSQLGenerator type +type SelectSQLGenerator struct { + mock.Mock +} + +// Dialect provides a mock function with given fields: +func (_m *SelectSQLGenerator) Dialect() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// Generate provides a mock function with given fields: b, clauses +func (_m *SelectSQLGenerator) Generate(b sb.SQLBuilder, clauses exp.SelectClauses) { + _m.Called(b, clauses) +} diff --git a/sqlgen/mocks/TruncateSQLGenerator.go b/sqlgen/mocks/TruncateSQLGenerator.go new file mode 100644 index 00000000..8799f8ea --- /dev/null +++ b/sqlgen/mocks/TruncateSQLGenerator.go @@ -0,0 +1,31 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import exp "github.com/doug-martin/goqu/v8/exp" +import mock "github.com/stretchr/testify/mock" +import sb "github.com/doug-martin/goqu/v8/internal/sb" + +// TruncateSQLGenerator is an autogenerated mock type for the TruncateSQLGenerator type +type TruncateSQLGenerator struct { + mock.Mock +} + +// Dialect provides a mock function with given fields: +func (_m *TruncateSQLGenerator) Dialect() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// Generate provides a mock function with given fields: b, clauses +func (_m *TruncateSQLGenerator) Generate(b sb.SQLBuilder, clauses exp.TruncateClauses) { + _m.Called(b, clauses) +} diff --git a/sqlgen/mocks/UpdateSQLGenerator.go b/sqlgen/mocks/UpdateSQLGenerator.go new file mode 100644 index 00000000..88c3a713 --- /dev/null +++ b/sqlgen/mocks/UpdateSQLGenerator.go @@ -0,0 +1,31 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import exp "github.com/doug-martin/goqu/v8/exp" +import mock "github.com/stretchr/testify/mock" +import sb "github.com/doug-martin/goqu/v8/internal/sb" + +// UpdateSQLGenerator is an autogenerated mock type for the UpdateSQLGenerator type +type UpdateSQLGenerator struct { + mock.Mock +} + +// Dialect provides a mock function with given fields: +func (_m *UpdateSQLGenerator) Dialect() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// Generate provides a mock function with given fields: b, clauses +func (_m *UpdateSQLGenerator) Generate(b sb.SQLBuilder, clauses exp.UpdateClauses) { + _m.Called(b, clauses) +} diff --git a/sqlgen/select_sql_generator.go b/sqlgen/select_sql_generator.go new file mode 100644 index 00000000..b9d358de --- /dev/null +++ b/sqlgen/select_sql_generator.go @@ -0,0 +1,206 @@ +package sqlgen + +import ( + "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/errors" + "github.com/doug-martin/goqu/v8/internal/sb" +) + +type ( + // An adapter interface to be used by a Dataset to generate SQL for a specific dialect. + // See DefaultAdapter for a concrete implementation and examples. + SelectSQLGenerator interface { + Dialect() string + Generate(b sb.SQLBuilder, clauses exp.SelectClauses) + } + // The default adapter. This class should be used when building a new adapter. When creating a new adapter you can + // either override methods, or more typically update default values. + // See (github.com/doug-martin/goqu/adapters/postgres) + selectSQLGenerator struct { + *commonSQLGenerator + } +) + +func errNotSupportedJoinType(j exp.JoinExpression) error { + return errors.New("dialect does not support %v", j.JoinType()) +} + +func errJoinConditionRequired(j exp.JoinExpression) error { + return errors.New("join condition required for conditioned join %v", j.JoinType()) +} +func errDistinctOnNotSupported(dialect string) error { + return errors.New("dialect does not support DISTINCT ON clause [dialect=%s]", dialect) +} + +func NewSelectSQLGenerator(dialect string, do *SQLDialectOptions) SelectSQLGenerator { + return &selectSQLGenerator{newCommonSQLGenerator(dialect, do)} +} + +func (ssg *selectSQLGenerator) Dialect() string { + return ssg.dialect +} + +func (ssg *selectSQLGenerator) Generate(b sb.SQLBuilder, clauses exp.SelectClauses) { + for _, f := range ssg.dialectOptions.SelectSQLOrder { + if b.Error() != nil { + return + } + switch f { + case CommonTableSQLFragment: + ssg.esg.Generate(b, clauses.CommonTables()) + case SelectSQLFragment: + ssg.SelectSQL(b, clauses) + case FromSQLFragment: + ssg.FromSQL(b, clauses.From()) + case JoinSQLFragment: + ssg.JoinSQL(b, clauses.Joins()) + case WhereSQLFragment: + ssg.WhereSQL(b, clauses.Where()) + case GroupBySQLFragment: + ssg.GroupBySQL(b, clauses.GroupBy()) + case HavingSQLFragment: + ssg.HavingSQL(b, clauses.Having()) + case CompoundsSQLFragment: + ssg.CompoundsSQL(b, clauses.Compounds()) + case OrderSQLFragment: + ssg.OrderSQL(b, clauses.Order()) + case LimitSQLFragment: + ssg.LimitSQL(b, clauses.Limit()) + case OffsetSQLFragment: + ssg.OffsetSQL(b, clauses.Offset()) + case ForSQLFragment: + ssg.ForSQL(b, clauses.Lock()) + default: + b.SetError(errNotSupportedFragment("SELECT", f)) + } + } +} + +// Adds the SELECT clause and columns to a sql statement +func (ssg *selectSQLGenerator) SelectSQL(b sb.SQLBuilder, clauses exp.SelectClauses) { + b.Write(ssg.dialectOptions.SelectClause). + WriteRunes(ssg.dialectOptions.SpaceRune) + dc := clauses.Distinct() + if dc != nil { + b.Write(ssg.dialectOptions.DistinctFragment) + if !dc.IsEmpty() { + if ssg.dialectOptions.SupportsDistinctOn { + b.Write(ssg.dialectOptions.OnFragment).WriteRunes(ssg.dialectOptions.LeftParenRune) + ssg.esg.Generate(b, dc) + b.WriteRunes(ssg.dialectOptions.RightParenRune, ssg.dialectOptions.SpaceRune) + } else { + b.SetError(errDistinctOnNotSupported(ssg.dialect)) + return + } + } else { + b.WriteRunes(ssg.dialectOptions.SpaceRune) + } + } + cols := clauses.Select() + if clauses.IsDefaultSelect() || len(cols.Columns()) == 0 { + b.WriteRunes(ssg.dialectOptions.StarRune) + } else { + ssg.esg.Generate(b, cols) + } +} + +// Generates the JOIN clauses for an SQL statement +func (ssg *selectSQLGenerator) JoinSQL(b sb.SQLBuilder, joins exp.JoinExpressions) { + if len(joins) > 0 { + for _, j := range joins { + joinType, ok := ssg.dialectOptions.JoinTypeLookup[j.JoinType()] + if !ok { + b.SetError(errNotSupportedJoinType(j)) + return + } + b.Write(joinType) + ssg.esg.Generate(b, j.Table()) + if t, ok := j.(exp.ConditionedJoinExpression); ok { + if t.IsConditionEmpty() { + b.SetError(errJoinConditionRequired(j)) + return + } + ssg.joinConditionSQL(b, t.Condition()) + } + } + } +} + +// Generates the GROUP BY clause for an SQL statement +func (ssg *selectSQLGenerator) GroupBySQL(b sb.SQLBuilder, groupBy exp.ColumnListExpression) { + if groupBy != nil && len(groupBy.Columns()) > 0 { + b.Write(ssg.dialectOptions.GroupByFragment) + ssg.esg.Generate(b, groupBy) + } +} + +// Generates the HAVING clause for an SQL statement +func (ssg *selectSQLGenerator) HavingSQL(b sb.SQLBuilder, having exp.ExpressionList) { + if having != nil && len(having.Expressions()) > 0 { + b.Write(ssg.dialectOptions.HavingFragment) + ssg.esg.Generate(b, having) + } +} + +// Generates the OFFSET clause for an SQL statement +func (ssg *selectSQLGenerator) OffsetSQL(b sb.SQLBuilder, offset uint) { + if offset > 0 { + b.Write(ssg.dialectOptions.OffsetFragment) + ssg.esg.Generate(b, offset) + } +} + +// Generates the compound sql clause for an SQL statement (e.g. UNION, INTERSECT) +func (ssg *selectSQLGenerator) CompoundsSQL(b sb.SQLBuilder, compounds []exp.CompoundExpression) { + for _, compound := range compounds { + ssg.esg.Generate(b, compound) + } +} + +// Generates the FOR (aka "locking") clause for an SQL statement +func (ssg *selectSQLGenerator) ForSQL(b sb.SQLBuilder, lockingClause exp.Lock) { + if lockingClause == nil { + return + } + switch lockingClause.Strength() { + case exp.ForNolock: + return + case exp.ForUpdate: + b.Write(ssg.dialectOptions.ForUpdateFragment) + case exp.ForNoKeyUpdate: + b.Write(ssg.dialectOptions.ForNoKeyUpdateFragment) + case exp.ForShare: + b.Write(ssg.dialectOptions.ForShareFragment) + case exp.ForKeyShare: + b.Write(ssg.dialectOptions.ForKeyShareFragment) + } + // the WAIT case is the default in Postgres, and is what you get if you don't specify NOWAIT or + // SKIP LOCKED. There's no special syntax for it in PG, so we don't do anything for it here + switch lockingClause.WaitOption() { + case exp.NoWait: + b.Write(ssg.dialectOptions.NowaitFragment) + case exp.SkipLocked: + b.Write(ssg.dialectOptions.SkipLockedFragment) + } +} + +func (ssg *selectSQLGenerator) joinConditionSQL(b sb.SQLBuilder, jc exp.JoinCondition) { + switch t := jc.(type) { + case exp.JoinOnCondition: + ssg.joinOnConditionSQL(b, t) + case exp.JoinUsingCondition: + ssg.joinUsingConditionSQL(b, t) + } +} + +func (ssg *selectSQLGenerator) joinUsingConditionSQL(b sb.SQLBuilder, jc exp.JoinUsingCondition) { + b.Write(ssg.dialectOptions.UsingFragment). + WriteRunes(ssg.dialectOptions.LeftParenRune) + ssg.esg.Generate(b, jc.Using()) + b.WriteRunes(ssg.dialectOptions.RightParenRune) +} + +func (ssg *selectSQLGenerator) joinOnConditionSQL(b sb.SQLBuilder, jc exp.JoinOnCondition) { + b.Write(ssg.dialectOptions.OnFragment) + ssg.esg.Generate(b, jc.On()) +} diff --git a/sqlgen/select_sql_generator_test.go b/sqlgen/select_sql_generator_test.go new file mode 100644 index 00000000..838f69e4 --- /dev/null +++ b/sqlgen/select_sql_generator_test.go @@ -0,0 +1,412 @@ +package sqlgen + +import ( + "testing" + + "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/errors" + "github.com/doug-martin/goqu/v8/internal/sb" + "github.com/stretchr/testify/suite" +) + +type ( + selectTestCase struct { + clause exp.SelectClauses + sql string + isPrepared bool + args []interface{} + err string + } + selectSQLGeneratorSuite struct { + baseSQLGeneratorSuite + } +) + +func (ssgs *selectSQLGeneratorSuite) assertCases(ssg SelectSQLGenerator, testCases ...selectTestCase) { + for _, tc := range testCases { + b := sb.NewSQLBuilder(tc.isPrepared) + ssg.Generate(b, tc.clause) + switch { + case len(tc.err) > 0: + ssgs.assertErrorSQL(b, tc.err) + case tc.isPrepared: + ssgs.assertPreparedSQL(b, tc.sql, tc.args) + default: + ssgs.assertNotPreparedSQL(b, tc.sql) + } + } +} + +func (ssgs *selectSQLGeneratorSuite) TestDialect() { + opts := DefaultDialectOptions() + d := NewSelectSQLGenerator("test", opts) + ssgs.Equal("test", d.Dialect()) + + opts2 := DefaultDialectOptions() + d2 := NewSelectSQLGenerator("test2", opts2) + ssgs.Equal("test2", d2.Dialect()) +} + +func (ssgs *selectSQLGeneratorSuite) TestGenerate() { + opts := DefaultDialectOptions() + opts.SelectClause = []byte("select") + opts.StarRune = '#' + + sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) + scWithCols := sc.SetSelect(exp.NewColumnListExpression("a", "b")) + + ssgs.assertCases( + NewSelectSQLGenerator("test", opts), + selectTestCase{clause: sc, sql: `select # FROM "test"`}, + selectTestCase{clause: sc, sql: `select # FROM "test"`, isPrepared: true}, + + selectTestCase{clause: scWithCols, sql: `select "a", "b" FROM "test"`}, + selectTestCase{clause: scWithCols, sql: `select "a", "b" FROM "test"`, isPrepared: true}, + ) +} + +func (ssgs *selectSQLGeneratorSuite) TestGenerate_UnsupportedFragment() { + opts := DefaultDialectOptions() + opts.SelectSQLOrder = []SQLFragmentType{InsertBeingSQLFragment} + + sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) + expectedErr := "goqu: unsupported SELECT SQL fragment InsertBeingSQLFragment" + ssgs.assertCases( + NewSelectSQLGenerator("test", opts), + selectTestCase{clause: sc, err: expectedErr}, + selectTestCase{clause: sc, err: expectedErr, isPrepared: true}, + ) +} + +func (ssgs *selectSQLGeneratorSuite) TestGenerate_WithErroredBuilder() { + opts := DefaultDialectOptions() + opts.SelectSQLOrder = []SQLFragmentType{InsertBeingSQLFragment} + d := NewSelectSQLGenerator("test", opts) + + b := sb.NewSQLBuilder(true).SetError(errors.New("test error")) + c := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) + d.Generate(b, c) + ssgs.assertErrorSQL(b, `goqu: test error`) +} + +func (ssgs *selectSQLGeneratorSuite) TestGenerate_withDistinct() { + opts := DefaultDialectOptions() + // make sure the fragments are used + opts.SelectClause = []byte("select") + opts.StarRune = '#' + opts.DistinctFragment = []byte("distinct") + opts.OnFragment = []byte(" on ") + opts.SupportsDistinctOn = true + + sc := exp.NewSelectClauses().SetDistinct(exp.NewColumnListExpression()) + scDistinctOn := sc.SetDistinct(exp.NewColumnListExpression("a", "b")) + + ssgs.assertCases( + NewSelectSQLGenerator("test", opts), + selectTestCase{clause: sc, sql: `select distinct #`}, + selectTestCase{clause: sc, sql: `select distinct #`, isPrepared: true}, + + selectTestCase{clause: scDistinctOn, sql: `select distinct on ("a", "b") #`}, + selectTestCase{clause: scDistinctOn, sql: `select distinct on ("a", "b") #`, isPrepared: true}, + ) + + opts = DefaultDialectOptions() + opts.SupportsDistinctOn = false + expectedErr := "goqu: dialect does not support DISTINCT ON clause [dialect=test]" + ssgs.assertCases( + NewSelectSQLGenerator("test", opts), + selectTestCase{clause: sc, sql: `SELECT DISTINCT *`}, + selectTestCase{clause: sc, sql: `SELECT DISTINCT *`, isPrepared: true}, + + selectTestCase{clause: scDistinctOn, err: expectedErr}, + selectTestCase{clause: scDistinctOn, err: expectedErr, isPrepared: true}, + ) +} + +func (ssgs *selectSQLGeneratorSuite) TestGenerate_withFromSQL() { + opts := DefaultDialectOptions() + opts.FromFragment = []byte(" from") + + sc := exp.NewSelectClauses() + scFrom := sc.SetFrom(exp.NewColumnListExpression("a", "b")) + ssgs.assertCases( + NewSelectSQLGenerator("test", opts), + selectTestCase{clause: sc, sql: `SELECT *`}, + selectTestCase{clause: sc, sql: `SELECT *`, isPrepared: true}, + + selectTestCase{clause: scFrom, sql: `SELECT * from "a", "b"`}, + selectTestCase{clause: scFrom, sql: `SELECT * from "a", "b"`, isPrepared: true}, + ) +} + +func (ssgs *selectSQLGeneratorSuite) TestGenerate_withJoin() { + opts := DefaultDialectOptions() + // override fragements to make sure dialect is used + opts.UsingFragment = []byte(" using ") + opts.OnFragment = []byte(" on ") + opts.JoinTypeLookup = map[exp.JoinType][]byte{ + exp.LeftJoinType: []byte(" left join "), + exp.NaturalJoinType: []byte(" natural join "), + } + + sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) + ti := exp.NewIdentifierExpression("", "test2", "") + uj := exp.NewUnConditionedJoinExpression(exp.NaturalJoinType, ti) + cjo := exp.NewConditionedJoinExpression(exp.LeftJoinType, ti, exp.NewJoinOnCondition(exp.Ex{"a": "foo"})) + cju := exp.NewConditionedJoinExpression(exp.LeftJoinType, ti, exp.NewJoinUsingCondition("a")) + rj := exp.NewConditionedJoinExpression(exp.RightJoinType, ti, exp.NewJoinUsingCondition(exp.NewIdentifierExpression("", "", "a"))) + badJoin := exp.NewConditionedJoinExpression(exp.LeftJoinType, ti, exp.NewJoinUsingCondition()) + + expectedRjError := "goqu: dialect does not support RightJoinType" + expectedJoinCondError := "goqu: join condition required for conditioned join LeftJoinType" + ssgs.assertCases( + NewSelectSQLGenerator("test", opts), + selectTestCase{clause: sc.JoinsAppend(uj), sql: `SELECT * FROM "test" natural join "test2"`}, + selectTestCase{clause: sc.JoinsAppend(uj), sql: `SELECT * FROM "test" natural join "test2"`, isPrepared: true}, + + selectTestCase{clause: sc.JoinsAppend(cjo), sql: `SELECT * FROM "test" left join "test2" on ("a" = 'foo')`}, + selectTestCase{ + clause: sc.JoinsAppend(cjo), + sql: `SELECT * FROM "test" left join "test2" on ("a" = ?)`, + isPrepared: true, + args: []interface{}{"foo"}, + }, + + selectTestCase{clause: sc.JoinsAppend(cju), sql: `SELECT * FROM "test" left join "test2" using ("a")`}, + selectTestCase{clause: sc.JoinsAppend(cju), sql: `SELECT * FROM "test" left join "test2" using ("a")`, isPrepared: true}, + + selectTestCase{ + clause: sc.JoinsAppend(uj).JoinsAppend(cjo).JoinsAppend(cju), + sql: `SELECT * FROM "test" natural join "test2" left join "test2" on ("a" = 'foo') left join "test2" using ("a")`, + }, + selectTestCase{ + clause: sc.JoinsAppend(uj).JoinsAppend(cjo).JoinsAppend(cju), + sql: `SELECT * FROM "test" natural join "test2" left join "test2" on ("a" = ?) left join "test2" using ("a")`, + isPrepared: true, + args: []interface{}{"foo"}, + }, + + selectTestCase{clause: sc.JoinsAppend(rj), err: expectedRjError}, + selectTestCase{clause: sc.JoinsAppend(rj), err: expectedRjError, isPrepared: true}, + + selectTestCase{clause: sc.JoinsAppend(badJoin), err: expectedJoinCondError}, + selectTestCase{clause: sc.JoinsAppend(badJoin), err: expectedJoinCondError, isPrepared: true}, + ) + +} + +func (ssgs *selectSQLGeneratorSuite) TestGenerate_withWhere() { + opts := DefaultDialectOptions() + opts.WhereFragment = []byte(" where ") + + sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) + w := exp.Ex{"a": "b"} + w2 := exp.Ex{"b": "c"} + scWhere1 := sc.WhereAppend(w) + scWhere2 := sc.WhereAppend(w, w2) + + ssgs.assertCases( + NewSelectSQLGenerator("test", opts), + selectTestCase{clause: scWhere1, sql: `SELECT * FROM "test" where ("a" = 'b')`}, + selectTestCase{clause: scWhere1, sql: `SELECT * FROM "test" where ("a" = ?)`, isPrepared: true, args: []interface{}{"b"}}, + + selectTestCase{clause: scWhere2, sql: `SELECT * FROM "test" where (("a" = 'b') AND ("b" = 'c'))`}, + selectTestCase{ + clause: scWhere2, + sql: `SELECT * FROM "test" where (("a" = ?) AND ("b" = ?))`, + isPrepared: true, + args: []interface{}{"b", "c"}, + }, + ) +} + +func (ssgs *selectSQLGeneratorSuite) TestGenerate_withGroupBy() { + opts := DefaultDialectOptions() + opts.GroupByFragment = []byte(" group by ") + + sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) + scGroup := sc.SetGroupBy(exp.NewColumnListExpression("a")) + scGroupMulti := sc.SetGroupBy(exp.NewColumnListExpression("a", "b")) + + ssgs.assertCases( + NewSelectSQLGenerator("test", opts), + selectTestCase{clause: scGroup, sql: `SELECT * FROM "test" group by "a"`}, + selectTestCase{clause: scGroup, sql: `SELECT * FROM "test" group by "a"`, isPrepared: true}, + + selectTestCase{clause: scGroupMulti, sql: `SELECT * FROM "test" group by "a", "b"`}, + selectTestCase{clause: scGroupMulti, sql: `SELECT * FROM "test" group by "a", "b"`, isPrepared: true}, + ) +} + +func (ssgs *selectSQLGeneratorSuite) TestGenerate_withHaving() { + opts := DefaultDialectOptions() + opts.HavingFragment = []byte(" having ") + + sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) + w := exp.Ex{"a": "b"} + w2 := exp.Ex{"b": "c"} + scHaving1 := sc.HavingAppend(w) + scHaving2 := sc.HavingAppend(w, w2) + + ssgs.assertCases( + NewSelectSQLGenerator("test", opts), + selectTestCase{clause: scHaving1, sql: `SELECT * FROM "test" having ("a" = 'b')`}, + selectTestCase{clause: scHaving1, sql: `SELECT * FROM "test" having ("a" = ?)`, isPrepared: true, args: []interface{}{"b"}}, + + selectTestCase{clause: scHaving2, sql: `SELECT * FROM "test" having (("a" = 'b') AND ("b" = 'c'))`}, + selectTestCase{ + clause: scHaving2, + sql: `SELECT * FROM "test" having (("a" = ?) AND ("b" = ?))`, + isPrepared: true, + args: []interface{}{"b", "c"}, + }, + ) +} + +func (ssgs *selectSQLGeneratorSuite) TestGenerate_withOrder() { + sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")). + SetOrder( + exp.NewIdentifierExpression("", "", "a").Asc(), + exp.NewIdentifierExpression("", "", "b").Desc(), + ) + ssgs.assertCases( + NewSelectSQLGenerator("test", DefaultDialectOptions()), + selectTestCase{clause: sc, sql: `SELECT * FROM "test" ORDER BY "a" ASC, "b" DESC`}, + selectTestCase{clause: sc, sql: `SELECT * FROM "test" ORDER BY "a" ASC, "b" DESC`, isPrepared: true}, + ) +} + +func (ssgs *selectSQLGeneratorSuite) TestGenerate_withLimit() { + sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")). + SetLimit(10) + ssgs.assertCases( + NewSelectSQLGenerator("test", DefaultDialectOptions()), + selectTestCase{clause: sc, sql: `SELECT * FROM "test" LIMIT 10`}, + selectTestCase{clause: sc, sql: `SELECT * FROM "test" LIMIT ?`, isPrepared: true, args: []interface{}{int64(10)}}, + ) +} + +func (ssgs *selectSQLGeneratorSuite) TestGenerate_withOffset() { + sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")). + SetOffset(10) + ssgs.assertCases( + NewSelectSQLGenerator("test", DefaultDialectOptions()), + selectTestCase{clause: sc, sql: `SELECT * FROM "test" OFFSET 10`}, + selectTestCase{clause: sc, sql: `SELECT * FROM "test" OFFSET ?`, isPrepared: true, args: []interface{}{int64(10)}}, + ) +} + +func (ssgs *selectSQLGeneratorSuite) TestGenerate_withCommonTables() { + + tse := newTestAppendableExpression("select * from foo", emptyArgs, nil, nil) + + sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test_cte")) + scCte1 := sc.CommonTablesAppend(exp.NewCommonTableExpression(false, "test_cte", tse)) + scCte2 := sc.CommonTablesAppend(exp.NewCommonTableExpression(true, "test_cte", tse)) + + ssgs.assertCases( + NewSelectSQLGenerator("test", DefaultDialectOptions()), + selectTestCase{clause: scCte1, sql: `WITH test_cte AS (select * from foo) SELECT * FROM "test_cte"`}, + selectTestCase{clause: scCte1, sql: `WITH test_cte AS (select * from foo) SELECT * FROM "test_cte"`, isPrepared: true}, + + selectTestCase{clause: scCte2, sql: `WITH RECURSIVE test_cte AS (select * from foo) SELECT * FROM "test_cte"`}, + selectTestCase{clause: scCte2, sql: `WITH RECURSIVE test_cte AS (select * from foo) SELECT * FROM "test_cte"`, isPrepared: true}, + ) + +} + +func (ssgs *selectSQLGeneratorSuite) TestGenerate_withCompounds() { + tse := newTestAppendableExpression("select * from foo", emptyArgs, nil, nil) + sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")). + CompoundsAppend(exp.NewCompoundExpression(exp.UnionCompoundType, tse)). + CompoundsAppend(exp.NewCompoundExpression(exp.IntersectCompoundType, tse)) + + expectedSQL := `SELECT * FROM "test" UNION (select * from foo) INTERSECT (select * from foo)` + ssgs.assertCases( + NewSelectSQLGenerator("test", DefaultDialectOptions()), + selectTestCase{clause: sc, sql: expectedSQL}, + selectTestCase{clause: sc, sql: expectedSQL, isPrepared: true}, + ) +} + +func (ssgs *selectSQLGeneratorSuite) TestToSelectSQL_withFor() { + opts := DefaultDialectOptions() + opts.ForUpdateFragment = []byte(" for update ") + opts.ForNoKeyUpdateFragment = []byte(" for no key update ") + opts.ForShareFragment = []byte(" for share ") + opts.ForKeyShareFragment = []byte(" for key share ") + opts.NowaitFragment = []byte("nowait") + opts.SkipLockedFragment = []byte("skip locked") + + sc := exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")) + scFnW := sc.SetLock(exp.NewLock(exp.ForNolock, exp.Wait)) + scFnNw := sc.SetLock(exp.NewLock(exp.ForNolock, exp.NoWait)) + scFnSl := sc.SetLock(exp.NewLock(exp.ForNolock, exp.SkipLocked)) + + scFsW := sc.SetLock(exp.NewLock(exp.ForShare, exp.Wait)) + scFsNw := sc.SetLock(exp.NewLock(exp.ForShare, exp.NoWait)) + scFsSl := sc.SetLock(exp.NewLock(exp.ForShare, exp.SkipLocked)) + + scFksW := sc.SetLock(exp.NewLock(exp.ForKeyShare, exp.Wait)) + scFksNw := sc.SetLock(exp.NewLock(exp.ForKeyShare, exp.NoWait)) + scFksSl := sc.SetLock(exp.NewLock(exp.ForKeyShare, exp.SkipLocked)) + + scFuW := sc.SetLock(exp.NewLock(exp.ForUpdate, exp.Wait)) + scFuNw := sc.SetLock(exp.NewLock(exp.ForUpdate, exp.NoWait)) + scFuSl := sc.SetLock(exp.NewLock(exp.ForUpdate, exp.SkipLocked)) + + scFkuW := sc.SetLock(exp.NewLock(exp.ForNoKeyUpdate, exp.Wait)) + scFkuNw := sc.SetLock(exp.NewLock(exp.ForNoKeyUpdate, exp.NoWait)) + scFkuSl := sc.SetLock(exp.NewLock(exp.ForNoKeyUpdate, exp.SkipLocked)) + ssgs.assertCases( + NewSelectSQLGenerator("test", opts), + selectTestCase{clause: scFnW, sql: `SELECT * FROM "test"`}, + selectTestCase{clause: scFnW, sql: `SELECT * FROM "test"`, isPrepared: true}, + + selectTestCase{clause: scFnNw, sql: `SELECT * FROM "test"`}, + selectTestCase{clause: scFnNw, sql: `SELECT * FROM "test"`, isPrepared: true}, + + selectTestCase{clause: scFnSl, sql: `SELECT * FROM "test"`}, + selectTestCase{clause: scFnSl, sql: `SELECT * FROM "test"`, isPrepared: true}, + + selectTestCase{clause: scFsW, sql: `SELECT * FROM "test" for share `}, + selectTestCase{clause: scFsW, sql: `SELECT * FROM "test" for share `, isPrepared: true}, + + selectTestCase{clause: scFsNw, sql: `SELECT * FROM "test" for share nowait`}, + selectTestCase{clause: scFsNw, sql: `SELECT * FROM "test" for share nowait`, isPrepared: true}, + + selectTestCase{clause: scFsSl, sql: `SELECT * FROM "test" for share skip locked`}, + selectTestCase{clause: scFsSl, sql: `SELECT * FROM "test" for share skip locked`, isPrepared: true}, + + selectTestCase{clause: scFksW, sql: `SELECT * FROM "test" for key share `}, + selectTestCase{clause: scFksW, sql: `SELECT * FROM "test" for key share `, isPrepared: true}, + + selectTestCase{clause: scFksNw, sql: `SELECT * FROM "test" for key share nowait`}, + selectTestCase{clause: scFksNw, sql: `SELECT * FROM "test" for key share nowait`, isPrepared: true}, + + selectTestCase{clause: scFksSl, sql: `SELECT * FROM "test" for key share skip locked`}, + selectTestCase{clause: scFksSl, sql: `SELECT * FROM "test" for key share skip locked`, isPrepared: true}, + + selectTestCase{clause: scFuW, sql: `SELECT * FROM "test" for update `}, + selectTestCase{clause: scFuW, sql: `SELECT * FROM "test" for update `, isPrepared: true}, + + selectTestCase{clause: scFuNw, sql: `SELECT * FROM "test" for update nowait`}, + selectTestCase{clause: scFuNw, sql: `SELECT * FROM "test" for update nowait`, isPrepared: true}, + + selectTestCase{clause: scFuSl, sql: `SELECT * FROM "test" for update skip locked`}, + selectTestCase{clause: scFuSl, sql: `SELECT * FROM "test" for update skip locked`, isPrepared: true}, + + selectTestCase{clause: scFkuW, sql: `SELECT * FROM "test" for no key update `}, + selectTestCase{clause: scFkuW, sql: `SELECT * FROM "test" for no key update `, isPrepared: true}, + + selectTestCase{clause: scFkuNw, sql: `SELECT * FROM "test" for no key update nowait`}, + selectTestCase{clause: scFkuNw, sql: `SELECT * FROM "test" for no key update nowait`, isPrepared: true}, + + selectTestCase{clause: scFkuSl, sql: `SELECT * FROM "test" for no key update skip locked`}, + selectTestCase{clause: scFkuSl, sql: `SELECT * FROM "test" for no key update skip locked`, isPrepared: true}, + ) +} + +func TestSelectSQLGenerator(t *testing.T) { + suite.Run(t, new(selectSQLGeneratorSuite)) +} diff --git a/sql_dialect_options.go b/sqlgen/sql_dialect_options.go similarity index 99% rename from sql_dialect_options.go rename to sqlgen/sql_dialect_options.go index 1afb8144..174adf8f 100644 --- a/sql_dialect_options.go +++ b/sqlgen/sql_dialect_options.go @@ -1,4 +1,4 @@ -package goqu +package sqlgen import ( "fmt" diff --git a/sqlgen/truncate_sql_generator.go b/sqlgen/truncate_sql_generator.go new file mode 100644 index 00000000..ba43aaac --- /dev/null +++ b/sqlgen/truncate_sql_generator.go @@ -0,0 +1,68 @@ +package sqlgen + +import ( + "strings" + + "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/errors" + "github.com/doug-martin/goqu/v8/internal/sb" +) + +type ( + // An adapter interface to be used by a Dataset to generate SQL for a specific dialect. + // See DefaultAdapter for a concrete implementation and examples. + TruncateSQLGenerator interface { + Dialect() string + Generate(b sb.SQLBuilder, clauses exp.TruncateClauses) + } + // The default adapter. This class should be used when building a new adapter. When creating a new adapter you can + // either override methods, or more typically update default values. + // See (github.com/doug-martin/goqu/adapters/postgres) + truncateSQLGenerator struct { + *commonSQLGenerator + } +) + +var errNoSourceForTruncate = errors.New("no source found when generating truncate sql") + +func NewTruncateSQLGenerator(dialect string, do *SQLDialectOptions) TruncateSQLGenerator { + return &truncateSQLGenerator{newCommonSQLGenerator(dialect, do)} +} + +func (tsg *truncateSQLGenerator) Dialect() string { + return tsg.dialect +} + +func (tsg *truncateSQLGenerator) Generate(b sb.SQLBuilder, clauses exp.TruncateClauses) { + if !clauses.HasTable() { + b.SetError(errNoSourceForTruncate) + return + } + for _, f := range tsg.dialectOptions.TruncateSQLOrder { + if b.Error() != nil { + return + } + switch f { + case TruncateSQLFragment: + tsg.TruncateSQL(b, clauses.Table(), clauses.Options()) + default: + b.SetError(errNotSupportedFragment("TRUNCATE", f)) + } + } +} + +// Generates a TRUNCATE statement +func (tsg *truncateSQLGenerator) TruncateSQL(b sb.SQLBuilder, from exp.ColumnListExpression, opts exp.TruncateOptions) { + b.Write(tsg.dialectOptions.TruncateClause) + tsg.SourcesSQL(b, from) + if opts.Identity != tsg.dialectOptions.EmptyString { + b.WriteRunes(tsg.dialectOptions.SpaceRune). + WriteStrings(strings.ToUpper(opts.Identity)). + Write(tsg.dialectOptions.IdentityFragment) + } + if opts.Cascade { + b.Write(tsg.dialectOptions.CascadeFragment) + } else if opts.Restrict { + b.Write(tsg.dialectOptions.RestrictFragment) + } +} diff --git a/sqlgen/truncate_sql_generator_test.go b/sqlgen/truncate_sql_generator_test.go new file mode 100644 index 00000000..99cb281d --- /dev/null +++ b/sqlgen/truncate_sql_generator_test.go @@ -0,0 +1,120 @@ +package sqlgen + +import ( + "testing" + + "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/errors" + "github.com/doug-martin/goqu/v8/internal/sb" + "github.com/stretchr/testify/suite" +) + +type ( + truncateTestCase struct { + clause exp.TruncateClauses + sql string + isPrepared bool + args []interface{} + err string + } + truncateSQLGeneratorSuite struct { + baseSQLGeneratorSuite + } +) + +func (tsgs *truncateSQLGeneratorSuite) assertCases(tsg TruncateSQLGenerator, testCases ...truncateTestCase) { + for _, tc := range testCases { + b := sb.NewSQLBuilder(tc.isPrepared) + tsg.Generate(b, tc.clause) + switch { + case len(tc.err) > 0: + tsgs.assertErrorSQL(b, tc.err) + case tc.isPrepared: + tsgs.assertPreparedSQL(b, tc.sql, tc.args) + default: + tsgs.assertNotPreparedSQL(b, tc.sql) + } + } +} + +func (tsgs *truncateSQLGeneratorSuite) TestDialect() { + opts := DefaultDialectOptions() + d := NewTruncateSQLGenerator("test", opts) + tsgs.Equal("test", d.Dialect()) + + opts2 := DefaultDialectOptions() + d2 := NewTruncateSQLGenerator("test2", opts2) + tsgs.Equal("test2", d2.Dialect()) +} + +func (tsgs *truncateSQLGeneratorSuite) TestGenerate() { + opts := DefaultDialectOptions() + opts.TruncateClause = []byte("truncate") + + tcNoTable := exp.NewTruncateClauses() + tcSingle := tcNoTable.SetTable(exp.NewColumnListExpression("a")) + tcMulti := exp.NewTruncateClauses().SetTable(exp.NewColumnListExpression("a", "b")) + + expectedNoSourceErr := "goqu: no source found when generating truncate sql" + tsgs.assertCases( + NewTruncateSQLGenerator("test", opts), + truncateTestCase{clause: tcSingle, sql: `truncate "a"`}, + truncateTestCase{clause: tcSingle, sql: `truncate "a"`, isPrepared: true}, + + truncateTestCase{clause: tcMulti, sql: `truncate "a", "b"`}, + truncateTestCase{clause: tcMulti, sql: `truncate "a", "b"`, isPrepared: true}, + + truncateTestCase{clause: tcNoTable, err: expectedNoSourceErr}, + truncateTestCase{clause: tcNoTable, err: expectedNoSourceErr, isPrepared: true}, + ) +} + +func (tsgs *truncateSQLGeneratorSuite) TestGenerate_UnsupportedFragment() { + opts := DefaultDialectOptions() + opts.TruncateSQLOrder = []SQLFragmentType{UpdateBeginSQLFragment} + tc := exp.NewTruncateClauses().SetTable(exp.NewColumnListExpression("a")) + expectedErr := "goqu: unsupported TRUNCATE SQL fragment UpdateBeginSQLFragment" + tsgs.assertCases( + NewTruncateSQLGenerator("test", opts), + truncateTestCase{clause: tc, err: expectedErr}, + truncateTestCase{clause: tc, err: expectedErr, isPrepared: true}, + ) +} + +func (tsgs *truncateSQLGeneratorSuite) TestGenerate_WithErroredBuilder() { + opts := DefaultDialectOptions() + opts.TruncateSQLOrder = []SQLFragmentType{UpdateBeginSQLFragment} + d := NewTruncateSQLGenerator("test", opts) + + b := sb.NewSQLBuilder(true).SetError(errors.New("expected error")) + d.Generate(b, exp.NewTruncateClauses().SetTable(exp.NewColumnListExpression("a"))) + tsgs.assertErrorSQL(b, `goqu: expected error`) +} + +func (tsgs *truncateSQLGeneratorSuite) TestGenerate_WithCascade() { + opts := DefaultDialectOptions() + opts.CascadeFragment = []byte(" cascade") + opts.RestrictFragment = []byte(" restrict") + opts.IdentityFragment = []byte(" identity") + + tc := exp.NewTruncateClauses().SetTable(exp.NewColumnListExpression("a")) + tcCascade := tc.SetOptions(exp.TruncateOptions{Cascade: true}) + tcRestrict := tc.SetOptions(exp.TruncateOptions{Restrict: true}) + tcRestart := tc.SetOptions(exp.TruncateOptions{Identity: "restart"}) + + tsgs.assertCases( + NewTruncateSQLGenerator("test", opts), + truncateTestCase{clause: tcCascade, sql: `TRUNCATE "a" cascade`}, + truncateTestCase{clause: tcCascade, sql: `TRUNCATE "a" cascade`, isPrepared: true}, + + truncateTestCase{clause: tcRestrict, sql: `TRUNCATE "a" restrict`}, + truncateTestCase{clause: tcRestrict, sql: `TRUNCATE "a" restrict`, isPrepared: true}, + + truncateTestCase{clause: tcRestart, sql: `TRUNCATE "a" RESTART identity`}, + truncateTestCase{clause: tcRestart, sql: `TRUNCATE "a" RESTART identity`, isPrepared: true}, + ) +} + +func TestTruncateSQLGenerator(t *testing.T) { + suite.Run(t, new(truncateSQLGeneratorSuite)) +} diff --git a/sqlgen/update_sql_generator.go b/sqlgen/update_sql_generator.go new file mode 100644 index 00000000..dcd85cfc --- /dev/null +++ b/sqlgen/update_sql_generator.go @@ -0,0 +1,117 @@ +package sqlgen + +import ( + "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/errors" + "github.com/doug-martin/goqu/v8/internal/sb" +) + +type ( + // An adapter interface to be used by a Dataset to generate SQL for a specific dialect. + // See DefaultAdapter for a concrete implementation and examples. + UpdateSQLGenerator interface { + Dialect() string + Generate(b sb.SQLBuilder, clauses exp.UpdateClauses) + } + // The default adapter. This class should be used when building a new adapter. When creating a new adapter you can + // either override methods, or more typically update default values. + // See (github.com/doug-martin/goqu/adapters/postgres) + updateSQLGenerator struct { + *commonSQLGenerator + } +) + +var ( + errNoSourceForUpdate = errors.New("no source found when generating update sql") + errNoSetValuesForUpdate = errors.New("no set values found when generating UPDATE sql") +) + +func NewUpdateSQLGenerator(dialect string, do *SQLDialectOptions) UpdateSQLGenerator { + return &updateSQLGenerator{newCommonSQLGenerator(dialect, do)} +} + +func (usg *updateSQLGenerator) Dialect() string { + return usg.dialect +} + +func (usg *updateSQLGenerator) Generate(b sb.SQLBuilder, clauses exp.UpdateClauses) { + if !clauses.HasTable() { + b.SetError(errNoSourceForUpdate) + return + } + if !clauses.HasSetValues() { + b.SetError(errNoSetValuesForUpdate) + return + } + if !usg.dialectOptions.SupportsMultipleUpdateTables && clauses.HasFrom() { + b.SetError(errors.New("%s dialect does not support multiple tables in UPDATE", usg.dialect)) + } + updates, err := exp.NewUpdateExpressions(clauses.SetValues()) + if err != nil { + b.SetError(err) + return + } + for _, f := range usg.dialectOptions.UpdateSQLOrder { + if b.Error() != nil { + return + } + switch f { + case CommonTableSQLFragment: + usg.esg.Generate(b, clauses.CommonTables()) + case UpdateBeginSQLFragment: + usg.UpdateBeginSQL(b) + case SourcesSQLFragment: + usg.updateTableSQL(b, clauses) + case UpdateSQLFragment: + usg.UpdateExpressionsSQL(b, updates...) + case UpdateFromSQLFragment: + usg.updateFromSQL(b, clauses.From()) + case WhereSQLFragment: + usg.WhereSQL(b, clauses.Where()) + case OrderSQLFragment: + if usg.dialectOptions.SupportsOrderByOnUpdate { + usg.OrderSQL(b, clauses.Order()) + } + case LimitSQLFragment: + if usg.dialectOptions.SupportsLimitOnUpdate { + usg.LimitSQL(b, clauses.Limit()) + } + case ReturningSQLFragment: + usg.ReturningSQL(b, clauses.Returning()) + default: + b.SetError(errNotSupportedFragment("UPDATE", f)) + } + } +} + +// Adds the correct fragment to being an UPDATE statement +func (usg *updateSQLGenerator) UpdateBeginSQL(b sb.SQLBuilder) { + b.Write(usg.dialectOptions.UpdateClause) +} + +// Adds column setters in an update SET clause +func (usg *updateSQLGenerator) UpdateExpressionsSQL(b sb.SQLBuilder, updates ...exp.UpdateExpression) { + b.Write(usg.dialectOptions.SetFragment) + usg.UpdateExpressionSQL(b, updates...) + +} + +func (usg *updateSQLGenerator) updateTableSQL(b sb.SQLBuilder, uc exp.UpdateClauses) { + b.WriteRunes(usg.dialectOptions.SpaceRune) + usg.esg.Generate(b, uc.Table()) + if uc.HasFrom() { + if !usg.dialectOptions.UseFromClauseForMultipleUpdateTables { + b.WriteRunes(usg.dialectOptions.CommaRune) + usg.esg.Generate(b, uc.From()) + } + } +} + +func (usg *updateSQLGenerator) updateFromSQL(b sb.SQLBuilder, ce exp.ColumnListExpression) { + if ce == nil || ce.IsEmpty() { + return + } + if usg.dialectOptions.UseFromClauseForMultipleUpdateTables { + usg.FromSQL(b, ce) + } +} diff --git a/sqlgen/update_sql_generator_test.go b/sqlgen/update_sql_generator_test.go new file mode 100644 index 00000000..46e2ce5b --- /dev/null +++ b/sqlgen/update_sql_generator_test.go @@ -0,0 +1,241 @@ +package sqlgen + +import ( + "testing" + + "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/sb" + "github.com/stretchr/testify/suite" +) + +type ( + updateTestCase struct { + clause exp.UpdateClauses + sql string + isPrepared bool + args []interface{} + err string + } + updateSQLGeneratorSuite struct { + baseSQLGeneratorSuite + } +) + +func (usgs *updateSQLGeneratorSuite) assertCases(usg UpdateSQLGenerator, testCases ...updateTestCase) { + for _, tc := range testCases { + b := sb.NewSQLBuilder(tc.isPrepared) + usg.Generate(b, tc.clause) + switch { + case len(tc.err) > 0: + usgs.assertErrorSQL(b, tc.err) + case tc.isPrepared: + usgs.assertPreparedSQL(b, tc.sql, tc.args) + default: + usgs.assertNotPreparedSQL(b, tc.sql) + } + } +} + +func (usgs *updateSQLGeneratorSuite) TestDialect() { + opts := DefaultDialectOptions() + d := NewUpdateSQLGenerator("test", opts) + usgs.Equal("test", d.Dialect()) + + opts2 := DefaultDialectOptions() + d2 := NewUpdateSQLGenerator("test2", opts2) + usgs.Equal("test2", d2.Dialect()) +} + +func (usgs *updateSQLGeneratorSuite) TestGenerate_unsupportedFragment() { + opts := DefaultDialectOptions() + opts.UpdateSQLOrder = []SQLFragmentType{InsertBeingSQLFragment} + + uc := exp.NewUpdateClauses(). + SetTable(exp.NewIdentifierExpression("", "test", "")). + SetSetValues(exp.Record{"a": "b", "b": "c"}) + expectedErr := "goqu: unsupported UPDATE SQL fragment InsertBeingSQLFragment" + usgs.assertCases( + NewUpdateSQLGenerator("test", opts), + updateTestCase{clause: uc, err: expectedErr}, + updateTestCase{clause: uc, err: expectedErr, isPrepared: true}, + ) +} + +func (usgs *updateSQLGeneratorSuite) TestGenerate_empty() { + uc := exp.NewUpdateClauses() + usgs.assertCases( + NewUpdateSQLGenerator("test", DefaultDialectOptions()), + updateTestCase{clause: uc, err: errNoSourceForUpdate.Error()}, + updateTestCase{clause: uc, err: errNoSourceForUpdate.Error(), isPrepared: true}, + ) + +} + +func (usgs *updateSQLGeneratorSuite) TestGenerate_withBadUpdateValues() { + uc := exp.NewUpdateClauses(). + SetTable(exp.NewIdentifierExpression("", "test", "")). + SetSetValues(true) + + expectedErr := "goqu: unsupported update interface type bool" + usgs.assertCases( + NewUpdateSQLGenerator("test", DefaultDialectOptions()), + updateTestCase{clause: uc, err: expectedErr}, + updateTestCase{clause: uc, err: expectedErr, isPrepared: true}, + ) +} + +func (usgs *updateSQLGeneratorSuite) TestGenerate_noSetValues() { + uc := exp.NewUpdateClauses().SetTable(exp.NewIdentifierExpression("", "test", "")) + + expectedErr := errNoSetValuesForUpdate.Error() + usgs.assertCases( + NewUpdateSQLGenerator("test", DefaultDialectOptions()), + updateTestCase{clause: uc, err: expectedErr}, + updateTestCase{clause: uc, err: expectedErr, isPrepared: true}, + ) +} + +func (usgs *updateSQLGeneratorSuite) TestGenerate_withFrom() { + uc := exp.NewUpdateClauses(). + SetTable(exp.NewIdentifierExpression("", "test", "")). + SetSetValues(exp.Record{"foo": "bar"}). + SetFrom(exp.NewColumnListExpression("other_test")) + + opts := DefaultDialectOptions() + usgs.assertCases( + NewUpdateSQLGenerator("test", opts), + updateTestCase{clause: uc, sql: `UPDATE "test" SET "foo"='bar' FROM "other_test"`}, + updateTestCase{clause: uc, sql: `UPDATE "test" SET "foo"=? FROM "other_test"`, isPrepared: true, args: []interface{}{"bar"}}, + ) + + opts = DefaultDialectOptions() + opts.UseFromClauseForMultipleUpdateTables = false + usgs.assertCases( + NewUpdateSQLGenerator("test", opts), + updateTestCase{clause: uc, sql: `UPDATE "test","other_test" SET "foo"='bar'`}, + updateTestCase{clause: uc, sql: `UPDATE "test","other_test" SET "foo"=?`, isPrepared: true, args: []interface{}{"bar"}}, + ) + + opts = DefaultDialectOptions() + opts.SupportsMultipleUpdateTables = false + expectedErr := "goqu: test dialect does not support multiple tables in UPDATE" + usgs.assertCases( + NewUpdateSQLGenerator("test", opts), + updateTestCase{clause: uc, err: expectedErr}, + updateTestCase{clause: uc, err: expectedErr, isPrepared: true}, + ) +} + +func (usgs *updateSQLGeneratorSuite) TestGenerate_withUpdateExpression() { + + opts := DefaultDialectOptions() + // make sure the fragments are used + opts.SetFragment = []byte(" set ") + uc := exp.NewUpdateClauses(). + SetTable(exp.NewIdentifierExpression("", "test", "")) + ucRecord := uc.SetSetValues(exp.Record{"a": "b", "b": "c"}) + ucEmptyRecord := uc.SetSetValues(exp.Record{}) + + usgs.assertCases( + NewUpdateSQLGenerator("test", opts), + updateTestCase{clause: ucRecord, sql: `UPDATE "test" set "a"='b',"b"='c'`}, + updateTestCase{clause: ucRecord, sql: `UPDATE "test" set "a"=?,"b"=?`, isPrepared: true, args: []interface{}{"b", "c"}}, + + updateTestCase{clause: ucEmptyRecord, err: errNoUpdatedValuesProvided.Error()}, + updateTestCase{clause: ucEmptyRecord, err: errNoUpdatedValuesProvided.Error(), isPrepared: true}, + ) +} + +func (usgs *updateSQLGeneratorSuite) TestGenerate_withOrder() { + uc := exp.NewUpdateClauses(). + SetTable(exp.NewIdentifierExpression("", "test", "")). + SetSetValues(exp.Record{"a": "b", "b": "c"}). + SetOrder( + exp.NewIdentifierExpression("", "", "a").Asc(), + exp.NewIdentifierExpression("", "", "b").Desc(), + ) + + opts := DefaultDialectOptions() + opts.SupportsOrderByOnUpdate = true + + usgs.assertCases( + NewUpdateSQLGenerator("test", opts), + updateTestCase{clause: uc, sql: `UPDATE "test" SET "a"='b',"b"='c' ORDER BY "a" ASC, "b" DESC`}, + updateTestCase{ + clause: uc, + sql: `UPDATE "test" SET "a"=?,"b"=? ORDER BY "a" ASC, "b" DESC`, + isPrepared: true, + args: []interface{}{"b", "c"}, + }, + ) + + opts = DefaultDialectOptions() + opts.SupportsOrderByOnUpdate = false + usgs.assertCases( + NewUpdateSQLGenerator("test", opts), + updateTestCase{clause: uc, sql: `UPDATE "test" SET "a"='b',"b"='c'`}, + updateTestCase{clause: uc, sql: `UPDATE "test" SET "a"=?,"b"=?`, isPrepared: true, args: []interface{}{"b", "c"}}, + ) +} + +func (usgs *updateSQLGeneratorSuite) TestGenerate_withLimit() { + uc := exp.NewUpdateClauses(). + SetTable(exp.NewIdentifierExpression("", "test", "")). + SetSetValues(exp.Record{"a": "b", "b": "c"}). + SetLimit(10) + + opts := DefaultDialectOptions() + opts.SupportsLimitOnUpdate = true + + usgs.assertCases( + NewUpdateSQLGenerator("test", opts), + updateTestCase{clause: uc, sql: `UPDATE "test" SET "a"='b',"b"='c' LIMIT 10`}, + updateTestCase{clause: uc, sql: `UPDATE "test" SET "a"=?,"b"=? LIMIT ?`, isPrepared: true, args: []interface{}{"b", "c", int64(10)}}, + ) + + opts = DefaultDialectOptions() + opts.SupportsLimitOnUpdate = false + usgs.assertCases( + NewUpdateSQLGenerator("test", opts), + updateTestCase{clause: uc, sql: `UPDATE "test" SET "a"='b',"b"='c'`}, + updateTestCase{clause: uc, sql: `UPDATE "test" SET "a"=?,"b"=?`, isPrepared: true, args: []interface{}{"b", "c"}}, + ) +} + +func (usgs *updateSQLGeneratorSuite) TestGenerate_withCommonTables() { + tse := newTestAppendableExpression("select * from foo", emptyArgs, nil, nil) + uc := exp.NewUpdateClauses(). + SetTable(exp.NewIdentifierExpression("", "test_cte", "")). + SetSetValues(exp.Record{"a": "b", "b": "c"}) + ucCte1 := uc.CommonTablesAppend(exp.NewCommonTableExpression(false, "test_cte", tse)) + ucCte2 := uc.CommonTablesAppend(exp.NewCommonTableExpression(true, "test_cte", tse)) + + usgs.assertCases( + NewUpdateSQLGenerator("test", DefaultDialectOptions()), + updateTestCase{ + clause: ucCte1, + sql: `WITH test_cte AS (select * from foo) UPDATE "test_cte" SET "a"='b',"b"='c'`, + }, + updateTestCase{ + clause: ucCte1, + sql: `WITH test_cte AS (select * from foo) UPDATE "test_cte" SET "a"=?,"b"=?`, + isPrepared: true, + args: []interface{}{"b", "c"}, + }, + + updateTestCase{ + clause: ucCte2, + sql: `WITH RECURSIVE test_cte AS (select * from foo) UPDATE "test_cte" SET "a"='b',"b"='c'`, + }, + updateTestCase{ + clause: ucCte2, + sql: `WITH RECURSIVE test_cte AS (select * from foo) UPDATE "test_cte" SET "a"=?,"b"=?`, + isPrepared: true, + args: []interface{}{"b", "c"}, + }, + ) +} + +func TestUpdateSQLGenerator(t *testing.T) { + suite.Run(t, new(updateSQLGeneratorSuite)) +} diff --git a/truncate_dataset_test.go b/truncate_dataset_test.go index 737b4e78..065d5986 100644 --- a/truncate_dataset_test.go +++ b/truncate_dataset_test.go @@ -13,8 +13,20 @@ import ( "github.com/stretchr/testify/suite" ) -type truncateDatasetSuite struct { - suite.Suite +type ( + truncateTestCase struct { + ds *TruncateDataset + clauses exp.TruncateClauses + } + truncateDatasetSuite struct { + suite.Suite + } +) + +func (tds *truncateDatasetSuite) assertCases(cases ...truncateTestCase) { + for _, s := range cases { + tds.Equal(s.clauses, s.ds.GetClauses()) + } } func (tds *truncateDatasetSuite) TestClone() { @@ -58,78 +70,148 @@ func (tds *truncateDatasetSuite) TestGetClauses() { tds.Equal(ce, ds.GetClauses()) } -func (tds *truncateDatasetSuite) TestTable(from ...interface{}) { - ds := Truncate("test") - dsc := ds.GetClauses() - ec := dsc.SetTable(exp.NewColumnListExpression(T("t"))) - tds.Equal(ec, ds.Table(T("t")).GetClauses()) - tds.Equal(dsc, ds.GetClauses()) +func (tds *truncateDatasetSuite) TestTable() { + bd := Truncate("test") + tds.assertCases( + truncateTestCase{ + ds: bd.Table("test2"), + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test2")), + }, + truncateTestCase{ + ds: bd.Table("test1", "test2"), + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test1", "test2")), + }, + truncateTestCase{ + ds: bd, + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")), + }, + ) } func (tds *truncateDatasetSuite) TestCascade() { - ds := Truncate("test") - dsc := ds.GetClauses() - ec := dsc.SetOptions(exp.TruncateOptions{Cascade: true}) - tds.Equal(ec, ds.Cascade().GetClauses()) - tds.Equal(dsc, ds.GetClauses()) -} - -func (tds *truncateDatasetSuite) TestCascade_ToSQL() { - ds1 := Truncate("items") - tsql, _, err := ds1.Cascade().ToSQL() - tds.NoError(err) - tds.Equal(`TRUNCATE "items" CASCADE`, tsql) + bd := Truncate("test") + tds.assertCases( + truncateTestCase{ + ds: bd.Cascade(), + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")). + SetOptions(exp.TruncateOptions{Cascade: true}), + }, + truncateTestCase{ + ds: bd.Restrict().Cascade(), + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")). + SetOptions(exp.TruncateOptions{Cascade: true, Restrict: true}), + }, + truncateTestCase{ + ds: bd, + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")), + }, + ) } func (tds *truncateDatasetSuite) TestNoCascade() { - ds := Truncate("test").Cascade() - dsc := ds.GetClauses() - ec := dsc.SetOptions(exp.TruncateOptions{Cascade: false}) - tds.Equal(ec, ds.NoCascade().GetClauses()) - tds.Equal(dsc, ds.GetClauses()) + bd := Truncate("test").Cascade() + tds.assertCases( + truncateTestCase{ + ds: bd.NoCascade(), + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")). + SetOptions(exp.TruncateOptions{}), + }, + truncateTestCase{ + ds: bd.Restrict().NoCascade(), + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")). + SetOptions(exp.TruncateOptions{Cascade: false, Restrict: true}), + }, + truncateTestCase{ + ds: bd, + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")). + SetOptions(exp.TruncateOptions{Cascade: true}), + }, + ) } func (tds *truncateDatasetSuite) TestRestrict() { - ds := Truncate("test") - dsc := ds.GetClauses() - ec := dsc.SetOptions(exp.TruncateOptions{Restrict: true}) - tds.Equal(ec, ds.Restrict().GetClauses()) - tds.Equal(dsc, ds.GetClauses()) -} - -func (tds *truncateDatasetSuite) TestRestrict_ToSQL() { - ds1 := Truncate("items") - tsql, _, err := ds1.Restrict().ToSQL() - tds.NoError(err) - tds.Equal(`TRUNCATE "items" RESTRICT`, tsql) + bd := Truncate("test") + tds.assertCases( + truncateTestCase{ + ds: bd.Restrict(), + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")). + SetOptions(exp.TruncateOptions{Restrict: true}), + }, + truncateTestCase{ + ds: bd.Cascade().Restrict(), + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")). + SetOptions(exp.TruncateOptions{Cascade: true, Restrict: true}), + }, + truncateTestCase{ + ds: bd, + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")), + }, + ) } func (tds *truncateDatasetSuite) TestNoRestrict() { - ds := Truncate("test").Restrict() - dsc := ds.GetClauses() - ec := dsc.SetOptions(exp.TruncateOptions{Restrict: false}) - tds.Equal(ec, ds.NoRestrict().GetClauses()) - tds.Equal(dsc, ds.GetClauses()) + bd := Truncate("test").Restrict() + tds.assertCases( + truncateTestCase{ + ds: bd.NoRestrict(), + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")). + SetOptions(exp.TruncateOptions{}), + }, + truncateTestCase{ + ds: bd.Cascade().NoRestrict(), + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")). + SetOptions(exp.TruncateOptions{Cascade: true, Restrict: false}), + }, + truncateTestCase{ + ds: bd, + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")). + SetOptions(exp.TruncateOptions{Restrict: true}), + }, + ) } func (tds *truncateDatasetSuite) TestIdentity() { - ds := Truncate("test") - dsc := ds.GetClauses() - ec := dsc.SetOptions(exp.TruncateOptions{Identity: "RESTART"}) - tds.Equal(ec, ds.Identity("RESTART").GetClauses()) - tds.Equal(dsc, ds.GetClauses()) -} - -func (tds *truncateDatasetSuite) TestIdentity_ToSQL() { - ds1 := Truncate("items") - - tsql, _, err := ds1.Identity("restart").ToSQL() - tds.NoError(err) - tds.Equal(`TRUNCATE "items" RESTART IDENTITY`, tsql) - - tsql, _, err = ds1.Identity("continue").ToSQL() - tds.NoError(err) - tds.Equal(`TRUNCATE "items" CONTINUE IDENTITY`, tsql) + bd := Truncate("test") + tds.assertCases( + truncateTestCase{ + ds: bd.Identity("RESTART"), + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")). + SetOptions(exp.TruncateOptions{Identity: "RESTART"}), + }, + truncateTestCase{ + ds: bd.Identity("CONTINUE"), + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")). + SetOptions(exp.TruncateOptions{Identity: "CONTINUE"}), + }, + truncateTestCase{ + ds: bd.Cascade().Restrict().Identity("CONTINUE"), + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")). + SetOptions(exp.TruncateOptions{Cascade: true, Restrict: true, Identity: "CONTINUE"}), + }, + truncateTestCase{ + ds: bd, + clauses: exp.NewTruncateClauses(). + SetTable(exp.NewColumnListExpression("test")), + }, + ) } func (tds *truncateDatasetSuite) TestToSQL() { diff --git a/update_dataset.go b/update_dataset.go index 2d5025cf..b3c137b2 100644 --- a/update_dataset.go +++ b/update_dataset.go @@ -3,6 +3,7 @@ package goqu import ( "github.com/doug-martin/goqu/v8/exec" "github.com/doug-martin/goqu/v8/exp" + "github.com/doug-martin/goqu/v8/internal/errors" "github.com/doug-martin/goqu/v8/internal/sb" ) @@ -13,6 +14,8 @@ type UpdateDataset struct { queryFactory exec.QueryFactory } +var errUnsupportedUpdateTableType = errors.New("unsupported table type, a string or identifier expression is required") + // used internally by database to create a database with a specific adapter func newUpdateDataset(d string, queryFactory exec.QueryFactory) *UpdateDataset { return &UpdateDataset{ @@ -112,7 +115,7 @@ func (ud *UpdateDataset) Table(table interface{}) *UpdateDataset { case string: return ud.copy(ud.clauses.SetTable(exp.ParseIdentifier(t))) default: - panic("unsupported table type, a string or identifier expression is required") + panic(errUnsupportedUpdateTableType) } } diff --git a/update_dataset_test.go b/update_dataset_test.go index 8a2cb5e0..5dd93424 100644 --- a/update_dataset_test.go +++ b/update_dataset_test.go @@ -1,11 +1,7 @@ package goqu import ( - "database/sql" - "database/sql/driver" - "fmt" "testing" - "time" "github.com/DATA-DOG/go-sqlmock" "github.com/doug-martin/goqu/v8/exec" @@ -17,28 +13,20 @@ import ( "github.com/stretchr/testify/suite" ) -type updateDatasetSuite struct { - suite.Suite -} - -func (uds *updateDatasetSuite) SetupSuite() { - noReturn := DefaultDialectOptions() - noReturn.SupportsReturn = false - RegisterDialect("no-return", noReturn) - - limitOnUpdate := DefaultDialectOptions() - limitOnUpdate.SupportsLimitOnUpdate = true - RegisterDialect("limit-on-update", limitOnUpdate) - - orderOnUpdate := DefaultDialectOptions() - orderOnUpdate.SupportsOrderByOnUpdate = true - RegisterDialect("order-on-update", orderOnUpdate) -} +type ( + updateTestCase struct { + ds *UpdateDataset + clauses exp.UpdateClauses + } + updateDatasetSuite struct { + suite.Suite + } +) -func (uds *updateDatasetSuite) TearDownSuite() { - DeregisterDialect("no-return") - DeregisterDialect("limit-on-update") - DeregisterDialect("order-on-update") +func (uds *updateDatasetSuite) assertCases(cases ...updateTestCase) { + for _, s := range cases { + uds.Equal(s.clauses, s.ds.GetClauses()) + } } func (uds *updateDatasetSuite) TestClone() { @@ -84,758 +72,291 @@ func (uds *updateDatasetSuite) TestGetClauses() { func (uds *updateDatasetSuite) TestWith() { from := Update("cte") - ds := Update("test") - dsc := ds.GetClauses() - ec := dsc.CommonTablesAppend(exp.NewCommonTableExpression(false, "test-cte", from)) - uds.Equal(ec, ds.With("test-cte", from).GetClauses()) - uds.Equal(dsc, ds.GetClauses()) + bd := Update("items") + uds.assertCases( + updateTestCase{ + ds: bd.With("test-cte", from), + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + CommonTablesAppend(exp.NewCommonTableExpression(false, "test-cte", from)), + }, + updateTestCase{ + ds: bd, + clauses: exp.NewUpdateClauses().SetTable(C("items")), + }, + ) } func (uds *updateDatasetSuite) TestWithRecursive() { from := Update("cte") - ds := Update("test") - dsc := ds.GetClauses() - ec := dsc.CommonTablesAppend(exp.NewCommonTableExpression(true, "test-cte", from)) - uds.Equal(ec, ds.WithRecursive("test-cte", from).GetClauses()) - uds.Equal(dsc, ds.GetClauses()) + bd := Update("items") + uds.assertCases( + updateTestCase{ + ds: bd.WithRecursive("test-cte", from), + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + CommonTablesAppend(exp.NewCommonTableExpression(true, "test-cte", from)), + }, + updateTestCase{ + ds: bd, + clauses: exp.NewUpdateClauses().SetTable(C("items")), + }, + ) } func (uds *updateDatasetSuite) TestTable() { - ds := Update("test") - dsc := ds.GetClauses() - ec := dsc.SetTable(T("t")) - uds.Equal(ec, ds.Table(T("t")).GetClauses()) - uds.Equal(dsc, ds.GetClauses()) -} - -func (uds *updateDatasetSuite) TestTable_ToSQL() { - ds1 := Update("test").Set(C("a").Set("a1")) - - updateSQL, _, err := ds1.ToSQL() - uds.NoError(err) - uds.Equal(`UPDATE "test" SET "a"='a1'`, updateSQL) - - ds2 := ds1.Table("test2") - updateSQL, _, err = ds2.ToSQL() - uds.NoError(err) - uds.Equal(`UPDATE "test2" SET "a"='a1'`, updateSQL) - - // should not change original - updateSQL, _, err = ds1.ToSQL() - uds.NoError(err) - uds.Equal(`UPDATE "test" SET "a"='a1'`, updateSQL) -} - -func (uds *updateDatasetSuite) TestSet_ToSQLWithStructs() { - type item struct { - Address string `db:"address"` - Name string `db:"name"` - } - ds1 := Update("items").Set(item{Name: "Test", Address: "111 Test Addr"}) - updateSQL, args, err := ds1.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "items" SET "address"='111 Test Addr',"name"='Test'`, updateSQL) - - updateSQL, args, err = ds1.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal(args, []interface{}{"111 Test Addr", "Test"}) - uds.Equal(`UPDATE "items" SET "address"=?,"name"=?`, updateSQL) -} - -func (uds *updateDatasetSuite) TestSet_ToSQLWithMaps() { - ds1 := Update("items").Set(Record{"name": "Test", "address": "111 Test Addr"}) - updateSQL, args, err := ds1.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "items" SET "address"='111 Test Addr',"name"='Test'`, updateSQL) - - updateSQL, args, err = ds1.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"111 Test Addr", "Test"}, args) - uds.Equal(`UPDATE "items" SET "address"=?,"name"=?`, updateSQL) - -} - -func (uds *updateDatasetSuite) TestSet_ToSQLWithByteSlice() { - type item struct { - Name string `db:"name"` - Data []byte `db:"data"` - } - ds1 := Update("items").Set(item{Name: "Test", Data: []byte(`{"someJson":"data"}`)}) - updateSQL, args, err := ds1.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "items" SET "data"='{"someJson":"data"}',"name"='Test'`, updateSQL) - - updateSQL, args, err = ds1.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal(`UPDATE "items" SET "data"=?,"name"=?`, updateSQL) - uds.Equal(args, []interface{}{[]byte(`{"someJson":"data"}`), "Test"}) -} - -type valuerType []byte - -func (j valuerType) Value() (driver.Value, error) { - return []byte(fmt.Sprintf("%s World", string(j))), nil -} - -func (uds *updateDatasetSuite) TestSet_ToSQLWithCustomValuer() { - type item struct { - Name string `db:"name"` - Data valuerType `db:"data"` - } - ds1 := Update("items").Set(item{Name: "Test", Data: []byte(`Hello`)}) - updateSQL, args, err := ds1.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "items" SET "data"='Hello World',"name"='Test'`, updateSQL) - - updateSQL, args, err = ds1.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{[]byte("Hello World"), "Test"}, args) - uds.Equal(`UPDATE "items" SET "data"=?,"name"=?`, updateSQL) -} - -func (uds *updateDatasetSuite) TestSet_ToSQLWithValuer() { - type item struct { - Name string `db:"name"` - Data sql.NullString `db:"data"` - } - ds1 := Update("items").Set(item{Name: "Test", Data: sql.NullString{String: "Hello World", Valid: true}}) - - updateSQL, args, err := ds1.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "items" SET "data"='Hello World',"name"='Test'`, updateSQL) - - updateSQL, args, err = ds1.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal(args, []interface{}{"Hello World", "Test"}) - uds.Equal(`UPDATE "items" SET "data"=?,"name"=?`, updateSQL) -} - -func (uds *updateDatasetSuite) TestSet_ToSQLWithValuerNull() { - type item struct { - Name string `db:"name"` - Data sql.NullString `db:"data"` - } - ds1 := Update("items").Set(item{Name: "Test"}) - updateSQL, args, err := ds1.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "items" SET "data"=NULL,"name"='Test'`, updateSQL) - - updateSQL, args, err = ds1.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"Test"}, args) - uds.Equal(`UPDATE "items" SET "data"=NULL,"name"=?`, updateSQL) -} - -func (uds *updateDatasetSuite) TestSet_ToSQLWithEmbeddedStruct() { - type Phone struct { - Primary string `db:"primary_phone"` - Home string `db:"home_phone"` - Created time.Time `db:"phone_created"` - } - type item struct { - Phone - Address string `db:"address" goqu:"skipupdate"` - Name string `db:"name"` - Created time.Time `db:"created"` - NilPointer interface{} `db:"nil_pointer"` - } - created, _ := time.Parse("2006-01-02", "2015-01-01") - ds1 := Update("items").Set(item{ - Name: "Test", - Address: "111 Test Addr", - Created: created, - Phone: Phone{ - Home: "123123", - Primary: "456456", - Created: created, + bd := Update("items") + uds.assertCases( + updateTestCase{ + ds: bd.Table("items2"), + clauses: exp.NewUpdateClauses().SetTable(C("items2")), }, - }) - updateSQL, args, err := ds1.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "items" SET `+ - `"created"='2015-01-01T00:00:00Z',`+ - `"home_phone"='123123',`+ - `"name"='Test',`+ - `"nil_pointer"=NULL,`+ - `"phone_created"='2015-01-01T00:00:00Z',`+ - `"primary_phone"='456456'`, updateSQL) - - updateSQL, args, err = ds1.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal(`UPDATE "items" SET `+ - `"created"=?,"home_phone"=?,"name"=?,"nil_pointer"=NULL,"phone_created"=?,"primary_phone"=?`, updateSQL) - uds.Equal([]interface{}{created, "123123", "Test", created, "456456"}, args) -} - -func (uds *updateDatasetSuite) TestSet_ToSQLWithEmbeddedStructPtr() { - type Phone struct { - Primary string `db:"primary_phone"` - Home string `db:"home_phone"` - Created time.Time `db:"phone_created"` - } - type item struct { - *Phone - Address string `db:"address" goqu:"skipupdate"` - Name string `db:"name"` - Created time.Time `db:"created"` - } - created, _ := time.Parse("2006-01-02", "2015-01-01") - - ds1 := Update("items").Set(item{ - Name: "Test", - Address: "111 Test Addr", - Created: created, - Phone: &Phone{ - Home: "123123", - Primary: "456456", - Created: created, + updateTestCase{ + ds: bd.Table(L("literal_table")), + clauses: exp.NewUpdateClauses().SetTable(L("literal_table")), + }, + updateTestCase{ + ds: bd, + clauses: exp.NewUpdateClauses().SetTable(C("items")), }, + ) + uds.PanicsWithValue(errUnsupportedUpdateTableType, func() { + bd.Table(true) }) - - updateSQL, args, err := ds1.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "items" SET `+ - `"created"='2015-01-01T00:00:00Z',`+ - `"home_phone"='123123',`+ - `"name"='Test',`+ - `"phone_created"='2015-01-01T00:00:00Z',`+ - `"primary_phone"='456456'`, updateSQL) - - updateSQL, args, err = ds1.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal(`UPDATE "items" `+ - `SET "created"=?,"home_phone"=?,"name"=?,"phone_created"=?,"primary_phone"=?`, updateSQL) - uds.Equal([]interface{}{created, "123123", "Test", created, "456456"}, args) -} - -func (uds *updateDatasetSuite) TestSet_ToSQLWithUnsupportedType() { - ds1 := Update("items").Set([]string{"HELLO"}) - - _, _, err := ds1.ToSQL() - uds.EqualError(err, "goqu: unsupported update interface type []string") - - _, _, err = ds1.Prepared(true).ToSQL() - uds.EqualError(err, "goqu: unsupported update interface type []string") } -func (uds *updateDatasetSuite) TestSet_ToSQLWithSkipupdateTag() { +func (uds *updateDatasetSuite) TestSet() { type item struct { - Address string `db:"address" goqu:"skipupdate"` + Address string `db:"address"` Name string `db:"name"` } - ds1 := Update("items").Set(item{Name: "Test", Address: "111 Test Addr"}) - - updateSQL, args, err := ds1.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "items" SET "name"='Test'`, updateSQL) - - updateSQL, args, err = ds1.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"Test"}, args) - uds.Equal(`UPDATE "items" SET "name"=?`, updateSQL) -} - -func (uds *updateDatasetSuite) TestSet_ToSQLWithDefaultIfEmptyTag() { - type item struct { - Address string `db:"address" goqu:"skipupdate, defaultifempty"` - Name string `db:"name" goqu:"defaultifempty"` - Alias *string `db:"alias" goqu:"defaultifempty"` - } - ds := Update("items").Set(item{Name: "Test", Address: "111 Test Addr"}) - - updateSQL, args, err := ds.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "items" SET "alias"=DEFAULT,"name"='Test'`, updateSQL) - - updateSQL, args, err = ds.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"Test"}, args) - uds.Equal(`UPDATE "items" SET "alias"=DEFAULT,"name"=?`, updateSQL) - - var alias = "" - ds = ds.Set(item{Alias: &alias}) - - updateSQL, args, err = ds.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "items" SET "alias"='',"name"=DEFAULT`, updateSQL) - - updateSQL, args, err = ds.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{""}, args) - uds.Equal(`UPDATE "items" SET "alias"=?,"name"=DEFAULT`, updateSQL) + bd := Update("items") + uds.assertCases( + updateTestCase{ + ds: bd.Set(item{Name: "Test", Address: "111 Test Addr"}), + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + SetSetValues(item{Name: "Test", Address: "111 Test Addr"}), + }, + updateTestCase{ + ds: bd.Set(Record{"name": "Test", "address": "111 Test Addr"}), + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + SetSetValues(Record{"name": "Test", "address": "111 Test Addr"}), + }, + updateTestCase{ + ds: bd, + clauses: exp.NewUpdateClauses(). + SetTable(C("items")), + }, + ) } func (uds *updateDatasetSuite) TestFrom() { - ds := Update("test") - dsc := ds.GetClauses() - ec := dsc.SetFrom(exp.NewColumnListExpression("other")) - uds.Equal(ec, ds.From("other").GetClauses()) - uds.Equal(dsc, ds.GetClauses()) -} - -func (uds *updateDatasetSuite) TestFrom_ToSQL() { - ds1 := Update("test").Set(C("a").Set("a1")).From("other_table").Where(Ex{ - "test.name": T("other_test").Col("name"), - }) - - updateSQL, args, err := ds1.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "test" SET "a"='a1' FROM "other_table" WHERE ("test"."name" = "other_test"."name")`, updateSQL) - - updateSQL, args, err = ds1.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1"}, args) - uds.Equal(`UPDATE "test" SET "a"=? FROM "other_table" WHERE ("test"."name" = "other_test"."name")`, updateSQL) -} - -func (uds *updateDatasetSuite) TestWhere() { - ds := Update("test") - dsc := ds.GetClauses() - w := Ex{ - "a": 1, - } - ec := dsc.WhereAppend(w) - uds.Equal(ec, ds.Where(w).GetClauses()) - uds.Equal(dsc, ds.GetClauses()) -} - -func (uds *updateDatasetSuite) TestWhere_ToSQL() { - ds1 := Update("test").Set(C("a").Set("a1")) - - b := ds1.Where( - C("a").Eq(true), - C("a").Neq(true), - C("a").Eq(false), - C("a").Neq(false), - ) - updateSQL, args, err := b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal( - `UPDATE "test" SET "a"='a1' `+ - `WHERE (("a" IS TRUE) AND ("a" IS NOT TRUE) AND ("a" IS FALSE) AND ("a" IS NOT FALSE))`, - updateSQL, - ) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1"}, args) - uds.Equal( - `UPDATE "test" SET "a"=? `+ - `WHERE (("a" IS TRUE) AND ("a" IS NOT TRUE) AND ("a" IS FALSE) AND ("a" IS NOT FALSE))`, - updateSQL, - ) - - b = ds1.Where( - C("a").Eq("a"), - C("b").Neq("b"), - C("c").Gt("c"), - C("d").Gte("d"), - C("e").Lt("e"), - C("f").Lte("f"), - ) - updateSQL, args, err = b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal( - `UPDATE "test" SET "a"='a1' `+ - `WHERE (("a" = 'a') AND ("b" != 'b') AND ("c" > 'c') AND ("d" >= 'd') AND ("e" < 'e') AND ("f" <= 'f'))`, - updateSQL, - ) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1", "a", "b", "c", "d", "e", "f"}, args) - uds.Equal( - `UPDATE "test" SET "a"=? `+ - `WHERE (("a" = ?) AND ("b" != ?) AND ("c" > ?) AND ("d" >= ?) AND ("e" < ?) AND ("f" <= ?))`, - updateSQL, - ) - - b = ds1.Where( - C("a").Eq(From("test2").Select("id")), - ) - updateSQL, args, err = b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "test" SET "a"='a1' WHERE ("a" IN (SELECT "id" FROM "test2"))`, updateSQL) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1"}, args) - uds.Equal(`UPDATE "test" SET "a"=? WHERE ("a" IN (SELECT "id" FROM "test2"))`, updateSQL) - - b = ds1.Where(Ex{ - "a": "a", - "b": Op{"neq": "b"}, - "c": Op{"gt": "c"}, - "d": Op{"gte": "d"}, - "e": Op{"lt": "e"}, - "f": Op{"lte": "f"}, - }) - updateSQL, args, err = b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal( - `UPDATE "test" SET "a"='a1' `+ - `WHERE (("a" = 'a') AND ("b" != 'b') AND ("c" > 'c') AND ("d" >= 'd') AND ("e" < 'e') AND ("f" <= 'f'))`, - updateSQL, - ) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1", "a", "b", "c", "d", "e", "f"}, args) - uds.Equal( - `UPDATE "test" SET "a"=? `+ - `WHERE (("a" = ?) AND ("b" != ?) AND ("c" > ?) AND ("d" >= ?) AND ("e" < ?) AND ("f" <= ?))`, - updateSQL, + bd := Update("items") + uds.assertCases( + updateTestCase{ + ds: bd.From("other"), + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + SetFrom(exp.NewColumnListExpression("other")), + }, + updateTestCase{ + ds: bd.From("other").From("other2"), + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + SetFrom(exp.NewColumnListExpression("other2")), + }, + updateTestCase{ + ds: bd, + clauses: exp.NewUpdateClauses(). + SetTable(C("items")), + }, ) - - b = ds1.Where(Ex{ - "a": From("test2").Select("id"), - }) - updateSQL, args, err = b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "test" SET "a"='a1' WHERE ("a" IN (SELECT "id" FROM "test2"))`, updateSQL) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1"}, args) - uds.Equal(`UPDATE "test" SET "a"=? WHERE ("a" IN (SELECT "id" FROM "test2"))`, updateSQL) -} - -func (uds *updateDatasetSuite) TestWhere_ToSQLEmpty() { - ds1 := Update("test").Set(C("a").Set("a1")) - - b := ds1.Where() - updateSQL, args, err := b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "test" SET "a"='a1'`, updateSQL) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1"}, args) - uds.Equal(`UPDATE "test" SET "a"=?`, updateSQL) } -func (uds *updateDatasetSuite) TestWhere_ToSQLWithChain() { - ds1 := Update("test").Set(C("a").Set("a1")).Where( - C("x").Eq(0), - C("y").Eq(1), - ) - - ds2 := ds1.Where( - C("z").Eq(2), - ) - - a := ds2.Where( - C("a").Eq("A"), - ) - b := ds2.Where( - C("b").Eq("B"), - ) - updateSQL, args, err := a.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal( - `UPDATE "test" SET "a"='a1' WHERE (("x" = 0) AND ("y" = 1) AND ("z" = 2) AND ("a" = 'A'))`, - updateSQL, - ) - - updateSQL, args, err = a.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1", int64(0), int64(1), int64(2), "A"}, args) - uds.Equal( - `UPDATE "test" SET "a"=? WHERE (("x" = ?) AND ("y" = ?) AND ("z" = ?) AND ("a" = ?))`, - updateSQL, - ) - - updateSQL, args, err = b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal( - `UPDATE "test" SET "a"='a1' WHERE (("x" = 0) AND ("y" = 1) AND ("z" = 2) AND ("b" = 'B'))`, - updateSQL, - ) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1", int64(0), int64(1), int64(2), "B"}, args) - uds.Equal( - `UPDATE "test" SET "a"=? WHERE (("x" = ?) AND ("y" = ?) AND ("z" = ?) AND ("b" = ?))`, - updateSQL, +func (uds *updateDatasetSuite) TestWhere() { + bd := Update("items") + uds.assertCases( + updateTestCase{ + ds: bd.Where(Ex{"a": 1}), + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + WhereAppend(Ex{"a": 1}), + }, + updateTestCase{ + ds: bd.Where(Ex{"a": 1}).Where(C("b").Eq("c")), + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + WhereAppend(Ex{"a": 1}).WhereAppend(C("b").Eq("c")), + }, + updateTestCase{ + ds: bd, + clauses: exp.NewUpdateClauses(). + SetTable(C("items")), + }, ) } func (uds *updateDatasetSuite) TestClearWhere() { - w := Ex{ - "a": 1, - } - ds := Update("test").Where(w) - dsc := ds.GetClauses() - ec := dsc.ClearWhere() - uds.Equal(ec, ds.ClearWhere().GetClauses()) - uds.Equal(dsc, ds.GetClauses()) -} - -func (uds *updateDatasetSuite) TestClearWhere_ToSQL() { - ds1 := Update("test").Set(C("a").Set("a1")) - - b := ds1.Where( - C("a").Eq(1), - ).ClearWhere() - updateSQL, args, err := b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "test" SET "a"='a1'`, updateSQL) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1"}, args) - uds.Equal(`UPDATE "test" SET "a"=?`, updateSQL) + bd := Update("items").Where(Ex{"a": 1}) + uds.assertCases( + updateTestCase{ + ds: bd.ClearWhere(), + clauses: exp.NewUpdateClauses().SetTable(C("items")), + }, + updateTestCase{ + ds: bd, + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + WhereAppend(Ex{"a": 1}), + }, + ) } func (uds *updateDatasetSuite) TestOrder() { - ds := Update("test") - dsc := ds.GetClauses() - o := C("a").Desc() - ec := dsc.SetOrder(o) - uds.Equal(ec, ds.Order(o).GetClauses()) - uds.Equal(dsc, ds.GetClauses()) -} - -func (uds *updateDatasetSuite) TestOrder_ToSQL() { - - ds1 := Update("test").WithDialect("order-on-update").Set(C("a").Set("a1")) - - b := ds1.Order(C("a").Asc(), L(`("a" + "b" > 2)`).Asc()) - - updateSQL, args, err := b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "test" SET "a"='a1' ORDER BY "a" ASC, ("a" + "b" > 2) ASC`, updateSQL) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1"}, args) - uds.Equal(`UPDATE "test" SET "a"=? ORDER BY "a" ASC, ("a" + "b" > 2) ASC`, updateSQL) + bd := Update("items") + uds.assertCases( + updateTestCase{ + ds: bd.Order(C("a").Desc()), + clauses: exp.NewUpdateClauses(). + SetTable(C("items")).OrderAppend(C("a").Desc()), + }, + updateTestCase{ + ds: bd.Order(C("a").Desc()).Order(C("b").Asc()), + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + OrderAppend(C("b").Asc()), + }, + updateTestCase{ + ds: bd, + clauses: exp.NewUpdateClauses().SetTable(C("items")), + }, + ) } func (uds *updateDatasetSuite) TestOrderAppend() { - ds := Update("test").Order(C("a").Desc()) - dsc := ds.GetClauses() - o := C("b").Desc() - ec := dsc.OrderAppend(o) - uds.Equal(ec, ds.OrderAppend(o).GetClauses()) - uds.Equal(dsc, ds.GetClauses()) + bd := Update("items").Order(C("a").Desc()) + uds.assertCases( + updateTestCase{ + ds: bd.OrderAppend(C("b").Asc()), + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + OrderAppend(C("a").Desc()). + OrderAppend(C("b").Asc()), + }, + updateTestCase{ + ds: bd, + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + OrderAppend(C("a").Desc()), + }, + ) } - -func (uds *updateDatasetSuite) TestOrderAppend_ToSQL() { - ds := Update("test").WithDialect("order-on-update").Set(C("a").Set("a1")) - - b := ds.Order(C("a").Asc().NullsFirst()).OrderAppend(C("b").Desc().NullsLast()) - updateSQL, args, err := b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "test" SET "a"='a1' ORDER BY "a" ASC NULLS FIRST, "b" DESC NULLS LAST`, updateSQL) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1"}, args) - uds.Equal(updateSQL, `UPDATE "test" SET "a"=? ORDER BY "a" ASC NULLS FIRST, "b" DESC NULLS LAST`) - - b = ds.OrderAppend(C("a").Asc().NullsFirst()).OrderAppend(C("b").Desc().NullsLast()) - updateSQL, args, err = b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "test" SET "a"='a1' ORDER BY "a" ASC NULLS FIRST, "b" DESC NULLS LAST`, updateSQL) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1"}, args) - uds.Equal(`UPDATE "test" SET "a"=? ORDER BY "a" ASC NULLS FIRST, "b" DESC NULLS LAST`, updateSQL) +func (uds *updateDatasetSuite) TestOrderPrepend() { + bd := Update("items").Order(C("a").Desc()) + uds.assertCases( + updateTestCase{ + ds: bd.OrderPrepend(C("b").Asc()), + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + OrderAppend(C("b").Asc()). + OrderAppend(C("a").Desc()), + }, + updateTestCase{ + ds: bd, + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + OrderAppend(C("a").Desc()), + }, + ) } func (uds *updateDatasetSuite) TestClearOrder() { - ds := Update("test").Order(C("a").Desc()) - dsc := ds.GetClauses() - ec := dsc.ClearOrder() - uds.Equal(ec, ds.ClearOrder().GetClauses()) - uds.Equal(dsc, ds.GetClauses()) -} - -func (uds *updateDatasetSuite) TestClearOrder_ToSQL() { - b := Update("test"). - WithDialect("order-on-update"). - Set(C("a").Set("a1")). - Order(C("a").Asc().NullsFirst()). - ClearOrder() - - updateSQL, args, err := b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "test" SET "a"='a1'`, updateSQL) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1"}, args) - uds.Equal(`UPDATE "test" SET "a"=?`, updateSQL) + bd := Update("items").Order(C("a").Desc()) + uds.assertCases( + updateTestCase{ + ds: bd.ClearOrder(), + clauses: exp.NewUpdateClauses().SetTable(C("items")), + }, + updateTestCase{ + ds: bd, + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + OrderAppend(C("a").Desc()), + }, + ) } func (uds *updateDatasetSuite) TestLimit() { - ds := Update("test") - dsc := ds.GetClauses() - ec := dsc.SetLimit(uint(1)) - uds.Equal(ec, ds.Limit(1).GetClauses()) - uds.Equal(dsc, ds.Limit(0).GetClauses()) - uds.Equal(dsc, ds.GetClauses()) -} - -func (uds *updateDatasetSuite) TestLimit_ToSQL() { - ds1 := Update("test").WithDialect("limit-on-update").Set(C("a").Set("a1")) - - b := ds1.Limit(10) - - updateSQL, args, err := b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "test" SET "a"='a1' LIMIT 10`, updateSQL) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1", int64(10)}, args) - uds.Equal(`UPDATE "test" SET "a"=? LIMIT ?`, updateSQL) - - b = ds1.Limit(0) - updateSQL, args, err = b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "test" SET "a"='a1'`, updateSQL) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1"}, args) - uds.Equal(`UPDATE "test" SET "a"=?`, updateSQL) + bd := Update("items") + uds.assertCases( + updateTestCase{ + ds: bd.Limit(10), + clauses: exp.NewUpdateClauses().SetTable(C("items")).SetLimit(uint(10)), + }, + updateTestCase{ + ds: bd.Limit(0), + clauses: exp.NewUpdateClauses().SetTable(C("items")), + }, + updateTestCase{ + ds: bd, + clauses: exp.NewUpdateClauses().SetTable(C("items")), + }, + ) } func (uds *updateDatasetSuite) TestLimitAll() { - ds := Update("test") - dsc := ds.GetClauses() - ec := dsc.SetLimit(L("ALL")) - uds.Equal(ec, ds.LimitAll().GetClauses()) - uds.Equal(dsc, ds.GetClauses()) -} - -func (uds *updateDatasetSuite) TestLimitAll_ToSQL() { - ds1 := Update("test").WithDialect("limit-on-update").Set(C("a").Set("a1")) - - b := ds1.LimitAll() - updateSQL, args, err := b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "test" SET "a"='a1' LIMIT ALL`, updateSQL) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1"}, args) - uds.Equal(`UPDATE "test" SET "a"=? LIMIT ALL`, updateSQL) - - b = ds1.Limit(0).LimitAll() - updateSQL, args, err = b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "test" SET "a"='a1' LIMIT ALL`, updateSQL) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1"}, args) - uds.Equal(`UPDATE "test" SET "a"=? LIMIT ALL`, updateSQL) + bd := Update("items") + uds.assertCases( + updateTestCase{ + ds: bd.LimitAll(), + clauses: exp.NewUpdateClauses().SetTable(C("items")).SetLimit(L("ALL")), + }, + updateTestCase{ + ds: bd, + clauses: exp.NewUpdateClauses().SetTable(C("items")), + }, + ) } - func (uds *updateDatasetSuite) TestClearLimit() { - ds := Update("test").Limit(1) - dsc := ds.GetClauses() - ec := dsc.ClearLimit() - uds.Equal(ec, ds.ClearLimit().GetClauses()) - uds.Equal(dsc, ds.GetClauses()) -} - -func (uds *updateDatasetSuite) TestClearLimit_ToSQL() { - ds1 := Update("test").WithDialect("limit-on-update").Set(C("a").Set("a1")) - - b := ds1.LimitAll().ClearLimit() - updateSQL, args, err := b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "test" SET "a"='a1'`, updateSQL) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1"}, args) - uds.Equal(`UPDATE "test" SET "a"=?`, updateSQL) - - b = ds1.Limit(10).ClearLimit() - updateSQL, args, err = b.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "test" SET "a"='a1'`, updateSQL) - - updateSQL, args, err = b.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"a1"}, args) - uds.Equal(`UPDATE "test" SET "a"=?`, updateSQL) + bd := Update("items") + uds.assertCases( + updateTestCase{ + ds: bd.LimitAll().ClearLimit(), + clauses: exp.NewUpdateClauses().SetTable(C("items")), + }, + updateTestCase{ + ds: bd.Limit(10).ClearLimit(), + clauses: exp.NewUpdateClauses().SetTable(C("items")), + }, + updateTestCase{ + ds: bd, + clauses: exp.NewUpdateClauses().SetTable(C("items")), + }, + ) } func (uds *updateDatasetSuite) TestReturning() { - ds := Update("test") - dsc := ds.GetClauses() - ec := dsc.SetReturning(exp.NewColumnListExpression(C("a"))) - uds.Equal(ec, ds.Returning("a").GetClauses()) - uds.Equal(dsc, ds.GetClauses()) -} - -func (uds *updateDatasetSuite) TestReturning_ToSQL() { - type item struct { - Address string `db:"address"` - Name string `db:"name"` - } - ds := Update("items") - ds1 := ds.Set(item{Name: "Test", Address: "111 Test Addr"}).Returning(T("items").All()) - ds2 := ds.Set(Record{"name": "Test", "address": "111 Test Addr"}).Returning(T("items").All()) - - updateSQL, args, err := ds1.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "items" SET "address"='111 Test Addr',"name"='Test' RETURNING "items".*`, updateSQL) - - updateSQL, args, err = ds1.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"111 Test Addr", "Test"}, args) - uds.Equal(`UPDATE "items" SET "address"=?,"name"=? RETURNING "items".*`, updateSQL) - - updateSQL, args, err = ds2.ToSQL() - uds.NoError(err) - uds.Empty(args) - uds.Equal(`UPDATE "items" SET "address"='111 Test Addr',"name"='Test' RETURNING "items".*`, updateSQL) - - updateSQL, args, err = ds2.Prepared(true).ToSQL() - uds.NoError(err) - uds.Equal([]interface{}{"111 Test Addr", "Test"}, args) - uds.Equal(`UPDATE "items" SET "address"=?,"name"=? RETURNING "items".*`, updateSQL) + bd := Update("items") + uds.assertCases( + updateTestCase{ + ds: bd.Returning("a", "b"), + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + SetReturning(exp.NewColumnListExpression("a", "b")), + }, + updateTestCase{ + ds: bd.Returning("a", "b").Returning("c"), + clauses: exp.NewUpdateClauses(). + SetTable(C("items")). + SetReturning(exp.NewColumnListExpression("c")), + }, + updateTestCase{ + ds: bd, + clauses: exp.NewUpdateClauses().SetTable(C("items")), + }, + ) } func (uds *updateDatasetSuite) TestToSQL() { @@ -866,26 +387,6 @@ func (uds *updateDatasetSuite) TestToSQL_Prepared() { md.AssertExpectations(uds.T()) } -func (uds *updateDatasetSuite) TestToSQL_withNoSources() { - ds1 := newUpdateDataset("test", nil) - type item struct { - Address string `db:"address"` - Name string `db:"name"` - } - _, _, err := ds1.Set(item{Name: "Test", Address: "111 Test Addr"}).ToSQL() - uds.EqualError(err, "goqu: no source found when generating update sql") -} - -func (uds *updateDatasetSuite) TestToSQL_withReturnNotSupported() { - ds1 := New("no-return", nil).Update("items") - type item struct { - Address string `db:"address"` - Name string `db:"name"` - } - _, _, err := ds1.Set(item{Name: "Test", Address: "111 Test Addr"}).Returning("id").ToSQL() - uds.EqualError(err, "goqu: dialect does not support RETURNING clause [dialect=no-return]") -} - func (uds *updateDatasetSuite) TestToSQL_WithError() { md := new(mocks.SQLDialect) ds := Update("test").SetDialect(md)