Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions .mockery.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
dir: "{{.PackageName}}/testing"
filename: "{{.InterfaceName}}_mock_gen.go"
mockname: "Mock{{.InterfaceName}}"
outpkg: "{{.PackageName}}testing"
replace-type:
- "log=github.com/symflower/eval-dev-quality/log"
with-expecter: false
dir: '{{.PackageName}}/testing'
filename: '{{.InterfaceName}}_mock_gen.go'
mockname: 'Mock{{.InterfaceName}}'
outpkg: '{{.PackageName}}testing'

packages:
github.com/symflower/eval-dev-quality/language:
interfaces:
Expand All @@ -12,6 +15,6 @@ packages:
Model:
github.com/symflower/eval-dev-quality/provider:
interfaces:
Loader:
Provider:
Query:
replace-type: 'log=github.com/symflower/eval-dev-quality/log'
7 changes: 5 additions & 2 deletions cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ func (command *Evaluate) Execute(args []string) (err error) {

// Gather models.
modelsSelected := map[string]model.Model{}
providerForModel := map[model.Model]provider.Provider{}
{
models := map[string]model.Model{}
for _, p := range provider.Providers {
Expand Down Expand Up @@ -204,6 +205,7 @@ func (command *Evaluate) Execute(args []string) (err error) {

for _, m := range ms {
models[m.ID()] = m
providerForModel[m] = p
}
}
modelIDs := maps.Keys(models)
Expand Down Expand Up @@ -249,8 +251,9 @@ func (command *Evaluate) Execute(args []string) (err error) {

Languages: ls,

Models: ms,
QueryAttempts: command.QueryAttempts,
Models: ms,
ProviderForModel: providerForModel,
QueryAttempts: command.QueryAttempts,

RepositoryPaths: command.Repositories,
ResultPath: command.ResultPath,
Expand Down
2 changes: 2 additions & 0 deletions cmd/eval-dev-quality/cmd/evaluate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,8 @@ func TestEvaluateExecute(t *testing.T) {
// Since the model is non-deterministic, we can only assert that the model did at least not error.
assert.Contains(t, data, `Evaluation score for "ollama/qwen:0.5b"`)
assert.Contains(t, data, "response-no-error=1")
assert.Contains(t, data, "preloading model")
assert.Contains(t, data, "unloading model")
},
"golang-summed.csv": nil,
"models-summed.csv": nil,
Expand Down
96 changes: 60 additions & 36 deletions evaluate/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/symflower/eval-dev-quality/evaluate/report"
evallanguage "github.com/symflower/eval-dev-quality/language"
evalmodel "github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/provider"
)

// Context holds an evaluation context.
Expand All @@ -20,6 +21,8 @@ type Context struct {

// Models determines which models should be used for the evaluation, or empty if all models should be used.
Models []evalmodel.Model
// ProviderForModel holds the models and their associated provider.
ProviderForModel map[evalmodel.Model]provider.Provider
// QueryAttempts holds the number of query attempts to perform when a model request errors in the process of solving a task.
QueryAttempts uint

Expand Down Expand Up @@ -101,27 +104,29 @@ func Evaluate(ctx *Context) (assessments report.AssessmentPerModelPerLanguagePer
r.SetQueryAttempts(ctx.QueryAttempts)
}

for rm := uint(0); rm < ctx.runsAtModelLevel(); rm++ {
if err := ResetTemporaryRepository(ctx.Log, temporaryRepositoryPath); err != nil {
ctx.Log.Panicf("ERROR: unable to reset temporary repository path: %s", err)
withLoadedModel(ctx.Log, model, ctx.ProviderForModel[model], func() {
for rm := uint(0); rm < ctx.runsAtModelLevel(); rm++ {
if err := ResetTemporaryRepository(ctx.Log, temporaryRepositoryPath); err != nil {
ctx.Log.Panicf("ERROR: unable to reset temporary repository path: %s", err)
}

if ctx.Runs > 1 && ctx.RunsSequential {
ctx.Log.Printf("Run %d/%d for model %q", rm+1, ctx.Runs, modelID)
}

assessment, ps, err := Repository(ctx.Log, ctx.ResultPath, model, language, temporaryRepositoryPath, repositoryPath)
assessments[model][language][repositoryPath].Add(assessment)
if err != nil {
ps = append(ps, err)
}
if len(ps) > 0 {
ctx.Log.Printf("Model %q was not able to solve the %q repository for language %q: %+v", modelID, repositoryPath, languageID, ps)
problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...)
} else {
modelSucceededBasicChecksOfLanguage[model][language] = true
}
}

if ctx.Runs > 1 && ctx.RunsSequential {
ctx.Log.Printf("Run %d/%d for model %q", rm+1, ctx.Runs, modelID)
}

assessment, ps, err := Repository(ctx.Log, ctx.ResultPath, model, language, temporaryRepositoryPath, repositoryPath)
assessments[model][language][repositoryPath].Add(assessment)
if err != nil {
ps = append(ps, err)
}
if len(ps) > 0 {
ctx.Log.Printf("Model %q was not able to solve the %q repository for language %q: %+v", modelID, repositoryPath, languageID, ps)
problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...)
} else {
modelSucceededBasicChecksOfLanguage[model][language] = true
}
}
})
}
}
}
Expand Down Expand Up @@ -189,23 +194,24 @@ func Evaluate(ctx *Context) (assessments report.AssessmentPerModelPerLanguagePer

continue
}

