Skip to content

Commit 1d557f9

Browse files
committed
Remove PlanScan memoization
Previously, PlanScan used a cache to improve performance. However, the cache could get confused in certain cases. For example, the following would fail: m := pgtype.NewMap() var err error var tags any err = m.Scan(pgtype.TextArrayOID, pgx.TextFormatCode, []byte("{foo,bar,baz}"), &tags) require.NoError(t, err) var cells [][]string err = m.Scan(pgtype.TextArrayOID, pgx.TextFormatCode, []byte("{{foo,bar},{baz,quz}}"), &cells) require.NoError(t, err) This commit removes the memoization and adds a test to ensure that this case works. The benchmarks were also updated to include an array of strings to ensure this path is benchmarked. As it turned out, there was next to no performance difference between the cached and non-cached versions. It's possible there may be a performance impact in certain complicated cases, but I have not encountered any. If there are any performance issues, we can optimize the narrower case rather than adding memoization everywhere.
1 parent de7fe81 commit 1d557f9

File tree

4 files changed

+38
-47
lines changed

4 files changed

+38
-47
lines changed

bench_test.go

+20-17
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,7 @@ type BenchRowSimple struct {
944944
BirthDate time.Time
945945
Weight int32
946946
Height int32
947+
Tags []string
947948
UpdateTime time.Time
948949
}
949950

@@ -957,13 +958,13 @@ func BenchmarkSelectRowsScanSimple(b *testing.B) {
957958
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
958959
br := &BenchRowSimple{}
959960
for i := 0; i < b.N; i++ {
960-
rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
961+
rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
961962
if err != nil {
962963
b.Fatal(err)
963964
}
964965

965966
for rows.Next() {
966-
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime)
967+
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime)
967968
}
968969

