Skip to content
This repository was archived by the owner on Oct 30, 2024. It is now read-only.

add: rolling markdown splitter #160

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/datastore/defaults/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ const (
TopK int = 10

TextSplitterTokenModel = "gpt-4"
TextSplitterChunkSize = 1024
TextSplitterChunkSize = 2048
TextSplitterChunkOverlap = 256
TextSplitterTokenEncoding = "cl100k_base"
)
187 changes: 187 additions & 0 deletions pkg/datastore/textsplitter/markdown_rolling/markdown_rolling.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package markdown_rolling

import (
"fmt"
"strings"

"github.com/pkoukk/tiktoken-go"
lcgosplitter "github.com/tmc/langchaingo/textsplitter"
)

// NewMarkdownTextSplitter creates a new Markdown text splitter.
func NewMarkdownTextSplitter(opts ...Option) (*MarkdownTextSplitter, error) {
options := DefaultOptions()

for _, opt := range opts {
opt(&options)
}

var tk *tiktoken.Tiktoken
var err error
if options.EncodingName != "" {
tk, err = tiktoken.GetEncoding(options.EncodingName)
} else {
tk, err = tiktoken.EncodingForModel(options.ModelName)
}
if err != nil {
return nil, fmt.Errorf("couldn't get encoding: %w", err)
}

tokenSplitter := lcgosplitter.TokenSplitter{
ChunkSize: options.ChunkSize,
ChunkOverlap: options.ChunkOverlap,
ModelName: options.ModelName,
EncodingName: options.EncodingName,
AllowedSpecial: []string{},
DisallowedSpecial: []string{"all"},
}

return &MarkdownTextSplitter{
options,
tk,
tokenSplitter,
}, nil
}

// MarkdownTextSplitter markdown header text splitter.
type MarkdownTextSplitter struct {
Options
*tiktoken.Tiktoken
tokenSplitter lcgosplitter.TokenSplitter
}

type block struct {
headings []string
lines []string
text string
tokenSize int
}

func (s *MarkdownTextSplitter) getTokenSize(text string) int {
return len(s.Encode(text, []string{}, []string{"all"}))
}

func (s *MarkdownTextSplitter) finishBlock(blocks []block, currentBlock block, headingStack []string) ([]block, block, error) {

for _, header := range headingStack {
if header != "" {
currentBlock.headings = append(currentBlock.headings, header)
}
}

if len(currentBlock.lines) == 0 && s.IgnoreHeadingOnly {
return blocks, block{}, nil
}

headingStr := strings.TrimSpace(strings.Join(currentBlock.headings, "\n"))
contentStr := strings.TrimSpace(strings.Join(currentBlock.lines, "\n"))
text := headingStr + "\n" + contentStr

if len(text) == 0 {
return blocks, block{}, nil
}

textTokenSize := s.getTokenSize(text)

if textTokenSize <= s.ChunkSize {
// append new block to free up some space
return append(blocks, block{
text: text,
tokenSize: textTokenSize,
}), block{}, nil
}

// If the block is larger than the chunk size, split it
headingTokenSize := s.getTokenSize(headingStr)

// Split into chunks that leave room for the heading
s.tokenSplitter.ChunkSize = s.ChunkSize - headingTokenSize

splits, err := s.tokenSplitter.SplitText(contentStr)
if err != nil {
return blocks, block{}, err
}

for _, split := range splits {
text = headingStr + "\n" + split
blocks = append(blocks, block{
text: text,
tokenSize: s.getTokenSize(text),
})
}

return blocks, block{}, nil

}

