Skip to content

Commit

Permalink
Merge branch 'newest' into fix_conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
ziqi-jin committed Aug 12, 2022
2 parents c68da87 + 2780588 commit 5c542af
Show file tree
Hide file tree
Showing 17 changed files with 403 additions and 136 deletions.
138 changes: 59 additions & 79 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,11 @@
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
</p>

<h4 align="center">
<a href=#特性> 特性 </a> |
<a href=#服务器端> 服务器端 </a> |
<a href=#端侧> 端侧 </a> |
<a href=#社区交流> 社区交流 </a>
</h4>

**⚡️FastDeploy**是一款**简单易用**的推理部署工具箱。覆盖业界主流**优质预训练模型**并提供**开箱即用**的开发体验,包括图像分类、目标检测、图像分割、人脸检测、人体关键点识别、文字识别等多任务,满足开发者**多场景****多硬件****多平台**的快速部署需求。

## 发版历史
- [v0.2.0] 2022.08.18 全面开源服务端部署代码,支持40+视觉模型在CPU/GPU,以及通过GPU TensorRT加速部署

## 支持模型

| 任务场景 | 模型 | X64 CPU | Nvidia-GPU | Nvidia-GPU TensorRT |
Expand All @@ -43,83 +39,67 @@
| | [PaddleDetection/PPYOLO](./examples/vision/detection/paddledetection) ||| - |
| | [PaddleDetection/PPYOLOv2](./examples/vision/detection/paddledetection) ||| - |
| | [PaddleDetection/FasterRCNN](./examples/vision/detection/paddledetection) ||| - |
| | [WongKinYiu/YOLOv7](./examples/vision/detection/yolov7) ||||

#### 快速开始

#### 安装FastDeploy Python

用户根据开发环境选择安装版本,更多安装环境参考[安装文档](docs/quick_start/install.md).

```
pip install https://bj.bcebos.com/paddlehub/fastdeploy/wheels/fastdeploy_python-0.2.0-cp38-cp38-manylinux1_x86_64.whl
```

准备目标检测模型和测试图片
```
wget https://bj.bcebos.com/paddlehub/fastdeploy/ppyoloe_crn_l_300e_coco.tgz
tar xvf ppyoloe_crn_l_300e_coco.tgz
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
```

加载模型预测
```
import fastdeploy.vision as vis
import cv2
model = vis.detection.PPYOLOE("ppyoloe_crn_l_300e_coco/model.pdmodel",
"ppyoloe_crn_l_300e_coco/model.pdiparams",
"ppyoloe_crn_l_300e_coco/infer_cfg.yml")
im = cv2.imread("000000014439.jpg")
result = model.predict(im.copy())
print(result)
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
cv2.imwrite("vis_image.jpg", vis_im)
```

预测完成,可视化结果保存至`vis_image.jpg`,同时输出检测结果如下
```
DetectionResult: [xmin, ymin, xmax, ymax, score, label_id]
415.047363,89.311523, 506.009613, 283.863129, 0.950423, 0
163.665710,81.914894, 198.585342, 166.760880, 0.896433, 0
581.788635,113.027596, 612.623474, 198.521713, 0.842597, 0
267.217224,89.777321, 298.796051, 169.361496, 0.837951, 0
104.465599,45.482410, 127.688835, 93.533875, 0.773348, 0
...
```

## 更多部署示例

FastDeploy提供了大量部署示例供开发者参考,支持模型在CPU、GPU以及TensorRT的部署

- [PaddleDetection模型部署](examples/vision/detection/paddledetection)
- [PaddleClas模型部署](examples/vision/classification/paddleclas)
- [PaddleSeg模型部署](examples/vision/segmentation/paddleseg)
- [YOLOv7部署](examples/vision/detection/yolov7)
- [YOLOv6部署](examples/vision/detection/yolov6)
- [YOLOv5部署](examples/vision/detection/yolov5)
- [人脸检测模型部署](examples/vision/facedet)
- [更多视觉模型部署示例...](examples/vision)


