From 11c4331058e5fed7d05707c5a6a6997947509f41 Mon Sep 17 00:00:00 2001 From: molon Date: Mon, 24 Jun 2024 17:42:59 +0800 Subject: [PATCH] feat: add MapColumns method (#6901) * add MapColumns method * fix MapColumns desc * add TestMapColumns --- chainable_api.go | 7 +++++++ scan.go | 9 +++++++++ statement.go | 6 ++++-- tests/query_test.go | 24 ++++++++++++++++++++++-- 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index e6c90cefe2..8953413d5f 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -185,6 +185,13 @@ func (db *DB) Omit(columns ...string) (tx *DB) { return } +// MapColumns modify the column names in the query results to facilitate align to the corresponding structural fields +func (db *DB) MapColumns(m map[string]string) (tx *DB) { + tx = db.getInstance() + tx.Statement.ColumnMapping = m + return +} + // Where add conditions // // See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND. diff --git a/scan.go b/scan.go index eac6ca9de0..d852c2c9f9 100644 --- a/scan.go +++ b/scan.go @@ -131,6 +131,15 @@ func Scan(rows Rows, db *DB, mode ScanMode) { onConflictDonothing = mode&ScanOnConflictDoNothing != 0 ) + if len(db.Statement.ColumnMapping) > 0 { + for i, column := range columns { + v, ok := db.Statement.ColumnMapping[column] + if ok { + columns[i] = v + } + } + } + db.RowsAffected = 0 switch dest := db.Statement.Dest.(type) { diff --git a/statement.go b/statement.go index ae79aa3218..39e05d093b 100644 --- a/statement.go +++ b/statement.go @@ -30,8 +30,9 @@ type Statement struct { Clauses map[string]clause.Clause BuildClauses []string Distinct bool - Selects []string // selected columns - Omits []string // omit columns + Selects []string // selected columns + Omits []string // omit columns + ColumnMapping map[string]string // map columns Joins []join Preloads map[string][]interface{} Settings sync.Map @@ -513,6 +514,7 @@ func (stmt *Statement) clone() *Statement { Distinct: stmt.Distinct, Selects: stmt.Selects, Omits: stmt.Omits, + ColumnMapping: stmt.ColumnMapping, Preloads: map[string][]interface{}{}, ConnPool: stmt.ConnPool, Schema: stmt.Schema, diff --git a/tests/query_test.go b/tests/query_test.go index 79f7182bbb..566763c515 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -860,6 +860,28 @@ func TestOmitWithAllFields(t *testing.T) { } } +func TestMapColumns(t *testing.T) { + user := User{Name: "MapColumnsUser", Age: 12} + DB.Save(&user) + + type result struct { + Name string + Nickname string + Age uint + } + var res result + DB.Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "nickname"}).Scan(&res) + if res.Nickname != user.Name { + t.Errorf("Expected res.Nickname to be %s, but got %s", user.Name, res.Nickname) + } + if res.Name != "" { + t.Errorf("Expected res.Name to be empty, but got %s", res.Name) + } + if res.Age != user.Age { + t.Errorf("Expected res.Age to be %d, but got %d", user.Age, res.Age) + } +} + func TestPluckWithSelect(t *testing.T) { users := []User{ {Name: "pluck_with_select_1", Age: 25}, @@ -1194,7 +1216,6 @@ func TestSubQueryWithRaw(t *testing.T) { Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}). Group("name"), ).Count(&count).Error - if err != nil { t.Errorf("Expected to get no errors, but got %v", err) } @@ -1210,7 +1231,6 @@ func TestSubQueryWithRaw(t *testing.T) { Not("age <= ?", 10).Not("name IN (?)", []string{"subquery_raw_1", "subquery_raw_3"}). Group("name"), ).Count(&count).Error - if err != nil { t.Errorf("Expected to get no errors, but got %v", err) }