// SplitText splits text into chunks.
func (s *MarkdownTextSplitter) SplitText(text string) ([]string, error) {

var (
headingStack []string
chunks []string
currentChunk block
currentHeadingLevel int = 1
currentBlock block

blocks []block
err error
)

// Parse markdown line-by-line and build heading-delimited blocks
for _, line := range strings.Split(text, "\n") {

// Handle header = start a new block
if strings.HasPrefix(line, "#") {
// Finish the previous Block
blocks, currentBlock, err = s.finishBlock(blocks, currentBlock, headingStack)
if err != nil {
return nil, err
}

// Get the header level
headingLevel := strings.Count(strings.Split(line, " ")[0], "#") - 1

headingStack = append(headingStack[:headingLevel], line)

// Clear the header stack for lower level headers
for j := headingLevel + 1; j < len(headingStack); j++ {
headingStack[j] = ""
}

// Reset header stack indices between this level and the last seen level, backwards
for j := headingLevel - 1; j > currentHeadingLevel; j-- {
headingStack[j] = ""
}

currentHeadingLevel = headingLevel
continue

}

// If the line is not a header, add it to the current block
currentBlock.lines = append(currentBlock.lines, line)

}

// Finish the last block
blocks, currentBlock, err = s.finishBlock(blocks, currentBlock, headingStack)
if err != nil {
return nil, err
}

// Combine blocks into chunks as close to the target token size as possible
for _, b := range blocks {
if currentChunk.tokenSize+b.tokenSize <= s.ChunkSize {
// Doesn't exceed chunk size, so add to the current chunk
currentChunk.text += "\n" + b.text
currentChunk.tokenSize += b.tokenSize
} else {
// Exceeds chunk size, so start a new chunk
chunks = append(chunks, currentChunk.text)
currentChunk = b
}
}

return chunks, nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package markdown_rolling

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestSplitTextWithBasicMarkdown(t *testing.T) {
splitter := NewMarkdownTextSplitter()

Check failure on line 10 in pkg/datastore/textsplitter/markdown_rolling/markdown_rolling_test.go

View workflow job for this annotation

GitHub Actions / Full Test Suite

assignment mismatch: 1 variable but NewMarkdownTextSplitter returns 2 values
chunks, err := splitter.SplitText("# Heading\n\nThis is a paragraph.")
assert.NoError(t, err)
assert.Equal(t, 1, len(chunks))

expected := []string{"# Heading\nThis is a paragraph."}

assert.Equal(t, expected, chunks)
}

func TestSplitTextWithOptions(t *testing.T) {
md := `
# Heading 1

some p under h1

## Heading 2
### Heading 3

- some
- list
- items

**bold**

# 2nd Heading 1
#### Heading 4

some p under h4
`

testcases := []struct {
name string
splitter *MarkdownTextSplitter
expected []string
}{
{
name: "default",
splitter: NewMarkdownTextSplitter(),

Check failure on line 48 in pkg/datastore/textsplitter/markdown_rolling/markdown_rolling_test.go

View workflow job for this annotation

GitHub Actions / Full Test Suite

multiple-value NewMarkdownTextSplitter() (value of type (*MarkdownTextSplitter, error)) in single-value context
expected: []string{
"# Heading 1\nsome p under h1",
"# Heading 1\n## Heading 2",
"# Heading 1\n## Heading 2\n### Heading 3\n- some\n- list\n- items\n\n**bold**",
"# 2nd Heading 1",
"# 2nd Heading 1\n#### Heading 4\nsome p under h4",
},
},
{
name: "ignore_heading_only",
splitter: NewMarkdownTextSplitter(WithIgnoreHeadingOnly(true)),

Check failure on line 59 in pkg/datastore/textsplitter/markdown_rolling/markdown_rolling_test.go

View workflow job for this annotation

GitHub Actions / Full Test Suite

multiple-value NewMarkdownTextSplitter(WithIgnoreHeadingOnly(true)) (value of type (*MarkdownTextSplitter, error)) in single-value context
expected: []string{
"# Heading 1\nsome p under h1",
"# Heading 1\n## Heading 2\n### Heading 3\n- some\n- list\n- items\n\n**bold**",
"# 2nd Heading 1\n#### Heading 4\nsome p under h4",
},
},
{
name: "split_h1_only",
splitter: NewMarkdownTextSplitter(),

Check failure on line 68 in pkg/datastore/textsplitter/markdown_rolling/markdown_rolling_test.go

View workflow job for this annotation

GitHub Actions / Full Test Suite

multiple-value NewMarkdownTextSplitter() (value of type (*MarkdownTextSplitter, error)) in single-value context
expected: []string{
"# Heading 1\nsome p under h1\n\n## Heading 2\n### Heading 3\n\n- some\n- list\n- items\n\n**bold**",
"# 2nd Heading 1\n#### Heading 4\n\nsome p under h4",
},
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
chunks, err := tc.splitter.SplitText(md)
assert.NoError(t, err)
assert.Equal(t, len(tc.expected), len(chunks))

assert.Equal(t, tc.expected, chunks)
})
}
}
69 changes: 69 additions & 0 deletions pkg/datastore/textsplitter/markdown_rolling/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package markdown_rolling

import (
"github.com/gptscript-ai/knowledge/pkg/datastore/defaults"
lcgosplitter "github.com/tmc/langchaingo/textsplitter"
)

// Options is a struct that contains options for a text splitter.
type Options struct {
ChunkSize int
ChunkOverlap int
Separators []string
KeepSeparator bool
ModelName string
EncodingName string
SecondSplitter lcgosplitter.TextSplitter

IgnoreHeadingOnly bool // Ignore chunks that only contain headings
}

// DefaultOptions returns the default options for all text splitter.
func DefaultOptions() Options {
return Options{
ChunkSize: defaults.TextSplitterChunkSize,
ChunkOverlap: defaults.TextSplitterChunkOverlap,

ModelName: defaults.TextSplitterTokenModel,
EncodingName: defaults.TextSplitterTokenEncoding,

IgnoreHeadingOnly: true,
}
}

// Option is a function that can be used to set options for a text splitter.
type Option func(*Options)

// WithChunkSize sets the chunk size for a text splitter.
func WithChunkSize(chunkSize int) Option {
return func(o *Options) {
o.ChunkSize = chunkSize
}
}

// WithChunkOverlap sets the chunk overlap for a text splitter.
func WithChunkOverlap(chunkOverlap int) Option {
return func(o *Options) {
o.ChunkOverlap = chunkOverlap
}
}

// WithModelName sets the model name for a text splitter.
func WithModelName(modelName string) Option {
return func(o *Options) {
o.ModelName = modelName
}
}

// WithEncodingName sets the encoding name for a text splitter.
func WithEncodingName(encodingName string) Option {
return func(o *Options) {
o.EncodingName = encodingName
}
}

func WithIgnoreHeadingOnly(ignoreHeadingOnly bool) Option {
return func(o *Options) {
o.IgnoreHeadingOnly = ignoreHeadingOnly
}
}
Loading