Skip to content

Commit

Permalink
Reduce per-row allocations (#130)
Browse files Browse the repository at this point in the history
Modify several internal functions to avoid allocating unnecessary
memory. This commit employs three strategies:

- Allocate a scans slice only when scanning the first row, and store it
  inside the RowScanner to reuse on subsequent rows.
- When scanning into slices, use more sophistocated reflection code to
  extend the slie while avoid temporary slice header allocations.
- When scanning into slices of non-pointers, scan directly into the
  slice index, rather than allocate a new value and copy after scanning.
  • Loading branch information
zolstein authored Mar 24, 2024
1 parent 981a63a commit 041a992
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 16 deletions.
34 changes: 26 additions & 8 deletions dbscan/dbscan.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,19 +299,37 @@ func (api *API) parseSliceDestination(dst interface{}) (*sliceDestinationMeta, e
}

func scanSliceElement(rs *RowScanner, sliceMeta *sliceDestinationMeta) error {
dstValPtr := reflect.New(sliceMeta.elementBaseType)
s := sliceMeta.val
l := s.Len()
growSliceByOne(s)
var dstValPtr reflect.Value
if sliceMeta.elementByPtr {
dstValPtr = reflect.New(sliceMeta.elementBaseType)
s.Index(l).Set(dstValPtr)
} else {
dstValPtr = s.Index(l).Addr()
}
if err := rs.Scan(dstValPtr.Interface()); err != nil {
// Undo growing the slice. Zero the value to ensure it doesn't retain garbage.
s.Index(l).Set(reflect.Zero(s.Type().Elem()))
s.SetLen(l)
return fmt.Errorf("scanning: %w", err)
}
var elemVal reflect.Value
if sliceMeta.elementByPtr {
elemVal = dstValPtr
} else {
elemVal = dstValPtr.Elem()
return nil
}

func growSliceByOne(s reflect.Value) {
// In go 1.20 and above, this could be made simpler (and possibly more efficient)
// by using Value.Grow.
l := s.Len()
c := s.Cap()
if l < c {
s.SetLen(l + 1)
return
}

sliceMeta.val.Set(reflect.Append(sliceMeta.val, elemVal))
return nil
t := s.Type().Elem()
s.Set(reflect.Append(s, reflect.Zero(t)))
}

// ScanRow is a package-level helper function that uses the DefaultAPI object.
Expand Down
25 changes: 17 additions & 8 deletions dbscan/rowscanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type RowScanner struct {
started bool
scanFn func(dstVal reflect.Value) error
start startScannerFunc
scans []any
}

// NewRowScanner is a package-level helper function that uses the DefaultAPI object.
Expand Down Expand Up @@ -124,13 +125,15 @@ func startScanner(rs *RowScanner, dstValue reflect.Value) error {
}

func (rs *RowScanner) scanStruct(structValue reflect.Value) error {
scans := make([]interface{}, len(rs.columns))
if rs.scans == nil {
rs.scans = make([]interface{}, len(rs.columns))
}
for i, column := range rs.columns {
fieldIndex, ok := rs.columnToFieldIndex[column]
if !ok {
if rs.api.allowUnknownColumns {
var tmp interface{}
scans[i] = &tmp
rs.scans[i] = &tmp
continue
}
return fmt.Errorf(
Expand All @@ -144,9 +147,9 @@ func (rs *RowScanner) scanStruct(structValue reflect.Value) error {
initializeNested(structValue, fieldIndex)

fieldVal := structValue.FieldByIndex(fieldIndex)
scans[i] = fieldVal.Addr().Interface()
rs.scans[i] = fieldVal.Addr().Interface()
}
if err := rs.rows.Scan(scans...); err != nil {
if err := rs.rows.Scan(rs.scans...); err != nil {
return fmt.Errorf("scany: scan row into struct fields: %w", err)
}
return nil
Expand All @@ -157,14 +160,16 @@ func (rs *RowScanner) scanMap(mapValue reflect.Value) error {
mapValue.Set(reflect.MakeMap(mapValue.Type()))
}

scans := make([]interface{}, len(rs.columns))
if rs.scans == nil {
rs.scans = make([]interface{}, len(rs.columns))
}
values := make([]reflect.Value, len(rs.columns))
for i := range rs.columns {
valuePtr := reflect.New(rs.mapElementType)
scans[i] = valuePtr.Interface()
rs.scans[i] = valuePtr.Interface()
values[i] = valuePtr.Elem()
}
if err := rs.rows.Scan(scans...); err != nil {
if err := rs.rows.Scan(rs.scans...); err != nil {
return fmt.Errorf("scany: scan rows into map: %w", err)
}
// We can't set reflect values into destination map before scanning them,
Expand All @@ -179,7 +184,11 @@ func (rs *RowScanner) scanMap(mapValue reflect.Value) error {
}

func (rs *RowScanner) scanPrimitive(value reflect.Value) error {
if err := rs.rows.Scan(value.Addr().Interface()); err != nil {
if rs.scans == nil {
rs.scans = make([]interface{}, 1)
}
rs.scans[0] = value.Addr().Interface()
if err := rs.rows.Scan(rs.scans...); err != nil {
return fmt.Errorf("scany: scan row value into a primitive type: %w", err)
}
return nil
Expand Down

0 comments on commit 041a992

Please sign in to comment.