### 📱轻量化SDK快速实现端侧AI推理部署


| <font size=2> 任务场景 | <font size=2> 模型 | <font size=2> 大小(MB) | <font size=2>边缘端 | <font size=2>移动端 | <font size=2> 移动端 |
| ------------------ | ---------------------------- | --------------------- | --------------------- | ---------------------- | --------------------- |
| ---- | --- | --- | <font size=2> Linux | <font size=2> Android | <font size=2> iOS |
| ----- | ---- | --- | <font size=2> ARM CPU | <font size=2> ARM CPU | <font size=2> ARM CPU |
| Classfication | PP-LCNet | 11.9 ||||
| | PP-LCNetv2 | 26.6 ||||
| | EfficientNet | 31.4 ||||
| | GhostNet | 20.8 ||||
| | MobileNetV1 | 17 ||||
| | MobileNetV2 | 14.2 ||||
| | MobileNetV3 | 22 ||||
| | ShuffleNetV2 | 9.2 ||||
| | SqueezeNetV1.1 | 5 ||||
| | Inceptionv3 | 95.5 ||||
| | PP-HGNet | 59 ||||
| | SwinTransformer_224_win7 | 352.7 ||||
| Detection | PP-PicoDet_s_320_coco | 4.1 ||||
| | PP-PicoDet_s_320_lcnet | 4.9 ||||
| | CenterNet | 4.8 ||||
| | YOLOv3_MobileNetV3 | 94.6 ||||
| | PP-YOLO_tiny_650e_coco | 4.4 ||||
| | SSD_MobileNetV1_300_120e_voc | 23.3 ||||
| | PP-YOLO_ResNet50vd | 188.5 ||||
| | PP-YOLOv2_ResNet50vd | 218.7 ||||
| | PP-YOLO_crn_l_300e_coco | 209.1 ||||
| | YOLOv5s | 29.3 ||||
| Face Detection | BlazeFace | 1.5 ||||
| Face Localisation | RetinaFace | 1.7 ||||
| Keypoint Detection | PP-TinyPose | 5.5 ||||
| Segmentation | PP-LiteSeg(STDC1) | 32.2 ||||
| | PP-HumanSeg-Lite | 0.556 ||||
| | HRNet-w18 | 38.7 ||||
| | PP-HumanSeg-Server | 107.2 ||||
| | Unet | 53.7 ||||
| OCR | PP-OCRv1 | 2.3+4.4 ||||
| | PP-OCRv2 | 2.3+4.4 ||||
| | PP-OCRv3 | 2.4+10.6 ||||
| | PP-OCRv3-tiny | 2.4+10.7 ||||


#### 边缘侧部署

- ARM Linux 系统
- [C++ Inference部署(含视频流)](./docs/ARM-Linux-CPP-SDK-Inference.md)
- [C++ 服务化部署](./docs/ARM-Linux-CPP-SDK-Serving.md)
- [Python Inference部署](./docs/ARM-Linux-Python-SDK-Inference.md)
- [Python 服务化部署](./docs/ARM-Linux-Python-SDK-Serving.md)

#### 移动端部署

- [iOS 系统部署](./docs/iOS-SDK.md)
- [Android 系统部署](./docs/Android-SDK.md)

#### 自定义模型部署

- [快速实现个性化模型替换](./docs/Replace-Model-With-Anther-One.md)

## 社区交流

- **加入社区👬:** 微信扫描二维码后,填写问卷加入交流群,与开发者共同讨论推理部署痛点问题

<div align="center">
<img src="https://user-images.githubusercontent.com/54695910/175854075-2c0f9997-ed18-4b17-9aaf-1b43266d3996.jpeg" width = "200" height = "200" />
</div>

## Acknowledge

