diff --git a/dbscan/dbscan.go b/dbscan/dbscan.go index 66da180..1a4faf9 100644 --- a/dbscan/dbscan.go +++ b/dbscan/dbscan.go @@ -299,19 +299,37 @@ func (api *API) parseSliceDestination(dst interface{}) (*sliceDestinationMeta, e } func scanSliceElement(rs *RowScanner, sliceMeta *sliceDestinationMeta) error { - dstValPtr := reflect.New(sliceMeta.elementBaseType) + s := sliceMeta.val + l := s.Len() + growSliceByOne(s) + var dstValPtr reflect.Value + if sliceMeta.elementByPtr { + dstValPtr = reflect.New(sliceMeta.elementBaseType) + s.Index(l).Set(dstValPtr) + } else { + dstValPtr = s.Index(l).Addr() + } if err := rs.Scan(dstValPtr.Interface()); err != nil { + // Undo growing the slice. Zero the value to ensure it doesn't retain garbage. + s.Index(l).Set(reflect.Zero(s.Type().Elem())) + s.SetLen(l) return fmt.Errorf("scanning: %w", err) } - var elemVal reflect.Value - if sliceMeta.elementByPtr { - elemVal = dstValPtr - } else { - elemVal = dstValPtr.Elem() + return nil +} + +func growSliceByOne(s reflect.Value) { + // In go 1.20 and above, this could be made simpler (and possibly more efficient) + // by using Value.Grow. + l := s.Len() + c := s.Cap() + if l < c { + s.SetLen(l + 1) + return } - sliceMeta.val.Set(reflect.Append(sliceMeta.val, elemVal)) - return nil + t := s.Type().Elem() + s.Set(reflect.Append(s, reflect.Zero(t))) } // ScanRow is a package-level helper function that uses the DefaultAPI object. diff --git a/dbscan/rowscanner.go b/dbscan/rowscanner.go index 0e9f6bf..b1df194 100644 --- a/dbscan/rowscanner.go +++ b/dbscan/rowscanner.go @@ -34,6 +34,7 @@ type RowScanner struct { started bool scanFn func(dstVal reflect.Value) error start startScannerFunc + scans []any } // NewRowScanner is a package-level helper function that uses the DefaultAPI object. @@ -124,13 +125,15 @@ func startScanner(rs *RowScanner, dstValue reflect.Value) error { } func (rs *RowScanner) scanStruct(structValue reflect.Value) error { - scans := make([]interface{}, len(rs.columns)) + if rs.scans == nil { + rs.scans = make([]interface{}, len(rs.columns)) + } for i, column := range rs.columns { fieldIndex, ok := rs.columnToFieldIndex[column] if !ok { if rs.api.allowUnknownColumns { var tmp interface{} - scans[i] = &tmp + rs.scans[i] = &tmp continue } return fmt.Errorf( @@ -144,9 +147,9 @@ func (rs *RowScanner) scanStruct(structValue reflect.Value) error { initializeNested(structValue, fieldIndex) fieldVal := structValue.FieldByIndex(fieldIndex) - scans[i] = fieldVal.Addr().Interface() + rs.scans[i] = fieldVal.Addr().Interface() } - if err := rs.rows.Scan(scans...); err != nil { + if err := rs.rows.Scan(rs.scans...); err != nil { return fmt.Errorf("scany: scan row into struct fields: %w", err) } return nil @@ -157,14 +160,16 @@ func (rs *RowScanner) scanMap(mapValue reflect.Value) error { mapValue.Set(reflect.MakeMap(mapValue.Type())) } - scans := make([]interface{}, len(rs.columns)) + if rs.scans == nil { + rs.scans = make([]interface{}, len(rs.columns)) + } values := make([]reflect.Value, len(rs.columns)) for i := range rs.columns { valuePtr := reflect.New(rs.mapElementType) - scans[i] = valuePtr.Interface() + rs.scans[i] = valuePtr.Interface() values[i] = valuePtr.Elem() } - if err := rs.rows.Scan(scans...); err != nil { + if err := rs.rows.Scan(rs.scans...); err != nil { return fmt.Errorf("scany: scan rows into map: %w", err) } // We can't set reflect values into destination map before scanning them, @@ -179,7 +184,11 @@ func (rs *RowScanner) scanMap(mapValue reflect.Value) error { } func (rs *RowScanner) scanPrimitive(value reflect.Value) error { - if err := rs.rows.Scan(value.Addr().Interface()); err != nil { + if rs.scans == nil { + rs.scans = make([]interface{}, 1) + } + rs.scans[0] = value.Addr().Interface() + if err := rs.rows.Scan(rs.scans...); err != nil { return fmt.Errorf("scany: scan row value into a primitive type: %w", err) } return nil