diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 816aaddbc..6b33af518 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -117,7 +117,7 @@ def main(args): ) args.batch_size = batch_size ort_sess = ort.InferenceSession( - model.SerializeToString(), + onnx_path, providers=["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"], @@ -154,7 +154,7 @@ def run_batch(path_imgs): imgs = np.array([im for _, im in path_imgs]) if args.onnx: - probs = ort_sess.run(None, {input_name: imgs}) # onnx output numpy + probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy else: probs = model(imgs, training=False) probs = probs.numpy()