From 7e8b5aba0f8b821de277fd451674915bd8be54a3 Mon Sep 17 00:00:00 2001 From: Jim Smart Date: Sat, 10 Apr 2021 22:46:40 +0100 Subject: [PATCH] Cleanups --- dialect_mssql.go | 7 ++----- dialect_mysql.go | 7 ++----- dialect_oracle.go | 7 ++----- dialect_postgres.go | 7 ++----- dialect_sqlite.go | 7 ++----- schema.go | 3 ++- 6 files changed, 12 insertions(+), 26 deletions(-) diff --git a/dialect_mssql.go b/dialect_mssql.go index 3df9755..a28ce86 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -2,7 +2,6 @@ package schema import ( "database/sql" - "fmt" ) const mssqlAllColumns = `SELECT * FROM %s WHERE 1=0` @@ -88,8 +87,7 @@ func (mssqlDialect) PrimaryKey(db *sql.DB, name string) ([]string, error) { } func (d mssqlDialect) Table(db *sql.DB, name string) ([]*sql.ColumnType, error) { - q := fmt.Sprintf(mssqlAllColumns, d.escapeIdent(name)) - return fetchColumnTypes(db, q) + return fetchColumnTypes(db, mssqlAllColumns, name, d.escapeIdent) } func (mssqlDialect) TableNames(db *sql.DB) ([]string, error) { @@ -101,8 +99,7 @@ func (d mssqlDialect) Tables(db *sql.DB) (map[string][]*sql.ColumnType, error) { } func (d mssqlDialect) View(db *sql.DB, name string) ([]*sql.ColumnType, error) { - q := fmt.Sprintf(mssqlAllColumns, d.escapeIdent(name)) - return fetchColumnTypes(db, q) + return fetchColumnTypes(db, mssqlAllColumns, name, d.escapeIdent) } func (mssqlDialect) ViewNames(db *sql.DB) ([]string, error) { diff --git a/dialect_mysql.go b/dialect_mysql.go index 9cc1273..e23f9d4 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -2,7 +2,6 @@ package schema import ( "database/sql" - "fmt" ) const mysqlAllColumns = `SELECT * FROM %s LIMIT 0` @@ -61,8 +60,7 @@ func (mysqlDialect) PrimaryKey(db *sql.DB, name string) ([]string, error) { } func (d mysqlDialect) Table(db *sql.DB, name string) ([]*sql.ColumnType, error) { - q := fmt.Sprintf(mysqlAllColumns, d.escapeIdent(name)) - return fetchColumnTypes(db, q) + return fetchColumnTypes(db, mysqlAllColumns, name, d.escapeIdent) } func (mysqlDialect) TableNames(db *sql.DB) ([]string, error) { @@ -74,8 +72,7 @@ func (d mysqlDialect) Tables(db *sql.DB) (map[string][]*sql.ColumnType, error) { } func (d mysqlDialect) View(db *sql.DB, name string) ([]*sql.ColumnType, error) { - q := fmt.Sprintf(mysqlAllColumns, d.escapeIdent(name)) - return fetchColumnTypes(db, q) + return fetchColumnTypes(db, mysqlAllColumns, name, d.escapeIdent) } func (mysqlDialect) ViewNames(db *sql.DB) ([]string, error) { diff --git a/dialect_oracle.go b/dialect_oracle.go index df6f4dd..bca8271 100644 --- a/dialect_oracle.go +++ b/dialect_oracle.go @@ -2,7 +2,6 @@ package schema import ( "database/sql" - "fmt" ) // TODO(js) Is there some way to filter system tables (like mssql)? Or should we always just be using our own schema? @@ -59,8 +58,7 @@ func (oracleDialect) PrimaryKey(db *sql.DB, name string) ([]string, error) { } func (d oracleDialect) Table(db *sql.DB, name string) ([]*sql.ColumnType, error) { - q := fmt.Sprintf(oracleAllColumns, d.escapeIdent(name)) - return fetchColumnTypes(db, q) + return fetchColumnTypes(db, oracleAllColumns, name, d.escapeIdent) } func (oracleDialect) TableNames(db *sql.DB) ([]string, error) { @@ -72,8 +70,7 @@ func (d oracleDialect) Tables(db *sql.DB) (map[string][]*sql.ColumnType, error) } func (d oracleDialect) View(db *sql.DB, name string) ([]*sql.ColumnType, error) { - q := fmt.Sprintf(oracleAllColumns, d.escapeIdent(name)) - return fetchColumnTypes(db, q) + return fetchColumnTypes(db, oracleAllColumns, name, d.escapeIdent) } func (oracleDialect) ViewNames(db *sql.DB) ([]string, error) { diff --git a/dialect_postgres.go b/dialect_postgres.go index 2e34360..910d808 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -2,7 +2,6 @@ package schema import ( "database/sql" - "fmt" ) const postgresAllColumns = `SELECT * FROM %s LIMIT 0` @@ -61,8 +60,7 @@ func (postgresDialect) PrimaryKey(db *sql.DB, name string) ([]string, error) { } func (d postgresDialect) Table(db *sql.DB, name string) ([]*sql.ColumnType, error) { - q := fmt.Sprintf(postgresAllColumns, d.escapeIdent(name)) - return fetchColumnTypes(db, q) + return fetchColumnTypes(db, postgresAllColumns, name, d.escapeIdent) } func (postgresDialect) TableNames(db *sql.DB) ([]string, error) { @@ -74,8 +72,7 @@ func (d postgresDialect) Tables(db *sql.DB) (map[string][]*sql.ColumnType, error } func (d postgresDialect) View(db *sql.DB, name string) ([]*sql.ColumnType, error) { - q := fmt.Sprintf(postgresAllColumns, d.escapeIdent(name)) - return fetchColumnTypes(db, q) + return fetchColumnTypes(db, postgresAllColumns, name, d.escapeIdent) } func (postgresDialect) ViewNames(db *sql.DB) ([]string, error) { diff --git a/dialect_sqlite.go b/dialect_sqlite.go index 27d5d21..0d412b1 100644 --- a/dialect_sqlite.go +++ b/dialect_sqlite.go @@ -2,7 +2,6 @@ package schema import ( "database/sql" - "fmt" ) const sqliteAllColumns = `SELECT * FROM %s LIMIT 0` @@ -25,8 +24,7 @@ func (sqliteDialect) PrimaryKey(db *sql.DB, name string) ([]string, error) { } func (d sqliteDialect) Table(db *sql.DB, name string) ([]*sql.ColumnType, error) { - q := fmt.Sprintf(sqliteAllColumns, d.escapeIdent(name)) - return fetchColumnTypes(db, q) + return fetchColumnTypes(db, sqliteAllColumns, name, d.escapeIdent) } func (sqliteDialect) TableNames(db *sql.DB) ([]string, error) { @@ -38,8 +36,7 @@ func (d sqliteDialect) Tables(db *sql.DB) (map[string][]*sql.ColumnType, error) } func (d sqliteDialect) View(db *sql.DB, name string) ([]*sql.ColumnType, error) { - q := fmt.Sprintf(sqliteAllColumns, d.escapeIdent(name)) - return fetchColumnTypes(db, q) + return fetchColumnTypes(db, sqliteAllColumns, name, d.escapeIdent) } func (sqliteDialect) ViewNames(db *sql.DB) ([]string, error) { diff --git a/schema.go b/schema.go index 77354ee..ec2f30f 100644 --- a/schema.go +++ b/schema.go @@ -198,7 +198,8 @@ func View(db *sql.DB, name string) ([]*sql.ColumnType, error) { // fetchColumnTypes queries the database and returns column's type metadata // for a single table or view. -func fetchColumnTypes(db *sql.DB, query string) ([]*sql.ColumnType, error) { +func fetchColumnTypes(db *sql.DB, query, name string, escapeIdent func(string) string) ([]*sql.ColumnType, error) { + query = fmt.Sprintf(query, escapeIdent(name)) rows, err := db.Query(query) if err != nil { return nil, err