Skip to content

Commit

Permalink
support train.json output (#712)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuhongen1234567 authored Nov 22, 2024
1 parent 510fa7c commit 2a55a96
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 4 deletions.
119 changes: 119 additions & 0 deletions paddlevideo/tasks/save_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import json
import yaml
import paddle

from paddlevideo.utils import get_logger

logger = get_logger("paddlevideo")


def save_predict_result(save_path, result):
if os.path.splitext(save_path)[-1] == "":
if save_path[-1] == "/":
save_path = save_path[:-1]
save_path = save_path + ".json"
elif os.path.splitext(save_path)[-1] == ".json":
save_path = save_path
else:
raise Exception(
f"{save_path} is invalid input path, only files in json format are supported."
)

if os.path.exists(save_path):
logger.warning(f"The file {save_path} will be overwritten.")
with open(save_path, "w", encoding="utf-8") as f:
json.dump(result, f)


def update_train_results(
config, prefix, metric_info, done_flag=False, last_num=5, ema=False
):

if paddle.distributed.get_rank() != 0:
return

assert last_num >= 1
train_results_path = os.path.join(config["output_dir"], "train_result.json")
save_model_tag = ["pdparams", "pdopt", "pdstates"]
save_inference_tag = ["inference_config", "pdmodel", "pdiparams", "pdiparams.info"]
if ema:
save_model_tag.append("pdema")
if os.path.exists(train_results_path):
with open(train_results_path, "r") as fp:
train_results = json.load(fp)
else:
train_results = {}
train_results["model_name"] = config["Global"].get("pdx_model_name", None)
if config.get("Infer", None):
train_results["label_dict"] = config["Infer"]["PostProcess"].get(
"class_id_map_file", ""
)
else:
train_results["label_dict"] = ""
train_results["train_log"] = "train.log"
train_results["visualdl_log"] = ""
train_results["config"] = "config.yaml"
train_results["models"] = {}
for i in range(1, last_num + 1):
train_results["models"][f"last_{i}"] = {}
train_results["models"]["best"] = {}
train_results["done_flag"] = done_flag
if prefix == "best_model":
train_results["models"]["best"]["score"] = metric_info["metric"]
for tag in save_model_tag:
train_results["models"]["best"][tag] = os.path.join(
prefix, f"{prefix}.{tag}"
)
for tag in save_inference_tag:
train_results["models"]["best"][tag] = os.path.join(
prefix,
"inference",
f"inference.{tag}" if tag != "inference_config" else "inference.yml",
)
else:
for i in range(last_num - 1, 0, -1):
train_results["models"][f"last_{i + 1}"] = train_results["models"][
f"last_{i}"
].copy()
train_results["models"][f"last_{1}"]["score"] = metric_info["metric"]
for tag in save_model_tag:
train_results["models"][f"last_{1}"][tag] = os.path.join(
prefix, f"{prefix}.{tag}"
)
for tag in save_inference_tag:
train_results["models"][f"last_{1}"][tag] = os.path.join(
prefix,
"inference",
f"inference.{tag}" if tag != "inference_config" else "inference.yml",
)

with open(train_results_path, "w") as fp:
json.dump(train_results, fp)


def save_model_info(model_info, save_path, prefix):
"""
save model info to the target path
"""
if paddle.distributed.get_rank() != 0:
return
save_path = os.path.join(save_path, prefix)
if not os.path.exists(save_path):
os.makedirs(save_path)
with open(os.path.join(save_path, f"{prefix}.info.json"), "w") as f:
json.dump(model_info, f)
logger.info("Already save model info in {}".format(save_path))
23 changes: 19 additions & 4 deletions paddlevideo/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from ..utils import do_preciseBN
from tools.export_model import get_input_spec
from .download import get_weights_path_from_url
from .save_result import update_train_results, save_model_info


def _mkdir_if_not_exist(path, logger):
Expand Down Expand Up @@ -436,8 +437,8 @@ def evaluate(best):
):
best = record_list[top_flag].avg
best_flag = True

return best, best_flag
acc = record_list["top1"].avg
return best, best_flag, acc

# use precise bn to improve acc
if cfg.get("PRECISEBN") and (
Expand All @@ -457,7 +458,7 @@ def evaluate(best):
epoch % cfg.get("val_interval", 1) == 0 or epoch == cfg.epochs - 1
):
with paddle.no_grad():
best, save_best_flag = evaluate(best)
best, save_best_flag, acc = evaluate(best)
# save best
if save_best_flag:
save_student_model_flag = (
Expand All @@ -481,7 +482,7 @@ def evaluate(best):
model_path + ".pdparams",
save_student_model=save_student_model_flag,
)

metric_info = {"metric": acc, "epoch": epoch}
if uniform_output_enabled:
save_path = os.path.join(output_dir, prefix, "inference")
export(
Expand All @@ -491,6 +492,9 @@ def evaluate(best):
uniform_output_enabled=uniform_output_enabled,
logger=logger,
)

update_train_results(cfg, prefix, metric_info, ema=None)
save_model_info(metric_info, output_dir, prefix)
else:
save(
optimizer.state_dict(),
Expand Down Expand Up @@ -531,6 +535,7 @@ def evaluate(best):
# 10. Save model and optimizer
if epoch % cfg.get("save_interval", 1) == 0 or epoch == cfg.epochs - 1:
if cfg.get("Global") is not None:
metric_info = {"metric": acc, "epoch": epoch}
prefix = "epoch_{}".format(epoch)
model_path = osp.join(output_dir, prefix)
_mkdir_if_not_exist(model_path, logger)
Expand All @@ -547,6 +552,14 @@ def evaluate(best):
uniform_output_enabled=uniform_output_enabled,
logger=logger,
)
update_train_results(
cfg,
prefix,
metric_info,
done_flag=epoch == cfg["epochs"],
ema=None,
)
save_model_info(metric_info, output_dir, prefix)
else:
save(
optimizer.state_dict(),
Expand All @@ -559,6 +572,7 @@ def evaluate(best):
),
)
if cfg.get("Global") is not None:
metric_info = {"metric": acc, "epoch": epoch}
prefix = "latest"
model_path = osp.join(output_dir, prefix)
_mkdir_if_not_exist(model_path, logger)
Expand All @@ -575,6 +589,7 @@ def evaluate(best):
uniform_output_enabled=uniform_output_enabled,
logger=logger,
)
save_model_info(metric_info, output_dir, prefix)

logger.info(f"training {model_name} finished")

Expand Down

0 comments on commit 2a55a96

Please sign in to comment.