Skip to content

Commit

Permalink
optimize DataTypeOf func and other code optimize for mysql package (#43)
Browse files Browse the repository at this point in the history
* optimize DataTypeOf func and other code optimize for mysql package

* adjust DecimalSize func
  • Loading branch information
daheige authored Aug 9, 2021
1 parent 99d7c92 commit 7ed9f94
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 191 deletions.
169 changes: 87 additions & 82 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ type Column struct {
name string
nullable sql.NullString
datatype string
maxlen sql.NullInt64
maxLen sql.NullInt64
precision sql.NullInt64
scale sql.NullInt64
datetimeprecision sql.NullInt64
datetimePrecision sql.NullInt64
}

func (c Column) Name() string {
Expand All @@ -33,38 +33,37 @@ func (c Column) DatabaseTypeName() string {
return c.datatype
}

func (c Column) Length() (length int64, ok bool) {
ok = c.maxlen.Valid
if ok {
length = c.maxlen.Int64
} else {
length = 0
func (c Column) Length() (int64, bool) {
if c.maxLen.Valid {
return c.maxLen.Int64, c.maxLen.Valid
}
return

return 0, false
}

func (c Column) Nullable() (nullable bool, ok bool) {
func (c Column) Nullable() (bool, bool) {
if c.nullable.Valid {
nullable, ok = c.nullable.String == "YES", true
} else {
nullable, ok = false, false
return c.nullable.String == "YES", true
}
return

return false, false
}

func (c Column) DecimalSize() (precision int64, scale int64, ok bool) {
// DecimalSize return precision int64, scale int64, ok bool
func (c Column) DecimalSize() (int64, int64, bool) {
if c.precision.Valid {
if c.scale.Valid {
precision, scale, ok = c.precision.Int64, c.scale.Int64, true
} else {
precision, scale, ok = c.precision.Int64, 0, true
return c.precision.Int64, c.scale.Int64, true
}
} else if c.datetimeprecision.Valid {
precision, scale, ok = c.datetimeprecision.Int64, 0, true
} else {
precision, scale, ok = 0, 0, false

return c.precision.Int64, 0, true
}
return

if c.datetimePrecision.Valid {
return c.datetimePrecision.Int64, 0, true
}

return 0, 0, false
}

func (m Migrator) FullDataTypeOf(field *schema.Field) clause.Expr {
Expand All @@ -91,69 +90,71 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {

func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if m.Dialector.DontSupportRenameColumn {
var field *schema.Field
if f := stmt.Schema.LookUpField(oldName); f != nil {
oldName = f.DBName
field = f
}
if !m.Dialector.DontSupportRenameColumn {
return m.Migrator.RenameColumn(value, oldName, newName)
}

if f := stmt.Schema.LookUpField(newName); f != nil {
newName = f.DBName
field = f
}
var field *schema.Field
if f := stmt.Schema.LookUpField(oldName); f != nil {
oldName = f.DBName
field = f
}

if field != nil {
return m.DB.Exec(
"ALTER TABLE ? CHANGE ? ? ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, m.FullDataTypeOf(field),
).Error
}
} else {
return m.Migrator.RenameColumn(value, oldName, newName)
if f := stmt.Schema.LookUpField(newName); f != nil {
newName = f.DBName
field = f
}

if field != nil {
return m.DB.Exec(
"ALTER TABLE ? CHANGE ? ? ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName},
clause.Column{Name: newName}, m.FullDataTypeOf(field),
).Error
}

return fmt.Errorf("failed to look up field with name: %s", newName)
})
}

func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
if m.Dialector.DontSupportRenameIndex {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
err := m.DropIndex(value, oldName)
if err == nil {
if idx := stmt.Schema.LookIndex(newName); idx == nil {
if idx = stmt.Schema.LookIndex(oldName); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: newName}, clause.Table{Name: stmt.Table}, opts}

createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ? ON ??"

if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}

return m.DB.Exec(createIndexSQL, values...).Error
}
}

err = m.CreateIndex(value, newName)
}

return err
})
} else {
if !m.Dialector.DontSupportRenameIndex {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec(
"ALTER TABLE ? RENAME INDEX ? TO ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error
})
}

return m.RunWithValue(value, func(stmt *gorm.Statement) error {
err := m.DropIndex(value, oldName)
if err != nil {
return err
}

if idx := stmt.Schema.LookIndex(newName); idx == nil {
if idx = stmt.Schema.LookIndex(oldName); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: newName}, clause.Table{Name: stmt.Table}, opts}

createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ? ON ??"

if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}

return m.DB.Exec(createIndexSQL, values...).Error
}
}

return m.CreateIndex(value, newName)
})

}

func (m Migrator) DropTable(values ...interface{}) error {
Expand Down Expand Up @@ -187,9 +188,10 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
})
}

func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) {
columnTypes = make([]gorm.ColumnType, 0)
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
// ColumnTypes column types return columnTypes,error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
var (
currentDatabase = m.DB.Migrator().CurrentDatabase()
columnTypeSQL = "SELECT column_name, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_scale "
Expand All @@ -200,27 +202,30 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
}
columnTypeSQL += "FROM information_schema.columns WHERE table_schema = ? AND table_name = ?"

columns, err := m.DB.Raw(columnTypeSQL, currentDatabase, stmt.Table).Rows()
if err != nil {
return err
columns, rowErr := m.DB.Raw(columnTypeSQL, currentDatabase, stmt.Table).Rows()
if rowErr != nil {
return rowErr
}

defer columns.Close()

for columns.Next() {
var column Column
var values = []interface{}{&column.name, &column.nullable, &column.datatype, &column.maxlen, &column.precision, &column.scale}
var values = []interface{}{&column.name, &column.nullable, &column.datatype,
&column.maxLen, &column.precision, &column.scale}

if !m.DisableDatetimePrecision {
values = append(values, &column.datetimeprecision)
values = append(values, &column.datetimePrecision)
}

if err = columns.Scan(values...); err != nil {
return err
if scanErr := columns.Scan(values...); scanErr != nil {
return scanErr
}
columnTypes = append(columnTypes, column)
}

return err
return nil
})
return

return columnTypes, err
}
Loading

0 comments on commit 7ed9f94

Please sign in to comment.