Skip to content

Commit

Permalink
Update train(hyp, *args) to accept hyp file or dict (#3668)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored Jun 17, 2021
1 parent 6d6e2ca commit fa201f9
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,11 @@
logger = logging.getLogger(__name__)


def train(hyp,
def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt,
device,
tb_writer=None
):
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
opt.single_cls
Expand All @@ -56,6 +55,12 @@ def train(hyp,
best = wdir / 'best.pt'
results_file = save_dir / 'results.txt'

# Hyperparameters
if isinstance(hyp, str):
with open(hyp) as f:
hyp = yaml.safe_load(f) # load hyps dict
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))

# Save run settings
with open(save_dir / 'hyp.yaml', 'w') as f:
yaml.safe_dump(hyp, f, sort_keys=False)
Expand Down Expand Up @@ -529,10 +534,6 @@ def train(hyp,
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
opt.batch_size = opt.total_batch_size // opt.world_size

# Hyperparameters
with open(opt.hyp) as f:
hyp = yaml.safe_load(f) # load hyps

# Train
logger.info(opt)
if not opt.evolve:
Expand All @@ -541,7 +542,7 @@ def train(hyp,
prefix = colorstr('tensorboard: ')
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
train(hyp, opt, device, tb_writer)
train(opt.hyp, opt, device, tb_writer)

# Evolve hyperparameters (optional)
else:
Expand Down Expand Up @@ -575,6 +576,8 @@ def train(hyp,
'mosaic': (1, 0.0, 1.0), # image mixup (probability)
'mixup': (1, 0.0, 1.0)} # image mixup (probability)

with open(opt.hyp) as f:
hyp = yaml.safe_load(f) # load hyps dict
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
opt.notest, opt.nosave = True, True # only test/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
Expand Down

0 comments on commit fa201f9

Please sign in to comment.