From 123b59a57e453d73779eeb5a74ec25195da7d34b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Oct 2024 12:20:50 -0500 Subject: [PATCH] Ensure planning encodes and scans cannot infinitely recurse https://github.com/jackc/pgx/issues/2141 --- pgtype/json.go | 2 +- pgtype/pgtype.go | 33 +++++++++++++++++++++++++-------- pgtype/xml.go | 2 +- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/pgtype/json.go b/pgtype/json.go index c2aa0d3bf..f65fa492e 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -130,7 +130,7 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP // https://github.com/jackc/pgx/issues/1691 -- ** anything else if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok { - if nextPlan := m.planScan(oid, format, nextDst); nextPlan != nil { + if nextPlan := m.planScan(oid, format, nextDst, 0); nextPlan != nil { if _, failed := nextPlan.(*scanPlanFail); !failed { wrapperPlan.SetNext(nextPlan) return wrapperPlan diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index bdd9f05ca..75445d43c 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -449,14 +449,14 @@ func (plan *scanPlanFail) Scan(src []byte, dst any) error { // As a horrible hack try all types to find anything that can scan into dst. for oid := range plan.m.oidToType { // using planScan instead of Scan or PlanScan to avoid polluting the planned scan cache. - plan := plan.m.planScan(oid, plan.formatCode, dst) + plan := plan.m.planScan(oid, plan.formatCode, dst, 0) if _, ok := plan.(*scanPlanFail); !ok { return plan.Scan(src, dst) } } for oid := range defaultMap.oidToType { if _, ok := plan.m.oidToType[oid]; !ok { - plan := plan.m.planScan(oid, plan.formatCode, dst) + plan := plan.m.planScan(oid, plan.formatCode, dst, 0) if _, ok := plan.(*scanPlanFail); !ok { return plan.Scan(src, dst) } @@ -1064,6 +1064,14 @@ func (plan *wrapPtrArrayReflectScanPlan) Scan(src []byte, target any) error { // PlanScan prepares a plan to scan a value into target. func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan { + return m.planScanDepth(oid, formatCode, target, 0) +} + +func (m *Map) planScanDepth(oid uint32, formatCode int16, target any, depth int) ScanPlan { + if depth > 8 { + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} + } + oidMemo := m.memoizedScanPlans[oid] if oidMemo == nil { oidMemo = make(map[reflect.Type][2]ScanPlan) @@ -1073,7 +1081,7 @@ func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan { typeMemo := oidMemo[targetReflectType] plan := typeMemo[formatCode] if plan == nil { - plan = m.planScan(oid, formatCode, target) + plan = m.planScan(oid, formatCode, target, depth) typeMemo[formatCode] = plan oidMemo[targetReflectType] = typeMemo } @@ -1081,7 +1089,7 @@ func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan { return plan } -func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { +func (m *Map) planScan(oid uint32, formatCode int16, target any, depth int) ScanPlan { if target == nil { return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} } @@ -1141,7 +1149,7 @@ func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { for _, f := range m.TryWrapScanPlanFuncs { if wrapperPlan, nextDst, ok := f(target); ok { - if nextPlan := m.planScan(oid, formatCode, nextDst); nextPlan != nil { + if nextPlan := m.planScanDepth(oid, formatCode, nextDst, depth+1); nextPlan != nil { if _, failed := nextPlan.(*scanPlanFail); !failed { wrapperPlan.SetNext(nextPlan) return wrapperPlan @@ -1201,6 +1209,15 @@ func codecDecodeToTextFormat(codec Codec, m *Map, oid uint32, format int16, src // PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be // found then nil is returned. func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { + return m.planEncodeDepth(oid, format, value, 0) +} + +func (m *Map) planEncodeDepth(oid uint32, format int16, value any, depth int) EncodePlan { + // Guard against infinite recursion. + if depth > 8 { + return nil + } + oidMemo := m.memoizedEncodePlans[oid] if oidMemo == nil { oidMemo = make(map[reflect.Type][2]EncodePlan) @@ -1210,7 +1227,7 @@ func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { typeMemo := oidMemo[targetReflectType] plan := typeMemo[format] if plan == nil { - plan = m.planEncode(oid, format, value) + plan = m.planEncode(oid, format, value, depth) typeMemo[format] = plan oidMemo[targetReflectType] = typeMemo } @@ -1218,7 +1235,7 @@ func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { return plan } -func (m *Map) planEncode(oid uint32, format int16, value any) EncodePlan { +func (m *Map) planEncode(oid uint32, format int16, value any, depth int) EncodePlan { if format == TextFormatCode { switch value.(type) { case string: @@ -1249,7 +1266,7 @@ func (m *Map) planEncode(oid uint32, format int16, value any) EncodePlan { for _, f := range m.TryWrapEncodePlanFuncs { if wrapperPlan, nextValue, ok := f(value); ok { - if nextPlan := m.PlanEncode(oid, format, nextValue); nextPlan != nil { + if nextPlan := m.planEncodeDepth(oid, format, nextValue, depth+1); nextPlan != nil { wrapperPlan.SetNext(nextPlan) return wrapperPlan } diff --git a/pgtype/xml.go b/pgtype/xml.go index fb4c49ad9..79e3698a4 100644 --- a/pgtype/xml.go +++ b/pgtype/xml.go @@ -113,7 +113,7 @@ func (c *XMLCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPl // https://github.com/jackc/pgx/issues/1691 -- ** anything else if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok { - if nextPlan := m.planScan(oid, format, nextDst); nextPlan != nil { + if nextPlan := m.planScan(oid, format, nextDst, 0); nextPlan != nil { if _, failed := nextPlan.(*scanPlanFail); !failed { wrapperPlan.SetNext(nextPlan) return wrapperPlan