Skip to content

Commit

Permalink
fix: interface and actual package name
Browse files Browse the repository at this point in the history
  • Loading branch information
si3nloong committed Nov 13, 2023
1 parent 4e7002a commit 3b01715
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 4 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
copy:
cp -rf ./sequel/sequel.go ./codegen/sequel.go.tpl
build:
go build -o sqlgen -ldflags="-s -w" ./main.go
30 changes: 26 additions & 4 deletions codegen/interface.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package codegen

import (
"go/ast"
"go/importer"
"go/parser"
"go/token"
"go/types"

_ "embed"

"golang.org/x/tools/go/packages"
)

Expand All @@ -11,12 +17,15 @@ var (
sqlTabler, sqlColumner,
binaryMarshaler, binaryUnmarshaler,
textMarshaler, textUnmarshaler *types.Interface

//go:embed sequel.go.tpl
sqlBytes []byte
)

func init() {
pkgs, err := packages.Load(&packages.Config{
Mode: packages.NeedTypes,
}, "database/sql...", "github.com/si3nloong/sqlgen...", "encoding")
}, "database/sql...", "encoding")
if err != nil {
panic(err)
}
Expand All @@ -32,10 +41,23 @@ func init() {
binaryUnmarshaler = p.Types.Scope().Lookup("BinaryUnmarshaler").Type().Underlying().(*types.Interface)
textMarshaler = p.Types.Scope().Lookup("TextMarshaler").Type().Underlying().(*types.Interface)
textUnmarshaler = p.Types.Scope().Lookup("TextUnmarshaler").Type().Underlying().(*types.Interface)
case "github.com/si3nloong/sqlgen/sequel":
sqlTabler = p.Types.Scope().Lookup("Tabler").Type().Underlying().(*types.Interface)
sqlColumner = p.Types.Scope().Lookup("Columner").Type().Underlying().(*types.Interface)
// case "github.com/si3nloong/sqlgen/sequel":
// sqlTabler = p.Types.Scope().Lookup("Tabler").Type().Underlying().(*types.Interface)
// sqlColumner = p.Types.Scope().Lookup("Columner").Type().Underlying().(*types.Interface)
// sqlRower = p.Types.Scope().Lookup("Valuer").Type().Underlying().(*types.Interface)
}

fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "", sqlBytes, parser.AllErrors)
if err != nil {
panic(err)
}
conf := types.Config{Importer: importer.Default()}
pkg, err := conf.Check("sequel", fset, []*ast.File{f}, nil)
if err != nil {
panic(err)
}
sqlTabler = pkg.Scope().Lookup("Tabler").Type().Underlying().(*types.Interface)
sqlColumner = pkg.Scope().Lookup("Columner").Type().Underlying().(*types.Interface)
}
}
10 changes: 10 additions & 0 deletions codegen/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,19 @@ func (p *Package) Import(pkg *types.Package) (*types.Package, bool) {
}); i > -1 {
return p.imports[i], false
}

pkgs, err := packages.Load(&packages.Config{
Mode: packages.NeedName,
}, pkg.Path())
if err != nil {
return pkg, false
}

// If the import path is this package, skip to import
if pkg.Path() == p.PkgPath() {
return nil, false
} else if pkgs[0].Name != "" && pkgs[0].Name != pkg.Name() {
pkg.SetName(pkgs[0].Name)
}
alias := p.newAliasIfExists(pkg)
pkg.SetName(alias)
Expand Down
112 changes: 112 additions & 0 deletions codegen/sequel.go.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package sequel

import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"io"
)

// For rename table name
type Name struct{}

type Scanner[T any] interface {
*T
Addrs() []any
}

type TableColumnValuer[T any] interface {
Tabler
Columner
Valuer
}

type KeyValuer[T any] interface {
Keyer
Tabler
Columner
Valuer
}

type KeyValueScanner[T any] interface {
KeyValuer[T]
Scanner[T]
}

