Skip to content

Commit

Permalink
[GCU] Support inference for GCU
Browse files Browse the repository at this point in the history
  • Loading branch information
EnflameGCU committed Nov 1, 2024
1 parent d1bc416 commit 6f3a1aa
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
26 changes: 25 additions & 1 deletion tools/infer/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import random
from ppocr.utils.logging import get_logger

import paddle_custom_device.gcu.passes as gcu_passes


def str2bool(v):
return v.lower() in ("true", "yes", "t", "y", "1")
Expand All @@ -41,6 +43,7 @@ def init_args():
parser.add_argument("--use_xpu", type=str2bool, default=False)
parser.add_argument("--use_npu", type=str2bool, default=False)
parser.add_argument("--use_mlu", type=str2bool, default=False)
parser.add_argument("--use_gcu", type=str2bool, default=False)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--min_subgraph_size", type=int, default=15)
Expand Down Expand Up @@ -287,6 +290,26 @@ def create_predictor(args, mode, logger):
config.enable_custom_device("mlu")
elif args.use_xpu:
config.enable_xpu(10 * 1024 * 1024)
elif args.use_gcu:
gcu_passes.setUp()
if args.precision == "fp16":
config.enable_custom_device(
"gcu", 0, paddle.inference.PrecisionType.Half
)
gcu_passes.set_exp_enable_mixed_precision_ops(config)
else:
config.enable_custom_device("gcu")

if paddle.framework.use_pir_api():
config.enable_new_ir(True)
config.enable_new_executor(True)
kPirGcuPasses = gcu_passes.inference_passes(
use_pir=True, name="PaddleOCR"
)
config.enable_custom_passes(kPirGcuPasses, True)
else:
pass_builder = config.pass_builder()
gcu_passes.append_passes_for_legacy_ir(pass_builder, "PaddleOCR")
else:
config.disable_gpu()
if args.enable_mkldnn:
Expand All @@ -303,7 +326,8 @@ def create_predictor(args, mode, logger):
# enable memory optim
config.enable_memory_optim()
config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
if not args.use_gcu:
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.delete_pass("matmul_transpose_reshape_fuse_pass")
if mode == "rec" and args.rec_algorithm == "SRN":
config.delete_pass("gpu_cpu_map_matmul_v2_to_matmul_pass")
Expand Down
10 changes: 8 additions & 2 deletions tools/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def merge_config(config, opts):
return config


def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False, use_gcu=False):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
Expand Down Expand Up @@ -154,6 +154,9 @@ def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
if use_mlu and not paddle.device.is_compiled_with_mlu():
print(err.format("use_mlu", "mlu", "mlu", "use_mlu"))
sys.exit(1)
if use_gcu and not paddle.device.is_compiled_with_custom_device("gcu"):
print(err.format("use_gcu", "gcu", "gcu", "use_gcu"))
sys.exit(1)
except Exception as e:
pass

Expand Down Expand Up @@ -799,6 +802,7 @@ def preprocess(is_train=False):
use_xpu = config["Global"].get("use_xpu", False)
use_npu = config["Global"].get("use_npu", False)
use_mlu = config["Global"].get("use_mlu", False)
use_gcu = config["Global"].get("use_gcu", False)

alg = config["Architecture"]["algorithm"]
assert alg in [
Expand Down Expand Up @@ -853,9 +857,11 @@ def preprocess(is_train=False):
device = "npu:{0}".format(os.getenv("FLAGS_selected_npus", 0))
elif use_mlu:
device = "mlu:{0}".format(os.getenv("FLAGS_selected_mlus", 0))
elif use_gcu:
device = "gcu:{0}".format(os.getenv("FLAGS_selected_gcus", 0))
else:
device = "gpu:{}".format(dist.ParallelEnv().dev_id) if use_gpu else "cpu"
check_device(use_gpu, use_xpu, use_npu, use_mlu)
check_device(use_gpu, use_xpu, use_npu, use_mlu, use_gcu)

device = paddle.set_device(device)

Expand Down

0 comments on commit 6f3a1aa

Please sign in to comment.