Skip to content

Commit

Permalink
Merge pull request #105 from c-bata/blackhole-storage
Browse files Browse the repository at this point in the history
Add BlackHoleStorage towards 100k+ evaluations
  • Loading branch information
c-bata authored Apr 13, 2020
2 parents 04f48e3 + 5f4901e commit a9d1276
Show file tree
Hide file tree
Showing 7 changed files with 662 additions and 4 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/run-examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ jobs:
GO111MODULE: on
run: |
make build
./bin/cmaes
./bin/cmaes_blackhole
./bin/simple_tpe
./bin/concurrency
./bin/trialnotify
5 changes: 5 additions & 0 deletions _examples/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ DIR=$(cd $(dirname $0); pwd)
BIN_DIR=$(cd $(dirname $(dirname $0)); pwd)/bin

mkdir -p ${BIN_DIR}

set -ex

go build -o ${BIN_DIR}/cmaes ${DIR}/cmaes/main.go
go build -o ${BIN_DIR}/cmaes_blackhole ${DIR}/cmaes/blackhole/main.go
go build -o ${BIN_DIR}/concurrency ${DIR}/concurrency/main.go
go build -o ${BIN_DIR}/enqueue_trial ${DIR}/enqueue_trial/main.go
go build -o ${BIN_DIR}/trialnotify ${DIR}/trialnotify/main.go
Expand Down
59 changes: 59 additions & 0 deletions _examples/cmaes/blackhole/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package main

import (
"log"
"math"

"github.com/c-bata/goptuna"
"github.com/c-bata/goptuna/cmaes"
)

func objective(trial goptuna.Trial) (float64, error) {
x1, err := trial.SuggestFloat("x1", -10, 10)
if err != nil {
return -1, err
}
x2, err := trial.SuggestFloat("x2", -10, 10)
if err != nil {
return -1, err
}
return math.Pow(x1-2, 2) + math.Pow(x2+5, 2), nil
}

func main() {
relativeSampler := cmaes.NewSampler(
cmaes.SamplerOptionNStartupTrials(0))
study, err := goptuna.CreateStudy(
"goptuna-example",
goptuna.StudyOptionStorage(goptuna.NewBlackHoleStorage(20)),
goptuna.StudyOptionRelativeSampler(relativeSampler),
goptuna.StudyOptionDefineSearchSpace(map[string]interface{}{
"x1": goptuna.UniformDistribution{
High: 10,
Low: -10,
},
"x2": goptuna.UniformDistribution{
High: 10,
Low: -10,
},
}),
)
if err != nil {
log.Fatal("failed to create study:", err)
}

if err = study.Optimize(objective, 10000); err != nil {
log.Fatal("failed to optimize:", err)
}

v, err := study.GetBestValue()
if err != nil {
log.Fatal("failed to get best value:", err)
}
params, err := study.GetBestParams()
if err != nil {
log.Fatal("failed to get best params:", err)
}
log.Printf("Best evaluation=%f (x1=%f, x2=%f)",
v, params["x1"].(float64), params["x2"].(float64))
}
15 changes: 13 additions & 2 deletions cmaes/sampler.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (s *Sampler) SampleRelative(
sort.Strings(orderedKeys)

trials, err := study.GetTrials()
if err != nil {
if err != nil && err != goptuna.ErrTrialsPartiallyDeleted {
return nil, err
}
completed := make([]goptuna.FrozenTrial, 0, len(trials))
Expand All @@ -57,8 +57,19 @@ func (s *Sampler) SampleRelative(
}
}
if len(completed) < s.nStartUpTrials {
return nil, nil
// If catch ErrTrialsPartiallyDeleted, nStartUpTrials should be smaller than len(completed).
study.GetLogger().Error("Your BlackHoleStorage buffer is too small.",
fmt.Sprintf("nStartUpTrials:%d", s.nStartUpTrials))
return nil, err
}
if err == goptuna.ErrTrialsPartiallyDeleted && s.optimizer != nil &&
len(completed) < s.optimizer.PopulationSize() {
// If catch ErrTrialsPartiallyDeleted, population size should be smaller than len(completed).
study.GetLogger().Error("Your BlackHoleStorage buffer is too small.",
fmt.Sprintf("popsize:%d", s.optimizer.PopulationSize()))
return nil, err
}
err = nil

if s.optimizer == nil {
s.optimizer, err = s.initOptimizer(searchSpace, orderedKeys)
Expand Down
6 changes: 5 additions & 1 deletion sampler.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ func IntersectionSearchSpace(study *Study) (map[string]interface{}, error) {
var searchSpace map[string]interface{}

trials, err := study.GetTrials()
if err != nil {
if err == ErrTrialsPartiallyDeleted {
study.logger.Warn("Some trials are not used to calculate intersection of search spaces." +
" Please use `goptuna.StudyOptionDefineSearchSpace` option.")
err = nil
} else if err != nil {
return nil, err
}

Expand Down
Loading

0 comments on commit a9d1276

Please sign in to comment.