Skip to content

Commit

Permalink
fix: types
Browse files Browse the repository at this point in the history
  • Loading branch information
si3nloong committed Sep 27, 2024
1 parent c50d785 commit 95c7ce4
Show file tree
Hide file tree
Showing 20 changed files with 1,097 additions and 129 deletions.
26 changes: 13 additions & 13 deletions codegen/code_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,38 +614,38 @@ func (g *Generator) buildUpdateByPK(importPkgs *Package, t *tableInfo) {
func (g *Generator) valuer(importPkgs *Package, goPath string, t types.Type) string {
utype, isPtr := underlyingType(t)
if columnType, ok := g.columnTypes[t.String()]; ok && columnType.Valuer != "" {
return Expr(columnType.Valuer).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
return Expr(columnType.Valuer).Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr})
} else if _, wrong := types.MissingMethod(utype, goSqlValuer, true); wrong {
if isPtr {
return Expr("(database/sql/driver.Valuer)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
return Expr("(database/sql/driver.Valuer)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr})
}
return Expr("(database/sql/driver.Valuer)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
return Expr("(database/sql/driver.Valuer)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr})
} else if columnType, ok := g.columnDataType(t); ok && columnType.Valuer != "" {
return Expr(columnType.Valuer).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr, Len: arraySize(t)})
return Expr(columnType.Valuer).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr, Type: t, Len: arraySize(t)})
} else if isImplemented(utype, textMarshaler) {
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextMarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextMarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr})
}
return Expr(g.defaultColumnTypes["*"].Valuer).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
return Expr(g.defaultColumnTypes["*"].Valuer).Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr})
}

func (g *Generator) scanner(importPkgs *Package, goPath string, t types.Type) string {
ptr, isPtr := pointerType(t)
if columnType, ok := g.columnTypes[t.String()]; ok && columnType.Scanner != "" {
return Expr(columnType.Scanner).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
return Expr(columnType.Scanner).Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr})
} else if isImplemented(ptr, goSqlScanner) {
if isPtr {
return Expr("(database/sql.Scanner)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
return Expr("(database/sql.Scanner)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr})
}
return Expr("(database/sql.Scanner)({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
return Expr("(database/sql.Scanner)({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr})
} else if columnType, ok := g.columnDataType(t); ok && columnType.Scanner != "" {
return Expr(columnType.Scanner).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr, Len: arraySize(t)})
return Expr(columnType.Scanner).Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr, Len: arraySize(t)})
} else if isImplemented(ptr, textUnmarshaler) {
if isPtr {
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextUnmarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextUnmarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr})
}
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextUnmarshaler({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
return Expr("github.com/si3nloong/sqlgen/sequel/types.TextUnmarshaler({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr})
}
return Expr(g.defaultColumnTypes["*"].Scanner).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr})
return Expr(g.defaultColumnTypes["*"].Scanner).Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr})
}

