diff --git a/rdb/attrs.go b/rdb/attrs.go index 8e8b74e9..24139039 100644 --- a/rdb/attrs.go +++ b/rdb/attrs.go @@ -8,8 +8,6 @@ import ( // See https://github.com/c-bata/goptuna/issues/34 // for the reason why we need following code. -// Caution "_number" in trial_system_attributes must not be encoded. - func encodeAttrValue(xr string) string { return fmt.Sprintf("\"%s\"", base64.StdEncoding.EncodeToString([]byte(xr))) diff --git a/rdb/converter.go b/rdb/converter.go index 57e9b058..fac264f5 100644 --- a/rdb/converter.go +++ b/rdb/converter.go @@ -2,8 +2,6 @@ package rdb import ( "errors" - "fmt" - "strconv" "time" "github.com/c-bata/goptuna" @@ -21,13 +19,9 @@ func toFrozenTrial(trial trialModel) (goptuna.FrozenTrial, error) { systemAttrs := make(map[string]string, len(trial.SystemAttributes)) for i := range trial.SystemAttributes { - if trial.SystemAttributes[i].Key == keyNumber { - systemAttrs[trial.SystemAttributes[i].Key] = trial.SystemAttributes[i].ValueJSON - } else { - systemAttrs[trial.SystemAttributes[i].Key], err = decodeAttrValue(trial.SystemAttributes[i].ValueJSON) - if err != nil { - return goptuna.FrozenTrial{}, err - } + systemAttrs[trial.SystemAttributes[i].Key], err = decodeAttrValue(trial.SystemAttributes[i].ValueJSON) + if err != nil { + return goptuna.FrozenTrial{}, err } } @@ -52,15 +46,6 @@ func toFrozenTrial(trial trialModel) (goptuna.FrozenTrial, error) { } } - numberStr, ok := systemAttrs[keyNumber] - if !ok { - return goptuna.FrozenTrial{}, errors.New("number is not exist in system attrs") - } - number, err := strconv.Atoi(numberStr) - if err != nil { - return goptuna.FrozenTrial{}, fmt.Errorf("invalid trial number: %s", err) - } - state, err := toStateExternalRepresentation(trial.State) if err != nil { return goptuna.FrozenTrial{}, err @@ -82,7 +67,7 @@ func toFrozenTrial(trial trialModel) (goptuna.FrozenTrial, error) { return goptuna.FrozenTrial{ ID: trial.ID, StudyID: trial.TrialReferStudy, - Number: number, + Number: trial.Number, State: state, Value: trial.Value, IntermediateValues: intermediateValue, diff --git a/rdb/model.go b/rdb/model.go index 8e1a7275..64a34c30 100644 --- a/rdb/model.go +++ b/rdb/model.go @@ -61,6 +61,7 @@ func (m studySystemAttributeModel) TableName() string { type trialModel struct { ID int `gorm:"column:trial_id;PRIMARY_KEY"` + Number int `gorm:"column:number"` TrialReferStudy int `gorm:"column:study_id"` State string `gorm:"column:state;NOT NULL"` Value float64 `gorm:"column:value"` diff --git a/rdb/storage.go b/rdb/storage.go index a6949628..0f64f96b 100644 --- a/rdb/storage.go +++ b/rdb/storage.go @@ -2,7 +2,6 @@ package rdb import ( "fmt" - "strconv" "time" "github.com/c-bata/goptuna" @@ -12,8 +11,6 @@ import ( var _ goptuna.Storage = &Storage{} -const keyNumber = "_number" - // NewStorage returns new RDB storage. func NewStorage(db *gorm.DB) *Storage { return &Storage{ @@ -234,12 +231,9 @@ func (s *Storage) CreateNewTrial(studyID int) (int, error) { return -1, err } - // Set '_number' in trial_system_attributes. - err = tx.Create(&trialSystemAttributeModel{ - SystemAttributeReferTrial: trial.ID, - Key: keyNumber, - ValueJSON: strconv.Itoa(number), - }).Error + err = tx.Model(&trialModel{}). + Where("trial_id = ?", trial.ID). + Update("number", number).Error if err != nil { tx.Rollback() return -1, err @@ -313,9 +307,6 @@ func (s *Storage) CloneTrial(studyID int, baseTrial goptuna.FrozenTrial) (int, e // system attrs for key := range baseTrial.SystemAttrs { - if key == "_number" { - continue - } err := tx.Create(&trialSystemAttributeModel{ SystemAttributeReferTrial: trial.ID, Key: key, @@ -364,11 +355,10 @@ func (s *Storage) CloneTrial(studyID int, baseTrial goptuna.FrozenTrial) (int, e tx.Rollback() return -1, err } - err = tx.Create(&trialSystemAttributeModel{ - SystemAttributeReferTrial: trial.ID, - Key: keyNumber, - ValueJSON: strconv.Itoa(number), - }).Error + + err = tx.Model(&trialModel{}). + Where("trial_id = ?", trial.ID). + Update("number", number).Error if err != nil { tx.Rollback() return -1, err @@ -579,13 +569,11 @@ func (s *Storage) SetTrialSystemAttr(trialID int, key string, value string) erro // GetTrialNumberFromID returns the trial's number. func (s *Storage) GetTrialNumberFromID(trialID int) (int, error) { - var attr trialSystemAttributeModel - err := s.db.First(&attr, "trial_id = ? AND key = ?", trialID, keyNumber).Error + trial, err := s.GetTrial(trialID) if err != nil { return -1, err } - number, err := strconv.Atoi(attr.ValueJSON) - return number, err + return trial.Number, err } // GetTrialParam returns the internal parameter of the trial