From 95c7ce4598c76d8ef2dacf8c0af9091cb49aed71 Mon Sep 17 00:00:00 2001 From: si3nloong Date: Fri, 27 Sep 2024 12:45:22 +0800 Subject: [PATCH] fix: types --- codegen/code_generator.go | 26 +-- codegen/dialect/dialect.go | 24 ++- codegen/dialect/mysql/migration.go | 2 +- codegen/dialect/postgres/column.go | 41 +++-- codegen/dialect/postgres/column_type.go | 64 +++---- codegen/dialect/postgres/migration.go | 75 +++++---- codegen/dialect/sqlite/migration.go | 2 +- codegen/expr.go | 23 ++- codegen/types.go | 4 +- sequel/types/int.go | 20 +-- sequel/types/pgtype/array.go | 211 ++++++++++++++++++++++++ sequel/types/pgtype/bool_array.go | 139 ++++++++++++++++ sequel/types/pgtype/byte_array.go | 129 +++++++++++++++ sequel/types/pgtype/float_array.go | 133 +++++++++++++++ sequel/types/pgtype/int_array.go | 113 +++++++++++++ sequel/types/pgtype/string_array.go | 109 ++++++++++++ sequel/types/pgtype/uint_array.go | 89 ++++++++++ sequel/types/string.go | 8 +- sequel/types/time.go | 2 +- sequel/types/types.go | 12 +- 20 files changed, 1097 insertions(+), 129 deletions(-) create mode 100644 sequel/types/pgtype/array.go create mode 100644 sequel/types/pgtype/bool_array.go create mode 100644 sequel/types/pgtype/byte_array.go create mode 100644 sequel/types/pgtype/float_array.go create mode 100644 sequel/types/pgtype/int_array.go create mode 100644 sequel/types/pgtype/string_array.go create mode 100644 sequel/types/pgtype/uint_array.go diff --git a/codegen/code_generator.go b/codegen/code_generator.go index eda7f07..dc824ba 100644 --- a/codegen/code_generator.go +++ b/codegen/code_generator.go @@ -614,38 +614,38 @@ func (g *Generator) buildUpdateByPK(importPkgs *Package, t *tableInfo) { func (g *Generator) valuer(importPkgs *Package, goPath string, t types.Type) string { utype, isPtr := underlyingType(t) if columnType, ok := g.columnTypes[t.String()]; ok && columnType.Valuer != "" { - return Expr(columnType.Valuer).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr}) + return Expr(columnType.Valuer).Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr}) } else if _, wrong := types.MissingMethod(utype, goSqlValuer, true); wrong { if isPtr { - return Expr("(database/sql/driver.Valuer)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr}) + return Expr("(database/sql/driver.Valuer)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr}) } - return Expr("(database/sql/driver.Valuer)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr}) + return Expr("(database/sql/driver.Valuer)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr}) } else if columnType, ok := g.columnDataType(t); ok && columnType.Valuer != "" { - return Expr(columnType.Valuer).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr, Len: arraySize(t)}) + return Expr(columnType.Valuer).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr, Type: t, Len: arraySize(t)}) } else if isImplemented(utype, textMarshaler) { - return Expr("github.com/si3nloong/sqlgen/sequel/types.TextMarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr}) + return Expr("github.com/si3nloong/sqlgen/sequel/types.TextMarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr}) } - return Expr(g.defaultColumnTypes["*"].Valuer).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr}) + return Expr(g.defaultColumnTypes["*"].Valuer).Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr}) } func (g *Generator) scanner(importPkgs *Package, goPath string, t types.Type) string { ptr, isPtr := pointerType(t) if columnType, ok := g.columnTypes[t.String()]; ok && columnType.Scanner != "" { - return Expr(columnType.Scanner).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr}) + return Expr(columnType.Scanner).Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr}) } else if isImplemented(ptr, goSqlScanner) { if isPtr { - return Expr("(database/sql.Scanner)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr}) + return Expr("(database/sql.Scanner)({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr}) } - return Expr("(database/sql.Scanner)({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr}) + return Expr("(database/sql.Scanner)({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr}) } else if columnType, ok := g.columnDataType(t); ok && columnType.Scanner != "" { - return Expr(columnType.Scanner).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr, Len: arraySize(t)}) + return Expr(columnType.Scanner).Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr, Len: arraySize(t)}) } else if isImplemented(ptr, textUnmarshaler) { if isPtr { - return Expr("github.com/si3nloong/sqlgen/sequel/types.TextUnmarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr}) + return Expr("github.com/si3nloong/sqlgen/sequel/types.TextUnmarshaler({{goPath}})").Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr}) } - return Expr("github.com/si3nloong/sqlgen/sequel/types.TextUnmarshaler({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr}) + return Expr("github.com/si3nloong/sqlgen/sequel/types.TextUnmarshaler({{addrOfGoPath}})").Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr}) } - return Expr(g.defaultColumnTypes["*"].Scanner).Format(importPkgs, ExprParams{GoPath: goPath, IsPtr: isPtr}) + return Expr(g.defaultColumnTypes["*"].Scanner).Format(importPkgs, ExprParams{GoPath: goPath, Type: t, IsPtr: isPtr}) } func (g *Generator) sqlScanner(f *columnInfo) string { diff --git a/codegen/dialect/dialect.go b/codegen/dialect/dialect.go index 59d43c3..bf2aacb 100644 --- a/codegen/dialect/dialect.go +++ b/codegen/dialect/dialect.go @@ -36,7 +36,7 @@ type Dialect interface { // Column data types ColumnDataTypes() map[string]*ColumnType - Migrate(ctx context.Context, dsn string, w Writer, schema Schema) error + Migrate(ctx context.Context, dsn string, w Writer, m TableMigrator) error } type ColumnType struct { @@ -55,12 +55,20 @@ type Index interface { Unique() bool } -type Schema interface { +type TableMigrator interface { DBName() string + + // Table name TableName() string + + // Return the columns of the table Columns() []string - Keys() []string - ColumnGoType(i int) GoColumn + + // Return the table primary key + PK() []string + + ColumnByIndex(i int) GoColumn + RangeIndex(func(Index, int)) } @@ -86,7 +94,7 @@ type GoColumn interface { DataType() string // SQL default value, this can be - // string, bool, int64, float64, sql.RawBytes + // string, []byte, bool, int64, float64, sql.RawBytes Default() (driver.Value, bool) // Determine whether this column is auto increment or not @@ -97,6 +105,12 @@ type GoColumn interface { // Column size that declared by user Size() int + + // CharacterMaxLength() (int64, bool) + + // NumericPrecision() (int64, bool) + + // DatetimePrecision() (int64, bool) } func RegisterDialect(name string, d Dialect) { diff --git a/codegen/dialect/mysql/migration.go b/codegen/dialect/mysql/migration.go index f8a3b68..7b442d8 100644 --- a/codegen/dialect/mysql/migration.go +++ b/codegen/dialect/mysql/migration.go @@ -6,6 +6,6 @@ import ( "github.com/si3nloong/sqlgen/codegen/dialect" ) -func (s *mysqlDriver) Migrate(ctx context.Context, dsn string, w dialect.Writer, schema dialect.Schema) error { +func (s *mysqlDriver) Migrate(ctx context.Context, dsn string, w dialect.Writer, schema dialect.TableMigrator) error { return nil } diff --git a/codegen/dialect/postgres/column.go b/codegen/dialect/postgres/column.go index 5390e94..a2a8a84 100644 --- a/codegen/dialect/postgres/column.go +++ b/codegen/dialect/postgres/column.go @@ -3,19 +3,18 @@ package postgres import ( "context" "database/sql" - "fmt" "github.com/si3nloong/sqlgen/internal/sqltype" ) type column struct { - Name string - DataType string - Default sql.RawBytes - IsNullable sqltype.Bool - CharacterMaxLength sql.NullInt64 - NumericPrecision sql.NullInt64 - DatetimePrecision sql.NullInt64 + Name string + DataType string + Default sql.RawBytes + IsNullable sqltype.Bool + // CharacterMaxLength sql.NullInt64 + // NumericPrecision sql.NullInt64 + // DatetimePrecision sql.NullInt64 } func (c column) Equal(v column) bool { @@ -26,13 +25,13 @@ func (c column) Equal(v column) bool { func (c column) ColumnType() string { switch c.DataType { case "varchar": - if c.CharacterMaxLength.Valid && c.CharacterMaxLength.Int64 > 0 { - return fmt.Sprintf("%s(%d)", c.DataType, c.CharacterMaxLength.Int64) - } + // if c.CharacterMaxLength.Valid && c.CharacterMaxLength.Int64 > 0 { + // return fmt.Sprintf("%s(%d)", c.DataType, c.CharacterMaxLength.Int64) + // } case "timestamptz": - if c.DatetimePrecision.Valid && c.DatetimePrecision.Int64 > 0 { - return fmt.Sprintf("%s(%d)", c.DataType, c.DatetimePrecision.Int64) - } + // if c.DatetimePrecision.Valid && c.DatetimePrecision.Int64 > 0 { + // return fmt.Sprintf("%s(%d)", c.DataType, c.DatetimePrecision.Int64) + // } } return c.DataType } @@ -42,10 +41,10 @@ func (s *postgresDriver) tableColumns(ctx context.Context, sqlConn *sql.DB, dbNa column_name, column_default, is_nullable, - udt_name, - character_maximum_length, - numeric_precision, - datetime_precision + udt_name + -- character_maximum_length, + -- numeric_precision, + -- datetime_precision FROM information_schema.columns WHERE @@ -66,9 +65,9 @@ ORDER BY &col.Default, &col.IsNullable, &col.DataType, - &col.CharacterMaxLength, - &col.NumericPrecision, - &col.DatetimePrecision, + // &col.CharacterMaxLength, + // &col.NumericPrecision, + // &col.DatetimePrecision, ); err != nil { return nil, err } diff --git a/codegen/dialect/postgres/column_type.go b/codegen/dialect/postgres/column_type.go index 96fec39..657b5fa 100644 --- a/codegen/dialect/postgres/column_type.go +++ b/codegen/dialect/postgres/column_type.go @@ -150,8 +150,8 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType { }, "[...]string": { DataType: s.columnDataType("text[{{len}}]"), - Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalStringSlice(({{goPath}})[:],[2]byte{'{','}'})", - Scanner: "github.com/lib/pq.Array(({{addrOfGoPath}})[:])", + Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.StringArray[{{elemType}}]({{goPath}}[:])", + Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.StringArray[{{elemType}}]({{goPath}}[:])", }, "[...]bool": { DataType: s.columnDataType("bool[{{len}}]"), @@ -160,8 +160,8 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType { }, "[]string": { DataType: s.columnDataType("text[]"), - Valuer: "github.com/si3nloong/sqlgen/sequel/encoding.MarshalStringSlice({{goPath}},[2]byte{'{', '}'})", - Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.StringArray[{{elemType}}]({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.StringArrayScanner({{addrOfGoPath}})", }, "[]rune": { DataType: s.columnDataType("text"), @@ -175,73 +175,73 @@ func (s *postgresDriver) ColumnDataTypes() map[string]*dialect.ColumnType { }, "[]bool": { DataType: s.columnDataType("bool[]"), - Valuer: "github.com/lib/pq.Array({{goPath}})", - Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.BoolArray[{{elemType}}]({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.BoolArrayScanner({{addrOfGoPath}})", }, "[][]byte": { DataType: s.columnDataType("bytea"), - Valuer: "github.com/lib/pq.Array({{goPath}})", - Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.ByteArray({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.ByteArray({{goPath}})", }, "[]float32": { DataType: s.columnDataType("double[]"), - Valuer: "github.com/lib/pq.Array({{goPath}})", - Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Float32Array[{{elemType}}]({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Float32Array[{{elemType}}]({{goPath}})", }, "[]float64": { DataType: s.columnDataType("double[]"), - Valuer: "github.com/lib/pq.Array({{goPath}})", - Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Float64Array[{{elemType}}]({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Float64Array[{{elemType}}]({{goPath}})", }, "[]int": { DataType: s.columnDataType("int4[]"), - Valuer: "github.com/lib/pq.Array({{goPath}})", - Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayValue({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayScanner({{addrOfGoPath}})", }, "[]int8": { DataType: s.columnDataType("int2[]"), - Valuer: "github.com/lib/pq.Array({{goPath}})", - Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayValue({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayScanner({{addrOfGoPath}})", }, "[]int16": { DataType: s.columnDataType("int2[]"), - Valuer: "github.com/lib/pq.Array({{goPath}})", - Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Int16Array[{{elemType}}]({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayScanner({{addrOfGoPath}})", }, "[]int32": { DataType: s.columnDataType("int4[]"), - Valuer: "github.com/lib/pq.Array({{goPath}})", - Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Int32Array[{{elemType}}]({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayScanner({{addrOfGoPath}})", }, "[]int64": { DataType: s.columnDataType("int8[]"), - Valuer: "github.com/lib/pq.Array({{goPath}})", - Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.Int64Array[{{elemType}}]({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.IntArrayScanner({{addrOfGoPath}})", }, "[]uint": { DataType: s.columnDataType("int4[]"), - Valuer: "github.com/lib/pq.Array({{goPath}})", - Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayValue({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayScanner({{addrOfGoPath}})", }, "[]uint8": { DataType: s.columnDataType("int2[]"), - Valuer: "github.com/lib/pq.Array({{goPath}})", - Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayValue({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayScanner({{addrOfGoPath}})", }, "[]uint16": { DataType: s.columnDataType("int2[]"), - Valuer: "github.com/lib/pq.Array({{goPath}})", - Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayValue({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayScanner({{addrOfGoPath}})", }, "[]uint32": { DataType: s.columnDataType("int4[]"), - Valuer: "github.com/lib/pq.Array({{goPath}})", - Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayValue({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayScanner({{addrOfGoPath}})", }, "[]uint64": { DataType: s.columnDataType("int8[]"), - Valuer: "github.com/lib/pq.Array({{goPath}})", - Scanner: "github.com/lib/pq.Array({{addrOfGoPath}})", + Valuer: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayValue({{goPath}})", + Scanner: "github.com/si3nloong/sqlgen/sequel/types/pgtype.UintArrayScanner({{addrOfGoPath}})", }, "*": { DataType: s.columnDataType("json"), diff --git a/codegen/dialect/postgres/migration.go b/codegen/dialect/postgres/migration.go index fe1b63c..b628b99 100644 --- a/codegen/dialect/postgres/migration.go +++ b/codegen/dialect/postgres/migration.go @@ -11,7 +11,7 @@ import ( "github.com/si3nloong/sqlgen/internal/sqltype" ) -func (s *postgresDriver) Migrate(ctx context.Context, dsn string, w dialect.Writer, schema dialect.Schema) error { +func (s *postgresDriver) Migrate(ctx context.Context, dsn string, w dialect.Writer, schema dialect.TableMigrator) error { sqlConn, err := sql.Open("pgx", dsn) if err != nil { return err @@ -35,15 +35,27 @@ func (s *postgresDriver) Migrate(ctx context.Context, dsn string, w dialect.Writ // Check existing columns and new columns is not matching // If it's not matching then we need to do alter table columns := schema.Columns() + // log.Println(existedColumns) + // log.Println(reflect.DeepEqual(existedColumns, lo.Map(columns, func(_ string, i int) column { + // field := schema.ColumnByIndex(i) + // return column{ + // Name: field.ColumnName(), + // DataType: field.DataType(), + // IsNullable: sqltype.Bool(field.GoNullable()), + // // CharacterMaxLength: toNullInt64(field.CharacterMaxLength()), + // // NumericPrecision: toNullInt64(field.NumericPrecision()), + // // DatetimePrecision: toNullInt64(field.DatetimePrecision()), + // } + // }))) if reflect.DeepEqual(existedColumns, lo.Map(columns, func(_ string, i int) column { - field := schema.ColumnGoType(i) + field := schema.ColumnByIndex(i) return column{ - Name: field.ColumnName(), - DataType: field.DataType(), - IsNullable: sqltype.Bool(field.GoNullable()), - CharacterMaxLength: sql.NullInt64{}, - NumericPrecision: sql.NullInt64{}, - DatetimePrecision: sql.NullInt64{}, + Name: field.ColumnName(), + DataType: field.DataType(), + IsNullable: sqltype.Bool(field.GoNullable()), + // CharacterMaxLength: toNullInt64(field.CharacterMaxLength()), + // NumericPrecision: toNullInt64(field.NumericPrecision()), + // DatetimePrecision: toNullInt64(field.DatetimePrecision()), } })) { return dialect.ErrNoNewMigration @@ -56,7 +68,7 @@ func (s *postgresDriver) Migrate(ctx context.Context, dsn string, w dialect.Writ if i > 0 { w.WriteString(",\n") } - column := schema.ColumnGoType(i) + column := schema.ColumnByIndex(i) w.WriteString("\t" + s.QuoteIdentifier(column.ColumnName()) + " " + column.DataType()) if !column.GoNullable() { w.WriteString(" NOT NULL") @@ -66,13 +78,14 @@ func (s *postgresDriver) Migrate(ctx context.Context, dsn string, w dialect.Writ } } - if keys := schema.Keys(); len(keys) > 0 { - w.WriteString(",\n\tPRIMARY KEY (") - for i := range keys { + pks := schema.PK() + if len(pks) > 0 { + w.WriteString(",\n\tCONSTRAINT " + s.QuoteIdentifier(indexName(pks, pk)) + " PRIMARY KEY (") + for i := range pks { if i > 0 { - w.WriteString("," + s.QuoteIdentifier(keys[i])) + w.WriteString("," + s.QuoteIdentifier(pks[i])) } else { - w.WriteString(s.QuoteIdentifier(keys[i])) + w.WriteString(s.QuoteIdentifier(pks[i])) } } w.WriteString(")") @@ -105,7 +118,7 @@ func (s *postgresDriver) Migrate(ctx context.Context, dsn string, w dialect.Writ down.WriteString(",\n") } - col := schema.ColumnGoType(i) + col := schema.ColumnByIndex(i) if prevColumn, idx, _ := lo.FindIndexOf(existedColumns, func(v column) bool { return v.Name == col.ColumnName() }); idx < 0 { @@ -145,31 +158,33 @@ func (s *postgresDriver) Migrate(ctx context.Context, dsn string, w dialect.Writ return err } - keys := schema.Keys() - if existedIndex, _, ok := lo.FindIndexOf(existedIndexes, func(v index) bool { + if pkIndex, _, ok := lo.FindIndexOf(existedIndexes, func(v index) bool { return v.IsPK - }); !ok && len(keys) > 0 { - up.WriteString(",\n\tADD CONSTRAINT " + s.QuoteIdentifier(indexName(keys, pk)) + " PRIMARY KEY (") - for i := range keys { + }); !ok && len(pks) > 0 { + up.WriteString(",\n\tADD CONSTRAINT " + s.QuoteIdentifier(indexName(pks, pk)) + " PRIMARY KEY (") + for i := range pks { if i > 0 { - up.WriteString("," + s.QuoteIdentifier(keys[i])) + up.WriteString("," + s.QuoteIdentifier(pks[i])) } else { - up.WriteString(s.QuoteIdentifier(keys[i])) + up.WriteString(s.QuoteIdentifier(pks[i])) } } up.WriteByte(')') down.WriteString(",\n\tDROP PRIMARY KEY") } else { // log.Println(existedIndexes, len(existedIndexes)) + // existedIndexes = append(existedIndexes[:idx], existedIndexes[idx+1:]...) - up.WriteString(",\n\tDROP CONSTRAINT " + s.QuoteIdentifier(existedIndex.Name)) - if len(keys) > 0 { - up.WriteString(",\n\tADD " + s.QuoteIdentifier(indexName(keys, pk)) + " PRIMARY KEY (") - for i := range keys { + up.WriteString(",\n\tDROP CONSTRAINT " + s.QuoteIdentifier(pkIndex.Name)) + // If the current primary key index name is not similar to the previous one + // we definitely need to replace it + if len(pks) > 0 && indexName(pks, pk) != pkIndex.Name { + up.WriteString(",\n\tADD " + s.QuoteIdentifier(indexName(pks, pk)) + " PRIMARY KEY (") + for i := range pks { if i > 0 { - up.WriteString("," + s.QuoteIdentifier(keys[i])) + up.WriteString("," + s.QuoteIdentifier(pks[i])) } else { - up.WriteString(s.QuoteIdentifier(keys[i])) + up.WriteString(s.QuoteIdentifier(pks[i])) } } up.WriteByte(')') @@ -212,3 +227,7 @@ func (s *postgresDriver) Migrate(ctx context.Context, dsn string, w dialect.Writ w.WriteString(down.String()) return nil } + +func toNullInt64(n int64, ok bool) sql.NullInt64 { + return sql.NullInt64{Int64: n, Valid: ok} +} diff --git a/codegen/dialect/sqlite/migration.go b/codegen/dialect/sqlite/migration.go index 843da6e..5c181bf 100644 --- a/codegen/dialect/sqlite/migration.go +++ b/codegen/dialect/sqlite/migration.go @@ -6,6 +6,6 @@ import ( "github.com/si3nloong/sqlgen/codegen/dialect" ) -func (s *sqliteDriver) Migrate(ctx context.Context, dsn string, w dialect.Writer, schema dialect.Schema) error { +func (s *sqliteDriver) Migrate(ctx context.Context, dsn string, w dialect.Writer, schema dialect.TableMigrator) error { return nil } diff --git a/codegen/expr.go b/codegen/expr.go index b3148ff..86edd17 100644 --- a/codegen/expr.go +++ b/codegen/expr.go @@ -33,6 +33,7 @@ type ExprParams struct { GoPath string IsPtr bool Len int64 + Type types.Type } func (e Expr) Format(pkg *Package, args ...ExprParams) string { @@ -52,6 +53,15 @@ func (e Expr) Format(pkg *Package, args ...ExprParams) string { "goPath": func() string { return params.GoPath }, + "elemType": func() string { + switch t := params.Type.(type) { + case *types.Array: + return importPkgIfNeeded(pkg, t.Elem().String()) + case *types.Slice: + return importPkgIfNeeded(pkg, t.Elem().String()) + } + return "" + }, "addrOfGoPath": func() string { if params.IsPtr { return params.GoPath @@ -75,14 +85,19 @@ func (e Expr) Format(pkg *Package, args ...ExprParams) string { panic(err) } str = buf.String() - matches := pkgRegexp.FindStringSubmatch(str) + str = importPkgIfNeeded(pkg, str) + return str +} + +func importPkgIfNeeded(pkg *Package, importPath string) string { + matches := pkgRegexp.FindStringSubmatch(importPath) if len(matches) > 0 { p, _ := pkg.Import(types.NewPackage(matches[1], filepath.Base(matches[1]))) if p != nil { - str = strings.Replace(str, matches[1], p.Name(), -1) + importPath = strings.Replace(importPath, matches[1], p.Name(), -1) } else { - str = strings.Replace(str, matches[1]+".", "", -1) + importPath = strings.Replace(importPath, matches[1]+".", "", -1) } } - return str + return importPath } diff --git a/codegen/types.go b/codegen/types.go index a1e4b11..78e5d3c 100644 --- a/codegen/types.go +++ b/codegen/types.go @@ -59,7 +59,7 @@ func (b *tableInfo) TableName() string { return b.tableName } -func (b *tableInfo) Keys() []string { +func (b *tableInfo) PK() []string { return lo.Map(b.keys, func(c *columnInfo, _ int) string { return c.columnName }) @@ -71,7 +71,7 @@ func (b *tableInfo) Columns() []string { }) } -func (b *tableInfo) ColumnGoType(i int) dialect.GoColumn { +func (b *tableInfo) ColumnByIndex(i int) dialect.GoColumn { return b.columns[i] } diff --git a/sequel/types/int.go b/sequel/types/int.go index e2f02b3..18732a3 100644 --- a/sequel/types/int.go +++ b/sequel/types/int.go @@ -1,7 +1,6 @@ package types import ( - "database/sql" "database/sql/driver" "fmt" "strconv" @@ -15,12 +14,7 @@ type intLike[T constraints.Integer] struct { strictType bool } -var ( - _ sql.Scanner = (*intLike[int])(nil) - _ driver.Valuer = (*intLike[int])(nil) -) - -func Integer[T constraints.Integer](addr *T, strict ...bool) *intLike[T] { +func Integer[T constraints.Integer](addr *T, strict ...bool) ValueScanner[T] { var strictType bool if len(strict) > 0 { strictType = strict[0] @@ -45,6 +39,12 @@ func (i intLike[T]) Value() (driver.Value, error) { func (i *intLike[T]) Scan(v any) error { var val T switch vi := v.(type) { + case []byte: + m, err := strconv.ParseInt(unsafe.String(unsafe.SliceData(vi), len(vi)), 10, 64) + if err != nil { + return err + } + val = T(m) case int64: val = T(vi) case nil: @@ -57,12 +57,6 @@ func (i *intLike[T]) Scan(v any) error { } switch vi := v.(type) { - case []byte: - m, err := strconv.ParseInt(unsafe.String(unsafe.SliceData(vi), len(vi)), 10, 64) - if err != nil { - return err - } - val = T(m) case string: m, err := strconv.ParseInt(string(vi), 10, 64) if err != nil { diff --git a/sequel/types/pgtype/array.go b/sequel/types/pgtype/array.go new file mode 100644 index 0000000..8262a56 --- /dev/null +++ b/sequel/types/pgtype/array.go @@ -0,0 +1,211 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "strconv" + "strings" + + "golang.org/x/exp/constraints" +) + +func arrayScan[T constraints.Integer, Arr interface{ ~[]T }](a *Arr, src any, t string) error { + switch src := src.(type) { + case []byte: + return scanBytes(a, src, t) + case string: + return scanBytes(a, ([]byte)(src), t) + case nil: + *a = nil + return nil + } + return fmt.Errorf("pgtype: cannot convert %T to %s", src, t) +} + +func intArrayValue[T constraints.Signed](a []T) (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendInt(b, (int64)(a[0]), 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendInt(b, (int64)(a[i]), 10) + } + + return string(append(b, '}')), nil + } + return "{}", nil +} + +func uintArrayValue[T constraints.Unsigned](a []T) (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendUint(b, (uint64)(a[0]), 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendUint(b, (uint64)(a[i]), 10) + } + + return string(append(b, '}')), nil + } + return "{}", nil +} + +// parseArray extracts the dimensions and elements of an array represented in +// text format. Only representations emitted by the backend are supported. +// Notably, whitespace around brackets and delimiters is significant, and NULL +// is case-sensitive. +// +// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) { + var depth, i int + + if len(src) < 1 || src[0] != '{' { + return nil, nil, fmt.Errorf("pgtype: unable to parse array; expected %q at offset %d", '{', 0) + } + +Open: + for i < len(src) { + switch src[i] { + case '{': + depth++ + i++ + case '}': + elems = make([][]byte, 0) + goto Close + default: + break Open + } + } + dims = make([]int, i) + +Element: + for i < len(src) { + switch src[i] { + case '{': + if depth == len(dims) { + break Element + } + depth++ + dims[depth-1] = 0 + i++ + case '"': + var elem = []byte{} + var escape bool + for i++; i < len(src); i++ { + if escape { + elem = append(elem, src[i]) + escape = false + } else { + switch src[i] { + default: + elem = append(elem, src[i]) + case '\\': + escape = true + case '"': + elems = append(elems, elem) + i++ + break Element + } + } + } + default: + for start := i; i < len(src); i++ { + if bytes.HasPrefix(src[i:], del) || src[i] == '}' { + elem := src[start:i] + if len(elem) == 0 { + return nil, nil, fmt.Errorf("pgtype: unable to parse array; unexpected %q at offset %d", src[i], i) + } + if bytes.Equal(elem, []byte("NULL")) { + elem = nil + } + elems = append(elems, elem) + break Element + } + } + } + } + + for i < len(src) { + if bytes.HasPrefix(src[i:], del) && depth > 0 { + dims[depth-1]++ + i += len(del) + goto Element + } else if src[i] == '}' && depth > 0 { + dims[depth-1]++ + depth-- + i++ + } else { + return nil, nil, fmt.Errorf("pgtype: unable to parse array; unexpected %q at offset %d", src[i], i) + } + } + +Close: + for i < len(src) { + if src[i] == '}' && depth > 0 { + depth-- + i++ + } else { + return nil, nil, fmt.Errorf("pgtype: unable to parse array; unexpected %q at offset %d", src[i], i) + } + } + if depth > 0 { + err = fmt.Errorf("pgtype: unable to parse array; expected %q at offset %d", '}', i) + } + if err == nil { + for _, d := range dims { + if (len(elems) % d) != 0 { + err = fmt.Errorf("pgtype: multidimensional arrays must have elements with matching dimensions") + } + } + } + return +} + +func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) { + dims, elems, err := parseArray(src, del) + if err != nil { + return nil, err + } + if len(dims) > 1 { + return nil, fmt.Errorf("pgtype: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ) + } + return elems, err +} + +func appendArrayQuotedBytes(b, v []byte) []byte { + b = append(b, '"') + for { + i := bytes.IndexAny(v, `"\`) + if i < 0 { + b = append(b, v...) + break + } + if i > 0 { + b = append(b, v[:i]...) + } + b = append(b, '\\', v[i]) + v = v[i+1:] + } + return append(b, '"') +} + +// func appendValue(b []byte, v driver.Value) ([]byte, error) { +// return append(b, encode(nil, v, 0)...), nil +// } diff --git a/sequel/types/pgtype/bool_array.go b/sequel/types/pgtype/bool_array.go new file mode 100644 index 0000000..9c6a3c8 --- /dev/null +++ b/sequel/types/pgtype/bool_array.go @@ -0,0 +1,139 @@ +package pgtype + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "strconv" + "unsafe" +) + +func BoolArrayValue[T ~bool](b []T) driver.Valuer { + return BoolArray[T](b) +} + +func BoolArrayScanner[T ~bool](b *[]T) sql.Scanner { + return &boolArray[T]{b: b} +} + +type boolArray[T ~bool] struct { + b *[]T +} + +// Scan implements the sql.Scanner interface. +func (a *boolArray[T]) Scan(src any) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a.b = nil + return nil + } + return fmt.Errorf("pgtype: cannot convert %T to BoolArray", src) +} + +func (a boolArray[T]) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "BoolArray") + if err != nil { + return err + } + if *a.b != nil && len(elems) == 0 { + *a.b = (*a.b)[:0] + } else { + b := make(BoolArray[T], len(elems)) + for i, v := range elems { + if len(v) != 1 { + return fmt.Errorf("pgtype: could not parse boolean array index %d: invalid boolean %q", i, v) + } + switch v[0] { + case 't': + b[i] = true + case 'f': + b[i] = false + default: + f, err := strconv.ParseBool(unsafe.String(unsafe.SliceData(v), len(v))) + if err != nil { + return fmt.Errorf("pgtype: could not parse boolean array index %d: invalid boolean %q", i, v) + } + b[i] = (T)(f) + } + } + *a.b = b + } + return nil +} + +// BoolArray represents a one-dimensional array of the PostgreSQL boolean type. +type BoolArray[T ~bool] []T + +// Scan implements the sql.Scanner interface. +func (a *BoolArray[T]) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + return fmt.Errorf("pgtype: cannot convert %T to BoolArray", src) +} + +func (a *BoolArray[T]) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "BoolArray") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(BoolArray[T], len(elems)) + for i, v := range elems { + if len(v) != 1 { + return fmt.Errorf("pgtype: could not parse boolean array index %d: invalid boolean %q", i, v) + } + switch v[0] { + case 't': + b[i] = true + case 'f': + b[i] = false + default: + return fmt.Errorf("pgtype: could not parse boolean array index %d: invalid boolean %q", i, v) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a BoolArray[T]) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be exactly two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1+2*n) + + for i := 0; i < n; i++ { + b[2*i] = ',' + if a[i] { + b[1+2*i] = 't' + } else { + b[1+2*i] = 'f' + } + } + + b[0] = '{' + b[2*n] = '}' + + return string(b), nil + } + + return "{}", nil +} diff --git a/sequel/types/pgtype/byte_array.go b/sequel/types/pgtype/byte_array.go new file mode 100644 index 0000000..77bbdf1 --- /dev/null +++ b/sequel/types/pgtype/byte_array.go @@ -0,0 +1,129 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/hex" + "fmt" + "strconv" +) + +// ByteaArray represents a one-dimensional array of the PostgreSQL bytea type. +type ByteaArray [][]byte + +// Scan implements the sql.Scanner interface. +func (a *ByteaArray) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to ByteaArray", src) +} + +func (a *ByteaArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "ByteaArray") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(ByteaArray, len(elems)) + for i, v := range elems { + b[i], err = parseBytea(v) + if err != nil { + return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error()) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. It uses the "hex" format which +// is only supported on PostgreSQL 9.0 or newer. +func (a ByteaArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // 3*N bytes of hex formatting, and N-1 bytes of delimiters. + size := 1 + 6*n + for _, x := range a { + size += hex.EncodedLen(len(x)) + } + + b := make([]byte, size) + + for i, s := 0, b; i < n; i++ { + o := copy(s, `,"\\x`) + o += hex.Encode(s[o:], a[i]) + s[o] = '"' + s = s[o+1:] + } + + b[0] = '{' + b[size-1] = '}' + + return string(b), nil + } + + return "{}", nil +} + +// Parse a bytea value received from the server. Both "hex" and the legacy +// "escape" format are supported. +func parseBytea(s []byte) (result []byte, err error) { + if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { + // bytea_output = hex + s = s[2:] // trim off leading "\\x" + result = make([]byte, hex.DecodedLen(len(s))) + _, err := hex.Decode(result, s) + if err != nil { + return nil, err + } + } else { + // bytea_output = escape + for len(s) > 0 { + if s[0] == '\\' { + // escaped '\\' + if len(s) >= 2 && s[1] == '\\' { + result = append(result, '\\') + s = s[2:] + continue + } + + // '\\' followed by an octal number + if len(s) < 4 { + return nil, fmt.Errorf("invalid bytea sequence %v", s) + } + r, err := strconv.ParseUint(string(s[1:4]), 8, 8) + if err != nil { + return nil, fmt.Errorf("could not parse bytea value: %s", err.Error()) + } + result = append(result, byte(r)) + s = s[4:] + } else { + // We hit an unescaped, raw byte. Try to read in as many as + // possible in one go. + i := bytes.IndexByte(s, '\\') + if i == -1 { + result = append(result, s...) + break + } + result = append(result, s[:i]...) + s = s[i:] + } + } + } + + return result, nil +} diff --git a/sequel/types/pgtype/float_array.go b/sequel/types/pgtype/float_array.go new file mode 100644 index 0000000..d0fa265 --- /dev/null +++ b/sequel/types/pgtype/float_array.go @@ -0,0 +1,133 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + "strconv" +) + +// Float64Array represents a one-dimensional array of the PostgreSQL double +// precision type. +type Float64Array[T ~float64] []T + +// Scan implements the sql.Scanner interface. +func (a *Float64Array[T]) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pgtype: cannot convert %T to Float64Array", src) +} + +func (a *Float64Array[T]) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float64Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Float64Array[T], len(elems)) + for i, v := range elems { + f, err := strconv.ParseFloat(string(v), 64) + if err != nil { + return fmt.Errorf("pgtype: parsing array element index %d: %v", i, err) + } + b[i] = T(f) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Float64Array[T]) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendFloat(b, (float64)(a[0]), 'f', -1, 64) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendFloat(b, (float64)(a[i]), 'f', -1, 64) + } + + return string(append(b, '}')), nil + } + return "{}", nil +} + +// Float32Array represents a one-dimensional array of the PostgreSQL double +// precision type. +type Float32Array[T ~float32] []T + +// Scan implements the sql.Scanner interface. +func (a *Float32Array[T]) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + return fmt.Errorf("pgtype: cannot convert %T to Float32Array", src) +} + +func (a *Float32Array[T]) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float32Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Float32Array[T], len(elems)) + for i, v := range elems { + f, err := strconv.ParseFloat(string(v), 32) + if err != nil { + return fmt.Errorf("pgtype: parsing array element index %d: %v", i, err) + } + b[i] = T(f) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Float32Array[T]) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendFloat(b, (float64)(a[0]), 'f', -1, 32) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendFloat(b, (float64)(a[i]), 'f', -1, 32) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} diff --git a/sequel/types/pgtype/int_array.go b/sequel/types/pgtype/int_array.go new file mode 100644 index 0000000..4498332 --- /dev/null +++ b/sequel/types/pgtype/int_array.go @@ -0,0 +1,113 @@ +package pgtype + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "strconv" + "unsafe" + + "golang.org/x/exp/constraints" +) + +// Int64Array represents a one-dimensional array of the PostgreSQL integer types. +type ( + IntArray[T ~int] []T + Int8Array[T ~int8] []T + Int16Array[T ~int16] []T + Int32Array[T ~int32] []T + Int64Array[T ~int64] []T +) + +func IntArrayValue[T constraints.Signed](a []T) driver.Valuer { + return intArray[T]{v: &a} +} + +func IntArrayScanner[T constraints.Signed](a *[]T) sql.Scanner { + return intArray[T]{v: a} +} + +type intArray[T constraints.Signed] struct { + v *[]T +} + +// Value implements the driver.Valuer interface. +func (a intArray[T]) Value() (driver.Value, error) { + return intArrayValue(*a.v) +} + +// Scan implements the sql.Scanner interface. +func (a intArray[T]) Scan(src any) error { + return arrayScan(a.v, src, "IntArray") +} + +// Scan implements the sql.Scanner interface. +func (a *IntArray[T]) Scan(src any) error { + return arrayScan(a, src, "IntArray") +} + +// Value implements the driver.Valuer interface. +func (a IntArray[T]) Value() (driver.Value, error) { + return intArrayValue(a) +} + +// Scan implements the sql.Scanner interface. +func (a *Int8Array[T]) Scan(src any) error { + return arrayScan(a, src, "Int8Array") +} + +// Value implements the driver.Valuer interface. +func (a Int8Array[T]) Value() (driver.Value, error) { + return intArrayValue(a) +} + +// Scan implements the sql.Scanner interface. +func (a *Int16Array[T]) Scan(src any) error { + return arrayScan(a, src, "Int16Array") +} + +// Value implements the driver.Valuer interface. +func (a Int16Array[T]) Value() (driver.Value, error) { + return intArrayValue(a) +} + +// Scan implements the sql.Scanner interface. +func (a *Int32Array[T]) Scan(src any) error { + return arrayScan(a, src, "Int32Array") +} + +// Value implements the driver.Valuer interface. +func (a Int32Array[T]) Value() (driver.Value, error) { + return intArrayValue(a) +} + +// Scan implements the sql.Scanner interface. +func (a *Int64Array[T]) Scan(src any) error { + return arrayScan(a, src, "Int64Array") +} + +// Value implements the driver.Valuer interface. +func (a Int64Array[T]) Value() (driver.Value, error) { + return intArrayValue(a) +} + +func scanBytes[T constraints.Integer, Arr interface{ ~[]T }](a *Arr, src []byte, t string) error { + elems, err := scanLinearArray(src, []byte{','}, t) + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make([]T, len(elems)) + for i, v := range elems { + n, err := strconv.ParseInt(unsafe.String(unsafe.SliceData(v), len(v)), 10, 64) + if err != nil { + return fmt.Errorf("pgtype: parsing array element index %d: %v", i, err) + } + b[i] = T(n) + } + *a = b + } + return nil +} diff --git a/sequel/types/pgtype/string_array.go b/sequel/types/pgtype/string_array.go new file mode 100644 index 0000000..6da109b --- /dev/null +++ b/sequel/types/pgtype/string_array.go @@ -0,0 +1,109 @@ +package pgtype + +import ( + "database/sql" + "database/sql/driver" + "fmt" +) + +func StringArrayScanner[T ~string](a *[]T) sql.Scanner { + return &stringArray[T]{v: a} +} + +type stringArray[T ~string] struct { + v *[]T +} + +// Scan implements the sql.Scanner interface. +func (a *stringArray[T]) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a.v = nil + return nil + } + return fmt.Errorf("pgtype: cannot convert %T to StringArray", src) +} + +func (a stringArray[T]) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "StringArray") + if err != nil { + return err + } + if *a.v != nil && len(elems) == 0 { + *a.v = (*a.v)[:0] + } else { + b := make(StringArray[T], len(elems)) + for i, v := range elems { + if v == nil { + return fmt.Errorf("pgtype: parsing array element index %d: cannot convert nil to string", i) + } + b[i] = T(v) + } + *a.v = b + } + return nil +} + +// StringArray represents a one-dimensional array of the PostgreSQL character types. +type StringArray[T ~string] []T + +// Scan implements the sql.Scanner interface. +func (a *StringArray[T]) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + return fmt.Errorf("pgtype: cannot convert %T to StringArray", src) +} + +func (a *StringArray[T]) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "StringArray") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(StringArray[T], len(elems)) + for i, v := range elems { + if v == nil { + return fmt.Errorf("pgtype: parsing array element index %d: cannot convert nil to string", i) + } + b[i] = T(v) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a StringArray[T]) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+3*n) + b[0] = '{' + + b = appendArrayQuotedBytes(b, []byte(a[0])) + for i := 1; i < n; i++ { + b = append(b, ',') + b = appendArrayQuotedBytes(b, []byte(a[i])) + } + + return string(append(b, '}')), nil + } + return "{}", nil +} diff --git a/sequel/types/pgtype/uint_array.go b/sequel/types/pgtype/uint_array.go new file mode 100644 index 0000000..d22a769 --- /dev/null +++ b/sequel/types/pgtype/uint_array.go @@ -0,0 +1,89 @@ +package pgtype + +import ( + "database/sql" + "database/sql/driver" + + "golang.org/x/exp/constraints" +) + +// Uint64Array represents a one-dimensional array of the PostgreSQL integer types. +type ( + UintArray[T ~uint] []T + Uint8Array[T ~uint8] []T + Uint16Array[T ~uint16] []T + Uint32Array[T ~uint32] []T + Uint64Array[T ~uint64] []T +) + +func UintArrayValue[T constraints.Unsigned](a []T) driver.Valuer { + return uintArray[T]{v: &a} +} + +func UintArrayScanner[T constraints.Unsigned](a *[]T) sql.Scanner { + return uintArray[T]{v: a} +} + +type uintArray[T constraints.Unsigned] struct { + v *[]T +} + +// Value implements the driver.Valuer interface. +func (a uintArray[T]) Value() (driver.Value, error) { + return uintArrayValue(*a.v) +} + +// Scan implements the sql.Scanner interface. +func (a uintArray[T]) Scan(src any) error { + return arrayScan(a.v, src, "UintArray") +} + +// Scan implements the sql.Scanner interface. +func (a *UintArray[T]) Scan(src any) error { + return arrayScan(a, src, "UintArray") +} + +// Value implements the driver.Valuer interface. +func (a UintArray[T]) Value() (driver.Value, error) { + return uintArrayValue(a) +} + +// Scan implements the sql.Scanner interface. +func (a *Uint8Array[T]) Scan(src any) error { + return arrayScan(a, src, "Uint8Array") +} + +// Value implements the driver.Valuer interface. +func (a Uint8Array[T]) Value() (driver.Value, error) { + return uintArrayValue(a) +} + +// Scan implements the sql.Scanner interface. +func (a *Uint16Array[T]) Scan(src any) error { + return arrayScan(a, src, "Uint16Array") +} + +// Value implements the driver.Valuer interface. +func (a Uint16Array[T]) Value() (driver.Value, error) { + return uintArrayValue(a) +} + +// Scan implements the sql.Scanner interface. +func (a *Uint32Array[T]) Scan(src any) error { + return arrayScan(a, src, "Uint32Array") +} + +// Value implements the driver.Valuer interface. +func (a Uint32Array[T]) Value() (driver.Value, error) { + return uintArrayValue(a) +} + +// Scan implements the sql.Scanner interface. +func (a *Uint64Array[T]) Scan(src any) error { + return arrayScan(a, src, "Uint64Array") +} + +// Value implements the driver.Valuer interface. +func (a Uint64Array[T]) Value() (driver.Value, error) { + return uintArrayValue(a) +} diff --git a/sequel/types/string.go b/sequel/types/string.go index 9632fe9..7c35216 100644 --- a/sequel/types/string.go +++ b/sequel/types/string.go @@ -1,7 +1,6 @@ package types import ( - "database/sql" "database/sql/driver" "fmt" "strconv" @@ -16,12 +15,7 @@ type strLike[T StringLikeType] struct { strictType bool } -var ( - _ sql.Scanner = (*strLike[string])(nil) - _ driver.Valuer = (*strLike[string])(nil) -) - -func String[T StringLikeType](addr *T, strict ...bool) *strLike[T] { +func String[T StringLikeType](addr *T, strict ...bool) ValueScanner[T] { var strictType bool if len(strict) > 0 { strictType = strict[0] diff --git a/sequel/types/time.go b/sequel/types/time.go index b7e3240..bfa6296 100644 --- a/sequel/types/time.go +++ b/sequel/types/time.go @@ -25,7 +25,7 @@ var ( _ driver.Valuer = (*timestamp[time.Time])(nil) ) -func Time[T time.Time](addr *T, strict ...bool) *timestamp[T] { +func Time[T time.Time](addr *T, strict ...bool) ValueScanner[T] { var strictType bool if len(strict) > 0 { strictType = strict[0] diff --git a/sequel/types/types.go b/sequel/types/types.go index 284895e..d7884fc 100644 --- a/sequel/types/types.go +++ b/sequel/types/types.go @@ -5,7 +5,17 @@ // This package is a helper library to prevent the value being fallback using reflection in `database/sql`. package types -import "unsafe" +import ( + "database/sql" + "database/sql/driver" + "unsafe" +) + +type ValueScanner[T any] interface { + driver.Valuer + sql.Scanner + Interface() T +} const nullStr = "null"