func (g *Generator) sqlScanner(f *columnInfo) string {
Expand Down
24 changes: 19 additions & 5 deletions codegen/dialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type Dialect interface {
// Column data types
ColumnDataTypes() map[string]*ColumnType

Migrate(ctx context.Context, dsn string, w Writer, schema Schema) error
Migrate(ctx context.Context, dsn string, w Writer, m TableMigrator) error
}

type ColumnType struct {
Expand All @@ -55,12 +55,20 @@ type Index interface {
Unique() bool
}

type Schema interface {
type TableMigrator interface {
DBName() string

// Table name
TableName() string

// Return the columns of the table
Columns() []string
Keys() []string
ColumnGoType(i int) GoColumn

// Return the table primary key
PK() []string

ColumnByIndex(i int) GoColumn

RangeIndex(func(Index, int))
}

Expand All @@ -86,7 +94,7 @@ type GoColumn interface {
DataType() string

// SQL default value, this can be
// string, bool, int64, float64, sql.RawBytes
// string, []byte, bool, int64, float64, sql.RawBytes
Default() (driver.Value, bool)

// Determine whether this column is auto increment or not
Expand All @@ -97,6 +105,12 @@ type GoColumn interface {

// Column size that declared by user
Size() int

// CharacterMaxLength() (int64, bool)

// NumericPrecision() (int64, bool)

// DatetimePrecision() (int64, bool)
}

func RegisterDialect(name string, d Dialect) {
Expand Down
2 changes: 1 addition & 1 deletion codegen/dialect/mysql/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ import (
"github.com/si3nloong/sqlgen/codegen/dialect"
)

func (s *mysqlDriver) Migrate(ctx context.Context, dsn string, w dialect.Writer, schema dialect.Schema) error {
func (s *mysqlDriver) Migrate(ctx context.Context, dsn string, w dialect.Writer, schema dialect.TableMigrator) error {
return nil
}
41 changes: 20 additions & 21 deletions codegen/dialect/postgres/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@ package postgres
import (
"context"
"database/sql"
"fmt"

"github.com/si3nloong/sqlgen/internal/sqltype"
)

type column struct {
Name string
DataType string
Default sql.RawBytes
IsNullable sqltype.Bool
CharacterMaxLength sql.NullInt64
NumericPrecision sql.NullInt64
DatetimePrecision sql.NullInt64
Name string
DataType string
Default sql.RawBytes
IsNullable sqltype.Bool
// CharacterMaxLength sql.NullInt64
// NumericPrecision sql.NullInt64
// DatetimePrecision sql.NullInt64
}

func (c column) Equal(v column) bool {
Expand All @@ -26,13 +25,13 @@ func (c column) Equal(v column) bool {
func (c column) ColumnType() string {
switch c.DataType {
case "varchar":
if c.CharacterMaxLength.Valid && c.CharacterMaxLength.Int64 > 0 {
return fmt.Sprintf("%s(%d)", c.DataType, c.CharacterMaxLength.Int64)
}
// if c.CharacterMaxLength.Valid && c.CharacterMaxLength.Int64 > 0 {
// return fmt.Sprintf("%s(%d)", c.DataType, c.CharacterMaxLength.Int64)
// }
case "timestamptz":
if c.DatetimePrecision.Valid && c.DatetimePrecision.Int64 > 0 {
return fmt.Sprintf("%s(%d)", c.DataType, c.DatetimePrecision.Int64)
}
// if c.DatetimePrecision.Valid && c.DatetimePrecision.Int64 > 0 {
// return fmt.Sprintf("%s(%d)", c.DataType, c.DatetimePrecision.Int64)
// }
}
return c.DataType
}
Expand All @@ -42,10 +41,10 @@ func (s *postgresDriver) tableColumns(ctx context.Context, sqlConn *sql.DB, dbNa
column_name,
column_default,
is_nullable,
udt_name,
character_maximum_length,
numeric_precision,
datetime_precision
udt_name
-- character_maximum_length,
-- numeric_precision,
-- datetime_precision
FROM
information_schema.columns
WHERE
Expand All @@ -66,9 +65,9 @@ ORDER BY
&col.Default,
&col.IsNullable,
&col.DataType,
&col.CharacterMaxLength,
&col.NumericPrecision,
&col.DatetimePrecision,
// &col.CharacterMaxLength,
// &col.NumericPrecision,
// &col.DatetimePrecision,
); err != nil {
return nil, err
}
Expand Down
64 changes: 32 additions & 32 deletions codegen/dialect/postgres/column_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType {
},
"[...]string": {
DataType: s.columnDataType("text[{{len}}]"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalStringSlice(({{goPath}})[:],[2]byte{'{','}'})",
Scanner: "github.com/lib/pq.Array(({{addrOfGoPath}})[:])",
Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.StringArray[{{elemType}}]({{goPath}}[:])",
Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.StringArray[{{elemType}}]({{goPath}}[:])",
},
"[...]bool": {
DataType: s.columnDataType("bool[{{len}}]"),
Expand All @@ -160,8 +160,8 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType {
},
"[]string": {
DataType: s.columnDataType("text[]"),
Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalStringSlice({{goPath}},[2]byte{'{', '}'})",
Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})",
Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.StringArray[{{elemType}}]({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.StringArrayScanner({{addrOfGoPath}})",
},
"[]rune": {
DataType: s.columnDataType("text"),
Expand All @@ -175,73 +175,73 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType {
},
"[]bool": {
DataType: s.columnDataType("bool[]"),
Valuer: "github.com/lib/pq.Array({{goPath}})",
Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})",
Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.BoolArray[{{elemType}}]({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.BoolArrayScanner({{addrOfGoPath}})",
},
"[][]byte": {
DataType: s.columnDataType("bytea"),
Valuer: "github.com/lib/pq.Array({{goPath}})",
Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})",
Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.ByteArray({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.ByteArray({{goPath}})",
},
"[]float32": {
DataType: s.columnDataType("double[]"),
Valuer: "github.com/lib/pq.Array({{goPath}})",
Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})",
Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Float32Array[{{elemType}}]({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Float32Array[{{elemType}}]({{goPath}})",
},
"[]float64": {
DataType: s.columnDataType("double[]"),
Valuer: "github.com/lib/pq.Array({{goPath}})",
Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})",
Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Float64Array[{{elemType}}]({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Float64Array[{{elemType}}]({{goPath}})",
},
"[]int": {
DataType: s.columnDataType("int4[]"),
Valuer: "github.com/lib/pq.Array({{goPath}})",
Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})",
Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayValue({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayScanner({{addrOfGoPath}})",
},
"[]int8": {
DataType: s.columnDataType("int2[]"),
Valuer: "github.com/lib/pq.Array({{goPath}})",
Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})",
Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayValue({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayScanner({{addrOfGoPath}})",
},
"[]int16": {
DataType: s.columnDataType("int2[]"),
Valuer: "github.com/lib/pq.Array({{goPath}})",
Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})",
Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Int16Array[{{elemType}}]({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayScanner({{addrOfGoPath}})",
},
"[]int32": {
DataType: s.columnDataType("int4[]"),
Valuer: "github.com/lib/pq.Array({{goPath}})",
Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})",
Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Int32Array[{{elemType}}]({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayScanner({{addrOfGoPath}})",
},
"[]int64": {
DataType: s.columnDataType("int8[]"),
Valuer: "github.com/lib/pq.Array({{goPath}})",
Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})",
Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Int64Array[{{elemType}}]({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayScanner({{addrOfGoPath}})",
},
"[]uint": {
DataType: s.columnDataType("int4[]"),
Valuer: "github.com/lib/pq.Array({{goPath}})",
Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})",
Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayValue({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayScanner({{addrOfGoPath}})",
},
"[]uint8": {
DataType: s.columnDataType("int2[]"),
Valuer: "github.com/lib/pq.Array({{goPath}})",
Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})",
Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayValue({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayScanner({{addrOfGoPath}})",
},
"[]uint16": {
DataType: s.columnDataType("int2[]"),
Valuer: "github.com/lib/pq.Array({{goPath}})",
Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})",
Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayValue({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayScanner({{addrOfGoPath}})",
},
"[]uint32": {
DataType: s.columnDataType("int4[]"),
Valuer: "github.com/lib/pq.Array({{goPath}})",
Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})",
Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayValue({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayScanner({{addrOfGoPath}})",
},
"[]uint64": {
DataType: s.columnDataType("int8[]"),
Valuer: "github.com/lib/pq.Array({{goPath}})",
Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})",
Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayValue({{goPath}})",
Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayScanner({{addrOfGoPath}})",
},
"*": {
DataType: s.columnDataType("json"),
Expand Down
Loading

0 comments on commit 95c7ce4

Please sign in to comment.