Skip to content

Commit f4fee4b

Browse files
author
cruvie
committed
fix(pgtype):Adjust the usage logic of the fmt. Stringer interface to prioritize handling renamed base types - avoiding renamed base types being automatically treated as string types
1 parent d685c94 commit f4fee4b

File tree

3 files changed

+63
-22
lines changed

3 files changed

+63
-22
lines changed

pgtype/pgtype.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1497,7 +1497,18 @@ func TryWrapBuiltinTypeEncodePlan(value any) (plan WrappedEncodePlanNextSetter,
14971497
case []byte:
14981498
return &wrapByteSliceEncodePlan{}, byteSliceWrapper(value), true
14991499
case fmt.Stringer:
1500-
return &wrapFmtStringerEncodePlan{}, fmtStringerWrapper{value}, true
1500+
// Check if the value is a renamed basic type. If it is, prefer the basic type encoding.
1501+
rv := reflect.ValueOf(value)
1502+
switch rv.Kind() {
1503+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
1504+
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
1505+
reflect.Float32, reflect.Float64, reflect.Bool, reflect.String:
1506+
// For renamed basic types, don't use Stringer interface automatically
1507+
// Let the specific type match above handle it
1508+
default:
1509+
// For structs and other complex types that implement Stringer, use the Stringer interface
1510+
return &wrapFmtStringerEncodePlan{}, fmtStringerWrapper{value}, true
1511+
}
15011512
}
15021513

15031514
return nil, nil, false

values.go

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
package pgx
22

33
import (
4-
"database/sql/driver"
54
"errors"
6-
"reflect"
7-
"time"
85

96
"github.com/jackc/pgx/v5/internal/pgio"
107
"github.com/jackc/pgx/v5/pgtype"
@@ -17,24 +14,6 @@ const (
1714
)
1815

1916
func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) {
20-
// If arg implements driver.Valuer, use Value method
21-
if valuer, ok := arg.(driver.Valuer); ok && valuer != nil {
22-
v, err := valuer.Value()
23-
if err != nil {
24-
return nil, err
25-
}
26-
arg = v
27-
}
28-
29-
fieldValue := reflect.ValueOf(arg)
30-
switch fieldValue.Kind() {
31-
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
32-
if _, ok := arg.(time.Duration); !ok {
33-
arg = fieldValue.Int()
34-
}
35-
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
36-
arg = fieldValue.Uint()
37-
}
3817
buf, err := m.Encode(0, TextFormatCode, arg, []byte{})
3918
if err != nil {
4019
return nil, err

values_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,57 @@ func TestEncodeTypeRename(t *testing.T) {
10041004
})
10051005
}
10061006

1007+
// Define custom types that are aliases of basic types but also implement fmt.Stringer
1008+
type StringerInt32 int32
1009+
type StringerFloat64 float64
1010+
1011+
// Implement the String() method for these types
1012+
func (s StringerInt32) String() string {
1013+
return fmt.Sprintf("StringerInt32(%d)", int32(s))
1014+
}
1015+
1016+
func (s StringerFloat64) String() string {
1017+
return fmt.Sprintf("StringerFloat64(%f)", float64(s))
1018+
}
1019+
1020+
// TestStringerTypes tests custom type aliases that implement the fmt.Stringer interface
1021+
func TestStringerTypes(t *testing.T) {
1022+
t.Parallel()
1023+
1024+
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1025+
defer cancel()
1026+
1027+
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
1028+
// Test values
1029+
inInt := StringerInt32(42)
1030+
var outInt StringerInt32
1031+
1032+
inFloat := StringerFloat64(553.36)
1033+
var outFloat StringerFloat64
1034+
1035+
// Register types with the connection
1036+
conn.TypeMap().RegisterDefaultPgType(inInt, "int4")
1037+
conn.TypeMap().RegisterDefaultPgType(inFloat, "float8")
1038+
1039+
// Test that the underlying values are properly encoded/decoded,
1040+
// not the String() representation
1041+
err := conn.QueryRow(context.Background(), "select $1::int4, $2::float8", inInt, inFloat).
1042+
Scan(&outInt, &outFloat)
1043+
if err != nil {
1044+
t.Fatalf("Failed with Stringer types: %v", err)
1045+
}
1046+
1047+
// Check that the values are correctly preserved (not converted to their String() representation)
1048+
if inInt != outInt {
1049+
t.Errorf("StringerInt32: expected %v, got %v", inInt, outInt)
1050+
}
1051+
1052+
if inFloat != outFloat {
1053+
t.Errorf("StringerFloat64: expected %v, got %v", inFloat, outFloat)
1054+
}
1055+
})
1056+
}
1057+
10071058
// func TestRowDecodeBinary(t *testing.T) {
10081059
// t.Parallel()
10091060

0 commit comments

Comments
 (0)