本项目中SDK生成和下载使用了[EasyEdge](https://ai.baidu.com/easyedge/app/openSource)中的免费开放能力,再次表示感谢。

## License

FastDeploy遵循[Apache-2.0开源协议](./LICENSE)
154 changes: 151 additions & 3 deletions csrc/fastdeploy/function/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "fastdeploy/function/reduce.h"

#include <limits>
#include <set>

#include "fastdeploy/function/eigen.h"
Expand Down Expand Up @@ -215,9 +216,139 @@ void Reduce(const FDTensor& x, FDTensor* out, const std::vector<int64_t>& dims,
}
reduce_all = (reduce_all || full_dim);

FD_VISIT_ALL_TYPES(x.dtype, "ReduceKernelImpl", ([&] {
ReduceKernelImpl<data_t, Functor>(x, out, dims, keep_dim,
reduce_all);
FD_VISIT_INT_FLOAT_TYPES(x.dtype, "ReduceKernelImpl", ([&] {
ReduceKernelImpl<data_t, Functor>(
x, out, dims, keep_dim, reduce_all);
}));
}

enum ArgMinMaxType { kArgMin, kArgMax };

template <typename T, typename Tout, int64_t Rank, ArgMinMaxType argMinMaxValue>
struct ArgMinMaxFunctor {};

#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
template <typename T, typename Tout, int64_t Rank> \
struct ArgMinMaxFunctor<T, Tout, Rank, enum_argminmax_value> { \
void operator()(const FDTensor& in, FDTensor* out, \
const std::vector<int64_t>& x_dims, int64_t axis, \
bool keepdims, bool flatten) { \
const auto& dev = *EigenDeviceWrapper::GetInstance()->GetDevice(); \
auto in_eigen = EigenTensor<T, Rank>::From(in, x_dims); \
if (keepdims) { \
if (!flatten) { \
auto out_eigen = EigenTensor<Tout, Rank>::From(*out); \
out_eigen.device(dev) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} else { \
auto out_eigen = EigenScalar<Tout>::From(*out); \
out_eigen.device(dev) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
} else { \
auto out_eigen = EigenTensor<Tout, Rank - 1>::From(*out); \
out_eigen.device(dev) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
} \
}

DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax);

template <typename T, typename Tout, ArgMinMaxType EnumArgMinMaxValue>
void ArgMinMaxKernel(const FDTensor& x, FDTensor* out, int64_t axis,
bool keepdims, bool flatten) {
bool new_keepdims = keepdims | flatten;
// if flatten, will construct the new dims for the cacluate
std::vector<int64_t> x_dims;
int new_axis = axis;
if (flatten) {
x_dims = {x.Numel()};
// if flatten, the axis just as 0
new_axis = 0;
} else {
x_dims = x.shape;
if (axis < 0) new_axis = axis + x_dims.size();
}
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<T, Tout, rank, EnumArgMinMaxValue> functor##rank; \
functor##rank(x, out, x_dims, new_axis, new_keepdims, flatten)

switch (x_dims.size()) {
case 1:
CALL_ARG_MINMAX_FUNCTOR(1);
break;
case 2:
CALL_ARG_MINMAX_FUNCTOR(2);
break;
case 3:
CALL_ARG_MINMAX_FUNCTOR(3);
break;
case 4:
CALL_ARG_MINMAX_FUNCTOR(4);
break;
case 5:
CALL_ARG_MINMAX_FUNCTOR(5);
break;
case 6:
CALL_ARG_MINMAX_FUNCTOR(6);
break;
default:
FDASSERT(x_dims.size() <= 6,
"%s operator doesn't supports tensors whose ranks are greater "
"than 6.",
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax"));
break;
#undef CALL_ARG_MINMAX_FUNCTOR
}
}

template <typename T, ArgMinMaxType EnumArgMinMaxValue>
void ArgMinMax(const FDTensor& x, FDTensor* out, int64_t axis,
FDDataType output_dtype, bool keepdims, bool flatten) {
const auto& x_dims = x.shape;
int64_t x_rank = x_dims.size();
FDASSERT(axis >= -x_rank,
"'axis'(%d) must be greater than or equal to -Rank(X)(%d).", axis,
-x_rank);
FDASSERT(axis < x_rank,
"'axis'(%d) must be less than or equal to Rank(X)(%d).", axis,
x_rank);
FDASSERT(output_dtype == FDDataType::INT32 || FDDataType::INT64,
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s].",
Str(FDDataType::INT32), Str(FDDataType::INT64), Str(output_dtype));
if (axis < 0) axis += x_rank;
if (output_dtype == FDDataType::INT32) {
int64_t all_element_num = 0;
if (flatten) {
all_element_num = x.Numel();

} else {
all_element_num = x_dims[axis];
}
FDASSERT(all_element_num <= std::numeric_limits<int>::max(),
"The element num of the argmin/argmax input at axis is "
"%d, is larger than int32 maximum value:%d, you must "
"set the dtype of argmin/argmax to 'int64'.",
all_element_num, std::numeric_limits<int>::max());
}
std::vector<int64_t> vec;
if (flatten) {
vec.emplace_back(static_cast<int64_t>(1));
} else {
for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]);
if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
}
for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
}
out->Allocate(vec, output_dtype);

FD_VISIT_INT_TYPES(output_dtype, "ArgMinMaxKernel", ([&] {
ArgMinMaxKernel<T, data_t, EnumArgMinMaxValue>(
x, out, axis, keepdims, flatten);
}));
}

Expand Down Expand Up @@ -255,6 +386,23 @@ void Prod(const FDTensor& x, FDTensor* out, const std::vector<int64_t>& dims,
bool keep_dim, bool reduce_all) {
Reduce<ProdFunctor>(x, out, dims, keep_dim, reduce_all);
}

void ArgMax(const FDTensor& x, FDTensor* out, int64_t axis,
FDDataType output_dtype, bool keep_dim, bool flatten) {
FD_VISIT_INT_FLOAT_TYPES(x.dtype, "ArgMaxKernel", ([&] {
ArgMinMax<data_t, kArgMax>(
x, out, axis, output_dtype, keep_dim, flatten);
}));
}

void ArgMin(const FDTensor& x, FDTensor* out, int64_t axis,
FDDataType output_dtype, bool keep_dim, bool flatten) {
FD_VISIT_INT_FLOAT_TYPES(x.dtype, "ArgMaxKernel", ([&] {
ArgMinMax<data_t, kArgMin>(
x, out, axis, output_dtype, keep_dim, flatten);
}));
}

#endif

} // namespace fastdeploy
28 changes: 28 additions & 0 deletions csrc/fastdeploy/function/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,33 @@ FASTDEPLOY_DECL void Prod(const FDTensor& x, FDTensor* out,
const std::vector<int64_t>& dims,
bool keep_dim = false, bool reduce_all = false);

