Skip to content

Commit

Permalink
W&B sweeps support (ultralytics#3938)
Browse files Browse the repository at this point in the history
* Add support for W&B Sweeps

* Update and reformat

* Update search space

* reformat

* reformat sweep.py

* Update sweep.py

* Move sweeps files to wandb dir

* Remove print

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
AyushExel and glenn-jocher authored Jul 14, 2021
1 parent e914a8a commit 7ac8372
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 1 deletion.
33 changes: 33 additions & 0 deletions utils/wandb_logging/sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import sys
from pathlib import Path
import wandb

FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[2].as_posix()) # add utils/ to path

from train import train, parse_opt
import test
from utils.general import increment_path
from utils.torch_utils import select_device


def sweep():
wandb.init()
# Get hyp dict from sweep agent
hyp_dict = vars(wandb.config).get("_items")

# Workaround: get necessary opt args
opt = parse_opt(known=True)
opt.batch_size = hyp_dict.get("batch_size")
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve))
opt.epochs = hyp_dict.get("epochs")
opt.nosave = True
opt.data = hyp_dict.get("data")
device = select_device(opt.device, batch_size=opt.batch_size)

# train
train(hyp_dict, opt, device)


if __name__ == "__main__":
sweep()
143 changes: 143 additions & 0 deletions utils/wandb_logging/sweep.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Hyperparameters for training
# To set range-
# Provide min and max values as:
# parameter:
#
# min: scalar
# max: scalar
# OR
#
# Set a specific list of search space-
# parameter:
# values: [scalar1, scalar2, scalar3...]
#
# You can use grid, bayesian and hyperopt search strategy
# For more info on configuring sweeps visit - https://docs.wandb.ai/guides/sweeps/configuration

program: utils/wandb_logging/sweep.py
method: random
metric:
name: metrics/mAP_0.5
goal: maximize

parameters:
# hyperparameters: set either min, max range or values list
data:
value: "data/coco128.yaml"
batch_size:
values: [ 64 ]
epochs:
values: [ 10 ]

lr0:
distribution: uniform
min: 1e-5
max: 1e-1
lrf:
distribution: uniform
min: 0.01
max: 1.0
momentum:
distribution: uniform
min: 0.6
max: 0.98
weight_decay:
distribution: uniform
min: 0.0
max: 0.001
warmup_epochs:
distribution: uniform
min: 0.0
max: 5.0
warmup_momentum:
distribution: uniform
min: 0.0
max: 0.95
warmup_bias_lr:
distribution: uniform
min: 0.0
max: 0.2
box:
distribution: uniform
min: 0.02
max: 0.2
cls:
distribution: uniform
min: 0.2
max: 4.0
cls_pw:
distribution: uniform
min: 0.5
max: 2.0
obj:
distribution: uniform
min: 0.2
max: 4.0
obj_pw:
distribution: uniform
min: 0.5
max: 2.0
iou_t:
distribution: uniform
min: 0.1
max: 0.7
anchor_t:
distribution: uniform
min: 2.0
max: 8.0
fl_gamma:
distribution: uniform
min: 0.0
max: 0.1
hsv_h:
distribution: uniform
min: 0.0
max: 0.1
hsv_s:
distribution: uniform
min: 0.0
max: 0.9
hsv_v:
distribution: uniform
min: 0.0
max: 0.9
degrees:
distribution: uniform
min: 0.0
max: 45.0
translate:
distribution: uniform
min: 0.0
max: 0.9
scale:
distribution: uniform
min: 0.0
max: 0.9
shear:
distribution: uniform
min: 0.0
max: 10.0
perspective:
distribution: uniform
min: 0.0
max: 0.001
flipud:
distribution: uniform
min: 0.0
max: 1.0
fliplr:
distribution: uniform
min: 0.0
max: 1.0
mosaic:
distribution: uniform
min: 0.0
max: 1.0
mixup:
distribution: uniform
min: 0.0
max: 1.0
copy_paste:
distribution: uniform
min: 0.0
max: 1.0
2 changes: 1 addition & 1 deletion utils/wandb_logging/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def setup_training(self, opt, data_dict):
self.weights = Path(modeldir) / "last.pt"
config = self.wandb_run.config
opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(
self.weights), config.save_period, config.total_batch_size, config.bbox_interval, config.epochs, \
self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \
config.opt['hyp']
data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download
Expand Down

0 comments on commit 7ac8372

Please sign in to comment.