Skip to content

Commit

Permalink
Merge pull request #88 from c-bata/add-number-field2
Browse files Browse the repository at this point in the history
Add number field in trials table.
  • Loading branch information
c-bata authored Mar 11, 2020
2 parents d9758c3 + d6b1c67 commit 92ae970
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 42 deletions.
2 changes: 0 additions & 2 deletions rdb/attrs.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import (
// See https://github.com/c-bata/goptuna/issues/34
// for the reason why we need following code.

// Caution "_number" in trial_system_attributes must not be encoded.

func encodeAttrValue(xr string) string {
return fmt.Sprintf("\"%s\"",
base64.StdEncoding.EncodeToString([]byte(xr)))
Expand Down
23 changes: 4 additions & 19 deletions rdb/converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package rdb

import (
"errors"
"fmt"
"strconv"
"time"

"github.com/c-bata/goptuna"
Expand All @@ -21,13 +19,9 @@ func toFrozenTrial(trial trialModel) (goptuna.FrozenTrial, error) {

systemAttrs := make(map[string]string, len(trial.SystemAttributes))
for i := range trial.SystemAttributes {
if trial.SystemAttributes[i].Key == keyNumber {
systemAttrs[trial.SystemAttributes[i].Key] = trial.SystemAttributes[i].ValueJSON
} else {
systemAttrs[trial.SystemAttributes[i].Key], err = decodeAttrValue(trial.SystemAttributes[i].ValueJSON)
if err != nil {
return goptuna.FrozenTrial{}, err
}
systemAttrs[trial.SystemAttributes[i].Key], err = decodeAttrValue(trial.SystemAttributes[i].ValueJSON)
if err != nil {
return goptuna.FrozenTrial{}, err
}
}

Expand All @@ -52,15 +46,6 @@ func toFrozenTrial(trial trialModel) (goptuna.FrozenTrial, error) {
}
}

numberStr, ok := systemAttrs[keyNumber]
if !ok {
return goptuna.FrozenTrial{}, errors.New("number is not exist in system attrs")
}
number, err := strconv.Atoi(numberStr)
if err != nil {
return goptuna.FrozenTrial{}, fmt.Errorf("invalid trial number: %s", err)
}

state, err := toStateExternalRepresentation(trial.State)
if err != nil {
return goptuna.FrozenTrial{}, err
Expand All @@ -82,7 +67,7 @@ func toFrozenTrial(trial trialModel) (goptuna.FrozenTrial, error) {
return goptuna.FrozenTrial{
ID: trial.ID,
StudyID: trial.TrialReferStudy,
Number: number,
Number: trial.Number,
State: state,
Value: trial.Value,
IntermediateValues: intermediateValue,
Expand Down
1 change: 1 addition & 0 deletions rdb/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func (m studySystemAttributeModel) TableName() string {

type trialModel struct {
ID int `gorm:"column:trial_id;PRIMARY_KEY"`
Number int `gorm:"column:number"`
TrialReferStudy int `gorm:"column:study_id"`
State string `gorm:"column:state;NOT NULL"`
Value float64 `gorm:"column:value"`
Expand Down
30 changes: 9 additions & 21 deletions rdb/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package rdb

import (
"fmt"
"strconv"
"time"

"github.com/c-bata/goptuna"
Expand All @@ -12,8 +11,6 @@ import (

var _ goptuna.Storage = &Storage{}

const keyNumber = "_number"

// NewStorage returns new RDB storage.
func NewStorage(db *gorm.DB) *Storage {
return &Storage{
Expand Down Expand Up @@ -234,12 +231,9 @@ func (s *Storage) CreateNewTrial(studyID int) (int, error) {
return -1, err
}

// Set '_number' in trial_system_attributes.
err = tx.Create(&trialSystemAttributeModel{
SystemAttributeReferTrial: trial.ID,
Key: keyNumber,
ValueJSON: strconv.Itoa(number),
}).Error
err = tx.Model(&trialModel{}).
Where("trial_id = ?", trial.ID).
Update("number", number).Error
if err != nil {
tx.Rollback()
return -1, err
Expand Down Expand Up @@ -313,9 +307,6 @@ func (s *Storage) CloneTrial(studyID int, baseTrial goptuna.FrozenTrial) (int, e

// system attrs
for key := range baseTrial.SystemAttrs {
if key == "_number" {
continue
}
err := tx.Create(&trialSystemAttributeModel{
SystemAttributeReferTrial: trial.ID,
Key: key,
Expand Down Expand Up @@ -364,11 +355,10 @@ func (s *Storage) CloneTrial(studyID int, baseTrial goptuna.FrozenTrial) (int, e
tx.Rollback()
return -1, err
}
err = tx.Create(&trialSystemAttributeModel{
SystemAttributeReferTrial: trial.ID,
Key: keyNumber,
ValueJSON: strconv.Itoa(number),
}).Error

err = tx.Model(&trialModel{}).
Where("trial_id = ?", trial.ID).
Update("number", number).Error
if err != nil {
tx.Rollback()
return -1, err
Expand Down Expand Up @@ -579,13 +569,11 @@ func (s *Storage) SetTrialSystemAttr(trialID int, key string, value string) erro

// GetTrialNumberFromID returns the trial's number.
func (s *Storage) GetTrialNumberFromID(trialID int) (int, error) {
var attr trialSystemAttributeModel
err := s.db.First(&attr, "trial_id = ? AND key = ?", trialID, keyNumber).Error
trial, err := s.GetTrial(trialID)
if err != nil {
return -1, err
}
number, err := strconv.Atoi(attr.ValueJSON)
return number, err
return trial.Number, err
}

// GetTrialParam returns the internal parameter of the trial
Expand Down

0 comments on commit 92ae970

Please sign in to comment.