/** Excute the argmax operation for input FDTensor along given dims.
@param x The input tensor.
@param out The output tensor which stores the result.
@param axis The axis which will be reduced.
@param output_dtype The data type of output FDTensor, INT64 or INT32,
default to INT64.
@param keep_dim Whether to keep the reduced dims, default false.
@param flatten Whether to flatten FDTensor to get the argmin index, default
false.
*/
FASTDEPLOY_DECL void ArgMax(const FDTensor& x, FDTensor* out, int64_t axis,
FDDataType output_dtype = FDDataType::INT64,
bool keep_dim = false, bool flatten = false);

/** Excute the argmin operation for input FDTensor along given dims.
@param x The input tensor.
@param out The output tensor which stores the result.
@param axis The axis which will be reduced.
@param output_dtype The data type of output FDTensor, INT64 or INT32,
default to INT64.
@param keep_dim Whether to keep the reduced dims, default false.
@param flatten Whether to flatten FDTensor to get the argmin index, default
false.
*/
FASTDEPLOY_DECL void ArgMin(const FDTensor& x, FDTensor* out, int64_t axis,
FDDataType output_dtype = FDDataType::INT64,
bool keep_dim = false, bool flatten = false);

#endif
} // namespace fastdeploy
Loading

0 comments on commit 5c542af

Please sign in to comment.