Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

基于mnn框架推理 #73

Merged
merged 4 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
基于mnn框架推理
  • Loading branch information
zjkhahah committed Sep 7, 2024
commit 5263639b58a831df5ea94856076ee3ca77a9c9cb
34 changes: 33 additions & 1 deletion hivision/creator/human_matting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from .context import Context
import cv2
import os

import MNN.expr as expr
import MNN.nn as nn

WEIGHTS = {
"hivision_modnet": os.path.join(
Expand All @@ -25,6 +26,11 @@
"weights",
"modnet_photographic_portrait_matting.onnx",
),
"mnn_hivision_modnet": os.path.join(
os.path.dirname(__file__),
"weights",
"mnn_hivision_modnet.mnn",
)
}


Expand All @@ -40,6 +46,27 @@ def extract_human(ctx: Context):
ctx.matting_image = ctx.processing_image.copy()


def get_mnn_modnet_matting(input_image, checkpoint_path, ref_size=512):
config = {}
config['precision'] = 'low' # 当硬件支持(armv8.2)时使用fp16推理
config['backend'] = 0 # CPU
config['numThread'] = 4 # 线程数
im, width, length = read_modnet_image(input_image, ref_size=512)
rt = nn.create_runtime_manager((config,))
net = nn.load_module_from_file(checkpoint_path, ['input1'], ['output1'], runtime_manager=rt)
input_var = expr.convert(im, expr.NCHW)
output_var = net.forward(input_var)
matte = expr.convert(output_var, expr.NCHW)
matte = matte.read()#var转换为np
matte = (matte * 255).astype("uint8")
matte = np.squeeze(matte)
mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA)
b, g, r = cv2.split(np.uint8(input_image))

output_image = cv2.merge((b, g, r, mask))

return output_image

def extract_human_modnet_photographic_portrait_matting(ctx: Context):
"""
人像抠图
Expand All @@ -53,6 +80,11 @@ def extract_human_modnet_photographic_portrait_matting(ctx: Context):
ctx.processing_image = hollow_out_fix(matting_image)
ctx.matting_image = ctx.processing_image.copy()

def extract_human_mnn_modnet(ctx: Context):
matting_image = get_mnn_modnet_matting(ctx.processing_image, WEIGHTS["mnn_hivision_modnet"])
ctx.processing_image = hollow_out_fix(matting_image)
ctx.matting_image = ctx.processing_image.copy()


def hollow_out_fix(src: np.ndarray) -> np.ndarray:
"""
Expand Down
Binary file added hivision/creator/weights/mnn_hivision_modnet.mnn
Binary file not shown.
5 changes: 4 additions & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from hivision.creator.human_matting import (
extract_human_modnet_photographic_portrait_matting,
extract_human,
extract_human_mnn_modnet,
)

parser = argparse.ArgumentParser(description="HivisionIDPhotos 证件照制作推理程序。")
Expand All @@ -24,7 +25,7 @@
"add_background",
"generate_layout_photos",
]
MATTING_MODEL = ["hivision_modnet", "modnet_photographic_portrait_matting"]
MATTING_MODEL = ["hivision_modnet", "modnet_photographic_portrait_matting", "mnn_hivision_modnet"]
RENDER = [0, 1, 2]

parser.add_argument(
Expand Down Expand Up @@ -64,6 +65,8 @@
creator.matting_handler = extract_human
elif args.matting_model == "modnet_photographic_portrait_matting":
creator.matting_handler = extract_human_modnet_photographic_portrait_matting
elif args.matting_model == "mnn_hivision_modnet":
creator.matting_handler = extract_human_mnn_modnet

root_dir = os.path.dirname(os.path.abspath(__file__))
input_image = cv2.imread(args.input_image_dir, cv2.IMREAD_UNCHANGED)
Expand Down
9 changes: 7 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
opencv-python>=4.8.1.78
onnxruntime>=1.15.0
numpy<=1.26.4
requests
mtcnn-runtime
requests~=2.31.0
mtcnn-runtime

gradio~=4.43.0
mnn~=2.9.3
pillow~=10.1.0
fastapi~=0.112.4