for rm := uint(0); rm < ctx.runsAtModelLevel(); rm++ {
if ctx.Runs > 1 && ctx.RunsSequential {
ctx.Log.Printf("Run %d/%d for model %q", rm+1, ctx.Runs, modelID)
}

if err := ResetTemporaryRepository(ctx.Log, temporaryRepositoryPath); err != nil {
ctx.Log.Panicf("ERROR: unable to reset temporary repository path: %s", err)
}

assessment, ps, err := Repository(ctx.Log, ctx.ResultPath, model, language, temporaryRepositoryPath, repositoryPath)
assessments[model][language][repositoryPath].Add(assessment)
problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...)
if err != nil {
ctx.Log.Printf("ERROR: Model %q encountered a hard error for language %q, repository %q: %+v", modelID, languageID, repositoryPath, err)
withLoadedModel(ctx.Log, model, ctx.ProviderForModel[model], func() {
for rm := uint(0); rm < ctx.runsAtModelLevel(); rm++ {
if ctx.Runs > 1 && ctx.RunsSequential {
ctx.Log.Printf("Run %d/%d for model %q", rm+1, ctx.Runs, modelID)
}

if err := ResetTemporaryRepository(ctx.Log, temporaryRepositoryPath); err != nil {
ctx.Log.Panicf("ERROR: unable to reset temporary repository path: %s", err)
}

assessment, ps, err := Repository(ctx.Log, ctx.ResultPath, model, language, temporaryRepositoryPath, repositoryPath)
assessments[model][language][repositoryPath].Add(assessment)
problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...)
if err != nil {
ctx.Log.Printf("ERROR: Model %q encountered a hard error for language %q, repository %q: %+v", modelID, languageID, repositoryPath, err)
}
}
}
})
}
}
}
Expand All @@ -226,3 +232,21 @@ func Evaluate(ctx *Context) (assessments report.AssessmentPerModelPerLanguagePer

return assessments, totalScore
}

// withLoadedModel loads the model for the duration of the given task if supported by the model's provider.
func withLoadedModel(log *log.Logger, model evalmodel.Model, modelProvider provider.Provider, task func()) {
if loader, ok := modelProvider.(provider.Loader); ok {
log.Printf("preloading model %q", model.ID())
if err := loader.Load(model.ID()); err != nil {
log.Panicf("ERROR: could not load model %q with provider %q", model.ID(), modelProvider.ID())
}
defer func() {
log.Printf("unloading model %q", model.ID())
if err := loader.Unload(model.ID()); err != nil {
log.Panicf("ERROR: could not unload model %q with provider %q", model.ID(), modelProvider.ID())
}
}()
}

task()
}
149 changes: 149 additions & 0 deletions evaluate/evaluate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ import (
"github.com/symflower/eval-dev-quality/language"
"github.com/symflower/eval-dev-quality/language/golang"
"github.com/symflower/eval-dev-quality/log"
"github.com/symflower/eval-dev-quality/model"
evalmodel "github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/model/llm"
modeltesting "github.com/symflower/eval-dev-quality/model/testing"
"github.com/symflower/eval-dev-quality/provider"
providertesting "github.com/symflower/eval-dev-quality/provider/testing"
)

Expand Down Expand Up @@ -620,4 +622,151 @@ func TestEvaluate(t *testing.T) {
})
}
})

