Skip to content

Commit

Permalink
Ensure planning encodes and scans cannot infinitely recurse
Browse files Browse the repository at this point in the history
  • Loading branch information
jackc committed Oct 5, 2024
1 parent a95cfbb commit 123b59a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pgtype/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP
// https://github.com/jackc/pgx/issues/1691 -- ** anything else

if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok {
if nextPlan := m.planScan(oid, format, nextDst); nextPlan != nil {
if nextPlan := m.planScan(oid, format, nextDst, 0); nextPlan != nil {
if _, failed := nextPlan.(*scanPlanFail); !failed {
wrapperPlan.SetNext(nextPlan)
return wrapperPlan
Expand Down
33 changes: 25 additions & 8 deletions pgtype/pgtype.go
Original file line number Diff line number Diff line change
Expand Up @@ -449,14 +449,14 @@ func (plan *scanPlanFail) Scan(src []byte, dst any) error {
// As a horrible hack try all types to find anything that can scan into dst.
for oid := range plan.m.oidToType {
// using planScan instead of Scan or PlanScan to avoid polluting the planned scan cache.
plan := plan.m.planScan(oid, plan.formatCode, dst)
plan := plan.m.planScan(oid, plan.formatCode, dst, 0)
if _, ok := plan.(*scanPlanFail); !ok {
return plan.Scan(src, dst)
}
}
for oid := range defaultMap.oidToType {
if _, ok := plan.m.oidToType[oid]; !ok {
plan := plan.m.planScan(oid, plan.formatCode, dst)
plan := plan.m.planScan(oid, plan.formatCode, dst, 0)
if _, ok := plan.(*scanPlanFail); !ok {
return plan.Scan(src, dst)
}
Expand Down Expand Up @@ -1064,6 +1064,14 @@ func (plan *wrapPtrArrayReflectScanPlan) Scan(src []byte, target any) error {

// PlanScan prepares a plan to scan a value into target.
func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan {
return m.planScanDepth(oid, formatCode, target, 0)
}

func (m *Map) planScanDepth(oid uint32, formatCode int16, target any, depth int) ScanPlan {
if depth > 8 {
return &scanPlanFail{m: m, oid: oid, formatCode: formatCode}
}

oidMemo := m.memoizedScanPlans[oid]
if oidMemo == nil {
oidMemo = make(map[reflect.Type][2]ScanPlan)
Expand All @@ -1073,15 +1081,15 @@ func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan {
typeMemo := oidMemo[targetReflectType]
plan := typeMemo[formatCode]
if plan == nil {
plan = m.planScan(oid, formatCode, target)
plan = m.planScan(oid, formatCode, target, depth)
typeMemo[formatCode] = plan
oidMemo[targetReflectType] = typeMemo
}

return plan
}

func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan {
func (m *Map) planScan(oid uint32, formatCode int16, target any, depth int) ScanPlan {
if target == nil {
return &scanPlanFail{m: m, oid: oid, formatCode: formatCode}
}
Expand Down Expand Up @@ -1141,7 +1149,7 @@ func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan {

for _, f := range m.TryWrapScanPlanFuncs {
if wrapperPlan, nextDst, ok := f(target); ok {
if nextPlan := m.planScan(oid, formatCode, nextDst); nextPlan != nil {
if nextPlan := m.planScanDepth(oid, formatCode, nextDst, depth+1); nextPlan != nil {
if _, failed := nextPlan.(*scanPlanFail); !failed {
wrapperPlan.SetNext(nextPlan)
return wrapperPlan
Expand Down Expand Up @@ -1201,6 +1209,15 @@ func codecDecodeToTextFormat(codec Codec, m *Map, oid uint32, format int16, src
// PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be
// found then nil is returned.
func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan {
return m.planEncodeDepth(oid, format, value, 0)
}

func (m *Map) planEncodeDepth(oid uint32, format int16, value any, depth int) EncodePlan {
// Guard against infinite recursion.
if depth > 8 {
return nil
}

oidMemo := m.memoizedEncodePlans[oid]
if oidMemo == nil {
oidMemo = make(map[reflect.Type][2]EncodePlan)
Expand All @@ -1210,15 +1227,15 @@ func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan {
typeMemo := oidMemo[targetReflectType]
plan := typeMemo[format]
if plan == nil {
plan = m.planEncode(oid, format, value)
plan = m.planEncode(oid, format, value, depth)
typeMemo[format] = plan
oidMemo[targetReflectType] = typeMemo
}

return plan
}

func (m *Map) planEncode(oid uint32, format int16, value any) EncodePlan {
func (m *Map) planEncode(oid uint32, format int16, value any, depth int) EncodePlan {
if format == TextFormatCode {
switch value.(type) {
case string:
Expand Down Expand Up @@ -1249,7 +1266,7 @@ func (m *Map) planEncode(oid uint32, format int16, value any) EncodePlan {

for _, f := range m.TryWrapEncodePlanFuncs {
if wrapperPlan, nextValue, ok := f(value); ok {
if nextPlan := m.PlanEncode(oid, format, nextValue); nextPlan != nil {
if nextPlan := m.planEncodeDepth(oid, format, nextValue, depth+1); nextPlan != nil {
wrapperPlan.SetNext(nextPlan)
return wrapperPlan
}
Expand Down
2 changes: 1 addition & 1 deletion pgtype/xml.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (c *XMLCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPl
// https://github.com/jackc/pgx/issues/1691 -- ** anything else

if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok {
if nextPlan := m.planScan(oid, format, nextDst); nextPlan != nil {
if nextPlan := m.planScan(oid, format, nextDst, 0); nextPlan != nil {
if _, failed := nextPlan.(*scanPlanFail); !failed {
wrapperPlan.SetNext(nextPlan)
return wrapperPlan
Expand Down

0 comments on commit 123b59a

Please sign in to comment.