Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/component/operator/text/v0/chunk_text.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,20 +174,20 @@ func chunkMarkdown(input ChunkTextInput) (ChunkTextOutput, error) {
err := sp.Validate()

if err != nil {
return output, fmt.Errorf("failed to validate MarkdownTextSplitter: %w", err)
return output, fmt.Errorf("validating MarkdownTextSplitter: %w", err)
}

chunks, err := sp.SplitText()

if err != nil {
return output, fmt.Errorf("failed to split text: %w", err)
return output, fmt.Errorf("splitting text: %w", err)
}

tkm, err := tiktoken.EncodingForModel(setting.ModelName)

mergedChunks := mergeChunks(chunks, input, tkm)
if err != nil {
return output, fmt.Errorf("failed to get encoding for model: %w", err)
return output, fmt.Errorf("getting encoding for model: %w", err)
}

totalTokenCount := 0
Expand Down
51 changes: 50 additions & 1 deletion pkg/component/operator/text/v0/chunk_text_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,63 @@
package text

import (
"encoding/json"
"os"
"testing"

"github.com/frankban/quicktest"
)

func TestChunkText(t *testing.T) {
func TestChunkText_LongMD(t *testing.T) {
c := quicktest.New(t)

inputData, err := os.ReadFile("testdata/chunk-markdown-input.md")
c.Assert(err, quicktest.IsNil)

// Expectations are complex and hence stored in a file
expectedData, err := os.ReadFile("testdata/chunk-markdown-output.json")
c.Assert(err, quicktest.IsNil)

var want ChunkTextOutput
err = json.Unmarshal(expectedData, &want)
c.Assert(err, quicktest.IsNil)

input := ChunkTextInput{
Text: string(inputData),
Strategy: Strategy{
Setting: Setting{
ChunkMethod: "Markdown",
ModelName: "gpt-4",
ChunkSize: 1024,
ChunkOverlap: 200,
},
},
}

// Run the algorithm
got, err := chunkMarkdown(input)
c.Assert(err, quicktest.IsNil)

// Compare the results
c.Check(got.ChunkNum, quicktest.Equals, want.ChunkNum)
c.Check(got.TokenCount, quicktest.Equals, want.TokenCount)
c.Check(got.ChunksTokenCount, quicktest.Equals, want.ChunksTokenCount)
c.Check(len(got.TextChunks), quicktest.Equals, len(want.TextChunks))

// Compare each chunk (ignoring the actual text content for performance)
for i, chunk := range got.TextChunks {
// NOTE: We don't compare the Text field as it can be large and the
// positions/tokens are sufficient

wantChunk := want.TextChunks[i]
c.Check(chunk.StartPosition, quicktest.Equals, wantChunk.StartPosition)
c.Check(chunk.EndPosition, quicktest.Equals, wantChunk.EndPosition)
c.Check(chunk.TokenCount, quicktest.Equals, wantChunk.TokenCount)

}
}

func TestChunkText(t *testing.T) {
c := quicktest.New(t)

testCases := []struct {
Expand Down
95 changes: 40 additions & 55 deletions pkg/component/operator/text/v0/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,8 @@ type mergedChunk struct {
ContentEndPosition int
}

// mergeChunks does the following things:
// 1. collect the merging chunks that need to be merged
// 2. merge the merging chunks
//
// The reason why we need to separate the merging process into 2 parts is that - to calculate position correctly after dealing with overlap, we need to retain the position information.
// mergeChunks groups chunks to have a token count under the specified chunk
// size.
func mergeChunks(chunks []ContentChunk, inputStruct ChunkTextInput, tkm *tiktoken.Tiktoken) []mergedChunk {
if len(chunks) <= 1 {
mergedChunks := []mergedChunk{}
Expand All @@ -54,70 +51,62 @@ func mergeChunks(chunks []ContentChunk, inputStruct ChunkTextInput, tkm *tiktoke
}
return mergedChunks
}
mergingChunks := collectMergingChunks(chunks, inputStruct, tkm)

// The merging process is divided in 2 steps in order to compute the
// positions of each chunk taking into account the overlap.
mergingChunks := collectMergingChunks(chunks, inputStruct.Strategy.Setting, tkm)
mergedChunks := processMergingChunks(mergingChunks, inputStruct, tkm)

return mergedChunks
}

// Collect the chunks that need to be merged
// The chunk that need to be merged is the chunk that the token size is less than the chunk size
// We need to check with the sequence:
// 1. if currentChunk.PrependHeader != nextChunk.PrependHeader
// - check currentChunk.PrependHeader + currentChunk.Chunk + diff(currentChunk.PrependHeader, nextChunk.PrependHeader) + nextChunk.Chunk < chunkSize
// - if yes, add nextChunk to the currentMergingChunk
// - if no, break
//
// 2. if currentChunk.PrependHeader == nextChunk.PrependHeader
// - check currentChunk.PrependHeader + currentChunk.Chunk + "\n" + nextChunk.Chunk < chunkSize
// - if yes, add nextChunk to the currentMergingChunk
// - if no, break
func collectMergingChunks(chunks []ContentChunk, inputStruct ChunkTextInput, tkm *tiktoken.Tiktoken) []mergingChunks {
// collectMergingChunks combines the chunks so their token size is less than
// the chunk size in the settings. When adjacent chunks share part or all of
// their headers, the token count computes only the header diff.
func collectMergingChunks(chunks []ContentChunk, setting Setting, tkm *tiktoken.Tiktoken) []mergingChunks {
var collectedMergingChunks []mergingChunks
var currentMergingChunk mergingChunks

for i := 0; i < len(chunks); i++ {
currentChunk := chunks[i]
nextIndex := i + 1

currentMergingChunk.Chunks = append(currentMergingChunk.Chunks, currentChunk)
prependedChunk := currentChunk.PrependHeader + currentChunk.Chunk
currentMergingChunk.CollectedTokenSize += getTokenSize(prependedChunk, &inputStruct.Strategy.Setting, tkm)

currentMergingChunk := mergingChunks{
Chunks: []ContentChunk{currentChunk},
CollectedTokenSize: getTokenSize(prependedChunk, setting, tkm),
}

nextIndex := i + 1
for nextIndex < len(chunks) {
nextChunk := chunks[nextIndex]

var potentialSize int
potentialSize := currentMergingChunk.CollectedTokenSize
if currentChunk.PrependHeader != nextChunk.PrependHeader {
diffHeader := headerDiff(currentChunk.PrependHeader, nextChunk.PrependHeader)
addedChunk := diffHeader + nextChunk.Chunk
potentialSize = currentMergingChunk.CollectedTokenSize + getTokenSize(addedChunk, &inputStruct.Strategy.Setting, tkm)
potentialSize += getTokenSize(addedChunk, setting, tkm)
} else {
potentialSize = currentMergingChunk.CollectedTokenSize + getTokenSize(nextChunk.Chunk, &inputStruct.Strategy.Setting, tkm)
potentialSize += getTokenSize(nextChunk.Chunk, setting, tkm)
}

// We need to leave the overlap part for the next chunk
var cannotOverSize int
cannotOverSize := setting.ChunkSize
if len(collectedMergingChunks) > 0 {
previousCollectedChunk := collectedMergingChunks[len(collectedMergingChunks)-1]
if len(previousCollectedChunk.Chunks) > 1 {
cannotOverSize = inputStruct.Strategy.Setting.ChunkSize - inputStruct.Strategy.Setting.ChunkOverlap
} else {
cannotOverSize = inputStruct.Strategy.Setting.ChunkSize
cannotOverSize = setting.ChunkSize - setting.ChunkOverlap
}
} else {
cannotOverSize = inputStruct.Strategy.Setting.ChunkSize
}

if potentialSize <= cannotOverSize {
currentMergingChunk.Chunks = append(currentMergingChunk.Chunks, nextChunk)
currentMergingChunk.CollectedTokenSize = potentialSize
nextIndex++
} else {
if potentialSize > cannotOverSize {
break
}

// If the next chunk has no header, we use the current chunk's header to avoid the duplicate header
currentMergingChunk.Chunks = append(currentMergingChunk.Chunks, nextChunk)
currentMergingChunk.CollectedTokenSize = potentialSize
nextIndex++

// If the next chunk has no header, we use the current chunk's
// header to avoid the duplicate header.
if nextChunk.PrependHeader == "" {
nextChunk.PrependHeader = currentChunk.PrependHeader
}
Expand All @@ -126,7 +115,6 @@ func collectMergingChunks(chunks []ContentChunk, inputStruct ChunkTextInput, tkm
}

collectedMergingChunks = append(collectedMergingChunks, currentMergingChunk)
currentMergingChunk = mergingChunks{}
i = nextIndex - 1
}

Expand All @@ -146,14 +134,16 @@ func processMergingChunks(mergingChunks []mergingChunks, inputStruct ChunkTextIn
firstMergedChunk := mergeMergingChunks(firstMergingChunk)
mergedChunks = append(mergedChunks, firstMergedChunk)

// merge the rest merging chunks, we need to consider the overlap part
mergingIdx := 1
for mergingIdx < len(mergingChunks) {
previousMergingChunk := mergingChunks[mergingIdx-1]
currentMergingChunk := mergingChunks[mergingIdx]
if len(mergingChunks) < 2 {
return mergedChunks
}

// Merge the other chunks taking into account the chunk overlap.
for prevIdx, currentMergingChunk := range mergingChunks[1:] {
previousMergingChunk := mergingChunks[prevIdx]

if len(previousMergingChunk.Chunks) > 1 {
overlapText, overlapPosition := getOverlapForSameHeader(previousMergingChunk, currentMergingChunk, &inputStruct.Strategy.Setting, tkm)
overlapText, overlapPosition := getOverlapForSameHeader(previousMergingChunk, currentMergingChunk, inputStruct.Strategy.Setting, tkm)
if overlapText != "" {
currentMergingChunk.Chunks[0].Chunk = overlapText + currentMergingChunk.Chunks[0].Chunk
currentMergingChunk.Chunks[0].ContentStartPosition = overlapPosition
Expand All @@ -162,8 +152,6 @@ func processMergingChunks(mergingChunks []mergingChunks, inputStruct ChunkTextIn

mergedChunk := mergeMergingChunks(currentMergingChunk)
mergedChunks = append(mergedChunks, mergedChunk)
mergingIdx++

}

return mergedChunks
Expand Down Expand Up @@ -197,12 +185,9 @@ func headerDiff(currentChunkHeader, nextChunkHeader string) string {
currentHeaders := strings.Split(strings.TrimSpace(currentChunkHeader), "\n")
nextHeaders := strings.Split(strings.TrimSpace(nextChunkHeader), "\n")

minLen := len(currentHeaders)
if len(nextHeaders) < minLen {
minLen = len(nextHeaders)
}
minLen := min(len(currentHeaders), len(nextHeaders))

for i := 0; i < minLen; i++ {
for i := range minLen {
if currentHeaders[i] != nextHeaders[i] {
return strings.Join(nextHeaders[i:], "\n")
}
Expand All @@ -216,7 +201,7 @@ func headerDiff(currentChunkHeader, nextChunkHeader string) string {
return ""
}

func getOverlapForSameHeader(previousMergingChunk mergingChunks, currentMergingChunks mergingChunks, setting *Setting, tkm *tiktoken.Tiktoken) (string, int) {
func getOverlapForSameHeader(previousMergingChunk mergingChunks, currentMergingChunks mergingChunks, setting Setting, tkm *tiktoken.Tiktoken) (string, int) {
overlapText := ""
overlapSize := setting.ChunkOverlap
var overlapPosition int
Expand All @@ -241,6 +226,6 @@ func getOverlapForSameHeader(previousMergingChunk mergingChunks, currentMergingC
}

// getTokenSize returns the token size of the text
func getTokenSize(text string, setting *Setting, tkm *tiktoken.Tiktoken) int {
func getTokenSize(text string, setting Setting, tkm *tiktoken.Tiktoken) int {
return len(tkm.Encode(text, setting.AllowedSpecial, setting.DisallowedSpecial))
}
2 changes: 1 addition & 1 deletion pkg/component/operator/text/v0/markdown_document.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func buildDocument(rawRunes []rune, previousDocument *MarkdownDocument, startPos
currentContent.Type = "plaintext"
currentContent.PlainText = block

currentContent.BlockStartPosition = currentPosition - sizeOfString(block) - 1
currentContent.BlockStartPosition = max(currentPosition-sizeOfString(block)-1, 0)
currentContent.BlockEndPosition = currentPosition
doc.Contents = append(doc.Contents, currentContent)
}
Expand Down
13 changes: 2 additions & 11 deletions pkg/component/operator/text/v0/markdown_splitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ func (sp *MarkdownTextSplitter) SplitText() ([]ContentChunk, error) {
var chunks []ContentChunk

rawRunes := []rune(sp.RawText)

docs, err := buildDocuments(rawRunes)

if err != nil {
Expand Down Expand Up @@ -195,11 +194,7 @@ func (sp MarkdownTextSplitter) processChunks(lists []List, headers []Header) []C
}

addListCount := 0
countI := map[int]int{}
for i := 0; i < len(lists); i++ {
countI[i] = 0
}

countI := make([]int, len(lists))
for i := 0; i < len(lists); i++ {
countI[i]++
list := lists[i]
Expand Down Expand Up @@ -435,14 +430,12 @@ func (sp MarkdownTextSplitter) chunkLargeList(list List, prependStringSize int)
}

func (sp MarkdownTextSplitter) chunkPlainText(content Content, headers []Header) ([]ContentChunk, error) {

split := textsplitter.NewRecursiveCharacter(
textsplitter.WithChunkSize(sp.ChunkSize),
textsplitter.WithChunkOverlap(sp.ChunkOverlap),
)

chunks, err := split.SplitText(content.PlainText)

if err != nil {
return nil, err
}
Expand All @@ -457,12 +450,11 @@ func (sp MarkdownTextSplitter) chunkPlainText(content Content, headers []Header)
}

rawRunes := []rune(sp.RawText)
startScanPosition := 0
startScanPosition := content.BlockStartPosition

contentChunks := []ContentChunk{}
for _, chunk := range chunks {
chunkRunes := []rune(chunk)

startPosition, endPosition := getChunkPositions(rawRunes, chunkRunes, startScanPosition)

if shouldScanRawTextFromPreviousChunk(startPosition, endPosition) {
Expand Down Expand Up @@ -492,7 +484,6 @@ func (sp MarkdownTextSplitter) chunkPlainText(content Content, headers []Header)
}

func getChunkPositions(rawText, chunk []rune, startScanPosition int) (startPosition int, endPosition int) {

for i := startScanPosition; i < len(rawText); i++ {
if rawText[i] == chunk[0] {

Expand Down
Loading
Loading