diff --git a/pkg/executor/mem_reader.go b/pkg/executor/mem_reader.go index 6fad3d09b9a26..fd2046ef3c5ee 100644 --- a/pkg/executor/mem_reader.go +++ b/pkg/executor/mem_reader.go @@ -91,15 +91,31 @@ func buildMemIndexReader(ctx context.Context, us *UnionScanExec, idxReader *Inde } func (m *memIndexReader) getMemRowsIter(ctx context.Context) (memRowsIter, error) { - data, err := m.getMemRows(ctx) + if m.keepOrder && m.table.GetPartitionInfo() != nil { + data, err := m.getMemRows(ctx) + if err != nil { + return nil, errors.Trace(err) + } + return &defaultRowsIter{data: data}, nil + } + + kvIter, err := newTxnMemBufferIter(m.ctx, m.cacheTable, m.kvRanges, m.desc) if err != nil { return nil, errors.Trace(err) } - return &defaultRowsIter{data: data}, nil + tps := m.getTypes() + colInfos := tables.BuildRowcodecColInfoForIndexColumns(m.index, m.table) + colInfos = tables.TryAppendCommonHandleRowcodecColInfos(colInfos, m.table) + return &memRowsIterForIndex{ + kvIter: kvIter, + tps: tps, + mutableRow: chunk.MutRowFromTypes(m.retFieldTypes), + memIndexReader: m, + colInfos: colInfos, + }, nil } -func (m *memIndexReader) getMemRows(ctx context.Context) ([][]types.Datum, error) { - defer tracing.StartRegion(ctx, "memIndexReader.getMemRows").End() +func (m *memIndexReader) getTypes() []*types.FieldType { tps := make([]*types.FieldType, 0, len(m.index.Columns)+1) cols := m.table.Columns for _, col := range m.index.Columns { @@ -122,10 +138,18 @@ func (m *memIndexReader) getMemRows(ctx context.Context) ([][]types.Datum, error default: // ExtraHandle Column tp. tps = append(tps, types.NewFieldType(mysql.TypeLonglong)) } + return tps +} + +func (m *memIndexReader) getMemRows(ctx context.Context) ([][]types.Datum, error) { + defer tracing.StartRegion(ctx, "memIndexReader.getMemRows").End() + tps := m.getTypes() + colInfos := tables.BuildRowcodecColInfoForIndexColumns(m.index, m.table) + colInfos = tables.TryAppendCommonHandleRowcodecColInfos(colInfos, m.table) mutableRow := chunk.MutRowFromTypes(m.retFieldTypes) err := iterTxnMemBuffer(m.ctx, m.cacheTable, m.kvRanges, m.desc, func(key, value []byte) error { - data, err := m.decodeIndexKeyValue(key, value, tps) + data, err := m.decodeIndexKeyValue(key, value, tps, colInfos) if err != nil { return err } @@ -156,13 +180,11 @@ func (m *memIndexReader) getMemRows(ctx context.Context) ([][]types.Datum, error return m.addedRows, nil } -func (m *memIndexReader) decodeIndexKeyValue(key, value []byte, tps []*types.FieldType) ([]types.Datum, error) { +func (m *memIndexReader) decodeIndexKeyValue(key, value []byte, tps []*types.FieldType, colInfos []rowcodec.ColInfo) ([]types.Datum, error) { hdStatus := tablecodec.HandleDefault if mysql.HasUnsignedFlag(tps[len(tps)-1].GetFlag()) { hdStatus = tablecodec.HandleIsUnsigned } - colInfos := tables.BuildRowcodecColInfoForIndexColumns(m.index, m.table) - colInfos = tables.TryAppendCommonHandleRowcodecColInfos(colInfos, m.table) values, err := tablecodec.DecodeIndexKV(key, value, len(m.index.Columns), hdStatus, colInfos) if err != nil { return nil, errors.Trace(err) @@ -259,108 +281,91 @@ func buildMemTableReader(ctx context.Context, us *UnionScanExec, kvRanges []kv.K } } +// txnMemBufferIter implements a kv.Iterator, it is an iterator that combines the membuffer data and snapshot data. type txnMemBufferIter struct { - *memTableReader - txn kv.Transaction - idx int - curr kv.Iterator - - reverse bool - cd *rowcodec.ChunkDecoder - chk *chunk.Chunk - datumRow []types.Datum + sctx sessionctx.Context + kvRanges []kv.KeyRange + cacheTable kv.MemBuffer + txn kv.Transaction + idx int + curr kv.Iterator + reverse bool + err error } -func (iter *txnMemBufferIter) Next() ([]types.Datum, error) { - var ret []types.Datum - for iter.idx < len(iter.kvRanges) { - if iter.curr == nil { - rg := iter.kvRanges[iter.idx] - var tmp kv.Iterator - if !iter.reverse { - tmp = iter.txn.GetMemBuffer().SnapshotIter(rg.StartKey, rg.EndKey) - } else { - tmp = iter.txn.GetMemBuffer().SnapshotIterReverse(rg.EndKey, rg.StartKey) - } - snapCacheIter, err := getSnapIter(iter.ctx, iter.cacheTable, rg, iter.reverse) - if err != nil { - return nil, err - } - if snapCacheIter != nil { - tmp, err = transaction.NewUnionIter(tmp, snapCacheIter, iter.reverse) - if err != nil { - return nil, err - } - } - iter.curr = tmp - } else { - var err error - ret, err = iter.next() - if err != nil { - return nil, errors.Trace(err) - } - if ret != nil { - break - } - iter.idx++ - iter.curr = nil - } +func newTxnMemBufferIter(sctx sessionctx.Context, cacheTable kv.MemBuffer, kvRanges []kv.KeyRange, reverse bool) (*txnMemBufferIter, error) { + txn, err := sctx.Txn(true) + if err != nil { + return nil, errors.Trace(err) } - return ret, nil + return &txnMemBufferIter{ + sctx: sctx, + txn: txn, + kvRanges: kvRanges, + cacheTable: cacheTable, + reverse: reverse, + }, nil } -func (iter *txnMemBufferIter) next() ([]types.Datum, error) { - var err error - curr := iter.curr - for ; err == nil && curr.Valid(); err = curr.Next() { - // check whether the key was been deleted. - if len(curr.Value()) == 0 { - continue +func (iter *txnMemBufferIter) Valid() bool { + if iter.curr != nil { + if iter.curr.Valid() { + return true } - - handle, err := tablecodec.DecodeRowKey(curr.Key()) + iter.curr = nil + iter.idx++ + } + for iter.idx < len(iter.kvRanges) { + rg := iter.kvRanges[iter.idx] + var tmp kv.Iterator + if !iter.reverse { + tmp = iter.txn.GetMemBuffer().SnapshotIter(rg.StartKey, rg.EndKey) + } else { + tmp = iter.txn.GetMemBuffer().SnapshotIterReverse(rg.EndKey, rg.StartKey) + } + snapCacheIter, err := getSnapIter(iter.sctx, iter.cacheTable, rg, iter.reverse) if err != nil { - return nil, errors.Trace(err) + iter.err = errors.Trace(err) + return true } - iter.chk.Reset() - - if !rowcodec.IsNewFormat(curr.Value()) { - // TODO: remove the legacy code! - // fallback to the old way. - iter.datumRow, err = iter.decodeRecordKeyValue(curr.Key(), curr.Value(), &iter.datumRow) - if err != nil { - return nil, errors.Trace(err) - } - - mutableRow := chunk.MutRowFromTypes(iter.retFieldTypes) - mutableRow.SetDatums(iter.datumRow...) - matched, _, err := expression.EvalBool(iter.ctx, iter.conditions, mutableRow.ToRow()) + if snapCacheIter != nil { + tmp, err = transaction.NewUnionIter(tmp, snapCacheIter, iter.reverse) if err != nil { - return nil, errors.Trace(err) + iter.err = errors.Trace(err) + return true } - if !matched { - continue - } - return iter.datumRow, curr.Next() } - - err = iter.cd.DecodeToChunk(curr.Value(), handle, iter.chk) - if err != nil { - return nil, errors.Trace(err) + iter.curr = tmp + if iter.curr.Valid() { + return true } + iter.curr = nil + iter.idx++ + } + return false +} - row := iter.chk.GetRow(0) - matched, _, err := expression.EvalBool(iter.ctx, iter.conditions, row) - if err != nil { - return nil, errors.Trace(err) - } - if !matched { - continue +func (iter *txnMemBufferIter) Next() error { + if iter.err != nil { + return errors.Trace(iter.err) + } + if iter.curr != nil { + if iter.curr.Valid() { + return iter.curr.Next() } - ret := row.GetDatumRowWithBuffer(iter.retFieldTypes, iter.datumRow) - return ret, curr.Next() } - return nil, err + return nil +} + +func (iter *txnMemBufferIter) Key() kv.Key { + return iter.curr.Key() +} + +func (iter *txnMemBufferIter) Value() []byte { + return iter.curr.Value() +} + +func (*txnMemBufferIter) Close() { } func (m *memTableReader) getMemRowsIter(ctx context.Context) (memRowsIter, error) { @@ -377,22 +382,20 @@ func (m *memTableReader) getMemRowsIter(ctx context.Context) (memRowsIter, error for i, col := range m.columns { m.offsets[i] = m.colIDs[col.ID] } - txn, err := m.ctx.Txn(true) + + kvIter, err := newTxnMemBufferIter(m.ctx, m.cacheTable, m.kvRanges, m.desc) if err != nil { - return nil, err + return nil, errors.Trace(err) } - - return &txnMemBufferIter{ - memTableReader: m, - txn: txn, + return &memRowsIterForTable{ + kvIter: kvIter, cd: m.buffer.cd, chk: chunk.New(m.retFieldTypes, 1, 1), datumRow: make([]types.Datum, len(m.retFieldTypes)), - reverse: m.desc, + memTableReader: m, }, nil } -// TODO: Try to make memXXXReader lazy, There is no need to decode many rows when parent operator only need 1 row. func (m *memTableReader) getMemRows(ctx context.Context) ([][]types.Datum, error) { defer tracing.StartRegion(ctx, "memTableReader.getMemRows").End() mutableRow := chunk.MutRowFromTypes(m.retFieldTypes) @@ -859,6 +862,115 @@ func (iter *defaultRowsIter) Next() ([]types.Datum, error) { return nil, nil } +// memRowsIterForTable combine a kv.Iterator and a kv decoder to get a memRowsIter. +type memRowsIterForTable struct { + kvIter *txnMemBufferIter // txnMemBufferIter is the kv.Iterator + cd *rowcodec.ChunkDecoder + chk *chunk.Chunk + datumRow []types.Datum + *memTableReader +} + +func (iter *memRowsIterForTable) Next() ([]types.Datum, error) { + curr := iter.kvIter + var ret []types.Datum + for curr.Valid() { + key := curr.Key() + value := curr.Value() + if err := curr.Next(); err != nil { + return nil, errors.Trace(err) + } + + // check whether the key was been deleted. + if len(value) == 0 { + continue + } + handle, err := tablecodec.DecodeRowKey(key) + if err != nil { + return nil, errors.Trace(err) + } + iter.chk.Reset() + + if !rowcodec.IsNewFormat(value) { + // TODO: remove the legacy code! + // fallback to the old way. + iter.datumRow, err = iter.memTableReader.decodeRecordKeyValue(key, value, &iter.datumRow) + if err != nil { + return nil, errors.Trace(err) + } + + mutableRow := chunk.MutRowFromTypes(iter.retFieldTypes) + mutableRow.SetDatums(iter.datumRow...) + matched, _, err := expression.EvalBool(iter.ctx, iter.conditions, mutableRow.ToRow()) + if err != nil { + return nil, errors.Trace(err) + } + if !matched { + continue + } + return iter.datumRow, nil + } + + err = iter.cd.DecodeToChunk(value, handle, iter.chk) + if err != nil { + return nil, errors.Trace(err) + } + + row := iter.chk.GetRow(0) + matched, _, err := expression.EvalBool(iter.ctx, iter.conditions, row) + if err != nil { + return nil, errors.Trace(err) + } + if !matched { + continue + } + ret = row.GetDatumRowWithBuffer(iter.retFieldTypes, iter.datumRow) + break + } + return ret, nil +} + +type memRowsIterForIndex struct { + kvIter *txnMemBufferIter + tps []*types.FieldType + mutableRow chunk.MutRow + *memIndexReader + colInfos []rowcodec.ColInfo +} + +func (iter *memRowsIterForIndex) Next() ([]types.Datum, error) { + var ret []types.Datum + curr := iter.kvIter + for curr.Valid() { + key := curr.Key() + value := curr.Value() + if err := curr.Next(); err != nil { + return nil, errors.Trace(err) + } + // check whether the key was been deleted. + if len(value) == 0 { + continue + } + + data, err := iter.memIndexReader.decodeIndexKeyValue(key, value, iter.tps, iter.colInfos) + if err != nil { + return nil, err + } + + iter.mutableRow.SetDatums(data...) + matched, _, err := expression.EvalBool(iter.memIndexReader.ctx, iter.memIndexReader.conditions, iter.mutableRow.ToRow()) + if err != nil { + return nil, errors.Trace(err) + } + if !matched { + continue + } + ret = data + break + } + return ret, nil +} + func (m *memIndexMergeReader) getMemRowsIter(ctx context.Context) (memRowsIter, error) { data, err := m.getMemRows(ctx) if err != nil {