-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtokenize.go
123 lines (96 loc) · 2.71 KB
/
tokenize.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
package main
import (
"fmt"
"infinigram/tokenizers"
"os"
"path"
"strings"
"sync"
"github.com/schollz/progressbar/v3"
)
func isAllWhitespace(lineP *string) bool {
return strings.TrimSpace(*lineP) == ""
}
func initTokenizer(tokenizerConfig string) (*tokenizers.Tokenizer, error) {
tk, err := tokenizers.FromFile(tokenizerConfig)
return tk, err
}
func worker(wg *sync.WaitGroup, tokenizerConfig string, sentinalVal int, sentinalSize int, textJobs <-chan *string, results chan<- []byte) {
defer wg.Done()
tk, err := initTokenizer(tokenizerConfig)
if err != nil {
panic(err)
}
defer tk.Close()
for textP := range textJobs {
en, _ := tk.Encode(*textP, false)
dataBytes := make([]byte, (len(en)+sentinalSize)*2)
encodeSequence(dataBytes, en, sentinalVal, sentinalSize)
results <- dataBytes
}
}
func writeWorker(wg *sync.WaitGroup, filename string, results <-chan []byte) error {
defer wg.Done()
f, err := os.Create(filename)
if err != nil {
return err
}
defer f.Close()
for res := range results {
if _, err := f.Write(res); err != nil {
return err
}
}
return nil
}
// Tokenize a file from filename using numWorkers processes and writes the
// resulting tokenized data to outpath. The tokenizer configuration file path
// is tokenizerConfig. The sentinal value is set by sentinalVal and sentinalSize
// Ignores documents that are all whitespace. Tokens are in uint16 format.
// Tokenized data is streamed directly to disk.
func tokenizeMultiprocess(filename, docSplit, outpath, tokenizerConfig string, sentinalVal, sentinalSize, numWorkers int) (string, error) {
// Initialize output path
if err := makeFolder(outpath); err != nil {
return "", err
}
saPath := path.Join(outpath, "data.bin")
// Count lines for the progress bar
fileNumLines, err := numLines(filename, docSplit)
if err != nil {
return "", err
}
fmt.Println("Num lines: ", fileNumLines)
// Initialize workers
textJobs := make(chan *string, numWorkers*4)
results := make(chan []byte, numWorkers*4)
wgWorkers := &sync.WaitGroup{}
wgWriter := &sync.WaitGroup{}
for w := 0; w < numWorkers; w++ {
wgWorkers.Add(1)
go worker(wgWorkers, tokenizerConfig, sentinalVal, sentinalSize, textJobs, results)
}
wgWriter.Add(1)
go writeWorker(wgWriter, saPath, results)
// Read input file and enqueue lines for processing
file, err := os.Open(filename)
if err != nil {
return "", err
}
defer file.Close()
bar := progressbar.Default(int64(fileNumLines))
err = readDocuments(filename, docSplit, func(lineP *string) error {
bar.Add(1)
if !isAllWhitespace(lineP) {
textJobs <- lineP
}
return nil
})
if err != nil {
return "", err
}
close(textJobs)
wgWorkers.Wait()
close(results)
wgWriter.Wait()
return saPath, nil
}