969970
if rows.Err() != nil {
@@ -982,6 +983,7 @@ type BenchRowStringBytes struct {
982983
BirthDate time.Time
983984
Weight int32
984985
Height int32
986+
Tags []string
985987
UpdateTime time.Time
986988
}
987989

@@ -995,13 +997,13 @@ func BenchmarkSelectRowsScanStringBytes(b *testing.B) {
995997
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
996998
br := &BenchRowStringBytes{}
997999
for i := 0; i < b.N; i++ {
998-
rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
1000+
rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
9991001
if err != nil {
10001002
b.Fatal(err)
10011003
}
10021004

10031005
for rows.Next() {
1004-
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime)
1006+
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime)
10051007
}
10061008

10071009
if rows.Err() != nil {
@@ -1020,6 +1022,7 @@ type BenchRowDecoder struct {
10201022
BirthDate pgtype.Date
10211023
Weight pgtype.Int4
10221024
Height pgtype.Int4
1025+
Tags pgtype.FlatArray[string]
10231026
UpdateTime pgtype.Timestamptz
10241027
}
10251028

@@ -1045,7 +1048,7 @@ func BenchmarkSelectRowsScanDecoder(b *testing.B) {
10451048
for i := 0; i < b.N; i++ {
10461049
rows, err := conn.Query(
10471050
context.Background(),
1048-
"select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n",
1051+
"select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n",
10491052
pgx.QueryResultFormats{format.code},
10501053
rowCount,
10511054
)
@@ -1054,7 +1057,7 @@ func BenchmarkSelectRowsScanDecoder(b *testing.B) {
10541057
}
10551058

10561059
for rows.Next() {
1057-
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime)
1060+
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime)
10581061
}
10591062

10601063
if rows.Err() != nil {
@@ -1076,7 +1079,7 @@ func BenchmarkSelectRowsPgConnExecText(b *testing.B) {
10761079
for _, rowCount := range rowCounts {
10771080
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
10781081
for i := 0; i < b.N; i++ {
1079-
mrr := conn.PgConn().Exec(context.Background(), fmt.Sprintf("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + %d) n", rowCount))
1082+
mrr := conn.PgConn().Exec(context.Background(), fmt.Sprintf("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + %d) n", rowCount))
10801083
for mrr.NextResult() {
10811084
rr := mrr.ResultReader()
10821085
for rr.NextRow() {
@@ -1113,11 +1116,11 @@ func BenchmarkSelectRowsPgConnExecParams(b *testing.B) {
11131116
for i := 0; i < b.N; i++ {
11141117
rr := conn.PgConn().ExecParams(
11151118
context.Background(),
1116-
"select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n",
1119+
"select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n",
11171120
[][]byte{[]byte(strconv.FormatInt(rowCount, 10))},
11181121
nil,
11191122
nil,
1120-
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code},
1123+
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code},
11211124
)
11221125
for rr.NextRow() {
11231126
rr.Values()
@@ -1143,7 +1146,7 @@ func BenchmarkSelectRowsSimpleCollectRowsRowToStructByPos(b *testing.B) {
11431146
for _, rowCount := range rowCounts {
11441147
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
11451148
for i := 0; i < b.N; i++ {
1146-
rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
1149+
rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
11471150
benchRows, err := pgx.CollectRows(rows, pgx.RowToStructByPos[BenchRowSimple])
11481151
if err != nil {
11491152
b.Fatal(err)
@@ -1167,7 +1170,7 @@ func BenchmarkSelectRowsSimpleAppendRowsRowToStructByPos(b *testing.B) {
11671170
benchRows := make([]BenchRowSimple, 0, rowCount)
11681171
for i := 0; i < b.N; i++ {
11691172
benchRows = benchRows[:0]
1170-
rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
1173+
rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
11711174
var err error
11721175
benchRows, err = pgx.AppendRows(benchRows, rows, pgx.RowToStructByPos[BenchRowSimple])
11731176
if err != nil {
@@ -1190,7 +1193,7 @@ func BenchmarkSelectRowsSimpleCollectRowsRowToStructByName(b *testing.B) {
11901193
for _, rowCount := range rowCounts {
11911194
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
11921195
for i := 0; i < b.N; i++ {
1193-
rows, _ := conn.Query(context.Background(), "select n as id, 'Adam' as first_name, 'Smith ' || n as last_name, 'male' as sex, '1952-06-16'::date as birth_date, 258 as weight, 72 as height, '2001-01-28 01:02:03-05'::timestamptz as update_time from generate_series(100001, 100000 + $1) n", rowCount)
1196+
rows, _ := conn.Query(context.Background(), "select n as id, 'Adam' as first_name, 'Smith ' || n as last_name, 'male' as sex, '1952-06-16'::date as birth_date, 258 as weight, 72 as height, '{foo,bar,baz}'::text[] as tags, '2001-01-28 01:02:03-05'::timestamptz as update_time from generate_series(100001, 100000 + $1) n", rowCount)
11941197
benchRows, err := pgx.CollectRows(rows, pgx.RowToStructByName[BenchRowSimple])
11951198
if err != nil {
11961199
b.Fatal(err)
@@ -1214,7 +1217,7 @@ func BenchmarkSelectRowsSimpleAppendRowsRowToStructByName(b *testing.B) {
12141217
benchRows := make([]BenchRowSimple, 0, rowCount)
12151218
for i := 0; i < b.N; i++ {
12161219
benchRows = benchRows[:0]
1217-
rows, _ := conn.Query(context.Background(), "select n as id, 'Adam' as first_name, 'Smith ' || n as last_name, 'male' as sex, '1952-06-16'::date as birth_date, 258 as weight, 72 as height, '2001-01-28 01:02:03-05'::timestamptz as update_time from generate_series(100001, 100000 + $1) n", rowCount)
1220+
rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
12181221
var err error
12191222
benchRows, err = pgx.AppendRows(benchRows, rows, pgx.RowToStructByPos[BenchRowSimple])
12201223
if err != nil {
@@ -1234,7 +1237,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
12341237

12351238
rowCounts := getSelectRowsCounts(b)
12361239

1237-
_, err := conn.PgConn().Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
1240+
_, err := conn.PgConn().Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
12381241
if err != nil {
12391242
b.Fatal(err)
12401243
}
@@ -1256,7 +1259,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
12561259
"ps1",
12571260
[][]byte{[]byte(strconv.FormatInt(rowCount, 10))},
12581261
nil,
1259-
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code},
1262+
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code},
12601263
)
12611264
for rr.NextRow() {
12621265
rr.Values()
@@ -1335,7 +1338,7 @@ func BenchmarkSelectRowsRawPrepared(b *testing.B) {
13351338
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")).PgConn()
13361339
defer conn.Close(context.Background())
13371340

1338-
_, err := conn.Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
1341+
_, err := conn.Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
13391342
if err != nil {
13401343
b.Fatal(err)
13411344
}
@@ -1358,7 +1361,7 @@ func BenchmarkSelectRowsRawPrepared(b *testing.B) {
13581361
"ps1",
13591362
[][]byte{[]byte(strconv.FormatInt(rowCount, 10))},
13601363
nil,
1361-
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code},
1364+
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code},
13621365
)
13631366
_, err := rr.Close()
13641367
require.NoError(b, err)

pgtype/pgtype.go

+3-29
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ type Map struct {
202202

203203
reflectTypeToType map[reflect.Type]*Type
204204

205-
memoizedScanPlans map[uint32]map[reflect.Type][2]ScanPlan
206205
memoizedEncodePlans map[uint32]map[reflect.Type][2]EncodePlan
207206

208207
// TryWrapEncodePlanFuncs is a slice of functions that will wrap a value that cannot be encoded by the Codec. Every
@@ -236,7 +235,6 @@ func NewMap() *Map {
236235
reflectTypeToName: make(map[reflect.Type]string),
237236
oidToFormatCode: make(map[uint32]int16),
238237

239-
memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan),
240238
memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan),
241239

242240
TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{
@@ -276,9 +274,6 @@ func (m *Map) RegisterType(t *Type) {
276274

277275
// Invalidated by type registration
278276
m.reflectTypeToType = nil
279-
for k := range m.memoizedScanPlans {
280-
delete(m.memoizedScanPlans, k)
281-
}
282277
for k := range m.memoizedEncodePlans {
283278
delete(m.memoizedEncodePlans, k)
284279
}
@@ -292,9 +287,6 @@ func (m *Map) RegisterDefaultPgType(value any, name string) {
292287

293288
// Invalidated by type registration
294289
m.reflectTypeToType = nil
295-
for k := range m.memoizedScanPlans {
296-
delete(m.memoizedScanPlans, k)
297-
}
298290
for k := range m.memoizedEncodePlans {
299291
delete(m.memoizedEncodePlans, k)
300292
}
@@ -1067,32 +1059,14 @@ func (plan *wrapPtrArrayReflectScanPlan) Scan(src []byte, target any) error {
10671059

10681060
// PlanScan prepares a plan to scan a value into target.
10691061
func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan {
1070-
return m.planScanDepth(oid, formatCode, target, 0)
1062+
return m.planScan(oid, formatCode, target, 0)
10711063
}
10721064

1073-
func (m *Map) planScanDepth(oid uint32, formatCode int16, target any, depth int) ScanPlan {
1065+
func (m *Map) planScan(oid uint32, formatCode int16, target any, depth int) ScanPlan {
10741066
if depth > 8 {
10751067
return &scanPlanFail{m: m, oid: oid, formatCode: formatCode}
10761068
}
10771069

1078-
oidMemo := m.memoizedScanPlans[oid]
1079-
if oidMemo == nil {
1080-
oidMemo = make(map[reflect.Type][2]ScanPlan)
1081-
m.memoizedScanPlans[oid] = oidMemo
1082-
}
1083-
targetReflectType := reflect.TypeOf(target)
1084-
typeMemo := oidMemo[targetReflectType]
1085-
plan := typeMemo[formatCode]
1086-
if plan == nil {
1087-
plan = m.planScan(oid, formatCode, target, depth)
1088-
typeMemo[formatCode] = plan
1089-
oidMemo[targetReflectType] = typeMemo
1090-
}
1091-
1092-
return plan
1093-
}
1094-
1095-
func (m *Map) planScan(oid uint32, formatCode int16, target any, depth int) ScanPlan {
10961070
if target == nil {
10971071
return &scanPlanFail{m: m, oid: oid, formatCode: formatCode}
10981072
}
@@ -1152,7 +1126,7 @@ func (m *Map) planScan(oid uint32, formatCode int16, target any, depth int) Scan
11521126

11531127
for _, f := range m.TryWrapScanPlanFuncs {
11541128
if wrapperPlan, nextDst, ok := f(target); ok {
1155-
if nextPlan := m.planScanDepth(oid, formatCode, nextDst, depth+1); nextPlan != nil {
1129+
if nextPlan := m.planScan(oid, formatCode, nextDst, depth+1); nextPlan != nil {
11561130
if _, failed := nextPlan.(*scanPlanFail); !failed {
11571131
wrapperPlan.SetNext(nextPlan)
11581132
return wrapperPlan

pgtype/pgtype_default.go

-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ func initDefaultMap() {
2323
reflectTypeToName: make(map[reflect.Type]string),
2424
oidToFormatCode: make(map[uint32]int16),
2525

26-
memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan),
2726
memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan),
2827

2928
TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{

pgtype/pgtype_test.go

+15
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,21 @@ func TestMapEncodeRawJSONIntoUnknownOID(t *testing.T) {
555555
require.Equal(t, []byte(`{"foo": "bar"}`), buf)
556556
}
557557

558+
// PlanScan previously used a cache to improve performance. However, the cache could get confused in certain cases. The
559+
// example below was one such failure case.
560+
func TestCachedPlanScanConfusion(t *testing.T) {
561+
m := pgtype.NewMap()
562+
var err error
563+
564+
var tags any
565+
err = m.Scan(pgtype.TextArrayOID, pgx.TextFormatCode, []byte("{foo,bar,baz}"), &tags)
566+
require.NoError(t, err)
567+
568+
var cells [][]string
569+
err = m.Scan(pgtype.TextArrayOID, pgx.TextFormatCode, []byte("{{foo,bar},{baz,quz}}"), &cells)
570+
require.NoError(t, err)
571+
}
572+
558573
func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) {
559574
m := pgtype.NewMap()
560575
src := []byte{0, 0, 0, 42}

0 commit comments

Comments
 (0)