Skip to content

Commit

Permalink
add mlfow code
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Oct 9, 2021
1 parent b3e97fe commit fb3d318
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ timm
nbnb
mmpycocotools
omegaconf
mlflow

182 changes: 182 additions & 0 deletions train_tl_mlfow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.

"""
Training script using custom coco format dataset
what you need to do is simply change the img_dir and annotation path here
Also define your own categories.
"""

import os
from datetime import timedelta
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.engine import (
DefaultTrainer,
default_argument_parser,
default_setup,
launch,
)
from detectron2.evaluation import COCOEvaluator
from detectron2.data import (
MetadataCatalog,
build_detection_train_loader,
DatasetCatalog,
)
from detectron2.data.datasets.coco import load_coco_json
from detectron2.data.dataset_mapper import DatasetMapper
from detectron2.data.datasets.coco import load_coco_json, register_coco_instances
from detectron2.utils import comm
from detectron2.engine import hooks, HookBase

from yolov7.config import add_yolo_config
from yolov7.data.dataset_mapper import MyDatasetMapper2

import mlflow

DATASET_ROOT = "./datasets/tl"
ANN_ROOT = os.path.join(DATASET_ROOT, "annotations")
TRAIN_PATH = os.path.join(DATASET_ROOT, "JPEGImages")
VAL_PATH = os.path.join(DATASET_ROOT, "JPEGImages")
TRAIN_JSON = os.path.join(ANN_ROOT, "annotations_coco_tls_train.json")
VAL_JSON = os.path.join(ANN_ROOT, "annotations_coco_tls_val.json")

register_coco_instances("tl_train", {}, TRAIN_JSON, TRAIN_PATH)
register_coco_instances("tl_val", {}, VAL_JSON, VAL_PATH)


class MLFlowSnapshotHook(HookBase):
"""
Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook.
Note that when used as a hook,
it is unable to save additional data other than what's defined
by the given `checkpointer`.
It is executed every ``period`` iterations and after the last iteration.
"""

def after_train(self):
final_model_path = f"{self.trainer.cfg.OUTPUT_DIR}/model_final.pth"
mlflow.log_artifact(final_model_path, "model")

best_iter = (7 - len(str(self.trainer.best_iter))) * "0" + str(
self.trainer.best_iter
)
best_model_path = f"{self.trainer.cfg.OUTPUT_DIR}/model_{best_iter}.pth"
new_path = f"{self.trainer.cfg.OUTPUT_DIR}/model_best.pth"
os.rename(best_model_path, new_path)
mlflow.log_artifact(new_path, "model")


class Trainer(DefaultTrainer):
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
return COCOEvaluator(dataset_name, output_dir=output_folder)

@classmethod
def build_train_loader(cls, cfg):
# return build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True))
# test our own dataset mapper to add more augmentations
return build_detection_train_loader(cfg, mapper=MyDatasetMapper2(cfg, True))

def build_hooks(self):
"""
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
list[HookBase]:
"""
cfg = self.cfg.clone()
cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN

ret = [
hooks.IterationTimer(),
hooks.LRScheduler(),
MLFlowSnapshotHook(),
]

if comm.is_main_process():
ret.append(
hooks.PeriodicCheckpointer(
self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD
)
)

def test_and_save_results():
self._last_eval_results = self.test(self.cfg, self.model)
results = self._last_eval_results["bbox"]
for k in results:
mlflow.log_metric(k, results[k], self.iter)
if results["AP"] > self.best_ap:
self.best_ap = results["AP"]
self.best_iter = self.iter
mlflow.log_metric("best_AP", self.best_ap, self.iter)
return self._last_eval_results

# Do evaluation after checkpointer, because then if it fails,
# we can use the saved checkpoint to debug.
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))

if comm.is_main_process():
# Here the default print/log frequency of each writer is used.
# run writers in the end, so that evaluation metrics are written
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
return ret


def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
add_yolo_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg


def main(args):
cfg = setup(args)

mlflow.set_experiment("traffic_light")
mlflow.start_run(run_name="yolox_s_tl")
mlflow.log_param("max_iter", cfg.SOLVER.MAX_ITER)
mlflow.log_param("images_per_batch", cfg.SOLVER.IMS_PER_BATCH)

if args.eval_only:
model = Trainer.build_model(cfg)
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
res = Trainer.test(cfg, model)
return res

trainer = Trainer(cfg)
trainer.best_ap = 0
trainer.best_iter = 0
trainer.resume_or_load(resume=args.resume)
return trainer.train()


if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
timeout=timedelta(50),
args=(args,),
)

0 comments on commit fb3d318

Please sign in to comment.