Skip to content

Commit

Permalink
fix yolo train detail
Browse files Browse the repository at this point in the history
  • Loading branch information
jianzfb committed Dec 26, 2024
1 parent eec7c99 commit d0743e4
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion antgo/tools/third_part_model_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from antgo.framework.helper.utils.config import Config
import zlib
import subprocess
import shutil
import json
import yaml
import os


Expand Down Expand Up @@ -75,13 +77,19 @@ def yolo_model_train(exp_name, cfg, root, gpu_id, pretrained_model=None):

# 启动训练
data_path = data.get('path', None)
with open(data_path, 'r') as fp:
data_info = yaml.safe_load(fp)
data_info['path'] = os.path.dirname(data_path)
with open('./data.yaml', 'w') as fp:
yaml.safe_dump(data_info, fp)

data_imgsz = data.get('imgsz', 640)
batch_size = data.get('batch_size', 32)
workers = data.get('workers', 1)

device = [int(k) for k in gpu_id.split(',')]
results = model.train(
data=data_path,
data='./data.yaml',
epochs=cfg.get('max_epochs', 100),
imgsz=data_imgsz,
device=device,
Expand Down

0 comments on commit d0743e4

Please sign in to comment.