diff --git a/cmd/main.go b/cmd/main.go index 843145b..6e7a503 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -296,7 +296,6 @@ func main() { } func writeOutputs(wg *sync.WaitGroup, processedChannel chan []byte, errorChannel chan error, writeTarget io.WriteCloser) { - for processedChannel != nil || errorChannel != nil { select { case output, ok := <-processedChannel: diff --git a/hugot.go b/hugot.go index 5daf02b..f3b7bfc 100644 --- a/hugot.go +++ b/hugot.go @@ -38,21 +38,21 @@ func (m pipelineMap[T]) GetStats() []string { return stats } -// TokenClassificationConfig is the configuration for a token classification pipeline -type TokenClassificationConfig = pipelines.PipelineConfig[*pipelines.TokenClassificationPipeline] - -// TextClassificationConfig is the configuration for a text classification pipeline -type TextClassificationConfig = pipelines.PipelineConfig[*pipelines.TextClassificationPipeline] - // FeatureExtractionConfig is the configuration for a feature extraction pipeline type FeatureExtractionConfig = pipelines.PipelineConfig[*pipelines.FeatureExtractionPipeline] -// TokenClassificationOption is an option for a token classification pipeline -type TokenClassificationOption = pipelines.PipelineOption[*pipelines.TokenClassificationPipeline] +// TextClassificationConfig is the configuration for a text classification pipeline +type TextClassificationConfig = pipelines.PipelineConfig[*pipelines.TextClassificationPipeline] // TextClassificationOption is an option for a text classification pipeline type TextClassificationOption = pipelines.PipelineOption[*pipelines.TextClassificationPipeline] +// TokenClassificationConfig is the configuration for a token classification pipeline +type TokenClassificationConfig = pipelines.PipelineConfig[*pipelines.TokenClassificationPipeline] + +// // TokenClassificationOption is an option for a token classification pipeline +type TokenClassificationOption = pipelines.PipelineOption[*pipelines.TokenClassificationPipeline] + // FeatureExtractionOption is an option for a feature extraction pipeline type FeatureExtractionOption = pipelines.PipelineOption[*pipelines.FeatureExtractionPipeline] @@ -70,8 +70,8 @@ func NewSession(options ...WithOption) (*Session, error) { session := &Session{ featureExtractionPipelines: map[string]*pipelines.FeatureExtractionPipeline{}, - tokenClassificationPipelines: map[string]*pipelines.TokenClassificationPipeline{}, textClassificationPipelines: map[string]*pipelines.TextClassificationPipeline{}, + tokenClassificationPipelines: map[string]*pipelines.TokenClassificationPipeline{}, } // set session options and initialise @@ -286,7 +286,7 @@ func GetPipeline[T pipelines.Pipeline](s *Session, name string) (T, error) { func (s *Session) Destroy() error { return errors.Join( s.featureExtractionPipelines.Destroy(), - s.tokenClassificationPipelines.Destroy(), + // s.tokenClassificationPipelines.Destroy(), s.textClassificationPipelines.Destroy(), s.ortOptions.Destroy(), ort.DestroyEnvironment(), @@ -302,7 +302,8 @@ func (s *Session) Destroy() error { // the average time per onnxruntime inference batch call func (s *Session) GetStats() []string { // slices.Concat() is not implemented in experimental x/exp/slices package - return append(append(s.tokenClassificationPipelines.GetStats(), + return append(append( + s.tokenClassificationPipelines.GetStats(), s.textClassificationPipelines.GetStats()...), s.featureExtractionPipelines.GetStats()..., ) diff --git a/hugot_test.go b/hugot_test.go index 5cea24d..189b5d1 100644 --- a/hugot_test.go +++ b/hugot_test.go @@ -10,10 +10,11 @@ import ( "strings" "testing" - "github.com/stretchr/testify/assert" - "github.com/knights-analytics/hugot/pipelines" util "github.com/knights-analytics/hugot/utils" + "github.com/stretchr/testify/assert" + + ort "github.com/yalue/onnxruntime_go" ) //go:embed testData/tokenExpected.json @@ -35,6 +36,162 @@ func TestDownloadValidation(t *testing.T) { assert.Error(t, err) } +// FEATURE EXTRACTION + +func TestFeatureExtractionPipelineValidation(t *testing.T) { + session, err := NewSession(WithOnnxLibraryPath(onnxRuntimeSharedLibrary)) + check(t, err) + defer func(session *Session) { + err := session.Destroy() + check(t, err) + }(session) + + modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/all-MiniLM-L6-v2", "./models") + config := FeatureExtractionConfig{ + ModelPath: modelPath, + Name: "testPipeline", + } + pipeline, err := NewPipeline(session, config) + check(t, err) + + pipeline.InputsMeta[0].Dimensions = ort.NewShape(-1, -1, -1) + + err = pipeline.Validate() + assert.Error(t, err) + + pipeline.InputsMeta[0].Dimensions = ort.NewShape(1, 1, 1, 1) + err = pipeline.Validate() + assert.Error(t, err) +} + +func TestFeatureExtractionPipeline(t *testing.T) { + session, err := NewSession(WithOnnxLibraryPath(onnxRuntimeSharedLibrary)) + check(t, err) + defer func(session *Session) { + err := session.Destroy() + check(t, err) + }(session) + + modelPath := downloadModelIfNotExists(session, "sentence-transformers/all-MiniLM-L6-v2", "./models") + + config := FeatureExtractionConfig{ + ModelPath: modelPath, + Name: "testPipeline", + } + pipeline, err := NewPipeline(session, config) + check(t, err) + + var expectedResults map[string][][]float32 + err = json.Unmarshal(resultsByte, &expectedResults) + check(t, err) + var testResults [][]float32 + + // test 'robert smith' + testResults = expectedResults["test1output"] + batchResult, err := pipeline.RunPipeline([]string{"robert smith"}) + if err != nil { + t.Fatalf(err.Error()) + } + for i := range batchResult.Embeddings { + e := floatsEqual(batchResult.Embeddings[i], testResults[i]) + if e != nil { + t.Logf("Test 1: The neural network didn't produce the correct result on loop %d: %s\n", i, e) + t.FailNow() + } + } + + // test ['robert smith junior', 'francis ford coppola'] + testResults = expectedResults["test2output"] + batchResult, err = pipeline.RunPipeline([]string{"robert smith junior", "francis ford coppola"}) + if err != nil { + t.FailNow() + } + for i := range batchResult.Embeddings { + e := floatsEqual(batchResult.Embeddings[i], testResults[i]) + if e != nil { + t.Logf("Test 1: The neural network didn't produce the correct result on loop %d: %s\n", i, e) + t.FailNow() + } + } + + // determinism test to make sure embeddings of a string are not influenced by other strings in the batch + testPairs := map[string][][]string{} + testPairs["identity"] = [][]string{{"sinopharm", "yo"}, {"sinopharm", "yo"}} + testPairs["contextOverlap"] = [][]string{{"sinopharm", "yo"}, {"sinopharm", "yo mama yo"}} + testPairs["contextDisjoint"] = [][]string{{"sinopharm", "yo"}, {"sinopharm", "another test"}} + + for k, sentencePair := range testPairs { + // these vectors should be the same + firstBatchResult, err2 := pipeline.RunPipeline(sentencePair[0]) + check(t, err2) + firstEmbedding := firstBatchResult.Embeddings[0] + + secondBatchResult, err3 := pipeline.RunPipeline(sentencePair[1]) + check(t, err3) + secondEmbedding := secondBatchResult.Embeddings[0] + e := floatsEqual(firstEmbedding, secondEmbedding) + if e != nil { + t.Logf("Equality failed for determinism test %s test with pairs %s and %s", k, strings.Join(sentencePair[0], ","), strings.Join(sentencePair[1], ",")) + t.Log("First vector", firstEmbedding) + t.Log("second vector", secondEmbedding) + t.Fail() + } + } + + zero := uint64(0) + assert.Greater(t, pipeline.PipelineTimings.NumCalls, zero, "PipelineTimings.NumCalls should be greater than 0") + assert.Greater(t, pipeline.PipelineTimings.TotalNS, zero, "PipelineTimings.TotalNS should be greater than 0") + assert.Greater(t, pipeline.TokenizerTimings.NumCalls, zero, "TokenizerTimings.NumCalls should be greater than 0") + assert.Greater(t, pipeline.TokenizerTimings.TotalNS, zero, "TokenizerTimings.TotalNS should be greater than 0") + + // test normalization + testResults = expectedResults["normalizedOutput"] + config = FeatureExtractionConfig{ + ModelPath: modelPath, + Name: "testPipelineNormalise", + Options: []FeatureExtractionOption{ + pipelines.WithNormalization(), + }, + } + pipeline, err = NewPipeline(session, config) + check(t, err) + normalizationStrings := []string{"Onnxruntime is a great inference backend"} + normalizedEmbedding, err := pipeline.RunPipeline(normalizationStrings) + check(t, err) + for i, embedding := range normalizedEmbedding.Embeddings { + e := floatsEqual(embedding, testResults[i]) + if e != nil { + t.Fatalf("Normalization test failed: %s", normalizationStrings[i]) + } + } + + // test getting sentence embeddings + configSentence := FeatureExtractionConfig{ + ModelPath: modelPath, + Name: "testPipelineSentence", + Options: []FeatureExtractionOption{pipelines.WithOutputName("sentence_embedding")}, + } + pipelineSentence, err := NewPipeline(session, configSentence) + check(t, err) + outputSentence, err := pipelineSentence.RunPipeline([]string{"Onnxruntime is a great inference backend"}) + if err != nil { + t.FailNow() + } + fmt.Println(outputSentence.Embeddings[0]) + configSentence = FeatureExtractionConfig{ + ModelPath: modelPath, + Name: "testPipelineToken", + } + pipelineToken, err := NewPipeline(session, configSentence) + check(t, err) + out, err := pipelineToken.RunPipeline([]string{"Onnxruntime is a great inference backend"}) + if err != nil { + t.FailNow() + } + fmt.Println(out) + // TODO: assert the result here +} + // Text classification func TestTextClassificationPipeline(t *testing.T) { @@ -259,20 +416,26 @@ func TestTextClassificationPipelineValidation(t *testing.T) { } sentimentPipeline, err := NewPipeline(session, config) check(t, err) - sentimentPipeline.IdLabelMap = map[int]string{} - err = sentimentPipeline.Validate() - assert.Error(t, err) - if err != nil { - errInt := err.(interface{ Unwrap() []error }) - assert.Equal(t, 3, len(errInt.Unwrap())) - } - sentimentPipeline.OutputDim = 0 - err = sentimentPipeline.Validate() - assert.Error(t, err) - if err != nil { - errInt := err.(interface{ Unwrap() []error }) - assert.Equal(t, 3, len(errInt.Unwrap())) - } + + t.Run("id-label-map", func(t *testing.T) { + labelMapInitial := sentimentPipeline.IDLabelMap + defer func() { + sentimentPipeline.IDLabelMap = labelMapInitial + }() + sentimentPipeline.IDLabelMap = map[int]string{} + err = sentimentPipeline.Validate() + assert.Error(t, err) + }) + + t.Run("output-shape", func(t *testing.T) { + dimensionInitial := sentimentPipeline.OutputsMeta[0].Dimensions + defer func() { + sentimentPipeline.OutputsMeta[0].Dimensions = dimensionInitial + }() + sentimentPipeline.OutputsMeta[0].Dimensions = ort.NewShape(-1, -1, -1) + err = sentimentPipeline.Validate() + assert.Error(t, err) + }) } // Token classification @@ -374,20 +537,25 @@ func TestTokenClassificationPipelineValidation(t *testing.T) { pipelineSimple, err2 := NewPipeline(session, configSimple) check(t, err2) - pipelineSimple.IdLabelMap = map[int]string{} - err = pipelineSimple.Validate() - assert.Error(t, err) - if err != nil { - errInt := err.(interface{ Unwrap() []error }) - assert.Equal(t, 2, len(errInt.Unwrap())) - } - pipelineSimple.OutputDim = 0 - err = pipelineSimple.Validate() - assert.Error(t, err) - if err != nil { - errInt := err.(interface{ Unwrap() []error }) - assert.Equal(t, 2, len(errInt.Unwrap())) - } + t.Run("id-label-map", func(t *testing.T) { + labelMapInitial := pipelineSimple.IDLabelMap + defer func() { + pipelineSimple.IDLabelMap = labelMapInitial + }() + pipelineSimple.IDLabelMap = map[int]string{} + err = pipelineSimple.Validate() + assert.Error(t, err) + }) + + t.Run("output-shape", func(t *testing.T) { + dimensionInitial := pipelineSimple.OutputsMeta[0].Dimensions + defer func() { + pipelineSimple.OutputsMeta[0].Dimensions = dimensionInitial + }() + pipelineSimple.OutputsMeta[0].Dimensions = ort.NewShape(-1, -1, -1) + err = pipelineSimple.Validate() + assert.Error(t, err) + }) } func TestNoSameNamePipeline(t *testing.T) { @@ -415,129 +583,6 @@ func TestNoSameNamePipeline(t *testing.T) { assert.Error(t, err3) } -// feature extraction - -func TestFeatureExtractionPipeline(t *testing.T) { - session, err := NewSession(WithOnnxLibraryPath(onnxRuntimeSharedLibrary)) - check(t, err) - defer func(session *Session) { - err := session.Destroy() - check(t, err) - }(session) - - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/all-MiniLM-L6-v2", "./models") - - config := FeatureExtractionConfig{ - ModelPath: modelPath, - Name: "testPipeline", - } - pipeline, err := NewPipeline(session, config) - check(t, err) - - var expectedResults map[string][][]float32 - err = json.Unmarshal(resultsByte, &expectedResults) - check(t, err) - var testResults [][]float32 - - // test 'robert smith' - testResults = expectedResults["test1output"] - for i := 1; i <= 10; i++ { - batchResult, err := pipeline.RunPipeline([]string{"robert smith"}) - check(t, err) - e := floatsEqual(batchResult.Embeddings[0], testResults[0]) - if e != nil { - t.Logf("Test 1: The neural network didn't produce the correct result on loop %d: %s\n", i, e) - t.FailNow() - } - } - - // test ['robert smith junior', 'francis ford coppola'] - testResults = expectedResults["test2output"] - for i := 1; i <= 10; i++ { - batchResult, err := pipeline.RunPipeline([]string{"robert smith junior", "francis ford coppola"}) - check(t, err) - for j, res := range batchResult.Embeddings { - e := floatsEqual(res, testResults[j]) - if e != nil { - t.Logf("Test 2: The neural network didn't produce the correct result on loop %d: %s\n", i, e) - t.FailNow() - } - } - } - - // determinism test to make sure embeddings of a string are not influenced by other strings in the batch - testPairs := map[string][][]string{} - testPairs["identity"] = [][]string{{"sinopharm", "yo"}, {"sinopharm", "yo"}} - testPairs["contextOverlap"] = [][]string{{"sinopharm", "yo"}, {"sinopharm", "yo mama yo"}} - testPairs["contextDisjoint"] = [][]string{{"sinopharm", "yo"}, {"sinopharm", "another test"}} - - for k, sentencePair := range testPairs { - // these vectors should be the same - firstBatchResult, err2 := pipeline.RunPipeline(sentencePair[0]) - check(t, err2) - firstEmbedding := firstBatchResult.Embeddings[0] - - secondBatchResult, err3 := pipeline.RunPipeline(sentencePair[1]) - check(t, err3) - secondEmbedding := secondBatchResult.Embeddings[0] - e := floatsEqual(firstEmbedding, secondEmbedding) - if e != nil { - t.Logf("Equality failed for determinism test %s test with pairs %s and %s", k, strings.Join(sentencePair[0], ","), strings.Join(sentencePair[1], ",")) - t.Log("First vector", firstEmbedding) - t.Log("second vector", secondEmbedding) - t.Fail() - } - } - - zero := uint64(0) - assert.Greater(t, pipeline.PipelineTimings.NumCalls, zero, "PipelineTimings.NumCalls should be greater than 0") - assert.Greater(t, pipeline.PipelineTimings.TotalNS, zero, "PipelineTimings.TotalNS should be greater than 0") - assert.Greater(t, pipeline.TokenizerTimings.NumCalls, zero, "TokenizerTimings.NumCalls should be greater than 0") - assert.Greater(t, pipeline.TokenizerTimings.TotalNS, zero, "TokenizerTimings.TotalNS should be greater than 0") - - // test normalization - testResults = expectedResults["normalizedOutput"] - config = FeatureExtractionConfig{ - ModelPath: modelPath, - Name: "testPipelineNormalise", - Options: []FeatureExtractionOption{ - pipelines.WithNormalization(), - }, - } - pipeline, err = NewPipeline(session, config) - check(t, err) - normalizationStrings := []string{"Onnxruntime is a great inference backend"} - normalizedEmbedding, err := pipeline.RunPipeline(normalizationStrings) - check(t, err) - for i, embedding := range normalizedEmbedding.Embeddings { - e := floatsEqual(embedding, testResults[i]) - if e != nil { - t.Fatalf("Normalization test failed: %s", normalizationStrings[i]) - } - } -} - -func TestFeatureExtractionPipelineValidation(t *testing.T) { - session, err := NewSession(WithOnnxLibraryPath(onnxRuntimeSharedLibrary)) - check(t, err) - defer func(session *Session) { - err := session.Destroy() - check(t, err) - }(session) - - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/all-MiniLM-L6-v2", "./models") - config := FeatureExtractionConfig{ - ModelPath: modelPath, - Name: "testPipeline", - } - pipeline, err := NewPipeline(session, config) - check(t, err) - - pipeline.OutputDim = 0 - err = pipeline.Validate() - assert.Error(t, err) -} - // README: test the readme examples func TestReadmeExample(t *testing.T) { @@ -697,7 +742,17 @@ func BenchmarkCPUEmbedding(b *testing.B) { } } -// utilities +// // utilities + +func checkClassificationOutput(t *testing.T, inputResult []pipelines.ClassificationOutput, inputExpected []pipelines.ClassificationOutput) { + t.Helper() + assert.Equal(t, len(inputResult), len(inputExpected)) + for i, output := range inputResult { + resultExpected := inputExpected[i] + assert.Equal(t, output.Label, resultExpected.Label) + assert.True(t, almostEqual(float64(output.Score), float64(resultExpected.Score))) + } +} // Returns an error if any element between a and b don't match. func floatsEqual(a, b []float32) error { @@ -718,16 +773,6 @@ func floatsEqual(a, b []float32) error { return nil } -func checkClassificationOutput(t *testing.T, inputResult []pipelines.ClassificationOutput, inputExpected []pipelines.ClassificationOutput) { - t.Helper() - assert.Equal(t, len(inputResult), len(inputExpected)) - for i, output := range inputResult { - resultExpected := inputExpected[i] - assert.Equal(t, output.Label, resultExpected.Label) - assert.True(t, almostEqual(float64(output.Score), float64(resultExpected.Score))) - } -} - func almostEqual(a, b float64) bool { return math.Abs(a-b) <= 0.0001 } diff --git a/pipelines/featureExtraction.go b/pipelines/featureExtraction.go index 5aa1a8a..a5ab852 100644 --- a/pipelines/featureExtraction.go +++ b/pipelines/featureExtraction.go @@ -2,6 +2,11 @@ package pipelines import ( "errors" + "fmt" + "math" + "strings" + "sync/atomic" + "time" ort "github.com/yalue/onnxruntime_go" @@ -11,16 +16,11 @@ import ( // FeatureExtractionPipeline A feature extraction pipeline is a go version of // https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/feature_extraction.py - -// types - type FeatureExtractionPipeline struct { - BasePipeline - Normalization bool -} - -type FeatureExtractionPipelineConfig struct { - IdLabelMap map[int]string `json:"id2label"` + basePipeline + normalization bool + outputName string + output ort.InputOutputInfo } type FeatureExtractionOutput struct { @@ -35,15 +35,24 @@ func (t *FeatureExtractionOutput) GetOutput() []any { return out } -// options +// PIPELINE OPTIONS +// WithNormalization applies normalization to the mean pooled output of the feature pipeline. func WithNormalization() PipelineOption[*FeatureExtractionPipeline] { return func(pipeline *FeatureExtractionPipeline) { - pipeline.Normalization = true + pipeline.normalization = true + } +} + +// WithOutputName if there are multiple outputs from the underlying model, which output should +// be returned. If not passed, the first output from the feature pipeline is returned. +func WithOutputName(outputName string) PipelineOption[*FeatureExtractionPipeline] { + return func(pipeline *FeatureExtractionPipeline) { + pipeline.outputName = outputName } } -// NewFeatureExtractionPipeline Initialize a feature extraction pipeline +// NewFeatureExtractionPipeline init a feature extraction pipeline. func NewFeatureExtractionPipeline(config PipelineConfig[*FeatureExtractionPipeline], ortOptions *ort.SessionOptions) (*FeatureExtractionPipeline, error) { pipeline := &FeatureExtractionPipeline{} pipeline.ModelPath = config.ModelPath @@ -55,79 +64,194 @@ func NewFeatureExtractionPipeline(config PipelineConfig[*FeatureExtractionPipeli o(pipeline) } - // tokenizer + // tokenizer init pipeline.TokenizerOptions = []tokenizers.EncodeOption{tokenizers.WithReturnTypeIDs(), tokenizers.WithReturnAttentionMask()} + tk, err := loadTokenizer(pipeline.ModelPath) + if err != nil { + return nil, err + } + pipeline.Tokenizer = tk - pipeline.PipelineTimings = &Timings{} - pipeline.TokenizerTimings = &Timings{} + // onnx model init + model, err := loadOnnxModelBytes(pipeline.ModelPath, pipeline.OnnxFilename) + if err != nil { + return nil, err + } - // load onnx model - err := pipeline.loadModel() + // init of inputs and outputs + inputs, outputs, err := loadInputOutputMeta(model) if err != nil { return nil, err } + pipeline.InputsMeta = inputs + pipeline.OutputsMeta = outputs - // the dimension of the output is taken from the output meta. For the moment we assume that there is only one output - pipeline.OutputDim = int(pipeline.OutputsMeta[0].Dimensions[2]) + // filter outputs + if pipeline.outputName != "" { + for _, output := range outputs { + if output.Name == pipeline.outputName { + pipeline.output = output + break + } + } + if pipeline.output.Name == "" { + return nil, fmt.Errorf("output %s is not available, outputs are: %s", pipeline.outputName, strings.Join(getNames(outputs), ", ")) + } + } else { + pipeline.output = outputs[0] // we take the first output otherwise, like transformers does + } - err = pipeline.Validate() + // creation of the session. Only one output (either token or sentence embedding). + session, err := createSession(model, inputs, []ort.InputOutputInfo{pipeline.output}, ortOptions) if err != nil { return nil, err } + pipeline.OrtSession = session + + // initialize timings + + pipeline.PipelineTimings = &timings{} + pipeline.TokenizerTimings = &timings{} + // validate pipeline + err = pipeline.Validate() + if err != nil { + errDestroy := pipeline.Destroy() + return nil, errors.Join(err, errDestroy) + } return pipeline, nil } +// INTERFACE IMPLEMENTATION + +// Destroy frees the feature extraction pipeline resources. +func (p *FeatureExtractionPipeline) Destroy() error { + return destroySession(p.Tokenizer, p.OrtSession) +} + +// GetStats returns the runtime statistics for the pipeline. +func (p *FeatureExtractionPipeline) GetStats() []string { + return []string{ + fmt.Sprintf("Statistics for pipeline: %s", p.PipelineName), + fmt.Sprintf("Tokenizer: Total time=%s, Execution count=%d, Average query time=%s", + time.Duration(p.TokenizerTimings.TotalNS), + p.TokenizerTimings.NumCalls, + time.Duration(float64(p.TokenizerTimings.TotalNS)/math.Max(1, float64(p.TokenizerTimings.NumCalls)))), + fmt.Sprintf("ONNX: Total time=%s, Execution count=%d, Average query time=%s", + time.Duration(p.PipelineTimings.TotalNS), + p.PipelineTimings.NumCalls, + time.Duration(float64(p.PipelineTimings.TotalNS)/math.Max(1, float64(p.PipelineTimings.NumCalls)))), + } +} + +// Validate checks that the pipeline is valid. func (p *FeatureExtractionPipeline) Validate() error { var validationErrors []error - if p.OutputDim <= 0 { - validationErrors = append(validationErrors, errors.New("pipeline configuration invalid: outputDim parameter must be greater than zero")) + for _, input := range p.InputsMeta { + dims := []int64(input.Dimensions) + if len(dims) > 3 { + validationErrors = append(validationErrors, fmt.Errorf("inputs and outputs currently can have at most 3 dimensions")) + } + nDynamicDimensions := 0 + for _, d := range dims { + if d == -1 { + nDynamicDimensions++ + } + } + if nDynamicDimensions > 2 { + validationErrors = append(validationErrors, fmt.Errorf(`input %s has dimensions: %s. + There can only be max 2 dynamic dimensions (batch size and sequence length)`, + input.Name, input.Dimensions.String())) + } } return errors.Join(validationErrors...) } -// Postprocess Parse the results of the forward pass into the output. Token embeddings are mean pooled. -func (p *FeatureExtractionPipeline) Postprocess(batch PipelineBatch) (*FeatureExtractionOutput, error) { - maxSequence := batch.MaxSequence - vectorCounter := 0 - tokenCounter := 0 - inputCounter := 0 - outputs := make([][]float32, len(batch.Input)) - tokens := make([][]float32, maxSequence) - vectors := make([]float32, p.OutputDim) - - for _, result := range batch.OutputTensor { - vectors[vectorCounter] = result - if vectorCounter == p.OutputDim-1 { - tokens[tokenCounter] = vectors - vectorCounter = 0 - vectors = make([]float32, p.OutputDim) - if tokenCounter == maxSequence-1 { - outputs[inputCounter] = meanPooling(tokens, batch.Input[inputCounter], maxSequence, p.OutputDim) - tokenCounter = 0 - tokens = make([][]float32, maxSequence) - inputCounter++ +// Preprocess tokenizes the input strings. +func (p *FeatureExtractionPipeline) Preprocess(batch *PipelineBatch, inputs []string) error { + start := time.Now() + tokenizeInputs(batch, p.Tokenizer, inputs, p.TokenizerOptions) + atomic.AddUint64(&p.TokenizerTimings.NumCalls, 1) + atomic.AddUint64(&p.TokenizerTimings.TotalNS, uint64(time.Since(start))) + err := createInputTensors(batch, p.InputsMeta) + return err +} + +// Forward performs the forward inference of the feature extraction pipeline. +func (p *FeatureExtractionPipeline) Forward(batch *PipelineBatch) error { + start := time.Now() + err := runSessionOnBatch(batch, p.OrtSession, []ort.InputOutputInfo{p.output}) + if err != nil { + return err + } + atomic.AddUint64(&p.PipelineTimings.NumCalls, 1) + atomic.AddUint64(&p.PipelineTimings.TotalNS, uint64(time.Since(start))) + return nil +} + +// Postprocess parses the first output from the network similar to the transformers implementation. +func (p *FeatureExtractionPipeline) Postprocess(batch *PipelineBatch) (*FeatureExtractionOutput, error) { + // TODO: this works if token embeddings are returned or sentence embeddings are returned. + // in the former case embeddings are mean pooled. In the latter they are just returned. + // to make this more general for other pipelines and to allow return of raw token embeddings, + // we need an ndarray type that can be the return type of this pipeline. Need to think + // about how to do this in a lightweight manner. + + batchEmbeddings := make([][]float32, len(batch.Input)) + outputDimensions := []int64(p.output.Dimensions) + embeddingDimension := outputDimensions[len(outputDimensions)-1] + maxSequenceLength := batch.MaxSequenceLength + + // now take the output slice and gather the results as a "matrix" + outputEmbedding := make([]float32, embeddingDimension) + outputEmbeddingCounter := 0 + tokenEmbeddings := make([][]float32, maxSequenceLength) + tokenEmbeddingsCounter := 0 + batchInputCounter := 0 + + for _, result := range batch.OutputTensors[0].GetData() { + outputEmbedding[outputEmbeddingCounter] = result + if outputEmbeddingCounter == int(embeddingDimension)-1 { + // we gathered one embedding + if len(outputDimensions) <= 2 { + // it is already a sentence embedding, just add it to batch outputs + batchEmbeddings[batchInputCounter] = outputEmbedding + outputEmbedding = make([]float32, embeddingDimension) + batchInputCounter++ } else { - tokenCounter++ + // output is embedding for a token, add to token embeddings + tokenEmbeddings[tokenEmbeddingsCounter] = outputEmbedding + outputEmbedding = make([]float32, embeddingDimension) + if tokenEmbeddingsCounter == maxSequenceLength-1 { + // computed all embeddings for the tokens, calculate sentence embedding, add to batch outputs, and reset token embeddings and counter + batchEmbeddings[batchInputCounter] = meanPooling(tokenEmbeddings, batch.Input[batchInputCounter], maxSequenceLength, int(embeddingDimension)) + tokenEmbeddings = make([][]float32, maxSequenceLength) + tokenEmbeddingsCounter = 0 + batchInputCounter++ + } else { + // still more tokens to go + tokenEmbeddingsCounter++ + } } + outputEmbeddingCounter = 0 } else { - vectorCounter++ + // still more elements of the embedding to go + outputEmbeddingCounter++ } } // Normalize embeddings (if asked), like in https://huggingface.co/sentence-transformers/all-mpnet-base-v2 - if p.Normalization { - for i, output := range outputs { - outputs[i] = util.Normalize(output, 2) + if p.normalization { + for i, output := range batchEmbeddings { + batchEmbeddings[i] = util.Normalize(output, 2) } } - return &FeatureExtractionOutput{Embeddings: outputs}, nil + return &FeatureExtractionOutput{Embeddings: batchEmbeddings}, nil } -func meanPooling(tokens [][]float32, input TokenizedInput, maxSequence int, dimensions int) []float32 { - +func meanPooling(tokens [][]float32, input tokenizedInput, maxSequence int, dimensions int) []float32 { length := len(input.AttentionMask) vector := make([]float32, dimensions) for j := 0; j < maxSequence; j++ { @@ -146,16 +270,30 @@ func meanPooling(tokens [][]float32, input TokenizedInput, maxSequence int, dime return vector } -// Run the pipeline on a string batch +// Run the pipeline on a batch of strings. func (p *FeatureExtractionPipeline) Run(inputs []string) (PipelineBatchOutput, error) { return p.RunPipeline(inputs) } +// RunPipeline is like Run, but returns the concrete feature extraction output type rather than the interface. func (p *FeatureExtractionPipeline) RunPipeline(inputs []string) (*FeatureExtractionOutput, error) { - batch := p.Preprocess(inputs) - batch, forwardError := p.Forward(batch) - if forwardError != nil { - return nil, forwardError + var runErrors []error + batch := NewBatch() + defer func(*PipelineBatch) { + runErrors = append(runErrors, batch.Destroy()) + }(batch) + + runErrors = append(runErrors, p.Preprocess(batch, inputs)) + if e := errors.Join(runErrors...); e != nil { + return nil, e } - return p.Postprocess(batch) + + runErrors = append(runErrors, p.Forward(batch)) + if e := errors.Join(runErrors...); e != nil { + return nil, e + } + + result, postErr := p.Postprocess(batch) + runErrors = append(runErrors, postErr) + return result, errors.Join(runErrors...) } diff --git a/pipelines/pipeline.go b/pipelines/pipeline.go index fd3e02b..bb94c05 100644 --- a/pipelines/pipeline.go +++ b/pipelines/pipeline.go @@ -5,11 +5,8 @@ import ( "errors" "fmt" "io" - "math" "os" "strings" - "sync/atomic" - "time" "github.com/knights-analytics/tokenizers" ort "github.com/yalue/onnxruntime_go" @@ -17,8 +14,8 @@ import ( util "github.com/knights-analytics/hugot/utils" ) -// BasePipeline is a basic pipeline type used for struct composition in the other pipelines. -type BasePipeline struct { +// BasePipeline can be embedded by a pipeline. +type basePipeline struct { ModelPath string OnnxFilename string PipelineName string @@ -28,27 +25,27 @@ type BasePipeline struct { TokenizerOptions []tokenizers.EncodeOption InputsMeta []ort.InputOutputInfo OutputsMeta []ort.InputOutputInfo - hasTokenTypeIds bool - hasAttentionMask bool - OutputDim int - TokenizerTimings *Timings - PipelineTimings *Timings + TokenizerTimings *timings + PipelineTimings *timings } type PipelineBatchOutput interface { GetOutput() []any } +// Pipeline is the interface that any pipeline must implement. type Pipeline interface { - Destroy() error - GetStats() []string - GetOutputDim() int - Validate() error - Run([]string) (PipelineBatchOutput, error) + Destroy() error // Destroy the pipeline along with its onnx session + GetStats() []string // Get the pipeline running stats + Validate() error // Validate the pipeline for correctness + Run([]string) (PipelineBatchOutput, error) // Run the pipeline on an input } +// PipelineOption is an option for a pipeline type. type PipelineOption[T Pipeline] func(eo T) +// PipelineConfig is a configuration for a pipeline type that can be used +// to create that pipeline. type PipelineConfig[T Pipeline] struct { ModelPath string Name string @@ -56,81 +53,84 @@ type PipelineConfig[T Pipeline] struct { Options []PipelineOption[T] } -type Timings struct { +type timings struct { NumCalls uint64 TotalNS uint64 } -type TokenizedInput struct { +// tokenizedInput holds the result of running tokenizer on an input. +type tokenizedInput struct { Raw string Tokens []string - TokenIds []uint32 - TypeIds []uint32 + TokenIDs []uint32 + TypeIDs []uint32 AttentionMask []uint32 SpecialTokensMask []uint32 MaxAttentionIndex int Offsets []tokenizers.Offset } +// pipelineBatch represents a batch of inputs that runs through the pipeline. type PipelineBatch struct { - Input []TokenizedInput - IdsTensor []int64 - TypeIdsTensor []int64 - AttentionMasksTensor []int64 - MaxSequence int - OutputTensor []float32 + Input []tokenizedInput + InputTensors []*ort.Tensor[int64] + MaxSequenceLength int + OutputTensors []*ort.Tensor[float32] } -func (p *BasePipeline) GetOutputDim() int { - return p.OutputDim -} +func (b *PipelineBatch) Destroy() error { + destroyErrors := make([]error, 0, len(b.InputTensors)+len(b.OutputTensors)) -func getOnnxFiles(path string) ([][]string, error) { - var onnxFiles [][]string - walker := func(_ context.Context, _ string, parent string, info os.FileInfo, _ io.Reader) (toContinue bool, err error) { - if strings.HasSuffix(info.Name(), ".onnx") { - onnxFiles = append(onnxFiles, []string{util.PathJoinSafe(path, parent), info.Name()}) - } - return true, nil + for _, tensor := range b.InputTensors { + destroyErrors = append(destroyErrors, tensor.Destroy()) } - err := util.FileSystem.Walk(context.Background(), path, walker) - return onnxFiles, err + + for _, tensor := range b.OutputTensors { + destroyErrors = append(destroyErrors, tensor.Destroy()) + } + return errors.Join(destroyErrors...) } -// Load the ort model supporting the pipeline. -func (p *BasePipeline) loadModel() error { - tokenizerBytes, err := util.ReadFileBytes(util.PathJoinSafe(p.ModelPath, "tokenizer.json")) +// NewBatch initializes a new batch for inference. +func NewBatch() *PipelineBatch { + return &PipelineBatch{} +} + +func loadTokenizer(modelPath string) (*tokenizers.Tokenizer, error) { + tokenizerBytes, err := util.ReadFileBytes(util.PathJoinSafe(modelPath, "tokenizer.json")) if err != nil { - return err + return nil, err } tk, err := tokenizers.FromBytes(tokenizerBytes) if err != nil { - return err + return nil, err } + return tk, nil +} - // we look for .onnx files. +func loadOnnxModelBytes(modelPath string, modelFilename string) ([]byte, error) { var modelOnnxFile string - onnxFiles, err := getOnnxFiles(p.ModelPath) + onnxFiles, err := getOnnxFiles(modelPath) if err != nil { - return err + return nil, err } if len(onnxFiles) == 0 { - return fmt.Errorf("no .onnx file detected at %s. There should be exactly .onnx file", p.ModelPath) + return nil, fmt.Errorf("no .onnx file detected at %s. There should be exactly .onnx file", modelPath) } if len(onnxFiles) > 1 { - if p.OnnxFilename == "" { - return fmt.Errorf("multiple .onnx file detected at %s and no OnnxFilename specified", p.ModelPath) + if modelFilename == "" { + return nil, fmt.Errorf("multiple .onnx file detected at %s and no OnnxFilename specified", modelPath) } modelNameFound := false for i := range onnxFiles { - if onnxFiles[i][1] == p.OnnxFilename { + if onnxFiles[i][1] == modelFilename { modelNameFound = true modelOnnxFile = util.PathJoinSafe(onnxFiles[i]...) } } if !modelNameFound { - return fmt.Errorf("file %s not found at %s", p.OnnxFilename, p.ModelPath) + return nil, fmt.Errorf("file %s not found at %s", modelFilename, modelPath) } } else { modelOnnxFile = util.PathJoinSafe(onnxFiles[0]...) @@ -138,70 +138,57 @@ func (p *BasePipeline) loadModel() error { onnxBytes, err := util.ReadFileBytes(modelOnnxFile) if err != nil { - return err + return nil, err } + return onnxBytes, err +} +func loadInputOutputMeta(onnxBytes []byte) ([]ort.InputOutputInfo, []ort.InputOutputInfo, error) { inputs, outputs, err := ort.GetInputOutputInfoWithONNXData(onnxBytes) if err != nil { - return err + return nil, nil, err } + return inputs, outputs, nil +} - p.InputsMeta = inputs - p.OutputsMeta = outputs - - inputNames := make([]string, len(inputs)) - for i, meta := range inputs { - inputNames[i] = meta.Name - switch meta.Name { - case "token_type_ids": - p.hasTokenTypeIds = true - case "attention_mask": - p.hasAttentionMask = true - } +func createSession(onnxBytes []byte, inputs, outputs []ort.InputOutputInfo, options *ort.SessionOptions) (*ort.DynamicAdvancedSession, error) { + inputNames := []string{} + outputNames := []string{} + for _, v := range inputs { + inputNames = append(inputNames, v.Name) } - outputNames := make([]string, len(outputs)) - for i, meta := range outputs { - outputNames[i] = meta.Name + for _, v := range outputs { + outputNames = append(outputNames, v.Name) } session, err := ort.NewDynamicAdvancedSessionWithONNXData( onnxBytes, inputNames, outputNames, - p.OrtOptions, + options, ) - if err != nil { - return err - } - - p.OrtSession = session - p.Tokenizer = tk - return nil + return session, err } -func (p *BasePipeline) Destroy() error { - var finalErr error - errTokenizer := p.Tokenizer.Close() - if errTokenizer != nil { - finalErr = errTokenizer - } - ortError := p.OrtSession.Destroy() - if ortError != nil { - finalErr = ortError +func getOnnxFiles(path string) ([][]string, error) { + var onnxFiles [][]string + walker := func(_ context.Context, _ string, parent string, info os.FileInfo, _ io.Reader) (toContinue bool, err error) { + if strings.HasSuffix(info.Name(), ".onnx") { + onnxFiles = append(onnxFiles, []string{util.PathJoinSafe(path, parent), info.Name()}) + } + return true, nil } - return finalErr + err := util.FileSystem.Walk(context.Background(), path, walker) + return onnxFiles, err } -// Preprocess the input strings in the batch -func (p *BasePipeline) Preprocess(inputs []string) PipelineBatch { - start := time.Now() - - outputs := make([]TokenizedInput, len(inputs)) +func tokenizeInputs(batch *PipelineBatch, tk *tokenizers.Tokenizer, inputs []string, options []tokenizers.EncodeOption) { + outputs := make([]tokenizedInput, len(inputs)) maxSequence := 0 for i, input := range inputs { - output := p.Tokenizer.EncodeWithOptions(input, + output := tk.EncodeWithOptions(input, true, - p.TokenizerOptions..., + options..., ) maxAttentionIndex := 0 @@ -211,11 +198,11 @@ func (p *BasePipeline) Preprocess(inputs []string) PipelineBatch { } } - outputs[i] = TokenizedInput{ + outputs[i] = tokenizedInput{ Raw: input, Tokens: output.Tokens, - TokenIds: output.IDs, - TypeIds: output.TypeIDs, + TokenIDs: output.IDs, + TypeIDs: output.TypeIDs, AttentionMask: output.AttentionMask, MaxAttentionIndex: maxAttentionIndex, SpecialTokensMask: output.SpecialTokensMask, @@ -225,114 +212,121 @@ func (p *BasePipeline) Preprocess(inputs []string) PipelineBatch { maxSequence = maxAttentionIndex } } - - atomic.AddUint64(&p.TokenizerTimings.NumCalls, 1) - atomic.AddUint64(&p.TokenizerTimings.TotalNS, uint64(time.Since(start))) - batch := p.convertInputToTensors(outputs, maxSequence+1) - return batch + batch.Input = outputs + batch.MaxSequenceLength = maxSequence + 1 } -func (p *BasePipeline) getInputTensors(batch PipelineBatch, actualBatchSize int64, maxSequence int64) ([]ort.ArbitraryTensor, error) { - inputTensors := make([]ort.ArbitraryTensor, len(p.InputsMeta)) - var err error - - for i, input := range p.InputsMeta { - var inputTensor *ort.Tensor[int64] - - // create the tensor for the input name - switch input.Name { - case "input_ids": - inputTensor, err = ort.NewTensor(ort.NewShape(actualBatchSize, maxSequence), batch.IdsTensor) - case "token_type_ids": - inputTensor, err = ort.NewTensor(ort.NewShape(actualBatchSize, maxSequence), batch.TypeIdsTensor) - case "attention_mask": - inputTensor, err = ort.NewTensor(ort.NewShape(actualBatchSize, maxSequence), batch.AttentionMasksTensor) +// createInputTensors creates ort input tensors. +func createInputTensors(batch *PipelineBatch, inputsMeta []ort.InputOutputInfo) error { + tensorSize := len(batch.Input) * (batch.MaxSequenceLength) + batchSize := int64(len(batch.Input)) + + inputTensors := make([]*ort.Tensor[int64], len(inputsMeta)) + var tensorCreationErr error + + for i, inputMeta := range inputsMeta { + backingSlice := make([]int64, tensorSize) + counter := 0 + + for _, input := range batch.Input { + length := len(input.TokenIDs) + for j := 0; j < batch.MaxSequenceLength; j++ { + if j+1 <= length { + switch inputMeta.Name { + case "input_ids": + backingSlice[counter] = int64(input.TokenIDs[j]) + case "token_type_ids": + backingSlice[counter] = int64(input.TypeIDs[j]) + case "attention_mask": + backingSlice[counter] = int64(input.AttentionMask[j]) + default: + return fmt.Errorf("input %s not recognized", inputMeta.Name) + } + } else { + backingSlice[counter] = 0 // pad with zero + } + counter++ + } + } + inputTensors[i], tensorCreationErr = ort.NewTensor(ort.NewShape(batchSize, int64(batch.MaxSequenceLength)), backingSlice) + if tensorCreationErr != nil { + return tensorCreationErr } - - inputTensors[i] = inputTensor } - return inputTensors, err + batch.InputTensors = inputTensors + return nil } -// Forward pass of the neural network on the tokenized input -func (p *BasePipeline) Forward(batch PipelineBatch) (PipelineBatch, error) { - start := time.Now() - - actualBatchSize := int64(len(batch.Input)) - maxSequence := int64(batch.MaxSequence) - inputTensors, err := p.getInputTensors(batch, actualBatchSize, maxSequence) - if err != nil { - return batch, err - } - - outputTensor, err4 := ort.NewEmptyTensor[float32](ort.NewShape(actualBatchSize, maxSequence, int64(p.OutputDim))) - if err4 != nil { - return batch, err4 - } - - defer func(inputTensors []ort.ArbitraryTensor) { - for _, tensor := range inputTensors { - err = errors.Join(err, tensor.Destroy()) - } - }(inputTensors) - - // Run Onnx model - errOnnx := p.OrtSession.Run(inputTensors, []ort.ArbitraryTensor{outputTensor}) - if errOnnx != nil { - return batch, errOnnx +func getNames(info []ort.InputOutputInfo) []string { + names := make([]string, 0, len(info)) + for _, v := range info { + names = append(names, v.Name) } - batch.OutputTensor = outputTensor.GetData() - defer func(outputTensor *ort.Tensor[float32]) { - err = errors.Join(err, outputTensor.Destroy()) - }(outputTensor) - - atomic.AddUint64(&p.PipelineTimings.NumCalls, 1) - atomic.AddUint64(&p.PipelineTimings.TotalNS, uint64(time.Since(start))) - return batch, err + return names } -// convert tokenized input to the format required by the onnxruntime library -func (p *BasePipeline) convertInputToTensors(inputs []TokenizedInput, maxSequence int) PipelineBatch { - tensorSize := len(inputs) * maxSequence - counter := 0 - - idsTensor := make([]int64, tensorSize) - typeIdsTensor := make([]int64, tensorSize) - attentionMasksTensor := make([]int64, tensorSize) - - for _, input := range inputs { - length := len(input.TokenIds) - for j := 0; j < maxSequence; j++ { - if j+1 <= length { - idsTensor[counter] = int64(input.TokenIds[j]) - if p.hasTokenTypeIds { - typeIdsTensor[counter] = int64(input.TypeIds[j]) - } - if p.hasAttentionMask { - attentionMasksTensor[counter] = int64(input.AttentionMask[j]) +func runSessionOnBatch(batch *PipelineBatch, session *ort.DynamicAdvancedSession, outputs []ort.InputOutputInfo) error { + actualBatchSize := int64(len(batch.Input)) + maxSequenceLength := int64(batch.MaxSequenceLength) + + // allocate vectors with right dimensions for the output + outputTensors := make([]*ort.Tensor[float32], len(outputs)) + arbitraryOutputTensors := make([]ort.ArbitraryTensor, len(outputs)) + var outputCreationErr error + + for outputIndex, meta := range outputs { + var batchDimSet bool + var tokenDimSet bool + actualDims := make([]int64, 0, len(meta.Dimensions)) + + for _, dim := range meta.Dimensions { + if dim == -1 { + if !batchDimSet { + actualDims = append(actualDims, actualBatchSize) + batchDimSet = true + } else if !tokenDimSet { + actualDims = append(actualDims, maxSequenceLength) + tokenDimSet = true + } else { + return fmt.Errorf("only two axis can be dynamic (batch size and number of tokens)") } } else { - // padding all vectors to max sequence length - idsTensor[counter] = 0 - typeIdsTensor[counter] = 0 - attentionMasksTensor[counter] = 0 + actualDims = append(actualDims, dim) } - counter++ } + outputShape := ort.NewShape(actualDims...) + outputTensors[outputIndex], outputCreationErr = ort.NewEmptyTensor[float32](outputShape) + if outputCreationErr != nil { + return outputCreationErr + } + arbitraryOutputTensors[outputIndex] = ort.ArbitraryTensor(outputTensors[outputIndex]) } - return PipelineBatch{ - Input: inputs, - IdsTensor: idsTensor, - TypeIdsTensor: typeIdsTensor, - AttentionMasksTensor: attentionMasksTensor, - MaxSequence: maxSequence, + + // Run Onnx model + arbitraryInputTensors := make([]ort.ArbitraryTensor, len(batch.InputTensors)) + for i, t := range batch.InputTensors { + arbitraryInputTensors[i] = ort.ArbitraryTensor(t) + } + + errOnnx := session.Run(arbitraryInputTensors, arbitraryOutputTensors) + if errOnnx != nil { + return errOnnx } + + // store resulting tensors + batch.OutputTensors = outputTensors + return nil } -func (p *BasePipeline) GetStats() []string { - return []string{ - fmt.Sprintf("Statistics for pipeline: %s", p.PipelineName), - fmt.Sprintf("Tokenizer: Total time=%s, Execution count=%d, Average query time=%s", time.Duration(p.TokenizerTimings.TotalNS), p.TokenizerTimings.NumCalls, time.Duration(float64(p.TokenizerTimings.TotalNS)/math.Max(1, float64(p.TokenizerTimings.NumCalls)))), - fmt.Sprintf("ONNX: Total time=%s, Execution count=%d, Average query time=%s", time.Duration(p.PipelineTimings.TotalNS), p.PipelineTimings.NumCalls, time.Duration(float64(p.PipelineTimings.TotalNS)/math.Max(1, float64(p.PipelineTimings.NumCalls)))), +func destroySession(tk *tokenizers.Tokenizer, session *ort.DynamicAdvancedSession) error { + var finalErr error + errTokenizer := tk.Close() + if errTokenizer != nil { + finalErr = errTokenizer } + ortError := session.Destroy() + if ortError != nil { + finalErr = ortError + } + return finalErr } diff --git a/pipelines/textClassification.go b/pipelines/textClassification.go index 02504bd..7568a56 100644 --- a/pipelines/textClassification.go +++ b/pipelines/textClassification.go @@ -3,6 +3,7 @@ package pipelines import ( "errors" "fmt" + "math" "sync/atomic" "time" @@ -16,14 +17,14 @@ import ( // types type TextClassificationPipeline struct { - BasePipeline - IdLabelMap map[int]string + basePipeline + IDLabelMap map[int]string AggregationFunctionName string ProblemType string } type TextClassificationPipelineConfig struct { - IdLabelMap map[int]string `json:"id2label"` + IDLabelMap map[int]string `json:"id2label"` } type ClassificationOutput struct { @@ -71,7 +72,7 @@ func WithMultiLabel() PipelineOption[*TextClassificationPipeline] { } } -// NewTextClassificationPipeline initializes a new text classification pipeline +// NewTextClassificationPipeline initializes a new text classification pipeline. func NewTextClassificationPipeline(config PipelineConfig[*TextClassificationPipeline], ortOptions *ort.SessionOptions) (*TextClassificationPipeline, error) { pipeline := &TextClassificationPipeline{} pipeline.ModelPath = config.ModelPath @@ -94,10 +95,17 @@ func NewTextClassificationPipeline(config PipelineConfig[*TextClassificationPipe } } + // tokenizer init pipeline.TokenizerOptions = []tokenizers.EncodeOption{ tokenizers.WithReturnAttentionMask(), } + tk, err := loadTokenizer(pipeline.ModelPath) + if err != nil { + return nil, err + } + pipeline.Tokenizer = tk + // read id to label map configPath := util.PathJoinSafe(pipeline.ModelPath, "config.json") pipelineInputConfig := TextClassificationPipelineConfig{} mapBytes, err := util.ReadFileBytes(configPath) @@ -109,88 +117,122 @@ func NewTextClassificationPipeline(config PipelineConfig[*TextClassificationPipe return nil, err } - pipeline.IdLabelMap = pipelineInputConfig.IdLabelMap - pipeline.PipelineTimings = &Timings{} - pipeline.TokenizerTimings = &Timings{} + pipeline.IDLabelMap = pipelineInputConfig.IDLabelMap - // load onnx model - loadErr := pipeline.loadModel() - if loadErr != nil { - return nil, loadErr + // onnx model init + model, err := loadOnnxModelBytes(pipeline.ModelPath, pipeline.OnnxFilename) + if err != nil { + return nil, err } - pipeline.OutputDim = int(pipeline.OutputsMeta[0].Dimensions[1]) + // init of inputs and outputs + inputs, outputs, err := loadInputOutputMeta(model) + if err != nil { + return nil, err + } + pipeline.InputsMeta = inputs + pipeline.OutputsMeta = outputs - // validate - validationErrors := pipeline.Validate() - if validationErrors != nil { - return nil, validationErrors + // creation of the session + session, err := createSession(model, inputs, pipeline.OutputsMeta, ortOptions) + if err != nil { + return nil, err } + pipeline.OrtSession = session + // initialize timings + pipeline.PipelineTimings = &timings{} + pipeline.TokenizerTimings = &timings{} + + // validate + err = pipeline.Validate() + if err != nil { + errDestroy := pipeline.Destroy() + return nil, errors.Join(err, errDestroy) + } return pipeline, nil } +// INTERFACE IMPLEMENTATION + +// Destroy frees the text classification pipeline resources. +func (p *TextClassificationPipeline) Destroy() error { + return destroySession(p.Tokenizer, p.OrtSession) +} + +// GetStats returns the runtime statistics for the pipeline. +func (p *TextClassificationPipeline) GetStats() []string { + return []string{ + fmt.Sprintf("Statistics for pipeline: %s", p.PipelineName), + fmt.Sprintf("Tokenizer: Total time=%s, Execution count=%d, Average query time=%s", + time.Duration(p.TokenizerTimings.TotalNS), + p.TokenizerTimings.NumCalls, + time.Duration(float64(p.TokenizerTimings.TotalNS)/math.Max(1, float64(p.TokenizerTimings.NumCalls)))), + fmt.Sprintf("ONNX: Total time=%s, Execution count=%d, Average query time=%s", + time.Duration(p.PipelineTimings.TotalNS), + p.PipelineTimings.NumCalls, + time.Duration(float64(p.PipelineTimings.TotalNS)/math.Max(1, float64(p.PipelineTimings.NumCalls)))), + } +} + +// Validate checks that the pipeline is valid. func (p *TextClassificationPipeline) Validate() error { var validationErrors []error - if len(p.IdLabelMap) < 1 { - validationErrors = append(validationErrors, fmt.Errorf("only single label classification models are currently supported and more than one label is required")) + if len(p.IDLabelMap) <= 0 { + validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: length of id2label map for token classification pipeline must be greater than zero")) } - if p.OutputDim <= 0 { - validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: outputDim parameter must be greater than zero")) + + outDims := p.OutputsMeta[0].Dimensions + if len(outDims) != 2 { + validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: text classification must have 2 dimensional output")) } - if len(p.IdLabelMap) <= 0 { - validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: length of id2label map for token classification pipeline must be greater than zero")) + dynamicBatch := false + for _, d := range outDims { + if d == -1 { + if dynamicBatch { + validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: text classification must have max one dynamic dimensions (input)")) + break + } + dynamicBatch = true + } } - if len(p.IdLabelMap) != p.OutputDim { - validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: length of id2label map does not match model output dimension")) + nLogits := int(outDims[len(outDims)-1]) + if len(p.IDLabelMap) != nLogits { + validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: length of id2label map does not match number of logits in output (%d)", nLogits)) } return errors.Join(validationErrors...) } -func (p *TextClassificationPipeline) Forward(batch PipelineBatch) (PipelineBatch, error) { +// Preprocess tokenizes the input strings. +func (p *TextClassificationPipeline) Preprocess(batch *PipelineBatch, inputs []string) error { start := time.Now() + tokenizeInputs(batch, p.Tokenizer, inputs, p.TokenizerOptions) + atomic.AddUint64(&p.TokenizerTimings.NumCalls, 1) + atomic.AddUint64(&p.TokenizerTimings.TotalNS, uint64(time.Since(start))) + err := createInputTensors(batch, p.InputsMeta) + return err +} - actualBatchSize := int64(len(batch.Input)) - maxSequence := int64(batch.MaxSequence) - inputTensors, err := p.getInputTensors(batch, actualBatchSize, maxSequence) +func (p *TextClassificationPipeline) Forward(batch *PipelineBatch) error { + start := time.Now() + err := runSessionOnBatch(batch, p.OrtSession, p.OutputsMeta) if err != nil { - return batch, err - } - - defer func(inputTensors []ort.ArbitraryTensor) { - for _, tensor := range inputTensors { - err = errors.Join(err, tensor.Destroy()) - } - }(inputTensors) - - outputTensor, errTensor := ort.NewEmptyTensor[float32](ort.NewShape(actualBatchSize, int64(p.OutputDim))) - if errTensor != nil { - return batch, errTensor + return err } - - defer func(outputTensor *ort.Tensor[float32]) { - err = errors.Join(err, outputTensor.Destroy()) - }(outputTensor) - - // Run Onnx model - errOnnx := p.OrtSession.Run(inputTensors, []ort.ArbitraryTensor{outputTensor}) - if errOnnx != nil { - return batch, errOnnx - } - batch.OutputTensor = outputTensor.GetData() - atomic.AddUint64(&p.PipelineTimings.NumCalls, 1) atomic.AddUint64(&p.PipelineTimings.TotalNS, uint64(time.Since(start))) - return batch, err + return nil } -func (p *TextClassificationPipeline) Postprocess(batch PipelineBatch) (*TextClassificationOutput, error) { - outputTensor := batch.OutputTensor +func (p *TextClassificationPipeline) Postprocess(batch *PipelineBatch) (*TextClassificationOutput, error) { + outputTensor := batch.OutputTensors[0] + outputDims := p.OutputsMeta[0].Dimensions + nLogit := outputDims[len(outputDims)-1] output := make([][]float32, len(batch.Input)) inputCounter := 0 vectorCounter := 0 - inputVector := make([]float32, p.OutputDim) + inputVector := make([]float32, nLogit) var aggregationFunction func([]float32) []float32 switch p.AggregationFunctionName { case "SIGMOID": @@ -201,13 +243,12 @@ func (p *TextClassificationPipeline) Postprocess(batch PipelineBatch) (*TextClas return nil, fmt.Errorf("aggregation function %s is not supported", p.AggregationFunctionName) } - for _, result := range outputTensor { + for _, result := range outputTensor.GetData() { inputVector[vectorCounter] = result - if vectorCounter == p.OutputDim-1 { - + if vectorCounter == int(nLogit)-1 { output[inputCounter] = aggregationFunction(inputVector) vectorCounter = 0 - inputVector = make([]float32, p.OutputDim) + inputVector = make([]float32, nLogit) inputCounter++ } else { vectorCounter++ @@ -229,7 +270,7 @@ func (p *TextClassificationPipeline) Postprocess(batch PipelineBatch) (*TextClas err = errArgMax continue } - class, ok := p.IdLabelMap[index] + class, ok := p.IDLabelMap[index] if !ok { err = fmt.Errorf("class with index number %d not found in id label map", index) } @@ -239,9 +280,9 @@ func (p *TextClassificationPipeline) Postprocess(batch PipelineBatch) (*TextClas } batchClassificationOutputs.ClassificationOutputs[i] = inputClassificationOutputs case "multiLabel": - inputClassificationOutputs := make([]ClassificationOutput, len(p.IdLabelMap)) + inputClassificationOutputs := make([]ClassificationOutput, len(p.IDLabelMap)) for j := range output[i] { - class, ok := p.IdLabelMap[j] + class, ok := p.IDLabelMap[j] if !ok { err = fmt.Errorf("class with index number %d not found in id label map", j) } @@ -258,16 +299,29 @@ func (p *TextClassificationPipeline) Postprocess(batch PipelineBatch) (*TextClas return &batchClassificationOutputs, err } -// Run the pipeline on a string batch +// Run the pipeline on a string batch. func (p *TextClassificationPipeline) Run(inputs []string) (PipelineBatchOutput, error) { return p.RunPipeline(inputs) } func (p *TextClassificationPipeline) RunPipeline(inputs []string) (*TextClassificationOutput, error) { - batch := p.Preprocess(inputs) - batch, err := p.Forward(batch) - if err != nil { - return nil, err + var runErrors []error + batch := NewBatch() + defer func(*PipelineBatch) { + runErrors = append(runErrors, batch.Destroy()) + }(batch) + + runErrors = append(runErrors, p.Preprocess(batch, inputs)) + if e := errors.Join(runErrors...); e != nil { + return nil, e } - return p.Postprocess(batch) + + runErrors = append(runErrors, p.Forward(batch)) + if e := errors.Join(runErrors...); e != nil { + return nil, e + } + + result, postErr := p.Postprocess(batch) + runErrors = append(runErrors, postErr) + return result, errors.Join(runErrors...) } diff --git a/pipelines/tokenClassification.go b/pipelines/tokenClassification.go index 6372435..713b896 100644 --- a/pipelines/tokenClassification.go +++ b/pipelines/tokenClassification.go @@ -3,10 +3,11 @@ package pipelines import ( "errors" "fmt" + "math" + "slices" "strings" - - // according to https://freshman.tech/snippets/go/check-if-slice-contains-element - "golang.org/x/exp/slices" + "sync/atomic" + "time" ort "github.com/yalue/onnxruntime_go" @@ -16,17 +17,17 @@ import ( "github.com/knights-analytics/tokenizers" ) -// types - +// TokenClassificationPipeline is a go version of huggingface tokenClassificationPipeline. +// https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/token_classification.py type TokenClassificationPipeline struct { - BasePipeline - IdLabelMap map[int]string + basePipeline + IDLabelMap map[int]string AggregationStrategy string IgnoreLabels []string } type TokenClassificationPipelineConfig struct { - IdLabelMap map[int]string `json:"id2label"` + IDLabelMap map[int]string `json:"id2label"` } type Entity struct { @@ -35,7 +36,7 @@ type Entity struct { Scores []float32 Index int Word string - TokenId uint32 + TokenID uint32 Start uint End uint IsSubword bool @@ -55,12 +56,17 @@ func (t *TokenClassificationOutput) GetOutput() []any { // options +// TODO: need to implement the other types of aggregation (max etc) + +// WithSimpleAggregation sets the aggregation strategy for the token labels to simple +// It reproduces simple aggregation from the huggingface implementation. func WithSimpleAggregation() PipelineOption[*TokenClassificationPipeline] { return func(pipeline *TokenClassificationPipeline) { pipeline.AggregationStrategy = "SIMPLE" } } +// WithoutAggregation returns the token labels. func WithoutAggregation() PipelineOption[*TokenClassificationPipeline] { return func(pipeline *TokenClassificationPipeline) { pipeline.AggregationStrategy = "NONE" @@ -73,7 +79,7 @@ func WithIgnoreLabels(ignoreLabels []string) PipelineOption[*TokenClassification } } -// NewTokenClassificationPipeline Initializes a feature extraction pipeline +// NewTokenClassificationPipeline Initializes a feature extraction pipeline. func NewTokenClassificationPipeline(config PipelineConfig[*TokenClassificationPipeline], ortOptions *ort.SessionOptions) (*TokenClassificationPipeline, error) { pipeline := &TokenClassificationPipeline{} pipeline.ModelPath = config.ModelPath @@ -84,7 +90,7 @@ func NewTokenClassificationPipeline(config PipelineConfig[*TokenClassificationPi o(pipeline) } - // inputs and encoding options + // tokenizer init pipeline.TokenizerOptions = []tokenizers.EncodeOption{ tokenizers.WithReturnTokens(), tokenizers.WithReturnTypeIDs(), @@ -92,8 +98,27 @@ func NewTokenClassificationPipeline(config PipelineConfig[*TokenClassificationPi tokenizers.WithReturnSpecialTokensMask(), tokenizers.WithReturnOffsets(), } + tk, err := loadTokenizer(pipeline.ModelPath) + if err != nil { + return nil, err + } + pipeline.Tokenizer = tk - // load json model config and set pipeline settings + // onnx model init + model, err := loadOnnxModelBytes(pipeline.ModelPath, pipeline.OnnxFilename) + if err != nil { + return nil, err + } + + // init of inputs and outputs + inputs, outputs, err := loadInputOutputMeta(model) + if err != nil { + return nil, err + } + pipeline.InputsMeta = inputs + pipeline.OutputsMeta = outputs + + // Id label map configPath := util.PathJoinSafe(config.ModelPath, "config.json") pipelineInputConfig := TokenClassificationPipelineConfig{} mapBytes, err := util.ReadFileBytes(configPath) @@ -105,13 +130,9 @@ func NewTokenClassificationPipeline(config PipelineConfig[*TokenClassificationPi if err != nil { return nil, err } - pipeline.IdLabelMap = pipelineInputConfig.IdLabelMap - - pipeline.PipelineTimings = &Timings{} - pipeline.TokenizerTimings = &Timings{} - - // defaults + pipeline.IDLabelMap = pipelineInputConfig.IDLabelMap + // default strategies if not set if pipeline.AggregationStrategy == "" { pipeline.AggregationStrategy = "SIMPLE" } @@ -119,14 +140,15 @@ func NewTokenClassificationPipeline(config PipelineConfig[*TokenClassificationPi pipeline.IgnoreLabels = []string{"O"} } - // load onnx model - errModel := pipeline.loadModel() - if errModel != nil { - return nil, errModel - } + pipeline.PipelineTimings = &timings{} + pipeline.TokenizerTimings = &timings{} - // the dimension of the output is taken from the output meta. - pipeline.OutputDim = int(pipeline.OutputsMeta[0].Dimensions[2]) + // creation of the session. Only one output (either token or sentence embedding). + session, err := createSession(model, inputs, outputs, ortOptions) + if err != nil { + return nil, err + } + pipeline.OrtSession = session err = pipeline.Validate() if err != nil { @@ -135,54 +157,108 @@ func NewTokenClassificationPipeline(config PipelineConfig[*TokenClassificationPi return pipeline, nil } +// INTERFACE IMPLEMENTATION + +// Destroy frees the feature extraction pipeline resources. +func (p *TokenClassificationPipeline) Destroy() error { + return destroySession(p.Tokenizer, p.OrtSession) +} + +// GetStats returns the runtime statistics for the pipeline. +func (p *TokenClassificationPipeline) GetStats() []string { + return []string{ + fmt.Sprintf("Statistics for pipeline: %s", p.PipelineName), + fmt.Sprintf("Tokenizer: Total time=%s, Execution count=%d, Average query time=%s", + time.Duration(p.TokenizerTimings.TotalNS), + p.TokenizerTimings.NumCalls, + time.Duration(float64(p.TokenizerTimings.TotalNS)/math.Max(1, float64(p.TokenizerTimings.NumCalls)))), + fmt.Sprintf("ONNX: Total time=%s, Execution count=%d, Average query time=%s", + time.Duration(p.PipelineTimings.TotalNS), + p.PipelineTimings.NumCalls, + time.Duration(float64(p.PipelineTimings.TotalNS)/math.Max(1, float64(p.PipelineTimings.NumCalls)))), + } +} + +// Validate checks that the pipeline is valid. func (p *TokenClassificationPipeline) Validate() error { var validationErrors []error - if p.OutputDim <= 0 { - validationErrors = append(validationErrors, fmt.Errorf("p configuration invalid: outputDim parameter must be greater than zero")) + outputDim := p.OutputsMeta[0].Dimensions + if len(outputDim) != 3 { + validationErrors = append(validationErrors, + fmt.Errorf("output for token classification must be three dimensional (input, sequence, logits)")) } - if len(p.IdLabelMap) <= 0 { - validationErrors = append(validationErrors, fmt.Errorf("p configuration invalid: length of id2label map for token classification p must be greater than zero")) + + if outputDim[len(outputDim)-1] == -1 { + validationErrors = append(validationErrors, + fmt.Errorf("logit dimension cannot be dynamic")) } - if len(p.IdLabelMap) != p.OutputDim { - validationErrors = append(validationErrors, fmt.Errorf("p configuration invalid: length of id2label map does not match model output dimension")) + if len(p.IDLabelMap) <= 0 { + validationErrors = append(validationErrors, fmt.Errorf("p configuration invalid: length of id2label map for token classification p must be greater than zero")) } return errors.Join(validationErrors...) } -// Postprocess function for a token classification pipeline -func (p *TokenClassificationPipeline) Postprocess(batch PipelineBatch) (*TokenClassificationOutput, error) { +// Preprocess tokenizes the input strings. +func (p *TokenClassificationPipeline) Preprocess(batch *PipelineBatch, inputs []string) error { + start := time.Now() + tokenizeInputs(batch, p.Tokenizer, inputs, p.TokenizerOptions) + atomic.AddUint64(&p.TokenizerTimings.NumCalls, 1) + atomic.AddUint64(&p.TokenizerTimings.TotalNS, uint64(time.Since(start))) + err := createInputTensors(batch, p.InputsMeta) + return err +} + +// Forward performs the forward inference of the pipeline. +func (p *TokenClassificationPipeline) Forward(batch *PipelineBatch) error { + start := time.Now() + err := runSessionOnBatch(batch, p.OrtSession, p.OutputsMeta) + if err != nil { + return err + } + atomic.AddUint64(&p.PipelineTimings.NumCalls, 1) + atomic.AddUint64(&p.PipelineTimings.TotalNS, uint64(time.Since(start))) + return nil +} + +// Postprocess function for a token classification pipeline. +func (p *TokenClassificationPipeline) Postprocess(batch *PipelineBatch) (*TokenClassificationOutput, error) { + if len(batch.Input) == 0 { + return &TokenClassificationOutput{}, nil + } - outputs := make([][][]float32, len(batch.Input)) // holds the final output - inputVectors := make([][]float32, 0, batch.MaxSequence) // holds the embeddings of each original token (no padding) for an input - tokenVector := make([]float32, p.OutputDim) // holds the vector embedding for a token - inputTokens := batch.Input[0].TokenIds + outputDims := p.OutputsMeta[0].Dimensions + tokenLogitsDim := int(outputDims[len(outputDims)-1]) + outputs := make([][][]float32, len(batch.Input)) // holds the final output + inputVectors := make([][]float32, 0, batch.MaxSequenceLength) // holds the embeddings of each original token (no padding) for an input + tokenVector := make([]float32, tokenLogitsDim) // holds the vector embedding for a token + inputTokens := batch.Input[0].TokenIDs // original tokens from the input excluding the padded tokens tokenVectorCounter := 0 tokenCounter := 0 inputCounter := 0 nInputs := len(batch.Input) - // construct the output vectors, however discard the embeddings of the padding tokens so that the output vector length + // construct the output vectors by gathering the logits, + // however discard the embeddings of the padding tokens so that the output vector length // for an input is equal to the number of original tokens - - for _, result := range batch.OutputTensor { + for _, result := range batch.OutputTensors[0].GetData() { tokenVector[tokenVectorCounter] = result - if tokenVectorCounter == p.OutputDim-1 { + if tokenVectorCounter == tokenLogitsDim-1 { // raw result vector for token is now complete if tokenCounter < len(inputTokens) { // it is an original token (not resulting from padding), keep it inputVectors = append(inputVectors, util.SoftMax(tokenVector)) } tokenVectorCounter = 0 - tokenVector = make([]float32, p.OutputDim) - if tokenCounter == batch.MaxSequence-1 { + tokenVector = make([]float32, tokenLogitsDim) + if tokenCounter == batch.MaxSequenceLength-1 { // we went through all tokens in the sequence for this input outputs[inputCounter] = inputVectors tokenCounter = 0 - inputVectors = make([][]float32, 0, batch.MaxSequence) + inputVectors = make([][]float32, 0, batch.MaxSequenceLength) inputCounter++ if inputCounter < nInputs { - inputTokens = batch.Input[inputCounter].TokenIds + inputTokens = batch.Input[inputCounter].TokenIDs } } else { tokenCounter++ @@ -216,8 +292,7 @@ func (p *TokenClassificationPipeline) Postprocess(batch PipelineBatch) (*TokenCl } // GatherPreEntities from batch of logits to list of pre-aggregated outputs -func (p *TokenClassificationPipeline) GatherPreEntities(input TokenizedInput, output [][]float32) []Entity { - +func (p *TokenClassificationPipeline) GatherPreEntities(input tokenizedInput, output [][]float32) []Entity { sentence := input.Raw var preEntities []Entity @@ -229,7 +304,7 @@ func (p *TokenClassificationPipeline) GatherPreEntities(input TokenizedInput, ou } // TODO: the python code uses id_to_token to get the token here which is a method on the rust tokenizer, check if it's better word := input.Tokens[j] - tokenId := input.TokenIds[j] + tokenID := input.TokenIDs[j] // TODO: the determination of subword can probably be better done by exporting the words field from the tokenizer directly startInd := input.Offsets[j][0] endInd := input.Offsets[j][1] @@ -239,7 +314,7 @@ func (p *TokenClassificationPipeline) GatherPreEntities(input TokenizedInput, ou // in that case set the subword as in the python code preEntities = append(preEntities, Entity{ Word: word, - TokenId: tokenId, + TokenID: tokenID, Scores: tokenScores, Start: startInd, End: endInd, @@ -250,7 +325,7 @@ func (p *TokenClassificationPipeline) GatherPreEntities(input TokenizedInput, ou return preEntities } -func (p *TokenClassificationPipeline) Aggregate(input TokenizedInput, preEntities []Entity) ([]Entity, error) { +func (p *TokenClassificationPipeline) Aggregate(input tokenizedInput, preEntities []Entity) ([]Entity, error) { entities := make([]Entity, len(preEntities)) if p.AggregationStrategy == "SIMPLE" || p.AggregationStrategy == "NONE" { for i, preEntity := range preEntities { @@ -258,7 +333,7 @@ func (p *TokenClassificationPipeline) Aggregate(input TokenizedInput, preEntitie if argMaxErr != nil { return nil, argMaxErr } - label, ok := p.IdLabelMap[entityIdx] + label, ok := p.IDLabelMap[entityIdx] if !ok { return nil, fmt.Errorf("could not determine entity type for input %s, predicted entity index %d", input.Raw, entityIdx) } @@ -267,7 +342,7 @@ func (p *TokenClassificationPipeline) Aggregate(input TokenizedInput, preEntitie Score: score, Index: preEntity.Index, Word: preEntity.Word, - TokenId: preEntity.TokenId, + TokenID: preEntity.TokenID, Start: preEntity.Start, End: preEntity.End, } @@ -310,7 +385,7 @@ func (p *TokenClassificationPipeline) groupSubEntities(entities []Entity) Entity tokens := make([]uint32, len(entities)) for i, s := range entities { scores[i] = s.Score - tokens[i] = s.TokenId + tokens[i] = s.TokenID } score := util.Mean(scores) // note: here we directly appeal to the tokenizer decoder with the tokenIds @@ -326,7 +401,7 @@ func (p *TokenClassificationPipeline) groupSubEntities(entities []Entity) Entity } } -// GroupEntities group together adjacent tokens with the same entity predicted +// GroupEntities group together adjacent tokens with the same entity predicted. func (p *TokenClassificationPipeline) GroupEntities(entities []Entity) ([]Entity, error) { var entityGroups []Entity var currentGroupDisagg []Entity @@ -355,16 +430,30 @@ func (p *TokenClassificationPipeline) GroupEntities(entities []Entity) ([]Entity return entityGroups, nil } -// Run the pipeline on a string batch +// Run the pipeline on a string batch. func (p *TokenClassificationPipeline) Run(inputs []string) (PipelineBatchOutput, error) { return p.RunPipeline(inputs) } +// RunPipeline is like Run but returns the concrete type rather than the interface. func (p *TokenClassificationPipeline) RunPipeline(inputs []string) (*TokenClassificationOutput, error) { - batch := p.Preprocess(inputs) - batch, errForward := p.Forward(batch) - if errForward != nil { - return nil, errForward + var runErrors []error + batch := NewBatch() + defer func(*PipelineBatch) { + runErrors = append(runErrors, batch.Destroy()) + }(batch) + + runErrors = append(runErrors, p.Preprocess(batch, inputs)) + if e := errors.Join(runErrors...); e != nil { + return nil, e } - return p.Postprocess(batch) + + runErrors = append(runErrors, p.Forward(batch)) + if e := errors.Join(runErrors...); e != nil { + return nil, e + } + + result, postErr := p.Postprocess(batch) + runErrors = append(runErrors, postErr) + return result, errors.Join(runErrors...) }