Skip to content

Commit

Permalink
change: refactor of pipelines to avoid struct method inheritance, sim…
Browse files Browse the repository at this point in the history
…iplification of pipeline and batch structs, and support for output selection for feature extraction.
  • Loading branch information
riccardopinosio committed Jul 18, 2024
1 parent fd59728 commit 143dd73
Show file tree
Hide file tree
Showing 7 changed files with 869 additions and 549 deletions.
1 change: 0 additions & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ func main() {
}

func writeOutputs(wg *sync.WaitGroup, processedChannel chan []byte, errorChannel chan error, writeTarget io.WriteCloser) {

for processedChannel != nil || errorChannel != nil {
select {
case output, ok := <-processedChannel:
Expand Down
23 changes: 12 additions & 11 deletions hugot.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,21 @@ func (m pipelineMap[T]) GetStats() []string {
return stats
}

// TokenClassificationConfig is the configuration for a token classification pipeline
type TokenClassificationConfig = pipelines.PipelineConfig[*pipelines.TokenClassificationPipeline]

// TextClassificationConfig is the configuration for a text classification pipeline
type TextClassificationConfig = pipelines.PipelineConfig[*pipelines.TextClassificationPipeline]

// FeatureExtractionConfig is the configuration for a feature extraction pipeline
type FeatureExtractionConfig = pipelines.PipelineConfig[*pipelines.FeatureExtractionPipeline]

// TokenClassificationOption is an option for a token classification pipeline
type TokenClassificationOption = pipelines.PipelineOption[*pipelines.TokenClassificationPipeline]
// TextClassificationConfig is the configuration for a text classification pipeline
type TextClassificationConfig = pipelines.PipelineConfig[*pipelines.TextClassificationPipeline]

// TextClassificationOption is an option for a text classification pipeline
type TextClassificationOption = pipelines.PipelineOption[*pipelines.TextClassificationPipeline]

// TokenClassificationConfig is the configuration for a token classification pipeline
type TokenClassificationConfig = pipelines.PipelineConfig[*pipelines.TokenClassificationPipeline]

// // TokenClassificationOption is an option for a token classification pipeline
type TokenClassificationOption = pipelines.PipelineOption[*pipelines.TokenClassificationPipeline]

// FeatureExtractionOption is an option for a feature extraction pipeline
type FeatureExtractionOption = pipelines.PipelineOption[*pipelines.FeatureExtractionPipeline]

Expand All @@ -70,8 +70,8 @@ func NewSession(options ...WithOption) (*Session, error) {

session := &Session{
featureExtractionPipelines: map[string]*pipelines.FeatureExtractionPipeline{},
tokenClassificationPipelines: map[string]*pipelines.TokenClassificationPipeline{},
textClassificationPipelines: map[string]*pipelines.TextClassificationPipeline{},
tokenClassificationPipelines: map[string]*pipelines.TokenClassificationPipeline{},
}

// set session options and initialise
Expand Down Expand Up @@ -286,7 +286,7 @@ func GetPipeline[T pipelines.Pipeline](s *Session, name string) (T, error) {
func (s *Session) Destroy() error {
return errors.Join(
s.featureExtractionPipelines.Destroy(),
s.tokenClassificationPipelines.Destroy(),
// s.tokenClassificationPipelines.Destroy(),
s.textClassificationPipelines.Destroy(),
s.ortOptions.Destroy(),
ort.DestroyEnvironment(),
Expand All @@ -302,7 +302,8 @@ func (s *Session) Destroy() error {
// the average time per onnxruntime inference batch call
func (s *Session) GetStats() []string {
// slices.Concat() is not implemented in experimental x/exp/slices package
return append(append(s.tokenClassificationPipelines.GetStats(),
return append(append(
s.tokenClassificationPipelines.GetStats(),
s.textClassificationPipelines.GetStats()...),
s.featureExtractionPipelines.GetStats()...,
)
Expand Down
Loading

0 comments on commit 143dd73

Please sign in to comment.