Skip to content

Commit

Permalink
feat: support Postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
si3nloong committed Nov 18, 2023
1 parent a1a6b8c commit cbfdab8
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 17 deletions.
26 changes: 23 additions & 3 deletions codegen/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@ import (
"golang.org/x/tools/imports"
)

type Generator struct {
dialect sequel.Dialect
quoteChar rune
}

func (g Generator) QuoteStart() string {
return string(g.quoteChar)
}

func (g Generator) Quote(v string) string {
return string(g.quoteChar) + v + string(g.quoteChar)
}

func (g Generator) QuoteEnd() string {
return string(g.quoteChar)
}

func Init(cfg *config.Config) error {
tmpl, err := template.ParseFS(codegenTemplates, "templates/init.yml.go.tpl")
if err != nil {
Expand Down Expand Up @@ -50,21 +67,24 @@ func renderTemplate[T templates.ModelTmplParams | struct{}](
strpool.ReleaseString(blr)
}()

quoteChar := rune('"')
quote := strconv.Quote
switch dialect.Driver() {
case "postgres", "sqlite":
quoteChar = rune('`')
quote = func(s string) string {
return "`" + s + "`"
}
}

g := &Generator{quoteChar: quoteChar}
impPkg := NewPackage(pkgPath, pkgName)
tmpl, err := template.New(tmplName).Funcs(template.FuncMap{
"quote": quote,
"createTable": createTableStmt(dialect),
"createTable": g.createTableStmt(dialect),
"alterTable": alterTableStmt(dialect),
"insertOneStmt": insertOneStmt(dialect),
"findByPKStmt": findByPKStmt(dialect),
"insertOneStmt": g.insertOneStmt(dialect),
"findByPKStmt": g.findByPKStmt(dialect),
"reserveImport": reserveImport(impPkg),
"castAs": castAs(impPkg),
"addrOf": addrOf(impPkg),
Expand Down
28 changes: 14 additions & 14 deletions codegen/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ func addrOf(impPkgs *Package) func(string, *templates.Field) string {
}
}

func createTableStmt(dialect sequel.Dialect) func(string, *templates.Model) string {
func (g *Generator) createTableStmt(dialect sequel.Dialect) func(string, *templates.Model) string {
return func(n string, model *templates.Model) string {
buf := strpool.AcquireString()
defer strpool.ReleaseString(buf)

buf.WriteString(`"CREATE TABLE IF NOT EXISTS "+ `)
buf.WriteString(n + `.TableName() +" (`)
buf.WriteString(g.Quote("CREATE TABLE IF NOT EXISTS "))
buf.WriteString("+ " + n + ".TableName() +" + g.QuoteStart() + " (")
for i, f := range model.Fields {
if i > 0 {
buf.WriteByte(',')
Expand All @@ -86,7 +86,7 @@ func createTableStmt(dialect sequel.Dialect) func(string, *templates.Model) stri
if model.PK != nil {
buf.WriteString(",PRIMARY KEY (" + dialect.Wrap(model.PK.Field.ColumnName) + ")")
}
buf.WriteString(`);"`)
buf.WriteString(");" + g.QuoteEnd())
return buf.String()
}
}
Expand Down Expand Up @@ -122,11 +122,11 @@ func alterTableStmt(dialect sequel.Dialect) func(*templates.Model) string {
}
}

func insertOneStmt(dialect sequel.Dialect) func(*templates.Model) string {
func (g *Generator) insertOneStmt(dialect sequel.Dialect) func(*templates.Model) string {
return func(model *templates.Model) string {
buf := strpool.AcquireString()
defer strpool.ReleaseString(buf)
buf.WriteString(`"INSERT INTO ` + dialect.Wrap(model.TableName) + " (")
buf.WriteString("INSERT INTO " + dialect.Wrap(model.TableName) + " (")
var fields []*templates.Field
if model.PK != nil && model.PK.IsAutoIncr {
for _, f := range model.Fields {
Expand All @@ -148,18 +148,18 @@ func insertOneStmt(dialect sequel.Dialect) func(*templates.Model) string {
if i > 0 {
buf.WriteByte(',')
}
buf.WriteString(dialect.Var(i))
buf.WriteString(dialect.Var(i + 1))
}
buf.WriteString(`);"`)
return buf.String()
buf.WriteString(");")
return g.Quote(buf.String())
}
}

func findByPKStmt(dialect sequel.Dialect) func(*templates.Model) string {
func (g *Generator) findByPKStmt(dialect sequel.Dialect) func(*templates.Model) string {
return func(model *templates.Model) string {
buf := strpool.AcquireString()
defer strpool.ReleaseString(buf)
buf.WriteString(`"SELECT `)
buf.WriteString("SELECT ")
for i := range model.Fields {
if i > 0 {
buf.WriteByte(',')
Expand All @@ -168,8 +168,8 @@ func findByPKStmt(dialect sequel.Dialect) func(*templates.Model) string {
}
buf.WriteString(" FROM " + dialect.Wrap(model.TableName) + " WHERE ")
buf.WriteString(dialect.Wrap(model.PK.Field.ColumnName))
buf.WriteString(` = ` + dialect.Var(1) + ` LIMIT 1;"`)
return buf.String()
buf.WriteString(" = " + dialect.Var(1) + " LIMIT 1;")
return g.Quote(buf.String())
}
}

Expand All @@ -186,7 +186,7 @@ func varStmt(dialect sequel.Dialect) func(*templates.Model) string {
if i > 0 {
blr.WriteByte(',')
}
blr.WriteString(dialect.Var(i))
blr.WriteString(dialect.Var(i + 1))
}
blr.WriteByte(')')
return blr.String()
Expand Down
4 changes: 4 additions & 0 deletions codegen/templates/init.yml.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ struct_tag: {{ .Tag }}
# Optional:
skip_escape: {{ .SkipEscape }}

# Optional: to add prefix to getter
getter:
prefix: "Prefix"

# Optional: Where should the generated model code go?
exec:
skip_empty: {{ .Exec.SkipEmpty }}
Expand Down
64 changes: 64 additions & 0 deletions sequel/types/binary_test.go
Original file line number Diff line number Diff line change
@@ -1 +1,65 @@
package types

import (
"database/sql/driver"
"encoding/base64"
"testing"

"github.com/stretchr/testify/require"
)

// base64Text should implement encoding.BinaryMarshaler and encoding.BinaryUnmarshaler
type base64Text string

func (b base64Text) MarshalBinary() ([]byte, error) {
// Implement your binary marshal logic here
return []byte(base64.StdEncoding.EncodeToString([]byte(b))), nil
}

func (b *base64Text) UnmarshalBinary(data []byte) error {
bytes, err := base64.StdEncoding.DecodeString(string(data))
if err != nil {
return err
}
*b = base64Text(bytes)
// Implement your binary unmarshal logic here
return nil
}

func TestBinaryMarshaler(t *testing.T) {
// Create an instance of base64Text
v := base64Text(`abcd`) // Initialize your instance with appropriate values

// Create a binaryMarshaler for base64Text
bm := BinaryMarshaler(v)

// Call the Value method of binaryMarshaler
val, err := bm.Value()
require.NoError(t, err)

require.Equal(t, []byte("YWJjZA=="), val)

// Assert that the returned value is of type driver.Value
if _, ok := val.(driver.Value); !ok {
t.Errorf("Value() did not return a value of type driver.Value")
}
}

func TestBinaryUnmarshaler(t *testing.T) {
// Create an instance of base64Text
v := base64Text(``) // Initialize your instance with appropriate values

// Create a binaryUnmarshaler for base64Text
bum := BinaryUnmarshaler(&v)

// Create a sample []byte data to be unmarshaled
sampleData := []byte(`aGVsbG8gd29ybGQgIQ==`) // Provide a valid byte slice for testing

// Call the Scan method of binaryUnmarshaler
require.NoError(t, bum.Scan(sampleData))

require.Equal(t, `hello world !`, string(v))
// Assert that the base64Text instance has been properly unmarshaled
// Compare the fields of yourInstance with the expected values
// Add assertions based on your specific implementation
}
10 changes: 10 additions & 0 deletions sequel/types/bool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,14 @@ func TestBool(t *testing.T) {
require.False(t, flag)
require.False(t, v.Interface())
})

t.Run("Value method", func(t *testing.T) {
var value bool
b := Bool(&value)

// Test Value method
val, err := b.Value()
require.NoError(t, err)
require.Equal(t, false, val)
})
}

0 comments on commit cbfdab8

Please sign in to comment.