t.Run("Preloading", func(t *testing.T) {
generateTestsForFilePlainSuccess := func(args mock.Arguments) {
require.NoError(t, os.WriteFile(filepath.Join(args.String(2), "plain_test.go"), []byte("package plain\nimport \"testing\"\nfunc TestFunction(t *testing.T){}"), 0600))
}
generateTestsForFilePlainSuccessMetrics := metrics.Assessments{
metrics.AssessmentKeyProcessingTime: 1,
}
generateSuccess := func(mockedModel *modeltesting.MockModel) {
mockedModel.On("GenerateTestsForFile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(generateTestsForFilePlainSuccessMetrics, nil).Run(generateTestsForFilePlainSuccess)
}

{
// Setup provider and model mocking.
languageGolang := &golang.Language{}
mockedModelID := "testing-provider/testing-model"
mockedModel := modeltesting.NewMockModelNamed(t, mockedModelID)
mockedProviderID := "testing-provider"
mockedProvider := providertesting.NewMockProviderNamedWithModels(t, mockedProviderID, []model.Model{mockedModel})
mockedLoader := providertesting.NewMockLoader(t)
embeddedProvider := &struct {
provider.Provider
provider.Loader
}{
Provider: mockedProvider,
Loader: mockedLoader,
}
repositoryPath := filepath.Join("golang", "plain")

validate(t, &testCase{
Name: "Once for combined runs",

Before: func(t *testing.T, logger *log.Logger, resultPath string) {
generateSuccess(mockedModel)
mockedLoader.On("Load", mockedModelID).Return(nil)
mockedLoader.On("Unload", mockedModelID).Return(nil)
},
After: func(t *testing.T, logger *log.Logger, resultPath string) {
delete(provider.Providers, mockedProviderID)

mockedLoader.AssertNumberOfCalls(t, "Load", 1)
mockedLoader.AssertNumberOfCalls(t, "Unload", 1)
},

Context: &Context{
Languages: []language.Language{
languageGolang,
},

Models: []evalmodel.Model{
mockedModel,
},
ProviderForModel: map[evalmodel.Model]provider.Provider{
mockedModel: embeddedProvider,
},

RepositoryPaths: []string{
repositoryPath,
},

Runs: 3,
RunsSequential: true,
},

ExpectedAssessments: map[evalmodel.Model]map[language.Language]map[string]metrics.Assessments{
mockedModel: map[language.Language]map[string]metrics.Assessments{
languageGolang: map[string]metrics.Assessments{
repositoryPath: map[metrics.AssessmentKey]uint64{
metrics.AssessmentKeyFilesExecuted: 3,
metrics.AssessmentKeyResponseNoError: 3,
},
},
},
},
ExpectedTotalScore: 3,
ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){
filepath.Join(evalmodel.CleanModelNameForFileSystem(mockedModelID), "golang", "golang", "plain.log"): nil,
},
})
}
{
// Setup provider and model mocking.
languageGolang := &golang.Language{}
mockedModelID := "testing-provider/testing-model"
mockedModel := modeltesting.NewMockModelNamed(t, mockedModelID)
mockedProviderID := "testing-provider"
mockedProvider := providertesting.NewMockProviderNamedWithModels(t, mockedProviderID, []model.Model{mockedModel})
mockedLoader := providertesting.NewMockLoader(t)
embeddedProvider := &struct {
provider.Provider
provider.Loader
}{
Provider: mockedProvider,
Loader: mockedLoader,
}
repositoryPath := filepath.Join("golang", "plain")
validate(t, &testCase{
Name: "Multiple times for interleaved runs",

Before: func(t *testing.T, logger *log.Logger, resultPath string) {
generateSuccess(mockedModel)
mockedLoader.On("Load", mockedModelID).Return(nil)
mockedLoader.On("Unload", mockedModelID).Return(nil)
},
After: func(t *testing.T, logger *log.Logger, resultPath string) {
delete(provider.Providers, "testing-provider")

mockedLoader.AssertNumberOfCalls(t, "Load", 3)
mockedLoader.AssertNumberOfCalls(t, "Unload", 3)
},

Context: &Context{
Languages: []language.Language{
languageGolang,
},

Models: []evalmodel.Model{
mockedModel,
},
ProviderForModel: map[evalmodel.Model]provider.Provider{
mockedModel: embeddedProvider,
},

RepositoryPaths: []string{
repositoryPath,
},

Runs: 3,
},

ExpectedAssessments: map[evalmodel.Model]map[language.Language]map[string]metrics.Assessments{
mockedModel: map[language.Language]map[string]metrics.Assessments{
languageGolang: map[string]metrics.Assessments{
repositoryPath: map[metrics.AssessmentKey]uint64{
metrics.AssessmentKeyFilesExecuted: 3,
metrics.AssessmentKeyResponseNoError: 3,
},
},
},
},
ExpectedTotalScore: 3,
ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){
filepath.Join(evalmodel.CleanModelNameForFileSystem(mockedModelID), "golang", "golang", "plain.log"): nil,
},
})
}
})
}
14 changes: 14 additions & 0 deletions provider/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,21 @@ func (p *Provider) client() (client *openai.Client) {
return openai.NewClientWithConfig(config)
}

var _ provider.Service = (*Provider)(nil)

// Start starts necessary background services to use this provider and returns a shutdown function.
func (p *Provider) Start(logger *log.Logger) (shutdown func() (err error), err error) {
return tools.OllamaStart(logger, p.binaryPath, p.url)
}

var _ provider.Loader = (*Provider)(nil)

// Load loads the given model.
func (p *Provider) Load(modelIdentifier string) error {
return tools.OllamaLoad(p.url, strings.TrimPrefix(modelIdentifier, p.ID()+provider.ProviderModelSeparator))
}

// Unload unloads the given model.
func (p *Provider) Unload(modelIdentifier string) error {
return tools.OllamaUnload(p.url, strings.TrimPrefix(modelIdentifier, p.ID()+provider.ProviderModelSeparator))
}
8 changes: 8 additions & 0 deletions provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,11 @@ type Service interface {
// Start starts necessary background services to use this provider and returns a shutdown function.
Start(logger *log.Logger) (shutdown func() (err error), err error)
}

// Loader is a provider that is able to load and unload models.
type Loader interface {
// Load loads the given model.
Load(modelIdentifier string) error
// Unload unloads the given model.
Unload(modelIdentifier string) error
}
Loading