Skip to content

Commit

Permalink
fix: data type mapping and codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
si3nloong committed Sep 25, 2024
1 parent f67879f commit f27854f
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 73 deletions.
26 changes: 13 additions & 13 deletions codegen/code_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -595,38 +595,38 @@ func (g *Generator) buildUpdateByPK(importPkgs *Package, t *tableInfo) {
func (g *Generator) valuer(importPkgs *Package, goPath string, t types.Type) string {
utype, isPtr := underlyingType(t)
if columnType, ok := g.columnTypes[t.String()]; ok && columnType.Valuer != "" {
return Expr(columnType.Valuer).Format(importPkgs, ExprParams{GoPath: goPath})
return Expr(columnType.Valuer).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
} else if _, wrong := types.MissingMethod(utype, goSqlValuer, true); wrong {
if isPtr {
return Expr("(database/sql/driver.Valuer)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
return Expr("(database/sql/driver.Valuer)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
}
return Expr("(database/sql/driver.Valuer)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
return Expr("(database/sql/driver.Valuer)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
} else if columnType, ok := g.columnDataType(t); ok && columnType.Valuer != "" {
return Expr(columnType.Valuer).Format(importPkgs, ExprParams{GoPath: goPath, Len: arraySize(t)})
return Expr(columnType.Valuer).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr, Len: arraySize(t)})
} else if isImplemented(utype, textMarshaler) {
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextMarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextMarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
}
return Expr(g.defaultColumnTypes["*"].Valuer).Format(importPkgs, ExprParams{GoPath: goPath})
return Expr(g.defaultColumnTypes["*"].Valuer).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
}

func (g *Generator) scanner(importPkgs *Package, goPath string, t types.Type) string {
ptr, isPtr := pointerType(t)
if columnType, ok := g.columnTypes[t.String()]; ok && columnType.Scanner != "" {
return Expr(columnType.Scanner).Format(importPkgs, ExprParams{GoPath: goPath})
return Expr(columnType.Scanner).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
} else if isImplemented(ptr, goSqlScanner) {
if isPtr {
return Expr("(database/sql.Scanner)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
return Expr("(database/sql.Scanner)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
}
return Expr("(database/sql.Scanner)({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
return Expr("(database/sql.Scanner)({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
} else if columnType, ok := g.columnDataType(t); ok && columnType.Scanner != "" {
return Expr(columnType.Scanner).Format(importPkgs, ExprParams{GoPath: goPath, Len: arraySize(t)})
return Expr(columnType.Scanner).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr, Len: arraySize(t)})
} else if isImplemented(ptr, textUnmarshaler) {
if isPtr {
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextUnmarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextUnmarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
}
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextUnmarshaler({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath})
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextUnmarshaler({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
}
return Expr(g.defaultColumnTypes["*"].Scanner).Format(importPkgs, ExprParams{GoPath: goPath})
return Expr(g.defaultColumnTypes["*"].Scanner).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
}

func (g *Generator) sqlScanner(f *columnInfo) string {
Expand Down
2 changes: 1 addition & 1 deletion codegen/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ func parseGoPackage(
return fmt.Errorf("sqlgen: struct %q has duplicate column name %q in directory %q", s.name, column.columnName, dir)
}

if v, ok := column.getOption(TagOptionSize); ok {
if v, ok := column.getOptionValue(TagOptionSize); ok {
column.size, err = strconv.Atoi(v)
if err != nil {
return fmt.Errorf(`sqlgen: invalid size value %q %w`, v, err)
Expand Down
15 changes: 11 additions & 4 deletions codegen/dialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,21 @@ type ColumnType struct {
SQLValuer string
}

type Index interface {
// Indexed columns
Columns() []string

// Whether the index is unique
Unique() bool
}

type Schema interface {
DBName() string
TableName() string
Columns() []string
Keys() []string
ColumnGoType(i int) GoColumn
Indexes() []string
RangeIndex(func(Index, int))
}

type GoColumn interface {
Expand Down Expand Up @@ -84,12 +92,11 @@ type GoColumn interface {
// Determine whether this column is auto increment or not
AutoIncr() bool

Size() int

// Key is to identify whether column is primary or foreign key
Key() bool

// Implements(*types.Interface) (*types.Func, bool)
// Column size that declared by user
Size() int
}

func RegisterDialect(name string, d Dialect) {
Expand Down
11 changes: 8 additions & 3 deletions codegen/dialect/postgres/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@ type column struct {
DatetimePrecision sql.NullInt64
}

func (c column) Equal(v column) bool {
return c.Name == v.Name &&
c.IsNullable == v.IsNullable
}

func (c column) ColumnType() string {
switch c.DataType {
case "varchar":
if c.CharacterMaxLength.Valid {
if c.CharacterMaxLength.Valid && c.CharacterMaxLength.Int64 > 0 {
return fmt.Sprintf("%s(%d)", c.DataType, c.CharacterMaxLength.Int64)
}
case "timestamptz":
if c.DatetimePrecision.Valid {
if c.DatetimePrecision.Valid && c.DatetimePrecision.Int64 > 0 {
return fmt.Sprintf("%s(%d)", c.DataType, c.DatetimePrecision.Int64)
}
}
Expand All @@ -44,7 +49,7 @@ func (s *postgresDriver) tableColumns(ctx context.Context, sqlConn *sql.DB, dbNa
FROM
information_schema.columns
WHERE
table_schema = $1 AND
table_catalog = $1 AND
table_name = $2
ORDER BY
ordinal_position;`, dbName, tableName)
Expand Down
72 changes: 56 additions & 16 deletions codegen/dialect/postgres/column_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,40 @@ import (
)

func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType {
return map[string]*dialect.ColumnType{
dataTypes := map[string]*dialect.ColumnType{
"rune": {
DataType: s.columnDataType("char(1)"),
Valuer: "(string)({{goPath}})",
Scanner: "{{addrOfGoPath}}",
},
"string": {
DataType: s.columnDataType("varchar(255)", ""),
Valuer: "(string)({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.String({{addrOfGoPath}})",
DataType: func(col dialect.GoColumn) string {
size := 255
if n := col.Size(); n > 0 {
size = n
}
return fmt.Sprintf("varchar(%d)", size)
},
Valuer: "(string)({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.String({{addrOfGoPath}})",
},
"bool": {
DataType: s.columnDataType("boolean", false),
DataType: s.columnDataType("bool", false),
Valuer: "(bool)({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Bool({{addrOfGoPath}})",
},
"int8": {
DataType: s.intDataType("smallint", int64(0)),
DataType: s.intDataType("int2", int64(0)),
Valuer: "(int64)({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})",
},
"int16": {
DataType: s.intDataType("smallint", int64(0)),
DataType: s.intDataType("int2", int64(0)),
Valuer: "(int64)({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})",
},
"int32": {
DataType: s.intDataType("integer", int64(0)),
DataType: s.intDataType("int4", int64(0)),
Valuer: "(int64)({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})",
},
Expand All @@ -48,22 +54,22 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType {
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})",
},
"int": {
DataType: s.intDataType("integer", int64(0)),
DataType: s.intDataType("int4", int64(0)),
Valuer: "(int64)({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})",
},
"uint8": {
DataType: s.intDataType("smallint", uint64(0)),
DataType: s.intDataType("int2", uint64(0)),
Valuer: "(int64)({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})",
},
"uint16": {
DataType: s.intDataType("smallint", uint64(0)),
DataType: s.intDataType("int2", uint64(0)),
Valuer: "(int64)({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})",
},
"uint32": {
DataType: s.intDataType("integer", uint64(0)),
DataType: s.intDataType("int4", uint64(0)),
Valuer: "(int64)({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})",
},
Expand All @@ -73,7 +79,7 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType {
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})",
},
"uint": {
DataType: s.intDataType("integer", uint64(0)),
DataType: s.intDataType("int4", uint64(0)),
Valuer: "(int64)({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})",
},
Expand All @@ -88,14 +94,14 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType {
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Float({{addrOfGoPath}})",
},
"time.Time": {
DataType: s.columnDataType("timestamp(6) with time zone", sql.RawBytes(`NOW()`)),
DataType: s.columnDataType("timestamptz(6)", sql.RawBytes(`NOW()`)),
Valuer: "(time.Time)({{goPath}})",
Scanner: "(*time.Time)({{addrOfGoPath}})",
},
"*time.Time": {
DataType: s.columnDataType("timestamp(6) with time zone"),
DataType: s.columnDataType("timestamptz(6)"),
Valuer: "github.com/si3nloong/sqlgen/sequel/types.Time({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.PtrOfTime({{addrOfGoPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Time({{addrOfGoPath}})",
},
"[...]rune": {
DataType: func(c dialect.GoColumn) string {
Expand All @@ -104,6 +110,13 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType {
Valuer: "string({{goPath}}[:])",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.FixedSizeRunes(({{goPath}})[:],{{len}})",
},
"[...]byte": {
DataType: func(c dialect.GoColumn) string {
return s.columnDataType(fmt.Sprintf("varchar(%d)", c.Size()))(c)
},
Valuer: "string({{goPath}}[:])",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.FixedSizeBytes(({{goPath}})[:],{{len}})",
},
"[...]string": {
DataType: s.columnDataType("text[{{len}}]"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalStringSlice(({{goPath}})[:],[2]byte{'{','}'})",
Expand Down Expand Up @@ -205,6 +218,33 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType {
Scanner: "github.com/si3nloong/sqlgen/sequel/types.JSONUnmarshaler({{addrOfGoPath}})",
},
}
s.mapIntegers(dataTypes)
return dataTypes
}

func (s *postgresDriver) mapIntegers(dict map[string]*dialect.ColumnType) {
types := [][2]string{
{"int", "int4"}, {"int8", "int2"}, {"int16", "int2"}, {"int32", "int4"}, {"int64", "bigint"},
{"uint", "int4"}, {"uint8", "int2"}, {"uint16", "int2"}, {"uint32", "int4"}, {"uint64", "bigint"},
}
for _, t := range types {
dict[t[0]] = &dialect.ColumnType{
DataType: s.intDataType(t[1], int64(0)),
Valuer: "(int64)({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})",
}
}
types = [][2]string{
{"*int", "int4"}, {"*int8", "int2"}, {"*int16", "int2"}, {"*int32", "int4"}, {"*int64", "bigint"},
{"*uint", "int4"}, {"*uint8", "int2"}, {"*uint16", "int2"}, {"*uint32", "int4"}, {"*uint64", "bigint"},
}
for _, t := range types {
dict[t[0]] = &dialect.ColumnType{
DataType: s.intDataType(t[1], int64(0)),
Valuer: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types.Integer({{addrOfGoPath}})",
}
}
}

func (s *postgresDriver) intDataType(dataType string, defaultValue ...any) func(dialect.GoColumn) string {
Expand Down
28 changes: 28 additions & 0 deletions codegen/dialect/postgres/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,26 @@ package postgres

import (
"context"
"crypto/md5"
"database/sql"
"encoding/hex"
"strings"
"unsafe"

"github.com/lib/pq"
)

//go:generate stringer --type indexType --linecomment
type indexType uint8

const (
bTree indexType = iota // BTREE
hash // HASH
brin // BRIN
unique // UNIQUE
pk
)

type index struct {
Name string
IsPK bool
Expand Down Expand Up @@ -47,3 +62,16 @@ GROUP BY index_name, is_pk;`, tableName)
}
return idxs, rows.Err()
}

func indexName(columns []string, idxType indexType) string {
str := strings.Join(columns, ",")
hash := md5.Sum(unsafe.Slice(unsafe.StringData(str), len(str)))
prefix := "IX"
switch idxType {
case unique:
prefix = "UQ"
case pk:
prefix = "PK"
}
return prefix + "_" + hex.EncodeToString(hash[:])
}
Loading

0 comments on commit f27854f

Please sign in to comment.