Skip to content

Commit

Permalink
Cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
jimsmart committed Apr 10, 2021
1 parent 0a1866c commit 7e8b5ab
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 26 deletions.
7 changes: 2 additions & 5 deletions dialect_mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package schema

import (
"database/sql"
"fmt"
)

const mssqlAllColumns = `SELECT * FROM %s WHERE 1=0`
Expand Down Expand Up @@ -88,8 +87,7 @@ func (mssqlDialect) PrimaryKey(db *sql.DB, name string) ([]string, error) {
}

func (d mssqlDialect) Table(db *sql.DB, name string) ([]*sql.ColumnType, error) {
q := fmt.Sprintf(mssqlAllColumns, d.escapeIdent(name))
return fetchColumnTypes(db, q)
return fetchColumnTypes(db, mssqlAllColumns, name, d.escapeIdent)
}

func (mssqlDialect) TableNames(db *sql.DB) ([]string, error) {
Expand All @@ -101,8 +99,7 @@ func (d mssqlDialect) Tables(db *sql.DB) (map[string][]*sql.ColumnType, error) {
}

func (d mssqlDialect) View(db *sql.DB, name string) ([]*sql.ColumnType, error) {
q := fmt.Sprintf(mssqlAllColumns, d.escapeIdent(name))
return fetchColumnTypes(db, q)
return fetchColumnTypes(db, mssqlAllColumns, name, d.escapeIdent)
}

func (mssqlDialect) ViewNames(db *sql.DB) ([]string, error) {
Expand Down
7 changes: 2 additions & 5 deletions dialect_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package schema

import (
"database/sql"
"fmt"
)

const mysqlAllColumns = `SELECT * FROM %s LIMIT 0`
Expand Down Expand Up @@ -61,8 +60,7 @@ func (mysqlDialect) PrimaryKey(db *sql.DB, name string) ([]string, error) {
}

func (d mysqlDialect) Table(db *sql.DB, name string) ([]*sql.ColumnType, error) {
q := fmt.Sprintf(mysqlAllColumns, d.escapeIdent(name))
return fetchColumnTypes(db, q)
return fetchColumnTypes(db, mysqlAllColumns, name, d.escapeIdent)
}

func (mysqlDialect) TableNames(db *sql.DB) ([]string, error) {
Expand All @@ -74,8 +72,7 @@ func (d mysqlDialect) Tables(db *sql.DB) (map[string][]*sql.ColumnType, error) {
}

func (d mysqlDialect) View(db *sql.DB, name string) ([]*sql.ColumnType, error) {
q := fmt.Sprintf(mysqlAllColumns, d.escapeIdent(name))
return fetchColumnTypes(db, q)
return fetchColumnTypes(db, mysqlAllColumns, name, d.escapeIdent)
}

func (mysqlDialect) ViewNames(db *sql.DB) ([]string, error) {
Expand Down
7 changes: 2 additions & 5 deletions dialect_oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package schema

import (
"database/sql"
"fmt"
)

// TODO(js) Is there some way to filter system tables (like mssql)? Or should we always just be using our own schema?
Expand Down Expand Up @@ -59,8 +58,7 @@ func (oracleDialect) PrimaryKey(db *sql.DB, name string) ([]string, error) {
}

func (d oracleDialect) Table(db *sql.DB, name string) ([]*sql.ColumnType, error) {
q := fmt.Sprintf(oracleAllColumns, d.escapeIdent(name))
return fetchColumnTypes(db, q)
return fetchColumnTypes(db, oracleAllColumns, name, d.escapeIdent)
}

func (oracleDialect) TableNames(db *sql.DB) ([]string, error) {
Expand All @@ -72,8 +70,7 @@ func (d oracleDialect) Tables(db *sql.DB) (map[string][]*sql.ColumnType, error)
}

func (d oracleDialect) View(db *sql.DB, name string) ([]*sql.ColumnType, error) {
q := fmt.Sprintf(oracleAllColumns, d.escapeIdent(name))
return fetchColumnTypes(db, q)
return fetchColumnTypes(db, oracleAllColumns, name, d.escapeIdent)
}

func (oracleDialect) ViewNames(db *sql.DB) ([]string, error) {
Expand Down
7 changes: 2 additions & 5 deletions dialect_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package schema

import (
"database/sql"
"fmt"
)

const postgresAllColumns = `SELECT * FROM %s LIMIT 0`
Expand Down Expand Up @@ -61,8 +60,7 @@ func (postgresDialect) PrimaryKey(db *sql.DB, name string) ([]string, error) {
}

func (d postgresDialect) Table(db *sql.DB, name string) ([]*sql.ColumnType, error) {
q := fmt.Sprintf(postgresAllColumns, d.escapeIdent(name))
return fetchColumnTypes(db, q)
return fetchColumnTypes(db, postgresAllColumns, name, d.escapeIdent)
}

func (postgresDialect) TableNames(db *sql.DB) ([]string, error) {
Expand All @@ -74,8 +72,7 @@ func (d postgresDialect) Tables(db *sql.DB) (map[string][]*sql.ColumnType, error
}

func (d postgresDialect) View(db *sql.DB, name string) ([]*sql.ColumnType, error) {
q := fmt.Sprintf(postgresAllColumns, d.escapeIdent(name))
return fetchColumnTypes(db, q)
return fetchColumnTypes(db, postgresAllColumns, name, d.escapeIdent)
}

func (postgresDialect) ViewNames(db *sql.DB) ([]string, error) {
Expand Down
7 changes: 2 additions & 5 deletions dialect_sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package schema

import (
"database/sql"
"fmt"
)

const sqliteAllColumns = `SELECT * FROM %s LIMIT 0`
Expand All @@ -25,8 +24,7 @@ func (sqliteDialect) PrimaryKey(db *sql.DB, name string) ([]string, error) {
}

func (d sqliteDialect) Table(db *sql.DB, name string) ([]*sql.ColumnType, error) {
q := fmt.Sprintf(sqliteAllColumns, d.escapeIdent(name))
return fetchColumnTypes(db, q)
return fetchColumnTypes(db, sqliteAllColumns, name, d.escapeIdent)
}

func (sqliteDialect) TableNames(db *sql.DB) ([]string, error) {
Expand All @@ -38,8 +36,7 @@ func (d sqliteDialect) Tables(db *sql.DB) (map[string][]*sql.ColumnType, error)
}

func (d sqliteDialect) View(db *sql.DB, name string) ([]*sql.ColumnType, error) {
q := fmt.Sprintf(sqliteAllColumns, d.escapeIdent(name))
return fetchColumnTypes(db, q)
return fetchColumnTypes(db, sqliteAllColumns, name, d.escapeIdent)
}

func (sqliteDialect) ViewNames(db *sql.DB) ([]string, error) {
Expand Down
3 changes: 2 additions & 1 deletion schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ func View(db *sql.DB, name string) ([]*sql.ColumnType, error) {

// fetchColumnTypes queries the database and returns column's type metadata
// for a single table or view.
func fetchColumnTypes(db *sql.DB, query string) ([]*sql.ColumnType, error) {
func fetchColumnTypes(db *sql.DB, query, name string, escapeIdent func(string) string) ([]*sql.ColumnType, error) {
query = fmt.Sprintf(query, escapeIdent(name))
rows, err := db.Query(query)
if err != nil {
return nil, err
Expand Down

0 comments on commit 7e8b5ab

Please sign in to comment.