type Keyer interface {
PK() (columnName string, pos int, value driver.Value)
}

type AutoIncrKeyer interface {
Keyer
IsAutoIncr()
}

type Tabler interface {
TableName() string
}

type Columner interface {
Columns() []string
}

type Valuer interface {
Values() []any
}

type DB interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}

type Dialect interface {
Var(n int) string
Wrap(v string) string
Driver() string
}

type Migrator interface {
Tabler
CreateTableStmt() string
AlterTableStmt() string
}

type SingleInserter interface {
InsertOneStmt() string
}

type KeyFinder interface {
Keyer
FindByPKStmt() string
}

type Inserter interface {
Columner
InsertVarQuery() string
}

type StmtWriter interface {
io.StringWriter
io.ByteWriter
}

type StmtBuilder interface {
StmtWriter
Var(query string, v any)
Vars(query string, v []any)
}

type Stmt interface {
StmtBuilder
fmt.Stringer
Args() []any
Reset()
}

type (
WhereClause func(StmtBuilder)
SetClause func(StmtBuilder)
OrderByClause func(StmtWriter)
)
1 change: 1 addition & 0 deletions examples/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ toolchain go1.21.3
require (
cloud.google.com/go v0.110.10
github.com/go-sql-driver/mysql v1.7.1
github.com/gofrs/uuid/v5 v5.0.0
github.com/google/uuid v1.4.0
github.com/jaswdr/faker v1.16.0
github.com/lib/pq v1.10.7
Expand Down
2 changes: 2 additions & 0 deletions examples/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/gofrs/uuid/v5 v5.0.0 h1:p544++a97kEL+svbcFbCQVM9KFu0Yo25UoISXGNNH9M=
github.com/gofrs/uuid/v5 v5.0.0/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4=
Expand Down
45 changes: 45 additions & 0 deletions examples/testcase/struct-field/version/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 45 additions & 0 deletions examples/testcase/struct-field/version/generated.go.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Code generated by sqlgen, version v1.0.0-alpha.2. DO NOT EDIT.

package version

import (
"database/sql"
"database/sql/driver"

uuid "github.com/gofrs/uuid/v5"
"github.com/si3nloong/sqlgen/sequel"
)

func (v Version) CreateTableStmt() string {
return "CREATE TABLE IF NOT EXISTS " + v.TableName() + " (`id` VARCHAR(36) NOT NULL,PRIMARY KEY (`id`));"
}
func (Version) AlterTableStmt() string {
return "ALTER TABLE `version` MODIFY `id` VARCHAR(36) NOT NULL;"
}
func (Version) TableName() string {
return "`version`"
}
func (v Version) InsertOneStmt() string {
return "INSERT INTO `version` (`id`) VALUES (?);"
}
func (Version) InsertVarQuery() string {
return "(?)"
}
func (Version) Columns() []string {
return []string{"`id`"}
}
func (v Version) PK() (columnName string, pos int, value driver.Value) {
return "`id`", 0, (driver.Valuer)(v.ID)
}
func (v Version) FindByPKStmt() string {
return "SELECT `id` FROM `version` WHERE `id` = ? LIMIT 1;"
}
func (v Version) Values() []any {
return []any{(driver.Valuer)(v.ID)}
}
func (v *Version) Addrs() []any {
return []any{(sql.Scanner)(&v.ID)}
}
func (v Version) GetID() sequel.ColumnValuer[uuid.UUID] {
return sequel.Column[uuid.UUID]("`id`", v.ID, func(vi uuid.UUID) driver.Value { return (driver.Valuer)(vi) })
}
7 changes: 7 additions & 0 deletions examples/testcase/struct-field/version/version.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package version

import "github.com/gofrs/uuid/v5"

type Version struct {
ID uuid.UUID `sql:",pk"`
}

0 comments on commit 3b01715

Please sign in to comment.