From 699d6cd0326ebbb3cace63c0293fdfd89bb021f6 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Fri, 16 Aug 2019 22:39:39 +0800 Subject: [PATCH] publish lite (#1800) --- .clang-format | 24 +- .pre-commit-config.yaml | 66 +- CMakeLists.txt | 426 +- cmake/FindGflags.cmake | 582 + cmake/FindGlog.cmake | 24 + cmake/FindGperftools.cmake | 63 + cmake/FindJeMalloc.cmake | 28 + cmake/FindNumPy.cmake | 38 + cmake/cblas.cmake | 94 + cmake/ccache.cmake | 9 + cmake/configure.cmake | 213 + cmake/coveralls.cmake | 103 + cmake/coverallsGcovJsons.cmake | 401 + cmake/cross_compiling/android.cmake | 85 + cmake/cross_compiling/armlinux.cmake | 41 + cmake/cross_compiling/findar.cmake | 33 + cmake/cross_compiling/host.cmake | 48 + cmake/cross_compiling/ios.cmake | 691 ++ cmake/cross_compiling/npu.cmake | 70 + cmake/cross_compiling/postproject.cmake | 99 + cmake/cross_compiling/preproject.cmake | 59 + cmake/cuda.cmake | 228 + cmake/cudnn.cmake | 102 + cmake/cupti.cmake | 41 + cmake/external/eigen.cmake | 54 + cmake/external/gflags.cmake | 74 + cmake/external/glog.cmake | 77 + cmake/external/gtest.cmake | 85 + cmake/external/libxsmm.cmake | 55 + cmake/external/mkldnn.cmake | 120 + cmake/external/mklml.cmake | 77 + cmake/external/openblas.cmake | 93 + cmake/external/opencl-clhpp.cmake | 36 + cmake/external/opencl-headers.cmake | 33 + cmake/external/protobuf.cmake | 293 + cmake/external/xbyak.cmake | 57 + cmake/external/xxhash.cmake | 73 + cmake/flags.cmake | 194 + cmake/generic.cmake | 570 + cmake/hip.cmake | 53 + cmake/lite.cmake | 89 + cmake/make_resource.py | 25 + cmake/operators.cmake | 227 + cmake/package.cmake | 21 + cmake/simd.cmake | 99 + cmake/system.cmake | 85 + cmake/tensorrt.cmake | 38 + cmake/util.cmake | 55 + cmake/version.cmake | 66 + lite/CMakeLists.txt | 358 + lite/api/CMakeLists.txt | 223 + lite/api/android/.gitignore | 2 + lite/api/android/CMakeLists.txt | 5 + lite/api/android/jni/.gitignore | 3 + lite/api/android/jni/CMakeLists.txt | 51 + lite/api/android/jni/native/CMakeLists.txt | 32 + .../api/android/jni/native/convert_util_jni.h | 186 + .../api/android/jni/native/paddle_lite_jni.cc | 164 + lite/api/android/jni/native/paddle_lite_jni.h | 113 + lite/api/android/jni/native/tensor_jni.cc | 168 + lite/api/android/jni/native/tensor_jni.h | 90 + .../jni/src/com/baidu/paddle/lite/.gitignore | 2 + .../src/com/baidu/paddle/lite/ConfigBase.java | 31 + .../src/com/baidu/paddle/lite/CxxConfig.java | 39 + .../com/baidu/paddle/lite/MobileConfig.java | 22 + .../paddle/lite/PaddleLiteInitializer.java | 23 + .../baidu/paddle/lite/PaddlePredictor.java | 192 + .../jni/src/com/baidu/paddle/lite/Place.java | 148 + .../jni/src/com/baidu/paddle/lite/Tensor.java | 141 + .../paddle/lite/PaddlePredictorTest.java | 54 + lite/api/apis_test.cc | 112 + lite/api/cxx_api.cc | 141 + lite/api/cxx_api.h | 165 + lite/api/cxx_api_bin.cc | 129 + lite/api/cxx_api_impl.cc | 90 + lite/api/cxx_api_test.cc | 149 + lite/api/efficientnet_b0_test.cc | 102 + lite/api/inceptionv4_test.cc | 91 + lite/api/light_api.cc | 93 + lite/api/light_api.h | 70 + lite/api/light_api_impl.cc | 72 + lite/api/light_api_test.cc | 51 + lite/api/lite_api_test_helper.cc | 60 + lite/api/lite_api_test_helper.h | 31 + lite/api/mobilenetv1_int8_test.cc | 94 + lite/api/mobilenetv1_ssd_test.cc | 112 + lite/api/mobilenetv1_test.cc | 152 + lite/api/mobilenetv1_yolov3_test.cc | 119 + lite/api/mobilenetv2_test.cc | 154 + lite/api/model_optimize_tool.cc | 93 + lite/api/model_test.cc | 181 + lite/api/ocr_attention_test.cc | 115 + lite/api/paddle_api.cc | 73 + lite/api/paddle_api.h | 121 + lite/api/paddle_api_test.cc | 88 + lite/api/paddle_lite_factory_helper.h | 37 + lite/api/paddle_place.cc | 103 + lite/api/paddle_place.h | 153 + lite/api/paddle_use_kernels.h | 180 + lite/api/paddle_use_ops.h | 106 + lite/api/paddle_use_passes.h | 38 + lite/api/resnet18_test.cc | 87 + lite/api/resnet50_test.cc | 107 + lite/api/resnet50_test_fpga.cc | 61 + lite/api/shufflenetv2_test.cc | 92 + lite/api/test_googlenet_lite.cc | 80 + lite/api/test_helper.h | 37 + lite/api/test_inceptionv4_lite_x86.cc | 108 + lite/api/test_mobilenetv1_lite_x86.cc | 105 + lite/api/test_mobilenetv2_lite_x86.cc | 108 + lite/api/unet_test.cc | 103 + lite/arm/CMakeLists.txt | 2 + lite/arm/math/CMakeLists.txt | 109 + lite/arm/math/activation.cc | 638 ++ lite/arm/math/activation.h | 57 + lite/arm/math/argmax.cc | 65 + lite/arm/math/argmax.h | 35 + lite/arm/math/axpy.cc | 203 + lite/arm/math/axpy.h | 49 + lite/arm/math/beam_search.cc | 271 + lite/arm/math/beam_search.h | 41 + lite/arm/math/box_coder.cc | 92 + lite/arm/math/box_coder.h | 36 + lite/arm/math/col_im_transform.cc | 75 + lite/arm/math/col_im_transform.h | 40 + lite/arm/math/concat.cc | 60 + lite/arm/math/concat.h | 35 + lite/arm/math/conv3x3s1_direct_int8.cc | 806 ++ lite/arm/math/conv3x3s2_direct_int8.cc | 1081 ++ lite/arm/math/conv_block_utils.h | 4292 ++++++++ lite/arm/math/conv_depthwise.cc | 239 + lite/arm/math/conv_depthwise.h | 100 + lite/arm/math/conv_depthwise_3x3_int7.cc | 5322 +++++++++ lite/arm/math/conv_depthwise_3x3_int8.cc | 5832 ++++++++++ lite/arm/math/conv_depthwise_3x3p0.cc | 4178 +++++++ lite/arm/math/conv_depthwise_3x3p1.cc | 4850 +++++++++ lite/arm/math/conv_depthwise_5x5s1.cc | 9615 +++++++++++++++++ lite/arm/math/conv_depthwise_5x5s1_int8.cc | 618 ++ lite/arm/math/conv_depthwise_5x5s2.cc | 3746 +++++++ lite/arm/math/conv_direct.cc | 242 + lite/arm/math/conv_direct.h | 107 + lite/arm/math/conv_direct_3x3s1.cc | 1067 ++ lite/arm/math/conv_direct_3x3s2.cc | 1209 +++ lite/arm/math/conv_gemmlike.cc | 285 + lite/arm/math/conv_gemmlike.h | 108 + lite/arm/math/conv_impl.cc | 900 ++ lite/arm/math/conv_impl.h | 423 + lite/arm/math/conv_winograd.cc | 141 + lite/arm/math/conv_winograd.h | 65 + lite/arm/math/conv_winograd_3x3.cc | 479 + lite/arm/math/decode_bboxes.cc | 651 ++ lite/arm/math/decode_bboxes.h | 39 + lite/arm/math/dot_toolchain_support.h | 196 + lite/arm/math/dropout.cc | 93 + lite/arm/math/dropout.h | 32 + lite/arm/math/elementwise.cc | 758 ++ lite/arm/math/elementwise.h | 67 + lite/arm/math/fill_bias_relu.cc | 122 + lite/arm/math/fill_bias_relu.h | 44 + lite/arm/math/funcs.cc | 153 + lite/arm/math/funcs.h | 424 + lite/arm/math/gemm_prepacked_int8.cc | 3942 +++++++ lite/arm/math/gemm_prepacked_int8.h | 94 + lite/arm/math/gemv_arm_int8.cc | 480 + lite/arm/math/gemv_arm_int8.h | 40 + lite/arm/math/gru_utils.h | 434 + lite/arm/math/im2sequence.cc | 72 + lite/arm/math/im2sequence.h | 44 + lite/arm/math/increment.cc | 37 + lite/arm/math/increment.h | 33 + lite/arm/math/interpolate.cc | 514 + lite/arm/math/interpolate.h | 58 + lite/arm/math/lrn.cc | 101 + lite/arm/math/lrn.h | 49 + lite/arm/math/multiclass_nms.cc | 299 + lite/arm/math/multiclass_nms.h | 45 + lite/arm/math/negative.cc | 37 + lite/arm/math/negative.h | 33 + lite/arm/math/norm.cc | 52 + lite/arm/math/norm.h | 35 + lite/arm/math/packed_sgemm.cc | 3454 ++++++ lite/arm/math/packed_sgemm.h | 84 + lite/arm/math/pad2d.cc | 413 + lite/arm/math/pad2d.h | 71 + lite/arm/math/pooling.cc | 3173 ++++++ lite/arm/math/pooling.h | 154 + lite/arm/math/power.cc | 96 + lite/arm/math/power.h | 33 + lite/arm/math/prior_box.cc | 364 + lite/arm/math/prior_box.h | 68 + lite/arm/math/reduce_max.cc | 207 + lite/arm/math/reduce_max.h | 89 + lite/arm/math/saturate.h | 320 + lite/arm/math/scale.cc | 177 + lite/arm/math/scale.h | 45 + lite/arm/math/sequence2batch.h | 210 + lite/arm/math/sequence_expand.cc | 63 + lite/arm/math/sequence_expand.h | 35 + lite/arm/math/sequence_pool.cc | 224 + lite/arm/math/sequence_pool.h | 69 + lite/arm/math/sequence_softmax.cc | 49 + lite/arm/math/sequence_softmax.h | 34 + lite/arm/math/sgemm.cc | 68 + lite/arm/math/sgemm.h | 48 + lite/arm/math/sgemv.cc | 1054 ++ lite/arm/math/sgemv.h | 38 + lite/arm/math/shuffle_channel.cc | 81 + lite/arm/math/shuffle_channel.h | 34 + lite/arm/math/slice.cc | 91 + lite/arm/math/slice.h | 38 + lite/arm/math/softmax.cc | 616 ++ lite/arm/math/softmax.h | 71 + lite/arm/math/split.cc | 85 + lite/arm/math/split.h | 37 + lite/arm/math/topk.cc | 53 + lite/arm/math/topk.h | 34 + lite/arm/math/type_trans.cc | 919 ++ lite/arm/math/type_trans.h | 117 + lite/arm/math/yolo_box.cc | 168 + lite/arm/math/yolo_box.h | 37 + lite/core/CMakeLists.txt | 75 + lite/core/arena/CMakeLists.txt | 10 + lite/core/arena/framework.cc | 70 + lite/core/arena/framework.h | 258 + lite/core/arena/framework_test.cc | 83 + lite/core/context.cc | 23 + lite/core/context.h | 363 + lite/core/context_test.cc | 51 + lite/core/cpu_info.cc | 1073 ++ lite/core/cpu_info.h | 135 + lite/core/framework.proto | 188 + lite/core/kernel.cc | 104 + lite/core/kernel.h | 189 + lite/core/kernel_test.cc | 63 + lite/core/lite.map | 6 + lite/core/lite_gtest_main.cc | 23 + lite/core/lite_tensor_test.cc | 32 + lite/core/memory.cc | 110 + lite/core/memory.h | 111 + lite/core/memory_test.cc | 34 + lite/core/mir/CMakeLists.txt | 106 + lite/core/mir/argument_type_display_pass.cc | 45 + lite/core/mir/demo_pass.cc | 37 + lite/core/mir/dot.h | 167 + lite/core/mir/elimination/CMakeLists.txt | 10 + .../identity_scale_eliminate_pass.cc | 72 + .../identity_scale_eliminate_pass_test.cc | 93 + lite/core/mir/fusion/CMakeLists.txt | 36 + .../mir/fusion/conv_activation_fuse_pass.cc | 38 + .../mir/fusion/conv_activation_fuse_pass.h | 32 + lite/core/mir/fusion/conv_activation_fuser.cc | 103 + lite/core/mir/fusion/conv_activation_fuser.h | 47 + lite/core/mir/fusion/conv_bn_fuse_pass.cc | 37 + lite/core/mir/fusion/conv_bn_fuse_pass.h | 32 + .../core/mir/fusion/conv_bn_fuse_pass_test.cc | 140 + lite/core/mir/fusion/conv_bn_fuser.cc | 163 + lite/core/mir/fusion/conv_bn_fuser.h | 58 + ...ementwise_add_activation_fuse_pass_test.cc | 157 + .../mir/fusion/conv_elementwise_fuse_pass.cc | 38 + .../mir/fusion/conv_elementwise_fuse_pass.h | 32 + .../core/mir/fusion/conv_elementwise_fuser.cc | 102 + lite/core/mir/fusion/conv_elementwise_fuser.h | 43 + .../elementwise_add_activation_fuse_pass.cc | 36 + .../elementwise_add_activation_fuse_pass.h | 32 + ...ementwise_add_activation_fuse_pass_test.cc | 117 + .../elementwise_add_activation_fuser.cc | 87 + .../fusion/elementwise_add_activation_fuser.h | 41 + lite/core/mir/fusion/fc_fuse_pass.cc | 34 + lite/core/mir/fusion/fc_fuse_pass.h | 32 + lite/core/mir/fusion/fc_fuse_pass_test.cc | 112 + lite/core/mir/fusion/fc_fuser.cc | 78 + lite/core/mir/fusion/fc_fuser.h | 38 + .../mir/fusion/quant_dequant_fuse_pass.cc | 45 + .../core/mir/fusion/quant_dequant_fuse_pass.h | 33 + .../core/mir/fusion/quant_dequant_op_fuser.cc | 198 + lite/core/mir/fusion/quant_dequant_op_fuser.h | 59 + lite/core/mir/generate_program_pass.cc | 42 + lite/core/mir/generate_program_pass.h | 50 + lite/core/mir/graph_visualize_pass.cc | 76 + lite/core/mir/graph_visualize_pass.h | 39 + lite/core/mir/io_copy_kernel_pick_pass.cc | 74 + lite/core/mir/node.cc | 74 + lite/core/mir/node.h | 173 + lite/core/mir/pass.cc | 15 + lite/core/mir/pass.h | 78 + lite/core/mir/pass_manager.cc | 21 + lite/core/mir/pass_manager.h | 87 + lite/core/mir/pass_manager_test.cc | 33 + lite/core/mir/pass_registry.cc | 21 + lite/core/mir/pass_registry.h | 44 + lite/core/mir/pattern_matcher.cc | 527 + lite/core/mir/pattern_matcher.h | 424 + lite/core/mir/pattern_matcher_high_api.cc | 80 + lite/core/mir/pattern_matcher_high_api.h | 83 + .../core/mir/pattern_matcher_high_api_test.cc | 150 + lite/core/mir/pattern_matcher_test.cc | 233 + lite/core/mir/pattern_matcher_tester.cc | 233 + lite/core/mir/runtime_context_assign_pass.cc | 41 + lite/core/mir/ssa_graph.cc | 240 + lite/core/mir/ssa_graph.h | 144 + lite/core/mir/ssa_graph_test.cc | 59 + lite/core/mir/static_kernel_pick_pass.cc | 135 + lite/core/mir/static_kernel_pick_pass.h | 97 + lite/core/mir/subgraph/CMakeLists.txt | 32 + .../mir/subgraph/generate_npu_program_pass.cc | 259 + .../mir/subgraph/generate_npu_program_pass.h | 65 + .../generate_npu_program_pass_test.cc | 65 + .../mir/subgraph/subgraph_program_pass.cc | 139 + .../core/mir/subgraph/subgraph_program_pass.h | 70 + .../subgraph/subgraph_program_pass_test.cc | 140 + lite/core/mir/type_layout_cast_pass.cc | 176 + lite/core/mir/type_layout_cast_pass.h | 62 + lite/core/mir/type_precision_cast_pass.cc | 182 + lite/core/mir/type_precision_cast_pass.h | 66 + lite/core/mir/type_target_cast_pass.cc | 182 + lite/core/mir/type_target_cast_pass.h | 66 + .../core/mir/variable_place_inference_pass.cc | 34 + lite/core/mir/variable_place_inference_pass.h | 157 + .../mir/variable_place_inference_pass_test.cc | 101 + lite/core/naive_test_model.py | 56 + lite/core/op_lite.cc | 105 + lite/core/op_lite.h | 231 + lite/core/op_lite_test.cc | 24 + lite/core/op_registry.cc | 152 + lite/core/op_registry.h | 282 + lite/core/optimizer.cc | 34 + lite/core/optimizer.h | 196 + lite/core/optimizer_test.cc | 51 + lite/core/profile/CMakeLists.txt | 8 + lite/core/profile/basic_profiler.cc | 26 + lite/core/profile/basic_profiler.h | 201 + lite/core/profile/basic_profiler_test.cc | 46 + lite/core/profile/precision_profiler.h | 102 + lite/core/program.cc | 133 + lite/core/program.h | 149 + lite/core/program_fake_utils.cc | 22 + lite/core/program_fake_utils.h | 142 + lite/core/scope.cc | 72 + lite/core/scope.h | 79 + lite/core/scope_test.cc | 37 + lite/core/target_wrapper.cc | 21 + lite/core/target_wrapper.h | 208 + lite/core/tensor.cc | 115 + lite/core/tensor.h | 227 + lite/core/type_system.cc | 157 + lite/core/type_system.h | 390 + lite/core/type_system_test.cc | 35 + lite/core/types.cc | 95 + lite/core/types.h | 116 + lite/core/types_test.cc | 43 + lite/core/variable.cc | 19 + lite/core/variable.h | 52 + lite/core/workspace.cc | 15 + lite/core/workspace.h | 83 + lite/cuda/CMakeLists.txt | 8 + lite/cuda/blas.cc | 57 + lite/cuda/blas.h | 99 + lite/cuda/cuda_utils.h | 76 + lite/cuda/target_wrapper.cc | 74 + lite/cuda/target_wrapper.h | 29 + lite/demo/cxx/Makefile.def | 35 + lite/demo/cxx/README.md | 42 + .../mobile_full/Makefile.android.armv7 | 22 + .../mobile_full/Makefile.android.armv8 | 22 + .../mobile_light/Makefile.android.armv7 | 22 + .../mobile_light/Makefile.android.armv8 | 22 + .../cxx/mobile_full/mobilenetv1_full_api.cc | 73 + .../cxx/mobile_light/mobilenetv1_light_api.cc | 65 + lite/demo/java/README.md | 118 + .../java/android/PaddlePredictor/.gitignore | 13 + .../android/PaddlePredictor/app/.gitignore | 1 + .../android/PaddlePredictor/app/build.gradle | 28 + .../PaddlePredictor/app/proguard-rules.pro | 21 + .../paddle/lite/ExampleInstrumentedTest.java | 114 + .../app/src/main/AndroidManifest.xml | 21 + .../app/src/main/assets/README.txt | 8 + .../com/baidu/paddle/lite/MainActivity.java | 204 + .../drawable-v24/ic_launcher_foreground.xml | 34 + .../res/drawable/ic_launcher_background.xml | 170 + .../app/src/main/res/layout/activity_main.xml | 19 + .../res/mipmap-anydpi-v26/ic_launcher.xml | 5 + .../mipmap-anydpi-v26/ic_launcher_round.xml | 5 + .../src/main/res/mipmap-hdpi/ic_launcher.png | Bin 0 -> 2963 bytes .../res/mipmap-hdpi/ic_launcher_round.png | Bin 0 -> 4905 bytes .../src/main/res/mipmap-mdpi/ic_launcher.png | Bin 0 -> 2060 bytes .../res/mipmap-mdpi/ic_launcher_round.png | Bin 0 -> 2783 bytes .../src/main/res/mipmap-xhdpi/ic_launcher.png | Bin 0 -> 4490 bytes .../res/mipmap-xhdpi/ic_launcher_round.png | Bin 0 -> 6895 bytes .../main/res/mipmap-xxhdpi/ic_launcher.png | Bin 0 -> 6387 bytes .../res/mipmap-xxhdpi/ic_launcher_round.png | Bin 0 -> 10413 bytes .../main/res/mipmap-xxxhdpi/ic_launcher.png | Bin 0 -> 9128 bytes .../res/mipmap-xxxhdpi/ic_launcher_round.png | Bin 0 -> 15132 bytes .../app/src/main/res/values/colors.xml | 6 + .../app/src/main/res/values/strings.xml | 3 + .../app/src/main/res/values/styles.xml | 11 + .../baidu/paddle/lite/ExampleUnitTest.java | 17 + .../java/android/PaddlePredictor/build.gradle | 27 + .../android/PaddlePredictor/gradle.properties | 13 + .../gradle/wrapper/gradle-wrapper.jar | Bin 0 -> 54329 bytes .../gradle/wrapper/gradle-wrapper.properties | 6 + .../demo/java/android/PaddlePredictor/gradlew | 172 + .../java/android/PaddlePredictor/gradlew.bat | 84 + .../android/PaddlePredictor/settings.gradle | 1 + lite/demo/java/android/prepare_demo.bash | 23 + lite/fluid/CMakeLists.txt | 4 + lite/fluid/data_type.cc | 101 + lite/fluid/data_type.h | 88 + lite/fluid/data_type_test.cc | 40 + lite/fluid/eigen.h | 141 + lite/fluid/float16.h | 1100 ++ lite/fluid/lod.h | 38 + lite/fluid/math.h | 42 + lite/fpga/CMakeLists.txt | 15 + lite/fpga/KD/alignment.h | 26 + lite/fpga/KD/context.hpp | 50 + lite/fpga/KD/dl_engine.cpp | 27 + lite/fpga/KD/dl_engine.hpp | 36 + lite/fpga/KD/float16.hpp | 508 + lite/fpga/KD/fpga_cv.cpp | 80 + lite/fpga/KD/fpga_cv.hpp | 28 + lite/fpga/KD/layout.hpp | 99 + lite/fpga/KD/llapi/bias_scale.cpp | 102 + lite/fpga/KD/llapi/bias_scale.h | 30 + lite/fpga/KD/llapi/config.h | 19 + lite/fpga/KD/llapi/filter.cpp | 317 + lite/fpga/KD/llapi/filter.h | 58 + lite/fpga/KD/llapi/zynqmp_api.cpp | 323 + lite/fpga/KD/llapi/zynqmp_api.h | 337 + lite/fpga/KD/pe.hpp | 37 + lite/fpga/KD/pe_params.hpp | 233 + lite/fpga/KD/pes/batchnorm_pe.hpp | 105 + lite/fpga/KD/pes/concat_pe.hpp | 135 + lite/fpga/KD/pes/conv_pe.hpp | 138 + lite/fpga/KD/pes/conv_process.hpp | 418 + lite/fpga/KD/pes/crop_pe.cpp | 88 + lite/fpga/KD/pes/crop_pe.hpp | 45 + lite/fpga/KD/pes/depthwise_conv_pe.hpp | 102 + lite/fpga/KD/pes/elementwise_add_pe.hpp | 81 + lite/fpga/KD/pes/fully_connected_pe.hpp | 94 + lite/fpga/KD/pes/input_pe.hpp | 54 + lite/fpga/KD/pes/norm_pe.hpp | 121 + lite/fpga/KD/pes/output_pe.hpp | 53 + lite/fpga/KD/pes/pooling_pe.hpp | 176 + lite/fpga/KD/pes/prior_box_pe.cpp | 273 + lite/fpga/KD/pes/prior_box_pe.hpp | 46 + lite/fpga/KD/pes/relu_pe.hpp | 75 + lite/fpga/KD/pes/resize.hpp | 89 + lite/fpga/KD/pes/scale_pe.hpp | 120 + lite/fpga/KD/pes/softmax_pe.cpp | 162 + lite/fpga/KD/pes/softmax_pe.hpp | 44 + lite/fpga/KD/pes/split_pe.hpp | 124 + lite/fpga/KD/shape.hpp | 116 + lite/fpga/KD/tensor.hpp | 456 + lite/fpga/KD/tensor_util.cpp | 32 + lite/fpga/KD/tensor_util.hpp | 25 + lite/fpga/lite_tensor.cc | 110 + lite/fpga/lite_tensor.h | 224 + lite/fpga/target_wrapper.cc | 37 + lite/gen_code/CMakeLists.txt | 49 + lite/gen_code/gen_code.cc | 223 + lite/gen_code/gen_code.h | 258 + lite/gen_code/gen_code_test.cc | 165 + lite/gen_code/generated_code_test.cc | 87 + lite/gen_code/paddle_code_generator.cc | 54 + lite/gen_code/paddle_infer.cc | 145 + lite/gen_code/paddle_infer.h | 72 + lite/host/CMakeLists.txt | 3 + lite/host/target_wrapper.cc | 49 + lite/kernels/CMakeLists.txt | 11 + lite/kernels/arm/CMakeLists.txt | 141 + lite/kernels/arm/activation_compute.cc | 196 + lite/kernels/arm/activation_compute.h | 108 + lite/kernels/arm/argmax_compute.cc | 47 + lite/kernels/arm/argmax_compute.h | 37 + lite/kernels/arm/argmax_compute_test.cc | 139 + lite/kernels/arm/axpy_compute.cc | 62 + lite/kernels/arm/axpy_compute.h | 37 + lite/kernels/arm/axpy_compute_test.cc | 142 + lite/kernels/arm/batch_norm_compute.cc | 123 + lite/kernels/arm/batch_norm_compute.h | 42 + lite/kernels/arm/batch_norm_compute_test.cc | 221 + lite/kernels/arm/beam_search_compute.cc | 60 + lite/kernels/arm/beam_search_compute.h | 42 + .../kernels/arm/beam_search_decode_compute.cc | 296 + lite/kernels/arm/beam_search_decode_compute.h | 39 + lite/kernels/arm/box_coder_compute.cc | 55 + lite/kernels/arm/box_coder_compute.h | 36 + lite/kernels/arm/calib_compute.cc | 90 + lite/kernels/arm/calib_compute.h | 51 + lite/kernels/arm/calib_compute_test.cc | 156 + lite/kernels/arm/cast_compute.cc | 50 + lite/kernels/arm/cast_compute.h | 42 + lite/kernels/arm/compare_compute.cc | 186 + lite/kernels/arm/compare_compute.h | 43 + lite/kernels/arm/concat_compute.cc | 87 + lite/kernels/arm/concat_compute.h | 37 + lite/kernels/arm/concat_compute_test.cc | 236 + lite/kernels/arm/conv_compute.cc | 241 + lite/kernels/arm/conv_compute.h | 67 + lite/kernels/arm/conv_compute_test.cc | 1045 ++ lite/kernels/arm/conv_transpose_compute.cc | 164 + lite/kernels/arm/conv_transpose_compute.h | 40 + .../arm/conv_transpose_compute_test.cc | 371 + lite/kernels/arm/crop_compute.cc | 77 + lite/kernels/arm/crop_compute.h | 49 + lite/kernels/arm/decode_bboxes_compute.cc | 68 + lite/kernels/arm/decode_bboxes_compute.h | 36 + .../kernels/arm/decode_bboxes_compute_test.cc | 185 + lite/kernels/arm/density_prior_box_compute.cc | 122 + lite/kernels/arm/density_prior_box_compute.h | 37 + lite/kernels/arm/dropout_compute.cc | 51 + lite/kernels/arm/dropout_compute.h | 35 + lite/kernels/arm/dropout_compute_test.cc | 106 + lite/kernels/arm/elementwise_compute.cc | 280 + lite/kernels/arm/elementwise_compute.h | 76 + lite/kernels/arm/elementwise_compute_test.cc | 721 ++ lite/kernels/arm/fc_compute.cc | 263 + lite/kernels/arm/fc_compute.h | 68 + lite/kernels/arm/fc_compute_test.cc | 211 + lite/kernels/arm/fill_constant_compute.cc | 54 + lite/kernels/arm/gru_compute.cc | 146 + lite/kernels/arm/gru_compute.h | 38 + lite/kernels/arm/gru_unit_compute.cc | 116 + lite/kernels/arm/gru_unit_compute.h | 38 + lite/kernels/arm/im2sequence_compute.cc | 141 + lite/kernels/arm/im2sequence_compute.h | 42 + lite/kernels/arm/increment_compute.cc | 49 + lite/kernels/arm/increment_compute.h | 42 + lite/kernels/arm/interpolate_compute.cc | 94 + lite/kernels/arm/interpolate_compute.h | 44 + lite/kernels/arm/is_empty_compute.cc | 47 + lite/kernels/arm/is_empty_compute.h | 40 + lite/kernels/arm/lod_reset_compute.cc | 64 + lite/kernels/arm/lod_reset_compute.h | 41 + lite/kernels/arm/logical_compute.cc | 131 + lite/kernels/arm/logical_compute.h | 53 + lite/kernels/arm/lookup_table_compute.cc | 77 + lite/kernels/arm/lookup_table_compute.h | 38 + lite/kernels/arm/lrn_compute.cc | 56 + lite/kernels/arm/lrn_compute.h | 36 + lite/kernels/arm/lrn_compute_test.cc | 196 + lite/kernels/arm/mul_compute.cc | 98 + lite/kernels/arm/mul_compute.h | 42 + lite/kernels/arm/mul_compute_test.cc | 182 + lite/kernels/arm/multiclass_nms_compute.cc | 103 + lite/kernels/arm/multiclass_nms_compute.h | 38 + .../arm/multiclass_nms_compute_test.cc | 374 + lite/kernels/arm/negative_compute.cc | 53 + lite/kernels/arm/negative_compute.h | 37 + lite/kernels/arm/norm_compute.cc | 50 + lite/kernels/arm/norm_compute.h | 42 + lite/kernels/arm/pad2d_compute.cc | 72 + lite/kernels/arm/pad2d_compute.h | 46 + lite/kernels/arm/pool_compute.cc | 228 + lite/kernels/arm/pool_compute.h | 38 + lite/kernels/arm/pool_compute_test.cc | 286 + lite/kernels/arm/power_compute.cc | 45 + lite/kernels/arm/power_compute.h | 34 + lite/kernels/arm/prior_box_compute.cc | 103 + lite/kernels/arm/prior_box_compute.h | 36 + lite/kernels/arm/read_from_array_compute.cc | 57 + lite/kernels/arm/read_from_array_compute.h | 43 + lite/kernels/arm/reduce_max_compute.cc | 91 + lite/kernels/arm/reduce_max_compute.h | 38 + lite/kernels/arm/scale_compute.cc | 46 + lite/kernels/arm/scale_compute.h | 34 + lite/kernels/arm/scale_compute_test.cc | 117 + lite/kernels/arm/sequence_expand_compute.cc | 54 + lite/kernels/arm/sequence_expand_compute.h | 39 + lite/kernels/arm/sequence_pool_compute.cc | 79 + lite/kernels/arm/sequence_pool_compute.h | 40 + lite/kernels/arm/sequence_softmax_compute.cc | 58 + lite/kernels/arm/sequence_softmax_compute.h | 43 + lite/kernels/arm/shape_compute.cc | 41 + lite/kernels/arm/shape_compute.h | 34 + lite/kernels/arm/shuffle_channel_compute.cc | 50 + lite/kernels/arm/shuffle_channel_compute.h | 35 + lite/kernels/arm/slice_compute.cc | 57 + lite/kernels/arm/slice_compute.h | 41 + lite/kernels/arm/softmax_compute.cc | 80 + lite/kernels/arm/softmax_compute.h | 35 + lite/kernels/arm/softmax_compute_test.cc | 135 + lite/kernels/arm/split_compute.cc | 46 + lite/kernels/arm/split_compute.h | 35 + lite/kernels/arm/split_compute_test.cc | 179 + lite/kernels/arm/topk_compute.cc | 47 + lite/kernels/arm/topk_compute.h | 34 + lite/kernels/arm/transpose_compute.cc | 185 + lite/kernels/arm/transpose_compute.h | 48 + lite/kernels/arm/transpose_compute_test.cc | 205 + lite/kernels/arm/while_compute.cc | 54 + lite/kernels/arm/while_compute.h | 83 + lite/kernels/arm/write_to_array_compute.cc | 61 + lite/kernels/arm/write_to_array_compute.h | 42 + lite/kernels/arm/yolo_box_compute.cc | 60 + lite/kernels/arm/yolo_box_compute.h | 34 + lite/kernels/cuda/CMakeLists.txt | 13 + lite/kernels/cuda/io_copy_compute.cc | 143 + lite/kernels/cuda/mul_compute.cc | 31 + lite/kernels/cuda/mul_compute.h | 83 + lite/kernels/cuda/use_kernels.h | 24 + lite/kernels/fpga/CMakeLists.txt | 50 + lite/kernels/fpga/activation_compute.cc | 53 + lite/kernels/fpga/activation_compute.h | 46 + lite/kernels/fpga/activation_compute_test.cc | 97 + lite/kernels/fpga/calib_compute.cc | 114 + lite/kernels/fpga/calib_compute.h | 51 + lite/kernels/fpga/conv_compute.cc | 66 + lite/kernels/fpga/conv_compute.h | 48 + lite/kernels/fpga/conv_compute_test.cc | 315 + lite/kernels/fpga/elementwise_compute.cc | 102 + lite/kernels/fpga/elementwise_compute.h | 62 + lite/kernels/fpga/elementwise_compute_test.cc | 286 + lite/kernels/fpga/fc_compute.cc | 65 + lite/kernels/fpga/fc_compute.h | 49 + lite/kernels/fpga/fc_compute_test.cc | 205 + lite/kernels/fpga/feed_compute.cc | 60 + lite/kernels/fpga/feed_compute.h | 42 + lite/kernels/fpga/fetch_compute.cc | 59 + lite/kernels/fpga/fetch_compute.h | 41 + lite/kernels/fpga/io_copy_compute.cc | 157 + lite/kernels/fpga/layout_compute.cc | 146 + lite/kernels/fpga/pooling_compute.cc | 68 + lite/kernels/fpga/pooling_compute.h | 46 + lite/kernels/fpga/pooling_compute_test.cc | 291 + lite/kernels/fpga/scale_compute.cc | 39 + lite/kernels/fpga/scale_compute.h | 35 + lite/kernels/fpga/softmax_compute.cc | 55 + lite/kernels/fpga/softmax_compute.h | 45 + lite/kernels/fpga/softmax_compute_test.cc | 136 + lite/kernels/host/CMakeLists.txt | 17 + lite/kernels/host/feed_compute.cc | 46 + lite/kernels/host/fetch_compute.cc | 53 + lite/kernels/host/reshape_compute.cc | 95 + lite/kernels/host/reshape_compute.h | 36 + lite/kernels/host/reshape_compute_test.cc | 101 + lite/kernels/host/use_kernels.h | 21 + lite/kernels/npu/CMakeLists.txt | 13 + lite/kernels/npu/graph_compute.cc | 143 + lite/kernels/npu/graph_compute.h | 56 + lite/kernels/opencl/CMakeLists.txt | 61 + lite/kernels/opencl/conv_compute.cc | 296 + lite/kernels/opencl/conv_compute.h | 63 + lite/kernels/opencl/conv_compute_test.cc | 602 ++ .../opencl/depthwise_conv2d_compute.cc | 132 + .../opencl/depthwise_conv2d_compute_test.cc | 181 + .../kernels/opencl/elementwise_add_compute.cc | 107 + lite/kernels/opencl/elementwise_add_compute.h | 51 + .../opencl/elementwise_add_compute_test.cc | 251 + lite/kernels/opencl/fc_compute.cc | 126 + lite/kernels/opencl/fc_compute_test.cc | 200 + ...sion_elementwise_add_activation_compute.cc | 56 + lite/kernels/opencl/io_copy_compute.cc | 145 + lite/kernels/opencl/io_copy_compute_test.cc | 83 + lite/kernels/opencl/mul_compute.cc | 119 + lite/kernels/opencl/mul_compute_test.cc | 170 + lite/kernels/opencl/pool_compute.cc | 127 + lite/kernels/opencl/pool_compute_test.cc | 147 + lite/kernels/opencl/relu_compute.cc | 91 + lite/kernels/opencl/relu_compute_test.cc | 94 + lite/kernels/x86/CMakeLists.txt | 53 + lite/kernels/x86/activation_compute.cc | 127 + lite/kernels/x86/batch_norm_compute.cc | 34 + lite/kernels/x86/batch_norm_compute.h | 159 + lite/kernels/x86/batch_norm_compute_test.cc | 139 + lite/kernels/x86/concat_compute.cc | 25 + lite/kernels/x86/concat_compute.h | 103 + lite/kernels/x86/concat_compute_test.cc | 83 + lite/kernels/x86/conv_compute.cc | 39 + lite/kernels/x86/conv_compute.h | 167 + lite/kernels/x86/conv_compute_test.cc | 92 + lite/kernels/x86/dropout_compute.cc | 26 + lite/kernels/x86/dropout_compute.h | 82 + lite/kernels/x86/dropout_compute_test.cc | 78 + lite/kernels/x86/elementwise_compute.cc | 55 + lite/kernels/x86/elementwise_compute.h | 142 + lite/kernels/x86/elementwise_compute_test.cc | 88 + lite/kernels/x86/fc_compute.cc | 23 + lite/kernels/x86/fc_compute.h | 106 + lite/kernels/x86/fc_compute_test.cc | 100 + lite/kernels/x86/fill_constant_compute.cc | 59 + lite/kernels/x86/mean_compute.cc | 108 + lite/kernels/x86/mul_compute.cc | 42 + lite/kernels/x86/mul_compute.h | 149 + lite/kernels/x86/mul_compute_test.cc | 86 + lite/kernels/x86/pool_compute.cc | 25 + lite/kernels/x86/pool_compute.h | 87 + lite/kernels/x86/pool_compute_test.cc | 79 + lite/kernels/x86/relu_compute.cc | 25 + lite/kernels/x86/relu_compute.h | 52 + lite/kernels/x86/relu_compute_test.cc | 75 + lite/kernels/x86/scale_compute.cc | 25 + lite/kernels/x86/scale_compute.h | 58 + lite/kernels/x86/scale_compute_test.cc | 76 + lite/kernels/x86/sgd_compute.cc | 82 + lite/kernels/x86/softmax_compute.cc | 25 + lite/kernels/x86/softmax_compute.h | 90 + lite/kernels/x86/softmax_compute_test.cc | 74 + lite/kernels/x86/uniform_random_compute.cc | 70 + lite/model_parser/CMakeLists.txt | 33 + lite/model_parser/compatible_pb.cc | 286 + lite/model_parser/compatible_pb.h | 71 + lite/model_parser/compatible_pb_test.cc | 433 + lite/model_parser/cpp/CMakeLists.txt | 6 + lite/model_parser/cpp/block_desc.cc | 47 + lite/model_parser/cpp/block_desc.h | 75 + lite/model_parser/cpp/op_desc.cc | 122 + lite/model_parser/cpp/op_desc.h | 122 + lite/model_parser/cpp/program_desc.cc | 35 + lite/model_parser/cpp/program_desc.h | 57 + lite/model_parser/cpp/var_desc.cc | 15 + lite/model_parser/cpp/var_desc.h | 53 + lite/model_parser/desc_apis.h | 229 + lite/model_parser/model_parser.cc | 501 + lite/model_parser/model_parser.h | 83 + lite/model_parser/model_parser_test.cc | 111 + lite/model_parser/naive_buffer/CMakeLists.txt | 18 + lite/model_parser/naive_buffer/block_desc.cc | 103 + lite/model_parser/naive_buffer/block_desc.h | 86 + .../model_parser/naive_buffer/naive_buffer.cc | 136 + lite/model_parser/naive_buffer/naive_buffer.h | 373 + .../naive_buffer/naive_buffer_test.cc | 178 + .../naive_buffer_wrapper_helper.h | 47 + .../naive_buffer/naive_buffer_wrapper_test.cc | 235 + lite/model_parser/naive_buffer/op_desc.cc | 129 + lite/model_parser/naive_buffer/op_desc.h | 234 + lite/model_parser/naive_buffer/param_desc.cc | 218 + lite/model_parser/naive_buffer/param_desc.h | 88 + .../model_parser/naive_buffer/program_desc.cc | 58 + lite/model_parser/naive_buffer/program_desc.h | 66 + .../naive_buffer/proto/CMakeLists.txt | 1 + .../naive_buffer/proto/framework.nb.cc | 15 + .../naive_buffer/proto/framework.nb.h | 200 + lite/model_parser/naive_buffer/var_desc.cc | 109 + lite/model_parser/naive_buffer/var_desc.h | 63 + lite/model_parser/pb/CMakeLists.txt | 6 + lite/model_parser/pb/block_desc.cc | 47 + lite/model_parser/pb/block_desc.h | 80 + lite/model_parser/pb/op_desc.cc | 132 + lite/model_parser/pb/op_desc.h | 215 + lite/model_parser/pb/program_desc.cc | 36 + lite/model_parser/pb/program_desc.h | 62 + lite/model_parser/pb/var_desc.cc | 317 + lite/model_parser/pb/var_desc.h | 125 + lite/model_parser/runtime.cc | 109 + lite/model_parser/runtime.h | 122 + lite/npu/CMakeLists.txt | 6 + lite/npu/bridge/CMakeLists.txt | 46 + lite/npu/bridge/act_op.cc | 87 + lite/npu/bridge/act_op_test.cc | 100 + lite/npu/bridge/batch_norm_op.cc | 94 + lite/npu/bridge/batch_norm_op_test.cc | 166 + lite/npu/bridge/conv_op.cc | 198 + lite/npu/bridge/conv_op_test.cc | 255 + lite/npu/bridge/elementwise_ops.cc | 76 + lite/npu/bridge/elementwise_ops_test.cc | 182 + lite/npu/bridge/fc_op.cc | 119 + lite/npu/bridge/fc_op_test.cc | 146 + lite/npu/bridge/mul_op.cc | 122 + lite/npu/bridge/mul_op_test.cc | 125 + lite/npu/bridge/paddle_use_npu_bridges.h | 27 + lite/npu/bridge/pool_op.cc | 87 + lite/npu/bridge/pool_op_test.cc | 249 + lite/npu/bridge/registry.cc | 39 + lite/npu/bridge/registry.h | 84 + lite/npu/bridge/scale_op.cc | 88 + lite/npu/bridge/scale_op_test.cc | 123 + lite/npu/bridge/softmax_op.cc | 65 + lite/npu/bridge/softmax_op_test.cc | 134 + lite/npu/bridge/test_helper.cc | 101 + lite/npu/bridge/test_helper.h | 64 + lite/npu/bridge/transpose_op.cc | 76 + lite/npu/bridge/transpose_op_test.cc | 150 + lite/npu/bridge/utils.cc | 179 + lite/npu/bridge/utils.h | 83 + lite/npu/npu_helper.cc | 137 + lite/npu/npu_helper.h | 110 + lite/opencl/CMakeLists.txt | 18 + lite/opencl/cl_caller.cc | 169 + lite/opencl/cl_caller.h | 52 + lite/opencl/cl_context.cc | 126 + lite/opencl/cl_context.h | 54 + lite/opencl/cl_functions_test.cc | 451 + lite/opencl/cl_im2col_test.cc | 330 + lite/opencl/cl_image.cc | 160 + lite/opencl/cl_image.h | 114 + lite/opencl/cl_image_converter.cc | 461 + lite/opencl/cl_image_converter.h | 139 + lite/opencl/cl_include.h | 21 + .../buffer/depthwise_conv2d_kernel.cl | 70 + .../buffer/elementwise_add_kernel.cl | 45 + lite/opencl/cl_kernel/buffer/fc_kernel.cl | 424 + lite/opencl/cl_kernel/buffer/im2col_kernel.cl | 64 + .../opencl/cl_kernel/buffer/mat_mul_kernel.cl | 93 + lite/opencl/cl_kernel/buffer/pool_kernel.cl | 112 + lite/opencl/cl_kernel/buffer/relu_kernel.cl | 22 + lite/opencl/cl_kernel/cl_common.h | 38 + .../cl_kernel/image/channel_add_kernel.cl | 29 + .../cl_kernel/image/elementwise_add_kernel.cl | 26 + lite/opencl/cl_kernel/image/pool_kernel.cl | 90 + lite/opencl/cl_runtime.cc | 170 + lite/opencl/cl_runtime.h | 101 + lite/opencl/cl_utility.cc | 84 + lite/opencl/cl_utility.h | 46 + lite/opencl/cl_wrapper.cc | 732 ++ lite/opencl/cl_wrapper.h | 572 + lite/opencl/target_wrapper.cc | 341 + lite/opencl/target_wrapper.h | 83 + lite/operators/CMakeLists.txt | 187 + lite/operators/activation_ops.cc | 115 + lite/operators/activation_ops.h | 63 + lite/operators/argmax_op.cc | 62 + lite/operators/argmax_op.h | 48 + lite/operators/axpy_op.cc | 63 + lite/operators/axpy_op.h | 48 + lite/operators/batch_norm_op.cc | 112 + lite/operators/batch_norm_op.h | 46 + lite/operators/batch_norm_op_test.cc | 139 + lite/operators/beam_search_decode_op.cc | 59 + lite/operators/beam_search_decode_op.h | 47 + lite/operators/beam_search_op.cc | 69 + lite/operators/beam_search_op.h | 47 + lite/operators/box_coder_op.cc | 59 + lite/operators/box_coder_op.h | 45 + lite/operators/calib_once_op.cc | 30 + lite/operators/calib_once_op.h | 33 + lite/operators/calib_op.cc | 52 + lite/operators/calib_op.h | 59 + lite/operators/calib_op_test.cc | 62 + lite/operators/cast_op.cc | 52 + lite/operators/cast_op.h | 47 + lite/operators/compare_op.cc | 61 + lite/operators/compare_op.h | 47 + lite/operators/concat_op.cc | 77 + lite/operators/concat_op.h | 46 + lite/operators/concat_op_test.cc | 59 + lite/operators/conv_op.cc | 80 + lite/operators/conv_op.h | 107 + lite/operators/conv_transpose_op.cc | 97 + lite/operators/conv_transpose_op.h | 51 + lite/operators/crop_op.cc | 55 + lite/operators/crop_op.h | 46 + lite/operators/decode_bboxes_op.cc | 60 + lite/operators/decode_bboxes_op.h | 45 + lite/operators/density_prior_box_op.cc | 80 + lite/operators/density_prior_box_op.h | 46 + lite/operators/dropout_op.cc | 78 + lite/operators/elementwise_ops.cc | 96 + lite/operators/elementwise_ops.h | 66 + lite/operators/fake_dequantize_max_abs.cc | 25 + lite/operators/fake_dequantize_max_abs.h | 64 + .../fake_quantize_moving_avg_max_abs.cc | 25 + .../fake_quantize_moving_avg_max_abs.h | 69 + lite/operators/fc_op.cc | 107 + lite/operators/fc_op.h | 61 + lite/operators/fc_op_test.cc | 78 + lite/operators/feed_op.cc | 65 + lite/operators/fetch_op.cc | 60 + lite/operators/fill_constant_op.cc | 59 + .../fusion_elementwise_activation_ops.cc | 105 + .../fusion_elementwise_activation_ops.h | 71 + .../fusion_elementwise_activation_ops_test.cc | 63 + lite/operators/graph_op.cc | 52 + lite/operators/graph_op.h | 52 + lite/operators/gru_op.cc | 108 + lite/operators/gru_op.h | 46 + lite/operators/gru_unit_op.cc | 105 + lite/operators/gru_unit_op.h | 46 + lite/operators/im2sequence_op.cc | 77 + lite/operators/im2sequence_op.h | 47 + lite/operators/increment_op.cc | 51 + lite/operators/increment_op.h | 47 + lite/operators/interpolate_op.cc | 92 + lite/operators/interpolate_op.h | 47 + lite/operators/io_copy_once_op.cc | 30 + lite/operators/io_copy_once_op.h | 33 + lite/operators/io_copy_op.cc | 46 + lite/operators/io_copy_op.h | 42 + lite/operators/is_empty_op.cc | 40 + lite/operators/is_empty_op.h | 47 + lite/operators/layout_once_op.cc | 30 + lite/operators/layout_once_op.h | 33 + lite/operators/layout_op.cc | 46 + lite/operators/layout_op.h | 42 + lite/operators/lod_reset_op.cc | 60 + lite/operators/lod_reset_op.h | 47 + lite/operators/logical_op.cc | 80 + lite/operators/logical_op.h | 66 + lite/operators/lookup_table_op.cc | 75 + lite/operators/lookup_table_op.h | 46 + lite/operators/lrn_op.cc | 52 + lite/operators/lrn_op.h | 44 + lite/operators/mean_op.cc | 100 + lite/operators/mul_op.cc | 122 + lite/operators/mul_op.h | 93 + lite/operators/multiclass_nms_op.cc | 59 + lite/operators/multiclass_nms_op.h | 45 + lite/operators/negative_op.cc | 51 + lite/operators/negative_op.h | 46 + lite/operators/norm_op.cc | 52 + lite/operators/norm_op.h | 47 + lite/operators/op_params.cc | 15 + lite/operators/op_params.h | 699 ++ lite/operators/pad2d_op.cc | 58 + lite/operators/pad2d_op.h | 46 + lite/operators/pool_op.cc | 90 + lite/operators/pool_op.h | 82 + lite/operators/pool_op_test.cc | 90 + lite/operators/power_op.cc | 53 + lite/operators/power_op.h | 47 + lite/operators/prior_box_op.cc | 75 + lite/operators/prior_box_op.h | 45 + lite/operators/read_from_array_op.cc | 47 + lite/operators/read_from_array_op.h | 47 + lite/operators/reduce_max_op.cc | 112 + lite/operators/reduce_max_op.h | 43 + lite/operators/relu_op.cc | 49 + lite/operators/relu_op.h | 46 + lite/operators/reshape_op.cc | 150 + lite/operators/reshape_op.h | 63 + lite/operators/reshape_op_test.cc | 145 + lite/operators/scale_op.cc | 49 + lite/operators/scale_op.h | 46 + lite/operators/scale_op_test.cc | 58 + lite/operators/sequence_expand_op.cc | 93 + lite/operators/sequence_expand_op.h | 46 + lite/operators/sequence_pool_op.cc | 55 + lite/operators/sequence_pool_op.h | 43 + lite/operators/sequence_softmax_op.cc | 50 + lite/operators/sequence_softmax_op.h | 47 + lite/operators/sgd_op.cc | 55 + lite/operators/sgd_op.h | 50 + lite/operators/shape_op.cc | 49 + lite/operators/shape_op.h | 44 + lite/operators/shuffle_channel_op.cc | 52 + lite/operators/shuffle_channel_op.h | 50 + lite/operators/slice_op.cc | 90 + lite/operators/slice_op.h | 47 + lite/operators/softmax_op.cc | 59 + lite/operators/softmax_op.h | 46 + lite/operators/softmax_op_test.cc | 54 + lite/operators/split_op.cc | 82 + lite/operators/split_op.h | 46 + lite/operators/topk_op.cc | 59 + lite/operators/topk_op.h | 46 + lite/operators/transpose_op.cc | 165 + lite/operators/transpose_op.h | 66 + lite/operators/transpose_op_test.cc | 93 + lite/operators/uniform_random_op.cc | 45 + lite/operators/uniform_random_op.h | 50 + lite/operators/while_op.cc | 55 + lite/operators/while_op.h | 48 + lite/operators/write_to_array_op.cc | 48 + lite/operators/write_to_array_op.h | 47 + lite/operators/yolo_box_op.cc | 70 + lite/operators/yolo_box_op.h | 46 + lite/python/lite_test.py | 103 + lite/tests/CMakeLists.txt | 1 + lite/tests/README.md | 1 + lite/tests/kernels/CMakeLists.txt | 37 + lite/tests/kernels/activation_compute_test.cc | 449 + lite/tests/kernels/argmax_compute_test.cc | 130 + lite/tests/kernels/axpy_compute_test.cc | 136 + .../kernels/bilinear_interp_compute_test.cc | 282 + lite/tests/kernels/box_coder_compute_test.cc | 212 + lite/tests/kernels/compare_compute_test.cc | 243 + .../kernels/conv2d_transpose_compute_test.cc | 465 + lite/tests/kernels/crop_compute_test.cc | 129 + .../kernels/decode_bboxes_compute_test.cc | 225 + .../tests/kernels/elementwise_compute_test.cc | 415 + lite/tests/kernels/fc_compute_test.cc | 201 + lite/tests/kernels/fill_data.h | 33 + lite/tests/kernels/gru_unit_test.cc | 363 + .../tests/kernels/im2sequence_compute_test.cc | 249 + lite/tests/kernels/increment_compute_test.cc | 94 + lite/tests/kernels/logical_compute_test.cc | 106 + lite/tests/kernels/lrn_compute_test.cc | 206 + .../kernels/multiclass_nms_compute_test.cc | 181 + .../kernels/nearest_interp_compute_test.cc | 190 + lite/tests/kernels/negative_compute_test.cc | 80 + lite/tests/kernels/norm_compute_test.cc | 110 + lite/tests/kernels/pad2d_compute_test.cc | 182 + lite/tests/kernels/power_compute_test.cc | 99 + lite/tests/kernels/prior_box_compute_test.cc | 752 ++ .../kernels/read_from_array_compute_test.cc | 105 + lite/tests/kernels/reduce_max_compute_test.cc | 347 + lite/tests/kernels/scale_compute_test.cc | 125 + .../kernels/sequence_expand_compute_test.cc | 188 + .../kernels/sequence_pool_compute_test.cc | 195 + .../kernels/sequence_softmax_compute_test.cc | 123 + lite/tests/kernels/shape_compute_test.cc | 86 + .../kernels/shuffle_channel_compute_test.cc | 110 + lite/tests/kernels/test_funcs.h | 191 + lite/tests/kernels/test_sgemm.cc | 353 + lite/tests/kernels/topk_compute_test.cc | 119 + .../kernels/write_to_array_compute_test.cc | 116 + lite/tests/kernels/yolo_box_compute_test.cc | 254 + lite/tools/CMakeLists.txt | 1 + lite/tools/Dockerfile.mobile | 96 + lite/tools/build.sh | 153 + lite/tools/build_fpga.sh | 26 + lite/tools/build_ios_armv7_arm64.sh | 25 + lite/tools/ci_build.sh | 935 ++ lite/tools/debug/CMakeLists.txt | 15 + lite/tools/debug/analysis_tool.py | 401 + lite/tools/debug/check_model.sh | 182 + lite/tools/debug/debug_utils.cc | 15 + lite/tools/debug/debug_utils.h | 337 + lite/tools/debug/model_debug_tool.cc | 109 + lite/tools/gitlab_review.sh | 75 + lite/tools/mobile_readme.md | 135 + lite/utils/CMakeLists.txt | 26 + lite/utils/all.h | 28 + lite/utils/any.cc | 23 + lite/utils/any.h | 71 + lite/utils/check.h | 41 + lite/utils/container.h | 51 + lite/utils/cp_logging.cc | 19 + lite/utils/cp_logging.h | 20 + lite/utils/factory.h | 100 + lite/utils/hash.h | 28 + lite/utils/io.h | 46 + lite/utils/logging.cc | 62 + lite/utils/logging.h | 184 + lite/utils/logging_test.cc | 31 + lite/utils/macros.h | 55 + lite/utils/paddle_enforce.h | 39 + lite/utils/replace_stl/stream.cc | 105 + lite/utils/replace_stl/stream.h | 76 + lite/utils/string.cc | 19 + lite/utils/string.h | 97 + lite/utils/varient.h | 151 + lite/utils/varient_test.cc | 58 + lite/x86/CMakeLists.txt | 14 + lite/x86/cpu_info.cc | 160 + lite/x86/cpu_info.h | 80 + lite/x86/cupti_lib_path.h.in | 17 + lite/x86/dynamic_loader.cc | 263 + lite/x86/dynamic_loader.h | 38 + lite/x86/jit/CMakeLists.txt | 26 + lite/x86/jit/README.en.md | 103 + lite/x86/jit/README.md | 94 + lite/x86/jit/benchmark.cc | 576 + lite/x86/jit/gen/CMakeLists.txt | 36 + lite/x86/jit/gen/act.cc | 164 + lite/x86/jit/gen/act.h | 347 + lite/x86/jit/gen/blas.cc | 190 + lite/x86/jit/gen/blas.h | 125 + lite/x86/jit/gen/embseqpool.cc | 148 + lite/x86/jit/gen/embseqpool.h | 81 + lite/x86/jit/gen/gru.cc | 116 + lite/x86/jit/gen/gru.h | 116 + lite/x86/jit/gen/hopv.cc | 103 + lite/x86/jit/gen/hopv.h | 92 + lite/x86/jit/gen/jitcode.h | 133 + lite/x86/jit/gen/lstm.cc | 142 + lite/x86/jit/gen/lstm.h | 121 + lite/x86/jit/gen/matmul.cc | 127 + lite/x86/jit/gen/matmul.h | 62 + lite/x86/jit/gen/seqpool.cc | 85 + lite/x86/jit/gen/seqpool.h | 216 + lite/x86/jit/gen/sgd.cc | 130 + lite/x86/jit/gen/sgd.h | 60 + lite/x86/jit/gen/vbroadcast.cc | 91 + lite/x86/jit/gen/vbroadcast.h | 54 + lite/x86/jit/gen_base.cc | 95 + lite/x86/jit/gen_base.h | 87 + lite/x86/jit/helper.cc | 139 + lite/x86/jit/helper.h | 267 + lite/x86/jit/kernel_base.h | 365 + lite/x86/jit/kernel_key.cc | 71 + lite/x86/jit/kernel_key.h | 55 + lite/x86/jit/kernel_pool.cc | 41 + lite/x86/jit/kernel_pool.h | 116 + lite/x86/jit/macro.h | 32 + lite/x86/jit/more/CMakeLists.txt | 18 + lite/x86/jit/more/intrinsic/CMakeLists.txt | 9 + lite/x86/jit/more/intrinsic/crf_decoding.cc | 185 + lite/x86/jit/more/intrinsic/crf_decoding.h | 45 + lite/x86/jit/more/intrinsic/layer_norm.cc | 181 + lite/x86/jit/more/intrinsic/layer_norm.h | 48 + lite/x86/jit/more/mix/CMakeLists.txt | 15 + lite/x86/jit/more/mix/mix.cc | 255 + lite/x86/jit/more/mix/mix.h | 65 + lite/x86/jit/more/mkl/CMakeLists.txt | 20 + lite/x86/jit/more/mkl/mkl.cc | 336 + lite/x86/jit/more/mkl/mkl.h | 244 + lite/x86/jit/refer/CMakeLists.txt | 40 + lite/x86/jit/refer/refer.cc | 61 + lite/x86/jit/refer/refer.h | 603 ++ lite/x86/jit/registry.h | 178 + lite/x86/jit/test.cc | 1447 +++ lite/x86/legacy_place.h | 30 + lite/x86/math/CMakeLists.txt | 62 + lite/x86/math/beam_search.cc | 322 + lite/x86/math/beam_search.h | 125 + lite/x86/math/beam_search_test.cc | 152 + lite/x86/math/blas.cc | 57 + lite/x86/math/blas.h | 408 + lite/x86/math/blas_impl.h | 812 ++ lite/x86/math/concat_and_split.cc | 131 + lite/x86/math/concat_and_split.h | 83 + lite/x86/math/context_project.cc | 28 + lite/x86/math/context_project.h | 361 + lite/x86/math/cos_sim_functor.cc | 57 + lite/x86/math/cos_sim_functor.h | 187 + lite/x86/math/cpu_vec.h | 662 ++ lite/x86/math/cross_entropy.cc | 78 + lite/x86/math/cross_entropy.h | 74 + lite/x86/math/detail/CMakeLists.txt | 1 + lite/x86/math/detail/activation_functions.h | 193 + lite/x86/math/detail/avx_functions.cc | 91 + lite/x86/math/detail/avx_mathfun.h | 731 ++ lite/x86/math/detail/gru_cpu_kernel.h | 608 ++ lite/x86/math/detail/gru_kernel.h | 222 + lite/x86/math/gru_compute.cc | 181 + lite/x86/math/gru_compute.h | 69 + lite/x86/math/im2col.cc | 292 + lite/x86/math/im2col.h | 108 + lite/x86/math/im2col_cfo_cpu.h | 256 + lite/x86/math/im2col_test.cc | 331 + lite/x86/math/math_function.cc | 158 + lite/x86/math/math_function.h | 93 + lite/x86/math/math_function_impl.h | 192 + lite/x86/math/math_function_test.cc | 344 + lite/x86/math/maxouting.cc | 106 + lite/x86/math/maxouting.h | 47 + lite/x86/math/pooling.cc | 906 ++ lite/x86/math/pooling.h | 258 + lite/x86/math/prelu.h | 51 + lite/x86/math/sample_prob.cc | 28 + lite/x86/math/sample_prob.h | 128 + lite/x86/math/sampler.cc | 102 + lite/x86/math/sampler.h | 131 + lite/x86/math/sequence_pooling.cc | 406 + lite/x86/math/sequence_pooling.h | 52 + lite/x86/math/sequence_pooling_test.cc | 130 + lite/x86/math/softmax.cc | 33 + lite/x86/math/softmax.h | 67 + lite/x86/math/softmax_impl.h | 244 + lite/x86/math/tree2col.cc | 204 + lite/x86/math/tree2col.h | 95 + lite/x86/math/unpooling.cc | 96 + lite/x86/math/unpooling.h | 44 + lite/x86/math/vol2col.cc | 204 + lite/x86/math/vol2col.h | 92 + lite/x86/mklml.cc | 30 + lite/x86/mklml.h | 99 + lite/x86/port.h | 175 + lite/x86/target_wrapper.cc | 36 + lite/x86/target_wrapper.h | 22 + lite/x86/warpctc_lib_path.h.in | 17 + .../AppIcon.appiconset/Contents.json | 2 +- .../Assets.xcassets/Contents.json | 2 +- .../AppIcon.appiconset/Contents.json | 2 +- .../Assets.xcassets/Contents.json | 2 +- .../AppIcon.appiconset/Contents.json | 2 +- .../Assets.xcassets/Contents.json | 2 +- .../paddle-mobile.imageset/Contents.json | 2 +- .../OCInterface/PaddleMobileGPU.h | 2 - .../AppIcon.appiconset/Contents.json | 2 +- .../Assets.xcassets/Contents.json | 2 +- .../paddle-mobile/paddle_mobile.h | 2 - mobile/CMakeLists.txt | 288 + CONTRIBUTING.md => mobile/CONTRIBUTING.md | 2 - Dockerfile => mobile/Dockerfile | 0 LICENSE => mobile/LICENSE | 0 mobile/README.md | 137 + .../benchmark}/arm_benchmark.md | 2 +- .../benchmark}/metal_benchmark.md | 2 +- {demo => mobile/demo}/ReadMe.md | 0 {demo => mobile/demo}/getDemo.sh | 2 +- {doc => mobile/doc}/build.md | 0 {doc => mobile/doc}/design_doc.md | 6 - {doc => mobile/doc}/development_android.md | 1 - .../doc}/development_android_GPU.md | 0 {doc => mobile/doc}/development_arm_linux.md | 1 - {doc => mobile/doc}/development_fpga.md | 0 {doc => mobile/doc}/development_ios.md | 0 {doc => mobile/doc}/quantification.md | 6 - {src => mobile/src}/common/common.h | 0 {src => mobile/src}/common/enforce.h | 0 {src => mobile/src}/common/log.h | 0 {src => mobile/src}/common/threadpool.h | 0 {src => mobile/src}/common/type_define.h | 0 {src => mobile/src}/common/types.cpp | 0 {src => mobile/src}/common/types.h | 0 {src => mobile/src}/common/util.cpp | 0 {src => mobile/src}/common/util.h | 0 {src => mobile/src}/common/variant.h | 0 {src => mobile/src}/fpga/KD/alignment.h | 0 {src => mobile/src}/fpga/KD/context.hpp | 0 {src => mobile/src}/fpga/KD/dl_engine.cpp | 0 {src => mobile/src}/fpga/KD/dl_engine.hpp | 0 {src => mobile/src}/fpga/KD/float16.hpp | 0 {src => mobile/src}/fpga/KD/layout.hpp | 0 .../src}/fpga/KD/llapi/bias_scale.cpp | 0 .../src}/fpga/KD/llapi/bias_scale.h | 0 {src => mobile/src}/fpga/KD/llapi/config.h | 0 {src => mobile/src}/fpga/KD/llapi/filter.cpp | 0 {src => mobile/src}/fpga/KD/llapi/filter.h | 0 {src => mobile/src}/fpga/KD/llapi/image.cpp | 0 {src => mobile/src}/fpga/KD/llapi/image.h | 0 .../src}/fpga/KD/llapi/zynqmp_api.cpp | 0 .../src}/fpga/KD/llapi/zynqmp_api.h | 0 {src => mobile/src}/fpga/KD/pe.hpp | 0 {src => mobile/src}/fpga/KD/pe_params.hpp | 0 {src => mobile/src}/fpga/KD/pes/concat_pe.hpp | 0 {src => mobile/src}/fpga/KD/pes/conv_pe.hpp | 0 .../src}/fpga/KD/pes/conv_process.hpp | 0 .../src}/fpga/KD/pes/depthwise_conv_pe.hpp | 0 .../src}/fpga/KD/pes/elementwise_add_pe.hpp | 0 .../src}/fpga/KD/pes/fully_connected_pe.hpp | 0 {src => mobile/src}/fpga/KD/pes/input_pe.hpp | 0 .../src}/fpga/KD/pes/math_func_neon.h | 0 {src => mobile/src}/fpga/KD/pes/output_pe.hpp | 0 .../src}/fpga/KD/pes/pooling_pe.hpp | 0 .../src}/fpga/KD/pes/softmax_pe.cpp | 0 .../src}/fpga/KD/pes/softmax_pe.hpp | 0 {src => mobile/src}/fpga/KD/shape.hpp | 0 {src => mobile/src}/fpga/KD/tensor.hpp | 0 {src => mobile/src}/fpga/KD/tensor_util.cpp | 0 {src => mobile/src}/fpga/KD/tensor_util.hpp | 0 {src => mobile/src}/fpga/V1/api.cpp | 0 {src => mobile/src}/fpga/V1/api.h | 0 {src => mobile/src}/fpga/V1/bias_scale.cpp | 0 {src => mobile/src}/fpga/V1/bias_scale.h | 0 .../src}/fpga/V1/deconv_bias_scale.cpp | 0 .../src}/fpga/V1/deconv_bias_scale.h | 0 {src => mobile/src}/fpga/V1/deconv_filter.cpp | 0 {src => mobile/src}/fpga/V1/deconv_filter.h | 0 {src => mobile/src}/fpga/V1/filter.cpp | 0 {src => mobile/src}/fpga/V1/filter.h | 0 {src => mobile/src}/fpga/V1/image.cpp | 0 {src => mobile/src}/fpga/V1/image.h | 0 {src => mobile/src}/fpga/V1/pe.cpp | 0 {src => mobile/src}/fpga/V2/api.cpp | 0 {src => mobile/src}/fpga/V2/api.h | 0 {src => mobile/src}/fpga/V2/bias_scale.cpp | 0 {src => mobile/src}/fpga/V2/bias_scale.h | 0 .../src}/fpga/V2/deconv_bias_scale.cpp | 0 .../src}/fpga/V2/deconv_bias_scale.h | 0 {src => mobile/src}/fpga/V2/deconv_filter.cpp | 0 {src => mobile/src}/fpga/V2/deconv_filter.h | 0 {src => mobile/src}/fpga/V2/filter.cpp | 0 {src => mobile/src}/fpga/V2/filter.h | 0 {src => mobile/src}/fpga/V2/image.cpp | 0 {src => mobile/src}/fpga/V2/image.h | 0 {src => mobile/src}/fpga/V2/pe.cpp | 0 {src => mobile/src}/fpga/common/config.h | 0 {src => mobile/src}/fpga/common/driver.cpp | 0 {src => mobile/src}/fpga/common/driver.h | 0 .../src}/fpga/common/fpga_common.cpp | 0 {src => mobile/src}/fpga/common/fpga_common.h | 0 {src => mobile/src}/fpga/common/pe.h | 0 {src => mobile/src}/framework/CMakeLists.txt | 0 {src => mobile/src}/framework/attribute.cpp | 0 {src => mobile/src}/framework/attribute.h | 0 {src => mobile/src}/framework/cl/cl_deleter.h | 0 .../src}/framework/cl/cl_engine.cpp | 0 {src => mobile/src}/framework/cl/cl_engine.h | 0 {src => mobile/src}/framework/cl/cl_half.cpp | 0 {src => mobile/src}/framework/cl/cl_half.h | 0 {src => mobile/src}/framework/cl/cl_helper.h | 0 {src => mobile/src}/framework/cl/cl_image.cpp | 0 {src => mobile/src}/framework/cl/cl_image.h | 0 .../src}/framework/cl/cl_image_converter.cpp | 0 .../src}/framework/cl/cl_image_converter.h | 0 {src => mobile/src}/framework/cl/cl_scope.h | 0 {src => mobile/src}/framework/cl/cl_tensor.h | 0 {src => mobile/src}/framework/cl/cl_tool.cpp | 0 {src => mobile/src}/framework/cl/cl_tool.h | 0 {src => mobile/src}/framework/context.cpp | 0 {src => mobile/src}/framework/context.h | 0 {src => mobile/src}/framework/data_layout.h | 0 {src => mobile/src}/framework/data_type.cpp | 0 {src => mobile/src}/framework/data_type.h | 0 {src => mobile/src}/framework/ddim.cpp | 0 {src => mobile/src}/framework/ddim.h | 0 {src => mobile/src}/framework/dim.h | 0 {src => mobile/src}/framework/executor.cpp | 0 {src => mobile/src}/framework/executor.h | 0 .../src}/framework/framework.pb-c.c | 0 .../src}/framework/framework.pb-c.h | 0 {src => mobile/src}/framework/framework.proto | 0 {src => mobile/src}/framework/load_ops.h | 0 {src => mobile/src}/framework/loader.cpp | 0 {src => mobile/src}/framework/loader.h | 0 {src => mobile/src}/framework/lod_tensor.cpp | 0 {src => mobile/src}/framework/lod_tensor.h | 0 {src => mobile/src}/framework/mixed_vector.h | 0 {src => mobile/src}/framework/op_info.h | 0 .../src}/framework/op_kernel_type.h | 0 .../src}/framework/op_proto_maker.h | 0 {src => mobile/src}/framework/op_registry.h | 0 {src => mobile/src}/framework/operator.cpp | 0 {src => mobile/src}/framework/operator.h | 0 .../src}/framework/program/block_desc.cpp | 0 .../src}/framework/program/block_desc.h | 0 .../src}/framework/program/op_desc.cpp | 0 .../src}/framework/program/op_desc.h | 0 .../program-optimize/fusion_op_register.h | 0 .../program/program-optimize/node.cpp | 0 .../framework/program/program-optimize/node.h | 0 .../program-optimize/program_optimize.cpp | 0 .../program-optimize/program_optimize.h | 0 .../src}/framework/program/program.h | 0 .../src}/framework/program/program_desc.cpp | 0 .../src}/framework/program/program_desc.h | 0 .../src}/framework/program/tensor_desc.h | 0 .../src}/framework/program/var_desc.h | 0 {src => mobile/src}/framework/scope.cpp | 0 {src => mobile/src}/framework/scope.h | 0 .../src}/framework/selected_rows.cpp | 0 {src => mobile/src}/framework/selected_rows.h | 0 {src => mobile/src}/framework/tensor.h | 0 {src => mobile/src}/framework/tensor_base.h | 0 {src => mobile/src}/framework/tensor_util.cpp | 0 {src => mobile/src}/framework/tensor_util.h | 0 {src => mobile/src}/framework/type_trait.h | 0 {src => mobile/src}/framework/variable.h | 0 .../src}/framework/zynqmp/ztensor.hpp | 0 {src => mobile/src}/io/api.cc | 0 {src => mobile/src}/io/api_paddle_mobile.cc | 0 {src => mobile/src}/io/api_paddle_mobile.h | 0 .../src}/io/ios_io/PaddleMobileCPU.h | 0 .../src}/io/ios_io/PaddleMobileCPU.mm | 0 {src => mobile/src}/io/jni/PML.java | 0 .../src}/io/jni/paddle_mobile_jni.cpp | 0 .../src}/io/jni/paddle_mobile_jni.h | 0 {src => mobile/src}/io/loader.h | 0 {src => mobile/src}/io/opencl_interface.cpp | 0 {src => mobile/src}/io/opencl_interface.h | 0 {src => mobile/src}/io/paddle_inference_api.h | 0 {src => mobile/src}/io/paddle_mobile.cpp | 0 {src => mobile/src}/io/paddle_mobile.h | 0 {src => mobile/src}/io/paddle_mobile_wrap.cpp | 0 {src => mobile/src}/io/paddle_mobile_wrap.h | 1 + .../src}/io/paddle_test_inference_api.cpp | 0 .../src}/io/paddle_test_inference_api.h | 0 {src => mobile/src}/memory/t_malloc.cpp | 0 {src => mobile/src}/memory/t_malloc.h | 0 .../src}/operators/activation_op.cpp | 0 {src => mobile/src}/operators/activation_op.h | 0 {src => mobile/src}/operators/assign_op.cpp | 0 {src => mobile/src}/operators/assign_op.h | 0 .../src}/operators/assign_value_op.cpp | 0 .../src}/operators/assign_value_op.h | 0 .../src}/operators/batchnorm_op.cpp | 0 {src => mobile/src}/operators/batchnorm_op.h | 0 .../src}/operators/beam_search_decode_op.cpp | 0 .../src}/operators/beam_search_decode_op.h | 0 .../src}/operators/beam_search_op.cpp | 0 .../src}/operators/beam_search_op.h | 0 .../src}/operators/bilinear_interp_op.cpp | 0 .../src}/operators/bilinear_interp_op.h | 0 .../src}/operators/box_coder_op.cpp | 0 {src => mobile/src}/operators/box_coder_op.h | 0 {src => mobile/src}/operators/cast_op.cpp | 0 {src => mobile/src}/operators/cast_op.h | 0 {src => mobile/src}/operators/compare_op.cpp | 0 {src => mobile/src}/operators/compare_op.h | 0 {src => mobile/src}/operators/concat_op.cpp | 0 {src => mobile/src}/operators/concat_op.h | 0 .../src}/operators/conditional_block_op.cpp | 0 .../src}/operators/conditional_block_op.h | 0 .../tensor_array_read_write_op.cpp | 0 .../controlflow/tensor_array_read_write_op.h | 0 .../src}/operators/controlflow/while_op.cpp | 0 .../src}/operators/controlflow/while_op.h | 0 {src => mobile/src}/operators/conv_op.cpp | 0 {src => mobile/src}/operators/conv_op.h | 0 .../src}/operators/conv_transpose_op.cpp | 0 .../src}/operators/conv_transpose_op.h | 0 {src => mobile/src}/operators/crf_op.cpp | 0 {src => mobile/src}/operators/crf_op.h | 0 .../src}/operators/depthwise_conv_op.cpp | 0 .../src}/operators/depthwise_conv_op.h | 0 .../src}/operators/dequantize_op.cpp | 0 {src => mobile/src}/operators/dequantize_op.h | 0 .../src}/operators/detection_ops.cpp | 0 {src => mobile/src}/operators/detection_ops.h | 0 {src => mobile/src}/operators/dropout_op.cpp | 0 {src => mobile/src}/operators/dropout_op.h | 0 .../src}/operators/elementwise_add_op.cpp | 0 .../src}/operators/elementwise_add_op.h | 0 .../src}/operators/elementwise_mul_op.cpp | 0 .../src}/operators/elementwise_mul_op.h | 0 .../src}/operators/elementwise_sub_op.cpp | 0 .../src}/operators/elementwise_sub_op.h | 0 {src => mobile/src}/operators/exp_op.cpp | 0 {src => mobile/src}/operators/exp_op.h | 0 {src => mobile/src}/operators/feed_op.cpp | 0 {src => mobile/src}/operators/feed_op.h | 0 {src => mobile/src}/operators/fetch_op.cpp | 0 {src => mobile/src}/operators/fetch_op.h | 0 .../fill_constant_batch_size_like_op.cpp | 0 .../fill_constant_batch_size_like_op.h | 0 .../src}/operators/fill_constant_op.cpp | 0 .../src}/operators/fill_constant_op.h | 0 {src => mobile/src}/operators/flatten2_op.cpp | 0 {src => mobile/src}/operators/flatten2_op.h | 0 {src => mobile/src}/operators/flatten_op.cpp | 0 {src => mobile/src}/operators/flatten_op.h | 0 .../src}/operators/fusion_conv_add_bn_op.cpp | 0 .../src}/operators/fusion_conv_add_bn_op.h | 0 .../operators/fusion_conv_add_bn_relu_op.cpp | 0 .../operators/fusion_conv_add_bn_relu_op.h | 0 .../src}/operators/fusion_conv_add_op.cpp | 0 .../src}/operators/fusion_conv_add_op.h | 0 .../operators/fusion_conv_add_relu_op.cpp | 0 .../src}/operators/fusion_conv_add_relu_op.h | 0 .../operators/fusion_conv_bn_add_relu_op.cpp | 0 .../operators/fusion_conv_bn_add_relu_op.h | 0 .../src}/operators/fusion_conv_bn_op.cpp | 0 .../src}/operators/fusion_conv_bn_op.h | 0 .../src}/operators/fusion_conv_bn_relu_op.cpp | 0 .../src}/operators/fusion_conv_bn_relu_op.h | 0 .../src}/operators/fusion_conv_relu_op.cpp | 0 .../src}/operators/fusion_conv_relu_op.h | 0 .../operators/fusion_deconv_add_bn_op.cpp | 0 .../src}/operators/fusion_deconv_add_bn_op.h | 0 .../fusion_deconv_add_bn_relu_op.cpp | 0 .../operators/fusion_deconv_add_bn_relu_op.h | 0 .../src}/operators/fusion_deconv_add_op.cpp | 0 .../src}/operators/fusion_deconv_add_op.h | 0 .../operators/fusion_deconv_add_relu_op.cpp | 0 .../operators/fusion_deconv_add_relu_op.h | 0 .../operators/fusion_deconv_bn_relu_op.cpp | 0 .../src}/operators/fusion_deconv_bn_relu_op.h | 0 .../src}/operators/fusion_deconv_relu_op.cpp | 0 .../src}/operators/fusion_deconv_relu_op.h | 0 .../operators/fusion_dequant_add_bn_op.cpp | 0 .../src}/operators/fusion_dequant_add_bn_op.h | 0 .../fusion_dequant_add_bn_relu_op.cpp | 0 .../operators/fusion_dequant_add_bn_relu_op.h | 0 .../fusion_dequant_add_bn_relu_quant_op.cpp | 0 .../fusion_dequant_add_bn_relu_quant_op.h | 0 .../src}/operators/fusion_dequant_bn_op.cpp | 0 .../src}/operators/fusion_dequant_bn_op.h | 0 .../operators/fusion_dequant_bn_relu_op.h | 0 .../operators/fusion_dwconv_bn_relu_op.cpp | 0 .../src}/operators/fusion_dwconv_bn_relu_op.h | 0 .../fusion_elementwise_add_relu_op.cpp | 0 .../fusion_elementwise_add_relu_op.h | 0 .../src}/operators/fusion_fc_op.cpp | 0 {src => mobile/src}/operators/fusion_fc_op.h | 0 .../src}/operators/fusion_fc_relu_op.cpp | 0 .../src}/operators/fusion_fc_relu_op.h | 0 {src => mobile/src}/operators/gru_op.cpp | 0 {src => mobile/src}/operators/gru_op.h | 0 {src => mobile/src}/operators/gru_unit_op.cpp | 0 {src => mobile/src}/operators/gru_unit_op.h | 0 .../src}/operators/im2sequence_op.cpp | 0 .../src}/operators/im2sequence_op.h | 0 .../src}/operators/increment_op.cpp | 0 {src => mobile/src}/operators/increment_op.h | 0 {src => mobile/src}/operators/is_empty_op.cpp | 0 {src => mobile/src}/operators/is_empty_op.h | 0 .../src}/operators/kernel/activation_kernel.h | 0 .../kernel/arm/activation_kernel.cpp | 0 .../kernel/arm/anchor_generator_kernel.cpp | 0 .../operators/kernel/arm/assign_kernel.cpp | 0 .../kernel/arm/assign_value_kernel.cpp | 0 .../operators/kernel/arm/batchnorm_kernel.cpp | 0 .../kernel/arm/beam_search_decode_kernel.cpp | 0 .../kernel/arm/beam_search_kernel.cpp | 0 .../kernel/arm/bilinear_interp_kernel.cpp | 0 .../operators/kernel/arm/box_coder_kernel.cpp | 0 .../src}/operators/kernel/arm/cast_kernel.cpp | 0 .../operators/kernel/arm/compare_kernel.cpp | 0 .../operators/kernel/arm/concat_kernel.cpp | 0 .../kernel/arm/conditional_block_kernel.cpp | 0 .../convolution/conv_add_bn_relu_kernel.cpp | 0 .../arm/convolution/conv_add_kernel.cpp | 0 .../arm/convolution/conv_add_relu_kernel.cpp | 0 .../convolution/conv_bn_add_relu_kernel.cpp | 0 .../arm/convolution/conv_bn_relu_kernel.cpp | 0 .../kernel/arm/convolution/conv_common.cpp | 0 .../kernel/arm/convolution/conv_common.h | 0 .../kernel/arm/convolution/conv_kernel.cpp | 0 .../arm/convolution/conv_relu_kernel.cpp | 0 .../arm/convolution/conv_transpose_kernel.cpp | 0 .../arm/convolution/dwconv_bn_relu_kernel.cpp | 0 .../src}/operators/kernel/arm/crf_kernel.cpp | 0 .../kernel/arm/density_prior_box_kernel.cpp | 0 .../kernel/arm/dequantize_bn_kernel.cpp | 0 .../kernel/arm/dequantize_kernel.cpp | 0 .../operators/kernel/arm/dropout_kernel.cpp | 0 .../kernel/arm/elementwise_add_kernel.cpp | 0 .../kernel/arm/elementwise_mul_kernel.cpp | 0 .../kernel/arm/elementwise_sub_kernel.cpp | 0 .../src}/operators/kernel/arm/exp_kernel.cpp | 0 .../src}/operators/kernel/arm/feed_kernel.cpp | 0 .../operators/kernel/arm/fetch_kernel.cpp | 0 .../operators/kernel/arm/flatten_kernel.cpp | 0 .../operators/kernel/arm/fusion_fc_kernel.cpp | 0 .../src}/operators/kernel/arm/gru_kernel.cpp | 0 .../operators/kernel/arm/gru_unit_kernel.cpp | 0 .../kernel/arm/im2sequence_kernel.cpp | 0 .../operators/kernel/arm/increment_kernel.cpp | 0 .../operators/kernel/arm/is_empty_kernel.cpp | 0 .../operators/kernel/arm/lod_reset_kernel.cpp | 0 .../operators/kernel/arm/logical_kernel.cpp | 0 .../operators/kernel/arm/lookup_kernel.cpp | 0 .../src}/operators/kernel/arm/lrn_kernel.cpp | 0 .../src}/operators/kernel/arm/mul_kernel.cpp | 0 .../kernel/arm/multiclass_nms_kernel.cpp | 0 .../kernel/arm/nearest_interp_kernel.cpp | 0 .../src}/operators/kernel/arm/norm_kernel.cpp | 0 .../operators/kernel/arm/one_hot_kernel.cpp | 0 .../operators/kernel/arm/pad2d_kernel.cpp | 0 .../arm/polygon_box_transform_kernel.cpp | 0 .../src}/operators/kernel/arm/pool_kernel.cpp | 0 .../operators/kernel/arm/prelu_kernel.cpp | 0 .../operators/kernel/arm/prior_box_kernel.cpp | 0 .../operators/kernel/arm/proposal_kernel.cpp | 0 .../kernel/arm/psroi_pool_kernel.cpp | 0 .../operators/kernel/arm/quantize_kernel.cpp | 0 .../operators/kernel/arm/reshape2_kernel.cpp | 0 .../operators/kernel/arm/reshape_kernel.cpp | 0 .../operators/kernel/arm/resize_kernel.cpp | 0 .../kernel/arm/roi_perspective_kernel.cpp | 0 .../operators/kernel/arm/scale_kernel.cpp | 0 .../kernel/arm/sequence_expand_kernel.cpp | 0 .../kernel/arm/sequence_pool_kernel.cpp | 0 .../kernel/arm/sequence_softmax_kernel.cpp | 0 .../operators/kernel/arm/shape_kernel.cpp | 0 .../operators/kernel/arm/slice_kernel.cpp | 0 .../operators/kernel/arm/softmax_kernel.cpp | 0 .../operators/kernel/arm/split_kernel.cpp | 0 .../src}/operators/kernel/arm/sum_kernel.cpp | 0 .../arm/tensor_array_read_write_kernel.cpp | 0 .../operators/kernel/arm/top_k_kernel.cpp | 0 .../kernel/arm/transpose2_kernel.cpp | 0 .../operators/kernel/arm/transpose_kernel.cpp | 0 .../operators/kernel/arm/while_kernel.cpp | 0 .../src}/operators/kernel/assign_kernel.h | 0 .../operators/kernel/assign_value_kernel.h | 0 .../src}/operators/kernel/batchnorm_kernel.h | 0 .../kernel/beam_search_decode_kernel.h | 0 .../operators/kernel/beam_search_kernel.h | 0 .../operators/kernel/bilinear_interp_kernel.h | 0 .../src}/operators/kernel/box_coder_kernel.h | 0 .../central-arm-func/activation_arm_func.h | 0 .../central-arm-func/batchnorm_arm_func.h | 0 .../bilinear_interp_arm_func.h | 0 .../central-arm-func/box_coder_arm_func.h | 0 .../kernel/central-arm-func/concat_arm_func.h | 0 .../central-arm-func/conv_add_arm_func.h | 0 .../conv_add_bn_relu_arm_func.h | 0 .../central-arm-func/conv_add_relu_arm_func.h | 0 .../kernel/central-arm-func/conv_arm_func.cpp | 0 .../kernel/central-arm-func/conv_arm_func.h | 0 .../conv_bn_add_relu_arm_func.h | 0 .../central-arm-func/conv_bn_relu_arm_func.h | 0 .../conv_transpose_arm_func.h | 0 .../kernel/central-arm-func/crf_arm_func.h | 0 .../density_prior_box_arm_func.h | 0 .../dwconv_bn_relu_arm_func.h | 0 .../elementwise_add_arm_func.h | 0 .../elementwise_mul_arm_func.h | 0 .../elementwise_sub_arm_func.h | 0 .../central-arm-func/flatten_arm_func.h | 0 .../central-arm-func/fusion_fc_arm_func.h | 0 .../kernel/central-arm-func/gru_arm_func.h | 0 .../central-arm-func/gru_unit_arm_func.h | 0 .../central-arm-func/increment_arm_func.h | 0 .../kernel/central-arm-func/lookup_arm_func.h | 0 .../kernel/central-arm-func/lrn_arm_func.h | 0 .../kernel/central-arm-func/mul_arm_func.h | 0 .../multiclass_nms_arm_func.h | 0 .../kernel/central-arm-func/norm_arm_func.h | 0 .../polygon_box_transform_arm_func.h | 0 .../kernel/central-arm-func/pool_arm_func.h | 0 .../central-arm-func/prior_box_arm_func.h | 0 .../central-arm-func/reshape2_arm_func.h | 0 .../central-arm-func/reshape_arm_func.h | 0 .../kernel/central-arm-func/shape_arm_func.h | 0 .../central-arm-func/softmax_arm_func.h | 0 .../kernel/central-arm-func/split_arm_func.h | 0 .../kernel/central-arm-func/sum_arm_func.h | 0 .../central-arm-func/transpose_arm_func.h | 0 .../operators/kernel/cl/batchnorm_kernel.cpp | 0 .../operators/kernel/cl/box_coder_kernel.cpp | 0 .../kernel/cl/cl-kernel-func/conv_func.cpp | 0 .../kernel/cl/cl-kernel-func/conv_func.h | 0 .../kernel/cl/cl_kernel/batchnorm_kernel.cl | 0 .../kernel/cl/cl_kernel/box_coder_kernel.cl | 0 .../kernel/cl/cl_kernel/channel_add_kernel.cl | 0 .../operators/kernel/cl/cl_kernel/cl_common.h | 0 .../kernel/cl/cl_kernel/concat_kernel.cl | 0 .../kernel/cl/cl_kernel/conv_kernel.cl | 0 .../kernel/cl/cl_kernel/conv_kernel.inc.cl | 0 .../cl/cl_kernel/density_prior_box_kernel.cl | 0 .../depthwise_conv_add_bn_relu_kernel.cl | 0 .../cl/cl_kernel/depthwise_conv_kernel.cl | 0 .../kernel/cl/cl_kernel/dropout_kernel.cl | 0 .../cl/cl_kernel/elementwise_add_kernel.cl | 0 .../kernel/cl/cl_kernel/exp_kernel.cl | 0 .../kernel/cl/cl_kernel/feed_kernel.cl | 0 .../kernel/cl/cl_kernel/fetch_kernel.cl | 0 .../kernel/cl/cl_kernel/flatten2_kernel.cl | 0 .../kernel/cl/cl_kernel/leakyrelu_kernel.cl | 0 .../kernel/cl/cl_kernel/lrn_kernel.cl | 0 .../cl/cl_kernel/nearest_interp_kernel.cl | 0 .../kernel/cl/cl_kernel/pool_kernel.cl | 0 .../kernel/cl/cl_kernel/prior_box_kernel.cl | 0 .../operators/kernel/cl/cl_kernel/relu.cl | 0 .../operators/kernel/cl/cl_kernel/relu6.cl | 0 .../operators/kernel/cl/cl_kernel/reshape.cl | 0 .../kernel/cl/cl_kernel/scale_kernel.cl | 0 .../operators/kernel/cl/cl_kernel/sigmoid.cl | 0 .../kernel/cl/cl_kernel/slice_kernel.cl | 0 .../operators/kernel/cl/cl_kernel/softmax.cl | 0 .../kernel/cl/cl_kernel/transpose_kernel.cl | 0 .../operators/kernel/cl/concat_kernel.cpp | 0 .../kernel/cl/conv_add_bn_relu_kernel.cpp | 0 .../operators/kernel/cl/conv_add_kernel.cpp | 0 .../kernel/cl/conv_add_relu_kernel.cpp | 0 .../kernel/cl/conv_bn_add_relu_kernel.cpp | 0 .../kernel/cl/conv_bn_relu_kernel.cpp | 0 .../src}/operators/kernel/cl/conv_kernel.cpp | 0 .../operators/kernel/cl/conv_relu_kernel.cpp | 0 .../kernel/cl/density_prior_box_kernel.cpp | 0 .../kernel/cl/depthwise_conv_kernel.cpp | 0 .../operators/kernel/cl/dropout_kernel.cpp | 0 .../kernel/cl/dwconv_bn_relu_kernel.cpp | 0 .../kernel/cl/elementwise_add_kernel.cpp | 0 .../src}/operators/kernel/cl/exp_kernel.cpp | 0 .../src}/operators/kernel/cl/feed_kernel.cpp | 0 .../src}/operators/kernel/cl/fetch_kernel.cpp | 0 .../operators/kernel/cl/flatten2_kernel.cpp | 0 .../operators/kernel/cl/fusion_fc_kernel.cpp | 0 .../operators/kernel/cl/leakyrelu_kernel.cpp | 0 .../src}/operators/kernel/cl/lrn_kernel.cpp | 0 .../src}/operators/kernel/cl/mul_kernel.cpp | 0 .../kernel/cl/multiclass_nms_kernel.cpp | 0 .../kernel/cl/nearest_interp_kernel.cpp | 0 .../src}/operators/kernel/cl/pool_kernel.cpp | 0 .../operators/kernel/cl/prior_box_kernel.cpp | 0 .../src}/operators/kernel/cl/relu6_kernel.cpp | 0 .../src}/operators/kernel/cl/relu_kernel.cpp | 0 .../operators/kernel/cl/reshape2_kernel.cpp | 0 .../operators/kernel/cl/reshape_kernel.cpp | 0 .../src}/operators/kernel/cl/scale_kernel.cpp | 0 .../operators/kernel/cl/sigmoid_kernel.cpp | 0 .../src}/operators/kernel/cl/slice_kernel.cpp | 0 .../operators/kernel/cl/softmax_kernel.cpp | 0 .../src}/operators/kernel/cl/split_kernel.cpp | 0 .../operators/kernel/cl/transpose2_kernel.cpp | 0 .../operators/kernel/cl/transpose_kernel.cpp | 0 .../src}/operators/kernel/compare_kernel.h | 0 .../src}/operators/kernel/concat_kernel.h | 0 .../kernel/conditional_block_kernel.h | 0 .../operators/kernel/conv_add_bn_kernel.h | 0 .../kernel/conv_add_bn_relu_kernel.h | 0 .../src}/operators/kernel/conv_add_kernel.h | 0 .../operators/kernel/conv_add_relu_kernel.h | 0 .../kernel/conv_bn_add_relu_kernel.h | 0 .../src}/operators/kernel/conv_bn_kernel.h | 0 .../operators/kernel/conv_bn_relu_kernel.h | 0 .../src}/operators/kernel/conv_kernel.h | 0 .../src}/operators/kernel/conv_relu_kernel.h | 0 .../operators/kernel/conv_transpose_kernel.h | 0 .../src}/operators/kernel/crf_kernel.h | 0 .../operators/kernel/deconv_add_bn_kernel.h | 0 .../kernel/deconv_add_bn_relu_kernel.h | 0 .../src}/operators/kernel/deconv_add_kernel.h | 0 .../operators/kernel/deconv_add_relu_kernel.h | 0 .../operators/kernel/deconv_bn_relu_kernel.h | 0 .../operators/kernel/deconv_relu_kernel.h | 0 .../src}/operators/kernel/dequant_bn_kernel.h | 0 .../src}/operators/kernel/dequantize_kernel.h | 0 .../src}/operators/kernel/detection_kernel.h | 0 .../src}/operators/kernel/dropout_kernel.h | 0 .../operators/kernel/dwconv_bn_relu_kernel.h | 0 .../operators/kernel/elementwise_add_kernel.h | 0 .../kernel/elementwise_add_relu_kernel.h | 0 .../operators/kernel/elementwise_mul_kernel.h | 0 .../operators/kernel/elementwise_sub_kernel.h | 0 .../src}/operators/kernel/exp_kernel.h | 0 .../src}/operators/kernel/fc_relu_kernel.h | 0 .../src}/operators/kernel/feed_kernel.h | 0 .../src}/operators/kernel/fetch_kernel.h | 0 .../src}/operators/kernel/flatten2_kernel.h | 0 .../src}/operators/kernel/flatten_kernel.h | 0 .../kernel/fpga/KD/conv_add_bn_kernel.cpp | 0 .../kernel/fpga/KD/conv_add_kernel.cpp | 0 .../kernel/fpga/KD/conv_add_relu_kernel.cpp | 0 .../kernel/fpga/KD/conv_bn_kernel.cpp | 0 .../kernel/fpga/KD/conv_bn_relu_kernel.cpp | 0 .../fpga/KD/elementwise_add_relu_kernel.cpp | 0 .../operators/kernel/fpga/KD/feed_kernel.cpp | 0 .../operators/kernel/fpga/KD/fetch_kernel.cpp | 0 .../kernel/fpga/KD/fusion_fc_kernel.cpp | 0 .../operators/kernel/fpga/KD/pool_kernel.cpp | 0 .../kernel/fpga/KD/softmax_kernel.cpp | 0 .../fpga/V1/anchor_generator_kernel.cpp | 0 .../kernel/fpga/V1/concat_kernel.cpp | 0 .../kernel/fpga/V1/conv_add_bn_kernel.cpp | 0 .../fpga/V1/conv_add_bn_relu_kernel.cpp | 0 .../kernel/fpga/V1/conv_add_kernel.cpp | 0 .../kernel/fpga/V1/conv_add_relu_kernel.cpp | 0 .../kernel/fpga/V1/conv_bn_kernel.cpp | 0 .../kernel/fpga/V1/conv_bn_relu_kernel.cpp | 0 .../operators/kernel/fpga/V1/conv_kernel.cpp | 0 .../kernel/fpga/V1/conv_transpose_kernel.cpp | 0 .../kernel/fpga/V1/deconv_add_bn_kernel.cpp | 0 .../fpga/V1/deconv_add_bn_relu_kernel.cpp | 0 .../kernel/fpga/V1/deconv_add_kernel.cpp | 0 .../kernel/fpga/V1/deconv_add_relu_kernel.cpp | 0 .../kernel/fpga/V1/deconv_bn_relu_kernel.cpp | 0 .../kernel/fpga/V1/dropout_kernel.cpp | 0 .../kernel/fpga/V1/elementwise_add_kernel.cpp | 0 .../fpga/V1/elementwise_add_relu_kernel.cpp | 0 .../kernel/fpga/V1/elementwise_mul_kernel.cpp | 0 .../operators/kernel/fpga/V1/feed_kernel.cpp | 0 .../operators/kernel/fpga/V1/fetch_kernel.cpp | 0 .../kernel/fpga/V1/fusion_fc_kernel.cpp | 0 .../kernel/fpga/V1/fusion_fc_relu_kernel.cpp | 0 .../operators/kernel/fpga/V1/pad2d_kernel.cpp | 0 .../operators/kernel/fpga/V1/pool_kernel.cpp | 0 .../kernel/fpga/V1/proposal_kernel.cpp | 0 .../kernel/fpga/V1/psroi_pool_kernel.cpp | 0 .../operators/kernel/fpga/V1/relu_kernel.cpp | 0 .../kernel/fpga/V1/reshape2_kernel.cpp | 0 .../kernel/fpga/V1/reshape_kernel.cpp | 0 .../kernel/fpga/V1/roialign_pool_kernel.cpp | 0 .../kernel/fpga/V1/sigmoid_kernel.cpp | 0 .../operators/kernel/fpga/V1/slice_kernel.cpp | 0 .../kernel/fpga/V1/softmax_kernel.cpp | 0 .../operators/kernel/fpga/V1/split_kernel.cpp | 0 .../operators/kernel/fpga/V1/tanh_kernel.cpp | 0 .../kernel/fpga/V1/transpose2_kernel.cpp | 0 .../fpga/V2/anchor_generator_kernel.cpp | 0 .../kernel/fpga/V2/concat_kernel.cpp | 0 .../kernel/fpga/V2/conv_add_bn_kernel.cpp | 0 .../fpga/V2/conv_add_bn_relu_kernel.cpp | 0 .../kernel/fpga/V2/conv_add_kernel.cpp | 0 .../kernel/fpga/V2/conv_add_relu_kernel.cpp | 0 .../kernel/fpga/V2/conv_bn_kernel.cpp | 0 .../kernel/fpga/V2/conv_bn_relu_kernel.cpp | 0 .../operators/kernel/fpga/V2/conv_kernel.cpp | 0 .../kernel/fpga/V2/conv_transpose_kernel.cpp | 0 .../kernel/fpga/V2/deconv_add_bn_kernel.cpp | 0 .../fpga/V2/deconv_add_bn_relu_kernel.cpp | 0 .../kernel/fpga/V2/deconv_add_kernel.cpp | 0 .../kernel/fpga/V2/deconv_add_relu_kernel.cpp | 0 .../kernel/fpga/V2/deconv_bn_relu_kernel.cpp | 0 .../kernel/fpga/V2/dropout_kernel.cpp | 0 .../kernel/fpga/V2/elementwise_add_kernel.cpp | 0 .../fpga/V2/elementwise_add_relu_kernel.cpp | 0 .../kernel/fpga/V2/elementwise_mul_kernel.cpp | 0 .../operators/kernel/fpga/V2/feed_kernel.cpp | 0 .../operators/kernel/fpga/V2/fetch_kernel.cpp | 0 .../kernel/fpga/V2/fusion_fc_kernel.cpp | 0 .../kernel/fpga/V2/fusion_fc_relu_kernel.cpp | 0 .../operators/kernel/fpga/V2/pool_kernel.cpp | 0 .../kernel/fpga/V2/proposal_kernel.cpp | 0 .../kernel/fpga/V2/psroi_pool_kernel.cpp | 0 .../operators/kernel/fpga/V2/relu_kernel.cpp | 0 .../kernel/fpga/V2/reshape2_kernel.cpp | 0 .../kernel/fpga/V2/reshape_kernel.cpp | 0 .../kernel/fpga/V2/roialign_pool_kernel.cpp | 0 .../kernel/fpga/V2/sigmoid_kernel.cpp | 0 .../operators/kernel/fpga/V2/slice_kernel.cpp | 0 .../kernel/fpga/V2/softmax_kernel.cpp | 0 .../operators/kernel/fpga/V2/split_kernel.cpp | 0 .../operators/kernel/fpga/V2/tanh_kernel.cpp | 0 .../kernel/fpga/V2/transpose2_kernel.cpp | 0 .../src}/operators/kernel/fusion_fc_kernel.h | 0 .../src}/operators/kernel/gru_kernel.h | 0 .../src}/operators/kernel/gru_unit_kernel.h | 0 .../operators/kernel/im2sequence_kernel.h | 0 .../src}/operators/kernel/increment_kernel.h | 0 .../src}/operators/kernel/is_empty_kernel.h | 0 .../src}/operators/kernel/kernels.h | 0 .../src}/operators/kernel/logical_kernel.h | 0 .../src}/operators/kernel/lookup_kernel.h | 0 .../src}/operators/kernel/lrn_kernel.h | 0 .../src}/operators/kernel/mul_kernel.h | 0 .../operators/kernel/multiclass_nms_kernel.h | 0 .../operators/kernel/nearest_interp_kernel.h | 0 .../src}/operators/kernel/norm_kernel.h | 0 .../src}/operators/kernel/one_hot_kernel.h | 0 .../src}/operators/kernel/pad2d_kernel.h | 0 .../kernel/polygon_box_transform_kernel.h | 0 .../src}/operators/kernel/pool_kernel.h | 0 .../src}/operators/kernel/prelu_kernel.h | 0 .../src}/operators/kernel/prior_box_kernel.h | 0 .../src}/operators/kernel/quantize_kernel.h | 0 .../src}/operators/kernel/range_kernel.cpp | 0 .../src}/operators/kernel/range_kernel.h | 0 .../operators/kernel/reduce_prod_kernel.cpp | 0 .../operators/kernel/reduce_prod_kernel.h | 0 .../src}/operators/kernel/reshape2_kernel.h | 0 .../src}/operators/kernel/reshape_kernel.h | 0 .../src}/operators/kernel/resize_kernel.h | 0 .../src}/operators/kernel/scale_kernel.h | 0 .../src}/operators/kernel/sequence_kernels.h | 0 .../src}/operators/kernel/shape_kernel.h | 0 .../src}/operators/kernel/slice_kernel.h | 0 .../src}/operators/kernel/softmax_kernel.h | 0 .../src}/operators/kernel/split_kernel.h | 0 .../src}/operators/kernel/sum_kernel.h | 0 .../src}/operators/kernel/tanh_kernel.h | 0 .../kernel/tensor_array_read_write_kernel.h | 0 .../src}/operators/kernel/transpose2_kernel.h | 0 .../src}/operators/kernel/transpose_kernel.h | 0 .../src}/operators/kernel/while_kernel.h | 0 .../src}/operators/lod_reset_op.cpp | 0 {src => mobile/src}/operators/lod_reset_op.h | 0 {src => mobile/src}/operators/logical_op.cpp | 0 {src => mobile/src}/operators/logical_op.h | 0 {src => mobile/src}/operators/lookup_op.cpp | 0 {src => mobile/src}/operators/lookup_op.h | 0 {src => mobile/src}/operators/lrn_op.cpp | 0 {src => mobile/src}/operators/lrn_op.h | 0 .../src}/operators/math/activation.h | 0 .../math/depthwise/faster_depthwise_conv3x3.h | 0 .../depthwise/faster_depthwise_conv3x3p1.cpp | 0 .../src}/operators/math/depthwise_conv3x3.cpp | 0 .../src}/operators/math/depthwise_conv3x3.h | 0 .../operators/math/depthwise_conv3x3_int8.cpp | 0 .../src}/operators/math/depthwise_conv5x5.cpp | 0 .../src}/operators/math/depthwise_conv5x5.h | 0 .../operators/math/depthwise_conv5x5_int8.cpp | 0 .../src}/operators/math/element_wise.h | 0 .../operators/math/elementwise_op_function.h | 0 {src => mobile/src}/operators/math/gemm.cpp | 0 {src => mobile/src}/operators/math/gemm.h | 0 .../src}/operators/math/gemm/cblas.cc | 0 .../src}/operators/math/gemm/cblas.h | 0 .../src}/operators/math/gemm/executor.h | 0 .../src}/operators/math/gemm/gemm1x1s1.cpp | 0 .../src}/operators/math/gemm/gemm1x1s1.h | 0 .../src}/operators/math/gemm/gemm_kernel.h | 0 .../src}/operators/math/gemm/pack_kernel.h | 0 .../src}/operators/math/gemm/strategy.h | 0 .../src}/operators/math/gemm_int8.cpp | 0 .../src}/operators/math/gemm_omp_int8.cpp | 0 {src => mobile/src}/operators/math/gpc.cpp | 0 {src => mobile/src}/operators/math/gpc.h | 0 .../src}/operators/math/gru_compute.cpp | 0 .../src}/operators/math/gru_compute.h | 0 .../src}/operators/math/gru_cpu_kernel.h | 0 {src => mobile/src}/operators/math/im2col.cpp | 0 {src => mobile/src}/operators/math/im2col.h | 0 {src => mobile/src}/operators/math/math.h | 0 .../src}/operators/math/math_function.cpp | 0 .../src}/operators/math/math_function.h | 0 .../operators/math/math_function_int8.cpp | 0 {src => mobile/src}/operators/math/pad.cpp | 0 {src => mobile/src}/operators/math/pad.h | 0 .../src}/operators/math/poly_util.cpp | 0 .../src}/operators/math/poly_util.h | 0 .../src}/operators/math/pooling.cpp | 0 {src => mobile/src}/operators/math/pooling.h | 0 .../src}/operators/math/pooling2x2.cpp | 0 .../src}/operators/math/pooling3x3.cpp | 0 {src => mobile/src}/operators/math/quantize.h | 0 .../operators/math/selected_rows_functor.h | 0 .../src}/operators/math/sequence2batch.cpp | 0 .../src}/operators/math/sequence2batch.h | 0 .../operators/math/slidingwindow_conv3x3.cpp | 0 .../operators/math/slidingwindow_conv3x3.h | 0 .../operators/math/slidingwindow_utils.cpp | 0 .../src}/operators/math/slidingwindow_utils.h | 0 .../src}/operators/math/softmax.cpp | 0 {src => mobile/src}/operators/math/softmax.h | 0 .../src}/operators/math/transform.h | 0 .../src}/operators/math/vol2col.cpp | 0 {src => mobile/src}/operators/math/vol2col.h | 0 .../math/winograd/winograd_transform.h | 0 .../math/winograd/winograd_transform_f6k3.cpp | 0 {src => mobile/src}/operators/mul_op.cpp | 0 {src => mobile/src}/operators/mul_op.h | 0 .../src}/operators/multiclass_nms_op.cpp | 0 .../src}/operators/multiclass_nms_op.h | 0 .../src}/operators/nearest_interp_op.cpp | 0 .../src}/operators/nearest_interp_op.h | 0 {src => mobile/src}/operators/norm_op.cpp | 0 {src => mobile/src}/operators/norm_op.h | 0 {src => mobile/src}/operators/one_hot_op.cpp | 0 {src => mobile/src}/operators/one_hot_op.h | 0 {src => mobile/src}/operators/op_param.cpp | 0 {src => mobile/src}/operators/op_param.h | 0 {src => mobile/src}/operators/pad2d_op.cpp | 0 {src => mobile/src}/operators/pad2d_op.h | 0 .../operators/polygon_box_transform_op.cpp | 0 .../src}/operators/polygon_box_transform_op.h | 0 {src => mobile/src}/operators/pool_op.cpp | 0 {src => mobile/src}/operators/pool_op.h | 0 {src => mobile/src}/operators/prelu_op.cpp | 0 {src => mobile/src}/operators/prelu_op.h | 0 .../src}/operators/prior_box_op.cpp | 0 {src => mobile/src}/operators/prior_box_op.h | 0 {src => mobile/src}/operators/quantize_op.cpp | 0 {src => mobile/src}/operators/quantize_op.h | 0 {src => mobile/src}/operators/range_op.cpp | 0 {src => mobile/src}/operators/range_op.h | 0 .../src}/operators/reduce_prod_op.cpp | 0 .../src}/operators/reduce_prod_op.h | 0 {src => mobile/src}/operators/reshape2_op.cpp | 0 {src => mobile/src}/operators/reshape2_op.h | 0 {src => mobile/src}/operators/reshape_op.cpp | 0 {src => mobile/src}/operators/reshape_op.h | 0 {src => mobile/src}/operators/resize_op.cpp | 0 {src => mobile/src}/operators/resize_op.h | 0 {src => mobile/src}/operators/scale_op.cpp | 0 {src => mobile/src}/operators/scale_op.h | 0 .../sequence_ops/sequence_expand_op.cpp | 0 .../sequence_ops/sequence_expand_op.h | 0 .../sequence_ops/sequence_pool_op.cpp | 0 .../operators/sequence_ops/sequence_pool_op.h | 0 .../sequence_ops/sequence_softmax_op.cpp | 0 .../sequence_ops/sequence_softmax_op.h | 0 {src => mobile/src}/operators/shape_op.cpp | 0 {src => mobile/src}/operators/shape_op.h | 0 {src => mobile/src}/operators/slice_op.cpp | 0 {src => mobile/src}/operators/slice_op.h | 0 {src => mobile/src}/operators/softmax_op.cpp | 0 {src => mobile/src}/operators/softmax_op.h | 0 {src => mobile/src}/operators/split_op.cpp | 0 {src => mobile/src}/operators/split_op.h | 0 {src => mobile/src}/operators/sum_op.cpp | 0 {src => mobile/src}/operators/sum_op.h | 0 {src => mobile/src}/operators/top_k_op.cpp | 0 {src => mobile/src}/operators/top_k_op.h | 0 .../src}/operators/transpose2_op.cpp | 0 {src => mobile/src}/operators/transpose2_op.h | 0 .../src}/operators/transpose_op.cpp | 0 {src => mobile/src}/operators/transpose_op.h | 0 {src => mobile/src}/pass/memory_optimize.cpp | 0 {src => mobile/src}/pass/memory_optimize.h | 0 .../src}/pass/memory_optimize_super.cpp | 0 .../src}/pass/memory_optimize_super.h | 0 {src => mobile/src}/pass/model_obfuscate.cpp | 0 {src => mobile/src}/pass/model_obfuscate.h | 0 {src => mobile/src}/pass/pass_base.h | 0 {src => mobile/src}/protobuf-c/protobuf-c.c | 0 {src => mobile/src}/protobuf-c/protobuf-c.h | 0 {test => mobile/test}/CMakeLists.txt | 0 {test => mobile/test}/common/test_enforce.cpp | 0 .../test}/common/test_gemm_accuracy.cpp | 0 .../test}/common/test_gemm_int8_accuracy.cpp | 0 .../test}/common/test_gemm_perf.cpp | 0 .../test}/common/test_lib_size.cpp | 0 {test => mobile/test}/common/test_lib_size.h | 0 {test => mobile/test}/common/test_log.cpp | 0 {test => mobile/test}/common/test_openmp.cpp | 0 {test => mobile/test}/executor_for_test.h | 0 {test => mobile/test}/fpga/test_concat_op.cpp | 0 .../test}/fpga/test_densebox_combine.cpp | 0 .../test}/fpga/test_format_data.cpp | 0 {test => mobile/test}/fpga/test_marker.cpp | 0 {test => mobile/test}/fpga/test_marker2.cpp | 0 .../test}/fpga/test_marker_api.cpp | 0 .../test}/fpga/test_mobilenet_api.cpp | 0 {test => mobile/test}/fpga/test_pe.cpp | 0 {test => mobile/test}/fpga/test_resnet50.cpp | 0 {test => mobile/test}/fpga/test_rfcn.cpp | 0 {test => mobile/test}/fpga/test_rfcn_api.cpp | 0 {test => mobile/test}/fpga/test_ssd.cpp | 0 .../test}/fpga/test_tensor_quant.cpp | 0 {test => mobile/test}/fpga/test_yolo_api.cpp | 0 .../test}/framework/test_inference_api.cpp | 0 {test => mobile/test}/framework/test_load.cpp | 0 .../test}/framework/test_load_memory.cpp | 0 .../test_load_memory_inference_api.cpp | 0 .../test}/framework/test_optimize.cpp | 0 {test => mobile/test}/net/test_alexnet.cpp | 0 {test => mobile/test}/net/test_benchmark.cpp | 0 {test => mobile/test}/net/test_eng.cpp | 0 .../test}/net/test_genet_combine.cpp | 0 {test => mobile/test}/net/test_gesture.cpp | 0 {test => mobile/test}/net/test_googlenet.cpp | 0 .../test}/net/test_googlenet_quali.cpp | 0 .../test}/net/test_googlenetv1_combine.cpp | 0 .../test}/net/test_inceptionv4.cpp | 0 .../test}/net/test_mobilenet+ssd.cpp | 0 {test => mobile/test}/net/test_mobilenet.cpp | 0 .../test}/net/test_mobilenet_025_fssd.cpp | 0 .../test}/net/test_mobilenet_GPU.cpp | 0 .../test}/net/test_mobilenet_combine.cpp | 0 .../net/test_multi_inference_predict.cpp | 0 {test => mobile/test}/net/test_net.cpp | 0 .../test}/net/test_net_benchmark.cpp | 0 {test => mobile/test}/net/test_nlp.cpp | 0 {test => mobile/test}/net/test_ocr.cpp | 0 {test => mobile/test}/net/test_op_in_net.cpp | 0 {test => mobile/test}/net/test_resnet.cpp | 0 {test => mobile/test}/net/test_squeezenet.cpp | 0 {test => mobile/test}/net/test_super.cpp | 0 {test => mobile/test}/net/test_vgg16ssd.cpp | 0 {test => mobile/test}/net/test_wrap.cpp | 0 {test => mobile/test}/net/test_yolo.cpp | 0 .../test}/net/test_yolo_combined.cpp | 0 {test => mobile/test}/net/test_yologpu.cpp | 0 .../test}/operators/test_batchnorm_op.cpp | 0 .../test}/operators/test_box_coder_op.cpp | 0 .../test}/operators/test_cast_op.cpp | 0 .../test}/operators/test_concat_op.cpp | 0 .../test}/operators/test_conv_add_relu_op.cpp | 0 .../test}/operators/test_conv_bn_relu_op.cpp | 0 .../test}/operators/test_conv_gpu.cpp | 0 .../test}/operators/test_conv_op.cpp | 0 .../operators/test_depthwise_conv_op.cpp | 0 .../test}/operators/test_dequantize_op.cpp | 0 .../operators/test_dwconv_bn_relu_op.cpp | 0 .../operators/test_elementwise_add_op.cpp | 0 .../operators/test_elementwise_sub_op.cpp | 0 .../test}/operators/test_fill_constant_op.cpp | 0 .../test_fusion_conv_add_bn_relu_op.cpp | 0 .../test}/operators/test_fusion_fc_op.cpp | 0 .../test}/operators/test_gru_op.cpp | 0 .../test}/operators/test_im2sequence_op.cpp | 0 .../test}/operators/test_increment_op.cpp | 0 .../test}/operators/test_is_empty_op.cpp | 0 .../test}/operators/test_leaky_relu_op.cpp | 0 .../test}/operators/test_less_than_op.cpp | 0 .../test}/operators/test_log_op.cpp | 0 .../test}/operators/test_logical_and_op.cpp | 0 .../test}/operators/test_logical_not_op.cpp | 0 .../test}/operators/test_logical_or_op.cpp | 0 .../test}/operators/test_logical_xor_op.cpp | 0 .../test}/operators/test_lrn_op.cpp | 0 .../test}/operators/test_mul_op.cpp | 0 .../operators/test_multiclass_nms_op.cpp | 0 .../test_polygon_box_transform_op.cpp | 0 .../test}/operators/test_pool_op.cpp | 0 .../test}/operators/test_prelu_op.cpp | 0 .../test}/operators/test_prior_box_op.cpp | 0 .../test}/operators/test_quantize_op.cpp | 0 .../test}/operators/test_relu6_op.cpp | 0 .../test}/operators/test_relu_op.cpp | 0 .../test}/operators/test_reshape2_op.cpp | 0 .../test}/operators/test_reshape_op.cpp | 0 .../test}/operators/test_resize_op.cpp | 0 .../test}/operators/test_scale_op.cpp | 0 .../operators/test_sequence_expand_op.cpp | 0 .../test}/operators/test_sequence_pool_op.cpp | 0 .../operators/test_sequence_softmax_op.cpp | 0 .../test}/operators/test_sigmoid_op.cpp | 0 .../test}/operators/test_slice_op.cpp | 0 .../test}/operators/test_softmax_op.cpp | 0 .../test}/operators/test_sum_op.cpp | 0 .../test}/operators/test_tanh_op.cpp | 0 .../test}/operators/test_topk_op.cpp | 0 .../test}/operators/test_transpose2_op.cpp | 0 .../test}/operators/test_transpose_op.cpp | 0 {test => mobile/test}/test_helper.h | 0 {test => mobile/test}/test_include.h | 0 .../opencl/OpenCL-Headers/CL/cl.h | 1 - .../opencl/OpenCL-Headers/CL/cl_d3d10.h | 1 - .../opencl/OpenCL-Headers/CL/cl_d3d11.h | 1 - .../OpenCL-Headers/CL/cl_dx9_media_sharing.h | 1 - .../CL/cl_dx9_media_sharing_intel.h | 1 - .../opencl/OpenCL-Headers/CL/cl_egl.h | 0 .../opencl/OpenCL-Headers/CL/cl_ext.h | 0 .../opencl/OpenCL-Headers/CL/cl_ext_intel.h | 1 - .../opencl/OpenCL-Headers/CL/cl_gl.h | 0 .../opencl/OpenCL-Headers/CL/cl_gl_ext.h | 0 .../opencl/OpenCL-Headers/CL/cl_platform.h | 0 .../CL/cl_va_api_media_sharing_intel.h | 1 - .../opencl/OpenCL-Headers/CL/cl_version.h | 0 .../opencl/OpenCL-Headers/CL/opencl.h | 1 - .../opencl/OpenCL-Headers/LICENSE | 0 .../opencl/OpenCL-Headers/README.md | 0 .../android-cmake/android.toolchain.cmake | 0 .../android-debug-script/push2android.sh | 0 .../android-debug-script/run_on_android.sh | 2 +- {tools => mobile/tools}/arm-platform.cmake | 0 {tools => mobile/tools}/build.sh | 0 {tools => mobile/tools}/ci_build.sh | 0 {tools => mobile/tools}/ci_run_test.sh | 0 {tools => mobile/tools}/docker_build_fpga.sh | 1 - .../tools}/ios-cmake/ios.toolchain.cmake | 0 {tools => mobile/tools}/net-detail.awk | 0 {tools => mobile/tools}/net.awk | 0 {tools => mobile/tools}/op.cmake | 0 .../tools}/pre-commit.hooks/clang-format.hook | 0 .../tools}/pre-commit.hooks/clang-tidy.hook | 0 .../tools}/pre-commit.hooks/copyright.hook | 0 .../tools}/pre-commit.hooks/cpplint.hook | 0 .../tools}/prepare_images_and_models.sh | 0 {tools => mobile/tools}/profile_show.sh | 1 - .../tools}/python/caffetools/run.py | 0 .../tools}/python/fluidtools/.gitignore | 0 .../tools}/python/fluidtools/run.py | 0 .../tools}/python/imagetools/README.md | 0 .../tools}/python/imagetools/imagetools.py | 0 .../tools}/python/imagetools/img2nchw.py | 0 .../tools}/python/imagetools/img2nhwc.py | 0 .../tools}/python/imagetools/numpy2binary.py | 0 .../tools}/python/misc/.gitignore | 0 .../tools}/python/misc/fluidtools.py | 0 .../tools}/python/misc/ios-test-server.py | 1 - .../tools}/python/misc/restore-git.py | 0 .../python/misc/test-fluid-op-feature.py | 0 .../tools}/python/modeltools/.gitignore | 0 .../tools}/python/modeltools/core/__init__.py | 0 .../python/modeltools/core/framework.proto | 0 .../python/modeltools/core/framework_pb2.py | 0 .../tools}/python/modeltools/core/op_types.py | 0 .../python/modeltools/mobilenet/__init__.py | 0 .../mobilenet/converter_mobilenet.py | 0 .../python/modeltools/mobilenet/swicher.py | 0 .../python/modeltools/tools/__init__.py | 0 .../modeltools/tools/float2halffloat.py | 0 .../tools}/python/modeltools/tools/loader.py | 2 - .../python/modeltools/tools/model_combine.py | 0 .../python/modeltools/tools/model_reader.py | 0 .../tools}/python/modeltools/yolo/__init__.py | 0 .../python/modeltools/yolo/mdl2fluid.py | 0 .../tools}/python/modeltools/yolo/swicher.py | 0 .../tools}/quantification/CMakeLists.txt | 2 +- .../tools}/quantification/README.md | 2 - .../tools}/quantification/convert.cpp | 1 - .../quantification/src/block_desc_local.cpp | 0 .../quantification/src/block_desc_local.h | 0 .../tools}/quantification/src/enforce.h | 0 .../quantification/src/framework.pb-c.c | 0 .../quantification/src/framework.pb-c.h | 0 .../quantification/src/program_desc.cpp | 0 .../tools}/quantification/src/program_desc.h | 0 .../tools}/quantification/src/protobuf-c.c | 0 .../tools}/quantification/src/protobuf-c.h | 0 .../tools}/quantification/src/tensor_desc.h | 0 .../tools}/quantification/src/var_desc.h | 0 .../tools}/shell/check-bitcode.sh | 0 .../tools}/shell/check-filename.sh | 0 .../tools}/shell/generate-include/.gitignore | 0 .../generate-include/check_include_diff.sh | 0 .../tools}/shell/generate-include/main.cpp | 0 .../tools}/shell/generate-include/parse.py | 0 .../tools}/shell/generate-include/run.sh | 0 {tools => mobile/tools}/shell/merge.sh | 0 .../tools}/shell/prune_static_library.sh | 0 .../tools}/shell/restore-private-repo.sh | 0 .../tools}/toolchains/arm-android-neon.cmake | 0 .../tools}/toolchains/arm-linux-gnueabi.cmake | 0 .../toolchains/arm-linux-gnueabihf.cmake | 0 tools/codestyle/.gitignore | 1 + tools/codestyle/clang_format.hook | 15 + tools/codestyle/copyright.hook | 121 + tools/codestyle/cpplint_pre_commit.hook | 27 + tools/codestyle/docstring_checker.py | 349 + tools/codestyle/pylint_pre_commit.hook | 19 + tools/codestyle/test_docstring_checker.py | 232 + tools/document_preview.sh | 13 + web/src/ops/dummy.js | 2 +- web/src/ops/feed.js | 2 - web/src/ops/fetch.js | 2 +- web/test/testUtils/diff.js | 2 +- web/test/unitTest.html | 2 +- web/tools/toBinaryFile.py | 2 +- 2158 files changed, 187047 insertions(+), 362 deletions(-) create mode 100644 cmake/FindGflags.cmake create mode 100644 cmake/FindGlog.cmake create mode 100644 cmake/FindGperftools.cmake create mode 100644 cmake/FindJeMalloc.cmake create mode 100644 cmake/FindNumPy.cmake create mode 100644 cmake/cblas.cmake create mode 100644 cmake/ccache.cmake create mode 100644 cmake/configure.cmake create mode 100644 cmake/coveralls.cmake create mode 100644 cmake/coverallsGcovJsons.cmake create mode 100644 cmake/cross_compiling/android.cmake create mode 100644 cmake/cross_compiling/armlinux.cmake create mode 100644 cmake/cross_compiling/findar.cmake create mode 100644 cmake/cross_compiling/host.cmake create mode 100644 cmake/cross_compiling/ios.cmake create mode 100644 cmake/cross_compiling/npu.cmake create mode 100644 cmake/cross_compiling/postproject.cmake create mode 100644 cmake/cross_compiling/preproject.cmake create mode 100644 cmake/cuda.cmake create mode 100644 cmake/cudnn.cmake create mode 100644 cmake/cupti.cmake create mode 100644 cmake/external/eigen.cmake create mode 100644 cmake/external/gflags.cmake create mode 100644 cmake/external/glog.cmake create mode 100644 cmake/external/gtest.cmake create mode 100644 cmake/external/libxsmm.cmake create mode 100644 cmake/external/mkldnn.cmake create mode 100644 cmake/external/mklml.cmake create mode 100644 cmake/external/openblas.cmake create mode 100644 cmake/external/opencl-clhpp.cmake create mode 100644 cmake/external/opencl-headers.cmake create mode 100644 cmake/external/protobuf.cmake create mode 100644 cmake/external/xbyak.cmake create mode 100644 cmake/external/xxhash.cmake create mode 100644 cmake/flags.cmake create mode 100644 cmake/generic.cmake create mode 100644 cmake/hip.cmake create mode 100644 cmake/lite.cmake create mode 100644 cmake/make_resource.py create mode 100644 cmake/operators.cmake create mode 100644 cmake/package.cmake create mode 100644 cmake/simd.cmake create mode 100644 cmake/system.cmake create mode 100644 cmake/tensorrt.cmake create mode 100644 cmake/util.cmake create mode 100644 cmake/version.cmake create mode 100644 lite/CMakeLists.txt create mode 100644 lite/api/CMakeLists.txt create mode 100644 lite/api/android/.gitignore create mode 100644 lite/api/android/CMakeLists.txt create mode 100644 lite/api/android/jni/.gitignore create mode 100644 lite/api/android/jni/CMakeLists.txt create mode 100644 lite/api/android/jni/native/CMakeLists.txt create mode 100644 lite/api/android/jni/native/convert_util_jni.h create mode 100644 lite/api/android/jni/native/paddle_lite_jni.cc create mode 100644 lite/api/android/jni/native/paddle_lite_jni.h create mode 100644 lite/api/android/jni/native/tensor_jni.cc create mode 100644 lite/api/android/jni/native/tensor_jni.h create mode 100644 lite/api/android/jni/src/com/baidu/paddle/lite/.gitignore create mode 100644 lite/api/android/jni/src/com/baidu/paddle/lite/ConfigBase.java create mode 100644 lite/api/android/jni/src/com/baidu/paddle/lite/CxxConfig.java create mode 100644 lite/api/android/jni/src/com/baidu/paddle/lite/MobileConfig.java create mode 100644 lite/api/android/jni/src/com/baidu/paddle/lite/PaddleLiteInitializer.java create mode 100644 lite/api/android/jni/src/com/baidu/paddle/lite/PaddlePredictor.java create mode 100644 lite/api/android/jni/src/com/baidu/paddle/lite/Place.java create mode 100644 lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java create mode 100644 lite/api/android/jni/test/com/baidu/paddle/lite/PaddlePredictorTest.java create mode 100644 lite/api/apis_test.cc create mode 100644 lite/api/cxx_api.cc create mode 100644 lite/api/cxx_api.h create mode 100644 lite/api/cxx_api_bin.cc create mode 100644 lite/api/cxx_api_impl.cc create mode 100644 lite/api/cxx_api_test.cc create mode 100644 lite/api/efficientnet_b0_test.cc create mode 100644 lite/api/inceptionv4_test.cc create mode 100644 lite/api/light_api.cc create mode 100644 lite/api/light_api.h create mode 100644 lite/api/light_api_impl.cc create mode 100644 lite/api/light_api_test.cc create mode 100644 lite/api/lite_api_test_helper.cc create mode 100644 lite/api/lite_api_test_helper.h create mode 100644 lite/api/mobilenetv1_int8_test.cc create mode 100644 lite/api/mobilenetv1_ssd_test.cc create mode 100644 lite/api/mobilenetv1_test.cc create mode 100644 lite/api/mobilenetv1_yolov3_test.cc create mode 100644 lite/api/mobilenetv2_test.cc create mode 100644 lite/api/model_optimize_tool.cc create mode 100644 lite/api/model_test.cc create mode 100644 lite/api/ocr_attention_test.cc create mode 100644 lite/api/paddle_api.cc create mode 100644 lite/api/paddle_api.h create mode 100644 lite/api/paddle_api_test.cc create mode 100644 lite/api/paddle_lite_factory_helper.h create mode 100644 lite/api/paddle_place.cc create mode 100644 lite/api/paddle_place.h create mode 100644 lite/api/paddle_use_kernels.h create mode 100644 lite/api/paddle_use_ops.h create mode 100644 lite/api/paddle_use_passes.h create mode 100644 lite/api/resnet18_test.cc create mode 100644 lite/api/resnet50_test.cc create mode 100644 lite/api/resnet50_test_fpga.cc create mode 100644 lite/api/shufflenetv2_test.cc create mode 100644 lite/api/test_googlenet_lite.cc create mode 100644 lite/api/test_helper.h create mode 100644 lite/api/test_inceptionv4_lite_x86.cc create mode 100644 lite/api/test_mobilenetv1_lite_x86.cc create mode 100644 lite/api/test_mobilenetv2_lite_x86.cc create mode 100644 lite/api/unet_test.cc create mode 100644 lite/arm/CMakeLists.txt create mode 100644 lite/arm/math/CMakeLists.txt create mode 100644 lite/arm/math/activation.cc create mode 100644 lite/arm/math/activation.h create mode 100644 lite/arm/math/argmax.cc create mode 100644 lite/arm/math/argmax.h create mode 100644 lite/arm/math/axpy.cc create mode 100644 lite/arm/math/axpy.h create mode 100644 lite/arm/math/beam_search.cc create mode 100644 lite/arm/math/beam_search.h create mode 100644 lite/arm/math/box_coder.cc create mode 100644 lite/arm/math/box_coder.h create mode 100644 lite/arm/math/col_im_transform.cc create mode 100644 lite/arm/math/col_im_transform.h create mode 100644 lite/arm/math/concat.cc create mode 100644 lite/arm/math/concat.h create mode 100644 lite/arm/math/conv3x3s1_direct_int8.cc create mode 100644 lite/arm/math/conv3x3s2_direct_int8.cc create mode 100644 lite/arm/math/conv_block_utils.h create mode 100644 lite/arm/math/conv_depthwise.cc create mode 100644 lite/arm/math/conv_depthwise.h create mode 100644 lite/arm/math/conv_depthwise_3x3_int7.cc create mode 100644 lite/arm/math/conv_depthwise_3x3_int8.cc create mode 100644 lite/arm/math/conv_depthwise_3x3p0.cc create mode 100644 lite/arm/math/conv_depthwise_3x3p1.cc create mode 100644 lite/arm/math/conv_depthwise_5x5s1.cc create mode 100644 lite/arm/math/conv_depthwise_5x5s1_int8.cc create mode 100644 lite/arm/math/conv_depthwise_5x5s2.cc create mode 100644 lite/arm/math/conv_direct.cc create mode 100644 lite/arm/math/conv_direct.h create mode 100644 lite/arm/math/conv_direct_3x3s1.cc create mode 100644 lite/arm/math/conv_direct_3x3s2.cc create mode 100644 lite/arm/math/conv_gemmlike.cc create mode 100644 lite/arm/math/conv_gemmlike.h create mode 100644 lite/arm/math/conv_impl.cc create mode 100644 lite/arm/math/conv_impl.h create mode 100644 lite/arm/math/conv_winograd.cc create mode 100644 lite/arm/math/conv_winograd.h create mode 100644 lite/arm/math/conv_winograd_3x3.cc create mode 100644 lite/arm/math/decode_bboxes.cc create mode 100644 lite/arm/math/decode_bboxes.h create mode 100644 lite/arm/math/dot_toolchain_support.h create mode 100644 lite/arm/math/dropout.cc create mode 100644 lite/arm/math/dropout.h create mode 100644 lite/arm/math/elementwise.cc create mode 100644 lite/arm/math/elementwise.h create mode 100644 lite/arm/math/fill_bias_relu.cc create mode 100644 lite/arm/math/fill_bias_relu.h create mode 100644 lite/arm/math/funcs.cc create mode 100644 lite/arm/math/funcs.h create mode 100644 lite/arm/math/gemm_prepacked_int8.cc create mode 100644 lite/arm/math/gemm_prepacked_int8.h create mode 100644 lite/arm/math/gemv_arm_int8.cc create mode 100644 lite/arm/math/gemv_arm_int8.h create mode 100644 lite/arm/math/gru_utils.h create mode 100644 lite/arm/math/im2sequence.cc create mode 100644 lite/arm/math/im2sequence.h create mode 100644 lite/arm/math/increment.cc create mode 100644 lite/arm/math/increment.h create mode 100644 lite/arm/math/interpolate.cc create mode 100644 lite/arm/math/interpolate.h create mode 100644 lite/arm/math/lrn.cc create mode 100644 lite/arm/math/lrn.h create mode 100644 lite/arm/math/multiclass_nms.cc create mode 100644 lite/arm/math/multiclass_nms.h create mode 100644 lite/arm/math/negative.cc create mode 100644 lite/arm/math/negative.h create mode 100644 lite/arm/math/norm.cc create mode 100644 lite/arm/math/norm.h create mode 100644 lite/arm/math/packed_sgemm.cc create mode 100644 lite/arm/math/packed_sgemm.h create mode 100644 lite/arm/math/pad2d.cc create mode 100644 lite/arm/math/pad2d.h create mode 100644 lite/arm/math/pooling.cc create mode 100644 lite/arm/math/pooling.h create mode 100644 lite/arm/math/power.cc create mode 100644 lite/arm/math/power.h create mode 100644 lite/arm/math/prior_box.cc create mode 100644 lite/arm/math/prior_box.h create mode 100644 lite/arm/math/reduce_max.cc create mode 100644 lite/arm/math/reduce_max.h create mode 100644 lite/arm/math/saturate.h create mode 100644 lite/arm/math/scale.cc create mode 100644 lite/arm/math/scale.h create mode 100644 lite/arm/math/sequence2batch.h create mode 100644 lite/arm/math/sequence_expand.cc create mode 100644 lite/arm/math/sequence_expand.h create mode 100644 lite/arm/math/sequence_pool.cc create mode 100644 lite/arm/math/sequence_pool.h create mode 100644 lite/arm/math/sequence_softmax.cc create mode 100644 lite/arm/math/sequence_softmax.h create mode 100644 lite/arm/math/sgemm.cc create mode 100644 lite/arm/math/sgemm.h create mode 100644 lite/arm/math/sgemv.cc create mode 100644 lite/arm/math/sgemv.h create mode 100644 lite/arm/math/shuffle_channel.cc create mode 100644 lite/arm/math/shuffle_channel.h create mode 100644 lite/arm/math/slice.cc create mode 100644 lite/arm/math/slice.h create mode 100644 lite/arm/math/softmax.cc create mode 100644 lite/arm/math/softmax.h create mode 100644 lite/arm/math/split.cc create mode 100644 lite/arm/math/split.h create mode 100644 lite/arm/math/topk.cc create mode 100644 lite/arm/math/topk.h create mode 100644 lite/arm/math/type_trans.cc create mode 100644 lite/arm/math/type_trans.h create mode 100644 lite/arm/math/yolo_box.cc create mode 100644 lite/arm/math/yolo_box.h create mode 100644 lite/core/CMakeLists.txt create mode 100644 lite/core/arena/CMakeLists.txt create mode 100644 lite/core/arena/framework.cc create mode 100644 lite/core/arena/framework.h create mode 100644 lite/core/arena/framework_test.cc create mode 100644 lite/core/context.cc create mode 100644 lite/core/context.h create mode 100644 lite/core/context_test.cc create mode 100644 lite/core/cpu_info.cc create mode 100644 lite/core/cpu_info.h create mode 100644 lite/core/framework.proto create mode 100644 lite/core/kernel.cc create mode 100644 lite/core/kernel.h create mode 100644 lite/core/kernel_test.cc create mode 100644 lite/core/lite.map create mode 100644 lite/core/lite_gtest_main.cc create mode 100644 lite/core/lite_tensor_test.cc create mode 100644 lite/core/memory.cc create mode 100644 lite/core/memory.h create mode 100644 lite/core/memory_test.cc create mode 100644 lite/core/mir/CMakeLists.txt create mode 100644 lite/core/mir/argument_type_display_pass.cc create mode 100644 lite/core/mir/demo_pass.cc create mode 100644 lite/core/mir/dot.h create mode 100644 lite/core/mir/elimination/CMakeLists.txt create mode 100644 lite/core/mir/elimination/identity_scale_eliminate_pass.cc create mode 100644 lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc create mode 100644 lite/core/mir/fusion/CMakeLists.txt create mode 100644 lite/core/mir/fusion/conv_activation_fuse_pass.cc create mode 100644 lite/core/mir/fusion/conv_activation_fuse_pass.h create mode 100644 lite/core/mir/fusion/conv_activation_fuser.cc create mode 100644 lite/core/mir/fusion/conv_activation_fuser.h create mode 100644 lite/core/mir/fusion/conv_bn_fuse_pass.cc create mode 100644 lite/core/mir/fusion/conv_bn_fuse_pass.h create mode 100644 lite/core/mir/fusion/conv_bn_fuse_pass_test.cc create mode 100644 lite/core/mir/fusion/conv_bn_fuser.cc create mode 100644 lite/core/mir/fusion/conv_bn_fuser.h create mode 100644 lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass_test.cc create mode 100644 lite/core/mir/fusion/conv_elementwise_fuse_pass.cc create mode 100644 lite/core/mir/fusion/conv_elementwise_fuse_pass.h create mode 100644 lite/core/mir/fusion/conv_elementwise_fuser.cc create mode 100644 lite/core/mir/fusion/conv_elementwise_fuser.h create mode 100644 lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc create mode 100644 lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h create mode 100644 lite/core/mir/fusion/elementwise_add_activation_fuse_pass_test.cc create mode 100644 lite/core/mir/fusion/elementwise_add_activation_fuser.cc create mode 100644 lite/core/mir/fusion/elementwise_add_activation_fuser.h create mode 100644 lite/core/mir/fusion/fc_fuse_pass.cc create mode 100644 lite/core/mir/fusion/fc_fuse_pass.h create mode 100644 lite/core/mir/fusion/fc_fuse_pass_test.cc create mode 100644 lite/core/mir/fusion/fc_fuser.cc create mode 100644 lite/core/mir/fusion/fc_fuser.h create mode 100644 lite/core/mir/fusion/quant_dequant_fuse_pass.cc create mode 100644 lite/core/mir/fusion/quant_dequant_fuse_pass.h create mode 100644 lite/core/mir/fusion/quant_dequant_op_fuser.cc create mode 100644 lite/core/mir/fusion/quant_dequant_op_fuser.h create mode 100644 lite/core/mir/generate_program_pass.cc create mode 100644 lite/core/mir/generate_program_pass.h create mode 100644 lite/core/mir/graph_visualize_pass.cc create mode 100644 lite/core/mir/graph_visualize_pass.h create mode 100644 lite/core/mir/io_copy_kernel_pick_pass.cc create mode 100644 lite/core/mir/node.cc create mode 100644 lite/core/mir/node.h create mode 100644 lite/core/mir/pass.cc create mode 100644 lite/core/mir/pass.h create mode 100644 lite/core/mir/pass_manager.cc create mode 100644 lite/core/mir/pass_manager.h create mode 100644 lite/core/mir/pass_manager_test.cc create mode 100644 lite/core/mir/pass_registry.cc create mode 100644 lite/core/mir/pass_registry.h create mode 100644 lite/core/mir/pattern_matcher.cc create mode 100644 lite/core/mir/pattern_matcher.h create mode 100644 lite/core/mir/pattern_matcher_high_api.cc create mode 100644 lite/core/mir/pattern_matcher_high_api.h create mode 100644 lite/core/mir/pattern_matcher_high_api_test.cc create mode 100644 lite/core/mir/pattern_matcher_test.cc create mode 100644 lite/core/mir/pattern_matcher_tester.cc create mode 100644 lite/core/mir/runtime_context_assign_pass.cc create mode 100644 lite/core/mir/ssa_graph.cc create mode 100644 lite/core/mir/ssa_graph.h create mode 100644 lite/core/mir/ssa_graph_test.cc create mode 100644 lite/core/mir/static_kernel_pick_pass.cc create mode 100644 lite/core/mir/static_kernel_pick_pass.h create mode 100644 lite/core/mir/subgraph/CMakeLists.txt create mode 100644 lite/core/mir/subgraph/generate_npu_program_pass.cc create mode 100644 lite/core/mir/subgraph/generate_npu_program_pass.h create mode 100644 lite/core/mir/subgraph/generate_npu_program_pass_test.cc create mode 100644 lite/core/mir/subgraph/subgraph_program_pass.cc create mode 100644 lite/core/mir/subgraph/subgraph_program_pass.h create mode 100644 lite/core/mir/subgraph/subgraph_program_pass_test.cc create mode 100644 lite/core/mir/type_layout_cast_pass.cc create mode 100644 lite/core/mir/type_layout_cast_pass.h create mode 100644 lite/core/mir/type_precision_cast_pass.cc create mode 100644 lite/core/mir/type_precision_cast_pass.h create mode 100644 lite/core/mir/type_target_cast_pass.cc create mode 100644 lite/core/mir/type_target_cast_pass.h create mode 100644 lite/core/mir/variable_place_inference_pass.cc create mode 100644 lite/core/mir/variable_place_inference_pass.h create mode 100644 lite/core/mir/variable_place_inference_pass_test.cc create mode 100644 lite/core/naive_test_model.py create mode 100644 lite/core/op_lite.cc create mode 100644 lite/core/op_lite.h create mode 100644 lite/core/op_lite_test.cc create mode 100644 lite/core/op_registry.cc create mode 100644 lite/core/op_registry.h create mode 100644 lite/core/optimizer.cc create mode 100644 lite/core/optimizer.h create mode 100644 lite/core/optimizer_test.cc create mode 100644 lite/core/profile/CMakeLists.txt create mode 100644 lite/core/profile/basic_profiler.cc create mode 100644 lite/core/profile/basic_profiler.h create mode 100644 lite/core/profile/basic_profiler_test.cc create mode 100644 lite/core/profile/precision_profiler.h create mode 100644 lite/core/program.cc create mode 100644 lite/core/program.h create mode 100644 lite/core/program_fake_utils.cc create mode 100644 lite/core/program_fake_utils.h create mode 100644 lite/core/scope.cc create mode 100644 lite/core/scope.h create mode 100644 lite/core/scope_test.cc create mode 100644 lite/core/target_wrapper.cc create mode 100644 lite/core/target_wrapper.h create mode 100644 lite/core/tensor.cc create mode 100644 lite/core/tensor.h create mode 100644 lite/core/type_system.cc create mode 100644 lite/core/type_system.h create mode 100644 lite/core/type_system_test.cc create mode 100644 lite/core/types.cc create mode 100644 lite/core/types.h create mode 100644 lite/core/types_test.cc create mode 100644 lite/core/variable.cc create mode 100644 lite/core/variable.h create mode 100644 lite/core/workspace.cc create mode 100644 lite/core/workspace.h create mode 100644 lite/cuda/CMakeLists.txt create mode 100644 lite/cuda/blas.cc create mode 100644 lite/cuda/blas.h create mode 100644 lite/cuda/cuda_utils.h create mode 100644 lite/cuda/target_wrapper.cc create mode 100644 lite/cuda/target_wrapper.h create mode 100644 lite/demo/cxx/Makefile.def create mode 100644 lite/demo/cxx/README.md create mode 100644 lite/demo/cxx/makefiles/mobile_full/Makefile.android.armv7 create mode 100644 lite/demo/cxx/makefiles/mobile_full/Makefile.android.armv8 create mode 100644 lite/demo/cxx/makefiles/mobile_light/Makefile.android.armv7 create mode 100644 lite/demo/cxx/makefiles/mobile_light/Makefile.android.armv8 create mode 100644 lite/demo/cxx/mobile_full/mobilenetv1_full_api.cc create mode 100644 lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc create mode 100644 lite/demo/java/README.md create mode 100644 lite/demo/java/android/PaddlePredictor/.gitignore create mode 100644 lite/demo/java/android/PaddlePredictor/app/.gitignore create mode 100644 lite/demo/java/android/PaddlePredictor/app/build.gradle create mode 100644 lite/demo/java/android/PaddlePredictor/app/proguard-rules.pro create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/androidTest/java/com/baidu/paddle/lite/ExampleInstrumentedTest.java create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/AndroidManifest.xml create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/assets/README.txt create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/java/com/baidu/paddle/lite/MainActivity.java create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/drawable-v24/ic_launcher_foreground.xml create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/drawable/ic_launcher_background.xml create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/layout/activity_main.xml create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-hdpi/ic_launcher.png create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-hdpi/ic_launcher_round.png create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-mdpi/ic_launcher.png create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-mdpi/ic_launcher_round.png create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-xhdpi/ic_launcher.png create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-xxhdpi/ic_launcher.png create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/values/colors.xml create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/values/strings.xml create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/main/res/values/styles.xml create mode 100644 lite/demo/java/android/PaddlePredictor/app/src/test/java/com/baidu/paddle/lite/ExampleUnitTest.java create mode 100644 lite/demo/java/android/PaddlePredictor/build.gradle create mode 100644 lite/demo/java/android/PaddlePredictor/gradle.properties create mode 100644 lite/demo/java/android/PaddlePredictor/gradle/wrapper/gradle-wrapper.jar create mode 100644 lite/demo/java/android/PaddlePredictor/gradle/wrapper/gradle-wrapper.properties create mode 100755 lite/demo/java/android/PaddlePredictor/gradlew create mode 100644 lite/demo/java/android/PaddlePredictor/gradlew.bat create mode 100644 lite/demo/java/android/PaddlePredictor/settings.gradle create mode 100644 lite/demo/java/android/prepare_demo.bash create mode 100644 lite/fluid/CMakeLists.txt create mode 100644 lite/fluid/data_type.cc create mode 100644 lite/fluid/data_type.h create mode 100644 lite/fluid/data_type_test.cc create mode 100644 lite/fluid/eigen.h create mode 100644 lite/fluid/float16.h create mode 100644 lite/fluid/lod.h create mode 100644 lite/fluid/math.h create mode 100644 lite/fpga/CMakeLists.txt create mode 100644 lite/fpga/KD/alignment.h create mode 100644 lite/fpga/KD/context.hpp create mode 100644 lite/fpga/KD/dl_engine.cpp create mode 100644 lite/fpga/KD/dl_engine.hpp create mode 100755 lite/fpga/KD/float16.hpp create mode 100644 lite/fpga/KD/fpga_cv.cpp create mode 100644 lite/fpga/KD/fpga_cv.hpp create mode 100644 lite/fpga/KD/layout.hpp create mode 100644 lite/fpga/KD/llapi/bias_scale.cpp create mode 100644 lite/fpga/KD/llapi/bias_scale.h create mode 100755 lite/fpga/KD/llapi/config.h create mode 100644 lite/fpga/KD/llapi/filter.cpp create mode 100644 lite/fpga/KD/llapi/filter.h create mode 100644 lite/fpga/KD/llapi/zynqmp_api.cpp create mode 100644 lite/fpga/KD/llapi/zynqmp_api.h create mode 100644 lite/fpga/KD/pe.hpp create mode 100644 lite/fpga/KD/pe_params.hpp create mode 100644 lite/fpga/KD/pes/batchnorm_pe.hpp create mode 100644 lite/fpga/KD/pes/concat_pe.hpp create mode 100644 lite/fpga/KD/pes/conv_pe.hpp create mode 100644 lite/fpga/KD/pes/conv_process.hpp create mode 100644 lite/fpga/KD/pes/crop_pe.cpp create mode 100755 lite/fpga/KD/pes/crop_pe.hpp create mode 100755 lite/fpga/KD/pes/depthwise_conv_pe.hpp create mode 100755 lite/fpga/KD/pes/elementwise_add_pe.hpp create mode 100644 lite/fpga/KD/pes/fully_connected_pe.hpp create mode 100755 lite/fpga/KD/pes/input_pe.hpp create mode 100644 lite/fpga/KD/pes/norm_pe.hpp create mode 100644 lite/fpga/KD/pes/output_pe.hpp create mode 100644 lite/fpga/KD/pes/pooling_pe.hpp create mode 100644 lite/fpga/KD/pes/prior_box_pe.cpp create mode 100755 lite/fpga/KD/pes/prior_box_pe.hpp create mode 100755 lite/fpga/KD/pes/relu_pe.hpp create mode 100644 lite/fpga/KD/pes/resize.hpp create mode 100755 lite/fpga/KD/pes/scale_pe.hpp create mode 100755 lite/fpga/KD/pes/softmax_pe.cpp create mode 100644 lite/fpga/KD/pes/softmax_pe.hpp create mode 100644 lite/fpga/KD/pes/split_pe.hpp create mode 100755 lite/fpga/KD/shape.hpp create mode 100644 lite/fpga/KD/tensor.hpp create mode 100644 lite/fpga/KD/tensor_util.cpp create mode 100644 lite/fpga/KD/tensor_util.hpp create mode 100644 lite/fpga/lite_tensor.cc create mode 100644 lite/fpga/lite_tensor.h create mode 100644 lite/fpga/target_wrapper.cc create mode 100644 lite/gen_code/CMakeLists.txt create mode 100644 lite/gen_code/gen_code.cc create mode 100644 lite/gen_code/gen_code.h create mode 100644 lite/gen_code/gen_code_test.cc create mode 100644 lite/gen_code/generated_code_test.cc create mode 100644 lite/gen_code/paddle_code_generator.cc create mode 100644 lite/gen_code/paddle_infer.cc create mode 100644 lite/gen_code/paddle_infer.h create mode 100644 lite/host/CMakeLists.txt create mode 100644 lite/host/target_wrapper.cc create mode 100644 lite/kernels/CMakeLists.txt create mode 100644 lite/kernels/arm/CMakeLists.txt create mode 100644 lite/kernels/arm/activation_compute.cc create mode 100644 lite/kernels/arm/activation_compute.h create mode 100644 lite/kernels/arm/argmax_compute.cc create mode 100644 lite/kernels/arm/argmax_compute.h create mode 100644 lite/kernels/arm/argmax_compute_test.cc create mode 100644 lite/kernels/arm/axpy_compute.cc create mode 100644 lite/kernels/arm/axpy_compute.h create mode 100644 lite/kernels/arm/axpy_compute_test.cc create mode 100644 lite/kernels/arm/batch_norm_compute.cc create mode 100644 lite/kernels/arm/batch_norm_compute.h create mode 100644 lite/kernels/arm/batch_norm_compute_test.cc create mode 100644 lite/kernels/arm/beam_search_compute.cc create mode 100644 lite/kernels/arm/beam_search_compute.h create mode 100644 lite/kernels/arm/beam_search_decode_compute.cc create mode 100644 lite/kernels/arm/beam_search_decode_compute.h create mode 100644 lite/kernels/arm/box_coder_compute.cc create mode 100644 lite/kernels/arm/box_coder_compute.h create mode 100644 lite/kernels/arm/calib_compute.cc create mode 100644 lite/kernels/arm/calib_compute.h create mode 100644 lite/kernels/arm/calib_compute_test.cc create mode 100644 lite/kernels/arm/cast_compute.cc create mode 100644 lite/kernels/arm/cast_compute.h create mode 100644 lite/kernels/arm/compare_compute.cc create mode 100644 lite/kernels/arm/compare_compute.h create mode 100644 lite/kernels/arm/concat_compute.cc create mode 100644 lite/kernels/arm/concat_compute.h create mode 100644 lite/kernels/arm/concat_compute_test.cc create mode 100644 lite/kernels/arm/conv_compute.cc create mode 100644 lite/kernels/arm/conv_compute.h create mode 100644 lite/kernels/arm/conv_compute_test.cc create mode 100644 lite/kernels/arm/conv_transpose_compute.cc create mode 100644 lite/kernels/arm/conv_transpose_compute.h create mode 100644 lite/kernels/arm/conv_transpose_compute_test.cc create mode 100644 lite/kernels/arm/crop_compute.cc create mode 100644 lite/kernels/arm/crop_compute.h create mode 100644 lite/kernels/arm/decode_bboxes_compute.cc create mode 100644 lite/kernels/arm/decode_bboxes_compute.h create mode 100644 lite/kernels/arm/decode_bboxes_compute_test.cc create mode 100644 lite/kernels/arm/density_prior_box_compute.cc create mode 100644 lite/kernels/arm/density_prior_box_compute.h create mode 100644 lite/kernels/arm/dropout_compute.cc create mode 100644 lite/kernels/arm/dropout_compute.h create mode 100644 lite/kernels/arm/dropout_compute_test.cc create mode 100644 lite/kernels/arm/elementwise_compute.cc create mode 100644 lite/kernels/arm/elementwise_compute.h create mode 100644 lite/kernels/arm/elementwise_compute_test.cc create mode 100644 lite/kernels/arm/fc_compute.cc create mode 100644 lite/kernels/arm/fc_compute.h create mode 100644 lite/kernels/arm/fc_compute_test.cc create mode 100644 lite/kernels/arm/fill_constant_compute.cc create mode 100644 lite/kernels/arm/gru_compute.cc create mode 100644 lite/kernels/arm/gru_compute.h create mode 100644 lite/kernels/arm/gru_unit_compute.cc create mode 100644 lite/kernels/arm/gru_unit_compute.h create mode 100644 lite/kernels/arm/im2sequence_compute.cc create mode 100644 lite/kernels/arm/im2sequence_compute.h create mode 100644 lite/kernels/arm/increment_compute.cc create mode 100644 lite/kernels/arm/increment_compute.h create mode 100644 lite/kernels/arm/interpolate_compute.cc create mode 100644 lite/kernels/arm/interpolate_compute.h create mode 100644 lite/kernels/arm/is_empty_compute.cc create mode 100644 lite/kernels/arm/is_empty_compute.h create mode 100644 lite/kernels/arm/lod_reset_compute.cc create mode 100644 lite/kernels/arm/lod_reset_compute.h create mode 100644 lite/kernels/arm/logical_compute.cc create mode 100644 lite/kernels/arm/logical_compute.h create mode 100644 lite/kernels/arm/lookup_table_compute.cc create mode 100644 lite/kernels/arm/lookup_table_compute.h create mode 100644 lite/kernels/arm/lrn_compute.cc create mode 100644 lite/kernels/arm/lrn_compute.h create mode 100644 lite/kernels/arm/lrn_compute_test.cc create mode 100644 lite/kernels/arm/mul_compute.cc create mode 100644 lite/kernels/arm/mul_compute.h create mode 100644 lite/kernels/arm/mul_compute_test.cc create mode 100644 lite/kernels/arm/multiclass_nms_compute.cc create mode 100644 lite/kernels/arm/multiclass_nms_compute.h create mode 100644 lite/kernels/arm/multiclass_nms_compute_test.cc create mode 100644 lite/kernels/arm/negative_compute.cc create mode 100644 lite/kernels/arm/negative_compute.h create mode 100644 lite/kernels/arm/norm_compute.cc create mode 100644 lite/kernels/arm/norm_compute.h create mode 100644 lite/kernels/arm/pad2d_compute.cc create mode 100644 lite/kernels/arm/pad2d_compute.h create mode 100644 lite/kernels/arm/pool_compute.cc create mode 100644 lite/kernels/arm/pool_compute.h create mode 100644 lite/kernels/arm/pool_compute_test.cc create mode 100644 lite/kernels/arm/power_compute.cc create mode 100644 lite/kernels/arm/power_compute.h create mode 100644 lite/kernels/arm/prior_box_compute.cc create mode 100644 lite/kernels/arm/prior_box_compute.h create mode 100644 lite/kernels/arm/read_from_array_compute.cc create mode 100644 lite/kernels/arm/read_from_array_compute.h create mode 100644 lite/kernels/arm/reduce_max_compute.cc create mode 100644 lite/kernels/arm/reduce_max_compute.h create mode 100644 lite/kernels/arm/scale_compute.cc create mode 100644 lite/kernels/arm/scale_compute.h create mode 100644 lite/kernels/arm/scale_compute_test.cc create mode 100644 lite/kernels/arm/sequence_expand_compute.cc create mode 100644 lite/kernels/arm/sequence_expand_compute.h create mode 100644 lite/kernels/arm/sequence_pool_compute.cc create mode 100644 lite/kernels/arm/sequence_pool_compute.h create mode 100644 lite/kernels/arm/sequence_softmax_compute.cc create mode 100644 lite/kernels/arm/sequence_softmax_compute.h create mode 100644 lite/kernels/arm/shape_compute.cc create mode 100644 lite/kernels/arm/shape_compute.h create mode 100644 lite/kernels/arm/shuffle_channel_compute.cc create mode 100644 lite/kernels/arm/shuffle_channel_compute.h create mode 100644 lite/kernels/arm/slice_compute.cc create mode 100644 lite/kernels/arm/slice_compute.h create mode 100644 lite/kernels/arm/softmax_compute.cc create mode 100644 lite/kernels/arm/softmax_compute.h create mode 100644 lite/kernels/arm/softmax_compute_test.cc create mode 100644 lite/kernels/arm/split_compute.cc create mode 100644 lite/kernels/arm/split_compute.h create mode 100644 lite/kernels/arm/split_compute_test.cc create mode 100644 lite/kernels/arm/topk_compute.cc create mode 100644 lite/kernels/arm/topk_compute.h create mode 100644 lite/kernels/arm/transpose_compute.cc create mode 100644 lite/kernels/arm/transpose_compute.h create mode 100644 lite/kernels/arm/transpose_compute_test.cc create mode 100644 lite/kernels/arm/while_compute.cc create mode 100644 lite/kernels/arm/while_compute.h create mode 100644 lite/kernels/arm/write_to_array_compute.cc create mode 100644 lite/kernels/arm/write_to_array_compute.h create mode 100644 lite/kernels/arm/yolo_box_compute.cc create mode 100644 lite/kernels/arm/yolo_box_compute.h create mode 100644 lite/kernels/cuda/CMakeLists.txt create mode 100644 lite/kernels/cuda/io_copy_compute.cc create mode 100644 lite/kernels/cuda/mul_compute.cc create mode 100644 lite/kernels/cuda/mul_compute.h create mode 100644 lite/kernels/cuda/use_kernels.h create mode 100644 lite/kernels/fpga/CMakeLists.txt create mode 100644 lite/kernels/fpga/activation_compute.cc create mode 100644 lite/kernels/fpga/activation_compute.h create mode 100644 lite/kernels/fpga/activation_compute_test.cc create mode 100644 lite/kernels/fpga/calib_compute.cc create mode 100644 lite/kernels/fpga/calib_compute.h create mode 100644 lite/kernels/fpga/conv_compute.cc create mode 100644 lite/kernels/fpga/conv_compute.h create mode 100644 lite/kernels/fpga/conv_compute_test.cc create mode 100644 lite/kernels/fpga/elementwise_compute.cc create mode 100644 lite/kernels/fpga/elementwise_compute.h create mode 100644 lite/kernels/fpga/elementwise_compute_test.cc create mode 100644 lite/kernels/fpga/fc_compute.cc create mode 100644 lite/kernels/fpga/fc_compute.h create mode 100644 lite/kernels/fpga/fc_compute_test.cc create mode 100644 lite/kernels/fpga/feed_compute.cc create mode 100644 lite/kernels/fpga/feed_compute.h create mode 100644 lite/kernels/fpga/fetch_compute.cc create mode 100644 lite/kernels/fpga/fetch_compute.h create mode 100644 lite/kernels/fpga/io_copy_compute.cc create mode 100644 lite/kernels/fpga/layout_compute.cc create mode 100644 lite/kernels/fpga/pooling_compute.cc create mode 100644 lite/kernels/fpga/pooling_compute.h create mode 100644 lite/kernels/fpga/pooling_compute_test.cc create mode 100644 lite/kernels/fpga/scale_compute.cc create mode 100644 lite/kernels/fpga/scale_compute.h create mode 100644 lite/kernels/fpga/softmax_compute.cc create mode 100644 lite/kernels/fpga/softmax_compute.h create mode 100644 lite/kernels/fpga/softmax_compute_test.cc create mode 100644 lite/kernels/host/CMakeLists.txt create mode 100644 lite/kernels/host/feed_compute.cc create mode 100644 lite/kernels/host/fetch_compute.cc create mode 100644 lite/kernels/host/reshape_compute.cc create mode 100644 lite/kernels/host/reshape_compute.h create mode 100644 lite/kernels/host/reshape_compute_test.cc create mode 100644 lite/kernels/host/use_kernels.h create mode 100644 lite/kernels/npu/CMakeLists.txt create mode 100644 lite/kernels/npu/graph_compute.cc create mode 100644 lite/kernels/npu/graph_compute.h create mode 100644 lite/kernels/opencl/CMakeLists.txt create mode 100644 lite/kernels/opencl/conv_compute.cc create mode 100644 lite/kernels/opencl/conv_compute.h create mode 100644 lite/kernels/opencl/conv_compute_test.cc create mode 100644 lite/kernels/opencl/depthwise_conv2d_compute.cc create mode 100644 lite/kernels/opencl/depthwise_conv2d_compute_test.cc create mode 100644 lite/kernels/opencl/elementwise_add_compute.cc create mode 100644 lite/kernels/opencl/elementwise_add_compute.h create mode 100644 lite/kernels/opencl/elementwise_add_compute_test.cc create mode 100644 lite/kernels/opencl/fc_compute.cc create mode 100644 lite/kernels/opencl/fc_compute_test.cc create mode 100644 lite/kernels/opencl/fusion_elementwise_add_activation_compute.cc create mode 100644 lite/kernels/opencl/io_copy_compute.cc create mode 100644 lite/kernels/opencl/io_copy_compute_test.cc create mode 100644 lite/kernels/opencl/mul_compute.cc create mode 100644 lite/kernels/opencl/mul_compute_test.cc create mode 100644 lite/kernels/opencl/pool_compute.cc create mode 100644 lite/kernels/opencl/pool_compute_test.cc create mode 100644 lite/kernels/opencl/relu_compute.cc create mode 100644 lite/kernels/opencl/relu_compute_test.cc create mode 100644 lite/kernels/x86/CMakeLists.txt create mode 100644 lite/kernels/x86/activation_compute.cc create mode 100644 lite/kernels/x86/batch_norm_compute.cc create mode 100644 lite/kernels/x86/batch_norm_compute.h create mode 100644 lite/kernels/x86/batch_norm_compute_test.cc create mode 100644 lite/kernels/x86/concat_compute.cc create mode 100644 lite/kernels/x86/concat_compute.h create mode 100644 lite/kernels/x86/concat_compute_test.cc create mode 100644 lite/kernels/x86/conv_compute.cc create mode 100644 lite/kernels/x86/conv_compute.h create mode 100644 lite/kernels/x86/conv_compute_test.cc create mode 100644 lite/kernels/x86/dropout_compute.cc create mode 100644 lite/kernels/x86/dropout_compute.h create mode 100644 lite/kernels/x86/dropout_compute_test.cc create mode 100644 lite/kernels/x86/elementwise_compute.cc create mode 100644 lite/kernels/x86/elementwise_compute.h create mode 100644 lite/kernels/x86/elementwise_compute_test.cc create mode 100644 lite/kernels/x86/fc_compute.cc create mode 100644 lite/kernels/x86/fc_compute.h create mode 100644 lite/kernels/x86/fc_compute_test.cc create mode 100644 lite/kernels/x86/fill_constant_compute.cc create mode 100644 lite/kernels/x86/mean_compute.cc create mode 100644 lite/kernels/x86/mul_compute.cc create mode 100644 lite/kernels/x86/mul_compute.h create mode 100644 lite/kernels/x86/mul_compute_test.cc create mode 100644 lite/kernels/x86/pool_compute.cc create mode 100644 lite/kernels/x86/pool_compute.h create mode 100644 lite/kernels/x86/pool_compute_test.cc create mode 100644 lite/kernels/x86/relu_compute.cc create mode 100644 lite/kernels/x86/relu_compute.h create mode 100644 lite/kernels/x86/relu_compute_test.cc create mode 100644 lite/kernels/x86/scale_compute.cc create mode 100644 lite/kernels/x86/scale_compute.h create mode 100644 lite/kernels/x86/scale_compute_test.cc create mode 100644 lite/kernels/x86/sgd_compute.cc create mode 100644 lite/kernels/x86/softmax_compute.cc create mode 100644 lite/kernels/x86/softmax_compute.h create mode 100644 lite/kernels/x86/softmax_compute_test.cc create mode 100644 lite/kernels/x86/uniform_random_compute.cc create mode 100644 lite/model_parser/CMakeLists.txt create mode 100644 lite/model_parser/compatible_pb.cc create mode 100644 lite/model_parser/compatible_pb.h create mode 100644 lite/model_parser/compatible_pb_test.cc create mode 100644 lite/model_parser/cpp/CMakeLists.txt create mode 100644 lite/model_parser/cpp/block_desc.cc create mode 100644 lite/model_parser/cpp/block_desc.h create mode 100644 lite/model_parser/cpp/op_desc.cc create mode 100644 lite/model_parser/cpp/op_desc.h create mode 100644 lite/model_parser/cpp/program_desc.cc create mode 100644 lite/model_parser/cpp/program_desc.h create mode 100644 lite/model_parser/cpp/var_desc.cc create mode 100644 lite/model_parser/cpp/var_desc.h create mode 100644 lite/model_parser/desc_apis.h create mode 100644 lite/model_parser/model_parser.cc create mode 100644 lite/model_parser/model_parser.h create mode 100644 lite/model_parser/model_parser_test.cc create mode 100644 lite/model_parser/naive_buffer/CMakeLists.txt create mode 100644 lite/model_parser/naive_buffer/block_desc.cc create mode 100644 lite/model_parser/naive_buffer/block_desc.h create mode 100644 lite/model_parser/naive_buffer/naive_buffer.cc create mode 100644 lite/model_parser/naive_buffer/naive_buffer.h create mode 100644 lite/model_parser/naive_buffer/naive_buffer_test.cc create mode 100644 lite/model_parser/naive_buffer/naive_buffer_wrapper_helper.h create mode 100644 lite/model_parser/naive_buffer/naive_buffer_wrapper_test.cc create mode 100644 lite/model_parser/naive_buffer/op_desc.cc create mode 100644 lite/model_parser/naive_buffer/op_desc.h create mode 100644 lite/model_parser/naive_buffer/param_desc.cc create mode 100644 lite/model_parser/naive_buffer/param_desc.h create mode 100644 lite/model_parser/naive_buffer/program_desc.cc create mode 100644 lite/model_parser/naive_buffer/program_desc.h create mode 100644 lite/model_parser/naive_buffer/proto/CMakeLists.txt create mode 100644 lite/model_parser/naive_buffer/proto/framework.nb.cc create mode 100644 lite/model_parser/naive_buffer/proto/framework.nb.h create mode 100644 lite/model_parser/naive_buffer/var_desc.cc create mode 100644 lite/model_parser/naive_buffer/var_desc.h create mode 100644 lite/model_parser/pb/CMakeLists.txt create mode 100644 lite/model_parser/pb/block_desc.cc create mode 100644 lite/model_parser/pb/block_desc.h create mode 100644 lite/model_parser/pb/op_desc.cc create mode 100644 lite/model_parser/pb/op_desc.h create mode 100644 lite/model_parser/pb/program_desc.cc create mode 100644 lite/model_parser/pb/program_desc.h create mode 100644 lite/model_parser/pb/var_desc.cc create mode 100644 lite/model_parser/pb/var_desc.h create mode 100644 lite/model_parser/runtime.cc create mode 100644 lite/model_parser/runtime.h create mode 100644 lite/npu/CMakeLists.txt create mode 100644 lite/npu/bridge/CMakeLists.txt create mode 100644 lite/npu/bridge/act_op.cc create mode 100644 lite/npu/bridge/act_op_test.cc create mode 100644 lite/npu/bridge/batch_norm_op.cc create mode 100644 lite/npu/bridge/batch_norm_op_test.cc create mode 100644 lite/npu/bridge/conv_op.cc create mode 100644 lite/npu/bridge/conv_op_test.cc create mode 100644 lite/npu/bridge/elementwise_ops.cc create mode 100644 lite/npu/bridge/elementwise_ops_test.cc create mode 100644 lite/npu/bridge/fc_op.cc create mode 100644 lite/npu/bridge/fc_op_test.cc create mode 100644 lite/npu/bridge/mul_op.cc create mode 100644 lite/npu/bridge/mul_op_test.cc create mode 100644 lite/npu/bridge/paddle_use_npu_bridges.h create mode 100644 lite/npu/bridge/pool_op.cc create mode 100644 lite/npu/bridge/pool_op_test.cc create mode 100644 lite/npu/bridge/registry.cc create mode 100644 lite/npu/bridge/registry.h create mode 100644 lite/npu/bridge/scale_op.cc create mode 100644 lite/npu/bridge/scale_op_test.cc create mode 100644 lite/npu/bridge/softmax_op.cc create mode 100644 lite/npu/bridge/softmax_op_test.cc create mode 100644 lite/npu/bridge/test_helper.cc create mode 100644 lite/npu/bridge/test_helper.h create mode 100644 lite/npu/bridge/transpose_op.cc create mode 100644 lite/npu/bridge/transpose_op_test.cc create mode 100644 lite/npu/bridge/utils.cc create mode 100644 lite/npu/bridge/utils.h create mode 100644 lite/npu/npu_helper.cc create mode 100644 lite/npu/npu_helper.h create mode 100644 lite/opencl/CMakeLists.txt create mode 100644 lite/opencl/cl_caller.cc create mode 100644 lite/opencl/cl_caller.h create mode 100644 lite/opencl/cl_context.cc create mode 100644 lite/opencl/cl_context.h create mode 100644 lite/opencl/cl_functions_test.cc create mode 100644 lite/opencl/cl_im2col_test.cc create mode 100644 lite/opencl/cl_image.cc create mode 100644 lite/opencl/cl_image.h create mode 100644 lite/opencl/cl_image_converter.cc create mode 100644 lite/opencl/cl_image_converter.h create mode 100644 lite/opencl/cl_include.h create mode 100644 lite/opencl/cl_kernel/buffer/depthwise_conv2d_kernel.cl create mode 100644 lite/opencl/cl_kernel/buffer/elementwise_add_kernel.cl create mode 100644 lite/opencl/cl_kernel/buffer/fc_kernel.cl create mode 100644 lite/opencl/cl_kernel/buffer/im2col_kernel.cl create mode 100644 lite/opencl/cl_kernel/buffer/mat_mul_kernel.cl create mode 100644 lite/opencl/cl_kernel/buffer/pool_kernel.cl create mode 100644 lite/opencl/cl_kernel/buffer/relu_kernel.cl create mode 100644 lite/opencl/cl_kernel/cl_common.h create mode 100644 lite/opencl/cl_kernel/image/channel_add_kernel.cl create mode 100644 lite/opencl/cl_kernel/image/elementwise_add_kernel.cl create mode 100644 lite/opencl/cl_kernel/image/pool_kernel.cl create mode 100644 lite/opencl/cl_runtime.cc create mode 100644 lite/opencl/cl_runtime.h create mode 100644 lite/opencl/cl_utility.cc create mode 100644 lite/opencl/cl_utility.h create mode 100644 lite/opencl/cl_wrapper.cc create mode 100644 lite/opencl/cl_wrapper.h create mode 100644 lite/opencl/target_wrapper.cc create mode 100644 lite/opencl/target_wrapper.h create mode 100644 lite/operators/CMakeLists.txt create mode 100644 lite/operators/activation_ops.cc create mode 100644 lite/operators/activation_ops.h create mode 100644 lite/operators/argmax_op.cc create mode 100644 lite/operators/argmax_op.h create mode 100644 lite/operators/axpy_op.cc create mode 100644 lite/operators/axpy_op.h create mode 100644 lite/operators/batch_norm_op.cc create mode 100644 lite/operators/batch_norm_op.h create mode 100644 lite/operators/batch_norm_op_test.cc create mode 100644 lite/operators/beam_search_decode_op.cc create mode 100644 lite/operators/beam_search_decode_op.h create mode 100644 lite/operators/beam_search_op.cc create mode 100644 lite/operators/beam_search_op.h create mode 100644 lite/operators/box_coder_op.cc create mode 100644 lite/operators/box_coder_op.h create mode 100644 lite/operators/calib_once_op.cc create mode 100644 lite/operators/calib_once_op.h create mode 100644 lite/operators/calib_op.cc create mode 100644 lite/operators/calib_op.h create mode 100644 lite/operators/calib_op_test.cc create mode 100644 lite/operators/cast_op.cc create mode 100644 lite/operators/cast_op.h create mode 100644 lite/operators/compare_op.cc create mode 100644 lite/operators/compare_op.h create mode 100644 lite/operators/concat_op.cc create mode 100644 lite/operators/concat_op.h create mode 100644 lite/operators/concat_op_test.cc create mode 100644 lite/operators/conv_op.cc create mode 100644 lite/operators/conv_op.h create mode 100644 lite/operators/conv_transpose_op.cc create mode 100644 lite/operators/conv_transpose_op.h create mode 100644 lite/operators/crop_op.cc create mode 100644 lite/operators/crop_op.h create mode 100644 lite/operators/decode_bboxes_op.cc create mode 100644 lite/operators/decode_bboxes_op.h create mode 100644 lite/operators/density_prior_box_op.cc create mode 100644 lite/operators/density_prior_box_op.h create mode 100644 lite/operators/dropout_op.cc create mode 100644 lite/operators/elementwise_ops.cc create mode 100644 lite/operators/elementwise_ops.h create mode 100644 lite/operators/fake_dequantize_max_abs.cc create mode 100644 lite/operators/fake_dequantize_max_abs.h create mode 100644 lite/operators/fake_quantize_moving_avg_max_abs.cc create mode 100644 lite/operators/fake_quantize_moving_avg_max_abs.h create mode 100644 lite/operators/fc_op.cc create mode 100644 lite/operators/fc_op.h create mode 100644 lite/operators/fc_op_test.cc create mode 100644 lite/operators/feed_op.cc create mode 100644 lite/operators/fetch_op.cc create mode 100644 lite/operators/fill_constant_op.cc create mode 100644 lite/operators/fusion_elementwise_activation_ops.cc create mode 100644 lite/operators/fusion_elementwise_activation_ops.h create mode 100644 lite/operators/fusion_elementwise_activation_ops_test.cc create mode 100644 lite/operators/graph_op.cc create mode 100644 lite/operators/graph_op.h create mode 100644 lite/operators/gru_op.cc create mode 100644 lite/operators/gru_op.h create mode 100644 lite/operators/gru_unit_op.cc create mode 100644 lite/operators/gru_unit_op.h create mode 100644 lite/operators/im2sequence_op.cc create mode 100644 lite/operators/im2sequence_op.h create mode 100644 lite/operators/increment_op.cc create mode 100644 lite/operators/increment_op.h create mode 100644 lite/operators/interpolate_op.cc create mode 100644 lite/operators/interpolate_op.h create mode 100644 lite/operators/io_copy_once_op.cc create mode 100644 lite/operators/io_copy_once_op.h create mode 100644 lite/operators/io_copy_op.cc create mode 100644 lite/operators/io_copy_op.h create mode 100644 lite/operators/is_empty_op.cc create mode 100644 lite/operators/is_empty_op.h create mode 100644 lite/operators/layout_once_op.cc create mode 100644 lite/operators/layout_once_op.h create mode 100644 lite/operators/layout_op.cc create mode 100644 lite/operators/layout_op.h create mode 100644 lite/operators/lod_reset_op.cc create mode 100644 lite/operators/lod_reset_op.h create mode 100644 lite/operators/logical_op.cc create mode 100644 lite/operators/logical_op.h create mode 100644 lite/operators/lookup_table_op.cc create mode 100644 lite/operators/lookup_table_op.h create mode 100644 lite/operators/lrn_op.cc create mode 100644 lite/operators/lrn_op.h create mode 100644 lite/operators/mean_op.cc create mode 100644 lite/operators/mul_op.cc create mode 100644 lite/operators/mul_op.h create mode 100644 lite/operators/multiclass_nms_op.cc create mode 100644 lite/operators/multiclass_nms_op.h create mode 100644 lite/operators/negative_op.cc create mode 100644 lite/operators/negative_op.h create mode 100644 lite/operators/norm_op.cc create mode 100644 lite/operators/norm_op.h create mode 100644 lite/operators/op_params.cc create mode 100644 lite/operators/op_params.h create mode 100644 lite/operators/pad2d_op.cc create mode 100644 lite/operators/pad2d_op.h create mode 100644 lite/operators/pool_op.cc create mode 100644 lite/operators/pool_op.h create mode 100644 lite/operators/pool_op_test.cc create mode 100644 lite/operators/power_op.cc create mode 100644 lite/operators/power_op.h create mode 100644 lite/operators/prior_box_op.cc create mode 100644 lite/operators/prior_box_op.h create mode 100644 lite/operators/read_from_array_op.cc create mode 100644 lite/operators/read_from_array_op.h create mode 100644 lite/operators/reduce_max_op.cc create mode 100644 lite/operators/reduce_max_op.h create mode 100644 lite/operators/relu_op.cc create mode 100644 lite/operators/relu_op.h create mode 100644 lite/operators/reshape_op.cc create mode 100644 lite/operators/reshape_op.h create mode 100644 lite/operators/reshape_op_test.cc create mode 100644 lite/operators/scale_op.cc create mode 100644 lite/operators/scale_op.h create mode 100644 lite/operators/scale_op_test.cc create mode 100644 lite/operators/sequence_expand_op.cc create mode 100644 lite/operators/sequence_expand_op.h create mode 100644 lite/operators/sequence_pool_op.cc create mode 100644 lite/operators/sequence_pool_op.h create mode 100644 lite/operators/sequence_softmax_op.cc create mode 100644 lite/operators/sequence_softmax_op.h create mode 100644 lite/operators/sgd_op.cc create mode 100644 lite/operators/sgd_op.h create mode 100644 lite/operators/shape_op.cc create mode 100644 lite/operators/shape_op.h create mode 100644 lite/operators/shuffle_channel_op.cc create mode 100644 lite/operators/shuffle_channel_op.h create mode 100644 lite/operators/slice_op.cc create mode 100644 lite/operators/slice_op.h create mode 100644 lite/operators/softmax_op.cc create mode 100644 lite/operators/softmax_op.h create mode 100644 lite/operators/softmax_op_test.cc create mode 100644 lite/operators/split_op.cc create mode 100644 lite/operators/split_op.h create mode 100644 lite/operators/topk_op.cc create mode 100644 lite/operators/topk_op.h create mode 100644 lite/operators/transpose_op.cc create mode 100644 lite/operators/transpose_op.h create mode 100644 lite/operators/transpose_op_test.cc create mode 100644 lite/operators/uniform_random_op.cc create mode 100644 lite/operators/uniform_random_op.h create mode 100644 lite/operators/while_op.cc create mode 100644 lite/operators/while_op.h create mode 100644 lite/operators/write_to_array_op.cc create mode 100644 lite/operators/write_to_array_op.h create mode 100644 lite/operators/yolo_box_op.cc create mode 100644 lite/operators/yolo_box_op.h create mode 100644 lite/python/lite_test.py create mode 100644 lite/tests/CMakeLists.txt create mode 100644 lite/tests/README.md create mode 100644 lite/tests/kernels/CMakeLists.txt create mode 100644 lite/tests/kernels/activation_compute_test.cc create mode 100644 lite/tests/kernels/argmax_compute_test.cc create mode 100644 lite/tests/kernels/axpy_compute_test.cc create mode 100644 lite/tests/kernels/bilinear_interp_compute_test.cc create mode 100644 lite/tests/kernels/box_coder_compute_test.cc create mode 100644 lite/tests/kernels/compare_compute_test.cc create mode 100644 lite/tests/kernels/conv2d_transpose_compute_test.cc create mode 100644 lite/tests/kernels/crop_compute_test.cc create mode 100644 lite/tests/kernels/decode_bboxes_compute_test.cc create mode 100644 lite/tests/kernels/elementwise_compute_test.cc create mode 100644 lite/tests/kernels/fc_compute_test.cc create mode 100644 lite/tests/kernels/fill_data.h create mode 100644 lite/tests/kernels/gru_unit_test.cc create mode 100644 lite/tests/kernels/im2sequence_compute_test.cc create mode 100644 lite/tests/kernels/increment_compute_test.cc create mode 100644 lite/tests/kernels/logical_compute_test.cc create mode 100644 lite/tests/kernels/lrn_compute_test.cc create mode 100644 lite/tests/kernels/multiclass_nms_compute_test.cc create mode 100644 lite/tests/kernels/nearest_interp_compute_test.cc create mode 100644 lite/tests/kernels/negative_compute_test.cc create mode 100644 lite/tests/kernels/norm_compute_test.cc create mode 100644 lite/tests/kernels/pad2d_compute_test.cc create mode 100644 lite/tests/kernels/power_compute_test.cc create mode 100644 lite/tests/kernels/prior_box_compute_test.cc create mode 100644 lite/tests/kernels/read_from_array_compute_test.cc create mode 100644 lite/tests/kernels/reduce_max_compute_test.cc create mode 100644 lite/tests/kernels/scale_compute_test.cc create mode 100644 lite/tests/kernels/sequence_expand_compute_test.cc create mode 100644 lite/tests/kernels/sequence_pool_compute_test.cc create mode 100644 lite/tests/kernels/sequence_softmax_compute_test.cc create mode 100644 lite/tests/kernels/shape_compute_test.cc create mode 100644 lite/tests/kernels/shuffle_channel_compute_test.cc create mode 100644 lite/tests/kernels/test_funcs.h create mode 100644 lite/tests/kernels/test_sgemm.cc create mode 100644 lite/tests/kernels/topk_compute_test.cc create mode 100644 lite/tests/kernels/write_to_array_compute_test.cc create mode 100644 lite/tests/kernels/yolo_box_compute_test.cc create mode 100644 lite/tools/CMakeLists.txt create mode 100644 lite/tools/Dockerfile.mobile create mode 100755 lite/tools/build.sh create mode 100755 lite/tools/build_fpga.sh create mode 100755 lite/tools/build_ios_armv7_arm64.sh create mode 100755 lite/tools/ci_build.sh create mode 100644 lite/tools/debug/CMakeLists.txt create mode 100644 lite/tools/debug/analysis_tool.py create mode 100755 lite/tools/debug/check_model.sh create mode 100644 lite/tools/debug/debug_utils.cc create mode 100644 lite/tools/debug/debug_utils.h create mode 100644 lite/tools/debug/model_debug_tool.cc create mode 100755 lite/tools/gitlab_review.sh create mode 100644 lite/tools/mobile_readme.md create mode 100644 lite/utils/CMakeLists.txt create mode 100644 lite/utils/all.h create mode 100644 lite/utils/any.cc create mode 100644 lite/utils/any.h create mode 100644 lite/utils/check.h create mode 100644 lite/utils/container.h create mode 100644 lite/utils/cp_logging.cc create mode 100644 lite/utils/cp_logging.h create mode 100644 lite/utils/factory.h create mode 100644 lite/utils/hash.h create mode 100644 lite/utils/io.h create mode 100644 lite/utils/logging.cc create mode 100644 lite/utils/logging.h create mode 100644 lite/utils/logging_test.cc create mode 100644 lite/utils/macros.h create mode 100644 lite/utils/paddle_enforce.h create mode 100644 lite/utils/replace_stl/stream.cc create mode 100644 lite/utils/replace_stl/stream.h create mode 100644 lite/utils/string.cc create mode 100644 lite/utils/string.h create mode 100644 lite/utils/varient.h create mode 100644 lite/utils/varient_test.cc create mode 100644 lite/x86/CMakeLists.txt create mode 100644 lite/x86/cpu_info.cc create mode 100644 lite/x86/cpu_info.h create mode 100644 lite/x86/cupti_lib_path.h.in create mode 100644 lite/x86/dynamic_loader.cc create mode 100644 lite/x86/dynamic_loader.h create mode 100644 lite/x86/jit/CMakeLists.txt create mode 100644 lite/x86/jit/README.en.md create mode 100644 lite/x86/jit/README.md create mode 100644 lite/x86/jit/benchmark.cc create mode 100644 lite/x86/jit/gen/CMakeLists.txt create mode 100644 lite/x86/jit/gen/act.cc create mode 100644 lite/x86/jit/gen/act.h create mode 100644 lite/x86/jit/gen/blas.cc create mode 100644 lite/x86/jit/gen/blas.h create mode 100644 lite/x86/jit/gen/embseqpool.cc create mode 100644 lite/x86/jit/gen/embseqpool.h create mode 100644 lite/x86/jit/gen/gru.cc create mode 100644 lite/x86/jit/gen/gru.h create mode 100644 lite/x86/jit/gen/hopv.cc create mode 100644 lite/x86/jit/gen/hopv.h create mode 100644 lite/x86/jit/gen/jitcode.h create mode 100644 lite/x86/jit/gen/lstm.cc create mode 100644 lite/x86/jit/gen/lstm.h create mode 100644 lite/x86/jit/gen/matmul.cc create mode 100644 lite/x86/jit/gen/matmul.h create mode 100644 lite/x86/jit/gen/seqpool.cc create mode 100644 lite/x86/jit/gen/seqpool.h create mode 100644 lite/x86/jit/gen/sgd.cc create mode 100644 lite/x86/jit/gen/sgd.h create mode 100644 lite/x86/jit/gen/vbroadcast.cc create mode 100644 lite/x86/jit/gen/vbroadcast.h create mode 100644 lite/x86/jit/gen_base.cc create mode 100644 lite/x86/jit/gen_base.h create mode 100644 lite/x86/jit/helper.cc create mode 100644 lite/x86/jit/helper.h create mode 100644 lite/x86/jit/kernel_base.h create mode 100644 lite/x86/jit/kernel_key.cc create mode 100644 lite/x86/jit/kernel_key.h create mode 100644 lite/x86/jit/kernel_pool.cc create mode 100644 lite/x86/jit/kernel_pool.h create mode 100644 lite/x86/jit/macro.h create mode 100644 lite/x86/jit/more/CMakeLists.txt create mode 100644 lite/x86/jit/more/intrinsic/CMakeLists.txt create mode 100644 lite/x86/jit/more/intrinsic/crf_decoding.cc create mode 100644 lite/x86/jit/more/intrinsic/crf_decoding.h create mode 100644 lite/x86/jit/more/intrinsic/layer_norm.cc create mode 100644 lite/x86/jit/more/intrinsic/layer_norm.h create mode 100644 lite/x86/jit/more/mix/CMakeLists.txt create mode 100644 lite/x86/jit/more/mix/mix.cc create mode 100644 lite/x86/jit/more/mix/mix.h create mode 100644 lite/x86/jit/more/mkl/CMakeLists.txt create mode 100644 lite/x86/jit/more/mkl/mkl.cc create mode 100644 lite/x86/jit/more/mkl/mkl.h create mode 100644 lite/x86/jit/refer/CMakeLists.txt create mode 100644 lite/x86/jit/refer/refer.cc create mode 100644 lite/x86/jit/refer/refer.h create mode 100644 lite/x86/jit/registry.h create mode 100644 lite/x86/jit/test.cc create mode 100644 lite/x86/legacy_place.h create mode 100644 lite/x86/math/CMakeLists.txt create mode 100644 lite/x86/math/beam_search.cc create mode 100644 lite/x86/math/beam_search.h create mode 100644 lite/x86/math/beam_search_test.cc create mode 100644 lite/x86/math/blas.cc create mode 100644 lite/x86/math/blas.h create mode 100644 lite/x86/math/blas_impl.h create mode 100644 lite/x86/math/concat_and_split.cc create mode 100644 lite/x86/math/concat_and_split.h create mode 100644 lite/x86/math/context_project.cc create mode 100644 lite/x86/math/context_project.h create mode 100644 lite/x86/math/cos_sim_functor.cc create mode 100644 lite/x86/math/cos_sim_functor.h create mode 100644 lite/x86/math/cpu_vec.h create mode 100644 lite/x86/math/cross_entropy.cc create mode 100644 lite/x86/math/cross_entropy.h create mode 100644 lite/x86/math/detail/CMakeLists.txt create mode 100644 lite/x86/math/detail/activation_functions.h create mode 100644 lite/x86/math/detail/avx_functions.cc create mode 100644 lite/x86/math/detail/avx_mathfun.h create mode 100644 lite/x86/math/detail/gru_cpu_kernel.h create mode 100644 lite/x86/math/detail/gru_kernel.h create mode 100644 lite/x86/math/gru_compute.cc create mode 100644 lite/x86/math/gru_compute.h create mode 100644 lite/x86/math/im2col.cc create mode 100644 lite/x86/math/im2col.h create mode 100644 lite/x86/math/im2col_cfo_cpu.h create mode 100644 lite/x86/math/im2col_test.cc create mode 100644 lite/x86/math/math_function.cc create mode 100644 lite/x86/math/math_function.h create mode 100644 lite/x86/math/math_function_impl.h create mode 100644 lite/x86/math/math_function_test.cc create mode 100644 lite/x86/math/maxouting.cc create mode 100644 lite/x86/math/maxouting.h create mode 100644 lite/x86/math/pooling.cc create mode 100644 lite/x86/math/pooling.h create mode 100644 lite/x86/math/prelu.h create mode 100644 lite/x86/math/sample_prob.cc create mode 100644 lite/x86/math/sample_prob.h create mode 100644 lite/x86/math/sampler.cc create mode 100644 lite/x86/math/sampler.h create mode 100644 lite/x86/math/sequence_pooling.cc create mode 100644 lite/x86/math/sequence_pooling.h create mode 100644 lite/x86/math/sequence_pooling_test.cc create mode 100644 lite/x86/math/softmax.cc create mode 100644 lite/x86/math/softmax.h create mode 100644 lite/x86/math/softmax_impl.h create mode 100644 lite/x86/math/tree2col.cc create mode 100644 lite/x86/math/tree2col.h create mode 100644 lite/x86/math/unpooling.cc create mode 100644 lite/x86/math/unpooling.h create mode 100644 lite/x86/math/vol2col.cc create mode 100644 lite/x86/math/vol2col.h create mode 100644 lite/x86/mklml.cc create mode 100644 lite/x86/mklml.h create mode 100644 lite/x86/port.h create mode 100644 lite/x86/target_wrapper.cc create mode 100644 lite/x86/target_wrapper.h create mode 100644 lite/x86/warpctc_lib_path.h.in create mode 100644 mobile/CMakeLists.txt rename CONTRIBUTING.md => mobile/CONTRIBUTING.md (99%) rename Dockerfile => mobile/Dockerfile (100%) rename LICENSE => mobile/LICENSE (100%) create mode 100644 mobile/README.md rename {benchmark => mobile/benchmark}/arm_benchmark.md (95%) rename {benchmark => mobile/benchmark}/metal_benchmark.md (90%) rename {demo => mobile/demo}/ReadMe.md (100%) rename {demo => mobile/demo}/getDemo.sh (95%) rename {doc => mobile/doc}/build.md (100%) rename {doc => mobile/doc}/design_doc.md (99%) rename {doc => mobile/doc}/development_android.md (99%) rename {doc => mobile/doc}/development_android_GPU.md (100%) rename {doc => mobile/doc}/development_arm_linux.md (99%) rename {doc => mobile/doc}/development_fpga.md (100%) rename {doc => mobile/doc}/development_ios.md (100%) rename {doc => mobile/doc}/quantification.md (99%) rename {src => mobile/src}/common/common.h (100%) rename {src => mobile/src}/common/enforce.h (100%) rename {src => mobile/src}/common/log.h (100%) rename {src => mobile/src}/common/threadpool.h (100%) rename {src => mobile/src}/common/type_define.h (100%) rename {src => mobile/src}/common/types.cpp (100%) rename {src => mobile/src}/common/types.h (100%) rename {src => mobile/src}/common/util.cpp (100%) rename {src => mobile/src}/common/util.h (100%) rename {src => mobile/src}/common/variant.h (100%) rename {src => mobile/src}/fpga/KD/alignment.h (100%) rename {src => mobile/src}/fpga/KD/context.hpp (100%) rename {src => mobile/src}/fpga/KD/dl_engine.cpp (100%) rename {src => mobile/src}/fpga/KD/dl_engine.hpp (100%) rename {src => mobile/src}/fpga/KD/float16.hpp (100%) rename {src => mobile/src}/fpga/KD/layout.hpp (100%) rename {src => mobile/src}/fpga/KD/llapi/bias_scale.cpp (100%) rename {src => mobile/src}/fpga/KD/llapi/bias_scale.h (100%) rename {src => mobile/src}/fpga/KD/llapi/config.h (100%) rename {src => mobile/src}/fpga/KD/llapi/filter.cpp (100%) rename {src => mobile/src}/fpga/KD/llapi/filter.h (100%) rename {src => mobile/src}/fpga/KD/llapi/image.cpp (100%) rename {src => mobile/src}/fpga/KD/llapi/image.h (100%) rename {src => mobile/src}/fpga/KD/llapi/zynqmp_api.cpp (100%) rename {src => mobile/src}/fpga/KD/llapi/zynqmp_api.h (100%) rename {src => mobile/src}/fpga/KD/pe.hpp (100%) rename {src => mobile/src}/fpga/KD/pe_params.hpp (100%) rename {src => mobile/src}/fpga/KD/pes/concat_pe.hpp (100%) rename {src => mobile/src}/fpga/KD/pes/conv_pe.hpp (100%) rename {src => mobile/src}/fpga/KD/pes/conv_process.hpp (100%) rename {src => mobile/src}/fpga/KD/pes/depthwise_conv_pe.hpp (100%) rename {src => mobile/src}/fpga/KD/pes/elementwise_add_pe.hpp (100%) rename {src => mobile/src}/fpga/KD/pes/fully_connected_pe.hpp (100%) rename {src => mobile/src}/fpga/KD/pes/input_pe.hpp (100%) rename {src => mobile/src}/fpga/KD/pes/math_func_neon.h (100%) rename {src => mobile/src}/fpga/KD/pes/output_pe.hpp (100%) rename {src => mobile/src}/fpga/KD/pes/pooling_pe.hpp (100%) rename {src => mobile/src}/fpga/KD/pes/softmax_pe.cpp (100%) rename {src => mobile/src}/fpga/KD/pes/softmax_pe.hpp (100%) rename {src => mobile/src}/fpga/KD/shape.hpp (100%) rename {src => mobile/src}/fpga/KD/tensor.hpp (100%) rename {src => mobile/src}/fpga/KD/tensor_util.cpp (100%) rename {src => mobile/src}/fpga/KD/tensor_util.hpp (100%) rename {src => mobile/src}/fpga/V1/api.cpp (100%) rename {src => mobile/src}/fpga/V1/api.h (100%) rename {src => mobile/src}/fpga/V1/bias_scale.cpp (100%) rename {src => mobile/src}/fpga/V1/bias_scale.h (100%) rename {src => mobile/src}/fpga/V1/deconv_bias_scale.cpp (100%) rename {src => mobile/src}/fpga/V1/deconv_bias_scale.h (100%) rename {src => mobile/src}/fpga/V1/deconv_filter.cpp (100%) rename {src => mobile/src}/fpga/V1/deconv_filter.h (100%) rename {src => mobile/src}/fpga/V1/filter.cpp (100%) rename {src => mobile/src}/fpga/V1/filter.h (100%) rename {src => mobile/src}/fpga/V1/image.cpp (100%) rename {src => mobile/src}/fpga/V1/image.h (100%) rename {src => mobile/src}/fpga/V1/pe.cpp (100%) rename {src => mobile/src}/fpga/V2/api.cpp (100%) rename {src => mobile/src}/fpga/V2/api.h (100%) rename {src => mobile/src}/fpga/V2/bias_scale.cpp (100%) rename {src => mobile/src}/fpga/V2/bias_scale.h (100%) rename {src => mobile/src}/fpga/V2/deconv_bias_scale.cpp (100%) rename {src => mobile/src}/fpga/V2/deconv_bias_scale.h (100%) rename {src => mobile/src}/fpga/V2/deconv_filter.cpp (100%) rename {src => mobile/src}/fpga/V2/deconv_filter.h (100%) rename {src => mobile/src}/fpga/V2/filter.cpp (100%) rename {src => mobile/src}/fpga/V2/filter.h (100%) rename {src => mobile/src}/fpga/V2/image.cpp (100%) rename {src => mobile/src}/fpga/V2/image.h (100%) rename {src => mobile/src}/fpga/V2/pe.cpp (100%) rename {src => mobile/src}/fpga/common/config.h (100%) rename {src => mobile/src}/fpga/common/driver.cpp (100%) rename {src => mobile/src}/fpga/common/driver.h (100%) rename {src => mobile/src}/fpga/common/fpga_common.cpp (100%) rename {src => mobile/src}/fpga/common/fpga_common.h (100%) rename {src => mobile/src}/fpga/common/pe.h (100%) rename {src => mobile/src}/framework/CMakeLists.txt (100%) rename {src => mobile/src}/framework/attribute.cpp (100%) rename {src => mobile/src}/framework/attribute.h (100%) rename {src => mobile/src}/framework/cl/cl_deleter.h (100%) rename {src => mobile/src}/framework/cl/cl_engine.cpp (100%) rename {src => mobile/src}/framework/cl/cl_engine.h (100%) rename {src => mobile/src}/framework/cl/cl_half.cpp (100%) rename {src => mobile/src}/framework/cl/cl_half.h (100%) rename {src => mobile/src}/framework/cl/cl_helper.h (100%) rename {src => mobile/src}/framework/cl/cl_image.cpp (100%) rename {src => mobile/src}/framework/cl/cl_image.h (100%) rename {src => mobile/src}/framework/cl/cl_image_converter.cpp (100%) rename {src => mobile/src}/framework/cl/cl_image_converter.h (100%) rename {src => mobile/src}/framework/cl/cl_scope.h (100%) rename {src => mobile/src}/framework/cl/cl_tensor.h (100%) rename {src => mobile/src}/framework/cl/cl_tool.cpp (100%) rename {src => mobile/src}/framework/cl/cl_tool.h (100%) rename {src => mobile/src}/framework/context.cpp (100%) rename {src => mobile/src}/framework/context.h (100%) rename {src => mobile/src}/framework/data_layout.h (100%) rename {src => mobile/src}/framework/data_type.cpp (100%) rename {src => mobile/src}/framework/data_type.h (100%) rename {src => mobile/src}/framework/ddim.cpp (100%) rename {src => mobile/src}/framework/ddim.h (100%) rename {src => mobile/src}/framework/dim.h (100%) rename {src => mobile/src}/framework/executor.cpp (100%) rename {src => mobile/src}/framework/executor.h (100%) rename {src => mobile/src}/framework/framework.pb-c.c (100%) rename {src => mobile/src}/framework/framework.pb-c.h (100%) rename {src => mobile/src}/framework/framework.proto (100%) rename {src => mobile/src}/framework/load_ops.h (100%) rename {src => mobile/src}/framework/loader.cpp (100%) rename {src => mobile/src}/framework/loader.h (100%) rename {src => mobile/src}/framework/lod_tensor.cpp (100%) rename {src => mobile/src}/framework/lod_tensor.h (100%) rename {src => mobile/src}/framework/mixed_vector.h (100%) rename {src => mobile/src}/framework/op_info.h (100%) rename {src => mobile/src}/framework/op_kernel_type.h (100%) rename {src => mobile/src}/framework/op_proto_maker.h (100%) rename {src => mobile/src}/framework/op_registry.h (100%) rename {src => mobile/src}/framework/operator.cpp (100%) rename {src => mobile/src}/framework/operator.h (100%) rename {src => mobile/src}/framework/program/block_desc.cpp (100%) rename {src => mobile/src}/framework/program/block_desc.h (100%) rename {src => mobile/src}/framework/program/op_desc.cpp (100%) rename {src => mobile/src}/framework/program/op_desc.h (100%) rename {src => mobile/src}/framework/program/program-optimize/fusion_op_register.h (100%) rename {src => mobile/src}/framework/program/program-optimize/node.cpp (100%) rename {src => mobile/src}/framework/program/program-optimize/node.h (100%) rename {src => mobile/src}/framework/program/program-optimize/program_optimize.cpp (100%) rename {src => mobile/src}/framework/program/program-optimize/program_optimize.h (100%) rename {src => mobile/src}/framework/program/program.h (100%) rename {src => mobile/src}/framework/program/program_desc.cpp (100%) rename {src => mobile/src}/framework/program/program_desc.h (100%) rename {src => mobile/src}/framework/program/tensor_desc.h (100%) rename {src => mobile/src}/framework/program/var_desc.h (100%) rename {src => mobile/src}/framework/scope.cpp (100%) rename {src => mobile/src}/framework/scope.h (100%) rename {src => mobile/src}/framework/selected_rows.cpp (100%) rename {src => mobile/src}/framework/selected_rows.h (100%) rename {src => mobile/src}/framework/tensor.h (100%) rename {src => mobile/src}/framework/tensor_base.h (100%) rename {src => mobile/src}/framework/tensor_util.cpp (100%) rename {src => mobile/src}/framework/tensor_util.h (100%) rename {src => mobile/src}/framework/type_trait.h (100%) rename {src => mobile/src}/framework/variable.h (100%) rename {src => mobile/src}/framework/zynqmp/ztensor.hpp (100%) rename {src => mobile/src}/io/api.cc (100%) rename {src => mobile/src}/io/api_paddle_mobile.cc (100%) rename {src => mobile/src}/io/api_paddle_mobile.h (100%) rename {src => mobile/src}/io/ios_io/PaddleMobileCPU.h (100%) rename {src => mobile/src}/io/ios_io/PaddleMobileCPU.mm (100%) rename {src => mobile/src}/io/jni/PML.java (100%) rename {src => mobile/src}/io/jni/paddle_mobile_jni.cpp (100%) rename {src => mobile/src}/io/jni/paddle_mobile_jni.h (100%) rename {src => mobile/src}/io/loader.h (100%) rename {src => mobile/src}/io/opencl_interface.cpp (100%) rename {src => mobile/src}/io/opencl_interface.h (100%) rename {src => mobile/src}/io/paddle_inference_api.h (100%) rename {src => mobile/src}/io/paddle_mobile.cpp (100%) rename {src => mobile/src}/io/paddle_mobile.h (100%) rename {src => mobile/src}/io/paddle_mobile_wrap.cpp (100%) rename {src => mobile/src}/io/paddle_mobile_wrap.h (99%) rename {src => mobile/src}/io/paddle_test_inference_api.cpp (100%) rename {src => mobile/src}/io/paddle_test_inference_api.h (100%) rename {src => mobile/src}/memory/t_malloc.cpp (100%) rename {src => mobile/src}/memory/t_malloc.h (100%) rename {src => mobile/src}/operators/activation_op.cpp (100%) rename {src => mobile/src}/operators/activation_op.h (100%) rename {src => mobile/src}/operators/assign_op.cpp (100%) rename {src => mobile/src}/operators/assign_op.h (100%) rename {src => mobile/src}/operators/assign_value_op.cpp (100%) rename {src => mobile/src}/operators/assign_value_op.h (100%) rename {src => mobile/src}/operators/batchnorm_op.cpp (100%) rename {src => mobile/src}/operators/batchnorm_op.h (100%) rename {src => mobile/src}/operators/beam_search_decode_op.cpp (100%) rename {src => mobile/src}/operators/beam_search_decode_op.h (100%) rename {src => mobile/src}/operators/beam_search_op.cpp (100%) rename {src => mobile/src}/operators/beam_search_op.h (100%) rename {src => mobile/src}/operators/bilinear_interp_op.cpp (100%) rename {src => mobile/src}/operators/bilinear_interp_op.h (100%) rename {src => mobile/src}/operators/box_coder_op.cpp (100%) rename {src => mobile/src}/operators/box_coder_op.h (100%) rename {src => mobile/src}/operators/cast_op.cpp (100%) rename {src => mobile/src}/operators/cast_op.h (100%) rename {src => mobile/src}/operators/compare_op.cpp (100%) rename {src => mobile/src}/operators/compare_op.h (100%) rename {src => mobile/src}/operators/concat_op.cpp (100%) rename {src => mobile/src}/operators/concat_op.h (100%) rename {src => mobile/src}/operators/conditional_block_op.cpp (100%) rename {src => mobile/src}/operators/conditional_block_op.h (100%) rename {src => mobile/src}/operators/controlflow/tensor_array_read_write_op.cpp (100%) rename {src => mobile/src}/operators/controlflow/tensor_array_read_write_op.h (100%) rename {src => mobile/src}/operators/controlflow/while_op.cpp (100%) rename {src => mobile/src}/operators/controlflow/while_op.h (100%) rename {src => mobile/src}/operators/conv_op.cpp (100%) rename {src => mobile/src}/operators/conv_op.h (100%) rename {src => mobile/src}/operators/conv_transpose_op.cpp (100%) rename {src => mobile/src}/operators/conv_transpose_op.h (100%) rename {src => mobile/src}/operators/crf_op.cpp (100%) rename {src => mobile/src}/operators/crf_op.h (100%) rename {src => mobile/src}/operators/depthwise_conv_op.cpp (100%) rename {src => mobile/src}/operators/depthwise_conv_op.h (100%) rename {src => mobile/src}/operators/dequantize_op.cpp (100%) rename {src => mobile/src}/operators/dequantize_op.h (100%) rename {src => mobile/src}/operators/detection_ops.cpp (100%) rename {src => mobile/src}/operators/detection_ops.h (100%) rename {src => mobile/src}/operators/dropout_op.cpp (100%) rename {src => mobile/src}/operators/dropout_op.h (100%) rename {src => mobile/src}/operators/elementwise_add_op.cpp (100%) rename {src => mobile/src}/operators/elementwise_add_op.h (100%) rename {src => mobile/src}/operators/elementwise_mul_op.cpp (100%) rename {src => mobile/src}/operators/elementwise_mul_op.h (100%) rename {src => mobile/src}/operators/elementwise_sub_op.cpp (100%) rename {src => mobile/src}/operators/elementwise_sub_op.h (100%) rename {src => mobile/src}/operators/exp_op.cpp (100%) rename {src => mobile/src}/operators/exp_op.h (100%) rename {src => mobile/src}/operators/feed_op.cpp (100%) rename {src => mobile/src}/operators/feed_op.h (100%) rename {src => mobile/src}/operators/fetch_op.cpp (100%) rename {src => mobile/src}/operators/fetch_op.h (100%) rename {src => mobile/src}/operators/fill_constant_batch_size_like_op.cpp (100%) rename {src => mobile/src}/operators/fill_constant_batch_size_like_op.h (100%) rename {src => mobile/src}/operators/fill_constant_op.cpp (100%) rename {src => mobile/src}/operators/fill_constant_op.h (100%) rename {src => mobile/src}/operators/flatten2_op.cpp (100%) rename {src => mobile/src}/operators/flatten2_op.h (100%) rename {src => mobile/src}/operators/flatten_op.cpp (100%) rename {src => mobile/src}/operators/flatten_op.h (100%) rename {src => mobile/src}/operators/fusion_conv_add_bn_op.cpp (100%) rename {src => mobile/src}/operators/fusion_conv_add_bn_op.h (100%) rename {src => mobile/src}/operators/fusion_conv_add_bn_relu_op.cpp (100%) rename {src => mobile/src}/operators/fusion_conv_add_bn_relu_op.h (100%) rename {src => mobile/src}/operators/fusion_conv_add_op.cpp (100%) rename {src => mobile/src}/operators/fusion_conv_add_op.h (100%) rename {src => mobile/src}/operators/fusion_conv_add_relu_op.cpp (100%) rename {src => mobile/src}/operators/fusion_conv_add_relu_op.h (100%) rename {src => mobile/src}/operators/fusion_conv_bn_add_relu_op.cpp (100%) rename {src => mobile/src}/operators/fusion_conv_bn_add_relu_op.h (100%) rename {src => mobile/src}/operators/fusion_conv_bn_op.cpp (100%) rename {src => mobile/src}/operators/fusion_conv_bn_op.h (100%) rename {src => mobile/src}/operators/fusion_conv_bn_relu_op.cpp (100%) rename {src => mobile/src}/operators/fusion_conv_bn_relu_op.h (100%) rename {src => mobile/src}/operators/fusion_conv_relu_op.cpp (100%) rename {src => mobile/src}/operators/fusion_conv_relu_op.h (100%) rename {src => mobile/src}/operators/fusion_deconv_add_bn_op.cpp (100%) rename {src => mobile/src}/operators/fusion_deconv_add_bn_op.h (100%) rename {src => mobile/src}/operators/fusion_deconv_add_bn_relu_op.cpp (100%) rename {src => mobile/src}/operators/fusion_deconv_add_bn_relu_op.h (100%) rename {src => mobile/src}/operators/fusion_deconv_add_op.cpp (100%) rename {src => mobile/src}/operators/fusion_deconv_add_op.h (100%) rename {src => mobile/src}/operators/fusion_deconv_add_relu_op.cpp (100%) rename {src => mobile/src}/operators/fusion_deconv_add_relu_op.h (100%) rename {src => mobile/src}/operators/fusion_deconv_bn_relu_op.cpp (100%) rename {src => mobile/src}/operators/fusion_deconv_bn_relu_op.h (100%) rename {src => mobile/src}/operators/fusion_deconv_relu_op.cpp (100%) rename {src => mobile/src}/operators/fusion_deconv_relu_op.h (100%) rename {src => mobile/src}/operators/fusion_dequant_add_bn_op.cpp (100%) rename {src => mobile/src}/operators/fusion_dequant_add_bn_op.h (100%) rename {src => mobile/src}/operators/fusion_dequant_add_bn_relu_op.cpp (100%) rename {src => mobile/src}/operators/fusion_dequant_add_bn_relu_op.h (100%) rename {src => mobile/src}/operators/fusion_dequant_add_bn_relu_quant_op.cpp (100%) rename {src => mobile/src}/operators/fusion_dequant_add_bn_relu_quant_op.h (100%) rename {src => mobile/src}/operators/fusion_dequant_bn_op.cpp (100%) rename {src => mobile/src}/operators/fusion_dequant_bn_op.h (100%) rename {src => mobile/src}/operators/fusion_dequant_bn_relu_op.h (100%) rename {src => mobile/src}/operators/fusion_dwconv_bn_relu_op.cpp (100%) rename {src => mobile/src}/operators/fusion_dwconv_bn_relu_op.h (100%) rename {src => mobile/src}/operators/fusion_elementwise_add_relu_op.cpp (100%) rename {src => mobile/src}/operators/fusion_elementwise_add_relu_op.h (100%) rename {src => mobile/src}/operators/fusion_fc_op.cpp (100%) rename {src => mobile/src}/operators/fusion_fc_op.h (100%) rename {src => mobile/src}/operators/fusion_fc_relu_op.cpp (100%) rename {src => mobile/src}/operators/fusion_fc_relu_op.h (100%) rename {src => mobile/src}/operators/gru_op.cpp (100%) rename {src => mobile/src}/operators/gru_op.h (100%) rename {src => mobile/src}/operators/gru_unit_op.cpp (100%) rename {src => mobile/src}/operators/gru_unit_op.h (100%) rename {src => mobile/src}/operators/im2sequence_op.cpp (100%) rename {src => mobile/src}/operators/im2sequence_op.h (100%) rename {src => mobile/src}/operators/increment_op.cpp (100%) rename {src => mobile/src}/operators/increment_op.h (100%) rename {src => mobile/src}/operators/is_empty_op.cpp (100%) rename {src => mobile/src}/operators/is_empty_op.h (100%) rename {src => mobile/src}/operators/kernel/activation_kernel.h (100%) rename {src => mobile/src}/operators/kernel/arm/activation_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/anchor_generator_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/assign_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/assign_value_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/batchnorm_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/beam_search_decode_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/beam_search_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/bilinear_interp_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/box_coder_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/cast_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/compare_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/concat_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/conditional_block_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/convolution/conv_add_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/convolution/conv_common.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/convolution/conv_common.h (100%) rename {src => mobile/src}/operators/kernel/arm/convolution/conv_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/convolution/conv_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/convolution/conv_transpose_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/crf_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/density_prior_box_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/dequantize_bn_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/dequantize_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/dropout_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/elementwise_add_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/elementwise_mul_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/elementwise_sub_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/exp_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/feed_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/fetch_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/flatten_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/fusion_fc_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/gru_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/gru_unit_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/im2sequence_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/increment_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/is_empty_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/lod_reset_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/logical_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/lookup_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/lrn_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/mul_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/multiclass_nms_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/nearest_interp_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/norm_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/one_hot_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/pad2d_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/polygon_box_transform_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/pool_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/prelu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/prior_box_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/proposal_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/psroi_pool_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/quantize_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/reshape2_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/reshape_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/resize_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/roi_perspective_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/scale_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/sequence_expand_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/sequence_pool_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/sequence_softmax_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/shape_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/slice_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/softmax_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/split_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/sum_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/tensor_array_read_write_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/top_k_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/transpose2_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/transpose_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/arm/while_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/assign_kernel.h (100%) rename {src => mobile/src}/operators/kernel/assign_value_kernel.h (100%) rename {src => mobile/src}/operators/kernel/batchnorm_kernel.h (100%) rename {src => mobile/src}/operators/kernel/beam_search_decode_kernel.h (100%) rename {src => mobile/src}/operators/kernel/beam_search_kernel.h (100%) rename {src => mobile/src}/operators/kernel/bilinear_interp_kernel.h (100%) rename {src => mobile/src}/operators/kernel/box_coder_kernel.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/activation_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/batchnorm_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/bilinear_interp_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/box_coder_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/concat_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/conv_add_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/conv_add_relu_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/conv_arm_func.cpp (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/conv_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/conv_transpose_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/crf_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/density_prior_box_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/elementwise_add_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/elementwise_mul_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/elementwise_sub_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/flatten_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/fusion_fc_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/gru_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/gru_unit_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/increment_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/lookup_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/lrn_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/mul_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/multiclass_nms_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/norm_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/polygon_box_transform_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/pool_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/prior_box_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/reshape2_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/reshape_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/shape_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/softmax_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/split_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/sum_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/central-arm-func/transpose_arm_func.h (100%) rename {src => mobile/src}/operators/kernel/cl/batchnorm_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/box_coder_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/cl-kernel-func/conv_func.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/cl-kernel-func/conv_func.h (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/batchnorm_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/box_coder_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/channel_add_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/cl_common.h (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/concat_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/conv_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/conv_kernel.inc.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/density_prior_box_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/depthwise_conv_add_bn_relu_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/depthwise_conv_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/dropout_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/elementwise_add_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/exp_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/feed_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/fetch_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/flatten2_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/leakyrelu_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/lrn_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/nearest_interp_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/pool_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/prior_box_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/relu.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/relu6.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/reshape.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/scale_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/sigmoid.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/slice_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/softmax.cl (100%) rename {src => mobile/src}/operators/kernel/cl/cl_kernel/transpose_kernel.cl (100%) rename {src => mobile/src}/operators/kernel/cl/concat_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/conv_add_bn_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/conv_add_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/conv_add_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/conv_bn_add_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/conv_bn_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/conv_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/conv_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/density_prior_box_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/depthwise_conv_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/dropout_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/dwconv_bn_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/elementwise_add_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/exp_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/feed_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/fetch_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/flatten2_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/fusion_fc_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/leakyrelu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/lrn_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/mul_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/multiclass_nms_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/nearest_interp_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/pool_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/prior_box_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/relu6_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/reshape2_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/reshape_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/scale_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/sigmoid_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/slice_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/softmax_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/split_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/transpose2_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/cl/transpose_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/compare_kernel.h (100%) rename {src => mobile/src}/operators/kernel/concat_kernel.h (100%) rename {src => mobile/src}/operators/kernel/conditional_block_kernel.h (100%) rename {src => mobile/src}/operators/kernel/conv_add_bn_kernel.h (100%) rename {src => mobile/src}/operators/kernel/conv_add_bn_relu_kernel.h (100%) rename {src => mobile/src}/operators/kernel/conv_add_kernel.h (100%) rename {src => mobile/src}/operators/kernel/conv_add_relu_kernel.h (100%) rename {src => mobile/src}/operators/kernel/conv_bn_add_relu_kernel.h (100%) rename {src => mobile/src}/operators/kernel/conv_bn_kernel.h (100%) rename {src => mobile/src}/operators/kernel/conv_bn_relu_kernel.h (100%) rename {src => mobile/src}/operators/kernel/conv_kernel.h (100%) rename {src => mobile/src}/operators/kernel/conv_relu_kernel.h (100%) rename {src => mobile/src}/operators/kernel/conv_transpose_kernel.h (100%) rename {src => mobile/src}/operators/kernel/crf_kernel.h (100%) rename {src => mobile/src}/operators/kernel/deconv_add_bn_kernel.h (100%) rename {src => mobile/src}/operators/kernel/deconv_add_bn_relu_kernel.h (100%) rename {src => mobile/src}/operators/kernel/deconv_add_kernel.h (100%) rename {src => mobile/src}/operators/kernel/deconv_add_relu_kernel.h (100%) rename {src => mobile/src}/operators/kernel/deconv_bn_relu_kernel.h (100%) rename {src => mobile/src}/operators/kernel/deconv_relu_kernel.h (100%) rename {src => mobile/src}/operators/kernel/dequant_bn_kernel.h (100%) rename {src => mobile/src}/operators/kernel/dequantize_kernel.h (100%) rename {src => mobile/src}/operators/kernel/detection_kernel.h (100%) rename {src => mobile/src}/operators/kernel/dropout_kernel.h (100%) rename {src => mobile/src}/operators/kernel/dwconv_bn_relu_kernel.h (100%) rename {src => mobile/src}/operators/kernel/elementwise_add_kernel.h (100%) rename {src => mobile/src}/operators/kernel/elementwise_add_relu_kernel.h (100%) rename {src => mobile/src}/operators/kernel/elementwise_mul_kernel.h (100%) rename {src => mobile/src}/operators/kernel/elementwise_sub_kernel.h (100%) rename {src => mobile/src}/operators/kernel/exp_kernel.h (100%) rename {src => mobile/src}/operators/kernel/fc_relu_kernel.h (100%) rename {src => mobile/src}/operators/kernel/feed_kernel.h (100%) rename {src => mobile/src}/operators/kernel/fetch_kernel.h (100%) rename {src => mobile/src}/operators/kernel/flatten2_kernel.h (100%) rename {src => mobile/src}/operators/kernel/flatten_kernel.h (100%) rename {src => mobile/src}/operators/kernel/fpga/KD/conv_add_bn_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/KD/conv_add_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/KD/conv_add_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/KD/conv_bn_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/KD/conv_bn_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/KD/elementwise_add_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/KD/feed_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/KD/fetch_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/KD/fusion_fc_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/KD/pool_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/KD/softmax_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/anchor_generator_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/concat_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/conv_add_bn_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/conv_add_bn_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/conv_add_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/conv_add_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/conv_bn_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/conv_bn_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/conv_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/conv_transpose_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/deconv_add_bn_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/deconv_add_bn_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/deconv_add_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/deconv_add_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/deconv_bn_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/dropout_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/elementwise_add_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/elementwise_add_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/elementwise_mul_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/feed_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/fetch_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/fusion_fc_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/fusion_fc_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/pad2d_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/pool_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/proposal_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/psroi_pool_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/reshape2_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/reshape_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/roialign_pool_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/sigmoid_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/slice_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/softmax_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/split_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/tanh_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V1/transpose2_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/anchor_generator_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/concat_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/conv_add_bn_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/conv_add_bn_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/conv_add_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/conv_add_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/conv_bn_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/conv_bn_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/conv_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/conv_transpose_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/deconv_add_bn_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/deconv_add_bn_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/deconv_add_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/deconv_add_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/deconv_bn_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/dropout_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/elementwise_add_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/elementwise_add_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/elementwise_mul_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/feed_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/fetch_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/fusion_fc_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/fusion_fc_relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/pool_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/proposal_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/psroi_pool_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/relu_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/reshape2_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/reshape_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/roialign_pool_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/sigmoid_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/slice_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/softmax_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/split_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/tanh_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fpga/V2/transpose2_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/fusion_fc_kernel.h (100%) rename {src => mobile/src}/operators/kernel/gru_kernel.h (100%) rename {src => mobile/src}/operators/kernel/gru_unit_kernel.h (100%) rename {src => mobile/src}/operators/kernel/im2sequence_kernel.h (100%) rename {src => mobile/src}/operators/kernel/increment_kernel.h (100%) rename {src => mobile/src}/operators/kernel/is_empty_kernel.h (100%) rename {src => mobile/src}/operators/kernel/kernels.h (100%) rename {src => mobile/src}/operators/kernel/logical_kernel.h (100%) rename {src => mobile/src}/operators/kernel/lookup_kernel.h (100%) rename {src => mobile/src}/operators/kernel/lrn_kernel.h (100%) rename {src => mobile/src}/operators/kernel/mul_kernel.h (100%) rename {src => mobile/src}/operators/kernel/multiclass_nms_kernel.h (100%) rename {src => mobile/src}/operators/kernel/nearest_interp_kernel.h (100%) rename {src => mobile/src}/operators/kernel/norm_kernel.h (100%) rename {src => mobile/src}/operators/kernel/one_hot_kernel.h (100%) rename {src => mobile/src}/operators/kernel/pad2d_kernel.h (100%) rename {src => mobile/src}/operators/kernel/polygon_box_transform_kernel.h (100%) rename {src => mobile/src}/operators/kernel/pool_kernel.h (100%) rename {src => mobile/src}/operators/kernel/prelu_kernel.h (100%) rename {src => mobile/src}/operators/kernel/prior_box_kernel.h (100%) rename {src => mobile/src}/operators/kernel/quantize_kernel.h (100%) rename {src => mobile/src}/operators/kernel/range_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/range_kernel.h (100%) rename {src => mobile/src}/operators/kernel/reduce_prod_kernel.cpp (100%) rename {src => mobile/src}/operators/kernel/reduce_prod_kernel.h (100%) rename {src => mobile/src}/operators/kernel/reshape2_kernel.h (100%) rename {src => mobile/src}/operators/kernel/reshape_kernel.h (100%) rename {src => mobile/src}/operators/kernel/resize_kernel.h (100%) rename {src => mobile/src}/operators/kernel/scale_kernel.h (100%) rename {src => mobile/src}/operators/kernel/sequence_kernels.h (100%) rename {src => mobile/src}/operators/kernel/shape_kernel.h (100%) rename {src => mobile/src}/operators/kernel/slice_kernel.h (100%) rename {src => mobile/src}/operators/kernel/softmax_kernel.h (100%) rename {src => mobile/src}/operators/kernel/split_kernel.h (100%) rename {src => mobile/src}/operators/kernel/sum_kernel.h (100%) rename {src => mobile/src}/operators/kernel/tanh_kernel.h (100%) rename {src => mobile/src}/operators/kernel/tensor_array_read_write_kernel.h (100%) rename {src => mobile/src}/operators/kernel/transpose2_kernel.h (100%) rename {src => mobile/src}/operators/kernel/transpose_kernel.h (100%) rename {src => mobile/src}/operators/kernel/while_kernel.h (100%) rename {src => mobile/src}/operators/lod_reset_op.cpp (100%) rename {src => mobile/src}/operators/lod_reset_op.h (100%) rename {src => mobile/src}/operators/logical_op.cpp (100%) rename {src => mobile/src}/operators/logical_op.h (100%) rename {src => mobile/src}/operators/lookup_op.cpp (100%) rename {src => mobile/src}/operators/lookup_op.h (100%) rename {src => mobile/src}/operators/lrn_op.cpp (100%) rename {src => mobile/src}/operators/lrn_op.h (100%) rename {src => mobile/src}/operators/math/activation.h (100%) rename {src => mobile/src}/operators/math/depthwise/faster_depthwise_conv3x3.h (100%) rename {src => mobile/src}/operators/math/depthwise/faster_depthwise_conv3x3p1.cpp (100%) rename {src => mobile/src}/operators/math/depthwise_conv3x3.cpp (100%) rename {src => mobile/src}/operators/math/depthwise_conv3x3.h (100%) rename {src => mobile/src}/operators/math/depthwise_conv3x3_int8.cpp (100%) rename {src => mobile/src}/operators/math/depthwise_conv5x5.cpp (100%) rename {src => mobile/src}/operators/math/depthwise_conv5x5.h (100%) rename {src => mobile/src}/operators/math/depthwise_conv5x5_int8.cpp (100%) rename {src => mobile/src}/operators/math/element_wise.h (100%) rename {src => mobile/src}/operators/math/elementwise_op_function.h (100%) rename {src => mobile/src}/operators/math/gemm.cpp (100%) rename {src => mobile/src}/operators/math/gemm.h (100%) rename {src => mobile/src}/operators/math/gemm/cblas.cc (100%) rename {src => mobile/src}/operators/math/gemm/cblas.h (100%) rename {src => mobile/src}/operators/math/gemm/executor.h (100%) rename {src => mobile/src}/operators/math/gemm/gemm1x1s1.cpp (100%) rename {src => mobile/src}/operators/math/gemm/gemm1x1s1.h (100%) rename {src => mobile/src}/operators/math/gemm/gemm_kernel.h (100%) rename {src => mobile/src}/operators/math/gemm/pack_kernel.h (100%) rename {src => mobile/src}/operators/math/gemm/strategy.h (100%) rename {src => mobile/src}/operators/math/gemm_int8.cpp (100%) rename {src => mobile/src}/operators/math/gemm_omp_int8.cpp (100%) rename {src => mobile/src}/operators/math/gpc.cpp (100%) rename {src => mobile/src}/operators/math/gpc.h (100%) rename {src => mobile/src}/operators/math/gru_compute.cpp (100%) rename {src => mobile/src}/operators/math/gru_compute.h (100%) rename {src => mobile/src}/operators/math/gru_cpu_kernel.h (100%) rename {src => mobile/src}/operators/math/im2col.cpp (100%) rename {src => mobile/src}/operators/math/im2col.h (100%) rename {src => mobile/src}/operators/math/math.h (100%) rename {src => mobile/src}/operators/math/math_function.cpp (100%) rename {src => mobile/src}/operators/math/math_function.h (100%) rename {src => mobile/src}/operators/math/math_function_int8.cpp (100%) rename {src => mobile/src}/operators/math/pad.cpp (100%) rename {src => mobile/src}/operators/math/pad.h (100%) rename {src => mobile/src}/operators/math/poly_util.cpp (100%) rename {src => mobile/src}/operators/math/poly_util.h (100%) rename {src => mobile/src}/operators/math/pooling.cpp (100%) rename {src => mobile/src}/operators/math/pooling.h (100%) rename {src => mobile/src}/operators/math/pooling2x2.cpp (100%) rename {src => mobile/src}/operators/math/pooling3x3.cpp (100%) rename {src => mobile/src}/operators/math/quantize.h (100%) rename {src => mobile/src}/operators/math/selected_rows_functor.h (100%) rename {src => mobile/src}/operators/math/sequence2batch.cpp (100%) rename {src => mobile/src}/operators/math/sequence2batch.h (100%) rename {src => mobile/src}/operators/math/slidingwindow_conv3x3.cpp (100%) rename {src => mobile/src}/operators/math/slidingwindow_conv3x3.h (100%) rename {src => mobile/src}/operators/math/slidingwindow_utils.cpp (100%) rename {src => mobile/src}/operators/math/slidingwindow_utils.h (100%) rename {src => mobile/src}/operators/math/softmax.cpp (100%) rename {src => mobile/src}/operators/math/softmax.h (100%) rename {src => mobile/src}/operators/math/transform.h (100%) rename {src => mobile/src}/operators/math/vol2col.cpp (100%) rename {src => mobile/src}/operators/math/vol2col.h (100%) rename {src => mobile/src}/operators/math/winograd/winograd_transform.h (100%) rename {src => mobile/src}/operators/math/winograd/winograd_transform_f6k3.cpp (100%) rename {src => mobile/src}/operators/mul_op.cpp (100%) rename {src => mobile/src}/operators/mul_op.h (100%) rename {src => mobile/src}/operators/multiclass_nms_op.cpp (100%) rename {src => mobile/src}/operators/multiclass_nms_op.h (100%) rename {src => mobile/src}/operators/nearest_interp_op.cpp (100%) rename {src => mobile/src}/operators/nearest_interp_op.h (100%) rename {src => mobile/src}/operators/norm_op.cpp (100%) rename {src => mobile/src}/operators/norm_op.h (100%) rename {src => mobile/src}/operators/one_hot_op.cpp (100%) rename {src => mobile/src}/operators/one_hot_op.h (100%) rename {src => mobile/src}/operators/op_param.cpp (100%) rename {src => mobile/src}/operators/op_param.h (100%) rename {src => mobile/src}/operators/pad2d_op.cpp (100%) rename {src => mobile/src}/operators/pad2d_op.h (100%) rename {src => mobile/src}/operators/polygon_box_transform_op.cpp (100%) rename {src => mobile/src}/operators/polygon_box_transform_op.h (100%) rename {src => mobile/src}/operators/pool_op.cpp (100%) rename {src => mobile/src}/operators/pool_op.h (100%) rename {src => mobile/src}/operators/prelu_op.cpp (100%) rename {src => mobile/src}/operators/prelu_op.h (100%) rename {src => mobile/src}/operators/prior_box_op.cpp (100%) rename {src => mobile/src}/operators/prior_box_op.h (100%) rename {src => mobile/src}/operators/quantize_op.cpp (100%) rename {src => mobile/src}/operators/quantize_op.h (100%) rename {src => mobile/src}/operators/range_op.cpp (100%) rename {src => mobile/src}/operators/range_op.h (100%) rename {src => mobile/src}/operators/reduce_prod_op.cpp (100%) rename {src => mobile/src}/operators/reduce_prod_op.h (100%) rename {src => mobile/src}/operators/reshape2_op.cpp (100%) rename {src => mobile/src}/operators/reshape2_op.h (100%) rename {src => mobile/src}/operators/reshape_op.cpp (100%) rename {src => mobile/src}/operators/reshape_op.h (100%) rename {src => mobile/src}/operators/resize_op.cpp (100%) rename {src => mobile/src}/operators/resize_op.h (100%) rename {src => mobile/src}/operators/scale_op.cpp (100%) rename {src => mobile/src}/operators/scale_op.h (100%) rename {src => mobile/src}/operators/sequence_ops/sequence_expand_op.cpp (100%) rename {src => mobile/src}/operators/sequence_ops/sequence_expand_op.h (100%) rename {src => mobile/src}/operators/sequence_ops/sequence_pool_op.cpp (100%) rename {src => mobile/src}/operators/sequence_ops/sequence_pool_op.h (100%) rename {src => mobile/src}/operators/sequence_ops/sequence_softmax_op.cpp (100%) rename {src => mobile/src}/operators/sequence_ops/sequence_softmax_op.h (100%) rename {src => mobile/src}/operators/shape_op.cpp (100%) rename {src => mobile/src}/operators/shape_op.h (100%) rename {src => mobile/src}/operators/slice_op.cpp (100%) rename {src => mobile/src}/operators/slice_op.h (100%) rename {src => mobile/src}/operators/softmax_op.cpp (100%) rename {src => mobile/src}/operators/softmax_op.h (100%) rename {src => mobile/src}/operators/split_op.cpp (100%) rename {src => mobile/src}/operators/split_op.h (100%) rename {src => mobile/src}/operators/sum_op.cpp (100%) rename {src => mobile/src}/operators/sum_op.h (100%) rename {src => mobile/src}/operators/top_k_op.cpp (100%) rename {src => mobile/src}/operators/top_k_op.h (100%) rename {src => mobile/src}/operators/transpose2_op.cpp (100%) rename {src => mobile/src}/operators/transpose2_op.h (100%) rename {src => mobile/src}/operators/transpose_op.cpp (100%) rename {src => mobile/src}/operators/transpose_op.h (100%) rename {src => mobile/src}/pass/memory_optimize.cpp (100%) rename {src => mobile/src}/pass/memory_optimize.h (100%) rename {src => mobile/src}/pass/memory_optimize_super.cpp (100%) rename {src => mobile/src}/pass/memory_optimize_super.h (100%) rename {src => mobile/src}/pass/model_obfuscate.cpp (100%) rename {src => mobile/src}/pass/model_obfuscate.h (100%) rename {src => mobile/src}/pass/pass_base.h (100%) rename {src => mobile/src}/protobuf-c/protobuf-c.c (100%) rename {src => mobile/src}/protobuf-c/protobuf-c.h (100%) rename {test => mobile/test}/CMakeLists.txt (100%) rename {test => mobile/test}/common/test_enforce.cpp (100%) rename {test => mobile/test}/common/test_gemm_accuracy.cpp (100%) rename {test => mobile/test}/common/test_gemm_int8_accuracy.cpp (100%) rename {test => mobile/test}/common/test_gemm_perf.cpp (100%) rename {test => mobile/test}/common/test_lib_size.cpp (100%) rename {test => mobile/test}/common/test_lib_size.h (100%) rename {test => mobile/test}/common/test_log.cpp (100%) rename {test => mobile/test}/common/test_openmp.cpp (100%) rename {test => mobile/test}/executor_for_test.h (100%) rename {test => mobile/test}/fpga/test_concat_op.cpp (100%) rename {test => mobile/test}/fpga/test_densebox_combine.cpp (100%) rename {test => mobile/test}/fpga/test_format_data.cpp (100%) rename {test => mobile/test}/fpga/test_marker.cpp (100%) rename {test => mobile/test}/fpga/test_marker2.cpp (100%) rename {test => mobile/test}/fpga/test_marker_api.cpp (100%) rename {test => mobile/test}/fpga/test_mobilenet_api.cpp (100%) rename {test => mobile/test}/fpga/test_pe.cpp (100%) rename {test => mobile/test}/fpga/test_resnet50.cpp (100%) rename {test => mobile/test}/fpga/test_rfcn.cpp (100%) rename {test => mobile/test}/fpga/test_rfcn_api.cpp (100%) rename {test => mobile/test}/fpga/test_ssd.cpp (100%) rename {test => mobile/test}/fpga/test_tensor_quant.cpp (100%) rename {test => mobile/test}/fpga/test_yolo_api.cpp (100%) rename {test => mobile/test}/framework/test_inference_api.cpp (100%) rename {test => mobile/test}/framework/test_load.cpp (100%) rename {test => mobile/test}/framework/test_load_memory.cpp (100%) rename {test => mobile/test}/framework/test_load_memory_inference_api.cpp (100%) rename {test => mobile/test}/framework/test_optimize.cpp (100%) rename {test => mobile/test}/net/test_alexnet.cpp (100%) rename {test => mobile/test}/net/test_benchmark.cpp (100%) rename {test => mobile/test}/net/test_eng.cpp (100%) rename {test => mobile/test}/net/test_genet_combine.cpp (100%) rename {test => mobile/test}/net/test_gesture.cpp (100%) rename {test => mobile/test}/net/test_googlenet.cpp (100%) rename {test => mobile/test}/net/test_googlenet_quali.cpp (100%) rename {test => mobile/test}/net/test_googlenetv1_combine.cpp (100%) rename {test => mobile/test}/net/test_inceptionv4.cpp (100%) rename {test => mobile/test}/net/test_mobilenet+ssd.cpp (100%) rename {test => mobile/test}/net/test_mobilenet.cpp (100%) rename {test => mobile/test}/net/test_mobilenet_025_fssd.cpp (100%) rename {test => mobile/test}/net/test_mobilenet_GPU.cpp (100%) rename {test => mobile/test}/net/test_mobilenet_combine.cpp (100%) rename {test => mobile/test}/net/test_multi_inference_predict.cpp (100%) rename {test => mobile/test}/net/test_net.cpp (100%) rename {test => mobile/test}/net/test_net_benchmark.cpp (100%) rename {test => mobile/test}/net/test_nlp.cpp (100%) rename {test => mobile/test}/net/test_ocr.cpp (100%) rename {test => mobile/test}/net/test_op_in_net.cpp (100%) rename {test => mobile/test}/net/test_resnet.cpp (100%) rename {test => mobile/test}/net/test_squeezenet.cpp (100%) rename {test => mobile/test}/net/test_super.cpp (100%) rename {test => mobile/test}/net/test_vgg16ssd.cpp (100%) rename {test => mobile/test}/net/test_wrap.cpp (100%) rename {test => mobile/test}/net/test_yolo.cpp (100%) rename {test => mobile/test}/net/test_yolo_combined.cpp (100%) rename {test => mobile/test}/net/test_yologpu.cpp (100%) rename {test => mobile/test}/operators/test_batchnorm_op.cpp (100%) rename {test => mobile/test}/operators/test_box_coder_op.cpp (100%) rename {test => mobile/test}/operators/test_cast_op.cpp (100%) rename {test => mobile/test}/operators/test_concat_op.cpp (100%) rename {test => mobile/test}/operators/test_conv_add_relu_op.cpp (100%) rename {test => mobile/test}/operators/test_conv_bn_relu_op.cpp (100%) rename {test => mobile/test}/operators/test_conv_gpu.cpp (100%) rename {test => mobile/test}/operators/test_conv_op.cpp (100%) rename {test => mobile/test}/operators/test_depthwise_conv_op.cpp (100%) rename {test => mobile/test}/operators/test_dequantize_op.cpp (100%) rename {test => mobile/test}/operators/test_dwconv_bn_relu_op.cpp (100%) rename {test => mobile/test}/operators/test_elementwise_add_op.cpp (100%) rename {test => mobile/test}/operators/test_elementwise_sub_op.cpp (100%) rename {test => mobile/test}/operators/test_fill_constant_op.cpp (100%) rename {test => mobile/test}/operators/test_fusion_conv_add_bn_relu_op.cpp (100%) rename {test => mobile/test}/operators/test_fusion_fc_op.cpp (100%) rename {test => mobile/test}/operators/test_gru_op.cpp (100%) rename {test => mobile/test}/operators/test_im2sequence_op.cpp (100%) rename {test => mobile/test}/operators/test_increment_op.cpp (100%) rename {test => mobile/test}/operators/test_is_empty_op.cpp (100%) rename {test => mobile/test}/operators/test_leaky_relu_op.cpp (100%) rename {test => mobile/test}/operators/test_less_than_op.cpp (100%) rename {test => mobile/test}/operators/test_log_op.cpp (100%) rename {test => mobile/test}/operators/test_logical_and_op.cpp (100%) rename {test => mobile/test}/operators/test_logical_not_op.cpp (100%) rename {test => mobile/test}/operators/test_logical_or_op.cpp (100%) rename {test => mobile/test}/operators/test_logical_xor_op.cpp (100%) rename {test => mobile/test}/operators/test_lrn_op.cpp (100%) rename {test => mobile/test}/operators/test_mul_op.cpp (100%) rename {test => mobile/test}/operators/test_multiclass_nms_op.cpp (100%) rename {test => mobile/test}/operators/test_polygon_box_transform_op.cpp (100%) rename {test => mobile/test}/operators/test_pool_op.cpp (100%) rename {test => mobile/test}/operators/test_prelu_op.cpp (100%) rename {test => mobile/test}/operators/test_prior_box_op.cpp (100%) rename {test => mobile/test}/operators/test_quantize_op.cpp (100%) rename {test => mobile/test}/operators/test_relu6_op.cpp (100%) rename {test => mobile/test}/operators/test_relu_op.cpp (100%) rename {test => mobile/test}/operators/test_reshape2_op.cpp (100%) rename {test => mobile/test}/operators/test_reshape_op.cpp (100%) rename {test => mobile/test}/operators/test_resize_op.cpp (100%) rename {test => mobile/test}/operators/test_scale_op.cpp (100%) rename {test => mobile/test}/operators/test_sequence_expand_op.cpp (100%) rename {test => mobile/test}/operators/test_sequence_pool_op.cpp (100%) rename {test => mobile/test}/operators/test_sequence_softmax_op.cpp (100%) rename {test => mobile/test}/operators/test_sigmoid_op.cpp (100%) rename {test => mobile/test}/operators/test_slice_op.cpp (100%) rename {test => mobile/test}/operators/test_softmax_op.cpp (100%) rename {test => mobile/test}/operators/test_sum_op.cpp (100%) rename {test => mobile/test}/operators/test_tanh_op.cpp (100%) rename {test => mobile/test}/operators/test_topk_op.cpp (100%) rename {test => mobile/test}/operators/test_transpose2_op.cpp (100%) rename {test => mobile/test}/operators/test_transpose_op.cpp (100%) rename {test => mobile/test}/test_helper.h (100%) rename {test => mobile/test}/test_include.h (100%) rename {third_party => mobile/third_party}/opencl/OpenCL-Headers/CL/cl.h (99%) rename {third_party => mobile/third_party}/opencl/OpenCL-Headers/CL/cl_d3d10.h (99%) rename {third_party => mobile/third_party}/opencl/OpenCL-Headers/CL/cl_d3d11.h (99%) rename {third_party => mobile/third_party}/opencl/OpenCL-Headers/CL/cl_dx9_media_sharing.h (99%) rename {third_party => mobile/third_party}/opencl/OpenCL-Headers/CL/cl_dx9_media_sharing_intel.h (99%) rename {third_party => mobile/third_party}/opencl/OpenCL-Headers/CL/cl_egl.h (100%) rename {third_party => mobile/third_party}/opencl/OpenCL-Headers/CL/cl_ext.h (100%) rename {third_party => mobile/third_party}/opencl/OpenCL-Headers/CL/cl_ext_intel.h (99%) rename {third_party => mobile/third_party}/opencl/OpenCL-Headers/CL/cl_gl.h (100%) rename {third_party => mobile/third_party}/opencl/OpenCL-Headers/CL/cl_gl_ext.h (100%) rename {third_party => mobile/third_party}/opencl/OpenCL-Headers/CL/cl_platform.h (100%) rename {third_party => mobile/third_party}/opencl/OpenCL-Headers/CL/cl_va_api_media_sharing_intel.h (99%) rename {third_party => mobile/third_party}/opencl/OpenCL-Headers/CL/cl_version.h (100%) rename {third_party => mobile/third_party}/opencl/OpenCL-Headers/CL/opencl.h (99%) rename {third_party => mobile/third_party}/opencl/OpenCL-Headers/LICENSE (100%) rename {third_party => mobile/third_party}/opencl/OpenCL-Headers/README.md (100%) rename {tools => mobile/tools}/android-cmake/android.toolchain.cmake (100%) rename {tools => mobile/tools}/android-debug-script/push2android.sh (100%) rename {tools => mobile/tools}/android-debug-script/run_on_android.sh (99%) rename {tools => mobile/tools}/arm-platform.cmake (100%) rename {tools => mobile/tools}/build.sh (100%) rename {tools => mobile/tools}/ci_build.sh (100%) rename {tools => mobile/tools}/ci_run_test.sh (100%) rename {tools => mobile/tools}/docker_build_fpga.sh (99%) rename {tools => mobile/tools}/ios-cmake/ios.toolchain.cmake (100%) rename {tools => mobile/tools}/net-detail.awk (100%) rename {tools => mobile/tools}/net.awk (100%) rename {tools => mobile/tools}/op.cmake (100%) rename {tools => mobile/tools}/pre-commit.hooks/clang-format.hook (100%) rename {tools => mobile/tools}/pre-commit.hooks/clang-tidy.hook (100%) rename {tools => mobile/tools}/pre-commit.hooks/copyright.hook (100%) rename {tools => mobile/tools}/pre-commit.hooks/cpplint.hook (100%) rename {tools => mobile/tools}/prepare_images_and_models.sh (100%) rename {tools => mobile/tools}/profile_show.sh (99%) rename {tools => mobile/tools}/python/caffetools/run.py (100%) rename {tools => mobile/tools}/python/fluidtools/.gitignore (100%) rename {tools => mobile/tools}/python/fluidtools/run.py (100%) rename {tools => mobile/tools}/python/imagetools/README.md (100%) rename {tools => mobile/tools}/python/imagetools/imagetools.py (100%) rename {tools => mobile/tools}/python/imagetools/img2nchw.py (100%) rename {tools => mobile/tools}/python/imagetools/img2nhwc.py (100%) rename {tools => mobile/tools}/python/imagetools/numpy2binary.py (100%) rename {tools => mobile/tools}/python/misc/.gitignore (100%) rename {tools => mobile/tools}/python/misc/fluidtools.py (100%) rename {tools => mobile/tools}/python/misc/ios-test-server.py (99%) rename {tools => mobile/tools}/python/misc/restore-git.py (100%) rename {tools => mobile/tools}/python/misc/test-fluid-op-feature.py (100%) rename {tools => mobile/tools}/python/modeltools/.gitignore (100%) rename {tools => mobile/tools}/python/modeltools/core/__init__.py (100%) rename {tools => mobile/tools}/python/modeltools/core/framework.proto (100%) rename {tools => mobile/tools}/python/modeltools/core/framework_pb2.py (100%) rename {tools => mobile/tools}/python/modeltools/core/op_types.py (100%) rename {tools => mobile/tools}/python/modeltools/mobilenet/__init__.py (100%) rename {tools => mobile/tools}/python/modeltools/mobilenet/converter_mobilenet.py (100%) rename {tools => mobile/tools}/python/modeltools/mobilenet/swicher.py (100%) rename {tools => mobile/tools}/python/modeltools/tools/__init__.py (100%) rename {tools => mobile/tools}/python/modeltools/tools/float2halffloat.py (100%) rename {tools => mobile/tools}/python/modeltools/tools/loader.py (99%) rename {tools => mobile/tools}/python/modeltools/tools/model_combine.py (100%) rename {tools => mobile/tools}/python/modeltools/tools/model_reader.py (100%) rename {tools => mobile/tools}/python/modeltools/yolo/__init__.py (100%) rename {tools => mobile/tools}/python/modeltools/yolo/mdl2fluid.py (100%) rename {tools => mobile/tools}/python/modeltools/yolo/swicher.py (100%) rename {tools => mobile/tools}/quantification/CMakeLists.txt (98%) rename {tools => mobile/tools}/quantification/README.md (99%) rename {tools => mobile/tools}/quantification/convert.cpp (99%) rename {tools => mobile/tools}/quantification/src/block_desc_local.cpp (100%) rename {tools => mobile/tools}/quantification/src/block_desc_local.h (100%) rename {tools => mobile/tools}/quantification/src/enforce.h (100%) rename {tools => mobile/tools}/quantification/src/framework.pb-c.c (100%) rename {tools => mobile/tools}/quantification/src/framework.pb-c.h (100%) rename {tools => mobile/tools}/quantification/src/program_desc.cpp (100%) rename {tools => mobile/tools}/quantification/src/program_desc.h (100%) rename {tools => mobile/tools}/quantification/src/protobuf-c.c (100%) rename {tools => mobile/tools}/quantification/src/protobuf-c.h (100%) rename {tools => mobile/tools}/quantification/src/tensor_desc.h (100%) rename {tools => mobile/tools}/quantification/src/var_desc.h (100%) rename {tools => mobile/tools}/shell/check-bitcode.sh (100%) rename {tools => mobile/tools}/shell/check-filename.sh (100%) rename {tools => mobile/tools}/shell/generate-include/.gitignore (100%) rename {tools => mobile/tools}/shell/generate-include/check_include_diff.sh (100%) rename {tools => mobile/tools}/shell/generate-include/main.cpp (100%) rename {tools => mobile/tools}/shell/generate-include/parse.py (100%) rename {tools => mobile/tools}/shell/generate-include/run.sh (100%) rename {tools => mobile/tools}/shell/merge.sh (100%) rename {tools => mobile/tools}/shell/prune_static_library.sh (100%) rename {tools => mobile/tools}/shell/restore-private-repo.sh (100%) rename {tools => mobile/tools}/toolchains/arm-android-neon.cmake (100%) rename {tools => mobile/tools}/toolchains/arm-linux-gnueabi.cmake (100%) rename {tools => mobile/tools}/toolchains/arm-linux-gnueabihf.cmake (100%) create mode 100644 tools/codestyle/.gitignore create mode 100755 tools/codestyle/clang_format.hook create mode 100644 tools/codestyle/copyright.hook create mode 100755 tools/codestyle/cpplint_pre_commit.hook create mode 100644 tools/codestyle/docstring_checker.py create mode 100755 tools/codestyle/pylint_pre_commit.hook create mode 100644 tools/codestyle/test_docstring_checker.py create mode 100755 tools/document_preview.sh diff --git a/.clang-format b/.clang-format index d59e0885794..8b583062734 100644 --- a/.clang-format +++ b/.clang-format @@ -1,5 +1,27 @@ +# This file is used by clang-format to autoformat paddle source code +# +# The clang-format is part of llvm toolchain. +# It need to install llvm and clang to format source code style. +# +# The basic usage is, +# clang-format -i -style=file PATH/TO/SOURCE/CODE +# +# The -style=file implicit use ".clang-format" file located in one of +# parent directory. +# The -i means inplace change. +# +# The document of clang-format is +# http://clang.llvm.org/docs/ClangFormat.html +# http://clang.llvm.org/docs/ClangFormatStyleOptions.html --- Language: Cpp -BasedOnStyle: Google +BasedOnStyle: Google +IndentWidth: 2 +TabWidth: 2 +ContinuationIndentWidth: 4 +AccessModifierOffset: -1 # The private/protected/public has no indent in class Standard: Cpp11 +AllowAllParametersOfDeclarationOnNextLine: true +BinPackParameters: false +BinPackArguments: false ... diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5f91f8b8aae..f4b7f4e375e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,10 +3,12 @@ repos: sha: v1.0.1 hooks: - id: remove-crlf - files: (src/).*\.(md|py|mm|swift|java|c|cc|cxx|cpp|cu|h|hpp|hxx)$ - - id: remove-tabs - files: (test/|src/).*\.(md|py|mm|swift|java|c|cc|cxx|cpp|cu|h|hpp|hxx)$ - + files: (?!.*third_party)^.*$ | (?!.*book)^.*$ ^mobile/ ^metal/ ^web/ +#- repo: https://github.com/PaddlePaddle/mirrors-yapf.git + #sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 + #hooks: + #- id: yapf + #files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ - repo: https://github.com/pre-commit/pre-commit-hooks sha: 5bf6c09bfa1297d3692cadd621ef95f1284e33c0 hooks: @@ -14,47 +16,37 @@ repos: - id: check-merge-conflict - id: check-symlinks - id: detect-private-key - files: (?!.*tar.gz)^.*$ + files: (?!.*third_party)^.*$ | (?!.*book)^.*$ - id: end-of-file-fixer - files: (test/|src/).*\.(md|py|mm|swift|java|c|cc|cxx|cpp|h|hpp|hxx)$ - - id: trailing-whitespace - files: (test/|src/).*\.(md|py|mm|swift|java|c|cc|cxx|cpp|h|hpp|hxx)$ - - repo: local hooks: - - id: copyright - name: copyright - entry: python ./tools/pre-commit.hooks/copyright.hook - language: system - files: (test/|src/).*\.(c|cc|cxx|cpp|h|hpp|hxx|py)$ - exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$ - -- repo: local - hooks: - - id: clang-format + - id: clang-format-with-version-check name: clang-format description: Format files with ClangFormat. - entry: bash ./tools/pre-commit.hooks/clang-format.hook -i + entry: bash ./tools/codestyle/clang_format.hook -i language: system - files: (test/|src/).*\.(c|cc|cxx|cpp|h|hpp|hxx)$ - + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$ ^mobile/ ^metal/ ^web/ - repo: local hooks: - - id: cpplint + - id: cpplint-cpp-source name: cpplint - description: Check C++ code style using cpplint. - entry: bash ./tools/pre-commit.hooks/cpplint.hook + description: Check C++ code style using cpplint.py. + entry: bash ./tools/codestyle/cpplint_pre_commit.hook language: system - files: (test/|src/).*\.(c|cc|cxx|cpp|h|hpp|hxx)$ - exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$i | *\.pb\.cpp - - -# + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx)$ ^mobile/ ^metal/ ^web/ #- repo: local -# hooks: -# - id: clang-tidy -# name: clang-tidy -# description: Check C++ code style using clang-tidy. -# entry: bash ./tools/pre-commit.hooks/.clang-tidy.hook -i -# language: system -# files: (src).*\.(c|cc|cxx|cpp|h|hpp|hxx)$ + #hooks: + #- id: pylint-doc-string + #name: pylint + #description: Check python docstring style using docstring_checker. + #entry: bash ./tools/codestyle/pylint_pre_commit.hook + #language: system + #files: \.(py)$ +- repo: local + hooks: + - id: copyright_checker + name: copyright_checker + entry: python ./tools/codestyle/copyright.hook + language: system + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ ^mobile/ ^metal/ ^web/ + exclude: (?!.*third_party)^.*$ | (?!.*book)^.*$ diff --git a/CMakeLists.txt b/CMakeLists.txt index d34e9738a5a..a829ec6bf62 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,288 +1,180 @@ -cmake_minimum_required(VERSION 3.0.0) - -# basic build option -if(IS_IOS) - option(USE_OPENMP "build with openmp support" OFF) -else() - option(USE_OPENMP "build with openmp support" ON) -endif() -option(USE_EXCEPTION "build with exception" ON) -option(WITH_LOGGING "print logging for debug" OFF) -option(WITH_SYMBOL "build with all symbols" ON) # turn off if use jni or ios io -option(WITH_PROFILE "print op profile for debug" OFF) -option(WITH_TEST "build with unit tests" ON) - -# select platform: CPU, GPU_CL, FPGA -option(CPU "build with arm CPU support" ON) -option(GPU_CL "build with OpenCL support" OFF) -option(FPGA "build with FPGA support" OFF) -if(FPGA) - option(FPGAV1 "build with fpga v1 support" ON) - option(FPGAV2 "build with fpga v2 support" OFF) - option(FPGAKD "build with fpga KD support" OFF) -endif() - -project(paddle-mobile) - -# source code -file(GLOB_RECURSE PADDLE_MOBILE_CC src/*.cc src/*.cpp src/*.c src/*.mm) -file(GLOB_RECURSE PADDLE_MOBILE_H src/*.h) -include_directories(src/) - -# build flags -set(CMAKE_CXX_FLAGS "-O3 -s -DNDEBUG ${CMAKE_CXX_FLAGS} -Wno-attributes") -if(IS_IOS) - set(CMAKE_CXX_FLAGS "-mfpu=neon -marm -fobjc-abi-version=2 -fobjc-arc \ - -std=gnu++11 -stdlib=libc++ -isysroot ${CMAKE_OSX_SYSROOT} ${CMAKE_CXX_FLAGS}") - add_compile_options(-fembed-bitcode) -else() - set(CMAKE_CXX_FLAGS "-std=c++11 ${CMAKE_CXX_FLAGS}") -endif() - -# others -if(USE_OPENMP) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") - add_definitions(-DPADDLE_MOBILE_USE_OPENMP) -endif() - -if(WITH_LOGGING) - message(STATUS "Debugging mode") - add_definitions(-DPADDLE_MOBILE_DEBUG) -else() -endif() - -if(NOT WITH_SYMBOL) - add_definitions(-fvisibility=hidden -fvisibility-inlines-hidden) -endif() - -if(USE_EXCEPTION) - message(STATUS "Use exception") - add_definitions(-DENABLE_EXCEPTION -fexceptions) -else() - add_definitions(-fno-exceptions) -endif() - -if(WITH_PROFILE) - add_definitions(-DPADDLE_MOBILE_PROFILE) -endif() - -# platform control -if(ARM_LINUX) - include("${CMAKE_CURRENT_LIST_DIR}/tools/arm-platform.cmake") -endif() - -if(CPU) - add_definitions(-DPADDLE_MOBILE_CPU) -else() - file(GLOB_RECURSE _tmp_list src/operators/kernel/arm/*.cpp src/operators/kernel/arm/*.cc) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_CC ${f}) - endforeach() - - file(GLOB_RECURSE _tmp_list_h src/operators/kernel/arm/*.h) - foreach(f ${_tmp_list_h}) - list(REMOVE_ITEM PADDLE_MOBILE_H ${f}) - endforeach() +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +cmake_minimum_required(VERSION 3.0) +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") + +option(WITH_PADDLE_MOBILE "Use the paddle-mobile legacy build" OFF) +if (WITH_PADDLE_MOBILE) + add_subdirectory(mobile) + return() +endif(WITH_PADDLE_MOBILE) + +set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) +set(CMAKE_CXX_STANDARD 11) + +include(system) +include(cross_compiling/preproject) + +project(paddle CXX C) +message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: " + "${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}") +message(STATUS "C compiler: ${CMAKE_C_COMPILER}, version: " + "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}") +message(STATUS "AR tools: ${CMAKE_AR}") + +if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + find_package(CUDA QUIET) +endif() +find_package(Git REQUIRED) +find_package(Threads REQUIRED) + +include(simd) + +################################ Exposed Configurations ####################################### +option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) +option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) +option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) +option(WITH_PYTHON "Compile PaddlePaddle with python interpreter" ON) +option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF) +option(WITH_MKL "Compile PaddlePaddle with MKL support." ${AVX_FOUND}) +option(WITH_ARM_DOTPROD "Compile PaddlePaddle with ARM dot production" ON) +option(WITH_SYSTEM_BLAS "Use system blas library" OFF) +option(WITH_DISTRIBUTE "Compile with distributed support" OFF) +option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF) +################################ Internal Configurations ####################################### +option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler and gperftools" OFF) +option(WITH_JEMALLOC "Compile PaddlePaddle with jemalloc" OFF) +option(WITH_COVERAGE "Compile PaddlePaddle with code coverage" OFF) +option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF) +option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF) +# TODO(Superjomn) Remove WITH_ANAKIN option if not needed latter. +if(ANDROID OR IOS OR ARMLINUX) + set(WITH_GPU OFF CACHE STRING + "Disable GPU when cross-compiling for Android and iOS" FORCE) + set(WITH_DSO OFF CACHE STRING + "Disable DSO when cross-compiling for Android and iOS" FORCE) + set(WITH_AVX OFF CACHE STRING + "Disable AVX when cross-compiling for Android and iOS" FORCE) + set(WITH_PYTHON OFF CACHE STRING + "Disable PYTHON when cross-compiling for Android and iOS" FORCE) + set(WITH_RDMA OFF CACHE STRING + "Disable RDMA when cross-compiling for Android and iOS" FORCE) + set(WITH_MKL OFF CACHE STRING + "Disable MKL when cross-compiling for Android and iOS" FORCE) +endif() + +# for lite, both server and mobile framework. +option(LITE_WITH_JAVA "Enable Java JNI lib in lite mode" OFF) +option(LITE_WITH_CUDA "Enable CUDA in lite mode" OFF) +option(LITE_WITH_X86 "Enable X86 in lite mode" ON) +option(LITE_WITH_ARM "Enable ARM in lite mode" OFF) +option(LITE_WITH_NPU "Enable NPU in lite mode" OFF) +option(LITE_WITH_OPENMP "Enable OpenMP in lite framework" ON) +option(LITE_WITH_OPENCL "Enable OpenCL support in lite" OFF) +option(LITE_WITH_FPGA "Enable FPGA support in lite" OFF) +option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" OFF) +option(LITE_WITH_PROFILE "Enable profile mode in lite framework" OFF) +option(LITE_SHUTDOWN_LOG "Shutdown log system or not." OFF) +option(LITE_ON_TINY_PUBLISH "Publish tiny predictor lib." OFF) + +set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING + "A path setting third party libraries download & build directories.") + +# CMAKE_BUILD_TYPE +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING + "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel" + FORCE) +endif() + +# check options +if (LITE_ON_TINY_PUBLISH) + if (NOT (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_JAVA AND NOT WITH_TESTING)) + message(FATAL_ERROR "LITE_ON_TINY_PUBLISH=ON must be used with WITH_LITE=ON LITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON LITE_WITH_JAVA=ON WITH_TESTING=OFF") + return() + endif() endif() -if (GPU_CL) - add_definitions(-DPADDLE_MOBILE_CL) +include_directories("${PADDLE_SOURCE_DIR}") - # opencl version - add_definitions(-DCL_TARGET_OPENCL_VERSION=220) +# for mobile +if (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + message(STATUS "Building the mobile framework") + include(cross_compiling/postproject) + include(cross_compiling/npu) # check and prepare NPU DDK - link_libraries(${CMAKE_CURRENT_LIST_DIR}/third_party/opencl/libOpenCL.so) - include_directories(third_party/opencl/OpenCL-Headers) -else() - file(GLOB_RECURSE _tmp_list src/framework/cl/*.cpp src/operators/kernel/cl/*.cpp) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_CC ${f}) - endforeach() + # include the necessary thirdparty dependencies + include(external/gflags) # download, build, install gflags - file(GLOB_RECURSE _tmp_list_h src/framework/cl/*.h) - foreach(f ${_tmp_list_h}) - list(REMOVE_ITEM PADDLE_MOBILE_H ${f}) - endforeach() -endif() + # LITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON will disable glog + # TODO(sangoly): refine WITH_LITE and LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + #include(external/glog) # download, build, install glog + include(external/gtest) # download, build, install gtest + include(ccache) # set ccache for compilation -if(FPGA) - file(GLOB_RECURSE _tmp_list src/operators/math/*.cpp src/operators/math/*.cc src/operators/kernel/fpga/*.cc) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_CC ${f}) - endforeach() - file(GLOB_RECURSE _tmp_list_h src/operators/math/*.h) - foreach(f ${_tmp_list_h}) - list(REMOVE_ITEM PADDLE_MOBILE_H ${f}) - endforeach() - list(APPEND PADDLE_MOBILE_CC src/operators/math/softmax.cpp) - list(APPEND PADDLE_MOBILE_h src/operators/math/softmax.h) - list(APPEND PADDLE_MOBILE_h src/operators/math/math_func_neon.h) - if(FPGAV1) - add_definitions(-DPADDLE_MOBILE_FPGA) - message("FPGA_V1 enabled") - add_definitions(-DPADDLE_MOBILE_FPGA_V1) - file(GLOB_RECURSE _tmp_list src/operators/kernel/fpga/V2/*.cpp src/fpga/V2/*.cpp) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_CC ${f}) - endforeach() - file(GLOB_RECURSE _tmp_list src/operators/kernel/fpga/V2/*.h src/fpga/V2/*.h) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_H ${f}) - endforeach() - file(GLOB_RECURSE _tmp_list src/operators/kernel/fpga/KD/*.cpp src/fpga/KD/*.cpp) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_CC ${f}) - endforeach() - file(GLOB_RECURSE _tmp_list src/operators/kernel/fpga/KD/*.h src/operators/kernel/fpga/KD/*.hpp - src/fpga/KD/*.h src/fpga/KD/*.hpp) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_H ${f}) - endforeach() + if (NOT LITE_ON_TINY_PUBLISH) + include(external/protobuf) # download, build, install protobuf endif() - if(FPGAV2) - add_definitions(-DPADDLE_MOBILE_FPGA) - message("FPGA_V2 enabled") - add_definitions(-DPADDLE_MOBILE_FPGA_V2) - file(GLOB_RECURSE _tmp_list src/operators/kernel/fpga/V1/*.cpp src/fpga/V1/*.cpp) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_CC ${f}) - endforeach() - file(GLOB_RECURSE _tmp_list src/operators/kernel/fpga/V1/*.h src/fpga/V1/*.h) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_H ${f}) - endforeach() - file(GLOB_RECURSE _tmp_list src/operators/kernel/fpga/KD/*.cpp src/fpga/KD/*.cpp) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_CC ${f}) - endforeach() - file(GLOB_RECURSE _tmp_list src/operators/kernel/fpga/KD/*.h src/operators/kernel/fpga/KD/*.hpp - src/fpga/KD/*.h src/fpga/KD/*.hpp) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_H ${f}) - endforeach() - endif() - if(FPGAKD) - message("FPGAKD enabled") - add_definitions(-DPADDLE_MOBILE_FPGA_KD) - file(GLOB_RECURSE _tmp_list src/operators/kernel/fpga/V1/*.cpp src/fpga/V1/*.cpp) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_CC ${f}) - endforeach() - file(GLOB_RECURSE _tmp_list src/operators/kernel/fpga/V1/*.h src/fpga/V1/*.h) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_H ${f}) - endforeach() - file(GLOB_RECURSE _tmp_list src/operators/kernel/fpga/V2/*.cpp src/fpga/V2/*.cpp) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_CC ${f}) - endforeach() - file(GLOB_RECURSE _tmp_list src/operators/kernel/fpga/V2/*.h src/fpga/V2/*.h) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_H ${f}) - endforeach() - - file(GLOB_RECURSE _tmp_list src/operators/kernel/central-arm-func/*.h) - foreach(f ${_tmp_list}) - list(APPEND PADDLE_MOBILE_H ${f}) - endforeach() - file(GLOB_RECURSE _tmp_list src/operators/kernel/central-arm-func/*.cpp) - foreach(f ${_tmp_list}) - list(APPEND PADDLE_MOBILE_CC ${f}) - endforeach() + # for opencl + if (LITE_WITH_OPENCL) + include(external/opencl-headers) + include(external/opencl-clhpp) endif() -else() - file(GLOB_RECURSE _tmp_list src/operators/kernel/fpga/*.cpp src/operators/kernel/fpga/*.cc) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_CC ${f}) - endforeach() - - file(GLOB_RECURSE _tmp_list_h src/operators/kernel/fpga/*.h) - foreach(f ${_tmp_list_h}) - list(REMOVE_ITEM PADDLE_MOBILE_H ${f}) - endforeach() - - file(GLOB_RECURSE _tmp_list src/fpga/*.cpp src/fpga/*.cc) - foreach(f ${_tmp_list}) - list(REMOVE_ITEM PADDLE_MOBILE_CC ${f}) - endforeach() + include(generic) # simplify cmake module + include(configure) # add paddle env configuration - file(GLOB_RECURSE _tmp_list_h src/fpga/*.h) - foreach(f ${_tmp_list_h}) - list(REMOVE_ITEM PADDLE_MOBILE_H ${f}) - endforeach() + add_subdirectory(lite) + return() endif() -if(ANDROID_NDK_TOOLCHAIN_INCLUDED) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -llog") -else() - list(REMOVE_ITEM PADDLE_MOBILE_H ${CMAKE_CURRENT_SOURCE_DIR}/src/io/jni/paddle_mobile_jni.h) - list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/io/jni/paddle_mobile_jni.cpp) - list(REMOVE_ITEM PADDLE_MOBILE_H ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/math/math_func_neon.h) +set(WITH_MKLML ${WITH_MKL}) +if (NOT DEFINED WITH_MKLDNN) + if (WITH_MKL AND AVX2_FOUND) + set(WITH_MKLDNN ON) + else() + message(STATUS "Do not have AVX2 intrinsics and disabled MKL-DNN") + set(WITH_MKLDNN OFF) + endif() endif() -if(IS_IOS) -else() - list(REMOVE_ITEM PADDLE_MOBILE_H ${CMAKE_CURRENT_SOURCE_DIR}/src/io/ios_io/PaddleMobileCPU.h) - list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/io/ios_io/PaddleMobileCPU.mm) - list(REMOVE_ITEM PADDLE_MOBILE_H ${CMAKE_CURRENT_SOURCE_DIR}/src/io/ios_io/op_symbols.h) -endif () +######################################################################################## -set(CMAKE_VERBOSE_MAKEFILE ON) -set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY build) -set(CMAKE_LIBRARY_OUTPUT_DIRECTORY build) -set(CMAKE_RUNTIME_OUTPUT_DIRECTORY build) +include(external/mklml) # download mklml package +include(external/xbyak) # download xbyak package +include(external/libxsmm) # download, build, install libxsmm +include(external/gflags) # download, build, install gflags +include(external/glog) # download, build, install glog +include(external/gtest) # download, build, install gtest +include(external/protobuf) # download, build, install protobuf +include(external/openblas) # download, build, install openblas +include(external/mkldnn) # download, build, install mkldnn +include(external/eigen) # download eigen3 +include(external/xxhash) # download install xxhash needed for x86 jit -# NET default -if(FPGAV1) - set(NET "FPGA_NET_V1" CACHE STRING "select net type") -elseif(FPGAV2) - set(NET "FPGA_NET_V2" CACHE STRING "select net type") -elseif(FPGAKD) - set(NET "FPGA_OPS_KD" CACHE STRING "select net type") -else() - set(NET "default" CACHE STRING "select net type") -endif() +include(configure) # add paddle env configuration -set_property(CACHE NET PROPERTY STRINGS "default" "googlenet" "mobilenet" "yolo" "squeezenet" "FPGA_NET_V1" "FPGA_NET_V2" "NLP" "op") -include("${CMAKE_CURRENT_LIST_DIR}/tools/op.cmake") +include(generic) # simplify cmake module +include(ccache) # set ccache for compilation +include(util) # set unittest and link libs +include(version) # set PADDLE_VERSION -# build library -if(ANDROID_NDK_TOOLCHAIN_INCLUDED) - list(REMOVE_DUPLICATES CMAKE_CXX_FLAGS) - add_library(paddle-mobile SHARED ${PADDLE_MOBILE_CC} ${PADDLE_MOBILE_H}) -elseif(IS_IOS) - if(USE_OPENMP) - add_library(paddle-mobile-stage0 STATIC ${PADDLE_MOBILE_CC} ${PADDLE_MOBILE_H}) - add_custom_target(paddle-mobile ALL - COMMAND libtool -static -o ${CMAKE_BINARY_DIR}/libpaddle-mobile.a ${CMAKE_CURRENT_LIST_DIR}/tools/libomp.a $ - WORKING_DIRECTORY ${CMAKE_BINARY_DIR} - DEPENDS paddle-mobile - ) - add_dependencies(paddle-mobile paddle-mobile-stage0) - else() - add_library(paddle-mobile STATIC ${PADDLE_MOBILE_CC} ${PADDLE_MOBILE_H}) - endif() -else() - add_library(paddle-mobile SHARED ${PADDLE_MOBILE_CC} ${PADDLE_MOBILE_H}) -endif() -# unit test -if(WITH_TEST AND WITH_SYMBOL) - if(IS_IOS) - else() - add_subdirectory(test) - endif() -elseif(FPGA) - add_subdirectory(test) -endif() +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG") +set(CMAKE_C_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG") -# # if you want to combine third party static librares into paddle mobile so, please uncomment this code block -# target_link_libraries( -# paddle-mobile -# -Wl,--whole-archive -# "path_to_third_party_static_library" -# -Wl,--no-whole-archive -# ) +add_subdirectory(lite) diff --git a/cmake/FindGflags.cmake b/cmake/FindGflags.cmake new file mode 100644 index 00000000000..6587089ba38 --- /dev/null +++ b/cmake/FindGflags.cmake @@ -0,0 +1,582 @@ +# Ceres Solver - A fast non-linear least squares minimizer +# Copyright 2015 Google Inc. All rights reserved. +# http://ceres-solver.org/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of Google Inc. nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: alexs.mac@gmail.com (Alex Stewart) +# + +# FindGflags.cmake - Find Google gflags logging library. +# +# This module will attempt to find gflags, either via an exported CMake +# configuration (generated by gflags >= 2.1 which are built with CMake), or +# by performing a standard search for all gflags components. The order of +# precedence for these two methods of finding gflags is controlled by: +# GFLAGS_PREFER_EXPORTED_GFLAGS_CMAKE_CONFIGURATION. +# +# This module defines the following variables: +# +# GFLAGS_FOUND: TRUE iff gflags is found. +# GFLAGS_INCLUDE_DIRS: Include directories for gflags. +# GFLAGS_LIBRARIES: Libraries required to link gflags. +# GFLAGS_NAMESPACE: The namespace in which gflags is defined. In versions of +# gflags < 2.1, this was google, for versions >= 2.1 it is +# by default gflags, although can be configured when building +# gflags to be something else (i.e. google for legacy +# compatibility). +# +# The following variables control the behaviour of this module when an exported +# gflags CMake configuration is not found. +# +# GFLAGS_PREFER_EXPORTED_GFLAGS_CMAKE_CONFIGURATION: TRUE/FALSE, iff TRUE then +# then prefer using an exported CMake configuration +# generated by gflags >= 2.1 over searching for the +# gflags components manually. Otherwise (FALSE) +# ignore any exported gflags CMake configurations and +# always perform a manual search for the components. +# Default: TRUE iff user does not define this variable +# before we are called, and does NOT specify either +# GFLAGS_INCLUDE_DIR_HINTS or GFLAGS_LIBRARY_DIR_HINTS +# otherwise FALSE. +# GFLAGS_INCLUDE_DIR_HINTS: List of additional directories in which to +# search for gflags includes, e.g: /timbuktu/include. +# GFLAGS_LIBRARY_DIR_HINTS: List of additional directories in which to +# search for gflags libraries, e.g: /timbuktu/lib. +# +# The following variables are also defined by this module, but in line with +# CMake recommended FindPackage() module style should NOT be referenced directly +# by callers (use the plural variables detailed above instead). These variables +# do however affect the behaviour of the module via FIND_[PATH/LIBRARY]() which +# are NOT re-called (i.e. search for library is not repeated) if these variables +# are set with valid values _in the CMake cache_. This means that if these +# variables are set directly in the cache, either by the user in the CMake GUI, +# or by the user passing -DVAR=VALUE directives to CMake when called (which +# explicitly defines a cache variable), then they will be used verbatim, +# bypassing the HINTS variables and other hard-coded search locations. +# +# GFLAGS_INCLUDE_DIR: Include directory for gflags, not including the +# include directory of any dependencies. +# GFLAGS_LIBRARY: gflags library, not including the libraries of any +# dependencies. + +# Reset CALLERS_CMAKE_FIND_LIBRARY_PREFIXES to its value when FindGflags was +# invoked, necessary for MSVC. +macro(GFLAGS_RESET_FIND_LIBRARY_PREFIX) + if (MSVC) + set(CMAKE_FIND_LIBRARY_PREFIXES "${CALLERS_CMAKE_FIND_LIBRARY_PREFIXES}") + endif (MSVC) +endmacro(GFLAGS_RESET_FIND_LIBRARY_PREFIX) + +# Called if we failed to find gflags or any of it's required dependencies, +# unsets all public (designed to be used externally) variables and reports +# error message at priority depending upon [REQUIRED/QUIET/] argument. +macro(GFLAGS_REPORT_NOT_FOUND REASON_MSG) + unset(GFLAGS_FOUND) + unset(GFLAGS_INCLUDE_DIRS) + unset(GFLAGS_LIBRARIES) + # Do not use unset, as we want to keep GFLAGS_NAMESPACE in the cache, + # but simply clear its value. + set(GFLAGS_NAMESPACE "" CACHE STRING + "gflags namespace (google or gflags)" FORCE) + + # Make results of search visible in the CMake GUI if gflags has not + # been found so that user does not have to toggle to advanced view. + mark_as_advanced(CLEAR GFLAGS_INCLUDE_DIR + GFLAGS_LIBRARY + GFLAGS_NAMESPACE) + + gflags_reset_find_library_prefix() + + # Note _FIND_[REQUIRED/QUIETLY] variables defined by FindPackage() + # use the camelcase library name, not uppercase. + if (Gflags_FIND_QUIETLY) + message(STATUS "Failed to find gflags - " ${REASON_MSG} ${ARGN}) + elseif (Gflags_FIND_REQUIRED) + message(FATAL_ERROR "Failed to find gflags - " ${REASON_MSG} ${ARGN}) + else() + # Neither QUIETLY nor REQUIRED, use no priority which emits a message + # but continues configuration and allows generation. + message("-- Failed to find gflags - " ${REASON_MSG} ${ARGN}) + endif () + return() +endmacro(GFLAGS_REPORT_NOT_FOUND) + +# Verify that all variable names passed as arguments are defined (can be empty +# but must be defined) or raise a fatal error. +macro(GFLAGS_CHECK_VARS_DEFINED) + foreach(CHECK_VAR ${ARGN}) + if (NOT DEFINED ${CHECK_VAR}) + message(FATAL_ERROR "Ceres Bug: ${CHECK_VAR} is not defined.") + endif() + endforeach() +endmacro(GFLAGS_CHECK_VARS_DEFINED) + +# Use check_cxx_source_compiles() to compile trivial test programs to determine +# the gflags namespace. This works on all OSs except Windows. If using Visual +# Studio, it fails because msbuild forces check_cxx_source_compiles() to use +# CMAKE_BUILD_TYPE=Debug for the test project, which usually breaks detection +# because MSVC requires that the test project use the same build type as gflags, +# which would normally be built in Release. +# +# Defines: GFLAGS_NAMESPACE in the caller's scope with the detected namespace, +# which is blank (empty string, will test FALSE is CMake conditionals) +# if detection failed. +function(GFLAGS_CHECK_GFLAGS_NAMESPACE_USING_TRY_COMPILE) + # Verify that all required variables are defined. + gflags_check_vars_defined( + GFLAGS_INCLUDE_DIR GFLAGS_LIBRARY) + # Ensure that GFLAGS_NAMESPACE is always unset on completion unless + # we explicitly set if after having the correct namespace. + set(GFLAGS_NAMESPACE "" PARENT_SCOPE) + + include(CheckCXXSourceCompiles) + # Setup include path & link library for gflags for CHECK_CXX_SOURCE_COMPILES. + set(CMAKE_REQUIRED_INCLUDES ${GFLAGS_INCLUDE_DIR}) + set(CMAKE_REQUIRED_LIBRARIES ${GFLAGS_LIBRARY} ${GFLAGS_LINK_LIBRARIES}) + # First try the (older) google namespace. Note that the output variable + # MUST be unique to the build type as otherwise the test is not repeated as + # it is assumed to have already been performed. + check_cxx_source_compiles( + "#include + int main(int argc, char * argv[]) { + google::ParseCommandLineFlags(&argc, &argv, true); + return 0; + }" + GFLAGS_IN_GOOGLE_NAMESPACE) + if (GFLAGS_IN_GOOGLE_NAMESPACE) + set(GFLAGS_NAMESPACE google PARENT_SCOPE) + return() + endif() + + # Try (newer) gflags namespace instead. Note that the output variable + # MUST be unique to the build type as otherwise the test is not repeated as + # it is assumed to have already been performed. + set(CMAKE_REQUIRED_INCLUDES ${GFLAGS_INCLUDE_DIR}) + set(CMAKE_REQUIRED_LIBRARIES ${GFLAGS_LIBRARY} ${GFLAGS_LINK_LIBRARIES}) + check_cxx_source_compiles( + "#include + int main(int argc, char * argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + return 0; + }" + GFLAGS_IN_GFLAGS_NAMESPACE) + if (GFLAGS_IN_GFLAGS_NAMESPACE) + set(GFLAGS_NAMESPACE gflags PARENT_SCOPE) + return() + endif (GFLAGS_IN_GFLAGS_NAMESPACE) +endfunction(GFLAGS_CHECK_GFLAGS_NAMESPACE_USING_TRY_COMPILE) + +# Use regex on the gflags headers to attempt to determine the gflags namespace. +# Checks both gflags.h (contained namespace on versions < 2.1.2) and +# gflags_declare.h, which contains the namespace on versions >= 2.1.2. +# In general, this method should only be used when +# GFLAGS_CHECK_GFLAGS_NAMESPACE_USING_TRY_COMPILE() cannot be used, or has +# failed. +# +# Defines: GFLAGS_NAMESPACE in the caller's scope with the detected namespace, +# which is blank (empty string, will test FALSE is CMake conditionals) +# if detection failed. +function(GFLAGS_CHECK_GFLAGS_NAMESPACE_USING_REGEX) + # Verify that all required variables are defined. + gflags_check_vars_defined(GFLAGS_INCLUDE_DIR) + # Ensure that GFLAGS_NAMESPACE is always undefined on completion unless + # we explicitly set if after having the correct namespace. + set(GFLAGS_NAMESPACE "" PARENT_SCOPE) + + # Scan gflags.h to identify what namespace gflags was built with. On + # versions of gflags < 2.1.2, gflags.h was configured with the namespace + # directly, on >= 2.1.2, gflags.h uses the GFLAGS_NAMESPACE #define which + # is defined in gflags_declare.h, we try each location in turn. + set(GFLAGS_HEADER_FILE ${GFLAGS_INCLUDE_DIR}/gflags/gflags.h) + if (NOT EXISTS ${GFLAGS_HEADER_FILE}) + gflags_report_not_found( + "Could not find file: ${GFLAGS_HEADER_FILE} " + "containing namespace information in gflags install located at: " + "${GFLAGS_INCLUDE_DIR}.") + endif() + file(READ ${GFLAGS_HEADER_FILE} GFLAGS_HEADER_FILE_CONTENTS) + + string(REGEX MATCH "namespace [A-Za-z]+" + GFLAGS_NAMESPACE "${GFLAGS_HEADER_FILE_CONTENTS}") + string(REGEX REPLACE "namespace ([A-Za-z]+)" "\\1" + GFLAGS_NAMESPACE "${GFLAGS_NAMESPACE}") + + if (NOT GFLAGS_NAMESPACE) + gflags_report_not_found( + "Failed to extract gflags namespace from header file: " + "${GFLAGS_HEADER_FILE}.") + endif (NOT GFLAGS_NAMESPACE) + + if (GFLAGS_NAMESPACE STREQUAL "google" OR + GFLAGS_NAMESPACE STREQUAL "gflags") + # Found valid gflags namespace from gflags.h. + set(GFLAGS_NAMESPACE "${GFLAGS_NAMESPACE}" PARENT_SCOPE) + return() + endif() + + # Failed to find gflags namespace from gflags.h, gflags is likely a new + # version, check gflags_declare.h, which in newer versions (>= 2.1.2) contains + # the GFLAGS_NAMESPACE #define, which is then referenced in gflags.h. + set(GFLAGS_DECLARE_FILE ${GFLAGS_INCLUDE_DIR}/gflags/gflags_declare.h) + if (NOT EXISTS ${GFLAGS_DECLARE_FILE}) + gflags_report_not_found( + "Could not find file: ${GFLAGS_DECLARE_FILE} " + "containing namespace information in gflags install located at: " + "${GFLAGS_INCLUDE_DIR}.") + endif() + file(READ ${GFLAGS_DECLARE_FILE} GFLAGS_DECLARE_FILE_CONTENTS) + + string(REGEX MATCH "#define GFLAGS_NAMESPACE [A-Za-z]+" + GFLAGS_NAMESPACE "${GFLAGS_DECLARE_FILE_CONTENTS}") + string(REGEX REPLACE "#define GFLAGS_NAMESPACE ([A-Za-z]+)" "\\1" + GFLAGS_NAMESPACE "${GFLAGS_NAMESPACE}") + + if (NOT GFLAGS_NAMESPACE) + gflags_report_not_found( + "Failed to extract gflags namespace from declare file: " + "${GFLAGS_DECLARE_FILE}.") + endif (NOT GFLAGS_NAMESPACE) + + if (GFLAGS_NAMESPACE STREQUAL "google" OR + GFLAGS_NAMESPACE STREQUAL "gflags") + # Found valid gflags namespace from gflags.h. + set(GFLAGS_NAMESPACE "${GFLAGS_NAMESPACE}" PARENT_SCOPE) + return() + endif() +endfunction(GFLAGS_CHECK_GFLAGS_NAMESPACE_USING_REGEX) + +# ----------------------------------------------------------------- +# By default, if the user has expressed no preference for using an exported +# gflags CMake configuration over performing a search for the installed +# components, and has not specified any hints for the search locations, then +# prefer a gflags exported configuration if available. +if (NOT DEFINED GFLAGS_PREFER_EXPORTED_GFLAGS_CMAKE_CONFIGURATION + AND NOT GFLAGS_INCLUDE_DIR_HINTS + AND NOT GFLAGS_LIBRARY_DIR_HINTS) + message(STATUS "No preference for use of exported gflags CMake configuration " + "set, and no hints for include/library directories provided. " + "Defaulting to preferring an installed/exported gflags CMake configuration " + "if available.") + set(GFLAGS_PREFER_EXPORTED_GFLAGS_CMAKE_CONFIGURATION TRUE) +endif() + +if (GFLAGS_PREFER_EXPORTED_GFLAGS_CMAKE_CONFIGURATION) + # Try to find an exported CMake configuration for gflags, as generated by + # gflags versions >= 2.1. + # + # We search twice, s/t we can invert the ordering of precedence used by + # find_package() for exported package build directories, and installed + # packages (found via CMAKE_SYSTEM_PREFIX_PATH), listed as items 6) and 7) + # respectively in [1]. + # + # By default, exported build directories are (in theory) detected first, and + # this is usually the case on Windows. However, on OS X & Linux, the install + # path (/usr/local) is typically present in the PATH environment variable + # which is checked in item 4) in [1] (i.e. before both of the above, unless + # NO_SYSTEM_ENVIRONMENT_PATH is passed). As such on those OSs installed + # packages are usually detected in preference to exported package build + # directories. + # + # To ensure a more consistent response across all OSs, and as users usually + # want to prefer an installed version of a package over a locally built one + # where both exist (esp. as the exported build directory might be removed + # after installation), we first search with NO_CMAKE_PACKAGE_REGISTRY which + # means any build directories exported by the user are ignored, and thus + # installed directories are preferred. If this fails to find the package + # we then research again, but without NO_CMAKE_PACKAGE_REGISTRY, so any + # exported build directories will now be detected. + # + # To prevent confusion on Windows, we also pass NO_CMAKE_BUILDS_PATH (which + # is item 5) in [1]), to not preferentially use projects that were built + # recently with the CMake GUI to ensure that we always prefer an installed + # version if available. + # + # [1] http://www.cmake.org/cmake/help/v2.8.11/cmake.html#command:find_package + find_package(gflags QUIET + NO_MODULE + NO_CMAKE_PACKAGE_REGISTRY + NO_CMAKE_BUILDS_PATH) + if (gflags_FOUND) + message(STATUS "Found installed version of gflags: ${gflags_DIR}") + else(gflags_FOUND) + # Failed to find an installed version of gflags, repeat search allowing + # exported build directories. + message(STATUS "Failed to find installed gflags CMake configuration, " + "searching for gflags build directories exported with CMake.") + # Again pass NO_CMAKE_BUILDS_PATH, as we know that gflags is exported and + # do not want to treat projects built with the CMake GUI preferentially. + find_package(gflags QUIET + NO_MODULE + NO_CMAKE_BUILDS_PATH) + if (gflags_FOUND) + message(STATUS "Found exported gflags build directory: ${gflags_DIR}") + endif(gflags_FOUND) + endif(gflags_FOUND) + + set(FOUND_INSTALLED_GFLAGS_CMAKE_CONFIGURATION ${gflags_FOUND}) + + # gflags v2.1 - 2.1.2 shipped with a bug in their gflags-config.cmake [1] + # whereby gflags_LIBRARIES = "gflags", but there was no imported target + # called "gflags", they were called: gflags[_nothreads]-[static/shared]. + # As this causes linker errors when gflags is not installed in a location + # on the current library paths, detect if this problem is present and + # fix it. + # + # [1] https://github.com/gflags/gflags/issues/110 + if (gflags_FOUND) + # NOTE: This is not written as additional conditions in the outer + # if (gflags_FOUND) as the NOT TARGET "${gflags_LIBRARIES}" + # condition causes problems if gflags is not found. + if (${gflags_VERSION} VERSION_LESS 2.1.3 AND + NOT TARGET "${gflags_LIBRARIES}") + message(STATUS "Detected broken gflags install in: ${gflags_DIR}, " + "version: ${gflags_VERSION} <= 2.1.2 which defines gflags_LIBRARIES = " + "${gflags_LIBRARIES} which is not an imported CMake target, see: " + "https://github.com/gflags/gflags/issues/110. Attempting to fix by " + "detecting correct gflags target.") + # Ordering here expresses preference for detection, specifically we do not + # want to use the _nothreads variants if the full library is available. + list(APPEND CHECK_GFLAGS_IMPORTED_TARGET_NAMES + gflags-shared gflags-static + gflags_nothreads-shared gflags_nothreads-static) + foreach(CHECK_GFLAGS_TARGET ${CHECK_GFLAGS_IMPORTED_TARGET_NAMES}) + if (TARGET ${CHECK_GFLAGS_TARGET}) + message(STATUS "Found valid gflags target: ${CHECK_GFLAGS_TARGET}, " + "updating gflags_LIBRARIES.") + set(gflags_LIBRARIES ${CHECK_GFLAGS_TARGET}) + break() + endif() + endforeach() + if (NOT TARGET ${gflags_LIBRARIES}) + message(STATUS "Failed to fix detected broken gflags install in: " + "${gflags_DIR}, version: ${gflags_VERSION} <= 2.1.2, none of the " + "imported targets for gflags: ${CHECK_GFLAGS_IMPORTED_TARGET_NAMES} " + "are defined. Will continue with a manual search for gflags " + "components. We recommend you build/install a version of gflags > " + "2.1.2 (or master).") + set(FOUND_INSTALLED_GFLAGS_CMAKE_CONFIGURATION FALSE) + endif() + endif() + endif() + + if (FOUND_INSTALLED_GFLAGS_CMAKE_CONFIGURATION) + message(STATUS "Detected gflags version: ${gflags_VERSION}") + set(GFLAGS_FOUND ${gflags_FOUND}) + set(GFLAGS_INCLUDE_DIR ${gflags_INCLUDE_DIR}) + set(GFLAGS_LIBRARY ${gflags_LIBRARIES}) + + # gflags does not export the namespace in their CMake configuration, so + # use our function to determine what it should be, as it can be either + # gflags or google dependent upon version & configuration. + # + # NOTE: We use the regex method to determine the namespace here, as + # check_cxx_source_compiles() will not use imported targets, which + # is what gflags will be in this case. + gflags_check_gflags_namespace_using_regex() + + if (NOT GFLAGS_NAMESPACE) + gflags_report_not_found( + "Failed to determine gflags namespace using regex for gflags " + "version: ${gflags_VERSION} exported here: ${gflags_DIR} using CMake.") + endif (NOT GFLAGS_NAMESPACE) + else (FOUND_INSTALLED_GFLAGS_CMAKE_CONFIGURATION) + message(STATUS "Failed to find an installed/exported CMake configuration " + "for gflags, will perform search for installed gflags components.") + endif (FOUND_INSTALLED_GFLAGS_CMAKE_CONFIGURATION) +endif(GFLAGS_PREFER_EXPORTED_GFLAGS_CMAKE_CONFIGURATION) + +if (NOT GFLAGS_FOUND) + # Either failed to find an exported gflags CMake configuration, or user + # told us not to use one. Perform a manual search for all gflags components. + + # Handle possible presence of lib prefix for libraries on MSVC, see + # also GFLAGS_RESET_FIND_LIBRARY_PREFIX(). + if (MSVC) + # Preserve the caller's original values for CMAKE_FIND_LIBRARY_PREFIXES + # s/t we can set it back before returning. + set(CALLERS_CMAKE_FIND_LIBRARY_PREFIXES "${CMAKE_FIND_LIBRARY_PREFIXES}") + # The empty string in this list is important, it represents the case when + # the libraries have no prefix (shared libraries / DLLs). + set(CMAKE_FIND_LIBRARY_PREFIXES "lib" "" "${CMAKE_FIND_LIBRARY_PREFIXES}") + endif (MSVC) + + # Search user-installed locations first, so that we prefer user installs + # to system installs where both exist. + list(APPEND GFLAGS_CHECK_INCLUDE_DIRS + /usr/local/include + /usr/local/homebrew/include # Mac OS X + /opt/local/var/macports/software # Mac OS X. + /opt/local/include + /usr/include) + list(APPEND GFLAGS_CHECK_PATH_SUFFIXES + gflags/include # Windows (for C:/Program Files prefix). + gflags/Include ) # Windows (for C:/Program Files prefix). + + list(APPEND GFLAGS_CHECK_LIBRARY_DIRS + /usr/local/lib + /usr/local/homebrew/lib # Mac OS X. + /opt/local/lib + /usr/lib) + list(APPEND GFLAGS_CHECK_LIBRARY_SUFFIXES + gflags/lib # Windows (for C:/Program Files prefix). + gflags/Lib ) # Windows (for C:/Program Files prefix). + + # Search supplied hint directories first if supplied. + find_path(GFLAGS_INCLUDE_DIR + NAMES gflags/gflags.h + PATHS ${GFLAGS_INCLUDE_DIR_HINTS} + ${GFLAGS_CHECK_INCLUDE_DIRS} + PATH_SUFFIXES ${GFLAGS_CHECK_PATH_SUFFIXES}) + if (NOT GFLAGS_INCLUDE_DIR OR + NOT EXISTS ${GFLAGS_INCLUDE_DIR}) + gflags_report_not_found( + "Could not find gflags include directory, set GFLAGS_INCLUDE_DIR " + "to directory containing gflags/gflags.h") + endif (NOT GFLAGS_INCLUDE_DIR OR + NOT EXISTS ${GFLAGS_INCLUDE_DIR}) + + find_library(GFLAGS_LIBRARY NAMES gflags + PATHS ${GFLAGS_LIBRARY_DIR_HINTS} + ${GFLAGS_CHECK_LIBRARY_DIRS} + PATH_SUFFIXES ${GFLAGS_CHECK_LIBRARY_SUFFIXES}) + if (NOT GFLAGS_LIBRARY OR + NOT EXISTS ${GFLAGS_LIBRARY}) + gflags_report_not_found( + "Could not find gflags library, set GFLAGS_LIBRARY " + "to full path to libgflags.") + endif (NOT GFLAGS_LIBRARY OR + NOT EXISTS ${GFLAGS_LIBRARY}) + + # gflags typically requires a threading library (which is OS dependent), note + # that this defines the CMAKE_THREAD_LIBS_INIT variable. If we are able to + # detect threads, we assume that gflags requires it. + find_package(Threads QUIET) + set(GFLAGS_LINK_LIBRARIES ${CMAKE_THREAD_LIBS_INIT}) + # On Windows (including MinGW), the Shlwapi library is used by gflags if + # available. + if (WIN32) + include(CheckIncludeFileCXX) + check_include_file_cxx("shlwapi.h" HAVE_SHLWAPI) + if (HAVE_SHLWAPI) + list(APPEND GFLAGS_LINK_LIBRARIES shlwapi.lib) + endif(HAVE_SHLWAPI) + endif (WIN32) + + # Mark internally as found, then verify. GFLAGS_REPORT_NOT_FOUND() unsets + # if called. + set(GFLAGS_FOUND TRUE) + + # Identify what namespace gflags was built with. + if (GFLAGS_INCLUDE_DIR AND NOT GFLAGS_NAMESPACE) + # To handle Windows peculiarities / CMake bugs on MSVC we try two approaches + # to detect the gflags namespace: + # + # 1) Try to use check_cxx_source_compiles() to compile a trivial program + # with the two choices for the gflags namespace. + # + # 2) [In the event 1) fails] Use regex on the gflags headers to try to + # determine the gflags namespace. Whilst this is less robust than 1), + # it does avoid any interaction with msbuild. + gflags_check_gflags_namespace_using_try_compile() + + if (NOT GFLAGS_NAMESPACE) + # Failed to determine gflags namespace using check_cxx_source_compiles() + # method, try and obtain it using regex on the gflags headers instead. + message(STATUS "Failed to find gflags namespace using using " + "check_cxx_source_compiles(), trying namespace regex instead, " + "this is expected on Windows.") + gflags_check_gflags_namespace_using_regex() + + if (NOT GFLAGS_NAMESPACE) + gflags_report_not_found( + "Failed to determine gflags namespace either by " + "check_cxx_source_compiles(), or namespace regex.") + endif (NOT GFLAGS_NAMESPACE) + endif (NOT GFLAGS_NAMESPACE) + endif (GFLAGS_INCLUDE_DIR AND NOT GFLAGS_NAMESPACE) + + # Make the GFLAGS_NAMESPACE a cache variable s/t the user can view it, and could + # overwrite it in the CMake GUI. + set(GFLAGS_NAMESPACE "${GFLAGS_NAMESPACE}" CACHE STRING + "gflags namespace (google or gflags)" FORCE) + + # gflags does not seem to provide any record of the version in its + # source tree, thus cannot extract version. + + # Catch case when caller has set GFLAGS_NAMESPACE in the cache / GUI + # with an invalid value. + if (GFLAGS_NAMESPACE AND + NOT GFLAGS_NAMESPACE STREQUAL "google" AND + NOT GFLAGS_NAMESPACE STREQUAL "gflags") + gflags_report_not_found( + "Caller defined GFLAGS_NAMESPACE:" + " ${GFLAGS_NAMESPACE} is not valid, not google or gflags.") + endif () + # Catch case when caller has set GFLAGS_INCLUDE_DIR in the cache / GUI and + # thus FIND_[PATH/LIBRARY] are not called, but specified locations are + # invalid, otherwise we would report the library as found. + if (GFLAGS_INCLUDE_DIR AND + NOT EXISTS ${GFLAGS_INCLUDE_DIR}/gflags/gflags.h) + gflags_report_not_found( + "Caller defined GFLAGS_INCLUDE_DIR:" + " ${GFLAGS_INCLUDE_DIR} does not contain gflags/gflags.h header.") + endif (GFLAGS_INCLUDE_DIR AND + NOT EXISTS ${GFLAGS_INCLUDE_DIR}/gflags/gflags.h) + # TODO: This regex for gflags library is pretty primitive, we use lowercase + # for comparison to handle Windows using CamelCase library names, could + # this check be better? + string(TOLOWER "${GFLAGS_LIBRARY}" LOWERCASE_GFLAGS_LIBRARY) + if (GFLAGS_LIBRARY AND + NOT "${LOWERCASE_GFLAGS_LIBRARY}" MATCHES ".*gflags[^/]*") + gflags_report_not_found( + "Caller defined GFLAGS_LIBRARY: " + "${GFLAGS_LIBRARY} does not match gflags.") + endif (GFLAGS_LIBRARY AND + NOT "${LOWERCASE_GFLAGS_LIBRARY}" MATCHES ".*gflags[^/]*") + + gflags_reset_find_library_prefix() + +endif(NOT GFLAGS_FOUND) + +# Set standard CMake FindPackage variables if found. +if (GFLAGS_FOUND) + set(GFLAGS_INCLUDE_DIRS ${GFLAGS_INCLUDE_DIR}) + set(GFLAGS_LIBRARIES ${GFLAGS_LIBRARY} ${GFLAGS_LINK_LIBRARIES}) +endif (GFLAGS_FOUND) + +# Handle REQUIRED / QUIET optional arguments. +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(Gflags DEFAULT_MSG + GFLAGS_INCLUDE_DIRS GFLAGS_LIBRARIES GFLAGS_NAMESPACE) + +# Only mark internal variables as advanced if we found gflags, otherwise +# leave them visible in the standard GUI for the user to set manually. +if (GFLAGS_FOUND) + mark_as_advanced(FORCE GFLAGS_INCLUDE_DIR + GFLAGS_LIBRARY + GFLAGS_NAMESPACE + gflags_DIR) # Autogenerated by find_package(gflags) +endif (GFLAGS_FOUND) diff --git a/cmake/FindGlog.cmake b/cmake/FindGlog.cmake new file mode 100644 index 00000000000..142e2ca96ba --- /dev/null +++ b/cmake/FindGlog.cmake @@ -0,0 +1,24 @@ +# +# Find libglog +# +# LIBGLOG_INCLUDE_DIR - where to find glog/logging.h, etc. +# LIBGLOG_LIBRARY - List of libraries when using libglog. +# LIBGLOG_FOUND - True if libglog found. +# +# from https://github.com/facebook/hhvm/blob/master/CMake/FindGlog.cmake + +IF (LIBGLOG_INCLUDE_DIR) + # Already in cache, be silent + SET(LIBGLOG_FIND_QUIETLY TRUE) +ENDIF () + +FIND_PATH(LIBGLOG_INCLUDE_DIR glog/logging.h) + +FIND_LIBRARY(LIBGLOG_LIBRARY glog) + +# handle the QUIETLY and REQUIRED arguments and set LIBGLOG_FOUND to TRUE if +# all listed variables are TRUE +INCLUDE(FindPackageHandleStandardArgs) +FIND_PACKAGE_HANDLE_STANDARD_ARGS(LIBGLOG DEFAULT_MSG LIBGLOG_LIBRARY LIBGLOG_INCLUDE_DIR) + +MARK_AS_ADVANCED(LIBGLOG_LIBRARY LIBGLOG_INCLUDE_DIR) \ No newline at end of file diff --git a/cmake/FindGperftools.cmake b/cmake/FindGperftools.cmake new file mode 100644 index 00000000000..928f573a4fb --- /dev/null +++ b/cmake/FindGperftools.cmake @@ -0,0 +1,63 @@ +# Tries to find Gperftools. +# +# Usage of this module as follows: +# +# find_package(Gperftools) +# +# Variables used by this module, they can change the default behaviour and need +# to be set before calling find_package: +# +# Gperftools_ROOT_DIR Set this variable to the root installation of +# Gperftools if the module has problems finding +# the proper installation path. +# +# Variables defined by this module: +# +# GPERFTOOLS_FOUND System has Gperftools libs/headers +# GPERFTOOLS_LIBRARIES The Gperftools libraries (tcmalloc & profiler) +# GPERFTOOLS_INCLUDE_DIR The location of Gperftools headers + +find_library(GPERFTOOLS_TCMALLOC + NAMES tcmalloc + HINTS ${Gperftools_ROOT_DIR}/lib) + +find_library(GPERFTOOLS_PROFILER + NAMES profiler + HINTS ${Gperftools_ROOT_DIR}/lib) + +find_library(GPERFTOOLS_TCMALLOC_AND_PROFILER + NAMES tcmalloc_and_profiler + HINTS ${Gperftools_ROOT_DIR}/lib) + +find_path(GPERFTOOLS_INCLUDE_DIR + NAMES gperftools/heap-profiler.h + HINTS ${Gperftools_ROOT_DIR}/include) + +set(GPERFTOOLS_LIBRARIES ${GPERFTOOLS_TCMALLOC_AND_PROFILER}) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args( + Gperftools + DEFAULT_MSG + GPERFTOOLS_LIBRARIES + GPERFTOOLS_INCLUDE_DIR) + +mark_as_advanced( + Gperftools_ROOT_DIR + GPERFTOOLS_TCMALLOC + GPERFTOOLS_PROFILER + GPERFTOOLS_TCMALLOC_AND_PROFILER + GPERFTOOLS_LIBRARIES + GPERFTOOLS_INCLUDE_DIR) + +# create IMPORTED targets +if (Gperftools_FOUND AND NOT TARGET gperftools::tcmalloc) + add_library(gperftools::tcmalloc UNKNOWN IMPORTED) + set_target_properties(gperftools::tcmalloc PROPERTIES + IMPORTED_LOCATION ${GPERFTOOLS_TCMALLOC} + INTERFACE_INCLUDE_DIRECTORIES "${GPERFTOOLS_INCLUDE_DIR}") + add_library(gperftools::profiler UNKNOWN IMPORTED) + set_target_properties(gperftools::profiler PROPERTIES + IMPORTED_LOCATION ${GPERFTOOLS_PROFILER} + INTERFACE_INCLUDE_DIRECTORIES "${GPERFTOOLS_INCLUDE_DIR}") +endif() diff --git a/cmake/FindJeMalloc.cmake b/cmake/FindJeMalloc.cmake new file mode 100644 index 00000000000..b95287160ba --- /dev/null +++ b/cmake/FindJeMalloc.cmake @@ -0,0 +1,28 @@ +# - Find JeMalloc library +# Find the native JeMalloc includes and library +# +# JEMALLOC_INCLUDE_DIR - where to find jemalloc.h, etc. +# JEMALLOC_LIBRARIES - List of libraries when using jemalloc. +# JEMALLOC_FOUND - True if jemalloc found. + +find_path(JEMALLOC_INCLUDE_DIR + NAMES jemalloc/jemalloc.h + HINTS ${JEMALLOC_ROOT_DIR}/include) + +find_library(JEMALLOC_LIBRARIES + NAMES jemalloc + HINTS ${JEMALLOC_ROOT_DIR}/lib) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(jemalloc DEFAULT_MSG JEMALLOC_LIBRARIES JEMALLOC_INCLUDE_DIR) + +mark_as_advanced( + JEMALLOC_LIBRARIES + JEMALLOC_INCLUDE_DIR) + +if (JEMALLOC_FOUND) + add_library(jemalloc::jemalloc UNKNOWN IMPORTED) + set_target_properties(jemalloc::jemalloc PROPERTIES + IMPORTED_LOCATION ${JEMALLOC_LIBRARIES} + INTERFACE_INCLUDE_DIRECTORIES "${JEMALLOC_INCLUDE_DIR}") +endif() diff --git a/cmake/FindNumPy.cmake b/cmake/FindNumPy.cmake new file mode 100644 index 00000000000..8cdd642ac01 --- /dev/null +++ b/cmake/FindNumPy.cmake @@ -0,0 +1,38 @@ +# Find the Python NumPy package +# PYTHON_NUMPY_INCLUDE_DIR +# NUMPY_FOUND +# will be set by this script + +cmake_minimum_required(VERSION 2.6) + +if(NOT PYTHON_EXECUTABLE) + if(NumPy_FIND_QUIETLY) + find_package(PythonInterp QUIET) + else() + find_package(PythonInterp) + set(_numpy_out 1) + endif() +endif() + +if (PYTHON_EXECUTABLE) + # write a python script that finds the numpy path + file(WRITE ${PROJECT_BINARY_DIR}/FindNumpyPath.py + "try: import numpy; print(numpy.get_include())\nexcept:pass\n") + + # execute the find script + exec_program("${PYTHON_EXECUTABLE}" ${PROJECT_BINARY_DIR} + ARGS "FindNumpyPath.py" + OUTPUT_VARIABLE NUMPY_PATH) +elseif(_numpy_out) + message(STATUS "Python executable not found.") +endif(PYTHON_EXECUTABLE) + +find_path(PYTHON_NUMPY_INCLUDE_DIR numpy/arrayobject.h + HINTS "${NUMPY_PATH}" "${PYTHON_INCLUDE_PATH}") + +if(PYTHON_NUMPY_INCLUDE_DIR) + set(PYTHON_NUMPY_FOUND 1 CACHE INTERNAL "Python numpy found") +endif(PYTHON_NUMPY_INCLUDE_DIR) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NumPy DEFAULT_MSG PYTHON_NUMPY_INCLUDE_DIR) diff --git a/cmake/cblas.cmake b/cmake/cblas.cmake new file mode 100644 index 00000000000..52ac31d1d12 --- /dev/null +++ b/cmake/cblas.cmake @@ -0,0 +1,94 @@ +# Find the CBlas and lapack libraries +# +# It will search MKLML, atlas, OpenBlas, reference-cblas in order. +# +# If any cblas implementation found, the following variable will be set. +# CBLAS_PROVIDER # one of MKLML, OPENBLAS, REFERENCE +# CBLAS_INC_DIR # the include directory for cblas. +# CBLAS_LIBS # a list of libraries should be linked by paddle. +# # Each library should be full path to object file. + +set(CBLAS_FOUND OFF) + +## Find MKLML First. +if(WITH_MKLML AND MKLML_INC_DIR AND MKLML_LIB) + set(CBLAS_FOUND ON) + set(CBLAS_PROVIDER MKLML) + set(CBLAS_INC_DIR ${MKLML_INC_DIR}) + set(CBLAS_LIBRARIES ${MKLML_LIB}) + + add_definitions(-DPADDLE_WITH_MKLML) + add_definitions(-DLAPACK_FOUND) + + message(STATUS "Found cblas and lapack in MKLML " + "(include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") + return() +endif() + +## Then find openblas. +set(OPENBLAS_ROOT $ENV{OPENBLAS_ROOT} CACHE PATH "Folder contains Openblas") +set(OPENBLAS_INCLUDE_SEARCH_PATHS + ${OPENBLAS_ROOT}/include + /usr/include + /usr/include/openblas + /usr/local/opt/openblas/include) +set(OPENBLAS_LIB_SEARCH_PATHS + ${OPENBLAS_ROOT}/lib + /usr/lib + /usr/lib/blas/openblas + /usr/lib/openblas + /usr/local/opt/openblas/lib) + +find_path(OPENBLAS_INC_DIR NAMES cblas.h + PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS} NO_DEFAULT_PATH) +find_path(OPENBLAS_LAPACKE_INC_DIR NAMES lapacke.h + PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS}) +find_library(OPENBLAS_LIB NAMES openblas + PATHS ${OPENBLAS_LIB_SEARCH_PATHS}) + +if(OPENBLAS_LAPACKE_INC_DIR AND OPENBLAS_INC_DIR AND OPENBLAS_LIB) + set(CBLAS_FOUND ON) + set(CBLAS_PROVIDER OPENBLAS) + set(CBLAS_INC_DIR ${OPENBLAS_INC_DIR} ${OPENBLAS_LAPACKE_INC_DIR}) + set(CBLAS_LIBRARIES ${OPENBLAS_LIB}) + + add_definitions(-DPADDLE_USE_OPENBLAS) + add_definitions(-DLAPACK_FOUND) + + message(STATUS "Found OpenBLAS (include: ${OPENBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") + message(STATUS "Found lapack in OpenBLAS (include: ${OPENBLAS_LAPACKE_INC_DIR})") + return() +endif() + + +## Then find the reference-cblas. www.netlib.org/blas/ +set(REFERENCE_CBLAS_ROOT $ENV{REFERENCE_CBLAS_ROOT} CACHE PATH + "Folder contains reference-cblas") +set(REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS + ${REFERENCE_CBLAS_ROOT}/include + /usr/include + /usr/include/cblas +) + +set(REFERENCE_CBLAS_LIB_SEARCH_PATHS + ${REFERENCE_CBLAS_ROOT}/lib + /usr/lib + /usr/lib/blas/reference/ + /usr/lib/reference/ +) + +if(WITH_SYSTEM_BLAS) + find_path(REFERENCE_CBLAS_INCLUDE_DIR NAMES cblas.h PATHS + ${REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS}) + find_library(REFERENCE_CBLAS_LIBRARY NAMES cblas PATHS + ${REFERENCE_CBLAS_LIB_SEARCH_PATHS}) + + if(REFERENCE_CBLAS_INCLUDE_DIR AND REFERENCE_CBLAS_LIBRARY) + set(CBLAS_FOUND ON) + set(CBLAS_PROVIDER REFERENCE) + set(CBLAS_INC_DIR ${REFERENCE_CBLAS_INCLUDE_DIR}) + set(CBLAS_LIBRARIES ${REFERENCE_CBLAS_LIBRARY}) + add_definitions(-DPADDLE_USE_REFERENCE_CBLAS) + message(STATUS "Found reference-cblas (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") + endif() +endif() diff --git a/cmake/ccache.cmake b/cmake/ccache.cmake new file mode 100644 index 00000000000..900f59d4cb8 --- /dev/null +++ b/cmake/ccache.cmake @@ -0,0 +1,9 @@ +# Use ccache if found ccache program + +find_program(CCACHE_PATH ccache) + +if(CCACHE_PATH) + message(STATUS "Ccache is founded, use ccache to speed up compile.") + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ${CCACHE_PATH}) + set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ${CCACHE_PATH}) +endif(CCACHE_PATH) diff --git a/cmake/configure.cmake b/cmake/configure.cmake new file mode 100644 index 00000000000..21bcdef6ceb --- /dev/null +++ b/cmake/configure.cmake @@ -0,0 +1,213 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT WITH_PYTHON) + add_definitions(-DPADDLE_NO_PYTHON) +endif(NOT WITH_PYTHON) + +if(WITH_DSO) + add_definitions(-DPADDLE_USE_DSO) +endif(WITH_DSO) + +if(WITH_TESTING) + add_definitions(-DPADDLE_WITH_TESTING) +endif(WITH_TESTING) + +if(NOT WITH_PROFILER) + add_definitions(-DPADDLE_DISABLE_PROFILER) +endif(NOT WITH_PROFILER) + +if(WITH_AVX AND AVX_FOUND) + set(SIMD_FLAG ${AVX_FLAG}) +elseif(SSE3_FOUND) + set(SIMD_FLAG ${SSE3_FLAG}) +endif() + +if(WIN32) + # windows header option for all targets. + add_definitions(-D_XKEYCHECK_H) + # Use symbols instead of absolute path, reduce the cmake link command length. + SET(CMAKE_C_USE_RESPONSE_FILE_FOR_LIBRARIES 1) + SET(CMAKE_CXX_USE_RESPONSE_FILE_FOR_LIBRARIES 1) + SET(CMAKE_C_USE_RESPONSE_FILE_FOR_OBJECTS 1) + SET(CMAKE_CXX_USE_RESPONSE_FILE_FOR_OBJECTS 1) + SET(CMAKE_C_USE_RESPONSE_FILE_FOR_INCLUDES 1) + SET(CMAKE_CXX_USE_RESPONSE_FILE_FOR_INCLUDES 1) + SET(CMAKE_C_RESPONSE_FILE_LINK_FLAG "@") + SET(CMAKE_CXX_RESPONSE_FILE_LINK_FLAG "@") + + # Specify the program to use when building static libraries + SET(CMAKE_C_CREATE_STATIC_LIBRARY " lib ") + SET(CMAKE_CXX_CREATE_STATIC_LIBRARY " lib ") + + # set defination for the dll export + if (NOT MSVC) + message(FATAL "Windows build only support msvc. Which was binded by the nvcc compiler of NVIDIA.") + endif(NOT MSVC) +endif(WIN32) + +if(WITH_PSLIB) + add_definitions(-DPADDLE_WITH_PSLIB) +endif() + +if(WITH_GPU) + add_definitions(-DPADDLE_WITH_CUDA) + add_definitions(-DEIGEN_USE_GPU) + + FIND_PACKAGE(CUDA REQUIRED) + + if(${CUDA_VERSION_MAJOR} VERSION_LESS 7) + message(FATAL_ERROR "Paddle needs CUDA >= 7.0 to compile") + endif() + + if(NOT CUDNN_FOUND) + message(FATAL_ERROR "Paddle needs cudnn to compile") + endif() + if(CUPTI_FOUND) + include_directories(${CUPTI_INCLUDE_DIR}) + add_definitions(-DPADDLE_WITH_CUPTI) + else() + message(STATUS "Cannot find CUPTI, GPU Profiling is incorrect.") + endif() + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} "-Xcompiler ${SIMD_FLAG}") + + # Include cuda and cudnn + include_directories(${CUDNN_INCLUDE_DIR}) + include_directories(${CUDA_TOOLKIT_INCLUDE}) + + if(TENSORRT_FOUND) + if(${CUDA_VERSION_MAJOR} VERSION_LESS 8) + message(FATAL_ERROR "TensorRT needs CUDA >= 8.0 to compile") + endif() + if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7) + message(FATAL_ERROR "TensorRT needs CUDNN >= 7.0 to compile") + endif() + if(${TENSORRT_MAJOR_VERSION} VERSION_LESS 4) + message(FATAL_ERROR "Paddle needs TensorRT >= 4.0 to compile") + endif() + include_directories(${TENSORRT_INCLUDE_DIR}) + endif() + if(WITH_ANAKIN) + if(${CUDA_VERSION_MAJOR} VERSION_LESS 8) + message(WARNING "Anakin needs CUDA >= 8.0 to compile. Force WITH_ANAKIN=OFF") + set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when CUDA >= 8.0." FORCE) + endif() + if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7) + message(WARNING "Anakin needs CUDNN >= 7.0 to compile. Force WITH_ANAKIN=OFF") + set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when CUDNN >= 7.0." FORCE) + endif() + add_definitions(-DWITH_ANAKIN) + endif() + if(WITH_ANAKIN) + # NOTICE(minqiyang): the end slash is important because $CUDNN_INCLUDE_DIR + # is a softlink to real cudnn.h directory + set(ENV{CUDNN_INCLUDE_DIR} "${CUDNN_INCLUDE_DIR}/") + get_filename_component(CUDNN_LIBRARY_DIR ${CUDNN_LIBRARY} DIRECTORY) + set(ENV{CUDNN_LIBRARY} ${CUDNN_LIBRARY_DIR}) + endif() +elseif(WITH_AMD_GPU) + add_definitions(-DPADDLE_WITH_HIP) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__HIP_PLATFORM_HCC__") +else() + add_definitions(-DHPPL_STUB_FUNC) + list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu) +endif() + +if (WITH_MKLML AND MKLML_IOMP_LIB) + message(STATUS "Enable Intel OpenMP with ${MKLML_IOMP_LIB}") + if(WIN32) + # openmp not support well for now on windows + set(OPENMP_FLAGS "") + else(WIN32) + set(OPENMP_FLAGS "-fopenmp") + endif(WIN32) + set(CMAKE_C_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS}) + set(CMAKE_CXX_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS}) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OPENMP_FLAGS}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OPENMP_FLAGS}") +endif() + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}") + +if(WITH_DISTRIBUTE) + add_definitions(-DPADDLE_WITH_DISTRIBUTE) +endif() + +if(WITH_GRPC) + add_definitions(-DPADDLE_WITH_GRPC) +endif(WITH_GRPC) + +if(WITH_BRPC_RDMA) + add_definitions(-DPADDLE_WITH_BRPC_RDMA) +endif(WITH_BRPC_RDMA) + +if(ON_INFER) + add_definitions(-DPADDLE_ON_INFERENCE) +endif(ON_INFER) + +if(WITH_WBAES) + add_definitions(-DPADDLE_WITH_WBAES) +endif(WITH_WBAES) + +if (REPLACE_ENFORCE_GLOG) + add_definitions("-DREPLACE_ENFORCE_GLOG") +endif() + +# for lite +# TODO(Superjomn) not work fine with the option +if (LITE_WITH_CUDA) +add_definitions("-DLITE_WITH_CUDA") +endif() + +if (LITE_WITH_X86) + add_definitions("-DLITE_WITH_X86") +endif() + +if (LITE_WITH_ARM) + add_definitions("-DLITE_WITH_ARM") +endif() + +if (WITH_ARM_DOTPROD) + add_definitions("-DWITH_ARM_DOTPROD") +endif() + +if (LITE_WITH_NPU) + add_definitions("-DLITE_WITH_NPU") +endif() + +if (LITE_WITH_OPENCL) + add_definitions("-DLITE_WITH_OPENCL") +endif() + +if (LITE_WITH_FPGA) +add_definitions("-DLITE_WITH_FPGA") +endif() + +if (LITE_WITH_PROFILE) + add_definitions("-DLITE_WITH_PROFILE") +endif() + +if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + add_definitions("-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK") +endif() + +if (LITE_SHUTDOWN_LOG) + add_definitions("-DLITE_SHUTDOWN_LOG") +endif() + +if (LITE_ON_TINY_PUBLISH) + add_definitions("-DLITE_ON_TINY_PUBLISH") +endif() diff --git a/cmake/coveralls.cmake b/cmake/coveralls.cmake new file mode 100644 index 00000000000..ca1471cabb5 --- /dev/null +++ b/cmake/coveralls.cmake @@ -0,0 +1,103 @@ +# CMake script for code coverage. +# If _COVERALLS_UPLOAD is ON, it will upload json files to overalls.io automatically. + +# Param _COVERAGE_SRCS A list of coverage source files. +# Param _COVERALLS_UPLOAD Upload the result to coveralls. +# Param _CMAKE_SCRIPT_PATH CMake script path. +function(code_coverage _COVERAGE_SRCS _COVERALLS_UPLOAD _CMAKE_SCRIPT_PATH) + # clean previous gcov data. + file(REMOVE_RECURSE ${PROJECT_BINARY_DIR}/*.gcda) + + # find curl for upload JSON soon. + if (_COVERALLS_UPLOAD) + find_program(CURL_EXECUTABLE curl) + if (NOT CURL_EXECUTABLE) + message(FATAL_ERROR "Coveralls: curl not found!") + endif() + endif() + + # When passing a CMake list to an external process, the list + # will be converted from the format "1;2;3" to "1 2 3". + set(COVERAGE_SRCS "") + foreach (SINGLE_SRC ${_COVERAGE_SRCS}) + set(COVERAGE_SRCS "${COVERAGE_SRCS}*${SINGLE_SRC}") + endforeach() + + # query number of logical cores + cmake_host_system_information(RESULT core_size QUERY NUMBER_OF_LOGICAL_CORES) + # coveralls json file. + set(COVERALLS_FILE ${PROJECT_BINARY_DIR}/coveralls.json) + add_custom_target(coveralls_generate + # Run regress tests. + COMMAND ${CMAKE_CTEST_COMMAND} + -j ${core_size} + --output-on-failure + # Generate Gcov and translate it into coveralls JSON. + COMMAND ${CMAKE_COMMAND} + -DCOVERAGE_SRCS="${COVERAGE_SRCS}" + -DCOVERALLS_OUTPUT_FILE="${COVERALLS_FILE}" + -DCOV_PATH="${PROJECT_BINARY_DIR}" + -DPROJECT_ROOT="${PROJECT_SOURCE_DIR}" + -P "${_CMAKE_SCRIPT_PATH}/coverallsGcovJsons.cmake" + WORKING_DIRECTORY ${PROJECT_BINARY_DIR} + COMMENT "Coveralls: generating coveralls output..." + ) + + if (_COVERALLS_UPLOAD) + message("COVERALLS UPLOAD: ON") + # Upload the JSON to coveralls. + add_custom_target(coveralls_upload + COMMAND ${CURL_EXECUTABLE} + -S -F json_file=@${COVERALLS_FILE} + https://coveralls.io/api/v1/jobs + DEPENDS coveralls_generate + WORKING_DIRECTORY ${PROJECT_BINARY_DIR} + COMMENT "Coveralls: uploading coveralls output...") + + add_custom_target(coveralls DEPENDS coveralls_upload) + else() + message("COVERALLS UPLOAD: OFF") + add_custom_target(coveralls DEPENDS coveralls_generate) + endif() +endfunction() + +if(WITH_COVERAGE) + set(CMAKE_BUILD_TYPE "Debug") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0 -fprofile-arcs -ftest-coverage") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -g -O0 -fprofile-arcs -ftest-coverage") + + set(EXCLUDE_DIRS + "demo/" + "build/" + "tests/" + ".test_env/" + ) + + if(WITH_GPU) + file(GLOB_RECURSE PADDLE_SOURCES RELATIVE "${PROJECT_SOURCE_DIR}" "*.cpp" "*.cc" ".c" "*.cu") + else() + file(GLOB_RECURSE PADDLE_SOURCES RELATIVE "${PROJECT_SOURCE_DIR}" "*.cpp" "*.cc" "*.c") + endif() + + # exclude trivial files in PADDLE_SOURCES + foreach(EXCLUDE_DIR ${EXCLUDE_DIRS}) + foreach(TMP_PATH ${PADDLE_SOURCES}) + string(FIND ${TMP_PATH} ${EXCLUDE_DIR} EXCLUDE_DIR_FOUND) + if(NOT ${EXCLUDE_DIR_FOUND} EQUAL -1) + list(REMOVE_ITEM PADDLE_SOURCES ${TMP_PATH}) + endif() + endforeach(TMP_PATH) + endforeach() + + # convert to absolute path + set(PADDLE_SRCS "") + foreach(PADDLE_SRC ${PADDLE_SOURCES}) + set(PADDLE_SRCS "${PADDLE_SRCS};${PROJECT_SOURCE_DIR}/${PADDLE_SRC}") + endforeach() + + code_coverage( + "${PADDLE_SRCS}" + ${COVERALLS_UPLOAD} + "${PROJECT_SOURCE_DIR}/cmake" + ) +endif() diff --git a/cmake/coverallsGcovJsons.cmake b/cmake/coverallsGcovJsons.cmake new file mode 100644 index 00000000000..4641184fcf5 --- /dev/null +++ b/cmake/coverallsGcovJsons.cmake @@ -0,0 +1,401 @@ +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# Copyright (C) 2014 Joakim Söderberg +# +# This is intended to be run by a custom target in a CMake project like this. +# 0. Compile program with coverage support. +# 1. Clear coverage data. (Recursively delete *.gcda in build dir) +# 2. Run the unit tests. +# 3. Run this script specifying which source files the coverage should be performed on. +# +# This script will then use gcov to generate .gcov files in the directory specified +# via the COV_PATH var. This should probably be the same as your cmake build dir. +# +# It then parses the .gcov files to convert them into the Coveralls JSON format: +# https://coveralls.io/docs/api +# + +CMAKE_MINIMUM_REQUIRED(VERSION 2.8) + +# Since it's not possible to pass a CMake list properly in the +# "1;2;3" format to an external process, we have replaced the +# ";" with "*", so reverse that here so we get it back into the +# CMake list format. +string(REGEX REPLACE "\\*" ";" COVERAGE_SRCS ${COVERAGE_SRCS}) + +find_program(GCOV_EXECUTABLE gcov) +if (NOT GCOV_EXECUTABLE) + message(FATAL_ERROR "gcov not found! Aborting...") +endif() + +find_package(Git) + +# TODO: Add these git things to the coveralls json. +if (GIT_FOUND) + # Branch. + execute_process( + COMMAND ${GIT_EXECUTABLE} rev-parse --abbrev-ref HEAD + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + OUTPUT_VARIABLE GIT_BRANCH + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + + macro (git_log_format FORMAT_CHARS VAR_NAME) + execute_process( + COMMAND ${GIT_EXECUTABLE} log -1 --pretty=format:%${FORMAT_CHARS} + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + OUTPUT_VARIABLE ${VAR_NAME} + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + endmacro() + + git_log_format(an GIT_AUTHOR_EMAIL) + git_log_format(ae GIT_AUTHOR_EMAIL) + git_log_format(cn GIT_COMMITTER_NAME) + git_log_format(ce GIT_COMMITTER_EMAIL) + git_log_format(B GIT_COMMIT_MESSAGE) + + message("Git exe: ${GIT_EXECUTABLE}") + message("Git branch: ${GIT_BRANCH}") + message("Git author: ${GIT_AUTHOR_NAME}") + message("Git e-mail: ${GIT_AUTHOR_EMAIL}") + message("Git commiter name: ${GIT_COMMITTER_NAME}") + message("Git commiter e-mail: ${GIT_COMMITTER_EMAIL}") + message("Git commit message: ${GIT_COMMIT_MESSAGE}") + +endif() + +############################# Macros ######################################### + +# +# This macro converts from the full path format gcov outputs: +# +# /path/to/project/root/build/#path#to#project#root#subdir#the_file.c.gcov +# +# to the original source file path the .gcov is for: +# +# /path/to/project/root/subdir/the_file.c +# +macro(get_source_path_from_gcov_filename _SRC_FILENAME _GCOV_FILENAME) + + # /path/to/project/root/build/#path#to#project#root#subdir#the_file.c.gcov + # -> + # #path#to#project#root#subdir#the_file.c.gcov + get_filename_component(_GCOV_FILENAME_WEXT ${_GCOV_FILENAME} NAME) + + # #path#to#project#root#subdir#the_file.c.gcov -> /path/to/project/root/subdir/the_file.c + string(REGEX REPLACE "\\.gcov$" "" SRC_FILENAME_TMP ${_GCOV_FILENAME_WEXT}) + string(REGEX REPLACE "\#" "/" SRC_FILENAME_TMP ${SRC_FILENAME_TMP}) + set(${_SRC_FILENAME} "${SRC_FILENAME_TMP}") +endmacro() + +############################################################################## + +# Get the coverage data. +file(GLOB_RECURSE GCDA_FILES "${COV_PATH}" "*.gcda") +message("Process GCDA files:") +message("===============================") + +# Get a list of all the object directories needed by gcov +# (The directories the .gcda files and .o files are found in) +# and run gcov on those. +foreach(GCDA ${GCDA_FILES}) + get_filename_component(GCDA_DIR ${GCDA} PATH) + + # + # The -p below refers to "Preserve path components", + # This means that the generated gcov filename of a source file will + # keep the original files entire filepath, but / is replaced with #. + # Example: + # + # /path/to/project/root/build/CMakeFiles/the_file.dir/subdir/the_file.c.gcda + # ------------------------------------------------------------------------------ + # File '/path/to/project/root/subdir/the_file.c' + # Lines executed:68.34% of 199 + # /path/to/project/root/subdir/the_file.c:creating '#path#to#project#root#subdir#the_file.c.gcov' + # + # If -p is not specified then the file is named only "the_file.c.gcov" + # + execute_process( + COMMAND ${GCOV_EXECUTABLE} -p -o ${GCDA_DIR} ${GCDA} >/dev/null + WORKING_DIRECTORY ${GCDA_DIR} + ) +endforeach() + +# TODO: Make these be absolute path +file(GLOB_RECURSE ALL_GCOV_FILES "${COV_PATH}" "*.gcov") + +# Get only the filenames to use for filtering. +#set(COVERAGE_SRCS_NAMES "") +#foreach (COVSRC ${COVERAGE_SRCS}) +# get_filename_component(COVSRC_NAME ${COVSRC} NAME) +# message("${COVSRC} -> ${COVSRC_NAME}") +# list(APPEND COVERAGE_SRCS_NAMES "${COVSRC_NAME}") +#endforeach() + +# +# Filter out all but the gcov files we want. +# +# We do this by comparing the list of COVERAGE_SRCS filepaths that the +# user wants the coverage data for with the paths of the generated .gcov files, +# so that we only keep the relevant gcov files. +# +# Example: +# COVERAGE_SRCS = +# /path/to/project/root/subdir/the_file.c +# +# ALL_GCOV_FILES = +# /path/to/project/root/build/#path#to#project#root#subdir#the_file.c.gcov +# /path/to/project/root/build/#path#to#project#root#subdir#other_file.c.gcov +# +# Result should be: +# GCOV_FILES = +# /path/to/project/root/build/#path#to#project#root#subdir#the_file.c.gcov +# +set(GCOV_FILES "") +#message("Look in coverage sources: ${COVERAGE_SRCS}") +message("\nFilter out unwanted GCOV files:") +message("===============================") + +set(COVERAGE_SRCS_REMAINING ${COVERAGE_SRCS}) + +foreach (GCOV_FILE ${ALL_GCOV_FILES}) + + # + # /path/to/project/root/build/#path#to#project#root#subdir#the_file.c.gcov + # -> + # /path/to/project/root/subdir/the_file.c + get_source_path_from_gcov_filename(GCOV_SRC_PATH ${GCOV_FILE}) + + # Is this in the list of source files? + # TODO: We want to match against relative path filenames from the source file root... + list(FIND COVERAGE_SRCS ${GCOV_SRC_PATH} WAS_FOUND) + + if (NOT WAS_FOUND EQUAL -1) + message("YES: ${GCOV_FILE}") + list(APPEND GCOV_FILES ${GCOV_FILE}) + + # We remove it from the list, so we don't bother searching for it again. + # Also files left in COVERAGE_SRCS_REMAINING after this loop ends should + # have coverage data generated from them (no lines are covered). + list(REMOVE_ITEM COVERAGE_SRCS_REMAINING ${GCOV_SRC_PATH}) + else() + message("NO: ${GCOV_FILE}") + endif() +endforeach() + +# TODO: Enable setting these +set(JSON_SERVICE_NAME "travis-ci") +set(JSON_SERVICE_JOB_ID $ENV{TRAVIS_JOB_ID}) + +set(JSON_TEMPLATE +"{ + \"service_name\": \"\@JSON_SERVICE_NAME\@\", + \"service_job_id\": \"\@JSON_SERVICE_JOB_ID\@\", + \"source_files\": \@JSON_GCOV_FILES\@ +}" +) + +set(SRC_FILE_TEMPLATE +"{ + \"name\": \"\@GCOV_SRC_REL_PATH\@\", + \"source_digest\": \"\@GCOV_CONTENTS_MD5\@\", + \"coverage\": \@GCOV_FILE_COVERAGE\@ + }" +) + +message("\nGenerate JSON for files:") +message("=========================") + +set(JSON_GCOV_FILES "[") + +# Read the GCOV files line by line and get the coverage data. +foreach (GCOV_FILE ${GCOV_FILES}) + + get_source_path_from_gcov_filename(GCOV_SRC_PATH ${GCOV_FILE}) + file(RELATIVE_PATH GCOV_SRC_REL_PATH "${PROJECT_ROOT}" "${GCOV_SRC_PATH}") + + # The new coveralls API doesn't need the entire source (Yay!) + # However, still keeping that part for now. Will cleanup in the future. + file(MD5 "${GCOV_SRC_PATH}" GCOV_CONTENTS_MD5) + message("MD5: ${GCOV_SRC_PATH} = ${GCOV_CONTENTS_MD5}") + + # Loads the gcov file as a list of lines. + # (We first open the file and replace all occurences of [] with _ + # because CMake will fail to parse a line containing unmatched brackets... + # also the \ to escaped \n in macros screws up things.) + # https://public.kitware.com/Bug/view.php?id=15369 + file(READ ${GCOV_FILE} GCOV_CONTENTS) + string(REPLACE "[" "_" GCOV_CONTENTS "${GCOV_CONTENTS}") + string(REPLACE "]" "_" GCOV_CONTENTS "${GCOV_CONTENTS}") + string(REPLACE "\\" "_" GCOV_CONTENTS "${GCOV_CONTENTS}") + file(WRITE ${GCOV_FILE}_tmp "${GCOV_CONTENTS}") + + file(STRINGS ${GCOV_FILE}_tmp GCOV_LINES) + list(LENGTH GCOV_LINES LINE_COUNT) + + # Instead of trying to parse the source from the + # gcov file, simply read the file contents from the source file. + # (Parsing it from the gcov is hard because C-code uses ; in many places + # which also happens to be the same as the CMake list delimeter). + file(READ ${GCOV_SRC_PATH} GCOV_FILE_SOURCE) + + string(REPLACE "\\" "\\\\" GCOV_FILE_SOURCE "${GCOV_FILE_SOURCE}") + string(REGEX REPLACE "\"" "\\\\\"" GCOV_FILE_SOURCE "${GCOV_FILE_SOURCE}") + string(REPLACE "\t" "\\\\t" GCOV_FILE_SOURCE "${GCOV_FILE_SOURCE}") + string(REPLACE "\r" "\\\\r" GCOV_FILE_SOURCE "${GCOV_FILE_SOURCE}") + string(REPLACE "\n" "\\\\n" GCOV_FILE_SOURCE "${GCOV_FILE_SOURCE}") + # According to http://json.org/ these should be escaped as well. + # Don't know how to do that in CMake however... + #string(REPLACE "\b" "\\\\b" GCOV_FILE_SOURCE "${GCOV_FILE_SOURCE}") + #string(REPLACE "\f" "\\\\f" GCOV_FILE_SOURCE "${GCOV_FILE_SOURCE}") + #string(REGEX REPLACE "\u([a-fA-F0-9]{4})" "\\\\u\\1" GCOV_FILE_SOURCE "${GCOV_FILE_SOURCE}") + + # We want a json array of coverage data as a single string + # start building them from the contents of the .gcov + set(GCOV_FILE_COVERAGE "[") + + set(GCOV_LINE_COUNT 1) # Line number for the .gcov. + set(DO_SKIP 0) + foreach (GCOV_LINE ${GCOV_LINES}) + #message("${GCOV_LINE}") + # Example of what we're parsing: + # Hitcount |Line | Source + # " 8: 26: if (!allowed || (strlen(allowed) == 0))" + string(REGEX REPLACE + "^([^:]*):([^:]*):(.*)$" + "\\1;\\2;\\3" + RES + "${GCOV_LINE}") + + # Check if we should exclude lines using the Lcov syntax. + string(REGEX MATCH "LCOV_EXCL_START" START_SKIP "${GCOV_LINE}") + string(REGEX MATCH "LCOV_EXCL_END" END_SKIP "${GCOV_LINE}") + string(REGEX MATCH "LCOV_EXCL_LINE" LINE_SKIP "${GCOV_LINE}") + + set(RESET_SKIP 0) + if (LINE_SKIP AND NOT DO_SKIP) + set(DO_SKIP 1) + set(RESET_SKIP 1) + endif() + + if (START_SKIP) + set(DO_SKIP 1) + message("${GCOV_LINE_COUNT}: Start skip") + endif() + + if (END_SKIP) + set(DO_SKIP 0) + endif() + + list(LENGTH RES RES_COUNT) + + if (RES_COUNT GREATER 2) + list(GET RES 0 HITCOUNT) + list(GET RES 1 LINE) + list(GET RES 2 SOURCE) + + string(STRIP ${HITCOUNT} HITCOUNT) + string(STRIP ${LINE} LINE) + + # Lines with 0 line numbers are metadata and can be ignored. + if (NOT ${LINE} EQUAL 0) + + if (DO_SKIP) + set(GCOV_FILE_COVERAGE "${GCOV_FILE_COVERAGE}null, ") + else() + # Translate the hitcount into valid JSON values. + if (${HITCOUNT} STREQUAL "#####") + set(GCOV_FILE_COVERAGE "${GCOV_FILE_COVERAGE}0, ") + elseif (${HITCOUNT} STREQUAL "-") + set(GCOV_FILE_COVERAGE "${GCOV_FILE_COVERAGE}null, ") + else() + set(GCOV_FILE_COVERAGE "${GCOV_FILE_COVERAGE}${HITCOUNT}, ") + endif() + endif() + endif() + else() + message(WARNING "Failed to properly parse line (RES_COUNT = ${RES_COUNT}) ${GCOV_FILE}:${GCOV_LINE_COUNT}\n-->${GCOV_LINE}") + endif() + + if (RESET_SKIP) + set(DO_SKIP 0) + endif() + math(EXPR GCOV_LINE_COUNT "${GCOV_LINE_COUNT}+1") + endforeach() + + message("${GCOV_LINE_COUNT} of ${LINE_COUNT} lines read!") + + # Advanced way of removing the trailing comma in the JSON array. + # "[1, 2, 3, " -> "[1, 2, 3" + string(REGEX REPLACE ",[ ]*$" "" GCOV_FILE_COVERAGE ${GCOV_FILE_COVERAGE}) + + # Append the trailing ] to complete the JSON array. + set(GCOV_FILE_COVERAGE "${GCOV_FILE_COVERAGE}]") + + # Generate the final JSON for this file. + message("Generate JSON for file: ${GCOV_SRC_REL_PATH}...") + string(CONFIGURE ${SRC_FILE_TEMPLATE} FILE_JSON) + + set(JSON_GCOV_FILES "${JSON_GCOV_FILES}${FILE_JSON}, ") +endforeach() + +# Loop through all files we couldn't find any coverage for +# as well, and generate JSON for those as well with 0% coverage. +foreach(NOT_COVERED_SRC ${COVERAGE_SRCS_REMAINING}) + + # Loads the source file as a list of lines. + file(STRINGS ${NOT_COVERED_SRC} SRC_LINES) + + set(GCOV_FILE_COVERAGE "[") + set(GCOV_FILE_SOURCE "") + + foreach (SOURCE ${SRC_LINES}) + set(GCOV_FILE_COVERAGE "${GCOV_FILE_COVERAGE}0, ") + + string(REPLACE "\\" "\\\\" SOURCE "${SOURCE}") + string(REGEX REPLACE "\"" "\\\\\"" SOURCE "${SOURCE}") + string(REPLACE "\t" "\\\\t" SOURCE "${SOURCE}") + string(REPLACE "\r" "\\\\r" SOURCE "${SOURCE}") + set(GCOV_FILE_SOURCE "${GCOV_FILE_SOURCE}${SOURCE}\\n") + endforeach() + + # Remove trailing comma, and complete JSON array with ] + string(REGEX REPLACE ",[ ]*$" "" GCOV_FILE_COVERAGE ${GCOV_FILE_COVERAGE}) + set(GCOV_FILE_COVERAGE "${GCOV_FILE_COVERAGE}]") + + # Generate the final JSON for this file. + string(CONFIGURE ${SRC_FILE_TEMPLATE} FILE_JSON) + set(JSON_GCOV_FILES "${JSON_GCOV_FILES}${FILE_JSON}, ") +endforeach() + +# Get rid of trailing comma. +string(REGEX REPLACE ",[ ]*$" "" JSON_GCOV_FILES ${JSON_GCOV_FILES}) +set(JSON_GCOV_FILES "${JSON_GCOV_FILES}]") + +# Generate the final complete JSON! +message("Generate final JSON...") +string(CONFIGURE ${JSON_TEMPLATE} JSON) + +file(WRITE "${COVERALLS_OUTPUT_FILE}" "${JSON}") +message("###########################################################################") +message("Generated coveralls JSON containing coverage data:") +message("${COVERALLS_OUTPUT_FILE}") +message("###########################################################################") diff --git a/cmake/cross_compiling/android.cmake b/cmake/cross_compiling/android.cmake new file mode 100644 index 00000000000..11a803ff031 --- /dev/null +++ b/cmake/cross_compiling/android.cmake @@ -0,0 +1,85 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT ARM_TARGET_OS STREQUAL "android") + return() +endif() + +set(ANDROID TRUE) +add_definitions(-DLITE_WITH_LINUX) + +if(NOT DEFINED ANDROID_NDK) + set(ANDROID_NDK $ENV{NDK_ROOT}) + if(NOT ANDROID_NDK) + message(FATAL_ERROR "Must set ANDROID_NDK or env NDK_ROOT") + endif() +endif() + +if(ARM_TARGET_LANG STREQUAL "gcc") + # gcc do not need set lang on android + set(ARM_TARGET_LANG "") +endif() + +if(NOT DEFINED ANDROID_API_LEVEL) + set(ANDROID_API_LEVEL "22") +endif() + +# then check input arm abi +if(ARM_TARGET_ARCH_ABI STREQUAL "armv7hf") + message(FATAL_ERROR "ANDROID does not support hardfp on v7 use armv7 instead.") +endif() + +set(ANDROID_ARCH_ABI ${ARM_TARGET_ARCH_ABI} CACHE STRING "Choose Android Arch ABI") +if(ARM_TARGET_ARCH_ABI STREQUAL "armv8") + set(ANDROID_ARCH_ABI "arm64-v8a") +endif() + +if(ARM_TARGET_ARCH_ABI STREQUAL "armv7") + set(ANDROID_ARCH_ABI "armeabi-v7a") +endif() + +check_input_var(ANDROID_ARCH_ABI DEFAULT ${ANDROID_ARCH_ABI} LIST "arm64-v8a" "armeabi-v7a" + "armeabi-v6" "armeabi" "mips" "mips64" "x86" "x86_64") +check_input_var(ANDROID_STL_TYPE DEFAULT "c++_static" LIST "c++_static" "gnustl_static" "c++_shared") + +if(ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") + message(STATUS "armeabi-v7a use softfp by default.") + set(CMAKE_ANDROID_ARM_NEON ON) + message(STATUS "NEON is enabled on arm-v7a with softfp.") +endif() + +set(CMAKE_SYSTEM_NAME Android) +set(CMAKE_SYSTEM_VERSION ${ANDROID_API_LEVEL}) +set(CMAKE_ANDROID_ARCH_ABI ${ANDROID_ARCH_ABI}) +set(CMAKE_ANDROID_NDK ${ANDROID_NDK}) +set(CMAKE_ANDROID_NDK_TOOLCHAIN_VERSION ${ARM_TARGET_LANG}) +set(CMAKE_ANDROID_STL_TYPE ${ANDROID_STL_TYPE}) + +if (ARM_TARGET_LANG STREQUAL "clang") + if(ARM_TARGET_ARCH_ABI STREQUAL "armv8") + set(triple aarch64-v8a-linux-android) + elseif(ARM_TARGET_ARCH_ABI STREQUAL "armv7") + set(triple arm-v7a-linux-android) + set(LITE_WITH_OPENMP OFF CACHE STRING "Due to libomp's bug(For ARM64, it has been fixed by https://reviews.llvm.org/D19879, but still exists on ARM32), disable OpenMP on armv7 when cross-compiling using Clang" FORCE) + else() + message(FATAL_ERROR "Clang do not support this ${ARM_TARGET_ARCH_ABI}, use armv8 or armv7") + endif() + + set(CMAKE_C_COMPILER clang) + set(CMAKE_C_COMPILER_TARGET ${triple}) + set(CMAKE_CXX_COMPILER clang++) + set(CMAKE_CXX_COMPILER_TARGET ${triple}) + + message(STATUS "CMAKE_CXX_COMPILER_TARGET: ${CMAKE_CXX_COMPILER_TARGET}") +endif() diff --git a/cmake/cross_compiling/armlinux.cmake b/cmake/cross_compiling/armlinux.cmake new file mode 100644 index 00000000000..98f23d43005 --- /dev/null +++ b/cmake/cross_compiling/armlinux.cmake @@ -0,0 +1,41 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT ARM_TARGET_OS STREQUAL "armlinux") + return() +endif() + +set(ARMLINUX TRUE) +add_definitions(-DLITE_WITH_LINUX) +set(CMAKE_SYSTEM_NAME Linux) + +check_input_var(ARMLINUX_ARCH_ABI DEFAULT ${ARM_TARGET_ARCH_ABI} LIST "armv8" "armv7" "armv7hf") + +if(ARMLINUX_ARCH_ABI STREQUAL "armv8") + set(CMAKE_SYSTEM_PROCESSOR aarch64) + set(CMAKE_C_COMPILER "aarch64-linux-gnu-gcc") + set(CMAKE_CXX_COMPILER "aarch64-linux-gnu-g++") +endif() + +if(ARMLINUX_ARCH_ABI STREQUAL "armv7") + set(CMAKE_SYSTEM_PROCESSOR arm) + set(CMAKE_C_COMPILER "arm-linux-gnueabi-gcc") + set(CMAKE_CXX_COMPILER "arm-linux-gnueabi-g++") +endif() + +if(ARMLINUX_ARCH_ABI STREQUAL "armv7hf") + set(CMAKE_SYSTEM_PROCESSOR arm) + set(CMAKE_C_COMPILER "arm-linux-gnueabihf-gcc") + set(CMAKE_CXX_COMPILER "arm-linux-gnueabihf-g++") +endif() diff --git a/cmake/cross_compiling/findar.cmake b/cmake/cross_compiling/findar.cmake new file mode 100644 index 00000000000..bcb0dc70fd8 --- /dev/null +++ b/cmake/cross_compiling/findar.cmake @@ -0,0 +1,33 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT ARM_TARGET_LANG STREQUAL "clang") + # only clang need find ar tool + return() +endif() + +if(NOT EXISTS "${CMAKE_CXX_COMPILER}") + message(ERROR "Can not find CMAKE_CXX_COMPILER ${CMAKE_CXX_COMPILER}") +endif() + +get_filename_component(AR_PATH ${CMAKE_CXX_COMPILER} PATH) + +find_file(AR_TOOL NAMES llvm-ar PATHS ${AR_PATH}) + +if(NOT AR_TOOL) + message(ERROR "Failed to find AR_TOOL in ${AR_PATH}") +else() + set(CMAKE_AR ${AR_TOOL}) + message(STATUS "Found CMAKE_AR : " ${CMAKE_AR}) +endif() diff --git a/cmake/cross_compiling/host.cmake b/cmake/cross_compiling/host.cmake new file mode 100644 index 00000000000..b76dd600467 --- /dev/null +++ b/cmake/cross_compiling/host.cmake @@ -0,0 +1,48 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(HOST_C_COMPILER $ENV{CC}) +set(HOST_CXX_COMPILER $ENV{CXX}) + +if(IOS) + set(default_cc clang) + set(default_cxx clang++) +else() + set(default_cc gcc) + set(default_cxx g++) +endif() + +if(NOT HOST_C_COMPILER) + find_program(HOST_C_COMPILER NAMES ${default_cc} PATH + /usr/bin + /usr/local/bin) +endif() + +if(NOT HOST_CXX_COMPILER) + find_program(HOST_CXX_COMPILER NAMES ${default_cxx} PATH + /usr/bin + /usr/local/bin) +endif() + +if(NOT HOST_C_COMPILER OR NOT EXISTS ${HOST_C_COMPILER}) + MESSAGE(FATAL_ERROR "Cannot find host C compiler. export CC=/path/to/cc") +ENDIF() + +if(NOT HOST_CXX_COMPILER OR NOT EXISTS ${HOST_CXX_COMPILER}) + MESSAGE(FATAL_ERROR "Cannot find host C compiler. export CC=/path/to/cc") +ENDIF() + +MESSAGE(STATUS "Found host C compiler: " ${HOST_C_COMPILER}) +MESSAGE(STATUS "Found host CXX compiler: " ${HOST_CXX_COMPILER}) + diff --git a/cmake/cross_compiling/ios.cmake b/cmake/cross_compiling/ios.cmake new file mode 100644 index 00000000000..b8df182cd6d --- /dev/null +++ b/cmake/cross_compiling/ios.cmake @@ -0,0 +1,691 @@ +# This file is part of the ios-cmake project. It was retrieved from +# https://github.com/cristeab/ios-cmake.git, which is a fork of +# https://code.google.com/p/ios-cmake/. Which in turn is based off of +# the Platform/Darwin.cmake and Platform/UnixPaths.cmake files which +# are included with CMake 2.8.4 +# +# The ios-cmake project is licensed under the new BSD license. +# +# Copyright (c) 2014, Bogdan Cristea and LTE Engineering Software, +# Kitware, Inc., Insight Software Consortium. All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# This file is based off of the Platform/Darwin.cmake and +# Platform/UnixPaths.cmake files which are included with CMake 2.8.4 +# It has been altered for iOS development. +# +# Updated by Alex Stewart (alexs.mac@gmail.com) +# +# ***************************************************************************** +# Now maintained by Alexander Widerberg (widerbergaren [at] gmail.com) +# under the BSD-3-Clause license +# https://github.com/leetal/ios-cmake +# ***************************************************************************** +# +# INFORMATION / HELP +# +# The following arguments control the behaviour of this toolchain: +# +# PLATFORM: (default "OS") +# OS = Build for iPhoneOS. +# OS64 = Build for arm64 iphoneOS. +# OS64COMBINED = Build for arm64 x86_64 iphoneOS. Combined into FAT STATIC lib (supported on 3.14+ of CMakewith "-G Xcode" argument ONLY) +# SIMULATOR = Build for x86 i386 iphoneOS Simulator. +# SIMULATOR64 = Build for x86_64 iphoneOS Simulator. +# TVOS = Build for arm64 tvOS. +# TVOSCOMBINED = Build for arm64 x86_64 tvOS. Combined into FAT STATIC lib (supported on 3.14+ of CMake with "-G Xcode" argument ONLY) +# SIMULATOR_TVOS = Build for x86_64 tvOS Simulator. +# WATCHOS = Build for armv7k arm64_32 for watchOS. +# WATCHOSCOMBINED = Build for armv7k arm64_32 x86_64 watchOS. Combined into FAT STATIC lib (supported on 3.14+ of CMake with "-G Xcode" argument ONLY) +# SIMULATOR_WATCHOS = Build for x86_64 for watchOS Simulator. +# +# CMAKE_OSX_SYSROOT: Path to the SDK to use. By default this is +# automatically determined from PLATFORM and xcodebuild, but +# can also be manually specified (although this should not be required). +# +# CMAKE_DEVELOPER_ROOT: Path to the Developer directory for the platform +# being compiled for. By default this is automatically determined from +# CMAKE_OSX_SYSROOT, but can also be manually specified (although this should +# not be required). +# +# DEPLOYMENT_TARGET: Minimum SDK version to target. Default 2.0 on watchOS and 9.0 on tvOS+iOS +# +# ENABLE_BITCODE: (1|0) Enables or disables bitcode support. Default 1 (true) +# +# ENABLE_ARC: (1|0) Enables or disables ARC support. Default 1 (true, ARC enabled by default) +# +# ENABLE_VISIBILITY: (1|0) Enables or disables symbol visibility support. Default 0 (false, visibility hidden by default) +# +# ARCHS: (armv7 armv7s armv7k arm64 arm64_32 i386 x86_64) If specified, will override the default architectures for the given PLATFORM +# OS = armv7 armv7s arm64 (if applicable) +# OS64 = arm64 (if applicable) +# SIMULATOR = i386 +# SIMULATOR64 = x86_64 +# TVOS = arm64 +# SIMULATOR_TVOS = x86_64 (i386 has since long been deprecated) +# WATCHOS = armv7k arm64_32 (if applicable) +# SIMULATOR_WATCHOS = x86_64 (i386 has since long been deprecated) +# +# This toolchain defines the following variables for use externally: +# +# XCODE_VERSION: Version number (not including Build version) of Xcode detected. +# SDK_VERSION: Version of SDK being used. +# CMAKE_OSX_ARCHITECTURES: Architectures being compiled for (generated from PLATFORM). +# +# This toolchain defines the following macros for use externally: +# +# set_xcode_property (TARGET XCODE_PROPERTY XCODE_VALUE XCODE_VARIANT) +# A convenience macro for setting xcode specific properties on targets. +# Available variants are: All, Release, RelWithDebInfo, Debug, MinSizeRel +# example: set_xcode_property (myioslib IPHONEOS_DEPLOYMENT_TARGET "3.1" "all"). +# +# find_host_package (PROGRAM ARGS) +# A macro used to find executable programs on the host system, not within the +# environment. Thanks to the android-cmake project for providing the +# command. +# +# ******************************** DEPRECATIONS ******************************* +# +# IOS_DEPLOYMENT_TARGET: (Deprecated) Alias to DEPLOYMENT_TARGET +# CMAKE_IOS_DEVELOPER_ROOT: (Deprecated) Alias to CMAKE_DEVELOPER_ROOT +# IOS_PLATFORM: (Deprecated) Alias to PLATFORM +# IOS_ARCH: (Deprecated) Alias to ARCHS +# +# ***************************************************************************** +# + +## Lite settings +if (ARM_TARGET_OS STREQUAL "ios") + set(PLATFORM "OS") +elseif(ARM_TARGET_OS STREQUAL "ios64") + set(PLATFORM "OS64") +else() + return() +endif() + +# if do not specify the ARM_TARGET_ARCH_ABI then use default all supported +if(ARM_TARGET_ARCH_ABI STREQUAL "armv7" + OR ARM_TARGET_ARCH_ABI STREQUAL "armv7hf" + OR ARM_TARGET_ARCH_ABI STREQUAL "armeabi-v7a") + set(ARCHS "armv7") +elseif(ARM_TARGET_ARCH_ABI STREQUAL "armv8" + OR ARM_TARGET_ARCH_ABI STREQUAL "arm64-v8a") + set(ARCHS "arm64") +# else() all default choice: armv7 armv7s arm64 +endif() + +if(PLATFORM STREQUAL "OS64" AND ARCHS STREQUAL "armv7") + message(FATAL_ERROR "Can not build IOS64 with armv7") +endif() + +# TODO(xxx): enable omp on ios +set(LITE_WITH_OPENMP OFF CACHE STRING "Disable OpenMP when cross-compiling for Android and iOS" FORCE) +set(ARM_TARGET_LANG "clang" CACHE STRING "Force use clang on IOS" FORCE) + +add_definitions(-DLITE_WITH_IPHONE) +## End lite settings + +# Fix for PThread library not in path +set(CMAKE_THREAD_LIBS_INIT "-lpthread") +set(CMAKE_HAVE_THREADS_LIBRARY 1) +set(CMAKE_USE_WIN32_THREADS_INIT 0) +set(CMAKE_USE_PTHREADS_INIT 1) + +# Cache what generator is used +set(USED_CMAKE_GENERATOR "${CMAKE_GENERATOR}" CACHE STRING "Expose CMAKE_GENERATOR" FORCE) + +if(${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.14") + set(MODERN_CMAKE YES) + message(STATUS "Merging integrated CMake 3.14+ iOS,tvOS,watchOS,macOS toolchain(s) with this toolchain!") +endif() + +# Get the Xcode version being used. +execute_process(COMMAND xcodebuild -version + OUTPUT_VARIABLE XCODE_VERSION + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) +string(REGEX MATCH "Xcode [0-9\\.]+" XCODE_VERSION "${XCODE_VERSION}") +string(REGEX REPLACE "Xcode ([0-9\\.]+)" "\\1" XCODE_VERSION "${XCODE_VERSION}") +message(STATUS "Building with Xcode version: ${XCODE_VERSION}") + +######## ALIASES (DEPRECATION WARNINGS) + +if(DEFINED IOS_PLATFORM) + set(PLATFORM ${IOS_PLATFORM}) + message(DEPRECATION "IOS_PLATFORM argument is DEPRECATED. Consider using the new PLATFORM argument instead.") +endif() + +if(DEFINED IOS_DEPLOYMENT_TARGET) + set(DEPLOYMENT_TARGET ${IOS_DEPLOYMENT_TARGET}) + message(DEPRECATION "IOS_DEPLOYMENT_TARGET argument is DEPRECATED. Consider using the new DEPLOYMENT_TARGET argument instead.") +endif() + +if(DEFINED CMAKE_IOS_DEVELOPER_ROOT) + set(CMAKE_DEVELOPER_ROOT ${CMAKE_IOS_DEVELOPER_ROOT}) + message(DEPRECATION "CMAKE_IOS_DEVELOPER_ROOT argument is DEPRECATED. Consider using the new CMAKE_DEVELOPER_ROOT argument instead.") +endif() + +if(DEFINED IOS_ARCH) + set(ARCHS ${IOS_ARCH}) + message(DEPRECATION "IOS_ARCH argument is DEPRECATED. Consider using the new ARCHS argument instead.") +endif() + +######## END ALIASES + +# Unset the FORCE on cache variables if in try_compile() +set(FORCE_CACHE FORCE) +get_property(_CMAKE_IN_TRY_COMPILE GLOBAL PROPERTY IN_TRY_COMPILE) +if(_CMAKE_IN_TRY_COMPILE) + unset(FORCE_CACHE) +endif() + +# Default to building for iPhoneOS if not specified otherwise, and we cannot +# determine the platform from the CMAKE_OSX_ARCHITECTURES variable. The use +# of CMAKE_OSX_ARCHITECTURES is such that try_compile() projects can correctly +# determine the value of PLATFORM from the root project, as +# CMAKE_OSX_ARCHITECTURES is propagated to them by CMake. +if(NOT DEFINED PLATFORM) + if (CMAKE_OSX_ARCHITECTURES) + if(CMAKE_OSX_ARCHITECTURES MATCHES ".*arm.*" AND CMAKE_OSX_SYSROOT MATCHES ".*iphoneos.*") + set(PLATFORM "OS") + elseif(CMAKE_OSX_ARCHITECTURES MATCHES "i386" AND CMAKE_OSX_SYSROOT MATCHES ".*iphonesimulator.*") + set(PLATFORM "SIMULATOR") + elseif(CMAKE_OSX_ARCHITECTURES MATCHES "x86_64" AND CMAKE_OSX_SYSROOT MATCHES ".*iphonesimulator.*") + set(PLATFORM "SIMULATOR64") + elseif(CMAKE_OSX_ARCHITECTURES MATCHES "arm64" AND CMAKE_OSX_SYSROOT MATCHES ".*appletvos.*") + set(PLATFORM "TVOS") + elseif(CMAKE_OSX_ARCHITECTURES MATCHES "x86_64" AND CMAKE_OSX_SYSROOT MATCHES ".*appletvsimulator.*") + set(PLATFORM "SIMULATOR_TVOS") + elseif(CMAKE_OSX_ARCHITECTURES MATCHES ".*armv7k.*" AND CMAKE_OSX_SYSROOT MATCHES ".*watchos.*") + set(PLATFORM "WATCHOS") + elseif(CMAKE_OSX_ARCHITECTURES MATCHES "i386" AND CMAKE_OSX_SYSROOT MATCHES ".*watchsimulator.*") + set(PLATFORM "SIMULATOR_WATCHOS") + endif() + endif() + if (NOT PLATFORM) + set(PLATFORM "OS") + endif() +endif() + +set(PLATFORM_INT "${PLATFORM}" CACHE STRING "Type of platform for which the build targets.") + +# Handle the case where we are targeting iOS and a version above 10.0 (32-bit support dropped officially) +if(PLATFORM_INT STREQUAL "OS" AND DEPLOYMENT_TARGET VERSION_GREATER_EQUAL 10.0) + set(PLATFORM_INT "OS64") + message(STATUS "Targeting minimum SDK version ${DEPLOYMENT_TARGET}. Dropping 32-bit support.") +elseif(PLATFORM_INT STREQUAL "SIMULATOR" AND DEPLOYMENT_TARGET VERSION_GREATER_EQUAL 10.0) + set(PLATFORM_INT "SIMULATOR64") + message(STATUS "Targeting minimum SDK version ${DEPLOYMENT_TARGET}. Dropping 32-bit support.") +endif() + +# Determine the platform name and architectures for use in xcodebuild commands +# from the specified PLATFORM name. +if(PLATFORM_INT STREQUAL "OS") + set(SDK_NAME iphoneos) + if(NOT ARCHS) + set(ARCHS armv7 armv7s arm64) + endif() +elseif(PLATFORM_INT STREQUAL "OS64") + set(SDK_NAME iphoneos) + if(NOT ARCHS) + if (XCODE_VERSION VERSION_GREATER 10.0) + set(ARCHS arm64) # Add arm64e when Apple have fixed the integration issues with it, libarclite_iphoneos.a is currently missung bitcode markers for example + else() + set(ARCHS arm64) + endif() + endif() +elseif(PLATFORM_INT STREQUAL "OS64COMBINED") + set(SDK_NAME iphoneos) + if(MODERN_CMAKE) + if(NOT ARCHS) + if (XCODE_VERSION VERSION_GREATER 10.0) + set(ARCHS arm64 x86_64) # Add arm64e when Apple have fixed the integration issues with it, libarclite_iphoneos.a is currently missung bitcode markers for example + else() + set(ARCHS arm64 x86_64) + endif() + endif() + else() + message(FATAL_ERROR "Please make sure that you are running CMake 3.14+ to make the OS64COMBINED setting work") + endif() +elseif(PLATFORM_INT STREQUAL "SIMULATOR") + set(SDK_NAME iphonesimulator) + if(NOT ARCHS) + set(ARCHS i386) + endif() + message(DEPRECATION "SIMULATOR IS DEPRECATED. Consider using SIMULATOR64 instead.") +elseif(PLATFORM_INT STREQUAL "SIMULATOR64") + set(SDK_NAME iphonesimulator) + if(NOT ARCHS) + set(ARCHS x86_64) + endif() +elseif(PLATFORM_INT STREQUAL "TVOS") + set(SDK_NAME appletvos) + if(NOT ARCHS) + set(ARCHS arm64) + endif() +elseif (PLATFORM_INT STREQUAL "TVOSCOMBINED") + set(SDK_NAME appletvos) + if(MODERN_CMAKE) + if(NOT ARCHS) + set(ARCHS arm64 x86_64) + endif() + else() + message(FATAL_ERROR "Please make sure that you are running CMake 3.14+ to make the TVOSCOMBINED setting work") + endif() +elseif(PLATFORM_INT STREQUAL "SIMULATOR_TVOS") + set(SDK_NAME appletvsimulator) + if(NOT ARCHS) + set(ARCHS x86_64) + endif() +elseif(PLATFORM_INT STREQUAL "WATCHOS") + set(SDK_NAME watchos) + if(NOT ARCHS) + if (XCODE_VERSION VERSION_GREATER 10.0) + set(ARCHS armv7k arm64_32) + else() + set(ARCHS armv7k) + endif() + endif() +elseif(PLATFORM_INT STREQUAL "WATCHOSCOMBINED") + set(SDK_NAME watchos) + if(MODERN_CMAKE) + if(NOT ARCHS) + if (XCODE_VERSION VERSION_GREATER 10.0) + set(ARCHS armv7k arm64_32 i386) + else() + set(ARCHS armv7k i386) + endif() + endif() + else() + message(FATAL_ERROR "Please make sure that you are running CMake 3.14+ to make the WATCHOSCOMBINED setting work") + endif() +elseif(PLATFORM_INT STREQUAL "SIMULATOR_WATCHOS") + set(SDK_NAME watchsimulator) + if(NOT ARCHS) + set(ARCHS i386) + endif() +else() + message(FATAL_ERROR "Invalid PLATFORM: ${PLATFORM_INT}") +endif() +message(STATUS "Configuring ${SDK_NAME} build for platform: ${PLATFORM_INT}, architecture(s): ${ARCHS}") + +if(MODERN_CMAKE AND PLATFORM_INT MATCHES ".*COMBINED" AND NOT USED_CMAKE_GENERATOR MATCHES "Xcode") + message(FATAL_ERROR "The COMBINED options only work with Xcode generator, -G Xcode") +endif() + +# If user did not specify the SDK root to use, then query xcodebuild for it. +execute_process(COMMAND xcodebuild -version -sdk ${SDK_NAME} Path + OUTPUT_VARIABLE CMAKE_OSX_SYSROOT_INT + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) +if (NOT DEFINED CMAKE_OSX_SYSROOT_INT AND NOT DEFINED CMAKE_OSX_SYSROOT) + message(SEND_ERROR "Please make sure that Xcode is installed and that the toolchain" + "is pointing to the correct path. Please run:" + "sudo xcode-select -s /Applications/Xcode.app/Contents/Developer" + "and see if that fixes the problem for you.") + message(FATAL_ERROR "Invalid CMAKE_OSX_SYSROOT: ${CMAKE_OSX_SYSROOT} " + "does not exist.") +elseif(DEFINED CMAKE_OSX_SYSROOT) + message(STATUS "Using SDK: ${CMAKE_OSX_SYSROOT} for platform: ${PLATFORM_INT} when checking compatibility") +elseif(DEFINED CMAKE_OSX_SYSROOT_INT) + message(STATUS "Using SDK: ${CMAKE_OSX_SYSROOT_INT} for platform: ${PLATFORM_INT}") + set(CMAKE_OSX_SYSROOT "${CMAKE_OSX_SYSROOT_INT}" CACHE INTERNAL "") +endif() + +# Set Xcode property for SDKROOT as well if Xcode generator is used +if(USED_CMAKE_GENERATOR MATCHES "Xcode") + set(CMAKE_OSX_SYSROOT "${SDK_NAME}" CACHE INTERNAL "") + if(NOT DEFINED CMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM) + set(CMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM 123456789A CACHE INTERNAL "") + endif() +endif() + +# Specify minimum version of deployment target. +if(NOT DEFINED DEPLOYMENT_TARGET) + if (PLATFORM_INT STREQUAL "WATCHOS" OR PLATFORM_INT STREQUAL "SIMULATOR_WATCHOS") + # Unless specified, SDK version 2.0 is used by default as minimum target version (watchOS). + set(DEPLOYMENT_TARGET "2.0" + CACHE STRING "Minimum SDK version to build for." ) + else() + # Unless specified, SDK version 9.0 is used by default as minimum target version (iOS, tvOS). + set(DEPLOYMENT_TARGET "9.0" + CACHE STRING "Minimum SDK version to build for." ) + endif() + message(STATUS "Using the default min-version since DEPLOYMENT_TARGET not provided!") +endif() +# Use bitcode or not +if(NOT DEFINED ENABLE_BITCODE AND NOT ARCHS MATCHES "((^|, )(i386|x86_64))+") + # Unless specified, enable bitcode support by default + message(STATUS "Enabling bitcode support by default. ENABLE_BITCODE not provided!") + set(ENABLE_BITCODE TRUE) +elseif(NOT DEFINED ENABLE_BITCODE) + message(STATUS "Disabling bitcode support by default on simulators. ENABLE_BITCODE not provided for override!") + set(ENABLE_BITCODE FALSE) +endif() +set(ENABLE_BITCODE_INT ${ENABLE_BITCODE} CACHE BOOL "Whether or not to enable bitcode" ${FORCE_CACHE}) +# Use ARC or not +if(NOT DEFINED ENABLE_ARC) + # Unless specified, enable ARC support by default + set(ENABLE_ARC TRUE) + message(STATUS "Enabling ARC support by default. ENABLE_ARC not provided!") +endif() +set(ENABLE_ARC_INT ${ENABLE_ARC} CACHE BOOL "Whether or not to enable ARC" ${FORCE_CACHE}) +# Use hidden visibility or not +if(NOT DEFINED ENABLE_VISIBILITY) + # Unless specified, disable symbols visibility by default + set(ENABLE_VISIBILITY FALSE) + message(STATUS "Hiding symbols visibility by default. ENABLE_VISIBILITY not provided!") +endif() +set(ENABLE_VISIBILITY_INT ${ENABLE_VISIBILITY} CACHE BOOL "Whether or not to hide symbols (-fvisibility=hidden)" ${FORCE_CACHE}) +# Get the SDK version information. +execute_process(COMMAND xcodebuild -sdk ${CMAKE_OSX_SYSROOT} -version SDKVersion + OUTPUT_VARIABLE SDK_VERSION + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) + +# Find the Developer root for the specific iOS platform being compiled for +# from CMAKE_OSX_SYSROOT. Should be ../../ from SDK specified in +# CMAKE_OSX_SYSROOT. There does not appear to be a direct way to obtain +# this information from xcrun or xcodebuild. +if (NOT DEFINED CMAKE_DEVELOPER_ROOT AND NOT USED_CMAKE_GENERATOR MATCHES "Xcode") + get_filename_component(PLATFORM_SDK_DIR ${CMAKE_OSX_SYSROOT} PATH) + get_filename_component(CMAKE_DEVELOPER_ROOT ${PLATFORM_SDK_DIR} PATH) + + if (NOT DEFINED CMAKE_DEVELOPER_ROOT) + message(FATAL_ERROR "Invalid CMAKE_DEVELOPER_ROOT: " + "${CMAKE_DEVELOPER_ROOT} does not exist.") + endif() +endif() +# Find the C & C++ compilers for the specified SDK. +if(NOT CMAKE_C_COMPILER) + execute_process(COMMAND xcrun -sdk ${CMAKE_OSX_SYSROOT} -find clang + OUTPUT_VARIABLE CMAKE_C_COMPILER + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) + message(STATUS "Using C compiler: ${CMAKE_C_COMPILER}") +endif() +if(NOT CMAKE_CXX_COMPILER) + execute_process(COMMAND xcrun -sdk ${CMAKE_OSX_SYSROOT} -find clang++ + OUTPUT_VARIABLE CMAKE_CXX_COMPILER + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) + message(STATUS "Using CXX compiler: ${CMAKE_CXX_COMPILER}") +endif() +# Find (Apple's) libtool. +execute_process(COMMAND xcrun -sdk ${CMAKE_OSX_SYSROOT} -find libtool + OUTPUT_VARIABLE BUILD_LIBTOOL + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) +message(STATUS "Using libtool: ${BUILD_LIBTOOL}") +# Configure libtool to be used instead of ar + ranlib to build static libraries. +# This is required on Xcode 7+, but should also work on previous versions of +# Xcode. +set(CMAKE_C_CREATE_STATIC_LIBRARY + "${BUILD_LIBTOOL} -static -o ") +set(CMAKE_CXX_CREATE_STATIC_LIBRARY + "${BUILD_LIBTOOL} -static -o ") +# Get the version of Darwin (OS X) of the host. +execute_process(COMMAND uname -r + OUTPUT_VARIABLE CMAKE_HOST_SYSTEM_VERSION + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) +# CMake 3.14+ support building for iOS, watchOS and tvOS out of the box. +if(MODERN_CMAKE) + if(SDK_NAME MATCHES "iphone") + set(CMAKE_SYSTEM_NAME iOS CACHE INTERNAL "" ${FORCE_CACHE}) + elseif(SDK_NAME MATCHES "appletv") + set(CMAKE_SYSTEM_NAME tvOS CACHE INTERNAL "" ${FORCE_CACHE}) + elseif(SDK_NAME MATCHES "watch") + set(CMAKE_SYSTEM_NAME watchOS CACHE INTERNAL "" ${FORCE_CACHE}) + endif() + + # Provide flags for a combined FAT library build on newer CMake versions + if(PLATFORM_INT MATCHES ".*COMBINED") + set(CMAKE_XCODE_ATTRIBUTE_ONLY_ACTIVE_ARCH NO CACHE INTERNAL "") + set(CMAKE_IOS_INSTALL_COMBINED YES CACHE INTERNAL "") + message(STATUS "Will combine built (static) artifacts into FAT lib...") + endif() +else() + # Legacy code path prior to CMake 3.14 + set(CMAKE_SYSTEM_NAME Darwin CACHE INTERNAL "" ${FORCE_CACHE}) +endif() +# Standard settings. +set(CMAKE_SYSTEM_VERSION ${SDK_VERSION} CACHE INTERNAL "") +set(UNIX TRUE CACHE BOOL "") +set(APPLE TRUE CACHE BOOL "") +set(IOS TRUE CACHE BOOL "") +set(CMAKE_AR ar CACHE FILEPATH "" FORCE) +set(CMAKE_RANLIB ranlib CACHE FILEPATH "" FORCE) +set(CMAKE_STRIP strip CACHE FILEPATH "" FORCE) +# Set the architectures for which to build. +set(CMAKE_OSX_ARCHITECTURES ${ARCHS} CACHE STRING "Build architecture for iOS") +# Change the type of target generated for try_compile() so it'll work when cross-compiling +set(CMAKE_TRY_COMPILE_TARGET_TYPE STATIC_LIBRARY) +# All iOS/Darwin specific settings - some may be redundant. +set(CMAKE_SHARED_LIBRARY_PREFIX "lib") +set(CMAKE_SHARED_LIBRARY_SUFFIX ".dylib") +set(CMAKE_SHARED_MODULE_PREFIX "lib") +set(CMAKE_SHARED_MODULE_SUFFIX ".so") +set(CMAKE_C_COMPILER_ABI ELF) +set(CMAKE_CXX_COMPILER_ABI ELF) +set(CMAKE_C_HAS_ISYSROOT 1) +set(CMAKE_CXX_HAS_ISYSROOT 1) +set(CMAKE_MODULE_EXISTS 1) +set(CMAKE_DL_LIBS "") +set(CMAKE_C_OSX_COMPATIBILITY_VERSION_FLAG "-compatibility_version ") +set(CMAKE_C_OSX_CURRENT_VERSION_FLAG "-current_version ") +set(CMAKE_CXX_OSX_COMPATIBILITY_VERSION_FLAG "${CMAKE_C_OSX_COMPATIBILITY_VERSION_FLAG}") +set(CMAKE_CXX_OSX_CURRENT_VERSION_FLAG "${CMAKE_C_OSX_CURRENT_VERSION_FLAG}") + +if(ARCHS MATCHES "((^|, )(arm64|arm64e|x86_64))+") + set(CMAKE_C_SIZEOF_DATA_PTR 8) + set(CMAKE_CXX_SIZEOF_DATA_PTR 8) + if(ARCHS MATCHES "((^|, )(arm64|arm64e))+") + set(CMAKE_SYSTEM_PROCESSOR "arm64") + else() + set(CMAKE_SYSTEM_PROCESSOR "x86_64") + endif() + message(STATUS "Using a data_ptr size of 8") +else() + set(CMAKE_C_SIZEOF_DATA_PTR 4) + set(CMAKE_CXX_SIZEOF_DATA_PTR 4) + set(CMAKE_SYSTEM_PROCESSOR "arm") + message(STATUS "Using a data_ptr size of 4") +endif() + +message(STATUS "Building for minimum ${SDK_NAME} version: ${DEPLOYMENT_TARGET}" + " (SDK version: ${SDK_VERSION})") +# Note that only Xcode 7+ supports the newer more specific: +# -m${SDK_NAME}-version-min flags, older versions of Xcode use: +# -m(ios/ios-simulator)-version-min instead. +if(PLATFORM_INT STREQUAL "OS" OR PLATFORM_INT STREQUAL "OS64") + if(XCODE_VERSION VERSION_LESS 7.0) + set(SDK_NAME_VERSION_FLAGS + "-mios-version-min=${DEPLOYMENT_TARGET}") + else() + # Xcode 7.0+ uses flags we can build directly from SDK_NAME. + set(SDK_NAME_VERSION_FLAGS + "-m${SDK_NAME}-version-min=${DEPLOYMENT_TARGET}") + endif() +elseif(PLATFORM_INT STREQUAL "TVOS") + set(SDK_NAME_VERSION_FLAGS + "-mtvos-version-min=${DEPLOYMENT_TARGET}") +elseif(PLATFORM_INT STREQUAL "SIMULATOR_TVOS") + set(SDK_NAME_VERSION_FLAGS + "-mtvos-simulator-version-min=${DEPLOYMENT_TARGET}") +elseif(PLATFORM_INT STREQUAL "WATCHOS") + set(SDK_NAME_VERSION_FLAGS + "-mwatchos-version-min=${DEPLOYMENT_TARGET}") +elseif(PLATFORM_INT STREQUAL "SIMULATOR_WATCHOS") + set(SDK_NAME_VERSION_FLAGS + "-mwatchos-simulator-version-min=${DEPLOYMENT_TARGET}") +else() + # SIMULATOR or SIMULATOR64 both use -mios-simulator-version-min. + set(SDK_NAME_VERSION_FLAGS + "-mios-simulator-version-min=${DEPLOYMENT_TARGET}") +endif() +message(STATUS "Version flags set to: ${SDK_NAME_VERSION_FLAGS}") +set(CMAKE_OSX_DEPLOYMENT_TARGET ${DEPLOYMENT_TARGET} CACHE STRING + "Set CMake deployment target" ${FORCE_CACHE}) + +if(ENABLE_BITCODE_INT) + set(BITCODE "-fembed-bitcode") + set(CMAKE_XCODE_ATTRIBUTE_BITCODE_GENERATION_MODE bitcode CACHE INTERNAL "") + message(STATUS "Enabling bitcode support.") +else() + set(BITCODE "") + set(CMAKE_XCODE_ATTRIBUTE_ENABLE_BITCODE NO CACHE INTERNAL "") + message(STATUS "Disabling bitcode support.") +endif() + +if(ENABLE_ARC_INT) + set(FOBJC_ARC "-fobjc-arc") + set(CMAKE_XCODE_ATTRIBUTE_CLANG_ENABLE_OBJC_ARC YES CACHE INTERNAL "") + message(STATUS "Enabling ARC support.") +else() + set(FOBJC_ARC "-fno-objc-arc") + set(CMAKE_XCODE_ATTRIBUTE_CLANG_ENABLE_OBJC_ARC NO CACHE INTERNAL "") + message(STATUS "Disabling ARC support.") +endif() + +if(NOT ENABLE_VISIBILITY_INT) + set(VISIBILITY "-fvisibility=hidden") + set(CMAKE_XCODE_ATTRIBUTE_GCC_SYMBOLS_PRIVATE_EXTERN YES CACHE INTERNAL "") + message(STATUS "Hiding symbols (-fvisibility=hidden).") +else() + set(VISIBILITY "") + set(CMAKE_XCODE_ATTRIBUTE_GCC_SYMBOLS_PRIVATE_EXTERN NO CACHE INTERNAL "") +endif() + +#Check if Xcode generator is used, since that will handle these flags automagically +if(USED_CMAKE_GENERATOR MATCHES "Xcode") + message(STATUS "Not setting any manual command-line buildflags, since Xcode is selected as generator.") +else() + set(CMAKE_C_FLAGS + "${SDK_NAME_VERSION_FLAGS} ${BITCODE} -fobjc-abi-version=2 ${FOBJC_ARC} ${CMAKE_C_FLAGS}") + # Hidden visibilty is required for C++ on iOS. + set(CMAKE_CXX_FLAGS + "${SDK_NAME_VERSION_FLAGS} ${BITCODE} ${VISIBILITY} -fvisibility-inlines-hidden -fobjc-abi-version=2 ${FOBJC_ARC} ${CMAKE_CXX_FLAGS}") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS} -O0 -g ${CMAKE_CXX_FLAGS_DEBUG}") + set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS} -DNDEBUG -Os -ffast-math ${CMAKE_CXX_FLAGS_MINSIZEREL}") + set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS} -DNDEBUG -O2 -g -ffast-math ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS} -DNDEBUG -O3 -ffast-math ${CMAKE_CXX_FLAGS_RELEASE}") + set(CMAKE_C_LINK_FLAGS "${SDK_NAME_VERSION_FLAGS} -Wl,-search_paths_first ${CMAKE_C_LINK_FLAGS}") + set(CMAKE_CXX_LINK_FLAGS "${SDK_NAME_VERSION_FLAGS} -Wl,-search_paths_first ${CMAKE_CXX_LINK_FLAGS}") + + # In order to ensure that the updated compiler flags are used in try_compile() + # tests, we have to forcibly set them in the CMake cache, not merely set them + # in the local scope. + list(APPEND VARS_TO_FORCE_IN_CACHE + CMAKE_C_FLAGS + CMAKE_CXX_FLAGS + CMAKE_CXX_FLAGS_DEBUG + CMAKE_CXX_FLAGS_RELWITHDEBINFO + CMAKE_CXX_FLAGS_MINSIZEREL + CMAKE_CXX_FLAGS_RELEASE + CMAKE_C_LINK_FLAGS + CMAKE_CXX_LINK_FLAGS) + foreach(VAR_TO_FORCE ${VARS_TO_FORCE_IN_CACHE}) + set(${VAR_TO_FORCE} "${${VAR_TO_FORCE}}" CACHE STRING "") + endforeach() +endif() + +set(CMAKE_PLATFORM_HAS_INSTALLNAME 1) +set(CMAKE_SHARED_LINKER_FLAGS "-rpath @executable_path/Frameworks -rpath @loader_path/Frameworks") +set(CMAKE_SHARED_LIBRARY_CREATE_C_FLAGS "-dynamiclib -Wl,-headerpad_max_install_names") +set(CMAKE_SHARED_MODULE_CREATE_C_FLAGS "-bundle -Wl,-headerpad_max_install_names") +set(CMAKE_SHARED_MODULE_LOADER_C_FLAG "-Wl,-bundle_loader,") +set(CMAKE_SHARED_MODULE_LOADER_CXX_FLAG "-Wl,-bundle_loader,") +set(CMAKE_FIND_LIBRARY_SUFFIXES ".tbd" ".dylib" ".so" ".a") +set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-install_name") + +# Hack: if a new cmake (which uses CMAKE_INSTALL_NAME_TOOL) runs on an old +# build tree (where install_name_tool was hardcoded) and where +# CMAKE_INSTALL_NAME_TOOL isn't in the cache and still cmake didn't fail in +# CMakeFindBinUtils.cmake (because it isn't rerun) hardcode +# CMAKE_INSTALL_NAME_TOOL here to install_name_tool, so it behaves as it did +# before, Alex. +if(NOT DEFINED CMAKE_INSTALL_NAME_TOOL) + find_program(CMAKE_INSTALL_NAME_TOOL install_name_tool) +endif(NOT DEFINED CMAKE_INSTALL_NAME_TOOL) + +# Set the find root to the iOS developer roots and to user defined paths. +set(CMAKE_FIND_ROOT_PATH ${CMAKE_DEVELOPER_ROOT} ${CMAKE_OSX_SYSROOT_INT} + ${CMAKE_PREFIX_PATH} CACHE STRING "Root path that will be prepended to all search paths") +# Default to searching for frameworks first. +set(CMAKE_FIND_FRAMEWORK FIRST) +# Set up the default search directories for frameworks. +set(CMAKE_FRAMEWORK_PATH + ${CMAKE_DEVELOPER_ROOT}/Library/Frameworks + ${CMAKE_DEVELOPER_ROOT}/Library/PrivateFrameworks + ${CMAKE_OSX_SYSROOT_INT}/System/Library/Frameworks + ${CMAKE_FRAMEWORK_PATH} CACHE STRING "Frameworks search paths") + +# By default, search both the specified iOS SDK and the remainder of the host filesystem. +if(NOT CMAKE_FIND_ROOT_PATH_MODE_PROGRAM) + set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM BOTH CACHE STRING "" ${FORCE_CACHE}) +endif() +if(NOT CMAKE_FIND_ROOT_PATH_MODE_LIBRARY) + set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH CACHE STRING "" ${FORCE_CACHE}) +endif() +if(NOT CMAKE_FIND_ROOT_PATH_MODE_INCLUDE) + set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH CACHE STRING "" ${FORCE_CACHE}) +endif() +if(NOT CMAKE_FIND_ROOT_PATH_MODE_PACKAGE) + set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH CACHE STRING "" ${FORCE_CACHE}) +endif() + +# +# Some helper-macros below to simplify and beautify the CMakeFile +# + +# This little macro lets you set any XCode specific property. +macro(set_xcode_property TARGET XCODE_PROPERTY XCODE_VALUE XCODE_RELVERSION) + set(XCODE_RELVERSION_I "${XCODE_RELVERSION}") + if(XCODE_RELVERSION_I STREQUAL "All") + set_property(TARGET ${TARGET} PROPERTY + XCODE_ATTRIBUTE_${XCODE_PROPERTY} "${XCODE_VALUE}") + else() + set_property(TARGET ${TARGET} PROPERTY + XCODE_ATTRIBUTE_${XCODE_PROPERTY}[variant=${XCODE_RELVERSION_I}] "${XCODE_VALUE}") + endif() +endmacro(set_xcode_property) +# This macro lets you find executable programs on the host system. +macro(find_host_package) + set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) + set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY NEVER) + set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE NEVER) + set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE NEVER) + set(IOS FALSE) + find_package(${ARGN}) + set(IOS TRUE) + set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM BOTH) + set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH) + set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH) + set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH) +endmacro(find_host_package) diff --git a/cmake/cross_compiling/npu.cmake b/cmake/cross_compiling/npu.cmake new file mode 100644 index 00000000000..ec5a3188967 --- /dev/null +++ b/cmake/cross_compiling/npu.cmake @@ -0,0 +1,70 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT LITE_WITH_NPU) + return() +endif() + +if(NOT DEFINED NPU_DDK_ROOT) + set(NPU_DDK_ROOT $ENV{NPU_DDK_ROOT}) + if(NOT NPU_DDK_ROOT) + message(FATAL_ERROR "Must set NPU_DDK_ROOT or env NPU_DDK_ROOT when LITE_WITH_NPU=ON") + endif() +endif() + +message(STATUS "NPU_DDK_ROOT: ${NPU_DDK_ROOT}") +find_path(NPU_DDK_INC NAMES HiAiModelManagerService.h + PATHS ${NPU_DDK_ROOT}/include NO_DEFAULT_PATH) +if(NOT NPU_DDK_INC) + message(FATAL_ERROR "Can not find HiAiModelManagerService.h in ${NPU_DDK_ROOT}/include") +endif() + +include_directories("${NPU_DDK_ROOT}") + +find_library(NPU_DDK_HIAI_FILE NAMES hiai + PATHS ${NPU_DDK_ROOT}/lib64) + +find_library(NPU_DDK_IR_FILE NAMES hiai_ir + PATHS ${NPU_DDK_ROOT}/lib64) + +find_library(NPU_DDK_IR_BUILD_FILE NAMES hiai_ir_build + PATHS ${NPU_DDK_ROOT}/lib64) + +if(NOT NPU_DDK_HIAI_FILE) + message(FATAL_ERROR "Can not find NPU_DDK_HIAI_FILE in ${NPU_DDK_ROOT}") +else() + message(STATUS "Found NPU_DDK HIAI Library: ${NPU_DDK_HIAI_FILE}") + add_library(npu_ddk_hiai SHARED IMPORTED GLOBAL) + set_property(TARGET npu_ddk_hiai PROPERTY IMPORTED_LOCATION ${NPU_DDK_HIAI_FILE}) +endif() + +if(NOT NPU_DDK_IR_FILE) + message(FATAL_ERROR "Can not find NPU_DDK_IR_FILE in ${NPU_DDK_ROOT}") +else() + message(STATUS "Found NPU_DDK IR Library: ${NPU_DDK_IR_FILE}") + add_library(npu_ddk_ir SHARED IMPORTED GLOBAL) + set_property(TARGET npu_ddk_ir PROPERTY IMPORTED_LOCATION ${NPU_DDK_IR_FILE}) +endif() + +if(NOT NPU_DDK_IR_BUILD_FILE) + message(FATAL_ERROR "Can not find NPU_DDK_IR_BUILD_FILE in ${NPU_DDK_ROOT}") +else() + message(STATUS "Found NPU_DDK IR_BUILD Library: ${NPU_DDK_IR_BUILD_FILE}") + add_library(npu_ddk_ir_build SHARED IMPORTED GLOBAL) + set_property(TARGET npu_ddk_ir_build PROPERTY IMPORTED_LOCATION ${NPU_DDK_IR_BUILD_FILE}) +endif() + +set(npu_ddk_libs npu_ddk_hiai npu_ddk_ir npu_ddk_ir_build CACHE INTERNAL "npu ddk libs") + + diff --git a/cmake/cross_compiling/postproject.cmake b/cmake/cross_compiling/postproject.cmake new file mode 100644 index 00000000000..33254df03c4 --- /dev/null +++ b/cmake/cross_compiling/postproject.cmake @@ -0,0 +1,99 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + return() +endif() + +include(CheckCXXCompilerFlag) + +if(ANDROID) + include(cross_compiling/findar) + + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -llog -fPIC") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -llog -fPIC") +endif() + +if(ARMLINUX) + if(ARMLINUX_ARCH_ABI STREQUAL "armv8") + set(CMAKE_CXX_FLAGS "-march=armv8-a ${CMAKE_CXX_FLAGS}") + set(CMAKE_C_FLAGS "-march=armv8-a ${CMAKE_C_FLAGS}") + message(STATUS "NEON is enabled on arm64-v8a") + endif() + + if(ARMLINUX_ARCH_ABI STREQUAL "armv7") + set(CMAKE_CXX_FLAGS "-march=armv7-a -mfloat-abi=softfp -mfpu=neon-vfpv4 ${CMAKE_CXX_FLAGS}") + set(CMAKE_C_FLAGS "-march=armv7-a -mfloat-abi=softfp -mfpu=neon-vfpv4 ${CMAKE_C_FLAGS}") + message(STATUS "NEON is enabled on arm-v7a with softfp") + endif() + + if(ARMLINUX_ARCH_ABI STREQUAL "armv7hf") + set(CMAKE_CXX_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon-vfpv4 ${CMAKE_CXX_FLAGS}") + set(CMAKE_C_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon-vfpv4 ${CMAKE_C_FLAGS}" ) + message(STATUS "NEON is enabled on arm-v7a with hard float") + endif() +endif() + +function(check_linker_flag) + foreach(flag ${ARGN}) + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${flag}") + check_cxx_compiler_flag("" out_var) + if(${out_var}) + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${flag}") + endif() + endforeach() + set(CMAKE_SHARED_LINKER_FLAGS ${CMAKE_SHARED_LINKER_FLAGS} PARENT_SCOPE) +endfunction() +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") +if (LITE_ON_TINY_PUBLISH) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ffast-math -Ofast -Os -fno-exceptions -fomit-frame-pointer -fno-asynchronous-unwind-tables -fno-unwind-tables") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -flto -fvisibility=hidden -fvisibility-inlines-hidden -fdata-sections -ffunction-sections") + check_linker_flag(-Wl,--gc-sections) +endif() + +if(LITE_WITH_OPENMP) + find_package(OpenMP REQUIRED) + if(OPENMP_FOUND OR OpenMP_CXX_FOUND) + add_definitions(-DARM_WITH_OMP) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + message(STATUS "Found OpenMP ${OpenMP_VERSION} ${OpenMP_CXX_VERSION}") + message(STATUS "OpenMP C flags: ${OpenMP_C_FLAGS}") + message(STATUS "OpenMP CXX flags: ${OpenMP_CXX_FLAGS}") + message(STATUS "OpenMP OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}") + message(STATUS "OpenMP OpenMP_CXX_LIBRARIES: ${OpenMP_CXX_LIBRARIES}") + else() + message(FATAL_ERROR "Could not found OpenMP!") + endif() +endif() + +# third party cmake args +set(CROSS_COMPILE_CMAKE_ARGS + "-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}" + "-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}") + +if(ANDROID) + set(CROSS_COMPILE_CMAKE_ARGS ${CROSS_COMPILE_CMAKE_ARGS} + "-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}" + "-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}" + "-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}" + "-DCMAKE_ANDROID_NDK_TOOLCHAIN_VERSION=${CMAKE_ANDROID_NDK_TOOLCHAIN_VERSION}") +endif() + +if(IOS) + set(CROSS_COMPILE_CMAKE_ARGS ${CROSS_COMPILE_CMAKE_ARGS} + "-DCMAKE_OSX_ARCHITECTURES=${CMAKE_OSX_ARCHITECTURES}" + "-DCMAKE_SYSTEM_PROCESSOR=${CMAKE_SYSTEM_PROCESSOR}" + "-DCMAKE_OSX_SYSROOT=${CMAKE_OSX_SYSROOT}") +endif() diff --git a/cmake/cross_compiling/preproject.cmake b/cmake/cross_compiling/preproject.cmake new file mode 100644 index 00000000000..813d1910fcf --- /dev/null +++ b/cmake/cross_compiling/preproject.cmake @@ -0,0 +1,59 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + return() +endif() + +cmake_minimum_required(VERSION 3.10) + +# define check function +function(check_input_var VAR_NAME) + set(options "") + set(oneValueArgs "") + set(multiValueArgs DEFAULT LIST) + cmake_parse_arguments(check_input_var "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(var_out "") + if(NOT DEFINED ${VAR_NAME}) + set(var_out ${check_input_var_DEFAULT}) + else() + set(var_out ${${VAR_NAME}}) + endif() + + if(NOT var_out IN_LIST check_input_var_LIST) + message(FATAL_ERROR "${VAR_NAME}:${var_out} must be in one of ${check_input_var_LIST}") + endif() + set(${VAR_NAME} ${var_out} PARENT_SCOPE) +endfunction(check_input_var) + +check_input_var(ARM_TARGET_OS DEFAULT "android" LIST "android" "armlinux" "ios" "ios64") +check_input_var(ARM_TARGET_ARCH_ABI DEFAULT "armv8" LIST "armv8" "armv7" "armv7hf" "arm64-v8a" "armeabi-v7a") +check_input_var(ARM_TARGET_LANG DEFAULT "gcc" LIST "gcc" "clang") +check_input_var(ARM_TARGET_LIB_TYPE DEFAULT "static" LIST "static" "shared") + +include(cross_compiling/armlinux) +include(cross_compiling/android) +include(cross_compiling/ios) +include(cross_compiling/host) + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Default use Release in android" FORCE) +endif() + +if(NOT THIRD_PARTY_BUILD_TYPE) + set(THIRD_PARTY_BUILD_TYPE "MinSizeRel" CACHE STRING "Default use MinSizeRel in android" FORCE) +endif() + +message(STATUS "Lite ARM Compile ${ARM_TARGET_OS} with ${ARM_TARGET_ARCH_ABI} ${ARM_TARGET_LANG}") diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake new file mode 100644 index 00000000000..735846db1db --- /dev/null +++ b/cmake/cuda.cmake @@ -0,0 +1,228 @@ +if(NOT WITH_GPU) + return() +endif() + +set(paddle_known_gpu_archs "30 35 50 52 60 61 70") +set(paddle_known_gpu_archs7 "30 35 50 52") +set(paddle_known_gpu_archs8 "30 35 50 52 60 61") +set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70") +set(paddle_known_gpu_archs10 "30 35 50 52 60 61 70 75") + +###################################################################################### +# A function for automatic detection of GPUs installed (if autodetection is enabled) +# Usage: +# detect_installed_gpus(out_variable) +function(detect_installed_gpus out_variable) + if(NOT CUDA_gpu_detect_output) + set(cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu) + + file(WRITE ${cufile} "" + "#include \n" + "int main() {\n" + " int count = 0;\n" + " if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n" + " if (count == 0) return -1;\n" + " for (int device = 0; device < count; ++device) {\n" + " cudaDeviceProp prop;\n" + " if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n" + " std::printf(\"%d.%d \", prop.major, prop.minor);\n" + " }\n" + " return 0;\n" + "}\n") + + execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" "-ccbin=${CUDA_HOST_COMPILER}" + "--run" "${cufile}" + WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/" + RESULT_VARIABLE nvcc_res OUTPUT_VARIABLE nvcc_out + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(nvcc_res EQUAL 0) + # only keep the last line of nvcc_out + STRING(REGEX REPLACE ";" "\\\\;" nvcc_out "${nvcc_out}") + STRING(REGEX REPLACE "\n" ";" nvcc_out "${nvcc_out}") + list(GET nvcc_out -1 nvcc_out) + string(REPLACE "2.1" "2.1(2.0)" nvcc_out "${nvcc_out}") + set(CUDA_gpu_detect_output ${nvcc_out} CACHE INTERNAL "Returned GPU architetures from detect_installed_gpus tool" FORCE) + endif() + endif() + + if(NOT CUDA_gpu_detect_output) + message(STATUS "Automatic GPU detection failed. Building for all known architectures.") + set(${out_variable} ${paddle_known_gpu_archs} PARENT_SCOPE) + else() + set(${out_variable} ${CUDA_gpu_detect_output} PARENT_SCOPE) + endif() +endfunction() + + +######################################################################## +# Function for selecting GPU arch flags for nvcc based on CUDA_ARCH_NAME +# Usage: +# select_nvcc_arch_flags(out_variable) +function(select_nvcc_arch_flags out_variable) + # List of arch names + set(archs_names "Kepler" "Maxwell" "Pascal" "Volta" "Turing" "All" "Manual") + set(archs_name_default "All") + list(APPEND archs_names "Auto") + + # set CUDA_ARCH_NAME strings (so it will be seen as dropbox in CMake-Gui) + set(CUDA_ARCH_NAME ${archs_name_default} CACHE STRING "Select target NVIDIA GPU achitecture.") + set_property( CACHE CUDA_ARCH_NAME PROPERTY STRINGS "" ${archs_names} ) + mark_as_advanced(CUDA_ARCH_NAME) + + # verify CUDA_ARCH_NAME value + if(NOT ";${archs_names};" MATCHES ";${CUDA_ARCH_NAME};") + string(REPLACE ";" ", " archs_names "${archs_names}") + message(FATAL_ERROR "Only ${archs_names} architeture names are supported.") + endif() + + if(${CUDA_ARCH_NAME} STREQUAL "Manual") + set(CUDA_ARCH_BIN ${paddle_known_gpu_archs} CACHE STRING "Specify 'real' GPU architectures to build binaries for, BIN(PTX) format is supported") + set(CUDA_ARCH_PTX "50" CACHE STRING "Specify 'virtual' PTX architectures to build PTX intermediate code for") + mark_as_advanced(CUDA_ARCH_BIN CUDA_ARCH_PTX) + else() + unset(CUDA_ARCH_BIN CACHE) + unset(CUDA_ARCH_PTX CACHE) + endif() + + if(${CUDA_ARCH_NAME} STREQUAL "Kepler") + set(cuda_arch_bin "30 35") + elseif(${CUDA_ARCH_NAME} STREQUAL "Maxwell") + set(cuda_arch_bin "50") + elseif(${CUDA_ARCH_NAME} STREQUAL "Pascal") + set(cuda_arch_bin "60 61") + elseif(${CUDA_ARCH_NAME} STREQUAL "Volta") + set(cuda_arch_bin "70") + elseif(${CUDA_ARCH_NAME} STREQUAL "Turing") + set(cuda_arch_bin "75") + elseif(${CUDA_ARCH_NAME} STREQUAL "All") + set(cuda_arch_bin ${paddle_known_gpu_archs}) + elseif(${CUDA_ARCH_NAME} STREQUAL "Auto") + detect_installed_gpus(cuda_arch_bin) + else() # (${CUDA_ARCH_NAME} STREQUAL "Manual") + set(cuda_arch_bin ${CUDA_ARCH_BIN}) + endif() + + # remove dots and convert to lists + string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}") + string(REGEX REPLACE "\\." "" cuda_arch_ptx "${CUDA_ARCH_PTX}") + string(REGEX MATCHALL "[0-9()]+" cuda_arch_bin "${cuda_arch_bin}") + string(REGEX MATCHALL "[0-9]+" cuda_arch_ptx "${cuda_arch_ptx}") + list(REMOVE_DUPLICATES cuda_arch_bin) + list(REMOVE_DUPLICATES cuda_arch_ptx) + + set(nvcc_flags "") + set(nvcc_archs_readable "") + + # Tell NVCC to add binaries for the specified GPUs + foreach(arch ${cuda_arch_bin}) + if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)") + # User explicitly specified PTX for the concrete BIN + list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1}) + list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1}) + else() + # User didn't explicitly specify PTX for the concrete BIN, we assume PTX=BIN + list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch}) + list(APPEND nvcc_archs_readable sm_${arch}) + endif() + endforeach() + + # Tell NVCC to add PTX intermediate code for the specified architectures + foreach(arch ${cuda_arch_ptx}) + list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch}) + list(APPEND nvcc_archs_readable compute_${arch}) + endforeach() + + string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}") + set(${out_variable} ${nvcc_flags} PARENT_SCOPE) + set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE) +endfunction() + +message(STATUS "CUDA detected: " ${CUDA_VERSION}) +if (${CUDA_VERSION} LESS 7.0) + set(paddle_known_gpu_archs ${paddle_known_gpu_archs}) + add_definitions("-DPADDLE_CUDA_BINVER=\"60\"") +elseif (${CUDA_VERSION} LESS 8.0) # CUDA 7.x + set(paddle_known_gpu_archs ${paddle_known_gpu_archs7}) + list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") + list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__") + add_definitions("-DPADDLE_CUDA_BINVER=\"70\"") +elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x + set(paddle_known_gpu_archs ${paddle_known_gpu_archs8}) + list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") + list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__") + # CUDA 8 may complain that sm_20 is no longer supported. Suppress the + # warning for now. + list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets") + add_definitions("-DPADDLE_CUDA_BINVER=\"80\"") +elseif (${CUDA_VERSION} LESS 10.0) # CUDA 9.x + set(paddle_known_gpu_archs ${paddle_known_gpu_archs9}) + list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") + list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__") + add_definitions("-DPADDLE_CUDA_BINVER=\"90\"") +elseif (${CUDA_VERSION} LESS 11.0) # CUDA 10.x + set(paddle_known_gpu_archs ${paddle_known_gpu_archs10}) + list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") + list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__") + add_definitions("-DPADDLE_CUDA_BINVER=\"100\"") +endif() + +include_directories(${CUDA_INCLUDE_DIRS}) +if(NOT WITH_DSO) + if(WIN32) + set_property(GLOBAL PROPERTY CUDA_MODULES ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY}) + endif(WIN32) +endif(NOT WITH_DSO) + +# setting nvcc arch flags +select_nvcc_arch_flags(NVCC_FLAGS_EXTRA) +list(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA}) +message(STATUS "Added CUDA NVCC flags for: ${NVCC_FLAGS_EXTRA_readable}") + +# Set C++11 support +set(CUDA_PROPAGATE_HOST_FLAGS OFF) + +# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. +# So, don't set these flags here. +if (NOT WIN32) # windows msvc2015 support c++11 natively. +# -std=c++11 -fPIC not recoginize by msvc, -Xcompiler will be added by cmake. +list(APPEND CUDA_NVCC_FLAGS "-std=c++11") +list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC") +endif(NOT WIN32) + +if(WITH_FAST_MATH) + # Make use of fast math library. https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html + list(APPEND CUDA_NVCC_FLAGS "--use_fast_math") +endif() +# in cuda9, suppress cuda warning on eigen +list(APPEND CUDA_NVCC_FLAGS "-w") +# Set :expt-relaxed-constexpr to suppress Eigen warnings +list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr") + +if (NOT WIN32) + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_DEBUG}) + elseif(CMAKE_BUILD_TYPE STREQUAL "Release") + list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE}) + elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo") + list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}) + elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel") + # nvcc 9 does not support -Os. Use Release flags instead + list(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE}) + endif() +else(NOT WIN32) + list(APPEND CUDA_NVCC_FLAGS "-Xcompiler \"/wd 4244 /wd 4267 /wd 4819\"") + list(APPEND CUDA_NVCC_FLAGS "--compiler-options;/bigobj") + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + list(APPEND CUDA_NVCC_FLAGS "-g -G") + # match the cl's _ITERATOR_DEBUG_LEVEL + list(APPEND CUDA_NVCC_FLAGS "-D_DEBUG") + elseif(CMAKE_BUILD_TYPE STREQUAL "Release") + list(APPEND CUDA_NVCC_FLAGS "-O3 -DNDEBUG") + else() + message(FATAL "Windows only support Release or Debug build now. Please set visual studio build type to Release/Debug, x64 build.") +endif() +endif(NOT WIN32) + +mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD) +mark_as_advanced(CUDA_SDK_ROOT_DIR CUDA_SEPARABLE_COMPILATION) diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake new file mode 100644 index 00000000000..fff1980637d --- /dev/null +++ b/cmake/cudnn.cmake @@ -0,0 +1,102 @@ +if(NOT WITH_GPU) + return() +endif() + +if(WIN32) + set(CUDNN_ROOT ${CUDA_TOOLKIT_ROOT_DIR}) +else(WIN32) + set(CUDNN_ROOT "/usr" CACHE PATH "CUDNN ROOT") +endif(WIN32) + +find_path(CUDNN_INCLUDE_DIR cudnn.h + PATHS ${CUDNN_ROOT} ${CUDNN_ROOT}/include + $ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}/include ${CUDA_TOOLKIT_INCLUDE} + NO_DEFAULT_PATH +) + +get_filename_component(__libpath_hist ${CUDA_CUDART_LIBRARY} PATH) + +set(TARGET_ARCH "x86_64") +if(NOT ${CMAKE_SYSTEM_PROCESSOR}) + set(TARGET_ARCH ${CMAKE_SYSTEM_PROCESSOR}) +endif() + +list(APPEND CUDNN_CHECK_LIBRARY_DIRS + ${CUDNN_ROOT} + ${CUDNN_ROOT}/lib64 + ${CUDNN_ROOT}/lib + ${CUDNN_ROOT}/lib/${TARGET_ARCH}-linux-gnu + ${CUDNN_ROOT}/local/cuda-${CUDA_VERSION}/targets/${TARGET_ARCH}-linux/lib/ + $ENV{CUDNN_ROOT} + $ENV{CUDNN_ROOT}/lib64 + $ENV{CUDNN_ROOT}/lib + /usr/lib + ${CUDA_TOOLKIT_ROOT_DIR} + ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 + ) +set(CUDNN_LIB_NAME "") +if (LINUX) +set(CUDNN_LIB_NAME "libcudnn.so") +endif(LINUX) + +if(WIN32) +# only support cudnn7 +set(CUDNN_LIB_NAME "cudnn.lib" "cudnn64_7.dll") +endif(WIN32) + +if(APPLE) +set(CUDNN_LIB_NAME "libcudnn.dylib" "libcudnn.so") +endif(APPLE) + +find_library(CUDNN_LIBRARY NAMES ${CUDNN_LIB_NAME} # libcudnn_static.a + PATHS ${CUDNN_CHECK_LIBRARY_DIRS} ${CUDNN_INCLUDE_DIR} ${__libpath_hist} + NO_DEFAULT_PATH + DOC "Path to cuDNN library.") + + +if(CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY) + set(CUDNN_FOUND ON) +else() + set(CUDNN_FOUND OFF) +endif() + +if(CUDNN_FOUND) + file(READ ${CUDNN_INCLUDE_DIR}/cudnn.h CUDNN_VERSION_FILE_CONTENTS) + + get_filename_component(CUDNN_LIB_PATH ${CUDNN_LIBRARY} DIRECTORY) + + string(REGEX MATCH "define CUDNN_VERSION +([0-9]+)" + CUDNN_VERSION "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_VERSION +([0-9]+)" "\\1" + CUDNN_VERSION "${CUDNN_VERSION}") + + if("${CUDNN_VERSION}" STREQUAL "2000") + message(STATUS "Current cuDNN version is v2. ") + else() + string(REGEX MATCH "define CUDNN_MAJOR +([0-9]+)" CUDNN_MAJOR_VERSION + "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_MAJOR +([0-9]+)" "\\1" + CUDNN_MAJOR_VERSION "${CUDNN_MAJOR_VERSION}") + string(REGEX MATCH "define CUDNN_MINOR +([0-9]+)" CUDNN_MINOR_VERSION + "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_MINOR +([0-9]+)" "\\1" + CUDNN_MINOR_VERSION "${CUDNN_MINOR_VERSION}") + string(REGEX MATCH "define CUDNN_PATCHLEVEL +([0-9]+)" + CUDNN_PATCHLEVEL_VERSION "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_PATCHLEVEL +([0-9]+)" "\\1" + CUDNN_PATCHLEVEL_VERSION "${CUDNN_PATCHLEVEL_VERSION}") + + if(NOT CUDNN_MAJOR_VERSION) + set(CUDNN_VERSION "???") + else() + add_definitions("-DPADDLE_CUDNN_BINVER=\"${CUDNN_MAJOR_VERSION}\"") + math(EXPR CUDNN_VERSION + "${CUDNN_MAJOR_VERSION} * 1000 + + ${CUDNN_MINOR_VERSION} * 100 + ${CUDNN_PATCHLEVEL_VERSION}") + endif() + + message(STATUS "Current cuDNN header is ${CUDNN_INCLUDE_DIR}/cudnn.h. " + "Current cuDNN version is v${CUDNN_MAJOR_VERSION}. ") + + endif() +endif() diff --git a/cmake/cupti.cmake b/cmake/cupti.cmake new file mode 100644 index 00000000000..72ed0f1e585 --- /dev/null +++ b/cmake/cupti.cmake @@ -0,0 +1,41 @@ +if(NOT WITH_GPU) + return() +endif() + + +set(CUPTI_ROOT "/usr" CACHE PATH "CUPTI ROOT") +find_path(CUPTI_INCLUDE_DIR cupti.h + PATHS ${CUPTI_ROOT} ${CUPTI_ROOT}/include + $ENV{CUPTI_ROOT} $ENV{CUPTI_ROOT}/include + ${CUDA_TOOLKIT_ROOT_DIR}/extras/CUPTI/include + NO_DEFAULT_PATH + ) + +get_filename_component(__libpath_hist ${CUDA_CUDART_LIBRARY} PATH) + +set(TARGET_ARCH "x86_64") +if(NOT ${CMAKE_SYSTEM_PROCESSOR}) + set(TARGET_ARCH ${CMAKE_SYSTEM_PROCESSOR}) +endif() + +list(APPEND CUPTI_CHECK_LIBRARY_DIRS + ${CUPTI_ROOT} + ${CUPTI_ROOT}/lib64 + ${CUPTI_ROOT}/lib + ${CUPTI_ROOT}/lib/${TARGET_ARCH}-linux-gnu + $ENV{CUPTI_ROOT} + $ENV{CUPTI_ROOT}/lib64 + $ENV{CUPTI_ROOT}/lib + /usr/lib + ${CUDA_TOOLKIT_ROOT_DIR}/extras/CUPTI/lib64) +find_library(CUPTI_LIBRARY NAMES libcupti.so libcupti.dylib # libcupti_static.a + PATHS ${CUPTI_CHECK_LIBRARY_DIRS} ${CUPTI_INCLUDE_DIR} ${__libpath_hist} + NO_DEFAULT_PATH + DOC "Path to cuPTI library.") + +get_filename_component(CUPTI_LIBRARY_PATH ${CUPTI_LIBRARY} DIRECTORY) +if(CUPTI_INCLUDE_DIR AND CUPTI_LIBRARY) + set(CUPTI_FOUND ON) +else() + set(CUPTI_FOUND OFF) +endif() diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake new file mode 100644 index 00000000000..72441160f89 --- /dev/null +++ b/cmake/external/eigen.cmake @@ -0,0 +1,54 @@ +INCLUDE(ExternalProject) + +SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3) +SET(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR}/src/extern_eigen3) +INCLUDE_DIRECTORIES(${EIGEN_INCLUDE_DIR}) +if(NOT WITH_FAST_MATH) + # EIGEN_FAST_MATH: https://eigen.tuxfamily.org/dox/TopicPreprocessorDirectives.html + # enables some optimizations which might affect the accuracy of the result. + # This currently enables the SSE vectorization of sin() and cos(), + # and speedups sqrt() for single precision. + # Defined to 1 by default. Define it to 0 to disable. + add_definitions(-DEIGEN_FAST_MATH=0) +endif() + +if(WITH_AMD_GPU) + ExternalProject_Add( + extern_eigen3 + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/sabreshao/hipeigen.git" + GIT_TAG 7cb2b6e5a4b4a1efe658abb215cd866c6fb2275e + PREFIX ${EIGEN_SOURCE_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +else() + ExternalProject_Add( + extern_eigen3 + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/eigenteam/eigen-git-mirror" + # eigen on cuda9.1 missing header of math_funtions.hpp + # https://stackoverflow.com/questions/43113508/math-functions-hpp-not-found-when-using-cuda-with-eigen + GIT_TAG 917060c364181f33a735dc023818d5a54f60e54c + PREFIX ${EIGEN_SOURCE_DIR} + DOWNLOAD_NAME "eigen" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +endif() + +if (${CMAKE_VERSION} VERSION_LESS "3.3.0") + set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/eigen3_dummy.c) + file(WRITE ${dummyfile} "const char *dummy_eigen3 = \"${dummyfile}\";") + add_library(eigen3 STATIC ${dummyfile}) +else() + add_library(eigen3 INTERFACE) +endif() + +add_dependencies(eigen3 extern_eigen3) diff --git a/cmake/external/gflags.cmake b/cmake/external/gflags.cmake new file mode 100644 index 00000000000..5dc5a60f509 --- /dev/null +++ b/cmake/external/gflags.cmake @@ -0,0 +1,74 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +INCLUDE(ExternalProject) + +SET(GFLAGS_SOURCES_DIR ${THIRD_PARTY_PATH}/gflags) +SET(GFLAGS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gflags) +SET(GFLAGS_INCLUDE_DIR "${GFLAGS_INSTALL_DIR}/include" CACHE PATH "gflags include directory." FORCE) +IF(WIN32) + set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/libgflags.lib" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE) +ELSE(WIN32) + set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/libgflags.a" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE) +ENDIF(WIN32) + +INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR}) + +SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" + "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" + "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}" + "-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}" + "-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}" + "-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}" + "-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}" + "-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}") + +ExternalProject_Add( + extern_gflags + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/gflags/gflags.git" + GIT_TAG 77592648e3f3be87d6c7123eb81cbad75f9aef5a + PREFIX ${GFLAGS_SOURCES_DIR} + UPDATE_COMMAND "" + CMAKE_ARGS -DBUILD_STATIC_LIBS=ON + -DCMAKE_INSTALL_PREFIX=${GFLAGS_INSTALL_DIR} + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DBUILD_TESTING=OFF + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} + ${CROSS_COMPILE_CMAKE_ARGS} + ${OPTIONAL_ARGS} + ${EXTERNAL_OPTIONAL_ARGS} + CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GFLAGS_INSTALL_DIR} + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} +) +IF(WIN32) + IF(NOT EXISTS "${GFLAGS_INSTALL_DIR}/lib/libgflags.lib") + add_custom_command(TARGET extern_gflags POST_BUILD + COMMAND cmake -E copy ${GFLAGS_INSTALL_DIR}/lib/gflags_static.lib ${GFLAGS_INSTALL_DIR}/lib/libgflags.lib + ) + ENDIF() +ENDIF(WIN32) +ADD_LIBRARY(gflags STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARIES}) +ADD_DEPENDENCIES(gflags extern_gflags) + +# On Windows (including MinGW), the Shlwapi library is used by gflags if available. +if (WIN32) + include(CheckIncludeFileCXX) + check_include_file_cxx("shlwapi.h" HAVE_SHLWAPI) + if (HAVE_SHLWAPI) + set_property(GLOBAL PROPERTY OS_DEPENDENCY_MODULES shlwapi.lib) + endif(HAVE_SHLWAPI) +endif (WIN32) diff --git a/cmake/external/glog.cmake b/cmake/external/glog.cmake new file mode 100644 index 00000000000..970020d784f --- /dev/null +++ b/cmake/external/glog.cmake @@ -0,0 +1,77 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +INCLUDE(ExternalProject) + +SET(GLOG_SOURCES_DIR ${THIRD_PARTY_PATH}/glog) +SET(GLOG_INSTALL_DIR ${THIRD_PARTY_PATH}/install/glog) +SET(GLOG_INCLUDE_DIR "${GLOG_INSTALL_DIR}/include" CACHE PATH "glog include directory." FORCE) + +IF(WIN32) + SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/libglog.lib" CACHE FILEPATH "glog library." FORCE) + SET(GLOG_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4267 /wd4530") +ELSE(WIN32) + SET(GLOG_LIBRARIES "${GLOG_INSTALL_DIR}/lib/libglog.a" CACHE FILEPATH "glog library." FORCE) + SET(GLOG_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) +ENDIF(WIN32) + +INCLUDE_DIRECTORIES(${GLOG_INCLUDE_DIR}) + +SET(GLOG_REPOSITORY "https://github.com/google/glog.git") +SET(GLOG_TAG "v0.3.5") + +SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" + "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" + "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}" + "-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}" + "-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}" + "-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}" + "-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}" + "-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}") + +ExternalProject_Add( + extern_glog + ${EXTERNAL_PROJECT_LOG_ARGS} + DEPENDS gflags + GIT_REPOSITORY ${GLOG_REPOSITORY} + GIT_TAG ${GLOG_TAG} + PREFIX ${GLOG_SOURCES_DIR} + UPDATE_COMMAND "" + CMAKE_ARGS ${CROSS_COMPILE_CMAKE_ARGS} + ${OPTIONAL_ARGS} + -DCMAKE_INSTALL_PREFIX=${GLOG_INSTALL_DIR} + -DCMAKE_INSTALL_LIBDIR=${GLOG_INSTALL_DIR}/lib + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DWITH_GFLAGS=ON + -Dgflags_DIR=${GFLAGS_INSTALL_DIR}/lib/cmake/gflags + -DBUILD_TESTING=OFF + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} + ${EXTERNAL_OPTIONAL_ARGS} + CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GLOG_INSTALL_DIR} + -DCMAKE_INSTALL_LIBDIR:PATH=${GLOG_INSTALL_DIR}/lib + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} +) +IF(WIN32) + IF(NOT EXISTS "${GLOG_INSTALL_DIR}/lib/libglog.lib") + add_custom_command(TARGET extern_glog POST_BUILD + COMMAND cmake -E copy ${GLOG_INSTALL_DIR}/lib/glog.lib ${GLOG_INSTALL_DIR}/lib/libglog.lib + ) + ENDIF() +ENDIF(WIN32) + +ADD_LIBRARY(glog STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES}) +ADD_DEPENDENCIES(glog extern_glog gflags) +LINK_LIBRARIES(glog gflags) diff --git a/cmake/external/gtest.cmake b/cmake/external/gtest.cmake new file mode 100644 index 00000000000..b50f2729ce3 --- /dev/null +++ b/cmake/external/gtest.cmake @@ -0,0 +1,85 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#FIXME:(gongwb) Move brpc's gtest dependency. +IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC)) + IF(WITH_TESTING) + ENABLE_TESTING() + ENDIF(WITH_TESTING) + + INCLUDE(ExternalProject) + + SET(GTEST_SOURCES_DIR ${THIRD_PARTY_PATH}/gtest) + SET(GTEST_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gtest) + SET(GTEST_INCLUDE_DIR "${GTEST_INSTALL_DIR}/include" CACHE PATH "gtest include directory." FORCE) + + INCLUDE_DIRECTORIES(${GTEST_INCLUDE_DIR}) + + IF(WIN32) + set(GTEST_LIBRARIES + "${GTEST_INSTALL_DIR}/lib/gtest.lib" CACHE FILEPATH "gtest libraries." FORCE) + set(GTEST_MAIN_LIBRARIES + "${GTEST_INSTALL_DIR}/lib/gtest_main.lib" CACHE FILEPATH "gtest main libraries." FORCE) + ELSE(WIN32) + set(GTEST_LIBRARIES + "${GTEST_INSTALL_DIR}/lib/libgtest.a" CACHE FILEPATH "gtest libraries." FORCE) + set(GTEST_MAIN_LIBRARIES + "${GTEST_INSTALL_DIR}/lib/libgtest_main.a" CACHE FILEPATH "gtest main libraries." FORCE) + ENDIF(WIN32) + + IF(WITH_MKLML) + # wait for mklml downloading completed + SET(GTEST_DEPENDS ${MKLML_PROJECT}) + ENDIF() + + SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" + "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" + "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}" + "-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}" + "-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}" + "-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}" + "-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}" + "-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}") + + ExternalProject_Add( + extern_gtest + ${EXTERNAL_PROJECT_LOG_ARGS} + DEPENDS ${GTEST_DEPENDS} + GIT_REPOSITORY "https://github.com/google/googletest.git" + GIT_TAG "release-1.8.0" + PREFIX ${GTEST_SOURCES_DIR} + UPDATE_COMMAND "" + CMAKE_ARGS ${CROSS_COMPILE_CMAKE_ARGS} + ${OPTIONAL_ARGS} + -DCMAKE_INSTALL_PREFIX=${GTEST_INSTALL_DIR} + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DBUILD_GMOCK=ON + -Dgtest_disable_pthreads=ON + -Dgtest_force_shared_crt=ON + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} + ${EXTERNAL_OPTIONAL_ARGS} + CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GTEST_INSTALL_DIR} + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} + ) + + ADD_LIBRARY(gtest STATIC IMPORTED GLOBAL) + SET_PROPERTY(TARGET gtest PROPERTY IMPORTED_LOCATION ${GTEST_LIBRARIES}) + ADD_DEPENDENCIES(gtest extern_gtest) + + ADD_LIBRARY(gtest_main STATIC IMPORTED GLOBAL) + SET_PROPERTY(TARGET gtest_main PROPERTY IMPORTED_LOCATION ${GTEST_MAIN_LIBRARIES}) + ADD_DEPENDENCIES(gtest_main extern_gtest) + +ENDIF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC)) diff --git a/cmake/external/libxsmm.cmake b/cmake/external/libxsmm.cmake new file mode 100644 index 00000000000..69cdba7c592 --- /dev/null +++ b/cmake/external/libxsmm.cmake @@ -0,0 +1,55 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +OPTION(WITH_LIBXSMM "Compile with libxsmm" OFF) + +IF(NOT WITH_LIBXSMM) + return() +ENDIF() + +IF(WIN32 OR APPLE) + MESSAGE(WARNING "Windows, Mac are not supported with libxsmm in Paddle yet.") + SET(WITH_LIBXSMM OFF CACHE STRING "Disable LIBXSMM" FORCE) + return() +ENDIF() + +INCLUDE (ExternalProject) + +SET(LIBXSMM_SOURCES_DIR ${THIRD_PARTY_PATH}/libxsmm) +SET(LIBXSMM_INSTALL_DIR ${THIRD_PARTY_PATH}/install/libxsmm) +SET(LIBXSMM_INCLUDE_DIR "${LIBXSMM_INSTALL_DIR}/include" CACHE PATH "LIBXSMM include directory." FORCE) +SET(LIBXSMM_LIBRARY_DIR "${LIBXSMM_INSTALL_DIR}/lib" CACHE PATH "LIBXSMM library directory." FORCE) +SET(LIBXSMM_LIBS "${LIBXSMM_LIBRARY_DIR}/libxsmm.a" + "${LIBXSMM_LIBRARY_DIR}/libxsmmnoblas.a") + +ExternalProject_Add( + extern_libxsmm + GIT_REPOSITORY "https://github.com/hfp/libxsmm.git" + GIT_TAG "7cc03b5b342fdbc6b6d990b190671c5dbb8489a2" + PREFIX ${LIBXSMM_SOURCES_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_IN_SOURCE 1 + BUILD_COMMAND $(MAKE) --silent PREFIX=${LIBXSMM_INSTALL_DIR} CXX=g++ CC=gcc WARP=0 install + INSTALL_COMMAND "" +) +ADD_LIBRARY(libxsmm STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET libxsmm PROPERTY IMPORTED_LOCATION "${LIBXSMM_LIBRARY_DIR}/libxsmm.a") +SET_PROPERTY(TARGET libxsmm PROPERTY IMPORTED_LOCATION "${LIBXSMM_LIBRARY_DIR}/libxsmmnoblas.a") + +MESSAGE(STATUS "Libxsmm library: ${LIBXSMM_LIBS}") +include_directories(${LIBXSMM_INCLUDE_DIR}) +ADD_DEFINITIONS(-DPADDLE_WITH_LIBXSMM) +ADD_DEPENDENCIES(libxsmm extern_libxsmm) diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake new file mode 100644 index 00000000000..b1e437a9007 --- /dev/null +++ b/cmake/external/mkldnn.cmake @@ -0,0 +1,120 @@ +# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +IF(NOT ${WITH_MKLDNN}) + return() +ENDIF(NOT ${WITH_MKLDNN}) + +INCLUDE(ExternalProject) + +SET(MKLDNN_PROJECT "extern_mkldnn") +SET(MKLDNN_SOURCES_DIR ${THIRD_PARTY_PATH}/mkldnn) +SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn) +SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE) + +IF(APPLE) + MESSAGE(WARNING + "Mac is not supported with MKLDNN in Paddle yet." + "Force WITH_MKLDNN=OFF") + SET(WITH_MKLDNN OFF CACHE STRING "Disable MKLDNN in MacOS" FORCE) + return() +ENDIF() + +# Introduce variables: +# * CMAKE_INSTALL_LIBDIR +INCLUDE(GNUInstallDirs) +SET(LIBDIR "lib") +if(CMAKE_INSTALL_LIBDIR MATCHES ".*lib64$") + SET(LIBDIR "lib64") +endif() + +MESSAGE(STATUS "Set ${MKLDNN_INSTALL_DIR}/l${LIBDIR} to runtime path") +SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) +SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/${LIBDIR}") + +INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR}) # For MKLDNN code to include internal headers. + +IF(${CBLAS_PROVIDER} STREQUAL "MKLML") + SET(MKLDNN_DEPENDS ${MKLML_PROJECT}) + MESSAGE(STATUS "Build MKLDNN with MKLML ${MKLML_ROOT}") +ELSE() + MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN") +ENDIF() + +IF(NOT WIN32) + SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-error=array-bounds") + SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value") + SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}") + SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}") +ELSE() + SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} /EHsc") +ENDIF(NOT WIN32) + +ExternalProject_Add( + ${MKLDNN_PROJECT} + ${EXTERNAL_PROJECT_LOG_ARGS} + DEPENDS ${MKLDNN_DEPENDS} + GIT_REPOSITORY "https://github.com/intel/mkl-dnn.git" + GIT_TAG "863ff6e7042cec7d2e29897fe9f0872e0888b0fc" + PREFIX ${MKLDNN_SOURCES_DIR} + UPDATE_COMMAND "" + CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} + CMAKE_ARGS -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + CMAKE_ARGS -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} + CMAKE_ARGS -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG} + CMAKE_ARGS -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} + CMAKE_ARGS -DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG} + CMAKE_ARGS -DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE} + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR} + CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + CMAKE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE=ON + CMAKE_ARGS -DMKLROOT=${MKLML_ROOT} + CMAKE_ARGS -DCMAKE_C_FLAGS=${MKLDNN_CFLAG} + CMAKE_ARGS -DCMAKE_CXX_FLAGS=${MKLDNN_CXXFLAG} + CMAKE_ARGS -DWITH_TEST=OFF -DWITH_EXAMPLE=OFF + CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR} + -DMKLROOT:PATH=${MKLML_ROOT} +) +if(WIN32) + SET(MKLDNN_LIB "${MKLDNN_INSTALL_DIR}/${LIBDIR}/mkldnn.lib" CACHE FILEPATH "mkldnn library." FORCE) +else(WIN32) + SET(MKLDNN_LIB "${MKLDNN_INSTALL_DIR}/${LIBDIR}/libmkldnn.so" CACHE FILEPATH "mkldnn library." FORCE) +endif(WIN32) + +ADD_LIBRARY(shared_mkldnn SHARED IMPORTED GLOBAL) +SET_PROPERTY(TARGET shared_mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIB}) +ADD_DEPENDENCIES(shared_mkldnn ${MKLDNN_PROJECT}) +MESSAGE(STATUS "MKLDNN library: ${MKLDNN_LIB}") +add_definitions(-DPADDLE_WITH_MKLDNN) + +# generate a static dummy target to track mkldnn dependencies +# for cc_library(xxx SRCS xxx.c DEPS mkldnn) +SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/mkldnn_dummy.c) +FILE(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";") +ADD_LIBRARY(mkldnn STATIC ${dummyfile}) +TARGET_LINK_LIBRARIES(mkldnn ${MKLDNN_LIB} ${MKLML_LIB} ${MKLML_IOMP_LIB}) +ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT}) + +# copy the real so.0 lib to install dir +# it can be directly contained in wheel or capi +if(WIN32) + SET(MKLDNN_SHARED_LIB ${MKLDNN_INSTALL_DIR}/bin/mkldnn.dll) +else(WIN32) + SET(MKLDNN_SHARED_LIB ${MKLDNN_INSTALL_DIR}/libmkldnn.so.0) + ADD_CUSTOM_COMMAND(OUTPUT ${MKLDNN_SHARED_LIB} + COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_LIB} ${MKLDNN_SHARED_LIB} + DEPENDS mkldnn shared_mkldnn) +endif(WIN32) +ADD_CUSTOM_TARGET(mkldnn_shared_lib ALL DEPENDS ${MKLDNN_SHARED_LIB}) +ADD_DEPENDENCIES(mkldnn_shared_lib ${MKLDNN_PROJECT} mkldnn) diff --git a/cmake/external/mklml.cmake b/cmake/external/mklml.cmake new file mode 100644 index 00000000000..142fce816de --- /dev/null +++ b/cmake/external/mklml.cmake @@ -0,0 +1,77 @@ +# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +IF(NOT ${WITH_MKLML}) + return() +ENDIF(NOT ${WITH_MKLML}) + +IF(APPLE) + MESSAGE(WARNING "Mac is not supported with MKLML in Paddle yet. Force WITH_MKLML=OFF.") + SET(WITH_MKLML OFF CACHE STRING "Disable MKLML package in MacOS" FORCE) + return() +ENDIF() + +INCLUDE(ExternalProject) +SET(MKLML_DST_DIR "mklml") +SET(MKLML_INSTALL_ROOT "${THIRD_PARTY_PATH}/install") +SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR}) +SET(MKLML_ROOT ${MKLML_INSTALL_DIR}) +SET(MKLML_INC_DIR ${MKLML_ROOT}/include) +SET(MKLML_LIB_DIR ${MKLML_ROOT}/lib) +SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib") + +SET(TIME_VERSION "2019.0.1.20181227") +IF(WIN32) + SET(MKLML_VER "mklml_win_${TIME_VERSION}" CACHE STRING "" FORCE) + SET(MKLML_URL "https://paddlepaddledeps.bj.bcebos.com/${MKLML_VER}.zip" CACHE STRING "" FORCE) + SET(MKLML_LIB ${MKLML_LIB_DIR}/mklml.lib) + SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.lib) + SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll) + SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll) +ELSE() + #TODO(intel-huying): + # Now enable Erf function in mklml library temporarily, it will be updated as offical version later. + SET(MKLML_VER "Glibc225_vsErf_mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE) + SET(MKLML_URL "http://paddlepaddledeps.bj.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE) + SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so) + SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so) + SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/libmklml_intel.so) + SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so) +ENDIF() + +SET(MKLML_PROJECT "extern_mklml") +MESSAGE(STATUS "MKLML_VER: ${MKLML_VER}, MKLML_URL: ${MKLML_URL}") +SET(MKLML_SOURCE_DIR "${THIRD_PARTY_PATH}/mklml") +SET(MKLML_DOWNLOAD_DIR "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}") + +ExternalProject_Add( + ${MKLML_PROJECT} + ${EXTERNAL_PROJECT_LOG_ARGS} + PREFIX ${MKLML_SOURCE_DIR} + URL ${MKLML_URL} + DOWNLOAD_DIR ${MKLML_DOWNLOAD_DIR} + DOWNLOAD_NO_PROGRESS 1 + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + UPDATE_COMMAND "" + INSTALL_COMMAND + ${CMAKE_COMMAND} -E copy_directory ${MKLML_DOWNLOAD_DIR}/include ${MKLML_INC_DIR} && + ${CMAKE_COMMAND} -E copy_directory ${MKLML_DOWNLOAD_DIR}/lib ${MKLML_LIB_DIR} +) + +INCLUDE_DIRECTORIES(${MKLML_INC_DIR}) + +ADD_LIBRARY(mklml SHARED IMPORTED GLOBAL) +SET_PROPERTY(TARGET mklml PROPERTY IMPORTED_LOCATION ${MKLML_LIB}) +ADD_DEPENDENCIES(mklml ${MKLML_PROJECT}) diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake new file mode 100644 index 00000000000..d8a4a0be6f5 --- /dev/null +++ b/cmake/external/openblas.cmake @@ -0,0 +1,93 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +INCLUDE(cblas) + +IF(NOT ${CBLAS_FOUND}) + INCLUDE(ExternalProject) + + SET(CBLAS_SOURCES_DIR ${THIRD_PARTY_PATH}/openblas) + SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas) + SET(CBLAS_INC_DIR "${CBLAS_INSTALL_DIR}/include" CACHE PATH "openblas include directory." FORCE) + + SET(CBLAS_LIBRARIES + "${CBLAS_INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}openblas${CMAKE_STATIC_LIBRARY_SUFFIX}" + CACHE FILEPATH "openblas library." FORCE) + + ADD_DEFINITIONS(-DPADDLE_USE_OPENBLAS) + + IF (WIN32) + SET(CBLAS_FOUND true) + MESSAGE(WARNING, "In windows, openblas only support msvc build, please build it manually and put it at " ${CBLAS_INSTALL_DIR}) + ENDIF(WIN32) + + IF (NOT WIN32) + SET(OPENBLAS_CC "${CMAKE_C_COMPILER} -Wno-unused-but-set-variable -Wno-unused-variable") + SET(OPENBLAS_COMMIT "v0.2.20") + + IF(APPLE) + SET(OPENBLAS_CC "${CMAKE_C_COMPILER} -isysroot ${CMAKE_OSX_SYSROOT}") + ENDIF() + SET(OPTIONAL_ARGS "") + IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^x86(_64)?$") + SET(OPTIONAL_ARGS DYNAMIC_ARCH=1 NUM_THREADS=64) + ENDIF() + + SET(COMMON_ARGS CC=${OPENBLAS_CC} NO_SHARED=1 NO_LAPACK=1 libs) + ExternalProject_Add( + extern_openblas + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY https://github.com/xianyi/OpenBLAS.git + GIT_TAG ${OPENBLAS_COMMIT} + PREFIX ${CBLAS_SOURCES_DIR} + INSTALL_DIR ${CBLAS_INSTALL_DIR} + BUILD_IN_SOURCE 1 + BUILD_COMMAND ${CMAKE_MAKE_PROGRAM} ${COMMON_ARGS} ${OPTIONAL_ARGS} + INSTALL_COMMAND ${CMAKE_MAKE_PROGRAM} install NO_SHARED=1 NO_LAPACK=1 PREFIX= + && rm -r ${CBLAS_INSTALL_DIR}/lib/cmake ${CBLAS_INSTALL_DIR}/lib/pkgconfig + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + ) + ELSE() + ENDIF(NOT WIN32) + SET(CBLAS_PROVIDER openblas) +ENDIF(NOT ${CBLAS_FOUND}) + +MESSAGE(STATUS "BLAS library: ${CBLAS_LIBRARIES}") +MESSAGE(STATUS "BLAS Include: ${CBLAS_INC_DIR}") +INCLUDE_DIRECTORIES(${CBLAS_INC_DIR}) + +# FIXME(gangliao): generate cblas target to track all high performance +# linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas) +SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cblas_dummy.c) +FILE(WRITE ${dummyfile} "const char *dummy_cblas = \"${dummyfile}\";") +ADD_LIBRARY(cblas STATIC ${dummyfile}) + +IF("${CBLAS_PROVIDER}" STREQUAL "MKLML") + TARGET_LINK_LIBRARIES(cblas dynload_mklml) +ELSE() + TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES}) +ENDIF("${CBLAS_PROVIDER}" STREQUAL "MKLML") + +IF(WITH_LIBXSMM) + TARGET_LINK_LIBRARIES(cblas ${LIBXSMM_LIBS}) + ADD_DEPENDENCIES(cblas extern_libxsmm) +ENDIF() + +IF(NOT ${CBLAS_FOUND}) + ADD_DEPENDENCIES(cblas extern_openblas) +ELSE() + IF("${CBLAS_PROVIDER}" STREQUAL "MKLML") + ADD_DEPENDENCIES(cblas mklml) + ENDIF() +ENDIF(NOT ${CBLAS_FOUND}) diff --git a/cmake/external/opencl-clhpp.cmake b/cmake/external/opencl-clhpp.cmake new file mode 100644 index 00000000000..ea724860d9b --- /dev/null +++ b/cmake/external/opencl-clhpp.cmake @@ -0,0 +1,36 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +INCLUDE(ExternalProject) + +SET(OPENCL_CLHPP_SRCS_DIR ${THIRD_PARTY_PATH}/opencl-clhpp) +SET(OPENCL_CLHPP_INSTALL_DIR ${THIRD_PARTY_PATH}/install/opencl-clhpp) +SET(OPENCL_CLHPP_INCLUDE_DIR "${OPENCL_CLHPP_INSTALL_DIR}" CACHE PATH "opencl-clhpp include directory." FORCE) + +INCLUDE_DIRECTORIES(${OPENCL_CLHPP_INCLUDE_DIR}) + +ExternalProject_Add( + opencl_clhpp + GIT_REPOSITORY "https://github.com/KhronosGroup/OpenCL-CLHPP.git" + GIT_TAG "v2.0.10" + PREFIX "${OPENCL_CLHPP_SRCS_DIR}" + CMAKE_ARGS -DBUILD_DOCS=OFF + -DBUILD_EXAMPLES=OFF + -DBUILD_TESTS=OFF + -DCMAKE_INSTALL_PREFIX=${OPENCL_CLHPP_INSTALL_DIR} + CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${OPENCL_CLHPP_INSTALL_DIR} + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} +) + +ADD_DEPENDENCIES(opencl_clhpp opencl_headers) diff --git a/cmake/external/opencl-headers.cmake b/cmake/external/opencl-headers.cmake new file mode 100644 index 00000000000..68c9c5251cf --- /dev/null +++ b/cmake/external/opencl-headers.cmake @@ -0,0 +1,33 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +INCLUDE(ExternalProject) + +SET(OPENCL_HEADERS_SRCS_DIR ${THIRD_PARTY_PATH}/opencl-headers) +SET(OPENCL_HEADERS_INCLUDE_DIR "${OPENCL_HEADERS_SRCS_DIR}/src/opencl_headers" CACHE PATH "opencl-headers include directory." FORCE) + +INCLUDE_DIRECTORIES(${OPENCL_HEADERS_INCLUDE_DIR}) + +ExternalProject_Add( + opencl_headers + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/KhronosGroup/OpenCL-Headers.git" + GIT_TAG "c5a4bbeabb10d8ed3d1c651b93aa31737bc473dd" + PREFIX ${OPENCL_HEADERS_SRCS_DIR} + DOWNLOAD_NAME "OpenCL-Headers" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/cmake/external/protobuf.cmake b/cmake/external/protobuf.cmake new file mode 100644 index 00000000000..a3029877945 --- /dev/null +++ b/cmake/external/protobuf.cmake @@ -0,0 +1,293 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +INCLUDE(ExternalProject) +# Always invoke `FIND_PACKAGE(Protobuf)` for importing function protobuf_generate_cpp +IF(NOT WIN32) +FIND_PACKAGE(Protobuf QUIET) +ENDIF(NOT WIN32) +macro(UNSET_VAR VAR_NAME) + UNSET(${VAR_NAME} CACHE) + UNSET(${VAR_NAME}) +endmacro() + +UNSET_VAR(PROTOBUF_INCLUDE_DIR) +UNSET_VAR(PROTOBUF_FOUND) +UNSET_VAR(PROTOBUF_PROTOC_EXECUTABLE) +UNSET_VAR(PROTOBUF_PROTOC_LIBRARY) +UNSET_VAR(PROTOBUF_LITE_LIBRARY) +UNSET_VAR(PROTOBUF_LIBRARY) +UNSET_VAR(PROTOBUF_INCLUDE_DIR) +UNSET_VAR(Protobuf_PROTOC_EXECUTABLE) +function(protobuf_generate_python SRCS) + # shameless copy from https://github.com/Kitware/CMake/blob/master/Modules/FindProtobuf.cmake + if(NOT ARGN) + message(SEND_ERROR "Error: PROTOBUF_GENERATE_PYTHON() called without any proto files") + return() + endif() + + if(PROTOBUF_GENERATE_CPP_APPEND_PATH) + # Create an include path for each file specified + foreach(FIL ${ARGN}) + get_filename_component(ABS_FIL ${FIL} ABSOLUTE) + get_filename_component(ABS_PATH ${ABS_FIL} PATH) + list(FIND _protobuf_include_path ${ABS_PATH} _contains_already) + if(${_contains_already} EQUAL -1) + list(APPEND _protobuf_include_path -I ${ABS_PATH}) + endif() + endforeach() + else() + set(_protobuf_include_path -I ${CMAKE_CURRENT_SOURCE_DIR}) + endif() + if(DEFINED PROTOBUF_IMPORT_DIRS AND NOT DEFINED Protobuf_IMPORT_DIRS) + set(Protobuf_IMPORT_DIRS "${PROTOBUF_IMPORT_DIRS}") + endif() + + if(DEFINED Protobuf_IMPORT_DIRS) + foreach(DIR ${Protobuf_IMPORT_DIRS}) + get_filename_component(ABS_PATH ${DIR} ABSOLUTE) + list(FIND _protobuf_include_path ${ABS_PATH} _contains_already) + if(${_contains_already} EQUAL -1) + list(APPEND _protobuf_include_path -I ${ABS_PATH}) + endif() + endforeach() + endif() + + set(${SRCS}) + foreach(FIL ${ARGN}) + get_filename_component(ABS_FIL ${FIL} ABSOLUTE) + get_filename_component(FIL_WE ${FIL} NAME_WE) + if(NOT PROTOBUF_GENERATE_CPP_APPEND_PATH) + get_filename_component(FIL_DIR ${FIL} DIRECTORY) + if(FIL_DIR) + set(FIL_WE "${FIL_DIR}/${FIL_WE}") + endif() + endif() + list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py") + add_custom_command( + OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} --python_out ${CMAKE_CURRENT_BINARY_DIR} ${_protobuf_include_path} ${ABS_FIL} + DEPENDS ${ABS_FIL} ${PROTOBUF_PROTOC_EXECUTABLE} + COMMENT "Running Python protocol buffer compiler on ${FIL}" + VERBATIM ) + endforeach() + + set(${SRCS} ${${SRCS}} PARENT_SCOPE) +endfunction() + +# Print and set the protobuf library information, +# finish this cmake process and exit from this file. +macro(PROMPT_PROTOBUF_LIB) + SET(protobuf_DEPS ${ARGN}) + + MESSAGE(STATUS "Protobuf protoc executable: ${PROTOBUF_PROTOC_EXECUTABLE}") + MESSAGE(STATUS "Protobuf-lite library: ${PROTOBUF_LITE_LIBRARY}") + MESSAGE(STATUS "Protobuf library: ${PROTOBUF_LIBRARY}") + MESSAGE(STATUS "Protoc library: ${PROTOBUF_PROTOC_LIBRARY}") + MESSAGE(STATUS "Protobuf version: ${PROTOBUF_VERSION}") + INCLUDE_DIRECTORIES(${PROTOBUF_INCLUDE_DIR}) + + # Assuming that all the protobuf libraries are of the same type. + IF(${PROTOBUF_LIBRARY} MATCHES ${CMAKE_STATIC_LIBRARY_SUFFIX}) + SET(protobuf_LIBTYPE STATIC) + ELSEIF(${PROTOBUF_LIBRARY} MATCHES "${CMAKE_SHARED_LIBRARY_SUFFIX}$") + SET(protobuf_LIBTYPE SHARED) + ELSE() + MESSAGE(FATAL_ERROR "Unknown library type: ${PROTOBUF_LIBRARY}") + ENDIF() + + ADD_LIBRARY(protobuf ${protobuf_LIBTYPE} IMPORTED GLOBAL) + SET_PROPERTY(TARGET protobuf PROPERTY IMPORTED_LOCATION ${PROTOBUF_LIBRARY}) + + ADD_LIBRARY(protobuf_lite ${protobuf_LIBTYPE} IMPORTED GLOBAL) + SET_PROPERTY(TARGET protobuf_lite PROPERTY IMPORTED_LOCATION ${PROTOBUF_LITE_LIBRARY}) + + ADD_LIBRARY(libprotoc ${protobuf_LIBTYPE} IMPORTED GLOBAL) + SET_PROPERTY(TARGET libprotoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY}) + + ADD_EXECUTABLE(protoc IMPORTED GLOBAL) + SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOBUF_PROTOC_EXECUTABLE}) + # FIND_Protobuf.cmake uses `Protobuf_PROTOC_EXECUTABLE`. + # make `protobuf_generate_cpp` happy. + SET(Protobuf_PROTOC_EXECUTABLE ${PROTOBUF_PROTOC_EXECUTABLE}) + + FOREACH(dep ${protobuf_DEPS}) + ADD_DEPENDENCIES(protobuf ${dep}) + ADD_DEPENDENCIES(protobuf_lite ${dep}) + ADD_DEPENDENCIES(libprotoc ${dep}) + ADD_DEPENDENCIES(protoc ${dep}) + ENDFOREACH() + + RETURN() +endmacro() +macro(SET_PROTOBUF_VERSION) + EXEC_PROGRAM(${PROTOBUF_PROTOC_EXECUTABLE} ARGS --version OUTPUT_VARIABLE PROTOBUF_VERSION) + STRING(REGEX MATCH "[0-9]+.[0-9]+" PROTOBUF_VERSION "${PROTOBUF_VERSION}") +endmacro() + +set(PROTOBUF_ROOT "" CACHE PATH "Folder contains protobuf") +IF (WIN32) + SET(PROTOBUF_ROOT ${THIRD_PARTY_PATH}/install/protobuf) +ENDIF(WIN32) + +if (NOT "${PROTOBUF_ROOT}" STREQUAL "") + find_path(PROTOBUF_INCLUDE_DIR google/protobuf/message.h PATHS ${PROTOBUF_ROOT}/include NO_DEFAULT_PATH) + find_library(PROTOBUF_LIBRARY protobuf libprotobuf.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH) + find_library(PROTOBUF_LITE_LIBRARY protobuf-lite libprotobuf-lite.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH) + find_library(PROTOBUF_PROTOC_LIBRARY protoc libprotoc.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH) + find_program(PROTOBUF_PROTOC_EXECUTABLE protoc PATHS ${PROTOBUF_ROOT}/bin NO_DEFAULT_PATH) + if (PROTOBUF_INCLUDE_DIR AND PROTOBUF_LIBRARY AND PROTOBUF_LITE_LIBRARY AND PROTOBUF_PROTOC_LIBRARY AND PROTOBUF_PROTOC_EXECUTABLE) + message(STATUS "Using custom protobuf library in ${PROTOBUF_ROOT}.") + SET(PROTOBUF_FOUND true) + SET_PROTOBUF_VERSION() + PROMPT_PROTOBUF_LIB() + else() + message(WARNING "Cannot find protobuf library in ${PROTOBUF_ROOT}") + endif() +endif() + +FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) + STRING(REPLACE "extern_" "" TARGET_DIR_NAME "${TARGET_NAME}") + SET(PROTOBUF_SOURCES_DIR ${THIRD_PARTY_PATH}/${TARGET_DIR_NAME}) + SET(PROTOBUF_INSTALL_DIR ${THIRD_PARTY_PATH}/install/${TARGET_DIR_NAME}) + + SET(${TARGET_NAME}_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" PARENT_SCOPE) + SET(PROTOBUF_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" PARENT_SCOPE) + SET(${TARGET_NAME}_LITE_LIBRARY + "${PROTOBUF_INSTALL_DIR}/lib/libprotobuf-lite${CMAKE_STATIC_LIBRARY_SUFFIX}" + PARENT_SCOPE) + SET(${TARGET_NAME}_LIBRARY + "${PROTOBUF_INSTALL_DIR}/lib/libprotobuf${CMAKE_STATIC_LIBRARY_SUFFIX}" + PARENT_SCOPE) + SET(${TARGET_NAME}_PROTOC_LIBRARY + "${PROTOBUF_INSTALL_DIR}/lib/libprotoc${CMAKE_STATIC_LIBRARY_SUFFIX}" + PARENT_SCOPE) + SET(${TARGET_NAME}_PROTOC_EXECUTABLE + "${PROTOBUF_INSTALL_DIR}/bin/protoc${CMAKE_EXECUTABLE_SUFFIX}" + PARENT_SCOPE) + + SET(PROTOBUF_REPO "https://github.com/protocolbuffers/protobuf.git") + SET(PROTOBUF_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546") + SET(OPTIONAL_CACHE_ARGS "") + SET(OPTIONAL_ARGS "") + + IF(BUILD_FOR_HOST) + SET(OPTIONAL_ARGS + "-DCMAKE_C_COMPILER=${HOST_C_COMPILER}" + "-DCMAKE_CXX_COMPILER=${HOST_CXX_COMPILER}" + "-Dprotobuf_WITH_ZLIB=OFF" + "-DZLIB_ROOT:FILEPATH=${ZLIB_ROOT}") + SET(OPTIONAL_CACHE_ARGS "-DZLIB_ROOT:STRING=${ZLIB_ROOT}") + ELSE() + # protobuf have compile issue when use android stl c++_static + SET(PROTOBUF_REPO "https://github.com/tensor-tang/protobuf.git") + SET(PROTOBUF_TAG "mobile") + SET(OPTIONAL_ARGS "-Dprotobuf_WITH_ZLIB=OFF" + ${CROSS_COMPILE_CMAKE_ARGS} + "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" + "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" + "-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}" + "-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}" + "-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}" + "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}" + "-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}" + "-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}") + ENDIF() + IF(WIN32) + SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} "-DCMAKE_GENERATOR_PLATFORM=x64") + ENDIF() + + if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + ExternalProject_Add( + ${TARGET_NAME} + ${EXTERNAL_PROJECT_LOG_ARGS} + PREFIX ${PROTOBUF_SOURCES_DIR} + SOURCE_SUBDIR cmake + UPDATE_COMMAND "" + GIT_REPOSITORY ${PROTOBUF_REPO} + GIT_TAG ${PROTOBUF_TAG} + CMAKE_ARGS + ${OPTIONAL_ARGS} + -Dprotobuf_BUILD_TESTS=OFF + -DCMAKE_SKIP_RPATH=ON + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} + -DCMAKE_INSTALL_PREFIX=${PROTOBUF_INSTALL_DIR} + -DCMAKE_INSTALL_LIBDIR=lib + -DBUILD_SHARED_LIBS=OFF + CMAKE_CACHE_ARGS + -DCMAKE_INSTALL_PREFIX:PATH=${PROTOBUF_INSTALL_DIR} + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} + -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + ${OPTIONAL_CACHE_ARGS} + ) + else() + ExternalProject_Add( + ${TARGET_NAME} + ${EXTERNAL_PROJECT_LOG_ARGS} + PREFIX ${PROTOBUF_SOURCES_DIR} + UPDATE_COMMAND "" + GIT_REPOSITORY ${PROTOBUF_REPO} + GIT_TAG ${PROTOBUF_TAG} + CONFIGURE_COMMAND + ${CMAKE_COMMAND} ${PROTOBUF_SOURCES_DIR}/src/${TARGET_NAME}/cmake + ${OPTIONAL_ARGS} + -Dprotobuf_BUILD_TESTS=OFF + -DCMAKE_SKIP_RPATH=ON + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} + -DCMAKE_INSTALL_PREFIX=${PROTOBUF_INSTALL_DIR} + -DCMAKE_INSTALL_LIBDIR=lib + -DBUILD_SHARED_LIBS=OFF + CMAKE_CACHE_ARGS + -DCMAKE_INSTALL_PREFIX:PATH=${PROTOBUF_INSTALL_DIR} + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} + -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + ${OPTIONAL_CACHE_ARGS} + ) + endif() +ENDFUNCTION() + +SET(PROTOBUF_VERSION 3.1.0) + +IF(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + build_protobuf(protobuf_host TRUE) + LIST(APPEND external_project_dependencies protobuf_host) + SET(PROTOBUF_PROTOC_EXECUTABLE ${protobuf_host_PROTOC_EXECUTABLE} + CACHE FILEPATH "protobuf executable." FORCE) +ENDIF() + +IF(NOT PROTOBUF_FOUND) + build_protobuf(extern_protobuf FALSE) + + SET(PROTOBUF_INCLUDE_DIR ${extern_protobuf_INCLUDE_DIR} + CACHE PATH "protobuf include directory." FORCE) + SET(PROTOBUF_LITE_LIBRARY ${extern_protobuf_LITE_LIBRARY} + CACHE FILEPATH "protobuf lite library." FORCE) + SET(PROTOBUF_LIBRARY ${extern_protobuf_LIBRARY} + CACHE FILEPATH "protobuf library." FORCE) + SET(PROTOBUF_PROTOC_LIBRARY ${extern_protobuf_PROTOC_LIBRARY} + CACHE FILEPATH "protoc library." FORCE) + + IF(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + PROMPT_PROTOBUF_LIB(protobuf_host extern_protobuf) + ELSE() + SET(PROTOBUF_PROTOC_EXECUTABLE ${extern_protobuf_PROTOC_EXECUTABLE} + CACHE FILEPATH "protobuf executable." FORCE) + PROMPT_PROTOBUF_LIB(extern_protobuf) + ENDIF() + +ENDIF(NOT PROTOBUF_FOUND) diff --git a/cmake/external/xbyak.cmake b/cmake/external/xbyak.cmake new file mode 100644 index 00000000000..1d61154c0d4 --- /dev/null +++ b/cmake/external/xbyak.cmake @@ -0,0 +1,57 @@ +# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(WITH_XBYAK ON) +if(WIN32 OR APPLE) + SET(WITH_XBYAK OFF CACHE STRING "Disable XBYAK in Windows and MacOS" FORCE) + return() +endif() + +include(ExternalProject) + +set(XBYAK_PROJECT extern_xbyak) +set(XBYAK_PREFIX_DIR ${THIRD_PARTY_PATH}/xbyak) +set(XBYAK_INSTALL_ROOT ${THIRD_PARTY_PATH}/install/xbyak) +set(XBYAK_INC_DIR ${XBYAK_INSTALL_ROOT}/include) + +include_directories(${XBYAK_INC_DIR}) +include_directories(${XBYAK_INC_DIR}/xbyak) + +add_definitions(-DPADDLE_WITH_XBYAK) + +# xbyak options +add_definitions(-DXBYAK64) +add_definitions(-DXBYAK_NO_OP_NAMES) + +ExternalProject_Add( + ${XBYAK_PROJECT} + ${EXTERNAL_PROJECT_LOG_ARGS} + DEPENDS "" + GIT_REPOSITORY "https://github.com/herumi/xbyak.git" + GIT_TAG "v5.661" # Jul 26th + PREFIX ${XBYAK_PREFIX_DIR} + UPDATE_COMMAND "" + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${XBYAK_INSTALL_ROOT} + CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${XBYAK_INSTALL_ROOT} +) + +if (${CMAKE_VERSION} VERSION_LESS "3.3.0") + set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/xbyak_dummy.c) + file(WRITE ${dummyfile} "const char *dummy_xbyak = \"${dummyfile}\";") + add_library(xbyak STATIC ${dummyfile}) +else() + add_library(xbyak INTERFACE) +endif() + +add_dependencies(xbyak ${XBYAK_PROJECT}) diff --git a/cmake/external/xxhash.cmake b/cmake/external/xxhash.cmake new file mode 100644 index 00000000000..23b1e021086 --- /dev/null +++ b/cmake/external/xxhash.cmake @@ -0,0 +1,73 @@ +INCLUDE(ExternalProject) + +set(XXHASH_SOURCE_DIR ${THIRD_PARTY_PATH}/xxhash) +set(XXHASH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/xxhash) +set(XXHASH_INCLUDE_DIR "${XXHASH_INSTALL_DIR}/include") + +IF(WITH_STATIC_LIB) + SET(BUILD_CMD make lib) +ELSE() + IF(APPLE) + SET(BUILD_CMD sed -i \"\" "s/-Wstrict-prototypes -Wundef/-Wstrict-prototypes -Wundef -fPIC/g" ${XXHASH_SOURCE_DIR}/src/extern_xxhash/Makefile && make lib) + ELSE(APPLE) + SET(BUILD_CMD sed -i "s/-Wstrict-prototypes -Wundef/-Wstrict-prototypes -Wundef -fPIC/g" ${XXHASH_SOURCE_DIR}/src/extern_xxhash/Makefile && make lib) + ENDIF(APPLE) +ENDIF() + +if(WIN32) + ExternalProject_Add( + extern_xxhash + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/Cyan4973/xxHash" + GIT_TAG "v0.6.5" + PREFIX ${XXHASH_SOURCE_DIR} + DOWNLOAD_NAME "xxhash" + UPDATE_COMMAND "" + BUILD_IN_SOURCE 1 + PATCH_COMMAND + CONFIGURE_COMMAND + ${CMAKE_COMMAND} ${XXHASH_SOURCE_DIR}/src/extern_xxhash/cmake_unofficial + -DCMAKE_INSTALL_PREFIX:PATH=${XXHASH_INSTALL_DIR} + -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} + -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + -DBUILD_XXHSUM=OFF + -DCMAKE_GENERATOR_PLATFORM=x64 + -DBUILD_SHARED_LIBS=OFF + ${OPTIONAL_CACHE_ARGS} + TEST_COMMAND "" + ) +else() + ExternalProject_Add( + extern_xxhash + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/Cyan4973/xxHash" + GIT_TAG "v0.6.5" + PREFIX ${XXHASH_SOURCE_DIR} + DOWNLOAD_NAME "xxhash" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_IN_SOURCE 1 + PATCH_COMMAND + BUILD_COMMAND ${BUILD_CMD} + INSTALL_COMMAND export PREFIX=${XXHASH_INSTALL_DIR}/ && make install + TEST_COMMAND "" + ) +endif() + +if (WIN32) + IF(NOT EXISTS "${XXHASH_INSTALL_DIR}/lib/libxxhash.lib") + add_custom_command(TARGET extern_xxhash POST_BUILD + COMMAND cmake -E copy ${XXHASH_INSTALL_DIR}/lib/xxhash.lib ${XXHASH_INSTALL_DIR}/lib/libxxhash.lib + ) + ENDIF() + set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.lib") +else() + set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.a") +endif () +INCLUDE_DIRECTORIES(${XXHASH_INCLUDE_DIR}) + +add_library(xxhash STATIC IMPORTED GLOBAL) +set_property(TARGET xxhash PROPERTY IMPORTED_LOCATION ${XXHASH_LIBRARIES}) +include_directories(${XXHASH_INCLUDE_DIR}) +add_dependencies(xxhash extern_xxhash) diff --git a/cmake/flags.cmake b/cmake/flags.cmake new file mode 100644 index 00000000000..36b533aa4f7 --- /dev/null +++ b/cmake/flags.cmake @@ -0,0 +1,194 @@ +# Setting Paddle Compile Flags +include(CheckCXXCompilerFlag) +include(CheckCCompilerFlag) +include(CheckCXXSymbolExists) +include(CheckTypeSize) + +function(CheckCompilerCXX11Flag) + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) + message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") + endif() + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + # cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang" + # Apple Clang is a different compiler than upstream Clang which havs different version numbers. + # https://gist.github.com/yamaya/2924292 + if(APPLE) # cmake < 3.0 compiler id "Clang" on Mac OS X + if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 5.1) + message(FATAL_ERROR "Unsupported AppleClang version. AppleClang >= 5.1 required.") + endif() + else() + if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3) + message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.") + endif() + endif() + endif() +endfunction() + +CheckCompilerCXX11Flag() +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") +# safe_set_flag +# +# Set a compile flag only if compiler is support +# is_c: is C flag or C++ flag, bool type. +# src_list: The list name which the flag name will be append to. +# flag_name: the flag name for compiler, such as '-Werror' '-Wall' etc +# rest arguments: not used. +function(safe_set_flag is_c src_list flag_name) + string(REPLACE "-" "_" safe_name ${flag_name}) + string(REPLACE "=" "_" safe_name ${safe_name}) + if(is_c) + CHECK_C_COMPILER_FLAG(${flag_name} C_COMPILER_SUPPORT_FLAG_${safe_name}) + set(safe_name C_COMPILER_SUPPORT_FLAG_${safe_name}) + else() + CHECK_CXX_COMPILER_FLAG(${flag_name} CXX_COMPILER_SUPPORT_FLAG_${safe_name}) + set(safe_name CXX_COMPILER_SUPPORT_FLAG_${safe_name}) + endif() + if(${safe_name}) + set(${src_list} "${${src_list}} ${flag_name}" PARENT_SCOPE) + endif() +endfunction() + +# helper macro to set cflag +macro(safe_set_cflag src_list flag_name) + safe_set_flag(ON ${src_list} ${flag_name}) +endmacro() + +# helper macro to set cxxflag +macro(safe_set_cxxflag src_list flag_name) + safe_set_flag(OFF ${src_list} ${flag_name}) +endmacro() + +# helper macro to set nvcc flag +macro(safe_set_nvflag flag_name) + string(REPLACE "-" "_" safe_name ${flag_name}) + string(REPLACE "=" "_" safe_name ${safe_name}) + CHECK_C_COMPILER_FLAG(${flag_name} C_COMPILER_SUPPORT_FLAG_${safe_name}) + set(safe_name C_COMPILER_SUPPORT_FLAG_${safe_name}) + if(${safe_name}) + LIST(APPEND CUDA_NVCC_FLAGS -Xcompiler ${flag_name}) + endif() +endmacro() + +macro(safe_set_static_flag) # set c_flags and cxx_flags to static or shared + if (BUILD_SHARED_LIBS) + return() # if build shared libs, the flags keep same with '/MD' + endif(BUILD_SHARED_LIBS) + foreach(flag_var + CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO + CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE + CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO) + if(${flag_var} MATCHES "/MD") + string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") + endif(${flag_var} MATCHES "/MD") + endforeach(flag_var) +endmacro() + +CHECK_CXX_SYMBOL_EXISTS(UINT64_MAX "stdint.h" UINT64_MAX_EXISTS) +if(NOT UINT64_MAX_EXISTS) + set(CMAKE_REQUIRED_DEFINITIONS -D__STDC_LIMIT_MACROS) + CHECK_CXX_SYMBOL_EXISTS(UINT64_MAX "stdint.h" UINT64_MAX_EXISTS_HERE) + if(UINT64_MAX_EXISTS_HERE) + set(CMAKE_REQUIRED_DEFINITIONS) + add_definitions(-D__STDC_LIMIT_MACROS) + else() + message(FATAL_ERROR "Cannot find symbol UINT64_MAX") + endif() +endif() + +SET(CMAKE_EXTRA_INCLUDE_FILES "pthread.h") +CHECK_TYPE_SIZE(pthread_spinlock_t SPINLOCK_FOUND) +CHECK_TYPE_SIZE(pthread_barrier_t BARRIER_FOUND) +if(SPINLOCK_FOUND) + add_definitions(-DPADDLE_USE_PTHREAD_SPINLOCK) +endif(SPINLOCK_FOUND) +if(BARRIER_FOUND) + add_definitions(-DPADDLE_USE_PTHREAD_BARRIER) +endif(BARRIER_FOUND) +SET(CMAKE_EXTRA_INCLUDE_FILES "") + +# Common flags. the compiler flag used for C/C++ sources whenever release or debug +# Do not care if this flag is support for gcc. + +# https://github.com/PaddlePaddle/Paddle/issues/12773 +if (NOT WIN32) +set(COMMON_FLAGS + -fPIC + -fno-omit-frame-pointer + -Werror + -Wall + -Wextra + -Wnon-virtual-dtor + -Wdelete-non-virtual-dtor + -Wno-unused-parameter + -Wno-unused-function + -Wno-error=literal-suffix + -Wno-error=sign-compare + -Wno-error=unused-local-typedefs + -Wno-error=parentheses-equality # Warnings in pybind11 + -Wno-error=ignored-attributes # Warnings in Eigen, gcc 6.3 + -Wno-error=terminate # Warning in PADDLE_ENFORCE + -Wno-error=int-in-bool-context # Warning in Eigen gcc 7.2 + -Wimplicit-fallthrough=0 # Warning in tinyformat.h + -Wno-error=maybe-uninitialized # Warning in boost gcc 7.2 +) + +set(GPU_COMMON_FLAGS + -fPIC + -fno-omit-frame-pointer + -Wnon-virtual-dtor + -Wdelete-non-virtual-dtor + -Wno-unused-parameter + -Wno-unused-function + -Wno-error=sign-compare + -Wno-error=literal-suffix + -Wno-error=unused-local-typedefs + -Wno-error=unused-function # Warnings in Numpy Header. + -Wno-error=array-bounds # Warnings in Eigen::array +) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64") +endif(NOT WIN32) + +if (APPLE) + # On Mac OS X build fat binaries with x86_64 architectures by default. + set (CMAKE_OSX_ARCHITECTURES "x86_64" CACHE STRING "Build architectures for OSX" FORCE) + # On Mac OS X register class specifier is deprecated and will cause warning error on latest clang 10.0 + set (COMMON_FLAGS -Wno-deprecated-register) +endif(APPLE) + +if(LINUX) + set(GPU_COMMON_FLAGS + -Wall + -Wextra + -Werror + ${GPU_COMMON_FLAGS}) +endif(LINUX) + +if(UNIX AND NOT APPLE) + # except apple from nix*Os family + set(LINUX TRUE) +endif(UNIX AND NOT APPLE) + +foreach(flag ${COMMON_FLAGS}) + safe_set_cflag(CMAKE_C_FLAGS ${flag}) + safe_set_cxxflag(CMAKE_CXX_FLAGS ${flag}) + +endforeach() + +foreach(flag ${GPU_COMMON_FLAGS}) + safe_set_nvflag(${flag}) +endforeach() + +if(WIN32) +# windows build turn off warnings. +safe_set_static_flag() + foreach(flag_var + CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO + CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE + CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO) + string(REGEX REPLACE "(^| )/W[0-9]( |$)" " " ${flag_var} "${${flag_var}}") + set(flag_var "${flag_var} /w") + endforeach(flag_var) +endif(WIN32) diff --git a/cmake/generic.cmake b/cmake/generic.cmake new file mode 100644 index 00000000000..b09df6eaea9 --- /dev/null +++ b/cmake/generic.cmake @@ -0,0 +1,570 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +# generic.cmake defines CMakes functions that look like Bazel's +# building rules (https://bazel.build/). +# +# +# ------------------------------------------- +# C++ CUDA C++ Go +# ------------------------------------------- +# cc_library nv_library go_library +# cc_binary nv_binary go_binary +# cc_test nv_test go_test +# ------------------------------------------- +# +# To build a static library example.a from example.cc using the system +# compiler (like GCC): +# +# cc_library(example SRCS example.cc) +# +# To build a static library example.a from multiple source files +# example{1,2,3}.cc: +# +# cc_library(example SRCS example1.cc example2.cc example3.cc) +# +# To build a shared library example.so from example.cc: +# +# cc_library(example SHARED SRCS example.cc) +# +# To build a library using Nvidia's NVCC from .cu file(s), use the nv_ +# prefixed version: +# +# nv_library(example SRCS example.cu) +# +# To specify that a library new_example.a depends on other libraies: +# +# cc_library(new_example SRCS new_example.cc DEPS example) +# +# Static libraries can be composed of other static libraries: +# +# cc_library(composed DEPS dependent1 dependent2 dependent3) +# +# To build an executable binary file from some source files and +# dependent libraries: +# +# cc_binary(example SRCS main.cc something.cc DEPS example1 example2) +# +# To build an executable binary file using NVCC, use the nv_ prefixed +# version: +# +# nv_binary(example SRCS main.cc something.cu DEPS example1 example2) +# +# To build a unit test binary, which is an executable binary with +# GoogleTest linked: +# +# cc_test(example_test SRCS example_test.cc DEPS example) +# +# To build a unit test binary using NVCC, use the nv_ prefixed version: +# +# nv_test(example_test SRCS example_test.cu DEPS example) +# +# It is pretty often that executable and test binaries depend on +# pre-defined external libaries like glog and gflags defined in +# /cmake/external/*.cmake: +# +# cc_test(example_test SRCS example_test.cc DEPS example glog gflags) +# +# To build a go static library using Golang, use the go_ prefixed version: +# +# go_library(example STATIC) +# +# To build a go shared library using Golang, use the go_ prefixed version: +# +# go_library(example SHARED) +# + +# including binary directory for generated headers. +include_directories(${CMAKE_CURRENT_BINARY_DIR}) + +if(NOT APPLE) + find_package(Threads REQUIRED) + link_libraries(${CMAKE_THREAD_LIBS_INIT}) + set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl") + if (NOT ANDROID) + set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -lrt") + endif() +endif(NOT APPLE) + +set_property(GLOBAL PROPERTY FLUID_MODULES "") +# find all fluid modules is used for paddle fluid static library +# for building inference libs +function(find_fluid_modules TARGET_NAME) + get_filename_component(__target_path ${TARGET_NAME} ABSOLUTE) + string(REGEX REPLACE "^${PADDLE_SOURCE_DIR}/" "" __target_path ${__target_path}) + string(FIND "${__target_path}" "fluid" pos) + if(pos GREATER 1) + get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) + set(fluid_modules ${fluid_modules} ${TARGET_NAME}) + set_property(GLOBAL PROPERTY FLUID_MODULES "${fluid_modules}") + endif() +endfunction(find_fluid_modules) + + +function(common_link TARGET_NAME) + if (WITH_PROFILER) + target_link_libraries(${TARGET_NAME} gperftools::profiler) + endif() + + if (WITH_JEMALLOC) + target_link_libraries(${TARGET_NAME} jemalloc::jemalloc) + endif() +endfunction() + + +# find all third_party modules is used for paddle static library +# for reduce the dependency when building the inference libs. +set_property(GLOBAL PROPERTY FLUID_THIRD_PARTY) +function(find_fluid_thirdparties TARGET_NAME) + get_filename_component(__target_path ${TARGET_NAME} ABSOLUTE) + string(REGEX REPLACE "^${PADDLE_SOURCE_DIR}/" "" __target_path ${__target_path}) + string(FIND "${__target_path}" "third_party" pos) + if(pos GREATER 1) + get_property(fluid_ GLOBAL PROPERTY FLUID_THIRD_PARTY) + set(fluid_third_partys ${fluid_third_partys} ${TARGET_NAME}) + set_property(GLOBAL PROPERTY FLUID_THIRD_PARTY "${fluid_third_partys}") + endif() +endfunction(find_fluid_thirdparties) + +function(merge_static_libs TARGET_NAME) + set(libs ${ARGN}) + list(REMOVE_DUPLICATES libs) + + # Get all propagation dependencies from the merged libraries + foreach(lib ${libs}) + list(APPEND libs_deps ${${lib}_LIB_DEPENDS}) + endforeach() + if(libs_deps) + list(REMOVE_DUPLICATES libs_deps) + endif() + + # To produce a library we need at least one source file. + # It is created by add_custom_command below and will helps + # also help to track dependencies. + set(target_SRCS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c) + + if(APPLE) # Use OSX's libtool to merge archives + # Make the generated dummy source file depended on all static input + # libs. If input lib changes,the source file is touched + # which causes the desired effect (relink). + add_custom_command(OUTPUT ${target_SRCS} + COMMAND ${CMAKE_COMMAND} -E touch ${target_SRCS} + DEPENDS ${libs}) + + # Generate dummy staic lib + file(WRITE ${target_SRCS} "const char *dummy_${TARGET_NAME} = \"${target_SRCS}\";") + add_library(${TARGET_NAME} STATIC ${target_SRCS}) + target_link_libraries(${TARGET_NAME} ${libs_deps}) + + foreach(lib ${libs}) + # Get the file names of the libraries to be merged + set(libfiles ${libfiles} $) + endforeach() + add_custom_command(TARGET ${TARGET_NAME} POST_BUILD + COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" + COMMAND /usr/bin/libtool -static -o "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" ${libfiles} + ) + endif(APPLE) + if(LINUX) # general UNIX: use "ar" to extract objects and re-add to a common lib + set(target_DIR ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}.dir) + + foreach(lib ${libs}) + set(objlistfile ${target_DIR}/${lib}.objlist) # list of objects in the input library + set(objdir ${target_DIR}/${lib}.objdir) + + add_custom_command(OUTPUT ${objdir} + COMMAND ${CMAKE_COMMAND} -E make_directory ${objdir} + DEPENDS ${lib}) + + add_custom_command(OUTPUT ${objlistfile} + COMMAND ${CMAKE_AR} -x "$" + COMMAND ${CMAKE_AR} -t "$" > ${objlistfile} + DEPENDS ${lib} ${objdir} + WORKING_DIRECTORY ${objdir}) + + list(APPEND target_OBJS "${objlistfile}") + endforeach() + + # Make the generated dummy source file depended on all static input + # libs. If input lib changes,the source file is touched + # which causes the desired effect (relink). + add_custom_command(OUTPUT ${target_SRCS} + COMMAND ${CMAKE_COMMAND} -E touch ${target_SRCS} + DEPENDS ${libs} ${target_OBJS}) + + # Generate dummy staic lib + file(WRITE ${target_SRCS} "const char *dummy_${TARGET_NAME} = \"${target_SRCS}\";") + add_library(${TARGET_NAME} STATIC ${target_SRCS}) + target_link_libraries(${TARGET_NAME} ${libs_deps}) + + # Get the file name of the generated library + set(target_LIBNAME "$") + + add_custom_command(TARGET ${TARGET_NAME} POST_BUILD + COMMAND ${CMAKE_AR} crs ${target_LIBNAME} `find ${target_DIR} -name '*.o'` + COMMAND ${CMAKE_RANLIB} ${target_LIBNAME} + WORKING_DIRECTORY ${target_DIR}) + endif(LINUX) + if(WIN32) # windows do not support gcc/nvcc combined compiling. Use msvc lib.exe to merge libs. + # Make the generated dummy source file depended on all static input + # libs. If input lib changes,the source file is touched + # which causes the desired effect (relink). + add_custom_command(OUTPUT ${target_SRCS} + COMMAND ${CMAKE_COMMAND} -E touch ${target_SRCS} + DEPENDS ${libs}) + + # Generate dummy staic lib + file(WRITE ${target_SRCS} "const char *dummy_${TARGET_NAME} = \"${target_SRCS}\";") + add_library(${TARGET_NAME} STATIC ${target_SRCS}) + target_link_libraries(${TARGET_NAME} ${libs_deps}) + + foreach(lib ${libs}) + # Get the file names of the libraries to be merged + set(libfiles ${libfiles} $) + endforeach() + # msvc will put libarary in directory of "/Release/xxxlib" by default + # COMMAND cmake -E remove "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/${TARGET_NAME}.lib" + add_custom_command(TARGET ${TARGET_NAME} POST_BUILD + COMMAND cmake -E make_directory "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}" + COMMAND lib /OUT:${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/lib${TARGET_NAME}.lib ${libfiles} + ) + endif(WIN32) +endfunction(merge_static_libs) + +function(cc_library TARGET_NAME) + set(options STATIC static SHARED shared) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(cc_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + if(WIN32) + # add libxxx.lib prefix in windows + set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}") + endif(WIN32) + if(cc_library_SRCS) + if(cc_library_SHARED OR cc_library_shared) # build *.so + add_library(${TARGET_NAME} SHARED ${cc_library_SRCS}) + else() + add_library(${TARGET_NAME} STATIC ${cc_library_SRCS}) + find_fluid_modules(${TARGET_NAME}) + endif() + + if(cc_library_DEPS) + # Don't need link libwarpctc.so + if("${cc_library_DEPS};" MATCHES "warpctc;") + list(REMOVE_ITEM cc_library_DEPS warpctc) + add_dependencies(${TARGET_NAME} warpctc) + endif() + # Only deps libmklml.so, not link + if("${cc_library_DEPS};" MATCHES "mklml;") + list(REMOVE_ITEM cc_library_DEPS mklml) + if(NOT "${TARGET_NAME}" MATCHES "dynload_mklml") + list(APPEND cc_library_DEPS dynload_mklml) + endif() + add_dependencies(${TARGET_NAME} mklml) + if(WIN32) + target_link_libraries(${TARGET_NAME} ${MKLML_IOMP_LIB}) + else(WIN32) + target_link_libraries(${TARGET_NAME} "-L${MKLML_LIB_DIR} -liomp5 -Wl,--as-needed") + endif(WIN32) + endif() + # remove link to python, see notes at: + # https://github.com/pybind/pybind11/blob/master/docs/compiling.rst#building-manually + if("${cc_library_DEPS};" MATCHES "python;") + list(REMOVE_ITEM cc_library_DEPS python) + add_dependencies(${TARGET_NAME} python) + if(WIN32) + target_link_libraries(${TARGET_NAME} ${PYTHON_LIBRARIES}) + else() + target_link_libraries(${TARGET_NAME} "-Wl,-undefined,dynamic_lookup") + endif(WIN32) + endif() + target_link_libraries(${TARGET_NAME} ${cc_library_DEPS}) + add_dependencies(${TARGET_NAME} ${cc_library_DEPS}) + common_link(${TARGET_NAME}) + endif() + + set(full_path_src "") + # cpplint code style + foreach(source_file ${cc_library_SRCS}) + string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file}) + if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + list(APPEND cc_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + endif() + if(${source_file} MATCHES "framework.pb.cc") + list(APPEND full_path_src ${source_file}) + else() + list(APPEND full_path_src ${CMAKE_CURRENT_SOURCE_DIR}/${source_file}) + endif() + endforeach() + set(__lite_cc_files ${__lite_cc_files} ${full_path_src} CACHE INTERNAL "") + else(cc_library_SRCS) + if(cc_library_DEPS) + merge_static_libs(${TARGET_NAME} ${cc_library_DEPS}) + else() + message(FATAL_ERROR "Please specify source files or libraries in cc_library(${TARGET_NAME} ...).") + endif() + endif(cc_library_SRCS) +endfunction(cc_library) + +# The link operation under windows may exceeds the maximum characters limit, simply break the link command +# into multiple link opeartion can fix that, say +# original: +# lib /out:target.lib a.lib b.lib c.lib d.lib +# after: +# 1. lib /out:dummy_lib_1.lib a.lib b.lib +# 2. lib /out:dummy_lib_2.lib c.lib d.lib +# 1. lib /out:target.lib dummy_lib_1.lib dummy_lib_2.lib +function(sep_library TARGET_NAME) + set(options STATIC static SHARED shared) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(sep_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + set(dummy_index 1) + set(dummy_offset 1) + # the dummy target would be consisted of limit size libraries + set(dummy_limit 50) + list(LENGTH sep_library_DEPS sep_all_len) + foreach(v ${sep_library_DEPS}) + list(APPEND dummy_list ${v}) + list(LENGTH dummy_list listlen ) + if ((${listlen} GREATER ${dummy_limit}) OR (${dummy_offset} EQUAL ${sep_all_len})) + message("create dummy library ${TARGET_NAME}_dummy_lib_${dummy_index} for ${TARGET_NAME}") + cc_library(${TARGET_NAME}_dummy_lib_${dummy_index} STATIC DEPS ${dummy_list}) + foreach(i ${dummy_list}) + list(REMOVE_AT dummy_list 0) + endforeach() + list(APPEND ${TARGET_NAME}_dummy_list ${TARGET_NAME}_dummy_lib_${dummy_index}) + MATH(EXPR dummy_index "${dummy_index}+1") + endif() + MATH(EXPR dummy_offset "${dummy_offset}+1") + endforeach() + if(${sep_library_SHARED}) + cc_library(${TARGET_NAME} SHARED SRCS ${sep_library_SRCS} DEPS ${${TARGET_NAME}_dummy_list}) + else(${sep_library_SHARED}) + cc_library(${TARGET_NAME} STATIC SRCS ${sep_library_SRCS} DEPS ${${TARGET_NAME}_dummy_list}) + endif(${sep_library_SHARED}) +endfunction(sep_library) + +function(cc_binary TARGET_NAME) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(cc_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + add_executable(${TARGET_NAME} ${cc_binary_SRCS}) + if(cc_binary_DEPS) + target_link_libraries(${TARGET_NAME} ${cc_binary_DEPS}) + add_dependencies(${TARGET_NAME} ${cc_binary_DEPS}) + common_link(${TARGET_NAME}) + endif() + get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) + target_link_libraries(${TARGET_NAME} ${os_dependency_modules}) +endfunction(cc_binary) + +function(cc_test TARGET_NAME) + if(WITH_TESTING) + set(options SERIAL) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS ARGS) + cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + add_executable(${TARGET_NAME} ${cc_test_SRCS}) + if(WIN32) + if("${cc_test_DEPS};" MATCHES "python;") + list(REMOVE_ITEM cc_test_DEPS python) + target_link_libraries(${TARGET_NAME} ${PYTHON_LIBRARIES}) + endif() + endif(WIN32) + get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) + target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} ${os_dependency_modules} paddle_gtest_main lod_tensor memory gtest gflags glog) + add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) + common_link(${TARGET_NAME}) + add_test(NAME ${TARGET_NAME} + COMMAND ${TARGET_NAME} ${cc_test_ARGS} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + if (${cc_test_SERIAL}) + set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) + endif() + set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) + set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) + set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G + set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true) + # No unit test should exceed 10 minutes. + set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600) + endif() +endfunction(cc_test) + +# cc_test without default dependencies +function(raw_cc_test TARGET_NAME) + if(WITH_TESTING) + set(options SERIAL) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS ARGS) + cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + add_executable(${TARGET_NAME} ${cc_test_SRCS}) + if(WIN32) + if("${cc_test_DEPS};" MATCHES "python;") + list(REMOVE_ITEM cc_test_DEPS python) + target_link_libraries(${TARGET_NAME} ${PYTHON_LIBRARIES}) + endif() + endif(WIN32) + get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) + + if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} ${os_dependency_modules} lite_gtest_main gtest gflags logging) + add_dependencies(${TARGET_NAME} ${cc_test_DEPS} lite_gtest_main gtest gflags logging) + else() + target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} ${os_dependency_modules} lite_gtest_main gtest gflags glog) + add_dependencies(${TARGET_NAME} ${cc_test_DEPS} lite_gtest_main gtest gflags glog) + endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + + common_link(${TARGET_NAME}) + add_test(NAME ${TARGET_NAME} + COMMAND ${TARGET_NAME} ${cc_test_ARGS} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + if (${cc_test_SERIAL}) + set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) + endif() + # No unit test should exceed 10 minutes. + set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600) + endif() +endfunction(raw_cc_test) + +function(_lite_cc_test args) + message(STATUS "building lite raw test: ${args}") + raw_cc_test(${args} ${ARGN}) +endfunction() + +function(nv_library TARGET_NAME) + if (WITH_GPU) + set(options STATIC static SHARED shared) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(nv_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + if(nv_library_SRCS) + if (nv_library_SHARED OR nv_library_shared) # build *.so + cuda_add_library(${TARGET_NAME} SHARED ${nv_library_SRCS}) + else() + cuda_add_library(${TARGET_NAME} STATIC ${nv_library_SRCS}) + find_fluid_modules(${TARGET_NAME}) + endif() + if (nv_library_DEPS) + add_dependencies(${TARGET_NAME} ${nv_library_DEPS}) + target_link_libraries(${TARGET_NAME} ${nv_library_DEPS}) + endif() + # cpplint code style + foreach(source_file ${nv_library_SRCS}) + string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file}) + if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + list(APPEND nv_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + endif() + endforeach() + else(nv_library_SRCS) + if (nv_library_DEPS) + merge_static_libs(${TARGET_NAME} ${nv_library_DEPS}) + else() + message(FATAL "Please specify source file or library in nv_library.") + endif() + endif(nv_library_SRCS) + endif() +endfunction(nv_library) + +function(nv_binary TARGET_NAME) + if (WITH_GPU) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(nv_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + cuda_add_executable(${TARGET_NAME} ${nv_binary_SRCS}) + if(nv_binary_DEPS) + target_link_libraries(${TARGET_NAME} ${nv_binary_DEPS}) + add_dependencies(${TARGET_NAME} ${nv_binary_DEPS}) + common_link(${TARGET_NAME}) + endif() + endif() +endfunction(nv_binary) + +function(nv_test TARGET_NAME) + if (WITH_GPU AND WITH_TESTING) + set(options SERIAL) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS}) + get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) + target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog ${os_dependency_modules}) + add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) + common_link(${TARGET_NAME}) + add_test(${TARGET_NAME} ${TARGET_NAME}) + if (nv_test_SERIAL) + set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) + endif() + set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) + set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) + set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G + set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true) + endif() +endfunction(nv_test) + + +# Modification of standard 'protobuf_generate_cpp()' with protobuf-lite support +# Usage: +# paddle_protobuf_generate_cpp( ) + +function(paddle_protobuf_generate_cpp SRCS HDRS) + if(NOT ARGN) + message(SEND_ERROR "Error: paddle_protobuf_generate_cpp() called without any proto files") + return() + endif() + + set(${SRCS}) + set(${HDRS}) + + foreach(FIL ${ARGN}) + get_filename_component(ABS_FIL ${FIL} ABSOLUTE) + get_filename_component(FIL_WE ${FIL} NAME_WE) + + set(_protobuf_protoc_src "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.cc") + set(_protobuf_protoc_hdr "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.h") + list(APPEND ${SRCS} "${_protobuf_protoc_src}") + list(APPEND ${HDRS} "${_protobuf_protoc_hdr}") + + add_custom_command( + OUTPUT "${_protobuf_protoc_src}" + "${_protobuf_protoc_hdr}" + + COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} + -I${CMAKE_CURRENT_SOURCE_DIR} + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" ${ABS_FIL} + DEPENDS ${ABS_FIL} protoc + COMMENT "Running C++ protocol buffer compiler on ${FIL}" + VERBATIM ) + endforeach() + + set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE) + set(${SRCS} ${${SRCS}} PARENT_SCOPE) + set(${HDRS} ${${HDRS}} PARENT_SCOPE) +endfunction() + + +function(proto_library TARGET_NAME) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(proto_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + set(proto_srcs) + set(proto_hdrs) + paddle_protobuf_generate_cpp(proto_srcs proto_hdrs ${proto_library_SRCS}) + cc_library(${TARGET_NAME} SRCS ${proto_srcs} DEPS ${proto_library_DEPS} protobuf) +endfunction() diff --git a/cmake/hip.cmake b/cmake/hip.cmake new file mode 100644 index 00000000000..c3a748db502 --- /dev/null +++ b/cmake/hip.cmake @@ -0,0 +1,53 @@ +if(NOT WITH_AMD_GPU) + return() +endif() + +include_directories("/opt/rocm/include") +include_directories("/opt/rocm/hip/include") +include_directories("/opt/rocm/miopen/include") +include_directories("/opt/rocm/hipblas/include") +include_directories("/opt/rocm/hiprand/include") +include_directories("/opt/rocm/rocrand/include") +include_directories("/opt/rocm/rccl/include") +include_directories("/opt/rocm/thrust") + +set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -fPIC -DPADDLE_WITH_HIP -std=c++11" ) + +if(WITH_DSO) + set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_USE_DSO") +endif(WITH_DSO) + +if(WITH_TESTING) + set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_WITH_TESTING") +endif(WITH_TESTING) + +if(WITH_DISTRIBUTE) + set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_WITH_DISTRIBUTE") +endif(WITH_DISTRIBUTE) + +if(WITH_GRPC) + set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_WITH_GRPC") +endif(WITH_GRPC) + +if(WITH_MKLDNN) + set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_WITH_MKLDNN") +endif(WITH_MKLDNN) + +set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DANY_IMPL_ANY_CAST_MOVEABLE") + +if(CMAKE_BUILD_TYPE STREQUAL "Debug") + list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_DEBUG}) +elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo") + list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}) +elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel") + list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_MINSIZEREL}) +endif() + +if("x${HCC_HOME}" STREQUAL "x") + set(HCC_HOME "/opt/rocm/hcc") +endif() + +set(CMAKE_HIP_LINK_EXECUTABLE "${HIP_HIPCC_CMAKE_LINKER_HELPER} ${HCC_HOME} -o ") +set(CMAKE_HIP_CREATE_SHARED_LIBRARY "${HIP_HIPCC_CMAKE_LINKER_HELPER} ${HCC_HOME} -o -shared") +set(CMAKE_HIP_CREATE_SHARED_MODULE "${HIP_HIPCC_CMAKE_LINKER_HELPER} ${HCC_HOME} -o -shared") + diff --git a/cmake/lite.cmake b/cmake/lite.cmake new file mode 100644 index 00000000000..75b7c2b112c --- /dev/null +++ b/cmake/lite.cmake @@ -0,0 +1,89 @@ +# Bundle several static libraries into one. +function(bundle_static_library tgt_name bundled_tgt_name fake_target) + list(APPEND static_libs ${tgt_name}) + + function(_recursively_collect_dependencies input_target) + set(_input_link_libraries LINK_LIBRARIES) + get_target_property(_input_type ${input_target} TYPE) + if (${_input_type} STREQUAL "INTERFACE_LIBRARY") + set(_input_link_libraries INTERFACE_LINK_LIBRARIES) + endif() + get_target_property(public_dependencies ${input_target} ${_input_link_libraries}) + foreach(dependency IN LISTS public_dependencies) + if(TARGET ${dependency}) + get_target_property(alias ${dependency} ALIASED_TARGET) + if (TARGET ${alias}) + set(dependency ${alias}) + endif() + get_target_property(_type ${dependency} TYPE) + if (${_type} STREQUAL "STATIC_LIBRARY") + list(APPEND static_libs ${dependency}) + endif() + + get_property(library_already_added + GLOBAL PROPERTY _${tgt_name}_static_bundle_${dependency}) + if (NOT library_already_added) + set_property(GLOBAL PROPERTY _${tgt_name}_static_bundle_${dependency} ON) + _recursively_collect_dependencies(${dependency}) + endif() + endif() + endforeach() + set(static_libs ${static_libs} PARENT_SCOPE) + endfunction() + + _recursively_collect_dependencies(${tgt_name}) + + list(REMOVE_DUPLICATES static_libs) + + set(bundled_tgt_full_name + ${CMAKE_BINARY_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}${bundled_tgt_name}${CMAKE_STATIC_LIBRARY_SUFFIX}) + + message(STATUS "+++++ bundled_tgt_full_name: ${bundled_tgt_full_name}") + + if(NOT IOS) + file(WRITE ${CMAKE_BINARY_DIR}/${bundled_tgt_name}.ar.in + "CREATE ${bundled_tgt_full_name}\n" ) + + foreach(tgt IN LISTS static_libs) + file(APPEND ${CMAKE_BINARY_DIR}/${bundled_tgt_name}.ar.in + "ADDLIB $\n") + endforeach() + + file(APPEND ${CMAKE_BINARY_DIR}/${bundled_tgt_name}.ar.in "SAVE\n") + file(APPEND ${CMAKE_BINARY_DIR}/${bundled_tgt_name}.ar.in "END\n") + + file(GENERATE + OUTPUT ${CMAKE_BINARY_DIR}/${bundled_tgt_name}.ar + INPUT ${CMAKE_BINARY_DIR}/${bundled_tgt_name}.ar.in) + + set(ar_tool ${CMAKE_AR}) + if (CMAKE_INTERPROCEDURAL_OPTIMIZATION) + set(ar_tool ${CMAKE_CXX_COMPILER_AR}) + endif() + + add_custom_command( + COMMAND ${ar_tool} -M < ${CMAKE_BINARY_DIR}/${bundled_tgt_name}.ar + OUTPUT ${bundled_tgt_full_name} + COMMENT "Bundling ${bundled_tgt_name}" + VERBATIM) + else() + foreach(lib ${static_libs}) + set(libfiles ${libfiles} $) + endforeach() + add_custom_command( + COMMAND /usr/bin/libtool -static -o ${bundled_tgt_full_name} ${libfiles} + OUTPUT ${bundled_tgt_full_name} + ) + endif() + + add_custom_target(${fake_target} ALL DEPENDS ${bundled_tgt_full_name}) + add_dependencies(${fake_target} ${tgt_name}) + + add_library(${bundled_tgt_name} STATIC IMPORTED) + set_target_properties(${bundled_tgt_name} + PROPERTIES + IMPORTED_LOCATION ${bundled_tgt_full_name} + INTERFACE_INCLUDE_DIRECTORIES $) + add_dependencies(${bundled_tgt_name} ${fake_target}) + +endfunction() diff --git a/cmake/make_resource.py b/cmake/make_resource.py new file mode 100644 index 00000000000..09a2ca877dd --- /dev/null +++ b/cmake/make_resource.py @@ -0,0 +1,25 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import sys + +res = sys.argv[1] +out = sys.argv[2] +var = re.sub(r'[ .-]', '_', os.path.basename(res)) + +open(out, "w").write("const unsigned char " + var + "[] = {" + ",".join([ + "0x%02x" % ord(c) for c in open(res).read() +]) + ",0};\n" + "const unsigned " + var + "_size = sizeof(" + var + ");\n") diff --git a/cmake/operators.cmake b/cmake/operators.cmake new file mode 100644 index 00000000000..c17e718f427 --- /dev/null +++ b/cmake/operators.cmake @@ -0,0 +1,227 @@ +set(PART_CUDA_KERNEL_FILES) +function(op_library TARGET) + # op_library is a function to create op library. The interface is same as + # cc_library. But it handle split GPU/CPU code and link some common library + # for ops. + set(cc_srcs) + set(cu_srcs) + set(hip_cu_srcs) + set(miopen_hip_cc_srcs) + set(cu_cc_srcs) + set(cudnn_cu_cc_srcs) + set(CUDNN_FILE) + set(mkldnn_cc_srcs) + set(MKLDNN_FILE) + set(op_common_deps operator op_registry math_function) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + set(pybind_flag 0) + cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + list(LENGTH op_library_SRCS op_library_SRCS_len) + if (${op_library_SRCS_len} EQUAL 0) + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) + list(APPEND cc_srcs ${TARGET}.cc) + endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc) + list(APPEND cu_cc_srcs ${TARGET}.cu.cc) + endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) + list(APPEND cu_srcs ${TARGET}.cu) + endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu) + set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu + ${PART_CUDA_KERNEL_FILES} PARENT_SCOPE) + list(APPEND cu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu) + endif() + + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.hip.cu) + list(APPEND hip_cu_srcs ${TARGET}.hip.cu) + endif() + string(REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET}") + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc) + list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc) + endif() + if(WITH_AMD_GPU) + string(REPLACE "_op" "_miopen_op" MIOPEN_FILE "${TARGET}") + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.hip.cc) + list(APPEND miopen_hip_cc_srcs ${MIOPEN_FILE}.hip.cc) + endif() + endif() + if(WITH_MKLDNN) + string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}") + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/mkldnn/${MKLDNN_FILE}.cc) + list(APPEND mkldnn_cc_srcs mkldnn/${MKLDNN_FILE}.cc) + endif() + endif() + else() + foreach(src ${op_library_SRCS}) + if (${src} MATCHES ".*\\.hip.cu$") + list(APPEND hip_cu_srcs ${src}) + elseif (${src} MATCHES ".*\\.cu$") + list(APPEND cu_srcs ${src}) + elseif(${src} MATCHES ".*_cudnn_op.cu.cc$") + list(APPEND cudnn_cu_cc_srcs ${src}) + elseif(WITH_AMD_GPU AND ${src} MATCHES ".*_miopen_op.hip.cc$") + list(APPEND miopen_hip_cc_srcs ${src}) + elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$") + list(APPEND mkldnn_cc_srcs ${src}) + elseif(${src} MATCHES ".*\\.cu.cc$") + list(APPEND cu_cc_srcs ${src}) + elseif(${src} MATCHES ".*\\.cc$") + list(APPEND cc_srcs ${src}) + else() + message(FATAL_ERROR "${TARGET} Source file ${src} should only be .cc or .cu") + endif() + endforeach() + endif() + + list(LENGTH cc_srcs cc_srcs_len) + if (${cc_srcs_len} EQUAL 0) + message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file") + endif() + if (WIN32) + # remove windows unsupported op, because windows has no nccl, no warpctc such ops. + foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op") + if ("${TARGET}" STREQUAL "${windows_unsupport_op}") + return() + endif() + endforeach() + endif(WIN32) + set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} CACHE INTERNAL "op libs") + + list(LENGTH op_library_DEPS op_library_DEPS_len) + if (${op_library_DEPS_len} GREATER 0) + set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE) + endif() + if (WITH_GPU) + nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} + ${op_common_deps}) + elseif (WITH_AMD_GPU) + hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cu_srcs} ${miopen_hip_cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS} + ${op_common_deps}) + else() + cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS} + ${op_common_deps}) + endif() + + # Define operators that don't need pybind here. + foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" +"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" +"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op") + if ("${TARGET}" STREQUAL "${manual_pybind_op}") + set(pybind_flag 1) + endif() + endforeach() + + # The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h. + # Note that it's enough to just adding one operator to pybind in a *_op.cc file. + # And for detail pybind information, please see generated paddle/pybind/pybind.h. + file(READ ${TARGET}.cc TARGET_CONTENT) + string(REGEX MATCH "REGISTER_OPERATOR\\(.*REGISTER_OPERATOR\\(" multi_register "${TARGET_CONTENT}") + string(REGEX MATCH "REGISTER_OPERATOR\\([a-z0-9_]*," one_register "${multi_register}") + if (one_register STREQUAL "") + string(REPLACE "_op" "" TARGET "${TARGET}") + else () + string(REPLACE "REGISTER_OPERATOR(" "" TARGET "${one_register}") + string(REPLACE "," "" TARGET "${TARGET}") + endif() + + # pybind USE_NO_KERNEL_OP + # HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel + string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}") + string(REPLACE "_op" "" TARGET "${TARGET}") + if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "") + file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n") + set(pybind_flag 1) + endif() + + # pybind USE_CPU_ONLY_OP + list(LENGTH cu_srcs cu_srcs_len) + list(LENGTH cu_cc_srcs cu_cc_srcs_len) + list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) + list(LENGTH hip_cu_srcs hip_cu_srcs_len) + list(LENGTH miopen_hip_cc_srcs miopen_hip_cc_srcs_len) + if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND + ${hip_cu_srcs_len} EQUAL 0 AND ${miopen_hip_cc_srcs_len} EQUAL 0) + file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") + set(pybind_flag 1) + endif() + + # pybind USE_OP_DEVICE_KERNEL for CUDNN + list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len) + if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0) + if(${TARGET} STREQUAL "activation") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, CUDNN);\n") + else() + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") + endif() + endif() + + # pybind USE_OP_DEVICE_KERNEL for MIOPEN + if (WITH_AMD_GPU AND ${miopen_hip_cc_srcs_len} GREATER 0) + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MIOPEN);\n") + endif() + + # pybind USE_OP_DEVICE_KERNEL for MKLDNN + if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) + # Append first implemented MKLDNN activation operator + if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") + elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n") + + else() + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") + endif() + endif() + + # pybind USE_OP + if (${pybind_flag} EQUAL 0) + # NOTE(*): activation use macro to regist the kernels, set use_op manually. + if(${TARGET} STREQUAL "activation") + file(APPEND ${pybind_file} "USE_OP(relu);\n") + elseif(${TARGET} STREQUAL "fake_dequantize") + file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n") + elseif(${TARGET} STREQUAL "fake_quantize") + file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n") + elseif(${TARGET} STREQUAL "tensorrt_engine_op") + message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference") + elseif(${TARGET} STREQUAL "fc") + # HACK: fc only have mkldnn and cpu, which would mismatch the cpu only condition + file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") + else() + file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") + endif() + endif() +endfunction() + + +function(register_operators) + set(options "") + set(oneValueArgs "") + set(multiValueArgs EXCLUDES DEPS) + cmake_parse_arguments(register_operators "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") + string(REPLACE "_mkldnn" "" OPS "${OPS}") + string(REPLACE ".cc" "" OPS "${OPS}") + list(REMOVE_DUPLICATES OPS) + list(LENGTH register_operators_DEPS register_operators_DEPS_len) + + foreach(src ${OPS}) + list(FIND register_operators_EXCLUDES ${src} _index) + if (${_index} EQUAL -1) + if (${register_operators_DEPS_len} GREATER 0) + op_library(${src} DEPS ${register_operators_DEPS}) + else() + op_library(${src}) + endif() + endif() + endforeach() +endfunction() diff --git a/cmake/package.cmake b/cmake/package.cmake new file mode 100644 index 00000000000..79e02147f3f --- /dev/null +++ b/cmake/package.cmake @@ -0,0 +1,21 @@ +set(CPACK_PACKAGE_NAME paddle) +set(CPACK_PACKAGE_VERSION_MAJOR ${PADDLE_MAJOR_VERSION}) +set(CPACK_PACKAGE_VERSION_MINOR ${PADDLE_MINOR_VERSION}) +set(CPACK_PACKAGE_VERSION_PATCH ${PADDLE_PATCH_VERSION}) +set(CPACK_PACKAGE_VERSION ${PADDLE_VERSION}) +## DEB Settings +set(CPACK_DEBIAN_PACKAGE_NAME paddle) +set(CPACK_DEBIAN_PACKAGE_ARCHITECTURE amd64) +set(CPACK_DEBIAN_PACKAGE_MAINTAINER PaddlePaddle Dev ) +set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "Paddle") +set(CPACK_PACKAGE_DESCRIPTION "") +set(CPACK_DEBIAN_PACKAGE_DEPENDS "libpython2.7-dev, libstdc++6, python-pip, curl, libgfortran3, python-pip-whl") +set(CPACK_DEBIAN_PACKAGE_SECTION Devel) +set(CPACK_DEBIAN_PACKAGE_VERSION ${PADDLE_VERSION}) +set(CPACK_DEBIAN_PACKAGE_CONTROL_EXTRA "${PADDLE_SOURCE_DIR}/paddle/scripts/deb/postinst") +#set(CPACK_GENERATOR "DEB") +# Start cpack +include (CMakePackageConfigHelpers) +include (CPack) + + diff --git a/cmake/simd.cmake b/cmake/simd.cmake new file mode 100644 index 00000000000..566dc75fda0 --- /dev/null +++ b/cmake/simd.cmake @@ -0,0 +1,99 @@ +# This file is use to check all support level of AVX on your machine +# so that PaddlePaddle can unleash the vectorization power of muticore. + +include(CheckCXXSourceRuns) +include(CheckCXXSourceCompiles) + +if(CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + set(MMX_FLAG "-mmmx") + set(SSE2_FLAG "-msse2") + set(SSE3_FLAG "-msse3") + set(AVX_FLAG "-mavx") + set(AVX2_FLAG "-mavx2") + set(AVX512F_FLAG "-mavx512f") +elseif(MSVC) + set(MMX_FLAG "/arch:MMX") + set(SSE2_FLAG "/arch:SSE2") + set(SSE3_FLAG "/arch:SSE3") + SET(AVX_FLAG "/arch:AVX") + SET(AVX2_FLAG "/arch:AVX2") +endif() + +set(CMAKE_REQUIRED_FLAGS_RETAINED ${CMAKE_REQUIRED_FLAGS}) + +# Check MMX +set(CMAKE_REQUIRED_FLAGS ${MMX_FLAG}) +set(MMX_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE) +CHECK_CXX_SOURCE_RUNS(" +#include +int main() +{ + _mm_setzero_si64(); + return 0; +}" MMX_FOUND) + +# Check SSE2 +set(CMAKE_REQUIRED_FLAGS ${SSE2_FLAG}) +set(SSE2_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE) +CHECK_CXX_SOURCE_RUNS(" +#include +int main() +{ + _mm_setzero_si128(); + return 0; +}" SSE2_FOUND) + +# Check SSE3 +set(CMAKE_REQUIRED_FLAGS ${SSE3_FLAG}) +set(SSE3_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE) +CHECK_CXX_SOURCE_RUNS(" +#include +int main() +{ + __m128d a = _mm_set1_pd(6.28); + __m128d b = _mm_set1_pd(3.14); + __m128d result = _mm_addsub_pd(a, b); + result = _mm_movedup_pd(result); + return 0; +}" SSE3_FOUND) + +# Check AVX +set(CMAKE_REQUIRED_FLAGS ${AVX_FLAG}) +set(AVX_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE) +CHECK_CXX_SOURCE_RUNS(" +#include +int main() +{ + __m256 a = _mm256_set_ps (-1.0f, 2.0f, -3.0f, 4.0f, -1.0f, 2.0f, -3.0f, 4.0f); + __m256 b = _mm256_set_ps (1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f); + __m256 result = _mm256_add_ps (a, b); + return 0; +}" AVX_FOUND) + +# Check AVX 2 +set(CMAKE_REQUIRED_FLAGS ${AVX2_FLAG}) +set(AVX2_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE) +CHECK_CXX_SOURCE_RUNS(" +#include +int main() +{ + __m256i a = _mm256_set_epi32 (-1, 2, -3, 4, -1, 2, -3, 4); + __m256i result = _mm256_abs_epi32 (a); + return 0; +}" AVX2_FOUND) + +# Check AVX512F +set(CMAKE_REQUIRED_FLAGS ${AVX512F_FLAG}) +set(AVX512F_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE) +CHECK_CXX_SOURCE_RUNS(" +#include +int main() +{ + __m512i a = _mm512_set_epi32 (-1, 2, -3, 4, -1, 2, -3, 4, + 13, -5, 6, -7, 9, 2, -6, 3); + __m512i result = _mm512_abs_epi32 (a); + return 0; +}" AVX512F_FOUND) + +set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_RETAINED}) +mark_as_advanced(MMX_FOUND SSE2_FOUND SSE3_FOUND AVX_FOUND AVX2_FOUND AVX512F_FOUND) diff --git a/cmake/system.cmake b/cmake/system.cmake new file mode 100644 index 00000000000..65db05bebe9 --- /dev/null +++ b/cmake/system.cmake @@ -0,0 +1,85 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Detects the OS and sets appropriate variables. +# CMAKE_SYSTEM_NAME only give us a coarse-grained name of the OS CMake is +# building for, but the host processor name like centos is necessary +# in some scenes to distinguish system for customization. +# +# for instance, protobuf libs path is /lib64 +# on CentOS, but /lib on other systems. + +IF(WIN32) + SET(HOST_SYSTEM "win32") +ELSE(WIN32) + IF(APPLE) + SET(HOST_SYSTEM "macosx") + EXEC_PROGRAM(sw_vers ARGS -productVersion OUTPUT_VARIABLE HOST_SYSTEM_VERSION) + STRING(REGEX MATCH "[0-9]+.[0-9]+" MACOS_VERSION "${HOST_SYSTEM_VERSION}") + IF(NOT DEFINED $ENV{MACOSX_DEPLOYMENT_TARGET}) + # Set cache variable - end user may change this during ccmake or cmake-gui configure. + SET(CMAKE_OSX_DEPLOYMENT_TARGET ${MACOS_VERSION} CACHE STRING + "Minimum OS X version to target for deployment (at runtime); newer APIs weak linked. Set to empty string for default value.") + ENDIF() + set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security") + ELSE(APPLE) + + IF(EXISTS "/etc/issue") + FILE(READ "/etc/issue" LINUX_ISSUE) + IF(LINUX_ISSUE MATCHES "CentOS") + SET(HOST_SYSTEM "centos") + ELSEIF(LINUX_ISSUE MATCHES "Debian") + SET(HOST_SYSTEM "debian") + ELSEIF(LINUX_ISSUE MATCHES "Ubuntu") + SET(HOST_SYSTEM "ubuntu") + ELSEIF(LINUX_ISSUE MATCHES "Red Hat") + SET(HOST_SYSTEM "redhat") + ELSEIF(LINUX_ISSUE MATCHES "Fedora") + SET(HOST_SYSTEM "fedora") + ENDIF() + + STRING(REGEX MATCH "(([0-9]+)\\.)+([0-9]+)" HOST_SYSTEM_VERSION "${LINUX_ISSUE}") + ENDIF(EXISTS "/etc/issue") + + IF(EXISTS "/etc/redhat-release") + FILE(READ "/etc/redhat-release" LINUX_ISSUE) + IF(LINUX_ISSUE MATCHES "CentOS") + SET(HOST_SYSTEM "centos") + ENDIF() + ENDIF(EXISTS "/etc/redhat-release") + + IF(NOT HOST_SYSTEM) + SET(HOST_SYSTEM ${CMAKE_SYSTEM_NAME}) + ENDIF() + + ENDIF(APPLE) +ENDIF(WIN32) + +# query number of logical cores +CMAKE_HOST_SYSTEM_INFORMATION(RESULT CPU_CORES QUERY NUMBER_OF_LOGICAL_CORES) + +MARK_AS_ADVANCED(HOST_SYSTEM CPU_CORES) + +MESSAGE(STATUS "Found Paddle host system: ${HOST_SYSTEM}, version: ${HOST_SYSTEM_VERSION}") +MESSAGE(STATUS "Found Paddle host system's CPU: ${CPU_CORES} cores") + +# external dependencies log output +SET(EXTERNAL_PROJECT_LOG_ARGS + LOG_DOWNLOAD 0 # Wrap download in script to log output + LOG_UPDATE 1 # Wrap update in script to log output + LOG_CONFIGURE 1 # Wrap configure in script to log output + LOG_BUILD 0 # Wrap build in script to log output + LOG_TEST 1 # Wrap test in script to log output + LOG_INSTALL 0 # Wrap install in script to log output +) diff --git a/cmake/tensorrt.cmake b/cmake/tensorrt.cmake new file mode 100644 index 00000000000..3bf12094e4c --- /dev/null +++ b/cmake/tensorrt.cmake @@ -0,0 +1,38 @@ +if(NOT WITH_GPU) + return() +endif() + +set(TENSORRT_ROOT "/usr" CACHE PATH "TENSORRT ROOT") +find_path(TENSORRT_INCLUDE_DIR NvInfer.h + PATHS ${TENSORRT_ROOT} ${TENSORRT_ROOT}/include + $ENV{TENSORRT_ROOT} $ENV{TENSORRT_ROOT}/include + NO_DEFAULT_PATH +) + +find_library(TENSORRT_LIBRARY NAMES libnvinfer.so libnvinfer.a + PATHS ${TENSORRT_ROOT} ${TENSORRT_ROOT}/lib + $ENV{TENSORRT_ROOT} $ENV{TENSORRT_ROOT}/lib + NO_DEFAULT_PATH + DOC "Path to TensorRT library.") + +if(TENSORRT_INCLUDE_DIR AND TENSORRT_LIBRARY) + if(WITH_DSO) + set(TENSORRT_FOUND ON) + endif(WITH_DSO) +else() + set(TENSORRT_FOUND OFF) +endif() + +if(TENSORRT_FOUND) + file(READ ${TENSORRT_INCLUDE_DIR}/NvInfer.h TENSORRT_VERSION_FILE_CONTENTS) + string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION + "${TENSORRT_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define NV_TENSORRT_MAJOR +([0-9]+)" "\\1" + TENSORRT_MAJOR_VERSION "${TENSORRT_MAJOR_VERSION}") + + message(STATUS "Current TensorRT header is ${TENSORRT_INCLUDE_DIR}/NvInfer.h. " + "Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ") + include_directories(${TENSORRT_INCLUDE_DIR}) + link_directories(${TENSORRT_LIBRARY}) + add_definitions(-DPADDLE_WITH_TENSORRT) +endif() diff --git a/cmake/util.cmake b/cmake/util.cmake new file mode 100644 index 00000000000..02667dbce69 --- /dev/null +++ b/cmake/util.cmake @@ -0,0 +1,55 @@ +# Some common routine for paddle compile. + +# target_circle_link_libraries +# Link libraries to target which has circle dependencies. +# +# First Argument: target name want to be linked with libraries +# Rest Arguments: libraries which link together. +function(target_circle_link_libraries TARGET_NAME) + if(APPLE) + set(LIBS) + set(inArchive OFF) + set(libsInArgn) + + foreach(arg ${ARGN}) + if(${arg} STREQUAL "ARCHIVE_START") + set(inArchive ON) + elseif(${arg} STREQUAL "ARCHIVE_END") + set(inArchive OFF) + else() + if(inArchive) + list(APPEND LIBS "-Wl,-force_load") + endif() + list(APPEND LIBS ${arg}) + list(APPEND libsInArgn ${arg}) + endif() + endforeach() + if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang") + if(NOT IOS_ENABLE_BITCODE) + list(APPEND LIBS "-undefined dynamic_lookup") + endif() + endif() + list(REVERSE libsInArgn) + target_link_libraries(${TARGET_NAME} + ${LIBS} + ${libsInArgn}) + + else() # LINUX + set(LIBS) + + foreach(arg ${ARGN}) + if(${arg} STREQUAL "ARCHIVE_START") + list(APPEND LIBS "-Wl,--whole-archive") + elseif(${arg} STREQUAL "ARCHIVE_END") + list(APPEND LIBS "-Wl,--no-whole-archive") + else() + list(APPEND LIBS ${arg}) + endif() + endforeach() + + target_link_libraries(${TARGET_NAME} + "-Wl,--start-group" + ${LIBS} + "-Wl,--end-group") + endif() +endfunction() diff --git a/cmake/version.cmake b/cmake/version.cmake new file mode 100644 index 00000000000..8bcc4ffe725 --- /dev/null +++ b/cmake/version.cmake @@ -0,0 +1,66 @@ +# Get the latest git tag. +set(PADDLE_VERSION $ENV{PADDLE_VERSION}) +set(tmp_version "HEAD") +set(TAG_VERSION_REGEX "[0-9]+\\.[0-9]+\\.[0-9]+(\\.(a|b|rc)\\.[0-9]+)?") +set(COMMIT_VERSION_REGEX "[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+") +# set(LATEST_PADDLE_VERSION "latest") +set(LATEST_PADDLE_VERSION "0.0.0") + +while ("${PADDLE_VERSION}" STREQUAL "") + # Check current branch name + execute_process( + COMMAND ${GIT_EXECUTABLE} rev-parse --abbrev-ref ${tmp_version} + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR} + OUTPUT_VARIABLE GIT_BRANCH_NAME + RESULT_VARIABLE GIT_BRANCH_RESULT + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + if (NOT ${GIT_BRANCH_RESULT}) + execute_process( + COMMAND ${GIT_EXECUTABLE} describe --tags --abbrev=0 --always ${tmp_version} + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR} + OUTPUT_VARIABLE GIT_TAG_NAME + RESULT_VARIABLE GIT_RESULT + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + if (NOT ${GIT_RESULT}) + # Check if current branch is release branch + if (${GIT_BRANCH_NAME} MATCHES "release/${TAG_VERSION_REGEX}") + # Check the tag is a correct version + if (${GIT_TAG_NAME} MATCHES "${COMMIT_VERSION_REGEX}") + # if no tag was found, set PADDLE_VERSION to "latest" + set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}") + elseif (${GIT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}") + string(REPLACE "v" "" PADDLE_VERSION ${GIT_TAG_NAME}) + else() # otherwise, get the previous git tag name. + set(tmp_version "${GIT_TAG_NAME}~1") + endif() + else() + execute_process( + COMMAND ${GIT_EXECUTABLE} describe --exact-match --tags ${tmp_version} + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR} + OUTPUT_VARIABLE GIT_EXACT_TAG_NAME + RESULT_VARIABLE GIT_EXACT_TAG_RESULT + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + if (NOT ${GIT_EXACT_TAG_NAME}) + # Check if current branch is tag branch + if (${GIT_EXACT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}") + string(REPLACE "v" "" PADDLE_VERSION ${GIT_EXACT_TAG_NAME}) + else() + set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}") + endif() + else() + # otherwise, we always set PADDLE_VERSION to "latest" + set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}") + endif() + endif() + else() + set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}") + message(WARNING "Cannot add paddle version from git tag") + endif() + else() + set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}") + message(WARNING "Cannot add paddle version for wrong git branch result") + endif() +endwhile() + +add_definitions(-DPADDLE_VERSION=${PADDLE_VERSION}) +message(STATUS "Paddle version is ${PADDLE_VERSION}") diff --git a/lite/CMakeLists.txt b/lite/CMakeLists.txt new file mode 100644 index 00000000000..fcf3a8dcdf6 --- /dev/null +++ b/lite/CMakeLists.txt @@ -0,0 +1,358 @@ +include(lite) + +message(WARNING "Lite enabled!") +message(STATUS "LIGHT_FRAMEWORK:\t${LITE_WITH_LIGHT_WEIGHT_FRAMEWORK}") +message(STATUS "LITE_WITH_CUDA:\t${LITE_WITH_CUDA}") +message(STATUS "LITE_WITH_X86:\t${LITE_WITH_X86}") +message(STATUS "LITE_WITH_ARM:\t${LITE_WITH_ARM}") +message(STATUS "LITE_WITH_NPU:\t${LITE_WITH_NPU}") +message(STATUS "LITE_WITH_FPGA:\t${LITE_WITH_FPGA}") +message(STATUS "LITE_WITH_PROFILE:\t${LITE_WITH_PROFILE}") + +set(LITE_MODEL_DIR "${THIRD_PARTY_PATH}/install") + +set(LITE_ON_MOBILE ${LITE_WITH_LIGHT_WEIGHT_FRAMEWORK}) + +set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inference download url") + +function(lite_download_and_uncompress INSTALL_DIR URL FILENAME) + message(STATUS "Download inference test stuff from ${URL}/${FILENAME}") + string(REGEX REPLACE "[-%.]" "_" FILENAME_EX ${FILENAME}) + set(EXTERNAL_PROJECT_NAME "extern_lite_download_${FILENAME_EX}") + set(UNPACK_DIR "${INSTALL_DIR}/src/${EXTERNAL_PROJECT_NAME}") + ExternalProject_Add( + ${EXTERNAL_PROJECT_NAME} + ${EXTERNAL_PROJECT_LOG_ARGS} + PREFIX ${INSTALL_DIR} + DOWNLOAD_COMMAND wget --no-check-certificate -q -O ${INSTALL_DIR}/${FILENAME} ${URL}/${FILENAME} && ${CMAKE_COMMAND} -E tar xzf ${INSTALL_DIR}/${FILENAME} + DOWNLOAD_DIR ${INSTALL_DIR} + DOWNLOAD_NO_PROGRESS 1 + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + UPDATE_COMMAND "" + INSTALL_COMMAND "" + ) +endfunction() + +function (lite_deps TARGET) + set(options "") + set(oneValueArgs "") + set(multiValueArgs DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS CL_DEPS FPGA_DEPS NPU_DEPS ARGS) + cmake_parse_arguments(lite_deps "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(deps ${lite_deps_DEPS}) + + if(LITE_WITH_X86) + foreach(var ${lite_deps_X86_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + + if(LITE_WITH_CUDA) + foreach(var ${lite_deps_CUDA_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + + if(LITE_WITH_ARM) + foreach(var ${lite_deps_ARM_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + + if(LITE_WITH_PROFILE) + foreach(var ${lite_deps_PROFILE_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + + if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + foreach(var ${lite_deps_LIGHT_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + + if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + foreach(var ${lite_deps_HVY_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + + if (LITE_WITH_OPENCL) + foreach(var ${lite_deps_CL_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + + if (LITE_WITH_FPGA) + foreach(var ${lite_deps_FPGA_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + + if (LITE_WITH_NPU) + foreach(var ${lite_deps_NPU_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + + set(${TARGET} ${deps} PARENT_SCOPE) +endfunction() + + +# A fake target to include all the libraries and tests the lite module depends. +add_custom_target(lite_compile_deps COMMAND echo 1) + +# Add names for lite libraries for latter compile. We use this name list to avoid compiling +# the whole fluid project to accelerate the compile speed. +set(offline_lib_registry_file "${CMAKE_BINARY_DIR}/lite_libs.txt") +file(WRITE ${offline_lib_registry_file} "") # clean + +# cc_library with branch support. +# The branches: +# X86_DEPS: works only when LITE_WITH_X86 is ON. +# CUDA_DEPS: LITE_WITH_CUDA +# ARM_DEPS: LITE_WITH_ARM +# PROFILE_DEPS: LITE_WITH_PROFILE +# LIGHT_DEPS: LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +# HVY_DEPS: NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +# EXCLUDE_COMPILE_DEPS: TARGET will not be included in lite_compile_deps if this is not None +function(lite_cc_library TARGET) + set(options SHARED shared STATIC static MODULE module) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS NPU_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS LIGHT_DEPS + HVY_DEPS EXCLUDE_COMPILE_DEPS ARGS) + cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(deps "") + lite_deps(deps + DEPS ${args_DEPS} + X86_DEPS ${args_X86_DEPS} + CUDA_DEPS ${args_CUDA_DEPS} + CL_DEPS ${args_CL_DEPS} + NPU_DEPS ${args_NPU_DEPS} + ARM_DEPS ${args_ARM_DEPS} + FPGA_DEPS ${args_FPGA_DEPS} + PROFILE_DEPS ${args_PROFILE_DEPS} + LIGHT_DEPS ${args_LIGHT_DEPS} + HVY_DEPS ${args_HVY_DEPS} + ) + + if (args_SHARED OR ARGS_shared) + cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS} SHARED) + elseif (args_MODULE OR ARGS_module) + add_library(${TARGET} MODULE ${args_SRCS}) + add_dependencies(${TARGET} ${deps} ${args_DEPS}) + else() + cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS}) + endif() + target_compile_options(${TARGET} BEFORE PRIVATE -Wno-ignored-qualifiers) + + # collect targets need to compile for lite + if (args_SRCS AND NOT args_EXCLUDE_COMPILE_DEPS) + add_dependencies(lite_compile_deps ${TARGET}) + endif() + + # register a library name. + file(APPEND ${offline_lib_registry_file} "${TARGET}\n") +endfunction() + +function(lite_cc_binary TARGET) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS + LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS ARGS) + cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(deps "") + lite_deps(deps + DEPS ${args_DEPS} + X86_DEPS ${args_X86_DEPS} + CUDA_DEPS ${args_CUDA_DEPS} + CL_DEPS ${args_CL_DEPS} + ARM_DEPS ${args_ARM_DEPS} + FPGA_DEPS ${args_FPGA_DEPS} + PROFILE_DEPS ${args_PROFILE_DEPS} + LIGHT_DEPS ${args_LIGHT_DEPS} + HVY_DEPS ${args_HVY_DEPS} + ) + cc_binary(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS}) + target_compile_options(${TARGET} BEFORE PRIVATE -Wno-ignored-qualifiers) + # collect targets need to compile for lite + if (NOT args_EXCLUDE_COMPILE_DEPS) + add_dependencies(lite_compile_deps ${TARGET}) + endif() +endfunction() + +# Add a unit-test name to file for latter offline manual test. +set(offline_test_registry_file "${CMAKE_BINARY_DIR}/lite_tests.txt") +file(WRITE ${offline_test_registry_file} "") # clean +# Test lite modules. + +function(lite_cc_test TARGET) + if(NOT WITH_TESTING) + return() + endif() + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS CL_DEPS ARM_DEPS FPGA_DEPS PROFILE_DEPS + LIGHT_DEPS HVY_DEPS EXCLUDE_COMPILE_DEPS + ARGS) + cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(deps "") + lite_deps(deps + DEPS ${args_DEPS} + X86_DEPS ${args_X86_DEPS} + CUDA_DEPS ${args_CUDA_DEPS} + CL_DEPS ${args_CL_DEPS} + ARM_DEPS ${args_ARM_DEPS} + FPGA_DEPS ${args_FPGA_DEPS} + PROFILE_DEPS ${args_PROFILE_DEPS} + LIGHT_DEPS ${args_LIGHT_DEPS} + HVY_DEPS ${args_HVY_DEPS} + ) + _lite_cc_test(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ARGS ${args_ARGS}) + target_compile_options(${TARGET} BEFORE PRIVATE -Wno-ignored-qualifiers) + file(APPEND ${offline_test_registry_file} "${TARGET}\n") + + # collect targets need to compile for lite + if (NOT args_EXCLUDE_COMPILE_DEPS) + add_dependencies(lite_compile_deps ${TARGET}) + endif() +endfunction() + +add_subdirectory(utils) +add_subdirectory(operators) +add_subdirectory(kernels) +add_subdirectory(npu) +add_subdirectory(core) +add_subdirectory(x86) +add_subdirectory(arm) +add_subdirectory(host) +add_subdirectory(cuda) +add_subdirectory(opencl) +add_subdirectory(fpga) +add_subdirectory(model_parser) +add_subdirectory(api) +add_subdirectory(fluid) +if (NOT LITE_ON_TINY_PUBLISH) + add_subdirectory(tests) + add_subdirectory(tools) +endif() +if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND NOT LITE_ON_TINY_PUBLISH) + add_subdirectory(gen_code) +endif() + +if (WITH_TESTING) + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_naive_model.tar.gz") + if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v1.tar.gz") + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v2_relu.tar.gz") + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz") + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4_simple.tar.gz") + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "MobileNetV1_quant.tar.gz") + endif() + if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "GoogleNet_inference.tar.gz") + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v1.tar.gz") + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v2_relu.tar.gz") + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz") + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4_simple.tar.gz") + endif() +endif() + +if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) + # for publish + set(INFER_LITE_PUBLISH_ROOT "${CMAKE_BINARY_DIR}/inference_lite_lib.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}") + if (LITE_WITH_OPENCL) + set(INFER_LITE_PUBLISH_ROOT "${INFER_LITE_PUBLISH_ROOT}.opencl") + endif(LITE_WITH_OPENCL) + message(STATUS "publish inference lib to ${INFER_LITE_PUBLISH_ROOT}") + + # The final target for publish lite lib + add_custom_target(publish_inference) + if (NOT LITE_ON_TINY_PUBLISH) + # add cxx lib + add_custom_target(publish_inference_cxx_lib ${TARGET} + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/bin" + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/cxx/include" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include" + COMMAND cp "${CMAKE_BINARY_DIR}/libpaddle_api_full_bundled.a" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" + COMMAND cp "${CMAKE_BINARY_DIR}/libpaddle_api_light_bundled.a" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib" + COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/model_optimize_tool" "${INFER_LITE_PUBLISH_ROOT}/bin" + COMMAND cp "${CMAKE_BINARY_DIR}/lite/gen_code/paddle_code_generator" "${INFER_LITE_PUBLISH_ROOT}/bin" + COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/test_model_bin" "${INFER_LITE_PUBLISH_ROOT}/bin" + ) + add_dependencies(publish_inference_cxx_lib model_optimize_tool) + add_dependencies(publish_inference_cxx_lib paddle_code_generator) + add_dependencies(publish_inference_cxx_lib bundle_full_api) + add_dependencies(publish_inference_cxx_lib bundle_light_api) + add_dependencies(publish_inference_cxx_lib test_model_bin) + add_dependencies(publish_inference publish_inference_cxx_lib) + add_custom_command(TARGET publish_inference_cxx_lib POST_BUILD + COMMAND ${CMAKE_STRIP} "--strip-debug" ${INFER_LITE_PUBLISH_ROOT}/cxx/lib/*.a) + endif() + + + if (LITE_WITH_JAVA) + # add java lib + add_custom_target(publish_inference_java_lib ${TARGET} + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/java/so" + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/java/jar" + COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/android/jni/native/libpaddle_lite_jni.so" "${INFER_LITE_PUBLISH_ROOT}/java/so" + COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/android/jni/PaddlePredictor.jar" "${INFER_LITE_PUBLISH_ROOT}/java/jar" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/api/android/jni/src" "${INFER_LITE_PUBLISH_ROOT}/java" + ) + add_dependencies(publish_inference_java_lib paddle_lite_jni PaddlePredictor) + add_dependencies(publish_inference publish_inference_java_lib) + add_custom_command(TARGET publish_inference_java_lib POST_BUILD + COMMAND ${CMAKE_STRIP} "-s" ${INFER_LITE_PUBLISH_ROOT}/java/so/libpaddle_lite_jni.so) + endif() + + if ((ARM_TARGET_OS STREQUAL "android") AND (NOT LITE_WITH_OPENCL) AND + ((ARM_TARGET_ARCH_ABI STREQUAL armv7) OR (ARM_TARGET_ARCH_ABI STREQUAL armv8))) + if (NOT LITE_ON_TINY_PUBLISH) + # copy + add_custom_target(publish_inference_android_cxx_demos ${TARGET} + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/third_party" + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/include" + COMMAND cp -r "${CMAKE_BINARY_DIR}/third_party/install/gflags" "${INFER_LITE_PUBLISH_ROOT}/third_party" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/Makefile.def" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/README.md" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/mobile_full" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mobile_full/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mobile_full/Makefile" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/mobile_light" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mobile_light/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mobile_light/Makefile" + ) + add_dependencies(publish_inference_android_cxx_demos logging gflags) + add_dependencies(publish_inference_cxx_lib publish_inference_android_cxx_demos) + endif() + + if (LITE_WITH_JAVA) + # copy java mobile_light demo/lib + add_custom_target(publish_inference_android_java_demo ${TARGET} + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/demo/java" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/java/android" "${INFER_LITE_PUBLISH_ROOT}/demo/java" + COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/java/README.md" "${INFER_LITE_PUBLISH_ROOT}/demo/java" + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/demo/java/android/PaddlePredictor/app/libs" + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/demo/java/android/PaddlePredictor/app/src/main/jniLibs/arm7" + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/demo/java/android/PaddlePredictor/app/src/main/jniLibs/arm8" + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/demo/java/android/PaddlePredictor/app/src/main/jniLibs/arm64-v8a" + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/demo/java/android/PaddlePredictor/app/src/main/jniLibs/armeabi-v7a" + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/demo/java/android/PaddlePredictor/app/src/main/jniLibs/x86" + ) + add_dependencies(publish_inference_java_lib publish_inference_android_java_demo) + endif() + endif() + + if (LITE_WITH_OPENCL) + add_custom_target(publish_inference_opencl ${TARGET} + COMMAND mkdir -p "${INFER_LITE_PUBLISH_ROOT}/opencl" + COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/opencl/cl_kernel" "${INFER_LITE_PUBLISH_ROOT}/opencl" + ) + add_dependencies(publish_inference_cxx_lib publish_inference_opencl) + endif() +endif() diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt new file mode 100644 index 00000000000..7dc74e15701 --- /dev/null +++ b/lite/api/CMakeLists.txt @@ -0,0 +1,223 @@ +if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + lite_cc_library(place SRCS paddle_place.cc DEPS logging) +else() + lite_cc_library(place SRCS paddle_place.cc DEPS glog) +endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + +if (WITH_TESTING) + lite_cc_library(lite_api_test_helper SRCS lite_api_test_helper.cc + DEPS scope optimizer target_wrapper_host model_parser program + ${ops} ${host_kernels} + CUDA_DEPS kernels_cuda + X86_DEPS ${x86_kernels}) +endif() +if(LITE_WITH_FPGA) + set(light_api_deps ${light_api_deps} ${fpga_deps}) + set(cxx_api_deps ${cxx_api_deps} ${fpga_deps}) +endif() + +message(STATUS "get ops ${ops}") +message(STATUS "get Host kernels ${host_kernels}") +message(STATUS "get ARM kernels ${arm_kernels}") +message(STATUS "get NPU kernels ${npu_kernels}") +message(STATUS "get FPGA kernels ${fpga_kernels}") + +# for full api +if (NOT LITE_ON_TINY_PUBLISH) + set(cxx_api_deps + scope optimizer target_wrapper_host model_parser program) + if(LITE_WITH_CUDA) + set(cxx_api_deps ${cxx_api_deps} kernels_cuda) + lite_cc_library(cxx_api_cuda SRCS cxx_api.cc DEPS ${cxx_api_deps} target_wrapper_cuda) + nv_test(test_cxx_api_cuda SRCS cxx_api_test.cc DEPS cxx_api_cuda) + endif() + lite_cc_library(cxx_api + SRCS cxx_api.cc + DEPS ${cxx_api_deps} ${ops} ${host_kernels} program + X86_DEPS ${x86_kernels} + ARM_DEPS ${arm_kernels} + NPU_DEPS ${npu_kernels} ${npu_bridges} npu_pass + CL_DEPS ${opencl_kenrels} + FPGA_DEPS ${fpga_kenrels}) +endif() + +# for light api +set(light_api_deps + scope target_wrapper_host model_parser program) +if(LITE_WITH_CUDA) + set(light_api_deps ${light_api_deps} target_wrapper_cuda) +endif() +lite_cc_library(light_api SRCS light_api.cc + DEPS scope target_wrapper_host model_parser + ${light_api_deps} ${ops} ${host_kernels} program + CUDA_DEPS target_wrapper_cuda + X86_DEPS ${x86_kernels} + ARM_DEPS ${arm_kernels} + NPU_DEPS ${npu_kernels} ${npu_bridges} npu_pass + CL_DEPS ${opencl_kenrels} + FPGA_DEPS ${fpga_kenrels}) + +include(ExternalProject) +set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING + "A path setting inference demo download directories.") + +if(WITH_TESTING) + lite_cc_test(test_cxx_api SRCS cxx_api_test.cc + DEPS cxx_api mir_passes lite_api_test_helper + ${ops} ${host_kernels} + X86_DEPS ${x86_kernels} + ARM_DEPS ${arm_kernels} + NPU_DEPS ${npu_kernels} + CL_DEPS ${opencl_kernels} + FPGA_DEPS ${fpga_kernels} + EXCLUDE_COMPILE_DEPS "ON" + ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model + --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) + add_dependencies(test_cxx_api extern_lite_download_lite_naive_model_tar_gz) + if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + lite_cc_test(test_googlenet SRCS test_googlenet_lite.cc + DEPS cxx_api mir_passes lite_api_test_helper + ${ops} ${host_kernels} ${x86_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/googlenet) + add_dependencies(test_googlenet extern_lite_download_GoogleNet_inference_tar_gz) + lite_cc_test(test_mobilenetv1_lite_x86 SRCS test_mobilenetv1_lite_x86.cc + DEPS cxx_api mir_passes lite_api_test_helper + ${ops} ${host_kernels} ${x86_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1) + add_dependencies(test_mobilenetv1_lite_x86 extern_lite_download_mobilenet_v1_tar_gz) + lite_cc_test(test_mobilenetv2_lite_x86 SRCS test_mobilenetv2_lite_x86.cc + DEPS cxx_api mir_passes lite_api_test_helper + ${ops} ${host_kernels} ${x86_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v2_relu) + add_dependencies(test_mobilenetv2_lite_x86 extern_lite_download_mobilenet_v2_relu_tar_gz) + lite_cc_test(test_inceptionv4_lite_x86 SRCS test_inceptionv4_lite_x86.cc + DEPS cxx_api mir_passes lite_api_test_helper + ${ops} ${host_kernels} ${x86_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/inception_v4_simple) + add_dependencies(test_inceptionv4_lite_x86 extern_lite_download_inception_v4_simple_tar_gz) + endif() +endif() + +if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING) + set(lite_model_test_DEPS cxx_api mir_passes ${ops} ${host_kernels} ${arm_kernels} ${npu_kernels} ${fpga_kernels}) + + lite_cc_test(test_mobilenetv1_int8 SRCS mobilenetv1_int8_test.cc + DEPS ${lite_model_test_DEPS} + CL_DEPS ${opencl_kernels} + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/opencl + --model_dir=${LITE_MODEL_DIR}/MobilenetV1_quant SERIAL) + add_dependencies(test_mobilenetv1_int8 extern_lite_download_MobileNetV1_quant_tar_gz) + + lite_cc_test(test_mobilenetv1 SRCS mobilenetv1_test.cc + DEPS ${lite_model_test_DEPS} + CL_DEPS ${opencl_kernels} + NPU_DEPS ${npu_kernels} ${npu_bridges} + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/opencl + --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 SERIAL) + add_dependencies(test_mobilenetv1 extern_lite_download_mobilenet_v1_tar_gz) + set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map") + set_target_properties(test_mobilenetv1 PROPERTIES LINK_FLAGS "${LINK_FLAGS}") + + lite_cc_test(test_mobilenetv2 SRCS mobilenetv2_test.cc + DEPS ${lite_model_test_DEPS} + CL_DEPS ${opencl_kernels} + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/opencl + --model_dir=${LITE_MODEL_DIR}/mobilenet_v2_relu SERIAL) + add_dependencies(test_mobilenetv2 extern_lite_download_mobilenet_v2_relu_tar_gz) + set_target_properties(test_mobilenetv2 PROPERTIES LINK_FLAGS "${LINK_FLAGS}") + + lite_cc_test(test_resnet50 SRCS resnet50_test.cc + DEPS ${lite_model_test_DEPS} + CL_DEPS ${opencl_kernels} + FPGA_DEPS ${fpga_kernels} + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/opencl + --model_dir=${LITE_MODEL_DIR}/resnet50 SERIAL) + add_dependencies(test_resnet50 extern_lite_download_resnet50_tar_gz) + + lite_cc_test(test_resnet50_fpga SRCS resnet50_test_fpga.cc + DEPS ${lite_model_test_DEPS} + CL_DEPS ${opencl_kernels} + FPGA_DEPS ${fpga_kernels}) + + lite_cc_test(test_inceptionv4 SRCS inceptionv4_test.cc + DEPS ${lite_model_test_DEPS} + CL_DEPS ${opencl_kernels} + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/opencl + --model_dir=${LITE_MODEL_DIR}/inception_v4 SERIAL) + add_dependencies(test_inceptionv4 extern_lite_download_inception_v4_simple_tar_gz) +# lite_cc_test(test_ocr_attention SRCS ocr_attention_test.cc +# DEPS ${lite_model_test_DEPS}) +endif() + +# These tests needs CLI arguments, and is not supported in ARM CI. +# TODO(Superjomn) support latter. +lite_cc_test(test_light_api SRCS light_api_test.cc + DEPS light_api program mir_passes + CL_DEPS ${opencl_kernels} + FPGA_DEPS ${fpga_kernels} + ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) + +lite_cc_test(test_apis SRCS apis_test.cc + DEPS cxx_api light_api ${ops} + CL_DEPS ${opencl_kernels} + X86_DEPS ${x86_kernels} + FPGA_DEPS ${fpga_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model + --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) + +lite_cc_library(paddle_api SRCS paddle_api.cc DEPS op_params tensor) + +#----------------------------------------------------------------------------------------------------- +# The final inference library for both CxxConfig and MobileConfig. +lite_cc_library(paddle_api_light SRCS light_api_impl.cc DEPS light_api paddle_api) +if (NOT LITE_ON_TINY_PUBLISH) + lite_cc_library(paddle_api_full SRCS cxx_api_impl.cc DEPS cxx_api paddle_api light_api + ${ops} + ARM_DEPS ${arm_kernels} + NPU_DEPS ${npu_kernels} + CL_DEPS ${opencl_kernels} + FPGA_DEPS ${fpga_kernels}) + # The final inference library for just MobileConfig. + bundle_static_library(paddle_api_full paddle_api_full_bundled bundle_full_api) +endif() +bundle_static_library(paddle_api_light paddle_api_light_bundled bundle_light_api) +#----------------------------------------------------------------------------------------------------- + +if (LITE_WITH_JAVA AND LITE_WITH_ARM) + add_subdirectory(android) +endif() + +if (LITE_ON_TINY_PUBLISH) + return() +endif() +lite_cc_test(test_paddle_api SRCS paddle_api_test.cc DEPS paddle_api_full paddle_api_light + ${ops} + ARM_DEPS ${arm_kernels} + NPU_DEPS ${npu_kernels} + CL_DEPS ${opencl_kernels} + X86_DEPS ${x86_kernels} + FPGA_DEPS ${fpga_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model SERIAL) +if (WITH_TESTING) + add_dependencies(test_paddle_api extern_lite_download_lite_naive_model_tar_gz) +endif() + +# Some bins +if(NOT IOS) + lite_cc_binary(test_model_bin SRCS model_test.cc DEPS paddle_api_full paddle_api_light gflags + ${ops} + ARM_DEPS ${arm_kernels} + NPU_DEPS ${npu_kernels} + CL_DEPS ${opencl_kernels} + FPGA_DEPS ${fpga_kernels} + X86_DEPS ${x86_kernels}) +endif() + +#lite_cc_binary(cxx_api_bin SRCS cxx_api_bin.cc + #X86_DEPS operator + #DEPS light_api model_parser target_wrapper_host mir_passes + #ARM_DEPS ${arm_kernels}) NPU_DEPS ${npu_kernels}) + +lite_cc_binary(model_optimize_tool SRCS model_optimize_tool.cc + DEPS paddle_api_full gflags + CL_DEPS ${opencl_kernels}) diff --git a/lite/api/android/.gitignore b/lite/api/android/.gitignore new file mode 100644 index 00000000000..a1d6334395d --- /dev/null +++ b/lite/api/android/.gitignore @@ -0,0 +1,2 @@ +/bin/ +.classpath diff --git a/lite/api/android/CMakeLists.txt b/lite/api/android/CMakeLists.txt new file mode 100644 index 00000000000..7f31f7e9479 --- /dev/null +++ b/lite/api/android/CMakeLists.txt @@ -0,0 +1,5 @@ +if ((NOT LITE_WITH_JAVA) OR (NOT LITE_WITH_ARM)) + return() +endif() + +add_subdirectory(jni) diff --git a/lite/api/android/jni/.gitignore b/lite/api/android/jni/.gitignore new file mode 100644 index 00000000000..1299d2738c0 --- /dev/null +++ b/lite/api/android/jni/.gitignore @@ -0,0 +1,3 @@ +/PaddleListTest.class +/PaddleLite.class +/bin/ diff --git a/lite/api/android/jni/CMakeLists.txt b/lite/api/android/jni/CMakeLists.txt new file mode 100644 index 00000000000..c1337c8581e --- /dev/null +++ b/lite/api/android/jni/CMakeLists.txt @@ -0,0 +1,51 @@ +if ((NOT LITE_WITH_ARM) OR (NOT LITE_WITH_JAVA)) + return() +endif() + +include(UseJava) +find_package(Java REQUIRED) + +# We are only interested in finding jni.h: we do not care about extended JVM +# functionality or the AWT library. +set(JAVA_AWT_LIBRARY NotNeeded) +set(JAVA_JVM_LIBRARY NotNeeded) +set(JAVA_INCLUDE_PATH2 NotNeeded) +set(JAVA_AWT_INCLUDE_PATH NotNeeded) +find_package(JNI REQUIRED) + +# Generate PaddlePredictor.jar +include_directories(${JNI_INCLUDE_DIRS}) +add_jar(PaddlePredictor + src/com/baidu/paddle/lite/ConfigBase.java + src/com/baidu/paddle/lite/CxxConfig.java + src/com/baidu/paddle/lite/MobileConfig.java + src/com/baidu/paddle/lite/PaddleLiteInitializer.java + src/com/baidu/paddle/lite/PaddlePredictor.java + src/com/baidu/paddle/lite/Place.java + src/com/baidu/paddle/lite/Tensor.java) +get_target_property(_jarFile PaddlePredictor JAR_FILE) +get_target_property(_classDir PaddlePredictor CLASSDIR) +set(_stubDir "${CMAKE_CURRENT_BINARY_DIR}") + +# Generate native headers +add_custom_target( + paddle_lite_jni_header ALL + COMMAND ${Java_JAVAH_EXECUTABLE} -verbose + -classpath ${_classDir} + -o "${CMAKE_BINARY_DIR}/lite/api/android/jni/native/paddle_lite_jni.h" + -jni + com.baidu.paddle.lite.PaddlePredictor + COMMAND ${Java_JAVAH_EXECUTABLE} -verbose + -classpath ${_classDir} + -o "${CMAKE_BINARY_DIR}/lite/api/android/jni/native/tensor_jni.h" + -jni + com.baidu.paddle.lite.Tensor + COMMAND ${Java_JAVAH_EXECUTABLE} -verbose + -classpath ${_classDir} + -o "${CMAKE_BINARY_DIR}/lite/api/android/jni/native/paddle_init_jni.h" + -jni + com.baidu.paddle.lite.PaddleLiteInitializer + DEPENDS PaddlePredictor +) + +add_subdirectory(native) diff --git a/lite/api/android/jni/native/CMakeLists.txt b/lite/api/android/jni/native/CMakeLists.txt new file mode 100644 index 00000000000..0d9f466fbd6 --- /dev/null +++ b/lite/api/android/jni/native/CMakeLists.txt @@ -0,0 +1,32 @@ +# Generate paddle_lite_jni.so + +if (LITE_ON_TINY_PUBLISH) + set(CMAKE_CXX_FLAGS_RELEASE "-Os -DNDEBUG") + set(CMAKE_C_FLAGS_RELEASE "-Os -DNDEBUG") + set(lib_DEPS light_api paddle_api paddle_api_light) +else() + set(lib_DEPS light_api cxx_api paddle_api_full paddle_api paddle_api_light) +endif() + +include_directories(${JNI_INCLUDE_DIRS} ${_classDir} ${_stubDir}) +if (NOT LITE_ON_TINY_PUBLISH) + lite_cc_library(paddle_lite_jni MODULE + SRCS paddle_lite_jni.cc tensor_jni.cc + DEPS ${lib_DEPS} + ARM_DEPS ${arm_kernels} NPU_DEPS ${npu_kernels}) + # Unlike static library, module library has to link target to be able to work + # as a single .so lib. + target_link_libraries(paddle_lite_jni ${lib_DEPS} ${arm_kernels} ${npu_kernels}) +else() + add_library(paddle_lite_jni SHARED "") + target_sources(paddle_lite_jni PUBLIC ${__lite_cc_files} paddle_lite_jni.cc tensor_jni.cc) + #add_dependencies(paddle_lite_jni ${lib_DEPS} ${arm_kernels} ${npu_kernels}) +endif() + +if (APPLE) + # MacOS only accepts JNI lib ends with .jnilib or .dylib + set_target_properties(paddle_lite_jni PROPERTIES SUFFIX ".jnilib") +elseif (WIN32) + # Windows only accepts JNI lib ends with .dll + set_target_properties(paddle_lite_jni PROPERTIES SUFFIX ".dll") +endif (APPLE) diff --git a/lite/api/android/jni/native/convert_util_jni.h b/lite/api/android/jni/native/convert_util_jni.h new file mode 100644 index 00000000000..2524403a7f7 --- /dev/null +++ b/lite/api/android/jni/native/convert_util_jni.h @@ -0,0 +1,186 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "lite/api/light_api.h" +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_place.h" + +#ifndef PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_CONVERT_UTIL_JNI_H_ +#define PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_CONVERT_UTIL_JNI_H_ + +namespace paddle { +namespace lite_api { + +inline std::string jstring_to_cpp_string(JNIEnv *env, jstring jstr) { + // In java, a unicode char will be encoded using 2 bytes (utf16). + // so jstring will contain characters utf16. std::string in c++ is + // essentially a string of bytes, not characters, so if we want to + // pass jstring from JNI to c++, we have convert utf16 to bytes. + if (!jstr) { + return ""; + } + const jclass stringClass = env->GetObjectClass(jstr); + const jmethodID getBytes = + env->GetMethodID(stringClass, "getBytes", "(Ljava/lang/String;)[B"); + const jbyteArray stringJbytes = (jbyteArray)env->CallObjectMethod( + jstr, getBytes, env->NewStringUTF("UTF-8")); + + size_t length = (size_t)env->GetArrayLength(stringJbytes); + jbyte *pBytes = env->GetByteArrayElements(stringJbytes, NULL); + + std::string ret = std::string(reinterpret_cast(pBytes), length); + env->ReleaseByteArrayElements(stringJbytes, pBytes, JNI_ABORT); + + env->DeleteLocalRef(stringJbytes); + env->DeleteLocalRef(stringClass); + return ret; +} + +inline jfloatArray cpp_array_to_jfloatarray(JNIEnv *env, + const float *buf, + int64_t len) { + jfloatArray result = env->NewFloatArray(len); + env->SetFloatArrayRegion(result, 0, len, buf); + return result; +} + +inline jintArray cpp_array_to_jintarray(JNIEnv *env, + const int *buf, + int64_t len) { + jintArray result = env->NewIntArray(len); + env->SetIntArrayRegion(result, 0, len, buf); + return result; +} + +inline jbyteArray cpp_array_to_jbytearray(JNIEnv *env, + const int8_t *buf, + int64_t len) { + jbyteArray result = env->NewByteArray(len); + env->SetByteArrayRegion(result, 0, len, buf); + return result; +} + +inline jlongArray int64_vector_to_jlongarray(JNIEnv *env, + const std::vector &vec) { + jlongArray result = env->NewLongArray(vec.size()); + jlong *buf = new jlong[vec.size()]; + for (size_t i = 0; i < vec.size(); ++i) { + buf[i] = (jlong)vec[i]; + } + env->SetLongArrayRegion(result, 0, vec.size(), buf); + delete[] buf; + return result; +} + +inline std::vector jlongarray_to_int64_vector(JNIEnv *env, + jlongArray dims) { + int dim_size = env->GetArrayLength(dims); + jlong *dim_nums = env->GetLongArrayElements(dims, nullptr); + std::vector dim_vec(dim_nums, dim_nums + dim_size); + env->ReleaseLongArrayElements(dims, dim_nums, 0); + return dim_vec; +} + +/** + * Converts Java com.baidu.paddle.lite.Place to c++ paddle::lite_api::Place. + */ +inline Place jplace_to_cpp_place(JNIEnv *env, jobject java_place) { + jclass place_jclazz = env->GetObjectClass(java_place); + + jmethodID target_method = + env->GetMethodID(place_jclazz, "getTargetInt", "()I"); + jmethodID precision_method = + env->GetMethodID(place_jclazz, "getPrecisionInt", "()I"); + jmethodID data_layout_method = + env->GetMethodID(place_jclazz, "getDataLayoutInt", "()I"); + jmethodID device_method = env->GetMethodID(place_jclazz, "getDevice", "()I"); + + int target = env->CallIntMethod(java_place, target_method); + int precision = env->CallIntMethod(java_place, precision_method); + int data_layout = env->CallIntMethod(java_place, data_layout_method); + int device = env->CallIntMethod(java_place, device_method); + + return Place(static_cast(target), + static_cast(precision), + static_cast(data_layout), + device); +} + +inline CxxConfig jcxxconfig_to_cpp_cxxconfig(JNIEnv *env, jobject jcxxconfig) { + jclass cxxconfig_jclazz = env->GetObjectClass(jcxxconfig); + + jmethodID model_dir_method = + env->GetMethodID(cxxconfig_jclazz, "getModelDir", "()Ljava/lang/String;"); + jmethodID preferred_place_method = env->GetMethodID( + cxxconfig_jclazz, "getPreferredPlace", "()Lcom/baidu/paddle/lite/Place;"); + jmethodID valid_places_method = env->GetMethodID( + cxxconfig_jclazz, "getValidPlaces", "()[Lcom/baidu/paddle/lite/Place;"); + + CxxConfig config; + + jstring java_model_dir = + (jstring)env->CallObjectMethod(jcxxconfig, model_dir_method); + if (java_model_dir != nullptr) { + std::string cpp_model_dir = jstring_to_cpp_string(env, java_model_dir); + config.set_model_dir(cpp_model_dir); + } + + jobject java_preferred_place = + env->CallObjectMethod(jcxxconfig, preferred_place_method); + if (java_preferred_place != nullptr) { + Place cpp_preferred_place = jplace_to_cpp_place(env, java_preferred_place); + config.set_preferred_place(cpp_preferred_place); + } + + jobject object_valid_places = + env->CallObjectMethod(jcxxconfig, valid_places_method); + jobjectArray *java_valid_places = + reinterpret_cast(&object_valid_places); + if (java_valid_places != nullptr) { + int valid_place_count = env->GetArrayLength(*java_valid_places); + std::vector cpp_valid_places; + for (int i = 0; i < valid_place_count; ++i) { + jobject jplace = env->GetObjectArrayElement(*java_valid_places, i); + cpp_valid_places.push_back(jplace_to_cpp_place(env, jplace)); + } + config.set_valid_places(cpp_valid_places); + } + + return config; +} + +inline MobileConfig jmobileconfig_to_cpp_mobileconfig(JNIEnv *env, + jobject jmobileconfig) { + jclass mobileconfig_jclazz = env->GetObjectClass(jmobileconfig); + + jmethodID model_dir_method = env->GetMethodID( + mobileconfig_jclazz, "getModelDir", "()Ljava/lang/String;"); + MobileConfig config; + + jstring java_model_dir = + (jstring)env->CallObjectMethod(jmobileconfig, model_dir_method); + if (java_model_dir != nullptr) { + std::string cpp_model_dir = jstring_to_cpp_string(env, java_model_dir); + config.set_model_dir(cpp_model_dir); + } + return config; +} + +} // namespace lite_api +} // namespace paddle + +#endif // PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_CONVERT_UTIL_JNI_H_ diff --git a/lite/api/android/jni/native/paddle_lite_jni.cc b/lite/api/android/jni/native/paddle_lite_jni.cc new file mode 100644 index 00000000000..aa4ece68189 --- /dev/null +++ b/lite/api/android/jni/native/paddle_lite_jni.cc @@ -0,0 +1,164 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/api/android/jni/native/paddle_lite_jni.h" + +#include +#include +#include +#include + +#include "lite/api/android/jni/native/convert_util_jni.h" +#include "lite/api/light_api.h" +#include "lite/api/paddle_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +namespace paddle { +namespace lite_api { + +inline static std::shared_ptr *getPaddlePredictorPointer( + JNIEnv *env, jobject jpaddle_predictor) { + jclass jclazz = env->GetObjectClass(jpaddle_predictor); + jfieldID jfield = env->GetFieldID(jclazz, "cppPaddlePredictorPointer", "J"); + jlong java_pointer = env->GetLongField(jpaddle_predictor, jfield); + std::shared_ptr *ptr = + reinterpret_cast *>(java_pointer); + return ptr; +} + +JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_PaddlePredictor_run( + JNIEnv *env, jobject jpaddle_predictor) { + std::shared_ptr *predictor = + getPaddlePredictorPointer(env, jpaddle_predictor); + if (predictor == nullptr || (*predictor == nullptr)) { + return JNI_FALSE; + } + (*predictor)->Run(); + return JNI_TRUE; +} + +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_saveOptimizedModel( + JNIEnv *env, jobject jpaddle_predictor, jstring model_dir) { + std::shared_ptr *predictor = + getPaddlePredictorPointer(env, jpaddle_predictor); + if (predictor == nullptr || (*predictor == nullptr)) { + return JNI_FALSE; + } + (*predictor)->SaveOptimizedModel(jstring_to_cpp_string(env, model_dir)); + return JNI_TRUE; +} + +JNIEXPORT jlong JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_getInputCppTensorPointer( + JNIEnv *env, jobject jpaddle_predictor, jint offset) { + std::shared_ptr *predictor = + getPaddlePredictorPointer(env, jpaddle_predictor); + if (predictor == nullptr || (*predictor == nullptr)) { + return 0; + } + std::unique_ptr tensor = + (*predictor)->GetInput(static_cast(offset)); + std::unique_ptr *cpp_tensor_pointer = + new std::unique_ptr(std::move(tensor)); + return reinterpret_cast(cpp_tensor_pointer); +} + +JNIEXPORT jlong JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_getOutputCppTensorPointer( + JNIEnv *env, jobject jpaddle_predictor, jint offset) { + std::shared_ptr *predictor = + getPaddlePredictorPointer(env, jpaddle_predictor); + if (predictor == nullptr || (*predictor == nullptr)) { + return 0; + } + std::unique_ptr tensor = + (*predictor)->GetOutput(static_cast(offset)); + std::unique_ptr *cpp_tensor_pointer = + new std::unique_ptr(std::move(tensor)); + return reinterpret_cast(cpp_tensor_pointer); +} + +JNIEXPORT jlong JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_getCppTensorPointerByName( + JNIEnv *env, jobject jpaddle_predictor, jstring name) { + std::string cpp_name = jstring_to_cpp_string(env, name); + std::shared_ptr *predictor = + getPaddlePredictorPointer(env, jpaddle_predictor); + if (predictor == nullptr || (*predictor == nullptr)) { + return 0; + } + std::unique_ptr tensor = (*predictor)->GetTensor(cpp_name); + std::unique_ptr *cpp_tensor_pointer = + new std::unique_ptr(std::move(tensor)); + return reinterpret_cast(cpp_tensor_pointer); +} + +JNIEXPORT jlong JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_newCppPaddlePredictor__Lcom_baidu_\ +paddle_lite_CxxConfig_2(JNIEnv *env, + jobject jpaddle_predictor, + jobject jcxxconfig) { +#ifndef LITE_ON_TINY_PUBLISH + CxxConfig config = jcxxconfig_to_cpp_cxxconfig(env, jcxxconfig); + std::shared_ptr predictor = + paddle::lite_api::CreatePaddlePredictor(config); + if (predictor == nullptr) { + return 0; + } + std::shared_ptr *predictor_pointer = + new std::shared_ptr(predictor); + return reinterpret_cast(predictor_pointer); +#else + return 0; +#endif +} + +JNIEXPORT jlong JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_newCppPaddlePredictor__Lcom_baidu_\ +paddle_lite_MobileConfig_2(JNIEnv *env, + jobject jpaddle_predictor, + jobject jmobileconfig) { + MobileConfig config = jmobileconfig_to_cpp_mobileconfig(env, jmobileconfig); + std::shared_ptr predictor = + paddle::lite_api::CreatePaddlePredictor(config); + if (predictor == nullptr) { + return 0; + } + std::shared_ptr *predictor_pointer = + new std::shared_ptr(predictor); + return reinterpret_cast(predictor_pointer); +} + +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_deleteCppPaddlePredictor( + JNIEnv *env, jobject jpaddle_predictor, jlong java_pointer) { + if (java_pointer == 0) { + return JNI_FALSE; + } + std::shared_ptr *ptr = + reinterpret_cast *>(java_pointer); + ptr->reset(); + delete ptr; + return JNI_TRUE; +} + +} // namespace lite_api +} // namespace paddle + +#ifdef __cplusplus +} +#endif diff --git a/lite/api/android/jni/native/paddle_lite_jni.h b/lite/api/android/jni/native/paddle_lite_jni.h new file mode 100644 index 00000000000..913e9a4c3a8 --- /dev/null +++ b/lite/api/android/jni/native/paddle_lite_jni.h @@ -0,0 +1,113 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class com_baidu_paddle_lite_PaddlePredictor */ +#include "lite/api/paddle_lite_factory_helper.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#ifndef LITE_ON_TINY_PUBLISH +#include "lite/api/paddle_use_passes.h" +#endif +#ifdef __cplusplus +extern "C" { +#endif + +namespace paddle { +namespace lite_api { + +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: run + * Signature: ()Z + */ +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_run(JNIEnv *, jobject); + +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: saveOptimizedModel + * Signature: (Ljava/lang/String;)Z + */ +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_saveOptimizedModel(JNIEnv *, + jobject, + jstring); + +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: getInputCppTensorPointer + * Signature: (I)J + */ +JNIEXPORT jlong JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_getInputCppTensorPointer(JNIEnv *, + jobject, + jint); + +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: getOutputCppTensorPointer + * Signature: (I)J + */ +JNIEXPORT jlong JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_getOutputCppTensorPointer(JNIEnv *, + jobject, + jint); + +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: getCppTensorPointerByName + * Signature: (Ljava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_getCppTensorPointerByName(JNIEnv *, + jobject, + jstring); + +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: newCppPaddlePredictor + * Signature: (Lcom/baidu/paddle/lite/CxxConfig;)J + */ +JNIEXPORT jlong JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_newCppPaddlePredictor__Lcom_baidu_\ +paddle_lite_CxxConfig_2(JNIEnv *, jobject, jobject); + +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: newCppPaddlePredictor + * Signature: (Lcom/baidu/paddle/lite/MobileConfig;)J + */ +JNIEXPORT jlong JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_newCppPaddlePredictor__Lcom_baidu_\ +paddle_lite_MobileConfig_2(JNIEnv *, jobject, jobject); + +/* + * Class: com_baidu_paddle_lite_PaddlePredictor + * Method: deleteCppPaddlePredictor + * Signature: (J)Z + */ +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_PaddlePredictor_deleteCppPaddlePredictor(JNIEnv *, + jobject, + jlong); + +} // namespace lite_api +} // namespace paddle + +#ifdef __cplusplus +} +#endif diff --git a/lite/api/android/jni/native/tensor_jni.cc b/lite/api/android/jni/native/tensor_jni.cc new file mode 100644 index 00000000000..59cafa19399 --- /dev/null +++ b/lite/api/android/jni/native/tensor_jni.cc @@ -0,0 +1,168 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/api/android/jni/native/tensor_jni.h" + +#include +#include + +#include "lite/api/android/jni/native/convert_util_jni.h" + +#ifdef __cplusplus +extern "C" { +#endif + +namespace paddle { +namespace lite_api { + +inline static int64_t product(const std::vector &vec) { + if (vec.empty()) { + return 0; + } + int64_t result = 1; + for (int64_t d : vec) { + result *= d; + } + return result; +} + +inline static bool is_const_tensor(JNIEnv *env, jobject jtensor) { + jclass jclazz = env->GetObjectClass(jtensor); + jfieldID jfield = env->GetFieldID(jclazz, "readOnly", "Z"); + jboolean read_only = env->GetBooleanField(jtensor, jfield); + return static_cast(read_only); +} + +inline static std::unique_ptr *get_writable_tensor_pointer( + JNIEnv *env, jobject jtensor) { + jclass jclazz = env->GetObjectClass(jtensor); + jfieldID jfield = env->GetFieldID(jclazz, "cppTensorPointer", "J"); + jlong java_pointer = env->GetLongField(jtensor, jfield); + std::unique_ptr *ptr = + reinterpret_cast *>(java_pointer); + return ptr; +} + +inline static std::unique_ptr *get_read_only_tensor_pointer( + JNIEnv *env, jobject jtensor) { + jclass jclazz = env->GetObjectClass(jtensor); + jfieldID jfield = env->GetFieldID(jclazz, "cppTensorPointer", "J"); + jlong java_pointer = env->GetLongField(jtensor, jfield); + std::unique_ptr *ptr = + reinterpret_cast *>(java_pointer); + return ptr; +} + +JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeResize( + JNIEnv *env, jobject jtensor, jlongArray dims) { + std::unique_ptr *tensor = get_writable_tensor_pointer(env, jtensor); + if (tensor == nullptr || (*tensor == nullptr)) { + return JNI_FALSE; + } + std::vector shape = jlongarray_to_int64_vector(env, dims); + (*tensor)->Resize(shape); + return JNI_TRUE; +} + +JNIEXPORT jlongArray JNICALL +Java_com_baidu_paddle_lite_Tensor_shape(JNIEnv *env, jobject jtensor) { + if (is_const_tensor(env, jtensor)) { + std::unique_ptr *tensor = + get_read_only_tensor_pointer(env, jtensor); + std::vector shape = (*tensor)->shape(); + return int64_vector_to_jlongarray(env, shape); + } else { + std::unique_ptr *tensor = get_writable_tensor_pointer(env, jtensor); + std::vector shape = (*tensor)->shape(); + return int64_vector_to_jlongarray(env, shape); + } +} + +JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3F( + JNIEnv *env, jobject jtensor, jfloatArray buf) { + std::unique_ptr *tensor = get_writable_tensor_pointer(env, jtensor); + if (tensor == nullptr || (*tensor == nullptr)) { + return JNI_FALSE; + } + int64_t buf_size = (int64_t)env->GetArrayLength(buf); + if (buf_size != product((*tensor)->shape())) { + return JNI_FALSE; + } + + float *input = (*tensor)->mutable_data(); + env->GetFloatArrayRegion(buf, 0, buf_size, input); + return JNI_TRUE; +} + +JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B( + JNIEnv *env, jobject jtensor, jbyteArray buf) { + std::unique_ptr *tensor = get_writable_tensor_pointer(env, jtensor); + if (tensor == nullptr || (*tensor == nullptr)) { + return JNI_FALSE; + } + int64_t buf_size = (int64_t)env->GetArrayLength(buf); + if (buf_size != product((*tensor)->shape())) { + return JNI_FALSE; + } + + int8_t *input = (*tensor)->mutable_data(); + env->GetByteArrayRegion(buf, 0, buf_size, input); + return JNI_TRUE; +} + +JNIEXPORT jfloatArray JNICALL +Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *env, jobject jtensor) { + if (is_const_tensor(env, jtensor)) { + std::unique_ptr *tensor = + get_read_only_tensor_pointer(env, jtensor); + return cpp_array_to_jfloatarray( + env, (*tensor)->data(), product((*tensor)->shape())); + } else { + std::unique_ptr *tensor = get_writable_tensor_pointer(env, jtensor); + return cpp_array_to_jfloatarray( + env, (*tensor)->data(), product((*tensor)->shape())); + } +} + +JNIEXPORT jbyteArray JNICALL +Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *env, jobject jtensor) { + if (is_const_tensor(env, jtensor)) { + std::unique_ptr *tensor = + get_read_only_tensor_pointer(env, jtensor); + return cpp_array_to_jbytearray( + env, (*tensor)->data(), product((*tensor)->shape())); + } else { + std::unique_ptr *tensor = get_writable_tensor_pointer(env, jtensor); + return cpp_array_to_jbytearray( + env, (*tensor)->data(), product((*tensor)->shape())); + } +} + +JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_deleteCppTensor( + JNIEnv *env, jobject jtensor, jlong java_pointer) { + if (java_pointer == 0) { + return JNI_FALSE; + } + std::unique_ptr *ptr = + reinterpret_cast *>(java_pointer); + ptr->reset(); + delete ptr; + return JNI_TRUE; +} + +} // namespace lite_api +} // namespace paddle + +#ifdef __cplusplus +} +#endif diff --git a/lite/api/android/jni/native/tensor_jni.h b/lite/api/android/jni/native/tensor_jni.h new file mode 100644 index 00000000000..34c35b6a76f --- /dev/null +++ b/lite/api/android/jni/native/tensor_jni.h @@ -0,0 +1,90 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class com_baidu_paddle_lite_Tensor */ + +#ifndef PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ +#define PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ +#ifdef __cplusplus +extern "C" { +#endif + +namespace paddle { +namespace lite_api { + +/* + * Class: com_baidu_paddle_lite_Tensor + * Method: shape + * Signature: ()[J + */ +JNIEXPORT jlongArray JNICALL Java_com_baidu_paddle_lite_Tensor_shape(JNIEnv *, + jobject); + +/* + * Class: com_baidu_paddle_lite_Tensor + * Method: getFloatData + * Signature: ()[F + */ +JNIEXPORT jfloatArray JNICALL +Java_com_baidu_paddle_lite_Tensor_getFloatData(JNIEnv *, jobject); + +/* + * Class: com_baidu_paddle_lite_Tensor + * Method: getByteData + * Signature: ()[B + */ +JNIEXPORT jbyteArray JNICALL +Java_com_baidu_paddle_lite_Tensor_getByteData(JNIEnv *, jobject); + +/* + * Class: com_baidu_paddle_lite_Tensor + * Method: nativeResize + * Signature: ([J)Z + */ +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_Tensor_nativeResize(JNIEnv *, jobject, jlongArray); + +/* + * Class: com_baidu_paddle_lite_Tensor + * Method: nativeSetData + * Signature: ([F)Z + */ +JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3F( + JNIEnv *, jobject, jfloatArray); + +/* + * Class: com_baidu_paddle_lite_Tensor + * Method: nativeSetData + * Signature: ([B)Z + */ +JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_lite_Tensor_nativeSetData___3B( + JNIEnv *, jobject, jbyteArray); + +/* + * Class: com_baidu_paddle_lite_Tensor + * Method: deleteCppTensor + * Signature: (J)Z + */ +JNIEXPORT jboolean JNICALL +Java_com_baidu_paddle_lite_Tensor_deleteCppTensor(JNIEnv *, jobject, jlong); + +} // namespace lite_api +} // namespace paddle + +#ifdef __cplusplus +} +#endif +#endif // PADDLE_FLUID_LITE_API_ANDROID_JNI_NATIVE_TENSOR_JNI_H_ diff --git a/lite/api/android/jni/src/com/baidu/paddle/lite/.gitignore b/lite/api/android/jni/src/com/baidu/paddle/lite/.gitignore new file mode 100644 index 00000000000..870ec275e82 --- /dev/null +++ b/lite/api/android/jni/src/com/baidu/paddle/lite/.gitignore @@ -0,0 +1,2 @@ +/PaddleLite.class +/PaddleLiteTest.class diff --git a/lite/api/android/jni/src/com/baidu/paddle/lite/ConfigBase.java b/lite/api/android/jni/src/com/baidu/paddle/lite/ConfigBase.java new file mode 100644 index 00000000000..51115b30167 --- /dev/null +++ b/lite/api/android/jni/src/com/baidu/paddle/lite/ConfigBase.java @@ -0,0 +1,31 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +package com.baidu.paddle.lite; + +/** + * Base class for all configurations. + */ +public class ConfigBase { + + protected String modelDir; + + public String getModelDir() { + return modelDir; + } + + public void setModelDir(String modelDir) { + this.modelDir = modelDir; + } + +} diff --git a/lite/api/android/jni/src/com/baidu/paddle/lite/CxxConfig.java b/lite/api/android/jni/src/com/baidu/paddle/lite/CxxConfig.java new file mode 100644 index 00000000000..906293c92fe --- /dev/null +++ b/lite/api/android/jni/src/com/baidu/paddle/lite/CxxConfig.java @@ -0,0 +1,39 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +package com.baidu.paddle.lite; + +/** + * CxxConfig is the configuration for the Full feature predictor. + */ +public class CxxConfig extends ConfigBase { + + protected Place preferredPlace; + protected Place[] validPlaces; + + public Place getPreferredPlace() { + return preferredPlace; + } + + public void setPreferredPlace(Place preferredPlace) { + this.preferredPlace = preferredPlace; + } + + public Place[] getValidPlaces() { + return validPlaces; + } + + public void setValidPlaces(Place[] validPlaces) { + this.validPlaces = validPlaces; + } +} diff --git a/lite/api/android/jni/src/com/baidu/paddle/lite/MobileConfig.java b/lite/api/android/jni/src/com/baidu/paddle/lite/MobileConfig.java new file mode 100644 index 00000000000..e80eaad9bb2 --- /dev/null +++ b/lite/api/android/jni/src/com/baidu/paddle/lite/MobileConfig.java @@ -0,0 +1,22 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +package com.baidu.paddle.lite; + +/** + * MobileConfig is the config for the light weight predictor, it will skip IR + * optimization or other unnecessary stages. + */ +public class MobileConfig extends ConfigBase { + // Empty class +} diff --git a/lite/api/android/jni/src/com/baidu/paddle/lite/PaddleLiteInitializer.java b/lite/api/android/jni/src/com/baidu/paddle/lite/PaddleLiteInitializer.java new file mode 100644 index 00000000000..876d7cebd44 --- /dev/null +++ b/lite/api/android/jni/src/com/baidu/paddle/lite/PaddleLiteInitializer.java @@ -0,0 +1,23 @@ +package com.baidu.paddle.lite; + +/** + * Initializer for PaddleLite. The initialization methods are called by package + * classes only. Public users don't have to call them. Public users can get + * PaddleLite information constants such as JNI lib name in this class. + */ +public class PaddleLiteInitializer { + + /** name of C++ JNI lib */ + public final static String JNI_LIB_NAME = "paddle_lite_jni"; + + /** + * loads the C++ JNI lib. We only call it in our package, so it shouldn't be + * visible to public users. + * + * @return true if initialize successfully. + */ + protected static boolean init() { + System.loadLibrary(JNI_LIB_NAME); + return true; + } +} diff --git a/lite/api/android/jni/src/com/baidu/paddle/lite/PaddlePredictor.java b/lite/api/android/jni/src/com/baidu/paddle/lite/PaddlePredictor.java new file mode 100644 index 00000000000..d022fd7d618 --- /dev/null +++ b/lite/api/android/jni/src/com/baidu/paddle/lite/PaddlePredictor.java @@ -0,0 +1,192 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +package com.baidu.paddle.lite; + +/** Java Native Interface (JNI) class for Paddle Lite APIs */ +public class PaddlePredictor { + + /** + * Java doesn't have pointer. To maintain the life cycle of underneath C++ + * PaddlePredictor object, we use a long value to maintain it. + */ + private long cppPaddlePredictorPointer; + + /** + * Constructor of a PaddlePredictor. + * + * @param config the input configuration. + */ + public PaddlePredictor(ConfigBase config) { + init(config); + } + + /** + * Creates a PaddlePredictor object. + * + * @param config the input configuration. + * @return the PaddlePredictor object, or null if failed to create it. + */ + public static PaddlePredictor createPaddlePredictor(ConfigBase config) { + PaddlePredictor predictor = new PaddlePredictor(config); + return predictor.cppPaddlePredictorPointer == 0L ? null : predictor; + } + + /** + * Get offset-th input tensor. + * + * @param offset + * @return the tensor or null if failed to get it. + */ + public Tensor getInput(int offset) { + long cppTensorPointer = getInputCppTensorPointer(offset); + return cppTensorPointer == 0 ? null : new Tensor(cppTensorPointer, /* readOnly = */ false, this); + } + + /** + * Get offset-th output tensor. + * + * @param offset + * @return the tensor or null if failed to get it. + */ + public Tensor getOutput(int offset) { + long cppTensorPointer = getOutputCppTensorPointer(offset); + return cppTensorPointer == 0 ? null : new Tensor(cppTensorPointer, /* readOnly = */ true, this); + } + + /** + * Get a tensor by name. + * + * @param name the name of the tensor. + * @return the tensor or null if failed to get it. + */ + public Tensor getTensor(String name) { + long cppTensorPointer = getCppTensorPointerByName(name); + return cppTensorPointer == 0 ? null : new Tensor(cppTensorPointer, /* readOnly = */ true, this); + } + + /** + * Run the PaddlePredictor. + * + * @return true if run successfully. + */ + public native boolean run(); + + /** + * Saves the optimized model. It is available only for {@link CxxConfig} + * + * @param modelDir the path to save the optimized model + * @return true if save successfully. Otherwise returns false. + */ + public native boolean saveOptimizedModel(String modelDir); + + /** + * Deletes C++ PaddlePredictor pointer when Java PaddlePredictor object is + * destroyed + */ + @Override + protected void finalize() throws Throwable { + clear(); + super.finalize(); + } + + /** + * Create a C++ PaddlePredictor object based on configuration + * + * @param config the input configuration + * @return true if create successfully + */ + protected boolean init(ConfigBase config) { + if (config instanceof CxxConfig) { + cppPaddlePredictorPointer = newCppPaddlePredictor((CxxConfig) config); + } else if (config instanceof MobileConfig) { + cppPaddlePredictorPointer = newCppPaddlePredictor((MobileConfig) config); + } else { + throw new IllegalArgumentException("Not supported PaddleLite Config type"); + } + return cppPaddlePredictorPointer != 0L; + } + + /** + * Deletes C++ PaddlePredictor pointer + * + * @return true if deletion success + */ + protected boolean clear() { + boolean result = false; + if (cppPaddlePredictorPointer != 0L) { + result = deleteCppPaddlePredictor(cppPaddlePredictorPointer); + cppPaddlePredictorPointer = 0L; + } + return result; + } + + /** + * Gets offset-th input tensor pointer at C++ side. + * + * @param offset + * @return a long value which is reinterpret_cast of the C++ pointer. + */ + private native long getInputCppTensorPointer(int offset); + + /** + * Gets offset-th output tensor pointer at C++ side. + * + * @param offset + * @return a long value which is reinterpret_cast of the C++ pointer. + */ + private native long getOutputCppTensorPointer(int offset); + + /** + * Gets tensor pointer at C++ side by name. + * + * @param name the name of the tensor. + * @return a long value which is reinterpret_cast of the C++ pointer. + */ + private native long getCppTensorPointerByName(String name); + + /** + * Creates a new C++ PaddlePredcitor object using CxxConfig, returns the + * reinterpret_cast value of the C++ pointer which points to C++ + * PaddlePredictor. + * + * @param config + * @return a long value which is reinterpret_cast of the C++ pointer. + */ + private native long newCppPaddlePredictor(CxxConfig config); + + /** + * Creates a new C++ PaddlePredcitor object using Mobile, returns the + * reinterpret_cast value of the C++ pointer which points to C++ + * PaddlePredictor. + * + * @param config + * @return a long value which is reinterpret_cast of the C++ pointer. + */ + private native long newCppPaddlePredictor(MobileConfig config); + + /** + * Delete C++ PaddlePredictor object pointed by the input pointer, which is + * presented by a long value. + * + * @param nativePointer a long value which is reinterpret_cast of the C++ + * pointer. + * @return true if deletion success. + */ + private native boolean deleteCppPaddlePredictor(long nativePointer); + + /* Initializes at the beginning */ + static { + PaddleLiteInitializer.init(); + } +} diff --git a/lite/api/android/jni/src/com/baidu/paddle/lite/Place.java b/lite/api/android/jni/src/com/baidu/paddle/lite/Place.java new file mode 100644 index 00000000000..598bb21bd48 --- /dev/null +++ b/lite/api/android/jni/src/com/baidu/paddle/lite/Place.java @@ -0,0 +1,148 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +package com.baidu.paddle.lite; + +/** + * Place specifies the execution context of a Kernel or input/output for a + * kernel. It is used to make the analysis of the MIR more clear and accurate. + */ +public class Place { + + /** Place hardware target type. */ + public enum TargetType { + UNKNOWN(0), HOST(1), X86(2), CUDA(3), ARM(4), OPEN_CL(5), ANY(6); + + public final int value; + + private TargetType(int value) { + this.value = value; + } + } + + /** Place precision type */ + public enum PrecisionType { + UNKNOWN(0), FLOAT(1), INT8(2), INT32(3), ANY(4); + + public final int value; + + private PrecisionType(int value) { + this.value = value; + } + } + + /** Place data layout type */ + public enum DataLayoutType { + UNKNOWN(0), NCHW(1), ANY(2); + + public final int value; + + private DataLayoutType(int value) { + this.value = value; + } + } + + private TargetType target; + private PrecisionType precision; + private DataLayoutType layout; + private int device; + + public Place() { + target = TargetType.UNKNOWN; + precision = PrecisionType.UNKNOWN; + layout = DataLayoutType.UNKNOWN; + device = 0; + } + + public Place(TargetType target) { + this(target, PrecisionType.FLOAT); + } + + public Place(TargetType target, PrecisionType precision) { + this(target, precision, DataLayoutType.NCHW); + } + + public Place(TargetType target, PrecisionType precision, DataLayoutType layout) { + this(target, precision, layout, 0); + } + + public Place(TargetType target, PrecisionType precision, DataLayoutType layout, int device) { + this.target = target; + this.precision = precision; + this.layout = layout; + this.device = device; + } + + public boolean isValid() { + return target != TargetType.UNKNOWN && precision != PrecisionType.UNKNOWN && layout != DataLayoutType.UNKNOWN; + } + + public TargetType getTarget() { + return target; + } + + public void setTarget(TargetType target) { + this.target = target; + } + + public PrecisionType getPrecision() { + return precision; + } + + public void setPrecision(PrecisionType precision) { + this.precision = precision; + } + + public DataLayoutType getLayout() { + return layout; + } + + public void setLayout(DataLayoutType layout) { + this.layout = layout; + } + + public int getDevice() { + return device; + } + + public void setDevice(int device) { + this.device = device; + } + + /** + * Returns hardware target as enum int value. + * + * @return hardware target as enum int value + */ + public int getTargetInt() { + return target.value; + } + + /** + * Returns precision target as enum int value. + * + * @return precision as enum int value + */ + public int getPrecisionInt() { + return precision.value; + } + + /** + * Returns data layout as enum int value. + * + * @return data layout as enum int value + */ + public int getDataLayoutInt() { + return layout.value; + } +} diff --git a/lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java b/lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java new file mode 100644 index 00000000000..ac78800bd2e --- /dev/null +++ b/lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java @@ -0,0 +1,141 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +package com.baidu.paddle.lite; + +/** + * Tensor class provides the Java APIs that users can get or set the shape or + * the data of a Tensor. + */ +public class Tensor { + + /** + * Java doesn't have pointer. To maintain the life cycle of underneath C++ + * PaddlePredictor object, we use a long value to maintain it. + */ + private long cppTensorPointer; + + /** + * Is this tensor read-only. This field is also used at C++ side to know whether + * we should interpret the C++ tensor pointer to "Tensor" pointer or "const + * Tensor" pointer. + */ + private boolean readOnly; + + /** + * Due to different memory management of Java and C++, at C++, if a user + * destroys PaddlePredictor object, the tensor's memory will be released and a + * pointer operating on the released tensor will cause unknown behavior. At C++ + * side, that's users' responsibility to manage memory well. But for our Java + * code, we have to prevent this case. We make this {@link Tensor} keep a + * reference to {@link PaddlePredictor} to prevent the {@link PaddlePredictor} + * object be collected by JVM before {@Tensor}. + */ + private PaddlePredictor predictor; + + /** + * Accessed by package only to prevent public users to create it wrongly. A + * Tensor can be created by {@link com.baidu.paddle.lite.PaddlePredictor} only + */ + protected Tensor(long cppTensorPointer, boolean readOnly, PaddlePredictor predictor) { + this.cppTensorPointer = cppTensorPointer; + this.readOnly = readOnly; + this.predictor = predictor; + } + + /** Deletes C++ Tensor pointer when Java Tensor object is destroyed */ + protected void finalize() throws Throwable { + if (cppTensorPointer != 0L) { + deleteCppTensor(cppTensorPointer); + cppTensorPointer = 0L; + } + super.finalize(); + } + + /** + * @return whether this Tensor is read-only. + */ + public boolean isReadOnly() { + return readOnly; + } + + /** + * Resizes the tensor shape. + * + * @param dims long array of shape. + * @return true if resize successfully. + */ + public boolean resize(long[] dims) { + if (readOnly) { + return false; + } + return nativeResize(dims); + } + + /** + * Set the tensor float data. + * + * @param buf the float array buffer which will be copied into tensor. + * @return true if set data successfully. + */ + public boolean setData(float[] buf) { + if (readOnly) { + return false; + } + return nativeSetData(buf); + } + + /** + * Set the tensor byte data. + * + * @param buf the byte array buffer which will be copied into tensor. + * @return true if set data successfully. + */ + public boolean setData(byte[] buf) { + if (readOnly) { + return false; + } + return nativeSetData(buf); + } + + /** + * @return shape of the tensor as long array. + */ + public native long[] shape(); + + /** + * @return the tensor data as float array. + */ + public native float[] getFloatData(); + + /** + * @return the tensor data as byte array. + */ + public native byte[] getByteData(); + + private native boolean nativeResize(long[] dims); + + private native boolean nativeSetData(float[] buf); + + private native boolean nativeSetData(byte[] buf); + + /** + * Delete C++ Tenor object pointed by the input pointer, which is presented by a + * long value. + * + * @param nativePointer a long value which is reinterpret_cast of the C++ + * pointer. + * @return true if deletion success. + */ + private native boolean deleteCppTensor(long nativePointer); +} \ No newline at end of file diff --git a/lite/api/android/jni/test/com/baidu/paddle/lite/PaddlePredictorTest.java b/lite/api/android/jni/test/com/baidu/paddle/lite/PaddlePredictorTest.java new file mode 100644 index 00000000000..0af11efd28f --- /dev/null +++ b/lite/api/android/jni/test/com/baidu/paddle/lite/PaddlePredictorTest.java @@ -0,0 +1,54 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +package com.baidu.paddle.lite; + +import org.junit.jupiter.api.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Deprecated test. Now we use Android demo's Instrument test. + * + * @TODO make this test as Java Unit test. Then we don't have to launch Android + * demo to test. + */ +class PaddlePredictorTest { + + @Test + public void run_defaultModel() { + MobileConfig config = new MobileConfig(); + config.setModelDir(""); + PaddlePredictor predictor = PaddlePredictor.createPaddlePredictor(config); + + float[] inputBuffer = new float[10000]; + for (int i = 0; i < 10000; ++i) { + inputBuffer[i] = i; + } + long[] dims = { 100, 100 }; + + Tensor input = predictor.getInput(0); + input.resize(dims); + input.setData(inputBuffer); + + predictor.run(); + + Tensor output = predictor.getOutput(0); + float[] outputBuffer = output.getFloatData(); + + assertEquals(outputBuffer.length, 50000); + assertEquals(outputBuffer[0], 50.2132f, 1e-3f); + assertEquals(outputBuffer[1], -28.8729f, 1e-3f); + } + +} diff --git a/lite/api/apis_test.cc b/lite/api/apis_test.cc new file mode 100644 index 00000000000..bb2b2e1b874 --- /dev/null +++ b/lite/api/apis_test.cc @@ -0,0 +1,112 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * We test multiple apis here. + */ +#include +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/light_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/core/mir/pass_registry.h" + +DEFINE_string(model_dir, "", ""); +DEFINE_string(optimized_model, "", ""); + +namespace paddle { +namespace lite { + +void SetConstInput(lite::Tensor* x) { + x->Resize(DDim(std::vector({100, 100}))); + auto* data = x->mutable_data(); + for (int i = 0; i < 100 * 100; i++) { + data[i] = i; + } +} + +bool CompareTensors(const std::string& name, + const Predictor& cxx_api, + const LightPredictor& light_api) { + const auto* a = cxx_api.GetTensor(name); + const auto* b = light_api.GetTensor(name); + return TensorCompareWith(*a, *b); +} + +TEST(CXXApi_LightApi, optim_model) { + lite::Predictor cxx_api; + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, // Both works on X86 and ARM + }); + // On ARM devices, the preferred X86 target not works, but it can still + // select ARM kernels. + cxx_api.Build( + FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places); + cxx_api.SaveModel(FLAGS_optimized_model); +} + +TEST(CXXApi_LightApi, save_and_load_model) { + lite::Predictor cxx_api; + lite::LightPredictor light_api(FLAGS_optimized_model); + + // CXXAPi + { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, // Both works on X86 and ARM + }); + // On ARM devices, the preferred X86 target not works, but it can still + // select ARM kernels. + cxx_api.Build( + FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places); + + auto* x = cxx_api.GetInput(0); + SetConstInput(x); + + cxx_api.Run(); + + LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model; + cxx_api.SaveModel(FLAGS_optimized_model); + } + + // LightApi + { + auto* x = light_api.GetInput(0); + SetConstInput(x); + + light_api.Run(); + } + + const auto* cxx_out = cxx_api.GetOutput(0); + const auto* light_out = light_api.GetOutput(0); + ASSERT_TRUE(TensorCompareWith(*cxx_out, *light_out)); + + std::vector tensors_with_order({ + "a", "fc_0.w_0", "scale_0.tmp_0", + }); + + for (const auto& tensor_name : tensors_with_order) { + ASSERT_TRUE(CompareTensors(tensor_name, cxx_api, light_api)); + } +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc new file mode 100644 index 00000000000..36529ecf300 --- /dev/null +++ b/lite/api/cxx_api.cc @@ -0,0 +1,141 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/api/cxx_api.h" +#include +#include +#include +#include +#include "lite/utils/io.h" +#ifdef LITE_WITH_NPU +#include "lite/npu/npu_helper.h" +#endif + +namespace paddle { +namespace lite { + +void Predictor::SaveModel(const std::string &dir, + lite_api::LiteModelType model_type) { + if (!program_) { + GenRuntimeProgram(); + } + program_->SaveOpInfosToProgram(&program_desc_); + switch (model_type) { + case lite_api::LiteModelType::kProtobuf: + SaveModelPb(dir, *program_->exec_scope(), program_desc_); + break; + case lite_api::LiteModelType::kNaiveBuffer: + SaveModelNaive(dir, *program_->exec_scope(), program_desc_); + break; + default: + LOG(FATAL) << "Unknown model type"; + } +#ifdef LITE_WITH_NPU + for (auto name : npu::DeviceInfo::Global().AllClientNames()) { + // the npu offline model is saved in current dir + // so just copy to dst dir + CHECK_EQ( + system(string_format("cp -r %s %s", name.c_str(), dir.c_str()).c_str()), + 0) + << "Failed copy NPU model to " << dir; + } +#endif +} + +lite::Tensor *Predictor::GetInput(size_t offset) { + auto *_feed_list = exec_scope_->FindVar("feed"); + CHECK(_feed_list) << "no feed variable in exec_scope"; + auto *feed_list = _feed_list->GetMutable>(); + if (offset >= feed_list->size()) { + feed_list->resize(offset + 1); + } + return &feed_list->at(offset); +} + +const lite::Tensor *Predictor::GetOutput(size_t offset) const { + auto *_fetch_list = exec_scope_->FindVar("fetch"); + CHECK(_fetch_list) << "no fatch variable in exec_scope"; + auto &fetch_list = *_fetch_list->GetMutable>(); + CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; + return &fetch_list.at(offset); +} + +const cpp::ProgramDesc &Predictor::program_desc() const { + return program_desc_; +} +const RuntimeProgram &Predictor::runtime_program() const { return *program_; } + +void Predictor::Build(const std::string &model_path, + const Place &prefer_place, + const std::vector &valid_places, + const std::vector &passes, + lite_api::LiteModelType model_type) { + LOG(INFO) << "Load model from " << model_path; + switch (model_type) { + case lite_api::LiteModelType::kProtobuf: + LoadModelPb(model_path, scope_.get(), &program_desc_); + break; + case lite_api::LiteModelType::kNaiveBuffer: + LoadModelNaive(model_path, scope_.get(), &program_desc_); + break; + default: + LOG(FATAL) << "Unknown model type"; + } + Build(program_desc_, prefer_place, valid_places, passes); +} + +void Predictor::Build(const cpp::ProgramDesc &desc, + const Place &prefer_place, + const std::vector &valid_places, + const std::vector &passes) { + program_desc_ = desc; + Program program(desc, scope_, valid_places); + optimizer_.KernelPickPreferPlace(prefer_place); + core::KernelPickFactor factor; + factor.ConsiderTarget(); + factor.ConsiderPrecision(); + optimizer_.Run(std::move(program), valid_places, factor, passes); + exec_scope_ = optimizer_.exec_scope(); +} + +void Predictor::GenRuntimeProgram() { + program_ = optimizer_.GenRuntimeProgram(); + CHECK_EQ(exec_scope_, program_->exec_scope()); + program_generated_ = true; +} + +void Predictor::GenNPURuntimeProgram() { + program_ = optimizer_.GenNPURuntimeProgram(); + CHECK_EQ(exec_scope_, program_->exec_scope()); + program_generated_ = true; +} + +const lite::Tensor *Predictor::GetTensor(const std::string &name) const { + auto *var = exec_scope_->FindVar(name); + return &var->Get(); +} + +#ifdef LITE_WITH_TRAIN +void Predictor::FeedVars(const std::vector &tensors) { + auto var = scope_->FindVar("feed"); + auto &feed_list = *(var->GetMutable>()); + feed_list.resize(tensors.size()); + + for (size_t i = 0; i < tensors.size(); ++i) + feed_list[i].ShareDataWith(tensors[i]); +} +#endif + +} // namespace lite +} // namespace paddle diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h new file mode 100644 index 00000000000..5d94a75bb12 --- /dev/null +++ b/lite/api/cxx_api.h @@ -0,0 +1,165 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "lite/api/paddle_api.h" +#include "lite/core/op_lite.h" +#include "lite/core/optimizer.h" +#include "lite/core/program.h" +#include "lite/core/types.h" +#include "lite/model_parser/model_parser.h" + +namespace paddle { +namespace lite { + +/* + * Predictor for inference, input a model, it will optimize and execute it. + */ +class LITE_API Predictor { + public: + // Create an empty predictor. + Predictor() { scope_ = std::make_shared(); } + // Create a predictor with the weight variable scope set. + explicit Predictor(const std::shared_ptr& root_scope) + : scope_(root_scope) {} + + // Build from a model, with places set for hardware config. + void Build( + const std::string& model_path, + const Place& prefer_place, + const std::vector& valid_places, + const std::vector& passes = {}, + lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf); + + void Build(const cpp::ProgramDesc& desc, + const Place& prefer_place, + const std::vector& valid_places, + const std::vector& passes = {}); + + void GenRuntimeProgram(); + + void GenNPURuntimeProgram(); + + // Run the predictor for a single batch of data. + void Run() { + if (!program_generated_) { + GenRuntimeProgram(); + } + program_->Run(); + LOG(INFO) << "running"; + } + + // Get offset-th col of feed inputs. + lite::Tensor* GetInput(size_t offset); + + // Get offset-th col of fetch results. + const lite::Tensor* GetOutput(size_t offset) const; + + const cpp::ProgramDesc& program_desc() const; + const lite::Tensor* GetTensor(const std::string& name) const; + const RuntimeProgram& runtime_program() const; + + // This method is disabled in mobile, for unnecessary dependencies required. + void SaveModel( + const std::string& dir, + lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf); + +#ifdef LITE_WITH_TRAIN + void Run(const std::vector& tensors) { + FeedVars(tensors); + program_->Run(); + } + + void FeedVars(const std::vector& tensors); +#endif + + private: + Optimizer optimizer_; + cpp::ProgramDesc program_desc_; + std::shared_ptr scope_; + const Scope* exec_scope_; + std::unique_ptr program_; + bool program_generated_{false}; +}; + +/* + * An executor for training. + * + * Usage: + * + * CXXTrainer trainer(...); + * trainer.RunStartupProgram(...); + * auto exe = BuildMainProgramExecutor(...); + * + * for (auto& epoch : epoches) { + * auto* tensor0 = exe.GetInput(...); + * // fill data for tensor0 + * exe.Run(); + * } +#ifdef LITE_WITH_X86 +class LITE_API CXXTrainer { + public: + CXXTrainer(const std::shared_ptr& root_scope, + const Place& preferred_place, + const std::vector& valid_places) + : scope_(root_scope), + preferred_place_(preferred_place), + valid_places_(valid_places), + main_program_executor_(Predictor(scope_)) {} + + // Build the RuntimeProgram cache for the main program. The cache will run + // multiple times for the epoches. + // NOTE Just support to execute the 0-th block currently. + Predictor& BuildMainProgramExecutor(const framework::proto::ProgramDesc& desc, + int block_id = 0) { + main_program_executor_.Build(desc, preferred_place_, valid_places_); + return main_program_executor_; + } + +#ifdef LITE_WITH_TRAIN + Predictor& BuildMainProgramExecutor(framework::ProgramDesc& desc) { // NOLINT + return BuildMainProgramExecutor(*desc.Proto()); + } + + void RunStartupProgram(framework::ProgramDesc& desc) { // NOLINT + RunStartupProgram(*desc.Proto()); + } +#endif + + // Run the startup program. It just executes once, no cache needed. + void RunStartupProgram(const framework::proto::ProgramDesc& desc, + int block_id = 0) { + Predictor exe(scope_); + exe.Build(desc, preferred_place_, valid_places_); + exe.Run(); + } + + private: + std::shared_ptr scope_; + + Place preferred_place_; + std::vector valid_places_; + + // The training program. + Predictor main_program_executor_; +}; +#endif +*/ + +} // namespace lite +} // namespace paddle diff --git a/lite/api/cxx_api_bin.cc b/lite/api/cxx_api_bin.cc new file mode 100644 index 00000000000..836909556dc --- /dev/null +++ b/lite/api/cxx_api_bin.cc @@ -0,0 +1,129 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/api/cxx_api.h" +#include // NOLINT +#include "lite/api/paddle_use_passes.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +using Time = decltype(std::chrono::high_resolution_clock::now()); +Time time() { return std::chrono::high_resolution_clock::now(); } +double time_diff(Time t1, Time t2) { + typedef std::chrono::microseconds ms; + auto diff = t2 - t1; + ms counter = std::chrono::duration_cast(diff); + return counter.count() / 1000.0; +} + +void Run(const char* model_dir, int repeat) { +#ifdef LITE_WITH_ARM + DeviceInfo::Init(); +#endif + lite::Predictor predictor; + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kInt8)}, + }); + + predictor.Build( + model_dir, Place{TARGET(kARM), PRECISION(kInt8)}, valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < input_tensor->dims().production(); i++) { + data[i] = 1; + } + + auto time1 = time(); + for (int i = 0; i < repeat; i++) predictor.Run(); + auto time2 = time(); + std::cout << " predict cost: " << time_diff(time1, time2) / repeat << "ms" + << std::endl; + + auto* out = predictor.GetOutput(0); + LOG(INFO) << out << " memory size " << out->data_size(); + LOG(INFO) << "out " << out->data()[0]; + LOG(INFO) << "out " << out->data()[1]; + LOG(INFO) << "dims " << out->dims(); + LOG(INFO) << "out data size: " << out->data_size(); +} + +} // namespace lite +} // namespace paddle + +int main(int argc, char** argv) { + CHECK_EQ(argc, 3) << "usage: ./cmd "; + paddle::lite::Run(argv[1], std::stoi(argv[2])); + + return 0; +} + +USE_LITE_OP(mul); +USE_LITE_OP(fc); +USE_LITE_OP(scale); +USE_LITE_OP(feed); +USE_LITE_OP(fetch); +USE_LITE_OP(io_copy); +USE_LITE_OP(io_copy_once); + +USE_LITE_OP(conv2d); +USE_LITE_OP(batch_norm); +USE_LITE_OP(relu); +USE_LITE_OP(depthwise_conv2d); +USE_LITE_OP(pool2d); +USE_LITE_OP(elementwise_add); +USE_LITE_OP(softmax); +USE_LITE_OP(fake_quantize_moving_average_abs_max); +USE_LITE_OP(fake_dequantize_max_abs); + +USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); +USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); +USE_LITE_OP(calib); + +#ifdef LITE_WITH_ARM +USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(fc, kARM, kInt8, kNCHW, int8out); +USE_LITE_KERNEL(fc, kARM, kInt8, kNCHW, fp32out); +USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); + +USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, int8_out); +USE_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, fp32_out); +USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); + +USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8); +USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32); + +// USE_LITE_KERNEL(feed, kARM, kAny, kAny, def); +// USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def); +#endif // LITE_WITH_ARM + +#ifdef LITE_WITH_CUDA +USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host); +USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, host_to_device); +USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, device_to_host); +#endif diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc new file mode 100644 index 00000000000..bf741ef5899 --- /dev/null +++ b/lite/api/cxx_api_impl.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_api.h" + +namespace paddle { +namespace lite { + +class CxxPaddleApiImpl : public lite_api::PaddlePredictor { + public: + CxxPaddleApiImpl(); + + /// Create a new predictor from a config. + void Init(const lite_api::CxxConfig &config); + + std::unique_ptr GetInput(int i) override; + + std::unique_ptr GetOutput(int i) const override; + + void Run() override; + + std::unique_ptr GetTensor( + const std::string &name) const override; + + void SaveOptimizedModel(const std::string &model_dir, + lite_api::LiteModelType model_type = + lite_api::LiteModelType::kProtobuf) override; + + private: + Predictor raw_predictor_; +}; + +CxxPaddleApiImpl::CxxPaddleApiImpl() {} + +void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { + auto places = config.valid_places(); + places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)); + raw_predictor_.Build(config.model_dir(), config.preferred_place(), places); +} + +std::unique_ptr CxxPaddleApiImpl::GetInput(int i) { + auto *x = raw_predictor_.GetInput(i); + return std::unique_ptr(new lite_api::Tensor(x)); +} + +std::unique_ptr CxxPaddleApiImpl::GetOutput( + int i) const { + const auto *x = raw_predictor_.GetOutput(i); + return std::unique_ptr(new lite_api::Tensor(x)); +} + +void CxxPaddleApiImpl::Run() { raw_predictor_.Run(); } + +std::unique_ptr CxxPaddleApiImpl::GetTensor( + const std::string &name) const { + auto *x = raw_predictor_.GetTensor(name); + return std::unique_ptr(new lite_api::Tensor(x)); +} + +void CxxPaddleApiImpl::SaveOptimizedModel(const std::string &model_dir, + lite_api::LiteModelType model_type) { + raw_predictor_.SaveModel(model_dir, model_type); +} + +} // namespace lite + +namespace lite_api { + +template <> +std::shared_ptr CreatePaddlePredictor( + const CxxConfig &config) { + auto x = std::make_shared(); + x->Init(config); + return x; +} + +} // namespace lite_api +} // namespace paddle diff --git a/lite/api/cxx_api_test.cc b/lite/api/cxx_api_test.cc new file mode 100644 index 00000000000..ca483cb3283 --- /dev/null +++ b/lite/api/cxx_api_test.cc @@ -0,0 +1,149 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/api/cxx_api.h" +#include +#include +#include +#include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" + +// For training. +DEFINE_string(startup_program_path, "", ""); +DEFINE_string(main_program_path, "", ""); + +namespace paddle { +namespace lite { + +#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +TEST(CXXApi, test) { + const lite::Tensor* out = RunHvyModel(); + LOG(INFO) << out << " memory size " << out->data_size(); + for (int i = 0; i < 10; i++) { + LOG(INFO) << "out " << out->data()[i]; + } + LOG(INFO) << "dims " << out->dims(); + // LOG(INFO) << "out " << *out; +} + +TEST(CXXApi, save_model) { + lite::Predictor predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); + predictor.Build( + FLAGS_model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places); + + LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model; + predictor.SaveModel(FLAGS_optimized_model, + lite_api::LiteModelType::kProtobuf); + predictor.SaveModel(FLAGS_optimized_model + ".naive", + lite_api::LiteModelType::kNaiveBuffer); +} + +/*TEST(CXXTrainer, train) { + Place prefer_place({TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)}); + std::vector valid_places({prefer_place}); + auto scope = std::make_shared(); + + CXXTrainer trainer(scope, prefer_place, valid_places); + + std::string main_program_pb, startup_program_pb; + ReadBinaryFile(FLAGS_main_program_path, &main_program_pb); + ReadBinaryFile(FLAGS_startup_program_path, &startup_program_pb); + framework::proto::ProgramDesc main_program_desc, startup_program_desc; + main_program_desc.ParseFromString(main_program_pb); + startup_program_desc.ParseFromString(startup_program_pb); + + // LOG(INFO) << main_program_desc.DebugString(); + + for (const auto& op : main_program_desc.blocks(0).ops()) { + LOG(INFO) << "get op " << op.type(); + } + + return; + + trainer.RunStartupProgram(startup_program_desc); + auto& exe = trainer.BuildMainProgramExecutor(main_program_desc); + auto* tensor0 = exe.GetInput(0); + tensor0->Resize(std::vector({100, 100})); + auto* data0 = tensor0->mutable_data(); + data0[0] = 0; + + exe.Run(); +}*/ +#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + +#ifdef LITE_WITH_ARM +TEST(CXXApi, save_model) { + lite::Predictor predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}}); + predictor.Build( + FLAGS_model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, valid_places); + + LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model; + predictor.SaveModel(FLAGS_optimized_model); + predictor.SaveModel(FLAGS_optimized_model + ".naive", + lite_api::LiteModelType::kNaiveBuffer); +} + +TEST(CXXApi, load_model_naive) { + lite::Predictor predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}}); + predictor.Build(FLAGS_optimized_model + ".naive", + Place{TARGET(kARM), PRECISION(kFloat)}, + valid_places, + {}, + lite_api::LiteModelType::kNaiveBuffer); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(std::vector({1, 100})); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < 100; i++) { + data[i] = 1; + } + + predictor.Run(); + + std::vector result({0.4350058, + -0.6048313, + -0.29346266, + 0.40377066, + -0.13400325, + 0.37114543, + -0.3407839, + 0.14574292, + 0.4104212, + 0.8938774}); + + auto* output_tensor = predictor.GetOutput(0); + auto output_shape = output_tensor->dims().Vectorize(); + ASSERT_EQ(output_shape.size(), 2); + ASSERT_EQ(output_shape[0], 1); + ASSERT_EQ(output_shape[1], 500); + + int step = 50; + for (int i = 0; i < result.size(); i += step) { + EXPECT_NEAR(output_tensor->data()[i], result[i], 1e-6); + } +} +#endif + +} // namespace lite +} // namespace paddle diff --git a/lite/api/efficientnet_b0_test.cc b/lite/api/efficientnet_b0_test.cc new file mode 100644 index 00000000000..14e5e956511 --- /dev/null +++ b/lite/api/efficientnet_b0_test.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +void TestModel(const std::vector &valid_places, + const Place &preferred_place) { + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + lite::Predictor predictor; + + predictor.Build(FLAGS_model_dir, preferred_place, valid_places); + + auto *input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + auto *data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + std::vector> results; + // i = 1 + results.emplace_back(std::vector( + {-0.6746618, -0.7119305, -0.053502668, -0.6767762, -0.07488631, + -1.1109267, 0.63711894, 0.5979086, -0.20651843, -0.49293622, + -0.7404337, -0.25586239, 2.244521, 0.8738271, 0.7193805, + -0.21894705, -0.90460795, 0.07160086, 0.54588217, 0.020132724})); + auto *out = predictor.GetOutput(0); + ASSERT_EQ(out->dims().size(), 2); + ASSERT_EQ(out->dims()[0], 1); + ASSERT_EQ(out->dims()[1], 1000); + + int step = 50; + for (int i = 0; i < results.size(); ++i) { + for (int j = 0; j < results[i].size(); ++j) { + EXPECT_NEAR(out->data()[j * step + (out->dims()[1] * i)], + results[i][j], + 2e-4); + } + } +} + +TEST(EfficientNetB0, test_arm) { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + // Place{TARGET(kOpenCL), PRECISION(kFloat)}, + }); + + TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); +} + +TEST(EfficientNetB0, test_opencl) { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kOpenCL), PRECISION(kFloat)}, + }); + + TestModel(valid_places, Place({TARGET(kOpenCL), PRECISION(kFloat)})); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/inceptionv4_test.cc b/lite/api/inceptionv4_test.cc new file mode 100644 index 00000000000..9b23a3ba4ef --- /dev/null +++ b/lite/api/inceptionv4_test.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_ARM +TEST(InceptionV4, test) { + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + lite::Predictor predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}}); + + predictor.Build( + FLAGS_model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + // std::vector results({0.00078033, 0.00083865, 0.00060029, 0.00057083, + // 0.00070094, 0.00080584, 0.00044525, 0.00074907, + // 0.00059774, 0.00063654}); + // + std::vector> results; + // i = 1 + results.emplace_back(std::vector( + {0.0011684548, 0.0010390386, 0.0011301535, 0.0010133048, + 0.0010259597, 0.0010982729, 0.00093195855, 0.0009141837, + 0.00096620916, 0.00089982944, 0.0010064574, 0.0010474789, + 0.0009782845, 0.0009230255, 0.0010548076, 0.0010974824, + 0.0010612885, 0.00089107914, 0.0010112736, 0.00097655767})); + auto* out = predictor.GetOutput(0); + ASSERT_EQ(out->dims().size(), 2); + ASSERT_EQ(out->dims()[0], 1); + ASSERT_EQ(out->dims()[1], 1000); + + int step = 50; + for (int i = 0; i < results.size(); ++i) { + for (int j = 0; j < results[i].size(); ++j) { + EXPECT_NEAR(out->data()[j * step + (out->dims()[1] * i)], + results[i][j], + 1e-6); + } + } +} +#endif + +} // namespace lite +} // namespace paddle diff --git a/lite/api/light_api.cc b/lite/api/light_api.cc new file mode 100644 index 00000000000..7ebfe43323a --- /dev/null +++ b/lite/api/light_api.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/api/light_api.h" + +namespace paddle { +namespace lite { + +void LightPredictor::Build(const std::string& model_dir, + lite_api::LiteModelType model_type) { + cpp::ProgramDesc desc; + LOG(INFO) << "Load model from " << model_dir; + switch (model_type) { +#ifndef LITE_ON_TINY_PUBLISH + case lite_api::LiteModelType::kProtobuf: + LoadModelPb(model_dir, scope_.get(), &desc); + break; +#endif + case lite_api::LiteModelType::kNaiveBuffer: + LoadModelNaive(model_dir, scope_.get(), &desc); + break; + default: + LOG(FATAL) << "Unknown model type"; + } + BuildRuntimeProgram(desc); +} + +Tensor* LightPredictor::GetInput(size_t offset) { + auto* _feed_list = program_->exec_scope()->FindVar("feed"); + CHECK(_feed_list) << "no feed variable in exec_scope"; + auto* feed_list = _feed_list->GetMutable>(); + if (offset >= feed_list->size()) { + feed_list->resize(offset + 1); + } + return &feed_list->at(offset); +} + +const Tensor* LightPredictor::GetOutput(size_t offset) { + auto* _fetch_list = program_->exec_scope()->FindVar("fetch"); + CHECK(_fetch_list) << "no fatch variable in exec_scope"; + auto& fetch_list = *_fetch_list->GetMutable>(); + CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; + return &fetch_list.at(offset); +} + +void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { + std::vector insts; + // 1. Create op first + Program program(prog, scope_, {}); + + // 2. Create Instructs + + // Create the kernels of the target places, and filter out the specific + // kernel with the target alias. + for (auto& op : program.ops()) { + auto kernel_type = op->op_info()->GetAttr(kKernelTypeAttr); + std::string op_type, alias; + Place place; + KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place); + auto kernels = op->CreateKernels({place}); + // filter out a kernel + auto it = std::find_if( + kernels.begin(), kernels.end(), [&](std::unique_ptr& it) { + return it->alias() == alias; + }); + CHECK(it != kernels.end()); + (*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target())); + insts.emplace_back(op, std::move(*it)); + } + program_.reset(new RuntimeProgram(std::move(insts))); + CHECK(program.exec_scope()); + program_->set_exec_scope(program.exec_scope()); +} + +LightPredictor::LightPredictor(const std::string& model_dir, + lite_api::LiteModelType model_type) { + scope_ = std::make_shared(); + Build(model_dir, model_type); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/light_api.h b/lite/api/light_api.h new file mode 100644 index 00000000000..6d3a0bcebbc --- /dev/null +++ b/lite/api/light_api.h @@ -0,0 +1,70 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * This file implements a light-weight API which can run on mobile. We limit the + * dependencies and the runtime computation complexity. + */ +#pragma once + +#include +#include +#include +#include +#include "lite/api/paddle_api.h" +#include "lite/core/context.h" +#include "lite/core/program.h" +#include "lite/core/tensor.h" +#include "lite/core/types.h" +#include "lite/model_parser/model_parser.h" + +namespace paddle { +namespace lite { + +/* + * The light weight predictor, mainly for mobile. It loads an optimized model, + * and will not depend on the MIR or perform latter optimization. + */ +class LITE_API LightPredictor { + public: + explicit LightPredictor( + const std::string& model_dir, + lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf); + + void Run() { program_->Run(); } + + // Get offset-th col of feed inputs. + Tensor* GetInput(size_t offset); + + // Get offset-th col of fetch outputs. + const Tensor* GetOutput(size_t offset); + + const lite::Tensor* GetTensor(const std::string& name) const { + auto* var = program_->exec_scope()->FindVar(name); + return &var->Get(); + } + + private: + void Build( + const std::string& model_dir, + lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf); + void BuildRuntimeProgram(const cpp::ProgramDesc& prog); + + private: + std::shared_ptr scope_; + std::unique_ptr program_; +}; + +} // namespace lite +} // namespace paddle diff --git a/lite/api/light_api_impl.cc b/lite/api/light_api_impl.cc new file mode 100644 index 00000000000..7020b9b0e82 --- /dev/null +++ b/lite/api/light_api_impl.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/api/light_api.h" +#include "lite/api/paddle_api.h" +#include "lite/model_parser/model_parser.h" + +namespace paddle { +namespace lite_api { + +class LightPredictorImpl : public PaddlePredictor { + public: + LightPredictorImpl() = default; + + std::unique_ptr GetInput(int i) override; + + std::unique_ptr GetOutput(int i) const override; + + void Run() override; + + std::unique_ptr GetTensor( + const std::string& name) const override; + + void Init(const MobileConfig& config); + + private: + std::unique_ptr raw_predictor_; +}; + +void LightPredictorImpl::Init(const MobileConfig& config) { + // LightPredictor Only support NaiveBuffer backend in publish lib + raw_predictor_.reset(new lite::LightPredictor(config.model_dir(), + LiteModelType::kNaiveBuffer)); +} + +std::unique_ptr LightPredictorImpl::GetInput(int i) { + return std::unique_ptr(new Tensor(raw_predictor_->GetInput(i))); +} + +std::unique_ptr LightPredictorImpl::GetOutput(int i) const { + return std::unique_ptr(new Tensor(raw_predictor_->GetOutput(i))); +} + +void LightPredictorImpl::Run() { raw_predictor_->Run(); } + +std::unique_ptr LightPredictorImpl::GetTensor( + const std::string& name) const { + return std::unique_ptr( + new Tensor(raw_predictor_->GetTensor(name))); +} + +template <> +std::shared_ptr CreatePaddlePredictor( + const MobileConfig& config) { + auto x = std::make_shared(); + x->Init(config); + return x; +} + +} // namespace lite_api +} // namespace paddle diff --git a/lite/api/light_api_test.cc b/lite/api/light_api_test.cc new file mode 100644 index 00000000000..6f565b518b5 --- /dev/null +++ b/lite/api/light_api_test.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/api/light_api.h" +#include +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" + +DEFINE_string(optimized_model, "", ""); + +namespace paddle { +namespace lite { + +TEST(LightAPI, load) { + if (FLAGS_optimized_model.empty()) { + FLAGS_optimized_model = "lite_naive_model"; + } + LightPredictor predictor(FLAGS_optimized_model); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({100, 100}))); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < 100 * 100; i++) { + data[i] = i; + } + + predictor.Run(); + + const auto* output = predictor.GetOutput(0); + const float* raw_output = output->data(); + + for (int i = 0; i < 10; i++) { + LOG(INFO) << "out " << raw_output[i]; + } +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/lite_api_test_helper.cc b/lite/api/lite_api_test_helper.cc new file mode 100644 index 00000000000..1b3b157f9be --- /dev/null +++ b/lite/api/lite_api_test_helper.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/api/lite_api_test_helper.h" +#include + +DEFINE_string(model_dir, "", ""); +DEFINE_string(optimized_model, "", ""); + +namespace paddle { +namespace lite { + +const lite::Tensor* RunHvyModel() { + lite::Predictor predictor; +#ifndef LITE_WITH_CUDA + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); +#else + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kNCHW)}, + Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny)}, + Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)}, + }); +#endif + + predictor.Build(FLAGS_model_dir, + Place{TARGET(kX86), PRECISION(kFloat)}, // origin cuda + valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({100, 100}))); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < 100 * 100; i++) { + data[i] = i; + } + + // LOG(INFO) << "input " << *input_tensor; + + predictor.Run(); + + const auto* out = predictor.GetOutput(0); + return out; +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/lite_api_test_helper.h b/lite/api/lite_api_test_helper.h new file mode 100644 index 00000000000..ac3be77b10c --- /dev/null +++ b/lite/api/lite_api_test_helper.h @@ -0,0 +1,31 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/api/cxx_api.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" + +DECLARE_string(model_dir); +DECLARE_string(optimized_model); + +namespace paddle { +namespace lite { + +const lite::Tensor* RunHvyModel(); + +} // namespace lite +} // namespace paddle diff --git a/lite/api/mobilenetv1_int8_test.cc b/lite/api/mobilenetv1_int8_test.cc new file mode 100644 index 00000000000..7a87e11819a --- /dev/null +++ b/lite/api/mobilenetv1_int8_test.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +void TestModel(const std::vector& valid_places, + const Place& preferred_place, + bool use_npu = false) { + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + lite::Predictor predictor; + + predictor.Build(FLAGS_model_dir, preferred_place, valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + + if (use_npu) { + predictor.GenNPURuntimeProgram(); + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + std::vector> results; + // i = 1 + results.emplace_back(std::vector( + {0.000227548, 0.000262385, 0.000260347, 0.000293865, 0.00025008})); + auto* out = predictor.GetOutput(0); + ASSERT_EQ(out->dims().size(), 2); + ASSERT_EQ(out->dims()[0], 1); + ASSERT_EQ(out->dims()[1], 1000); + + int step = 50; + for (int i = 0; i < results.size(); ++i) { + for (int j = 0; j < results[i].size(); ++j) { + EXPECT_NEAR(out->data()[j * step + (out->dims()[1] * i)], + results[i][j], + 1e-6); + } + } +} + +TEST(MobileNetV1, test_arm) { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kInt8)}, + }); + + TestModel(valid_places, Place({TARGET(kARM), PRECISION(kInt8)})); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/mobilenetv1_ssd_test.cc b/lite/api/mobilenetv1_ssd_test.cc new file mode 100644 index 00000000000..9f8ab462410 --- /dev/null +++ b/lite/api/mobilenetv1_ssd_test.cc @@ -0,0 +1,112 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_ARM +void TestModel(const std::vector& valid_places, + const Place& preferred_place) { + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + lite::Predictor predictor; + + predictor.Build(FLAGS_model_dir, preferred_place, valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 300, 300}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + std::vector> results; + // i = 1 + results.emplace_back(std::vector( + {3, 0.042103, 0.00439525, 0.0234783, 1.01127, 0.990756})); + results.emplace_back(std::vector( + {5, 0.0145793, 0.00860882, 0.0344975, 1.01375, 1.00129})); + results.emplace_back(std::vector( + {8, 0.560059, 0.00439525, 0.0234783, 1.01127, 0.990756})); + results.emplace_back(std::vector( + {9, 0.0165109, -0.0020006, 0.0013622, 0.999179, 0.991846})); + results.emplace_back(std::vector( + {12, 0.0263337, -0.0020006, 0.0013622, 0.999179, 0.991846})); + results.emplace_back(std::vector( + {15, 0.0116742, 0.00580454, 0.0321349, 1.00545, 0.98476})); + results.emplace_back(std::vector( + {17, 0.0405541, 0.00860882, 0.0344975, 1.01375, 1.00129})); + results.emplace_back(std::vector( + {18, 0.0231487, -0.00245976, 0.00771075, 1.01654, 1.00395})); + results.emplace_back(std::vector( + {19, 0.0133921, 0.00860882, 0.0344975, 1.01375, 1.00129})); + results.emplace_back(std::vector( + {20, 0.039664, 0.00860882, 0.0344975, 1.01375, 1.00129})); + + auto* out = predictor.GetOutput(0); + ASSERT_EQ(out->dims().size(), 2); + ASSERT_EQ(out->dims()[0], 10); + ASSERT_EQ(out->dims()[1], 6); + ASSERT_EQ(out->lod().size(), 1); + ASSERT_EQ(out->lod()[0].size(), 2); + ASSERT_EQ(out->lod()[0][0], 0); + ASSERT_EQ(out->lod()[0][1], 10); + + for (int i = 0; i < results.size(); ++i) { + for (int j = 0; j < results[i].size(); ++j) { + EXPECT_NEAR( + out->data()[j + (out->dims()[1] * i)], results[i][j], 5e-6); + } + } +} + +TEST(MobileNetV1_SSD, test_arm) { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + }); + + TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); +} + +#endif // LITE_WITH_ARM + +} // namespace lite +} // namespace paddle diff --git a/lite/api/mobilenetv1_test.cc b/lite/api/mobilenetv1_test.cc new file mode 100644 index 00000000000..fb40ccf7c6e --- /dev/null +++ b/lite/api/mobilenetv1_test.cc @@ -0,0 +1,152 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +DEFINE_string(optimized_model, "", "optimized_model"); + +namespace paddle { +namespace lite { + +void TestModel(const std::vector& valid_places, + const Place& preferred_place, + const std::string& model_dir = FLAGS_model_dir, + bool gen_npu = false, + bool save_model = false) { + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + lite::Predictor predictor; + + predictor.Build(model_dir, preferred_place, valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + + if (gen_npu) { + predictor.GenNPURuntimeProgram(); + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + if (save_model) { + LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model; + predictor.SaveModel(FLAGS_optimized_model); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + std::vector> ref; + ref.emplace_back(std::vector( + {0.00019130898, 9.467885e-05, 0.00015971427, 0.0003650665, + 0.00026431272, 0.00060884043, 0.0002107942, 0.0015819625, + 0.0010323516, 0.00010079765, 0.00011006987, 0.0017364529, + 0.0048292773, 0.0013995157, 0.0018453331, 0.0002428986, + 0.00020211363, 0.00013668182, 0.0005855956, 0.00025901722})); + auto* out = predictor.GetOutput(0); + const auto* pdata = out->data(); + int step = 50; +#ifdef LITE_WITH_NPU + ASSERT_EQ(out->dims().production(), 1000); + double eps = 0.1; + for (int i = 0; i < ref.size(); ++i) { + for (int j = 0; j < ref[i].size(); ++j) { + auto result = pdata[j * step + (out->dims()[1] * i)]; + auto diff = std::fabs((result - ref[i][j]) / ref[i][j]); + VLOG(3) << diff; + EXPECT_LT(diff, eps); + } + } +#else + ASSERT_EQ(out->dims().size(), 2); + ASSERT_EQ(out->dims()[0], 1); + ASSERT_EQ(out->dims()[1], 1000); + double eps = 1e-6; + for (int i = 0; i < ref.size(); ++i) { + for (int j = 0; j < ref[i].size(); ++j) { + auto result = pdata[j * step + (out->dims()[1] * i)]; + EXPECT_NEAR(result, ref[i][j], eps); + } + } +#endif +} + +#ifdef LITE_WITH_NPU +TEST(MobileNetV1, test_npu) { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kNPU), PRECISION(kFloat)}, + }); + + TestModel(valid_places, + Place({TARGET(kARM), PRECISION(kFloat)}), + FLAGS_model_dir, + true /* gen_npu */, + true /* save_model*/); + + TestModel(valid_places, + Place({TARGET(kARM), PRECISION(kFloat)}), + FLAGS_optimized_model, + false /* gen_npu */, + false /* save model */); +} +#endif // LITE_WITH_NPU + +TEST(MobileNetV1, test_arm) { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + }); + + TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); +} + +#ifdef LITE_WITH_OPENCL +TEST(MobileNetV1, test_opencl) { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kOpenCL), PRECISION(kFloat)}, + }); + + TestModel(valid_places, Place({TARGET(kOpenCL), PRECISION(kFloat)})); +} +#endif // LITE_WITH_OPENCL + +} // namespace lite +} // namespace paddle diff --git a/lite/api/mobilenetv1_yolov3_test.cc b/lite/api/mobilenetv1_yolov3_test.cc new file mode 100644 index 00000000000..ec373fb115d --- /dev/null +++ b/lite/api/mobilenetv1_yolov3_test.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_ARM +void TestModel(const std::vector& valid_places, + const Place& preferred_place) { + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + lite::Predictor predictor; + + predictor.Build(FLAGS_model_dir, preferred_place, valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 608, 608}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; i++) { + data[i] = 50; + } + + auto* img_size = predictor.GetInput(1); + img_size->Resize(DDim(std::vector({1, 2}))); + auto* size_data = img_size->mutable_data(); + size_data[0] = 608; + size_data[1] = 608; + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + std::vector> results; + // i = 1 + results.emplace_back(std::vector( + {0., 0.7803235, 577.7447, 592.5643, 582.15314, 597.3399})); + results.emplace_back(std::vector( + {0., 0.7643098, 473.50653, 592.58966, 478.26117, 597.2353})); + results.emplace_back(std::vector( + {0., 0.7614112, 593.06946, 591.99646, 598.64087, 597.553})); + results.emplace_back(std::vector( + {0., 0.7579255, 161.40321, 592.61694, 166.33885, 597.28406})); + results.emplace_back(std::vector( + {0., 0.7569634, 193.39563, 592.62164, 198.35269, 597.2968})); + results.emplace_back(std::vector( + {0., 0.7568337, 297.3981, 592.62024, 302.35202, 597.2969})); + results.emplace_back(std::vector( + {0., 0.7568283, 265.39816, 592.6203, 270.35214, 597.29694})); + results.emplace_back(std::vector( + {0., 0.74383223, 33.430492, 592.7017, 38.453976, 597.4267})); + results.emplace_back(std::vector( + {0., 0.66492873, 9.396143, 576.7084, 15.35708, 581.8059})); + results.emplace_back(std::vector( + {0., 0.6568178, 9.970305, 145.12535, 15.043035, 149.76646})); + + auto* out = predictor.GetOutput(0); + ASSERT_EQ(out->dims().size(), 2); + ASSERT_EQ(out->dims()[0], 100); + ASSERT_EQ(out->dims()[1], 6); + ASSERT_EQ(out->lod().size(), 1); + ASSERT_EQ(out->lod()[0].size(), 2); + ASSERT_EQ(out->lod()[0][0], 0); + ASSERT_EQ(out->lod()[0][1], 100); + + int skip = 10; + for (int i = 0; i < results.size(); i += skip) { + for (int j = 0; j < results[i].size(); ++j) { + EXPECT_NEAR( + out->data()[j + (out->dims()[1] * i)], results[i][j], 3e-6); + } + } +} + +TEST(MobileNetV1_YoloV3, test_arm) { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + }); + + TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); +} + +#endif // LITE_WITH_ARM + +} // namespace lite +} // namespace paddle diff --git a/lite/api/mobilenetv2_test.cc b/lite/api/mobilenetv2_test.cc new file mode 100644 index 00000000000..380d6a1fb58 --- /dev/null +++ b/lite/api/mobilenetv2_test.cc @@ -0,0 +1,154 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +DEFINE_string(optimized_model, "", "optimized_model"); + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_ARM +void TestModel(const std::vector& valid_places, + const Place& preferred_place, + const std::string& model_dir = FLAGS_model_dir, + bool gen_npu = false, + bool save_model = false) { + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + lite::Predictor predictor; + + predictor.Build(model_dir, preferred_place, valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + + if (gen_npu) { + predictor.GenNPURuntimeProgram(); + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + if (save_model) { + LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model; + predictor.SaveModel(FLAGS_optimized_model); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + std::vector> ref; + // i = 1 + ref.emplace_back(std::vector( + {0.00017082224, 5.699624e-05, 0.000260885, 0.00016412718, + 0.00034818667, 0.00015230637, 0.00032959113, 0.0014772735, + 0.0009059976, 9.5378724e-05, 5.386537e-05, 0.0006427285, + 0.0070957416, 0.0016094646, 0.0018807327, 0.00010506048, + 6.823785e-05, 0.00012269315, 0.0007806194, 0.00022354358})); + auto* out = predictor.GetOutput(0); + const auto* pdata = out->data(); + int step = 50; +#ifdef LITE_WITH_NPU + ASSERT_EQ(out->dims().production(), 1000); + double eps = 0.1; + for (int i = 0; i < ref.size(); ++i) { + for (int j = 0; j < ref[i].size(); ++j) { + auto result = pdata[j * step + (out->dims()[1] * i)]; + auto diff = std::fabs((result - ref[i][j]) / ref[i][j]); + VLOG(3) << diff; + EXPECT_LT(diff, eps); + } + } +#else + ASSERT_EQ(out->dims().size(), 2); + ASSERT_EQ(out->dims()[0], 1); + ASSERT_EQ(out->dims()[1], 1000); + for (int i = 0; i < ref.size(); ++i) { + for (int j = 0; j < ref[i].size(); ++j) { + EXPECT_NEAR(pdata[j * step + (out->dims()[1] * i)], ref[i][j], 1e-6); + } + } +#endif +} + +#ifdef LITE_WITH_NPU +TEST(MobileNetV2, test_npu) { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kNPU), PRECISION(kFloat)}, + }); + + TestModel(valid_places, + Place({TARGET(kARM), PRECISION(kFloat)}), + FLAGS_model_dir, + true /* gen_npu */, + true /* save_model*/); + + TestModel(valid_places, + Place({TARGET(kARM), PRECISION(kFloat)}), + FLAGS_optimized_model, + false /* gen_npu */, + false /* save model */); +} +#endif // LITE_WITH_NPU + +TEST(MobileNetV2, test_arm) { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + }); + + TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); +} + +#ifdef LITE_WITH_OPENCL +TEST(MobileNetV2, test_opencl) { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kOpenCL), PRECISION(kFloat)}, + }); + + TestModel(valid_places, Place({TARGET(kOpenCL), PRECISION(kFloat)})); +} +#endif // LITE_WITH_OPENCL + +#endif // LITE_WITH_ARM + +} // namespace lite +} // namespace paddle diff --git a/lite/api/model_optimize_tool.cc b/lite/api/model_optimize_tool.cc new file mode 100644 index 00000000000..cb29d1b8fd3 --- /dev/null +++ b/lite/api/model_optimize_tool.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#ifdef PADDLE_WITH_TESTING +#include +#endif +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/utils/cp_logging.h" +#include "lite/utils/string.h" + +DEFINE_string(model_dir, "", "path of the model"); +DEFINE_string( + optimize_out_type, + "protobuf", + "store type of the output optimized model. protobuf/naive_buffer"); +DEFINE_string(optimize_out, "", "path of the output optimized model"); +DEFINE_string(valid_targets, + "arm", + "The targets this model optimized for, should be one of (arm, " + "opencl, x86), splitted by space"); +DEFINE_bool(int8_mode, false, "Support Int8 quantitative mode"); + +namespace paddle { +namespace lite_api { + +void Main() { + lite_api::CxxConfig config; + config.set_model_dir(FLAGS_model_dir); + + std::vector valid_places; + auto target_reprs = lite::Split(FLAGS_valid_targets, " "); + for (auto& target_repr : target_reprs) { + if (target_repr == "arm") { + valid_places.emplace_back(TARGET(kARM)); + } else if (target_repr == "opencl") { + valid_places.emplace_back(TARGET(kOpenCL)); + } else if (target_repr == "x86") { + valid_places.emplace_back(TARGET(kX86)); + } else { + LOG(FATAL) << lite::string_format( + "Wrong target '%s' found, please check the command flag " + "'valid_targets'", + target_repr.c_str()); + } + } + + CHECK(!valid_places.empty()) + << "At least one target should be set, should set the " + "command argument 'valid_targets'"; + if (FLAGS_int8_mode) { + LOG(WARNING) << "Int8 mode is only support by ARM target"; + valid_places.push_back(Place{TARGET(kARM), PRECISION(kInt8)}); + config.set_preferred_place(Place{TARGET(kARM), PRECISION(kInt8)}); + } + config.set_valid_places(valid_places); + + auto predictor = lite_api::CreatePaddlePredictor(config); + + LiteModelType model_type; + if (FLAGS_optimize_out_type == "protobuf") { + model_type = LiteModelType::kProtobuf; + } else if (FLAGS_optimize_out_type == "naive_buffer") { + model_type = LiteModelType::kNaiveBuffer; + } else { + LOG(FATAL) << "Unsupported Model type :" << FLAGS_optimize_out_type; + } + + predictor->SaveOptimizedModel(FLAGS_optimize_out, model_type); +} + +} // namespace lite_api +} // namespace paddle + +int main(int argc, char** argv) { + google::ParseCommandLineFlags(&argc, &argv, false); + paddle::lite_api::Main(); + return 0; +} diff --git a/lite/api/model_test.cc b/lite/api/model_test.cc new file mode 100644 index 00000000000..cf350ee0742 --- /dev/null +++ b/lite/api/model_test.cc @@ -0,0 +1,181 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/cpu_info.h" +#include "lite/utils/cp_logging.h" +#include "lite/utils/string.h" + +DEFINE_string(input_shape, + "1,3,224,224", + "input shapes, separated by colon and comma"); + +namespace paddle { +namespace lite_api { + +void OutputOptModel(const std::string& load_model_dir, + const std::string& save_optimized_model_dir, + const std::vector>& input_shapes) { + lite_api::CxxConfig config; + config.set_model_dir(load_model_dir); + config.set_preferred_place(Place{TARGET(kX86), PRECISION(kFloat)}); + config.set_valid_places({ + Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + }); + auto predictor = lite_api::CreatePaddlePredictor(config); + + // delete old optimized model + int ret = system( + paddle::lite::string_format("rm -rf %s", save_optimized_model_dir.c_str()) + .c_str()); + if (ret == 0) { + LOG(INFO) << "delete old optimized model " << save_optimized_model_dir; + } + predictor->SaveOptimizedModel(save_optimized_model_dir, + LiteModelType::kNaiveBuffer); + LOG(INFO) << "Load model from " << load_model_dir; + LOG(INFO) << "Save optimized model to " << save_optimized_model_dir; +} + +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +void Run(const std::vector>& input_shapes, + const std::string& model_dir, + const int repeat, + const int thread_num, + const int warmup_times = 0) { +#ifdef LITE_WITH_ARM + lite::DeviceInfo::Init(); + lite::DeviceInfo::Global().SetRunMode(lite::LITE_POWER_HIGH, thread_num); +#endif + lite_api::MobileConfig config; + config.set_model_dir(model_dir); + + auto predictor = lite_api::CreatePaddlePredictor(config); + + for (int j = 0; j < input_shapes.size(); ++j) { + auto input_tensor = predictor->GetInput(j); + input_tensor->Resize(input_shapes[j]); + auto input_data = input_tensor->mutable_data(); + int input_num = 1; + for (int i = 0; i < input_shapes[j].size(); ++i) { + input_num *= input_shapes[j][i]; + } + for (int i = 0; i < input_num; ++i) { + input_data[i] = 1.f; + } + } + + for (int i = 0; i < warmup_times; ++i) { + predictor->Run(); + } + + auto start = lite::GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + predictor->Run(); + } + auto end = lite::GetCurrentUS(); + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << model_dir << ", threads num " << thread_num + << ", warmup: " << warmup_times << ", repeats: " << repeat + << ", spend " << (end - start) / repeat / 1000.0 + << " ms in average."; + + auto output = predictor->GetOutput(0); + auto out = output->data(); + LOG(INFO) << "out " << out[0]; + LOG(INFO) << "out " << out[1]; + auto output_shape = output->shape(); + int output_num = 1; + for (int i = 0; i < output_shape.size(); ++i) { + output_num *= output_shape[i]; + } + LOG(INFO) << "output_num: " << output_num; +} +#endif + +} // namespace lite_api +} // namespace paddle + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (FLAGS_model_dir == "") { + LOG(INFO) << "usage: " + << "--model_dir /path/to/your/model"; + exit(0); + } + std::string save_optimized_model_dir = FLAGS_model_dir + "opt2"; + + auto split_string = + [](const std::string& str_in) -> std::vector { + std::vector str_out; + std::string tmp_str = str_in; + while (!tmp_str.empty()) { + size_t next_offset = tmp_str.find(":"); + str_out.push_back(tmp_str.substr(0, next_offset)); + if (next_offset == std::string::npos) { + break; + } else { + tmp_str = tmp_str.substr(next_offset + 1); + } + } + return str_out; + }; + + auto get_shape = [](const std::string& str_shape) -> std::vector { + std::vector shape; + std::string tmp_str = str_shape; + while (!tmp_str.empty()) { + int dim = atoi(tmp_str.data()); + shape.push_back(dim); + size_t next_offset = tmp_str.find(","); + if (next_offset == std::string::npos) { + break; + } else { + tmp_str = tmp_str.substr(next_offset + 1); + } + } + return shape; + }; + + LOG(INFO) << "input shapes: " << FLAGS_input_shape; + std::vector str_input_shapes = split_string(FLAGS_input_shape); + std::vector> input_shapes; + for (int i = 0; i < str_input_shapes.size(); ++i) { + LOG(INFO) << "input shape: " << str_input_shapes[i]; + input_shapes.push_back(get_shape(str_input_shapes[i])); + } + + // Output optimized model + paddle::lite_api::OutputOptModel( + FLAGS_model_dir, save_optimized_model_dir, input_shapes); + +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + // Run inference using optimized model + paddle::lite_api::Run(input_shapes, + save_optimized_model_dir, + FLAGS_repeats, + FLAGS_threads, + FLAGS_warmup); +#endif + return 0; +} diff --git a/lite/api/ocr_attention_test.cc b/lite/api/ocr_attention_test.cc new file mode 100644 index 00000000000..26cdde3ea79 --- /dev/null +++ b/lite/api/ocr_attention_test.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +void TestModel(const std::vector& valid_places, + const Place& preferred_place, + bool use_npu = false) { + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + lite::Predictor predictor; + + predictor.Build(FLAGS_model_dir, preferred_place, valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 1, 48, 512}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + + auto* init_scores = predictor.GetInput(2); + init_scores->Resize(DDim(std::vector({1, 1}))); + auto* data_scores = init_scores->mutable_data(); + auto scores_size = input_tensor->dims().production(); + for (int i = 0; i < scores_size; i++) { + data_scores[i] = 0; + } + auto lod_scores = init_scores->mutable_lod(); + std::vector> lod_s{{0, 1}, {0, 1}}; + *lod_scores = lod_s; + + auto* init_ids = predictor.GetInput(1); + init_ids->Resize(DDim(std::vector({1, 1}))); + auto* data_ids = init_ids->mutable_data(); + auto ids_size = init_ids->dims().production(); + for (int i = 0; i < ids_size; i++) { + data_ids[i] = 0; + } + auto lod_ids = init_ids->mutable_lod(); + std::vector> lod_i{{0, 1}, {0, 1}}; + *lod_ids = lod_i; + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + // std::vector> results; + // // i = 1 + // results.emplace_back(std::vector( + // {0.00019130898, 9.467885e-05, 0.00015971427, 0.0003650665, + // 0.00026431272, 0.00060884043, 0.0002107942, 0.0015819625, + // 0.0010323516, 0.00010079765, 0.00011006987, 0.0017364529, + // 0.0048292773, 0.0013995157, 0.0018453331, 0.0002428986, + // 0.00020211363, 0.00013668182, 0.0005855956, 0.00025901722})); + // auto* out = predictor.GetOutput(0); + // ASSERT_EQ(out->dims().size(), 2); + // ASSERT_EQ(out->dims()[0], 1); + // ASSERT_EQ(out->dims()[1], 1000); + // + // int step = 50; + // for (int i = 0; i < results.size(); ++i) { + // for (int j = 0; j < results[i].size(); ++j) { + // EXPECT_NEAR(out->data()[j * step + (out->dims()[1] * i)], + // results[i][j], + // 1e-6); + // } + // } +} + +TEST(OcrAttention, test_arm) { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + }); + + TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/paddle_api.cc b/lite/api/paddle_api.cc new file mode 100644 index 00000000000..fee4ebf6dce --- /dev/null +++ b/lite/api/paddle_api.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/api/paddle_api.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite_api { + +Tensor::Tensor(void *raw) : raw_tensor_(raw) {} + +// TODO(Superjomn) refine this by using another `const void* const_raw`; +Tensor::Tensor(const void *raw) { raw_tensor_ = const_cast(raw); } + +lite::Tensor *tensor(void *x) { return static_cast(x); } +const lite::Tensor *ctensor(void *x) { + return static_cast(x); +} + +void Tensor::Resize(const shape_t &shape) { + tensor(raw_tensor_)->Resize(shape); +} + +template <> +const float *Tensor::data() const { + return ctensor(raw_tensor_)->data(); +} +template <> +const int8_t *Tensor::data() const { + return ctensor(raw_tensor_)->data(); +} + +template <> +float *Tensor::mutable_data() const { + return tensor(raw_tensor_)->mutable_data(); +} +template <> +int8_t *Tensor::mutable_data() const { + return tensor(raw_tensor_)->mutable_data(); +} + +shape_t Tensor::shape() const { + return ctensor(raw_tensor_)->dims().Vectorize(); +} + +lod_t Tensor::lod() const { return ctensor(raw_tensor_)->lod(); } + +void Tensor::SetLoD(const lod_t &lod) { tensor(raw_tensor_)->set_lod(lod); } + +void PaddlePredictor::SaveOptimizedModel(const std::string &model_dir, + LiteModelType model_type) { + LOG(FATAL) + << "The SaveOptimizedModel API is only supported by CxxConfig predictor."; +} + +template +std::shared_ptr CreatePaddlePredictor(const ConfigT &) { + return std::shared_ptr(); +} + +} // namespace lite_api +} // namespace paddle diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h new file mode 100644 index 00000000000..62df111e0ac --- /dev/null +++ b/lite/api/paddle_api.h @@ -0,0 +1,121 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * This file defines PaddlePredictor, the api for lite. It supports multiple + * hardware including ARM, X86, OpenCL, CUDA and so on. + */ + +#ifndef PADDLE_LITE_API_H_ // NOLINT +#define PADDLE_LITE_API_H_ +#include +#include +#include +#include "paddle_place.h" // NOLINT + +namespace paddle { +namespace lite_api { + +using shape_t = std::vector; +using lod_t = std::vector>; + +enum class LiteModelType { kProtobuf = 0, kNaiveBuffer, UNK }; + +struct LITE_API Tensor { + explicit Tensor(void* raw); + explicit Tensor(const void* raw); + + void Resize(const shape_t& shape); + + /// Readonly data. + template + const T* data() const; + + template + T* mutable_data() const; + + /// Shape of the tensor. + shape_t shape() const; + + // LoD of the tensor + lod_t lod() const; + + // Set LoD of the tensor + void SetLoD(const lod_t& lod); + + private: + void* raw_tensor_; +}; + +/// The PaddlePredictor defines the basic interfaces for different kinds of +/// predictors. +class LITE_API PaddlePredictor { + public: + PaddlePredictor() = default; + + /// Get i-th input. + virtual std::unique_ptr GetInput(int i) = 0; + + /// Get i-th output. + virtual std::unique_ptr GetOutput(int i) const = 0; + + virtual void Run() = 0; + + /// Get a readonly tensor, return null if no one called `name` exists. + virtual std::unique_ptr GetTensor( + const std::string& name) const = 0; + + /// Persist the optimized model to disk. This API is only supported by + /// CxxConfig, and the persisted model can be reused for MobileConfig. + virtual void SaveOptimizedModel( + const std::string& model_dir, + LiteModelType model_type = LiteModelType::kProtobuf); + + virtual ~PaddlePredictor() = default; +}; + +/// Base class for all the configs. +class LITE_API ConfigBase { + std::string model_dir_; + + public: + void set_model_dir(const std::string& x) { model_dir_ = x; } + + const std::string& model_dir() const { return model_dir_; } +}; + +/// CxxConfig is the config for the Full feature predictor. +class LITE_API CxxConfig : public ConfigBase { + Place preferred_place_; + std::vector valid_places_; + + public: + void set_preferred_place(const Place& x) { preferred_place_ = x; } + void set_valid_places(const std::vector& x) { valid_places_ = x; } + + const Place& preferred_place() const { return preferred_place_; } + const std::vector& valid_places() const { return valid_places_; } +}; + +/// MobileConfig is the config for the light weight predictor, it will skip +/// IR optimization or other unnecessary stages. +class LITE_API MobileConfig : public ConfigBase {}; + +template +std::shared_ptr CreatePaddlePredictor(const ConfigT&); + +} // namespace lite_api +} // namespace paddle + +#endif // NOLINT diff --git a/lite/api/paddle_api_test.cc b/lite/api/paddle_api_test.cc new file mode 100644 index 00000000000..cc1523f185b --- /dev/null +++ b/lite/api/paddle_api_test.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/api/paddle_api.h" +#include +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/utils/cp_logging.h" + +DEFINE_string(model_dir, "", ""); + +namespace paddle { +namespace lite_api { + +TEST(CxxApi, run) { + lite_api::CxxConfig config; + config.set_model_dir(FLAGS_model_dir); + config.set_preferred_place(Place{TARGET(kX86), PRECISION(kFloat)}); + config.set_valid_places({ + Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + }); + + auto predictor = lite_api::CreatePaddlePredictor(config); + + auto input_tensor = predictor->GetInput(0); + input_tensor->Resize(std::vector({100, 100})); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < 100 * 100; i++) { + data[i] = i; + } + + predictor->Run(); + + auto output = predictor->GetOutput(0); + auto* out = output->data(); + LOG(INFO) << out[0]; + LOG(INFO) << out[1]; + + EXPECT_NEAR(out[0], 50.2132, 1e-3); + EXPECT_NEAR(out[1], -28.8729, 1e-3); + + predictor->SaveOptimizedModel(FLAGS_model_dir + ".opt2"); + predictor->SaveOptimizedModel(FLAGS_model_dir + ".opt2.naive", + LiteModelType::kNaiveBuffer); +} + +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +TEST(LightApi, run) { + lite_api::MobileConfig config; + config.set_model_dir(FLAGS_model_dir + ".opt2.naive"); + + auto predictor = lite_api::CreatePaddlePredictor(config); + + auto input_tensor = predictor->GetInput(0); + input_tensor->Resize(std::vector({100, 100})); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < 100 * 100; i++) { + data[i] = i; + } + + predictor->Run(); + + auto output = predictor->GetOutput(0); + auto* out = output->data(); + LOG(INFO) << out[0]; + LOG(INFO) << out[1]; + + EXPECT_NEAR(out[0], 50.2132, 1e-3); + EXPECT_NEAR(out[1], -28.8729, 1e-3); +} +#endif + +} // namespace lite_api +} // namespace paddle diff --git a/lite/api/paddle_lite_factory_helper.h b/lite/api/paddle_lite_factory_helper.h new file mode 100644 index 00000000000..544cd0e3130 --- /dev/null +++ b/lite/api/paddle_lite_factory_helper.h @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * This file defines some MACROS that explicitly determine the op, kernel, mir + * passes used in the inference lib. + */ +#pragma once + +#define USE_LITE_OP(op_type__) \ + extern int touch_op_##op_type__(); \ + int LITE_OP_REGISTER_FAKE(op_type__) __attribute__((unused)) = \ + touch_op_##op_type__(); + +#define USE_LITE_KERNEL(op_type__, target__, precision__, layout__, alias__) \ + extern int touch_##op_type__##target__##precision__##layout__##alias__(); \ + int op_type__##target__##precision__##layout__##alias__ \ + __attribute__((unused)) = \ + touch_##op_type__##target__##precision__##layout__##alias__(); + +#define USE_MIR_PASS(name__) \ + extern bool mir_pass_registry##name__##_fake(); \ + static bool mir_pass_usage##name__ __attribute__((unused)) = \ + mir_pass_registry##name__##_fake(); + +#define LITE_OP_REGISTER_FAKE(op_type__) op_type__##__registry__ diff --git a/lite/api/paddle_place.cc b/lite/api/paddle_place.cc new file mode 100644 index 00000000000..03252b5430c --- /dev/null +++ b/lite/api/paddle_place.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/api/paddle_place.h" +#include "lite/utils/cp_logging.h" +#include "lite/utils/hash.h" +#include "lite/utils/replace_stl/stream.h" +#include "lite/utils/string.h" + +namespace paddle { +namespace lite_api { + +size_t Place::hash() const { + std::hash h; + size_t hash = h(static_cast(target)); + hash = lite::hash_combine(hash, static_cast(precision)); + hash = lite::hash_combine(hash, static_cast(layout)); + hash = lite::hash_combine(hash, static_cast(device)); + return hash; +} + +bool operator<(const Place& a, const Place& b) { + if (a.target != b.target) return a.target < b.target; + if (a.precision != b.precision) return a.precision < b.precision; + if (a.layout != b.layout) return a.layout < b.layout; + if (a.device != b.device) return a.device < b.device; + return false; +} + +std::string Place::DebugString() const { + STL::stringstream os; + os << TargetToStr(target) << "/" << PrecisionToStr(precision) << "/" + << DataLayoutToStr(layout); + return os.str(); +} + +const std::string& TargetToStr(TargetType target) { + static const std::string target2string[] = { + "unk", "host", "x86", "cuda", "arm", "opencl", "any", "fpga", "npu"}; + auto x = static_cast(target); + CHECK_LT(x, static_cast(TARGET(NUM))); + return target2string[x]; +} + +const std::string& PrecisionToStr(PrecisionType precision) { + static const std::string precision2string[] = { + "unk", "float", "int8_t", "int32_t", "any", "float16", "bool"}; + auto x = static_cast(precision); + CHECK_LT(x, static_cast(PRECISION(NUM))); + return precision2string[x]; +} + +const std::string& DataLayoutToStr(DataLayoutType layout) { + static const std::string datalayout2string[] = {"unk", "NCHW", "any", "NHWC"}; + auto x = static_cast(layout); + CHECK_LT(x, static_cast(DATALAYOUT(NUM))); + return datalayout2string[x]; +} + +const std::string& TargetRepr(TargetType target) { + static const std::string target2string[] = {"kUnk", + "kHost", + "kX86", + "kCUDA", + "kARM", + "kOpenCL", + "kAny", + "kFPGA", + "kNPU"}; + auto x = static_cast(target); + CHECK_LT(x, static_cast(TARGET(NUM))); + return target2string[x]; +} + +const std::string& PrecisionRepr(PrecisionType precision) { + static const std::string precision2string[] = { + "kUnk", "kFloat", "kInt8", "kInt32", "kAny", "kFP16", "kBool"}; + auto x = static_cast(precision); + CHECK_LT(x, static_cast(PRECISION(NUM))); + return precision2string[x]; +} + +const std::string& DataLayoutRepr(DataLayoutType layout) { + static const std::string datalayout2string[] = { + "kUnk", "kNCHW", "kAny", "kNHWC"}; + auto x = static_cast(layout); + CHECK_LT(x, static_cast(DATALAYOUT(NUM))); + return datalayout2string[x]; +} + +} // namespace lite_api +} // namespace paddle diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h new file mode 100644 index 00000000000..4a75539d3a0 --- /dev/null +++ b/lite/api/paddle_place.h @@ -0,0 +1,153 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +// Generic helper definitions for shared library support +#if defined _WIN32 || defined __CYGWIN__ +#define PADDLE_LITE_HELPER_DLL_IMPORT __declspec(dllimport) +#define PADDLE_LITE_HELPER_DLL_EXPORT __declspec(dllexport) +#define PADDLE_LITE_HELPER_DLL_LOCAL +#else +#if __GNUC__ >= 4 +#define PADDLE_LITE_HELPER_DLL_IMPORT __attribute__((visibility("default"))) +#define PADDLE_LITE_HELPER_DLL_EXPORT __attribute__((visibility("default"))) +#else +#define PADDLE_LITE_HELPER_DLL_IMPORT +#define PADDLE_LITE_HELPER_DLL_EXPORT +#endif +#endif + +#ifdef LITE_ON_TINY_PUBLISH +#define LITE_API PADDLE_LITE_HELPER_DLL_EXPORT +#define LITE_API_IMPORT PADDLE_LITE_HELPER_DLL_IMPORT +#else +#define LITE_API +#define LITE_API_IMPORT +#endif + +namespace paddle { +namespace lite_api { + +enum class TargetType : int { + kUnk = 0, + kHost = 1, + kX86 = 2, + kCUDA = 3, + kARM = 4, + kOpenCL = 5, + kFPGA = 7, + kNPU = 8, + kAny = 6, // any target + NUM = 9, // number of fields. +}; +enum class PrecisionType : int { + kUnk = 0, + kFloat = 1, + kInt8 = 2, + kFP16 = 5, + kInt32 = 3, + kAny = 4, // any precision + kBool = 6, + NUM = 7, // number of fields. +}; +enum class DataLayoutType : int { + kUnk = 0, + kNCHW = 1, + kNHWC = 3, + kAny = 2, // any data layout + NUM = 4, // number of fields. +}; + +enum class ActivationType : int { + kIndentity = 0, + kRelu = 1, + kRelu6 = 2, + kPRelu = 3, + kLeakyRelu = 4, + kSigmoid = 5, + kTanh = 6, + kSwish = 7 +}; + +static size_t PrecisionTypeLength(PrecisionType type) { + switch (type) { + case PrecisionType::kFloat: + return 4; + case PrecisionType::kInt8: + return 1; + case PrecisionType::kInt32: + return 4; + case PrecisionType::kFP16: + return 2; + default: + return 4; + } +} + +#define TARGET(item__) paddle::lite_api::TargetType::item__ +#define PRECISION(item__) paddle::lite_api::PrecisionType::item__ +#define DATALAYOUT(item__) paddle::lite_api::DataLayoutType::item__ + +const std::string& TargetToStr(TargetType target); + +const std::string& PrecisionToStr(PrecisionType precision); + +const std::string& DataLayoutToStr(DataLayoutType layout); + +const std::string& TargetRepr(TargetType target); + +const std::string& PrecisionRepr(PrecisionType precision); + +const std::string& DataLayoutRepr(DataLayoutType layout); + +/* + * Place specifies the execution context of a Kernel or input/output for a + * kernel. It is used to make the analysis of the MIR more clear and accurate. + */ +struct LITE_API Place { + TargetType target{TARGET(kUnk)}; + PrecisionType precision{PRECISION(kUnk)}; + DataLayoutType layout{DATALAYOUT(kUnk)}; + int16_t device{0}; // device ID + + Place() = default; + Place(TargetType target, + PrecisionType precision = PRECISION(kFloat), + DataLayoutType layout = DATALAYOUT(kNCHW), + int16_t device = 0) + : target(target), precision(precision), layout(layout), device(device) {} + + bool is_valid() const { + return target != TARGET(kUnk) && precision != PRECISION(kUnk) && + layout != DATALAYOUT(kUnk); + } + + size_t hash() const; + + bool operator==(const Place& other) const { + return target == other.target && precision == other.precision && + layout == other.layout && device == other.device; + } + + bool operator!=(const Place& other) const { return !(*this == other); } + + friend bool operator<(const Place& a, const Place& b); + + std::string DebugString() const; +}; + +} // namespace lite_api +} // namespace paddle diff --git a/lite/api/paddle_use_kernels.h b/lite/api/paddle_use_kernels.h new file mode 100644 index 00000000000..d18a86a8a72 --- /dev/null +++ b/lite/api/paddle_use_kernels.h @@ -0,0 +1,180 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * ATTENTION this header file can only include in .cc file. + */ + +#pragma once +#include "paddle_lite_factory_helper.h" // NOLINT +#ifndef LITE_WITH_FPGA +USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); +USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); +USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def); +USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def); +#else +USE_LITE_KERNEL(feed, kFPGA, kFP16, kNHWC, def); +USE_LITE_KERNEL(fetch, kFPGA, kFP16, kNHWC, def); +#endif + +#ifdef LITE_WITH_ARM +USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(lrn, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(decode_bboxes, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(box_coder, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_mul, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_max, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(fusion_elementwise_add_activation, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(fusion_elementwise_mul_activation, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(fusion_elementwise_max_activation, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(split, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(dropout, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(concat, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(transpose, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(transpose2, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(power, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(shuffle_channel, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(yolo_box, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(argmax, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(axpy, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(leaky_relu, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(relu_clipped, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(prelu, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(sigmoid, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(tanh, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(swish, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(log, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(conv2d_transpose, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(pad2d, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(prior_box, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(density_prior_box, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(negative, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(crop, kARM, kFloat, kNCHW, def); + +USE_LITE_KERNEL(norm, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(sequence_softmax, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(im2sequence, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(bilinear_interp, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(nearest_interp, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(logical_xor, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(logical_and, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(less_than, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(top_k, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(increment, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(write_to_array, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(read_from_array, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(multiclass_nms, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(reduce_max, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(sequence_expand, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(sequence_pool, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(shape, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(fill_constant, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(cast, kARM, kFloat, kNCHW, def) +USE_LITE_KERNEL(slice, kARM, kFloat, kNCHW, def) + +USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8); +USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32); +USE_LITE_KERNEL(calib_once, kARM, kInt8, kNCHW, fp32_to_int8); +USE_LITE_KERNEL(calib_once, kARM, kInt8, kNCHW, int8_to_fp32); +USE_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, int8_out); +USE_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, fp32_out); +USE_LITE_KERNEL(fc, kARM, kInt8, kNCHW, int8out); +USE_LITE_KERNEL(fc, kARM, kInt8, kNCHW, fp32out); +USE_LITE_KERNEL(gru_unit, kARM, kFloat, kNCHW, def) +USE_LITE_KERNEL(gru, kARM, kFloat, kNCHW, def) +USE_LITE_KERNEL(beam_search_decode, kARM, kFloat, kNCHW, def) +USE_LITE_KERNEL(beam_search, kARM, kFloat, kNCHW, def) +USE_LITE_KERNEL(while, kARM, kFloat, kNCHW, def) +USE_LITE_KERNEL(lod_reset, kARM, kFloat, kNCHW, def) +USE_LITE_KERNEL(lookup_table, kARM, kFloat, kNCHW, def) +USE_LITE_KERNEL(is_empty, kARM, kFloat, kNCHW, def) +#endif + +#ifdef LITE_WITH_X86 +// NOTE all the X86 kernels are disabled temporarily for kernel are changed. +// USE_LITE_KERNEL(relu, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(scale, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(fill_constant, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(square, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(dropout, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(concat, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(pool2d, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(batch_norm, kX86, kFloat, kNCHW, def); +#endif + +#ifdef LITE_WITH_CUDA +USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host); +USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, host_to_device); +USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, device_to_host); +#endif + +#ifdef LITE_WITH_OPENCL +USE_LITE_KERNEL(io_copy, kOpenCL, kAny, kAny, host_to_device); +USE_LITE_KERNEL(io_copy, kOpenCL, kAny, kAny, device_to_host); +USE_LITE_KERNEL(io_copy_once, kOpenCL, kAny, kAny, host_to_device); +USE_LITE_KERNEL(io_copy_once, kOpenCL, kAny, kAny, device_to_host); + +USE_LITE_KERNEL(fc, kOpenCL, kFloat, kNCHW, def); +USE_LITE_KERNEL(mul, kOpenCL, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_add, kOpenCL, kFloat, kNCHW, def); +USE_LITE_KERNEL(fusion_elementwise_add_activation, kOpenCL, kFloat, kNCHW, def); +USE_LITE_KERNEL(pool2d, kOpenCL, kFloat, kNCHW, def); +USE_LITE_KERNEL(relu, kOpenCL, kFloat, kNCHW, def); +USE_LITE_KERNEL(depthwise_conv2d, kOpenCL, kFloat, kNCHW, def); +USE_LITE_KERNEL(conv2d, kOpenCL, kFloat, kNCHW, def); +#endif + +#ifdef LITE_WITH_NPU +USE_LITE_KERNEL(graph_op, kNPU, kFloat, kNCHW, def); +#endif +#ifdef LITE_WITH_FPGA +USE_LITE_KERNEL(relu, kFPGA, kFP16, kNHWC, def); +USE_LITE_KERNEL(conv2d, kFPGA, kFP16, kNHWC, def); +USE_LITE_KERNEL(elementwise_add, kFPGA, kFP16, kNHWC, def); +USE_LITE_KERNEL(fusion_elementwise_add_activation, kFPGA, kFP16, kNHWC, def); +USE_LITE_KERNEL(fc, kFPGA, kFP16, kNHWC, def); +USE_LITE_KERNEL(pool2d, kFPGA, kFP16, kNHWC, def); +USE_LITE_KERNEL(scale, kFPGA, kFP16, kNHWC, def); +USE_LITE_KERNEL(softmax, kFPGA, kFP16, kNHWC, def); +USE_LITE_KERNEL(io_copy, kFPGA, kAny, kAny, host_to_device); +USE_LITE_KERNEL(io_copy, kFPGA, kAny, kAny, device_to_host); +USE_LITE_KERNEL(io_copy_once, kFPGA, kAny, kAny, host_to_device_once); +USE_LITE_KERNEL(io_copy_once, kFPGA, kAny, kAny, device_to_host_once); +USE_LITE_KERNEL(calib, kFPGA, kFP16, kNHWC, fp32_to_fp16_fpga); +USE_LITE_KERNEL(calib, kFPGA, kFP16, kNHWC, fp16_to_fp32_fpga); +USE_LITE_KERNEL(calib_once, kFPGA, kFP16, kNHWC, fp32_to_fp16_fpga); +USE_LITE_KERNEL(calib_once, kFPGA, kFP16, kNHWC, fp16_to_fp32_fpga); +USE_LITE_KERNEL(layout, kFPGA, kAny, kNHWC, hwc_to_chw_fpga_fp16); +USE_LITE_KERNEL(layout, kFPGA, kAny, kNHWC, chw_to_hwc_fpga_fp16); +USE_LITE_KERNEL(layout_once, kFPGA, kAny, kNHWC, hwc_to_chw_fpga_fp16); +USE_LITE_KERNEL(layout_once, kFPGA, kAny, kNHWC, chw_to_hwc_fpga_fp16); +#endif diff --git a/lite/api/paddle_use_ops.h b/lite/api/paddle_use_ops.h new file mode 100644 index 00000000000..d11afb358bc --- /dev/null +++ b/lite/api/paddle_use_ops.h @@ -0,0 +1,106 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// ATTENTION This can only include in a .cc file. + +#include "paddle_lite_factory_helper.h" // NOLINT + +USE_LITE_OP(mul); +USE_LITE_OP(fc); +USE_LITE_OP(relu); +USE_LITE_OP(scale); +USE_LITE_OP(feed); +USE_LITE_OP(lrn); +USE_LITE_OP(decode_bboxes); +USE_LITE_OP(box_coder); +USE_LITE_OP(fetch); +USE_LITE_OP(io_copy); +USE_LITE_OP(io_copy_once); +USE_LITE_OP(elementwise_add) +USE_LITE_OP(elementwise_sub) +USE_LITE_OP(elementwise_mul) +USE_LITE_OP(elementwise_max) +USE_LITE_OP(fusion_elementwise_add_activation) +USE_LITE_OP(fusion_elementwise_mul_activation) +USE_LITE_OP(fusion_elementwise_max_activation) +USE_LITE_OP(square) +USE_LITE_OP(softmax) +USE_LITE_OP(dropout) +USE_LITE_OP(concat) +USE_LITE_OP(conv2d) +USE_LITE_OP(depthwise_conv2d) +USE_LITE_OP(pool2d) +USE_LITE_OP(batch_norm) +USE_LITE_OP(fusion_elementwise_sub_activation) +USE_LITE_OP(transpose) +USE_LITE_OP(transpose2) +USE_LITE_OP(argmax) +USE_LITE_OP(axpy) +USE_LITE_OP(leaky_relu) +USE_LITE_OP(relu_clipped) +USE_LITE_OP(prelu) +USE_LITE_OP(sigmoid) +USE_LITE_OP(tanh) +USE_LITE_OP(swish) +USE_LITE_OP(log) +USE_LITE_OP(conv2d_transpose) +USE_LITE_OP(negative) +USE_LITE_OP(pad2d) +USE_LITE_OP(power) +USE_LITE_OP(shuffle_channel) +USE_LITE_OP(yolo_box) +USE_LITE_OP(bilinear_interp) +USE_LITE_OP(nearest_interp) + +USE_LITE_OP(crop) +USE_LITE_OP(prior_box) +USE_LITE_OP(density_prior_box) +USE_LITE_OP(reshape) +USE_LITE_OP(reshape2) +USE_LITE_OP(split) +USE_LITE_OP(fake_quantize_moving_average_abs_max); +USE_LITE_OP(fake_dequantize_max_abs); +USE_LITE_OP(calib); +USE_LITE_OP(calib_once); +USE_LITE_OP(norm); +USE_LITE_OP(layout); +USE_LITE_OP(layout_once); +USE_LITE_OP(im2sequence); +USE_LITE_OP(sequence_softmax); +USE_LITE_OP(logical_xor); +USE_LITE_OP(logical_and); +USE_LITE_OP(less_than); +USE_LITE_OP(top_k); +USE_LITE_OP(increment); +USE_LITE_OP(write_to_array); +USE_LITE_OP(read_from_array); +USE_LITE_OP(gru_unit) +USE_LITE_OP(gru) +USE_LITE_OP(beam_search_decode) +USE_LITE_OP(beam_search) +USE_LITE_OP(fill_constant) +USE_LITE_OP(while) +USE_LITE_OP(lod_reset) +USE_LITE_OP(lookup_table) +USE_LITE_OP(multiclass_nms) +USE_LITE_OP(graph_op) +USE_LITE_OP(sequence_expand) +USE_LITE_OP(sequence_pool) +USE_LITE_OP(reduce_max) +USE_LITE_OP(is_empty) +USE_LITE_OP(shape) +USE_LITE_OP(slice) +USE_LITE_OP(cast) diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h new file mode 100644 index 00000000000..25eb103d9d9 --- /dev/null +++ b/lite/api/paddle_use_passes.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle_lite_factory_helper.h" // NOLINT + +USE_MIR_PASS(demo); +USE_MIR_PASS(static_kernel_pick_pass); +USE_MIR_PASS(variable_place_inference_pass); +USE_MIR_PASS(type_target_cast_pass); +USE_MIR_PASS(generate_program_pass); +USE_MIR_PASS(subgraph_program_pass); + +USE_MIR_PASS(io_copy_kernel_pick_pass); +USE_MIR_PASS(argument_type_display_pass); +USE_MIR_PASS(runtime_context_assign_pass); +USE_MIR_PASS(graph_visualze); + +USE_MIR_PASS(lite_conv_bn_fuse_pass); +USE_MIR_PASS(lite_fc_fuse_pass); +USE_MIR_PASS(identity_scale_eliminate_pass); +USE_MIR_PASS(lite_conv_elementwise_fuse_pass); +USE_MIR_PASS(lite_conv_activation_fuse_pass); +USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass); +USE_MIR_PASS(lite_quant_dequant_fuse_pass); +USE_MIR_PASS(type_precision_cast_pass); +USE_MIR_PASS(type_layout_cast_pass); diff --git a/lite/api/resnet18_test.cc b/lite/api/resnet18_test.cc new file mode 100644 index 00000000000..ad8248160c8 --- /dev/null +++ b/lite/api/resnet18_test.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_ARM +TEST(ResNet18, test) { + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + lite::Predictor predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}}); + + predictor.Build( + FLAGS_model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + std::vector> results; + // i = 1 + results.emplace_back(std::vector( + {0.00020891492, 0.00012855872, 0.00019274367, 0.00031139381, + 0.0003184143, 0.00022596598, 0.00025920002, 0.0006651449, + 0.0015664422, 0.0002835265, 0.0001418782, 0.0013916927, + 0.007779476, 0.0020724828, 0.0012296075, 0.00073855236, + 0.00014572912, 0.00025809053, 0.0004427299, 0.00042198936})); + auto* out = predictor.GetOutput(0); + ASSERT_EQ(out->dims().size(), 2); + ASSERT_EQ(out->dims()[0], 1); + ASSERT_EQ(out->dims()[1], 1000); + + int step = 50; + for (int i = 0; i < results.size(); ++i) { + for (int j = 0; j < results[i].size(); ++j) { + EXPECT_NEAR(out->data()[j * step + (out->dims()[1] * i)], + results[i][j], + 1e-6); + } + } +} +#endif + +} // namespace lite +} // namespace paddle diff --git a/lite/api/resnet50_test.cc b/lite/api/resnet50_test.cc new file mode 100644 index 00000000000..75404d173ff --- /dev/null +++ b/lite/api/resnet50_test.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_ARM +void TestModel(const std::vector& valid_places, + const Place& preferred_place) { + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + lite::Predictor predictor; + + predictor.Build(FLAGS_model_dir, preferred_place, valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + std::vector> results; + // i = 1 + results.emplace_back(std::vector( + {0.00024139918, 0.00020566184, 0.00022418296, 0.00041731037, + 0.0005366107, 0.00016948722, 0.00028638865, 0.0009257241, + 0.00072681636, 8.531815e-05, 0.0002129998, 0.0021168243, + 0.006387163, 0.0037145028, 0.0012812682, 0.00045948103, + 0.00013535398, 0.0002483765, 0.00076759676, 0.0002773295})); + auto* out = predictor.GetOutput(0); + ASSERT_EQ(out->dims().size(), 2); + ASSERT_EQ(out->dims()[0], 1); + ASSERT_EQ(out->dims()[1], 1000); + + int step = 50; + for (int i = 0; i < results.size(); ++i) { + for (int j = 0; j < results[i].size(); ++j) { + EXPECT_NEAR(out->data()[j * step + (out->dims()[1] * i)], + results[i][j], + 1e-6); + } + } +} + +TEST(ResNet50, test_arm) { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + }); + + TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); +} + +#ifdef LITE_WITH_OPENCL +TEST(ResNet50, test_opencl) { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kOpenCL), PRECISION(kFloat)}, + }); + + TestModel(valid_places, Place({TARGET(kOpenCL), PRECISION(kFloat)})); +} +#endif // LITE_WITH_OPENCL + +#endif // LITE_WITH_ARM + +} // namespace lite +} // namespace paddle diff --git a/lite/api/resnet50_test_fpga.cc b/lite/api/resnet50_test_fpga.cc new file mode 100644 index 00000000000..8a689c276fd --- /dev/null +++ b/lite/api/resnet50_test_fpga.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_FPGA +TEST(ResNet50, test) { + lite::Predictor predictor; + std::vector valid_places( + {Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)}, + Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNHWC)}}); + + predictor.Build(FLAGS_model_dir, + Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)}, + valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; +} +#endif + +} // namespace lite +} // namespace paddle diff --git a/lite/api/shufflenetv2_test.cc b/lite/api/shufflenetv2_test.cc new file mode 100644 index 00000000000..e3b119ec7a3 --- /dev/null +++ b/lite/api/shufflenetv2_test.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +void TestModel(const std::vector& valid_places, + const Place& preferred_place) { + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + lite::Predictor predictor; + + predictor.Build(FLAGS_model_dir, preferred_place, valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim((std::vector({1, 3, 224, 224})))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; ++i) { + data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + std::vector> results; + results.emplace_back(std::vector( + {0.00020622103, 9.36264e-05, 0.0002608151, 0.0004974526, + 0.00028529152, 9.3994095e-05, 0.00028626667, 0.0011567438, + 0.00094107876, 8.8955254e-05, 4.1932417e-05, 0.00016469292, + 0.006776762, 0.0028232741, 0.00024495262, 0.00022493803, + 0.00015700555, 0.00013883937, 0.00093898486, 0.00018184447})); + auto* out = predictor.GetOutput(0); + ASSERT_EQ(out->dims().size(), 2); + ASSERT_EQ(out->dims()[0], 1); + ASSERT_EQ(out->dims()[1], 1000); + + int step = 50; + for (int i = 0; i < results.size(); ++i) { + for (int j = 0; j < results[i].size(); ++j) { + EXPECT_NEAR(out->data()[j * step + (out->dims()[1] * i)], + results[i][j], + 1e-6); + } + } +} + +TEST(ShuffleNetV2, test_arm) { + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + // Place{TARGET(kOpenCL), PRECISION(kFloat)}, + }); + + TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/test_googlenet_lite.cc b/lite/api/test_googlenet_lite.cc new file mode 100644 index 00000000000..ffe4d141b3b --- /dev/null +++ b/lite/api/test_googlenet_lite.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" + +// for googlenet +DEFINE_string(model_dir, "", ""); + +namespace paddle { +namespace lite { +#ifdef LITE_WITH_X86 +TEST(CXXApi, test_lite_googlenet) { + lite::Predictor predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); + + // LOG(INFO)<<"FLAGS_eval_googlenet_dir:"<Resize(DDim(std::vector({1, 3, 224, 224}))); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < input_tensor->dims().production(); i++) { + data[i] = 1; + } + predictor.Run(); + + auto* out = predictor.GetOutput(0); + std::vector results( + {0.00034298553, 0.0008200012, 0.0005046297, 0.000839279, + 0.00052616704, 0.0003447803, 0.0010877076, 0.00081762316, + 0.0003941339, 0.0011430943, 0.0008892841, 0.00080191303, + 0.0004442384, 0.000658702, 0.0026721435, 0.0013686896, + 0.0005618166, 0.0006556497, 0.0006984528, 0.0014619455}); + for (size_t i = 0; i < results.size(); ++i) { + EXPECT_NEAR(out->data()[i * 51], results[i], 1e-5); + } + ASSERT_EQ(out->dims().size(), 2); + ASSERT_EQ(out->dims()[0], 1); + ASSERT_EQ(out->dims()[1], 1000); +} +#endif +} // namespace lite +} // namespace paddle diff --git a/lite/api/test_helper.h b/lite/api/test_helper.h new file mode 100644 index 00000000000..1a5ab31abd3 --- /dev/null +++ b/lite/api/test_helper.h @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +// for eval +DEFINE_string(model_dir, "", "model dir"); +DEFINE_int32(warmup, 0, "warmup times"); +DEFINE_int32(repeats, 1, "repeats times"); +DEFINE_int32(threads, 1, "threads num"); + +namespace paddle { +namespace lite { + +inline double GetCurrentUS() { + struct timeval time; + gettimeofday(&time, NULL); + return 1e+6 * time.tv_sec + time.tv_usec; +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/test_inceptionv4_lite_x86.cc b/lite/api/test_inceptionv4_lite_x86.cc new file mode 100644 index 00000000000..5b09fd8f489 --- /dev/null +++ b/lite/api/test_inceptionv4_lite_x86.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { + +TEST(InceptionV4, test_inceptionv4_lite_x86) { + lite::Predictor predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); + + // LOG(INFO)<<"FLAGS_eval_googlenet_dir:"< passes({"static_kernel_pick_pass", + "variable_place_inference_pass", + "type_target_cast_pass", + "variable_place_inference_pass", + "io_copy_kernel_pick_pass", + "variable_place_inference_pass", + "runtime_context_assign_pass"}); + predictor.Build( + model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places, passes); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < input_tensor->dims().production(); i++) { + data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", warmup: " << FLAGS_warmup + << ", repeats: " << FLAGS_repeats << ", spend " + << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + std::vector> results; + // i = 1 + results.emplace_back(std::vector( + {0.0011684548, 0.0010390386, 0.0011301535, 0.0010133048, + 0.0010259597, 0.0010982729, 0.00093195855, 0.0009141837, + 0.00096620916, 0.00089982944, 0.0010064574, 0.0010474789, + 0.0009782845, 0.0009230255, 0.0010548076, 0.0010974824, + 0.0010612885, 0.00089107914, 0.0010112736, 0.00097655767})); + + auto* out = predictor.GetOutput(0); + ASSERT_EQ(out->dims().size(), 2); + ASSERT_EQ(out->dims()[0], 1); + ASSERT_EQ(out->dims()[1], 1000); + + int step = 50; + for (int i = 0; i < results.size(); ++i) { + for (int j = 0; j < results[i].size(); ++j) { + EXPECT_NEAR(out->data()[j * step + (out->dims()[1] * i)], + results[i][j], + 1e-6); + } + } +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/test_mobilenetv1_lite_x86.cc b/lite/api/test_mobilenetv1_lite_x86.cc new file mode 100644 index 00000000000..84afac598e4 --- /dev/null +++ b/lite/api/test_mobilenetv1_lite_x86.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { + +TEST(Mobilenet_v1, test_mobilenetv1_lite_x86) { + lite::Predictor predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); + + std::string model_dir = FLAGS_model_dir; + std::vector passes({"static_kernel_pick_pass", + "variable_place_inference_pass", + "type_target_cast_pass", + "variable_place_inference_pass", + "io_copy_kernel_pick_pass", + "variable_place_inference_pass", + "runtime_context_assign_pass"}); + predictor.Build( + model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places, passes); + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < input_tensor->dims().production(); i++) { + data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", warmup: " << FLAGS_warmup + << ", repeats: " << FLAGS_repeats << ", spend " + << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + std::vector> results; + // i = 1 + results.emplace_back(std::vector( + {0.00019130898, 9.467885e-05, 0.00015971427, 0.0003650665, + 0.00026431272, 0.00060884043, 0.0002107942, 0.0015819625, + 0.0010323516, 0.00010079765, 0.00011006987, 0.0017364529, + 0.0048292773, 0.0013995157, 0.0018453331, 0.0002428986, + 0.00020211363, 0.00013668182, 0.0005855956, 0.00025901722})); + auto* out = predictor.GetOutput(0); + ASSERT_EQ(out->dims().size(), 2); + ASSERT_EQ(out->dims()[0], 1); + ASSERT_EQ(out->dims()[1], 1000); + + int step = 50; + for (int i = 0; i < results.size(); ++i) { + for (int j = 0; j < results[i].size(); ++j) { + EXPECT_NEAR(out->data()[j * step + (out->dims()[1] * i)], + results[i][j], + 1e-6); + } + } +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/test_mobilenetv2_lite_x86.cc b/lite/api/test_mobilenetv2_lite_x86.cc new file mode 100644 index 00000000000..fe98ffae7cf --- /dev/null +++ b/lite/api/test_mobilenetv2_lite_x86.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +// for googlenet + +namespace paddle { +namespace lite { + +TEST(Mobilenet_v2, test_mobilenetv2_lite_x86) { + lite::Predictor predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); + + // LOG(INFO)<<"FLAGS_eval_googlenet_dir:"< passes({"static_kernel_pick_pass", + "variable_place_inference_pass", + "type_target_cast_pass", + "variable_place_inference_pass", + "io_copy_kernel_pick_pass", + "variable_place_inference_pass", + "runtime_context_assign_pass"}); + predictor.Build( + model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places, passes); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < input_tensor->dims().production(); i++) { + data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", warmup: " << FLAGS_warmup + << ", repeats: " << FLAGS_repeats << ", spend " + << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + std::vector> results; + // i = 1 + results.emplace_back(std::vector( + {0.00017082224, 5.699624e-05, 0.000260885, 0.00016412718, + 0.00034818667, 0.00015230637, 0.00032959113, 0.0014772735, + 0.0009059976, 9.5378724e-05, 5.386537e-05, 0.0006427285, + 0.0070957416, 0.0016094646, 0.0018807327, 0.00010506048, + 6.823785e-05, 0.00012269315, 0.0007806194, 0.00022354358})); + auto* out = predictor.GetOutput(0); + ASSERT_EQ(out->dims().size(), 2); + ASSERT_EQ(out->dims()[0], 1); + ASSERT_EQ(out->dims()[1], 1000); + + int step = 50; + for (int i = 0; i < results.size(); ++i) { + for (int j = 0; j < results[i].size(); ++j) { + EXPECT_NEAR(out->data()[j * step + (out->dims()[1] * i)], + results[i][j], + 1e-6); + } + } +} + +} // namespace lite +} // namespace paddle diff --git a/lite/api/unet_test.cc b/lite/api/unet_test.cc new file mode 100644 index 00000000000..e1d8c9ec1e2 --- /dev/null +++ b/lite/api/unet_test.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_ARM +TEST(unet, test) { + DeviceInfo::Init(); + DeviceInfo::Global().SetRunMode(LITE_POWER_HIGH, FLAGS_threads); + lite::Predictor predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}}); + + predictor.Build( + FLAGS_model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 512, 512}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor.Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor.Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; + + // std::vector results({0.00078033, 0.00083865, 0.00060029, 0.00057083, + // 0.00070094, 0.00080584, 0.00044525, 0.00074907, + // 0.00059774, 0.00063654}); + // + std::vector> results; + // i = 1 + results.emplace_back(std::vector( + {0.9134332, 0.9652493, 0.959906, 0.96601194, 0.9704161, 0.973321, + 0.9763035, 0.9788776, 0.98090196, 0.9823532, 0.9830632, 0.98336476, + 0.9837605, 0.98430413, 0.9848935, 0.9854547, 0.9858877, 0.9862335, + 0.9865361, 0.9867324, 0.98686767, 0.9870094, 0.98710895, 0.98710257, + 0.98703253, 0.98695105, 0.98681927, 0.98661137, 0.98637575, 0.98613656, + 0.9858899, 0.98564225, 0.9853931, 0.9851323, 0.98487836, 0.9846578, + 0.9844529, 0.9842441, 0.98405427, 0.9839205, 0.98382735, 0.98373055, + 0.9836299, 0.9835474, 0.9834818, 0.9834427, 0.98343164, 0.9834163, + 0.9833809, 0.9833255, 0.9832343, 0.9831207, 0.98302484, 0.9829579, + 0.9829039, 0.98283756, 0.9827444, 0.98264474, 0.9825466, 0.98243505, + 0.982312, 0.98218083, 0.98203814, 0.981895, 0.9817609, 0.9816264, + 0.9814932, 0.9813706, 0.98124915, 0.9811211, 0.98099536, 0.9808748, + 0.98075336, 0.9806301, 0.98050594, 0.98038554, 0.980272, 0.9801562, + 0.9800356, 0.9799207, 0.9798147, 0.97971845, 0.97963905, 0.9795745, + 0.9795107, 0.97943753, 0.9793595, 0.97928876, 0.97922987, 0.9791764, + 0.97912955, 0.9790941, 0.9790663, 0.9790414, 0.9790204, 0.9790055, + 0.97899526, 0.9789867, 0.9789797, 0.9789748})); + auto* out = predictor.GetOutput(0); + ASSERT_EQ(out->dims().size(), 4); + ASSERT_EQ(out->dims()[0], 1); + ASSERT_EQ(out->dims()[1], 21); + + int step = 1; + for (int i = 0; i < results.size(); ++i) { + for (int j = 0; j < results[i].size(); ++j) { + EXPECT_NEAR(out->data()[j * step + (out->dims()[1] * i)], + results[i][j], + 1e-6); + } + } +} +#endif + +} // namespace lite +} // namespace paddle diff --git a/lite/arm/CMakeLists.txt b/lite/arm/CMakeLists.txt new file mode 100644 index 00000000000..8abd04b5233 --- /dev/null +++ b/lite/arm/CMakeLists.txt @@ -0,0 +1,2 @@ + +add_subdirectory(math) diff --git a/lite/arm/math/CMakeLists.txt b/lite/arm/math/CMakeLists.txt new file mode 100644 index 00000000000..9924425609d --- /dev/null +++ b/lite/arm/math/CMakeLists.txt @@ -0,0 +1,109 @@ +if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) + return() +endif() + +if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) + return() +endif() + +set(HAS_ARM_MATH_LIB_DIR OFF) +# will search name as "libmath_arm.${os}.${abi}.${lang}.a" +if(ARM_MATH_LIB_DIR AND EXISTS "${ARM_MATH_LIB_DIR}") + set(arm_math_name "") + if(ARM_TARGET_OS STREQUAL "android") + if(ARM_TARGET_ARCH_ABI STREQUAL "armv8") + set(arm_math_name "math_arm.android.armv8") + elseif(ARM_TARGET_ARCH_ABI STREQUAL "armv7") + set(arm_math_name "math_arm.android.armv7") + endif() + endif() + + if(ARM_TARGET_OS STREQUAL "armlinux" ) + if(ARM_TARGET_ARCH_ABI STREQUAL "armv8") + set(arm_math_name "math_arm.armlinux.armv8") + elseif(ARM_TARGET_ARCH_ABI STREQUAL "armv7") + set(arm_math_name "math_arm.armlinux.armv7") + endif() + endif() + + if(ARM_TARGET_LANG STREQUAL "clang") + set(arm_math_name "${arm_math_name}.clang") + else() + set(arm_math_name "${arm_math_name}.gcc") + endif() + + find_library(math_arm_file ${arm_math_name} ${ARM_MATH_LIB_DIR} NO_DEFAULT_PATH) + if(math_arm_file) + add_library(math_arm STATIC IMPORTED GLOBAL) + set_property(TARGET math_arm PROPERTY IMPORTED_LOCATION ${math_arm_file}) + message(STATUS "ARM math library imported: ${math_arm_file}") + set(HAS_ARM_MATH_LIB_DIR ON) + else() + message(WARNING "Can not find arm math library ${arm_math_name} in ${ARM_MATH_LIB_DIR}") + endif() +endif() + + +if (NOT HAS_ARM_MATH_LIB_DIR) + # TODO(xxx): seperate them and do not deps proto, eigen3 + cc_library(math_arm SRCS + funcs.cc + packed_sgemm.cc + sgemm.cc + softmax.cc + scale.cc + pooling.cc + elementwise.cc + lrn.cc + decode_bboxes.cc + multiclass_nms.cc + concat.cc + sgemv.cc + type_trans.cc + box_coder.cc + conv_impl.cc + conv_direct_3x3s1.cc + conv_direct_3x3s2.cc + conv_direct.cc + conv_depthwise_3x3_int7.cc + conv_depthwise_3x3_int8.cc + conv_depthwise_5x5s1_int8.cc + conv_depthwise_3x3p0.cc + conv_depthwise_3x3p1.cc + conv_depthwise_5x5s1.cc + conv_depthwise_5x5s2.cc + conv_depthwise.cc + conv_gemmlike.cc + conv_winograd_3x3.cc + conv_winograd.cc + split.cc + shuffle_channel.cc + activation.cc + yolo_box.cc + dropout.cc + gemm_prepacked_int8.cc + gemv_arm_int8.cc + conv3x3s1_direct_int8.cc + conv3x3s2_direct_int8.cc + power.cc + interpolate.cc + argmax.cc + axpy.cc + fill_bias_relu.cc + col_im_transform.cc + im2sequence.cc + prior_box.cc + sequence_softmax.cc + norm.cc + topk.cc + increment.cc + pad2d.cc + negative.cc + beam_search.cc + reduce_max.cc + sequence_pool.cc + sequence_expand.cc + slice.cc + DEPS ${lite_kernel_deps}) +endif() + diff --git a/lite/arm/math/activation.cc b/lite/arm/math/activation.cc new file mode 100644 index 00000000000..b5df8e793c7 --- /dev/null +++ b/lite/arm/math/activation.cc @@ -0,0 +1,638 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/activation.h" +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void act_relu(const float* din, float* dout, int size, int threads) { + int nums_per_thread = size / threads; + int remain = size - threads * nums_per_thread; + int neon_loop_cnt = nums_per_thread >> 4; + int neon_loop_remain = nums_per_thread - (neon_loop_cnt << 4); + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for + for (int i = 0; i < threads; ++i) { + const float* ptr_in_thread = din + i * nums_per_thread; + float* ptr_out_thread = dout + i * nums_per_thread; + int cnt = neon_loop_cnt; +#ifdef __aarch64__ + for (int num = 0; num < neon_loop_cnt; ++num) { + float32x4_t vr0 = vld1q_f32(ptr_in_thread); + ptr_in_thread += 4; + float32x4_t vr1 = vld1q_f32(ptr_in_thread); + ptr_in_thread += 4; + float32x4_t vr2 = vld1q_f32(ptr_in_thread); + ptr_in_thread += 4; + float32x4_t vr3 = vld1q_f32(ptr_in_thread); + ptr_in_thread += 4; + vr0 = vmaxq_f32(vr0, vzero); + vr1 = vmaxq_f32(vr1, vzero); + vr2 = vmaxq_f32(vr2, vzero); + vr3 = vmaxq_f32(vr3, vzero); + vst1q_f32(ptr_out_thread, vr0); + ptr_out_thread += 4; + vst1q_f32(ptr_out_thread, vr1); + ptr_out_thread += 4; + vst1q_f32(ptr_out_thread, vr2); + ptr_out_thread += 4; + vst1q_f32(ptr_out_thread, vr3); + ptr_out_thread += 4; + } + +#else + if (cnt > 0) { + asm volatile( + "1: @ loop header\n" + "vld1.32 {d0-d3}, [%[din]]! @ load din 0\n" + "vld1.32 {d4-d7}, [%[din]]! @ load din 0\n" + + "vmax.f32 q8, q0, %q[vzero] @ relu\n" + "vmax.f32 q9, q1, %q[vzero] @ relu\n" + "vmax.f32 q10, q2, %q[vzero] @ relu\n" + "vmax.f32 q11, q3, %q[vzero] @ relu\n" + + "vst1.32 {d16-d19}, [%[dout]]! @ store result, add pointer\n" + "vst1.32 {d20-d23}, [%[dout]]! @ store result, add pointer\n" + + "subs %[cnt], #1 @ loop count minus 1\n" + "bne 1b @ jump to main loop start " + "point\n" + : [dout] "+r"(ptr_out_thread), + [din] "+r"(ptr_in_thread), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero) + : "cc", "memory", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11"); + } +#endif + for (int j = 0; j < neon_loop_remain; ++j) { + ptr_out_thread[0] = ptr_in_thread[0] > 0.f ? ptr_in_thread[0] : 0.f; + ptr_in_thread++; + ptr_out_thread++; + } + } + float* out_ptr_remain = dout + threads * nums_per_thread; + const float* in_ptr_remain = din + threads * nums_per_thread; + for (int j = 0; j < remain; ++j) { + out_ptr_remain[0] = in_ptr_remain[0] > 0.f ? in_ptr_remain[0] : 0.f; + in_ptr_remain++; + out_ptr_remain++; + } +} + +template <> +void act_relu_neg(const float* din, + float* dout, + int size, + float negative_slope, + int threads) { + int nums_per_thread = size / threads; + int remain = size - threads * nums_per_thread; + int neon_loop_cnt = nums_per_thread >> 4; + int neon_loop_remain = nums_per_thread - (neon_loop_cnt << 4); + float32x4_t vzero = vdupq_n_f32(0.f); + float32x4_t valpha = vdupq_n_f32(negative_slope); +#pragma omp parallel for + for (int i = 0; i < threads; ++i) { + const float* ptr_in_thread = din + i * nums_per_thread; + float* ptr_out_thread = dout + i * nums_per_thread; + int cnt = neon_loop_cnt; +#ifdef __aarch64__ + for (int num = 0; num < neon_loop_cnt; ++num) { + float32x4_t vr0 = vld1q_f32(ptr_in_thread); + ptr_in_thread += 4; + float32x4_t vr1 = vld1q_f32(ptr_in_thread); + ptr_in_thread += 4; + float32x4_t vr2 = vld1q_f32(ptr_in_thread); + ptr_in_thread += 4; + float32x4_t vr3 = vld1q_f32(ptr_in_thread); + ptr_in_thread += 4; + + uint32x4_t vm0 = vcgeq_f32(vr0, vzero); + uint32x4_t vm1 = vcgeq_f32(vr1, vzero); + uint32x4_t vm2 = vcgeq_f32(vr2, vzero); + uint32x4_t vm3 = vcgeq_f32(vr3, vzero); + + float32x4_t vn0 = vmulq_f32(vr0, valpha); + float32x4_t vn1 = vmulq_f32(vr1, valpha); + float32x4_t vn2 = vmulq_f32(vr2, valpha); + float32x4_t vn3 = vmulq_f32(vr3, valpha); + + float32x4_t vo0 = vbslq_f32(vm0, vr0, vn0); + float32x4_t vo1 = vbslq_f32(vm1, vr1, vn1); + float32x4_t vo2 = vbslq_f32(vm2, vr2, vn2); + float32x4_t vo3 = vbslq_f32(vm3, vr3, vn3); + + vst1q_f32(ptr_out_thread, vo0); + ptr_out_thread += 4; + vst1q_f32(ptr_out_thread, vo1); + ptr_out_thread += 4; + vst1q_f32(ptr_out_thread, vo2); + ptr_out_thread += 4; + vst1q_f32(ptr_out_thread, vo3); + ptr_out_thread += 4; + } + +#else + if (cnt > 0) { + asm volatile( + "1: @ loop header\n" + "vld1.32 {d0-d3}, [%[din]]! @ load din 0\n" + "vld1.32 {d4-d7}, [%[din]]! @ load din 0\n" + + "vcge.f32 q8, q0, %q[vzero] @ get mask\n" + "vcge.f32 q9, q1, %q[vzero] @ get mask\n" + "vcge.f32 q10, q2, %q[vzero] @ get mask\n" + "vcge.f32 q11, q3, %q[vzero] @ get mask\n" + + "vmul.f32 q4, q0, %q[valpha] @ get neg data\n" + "vmul.f32 q5, q1, %q[valpha] @ get neg data\n" + "vmul.f32 q6, q2, %q[valpha] @ get neg data\n" + "vmul.f32 q7, q3, %q[valpha] @ get neg data\n" + + "vbit q4, q0, q8 @ bitsel, insert q0 to q4, " + "if q8 is 1\n" + "vbit q5, q1, q9 @ bitsel, insert q1 to q5, " + "if q9 is 1\n" + "vbit q6, q2, q10 @ bitsel, insert q2 to q6, " + "if q10 is 1\n" + "vbit q7, q3, q11 @ bitsel, insert q3 to q7, " + "if q11 is 1\n" + + "vst1.32 {d8-d11}, [%[dout]]! @ store result, add pointer\n" + "vst1.32 {d12-d15}, [%[dout]]! @ store result, add pointer\n" + + "subs %[cnt], #1 @ loop count minus 1\n" + "bne 1b @ jump to main loop start " + "point\n" + : [dout] "+r"(ptr_out_thread), + [din] "+r"(ptr_in_thread), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), [valpha] "w"(valpha) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11"); + } +#endif + for (int j = 0; j < neon_loop_remain; ++j) { + ptr_out_thread[0] = ptr_in_thread[0] > 0.f + ? ptr_in_thread[0] + : ptr_in_thread[0] * negative_slope; + ptr_in_thread++; + ptr_out_thread++; + } + } + float* out_ptr_remain = dout + threads * nums_per_thread; + const float* in_ptr_remain = din + threads * nums_per_thread; + for (int j = 0; j < remain; ++j) { + out_ptr_remain[0] = in_ptr_remain[0] > 0.f + ? in_ptr_remain[0] + : in_ptr_remain[0] * negative_slope; + in_ptr_remain++; + out_ptr_remain++; + } +} + +template <> +void act_clipped_relu( + const float* din, float* dout, int size, float coef, int threads) { + int nums_per_thread = size / threads; + int remain = size - threads * nums_per_thread; + int neon_loop_cnt = nums_per_thread >> 4; + int neon_loop_remain = nums_per_thread - (neon_loop_cnt << 4); + float32x4_t vzero = vdupq_n_f32(0.f); + float32x4_t vclip = vdupq_n_f32(coef); +#pragma omp parallel for + for (int i = 0; i < threads; ++i) { + const float* ptr_in_thread = din + i * nums_per_thread; + float* ptr_out_thread = dout + i * nums_per_thread; + int cnt = neon_loop_cnt; +#ifdef __aarch64__ + for (int num = 0; num < neon_loop_cnt; ++num) { + float32x4_t vr0 = vld1q_f32(ptr_in_thread); + ptr_in_thread += 4; + float32x4_t vr1 = vld1q_f32(ptr_in_thread); + ptr_in_thread += 4; + float32x4_t vr2 = vld1q_f32(ptr_in_thread); + ptr_in_thread += 4; + float32x4_t vr3 = vld1q_f32(ptr_in_thread); + ptr_in_thread += 4; + float32x4_t vt0 = vmaxq_f32(vr0, vzero); + float32x4_t vt1 = vmaxq_f32(vr1, vzero); + float32x4_t vt2 = vmaxq_f32(vr2, vzero); + float32x4_t vt3 = vmaxq_f32(vr3, vzero); + + float32x4_t vo0 = vminq_f32(vt0, vclip); + float32x4_t vo1 = vminq_f32(vt1, vclip); + float32x4_t vo2 = vminq_f32(vt2, vclip); + float32x4_t vo3 = vminq_f32(vt3, vclip); + + vst1q_f32(ptr_out_thread, vo0); + ptr_out_thread += 4; + vst1q_f32(ptr_out_thread, vo1); + ptr_out_thread += 4; + vst1q_f32(ptr_out_thread, vo2); + ptr_out_thread += 4; + vst1q_f32(ptr_out_thread, vo3); + ptr_out_thread += 4; + } +#else + if (cnt > 0) { + asm volatile( + "1: @ loop header\n" + "vld1.32 {d0-d3}, [%[din]]! @ load din 0\n" + "vld1.32 {d4-d7}, [%[din]]! @ load din 0\n" + + "vmax.f32 q8, q0, %q[vzero] @ relu\n" + "vmax.f32 q9, q1, %q[vzero] @ relu\n" + "vmax.f32 q10, q2, %q[vzero] @ relu\n" + "vmax.f32 q11, q3, %q[vzero] @ relu\n" + + "vmin.f32 q4, q8, %q[vclip] @ clip relu\n" + "vmin.f32 q5, q9, %q[vclip] @ clip relu\n" + "vmin.f32 q6, q10, %q[vclip] @ clip relu\n" + "vmin.f32 q7, q11, %q[vclip] @ clip relu\n" + + "vst1.32 {d8-d11}, [%[dout]]! @ store result, add pointer\n" + "vst1.32 {d12-d15}, [%[dout]]! @ store result, add pointer\n" + + "subs %[cnt], #1 @ loop count minus 1\n" + "bne 1b @ jump to main loop start " + "point\n" + : [dout] "+r"(ptr_out_thread), + [din] "+r"(ptr_in_thread), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), [vclip] "w"(vclip) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11"); + } +#endif + for (int j = 0; j < neon_loop_remain; ++j) { + ptr_out_thread[0] = ptr_in_thread[0] > 0.f ? ptr_in_thread[0] : 0.f; + ptr_out_thread[0] = ptr_out_thread[0] < coef ? ptr_out_thread[0] : coef; + ptr_in_thread++; + ptr_out_thread++; + } + } + float* out_ptr_remain = dout + threads * nums_per_thread; + const float* in_ptr_remain = din + threads * nums_per_thread; + for (int j = 0; j < remain; ++j) { + out_ptr_remain[0] = in_ptr_remain[0] > 0.f ? in_ptr_remain[0] : 0.f; + out_ptr_remain[0] = out_ptr_remain[0] < coef ? out_ptr_remain[0] : coef; + in_ptr_remain++; + out_ptr_remain++; + } +} + +template <> +void act_prelu(const float* din, + float* dout, + int outer_size, + int channel_size, + int inner_size, + std::string mode, + const float* alpha_data, + int threads) { + if (mode == "all" || mode == "channel") { + int stride_size = inner_size * channel_size; + int cnt = inner_size >> 4; + int remain = inner_size & 15; + float32x4_t vzero = vdupq_n_f32(0.f); + for (int n = 0; n < outer_size; n++) { + const float* data_in_batch = din + n * stride_size; + float* data_out_batch = dout + n * stride_size; +#pragma omp parallel for + for (int c = 0; c < channel_size; c++) { + const float* data_in_c = data_in_batch + c * inner_size; + float* data_out_c = data_out_batch + c * inner_size; + + float slope = mode == "all" ? alpha_data[0] : alpha_data[c]; + float32x4_t vslope = vdupq_n_f32(slope); +#ifdef __aarch64__ + for (int i = 0; i < cnt; ++i) { + float32x4_t vr0 = vld1q_f32(data_in_c); + float32x4_t vr1 = vld1q_f32(data_in_c + 4); + float32x4_t vr2 = vld1q_f32(data_in_c + 8); + float32x4_t vr3 = vld1q_f32(data_in_c + 12); + uint32x4_t vm0 = vcltq_f32(vr0, vzero); // vr0 <= vzero + uint32x4_t vm1 = vcltq_f32(vr1, vzero); // vr0 <= vzero + uint32x4_t vm2 = vcltq_f32(vr2, vzero); // vr0 <= vzero + uint32x4_t vm3 = vcltq_f32(vr3, vzero); // vr0 <= vzero + float32x4_t vo0 = vmulq_f32(vr0, vslope); // vr0 * vslope + float32x4_t vo1 = vmulq_f32(vr1, vslope); // vr0 * vslope + float32x4_t vo2 = vmulq_f32(vr2, vslope); // vr0 * vslope + float32x4_t vo3 = vmulq_f32(vr3, vslope); // vr0 * vslope + float32x4_t vos0 = vbslq_f32(vm0, vo0, vr0); + float32x4_t vos1 = vbslq_f32(vm1, vo1, vr1); + float32x4_t vos2 = vbslq_f32(vm2, vo2, vr2); + float32x4_t vos3 = vbslq_f32(vm3, vo3, vr3); + vst1q_f32(data_out_c, vos0); + vst1q_f32(data_out_c + 4, vos1); + vst1q_f32(data_out_c + 8, vos2); + vst1q_f32(data_out_c + 12, vos3); + data_in_c += 16; + data_out_c += 16; + } +#else + int cnt_loop = cnt; + if (cnt_loop > 0) { + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_in]]! @ load " + "input to q0, q1\n" + "pld [%[ptr_in]] @ preload\n" + "pld [%[ptr_in], #64] @ preload\n" + "pld [%[ptr_in], #128] @ preload\n" + "pld [%[ptr_in], #192] @ preload\n" + "1: @main loop\n" + "vld1.32 {d4-d7}, [%[ptr_in]]! @ load input to " + "q2, q3\n" + "vclt.f32 q8, q0, %q[vzero] @vcle q0 <= " + "vzero\n" + "vclt.f32 q9, q1, %q[vzero] @vcle q1 <= " + "vzero\n" + "vmul.f32 q10, q0, %q[vslope] @vmul q0 * " + "vslope\n" + "vmul.f32 q11, q1, %q[vslope] @vmul q1 * " + "vslope\n" + + "vclt.f32 q12, q2, %q[vzero] @vcle q2 <= " + "vzero\n" + "vclt.f32 q13, q3, %q[vzero] @vcle q3 <= " + "vzero\n" + "vmul.f32 q14, q2, %q[vslope] @vmul q2 * " + "vslope\n" + "vmul.f32 q15, q3, %q[vslope] @vmul q3 * " + "vslope\n" + + "vbif.32 q10, q0, q8 @vbit q10, q0, " + "q8\n" + "vbif.32 q11, q1, q9 @vbit q11, q1, " + "q9\n" + "vbif.32 q14, q2, q12 @vbit q14, q2, " + "q12\n" + "vbif.32 q15, q3, q13 @vbit q15, q3, " + "q13\n" + + "subs %[cnt], #1 @subs nn, 1\n" + "vld1.32 {d0-d3}, [%[ptr_in]]! @ load input to " + "q0, q1\n" + + "vst1.f32 {d20-d23}, [%[dout]]! @store data\n" + "vst1.f32 {d28-d31}, [%[dout]]! @store data\n" + "bne 1b @bne nn\n" + "sub %[ptr_in], #32 @ ptr-32\n" + : [ptr_in] "+r"(data_in_c), + [cnt] "+r"(cnt_loop), + [dout] "+r"(data_out_c) + : [vzero] "w"(vzero), [vslope] "w"(vslope) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } +#endif // __aarch64__ + for (int i = remain; i > 0; i--) { + *(data_out_c++) = + data_in_c[0] > 0.f ? data_in_c[0] : data_in_c[0] * slope; + data_in_c++; + } + } + } + } else { // mode = element + int stride_size = inner_size * channel_size; + for (int n = 0; n < outer_size; n++) { + const float* data_in_batch = din + n * stride_size; + const float* data_alpha_batch = alpha_data + n * stride_size; + float* data_out_batch = dout + n * stride_size; + for (int c = 0; c < channel_size; c++) { + const float* data_in_c = data_in_batch + c * inner_size; + const float* data_alpha_c = data_alpha_batch + c * inner_size; + float* data_out_c = data_out_batch + c * inner_size; + for (int i = 0; i < inner_size; i++) { + data_out_c[0] = data_in_c[0] > 0.f ? data_in_c[0] + : data_in_c[0] * data_alpha_c[0]; + data_in_c++; + data_alpha_c++; + data_out_c++; + } + } + } + } +} + +template <> +void act_sigmoid(const float* din, float* dout, int size, int threads) { + int nums_per_thread = size / threads; + int remain = size - threads * nums_per_thread; + int neon_loop_cnt_dim4 = nums_per_thread >> 2; + int neon_loop_remain_dim4 = nums_per_thread - (neon_loop_cnt_dim4 << 2); + + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for + for (int i = 0; i < threads; ++i) { + float32x4_t exp_vec = vdupq_n_f32(0.0f); + float32x4_t recip = vdupq_n_f32(0.0f); + const float* ptr_in_thread = din + i * nums_per_thread; + float* ptr_out_thread = dout + i * nums_per_thread; + for (int k = 0; k < neon_loop_cnt_dim4; ++k) { + exp_vec = exp_ps(vnegq_f32(vld1q_f32(ptr_in_thread))); + exp_vec = vaddq_f32(exp_vec, vdupq_n_f32(1.0f)); + recip = vrecpeq_f32(exp_vec); + recip = vmulq_f32(vrecpsq_f32(exp_vec, recip), recip); + recip = vmulq_f32(vrecpsq_f32(exp_vec, recip), recip); + vst1q_f32(ptr_out_thread, recip); + ptr_out_thread += 4; + ptr_in_thread += 4; + } + for (int j = 0; j < neon_loop_remain_dim4; ++j) { + ptr_out_thread[0] = 1.f / (1 + expf(-ptr_in_thread[0])); + ptr_in_thread++; + ptr_out_thread++; + } + } + float* ptr_out = dout + threads * nums_per_thread; + const float* ptr_in = din + threads * nums_per_thread; + for (int j = 0; j < remain; ++j) { + ptr_out[0] = 1.f / (1 + expf(-ptr_in[0])); + ptr_in++; + ptr_out++; + } +} + +// tanh : (exp(x) - exp(-x)) / (exp(x) + exp(-x)) +template <> +void act_tanh(const float* din, float* dout, int size, int threads) { + int nums_per_thread = size / threads; + int remain = size - threads * nums_per_thread; + int neon_loop_cnt_dim4 = nums_per_thread >> 2; + int neon_loop_remain_dim4 = nums_per_thread - (neon_loop_cnt_dim4 << 2); +#pragma omp parallel for + for (int i = 0; i < threads; ++i) { + float32x4_t exp_plus_vec = vdupq_n_f32(0.0f); + float32x4_t exp_minus_vec = vdupq_n_f32(0.0f); + float32x4_t exp_sum_vec = vdupq_n_f32(0.0f); + float32x4_t exp_diff_vec = vdupq_n_f32(0.0f); + float32x4_t recip = vdupq_n_f32(0.0f); + const float* ptr_in_thread = din + i * nums_per_thread; + float* ptr_out_thread = dout + i * nums_per_thread; + for (int k = 0; k < neon_loop_cnt_dim4; ++k) { + exp_plus_vec = exp_ps(vld1q_f32(ptr_in_thread)); + exp_minus_vec = exp_ps(vnegq_f32(vld1q_f32(ptr_in_thread))); + exp_sum_vec = vaddq_f32(exp_plus_vec, exp_minus_vec); + exp_diff_vec = vsubq_f32(exp_plus_vec, exp_minus_vec); + recip = div_ps(exp_diff_vec, exp_sum_vec); + vst1q_f32(ptr_out_thread, recip); + ptr_out_thread += 4; + ptr_in_thread += 4; + } + for (int j = 0; j < neon_loop_remain_dim4; ++j) { + ptr_out_thread[0] = (expf(ptr_in_thread[0]) - expf(-ptr_in_thread[0])) / + (expf(ptr_in_thread[0]) + expf(-ptr_in_thread[0])); + ptr_in_thread++; + ptr_out_thread++; + } + } + float* ptr_out = dout + threads * nums_per_thread; + const float* ptr_in = din + threads * nums_per_thread; + for (int j = 0; j < remain; ++j) { + ptr_out[0] = (expf(ptr_in[0]) - expf(-ptr_in[0])) / + (expf(ptr_in[0]) + expf(-ptr_in[0])); + ptr_in++; + ptr_out++; + } +} + +// swish: x /(1 + exp(-(b * x))) +template <> +void act_swish( + const float* din, float* dout, int size, float coef, int threads) { + int nums_per_thread = size / threads; + int remain = size - threads * nums_per_thread; + int neon_loop_cnt_dim4 = nums_per_thread >> 2; + int neon_loop_remain_dim4 = nums_per_thread - (neon_loop_cnt_dim4 << 2); + const float beta = coef; + float32x4_t vbeta = vdupq_n_f32(beta); + float32x4_t vone = vdupq_n_f32(1.f); +#pragma omp parallel for + for (int i = 0; i < threads; ++i) { + const float* ptr_in_thread = din + i * nums_per_thread; + float* ptr_out_thread = dout + i * nums_per_thread; + for (int k = 0; k < neon_loop_cnt_dim4; ++k) { + float32x4_t va = vld1q_f32(ptr_in_thread); // x + float32x4_t vb = vnegq_f32(vld1q_f32(ptr_in_thread)); // -x + float32x4_t vsum = vmulq_f32(vb, vbeta); + vsum = exp_ps(vsum); + float32x4_t vc = vaddq_f32(vone, vsum); + float32x4_t vrst = div_ps(va, vc); + vst1q_f32(ptr_out_thread, vrst); + ptr_out_thread += 4; + ptr_in_thread += 4; + } + for (int j = 0; j < neon_loop_remain_dim4; ++j) { + ptr_out_thread[0] = + ptr_in_thread[0] / (1.0 + expf(-ptr_in_thread[0] * beta)); + ptr_in_thread++; + ptr_out_thread++; + } + } + float* ptr_out = dout + threads * nums_per_thread; + const float* ptr_in = din + threads * nums_per_thread; + for (int j = 0; j < remain; ++j) { + ptr_out[0] = ptr_in[0] / (1.0 + expf(-ptr_in[0] * beta)); + ptr_in++; + ptr_out++; + } +} + +template <> +void act_log(const float* din, float* dout, int size, int threads) { + int nums_per_thread = size / threads; + int remain = size - threads * nums_per_thread; + int neon_loop_cnt_dim4 = nums_per_thread >> 2; + int neon_loop_remain_dim4 = nums_per_thread - (neon_loop_cnt_dim4 << 2); + LOG(INFO) << "nums_per_thread" << nums_per_thread; + LOG(INFO) << "remain" << remain; + LOG(INFO) << "neon_loop_cnt_dim4" << neon_loop_cnt_dim4; + LOG(INFO) << "neon_loop_remian_dim4" << neon_loop_remain_dim4; + + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for + for (int i = 0; i < threads; ++i) { + float32x4_t exp_vec = vdupq_n_f32(0.0f); + const float* ptr_in_thread = din + i * nums_per_thread; + float* ptr_out_thread = dout + i * nums_per_thread; + for (int k = 0; k < neon_loop_cnt_dim4; ++k) { + exp_vec = log_ps(vld1q_f32(ptr_in_thread)); + vst1q_f32(ptr_out_thread, exp_vec); + ptr_out_thread += 4; + ptr_in_thread += 4; + } + for (int j = 0; j < neon_loop_remain_dim4; ++j) { + ptr_out_thread[0] = logf(ptr_in_thread[0]); + ptr_in_thread++; + ptr_out_thread++; + } + } + float* ptr_out = dout + threads * nums_per_thread; + const float* ptr_in = din + threads * nums_per_thread; + for (int j = 0; j < remain; ++j) { + ptr_out[0] = logf(ptr_in[0]); + ptr_in++; + ptr_out++; + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/activation.h b/lite/arm/math/activation.h new file mode 100644 index 00000000000..c22c963cc10 --- /dev/null +++ b/lite/arm/math/activation.h @@ -0,0 +1,57 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void act_relu(const T* din, T* dout, int size, int threads); + +template +void act_relu_neg( + const T* din, T* dout, int size, float negative_slope, int threads); + +template +void act_clipped_relu(const T* din, T* dout, int size, float coef, int threads); + +template +void act_prelu(const T* din, + T* dout, + int outer_size, + int channel_size, + int inner_size, + std::string mode, + const float* alpha_data, + int threads); + +template +void act_sigmoid(const T* din, T* dout, int size, int threads); + +template +void act_tanh(const T* din, T* dout, int size, int threads); + +template +void act_swish(const T* din, T* dout, int size, float coef, int threads); + +template +void act_log(const T* din, T* dout, int size, int threads); +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/argmax.cc b/lite/arm/math/argmax.cc new file mode 100644 index 00000000000..878f807e1c2 --- /dev/null +++ b/lite/arm/math/argmax.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/argmax.h" +#include +#include +#include +#include +#include +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void argmax_func(const lite::Tensor *input, + const int axis, + lite::Tensor *output) { + auto input_ddim = input->dims(); + auto output_ddim = output->dims(); + + const int size = input_ddim[axis]; + const int in_channel = input_ddim.count(axis, input_ddim.size()); + const int out_channel = output_ddim.count(axis, output_ddim.size()); + const int in_stride = input_ddim.count(axis + 1, input_ddim.size()); + const int out_stride = input_ddim.count(0, axis); + + for (int n = 0; n < out_stride; n++) { + for (int k = 0; k < in_stride; k++) { + const float *in_ptr = input->data() + n * in_channel + k; + std::vector> vec; + vec.resize(size); + for (int i = 0; i < size; i++) { + vec[i] = std::make_pair(in_ptr[i * in_stride], i); + } + // sort + std::partial_sort(vec.begin(), + vec.begin() + 1, + vec.end(), + std::greater>()); + + // out + float *out_ptr = output->mutable_data() + n * out_channel + k; + *out_ptr = vec[0].second; + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/argmax.h b/lite/arm/math/argmax.h new file mode 100644 index 00000000000..c78cf2f7a8f --- /dev/null +++ b/lite/arm/math/argmax.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/operators/op_params.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void argmax_func(const lite::Tensor* input, + const int axis, + lite::Tensor* output); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/axpy.cc b/lite/arm/math/axpy.cc new file mode 100644 index 00000000000..ad4db7a2fa9 --- /dev/null +++ b/lite/arm/math/axpy.cc @@ -0,0 +1,203 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/axpy.h" +#include +#include +#include +#include "lite/arm/math/funcs.h" +#include "lite/arm/math/saturate.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void axpy_kernel_fp32(const float* scale, + const float* din, + const float* bias, + float* dout, + int num, + int channel, + int size, + int in_channel) { + int cnt = size >> 3; + int remain = size % 8; + for (int n = 0; n < num; n++) { + const float* din_ptr = din + n * in_channel; + const float* scale_ptr = scale + n * channel; + const float* bias_ptr = bias + n * in_channel; + float* dout_ptr = dout + n * in_channel; +#pragma omp parallel for + for (int c = 0; c < channel; c++) { + const float* din_ch_ptr = din_ptr + c * size; + const float* bias_ch_ptr = bias_ptr + c * size; + float* dout_ch_ptr = dout_ptr + c * size; + float32x4_t scale_val = vdupq_n_f32(scale_ptr[c]); + int col_cnt = cnt; + if (cnt > 0) { +#ifdef __aarch64__ + asm volatile( + "ld1 {v0.4s}, [%[din_ptr]], #16 \n" + "ld1 {v1.4s}, [%[bias_ptr]], #16 \n" + "1: \n" + "ld1 {v2.4s}, [%[din_ptr]], #16 \n" + "ld1 {v3.4s}, [%[bias_ptr]], #16 \n" + "fmul v4.4s , v0.4s, %[scale].4s \n" + "fmul v5.4s , v2.4s, %[scale].4s \n" + "fadd v4.4s, v4.4s, v1.4s \n" + "fadd v5.4s, v5.4s, v3.4s \n" + "ld1 {v0.4s}, [%[din_ptr]], #16 \n" + "ld1 {v1.4s}, [%[bias_ptr]], #16 \n" + "subs %[cnt], %[cnt], #1 \n" + "st1 {v4.4s}, [%[dout_ptr]], #16 \n" + "st1 {v5.4s}, [%[dout_ptr]], #16 \n" + "bne 1b \n" + : [din_ptr] "+r"(din_ch_ptr), + [bias_ptr] "+r"(bias_ch_ptr), + [dout_ptr] "+r"(dout_ch_ptr), + [cnt] "+r"(col_cnt) + : [scale] "w"(scale_val) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5"); +#else + asm volatile( + "vld1.32 {d2-d3}, [%[din_ptr]]! \n" + "vld1.32 {d4-d5}, [%[bias_ptr]]! \n" + "1: \n" + "vld1.32 {d6-d7}, [%[din_ptr]]! \n" + "vld1.32 {d8-d9}, [%[bias_ptr]]! \n" + "vmul.f32 q5, q1, %q[scale] \n" + "vmul.f32 q6, q3, %q[scale] \n" + "vadd.f32 q5, q5, q2 \n" + "vadd.f32 q6, q6, q4 \n" + "vld1.f32 {d2-d3}, [%[din_ptr]]! \n" + "vld1.f32 {d4-d5}, [%[bias_ptr]]! \n" + "subs %[cnt], #1 \n" + "vst1.32 {d10-d11}, [%[dout_ptr]]! \n" + "vst1.32 {d12-d13}, [%[dout_ptr]]! \n" + "bne 1b \n" + : [din_ptr] "+r"(din_ch_ptr), + [bias_ptr] "+r"(bias_ch_ptr), + [dout_ptr] "+r"(dout_ch_ptr), + [cnt] "+r"(col_cnt) + : [scale] "w"(scale_val) + : "cc", "memory", "q1", "q2", "q3", "q4", "q5", "q6"); +#endif + } + din_ch_ptr = din_ptr + c * size + cnt * 8; + bias_ch_ptr = bias_ptr + c * size + cnt * 8; + for (int i = 0; i < remain; i++) { + *dout_ch_ptr = (*din_ch_ptr) * scale_ptr[c] + (*bias_ch_ptr); + dout_ch_ptr++; + din_ch_ptr++; + bias_ch_ptr++; + } + } + } +} + +void axpy_kernel_int8(const int8_t* scale, + const int8_t* din, + const int8_t* bias, + int8_t* dout, + int num, + int channel, + int size, + int in_channel) { + int cnt = size >> 4; + int remain = size % 16; + for (int n = 0; n < num; n++) { + const int8_t* din_ptr = din + n * in_channel; + const int8_t* scale_ptr = scale + n * channel; + const int8_t* bias_ptr = bias + n * in_channel; + int8_t* dout_ptr = dout + n * in_channel; +#pragma omp parallel for + for (int c = 0; c < channel; c++) { + const int8_t* din_ch_ptr = din_ptr + c * size; + const int8_t* bias_ch_ptr = bias_ptr + c * size; + int8_t* dout_ch_ptr = dout_ptr + c * size; + int8x8_t scale_val = vdup_n_s8(scale_ptr[c]); + int col_cnt = cnt; + if (col_cnt > 0) { +#ifdef __aarch64__ + asm volatile( + "ld1 {v0.8b}, [%[din_ptr]], #8 \n" + "ld1 {v1.8b}, [%[bias_ptr]], #8 \n" + "1: \n" + "ld1 {v2.8b}, [%[din_ptr]], #8 \n" + "ld1 {v3.8b}, [%[bias_ptr]], #8 \n" + "smull v4.8h, v0.8b, %[scale].8b \n" + "smull v5.8h, v2.8b, %[scale].8b \n" + "saddw v4.8h, v4.8h, v1.8b \n" + "saddw v5.8h, v5.8h, v3.8b \n" + "ld1 {v0.8b}, [%[din_ptr]], #8 \n" + "ld1 {v1.8b}, [%[bias_ptr]], #8 \n" + "subs %[cnt], %[cnt], #1 \n" + // int16->int8 + "sqxtn v6.8b, v4.8h \n" + "sqxtn v7.8b, v5.8h \n" + "st1 {v6.8b}, [%[dout_ptr]], #8 \n" /* store c0r0*/ + "st1 {v7.8b}, [%[dout_ptr]], #8 \n" /* store c2r0*/ + "bne 1b \n" + : [din_ptr] "+r"(din_ch_ptr), + [bias_ptr] "+r"(bias_ch_ptr), + [dout_ptr] "+r"(dout_ch_ptr), + [cnt] "+r"(col_cnt) + : [scale] "w"(scale_val) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5"); +#else + asm volatile( + "vdup.s8 d0, %[scale] \n" + "vld1.8 {d2}, [%[din_ptr]]! \n" + "vld1.8 {d4}, [%[bias_ptr]]! \n" + "1: \n" + "vld1.8 {d3}, [%[din_ptr]]! \n" + "vld1.8 {d5}, [%[bias_ptr]]! \n" + "vmull.s8 q4, d2, d0 \n" + "vmull.s8 q5, d3, d0 \n" + "vaddw.s16 q4, q4, d4 \n" + "vaddw.s16 q5, q5, d5 \n" + "vld1.8 {d2}, [%[din_ptr]]! \n" + "vld1.8 {d4}, [%[bias_ptr]]! \n" + "subs %[cnt], #1 \n" + // int16->int8 + "vqmovn.s16 d12, q4 @ cnt to int8\n" + "vqmovn.s16 d13, q5 @ cnt to int8\n" + "vst1.32 {d12-d13}, [%[dout_ptr]]! \n" + "bne 1b \n" + : [din_ptr] "+r"(din_ch_ptr), + [bias_ptr] "+r"(bias_ch_ptr), + [dout_ptr] "+r"(dout_ch_ptr), + [cnt] "+r"(col_cnt) + : [scale] "r"(scale_val) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); +#endif + } + din_ch_ptr = din_ptr + c * size + cnt * 16; + bias_ch_ptr = bias_ptr + c * size + cnt * 16; + for (int i = 0; i < remain; i++) { + *dout_ch_ptr = saturate_cast( + roundf((*din_ch_ptr) * scale_ptr[c] + (*bias_ch_ptr))); + dout_ch_ptr++; + din_ch_ptr++; + bias_ch_ptr++; + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/axpy.h b/lite/arm/math/axpy.h new file mode 100644 index 00000000000..8245bf1d1a8 --- /dev/null +++ b/lite/arm/math/axpy.h @@ -0,0 +1,49 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/operators/op_params.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void axpy_kernel_fp32(const float* scale, + const float* din, + const float* bias, + float* dout, + int num, + int channel, + int size, + int in_channel); + +void axpy_kernel_int8(const int8_t* scale, + const int8_t* din, + const int8_t* bias, + int8_t* dout, + int num, + int channel, + int size, + int in_channel); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/beam_search.cc b/lite/arm/math/beam_search.cc new file mode 100644 index 00000000000..932db0e2c1f --- /dev/null +++ b/lite/arm/math/beam_search.cc @@ -0,0 +1,271 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/beam_search.h" +#include +#include +#include +#include +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +/* +* The basic items help to sort. +*/ +struct Item { + Item() {} + Item(size_t offset, size_t id, float score) + : offset(offset), id(id), score(score) {} + // offset in the higher lod level. + size_t offset; + // prefix id in the lower lod level. + // size_t prefix; + // the candidate id + size_t id; + // the corresponding score + float score; + + inline bool operator<(const Item &in) const { + return (score < in.score) || ((score == in.score) && (offset < in.offset)); + } + + inline void operator=(const Item &in) { + offset = in.offset; + id = in.id; + score = in.score; + } + + std::string ToString() { + std::ostringstream os; + os << "{"; + os << "offset: " << offset << ", "; + os << "id: " << id << ", "; + os << "score: " << score << ""; + os << "}"; + return os.str(); + } +}; + +/* + * Prune the source sentences all branchs finished, and it is optional. + * Pruning must one step later than finishing (thus pre_ids is needed here), + * since the end tokens must be writed out. + */ +void PruneEndBeams(const Tensor *pre_ids, + const LoD &abs_lod, + std::vector> *items, + size_t lod_level, + int end_id) { + auto *pre_ids_data = pre_ids->data(); + auto &high_level = abs_lod[lod_level]; + for (size_t src_idx = 0; src_idx < high_level.size() - 1; ++src_idx) { + size_t src_prefix_start = high_level[src_idx]; + size_t src_prefix_end = high_level[src_idx + 1]; + bool finish_flag = true; + for (size_t offset = src_prefix_start; offset < src_prefix_end; offset++) { + for (auto &item : items->at(offset)) { + if (item.id != static_cast(end_id) || + pre_ids_data[offset] != end_id) { + finish_flag = false; + break; + } + } + if (!finish_flag) break; + } + if (finish_flag) { // all branchs of the beam (source sentence) end and + // prune this beam + for (size_t offset = src_prefix_start; offset < src_prefix_end; offset++) + items->at(offset).clear(); + } + } +} + +/* + * Transform the items into a map whose key is offset, value is the items. + * NOTE low performance. + */ +std::vector> ToMap( + const std::vector> &items, size_t element_num) { + std::vector> result; + result.resize(element_num); + for (auto &entries : items) { + for (const auto &item : entries) { + result[item.offset].push_back(item); + } + } + return result; +} + +void Insert(std::vector *top_beam_ptr, + const Item &item, + size_t beam_size) { + std::vector &top_beam = *top_beam_ptr; + + size_t num_beams = top_beam.size(); + if (num_beams < beam_size) { + top_beam.resize(num_beams + 1); + num_beams++; + } else { + if (item < top_beam[beam_size - 1]) { + return; + } + } + + for (int k = static_cast(num_beams) - 2; k >= 0; --k) { + if (top_beam[k] < item) { + top_beam[k + 1] = top_beam[k]; + } else { + top_beam[k + 1] = item; + return; + } + } + top_beam[0] = item; +} + +/* + * For each source, select top beam_size records. + */ +std::vector> SelectTopBeamSizeItems(const Tensor *pre_ids, + const Tensor *pre_scores, + const Tensor *ids, + const Tensor *scores, + size_t lod_level, + size_t beam_size, + int end_id, + bool is_accumulated) { + std::vector> result; + + // find the current candidates + // auto abs_lod = framework::ToAbsOffset(scores->lod()); + auto abs_lod = scores->lod(); + auto *pre_ids_data = pre_ids->data(); + auto *pre_scores_data = pre_scores->data(); + + auto *ids_data = ids ? ids->data() : nullptr; + auto *scores_data = scores->data(); + + size_t num_seqs = abs_lod[lod_level].size() - 1; + size_t seq_width = 1; + for (int i = 1; i < scores->dims().size(); i++) { + seq_width *= scores->dims()[i]; + } + + for (size_t seq_id = 0; seq_id < num_seqs; ++seq_id) { + size_t seq_offset_start = abs_lod[lod_level][seq_id]; + size_t seq_offset_end = abs_lod[lod_level][seq_id + 1]; + + std::vector top_beam; + top_beam.reserve(beam_size); + + for (size_t offset = seq_offset_start; offset < seq_offset_end; ++offset) { + auto pre_id = pre_ids_data[offset]; + auto pre_score = pre_scores_data[offset]; + if (pre_id == end_id) { + // Allocate all probability mass to end_id for finished branchs and + // the other candidate ids can be ignored. + Item item(offset, end_id, pre_score); + Insert(&top_beam, item, beam_size); + } else { + size_t index = offset * seq_width; + for (size_t d = 0; d < seq_width; d++, index++) { + int64_t id = ids_data ? ids_data[index] : static_cast(d); + float score = is_accumulated + ? scores_data[index] + : pre_score + std::log(scores_data[index]); + Item item(offset, id, score); + Insert(&top_beam, item, beam_size); + } + } + } + + result.emplace_back(top_beam); + } + return result; +} + +void beam_search(const Tensor *pre_ids, + const Tensor *pre_scores, + const Tensor *ids, + const Tensor *scores, + Tensor *selected_ids, + Tensor *selected_scores, + Tensor *parent_idx, + int level, + int beam_size, + int end_id, + bool is_accumulated, + Context *ctx) { + // auto abs_lod = framework::ToAbsOffset(scores->lod()); + auto abs_lod = scores->lod(); + auto &high_level = abs_lod[level]; + auto items = SelectTopBeamSizeItems(pre_ids, + pre_scores, + ids, + scores, + level, + beam_size, + end_id, + is_accumulated); + auto selected_items = ToMap(items, high_level.back()); + + PruneEndBeams(pre_ids, abs_lod, &selected_items, level, end_id); + // calculate the output tensor's height + size_t num_instances = std::accumulate( + std::begin(selected_items), + std::end(selected_items), + 0, + [](size_t a, std::vector &b) { return a + b.size(); }); + // the output tensor shape should be [num_instances, 1] + auto dims = std::vector({static_cast(num_instances), 1}); + selected_ids->Resize(dims); + selected_scores->Resize(dims); + if (parent_idx) { + parent_idx->Resize(dims); + } + auto *selected_ids_data = selected_ids->mutable_data(); + auto *selected_scores_data = selected_scores->mutable_data(); + auto *parent_idx_data = + parent_idx ? parent_idx->mutable_data() : nullptr; + + // fill in data + std::vector low_level; + size_t low_offset = 0; + for (auto &items : selected_items) { + low_level.push_back(low_offset); + for (auto &item : items) { + if (parent_idx) { + parent_idx_data[low_offset] = static_cast(low_level.size() - 1); + } + selected_ids_data[low_offset] = item.id; + selected_scores_data[low_offset] = item.score; + low_offset++; + } + } + low_level.push_back(low_offset); + + // fill lod + LoD lod(2); + lod[0].assign(high_level.begin(), high_level.end()); + lod[1].assign(low_level.begin(), low_level.end()); + *(selected_ids->mutable_lod()) = lod; + *(selected_scores->mutable_lod()) = lod; +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/beam_search.h b/lite/arm/math/beam_search.h new file mode 100644 index 00000000000..2f07175e35e --- /dev/null +++ b/lite/arm/math/beam_search.h @@ -0,0 +1,41 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/context.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void beam_search(const Tensor* pre_ids, + const Tensor* pre_scores, + const Tensor* ids, + const Tensor* scores, + Tensor* selected_ids, + Tensor* selected_scores, + Tensor* parent_idx, + int level, + int beam_size, + int end_id, + bool is_accumulated, + Context* ctx); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/box_coder.cc b/lite/arm/math/box_coder.cc new file mode 100644 index 00000000000..9b3f32b56e2 --- /dev/null +++ b/lite/arm/math/box_coder.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/box_coder.h" +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void box_coder(lite::Tensor* proposals, + const lite::Tensor* anchors, + const lite::Tensor* variances, + const lite::Tensor* bbox_deltas, + const std::string code_type, + bool box_normalized, + int axis) { + if (code_type == "decode_center_size") { + float normalized = !box_normalized ? 1.f : 0; + + const float* anchor_data = anchors->data(); + const float* bbox_deltas_data = bbox_deltas->data(); + float* proposals_data = proposals->mutable_data(); + const float* variances_data = variances->data(); + + int N = bbox_deltas->dims()[0]; + int M = bbox_deltas->dims()[1]; + int len = bbox_deltas->dims()[2]; + + for (int64_t row_id = 0; row_id < N; ++row_id) { + for (int64_t col_id = 0; col_id < M; ++col_id) { + size_t offset = row_id * M * len + col_id * len; + int prior_box_offset = axis == 0 ? col_id * len : row_id * len; + int var_offset = axis == 0 ? col_id * len : row_id * len; + + auto anchor_data_tmp = anchor_data + prior_box_offset; + auto bbox_deltas_data_tmp = bbox_deltas_data + offset; + auto proposals_data_tmp = proposals_data + offset; + + auto anchor_width = + anchor_data_tmp[2] - anchor_data_tmp[0] + normalized; + auto anchor_height = + anchor_data_tmp[3] - anchor_data_tmp[1] + normalized; + auto anchor_center_x = anchor_data_tmp[0] + 0.5 * anchor_width; + auto anchor_center_y = anchor_data_tmp[1] + 0.5 * anchor_height; + + float bbox_center_x = 0, bbox_center_y = 0; + float bbox_width = 0, bbox_height = 0; + + auto variances_data_tmp = variances_data + var_offset; + + bbox_center_x = + variances_data_tmp[0] * bbox_deltas_data_tmp[0] * anchor_width + + anchor_center_x; + bbox_center_y = + variances_data_tmp[1] * bbox_deltas_data_tmp[1] * anchor_height + + anchor_center_y; + bbox_width = std::exp(variances_data_tmp[2] * bbox_deltas_data_tmp[2]) * + anchor_width; + bbox_height = + std::exp(variances_data_tmp[3] * bbox_deltas_data_tmp[3]) * + anchor_height; + + proposals_data_tmp[0] = bbox_center_x - bbox_width / 2; + proposals_data_tmp[1] = bbox_center_y - bbox_height / 2; + proposals_data_tmp[2] = bbox_center_x + bbox_width / 2 - normalized; + proposals_data_tmp[3] = bbox_center_y + bbox_height / 2 - normalized; + } + } + } else if (code_type == "encode_center_size") { + LOG(FATAL) << "not implemented type: " << code_type; + } else { + LOG(FATAL) << "not supported type: " << code_type; + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/box_coder.h b/lite/arm/math/box_coder.h new file mode 100644 index 00000000000..bbeb3e06184 --- /dev/null +++ b/lite/arm/math/box_coder.h @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void box_coder(lite::Tensor* proposals, + const lite::Tensor* anchors, + const lite::Tensor* variances, + const lite::Tensor* bbox_deltas, + const std::string code_type, + bool box_normalized, + int axis); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/col_im_transform.cc b/lite/arm/math/col_im_transform.cc new file mode 100644 index 00000000000..d909d4247d8 --- /dev/null +++ b/lite/arm/math/col_im_transform.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/col_im_transform.h" +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +inline bool is_a_ge_zero_and_a_lt_b(int a, int b) { + return static_cast(a) < static_cast(b); +} + +template <> +void col2im(const float* data_col, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + float* data_im) { + memset(data_im, 0, height * width * channels * sizeof(float)); + const int output_h = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int output_w = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + const int channel_size = height * width; + for (int channel = channels; channel--; data_im += channel_size) { + for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_row = -pad_h + kernel_row * dilation_h; + for (int output_rows = output_h; output_rows; output_rows--) { + if (!is_a_ge_zero_and_a_lt_b(input_row, height)) { + data_col += output_w; + } else { + int input_col = -pad_w + kernel_col * dilation_w; + for (int output_col = output_w; output_col; output_col--) { + if (is_a_ge_zero_and_a_lt_b(input_col, width)) { + data_im[input_row * width + input_col] += *data_col; + } + data_col++; + input_col += stride_w; + } + } + input_row += stride_h; + } + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/col_im_transform.h b/lite/arm/math/col_im_transform.h new file mode 100644 index 00000000000..8560679d7f4 --- /dev/null +++ b/lite/arm/math/col_im_transform.h @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void col2im(const Dtype* data_col, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + Dtype* data_im); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/concat.cc b/lite/arm/math/concat.cc new file mode 100644 index 00000000000..8dd156e2622 --- /dev/null +++ b/lite/arm/math/concat.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/concat.h" +#include +#include +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void concat_func(const std::vector &input, + const int axis, + lite::Tensor *output) { + size_t num = input.size(); + int rows = 1; + auto dim_0 = input[0]->dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int out_rows = rows, out_cols = 0; + + std::vector input_cols(input.size()); + for (int i = 0; i < num; ++i) { + int t_cols = input[i]->numel() / rows; + out_cols += t_cols; + input_cols[i] = t_cols; + } + + // computation + for (int k = 0; k < out_rows; ++k) { + float *dst_ptr = output->mutable_data() + k * out_cols; + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + const float *src_prt = input[j]->data() + k * col_len; + std::memcpy(dst_ptr + col_idx, src_prt, sizeof(float) * col_len); + col_idx += col_len; + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/concat.h b/lite/arm/math/concat.h new file mode 100644 index 00000000000..4c6159e9e09 --- /dev/null +++ b/lite/arm/math/concat.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/operators/op_params.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void concat_func(const std::vector &input, + const int axis, + lite::Tensor *output); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv3x3s1_direct_int8.cc b/lite/arm/math/conv3x3s1_direct_int8.cc new file mode 100644 index 00000000000..c34dffca29b --- /dev/null +++ b/lite/arm/math/conv3x3s1_direct_int8.cc @@ -0,0 +1,806 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/arm/math/conv_block_utils.h" +#include "lite/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +#ifdef __aarch64__ +void conv_3x3s1_direct_int8(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + Context* ctx, + PrecisionType out_type, + const float* scale) { + const int hin_r_block = 4; + const int hout_c_block = 4; // 8; + const int hout_r_block = 2; + + int stride_w = param.strides[1]; + int pad_w = param.paddings[1]; + int pad_h = param.paddings[0]; + bool flag_relu = param.fuse_relu; + bool flag_bias = (param.bias != nullptr); + + int wout_round = ((wout + 3) / 4) * 4; + int win_round = wout_round * stride_w + 4; + + int threads = ctx->threads(); + + int* tmp_work_space = ctx->workspace_data(); + int* ptr_zero = tmp_work_space; + memset(ptr_zero, 0, sizeof(int) * win_round); + int* ptr_write = ptr_zero + win_round; + + int in_len = win_round * chin; + int pre_in_size = hin_r_block * in_len; + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + signed char* pre_din = reinterpret_cast(ptr_write + wout_round); + + int size_in_channel = win * hin; + int size_out_channel = wout * hout; + int w_stride = chin * 9; + + int ws = -pad_w; + int we = ws + win_round; + int w_loop = wout_round / 4; + + int size_out = wout_round * hout_c_block; + + // printf("win_round: %d, wout_round: %d, ws: %d, we: %d\n", win_round, + // wout_round, ws, we); + // here + for (int n = 0; n < num; ++n) { + const signed char* din_batch = + static_cast(din) + n * chin * size_in_channel; + signed char* dout_batch = + reinterpret_cast(dout) + + n * chout * size_out_channel * PrecisionTypeLength(out_type); + + for (int h = 0; h < hout; h += 2) { + int hs = h - pad_h; + int he = hs + 4; + // printf("hs: %d, he: %d, chin: %d, hin: %d, win: %d \n", hs, he, chin, + // hin, win); + prepack_input_nxw(din_batch, + pre_din, + 0, + chin, + hs, + he, + ws, + we, + chin, + win, + hin, + (signed char*)ptr_zero); + +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < chout; c += hout_c_block) { +#ifdef ARM_WITH_OMP + int* pre_out = + reinterpret_cast(pre_din + (pre_in_size + 3) / 4 * 4) + + omp_get_thread_num() * pre_out_size; +#else + int* pre_out = + reinterpret_cast(pre_din + (pre_in_size + 3) / 4 * 4); +#endif + // printf("ptr_zero_int: %x, ptr_zero: %x, ptr_write: %x, pre_din: %x, + // pre_out: %x \n", ptr_zero_int, ptr_zero, ptr_write, pre_din, + // pre_out); + const signed char* inr0 = pre_din; + const signed char* inr1 = inr0 + in_len; + const signed char* inr2 = inr1 + in_len; + const signed char* inr3 = inr2 + in_len; + + const signed char* wc0 = + static_cast(weights) + c * w_stride; + + const int* bias_ptr = ptr_zero; + if (flag_bias) { + bias_ptr = static_cast(bias) + c; + } + // hout_r_block * wout_round * hout_c_block + fill_packed_bias_nxmw_int8( + bias_ptr, pre_out, hout_c_block, hout_r_block, wout_round); + + for (int i = 0; i < chin; ++i) { + const signed char* r0 = inr0; + const signed char* r1 = inr1; + const signed char* r2 = inr2; + const signed char* r3 = inr3; + + int* ptr_out0 = pre_out; + int* ptr_out1 = pre_out + size_out; + + int cnt = w_loop; + const signed char* ptr_wc0 = wc0; + + asm volatile( + "ldp q4, q5, [%[wc0]] \n" /* w4 w5 w6 w7 */ + "ldr q6, [%[wc0], #32] \n" /* w8 */ + "SXTL v11.8h, v4.8b \n" /* w to int16 */ + "SXTL2 v12.8h, v4.16b \n" /* w to int16 */ + "SXTL v13.8h, v5.8b \n" /* to int16 */ + "SXTL2 v14.8h, v5.16b \n" /* to int16 */ + "SXTL v15.8h, v6.8b \n" /* to int16 */ + "1: \n" /* main loop*/ + "ldr d0, [%[r0]] \n" /* load data din0-dinn7*/ + "SXTL v1.8h, v0.8b \n" /* to int16 */ + + /*output 1st row*/ + "smull v16.4s, v11.4h, v1.h[0] \n" /* */ + "smull v17.4s, v11.4h, v1.h[1] \n" /* */ + "smull v18.4s, v11.4h, v1.h[2] \n" /* */ + "smull v19.4s, v11.4h, v1.h[3] \n" /* */ + + "add %[r0], %[r0], #4\n" + + /*output 1st row*/ + "smlal2 v16.4s, v11.8h, v1.h[1] \n" /* */ + "smlal2 v17.4s, v11.8h, v1.h[2] \n" /* */ + "smlal2 v18.4s, v11.8h, v1.h[3] \n" /* */ + "smlal2 v19.4s, v11.8h, v1.h[4] \n" /* */ + + "ldr d0, [%[r1]] \n" /* load data */ + + /*output 1st row*/ + "smlal v16.4s, v12.4h, v1.h[2] \n" /* */ + "smlal v17.4s, v12.4h, v1.h[3] \n" /* */ + "SXTL v2.8h, v0.8b \n" /* to int16 */ + "smlal v18.4s, v12.4h, v1.h[4] \n" /* */ + "smlal v19.4s, v12.4h, v1.h[5] \n" /* */ + + "add %[r1], %[r1], #4 \n" + + /*output 1st row*/ + "smlal2 v16.4s, v12.8h, v2.h[0] \n" /* */ + "smlal2 v17.4s, v12.8h, v2.h[1] \n" /* */ + "smlal2 v18.4s, v12.8h, v2.h[2] \n" /* */ + "smlal2 v19.4s, v12.8h, v2.h[3] \n" /* */ + + /*output 1st row*/ + "smlal v16.4s, v13.4h, v2.h[1] \n" /* */ + "smlal v17.4s, v13.4h, v2.h[2] \n" /* */ + "smlal v18.4s, v13.4h, v2.h[3] \n" /* */ + "smlal v19.4s, v13.4h, v2.h[4] \n" /* */ + + /*output 1st row*/ + "smlal2 v16.4s, v13.8h, v2.h[2] \n" /* */ + "smlal2 v17.4s, v13.8h, v2.h[3] \n" /* */ + "smlal2 v18.4s, v13.8h, v2.h[4] \n" /* */ + "smlal2 v19.4s, v13.8h, v2.h[5] \n" /* */ + + /*output 2rd row*/ + "smull v24.4s, v11.4h, v2.h[0] \n" /* */ + "smull v25.4s, v11.4h, v2.h[1] \n" /* */ + "smull v26.4s, v11.4h, v2.h[2] \n" /* */ + "smull v27.4s, v11.4h, v2.h[3] \n" /* */ + + /*output 2rd row*/ + "smlal2 v24.4s, v11.8h, v2.h[1] \n" /* */ + "smlal2 v25.4s, v11.8h, v2.h[2] \n" /* */ + "smlal2 v26.4s, v11.8h, v2.h[3] \n" /* */ + "smlal2 v27.4s, v11.8h, v2.h[4] \n" /* */ + + "ldr d0, [%[r2]] \n" /* load data */ + + /*output 2rd row*/ + "smlal v24.4s, v12.4h, v2.h[2] \n" /* */ + "smlal v25.4s, v12.4h, v2.h[3] \n" /* */ + "SXTL v1.8h, v0.8b \n" /* to int16 */ + "smlal v26.4s, v12.4h, v2.h[4] \n" /* */ + "smlal v27.4s, v12.4h, v2.h[5] \n" /* */ + + /*output 1st row*/ + "smlal v16.4s, v14.4h, v1.h[0] \n" /* */ + "smlal v17.4s, v14.4h, v1.h[1] \n" /* */ + "smlal v18.4s, v14.4h, v1.h[2] \n" /* */ + "smlal v19.4s, v14.4h, v1.h[3] \n" /* */ + + "add %[r2], %[r2], #4 \n" + + /*output 1st row*/ + "smlal2 v16.4s, v14.8h, v1.h[1] \n" /* */ + "smlal2 v17.4s, v14.8h, v1.h[2] \n" /* */ + "smlal2 v18.4s, v14.8h, v1.h[3] \n" /* */ + "smlal2 v19.4s, v14.8h, v1.h[4] \n" /* */ + + "ldp q3, q4, [%[ptr_out0]] \n" + "ldp q5, q6, [%[ptr_out0], #32] \n" + + /*output 1st row*/ + "smlal v16.4s, v15.4h, v1.h[2] \n" /* */ + "smlal v17.4s, v15.4h, v1.h[3] \n" /* */ + "smlal v18.4s, v15.4h, v1.h[4] \n" /* */ + "smlal v19.4s, v15.4h, v1.h[5] \n" /* */ + + "ADD v3.4s, v16.4s, v3.4s \n" + "ADD v4.4s, v17.4s, v4.4s \n" + "ADD v5.4s, v18.4s, v5.4s \n" + "ADD v6.4s, v19.4s, v6.4s \n" + + "stp q3, q4, [%[ptr_out0]], #32 \n" /* save to + output*/ + "stp q5, q6, [%[ptr_out0]], #32 \n" /* save to + output*/ + + /*output 2rd row*/ + "smlal2 v24.4s, v12.8h, v1.h[0] \n" /* */ + "smlal2 v25.4s, v12.8h, v1.h[1] \n" /* */ + "smlal2 v26.4s, v12.8h, v1.h[2] \n" /* */ + "smlal2 v27.4s, v12.8h, v1.h[3] \n" /* */ + + /*output 2rd row*/ + "smlal v24.4s, v13.4h, v1.h[1] \n" /* */ + "smlal v25.4s, v13.4h, v1.h[2] \n" /* */ + "smlal v26.4s, v13.4h, v1.h[3] \n" /* */ + "smlal v27.4s, v13.4h, v1.h[4] \n" /* */ + + "ldr d0, [%[r3]] \n" /* load data */ + + /*output 2rd row*/ + "smlal2 v24.4s, v13.8h, v1.h[2] \n" /* */ + "smlal2 v25.4s, v13.8h, v1.h[3] \n" /* */ + "SXTL v2.8h, v0.8b \n" /* to int16 */ + "smlal2 v26.4s, v13.8h, v1.h[4] \n" /* */ + "smlal2 v27.4s, v13.8h, v1.h[5] \n" /* */ + + /*output 2rd row*/ + "smlal v24.4s, v14.4h, v2.h[0] \n" /* */ + "smlal v25.4s, v14.4h, v2.h[1] \n" /* */ + "smlal v26.4s, v14.4h, v2.h[2] \n" /* */ + "smlal v27.4s, v14.4h, v2.h[3] \n" /* */ + + "add %[r3], %[r3], #4 \n" + + /*output 2rd row*/ + "smlal2 v24.4s, v14.8h, v2.h[1] \n" /* */ + "smlal2 v25.4s, v14.8h, v2.h[2] \n" /* */ + "smlal2 v26.4s, v14.8h, v2.h[3] \n" /* */ + "smlal2 v27.4s, v14.8h, v2.h[4] \n" /* */ + + "ldp q3, q4, [%[ptr_out1]] \n" + "ldp q5, q6, [%[ptr_out1], #32] \n" + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1 */ + + /*output 2rd row*/ + "smlal v24.4s, v15.4h, v2.h[2] \n" /* */ + "smlal v25.4s, v15.4h, v2.h[3] \n" /* */ + "smlal v26.4s, v15.4h, v2.h[4] \n" /* */ + "smlal v27.4s, v15.4h, v2.h[5] \n" /* */ + + "ADD v3.4s, v24.4s, v3.4s \n" + "ADD v4.4s, v25.4s, v4.4s \n" + "ADD v5.4s, v26.4s, v5.4s \n" + "ADD v6.4s, v27.4s, v6.4s \n" + + "stp q3, q4, [%[ptr_out1]], #32 \n" /* save to output*/ + "stp q5, q6, [%[ptr_out1]], #32 \n" /* save to output*/ + + "bne 1b \n" /* jump to main loop*/ + + : [cnt] "+r"(cnt), + [wc0] "+r"(ptr_wc0), + [r0] "+r"(r0), + [r1] "+r"(r1), + [r2] "+r"(r2), + [r3] "+r"(r3), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v24", + "v25", + "v26", + "v27" + + ); + + wc0 += 9 * hout_c_block; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + } + if (out_type == PRECISION(kFloat)) { + write_to_output_c4_int32_1(pre_out, + reinterpret_cast(dout_batch), + hout_c_block, + hout_r_block, + c, + c + 4, + h, + h + 2, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + reinterpret_cast(ptr_write), + &scale[c], + out_type); + } else if (out_type == PRECISION(kInt8)) { + write_to_output_c4_int32_1(pre_out, + dout_batch, + hout_c_block, + hout_r_block, + c, + c + 4, + h, + h + 2, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + reinterpret_cast(ptr_write), + &scale[c], + out_type); + } else { // int32 + write_to_output_c4_int32(pre_out, + reinterpret_cast(dout_batch), + hout_c_block, + hout_r_block, + c, + c + 4, + h, + h + 2, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + ptr_write); + } + } + } + } +} + +#else + +void conv_3x3s1_direct_int8(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + Context* ctx, + PrecisionType out_type, + const float* scale) { + // printf("conv2_3x3s1_direct_int8 \n"); + + const int hin_r_block = 4; + const int hout_c_block = 4; // 8 + const int hout_r_block = 2; + + int stride_w = param.strides[1]; + int pad_w = param.paddings[1]; + int pad_h = param.paddings[0]; + bool flag_relu = param.fuse_relu; + bool flag_bias = (param.bias != nullptr); + + int wout_round = ((wout + 3) / 4) * 4; + int win_round = wout_round * stride_w + 4; + + int threads = ctx->threads(); + + int* tmp_work_space = ctx->workspace_data(); + int* ptr_zero = tmp_work_space; + memset(ptr_zero, 0, sizeof(int) * win_round); + int* ptr_write = ptr_zero + win_round; + + int in_len = win_round * chin; + int pre_in_size = hin_r_block * in_len; + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + signed char* pre_din = reinterpret_cast(ptr_write + wout_round); + + int size_in_channel = win * hin; + int size_out_channel = wout * hout; + int w_stride = chin * 9; + + int ws = -pad_w; + int we = ws + win_round; + int w_loop = wout_round / 4; + + int size_out = wout_round * hout_c_block; + + // printf("win_round: %d, wout_round: %d, ws: %d, we: %d\n", win_round, + // wout_round, ws, we); + + for (int n = 0; n < num; ++n) { + const signed char* din_batch = + static_cast(din) + n * chin * size_in_channel; + signed char* dout_batch = + reinterpret_cast(dout) + + n * chout * size_out_channel * PrecisionTypeLength(out_type); + + for (int h = 0; h < hout; h += 2) { + int hs = h - pad_h; + int he = hs + 4; + // printf("hs: %d, he: %d, chin: %d, hin: %d, win: %d \n", hs, he, chin, + // hin, win); + prepack_input_nxw(din_batch, + pre_din, + 0, + chin, + hs, + he, + ws, + we, + chin, + win, + hin, + (signed char*)ptr_zero); + +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < chout; c += hout_c_block) { // 4 +#ifdef ARM_WITH_OMP + int* pre_out = + reinterpret_cast(pre_din + (pre_in_size + 3) / 4 * 4) + + omp_get_thread_num() * pre_out_size; +#else + int* pre_out = + reinterpret_cast(pre_din + (pre_in_size + 3) / 4 * 4); +#endif + // printf("ptr_zero_int: %x, ptr_zero: %x, ptr_write: %x, pre_din: %x, + // pre_out: %x \n", ptr_zero_int, ptr_zero, ptr_write, pre_din, + // pre_out); + const signed char* inr0 = pre_din; + const signed char* inr1 = inr0 + in_len; + const signed char* inr2 = inr1 + in_len; + const signed char* inr3 = inr2 + in_len; + + const signed char* wc0 = + static_cast(weights) + c * w_stride; + + const int* bias_ptr = ptr_zero; + if (flag_bias) { + bias_ptr = static_cast(bias) + c; + } + // hout_r_block * wout_round * hout_c_block + fill_packed_bias_nxmw_int8( + bias_ptr, pre_out, hout_c_block, hout_r_block, wout_round); + + for (int i = 0; i < chin; ++i) { + const signed char* r0 = inr0; + const signed char* r1 = inr1; + const signed char* r2 = inr2; + const signed char* r3 = inr3; + + int* ptr_out0 = pre_out; + int* ptr_out1 = pre_out + size_out; + + int cnt = w_loop; + const signed char* ptr_wc = wc0; + + asm volatile( + "vld1.s8 {d0-d3}, [%[wc0]]! \n" /* wc0, wc1, wc2, wc3, wc4, + wc5, wc6, wc7*/ + "vld1.s8 {d4}, [%[wc0]]! \n" /* wc8 */ + "vmovl.s8 q3, d0 \n" /* q3 = w0, w1 */ + "vmovl.s8 q4, d1 \n" /* q4 = w2 ,w3 */ + "vmovl.s8 q5, d2 \n" /* q5 = w4, w5 */ + "vmovl.s8 q6, d3 \n" /* q6 = w6, w7 */ + "vmovl.s8 q7, d4 \n" /* q7 = w8 */ + + "1: \n" /* main loop*/ + "vld1.s32 {d0}, [%[r0]] \n" /* load data din0-dinn7*/ + "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ + /*output 1st row*/ + "vmull.s16 q8, d6, d0[0] \n" /* q8 = w0 * r0[0] */ + "vmull.s16 q9, d6, d0[1] \n" /* q9 = w0 * r0[2] */ + "vmull.s16 q10, d6, d0[2] \n" /* q10 = w0 * r0[4] */ + "vmull.s16 q11, d6, d0[3] \n" /* q11 = w0 * r0[6] */ + + "add %[r0], #4 \n" + + /*output 1st row*/ + "vmlal.s16 q8, d7, d0[1] \n" /* q8 = w1 * r0[1] */ + "vmlal.s16 q9, d7, d0[2] \n" /* q9 = w1 * r0[2] */ + "vmlal.s16 q10, d7, d0[3] \n" /* q10 = w1 * r0[3] */ + "vmlal.s16 q11, d7, d1[0] \n" /* q11 = w1 * r0[4] */ + + "vld1.s32 {d2}, [%[r1]] \n" /* load input r1 -> d2 */ + "vmovl.s8 q1, d2 \n" /* movl d2 -> q1 */ + + /*output 1st row*/ + "vmlal.s16 q8, d8, d0[2] \n" /* q8 = w2 * r0[2] */ + "vmlal.s16 q9, d8, d0[3] \n" /* q9 = w2 * r0[3] */ + "vmlal.s16 q10, d8, d1[0] \n" /* q10 = w2 * r0[4] */ + "vmlal.s16 q11, d8, d1[1] \n" /* q11 = w2 * r0[5] */ + + /*output 1st row*/ + "vmlal.s16 q8, d9, d2[0] \n" /* */ + "vmlal.s16 q9, d9, d2[1] \n" /* */ + "vmlal.s16 q10, d9, d2[2] \n" /* */ + "vmlal.s16 q11, d9, d2[3] \n" /* */ + + "add %[r1], #4 \n" + + /*output 1st row*/ + "vmlal.s16 q8, d10, d2[1] \n" /* */ + "vmlal.s16 q9, d10, d2[2] \n" /* */ + "vmlal.s16 q10, d10, d2[3] \n" /* */ + "vmlal.s16 q11, d10, d3[0] \n" /* */ + + /*output 1st row*/ + "vmlal.s16 q8, d11, d2[2] \n" /* */ + "vmlal.s16 q9, d11, d2[3] \n" /* */ + "vmlal.s16 q10, d11, d3[0] \n" /* */ + "vmlal.s16 q11, d11, d3[1] \n" /* */ + + /*output 2rd row*/ + "vmull.s16 q12, d6, d2[0] \n" /* */ + "vmull.s16 q13, d6, d2[1] \n" /* */ + "vmull.s16 q14, d6, d2[2] \n" /* */ + "vmull.s16 q15, d6, d2[3] \n" /* */ + + "vld1.s32 {d0}, [%[r2]] \n" /* load input r2 -> d2 */ + "vmovl.s8 q0, d0 \n" /* movl d2 -> q1 */ + + /*output 2rd row*/ + "vmlal.s16 q12, d7, d2[1] \n" /* */ + "vmlal.s16 q13, d7, d2[2] \n" /* */ + "vmlal.s16 q14, d7, d2[3] \n" /* */ + "vmlal.s16 q15, d7, d3[0] \n" /* */ + + /*output 2rd row*/ + "vmlal.s16 q12, d8, d2[2] \n" /* */ + "vmlal.s16 q13, d8, d2[3] \n" /* */ + "vmlal.s16 q14, d8, d3[0] \n" /* */ + "vmlal.s16 q15, d8, d3[1] \n" /* */ + + "add %[r2], #4 \n" + + /*output 1st row*/ + "vmlal.s16 q8, d12, d0[0] \n" /* */ + "vmlal.s16 q9, d12, d0[1] \n" /* */ + "vmlal.s16 q10, d12, d0[2] \n" /* */ + "vmlal.s16 q11, d12, d0[3] \n" /* */ + + /*output 1st row*/ + "vmlal.s16 q8, d13, d0[1] \n" /* */ + "vmlal.s16 q9, d13, d0[2] \n" /* */ + "vmlal.s16 q10, d13, d0[3] \n" /* */ + "vmlal.s16 q11, d13, d1[0] \n" /* */ + + "vld1.32 {d2-d5}, [%[ptr_out0]] \n" /* load ptr_out -> q, q + */ + + /*output 1st row*/ + "vmlal.s16 q8, d14, d0[2] \n" /* */ + "vmlal.s16 q9, d14, d0[3] \n" /* */ + "vmlal.s16 q10, d14, d1[0] \n" /* */ + "vmlal.s16 q11, d14, d1[1] \n" /* */ + + /*load & store output 1st row*/ + "vadd.s32 q1, q8, q1 \n" /* out[0] += q8 */ + "vadd.s32 q2, q9, q2 \n" /* out[0] += q8 */ + "vst1.s32 {d2-d5}, [%[ptr_out0]]! \n" + + /*output 2rd row*/ + "vmlal.s16 q12, d9, d0[0] \n" /* */ + "vmlal.s16 q13, d9, d0[1] \n" /* */ + "vmlal.s16 q14, d9, d0[2] \n" /* */ + "vmlal.s16 q15, d9, d0[3] \n" /* */ + + "vld1.32 {d2-d5}, [%[ptr_out0]] \n" /* load ptr_out -> q2, q3 + */ + + /*output 2rd row */ + "vmlal.s16 q12, d10, d0[1] \n" /* */ + "vmlal.s16 q13, d10, d0[2] \n" /* */ + "vadd.s32 q1, q10, q1 \n" /* out[0] += q */ + "vadd.s32 q2, q11, q2 \n" /* out[1] += q */ + + "vmlal.s16 q14, d10, d0[3] \n" /* */ + "vst1.s32 {d2-d5}, [%[ptr_out0]]! \n" + "vmlal.s16 q15, d10, d1[0] \n" /* */ + + /*output 2rd row */ + "vmlal.s16 q12, d11, d0[2] \n" /* */ + "vmlal.s16 q13, d11, d0[3] \n" /* */ + + "vld1.s32 {d4}, [%[r3]] \n" /* load input r2 -> d2 + */ + "vmovl.s8 q2, d4 \n" /* movl d2 -> q2 */ + + "vmlal.s16 q14, d11, d1[0] \n" /* */ + "vmlal.s16 q15, d11, d1[1] \n" /* */ + + "add %[r3], #4 \n" + + /*output 2rd row */ + "vmlal.s16 q12, d12, d4[0] \n" /* */ + "vmlal.s16 q13, d12, d4[1] \n" /* */ + "vmlal.s16 q14, d12, d4[2] \n" /* */ + "vmlal.s16 q15, d12, d4[3] \n" /* */ + + "vld1.32 {d0-d3}, [%[ptr_out1]] \n" /* */ + + /*output 2rd row */ + "vmlal.s16 q12, d13, d4[1] \n" /* */ + "vmlal.s16 q13, d13, d4[2] \n" /* */ + "vmlal.s16 q14, d13, d4[3] \n" /* */ + "vmlal.s16 q15, d13, d5[0] \n" /* */ + + "subs %[cnt], #1 \n" + + /*output 2rd row */ + "vmlal.s16 q12, d14, d4[2] \n" /* */ + "vmlal.s16 q13, d14, d4[3] \n" /* */ + "vmlal.s16 q14, d14, d5[0] \n" /* */ + "vmlal.s16 q15, d14, d5[1] \n" /* */ + + /*output 2rd row*/ + "vadd.s32 q0, q12, q0 \n" /* */ + "vadd.s32 q1, q13, q1 \n" /* */ + "vst1.s32 {d0-d3}, [%[ptr_out1]]! \n" + + "vld1.32 {d0-d3}, [%[ptr_out1]] \n" /* */ + "vadd.s32 q0, q14, q0 \n" /* */ + "vadd.s32 q1, q15, q1 \n" /* */ + "vst1.s32 {d0-d3}, [%[ptr_out1]]! \n" + + "bne 1b \n" /* jump to main loop*/ + + : [cnt] "+r"(cnt), + [r0] "+r"(r0), + [r1] "+r"(r1), + [r2] "+r"(r2), + [r3] "+r"(r3), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1), + [wc0] "+r"(ptr_wc) + : + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + + wc0 += 9 * hout_c_block; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + } + + if (out_type == PRECISION(kFloat)) { + write_to_output_c4_int32_1(pre_out, + reinterpret_cast(dout_batch), + hout_c_block, + hout_r_block, + c, + c + 4, + h, + h + 2, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + reinterpret_cast(ptr_write), + &scale[c], + out_type); + } else if (out_type == PRECISION(kInt8)) { + write_to_output_c4_int32_1(pre_out, + dout_batch, + hout_c_block, + hout_r_block, + c, + c + 4, + h, + h + 2, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + reinterpret_cast(ptr_write), + &scale[c], + out_type); + } else { // int32 + write_to_output_c4_int32(pre_out, + reinterpret_cast(dout_batch), + hout_c_block, + hout_r_block, + c, + c + 4, + h, + h + 2, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + ptr_write); + } + } + } + } +} + +#endif // __aarch64__ + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv3x3s2_direct_int8.cc b/lite/arm/math/conv3x3s2_direct_int8.cc new file mode 100644 index 00000000000..a73f685283c --- /dev/null +++ b/lite/arm/math/conv3x3s2_direct_int8.cc @@ -0,0 +1,1081 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/arm/math/conv_block_utils.h" +#include "lite/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +#ifdef __aarch64__ +int conv_3x3s2_direct_int8_c_num() { return 8; } +void conv_3x3s2_direct_int8(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + Context* ctx, + PrecisionType out_type, + const float* scale) { + //! 3x3s2 int8 convolution, implemented by direct algorithm + //! prepack input to tmp buffer + //! write output to tmp buffer + int threads = ctx->threads(); + int stride_w = param.strides[1]; + int pad_w = param.paddings[1]; + int pad_h = param.paddings[0]; + bool flag_relu = param.fuse_relu; + bool flag_bias = (param.bias != nullptr); + + //! set 2/3 l2 cache + int l2_size = ctx->llc_size() / 3 * 2; + const int hout_c_block = 8; + const int hout_r_kernel = 2; + const int wout_round = ((wout + 3) / 4) * 4; + const int win_round = wout_round * stride_w + 1; + + //! get h block + //! win_round * chin * hin_r_block * sizeof(int8_t) + wout_round * + //! hout_c_block * hout_r_block * threads * sizeof(int32_t)= l2_size + //! win_round = 2 * wout_round + 1 + //! hin_r_block = 2 * hout_r_block + 1 + int hout_r_block = + (l2_size - 2 * wout_round * chin - chin) / + ((4 * wout_round + 2) * chin + wout_round * hout_c_block * threads * 4); + hout_r_block = hout_r_block > hout ? hout : hout_r_block; + hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel; + hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + + const int hin_r_block = hout_r_block * 2 + 1; + + int8_t* tmp_work_space = ctx->workspace_data(); + int zero_size = chout > (win_round + 3) / 4 ? chout : (win_round + 3) / 4; + const int kZeroSize = zero_size; + int32_t ptr_zero[kZeroSize]; + memset(ptr_zero, 0, sizeof(int32_t) * zero_size); + const int kWoutRound = wout_round; + int32_t ptr_write[kWoutRound]; + + int in_len = win_round * chin; + int pre_in_size = hin_r_block * in_len; + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + //! l2_cache start + int8_t* pre_din = tmp_work_space; + + int size_in_channel = win * hin; + int size_out_channel = wout * hout; + int w_stride = chin * 9; + + int ws = -pad_w; + int we = ws + win_round; + int w_loop = wout_round / 4; + + int out_row_stride = hout_c_block * wout_round; + + for (int n = 0; n < num; ++n) { + const int8_t* din_batch = din + n * chin * size_in_channel; + int8_t* dout_batch = + reinterpret_cast(dout) + + n * chout * size_out_channel * PrecisionTypeLength(out_type); + for (int h = 0; h < hout; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > hout) { + h_kernel = hout - h; + } + int hs = h * 2 - pad_h; + int he = hs + h_kernel * 2 + 1; + prepack_input_nxw(din_batch, + pre_din, + 0, + chin, + hs, + he, + ws, + we, + chin, + win, + hin, + reinterpret_cast(ptr_zero)); + + const int8_t* cblock_inr0 = pre_din; + const int8_t* cblock_inr1 = cblock_inr0 + in_len; + const int8_t* cblock_inr2 = cblock_inr1 + in_len; + const int8_t* cblock_inr3 = cblock_inr2 + in_len; + const int8_t* cblock_inr4 = cblock_inr3 + in_len; + +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < chout; c += hout_c_block) { +#ifdef ARM_WITH_OMP + int32_t* pre_out = + reinterpret_cast(pre_din + (pre_in_size + 3) / 4 * 4) + + omp_get_thread_num() * pre_out_size; +#else + int32_t* pre_out = + reinterpret_cast(pre_din + (pre_in_size + 3) / 4 * 4); +#endif + const int8_t* block_inr0 = cblock_inr0; + const int8_t* block_inr1 = cblock_inr1; + const int8_t* block_inr2 = cblock_inr2; + const int8_t* block_inr3 = cblock_inr3; + const int8_t* block_inr4 = cblock_inr4; + + const int8_t* weight_c = weights + c * w_stride; + const int32_t* bias_ptr = ptr_zero; + if (flag_bias) { + bias_ptr = bias + c; + } + + fill_packed_bias_nxmw_int8(bias_ptr, pre_out, 8, h_kernel, wout_round); + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + const int8_t* wc0 = weight_c; + + const int8_t* inr0 = block_inr0; + const int8_t* inr1 = block_inr1; + const int8_t* inr2 = block_inr2; + const int8_t* inr3 = block_inr3; + const int8_t* inr4 = block_inr4; + + int32_t* pre_out0 = pre_out + hk * out_row_stride; + int32_t* pre_out1 = pre_out0 + out_row_stride; + for (int i = 0; i < chin; ++i) { + int16x8_t v0 = vmovl_s8(vld1_s8(wc0)); // w0 + int16x8_t v1 = vmovl_s8(vld1_s8(wc0 + 8)); // w1 + int16x8_t v2 = vmovl_s8(vld1_s8(wc0 + 16)); // w2, + + int16x8_t v3 = vmovl_s8(vld1_s8(wc0 + 24)); // w3 + int16x8_t v4 = vmovl_s8(vld1_s8(wc0 + 32)); // w4 + int16x8_t v5 = vmovl_s8(vld1_s8(wc0 + 40)); // w5 + + int16x8_t v6 = vmovl_s8(vld1_s8(wc0 + 48)); // w6 + int16x8_t v7 = vmovl_s8(vld1_s8(wc0 + 56)); // w7 + int16x8_t v8 = vmovl_s8(vld1_s8(wc0 + 64)); // w8 + + const int8_t* r0 = inr0; + const int8_t* r1 = inr1; + const int8_t* r2 = inr2; + const int8_t* r3 = inr3; + const int8_t* r4 = inr4; + + int32_t* ptr_out0 = pre_out0; + int32_t* ptr_out1 = pre_out1; + int cnt = w_loop; + + asm volatile( + "ldr q0, [%[r0]], #8 \n" /* load input r0 */ + "ldr q1, [%[r2]], #8 \n" /* load input r2 */ + "sshll v0.8h, v0.8b, #0 \n" /* r0: int8 -> int16 */ + "sshll v1.8h, v1.8b, #0 \n" /* r1: int8 -> int16*/ + "1: \n" /* main loop */ + + /* r0, r2 mul w00 */ + "smull v4.4s, %[v0].4h, v0.h[0]\n" /* outr00 = v0 * r0[0] + */ + "smull2 v5.4s, %[v0].8h, v0.h[0]\n" /* outr00 = v0 * r0[0] + */ + "smull v6.4s, %[v0].4h, v0.h[2]\n" /* outr01 = v0 * r0[2] + */ + "smull2 v7.4s, %[v0].8h, v0.h[2]\n" /* outr00 = v0 * r0[0] + */ + "smull v8.4s, %[v0].4h, v0.h[4]\n" /* outr02 = v0 * r0[4] + */ + "smull2 v9.4s, %[v0].8h, v0.h[4]\n" /* outr00 = v0 * r0[0] + */ + "smull v10.4s, %[v0].4h, v0.h[6]\n" /* outr03 = v0 * r0[6] + */ + "smull2 v11.4s, %[v0].8h, v0.h[6]\n" /* outr00 = v0 * r0[0] + */ + + "smull v12.4s, %[v0].4h, v1.h[0]\n" /* outr10 = v0 * r2[0] + */ + "smull2 v13.4s, %[v0].8h, v1.h[0]\n" /* outr11 = v0 * r2[2] + */ + "smull v14.4s, %[v0].4h, v1.h[2]\n" /* outr12 = v0 * r2[4] + */ + "smull2 v15.4s, %[v0].8h, v1.h[2]\n" /* outr13 = v0 * r2[6] + */ + "smull v16.4s, %[v0].4h, v1.h[4]\n" /* outr10 = v0 * r2[0] + */ + "smull2 v17.4s, %[v0].8h, v1.h[4]\n" /* outr11 = v0 * r2[2] + */ + "smull v18.4s, %[v0].4h, v1.h[6]\n" /* outr12 = v0 * r2[4] + */ + "smull2 v19.4s, %[v0].8h, v1.h[6]\n" /* outr13 = v0 * r2[6] + */ + + /* r2, mul w06 */ + "smlal v4.4s, %[v6].4h, v1.h[0]\n" /* outr00 = v6 * r2[1] + */ + "smlal2 v5.4s, %[v6].8h, v1.h[0]\n" /* outr01 = v6 * r2[3] + */ + "smlal v6.4s, %[v6].4h, v1.h[2]\n" /* outr02 = v6 * r2[5] + */ + "smlal2 v7.4s, %[v6].8h, v1.h[2]\n" /* outr03 = v6 * r2[7] + */ + "smlal v8.4s, %[v6].4h, v1.h[4]\n" /* outr00 = v6 * r2[1] + */ + "smlal2 v9.4s, %[v6].8h, v1.h[4]\n" /* outr01 = v6 * r2[3] + */ + "smlal v10.4s, %[v6].4h, v1.h[6]\n" /* outr02 = v6 * r2[5] + */ + "smlal2 v11.4s, %[v6].8h, v1.h[6]\n" /* outr03 = v6 * r2[7] + */ + + "ldr q2, [%[r0]] \n" /* load r0, 9th + data,v10.s[0] */ + + /* r0, r2, mul w01 */ + "smlal v4.4s, %[v1].4h, v0.h[1]\n" /* outr00 = v0 * r0[0] + */ + "smlal2 v5.4s, %[v1].8h, v0.h[1]\n" /* outr00 = v0 * r0[0] + */ + "smlal v6.4s, %[v1].4h, v0.h[3]\n" /* outr01 = v0 * r0[2] + */ + "smlal2 v7.4s, %[v1].8h, v0.h[3]\n" /* outr00 = v0 * r0[0] + */ + "sshll v2.8h, v2.8b, #0 \n" /* r0: int8 -> int16 */ + "smlal v8.4s, %[v1].4h, v0.h[5]\n" /* outr02 = v0 * r0[4] + */ + "smlal2 v9.4s, %[v1].8h, v0.h[5]\n" /* outr00 = v0 * r0[0] + */ + "smlal v10.4s, %[v1].4h, v0.h[7]\n" /* outr03 = v0 * r0[6] + */ + "smlal2 v11.4s, %[v1].8h, v0.h[7]\n" /* outr00 = v0 * r0[0] + */ + + "smlal v12.4s, %[v1].4h, v1.h[1]\n" /* outr10 = v0 * r2[0] + */ + "smlal2 v13.4s, %[v1].8h, v1.h[1]\n" /* outr11 = v0 * r2[2] + */ + "smlal v14.4s, %[v1].4h, v1.h[3]\n" /* outr12 = v0 * r2[4] + */ + "smlal2 v15.4s, %[v1].8h, v1.h[3]\n" /* outr13 = v0 * r2[6] + */ + "smlal v16.4s, %[v1].4h, v1.h[5]\n" /* outr10 = v0 * r2[0] + */ + "smlal2 v17.4s, %[v1].8h, v1.h[5]\n" /* outr11 = v0 * r2[2] + */ + "smlal v18.4s, %[v1].4h, v1.h[7]\n" /* outr12 = v0 * r2[4] + */ + "smlal2 v19.4s, %[v1].8h, v1.h[7]\n" /* outr13 = v0 * r2[6] + */ + + /* r2, mul w07 */ + "smlal v4.4s, %[v7].4h, v1.h[1]\n" /* outr00 = v6 * r2[1] + */ + "smlal2 v5.4s, %[v7].8h, v1.h[1]\n" /* outr01 = v6 * r2[3] + */ + "smlal v6.4s, %[v7].4h, v1.h[3]\n" /* outr02 = v6 * r2[5] + */ + "smlal2 v7.4s, %[v7].8h, v1.h[3]\n" /* outr03 = v6 * r2[7] + */ + "smlal v8.4s, %[v7].4h, v1.h[5]\n" /* outr00 = v6 * r2[1] + */ + "smlal2 v9.4s, %[v7].8h, v1.h[5]\n" /* outr01 = v6 * r2[3] + */ + "smlal v10.4s, %[v7].4h, v1.h[7]\n" /* outr02 = v6 * r2[5] + */ + "smlal2 v11.4s, %[v7].8h, v1.h[7]\n" /* outr03 = v6 * r2[7] + */ + + "ldr q3, [%[r2]] \n" /* load r2, 9th + data,v11.s[0] */ + + /* r0, r2, mul w02 */ + "smlal v4.4s, %[v2].4h, v0.h[2]\n" /* outr00 = v0 * r0[0] + */ + "smlal2 v5.4s, %[v2].8h, v0.h[2]\n" /* outr00 = v0 * r0[0] + */ + "smlal v6.4s, %[v2].4h, v0.h[4]\n" /* outr01 = v0 * r0[2] + */ + "smlal2 v7.4s, %[v2].8h, v0.h[4]\n" /* outr00 = v0 * r0[0] + */ + "sshll v3.8h, v3.8b, #0 \n" /* r2: int8 -> int16*/ + "smlal v8.4s, %[v2].4h, v0.h[6]\n" /* outr02 = v0 * r0[4] + */ + "smlal2 v9.4s, %[v2].8h, v0.h[6]\n" /* outr00 = v0 * r0[0] + */ + "smlal v10.4s, %[v2].4h, v2.h[0]\n" /* outr03 = v0 * r0[6] + */ + "smlal2 v11.4s, %[v2].8h, v2.h[0]\n" /* outr00 = v0 * r0[0] + */ + + "ldr q0, [%[r1]], #8 \n" /* load input r1 */ + + "smlal v12.4s, %[v2].4h, v1.h[2]\n" /* outr10 = v0 * r2[0] + */ + "smlal2 v13.4s, %[v2].8h, v1.h[2]\n" /* outr11 = v0 * r2[2] + */ + "smlal v14.4s, %[v2].4h, v1.h[4]\n" /* outr12 = v0 * r2[4] + */ + "smlal2 v15.4s, %[v2].8h, v1.h[4]\n" /* outr13 = v0 * r2[6] + */ + "sshll v0.8h, v0.8b, #0 \n" /* r1 : int8 -> int16 */ + "smlal v16.4s, %[v2].4h, v1.h[6]\n" /* outr10 = v0 * r2[0] + */ + "smlal2 v17.4s, %[v2].8h, v1.h[6]\n" /* outr11 = v0 * r2[2] + */ + "smlal v18.4s, %[v2].4h, v3.h[0]\n" /* outr12 = v0 * r2[4] + */ + "smlal2 v19.4s, %[v2].8h, v3.h[0]\n" /* outr13 = v0 * r2[6] + */ + + /* r2, mul w08 */ + "smlal v4.4s, %[v8].4h, v1.h[2]\n" /* outr00 = v6 * r2[1] + */ + "smlal2 v5.4s, %[v8].8h, v1.h[2]\n" /* outr01 = v6 * r2[3] + */ + "smlal v6.4s, %[v8].4h, v1.h[4]\n" /* outr02 = v6 * r2[5] + */ + "smlal2 v7.4s, %[v8].8h, v1.h[4]\n" /* outr03 = v6 * r2[7] + */ + "smlal v8.4s, %[v8].4h, v1.h[6]\n" /* outr00 = v6 * r2[1] + */ + "smlal2 v9.4s, %[v8].8h, v1.h[6]\n" /* outr01 = v6 * r2[3] + */ + "smlal v10.4s, %[v8].4h, v3.h[0]\n" /* outr02 = v6 * r2[5] + */ + "smlal2 v11.4s, %[v8].8h, v3.h[0]\n" /* outr03 = v6 * r2[7] + */ + + "ldr q1, [%[r3]], #8 \n" /* load input r3 */ + + /* r1, r3, mul w03 */ + "smlal v4.4s, %[v3].4h, v0.h[0]\n" /* outr00 = v0 * r0[0] + */ + "smlal2 v5.4s, %[v3].8h, v0.h[0]\n" /* outr00 = v0 * r0[0] + */ + "smlal v6.4s, %[v3].4h, v0.h[2]\n" /* outr01 = v0 * r0[2] + */ + "smlal2 v7.4s, %[v3].8h, v0.h[2]\n" /* outr00 = v0 * r0[0] + */ + "sshll v1.8h, v1.8b, #0 \n" /* r3: int8 -> int16 */ + "smlal v8.4s, %[v3].4h, v0.h[4]\n" /* outr02 = v0 * r0[4] + */ + "smlal2 v9.4s, %[v3].8h, v0.h[4]\n" /* outr00 = v0 * r0[0] + */ + "smlal v10.4s, %[v3].4h, v0.h[6]\n" /* outr03 = v0 * r0[6] + */ + "smlal2 v11.4s, %[v3].8h, v0.h[6]\n" /* outr00 = v0 * r0[0] + */ + "ldr q2, [%[r1]] \n" /* load r1, 9th + data,v10.s[0] */ + + "smlal v12.4s, %[v3].4h, v1.h[0]\n" /* outr10 = v0 * r2[0] + */ + "smlal2 v13.4s, %[v3].8h, v1.h[0]\n" /* outr11 = v0 * r2[2] + */ + "smlal v14.4s, %[v3].4h, v1.h[2]\n" /* outr12 = v0 * r2[4] + */ + "smlal2 v15.4s, %[v3].8h, v1.h[2]\n" /* outr13 = v0 * r2[6] + */ + "ldr q3, [%[r3]] \n" /* load r3, 9th + data,v11.s[0] */ + "smlal v16.4s, %[v3].4h, v1.h[4]\n" /* outr10 = v0 * r2[0] + */ + "smlal2 v17.4s, %[v3].8h, v1.h[4]\n" /* outr11 = v0 * r2[2] + */ + "smlal v18.4s, %[v3].4h, v1.h[6]\n" /* outr12 = v0 * r2[4] + */ + "smlal2 v19.4s, %[v3].8h, v1.h[6]\n" /* outr13 = v0 * r2[6] + */ + "sshll v2.8h, v2.8b, #0 \n" /* r1 : int8 -> int16 */ + + /* r1, r3, mul w05 */ + "smlal v4.4s, %[v5].4h, v0.h[2]\n" /* outr00 = v0 * r0[0] + */ + "smlal2 v5.4s, %[v5].8h, v0.h[2]\n" /* outr00 = v0 * r0[0] + */ + "smlal v6.4s, %[v5].4h, v0.h[4]\n" /* outr01 = v0 * r0[2] + */ + "smlal2 v7.4s, %[v5].8h, v0.h[4]\n" /* outr00 = v0 * r0[0] + */ + "sshll v3.8h, v3.8b, #0 \n" /* r3 : int8 -> int16 */ + "smlal v8.4s, %[v5].4h, v0.h[6]\n" /* outr02 = v0 * r0[4] + */ + "smlal2 v9.4s, %[v5].8h, v0.h[6]\n" /* outr00 = v0 * r0[0] + */ + "smlal v10.4s, %[v5].4h, v2.h[0]\n" /* outr03 = v0 * r0[6] + */ + "smlal2 v11.4s, %[v5].8h, v2.h[0]\n" /* outr00 = v0 * r0[0] + */ + + "smlal v12.4s, %[v5].4h, v1.h[2]\n" /* outr10 = v0 * r2[0] + */ + "smlal2 v13.4s, %[v5].8h, v1.h[2]\n" /* outr11 = v0 * r2[2] + */ + "smlal v14.4s, %[v5].4h, v1.h[4]\n" /* outr12 = v0 * r2[4] + */ + "smlal2 v15.4s, %[v5].8h, v1.h[4]\n" /* outr13 = v0 * r2[6] + */ + "smlal v16.4s, %[v5].4h, v1.h[6]\n" /* outr10 = v0 * r2[0] + */ + "smlal2 v17.4s, %[v5].8h, v1.h[6]\n" /* outr11 = v0 * r2[2] + */ + "smlal v18.4s, %[v5].4h, v3.h[0]\n" /* outr12 = v0 * r2[4] + */ + "smlal2 v19.4s, %[v5].8h, v3.h[0]\n" /* outr13 = v0 * r2[6] + */ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1 */ + + /* r1, r3, mul w04 */ + "smlal v4.4s, %[v4].4h, v0.h[1]\n" /* outr00 = v0 * r0[0] + */ + "smlal2 v5.4s, %[v4].8h, v0.h[1]\n" /* outr00 = v0 * r0[0] + */ + "smlal v6.4s, %[v4].4h, v0.h[3]\n" /* outr01 = v0 * r0[2] + */ + "smlal2 v7.4s, %[v4].8h, v0.h[3]\n" /* outr00 = v0 * r0[0] + */ + "smlal v8.4s, %[v4].4h, v0.h[5]\n" /* outr02 = v0 * r0[4] + */ + "smlal2 v9.4s, %[v4].8h, v0.h[5]\n" /* outr00 = v0 * r0[0] + */ + "smlal v10.4s, %[v4].4h, v0.h[7]\n" /* outr03 = v0 * r0[6] + */ + "smlal2 v11.4s, %[v4].8h, v0.h[7]\n" /* outr00 = v0 * r0[0] + */ + + "ldr q0, [%[r4]], #8 \n" /* load input r4 */ + + "smlal v12.4s, %[v4].4h, v1.h[1]\n" /* outr10 = v0 * r2[0] + */ + "smlal2 v13.4s, %[v4].8h, v1.h[1]\n" /* outr11 = v0 * r2[2] + */ + "smlal v14.4s, %[v4].4h, v1.h[3]\n" /* outr12 = v0 * r2[4] + */ + "smlal2 v15.4s, %[v4].8h, v1.h[3]\n" /* outr13 = v0 * r2[6] + */ + "sshll v0.8h, v0.8b, #0 \n" /* r4 : int8 -> int16 */ + "smlal v16.4s, %[v4].4h, v1.h[5]\n" /* outr10 = v0 * r2[0] + */ + "smlal2 v17.4s, %[v4].8h, v1.h[5]\n" /* outr11 = v0 * r2[2] + */ + "smlal v18.4s, %[v4].4h, v1.h[7]\n" /* outr12 = v0 * r2[4] + */ + "smlal2 v19.4s, %[v4].8h, v1.h[7]\n" /* outr13 = v0 * r2[6] + */ + + "ldr q2, [%[r4]] \n" /* load r4, 9th + data,v10.s[0] */ + "sshll v2.8h, v2.8b, #0 \n" /* r4 : int8 -> int16 */ + + "ldp q1, q3, [%[ptr_out0]] \n" /* load ptr_out + 0 -> + q2, q3 */ + "ldp q20, q21, [%[ptr_out0], #32]\n" /* load ptr_out + 32 -> + q4, q5 */ + + "add v4.4s, v1.4s , v4.4s \n" /* v10 = outr00[0].low + + q2 */ + "add v5.4s, v3.4s , v5.4s \n" /* v11 = outr00[0].high + + q3 */ + "add v6.4s, v20.4s, v6.4s \n" /* v12 = outr01[0].low + + q4 */ + "add v7.4s, v21.4s, v7.4s \n" /* v13 = outr01[0].high + + q5 */ + + "ldp q1 , q3 , [%[ptr_out0], #64]\n" /* load ptr_out + 64 -> + q6, q7 */ + "ldp q20, q21, [%[ptr_out0], #96]\n" /* load ptr_out + 96 -> + q8, q9 */ + + "stp q4, q5 , [%[ptr_out0]], #32\n" /* store q10, q11 -> + ptr_out */ + "stp q6, q7 , [%[ptr_out0]], #32\n" /* store q10, q11 -> + ptr_out */ + + "add v8.4s , v1.4s , v8.4s \n" /* v10 = outr00[0].low + + q2 */ + "add v9.4s , v3.4s , v9.4s \n" /* v11 = outr00[0].high + + q3 */ + "add v10.4s, v20.4s, v10.4s \n" /* v12 = outr01[0].low + + q4 */ + "add v11.4s, v21.4s, v11.4s \n" /* v13 = outr01[0].high + + q5 */ + "stp q8, q9, [%[ptr_out0]], #32\n" /* store q14, q15 -> + ptr_out += 64 */ + "stp q10, q11, [%[ptr_out0]], #32\n" /* store q16, q17 -> + ptr_out += 96 */ + + /* r4, mul w08 */ + "smlal v12.4s, %[v8].4h, v0.h[2]\n" /* outr00 = v0 * r0[0] + */ + "smlal2 v13.4s, %[v8].8h, v0.h[2]\n" /* outr00 = v0 * r0[0] + */ + "smlal v14.4s, %[v8].4h, v0.h[4]\n" /* outr01 = v0 * r0[2] + */ + "smlal2 v15.4s, %[v8].8h, v0.h[4]\n" /* outr00 = v0 * r0[0] + */ + + "smlal v16.4s, %[v8].4h, v0.h[6]\n" /* outr02 = v0 * r0[4] + */ + "smlal2 v17.4s, %[v8].8h, v0.h[6]\n" /* outr00 = v0 * r0[0] + */ + "smlal v18.4s, %[v8].4h, v2.h[0]\n" /* outr03 = v0 * r0[6] + */ + "smlal2 v19.4s, %[v8].8h, v2.h[0]\n" /* outr00 = v0 * r0[0] + */ + + /* r4, mul w07 */ + "smlal v12.4s, %[v7].4h, v0.h[1]\n" /* outr00 = v0 * r0[0] + */ + "smlal2 v13.4s, %[v7].8h, v0.h[1]\n" /* outr00 = v0 * r0[0] + */ + "smlal v14.4s, %[v7].4h, v0.h[3]\n" /* outr01 = v0 * r0[2] + */ + "smlal2 v15.4s, %[v7].8h, v0.h[3]\n" /* outr00 = v0 * r0[0] + */ + + "ldr q1, [%[r2]], #8 \n" /* load input r2 */ + + "smlal v16.4s, %[v7].4h, v0.h[5]\n" /* outr02 = v0 * r0[4] + */ + "smlal2 v17.4s, %[v7].8h, v0.h[5]\n" /* outr00 = v0 * r0[0] + */ + "smlal v18.4s, %[v7].4h, v0.h[7]\n" /* outr03 = v0 * r0[6] + */ + "smlal2 v19.4s, %[v7].8h, v0.h[7]\n" /* outr00 = v0 * r0[0] + */ + + "sshll v1.8h, v1.8b, #0 \n" /* r2: int8 -> int16 + */ + + /* r4, mul w06 */ + "ldp q4, q5, [%[ptr_out1]] \n" /* load ptr_out + 0 -> + q2, q3 */ + + "smlal v12.4s, %[v6].4h, v0.h[0]\n" /* outr00 = v0 * r0[0] + */ + "smlal2 v13.4s, %[v6].8h, v0.h[0]\n" /* outr00 = v0 * r0[0] + */ + "smlal v14.4s, %[v6].4h, v0.h[2]\n" /* outr01 = v0 * r0[2] + */ + + "ldp q8, q9, [%[ptr_out1], #64]\n" /* load ptr_out + 64 -> + q6, q7 */ + + "smlal2 v15.4s, %[v6].8h, v0.h[2]\n" /* outr00 = v0 * r0[0] + */ + "smlal v16.4s, %[v6].4h, v0.h[4]\n" /* outr02 = v0 * r0[4] + */ + "smlal2 v17.4s, %[v6].8h, v0.h[4]\n" /* outr00 = v0 * r0[0] + */ + + "ldp q10, q11, [%[ptr_out1], #96]\n" /* load ptr_out + 96 -> + q8, q9 */ + + "smlal v18.4s, %[v6].4h, v0.h[6]\n" /* outr03 = v0 * r0[6] + */ + "smlal2 v19.4s, %[v6].8h, v0.h[6]\n" /* outr00 = v0 * r0[0] + */ + + "ldr q0, [%[r0]], #8 \n" /* load input r2 */ + "ldp q6, q7, [%[ptr_out1], #32]\n" /* load ptr_out + 32 -> + q4, q5 */ + + "sshll v0.8h, v0.8b, #0 \n" /* r0: int8 -> int16 */ + + /* store outr1 */ + "add v12.4s, v4.4s , v12.4s\n" /* v10 = outr10[0].low + q2 */ + "add v13.4s, v5.4s , v13.4s\n" /* v11 = outr10[0].high + q3 */ + "add v14.4s, v6.4s , v14.4s\n" /* v12 = outr11[0].low + q4 */ + "add v15.4s, v7.4s , v15.4s\n" /* v13 = outr11[0].high + q5 */ + + "stp q12, q13, [%[ptr_out1]], #32\n" /* store q10, q11 -> + ptr_out */ + + "add v16.4s, v8.4s , v16.4s\n" /* v14 = outr12[0].low + q6 */ + "add v17.4s, v9.4s , v17.4s\n" /* v15 = outr12[0].high + q7 */ + + "stp q14, q15, [%[ptr_out1]], #32\n" /* store q12, q13 -> + ptr_out += 32 */ + + "add v18.4s, v10.4s, v18.4s\n" /* v16 = outr13[0].low + q8 */ + "add v19.4s, v11.4s, v19.4s\n" /* v17 = outr13[0].high + q9 */ + + "stp q16, q17, [%[ptr_out1]], #32\n" /* store q14, q15 -> + ptr_out += 64 */ + "stp q18, q19, [%[ptr_out1]], #32\n" /* store q16, q17 -> + ptr_out += 96 */ + + "bne 1b \n" /* jump to main loop */ + + : [cnt] "+r"(cnt), + [r0] "+r"(r0), + [r1] "+r"(r1), + [r2] "+r"(r2), + [r3] "+r"(r3), + [r4] "+r"(r4), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [v0] "w"(v0), + [v1] "w"(v1), + [v2] "w"(v2), + [v3] "w"(v3), + [v4] "w"(v4), + [v5] "w"(v5), + [v6] "w"(v6), + [v7] "w"(v7), + [v8] "w"(v8) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); + + wc0 += 9 * hout_c_block; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + inr4 += win_round; + } + block_inr0 = block_inr4; + block_inr1 = block_inr0 + in_len; + block_inr2 = block_inr1 + in_len; + block_inr3 = block_inr2 + in_len; + block_inr4 = block_inr3 + in_len; + } + if (out_type == PRECISION(kFloat)) { + write_to_output_c8_int32_1(pre_out, + reinterpret_cast(dout_batch), + hout_c_block, + 2, + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + reinterpret_cast(ptr_write), + &scale[c], + out_type); + } else if (out_type == PRECISION(kInt8)) { + write_to_output_c8_int32_1(pre_out, + dout_batch, + hout_c_block, + 2, + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + reinterpret_cast(ptr_write), + &scale[c], + out_type); + } else { + write_to_output_c8_int32(pre_out, + reinterpret_cast(dout_batch), + hout_c_block, + 2, + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + ptr_write); + } + } + } + } +} + +#else // __aarch64__ +int conv_3x3s2_direct_int8_c_num() { return 4; } +void conv_3x3s2_direct_int8(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + Context* ctx, + PrecisionType out_type, + const float* scale) { + //! 3x3s2 int8 convolution, implemented by direct algorithm + //! prepack input to tmp buffer + //! write output to tmp buffer + int threads = ctx->threads(); + int stride_w = param.strides[1]; + int pad_w = param.paddings[1]; + int pad_h = param.paddings[0]; + bool flag_relu = param.fuse_relu; + bool flag_bias = (param.bias != nullptr); + + //! set 2/3 l2 cache + int l2_size = ctx->llc_size() / 3 * 2; + const int hout_c_block = 4; + const int hout_r_kernel = 1; + const int wout_round = ((wout + 3) / 4) * 4; + const int win_round = wout_round * stride_w + 1; + + //! get h block + //! win_round * chin * hin_r_block * sizeof(int8_t) + wout_round * + //! hout_c_block * hout_r_block * threads * sizeof(int32_t)= l2_size + //! win_round = 2 * wout_round + 1 + //! hin_r_block = 2 * hout_r_block + 1 + int hout_r_block = + (l2_size - 2 * wout_round * chin - chin) / + ((4 * wout_round + 2) * chin + wout_round * hout_c_block * threads * 4); + hout_r_block = hout_r_block > hout ? hout : hout_r_block; + hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel; + hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + + const int hin_r_block = hout_r_block * 2 + 1; + + int8_t* tmp_work_space = ctx->workspace_data(); + int zero_size = chout > (win_round + 3) / 4 ? chout : (win_round + 3) / 4; + const int kZeroSize = zero_size; + int32_t ptr_zero[kZeroSize]; + memset(ptr_zero, 0, sizeof(int32_t) * zero_size); + const int kWoutRound = wout_round; + int32_t ptr_write[kWoutRound]; + + int in_len = win_round * chin; + int pre_in_size = hin_r_block * in_len; + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + //! l2_cache start + int8_t* pre_din = tmp_work_space; + + int size_in_channel = win * hin; + int size_out_channel = wout * hout; + int w_stride = chin * 9; + + int ws = -pad_w; + int we = ws + win_round; + int w_loop = wout_round / 4; + + int out_row_stride = hout_c_block * wout_round; + + for (int n = 0; n < num; ++n) { + const int8_t* din_batch = din + n * chin * size_in_channel; + int8_t* dout_batch = + reinterpret_cast(dout) + + n * chout * size_out_channel * PrecisionTypeLength(out_type); + for (int h = 0; h < hout; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > hout) { + h_kernel = hout - h; + } + int hs = h * 2 - pad_h; + int he = hs + h_kernel * 2 + 1; + prepack_input_nxw(din_batch, + pre_din, + 0, + chin, + hs, + he, + ws, + we, + chin, + win, + hin, + reinterpret_cast(ptr_zero)); + + const int8_t* cblock_inr0 = pre_din; + const int8_t* cblock_inr1 = cblock_inr0 + in_len; + const int8_t* cblock_inr2 = cblock_inr1 + in_len; +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < chout; c += hout_c_block) { +#ifdef ARM_WITH_OMP + int32_t* pre_out = + reinterpret_cast(pre_din + (pre_in_size + 3) / 4 * 4) + + omp_get_thread_num() * pre_out_size; +#else + int32_t* pre_out = + reinterpret_cast(pre_din + (pre_in_size + 3) / 4 * 4); +#endif + const int8_t* block_inr0 = cblock_inr0; + const int8_t* block_inr1 = cblock_inr1; + const int8_t* block_inr2 = cblock_inr2; + + const int8_t* weight_c = weights + c * w_stride; + const int32_t* bias_ptr = ptr_zero; + if (flag_bias) { + bias_ptr = bias + c; + } + + fill_packed_bias_nxmw_int8(bias_ptr, pre_out, 4, h_kernel, wout_round); + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + const int8_t* wc0 = weight_c; + + const int8_t* inr0 = block_inr0; + const int8_t* inr1 = block_inr1; + const int8_t* inr2 = block_inr2; + + int32_t* pre_out0 = pre_out + hk * out_row_stride; + for (int i = 0; i < chin; ++i) { + const int8_t* r0 = inr0; + const int8_t* r1 = inr1; + const int8_t* r2 = inr2; + + int32_t* ptr_out0 = pre_out0; + const signed char* ptr_wc0 = wc0; + int cnt = w_loop; + asm volatile( + "vld1.s32 {d0-d3}, [%[wc0]]! \n" /* w0-w7 */ + "vld1.s32 {d4}, [%[wc0]]! \n" /* w8 */ + "vmovl.s8 q3, d0 \n" /* q3 = w0, w1 */ + "vmovl.s8 q4, d1 \n" /* q4 = w2 ,w3 */ + "vmovl.s8 q5, d2 \n" /* q5 = w4, w5 */ + "vmovl.s8 q6, d3 \n" /* q6 = w6, w7 */ + "vmovl.s8 q7, d4 \n" /* q7 = w8 */ + "vld1.s32 {d0}, [%[r0]]! \n" /* load input r0 -> d0 */ + "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ + "1: \n" /* main loop */ + + /* r0 mul w0 */ + "vmull.s16 q8, d6, d0[0] \n" /* q8 = w0 * r0[0] */ + "vmull.s16 q9, d6, d0[2] \n" /* q9 = w0 * r0[2] */ + "vmull.s16 q10, d6, d1[0] \n" /* q10 = w0 * r0[4] */ + "vmull.s16 q11, d6, d1[2] \n" /* q11 = w0 * r0[6] */ + + "vld1.s32 {d2}, [%[r1]]! \n" /* load input r1 -> d2 */ + "vmovl.s8 q1, d2 \n" /* movl d2 -> q1 */ + + /* r0 mul w1 */ + "vmlal.s16 q8, d7, d0[1] \n" /* q8 = w1 * r0[1] */ + "vmlal.s16 q9, d7, d0[3] \n" /* q9 = w1 * r0[3] */ + "vmlal.s16 q10, d7, d1[1] \n" /* q10 = w1 * r0[5] */ + "vmlal.s16 q11, d7, d1[3] \n" /* q11 = w1 * r0[7] */ + + "vld1.s32 {d4}, [%[r0]] \n" /* load r0[8] -> d4 */ + "vmovl.s8 q2 , d4 \n" /* movl d4 -> q2 */ + + /* r0 mul w2 */ + "vmlal.s16 q8, d8, d0[2] \n" /* q8 = w2 * r0[2] */ + "vmlal.s16 q9, d8, d1[0] \n" /* q9 = w2 * r0[4] */ + "vmlal.s16 q10, d8, d1[2] \n" /* q10 = w2 * r0[6] */ + "vmlal.s16 q11, d8, d4[0] \n" /* q11 = w2 * r0[8] */ + + "subs %[cnt], #1 \n" /* loop count -1 */ + + /* r1 mul w3 */ + "vmlal.s16 q8, d9, d2[0] \n" /* q8 = w3 * r1[0] */ + "vmlal.s16 q9, d9, d2[2] \n" /* q9 = w3 * r1[2] */ + "vmlal.s16 q10, d9, d3[0] \n" /* q10 = w3 * r1[4] */ + "vmlal.s16 q11, d9, d3[2] \n" /* q11 = w3 * r1[6] */ + + "vld1.s32 {d4}, [%[r2]]! \n" /* load input r2 -> d4*/ + "vmovl.s8 q2, d4 \n" /* movl d4 -> q2 */ + + /* r1 mul w4 */ + "vmlal.s16 q8, d10, d2[1] \n" /* q8 = w4 * r1[1] */ + "vmlal.s16 q9, d10, d2[3] \n" /* q9 = w4 * r1[3] */ + "vmlal.s16 q10, d10, d3[1] \n" /* q10 = w4 * r1[5] */ + "vmlal.s16 q11, d10, d3[3] \n" /* q11 = w4 * r1[7] */ + + "vld1.s32 {d0}, [%[r1]] \n" /* load r1[8] -> d0 */ + "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ + + /* r1 mul w5 */ + "vmlal.s16 q8, d11, d2[2] \n" /* q8 = w5 * r1[2] */ + "vmlal.s16 q9, d11, d3[0] \n" /* q9 = w5 * r1[4] */ + "vmlal.s16 q10, d11, d3[2] \n" /* q10 = w5 * r1[6] */ + "vmlal.s16 q11, d11, d0[0] \n" /* q11 = w5 * r1[8] */ + + /* r2 mul w6 */ + "vmlal.s16 q8, d12, d4[0] \n" /* q8 = w6 * r2[0] */ + "vmlal.s16 q9, d12, d4[2] \n" /* q9 = w6 * r2[2] */ + "vmlal.s16 q10, d12, d5[0] \n" /* q10 = w6 * r2[4] */ + "vmlal.s16 q11, d12, d5[2] \n" /* q11 = w6 * r2[6] */ + + "vld1.s32 {d24-d27}, [%[ptr_out0]] \n" /* load output -> q12, + q13 */ + + /* r2 mul w7 */ + "vmlal.s16 q8, d13, d4[1] \n" /* q8 = w7 * r2[1] */ + "vmlal.s16 q9, d13, d4[3] \n" /* q9 = w7 * r2[3] */ + "vmlal.s16 q10, d13, d5[1] \n" /* q10 = w7 * r2[5] */ + "vmlal.s16 q11, d13, d5[3] \n" /* q11 = w7 * r2[7] */ + + "vld1.s32 {d0}, [%[r2]] \n" /* load r2[8] -> d0 */ + "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ + + /* r2 mul w8 */ + "vmlal.s16 q8, d14, d4[2] \n" /* q8 = w8 * r2[2] */ + "vmlal.s16 q9, d14, d5[0] \n" /* q9 = w8 * r2[4] */ + "vmlal.s16 q10, d14, d5[2] \n" /* q10 = w8 * r2[6] */ + "vmlal.s16 q11, d14, d0[0] \n" /* q11 = w8 * r2[8] */ + + "vadd.s32 q12, q8, q12 \n" /* out[0] += q8 */ + "vadd.s32 q13, q9, q13 \n" /* out[1] += q9 */ + "vst1.s32 {d24-d27}, [%[ptr_out0]]! \n" /* store q12, q13 -> + output[0,1] */ + + "vld1.s32 {d0}, [%[r0]]! \n" /* load next input r0 -> d0*/ + "vmovl.s8 q0, d0 \n" /* movl d0 -> q0 */ + + "vld1.s32 {d28-d31}, [%[ptr_out0]] \n" /* load output[0,1] -> + q14, q15 */ + "vadd.s32 q14, q10, q14 \n" /* out[2] += q10 */ + "vadd.s32 q15, q11, q15 \n" /* out[3] += q11 */ + "vst1.s32 {d28-d31}, [%[ptr_out0]]! \n" /* store q14, q15 -> + output[2,3] */ + + "bne 1b \n" /* jump to main loop */ + + : [cnt] "+r"(cnt), + [r0] "+r"(r0), + [r1] "+r"(r1), + [r2] "+r"(r2), + [ptr_out0] "+r"(ptr_out0), + [wc0] "+r"(ptr_wc0) + : + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + wc0 += 9 * hout_c_block; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + } + block_inr0 = block_inr2; + block_inr1 = block_inr0 + in_len; + block_inr2 = block_inr1 + in_len; + } + if (out_type == PRECISION(kFloat)) { + write_to_output_c4_int32_1(pre_out, + reinterpret_cast(dout_batch), + hout_c_block, + 1, + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + reinterpret_cast(ptr_write), + &scale[c], + out_type); + } else if (out_type == PRECISION(kInt8)) { + write_to_output_c4_int32_1(pre_out, + dout_batch, + hout_c_block, + 1, + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + reinterpret_cast(ptr_write), + &scale[c], + out_type); + } else { + write_to_output_c4_int32(pre_out, + reinterpret_cast(dout_batch), + hout_c_block, + 1, + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + ptr_write); + } + } + } + } +} +#endif // __aarch64__ + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_block_utils.h b/lite/arm/math/conv_block_utils.h new file mode 100644 index 00000000000..8d8a3907125 --- /dev/null +++ b/lite/arm/math/conv_block_utils.h @@ -0,0 +1,4292 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/arm/math/saturate.h" +#include "lite/arm/math/type_trans.h" +#include "lite/core/target_wrapper.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +#define LITEMAX(a, b) ((a) > (b) ? (a) : (b)) + +inline void fill_packed_biasc4(float* dout, const float* bias, int size) { + float32x4_t vb = vld1q_f32(bias); + int cnt = size / 4; + for (int i = 0; i < cnt; ++i) { + vst1q_f32(dout, vb); + dout += 4; + } +} + +/*preprocessing weights +* input weights: [chout, chin/ group, kh, kw] --> outputs weights: [chout / n, +* chin/ group, kh, kw, n] +*/ +template +static bool conv_trans_weights_numc(const dtype* din, + dtype* dout, + int chout, + int chin, + int n, + int kernel_size) { + if (n <= 0) { + LOG(ERROR) << "ch_n and hei_n are more than zero"; + return false; + } + int c_loop = chout / n; + int chout_round = (chout + n - 1) / n; + int win_stride = chin * kernel_size; + int wout_stride = n * win_stride; + int co = 0; + for (; co < c_loop; ++co) { + dtype* dout_c = dout + co * wout_stride; + const dtype* din_array[n]; + din_array[0] = din + co * wout_stride; + for (int i = 1; i < n; i++) { + din_array[i] = din_array[i - 1] + win_stride; + } + for (int ci = 0; ci < chin; ++ci) { + for (int k = 0; k < kernel_size; ++k) { + for (int i = 0; i < n; i++) { + *(dout_c++) = *(din_array[i]++); + } + } + } + } + // pad final chout + if (chout_round > c_loop) { + dtype* dout_c = dout + c_loop * wout_stride; + const dtype* din_array[n]; + din_array[0] = din + c_loop * wout_stride; + for (int i = 1; i < n; i++) { + din_array[i] = din_array[i - 1] + win_stride; + } + // deal remain + int cremain = chout_round * n - chout; + for (int i = 1; i <= cremain; i++) { + din_array[n - i] = din_array[0]; + } + for (int ci = 0; ci < chin; ++ci) { + for (int k = 0; k < kernel_size; ++k) { + for (int i = 0; i < n; i++) { + *(dout_c++) = *(din_array[i]++); + } + } + } + } + return true; +} +/*preprocessing inputs +* input din: [1, chin, he-hs, we - ws] --> outputs dout: [n, chin, 1, we - ws] +* n = he - hs +*/ +template +static bool prepack_input_nxw(const dtype* din, + dtype* dout, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int width, + int height, + dtype* zero_ptr) { + int n = he - hs; + if (n <= 0) { + LOG(ERROR) << "hei_n is more than zero"; + return false; + } + int w0 = ws < 0 ? 0 : ws; + int w1 = we > width ? width : we; + + int size_w = we - ws; + int size_wc_len = size_w * channel; + int size_c = width * height; + + int valid_w = w1 - w0; + size_t valid_w_byte = valid_w * sizeof(dtype); + + dtype* out_array[n]; + out_array[0] = dout; + for (int i = 1; i < n; i++) { + out_array[i] = out_array[i - 1] + size_wc_len; + } + + for (int c = 0; c < channel; ++c) { + int j = 0; + // valid height + for (int i = hs; i < he; i++) { + // get address + const dtype* in_array; + if (i < 0 || i >= height) { + in_array = zero_ptr; + } else { + in_array = din + i * width; + } + + for (int w = ws; w < w0; ++w) { + *(out_array[j]++) = 0.f; + } + memcpy(out_array[j], in_array, valid_w_byte); + out_array[j] += valid_w; + for (int w = w1; w < we; ++w) { + *(out_array[j]++) = 0.f; + } + j++; + } + din += size_c; + } + return true; +} + +/*wirte result in outputs +* input din: [n, c, h, w], output dout: [n, c, h, w] +*/ +inline bool write_to_output_c1_fp32(const float* din, + float* dout, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int height, + int width, + bool flag_relu, + float* trash_ptr) { + if (cs > channel) { + return true; + } + + const int c1 = 1; + const int w4 = 4; + + int size_c_out = width * height; + + float* doutc0r0 = dout + cs * size_c_out + hs * width + ws; + + const float* ptr_din = din; + + int size_h = (he > height ? height : he) - hs; // size_h == hei_n + + int w_round = we - ws; + int cnt = (width - ws) / w4; + + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + const float* din_hei_ptr = ptr_din + i * w_round * c1; + if (cnt > 0) { + int cnt_loop = cnt; + if (flag_relu) { +#ifdef __aarch64__ + asm volatile( + "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, + c0r3 */ + "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop*/ + "fmax v1.4s, v0.4s, v20.4s \n" /*relu*/ + "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, + c0r3 */ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "str q1, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "bne 1b \n" /* jump to main loop*/ + : [doutc0r0] "+r"(doutc0_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v0", "v1", "v20"); +#else + asm volatile( + "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data, c0r0, " + "c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n" + "vmov.u32 q15, #0 @ dump zero\n" + "1: @ main loop\n" + + "vmax.f32 q1, q0, q15 @ relu\n" + "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data \n" + + "vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, add " + "pointer\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q15"); +#endif + } else { +#ifdef __aarch64__ + asm volatile( + "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, + c0r3 */ + "1: \n" /* main loop*/ + "str q0, [%[doutc0r0]], #16 \n" /* store c2r0*/ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, + c0r3 */ + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v0"); +#else + asm volatile( + "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data, c0r0, " + "c0r1, c0r2, c0r3\n" + "1: @ main loop\n" + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " + "pointer\n" + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data \n" + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0"); +#endif + } + } + if (we > width) { + int offset = i * w_round * c1 + c1 * w4 * cnt; + din_hei_ptr = ptr_din + offset; + int j = we - w4; + if (flag_relu) { + for (; j < width; ++j) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); + din_hei_ptr++; + } + } else { + for (; j < width; ++j) { + *(doutc0_ptr++) = *(din_hei_ptr++); + } + } + } + } + return true; +} + +/*wirte result in outputs +* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] +*/ +inline bool write_to_output_c2_fp32(const float* din, + float* dout, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int height, + int width, + bool flag_relu, + float* trash_ptr) { + if (cs > channel) { + return true; + } + + const int c2 = 2; + const int w4 = 4; + + // float trash_ptr[width]; + + int size_c_out = width * height; + + float* doutc0r0 = dout + cs * size_c_out + hs * width + ws; + float* doutc1r0 = doutc0r0 + size_c_out; + + const float* ptr_din = din; + + int size_h = (he > height ? height : he) - hs; // size_h == hei_n + + int w_round = we - ws; + int cnt = (width - ws) / w4; + + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + float* doutc1_ptr = doutc1r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 1: + doutc1_ptr = trash_ptr; + default: + break; + } + } + const float* din_hei_ptr = ptr_din + i * w_round * c2; + if (cnt > 0) { + int cnt_loop = cnt; + if (flag_relu) { +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1, + c1r1, , c0r2, c1r2, c0r3, + c1r3 */ + "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop*/ + "trn1 v2.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "trn2 v3.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1, + c1r1, , c0r2, c1r2, c0r3, + c1r3 */ + "trn1 v4.2d, v2.2d, v3.2d \n" /* trans q8, q10*/ + "trn2 v5.2d, v2.2d, v3.2d \n" /* trans q8, q10*/ + + "fmax v2.4s, v4.4s, v20.4s \n" /*relu*/ + "fmax v3.4s, v5.4s, v20.4s \n" /*relu*/ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + + "str q2, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "str q3, [%[doutc1r0]], #16 \n" /* store c2r0*/ + + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v20"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0, " + "c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n" + "vmov.u32 q15, #0 @ dump zero\n" + "1: @ main loop\n" + "vtrn.32 d0, d1 @ trans data:c0r0, c0r1, " + "c1r0, c1r1 \n" + "vtrn.32 d2, d3 @ trans data:c0r2, c0r3, " + "c1r2, c1r3 \n" + + "vswp d1, d2 @ swap data\n" + + "vmax.f32 q0, q0, q15 @ relu\n" + "vmax.f32 q1, q1, q15 @ relu\n" + + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add " + "pointer\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + } else { +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1, + c1r1, , c0r2, c1r2, c0r3, + c1r3 */ + "1: \n" /* main loop*/ + "trn1 v2.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "trn2 v3.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1, + c1r1, , c0r2, c1r2, c0r3, + c1r3 */ + "trn1 v4.2d, v2.2d, v3.2d \n" /* trans q8, q10*/ + "trn2 v5.2d, v2.2d, v3.2d \n" /* trans q8, q10*/ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + + "str q4, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "str q5, [%[doutc1r0]], #16 \n" /* store c2r0*/ + + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0, " + "c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n" + "1: @ main loop\n" + "vtrn.32 d0, d1 @ trans data:c0r0, c0r1, " + "c1r0, c1r1 \n" + "vtrn.32 d2, d3 @ trans data:c0r2, c0r3, " + "c1r2, c1r3 \n" + + "vswp d1, d2 @ swap data\n" + + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add " + "pointer\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + } + } + if (we > width) { + int offset = i * w_round * c2 + c2 * w4 * cnt; + din_hei_ptr = ptr_din + offset; + int j = we - w4; + if (flag_relu) { + for (; j < width; ++j) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); + *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); + din_hei_ptr += 2; + } + } else { + for (; j < width; ++j) { + *(doutc0_ptr++) = *(din_hei_ptr++); + *(doutc1_ptr++) = *(din_hei_ptr++); + } + } + } + } + return true; +} + +/*wirte result in outputs +* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] +*/ +inline bool write_to_output_c4_fp32(const float* din, + float* dout, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int height, + int width, + bool flag_relu, + float* trash_ptr) { + const int c4 = 4; + const int w4 = 4; + const int w_round = we - ws; + const int ch_n = ce - cs; + if (ch_n != 4) { + LOG(ERROR) << "write_to_output_c4_fp32 ch_n must be equal 4 and hei_n is " + "more than zero"; + return false; + } + int size_c_out = width * height; + + float* doutc0r0 = dout + cs * size_c_out + hs * width + ws; + float* doutc1r0 = doutc0r0 + size_c_out; + float* doutc2r0 = doutc1r0 + size_c_out; + float* doutc3r0 = doutc2r0 + size_c_out; + + const float* ptr_din = din; + + int size_h = (he > height ? height : he) - hs; // size_h == hei_n + + int cnt = (width - ws) / w4; + + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + float* doutc1_ptr = doutc1r0 + size_w; + float* doutc2_ptr = doutc2r0 + size_w; + float* doutc3_ptr = doutc3r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 3: + doutc1_ptr = trash_ptr; + case 2: + doutc2_ptr = trash_ptr; + case 1: + doutc3_ptr = trash_ptr; + default: + break; + } + } + const float* din_hei_ptr = ptr_din + i * w_round * ch_n; + if (cnt > 0) { + int cnt_loop = cnt; + if (flag_relu) { +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + "fmax v16.4s, v16.4s, v20.4s \n" /*relu*/ + "fmax v17.4s, v17.4s, v20.4s \n" /*relu*/ + "fmax v18.4s, v18.4s, v20.4s \n" /*relu*/ + "fmax v19.4s, v19.4s, v20.4s \n" /*relu*/ + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vmov.u32 q15, #0 @ dump zero\n" + "1: @ main loop\n" + "vtrn.32 q0, q1 @ trans data:c00c01c20c21 " + "\n" + "vtrn.32 q2, q3 @ trans data:c02c03c22c23 " + "\n" + + "vswp d1, d4 @ swap data\n" + "vswp d3, d6 @ swap data\n" + + "vmax.f32 q0, q0, q15 @ relu\n" + "vmax.f32 q1, q1, q15 @ relu\n" + "vmax.f32 q2, q2, q15 @ relu\n" + "vmax.f32 q3, q3, q15 @ relu\n" + + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" + "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" + "vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n" + "vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + } else { +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v8", + "v9", + "v10", + "v11", + "v16", + "v17", + "v18", + "v19"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "1: @ main loop\n" + "vtrn.32 q0, q1 @ trans data:c00c01c20c21 " + "\n" + "vtrn.32 q2, q3 @ trans data:c02c03c22c23 " + "\n" + + "vswp d1, d4 @ swap data\n" + "vswp d3, d6 @ swap data\n" + + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" + "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" + "vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n" + "vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3"); +#endif + } + } + if (we > width) { + int offset = i * w_round * c4 + c4 * w4 * cnt; + din_hei_ptr = ptr_din + offset; + int j = we - w4; + if (flag_relu) { + for (; j < width; ++j) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); + *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); + *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f); + *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f); + din_hei_ptr += w4; + } + } else { + for (; j < width; ++j) { + *(doutc0_ptr++) = din_hei_ptr[0]; + *(doutc1_ptr++) = din_hei_ptr[1]; + *(doutc2_ptr++) = din_hei_ptr[2]; + *(doutc3_ptr++) = din_hei_ptr[3]; + din_hei_ptr += w4; + } + } + } + } + return true; +} + +/*wirte result in outputs +* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w] +*/ +inline bool write_to_output_c8_fp32(const float* din, + float* dout, + int ch_n, + int hei_n, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int height, + int width, + bool flag_relu, + float* trash_ptr) { + if (ch_n != 8 || hei_n <= 0) { + LOG(ERROR) << "ch_n must be equal 8 and hei_n is more than zero"; + return false; + } + int size_c_out = width * height; + + float* doutc0r0 = dout + cs * size_c_out + hs * width + ws; + float* doutc1r0 = doutc0r0 + size_c_out; + float* doutc2r0 = doutc1r0 + size_c_out; + float* doutc3r0 = doutc2r0 + size_c_out; + float* doutc4r0 = doutc3r0 + size_c_out; + float* doutc5r0 = doutc4r0 + size_c_out; + float* doutc6r0 = doutc5r0 + size_c_out; + float* doutc7r0 = doutc6r0 + size_c_out; + + const float* ptr_din = din; + + int size_h = (he > height ? height : he) - hs; // size_h == hei_n + + int valid_w = we - ws; + int cnt = valid_w / 4; + + if (we > width) { + cnt--; + } + if (flag_relu) { + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + float* doutc1_ptr = doutc1r0 + size_w; + float* doutc2_ptr = doutc2r0 + size_w; + float* doutc3_ptr = doutc3r0 + size_w; + float* doutc4_ptr = doutc4r0 + size_w; + float* doutc5_ptr = doutc5r0 + size_w; + float* doutc6_ptr = doutc6r0 + size_w; + float* doutc7_ptr = doutc7r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 7: + doutc1_ptr = trash_ptr; + case 6: + doutc2_ptr = trash_ptr; + case 5: + doutc3_ptr = trash_ptr; + case 4: + doutc4_ptr = trash_ptr; + case 3: + doutc5_ptr = trash_ptr; + case 2: + doutc6_ptr = trash_ptr; + case 1: + doutc7_ptr = trash_ptr; + default: + break; + } + } + ptr_din = din + i * valid_w * ch_n; + const float* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + + "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + + "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ + "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ + "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ + "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + + "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ + "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ + "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ + "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + + "fmax v16.4s, v16.4s, v20.4s \n" /*relu*/ + "fmax v17.4s, v17.4s, v20.4s \n" /*relu*/ + "fmax v18.4s, v18.4s, v20.4s \n" /*relu*/ + "fmax v19.4s, v19.4s, v20.4s \n" /*relu*/ + + "fmax v8.4s, v8.4s, v20.4s \n" /*relu*/ + "fmax v9.4s, v9.4s, v20.4s \n" /*relu*/ + "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ + "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ + + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ + "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ + "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ + "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ + + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + "vmov.u32 q15, #0 @ dump zero\n" + "1: @ main loop\n" + "vtrn.32 q0, q2 @ trans q0, q2 \n" + "vtrn.32 q4, q6 @ trans q4, q6 \n" + "vswp.32 d1, d8 @ swap d1, d8 \n" + "vswp.32 d5, d12 @ swap d5, d12\n" + + "vtrn.32 q1, q3 @ trans q1, q3 \n" + "vtrn.32 q5, q7 @ trans q5, q7 \n" + "vswp.32 d3, d10 @ swap d3, d10\n" + "vswp.32 d7, d14 @ swap d7, d14\n" + + "vmax.f32 q0, q0, q15 @ relu\n" + "vmax.f32 q1, q1, q15 @ relu\n" + "vmax.f32 q2, q2, q15 @ relu\n" + "vmax.f32 q3, q3, q15 @ relu\n" + + "vmax.f32 q4, q4, q15 @ relu\n" + "vmax.f32 q5, q5, q15 @ relu\n" + "vmax.f32 q6, q6, q15 @ relu\n" + "vmax.f32 q7, q7, q15 @ relu\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add " + "pointer\n" + + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + + "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add " + "pointer\n" + + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q4", "q15"); +#endif + } + if (we > width) { + int offset = 32 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + int i = we - 4; + for (; i < width; ++i) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); + *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); + *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f); + *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f); + *(doutc4_ptr++) = LITEMAX(din_hei_ptr[4], 0.f); + *(doutc5_ptr++) = LITEMAX(din_hei_ptr[5], 0.f); + *(doutc6_ptr++) = LITEMAX(din_hei_ptr[6], 0.f); + *(doutc7_ptr++) = LITEMAX(din_hei_ptr[7], 0.f); + din_hei_ptr += 8; + } + } + } + } else { + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + float* doutc1_ptr = doutc1r0 + size_w; + float* doutc2_ptr = doutc2r0 + size_w; + float* doutc3_ptr = doutc3r0 + size_w; + float* doutc4_ptr = doutc4r0 + size_w; + float* doutc5_ptr = doutc5r0 + size_w; + float* doutc6_ptr = doutc6r0 + size_w; + float* doutc7_ptr = doutc7r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 7: + doutc1_ptr = trash_ptr; + case 6: + doutc2_ptr = trash_ptr; + case 5: + doutc3_ptr = trash_ptr; + case 4: + doutc4_ptr = trash_ptr; + case 3: + doutc5_ptr = trash_ptr; + case 2: + doutc6_ptr = trash_ptr; + case 1: + doutc7_ptr = trash_ptr; + default: + break; + } + } + ptr_din = din + i * valid_w * ch_n; + const float* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + + "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + + "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ + "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ + "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ + "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + + "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ + "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ + "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ + "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ + "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ + "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ + "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ + + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + "1: @ main loop\n" + "vtrn.32 q0, q2 @ trans q0, q2 \n" + "vtrn.32 q4, q6 @ trans q4, q6 \n" + "vswp.32 d1, d8 @ swap d1, d8 \n" + "vswp.32 d5, d12 @ swap d5, d12\n" + + "vtrn.32 q1, q3 @ trans q1, q3 \n" + "vtrn.32 q5, q7 @ trans q5, q7 \n" + "vswp.32 d3, d10 @ swap d3, d10\n" + "vswp.32 d7, d14 @ swap d7, d14\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add " + "pointer\n" + + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + + "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add " + "pointer\n" + + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q4"); +#endif + } + if (we > width) { + int offset = 32 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + int i = we - 4; + for (; i < width; ++i) { + *(doutc0_ptr++) = din_hei_ptr[0]; + *(doutc1_ptr++) = din_hei_ptr[1]; + *(doutc2_ptr++) = din_hei_ptr[2]; + *(doutc3_ptr++) = din_hei_ptr[3]; + *(doutc4_ptr++) = din_hei_ptr[4]; + *(doutc5_ptr++) = din_hei_ptr[5]; + *(doutc6_ptr++) = din_hei_ptr[6]; + *(doutc7_ptr++) = din_hei_ptr[7]; + din_hei_ptr += 8; + } + } + } + } + return true; +} + +/*wirte result in outputs +* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] +*/ +inline bool write_to_output_c4_int32(const int* din, + int* dout, + int ch_n, + int hei_n, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int height, + int width, + bool flag_relu, + int* trash_ptr) { + if (ch_n != 4 || hei_n <= 0) { + LOG(ERROR) << "ch_n must be equal 4 and hei_n is more than zero"; + return false; + } + int size_c_out = width * height; + + int* doutc0r0 = dout + cs * size_c_out + hs * width + ws; + int* doutc1r0 = doutc0r0 + size_c_out; + int* doutc2r0 = doutc1r0 + size_c_out; + int* doutc3r0 = doutc2r0 + size_c_out; + + const int* ptr_din = din; + + int size_h = (he > height ? height : he) - hs; // size_h == hei_n + + int valid_w = we - ws; + int cnt = valid_w / 4; + + if (we > width) { + cnt--; + } + if (flag_relu) { + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + int* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + int* doutc1_ptr = doutc1r0 + size_w; + int* doutc2_ptr = doutc2r0 + size_w; + int* doutc3_ptr = doutc3r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 3: + doutc1_ptr = trash_ptr; + case 2: + doutc2_ptr = trash_ptr; + case 1: + doutc3_ptr = trash_ptr; + default: + break; + } + } + ptr_din = din + i * valid_w * ch_n; + const int* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + "smax v16.4s, v16.4s, v20.4s \n" /* relu */ + "smax v17.4s, v17.4s, v20.4s \n" /* relu */ + "smax v18.4s, v18.4s, v20.4s \n" /* relu */ + "smax v19.4s, v19.4s, v20.4s \n" /* relu */ + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vmov.u32 q15, #0 @ dump zero\n" + "1: @ main loop\n" + "vtrn.32 q0, q1 @ trans q0, q1 \n" + "vtrn.32 q2, q3 @ trans q2, q3 \n" + "vswp.32 d1, d4 @ swap d1, d4 \n" + "vswp.32 d3, d6 @ swap d3, d6 \n" + + "vmax.s32 q0, q0, q15 @ relu\n" + "vmax.s32 q1, q1, q15 @ relu\n" + "vmax.s32 q2, q2, q15 @ relu\n" + "vmax.s32 q3, q3, q15 @ relu\n" + + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" + "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" + "vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n" + "vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q4", "q15"); +#endif + } + if (we > width) { + int offset = 16 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + int i = we - 4; + for (; i < width; ++i) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0); + *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0); + *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0); + *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0); + din_hei_ptr += 4; + } + } + } + } else { + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + int* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + int* doutc1_ptr = doutc1r0 + size_w; + int* doutc2_ptr = doutc2r0 + size_w; + int* doutc3_ptr = doutc3r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 3: + doutc1_ptr = trash_ptr; + case 2: + doutc2_ptr = trash_ptr; + case 1: + doutc3_ptr = trash_ptr; + default: + break; + } + } + ptr_din = din + i * valid_w * ch_n; + const int* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "1: @ main loop\n" + "vtrn.32 q0, q1 @ trans q0, q1\n" + "vtrn.32 q2, q3 @ trans q2, q3\n" + "vswp.32 d1, d4 @ swap d1, d4 \n" + "vswp.32 d3, d6 @ swap d3, d6 \n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add " + "pointer\n" + + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q4", "q15"); +#endif + } + if (we > width) { + int offset = 16 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + int i = we - 4; + for (; i < width; ++i) { + *(doutc0_ptr++) = din_hei_ptr[0]; + *(doutc1_ptr++) = din_hei_ptr[1]; + *(doutc2_ptr++) = din_hei_ptr[2]; + *(doutc3_ptr++) = din_hei_ptr[3]; + din_hei_ptr += 4; + } + } + } + } + return true; +} + +/*wirte result in outputs --int8, fp32 +* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] +*/ +template +inline bool write_to_output_c4_int32_1(const int* din, + dtype* dout, + int ch_n, + int hei_n, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int height, + int width, + bool flag_relu, + dtype* trash_ptr, + const float* scale, + PrecisionType out_dtype) { + if (ch_n != 4 || hei_n <= 0) { + LOG(ERROR) << "ch_n must be equal 4 and hei_n is more than zero"; + return false; + } + int size_c_out = width * height; + + dtype* doutc0r0 = dout + cs * size_c_out + hs * width + ws; + dtype* doutc1r0 = doutc0r0 + size_c_out; + dtype* doutc2r0 = doutc1r0 + size_c_out; + dtype* doutc3r0 = doutc2r0 + size_c_out; + + const int* ptr_din = din; + + int size_h = (he > height ? height : he) - hs; // size_h == hei_n + + int valid_w = we - ws; + int cnt = valid_w / 4; + + float32x4_t w_scale = vld1q_f32(scale); + // float32x4_t vzero = vdupq_n_f32(0.f); + + if (we > width) { + cnt--; + } + if (out_dtype == PRECISION(kFloat)) { + // int32_to_fp32 + if (flag_relu) { + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + dtype* doutc1_ptr = doutc1r0 + size_w; + dtype* doutc2_ptr = doutc2r0 + size_w; + dtype* doutc3_ptr = doutc3r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 3: + doutc1_ptr = trash_ptr; + case 2: + doutc2_ptr = trash_ptr; + case 1: + doutc3_ptr = trash_ptr; + default: + break; + } + } + ptr_din = din + i * valid_w * ch_n; + const int* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + "smax v16.4s, v16.4s, v20.4s \n" /* relu */ + "smax v17.4s, v17.4s, v20.4s \n" /* relu */ + "smax v18.4s, v18.4s, v20.4s \n" /* relu */ + "smax v19.4s, v19.4s, v20.4s \n" /* relu */ + // int32 --> fp32 + "scvtf v4.4s, v16.4s \n" + "scvtf v5.4s, v17.4s \n" + "scvtf v6.4s, v18.4s \n" + "scvtf v7.4s, v19.4s \n" + // mul + "fmul v16.4s, v4.4s, %[scale].s[0] \n" + "fmul v17.4s, v5.4s, %[scale].s[2] \n" + "fmul v18.4s, v6.4s, %[scale].s[1] \n" + "fmul v19.4s, v7.4s, %[scale].s[3] \n" + // res + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : [scale] "w"(w_scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile( + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vmov.u32 q15, #0 @ dump zero\n" + "1: @ main loop\n" + "vtrn.32 q2, q3 @ trans q0, q1 \n" + "vtrn.32 q4, q5 @ trans q2, q3 \n" + "vswp.32 d5, d8 @ swap d1, d4 \n" + "vswp.32 d7, d10 @ swap d3, d6 \n" + + "vmax.s32 q2, q2, q15 @ relu\n" + "vmax.s32 q3, q3, q15 @ relu\n" + "vmax.s32 q4, q4, q15 @ relu\n" + "vmax.s32 q5, q5, q15 @ relu\n" + + // int32-> fp32 + "vcvt.f32.s32 q6, q2 \n" + "vcvt.f32.s32 q7, q3 \n" + "vcvt.f32.s32 q8, q4 \n" + "vcvt.f32.s32 q9, q5 \n" + + // mul + "vmul.f32 q2, q6, %e[scale][0] \n" + "vmul.f32 q3, q7, %e[scale][1] \n" + "vmul.f32 q4, q8, %f[scale][0] \n" + "vmul.f32 q5, q9, %f[scale][1] \n" + + "vst1.32 {d4-d5}, [%[doutc0r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d6-d7}, [%[doutc1r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d10-d11}, [%[doutc3r0]]! @ store result, add " + "pointer\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : [scale] "w"(w_scale) + : "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } + if (we > width) { + int offset = 16 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + int j = we - 4; + for (; j < width; ++j) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0] * scale[0], 0); + *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1] * scale[1], 0); + *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2] * scale[2], 0); + *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3] * scale[3], 0); + din_hei_ptr += 4; + } + } + } + } else { + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + dtype* doutc1_ptr = doutc1r0 + size_w; + dtype* doutc2_ptr = doutc2r0 + size_w; + dtype* doutc3_ptr = doutc3r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 3: + doutc1_ptr = trash_ptr; + case 2: + doutc2_ptr = trash_ptr; + case 1: + doutc3_ptr = trash_ptr; + default: + break; + } + } + ptr_din = din + i * valid_w * ch_n; + const int* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + // int32 --> fp32 + "scvtf v4.4s, v16.4s \n" + "scvtf v5.4s, v17.4s \n" + "scvtf v6.4s, v18.4s \n" + "scvtf v7.4s, v19.4s \n" + // mul + "fmul v16.4s, v4.4s, %[scale].s[0] \n" + "fmul v17.4s, v5.4s, %[scale].s[2] \n" + "fmul v18.4s, v6.4s, %[scale].s[1] \n" + "fmul v19.4s, v7.4s, %[scale].s[3] \n" + // res + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : [scale] "w"(w_scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile( + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vmov.u32 q15, #0 @ dump zero\n" + "1: @ main loop\n" + "vtrn.32 q2, q3 @ trans q0, q1 \n" + "vtrn.32 q4, q5 @ trans q2, q3 \n" + "vswp.32 d5, d8 @ swap d1, d4 \n" + "vswp.32 d7, d10 @ swap d3, d6 \n" + + // int32-> fp32 + "vcvt.f32.s32 q6, q2 \n" + "vcvt.f32.s32 q7, q3 \n" + "vcvt.f32.s32 q8, q4 \n" + "vcvt.f32.s32 q9, q5 \n" + + // mul + "vmul.f32 q2, q6, %e[scale][0] \n" + "vmul.f32 q3, q7, %e[scale][1] \n" + "vmul.f32 q4, q8, %f[scale][0] \n" + "vmul.f32 q5, q9, %f[scale][1] \n" + + "vst1.32 {d4-d5}, [%[doutc0r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d6-d7}, [%[doutc1r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d10-d11}, [%[doutc3r0]]! @ store result, add " + "pointer\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : [scale] "w"(w_scale) + : "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } + if (we > width) { + int offset = 16 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + int j = we - 4; + for (; j < width; ++j) { + *(doutc0_ptr++) = din_hei_ptr[0] * scale[0]; + *(doutc1_ptr++) = din_hei_ptr[1] * scale[1]; + *(doutc2_ptr++) = din_hei_ptr[2] * scale[2]; + *(doutc3_ptr++) = din_hei_ptr[3] * scale[3]; + din_hei_ptr += 4; + } + } + } + } + + } else if (out_dtype == PRECISION(kInt8)) { + // int32_to_int8 + if (flag_relu) { + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + dtype* doutc1_ptr = doutc1r0 + size_w; + dtype* doutc2_ptr = doutc2r0 + size_w; + dtype* doutc3_ptr = doutc3r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 3: + doutc1_ptr = trash_ptr; + case 2: + doutc2_ptr = trash_ptr; + case 1: + doutc3_ptr = trash_ptr; + default: + break; + } + } + ptr_din = din + i * valid_w * ch_n; + const int* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + "smax v16.4s, v16.4s, v20.4s \n" /* relu */ + "smax v17.4s, v17.4s, v20.4s \n" /* relu */ + "smax v18.4s, v18.4s, v20.4s \n" /* relu */ + "smax v19.4s, v19.4s, v20.4s \n" /* relu */ + // int32 --> fp32 + "scvtf v4.4s, v16.4s \n" + "scvtf v5.4s, v17.4s \n" + "scvtf v6.4s, v18.4s \n" + "scvtf v7.4s, v19.4s \n" + + // mul + "fmul v16.4s, v4.4s, %[scale].s[0] \n" + "fmul v17.4s, v5.4s, %[scale].s[2] \n" + "fmul v18.4s, v6.4s, %[scale].s[1] \n" + "fmul v19.4s, v7.4s, %[scale].s[3] \n" + + // fp32-int32 + "fcvtas v4.4s, v16.4s \n" + "fcvtas v5.4s, v17.4s \n" + "fcvtas v6.4s, v18.4s \n" + "fcvtas v7.4s, v19.4s \n" + + // int32-int16 + "sqxtn v8.4h, v4.4s \n" + "sqxtn v9.4h, v5.4s \n" + "sqxtn v10.4h, v6.4s \n" + "sqxtn v11.4h, v7.4s \n" + + "sqxtn v16.8b, v8.8h \n" + "sqxtn v17.8b, v9.8h \n" + "sqxtn v18.8b, v10.8h \n" + "sqxtn v19.8b, v11.8h \n" + // res + "str s16, [%[doutc0r0]], #4 \n" + "str s17, [%[doutc2r0]], #4 \n" + "str s18, [%[doutc1r0]], #4 \n" + "str s19, [%[doutc3r0]], #4 \n" + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : [scale] "w"(w_scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile( + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vmov.u32 q15, #0 @ dump zero\n" + "1: @ main loop\n" + "vtrn.32 q2, q3 @ trans q0, q1 \n" + "vtrn.32 q4, q5 @ trans q2, q3 \n" + "vswp.32 d5, d8 @ swap d1, d4 \n" + "vswp.32 d7, d10 @ swap d3, d6 \n" + + "vmax.s32 q2, q2, q15 @ relu\n" + "vmax.s32 q3, q3, q15 @ relu\n" + "vmax.s32 q4, q4, q15 @ relu\n" + "vmax.s32 q5, q5, q15 @ relu\n" + + // int32-> fp32 + "vcvt.f32.s32 q6, q2 \n" + "vcvt.f32.s32 q7, q3 \n" + "vcvt.f32.s32 q8, q4 \n" + "vcvt.f32.s32 q9, q5 \n" + + "vmov.f32 q2, #0.5 \n" + + // "vand.i32 q0, %q[vpoff], %q[vpoff] @ set offset, 0.5\n" + "vand.i32 q3, q2, q2 @ set offset, 0.5\n" + "vand.i32 q4, q2, q2 @ set offset, 0.5\n" + "vand.i32 q5, q2, q2 @ set offset, 0.5\n" + + "vcgt.f32 q10, q6, q15 @ get mask > 0, in0\n" + "vcgt.f32 q11, q7, q15 @ get mask > 0, in1\n" + "vcgt.f32 q12, q8, q15 @ get mask > 0, in2\n" + "vcgt.f32 q13, q9, q15 @ get mask > 0, in3\n" + + "vmov.f32 q15, #-0.5 \n" + + "vbif.f32 q2, q15, q10 @ get right offset\n" + "vbif.f32 q3, q15, q11 @ get right offset\n" + "vbif.f32 q4, q15, q12 @ get right offset\n" + "vbif.f32 q5, q15, q13 @ get right offset\n" + + "vmla.f32 q2, q6, %e[scale][0] @ mul scale\n" + "vmla.f32 q3, q7, %e[scale][1] @ mul scale\n" + "vmla.f32 q4, q8, %f[scale][0] @ mul scale\n" + "vmla.f32 q5, q9, %f[scale][1] @ mul scale\n" + + "vcvt.s32.f32 q6, q2 @ cvt to int32\n" + "vcvt.s32.f32 q7, q3 @ cvt to int32\n" + "vcvt.s32.f32 q8, q4 @ cvt to int32\n" + "vcvt.s32.f32 q9, q5 @ cvt to int32\n" + + "vqmovn.s32 d20, q6 @ cnt to int16\n" + "vqmovn.s32 d22, q7 @ cnt to int16\n" + "vqmovn.s32 d24, q8 @ cnt to int16\n" + "vqmovn.s32 d26, q9 @ cnt to int16\n" + + "vqmovn.s16 d8, q10 @ cnt to int8\n" + "vqmovn.s16 d9, q11 @ cnt to int8\n" + "vqmovn.s16 d10, q12 @ cnt to int8\n" + "vqmovn.s16 d11, q13 @ cnt to int8\n" + + "vst1.32 {d8[0]}, [%[doutc0r0]] @ write to output\n" + "vst1.32 {d9[0]}, [%[doutc1r0]] @ write to output\n" + "vst1.32 {d10[0]}, [%[doutc2r0]] @ write to output\n" + "vst1.32 {d11[0]}, [%[doutc3r0]] @ write to output\n" + + "add %[doutc0r0], #4 \n" + "add %[doutc1r0], #4 \n" + "add %[doutc2r0], #4 \n" + "add %[doutc3r0], #4 \n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + "vmov.u32 q15, #0 @ dump zero\n" + + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : [scale] "w"(w_scale) + : "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } + if (we > width) { + int offset = 16 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + int j = we - 4; + for (; j < width; ++j) { + *(doutc0_ptr++) = saturate_cast( + roundf(LITEMAX(din_hei_ptr[0], 0) * scale[0])); + *(doutc1_ptr++) = saturate_cast( + roundf(LITEMAX(din_hei_ptr[1], 0) * scale[1])); + *(doutc2_ptr++) = saturate_cast( + roundf(LITEMAX(din_hei_ptr[2], 0) * scale[2])); + *(doutc3_ptr++) = saturate_cast( + roundf(LITEMAX(din_hei_ptr[3], 0) * scale[3])); + din_hei_ptr += 4; + } + } + } + } else { + for (int i = 0; i < size_h; i++) { // size_h + int size_w = i * width; + dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + dtype* doutc1_ptr = doutc1r0 + size_w; + dtype* doutc2_ptr = doutc2r0 + size_w; + dtype* doutc3_ptr = doutc3r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 3: + doutc1_ptr = trash_ptr; + case 2: + doutc2_ptr = trash_ptr; + case 1: + doutc3_ptr = trash_ptr; + default: + break; + } + } + ptr_din = din + i * valid_w * ch_n; + const int* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ + "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ + // int32 --> fp32 + "scvtf v4.4s, v16.4s \n" + "scvtf v5.4s, v17.4s \n" + "scvtf v6.4s, v18.4s \n" + "scvtf v7.4s, v19.4s \n" + + // mul + "fmul v16.4s, v4.4s, %[scale].s[0] \n" + "fmul v17.4s, v5.4s, %[scale].s[2] \n" + "fmul v18.4s, v6.4s, %[scale].s[1] \n" + "fmul v19.4s, v7.4s, %[scale].s[3] \n" + + // fp32-int32 + "fcvtas v4.4s, v16.4s \n" + "fcvtas v5.4s, v17.4s \n" + "fcvtas v6.4s, v18.4s \n" + "fcvtas v7.4s, v19.4s \n" + + // int32-int16 + "sqxtn v8.4h, v4.4s \n" + "sqxtn v9.4h, v5.4s \n" + "sqxtn v10.4h, v6.4s \n" + "sqxtn v11.4h, v7.4s \n" + + "sqxtn v16.8b, v8.8h \n" + "sqxtn v17.8b, v9.8h \n" + "sqxtn v18.8b, v10.8h \n" + "sqxtn v19.8b, v11.8h \n" + // res + "str s16, [%[doutc0r0]], #4 \n" + "str s17, [%[doutc2r0]], #4 \n" + "str s18, [%[doutc1r0]], #4 \n" + "str s19, [%[doutc3r0]], #4 \n" + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : [scale] "w"(w_scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile( + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vmov.u32 q15, #0 @ dump zero\n" + "1: @ main loop\n" + "vtrn.32 q2, q3 @ trans q0, q1 \n" + "vtrn.32 q4, q5 @ trans q2, q3 \n" + "vswp.32 d5, d8 @ swap d1, d4 \n" + "vswp.32 d7, d10 @ swap d3, d6 \n" + + // int32-> fp32 + "vcvt.f32.s32 q6, q2 \n" + "vcvt.f32.s32 q7, q3 \n" + "vcvt.f32.s32 q8, q4 \n" + "vcvt.f32.s32 q9, q5 \n" + + "vmov.f32 q2, #0.5 \n" + + // "vand.i32 q0, %q[vpoff], %q[vpoff] @ set offset, 0.5\n" + "vand.i32 q3, q2, q2 @ set offset, 0.5\n" + "vand.i32 q4, q2, q2 @ set offset, 0.5\n" + "vand.i32 q5, q2, q2 @ set offset, 0.5\n" + + "vcgt.f32 q10, q6, q15 @ get mask > 0, in0\n" + "vcgt.f32 q11, q7, q15 @ get mask > 0, in1\n" + "vcgt.f32 q12, q8, q15 @ get mask > 0, in2\n" + "vcgt.f32 q13, q9, q15 @ get mask > 0, in3\n" + + "vmov.f32 q15, #-0.5 \n" + + "vbif.f32 q2, q15, q10 @ get right offset\n" + "vbif.f32 q3, q15, q11 @ get right offset\n" + "vbif.f32 q4, q15, q12 @ get right offset\n" + "vbif.f32 q5, q15, q13 @ get right offset\n" + + "vmla.f32 q2, q6, %e[scale][0] @ mul scale\n" + "vmla.f32 q3, q7, %e[scale][1] @ mul scale\n" + "vmla.f32 q4, q8, %f[scale][0] @ mul scale\n" + "vmla.f32 q5, q9, %f[scale][1] @ mul scale\n" + + "vcvt.s32.f32 q6, q2 @ cvt to int32\n" + "vcvt.s32.f32 q7, q3 @ cvt to int32\n" + "vcvt.s32.f32 q8, q4 @ cvt to int32\n" + "vcvt.s32.f32 q9, q5 @ cvt to int32\n" + + "vqmovn.s32 d20, q6 @ cnt to int16\n" + "vqmovn.s32 d22, q7 @ cnt to int16\n" + "vqmovn.s32 d24, q8 @ cnt to int16\n" + "vqmovn.s32 d26, q9 @ cnt to int16\n" + + "vqmovn.s16 d8, q10 @ cnt to int8\n" + "vqmovn.s16 d9, q11 @ cnt to int8\n" + "vqmovn.s16 d10, q12 @ cnt to int8\n" + "vqmovn.s16 d11, q13 @ cnt to int8\n" + + "vst1.32 {d8[0]}, [%[doutc0r0]] @ write to output\n" + "vst1.32 {d9[0]}, [%[doutc1r0]] @ write to output\n" + "vst1.32 {d10[0]}, [%[doutc2r0]] @ write to output\n" + "vst1.32 {d11[0]}, [%[doutc3r0]] @ write to output\n" + + "add %[doutc0r0], #4 \n" + "add %[doutc1r0], #4 \n" + "add %[doutc2r0], #4 \n" + "add %[doutc3r0], #4 \n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vmov.u32 q15, #0 @ dump zero\n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : [scale] "w"(w_scale) + : "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } + if (we > width) { + int offset = 16 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + int j = we - 4; + for (; j < width; ++j) { + *(doutc0_ptr++) = + saturate_cast(roundf(din_hei_ptr[0] * scale[0])); + *(doutc1_ptr++) = + saturate_cast(roundf(din_hei_ptr[1] * scale[1])); + *(doutc2_ptr++) = + saturate_cast(roundf(din_hei_ptr[2] * scale[2])); + *(doutc3_ptr++) = + saturate_cast(roundf(din_hei_ptr[3] * scale[3])); + din_hei_ptr += 4; + } + } + } + } + } else { + LOG(ERROR) << "ERROR: unsupported input data type!!"; + return false; + } + return true; +} + +/*wirte result in outputs +* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w] +*/ +inline bool write_to_output_c8_int32(const int* din, + int* dout, + int ch_n, + int hei_n, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int height, + int width, + bool flag_relu, + int* trash_ptr) { + if (ch_n != 8 || hei_n <= 0) { + LOG(ERROR) << "ch_n must be equal 8 and hei_n is more than zero"; + return false; + } + int size_c_out = width * height; + + int* doutc0r0 = dout + cs * size_c_out + hs * width + ws; + int* doutc1r0 = doutc0r0 + size_c_out; + int* doutc2r0 = doutc1r0 + size_c_out; + int* doutc3r0 = doutc2r0 + size_c_out; + int* doutc4r0 = doutc3r0 + size_c_out; + int* doutc5r0 = doutc4r0 + size_c_out; + int* doutc6r0 = doutc5r0 + size_c_out; + int* doutc7r0 = doutc6r0 + size_c_out; + + const int* ptr_din = din; + + int size_h = (he > height ? height : he) - hs; // size_h == hei_n + + int valid_w = we - ws; + int cnt = valid_w / 4; + + if (we > width) { + cnt--; + } + if (flag_relu) { + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + int* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + int* doutc1_ptr = doutc1r0 + size_w; + int* doutc2_ptr = doutc2r0 + size_w; + int* doutc3_ptr = doutc3r0 + size_w; + int* doutc4_ptr = doutc4r0 + size_w; + int* doutc5_ptr = doutc5r0 + size_w; + int* doutc6_ptr = doutc6r0 + size_w; + int* doutc7_ptr = doutc7r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 7: + doutc1_ptr = trash_ptr; + case 6: + doutc2_ptr = trash_ptr; + case 5: + doutc3_ptr = trash_ptr; + case 4: + doutc4_ptr = trash_ptr; + case 3: + doutc5_ptr = trash_ptr; + case 2: + doutc6_ptr = trash_ptr; + case 1: + doutc7_ptr = trash_ptr; + default: + break; + } + } + ptr_din = din + i * valid_w * ch_n; + const int* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + + "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + + "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ + "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ + "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ + "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + + "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ + "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ + "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ + "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + + "smax v16.4s, v16.4s, v20.4s \n" /*relu*/ + "smax v17.4s, v17.4s, v20.4s \n" /*relu*/ + "smax v18.4s, v18.4s, v20.4s \n" /*relu*/ + "smax v19.4s, v19.4s, v20.4s \n" /*relu*/ + + "smax v8.4s, v8.4s, v20.4s \n" /*relu*/ + "smax v9.4s, v9.4s, v20.4s \n" /*relu*/ + "smax v12.4s, v12.4s, v20.4s \n" /*relu*/ + "smax v13.4s, v13.4s, v20.4s \n" /*relu*/ + + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ + "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ + "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ + "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ + + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + "vmov.s32 q15, #0 @ dump zero\n" + "1: @ main loop\n" + "vtrn.32 q0, q2 @ trans q0, q2 \n" + "vtrn.32 q4, q6 @ trans q4, q6 \n" + "vswp.32 d1, d8 @ swap d1, d8 \n" + "vswp.32 d5, d12 @ swap d5, d12\n" + + "vtrn.32 q1, q3 @ trans q1, q3 \n" + "vtrn.32 q5, q7 @ trans q5, q7 \n" + "vswp.32 d3, d10 @ swap d3, d10\n" + "vswp.32 d7, d14 @ swap d7, d14\n" + + "vmax.s32 q0, q0, q15 @ relu\n" + "vmax.s32 q1, q1, q15 @ relu\n" + "vmax.s32 q2, q2, q15 @ relu\n" + "vmax.s32 q3, q3, q15 @ relu\n" + + "vmax.s32 q4, q4, q15 @ relu\n" + "vmax.s32 q5, q5, q15 @ relu\n" + "vmax.s32 q6, q6, q15 @ relu\n" + "vmax.s32 q7, q7, q15 @ relu\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" + "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer\n" + "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer\n" + "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer\n" + + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + + "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add pointer\n" + "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add pointer\n" + "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add pointer\n" + "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add pointer\n" + + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q4", "q15"); +#endif + } + if (we > width) { + int offset = 32 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + int i = we - 4; + for (; i < width; ++i) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0); + *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0); + *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0); + *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0); + *(doutc4_ptr++) = LITEMAX(din_hei_ptr[4], 0); + *(doutc5_ptr++) = LITEMAX(din_hei_ptr[5], 0); + *(doutc6_ptr++) = LITEMAX(din_hei_ptr[6], 0); + *(doutc7_ptr++) = LITEMAX(din_hei_ptr[7], 0); + din_hei_ptr += 8; + } + } + } + } else { + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + int* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + int* doutc1_ptr = doutc1r0 + size_w; + int* doutc2_ptr = doutc2r0 + size_w; + int* doutc3_ptr = doutc3r0 + size_w; + int* doutc4_ptr = doutc4r0 + size_w; + int* doutc5_ptr = doutc5r0 + size_w; + int* doutc6_ptr = doutc6r0 + size_w; + int* doutc7_ptr = doutc7r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 7: + doutc1_ptr = trash_ptr; + case 6: + doutc2_ptr = trash_ptr; + case 5: + doutc3_ptr = trash_ptr; + case 4: + doutc4_ptr = trash_ptr; + case 3: + doutc5_ptr = trash_ptr; + case 2: + doutc6_ptr = trash_ptr; + case 1: + doutc7_ptr = trash_ptr; + default: + break; + } + } + ptr_din = din + i * valid_w * ch_n; + const int* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + + "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + + "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ + "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ + "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ + "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + + "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ + "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ + "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ + "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ + "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ + "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ + "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ + + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + : "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + "1: @ main loop\n" + "vtrn.32 q0, q2 @ trans q0, q2 \n" + "vtrn.32 q4, q6 @ trans q4, q6 \n" + "vswp.32 d1, d8 @ swap d1, d8 \n" + "vswp.32 d5, d12 @ swap d5, d12\n" + + "vtrn.32 q1, q3 @ trans q1, q3 \n" + "vtrn.32 q5, q7 @ trans q5, q7 \n" + "vswp.32 d3, d10 @ swap d3, d10\n" + "vswp.32 d7, d14 @ swap d7, d14\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" + "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer\n" + "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer\n" + "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer\n" + + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + + "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add pointer\n" + "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add pointer\n" + "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add pointer\n" + "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add pointer\n" + + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q4", "q15"); +#endif + } + if (we > width) { + int offset = 32 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + int i = we - 4; + for (; i < width; ++i) { + *(doutc0_ptr++) = din_hei_ptr[0]; + *(doutc1_ptr++) = din_hei_ptr[1]; + *(doutc2_ptr++) = din_hei_ptr[2]; + *(doutc3_ptr++) = din_hei_ptr[3]; + *(doutc4_ptr++) = din_hei_ptr[4]; + *(doutc5_ptr++) = din_hei_ptr[5]; + *(doutc6_ptr++) = din_hei_ptr[6]; + *(doutc7_ptr++) = din_hei_ptr[7]; + din_hei_ptr += 8; + } + } + } + } + return true; +} + +/*wirte result in outputs--int8, fp32 +* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w] +*/ +template +static bool write_to_output_c8_int32_1(const int* din, + dtype* dout, + int ch_n, + int hei_n, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int height, + int width, + bool flag_relu, + dtype* trash_ptr, + const float* scale, + PrecisionType out_dtype) { + if (ch_n != 8 || hei_n <= 0) { + LOG(ERROR) << "ch_n must be equal 8 and hei_n is more than zero"; + return false; + } + int size_c_out = width * height; + + dtype* doutc0r0 = dout + cs * size_c_out + hs * width + ws; + dtype* doutc1r0 = doutc0r0 + size_c_out; + dtype* doutc2r0 = doutc1r0 + size_c_out; + dtype* doutc3r0 = doutc2r0 + size_c_out; + dtype* doutc4r0 = doutc3r0 + size_c_out; + dtype* doutc5r0 = doutc4r0 + size_c_out; + dtype* doutc6r0 = doutc5r0 + size_c_out; + dtype* doutc7r0 = doutc6r0 + size_c_out; + + const int* ptr_din = din; + + int size_h = (he > height ? height : he) - hs; // size_h == hei_n + + int valid_w = we - ws; + int cnt = valid_w / 4; + + float32x4_t w_scale0 = vld1q_f32(scale); + float32x4_t w_scale1 = vld1q_f32(scale + 4); + + float32x4_t vzero = vdupq_n_f32(0.f); + + if (we > width) { + cnt--; + } + if (out_dtype == PRECISION(kFloat)) { + if (flag_relu) { + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + dtype* doutc1_ptr = doutc1r0 + size_w; + dtype* doutc2_ptr = doutc2r0 + size_w; + dtype* doutc3_ptr = doutc3r0 + size_w; + dtype* doutc4_ptr = doutc4r0 + size_w; + dtype* doutc5_ptr = doutc5r0 + size_w; + dtype* doutc6_ptr = doutc6r0 + size_w; + dtype* doutc7_ptr = doutc7r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 7: + doutc1_ptr = trash_ptr; + case 6: + doutc2_ptr = trash_ptr; + case 5: + doutc3_ptr = trash_ptr; + case 4: + doutc4_ptr = trash_ptr; + case 3: + doutc5_ptr = trash_ptr; + case 2: + doutc6_ptr = trash_ptr; + case 1: + doutc7_ptr = trash_ptr; + default: + break; + } + } + ptr_din = din + i * valid_w * ch_n; + const int* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + + "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + + "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ + "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ + "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ + "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + + "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ + "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ + "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ + "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + + "smax v16.4s, v16.4s, v20.4s \n" /*relu*/ + "smax v17.4s, v17.4s, v20.4s \n" /*relu*/ + "smax v18.4s, v18.4s, v20.4s \n" /*relu*/ + "smax v19.4s, v19.4s, v20.4s \n" /*relu*/ + + "smax v8.4s, v8.4s, v20.4s \n" /*relu*/ + "smax v9.4s, v9.4s, v20.4s \n" /*relu*/ + "smax v12.4s, v12.4s, v20.4s \n" /*relu*/ + "smax v13.4s, v13.4s, v20.4s \n" /*relu*/ + + // int32->fp32 + "scvtf v10.4s, v16.4s \n" + "scvtf v11.4s, v17.4s \n" + "scvtf v14.4s, v18.4s \n" + "scvtf v15.4s, v19.4s \n" + // mul + "fmul v16.4s, v10.4s, %[scale0].s[0] \n" + "fmul v17.4s, v11.4s, %[scale0].s[2] \n" + "fmul v18.4s, v14.4s, %[scale0].s[1] \n" + "fmul v19.4s, v15.4s, %[scale0].s[3] \n" + + "scvtf v10.4s, v8.4s \n" + "scvtf v11.4s, v9.4s \n" + "scvtf v14.4s, v12.4s \n" + "scvtf v15.4s, v13.4s \n" + + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ + + // mul + "fmul v8.4s, v10.4s, %[scale1].s[0] \n" + "fmul v9.4s, v11.4s, %[scale1].s[2] \n" + "fmul v12.4s, v14.4s, %[scale1].s[1] \n" + "fmul v13.4s, v15.4s, %[scale1].s[3] \n" + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ + "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ + "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ + "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ + + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : [scale0] "w"(w_scale0), [scale1] "w"(w_scale1) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + "vmov.s32 q15, #0 @ dump zero\n" + "1: @ main loop\n" + "vmax.s32 q0, q0, q15 @ relu\n" + "vmax.s32 q1, q1, q15 @ relu\n" + "vmax.s32 q2, q2, q15 @ relu\n" + "vmax.s32 q3, q3, q15 @ relu\n" + + "vmax.s32 q4, q4, q15 @ relu\n" + "vmax.s32 q5, q5, q15 @ relu\n" + "vmax.s32 q6, q6, q15 @ relu\n" + "vmax.s32 q7, q7, q15 @ relu\n" + + // int32-> fp32 + "vcvt.f32.s32 q8, q0 \n" + "vcvt.f32.s32 q9, q1 \n" + "vcvt.f32.s32 q10, q2 \n" + "vcvt.f32.s32 q11, q3 \n" + + // mul + "vmul.f32 q0, q8, %q[scale0] \n" + "vmul.f32 q1, q9, %q[scale1] \n" + "vmul.f32 q2, q10, %q[scale0] \n" + "vmul.f32 q3, q11, %q[scale1] \n" + + // int32-> fp32 + "vcvt.f32.s32 q8, q4 \n" + "vcvt.f32.s32 q9, q5 \n" + "vcvt.f32.s32 q10, q6 \n" + "vcvt.f32.s32 q11, q7 \n" + + // mul + "vmul.f32 q4, q8, %q[scale0] \n" + "vmul.f32 q5, q9, %q[scale1] \n" + "vmul.f32 q6, q10, %q[scale0] \n" + "vmul.f32 q7, q11, %q[scale1] \n" + + "vtrn.32 q0, q2 @ trans q0, q2 \n" + "vtrn.32 q4, q6 @ trans q4, q6 \n" + "vswp.32 d1, d8 @ swap d1, d8 \n" + "vswp.32 d5, d12 @ swap d5, d12\n" + + "vtrn.32 q1, q3 @ trans q1, q3 \n" + "vtrn.32 q5, q7 @ trans q5, q7 \n" + "vswp.32 d3, d10 @ swap d3, d10\n" + "vswp.32 d7, d14 @ swap d7, d14\n" + + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" + "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer\n" + "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add " + "pointer\n" + + "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer\n" + "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer\n" + "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add " + "pointer\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : [scale0] "w"(w_scale0), [scale1] "w"(w_scale1) + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q15"); +#endif + } + if (we > width) { + int offset = 32 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + int i = we - 4; + for (; i < width; ++i) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0] * scale[0], 0); + *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1] * scale[1], 0); + *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2] * scale[2], 0); + *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3] * scale[3], 0); + *(doutc4_ptr++) = LITEMAX(din_hei_ptr[4] * scale[4], 0); + *(doutc5_ptr++) = LITEMAX(din_hei_ptr[5] * scale[5], 0); + *(doutc6_ptr++) = LITEMAX(din_hei_ptr[6] * scale[6], 0); + *(doutc7_ptr++) = LITEMAX(din_hei_ptr[7] * scale[7], 0); + din_hei_ptr += 8; + } + } + } + } else { + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + dtype* doutc1_ptr = doutc1r0 + size_w; + dtype* doutc2_ptr = doutc2r0 + size_w; + dtype* doutc3_ptr = doutc3r0 + size_w; + dtype* doutc4_ptr = doutc4r0 + size_w; + dtype* doutc5_ptr = doutc5r0 + size_w; + dtype* doutc6_ptr = doutc6r0 + size_w; + dtype* doutc7_ptr = doutc7r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 7: + doutc1_ptr = trash_ptr; + case 6: + doutc2_ptr = trash_ptr; + case 5: + doutc3_ptr = trash_ptr; + case 4: + doutc4_ptr = trash_ptr; + case 3: + doutc5_ptr = trash_ptr; + case 2: + doutc6_ptr = trash_ptr; + case 1: + doutc7_ptr = trash_ptr; + default: + break; + } + } + ptr_din = din + i * valid_w * ch_n; + const int* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + + "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + + "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ + "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ + "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ + "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + + "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ + "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ + "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ + "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + + // int32->fp32 + "scvtf v10.4s, v16.4s \n" + "scvtf v11.4s, v17.4s \n" + "scvtf v14.4s, v18.4s \n" + "scvtf v15.4s, v19.4s \n" + // mul + "fmul v16.4s, v10.4s, %[scale0].s[0] \n" + "fmul v17.4s, v11.4s, %[scale0].s[2] \n" + "fmul v18.4s, v14.4s, %[scale0].s[1] \n" + "fmul v19.4s, v15.4s, %[scale0].s[3] \n" + + "scvtf v10.4s, v8.4s \n" + "scvtf v11.4s, v9.4s \n" + "scvtf v14.4s, v12.4s \n" + "scvtf v15.4s, v13.4s \n" + + "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ + "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ + "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ + "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ + + // mul + "fmul v8.4s, v10.4s, %[scale1].s[0] \n" + "fmul v9.4s, v11.4s, %[scale1].s[2] \n" + "fmul v12.4s, v14.4s, %[scale1].s[1] \n" + "fmul v13.4s, v15.4s, %[scale1].s[3] \n" + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ + "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ + "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ + "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ + + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : [scale0] "w"(w_scale0), [scale1] "w"(w_scale1) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + "vmov.s32 q15, #0 @ dump zero\n" + "1: @ main loop\n" + // int32-> fp32 + "vcvt.f32.s32 q8, q0 \n" + "vcvt.f32.s32 q9, q1 \n" + "vcvt.f32.s32 q10, q2 \n" + "vcvt.f32.s32 q11, q3 \n" + + // mul + "vmul.f32 q0, q8, %q[scale0] \n" + "vmul.f32 q1, q9, %q[scale1] \n" + "vmul.f32 q2, q10, %q[scale0] \n" + "vmul.f32 q3, q11, %q[scale1] \n" + + // int32-> fp32 + "vcvt.f32.s32 q8, q4 \n" + "vcvt.f32.s32 q9, q5 \n" + "vcvt.f32.s32 q10, q6 \n" + "vcvt.f32.s32 q11, q7 \n" + + // mul + "vmul.f32 q4, q8, %q[scale0] \n" + "vmul.f32 q5, q9, %q[scale1] \n" + "vmul.f32 q6, q10, %q[scale0] \n" + "vmul.f32 q7, q11, %q[scale1] \n" + + "vtrn.32 q0, q2 @ trans q0, q2 \n" + "vtrn.32 q4, q6 @ trans q4, q6 \n" + "vswp.32 d1, d8 @ swap d1, d8 \n" + "vswp.32 d5, d12 @ swap d5, d12\n" + + "vtrn.32 q1, q3 @ trans q1, q3 \n" + "vtrn.32 q5, q7 @ trans q5, q7 \n" + "vswp.32 d3, d10 @ swap d3, d10\n" + "vswp.32 d7, d14 @ swap d7, d14\n" + + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" + "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer\n" + "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add " + "pointer\n" + + "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer\n" + "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer\n" + "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add " + "pointer\n" + "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add " + "pointer\n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" + "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : [scale0] "w"(w_scale0), [scale1] "w"(w_scale1) + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q15"); +#endif + } + if (we > width) { + int offset = 32 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + int i = we - 4; + for (; i < width; ++i) { + *(doutc0_ptr++) = din_hei_ptr[0] * scale[0]; + *(doutc1_ptr++) = din_hei_ptr[1] * scale[1]; + *(doutc2_ptr++) = din_hei_ptr[2] * scale[2]; + *(doutc3_ptr++) = din_hei_ptr[3] * scale[3]; + *(doutc4_ptr++) = din_hei_ptr[4] * scale[4]; + *(doutc5_ptr++) = din_hei_ptr[5] * scale[5]; + *(doutc6_ptr++) = din_hei_ptr[6] * scale[6]; + *(doutc7_ptr++) = din_hei_ptr[7] * scale[7]; + din_hei_ptr += 8; + } + } + } + } + } else if (out_dtype == PRECISION(kInt8)) { + // int32_to_int8 + float32x4_t vpoff = vdupq_n_f32(0.5f); + float32x4_t vnoff = vdupq_n_f32(-0.5f); + if (flag_relu) { + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + dtype* doutc1_ptr = doutc1r0 + size_w; + dtype* doutc2_ptr = doutc2r0 + size_w; + dtype* doutc3_ptr = doutc3r0 + size_w; + dtype* doutc4_ptr = doutc4r0 + size_w; + dtype* doutc5_ptr = doutc5r0 + size_w; + dtype* doutc6_ptr = doutc6r0 + size_w; + dtype* doutc7_ptr = doutc7r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 7: + doutc1_ptr = trash_ptr; + case 6: + doutc2_ptr = trash_ptr; + case 5: + doutc3_ptr = trash_ptr; + case 4: + doutc4_ptr = trash_ptr; + case 3: + doutc5_ptr = trash_ptr; + case 2: + doutc6_ptr = trash_ptr; + case 1: + doutc7_ptr = trash_ptr; + default: + break; + } + } + ptr_din = din + i * valid_w * ch_n; + const int* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + // "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + + "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + + "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ + "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ + "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ + "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + + "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ + "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ + "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ + "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + + "smax v16.4s, v16.4s, %[vzero].4s \n" /*relu*/ + "smax v17.4s, v17.4s, %[vzero].4s \n" /*relu*/ + "smax v18.4s, v18.4s, %[vzero].4s \n" /*relu*/ + "smax v19.4s, v19.4s, %[vzero].4s \n" /*relu*/ + + "smax v8.4s, v8.4s, %[vzero].4s \n" /*relu*/ + "smax v9.4s, v9.4s, %[vzero].4s \n" /*relu*/ + "smax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ + "smax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ + + // int32 --> fp32 + "scvtf v10.4s, v16.4s \n" + "scvtf v11.4s, v17.4s \n" + "scvtf v14.4s, v18.4s \n" + "scvtf v15.4s, v19.4s \n" + + "scvtf v20.4s, v8.4s \n" + "scvtf v21.4s, v9.4s \n" + "scvtf v22.4s, v12.4s \n" + "scvtf v23.4s, v13.4s \n" + + // mul + "fmul v16.4s, v10.4s, %[scale0].s[0] \n" + "fmul v17.4s, v11.4s, %[scale0].s[2] \n" + "fmul v18.4s, v14.4s, %[scale0].s[1] \n" + "fmul v19.4s, v15.4s, %[scale0].s[3] \n" + + "fmul v8.4s, v20.4s, %[scale1].s[0] \n" + "fmul v9.4s, v21.4s, %[scale1].s[2] \n" + "fmul v12.4s, v22.4s, %[scale1].s[1] \n" + "fmul v13.4s, v23.4s, %[scale1].s[3] \n" + + // fp32-int32 + "fcvtas v10.4s, v16.4s \n" + "fcvtas v11.4s, v17.4s \n" + "fcvtas v14.4s, v18.4s \n" + "fcvtas v15.4s, v19.4s \n" + + "fcvtas v20.4s, v8.4s \n" + "fcvtas v21.4s, v9.4s \n" + "fcvtas v22.4s, v12.4s \n" + "fcvtas v23.4s, v13.4s \n" + + // int32-int16 + "sqxtn v16.4h, v10.4s \n" + "sqxtn v17.4h, v11.4s \n" + "sqxtn v18.4h, v14.4s \n" + "sqxtn v19.4h, v15.4s \n" + + "sqxtn v8.4h, v20.4s \n" + "sqxtn v9.4h, v21.4s \n" + "sqxtn v12.4h, v22.4s \n" + "sqxtn v13.4h, v23.4s \n" + + // int16-int8 + "sqxtn v10.8b, v16.8h \n" + "sqxtn v11.8b, v17.8h \n" + "sqxtn v14.8b, v18.8h \n" + "sqxtn v15.8b, v19.8h \n" + + "sqxtn v20.8b, v8.8h \n" + "sqxtn v21.8b, v9.8h \n" + "sqxtn v22.8b, v12.8h \n" + "sqxtn v23.8b, v13.8h \n" + + "str s10, [%[doutc0r0]], #4 \n" /* store c0r0*/ + "str s11, [%[doutc2r0]], #4 \n" /* store c2r0*/ + "str s14, [%[doutc1r0]], #4 \n" /* store c1r0*/ + "str s15, [%[doutc3r0]], #4 \n" /* store c3r0*/ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "str s20, [%[doutc4r0]], #4 \n" /* store c0r0*/ + "str s21, [%[doutc6r0]], #4 \n" /* store c2r0*/ + "str s22, [%[doutc5r0]], #4 \n" /* store c1r0*/ + "str s23, [%[doutc7r0]], #4 \n" /* store c3r0*/ + + "bne 1b \n" /* jump to main loop*/ + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : + [scale0] "w"(w_scale0), [scale1] "w"(w_scale1), [vzero] "w"(vzero) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23"); +#else + asm volatile( + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + + "1: @ main loop\n" + "vmax.s32 q4, q4, %q[vzero] @ relu\n" + "vmax.s32 q5, q5, %q[vzero] @ relu\n" + "vmax.s32 q6, q6, %q[vzero] @ relu\n" + "vmax.s32 q7, q7, %q[vzero] @ relu\n" + + // int32-> fp32 + "vmov.f32 q15, #0.5 \n" + "vcvt.f32.s32 q8, q4 \n" + "vcvt.f32.s32 q9, q5 \n" + "vcvt.f32.s32 q10, q6 \n" + "vcvt.f32.s32 q11, q7 \n" + + "vand.i32 q4, q15, q15 @ set offset, 0.5\n" + "vand.i32 q5, q15, q15 @ set offset, 0.5\n" + "vand.i32 q6, q15, q15 @ set offset, 0.5\n" + "vand.i32 q7, q15, q15 @ set offset, 0.5\n" + + "vmov.f32 q15, #-0.5 \n" + + "vcgt.f32 q12, q8, %q[vzero] @ get mask > 0, in0\n" + "vcgt.f32 q13, q9, %q[vzero] @ get mask > 0, in0\n" + "vcgt.f32 q14, q10, %q[vzero] @ get mask > 0, in0\n" + "vcgt.f32 q3, q11, %q[vzero] @ get mask > 0, in0\n" + + "vbif.f32 q4, q15, q12 @ get right offset\n" + "vbif.f32 q5, q15, q13 @ get right offset\n" + "vbif.f32 q6, q15, q14 @ get right offset\n" + "vbif.f32 q7, q15, q3 @ get right offset\n" + + "vld1.32 {d24-d27}, [%[ptr_din]]! @load data \n" + "vld1.32 {d28-d29}, [%[ptr_din]]! @load data \n" + "vld1.32 {d6-d7}, [%[ptr_din]]! @load data \n" + + "vmla.f32 q4, q8, %q[scale0] @ mul scale\n" + "vmla.f32 q5, q9, %q[scale1] @ mul scale\n" + "vmla.f32 q6, q10, %q[scale0] @ mul scale\n" + "vmla.f32 q7, q11, %q[scale1] @ mul scale\n" + + "vmax.s32 q12, q12, %q[vzero] @ relu\n" + "vmax.s32 q13, q13, %q[vzero] @ relu\n" + "vmax.s32 q14, q14, %q[vzero] @ relu\n" + "vmax.s32 q3, q3, %q[vzero] @ relu\n" + + "vcvt.s32.f32 q8, q4 @ cvt to int32\n" + "vcvt.s32.f32 q9, q5 @ cvt to int32\n" + "vcvt.s32.f32 q10, q6 @ cvt to int32\n" + "vcvt.s32.f32 q11, q7 @ cvt to int32\n" + + "vqmovn.s32 d8, q8 @ cnt to int16\n" + "vqmovn.s32 d10, q9 @ cnt to int16\n" + "vqmovn.s32 d12, q10 @ cnt to int16\n" + "vqmovn.s32 d14, q11 @ cnt to int16\n" + + "vqmovn.s16 d16, q4 @ cnt to int8\n" + "vqmovn.s16 d17, q5 @ cnt to int8\n" + "vqmovn.s16 d18, q6 @ cnt to int8\n" + "vqmovn.s16 d19, q7 @ cnt to int8\n" + + "vmov.f32 q15, #0.5 \n" + + "vcvt.f32.s32 q4, q12 \n" + "vcvt.f32.s32 q5, q13 \n" + "vcvt.f32.s32 q6, q14 \n" + "vcvt.f32.s32 q7, q3 \n" + + "vand.i32 q12, q15, q15 @ set offset, 0.5\n" + "vand.i32 q13, q15, q15 @ set offset, 0.5\n" + "vand.i32 q14, q15, q15 @ set offset, 0.5\n" + "vand.i32 q3, q15, q15 @ set offset, 0.5\n" + + "vmov.f32 q15, #-0.5 \n" + + "vcgt.f32 q10, q4, %q[vzero] @ get mask > 0, in0\n" + "vcgt.f32 q11, q5, %q[vzero] @ get mask > 0, in0\n" + + "vbif.f32 q12, q15, q10 @ get right offset\n" + "vbif.f32 q13, q15, q11 @ get right offset\n" + + "vcgt.f32 q10, q6, %q[vzero] @ get mask > 0, in0\n" + "vcgt.f32 q11, q7, %q[vzero] @ get mask > 0, in0\n" + + "vbif.f32 q14, q15, q10 @ get right offset\n" + "vbif.f32 q3, q15, q11 @ get right offset\n" + + "vmla.f32 q12, q4, %q[scale0] @ mul scale\n" + "vmla.f32 q13, q5, %q[scale1] @ mul scale\n" + "vmla.f32 q14, q6, %q[scale0] @ mul scale\n" + "vmla.f32 q3, q7, %q[scale1] @ mul scale\n" + + "vcvt.s32.f32 q4, q12 @ cvt to int32\n" + "vcvt.s32.f32 q5, q13 @ cvt to int32\n" + "vcvt.s32.f32 q6, q14 @ cvt to int32\n" + "vcvt.s32.f32 q7, q3 @ cvt to int32\n" + + "vqmovn.s32 d24, q4 @ cnt to int16\n" + "vqmovn.s32 d26, q5 @ cnt to int16\n" + "vqmovn.s32 d28, q6 @ cnt to int16\n" + "vqmovn.s32 d6, q7 @ cnt to int16\n" + + "vqmovn.s16 d20, q12 @ cnt to int8\n" + "vqmovn.s16 d21, q13 @ cnt to int8\n" + "vqmovn.s16 d22, q14 @ cnt to int8\n" + "vqmovn.s16 d23, q3 @ cnt to int8\n" + + "vtrn.8 d16, d18 @ trans q0, q2 \n" + "vtrn.8 d20, d22 @ trans q4, q6 \n" + "vtrn.16 d16, d20 @ trans q0, q2 \n" + "vtrn.16 d18, d22 @ trans q4, q6 \n" + + "vtrn.8 d17, d19 @ trans q0, q2 \n" + "vtrn.8 d21, d23 @ trans q4, q6 \n" + "vtrn.16 d17, d21 @ trans q0, q2 \n" + "vtrn.16 d19, d23 @ trans q4, q6 \n" + + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + + "vst1.32 {d16[0]}, [%[doutc0r0]] @ store result, add " + "pointer\n" + "vst1.32 {d18[0]}, [%[doutc1r0]] @ store result, add " + "pointer\n" + "vst1.32 {d20[0]}, [%[doutc2r0]] @ store result, add " + "pointer\n" + "vst1.32 {d22[0]}, [%[doutc3r0]] @ store result, add " + "pointer\n" + + "vst1.32 {d17[0]}, [%[doutc4r0]] @ store result, add " + "pointer\n" + "vst1.32 {d19[0]}, [%[doutc5r0]] @ store result, add " + "pointer\n" + "vst1.32 {d21[0]}, [%[doutc6r0]] @ store result, add " + "pointer\n" + "vst1.32 {d23[0]}, [%[doutc7r0]] @ store result, add " + "pointer\n" + + "add %[doutc0r0], #4 @ add \n" + "add %[doutc1r0], #4 @ add \n" + "add %[doutc2r0], #4 @ add \n" + "add %[doutc3r0], #4 @ add \n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "add %[doutc4r0], #4 @ add \n" + "add %[doutc5r0], #4 @ add \n" + "add %[doutc6r0], #4 @ add \n" + "add %[doutc7r0], #4 @ add \n" + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + [scale0] "w"(w_scale0), [scale1] "w"(w_scale1), [vzero] "w"(vzero) + : "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } + if (we > width) { + int offset = 32 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + int i = we - 4; + for (; i < width; ++i) { + *(doutc0_ptr++) = saturate_cast( + roundf(LITEMAX(din_hei_ptr[0] * scale[0], 0))); + *(doutc1_ptr++) = saturate_cast( + roundf(LITEMAX(din_hei_ptr[1] * scale[1], 0))); + *(doutc2_ptr++) = saturate_cast( + roundf(LITEMAX(din_hei_ptr[2] * scale[2], 0))); + *(doutc3_ptr++) = saturate_cast( + roundf(LITEMAX(din_hei_ptr[3] * scale[3], 0))); + *(doutc4_ptr++) = saturate_cast( + roundf(LITEMAX(din_hei_ptr[4] * scale[4], 0))); + *(doutc5_ptr++) = saturate_cast( + roundf(LITEMAX(din_hei_ptr[5] * scale[5], 0))); + *(doutc6_ptr++) = saturate_cast( + roundf(LITEMAX(din_hei_ptr[6] * scale[6], 0))); + *(doutc7_ptr++) = saturate_cast( + roundf(LITEMAX(din_hei_ptr[7] * scale[7], 0))); + din_hei_ptr += 8; + } + } + } + } else { + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + dtype* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + dtype* doutc1_ptr = doutc1r0 + size_w; + dtype* doutc2_ptr = doutc2r0 + size_w; + dtype* doutc3_ptr = doutc3r0 + size_w; + dtype* doutc4_ptr = doutc4r0 + size_w; + dtype* doutc5_ptr = doutc5r0 + size_w; + dtype* doutc6_ptr = doutc6r0 + size_w; + dtype* doutc7_ptr = doutc7r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 7: + doutc1_ptr = trash_ptr; + case 6: + doutc2_ptr = trash_ptr; + case 5: + doutc3_ptr = trash_ptr; + case 4: + doutc4_ptr = trash_ptr; + case 3: + doutc5_ptr = trash_ptr; + case 2: + doutc6_ptr = trash_ptr; + case 1: + doutc7_ptr = trash_ptr; + default: + break; + } + } + ptr_din = din + i * valid_w * ch_n; + const int* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + // "movi v20.4s, #0 \n" /* for relu */ + "1: \n" /* main loop*/ + "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ + "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ + "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + + "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ + "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ + "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + + "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ + "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ + "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ + "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ + "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ + + "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ + "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ + "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ + "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ + "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ + + // int32 --> fp32 + "scvtf v10.4s, v16.4s \n" + "scvtf v11.4s, v17.4s \n" + "scvtf v14.4s, v18.4s \n" + "scvtf v15.4s, v19.4s \n" + + "scvtf v20.4s, v8.4s \n" + "scvtf v21.4s, v9.4s \n" + "scvtf v22.4s, v12.4s \n" + "scvtf v23.4s, v13.4s \n" + + // mul + "fmul v16.4s, v10.4s, %[scale0].s[0] \n" + "fmul v17.4s, v11.4s, %[scale0].s[2] \n" + "fmul v18.4s, v14.4s, %[scale0].s[1] \n" + "fmul v19.4s, v15.4s, %[scale0].s[3] \n" + + "fmul v8.4s, v20.4s, %[scale1].s[0] \n" + "fmul v9.4s, v21.4s, %[scale1].s[2] \n" + "fmul v12.4s, v22.4s, %[scale1].s[1] \n" + "fmul v13.4s, v23.4s, %[scale1].s[3] \n" + + // fp32-int32 + "fcvtas v10.4s, v16.4s \n" + "fcvtas v11.4s, v17.4s \n" + "fcvtas v14.4s, v18.4s \n" + "fcvtas v15.4s, v19.4s \n" + + "fcvtas v20.4s, v8.4s \n" + "fcvtas v21.4s, v9.4s \n" + "fcvtas v22.4s, v12.4s \n" + "fcvtas v23.4s, v13.4s \n" + + // int32-int16 + "sqxtn v16.4h, v10.4s \n" + "sqxtn v17.4h, v11.4s \n" + "sqxtn v18.4h, v14.4s \n" + "sqxtn v19.4h, v15.4s \n" + + "sqxtn v8.4h, v20.4s \n" + "sqxtn v9.4h, v21.4s \n" + "sqxtn v12.4h, v22.4s \n" + "sqxtn v13.4h, v23.4s \n" + + // int16-int8 + "sqxtn v10.8b, v16.8h \n" + "sqxtn v11.8b, v17.8h \n" + "sqxtn v14.8b, v18.8h \n" + "sqxtn v15.8b, v19.8h \n" + + "sqxtn v20.8b, v8.8h \n" + "sqxtn v21.8b, v9.8h \n" + "sqxtn v22.8b, v12.8h \n" + "sqxtn v23.8b, v13.8h \n" + + "str s10, [%[doutc0r0]], #4 \n" /* store c0r0*/ + "str s11, [%[doutc2r0]], #4 \n" /* store c2r0*/ + "str s14, [%[doutc1r0]], #4 \n" /* store c1r0*/ + "str s15, [%[doutc3r0]], #4 \n" /* store c3r0*/ + + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ + "str s20, [%[doutc4r0]], #4 \n" /* store c0r0*/ + "str s21, [%[doutc6r0]], #4 \n" /* store c2r0*/ + "str s22, [%[doutc5r0]], #4 \n" /* store c1r0*/ + "str s23, [%[doutc7r0]], #4 \n" /* store c3r0*/ + + "bne 1b \n" /* jump to main loop*/ + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_hei_ptr) + : [scale0] "w"(w_scale0), [scale1] "w"(w_scale1) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23"); +#else + asm volatile( + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + + "1: @ main loop\n" + // int32-> fp32 + "vmov.f32 q15, #0.5 \n" + "vcvt.f32.s32 q8, q4 \n" + "vcvt.f32.s32 q9, q5 \n" + "vcvt.f32.s32 q10, q6 \n" + "vcvt.f32.s32 q11, q7 \n" + + "vand.i32 q4, q15, q15 @ set offset, 0.5\n" + "vand.i32 q5, q4, q4 @ set offset, 0.5\n" + "vand.i32 q6, q4, q4 @ set offset, 0.5\n" + "vand.i32 q7, q4, q4 @ set offset, 0.5\n" + + "vmov.f32 q15, #-0.5 \n" + + "vcgt.f32 q12, q8, %q[vzero] @ get mask > 0, in0\n" + "vcgt.f32 q13, q9, %q[vzero] @ get mask > 0, in0\n" + "vcgt.f32 q14, q10, %q[vzero] @ get mask > 0, in0\n" + "vcgt.f32 q3, q11, %q[vzero] @ get mask > 0, in0\n" + + "vbif.f32 q4, q15, q12 @ get right offset\n" + "vbif.f32 q5, q15, q13 @ get right offset\n" + "vbif.f32 q6, q15, q14 @ get right offset\n" + "vbif.f32 q7, q15, q3 @ get right offset\n" + + "vld1.32 {d24-d27}, [%[ptr_din]]! @load data \n" + "vld1.32 {d28-d29}, [%[ptr_din]]! @load data \n" + "vld1.32 {d6-d7}, [%[ptr_din]]! @load data \n" + + "vmla.f32 q4, q8, %q[scale0] @ mul scale\n" + "vmla.f32 q5, q9, %q[scale1] @ mul scale\n" + "vmla.f32 q6, q10, %q[scale0] @ mul scale\n" + "vmla.f32 q7, q11, %q[scale1] @ mul scale\n" + + "vcvt.s32.f32 q8, q4 @ cvt to int32\n" + "vcvt.s32.f32 q9, q5 @ cvt to int32\n" + "vcvt.s32.f32 q10, q6 @ cvt to int32\n" + "vcvt.s32.f32 q11, q7 @ cvt to int32\n" + + "vqmovn.s32 d8, q8 @ cnt to int16\n" + "vqmovn.s32 d10, q9 @ cnt to int16\n" + "vqmovn.s32 d12, q10 @ cnt to int16\n" + "vqmovn.s32 d14, q11 @ cnt to int16\n" + + "vqmovn.s16 d16, q4 @ cnt to int8\n" + "vqmovn.s16 d17, q5 @ cnt to int8\n" + "vqmovn.s16 d18, q6 @ cnt to int8\n" + "vqmovn.s16 d19, q7 @ cnt to int8\n" + + "vmov.f32 q15, #0.5 \n" + + "vcvt.f32.s32 q4, q12 \n" + "vcvt.f32.s32 q5, q13 \n" + "vcvt.f32.s32 q6, q14 \n" + "vcvt.f32.s32 q7, q3 \n" + + "vand.i32 q12, q15, q15 @ set offset, 0.5\n" + "vand.i32 q13, q12, q12 @ set offset, 0.5\n" + "vand.i32 q14, q12, q12 @ set offset, 0.5\n" + "vand.i32 q3, q12, q12 @ set offset, 0.5\n" + + "vmov.f32 q15, #-0.5 \n" + + "vcgt.f32 q10, q4, %q[vzero] @ get mask > 0, in0\n" + "vcgt.f32 q11, q5, %q[vzero] @ get mask > 0, in0\n" + + "vbif.f32 q12, q15, q10 @ get right offset\n" + "vbif.f32 q13, q15, q11 @ get right offset\n" + + "vcgt.f32 q10, q6, %q[vzero] @ get mask > 0, in0\n" + "vcgt.f32 q11, q7, %q[vzero] @ get mask > 0, in0\n" + + "vbif.f32 q14, q15, q10 @ get right offset\n" + "vbif.f32 q3, q15, q11 @ get right offset\n" + + "vmla.f32 q12, q4, %q[scale0] @ mul scale\n" + "vmla.f32 q13, q5, %q[scale1] @ mul scale\n" + "vmla.f32 q14, q6, %q[scale0] @ mul scale\n" + "vmla.f32 q3, q7, %q[scale1] @ mul scale\n" + + "vcvt.s32.f32 q4, q12 @ cvt to int32\n" + "vcvt.s32.f32 q5, q13 @ cvt to int32\n" + "vcvt.s32.f32 q6, q14 @ cvt to int32\n" + "vcvt.s32.f32 q7, q3 @ cvt to int32\n" + + "vqmovn.s32 d24, q4 @ cnt to int16\n" + "vqmovn.s32 d26, q5 @ cnt to int16\n" + "vqmovn.s32 d28, q6 @ cnt to int16\n" + "vqmovn.s32 d6, q7 @ cnt to int16\n" + + "vqmovn.s16 d20, q12 @ cnt to int8\n" + "vqmovn.s16 d21, q13 @ cnt to int8\n" + "vqmovn.s16 d22, q14 @ cnt to int8\n" + "vqmovn.s16 d23, q3 @ cnt to int8\n" + + "vtrn.8 d16, d18 @ trans q0, q2 \n" + "vtrn.8 d20, d22 @ trans q4, q6 \n" + "vtrn.16 d16, d20 @ trans q0, q2 \n" + "vtrn.16 d18, d22 @ trans q4, q6 \n" + + "vtrn.8 d17, d19 @ trans q0, q2 \n" + "vtrn.8 d21, d23 @ trans q4, q6 \n" + "vtrn.16 d17, d21 @ trans q0, q2 \n" + "vtrn.16 d19, d23 @ trans q4, q6 \n" + + "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" + "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" + + "vst1.32 {d16[0]}, [%[doutc0r0]] @ store result, add " + "pointer\n" + "vst1.32 {d18[0]}, [%[doutc1r0]] @ store result, add " + "pointer\n" + "vst1.32 {d20[0]}, [%[doutc2r0]] @ store result, add " + "pointer\n" + "vst1.32 {d22[0]}, [%[doutc3r0]] @ store result, add " + "pointer\n" + + "vst1.32 {d17[0]}, [%[doutc4r0]] @ store result, add " + "pointer\n" + "vst1.32 {d19[0]}, [%[doutc5r0]] @ store result, add " + "pointer\n" + "vst1.32 {d21[0]}, [%[doutc6r0]] @ store result, add " + "pointer\n" + "vst1.32 {d23[0]}, [%[doutc7r0]] @ store result, add " + "pointer\n" + + "add %[doutc0r0], #4 @ add \n" + "add %[doutc1r0], #4 @ add \n" + "add %[doutc2r0], #4 @ add \n" + "add %[doutc3r0], #4 @ add \n" + + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" + + "add %[doutc4r0], #4 @ add \n" + "add %[doutc5r0], #4 @ add \n" + "add %[doutc6r0], #4 @ add \n" + "add %[doutc7r0], #4 @ add \n" + "bne 1b @ jump to main loop\n" + + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop) + : + [scale0] "w"(w_scale0), [scale1] "w"(w_scale1), [vzero] "w"(vzero) + : "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } + if (we > width) { + int offset = 32 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + int i = we - 4; + for (; i < width; ++i) { + *(doutc0_ptr++) = + saturate_cast(roundf(din_hei_ptr[0] * scale[0])); + *(doutc1_ptr++) = + saturate_cast(roundf(din_hei_ptr[1] * scale[1])); + *(doutc2_ptr++) = + saturate_cast(roundf(din_hei_ptr[2] * scale[2])); + *(doutc3_ptr++) = + saturate_cast(roundf(din_hei_ptr[3] * scale[3])); + *(doutc4_ptr++) = + saturate_cast(roundf(din_hei_ptr[4] * scale[4])); + *(doutc5_ptr++) = + saturate_cast(roundf(din_hei_ptr[5] * scale[5])); + *(doutc6_ptr++) = + saturate_cast(roundf(din_hei_ptr[6] * scale[6])); + *(doutc7_ptr++) = + saturate_cast(roundf(din_hei_ptr[7] * scale[7])); + din_hei_ptr += 8; + } + } + } + } + } else { + LOG(ERROR) << "ERROR: unsupported input data type!!"; + return false; + } + return true; +} + +/* +* din [n, hei_n, ch_n, w] +* dout [n, ch_n, hei_n, w] +*/ +template +static bool write_to_output_numc(const dtype* din, + dtype* dout, + int ch_n, + int hei_n, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int height, + int width, + bool flag_relu, + dtype* trash_ptr) { + if (ch_n <= 0 || hei_n <= 0) { + LOG(ERROR) << "ch_n and hei_n are more than zero"; + return false; + } + int size_c_out = width * height; + + dtype* out_array[ch_n]; + out_array[0] = dout + cs * size_c_out + hs * width + ws; + + for (int i = 1; i < ch_n; i++) { + out_array[i] = out_array[i - 1] + size_c_out; + } + + const dtype* ptr_din = din; + + int cremain = ce - channel; + for (int i = 1; i <= cremain; i++) { + out_array[ch_n - i] = trash_ptr; + } + + int size_h = (he > height ? height : he) - hs; // size_h == hei_n + + int size_w = we - ws; + + int size_c_in = ch_n * size_w; + + size_t valid_w_byte = width * sizeof(dtype); + + if (flag_relu) { + for (int h = 0; h < size_h; h++) { + const dtype* din_ptr = din + h * size_c_in; + for (int i = 0; i < ch_n; i++) { + dtype* dout_ptr = out_array[i] + h * width; + for (int k = 0; k < width; k++) { + *(dout_ptr++) = LITEMAX(din_ptr[k], 0); + } + din_ptr += size_w; + } + } + } else { + for (int h = 0; h < size_h; h++) { + const dtype* din_ptr = din + h * size_c_in; + for (int i = 0; i < ch_n; i++) { + dtype* dout_ptr = out_array[i] + h * width; + memcpy(dout_ptr, din_ptr, valid_w_byte); + din_ptr += size_w; + } + } + } + return true; +} + +/// ch_n == ce - cs ?? +/// hei_n == he - hs ?? +/// channel height width ? -> output +template +static bool write2_to_output_numc(const ditype* din, + dotype* dout, + int ch_n, + int hei_n, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int height, + int width, + bool flag_relu, + dotype* trash_ptr, + float const* scales) { + // static_assert(std::is_same::value, "just support float"); + + if (ch_n <= 0 || hei_n <= 0) { + LOG(ERROR) << "ch_n and hei_n are more than zero"; + return false; + } + + int size_c_out = width * height; + + dotype* out_array[ch_n]; + out_array[0] = dout + cs * size_c_out + hs * width + ws; + + for (int i = 1; i < ch_n; i++) { + out_array[i] = out_array[i - 1] + size_c_out; + } + + const ditype* ptr_din = din; + + int cremain = ce - channel; + for (int i = 1; i <= cremain; i++) { + out_array[ch_n - i] = trash_ptr; + } + + int size_h = (he > height ? height : he) - hs; // size_h == hei_n + + int size_w = we - ws; + + int size_c_in = ch_n * size_w; + + size_t valid_w_byte = width * sizeof(ditype); + + if (flag_relu) { + for (int h = 0; h < size_h; h++) { + ditype const* din_ptr = din + h * size_c_in; + for (int i = 0; i < ch_n; i++) { + float const ws = scales[(i + cs) % ch_n]; + dotype* dout_ptr = out_array[i] + h * width; + for (int k = 0; k < width; k++) { + *(dout_ptr++) = LITEMAX(din_ptr[k] * ws, 0); + } + din_ptr += size_w; + } + } + } else { + for (int h = 0; h < size_h; h++) { + ditype const* din_ptr = din + h * size_c_in; + for (int i = 0; i < ch_n; i++) { + dotype* dout_ptr = out_array[i] + h * width; + + float const* ws = &scales[(i + cs) % ch_n]; + int32_to_dtype(din_ptr, dout_ptr, ws, 1, 1, width); + + din_ptr += size_w; + } + } + } + return true; +} +/** +* innput din: nchwc(num) +*/ +inline bool fill_packed_bias_nxmw_fp32( + const float* bias, float* dout, int ch_n, int hei_n, int wround) { + if (ch_n <= 0 || hei_n <= 0) { + LOG(ERROR) << "ch_n and hei_n are more than zero"; + return false; + } + int cnt_ch = ch_n / 4; + int size = wround * ch_n; + for (int h = 0; h < hei_n; h++) { + float* dout_ptr = dout + h * size; + for (int i = 0; i < wround; i++) { + const float* bias_ptr = bias; + int j = 0; + for (; j < cnt_ch; j++) { + float32x4_t vb = vld1q_f32(bias_ptr); + bias_ptr += 4; + + vst1q_f32(dout_ptr, vb); + dout_ptr += 4; + } + j = j * 4; + for (; j < ch_n; j++) { + *dout_ptr = *bias_ptr; + dout_ptr++; + bias_ptr++; + } + } + } +} + +inline bool fill_packed_bias_nxmw_int8( + const int* bias, int* dout, int ch_n, int hei_n, int wround) { + if (ch_n <= 0 || hei_n <= 0) { + LOG(ERROR) << "ch_n and hei_n are more than zero"; + return false; + } + int cnt_ch = ch_n / 4; + int size = wround * ch_n; + for (int h = 0; h < hei_n; h++) { + int* dout_ptr = dout + h * size; + for (int i = 0; i < wround; i++) { + const int* bias_ptr = bias; + int j = 0; + for (; j < cnt_ch; j++) { + int32x4_t vb = vld1q_s32(bias_ptr); + bias_ptr += 4; + + vst1q_s32(dout_ptr, vb); + dout_ptr += 4; + } + j = j * 4; + for (; j < ch_n; j++) { + *dout_ptr = *bias_ptr; + dout_ptr++; + bias_ptr++; + } + } + } + return true; +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_depthwise.cc b/lite/arm/math/conv_depthwise.cc new file mode 100644 index 00000000000..f04de2178c5 --- /dev/null +++ b/lite/arm/math/conv_depthwise.cc @@ -0,0 +1,239 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/conv_depthwise.h" +#include "lite/arm/math/conv_block_utils.h" +#include "lite/arm/math/conv_impl.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +bool DepthwiseConv::create(const operators::ConvParam& param, + ARMContext* ctx) { + this->ctx_ = ctx; + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ic = x_dims[1]; + int ow = o_dims[3]; + int oc = o_dims[1]; + int kw = w_dims[3]; + int sw = param.strides[1]; + // select dw conv kernel + if (kw == 3) { + VLOG(5) << "invoke 3x3 dw conv"; + impl_ = conv_depthwise_3x3; + } else if (kw == 5) { + VLOG(5) << "invoke 5x5 dw conv"; + this->ctx_->ExtendWorkspace((iw + ow) * sizeof(float)); + impl_ = conv_depthwise_5x5; + } else { + LOG(ERROR) << "this type dw conv not impl"; + return false; + } + return true; +} + +template <> +bool DepthwiseConv::init(const operators::ConvParam& param, + Context* ctx) { + this->ctx_ = ctx; + return create(param, ctx); +} + +template <> +bool DepthwiseConv::run(const operators::ConvParam& param) { + // start timer + const auto* i_data = param.x->data(); + const auto* w_data = param.filter->data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + auto* o_data = param.output->mutable_data(); + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + + impl_(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + this->ctx_); + + // timer end + return true; +} + +template +bool DepthwiseConvInt8::create(const operators::ConvParam& param, + ARMContext* ctx) { + this->ctx_ = ctx; + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int ic = x_dims[1]; + int ih = x_dims[2]; + int iw = x_dims[3]; // nchw + int oc = o_dims[1]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int kw = w_dims[3]; + int sw = param.strides[1]; + w_scale_ = param.weight_scale; + + //! select dw conv kernel + if (kw == 3) { + tmp_int32_out_.Resize(o_dims); + VLOG(5) << "invoke 3x3 depthwise int8 conv"; + impl_ = conv_depthwise_3x3_int8; + } else if (kw == 5) { + // update w_data scale + if (Ptype_out == PRECISION(kFloat) || Ptype_out == PRECISION(kInt8)) { + CHECK_EQ(w_scale_.size(), oc) << "w_data scale size must be oc"; + float input_scale = param.input_scale; + float output_scale = param.output_scale; + for (auto& ws : w_scale_) { + ws *= input_scale; + if (Ptype_out == PRECISION(kInt8)) { + ws /= output_scale; + } + } + } + + const int wout_round = ((ow + 7) / 8) * 8; + const int win_round = wout_round * sw + 5 - 1; + const int hout_round = ((oh + 2) / 3) * 3; + const int hin_round = hout_round * sw + 5 - 1; + const int tmp_size_out = wout_round * hout_round; + const int tmp_size_in = win_round * hin_round; + const int tmp_size_io_bytes = tmp_size_in + tmp_size_out * sizeof(int); + const int tmp_row_io_bytes = win_round + wout_round * sizeof(int); + const int tmp_size_io_float = + (tmp_size_io_bytes + sizeof(float) - 1) / sizeof(float); + const int tmp_row_io_float = + (tmp_row_io_bytes + sizeof(float) - 1) / sizeof(float); + ctx_->ExtendWorkspace( + (ctx_->threads() * tmp_size_io_float + tmp_row_io_float) * + sizeof(float)); + impl_ = conv_depthwise_5x5_int8; + VLOG(5) << "invoke conv_depthwise_5x5 int8 conv"; + } else { + LOG(ERROR) << "this type depthwise int8 conv not impl"; + return false; + } + return true; +} + +template +bool DepthwiseConvInt8::init(const operators::ConvParam& param, + Context* ctx) { + this->ctx_ = ctx; + return create(param, ctx); +} + +template +bool DepthwiseConvInt8::run(const operators::ConvParam& param) { + const int8_t* i_data = param.x->data(); + int32_t* o_data = nullptr; + const int8_t* w_data = param.filter->data(); + const int32_t* b_data = param.bias ? param.bias->data() : nullptr; + + // LOG(INFO) << "input size: " << param.x->memory_size() << " " + // << param.input_scale << " " << w_scale_.size(); + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + int bs = x_dims[0]; + int ic = x_dims[1]; + int ih = x_dims[2]; + int iw = x_dims[3]; // nchw + int oc = o_dims[1]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int kw = w_dims[3]; + int sw = param.strides[1]; + + if (kw == 3 && Ptype_out != PRECISION(kInt32)) { + o_data = tmp_int32_out_.mutable_data(); + } else if (kw == 5 || (kw == 3 && Ptype_out == PRECISION(kInt32))) { + o_data = param.output->mutable_data(); + } else { + LOG(ERROR) << "this type dw int8 conv not impl"; + return false; + } + + impl_(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + this->ctx_, + Ptype_out, + w_scale_.data()); + + auto i_scale = param.input_scale; + auto o_scale = param.output_scale; + if (kw == 3) { + if (Ptype_out == PRECISION(kInt8)) { + trans_tensor_dtype( + &tmp_int32_out_, param.output, i_scale, o_scale, w_scale_); + } else if (Ptype_out == PRECISION(kFloat)) { + trans_tensor_dtype( + &tmp_int32_out_, param.output, i_scale, 1.f, w_scale_); + } else if (Ptype_out != PRECISION(kInt32)) { + LOG(ERROR) << "unsupported precision type!!"; + return false; + } + } + + return true; +} + +template class DepthwiseConvInt8; +template class DepthwiseConvInt8; +template class DepthwiseConvInt8; + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_depthwise.h b/lite/arm/math/conv_depthwise.h new file mode 100644 index 00000000000..15e3b36e305 --- /dev/null +++ b/lite/arm/math/conv_depthwise.h @@ -0,0 +1,100 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/core/target_wrapper.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +class DepthwiseConv + : public ImplBase { + public: + typedef void (*conv_dw_impl)(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int kw, + const float* w_data, + const float* b_data, + const operators::ConvParam& param, + Context* ctx); + DepthwiseConv() = default; + ~DepthwiseConv() {} + + virtual bool init(const operators::ConvParam& param, + Context* ctx); + + virtual bool create(const operators::ConvParam& param, + Context* ctx); + + virtual bool run(const operators::ConvParam& param); + + private: + conv_dw_impl impl_{nullptr}; +}; + +template +class DepthwiseConvInt8 + : public ImplBase { + public: + typedef void (*conv_dw_int8_impl)(const int8_t* i_data, + int32_t* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int kw, + const int8_t* w_data, + const int32_t* b_data, + const operators::ConvParam& param, + Context* ctx, + PrecisionType out_type, + const float* scale); + + DepthwiseConvInt8() = default; + ~DepthwiseConvInt8() {} + + virtual bool init(const operators::ConvParam& param, + Context* ctx); + + virtual bool create(const operators::ConvParam& param, + Context* ctx); + + virtual bool run(const operators::ConvParam& param); + + private: + conv_dw_int8_impl impl_{nullptr}; + std::vector w_scale_; + Tensor tmp_int32_out_; +}; + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_depthwise_3x3_int7.cc b/lite/arm/math/conv_depthwise_3x3_int7.cc new file mode 100644 index 00000000000..18dd2225ae6 --- /dev/null +++ b/lite/arm/math/conv_depthwise_3x3_int7.cc @@ -0,0 +1,5322 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void conv_depthwise_3x3s1p1_bias_int7(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +//! for input width <= 8 +void conv_depthwise_3x3s1p1_bias_s_int7(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p1_bias_int7(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +//! for input width <= 8 +void conv_depthwise_3x3s2p1_bias_s_int7(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s1p1_bias_relu_int7(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +//! for input width <= 4 +void conv_depthwise_3x3s1p1_bias_s_relu_int7(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p1_bias_relu_int7(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +//! for input width <= 4 +void conv_depthwise_3x3s2p1_bias_s_relu_int7(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3_int7(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + ARMContext* ctx, + PrecisionType out_type, + const float* scale) { + int w_in = win; + int h_in = hin; + int ch_in = chin; + + int w_out = wout; + int h_out = hout; + int ch_out = chout; + int stride_h = param.strides[0]; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + // if (param.activation_param.has_active) { + // if (param.activation_param.active == Active_relu || + // fabs(param.activation_param.negative_slope) > 1e-6f) { + // flag_relu = true; + // } + // } + //! only support stride = 1 or 2 + if (stride_h == 1) { + if (flag_relu) { + if (w_in > 8) { + conv_depthwise_3x3s1p1_bias_relu_int7(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s1p1_bias_s_relu_int7(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } else { + if (w_in > 8) { + conv_depthwise_3x3s1p1_bias_int7(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s1p1_bias_s_int7(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + } else { //! stride = 2 + if (flag_relu) { + if (w_in > 16) { + conv_depthwise_3x3s2p1_bias_relu_int7(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p1_bias_s_relu_int7(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } else { + if (w_in > 16) { + conv_depthwise_3x3s2p1_bias_int7(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p1_bias_s_int7(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + } +} +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, + * width > 4 + */ + +// 4line w_in > 8 +void conv_depthwise_3x3s1p1_bias_int7(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + // printf("3x3s1 mult height \n"); + //! pad is done implicit + const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + const unsigned char right_pad_idx[16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // printf("conv3x3_dw start \n"); + signed char* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(signed char)); + int* write_ptr = + reinterpret_cast(ctx->workspace_data()) + w_in; + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_w = (w_in + 7) >> 3; + int tile_h = (h_out + 1) >> 1; + int cnt_col = tile_w - 2; + + unsigned int size_pad_right = (unsigned int)(w_in - 7 - (cnt_col << 3)); + + int size_pad_bottom = h_out % 2; + + uint8x8_t vmask_rp1 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); + uint8x8_t vmask_rp2 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); + + uint8x16_t vmask_rp = + vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); + // uint8x8_t vmask_rp2 = vcgt_u8(vdup_n_u8(size_pad_right), + // vld1_u8(right_pad_idx + 8)); + unsigned char vmask[16]; + vst1q_u8(vmask, vmask_rp); + + unsigned int rst_remain = (unsigned int)(w_out - ((cnt_col + 1) << 3)); + uint32x4_t vmask_result1 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); + uint32x4_t vmask_result2 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); + + unsigned int rmask[8]; + vst1q_u32(rmask, vmask_result1); + vst1q_u32(rmask + 4, vmask_result2); + + int8x8_t vzero = vdup_n_s8(0); + int32x4_t vzero_32 = vdupq_n_s32(0); + + for (int n = 0; n < num; ++n) { + const signed char* din_batch = din + n * ch_in * size_in_channel; + int* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + int* dout_ptr = dout_batch + c * size_out_channel; + + const signed char* din_ch_ptr = din_batch + c * size_in_channel; + + int bias_val = flag_bias ? bias[c] : 0; + + const signed char* wei_ptr = weights + c * w_stride; + +#ifdef __aarch64__ + int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + + int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); + int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); + int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); + + int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); + + int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); +#endif + int* doutr0 = nullptr; + int* doutr1 = nullptr; + + const signed char* dr0 = din_ch_ptr; + const signed char* dr1 = dr0 + w_in; + const signed char* dr2 = dr1 + w_in; + const signed char* dr3 = dr2 + w_in; + + const signed char* din_ptr0 = nullptr; + const signed char* din_ptr1 = nullptr; + const signed char* din_ptr2 = nullptr; + const signed char* din_ptr3 = nullptr; + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + + doutr0 = dout_ptr; + doutr1 = doutr0 + w_out; + unsigned int* rst_mask = rmask; + unsigned char* val_mask = vmask; + + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + din_ptr3 = dr2; + dr0 = dr1; + dr1 = dr2; + dr2 = dr3; + dr3 = dr2 + w_in; + } else { + dr0 = dr2; + dr1 = dr3; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + } + //! process bottom pad + if (i + 3 > h_in) { + switch (i + 3 - h_in) { + case 3: + din_ptr1 = zero_ptr; + case 2: + din_ptr2 = zero_ptr; + case 1: + din_ptr3 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 2 > h_out) { + doutr1 = write_ptr; + } + int cnt = cnt_col; +#ifdef __aarch64__ + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "PRFM PLDL1KEEP, [%[din_ptr3]] \n" + "movi v21.4s, #0x0\n" /* out0 = 0 */ + // left + "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load + a00-a015 to + q0*/ + "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load + a00-a015 to + q0*/ + + "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + // r0 + "smull v18.8h, %[v1].8b, v0.8b \n" /* outr00 = 01234567 * w01 + */ + + "ext v4.8b, v21.8b, v0.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v5.8b, v0.8b, v1.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load + a00-a015 + to + q0*/ + + "smlal v18.8h, %[v0].8b, v4.8b\n" /* outr00 += 00123456 * w00 */ + + "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load + a00-a015 + to q0*/ + "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load + a00-a015 + to q0*/ + + "sub %[din_ptr0], %[din_ptr0], #1 \n" + "sub %[din_ptr1], %[din_ptr1], #1 \n" + + "smlal v18.8h, %[v2].8b, v5.8b\n" /* outr00 += 12345678 * w02 */ + + "ext v4.8b, v21.8b, v2.8b, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v5.8b, v2.8b, v3.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + // r1 + "sub %[din_ptr2], %[din_ptr2], #1 \n" + "sub %[din_ptr3], %[din_ptr3], #1 \n" + + "smull v19.8h, %[v1].8b, v2.8b \n" /* outr10 += 01234567 * w11 + */ + "smlal v18.8h, %[v4].8b, v2.8b \n" /* outr00 += 01234567 * w11 + */ + + "ext v14.8b, v21.8b, v6.8b, #7 \n" /* vext_s8(vzero, vinr0, + 7); 00123456 */ + "ext v15.8b, v6.8b, v7.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "smlal v19.8h, %[v0].8b, v4.8b \n" /* outr00 += 01234567 * w11 + */ + "smlal v18.8h, %[v3].8b, v4.8b \n" /* outr00 += 001234567 * w10 + */ + + "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load + a00-a015 + to + q0*/ + + "smlal v19.8h, %[v2].8b, v5.8b \n" /* outr00 += 01234567 * w11 + */ + "smlal v18.8h, %[v5].8b, v5.8b \n" /* outr00 += 12345678 * w12 + */ + + // r2 + "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load + a00-a015 to + q0*/ + "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load + a00-a015 to + q0*/ + + "smlal v19.8h, %[v4].8b, v6.8b \n" /* outr10 += 01234567 * w11 + */ + "smlal v18.8h, %[v7].8b, v6.8b \n" /* outr00 += 01234567 * w11 + */ + + "ext v4.8b, v21.8b, v8.8b, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v5.8b, v8.8b, v9.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "smlal v19.8h, %[v3].8b, v14.8b \n" /* outr10 += 01234567 * w11 + */ + "smlal v18.8h, %[v6].8b, v14.8b \n" /* outr00 += 01234567 * w11 + */ + + "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load + a00-a015 + to + q0*/ + + "smlal v19.8h, %[v5].8b, v15.8b \n" /* outr10 += 01234567 * w11 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v8].8b, v15.8b \n" /* outr00 += 01234567 * w11 + */ + + // r3 + "smlal v19.8h, %[v7].8b, v8.8b \n" /* outr00 += 01234567 * w11 + */ + + "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load + a00-a015 + to + q0*/ + + "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load + a00-a015 to + q0*/ + "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load + a00-a015 to + q0*/ + + "smlal v19.8h, %[v6].8b, v4.8b \n" /* outr00 += 01234567 * + w11 */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v8].8b, v5.8b \n" /* outr00 += 01234567 * + w11 */ + + "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "cmp %[cnt], #1 \n" + "blt 3f \n" + // mid + "1: \n" + "ext v4.8b, v0.8B, v1.8b, #1 \n" /*12345678 */ + "ext v5.8b, v0.8b, v1.8B, #2 \n" /*23456789 */ + + // r0 + "smull v18.8h, %[v0].8b, v0.8b \n" /* outr00 = 01234567 * w00 + */ + + "ext v14.8b, v2.8B, v3.8b, #1 \n" /*12345678 */ + "ext v15.8b, v2.8b, v3.8B, #2 \n" /*23456789 */ + + "smlal v18.8h, %[v1].8b, v4.8b\n" /* outr00 += 12345678 * w01 */ + + "ext v16.8b, v6.8B, v7.8b, #1 \n" /*12345678 */ + "ext v17.8b, v6.8b, v7.8B, #2 \n" /*23456789 */ + + "smlal v18.8h, %[v2].8b, v5.8b\n" /* outr00 += 23456789 * w02 */ + + // r1 + "ext v4.8b, v8.8B, v9.8b, #1 \n" /*12345678 */ + "ext v5.8b, v8.8b, v9.8B, #2 \n" /*23456789 */ + + "smull v19.8h, %[v0].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v3].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load + a00-a015 + to + q0*/ + + "smlal v19.8h, %[v1].8b, v14.8b\n" /* outr00 += 12345678 * w01 */ + "smlal v18.8h, %[v4].8b, v14.8b\n" /* outr00 += 12345678 * w01 */ + + "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load + a00-a015 + to q0*/ + "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load + a00-a015 + to q0*/ + + "smlal v19.8h, %[v2].8b, v15.8b\n" /* outr00 += 23456789 * w02 */ + "smlal v18.8h, %[v5].8b, v15.8b\n" /* outr00 += 12345678 * w01 */ + + // r2 + "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "smlal v19.8h, %[v4].8b, v16.8b\n" /* outr00 += 12345678 * w01 */ + "smlal v18.8h, %[v7].8b, v16.8b\n" /* outr00 += 12345678 * w01 */ + + "smlal v19.8h, %[v5].8b, v17.8b\n" /* outr00 += 23456789 * w02 */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v8].8b, v17.8b\n" /* outr00 += 12345678 * w01 */ + + // r3 + "smlal v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load + a00-a015 + to + q0*/ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smlal v19.8h, %[v7].8b, v4.8b\n" /* outr00 += 12345678 * w01 */ + + "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load + a00-a015 + to q0*/ + "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load + a00-a015 + to q0*/ + + "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v8].8b, v5.8b\n" /* outr00 += 23456789 * w02 */ + + "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "subs %[cnt], %[cnt], #1 \n" + + "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "bne 1b \n" + // right + "3: \n" + "ld1 {v14.8b}, [%[vmask]], #8 \n" + "ld1 {v15.8b}, [%[vmask]] \n" + + "bif v0.8b, v21.8b, v14.8b \n" + "bif v1.8b, v21.8b, v15.8b \n" + "bif v2.8b, v21.8b, v14.8b \n" + "bif v3.8b, v21.8b, v15.8b \n" + + "ext v4.8b, v0.8b, v1.8b, #1 \n" + "ext v5.8b, v0.8b, v1.8b, #2 \n" + + // r0 + "smull v18.8h, %[v0].8b, v0.8b \n" /* outr00 = 01234567 * w00 + */ + + "ext v16.8b, v2.8b, v3.8b, #1 \n" + "ext v17.8b, v2.8b, v3.8b, #2 \n" + + "bif v6.8b, v21.8b, v14.8b \n" + "bif v7.8b, v21.8b, v15.8b \n" + + "smlal v18.8h, %[v1].8b, v4.8b \n" /* outr00 = 01234567 * w00 + */ + + "bif v8.8b, v21.8b, v14.8b \n" + "bif v9.8b, v21.8b, v15.8b \n" + + "ext v20.8b, v6.8b, v7.8b, #1 \n" + "ext v22.8b, v6.8b, v7.8b, #2 \n" + + "smlal v18.8h, %[v2].8b, v5.8b \n" /* outr00 = 01234567 * w00 + */ + + // r1 + "ext v4.8b, v8.8b, v9.8b, #1 \n" + "ext v5.8b, v8.8b, v9.8b, #2 \n" + + "smull v19.8h, %[v0].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v3].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v14.4s}, [%[rmask]], #16 \n" + "ld1 {v15.4s}, [%[rmask]] \n" + + "smlal v19.8h, %[v1].8b, v16.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v4].8b, v16.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v0.4s}, [%[ptr_out0]], #16 \n" + "ld1 {v2.4s}, [%[ptr_out1]], #16 \n" + + "smlal v19.8h, %[v2].8b, v17.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v5].8b, v17.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v1.4s}, [%[ptr_out0]] \n" + "ld1 {v3.4s}, [%[ptr_out1]] \n" + + // r2 + "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "sub %[ptr_out0], %[ptr_out0], #16 \n" + "sub %[ptr_out1], %[ptr_out1], #16 \n" + + "smlal v19.8h, %[v4].8b, v20.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v7].8b, v20.8b \n" /* outr00 = 01234567 * w00 + */ + + "smlal v19.8h, %[v5].8b, v22.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v8].8b, v22.8b \n" /* outr00 = 01234567 * w00 + */ + + // r3 + "smlal v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smlal v19.8h, %[v7].8b, v4.8b \n" /* outr00 = 01234567 * w00 + */ + + "bif v10.16b, v0.16b, v14.16b \n" + "bif v11.16b, v1.16b, v15.16b \n" + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v8].8b, v5.8b \n" /* outr00 = 01234567 * w00 + */ + + "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "bif v12.16b, v2.16b, v14.16b \n" + "bif v13.16b, v3.16b, v15.16b \n" + + "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> + ptr_out */ + + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [ptr_out0] "+r"(doutr0), + [ptr_out1] "+r"(doutr1), + [vmask] "+r"(val_mask), + [rmask] "+r"(rst_mask) + : [v0] "w"(wr00), + [v1] "w"(wr01), + [v2] "w"(wr02), + [v3] "w"(wr10), + [bias_val] "r"(vbias), + [v4] "w"(wr11), + [v5] "w"(wr12), + [v6] "w"(wr20), + [v7] "w"(wr21), + [v8] "w"(wr22) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + // store weights + asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" + : + : [wei_ptr] "r"(wei_ptr) + : "memory"); + asm volatile( + // left + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + "pld [%[din_ptr3]] @ preload data\n" + "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" + "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vmov.u32 d11, #0 @ zero\n" + // out0 + "vdup.32 q8, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q9, %[bias] @ and \n" // q9 = + // vbias + // out1 + "vdup.32 q10, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q11, %[bias] @ and \n" // q9 = + // vbias + + // r0 + "vmull.s8 q12, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 + + "vld1.8 {d12-d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.8 {d14-d15}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vdup.s8 d5, d0[3] @ d5 = w10, w10, w00, w00\n" + "vdup.s8 d6, d0[4] @ d6 = w11, w11, w01, w01\n" + + "vmlal.s8 q12, d30, d2 @ out0 += din0 * w00 \n" // q12 += d10 * w00 + + "vdup.s8 d7, d0[5] @ d7 = w12, w12\n" + "add %[din_ptr0], #7 @add \n" + "add %[din_ptr1], #7 @add \n" + + "vmlal.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 + + // r1 + "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 + "vmull.s8 q13, d12, d3 @ out1 = din1 * w01 \n" // q13 = d12 * w01 + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + "vmull.s8 q12, d12, d6 @ out0 = din1 * w11 \n" // q12 = d12 * w11 + + "vld1.8 {d12-d13}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vdup.s8 d8, d0[6] @ d8 = w20, w00, w00, w00\n" + "vdup.s8 d9, d0[7] @ d9 = w21, w01, w01, w01\n" + "vdup.s8 d10, d1[0] @ d10 = w22, w02, w02, w02\n" + + "vmlal.s8 q13, d30, d2 @ out1 += din1 * w00 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d5 @ out0 += din1 * w10 \n" // q12 += d10 * w00 + + "add %[din_ptr2], #7 @add \n" + "add %[din_ptr3], #7 @add \n" + + "vmlal.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 + + // r2 + "vext.8 d30, d11, d14, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d15, #1 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d14, d6 @ out1 = din2 * w11 \n" // q13 = d12 * w01 + "vmull.s8 q12, d14, d9 @ out1 = din2 * w21 \n" // q13 = d12 * w01 + + "vmlal.s8 q13, d30, d5 @ out1 += din2 * w10 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d8 @ out0 += din2 * w20 \n" // q12 += d10 * w00 + + "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 + + // r3 + "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d12, d9 @ out1 = din3 * w21 \n" // q13 = d12 * w01 + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + + "vmlal.s8 q13, d30, d8 @ out1 += din3 * w20 \n" // q13 += d10 * w00 + "pld [%[din_ptr2]] @ preload data\n" + "pld [%[din_ptr3]] @ preload data\n" + + "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" + + "vmlal.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 + + "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" + "cmp %[cnt], #1 \n" + "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" + "blt 1f \n" + + // mid + "2: \n" + "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + // out0 + "vdup.32 q8, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q9, %[bias] @ and \n" // q9 = + // vbias + // out1 + "vdup.32 q10, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q11, %[bias] @ and \n" // q9 = + // vbias + + // r0 + "vmull.s8 q12, d12, d2 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 12345678 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 23456789 + + "vld1.8 {d12-d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.8 {d14-d15}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + + "vmlal.s8 q12, d30, d3 @ out0 += din0 * w00 \n" // q12 += d10 * w00 + + "add %[din_ptr0], #8 @add \n" + "add %[din_ptr1], #8 @add \n" + + "vmlal.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 + + // r1 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 + "vmull.s8 q13, d12, d2 @ out1 = din1 * w01 \n" // q13 = d12 * w01 + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + "vmull.s8 q12, d12, d5 @ out0 = din1 * w11 \n" // q12 = d12 * w11 + + "vld1.8 {d12-d13}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + + "vmlal.s8 q13, d30, d3 @ out1 += din1 * w00 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d6 @ out0 += din1 * w10 \n" // q12 += d10 * w00 + + "add %[din_ptr2], #8 @add \n" + "add %[din_ptr3], #8 @add \n" + + "vmlal.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 + + // r2 + "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d14, d5 @ out1 = din2 * w11 \n" // q13 = d12 * w01 + "vmull.s8 q12, d14, d8 @ out1 = din2 * w21 \n" // q13 = d12 * w01 + + "vmlal.s8 q13, d30, d6 @ out1 += din2 * w10 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d9 @ out0 += din2 * w20 \n" // q12 += d10 * w00 + + "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 + + // r3 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d12, d8 @ out1 = din3 * w21 \n" // q13 = d12 * w01 + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + + "vmlal.s8 q13, d30, d9 @ out1 += din3 * w20 \n" // q13 += d10 * w00 + "pld [%[din_ptr2]] @ preload data\n" + "pld [%[din_ptr3]] @ preload data\n" + + "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" + + "vmlal.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 + + "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" + "subs %[cnt], #1 \n" + "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" + "bne 2b \n" + // right + "1: \n" + "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + // out0 + "vdup.32 q8, %[bias] @ and \n" // q8 = vbias + "vdup.32 q9, %[bias] @ and \n" // q9 = vbias + // out1 + "vdup.32 q10, %[bias] @ and \n" // q8 = vbias + "vdup.32 q11, %[bias] @ and \n" // q9 = vbias + + "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" + "vld1.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + + // r0 + "vmull.s8 q12, d12, d2 @ out0 = din0 * w00 \n" // q12 = d12 * w01 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 12345678 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 23456789 + + "vld1.8 {d12-d13}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" + + "vmlal.s8 q12, d30, d3 @ out0 += din0 * w01 \n" // q12 += d10 * w00 + + "vmlal.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 + + // r1 + "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 + + "vmull.s8 q13, d14, d2 @ out1 = din1 * w00 \n" // q13 = d12 * w01 + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + "vmull.s8 q12, d14, d5 @ out0 = din1 * w10 \n" // q12 = d12 * w11 + + "vld1.8 {d14-d15}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vbif.8 d12, d11, d28 @ bit select, deal with " + "right pad\n" + "vbif.8 d13, d11, d29 @ bit select, deal with " + "right pad\n" + + "vmlal.s8 q13, d30, d3 @ out1 += din1 * w01 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d6 @ out0 += din1 * w11 \n" // q12 += d10 * w00 + + "vmlal.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 + + // r2 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d12, d5 @ out1 = din2 * w10 \n" // q13 = d12 * w01 + "vmull.s8 q12, d12, d8 @ out1 = din2 * w20 \n" // q13 = d12 * w01 + + "vbif.8 d14, d11, d28 @ bit select, deal with " + "right pad\n" + "vbif.8 d15, d11, d29 @ bit select, deal with " + "right pad\n" + + "vmlal.s8 q13, d30, d6 @ out1 += din2 * w10 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d9 @ out0 += din2 * w20 \n" // q12 += d10 * w00 + + "vld1.32 {d28-d29}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vld1.32 {d12-d13}, [%[dout_ptr1]] @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 5 6 7 8 " + "9\n" + + "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 + + // r3 + "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d14, d8 @ out1 = din3 * w20 \n" // q13 = d12 * w01 + "sub %[dout_ptr1], #16 @ sub \n" + "vld1.32 {d14-d15}, [%[dout_ptr2]]! @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vld1.32 {d24-d25}, [%[dout_ptr2]] @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + + "vmlal.s8 q13, d30, d9 @ out1 += din3 * w21 \n" // q13 += d10 * w00 + "vbif q8, q14, q1 @ bit select, deal with right " + "pad\n" + "vbif q9, q6, q2 @ bit select, deal with right " + "pad\n" + "sub %[dout_ptr2], #16 @ sub \n" + + "vmlal.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 + + "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" + "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vbif q10, q7, q1 @ bit select, deal with right pad\n" + "vbif q11, q12, q2 @ bit select, deal with right pad\n" + + "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" + "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" + + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [cnt] "+r"(cnt), + [bias] "+r"(bias_val), + [rs_mask] "+r"(rst_mask) + : [mask] "r"(vmask) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + dout_ptr += 2 * w_out; + } + } + } +} + +// w_in <= 8 +void conv_depthwise_3x3s1p1_bias_s_int7(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + // printf("3x3s1 mult height \n"); + const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + //! for 4x6 convolution window + const unsigned char right_pad_idx[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // printf("conv3x3_dw start \n"); + signed char* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(signed char)); + int* write_ptr = + reinterpret_cast(ctx->workspace_data()) + w_in; + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_h = (h_out + 1) >> 1; + + unsigned int size_pad_right = (unsigned int)(w_in); + + uint8x8_t vmask_rp = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); + // uint8x8_t vmask_rp2 = vcgt_u8(vdup_n_u8(size_pad_right), + // vld1_u8(right_pad_idx + 8)); + unsigned char vmask[8]; + vst1_u8(vmask, vmask_rp); + + unsigned int rst_remain = (unsigned int)w_out; + uint32x4_t vmask_result1 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); + uint32x4_t vmask_result2 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); + + unsigned int rmask[8]; + vst1q_u32(rmask, vmask_result1); + vst1q_u32(rmask + 4, vmask_result2); + + int8x8_t vzero = vdup_n_s8(0); + int32x4_t vzero_32 = vdupq_n_s32(0); + + for (int n = 0; n < num; ++n) { + const signed char* din_batch = din + n * ch_in * size_in_channel; + int* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + int* dout_ptr = dout_batch + c * size_out_channel; + + const signed char* din_ch_ptr = din_batch + c * size_in_channel; + + int bias_val = flag_bias ? bias[c] : 0; + + const signed char* wei_ptr = weights + c * w_stride; +#ifdef __aarch64__ + int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); + int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); + int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); + + int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); + + int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); +#endif + int* doutr0 = nullptr; + int* doutr1 = nullptr; + + const signed char* dr0 = din_ch_ptr; + const signed char* dr1 = dr0 + w_in; + const signed char* dr2 = dr1 + w_in; + const signed char* dr3 = dr2 + w_in; + + const signed char* din_ptr0 = nullptr; + const signed char* din_ptr1 = nullptr; + const signed char* din_ptr2 = nullptr; + const signed char* din_ptr3 = nullptr; + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + + doutr0 = dout_ptr; + doutr1 = doutr0 + w_out; + unsigned int* rst_mask = rmask; + + int out_buf1[8]; + int out_buf2[8]; + + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + din_ptr3 = dr2; + dr0 = dr1; + dr1 = dr2; + dr2 = dr3; + dr3 = dr2 + w_in; + } else { + dr0 = dr2; + dr1 = dr3; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + } + //! process bottom pad + if (i + 3 > h_in) { + switch (i + 3 - h_in) { + case 3: + din_ptr1 = zero_ptr; + case 2: + din_ptr2 = zero_ptr; + case 1: + din_ptr3 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 2 > h_out) { + doutr1 = write_ptr; + } +#ifdef __aarch64__ + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "PRFM PLDL1KEEP, [%[din_ptr3]] \n" + "movi v21.4s, #0x0\n" /* out0 = 0 */ + // left + "ld1 {v4.8b}, [%[vmask]] \n" + "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v1.8b}, [%[din_ptr1]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v2.8b}, [%[din_ptr2]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v3.8b}, [%[din_ptr3]], #8 \n" /* load + a00-a015 + to + q0*/ + + "bif v0.8b, v21.8b, v4.8b \n" + "bif v1.8b, v21.8b, v4.8b \n" + "bif v2.8b, v21.8b, v4.8b \n" + "bif v3.8b, v21.8b, v4.8b \n" + + "ext v6.8b, v21.8b, v0.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v7.8b, v0.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "ld1 {v10.4s}, [%[vbias]] \n" + "ld1 {v11.4s}, [%[vbias]] \n" + + // r0 + "smull v18.8h, %[v1].8b, v0.8b \n" /* outr00 = 01234567 * w01 + */ + + "ext v8.8b, v21.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v9.8b, v1.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "smlal v18.8h, %[v0].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v12.4s}, [%[vbias]] \n" + "ld1 {v13.4s}, [%[vbias]] \n" + + "smlal v18.8h, %[v2].8b, v7.8b \n" /* outr00 = 01234567 * w00 + */ + + "ext v6.8b, v21.8b, v2.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v7.8b, v2.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + // r1 + "smull v19.8h, %[v1].8b, v1.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v4].8b, v1.8b \n" /* outr00 = 01234567 * w00 + */ + + // "ld1 {v14.4s}, [%[rmask]], #16 \n" + // "ld1 {v15.4s}, [%[rmask]] \n" + + "smlal v19.8h, %[v0].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v3].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + // "ld1 {v16.4s}, [%[ptr_out0]], #16 \n" + // "ld1 {v17.4s}, [%[ptr_out1]], #16 \n" + + "smlal v19.8h, %[v2].8b, v9.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v5].8b, v9.8b \n" /* outr00 = 01234567 * w00 + */ + + "ext v8.8b, v21.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v9.8b, v3.8b, v21.8B, #1 \n" // vext_s8(vinr0, vinr0_1, + // 1); 12345678 + + // "ld1 {v0.4s}, [%[ptr_out0]] \n" + // "ld1 {v1.4s}, [%[ptr_out1]] \n" + + // r2 + "smlal v19.8h, %[v4].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v7].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + + // "sub %[ptr_out0], %[ptr_out0], #16 \n" + // "sub %[ptr_out1], %[ptr_out1], #16 \n" + + "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "smlal v19.8h, %[v5].8b, v7.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v8].8b, v7.8b \n" /* outr00 = 01234567 * w00 + */ + + // r3 + "smlal v19.8h, %[v7].8b, v3.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smlal v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + // "bif v10.16b, v16.16b, v14.16b \n" + // "bif v11.16b, v0.16b, v15.16b \n" + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v8].8b, v9.8b \n" /* outr00 = 01234567 * w00 + */ + + "stp q10, q11, [%[ptr_out0]] \n" /* store q10, q11 -> ptr_out + */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + // "bif v12.16b, v17.16b, v14.16b \n" + // "bif v13.16b, v1.16b, v15.16b \n" + + "stp q12, q13, [%[ptr_out1]] \n" /* store q10, q11 -> ptr_out */ + + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [rmask] "+r"(rst_mask) + : [v0] "w"(wr00), + [v1] "w"(wr01), + [v2] "w"(wr02), + [v3] "w"(wr10), + [vbias] "r"(vbias), + [v4] "w"(wr11), + [v5] "w"(wr12), + [v6] "w"(wr20), + [v7] "w"(wr21), + [v8] "w"(wr22), + [vmask] "r"(vmask), + [ptr_out0] "r"(out_buf1), + [ptr_out1] "r"(out_buf2) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + // store weights + asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" + : + : [wei_ptr] "r"(wei_ptr) + : "memory"); + asm volatile( + // left + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + "pld [%[din_ptr3]] @ preload data\n" + "vld1.8 {d28}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + "vld1.8 {d12}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + "vld1.8 {d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" + + "vmov.u32 d11, #0 @ zero\n" + // out0 + "vdup.32 q8, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q9, %[bias] @ and \n" // q9 = + // vbias + + "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d13, d11, d28 @ bit select, deal with right pad\n" + "vld1.8 {d14}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + "vld1.8 {d15}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + // out1 + "vdup.32 q10, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q11, %[bias] @ and \n" // q9 = + // vbias + + // r0 + "vmull.s8 q12, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d11, #1 @ ext \n" // d11 = 12345678 + + "vdup.s8 d5, d0[3] @ d5 = w10, w10, w00, w00\n" + "vdup.s8 d6, d0[4] @ d6 = w11, w11, w01, w01\n" + + "vmlal.s8 q12, d30, d2 @ out0 += din0 * w00 \n" // q12 += d10 * w00 + + "vdup.s8 d7, d0[5] @ d7 = w12, w12\n" + "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d15, d11, d28 @ bit select, deal with right pad\n" + + "vmlal.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 + + // r1 + "vext.8 d30, d11, d13, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d13, d11, #1 @ ext \n" // d11 = 12345678 + "vmull.s8 q13, d13, d3 @ out1 = din1 * w01 \n" // q13 = d12 * w01 + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + "vmull.s8 q12, d13, d6 @ out0 = din1 * w11 \n" // q12 = d12 * w11 + + "vdup.s8 d8, d0[6] @ d8 = w20, w00, w00, w00\n" + "vdup.s8 d9, d0[7] @ d9 = w21, w01, w01, w01\n" + + "vmlal.s8 q13, d30, d2 @ out1 += din1 * w00 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d5 @ out0 += din1 * w10 \n" // q12 += d10 * w00 + + "vdup.s8 d10, d1[0] @ d10 = w22, w02, w02, w02\n" + // "vld1.32 {d28-d29}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 + // 6 7 8 9\n" "vld1.32 {d12-d13}, [%[dout_ptr1]] @ load din00= 0 + // 1 2 3 4 5 6 7 8 9\n" + + "vmlal.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 + + // r2 + "vext.8 d30, d11, d14, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d11, #1 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d14, d6 @ out1 = din2 * w11 \n" // q13 = d12 * w01 + "vmull.s8 q12, d14, d9 @ out1 = din2 * w21 \n" // q13 = d12 * w01 + + // "sub %[dout_ptr1], #16 @ sub \n" + "vmlal.s8 q13, d30, d5 @ out1 += din2 * w10 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d8 @ out0 += din2 * w20 \n" // q12 += d10 * w00 + + // "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 + // 8 9\n" "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 + // 5 6 7 8 9\n" + + "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 + + // r3 + "vext.8 d30, d11, d15, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d15, d11, #1 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d15, d9 @ out1 = din3 * w21 \n" // q13 = d12 * w01 + + // "vld1.32 {d6-d7}, [%[dout_ptr2]]! @ load din00= 0 1 2 3 4 5 6 + // 7 8 9\n" "vld1.32 {d14-d15}, [%[dout_ptr2]] @ load din00= 0 1 + // 2 3 4 5 6 7 8 9\n" + + "vmlal.s8 q13, d30, d8 @ out1 += din3 * w20 \n" // q13 += d10 * w00 + + // "vbif q8, q14, q1 @ bit select, deal with right + // pad\n" "vbif q9, q6, q2 @ bit select, deal + // with right pad\n" + + "vmlal.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 + + // "sub %[dout_ptr2], #16 @ sub \n" + + "vst1.32 {d16-d19}, [%[dout_ptr1]] @ store\n" + // "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + // "vbif q10, q3, q1 @ bit select, deal with right + // pad\n" "vbif q11, q7, q2 @ bit select, deal + // with right pad\n" + + "vst1.32 {d20-d23}, [%[dout_ptr2]] @ store\n" + // "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [bias] "+r"(bias_val), + [rs_mask] "+r"(rst_mask) + : [mask] "r"(vmask), + [dout_ptr1] "r"(out_buf1), + [dout_ptr2] "r"(out_buf2) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + for (int w = 0; w < w_out; ++w) { + *doutr0++ = out_buf1[w]; + *doutr1++ = out_buf2[w]; + } + dout_ptr += 2 * w_out; + } + } + } +} + +// 4line w_in > 16 +void conv_depthwise_3x3s2p1_bias_int7(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + // printf("3x3s2 mult height \n"); + //! pad is done implicit + const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + //! for 4x6 convolution window + const unsigned char right_pad_idx[16] = { + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // printf("conv3x3_dw start \n"); + signed char* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(signed char)); + int* write_ptr = + reinterpret_cast(ctx->workspace_data()) + w_out; + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_w = (w_in + 15) >> 4; + int cnt_col = tile_w - 2; + + unsigned int size_pad_right = (unsigned int)(w_in - 15 - (cnt_col << 4)); + if (size_pad_right == 17) { + size_pad_right = 0; + cnt_col++; + } + + uint8x8_t vmask_rp1 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); + uint8x8_t vmask_rp2 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); + unsigned int rst_remain = (unsigned int)(w_out - ((cnt_col + 1) << 3)); + uint32x4_t vmask_result1 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); + uint32x4_t vmask_result2 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); + + uint8x16_t vmask_rp = + vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); + unsigned char vmask[16]; + vst1q_u8(vmask, vmask_rp); + + unsigned int rmask[8]; + vst1q_u32(rmask, vmask_result1); + vst1q_u32(rmask + 4, vmask_result2); + + int8x8_t vzero = vdup_n_s8(0); + // printf("cnt_col: %d, rst_remain: %d, size_pad_right: %d\n", cnt_col, + // rst_remain, size_pad_right); + for (int n = 0; n < num; ++n) { + const signed char* din_batch = din + n * ch_in * size_in_channel; + int* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + int* dout_ptr = dout_batch + c * size_out_channel; + + const signed char* din_ch_ptr = din_batch + c * size_in_channel; + + int bias_val = flag_bias ? bias[c] : 0; + + const signed char* wei_ptr = weights + c * w_stride; +#ifdef __aarch64__ + int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); + int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); + int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); + + int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); + + int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); +#endif + + int* doutr0 = nullptr; + + const signed char* dr0 = din_ch_ptr; + const signed char* dr1 = dr0 + w_in; + const signed char* dr2 = dr1 + w_in; + + const signed char* din_ptr0 = nullptr; + const signed char* din_ptr1 = nullptr; + const signed char* din_ptr2 = nullptr; + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + + doutr0 = dout_ptr; + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + dr0 = dr1; + dr1 = dr2; + dr2 = dr1 + w_in; + } else { + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + } + //! process bottom pad + if (i + 2 > h_in) { + switch (i + 2 - h_in) { + case 2: + din_ptr1 = zero_ptr; + case 1: + din_ptr2 = zero_ptr; + default: + break; + } + } +#ifdef __aarch64__ + int cnt = cnt_col; + unsigned char* val_mask = vmask; + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "movi v10.4s, #0x0\n" + // left + "ld2 {v0.8b - v1.8b}, [%[din_ptr0]] \n" /*load a00-a015 + to q0*/ + "ld2 {v2.8b - v3.8b}, [%[din_ptr1]] \n" /* load a00-a015 + to q0*/ + "ld2 {v4.8b - v5.8b}, [%[din_ptr2]] \n" /*load a00-a015 + to q0*/ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias*/ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "ext v6.8b, v10.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + "ext v7.8b, v10.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + "ext v8.8b, v10.8b, v5.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + + // r0 + "smull v14.8h, %[v1].8b, v0.8b \n" /* outr00 = 02468 * w01 */ + "smull v15.8h, %[v2].8b, v1.8b\n" /* outr00 += 13579 * w02 */ + "smull v16.8h, %[v0].8b, v6.8b\n" /* outr00 += 013579 * w00 */ + + "add %[din_ptr0], %[din_ptr0], #15 \n" + "add %[din_ptr1], %[din_ptr1], #15 \n" + "add %[din_ptr2], %[din_ptr2], #15 \n" + + // r1 + "smlal v14.8h, %[v4].8b, v2.8b \n" /* outr00 = 02468 * w01 */ + "smlal v15.8h, %[v5].8b, v3.8b\n" /* outr00 += 13579 * w02 */ + "smlal v16.8h, %[v3].8b, v7.8b\n" /* outr00 += 013579 * w00 */ + + // r2 + "smlal v14.8h, %[v7].8b, v4.8b \n" /* outr00 = 02468 * w01 */ + "smlal v15.8h, %[v8].8b, v5.8b\n" /* outr00 += 13579 * w02 */ + "smlal v16.8h, %[v6].8b, v8.8b\n" /* outr00 += 013579 * w00 */ + + "ld2 {v0.8b - v1.8b}, [%[din_ptr0]], #16 \n" /*load + a00-a015 + to q0*/ + "ld2 {v2.8b - v3.8b}, [%[din_ptr1]], #16 \n" /* load + a00-a015 + to q0*/ + "ld2 {v4.8b - v5.8b}, [%[din_ptr2]], #16 \n" /*load + a00-a015 + to q0*/ + + "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ + + "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "cmp %[cnt], #1 \n" + "blt 3f \n" + // mid + "1: \n" + "ld1 {v6.8b}, [%[din_ptr0]] \n" /*load a00-a015 to q0*/ + "ld1 {v7.8b}, [%[din_ptr1]] \n" /*load a00-a015 to q0*/ + "ld1 {v8.8b}, [%[din_ptr2]] \n" /*load a00-a015 to q0*/ + + "ext v9.8b, v0.8b, v6.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 246810 */ + "ext v11.8b, v2.8b, v7.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 246810 */ + "ext v14.8b, v4.8b, v8.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 246810 */ + + // r0 + "smull v6.8h, %[v0].8b, v0.8b \n" /* outr00 = 02468 * w00 */ + "smull v7.8h, %[v1].8b, v1.8b\n" /* outr00 += 13579 * w01 */ + "smull v8.8h, %[v2].8b, v9.8b\n" /* outr00 += 246810 * w02 */ + + // r1 + "smlal v6.8h, %[v3].8b, v2.8b \n" /* outr00 = 02468 * w00 */ + "smlal v7.8h, %[v4].8b, v3.8b\n" /* outr00 += 13579 * w01 */ + "smlal v8.8h, %[v5].8b, v11.8b\n" /* outr00 += 246810 * w02 */ + + // r2 + "smlal v6.8h, %[v6].8b, v4.8b \n" /* outr00 = 02468 * w00 */ + "smlal v7.8h, %[v7].8b, v5.8b\n" /* outr00 += 13579 * w01 */ + "smlal v8.8h, %[v8].8b, v14.8b\n" /* outr00 += 246810 * w02 */ + + "ld2 {v0.8b - v1.8b}, [%[din_ptr0]], #16 \n" /*load + a00-a015 + to q0*/ + "ld2 {v2.8b - v3.8b}, [%[din_ptr1]], #16 \n" /* load + a00-a015 + to q0*/ + "ld2 {v4.8b - v5.8b}, [%[din_ptr2]], #16 \n" /*load + a00-a015 + to q0*/ + + "saddw v12.4s, v12.4s, v6.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v6.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v7.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v7.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v8.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v8.8h \n" /* v11 += outr00.high*/ + + "subs %[cnt], %[cnt], #1 \n" + + "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "bne 1b \n" + // right + "3: \n" + "ld1 {v14.8b}, [%[vmask]], #8 \n" + "ld1 {v15.8b}, [%[vmask]] \n" + + "bif v0.8b, v10.8b, v14.8b \n" + "bif v1.8b, v10.8b, v15.8b \n" + "bif v2.8b, v10.8b, v14.8b \n" + "bif v3.8b, v10.8b, v15.8b \n" + "bif v4.8b, v10.8b, v14.8b \n" + "bif v5.8b, v10.8b, v15.8b \n" + + "ext v6.8b, v0.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 2468.. */ + "ext v7.8b, v2.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 2468..*/ + "ext v8.8b, v4.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 2468.. */ + + // r0 + "smull v14.8h, %[v0].8b, v0.8b \n" /* outr00 = 02468 * w00 */ + "smull v15.8h, %[v1].8b, v1.8b\n" /* outr00 += 13579 * w01 */ + "smull v16.8h, %[v2].8b, v6.8b\n" /* outr00 += 246810 * w02 */ + + // r1 + "smlal v14.8h, %[v3].8b, v2.8b \n" /* outr00 = 02468 * w00 */ + "smlal v15.8h, %[v4].8b, v3.8b\n" /* outr00 += 13579 * w01 */ + "smlal v16.8h, %[v5].8b, v7.8b\n" /* outr00 += 246810 * w02 */ + + // r2 + "smlal v14.8h, %[v6].8b, v4.8b \n" /* outr00 = 02468 * w00 */ + "smlal v15.8h, %[v7].8b, v5.8b\n" /* outr00 += 13579 * w01 */ + "smlal v16.8h, %[v8].8b, v8.8b\n" /* outr00 += 246810 * w02 */ + + "ldp q0, q1, [%[ptr_out0]] \n" /* dup v10, bias */ + "ldp q9, q11, [%[rst_mask]] \n" /* dup v10, bias */ + + "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ + + "bif v12.16b, v0.16b, v9.16b \n" + "bif v13.16b, v1.16b, v11.16b \n" + + "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [ptr_out0] "+r"(doutr0), + [vmask] "+r"(val_mask) + : [v0] "w"(wr00), + [v1] "w"(wr01), + [v2] "w"(wr02), + [v3] "w"(wr10), + [bias_val] "r"(vbias), + [v4] "w"(wr11), + [v5] "w"(wr12), + [v6] "w"(wr20), + [v7] "w"(wr21), + [v8] "w"(wr22), + [rst_mask] "r"(rmask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + unsigned int* rst_mask = rmask; + int cnt = cnt_col; + // prefetch input + // store weights + asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" + : + : [wei_ptr] "r"(wei_ptr) + : "memory"); + asm volatile( + // left + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" + "vld2.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 + "vld2.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 + "vld2.8 {d16-d17}, [%[din_ptr2]] @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 + "vmov.u32 d11, #0 @ zero\n" + + "vdup.s8 d5, d0[3] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d6, d0[4] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d7, d0[5] @ d4 = w02, w02, w02, w02\n" + + "vext.8 d18, d11, d13, #7 @ ext \n" // d16 = -1 1 3 5 + "vext.8 d19, d11, d15, #7 @ ext \n" // d17 = -1 1 3 5 + "vext.8 d20, d11, d17, #7 @ ext \n" // d18 = -1 1 3 5 + + // r0 + "vmull.s8 q13, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vmull.s8 q14, d13, d4 @ out1 = din0 * w02 \n" // q12 = d12 * w02 + "vmull.s8 q15, d18, d2 @ out2 = din0 * w00 \n" // q12 = d12 * w02 + + "vdup.s8 d8, d0[6] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d9, d0[7] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d10, d1[0] @ d4 = w02, w02, w02, w02\n" + + // r1 + "vmlal.s8 q13, d14, d6 @ out0 += din1 * w11 \n" // q12 = d12 * w11 + "vmlal.s8 q14, d15, d7 @ out1 += din1 * w12 \n" // q12 = d12 * w11 + "vmlal.s8 q15, d19, d5 @ out2 += din1 * w10 \n" // q12 = d12 * w11 + + // out0 + "vdup.32 q11, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q12, %[bias] @ and \n" // q9 = + // vbias + + // r2 + "vmlal.s8 q13, d16, d9 @ out0 += din1 * w21 \n" // q12 = d12 * w11 + "vmlal.s8 q14, d17, d10 @ out1 += din1 * w22 \n" // q12 = d12 * w11 + "vmlal.s8 q15, d20, d8 @ out2 += din1 * w20 \n" // q12 = d12 * w11 + + "add %[din_ptr0], #15 @add \n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "add %[din_ptr1], #15 @add \n" + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + "add %[din_ptr2], #15 @add \n" + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + + "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" + "cmp %[cnt], #1 \n" + "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" + "blt 1f \n" + + // mid + "2: \n" + "vld2.8 {d12-d13}, [%[din_ptr0]]! @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 + "vld2.8 {d14-d15}, [%[din_ptr1]]! @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 + "vld2.8 {d16-d17}, [%[din_ptr2]]! @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 + + "vld1.8 {d21}, [%[din_ptr0]] @ load din00= 16 17\n" // d10 = 0 2 + // 4 6 + "vld1.8 {d22}, [%[din_ptr1]] @ load din00= 16 17\n" // d12 = 0 2 + // 4 6 + "vld1.8 {d23}, [%[din_ptr2]] @ load din00= 16 17\n" // d14 = 0 2 + // 4 6 + + "vext.8 d18, d12, d21, #1 @ ext din00 = 2 4 6 8\n" // d16 = 2 + // 4 6 8 + "vext.8 d19, d14, d22, #1 @ ext \n" // d17 = 2 4 6 8 + "vext.8 d20, d16, d23, #1 @ ext \n" // d18 = 2 4 6 8 + + // r0 + "vmull.s8 q13, d12, d2 @ out0 = din0 * w00 \n" // q12 = 0 2 4 6 + "vmull.s8 q14, d13, d3 @ out1 = din0 * w01 \n" // q12 = 1 3 5 7 + "vmull.s8 q15, d18, d4 @ out2 = din0 * w02 \n" // q12 = 2 4 6 8 + + // out0 + "vdup.32 q11, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q12, %[bias] @ and \n" // q9 = + // vbias + + // r1 + "vmlal.s8 q13, d14, d5 @ out0 += din1 * w10 \n" // q12 = 0 2 4 6 + "vmlal.s8 q14, d15, d6 @ out1 += din1 * w11 \n" // q12 = 1 3 5 7 + "vmlal.s8 q15, d19, d7 @ out2 += din1 * w12 \n" // q12 = 2 4 6 8 + + // r2 + "vmlal.s8 q13, d16, d8 @ out0 += din1 * w20 \n" // q12 = 0 2 4 6 + "vmlal.s8 q14, d17, d9 @ out1 += din1 * w21 \n" // q12 = 1 3 5 7 + "vmlal.s8 q15, d20, d10 @ out2 += din1 * w22 \n" // q12 = 2 4 6 8 + + // "add %[din_ptr0], #16 @add \n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + // "add %[din_ptr1], #16 @add \n" + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + // "add %[din_ptr2], #16 @add \n" + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + + "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" + + "subs %[cnt], #1 \n" + "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" + "bne 2b \n" + // right + "1: \n" + "cmp %[size_pad_right], #1 \n" + "blt 3f \n" + "vld2.8 {d12-d13}, [%[din_ptr0]]! @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 + "vld2.8 {d14-d15}, [%[din_ptr1]]! @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 + "vld2.8 {d16-d17}, [%[din_ptr2]]! @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 + "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + + // out0 + "vdup.32 q11, %[bias] @ and \n" // q8 = vbias + "vdup.32 q12, %[bias] @ and \n" // q9 = vbias + + "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" + + "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" + + "vbif.8 d16, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d17, d11, d29 @ bit select, deal with right pad\n" + + "vext.8 d18, d12, d11, #1 @ ext din00 = 2 4 6 8\n" // d16 = -1 + // 1 3 5 + "vext.8 d19, d14, d11, #1 @ ext \n" // d17 = -1 1 3 5 + "vext.8 d20, d16, d11, #1 @ ext \n" // d18 = -1 1 3 5 + + // r0 + "vmull.s8 q13, d12, d2 @ out0 = din0 * w00 \n" // q12 = 0 2 4 6 + "vmull.s8 q14, d13, d3 @ out1 = din0 * w01 \n" // q12 = 1 3 5 7 + "vmull.s8 q15, d18, d4 @ out2 = din0 * w02 \n" // q12 = 2 4 6 8 + + // r1 + "vmlal.s8 q13, d14, d5 @ out0 += din1 * w11 \n" // q12 = 0 2 4 6 + "vmlal.s8 q14, d15, d6 @ out1 += din1 * w12 \n" // q12 = 1 3 5 7 + "vmlal.s8 q15, d19, d7 @ out2 += din1 * w10 \n" // q12 = 2 4 6 8 + + "vld1.32 {d12-d13}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vld1.32 {d14-d15}, [%[dout_ptr1]] @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + + // r2 + "vmlal.s8 q13, d16, d8 @ out0 += din1 * w11 \n" // q12 = 0 2 4 6 + "vmlal.s8 q14, d17, d9 @ out1 += din1 * w12 \n" // q12 = 1 3 5 7 + "vmlal.s8 q15, d20, d10 @ out2 += din1 * w10 \n" // q12 = 2 4 6 8 + + "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 5 6 7 8 " + "9\n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "sub %[dout_ptr1], #16 @ sub \n" + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vbif q11, q6, q1 @ bit select, deal with right pad\n" + "vbif q12, q7, q2 @ bit select, deal with right pad\n" + + "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" + "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" + "3: \n" + + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [dout_ptr1] "+r"(doutr0), + [cnt] "+r"(cnt), + [bias] "+r"(bias_val), + [rs_mask] "+r"(rst_mask) + : [mask] "r"(vmask), [size_pad_right] "r"(size_pad_right) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + dout_ptr += w_out; + } + } + } +} +// w_in <= 16 +void conv_depthwise_3x3s2p1_bias_s_int7(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + // printf("3x3s2 mult height \n"); + //! pad is done implicit + // const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + //! for 4x6 convolution window + const unsigned char right_pad_idx[16] = { + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // printf("conv3x3_dw start \n"); + signed char* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(signed char)); + int* write_ptr = + reinterpret_cast(ctx->workspace_data()) + w_out; + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + unsigned int size_pad_right = (unsigned int)(w_in); + + uint8x8_t vmask_rp1 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); + uint8x8_t vmask_rp2 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); + unsigned int rst_remain = (unsigned int)w_out; + uint32x4_t vmask_result1 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); + uint32x4_t vmask_result2 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); + + uint8x16_t vmask_rp = + vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); + unsigned char vmask[16]; + vst1q_u8(vmask, vmask_rp); + + unsigned int rmask[8]; + vst1q_u32(rmask, vmask_result1); + vst1q_u32(rmask + 4, vmask_result2); + + int8x8_t vzero = vdup_n_s8(0); + for (int n = 0; n < num; ++n) { + const signed char* din_batch = din + n * ch_in * size_in_channel; + int* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + int* dout_ptr = dout_batch + c * size_out_channel; + + const signed char* din_ch_ptr = din_batch + c * size_in_channel; + + int bias_val = flag_bias ? bias[c] : 0; + + const signed char* wei_ptr = weights + c * w_stride; +#ifdef __aarch64__ + int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + + int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); + int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); + int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); + + int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); + + int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); +#endif + int* doutr0 = nullptr; + + const signed char* dr0 = din_ch_ptr; + const signed char* dr1 = dr0 + w_in; + const signed char* dr2 = dr1 + w_in; + + const signed char* din_ptr0 = nullptr; + const signed char* din_ptr1 = nullptr; + const signed char* din_ptr2 = nullptr; + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + + doutr0 = dout_ptr; + + int out_buf1[8]; + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + dr0 = dr1; + dr1 = dr2; + dr2 = dr1 + w_in; + } else { + dr0 = dr2; + dr1 = dr2 + w_in; + dr2 = dr1 + w_in; + } + //! process bottom pad + if (i + 2 > h_in) { + switch (i + 2 - h_in) { + case 2: + din_ptr1 = zero_ptr; + case 1: + din_ptr2 = zero_ptr; + default: + break; + } + } +#ifdef __aarch64__ + unsigned int* rst_mask = rmask; + unsigned char* val_mask = vmask; + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "movi v16.4s, #0x0\n" + // left + "ld1 {v10.8b}, [%[vmask]], #8 \n" + "ld1 {v11.8b}, [%[vmask]] \n" + "ld2 {v0.8b - v1.8b}, [%[din_ptr0]] \n" /*load a00-a015 + to q0*/ + "ld2 {v2.8b - v3.8b}, [%[din_ptr1]] \n" /* load a00-a015 + to q0*/ + "ld2 {v4.8b - v5.8b}, [%[din_ptr2]] \n" /*load a00-a015 + to q0*/ + + "bif v0.8b, v16.8b, v10.8b \n" + "bif v1.8b, v16.8b, v11.8b \n" + "bif v2.8b, v16.8b, v10.8b \n" + "bif v3.8b, v16.8b, v11.8b \n" + "bif v4.8b, v16.8b, v10.8b \n" + "bif v5.8b, v16.8b, v11.8b \n" + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias*/ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "ext v6.8b, v16.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + "ext v7.8b, v16.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + "ext v8.8b, v16.8b, v5.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + + // r0 + "smull v17.8h, %[v1].8b, v0.8b \n" /* outr00 = 02468 * w01 */ + "smull v18.8h, %[v2].8b, v1.8b\n" /* outr00 += 13579 * w02 */ + "smull v19.8h, %[v0].8b, v6.8b\n" /* outr00 += 013579 * w00 */ + + // "ldp q0, q1, [%[ptr_out0]] \n" /* dup v10, + // bias */ "ldp q10, q11, [%[rst_mask]] \n" /* + // dup v10, bias */ + + // r1 + "smlal v17.8h, %[v4].8b, v2.8b \n" /* outr00 = 02468 * w01 */ + "smlal v18.8h, %[v5].8b, v3.8b\n" /* outr00 += 13579 * w02 */ + "smlal v19.8h, %[v3].8b, v7.8b\n" /* outr00 += 013579 * w00 */ + + // r2 + "smlal v17.8h, %[v7].8b, v4.8b \n" /* outr00 = 02468 * w01 */ + "smlal v18.8h, %[v8].8b, v5.8b\n" /* outr00 += 13579 * w02 */ + "smlal v19.8h, %[v6].8b, v8.8b\n" /* outr00 += 013579 * w00 */ + + "saddw v12.4s, v12.4s, v17.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v17.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v18.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + // "bif v12.16b, v0.16b, v10.16b \n" + // "bif v13.16b, v1.16b, v11.16b \n" + + "stp q12, q13, [%[ptr_out0]] \n" /* store q10, q11 -> ptr_out + */ + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [vmask] "+r"(val_mask) + : [v0] "w"(wr00), + [v1] "w"(wr01), + [v2] "w"(wr02), + [v3] "w"(wr10), + [bias_val] "r"(vbias), + [v4] "w"(wr11), + [v5] "w"(wr12), + [v6] "w"(wr20), + [v7] "w"(wr21), + [v8] "w"(wr22), + [rst_mask] "r"(rmask), + [ptr_out0] "r"(out_buf1) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + unsigned int* rst_mask = rmask; + // prefetch input + // store weights + asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" + : + : [wei_ptr] "r"(wei_ptr) + : "memory"); + asm volatile( + // left + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" + "vld2.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 + "vld2.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 + "vld2.8 {d16-d17}, [%[din_ptr2]] @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 + "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vmov.u32 d11, #0 @ zero\n" + + "vdup.s8 d5, d0[3] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d6, d0[4] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d7, d0[5] @ d4 = w02, w02, w02, w02\n" + + "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" + + "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" + + "vbif.8 d16, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d17, d11, d29 @ bit select, deal with right pad\n" + + "vext.8 d18, d11, d13, #7 @ ext \n" // d16 = -1 1 3 5 + "vext.8 d19, d11, d15, #7 @ ext \n" // d17 = -1 1 3 5 + "vext.8 d20, d11, d17, #7 @ ext \n" // d18 = -1 1 3 5 + + // "pld [%[dout_ptr1]] @ preload data\n" + + // r0 + "vmull.s8 q13, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vmull.s8 q14, d13, d4 @ out1 = din0 * w02 \n" // q12 = d12 * w02 + "vmull.s8 q15, d18, d2 @ out2 = din0 * w00 \n" // q12 = d12 * w02 + + "vdup.s8 d8, d0[6] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d9, d0[7] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d10, d1[0] @ d4 = w02, w02, w02, w02\n" + + // r1 + "vmlal.s8 q13, d14, d6 @ out0 += din1 * w11 \n" // q12 = d12 * w11 + "vmlal.s8 q14, d15, d7 @ out1 += din1 * w12 \n" // q12 = d12 * w11 + "vmlal.s8 q15, d19, d5 @ out2 += din1 * w10 \n" // q12 = d12 * w11 + + // "vld1.32 {d12-d13}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 + // 6 7 8 9\n" "vld1.32 {d14-d15}, [%[dout_ptr1]] @ load din00= 0 + // 1 2 3 4 5 6 7 8 9\n" + + // out0 + "vdup.32 q11, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q12, %[bias] @ and \n" // q9 = + // vbias + + // r2 + "vmlal.s8 q13, d16, d9 @ out0 += din1 * w21 \n" // q12 = d12 * w11 + "vmlal.s8 q14, d17, d10 @ out1 += din1 * w22 \n" // q12 = d12 * w11 + "vmlal.s8 q15, d20, d8 @ out2 += din1 * w20 \n" // q12 = d12 * w11 + + // "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 + // 8 9\n" "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 + // 5 6 7 8 9\n" + + // "sub %[dout_ptr1], #16 @ sub \n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + // "vbif q11, q6, q1 @ bit select, deal with right pad\n" + // "vbif q12, q7, q2 @ bit select, deal with right pad\n" + + "vst1.32 {d22-d25}, [%[dout_ptr1]] @ store\n" + // "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [bias] "+r"(bias_val), + [rs_mask] "+r"(rst_mask) + : [mask] "r"(vmask), + [size_pad_right] "r"(size_pad_right), + [dout_ptr1] "r"(out_buf1) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + for (int w = 0; w < w_out; ++w) { + *doutr0++ = out_buf1[w]; + } + dout_ptr += w_out; + } + } + } +} + +// relu +void conv_depthwise_3x3s1p1_bias_relu_int7(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + // printf("3x3s1 mult height \n"); + //! pad is done implicit + const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + //! for 4x6 convolution window + const unsigned char right_pad_idx[16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // printf("conv3x3_dw start \n"); + signed char* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(signed char)); + int* write_ptr = + reinterpret_cast(ctx->workspace_data()) + w_in; + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_w = (w_in + 7) >> 3; + int tile_h = (h_out + 1) >> 1; + int cnt_col = tile_w - 2; + + unsigned int size_pad_right = (unsigned int)(w_in - 7 - (cnt_col << 3)); + + int size_pad_bottom = h_out % 2; + + uint8x8_t vmask_rp1 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); + uint8x8_t vmask_rp2 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); + unsigned int rst_remain = (unsigned int)(w_out - ((cnt_col + 1) << 3)); + uint32x4_t vmask_result1 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); + uint32x4_t vmask_result2 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); + + int8x8_t vzero = vdup_n_s8(0); + int32x4_t vzero_32 = vdupq_n_s32(0); + + uint8x16_t vmask_rp = + vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); + // uint8x8_t vmask_rp2 = vcgt_u8(vdup_n_u8(size_pad_right), + // vld1_u8(right_pad_idx + 8)); + unsigned char vmask[16]; + vst1q_u8(vmask, vmask_rp); + + unsigned int rmask[8]; + vst1q_u32(rmask, vmask_result1); + vst1q_u32(rmask + 4, vmask_result2); + + for (int n = 0; n < num; ++n) { + const signed char* din_batch = din + n * ch_in * size_in_channel; + int* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + int* dout_ptr = dout_batch + c * size_out_channel; + + const signed char* din_ch_ptr = din_batch + c * size_in_channel; + + int bias_val = flag_bias ? bias[c] : 0; + + const signed char* wei_ptr = weights + c * w_stride; +#ifdef __aarch64__ + int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); + int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); + int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); + + int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); + + int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); +#endif + + int* doutr0 = nullptr; + int* doutr1 = nullptr; + + const signed char* dr0 = din_ch_ptr; + const signed char* dr1 = dr0 + w_in; + const signed char* dr2 = dr1 + w_in; + const signed char* dr3 = dr2 + w_in; + + const signed char* din_ptr0 = nullptr; + const signed char* din_ptr1 = nullptr; + const signed char* din_ptr2 = nullptr; + const signed char* din_ptr3 = nullptr; + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + + doutr0 = dout_ptr; + doutr1 = doutr0 + w_out; + unsigned int* rst_mask = rmask; + unsigned char* val_mask = vmask; + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + din_ptr3 = dr2; + dr0 = dr1; + dr1 = dr2; + dr2 = dr3; + dr3 = dr2 + w_in; + } else { + dr0 = dr2; + dr1 = dr3; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + } + //! process bottom pad + if (i + 3 > h_in) { + switch (i + 3 - h_in) { + case 3: + din_ptr1 = zero_ptr; + case 2: + din_ptr2 = zero_ptr; + case 1: + din_ptr3 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 2 > h_out) { + doutr1 = write_ptr; + } + int cnt = cnt_col; +#ifdef __aarch64__ + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "PRFM PLDL1KEEP, [%[din_ptr3]] \n" + "movi v21.4s, #0x0\n" /* out0 = 0 */ + // left + "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load + a00-a015 to + q0*/ + "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load + a00-a015 to + q0*/ + + "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + // r0 + "smull v18.8h, %[v1].8b, v0.8b \n" /* outr00 = 01234567 * w01 + */ + + "ext v4.8b, v21.8b, v0.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v5.8b, v0.8b, v1.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load + a00-a015 + to + q0*/ + + "smlal v18.8h, %[v0].8b, v4.8b\n" /* outr00 += 00123456 * w00 */ + + "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load + a00-a015 + to q0*/ + "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load + a00-a015 + to q0*/ + + "sub %[din_ptr0], %[din_ptr0], #1 \n" + "sub %[din_ptr1], %[din_ptr1], #1 \n" + + "smlal v18.8h, %[v2].8b, v5.8b\n" /* outr00 += 12345678 * w02 */ + + "ext v4.8b, v21.8b, v2.8b, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v5.8b, v2.8b, v3.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + // r1 + "sub %[din_ptr2], %[din_ptr2], #1 \n" + "sub %[din_ptr3], %[din_ptr3], #1 \n" + + "smull v19.8h, %[v1].8b, v2.8b \n" /* outr10 += 01234567 * w11 + */ + "smlal v18.8h, %[v4].8b, v2.8b \n" /* outr00 += 01234567 * w11 + */ + + "ext v14.8b, v21.8b, v6.8b, #7 \n" /* vext_s8(vzero, vinr0, + 7); 00123456 */ + "ext v15.8b, v6.8b, v7.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "smlal v19.8h, %[v0].8b, v4.8b \n" /* outr00 += 01234567 * w11 + */ + "smlal v18.8h, %[v3].8b, v4.8b \n" /* outr00 += 001234567 * w10 + */ + + "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load + a00-a015 + to + q0*/ + + "smlal v19.8h, %[v2].8b, v5.8b \n" /* outr00 += 01234567 * w11 + */ + "smlal v18.8h, %[v5].8b, v5.8b \n" /* outr00 += 12345678 * w12 + */ + + // r2 + "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load + a00-a015 to + q0*/ + "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load + a00-a015 to + q0*/ + + "smlal v19.8h, %[v4].8b, v6.8b \n" /* outr10 += 01234567 * w11 + */ + "smlal v18.8h, %[v7].8b, v6.8b \n" /* outr00 += 01234567 * w11 + */ + + "ext v4.8b, v21.8b, v8.8b, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v5.8b, v8.8b, v9.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "smlal v19.8h, %[v3].8b, v14.8b \n" /* outr10 += 01234567 * w11 + */ + "smlal v18.8h, %[v6].8b, v14.8b \n" /* outr00 += 01234567 * w11 + */ + + "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load + a00-a015 + to + q0*/ + + "smlal v19.8h, %[v5].8b, v15.8b \n" /* outr10 += 01234567 * w11 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v8].8b, v15.8b \n" /* outr00 += 01234567 * w11 + */ + + // r3 + "smlal v19.8h, %[v7].8b, v8.8b \n" /* outr00 += 01234567 * w11 + */ + + "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load + a00-a015 + to + q0*/ + + "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load + a00-a015 to + q0*/ + "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load + a00-a015 to + q0*/ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smlal v19.8h, %[v6].8b, v4.8b \n" /* outr00 += 01234567 * + w11 */ + + "smax v10.4s, v10.4s, v21.4s \n" /* relu*/ + "smax v11.4s, v11.4s, v21.4s \n" /* relu*/ + + "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v8].8b, v5.8b \n" /* outr00 += 01234567 * + w11 */ + + "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smax v12.4s, v12.4s, v21.4s \n" /* relu*/ + "smax v13.4s, v13.4s, v21.4s \n" /* relu*/ + + "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "cmp %[cnt], #1 \n" + "blt 3f \n" + // mid + "1: \n" + "ext v4.8b, v0.8B, v1.8b, #1 \n" /*12345678 */ + "ext v5.8b, v0.8b, v1.8B, #2 \n" /*23456789 */ + + // r0 + "smull v18.8h, %[v0].8b, v0.8b \n" /* outr00 = 01234567 * w00 + */ + + "ext v14.8b, v2.8B, v3.8b, #1 \n" /*12345678 */ + "ext v15.8b, v2.8b, v3.8B, #2 \n" /*23456789 */ + + "smlal v18.8h, %[v1].8b, v4.8b\n" /* outr00 += 12345678 * w01 */ + + "ext v16.8b, v6.8B, v7.8b, #1 \n" /*12345678 */ + "ext v17.8b, v6.8b, v7.8B, #2 \n" /*23456789 */ + + "smlal v18.8h, %[v2].8b, v5.8b\n" /* outr00 += 23456789 * w02 */ + + // r1 + "ext v4.8b, v8.8B, v9.8b, #1 \n" /*12345678 */ + "ext v5.8b, v8.8b, v9.8B, #2 \n" /*23456789 */ + + "smull v19.8h, %[v0].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v3].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load + a00-a015 + to + q0*/ + + "smlal v19.8h, %[v1].8b, v14.8b\n" /* outr00 += 12345678 * w01 */ + "smlal v18.8h, %[v4].8b, v14.8b\n" /* outr00 += 12345678 * w01 */ + + "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load + a00-a015 + to q0*/ + "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load + a00-a015 + to q0*/ + + "smlal v19.8h, %[v2].8b, v15.8b\n" /* outr00 += 23456789 * w02 */ + "smlal v18.8h, %[v5].8b, v15.8b\n" /* outr00 += 12345678 * w01 */ + + // r2 + "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "smlal v19.8h, %[v4].8b, v16.8b\n" /* outr00 += 12345678 * w01 */ + "smlal v18.8h, %[v7].8b, v16.8b\n" /* outr00 += 12345678 * w01 */ + + "smlal v19.8h, %[v5].8b, v17.8b\n" /* outr00 += 23456789 * w02 */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v8].8b, v17.8b\n" /* outr00 += 12345678 * w01 */ + + // r3 + "smlal v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load + a00-a015 + to + q0*/ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smlal v19.8h, %[v7].8b, v4.8b\n" /* outr00 += 12345678 * w01 */ + + "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load + a00-a015 + to q0*/ + "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load + a00-a015 + to q0*/ + + "smax v10.4s, v10.4s, v21.4s \n" /* relu*/ + "smax v11.4s, v11.4s, v21.4s \n" /* relu*/ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v8].8b, v5.8b\n" /* outr00 += 23456789 * w02 */ + + "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "subs %[cnt], %[cnt], #1 \n" + + "smax v12.4s, v12.4s, v21.4s \n" /* relu*/ + "smax v13.4s, v13.4s, v21.4s \n" /* relu*/ + + "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "bne 1b \n" + // right + "3: \n" + "ld1 {v14.8b}, [%[vmask]], #8 \n" + "ld1 {v15.8b}, [%[vmask]] \n" + + "bif v0.8b, v21.8b, v14.8b \n" + "bif v1.8b, v21.8b, v15.8b \n" + "bif v2.8b, v21.8b, v14.8b \n" + "bif v3.8b, v21.8b, v15.8b \n" + + "ext v4.8b, v0.8b, v1.8b, #1 \n" + "ext v5.8b, v0.8b, v1.8b, #2 \n" + + // r0 + "smull v18.8h, %[v0].8b, v0.8b \n" /* outr00 = 01234567 * w00 + */ + + "ext v16.8b, v2.8b, v3.8b, #1 \n" + "ext v17.8b, v2.8b, v3.8b, #2 \n" + + "bif v6.8b, v21.8b, v14.8b \n" + "bif v7.8b, v21.8b, v15.8b \n" + + "smlal v18.8h, %[v1].8b, v4.8b \n" /* outr00 = 01234567 * w00 + */ + + "bif v8.8b, v21.8b, v14.8b \n" + "bif v9.8b, v21.8b, v15.8b \n" + + "ext v20.8b, v6.8b, v7.8b, #1 \n" + "ext v22.8b, v6.8b, v7.8b, #2 \n" + + "smlal v18.8h, %[v2].8b, v5.8b \n" /* outr00 = 01234567 * w00 + */ + + // r1 + "ext v4.8b, v8.8b, v9.8b, #1 \n" + "ext v5.8b, v8.8b, v9.8b, #2 \n" + + "smull v19.8h, %[v0].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v3].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v14.4s}, [%[rmask]], #16 \n" + "ld1 {v15.4s}, [%[rmask]] \n" + + "smlal v19.8h, %[v1].8b, v16.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v4].8b, v16.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v0.4s}, [%[ptr_out0]], #16 \n" + "ld1 {v2.4s}, [%[ptr_out1]], #16 \n" + + "smlal v19.8h, %[v2].8b, v17.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v5].8b, v17.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v1.4s}, [%[ptr_out0]] \n" + "ld1 {v3.4s}, [%[ptr_out1]] \n" + + // r2 + "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "sub %[ptr_out0], %[ptr_out0], #16 \n" + "sub %[ptr_out1], %[ptr_out1], #16 \n" + + "smlal v19.8h, %[v4].8b, v20.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v7].8b, v20.8b \n" /* outr00 = 01234567 * w00 + */ + + "smlal v19.8h, %[v5].8b, v22.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v8].8b, v22.8b \n" /* outr00 = 01234567 * w00 + */ + + // r3 + "smlal v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smlal v19.8h, %[v7].8b, v4.8b \n" /* outr00 = 01234567 * w00 + */ + + "smax v10.4s, v10.4s, v21.4s \n" /* relu*/ + "smax v11.4s, v11.4s, v21.4s \n" /* relu*/ + + "bif v10.16b, v0.16b, v14.16b \n" + "bif v11.16b, v1.16b, v15.16b \n" + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v8].8b, v5.8b \n" /* outr00 = 01234567 * w00 + */ + + "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smax v12.4s, v12.4s, v21.4s \n" /* relu*/ + "smax v13.4s, v13.4s, v21.4s \n" /* relu*/ + + "bif v12.16b, v2.16b, v14.16b \n" + "bif v13.16b, v3.16b, v15.16b \n" + + "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> + ptr_out */ + + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [ptr_out0] "+r"(doutr0), + [ptr_out1] "+r"(doutr1), + [vmask] "+r"(val_mask), + [rmask] "+r"(rst_mask) + : [v0] "w"(wr00), + [v1] "w"(wr01), + [v2] "w"(wr02), + [v3] "w"(wr10), + [bias_val] "r"(vbias), + [v4] "w"(wr11), + [v5] "w"(wr12), + [v6] "w"(wr20), + [v7] "w"(wr21), + [v8] "w"(wr22) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + // store weights + asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" + : + : [wei_ptr] "r"(wei_ptr) + : "memory"); + asm volatile( + // left + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + "pld [%[din_ptr3]] @ preload data\n" + "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" + "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vmov.u32 d11, #0 @ zero\n" + // out0 + "vdup.32 q8, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q9, %[bias] @ and \n" // q9 = + // vbias + // out1 + "vdup.32 q10, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q11, %[bias] @ and \n" // q9 = + // vbias + + // r0 + "vmull.s8 q12, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 + + "vld1.8 {d12-d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.8 {d14-d15}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vdup.s8 d5, d0[3] @ d5 = w10, w10, w00, w00\n" + "vdup.s8 d6, d0[4] @ d6 = w11, w11, w01, w01\n" + + "vmlal.s8 q12, d30, d2 @ out0 += din0 * w00 \n" // q12 += d10 * w00 + + "vdup.s8 d7, d0[5] @ d7 = w12, w12\n" + "add %[din_ptr0], #7 @add \n" + "add %[din_ptr1], #7 @add \n" + + "vmlal.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 + + // r1 + "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 + "vmull.s8 q13, d12, d3 @ out1 = din1 * w01 \n" // q13 = d12 * w01 + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + "vmull.s8 q12, d12, d6 @ out0 = din1 * w11 \n" // q12 = d12 * w11 + + "vld1.8 {d12-d13}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vdup.s8 d8, d0[6] @ d8 = w20, w00, w00, w00\n" + "vdup.s8 d9, d0[7] @ d9 = w21, w01, w01, w01\n" + "vdup.s8 d10, d1[0] @ d10 = w22, w02, w02, w02\n" + + "vmlal.s8 q13, d30, d2 @ out1 += din1 * w00 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d5 @ out0 += din1 * w10 \n" // q12 += d10 * w00 + + "add %[din_ptr2], #7 @add \n" + "add %[din_ptr3], #7 @add \n" + + "vmlal.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 + + // r2 + "vext.8 d30, d11, d14, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d15, #1 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d14, d6 @ out1 = din2 * w11 \n" // q13 = d12 * w01 + "vmull.s8 q12, d14, d9 @ out1 = din2 * w21 \n" // q13 = d12 * w01 + + "vmlal.s8 q13, d30, d5 @ out1 += din2 * w10 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d8 @ out0 += din2 * w20 \n" // q12 += d10 * w00 + + "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 + + // r3 + "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 + "vmov.u32 q0, #0 @ mov \n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d12, d9 @ out1 = din3 * w21 \n" // q13 = d12 * w01 + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "vmax.s32 q8, q8, q0 @ max \n" + "vmax.s32 q9, q9, q0 @ max \n" + + "vmlal.s8 q13, d30, d8 @ out1 += din3 * w20 \n" // q13 += d10 * w00 + "pld [%[din_ptr2]] @ preload data\n" + "pld [%[din_ptr3]] @ preload data\n" + + "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" + + "vmlal.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 + + "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmax.s32 q10, q10, q0 @ max \n" + "vmax.s32 q11, q11, q0 @ max \n" + + "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" + "cmp %[cnt], #1 \n" + "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" + "blt 1f \n" + + // mid + "2: \n" + "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + // out0 + "vdup.32 q8, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q9, %[bias] @ and \n" // q9 = + // vbias + // out1 + "vdup.32 q10, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q11, %[bias] @ and \n" // q9 = + // vbias + + // r0 + "vmull.s8 q12, d12, d2 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 12345678 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 23456789 + + "vld1.8 {d12-d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.8 {d14-d15}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + + "vmlal.s8 q12, d30, d3 @ out0 += din0 * w00 \n" // q12 += d10 * w00 + + "add %[din_ptr0], #8 @add \n" + "add %[din_ptr1], #8 @add \n" + + "vmlal.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 + + // r1 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 + "vmull.s8 q13, d12, d2 @ out1 = din1 * w01 \n" // q13 = d12 * w01 + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + "vmull.s8 q12, d12, d5 @ out0 = din1 * w11 \n" // q12 = d12 * w11 + + "vld1.8 {d12-d13}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + + "vmlal.s8 q13, d30, d3 @ out1 += din1 * w00 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d6 @ out0 += din1 * w10 \n" // q12 += d10 * w00 + + "add %[din_ptr2], #8 @add \n" + "add %[din_ptr3], #8 @add \n" + + "vmlal.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 + + // r2 + "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d14, d5 @ out1 = din2 * w11 \n" // q13 = d12 * w01 + "vmull.s8 q12, d14, d8 @ out1 = din2 * w21 \n" // q13 = d12 * w01 + + "vmlal.s8 q13, d30, d6 @ out1 += din2 * w10 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d9 @ out0 += din2 * w20 \n" // q12 += d10 * w00 + + "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 + + // r3 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d12, d8 @ out1 = din3 * w21 \n" // q13 = d12 * w01 + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "vmax.s32 q8, q8, q0 @ max \n" + "vmax.s32 q9, q9, q0 @ max \n" + + "vmlal.s8 q13, d30, d9 @ out1 += din3 * w20 \n" // q13 += d10 * w00 + "pld [%[din_ptr2]] @ preload data\n" + "pld [%[din_ptr3]] @ preload data\n" + + "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" + + "vmlal.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 + + "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmax.s32 q10, q10, q0 @ max \n" + "vmax.s32 q11, q11, q0 @ max \n" + + "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" + "subs %[cnt], #1 \n" + "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" + "bne 2b \n" + // right + "1: \n" + "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + // out0 + "vdup.32 q8, %[bias] @ and \n" // q8 = vbias + "vdup.32 q9, %[bias] @ and \n" // q9 = vbias + // out1 + "vdup.32 q10, %[bias] @ and \n" // q8 = vbias + "vdup.32 q11, %[bias] @ and \n" // q9 = vbias + + "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" + "vld1.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + + // r0 + "vmull.s8 q12, d12, d2 @ out0 = din0 * w00 \n" // q12 = d12 * w01 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 12345678 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 23456789 + + "vld1.8 {d12-d13}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" + + "vmlal.s8 q12, d30, d3 @ out0 += din0 * w01 \n" // q12 += d10 * w00 + + "vmlal.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 + + // r1 + "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 + + "vmull.s8 q13, d14, d2 @ out1 = din1 * w00 \n" // q13 = d12 * w01 + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + "vmull.s8 q12, d14, d5 @ out0 = din1 * w10 \n" // q12 = d12 * w11 + + "vld1.8 {d14-d15}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vbif.8 d12, d11, d28 @ bit select, deal with " + "right pad\n" + "vbif.8 d13, d11, d29 @ bit select, deal with " + "right pad\n" + + "vmlal.s8 q13, d30, d3 @ out1 += din1 * w01 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d6 @ out0 += din1 * w11 \n" // q12 += d10 * w00 + + "vmlal.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 + + // r2 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d12, d5 @ out1 = din2 * w10 \n" // q13 = d12 * w01 + "vmull.s8 q12, d12, d8 @ out1 = din2 * w20 \n" // q13 = d12 * w01 + + "vbif.8 d14, d11, d28 @ bit select, deal with " + "right pad\n" + "vbif.8 d15, d11, d29 @ bit select, deal with " + "right pad\n" + + "vmlal.s8 q13, d30, d6 @ out1 += din2 * w10 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d9 @ out0 += din2 * w20 \n" // q12 += d10 * w00 + + "vld1.32 {d28-d29}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vld1.32 {d12-d13}, [%[dout_ptr1]] @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 5 6 7 8 " + "9\n" + + "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 + + // r3 + "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d14, d8 @ out1 = din3 * w20 \n" // q13 = d12 * w01 + "vld1.32 {d14-d15}, [%[dout_ptr2]]! @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vld1.32 {d24-d25}, [%[dout_ptr2]] @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vmax.s32 q8, q8, q0 @ max \n" + "vmax.s32 q9, q9, q0 @ max \n" + + "vmlal.s8 q13, d30, d9 @ out1 += din3 * w21 \n" // q13 += d10 * w00 + "vbif q8, q14, q1 @ bit select, deal with right " + "pad\n" + "vbif q9, q6, q2 @ bit select, deal with right " + "pad\n" + "sub %[dout_ptr1], #16 @ sub \n" + "sub %[dout_ptr2], #16 @ sub \n" + + "vmlal.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 + + "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" + "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmax.s32 q10, q10, q0 @ max \n" + "vmax.s32 q11, q11, q0 @ max \n" + + "vbif q10, q7, q1 @ bit select, deal with right pad\n" + "vbif q11, q12, q2 @ bit select, deal with right pad\n" + + "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" + "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" + + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [cnt] "+r"(cnt), + [bias] "+r"(bias_val), + [rs_mask] "+r"(rst_mask) + : [mask] "r"(vmask) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + dout_ptr += 2 * w_out; + } + } + } +} +// w_in <= 8 +void conv_depthwise_3x3s1p1_bias_s_relu_int7(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + // printf("3x3s1 mult height \n"); + //! pad is done implicit + const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + //! for 4x6 convolution window + const unsigned char right_pad_idx[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // printf("conv3x3_dw start \n"); + signed char* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(signed char)); + int* write_ptr = + reinterpret_cast(ctx->workspace_data()) + w_in; + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_h = (h_out + 3) >> 2; + + unsigned int size_pad_right = (unsigned int)(w_in); + + int size_pad_bottom = h_out % 4; + + uint8x8_t vmask_rp = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); + unsigned int rst_remain = (unsigned int)w_out; + uint32x4_t vmask_result1 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); + uint32x4_t vmask_result2 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); + + unsigned char vmask[8]; + vst1_u8(vmask, vmask_rp); + + unsigned int rmask[8]; + vst1q_u32(rmask, vmask_result1); + vst1q_u32(rmask + 4, vmask_result2); + + int8x8_t vzero = vdup_n_s8(0); + int32x4_t vzero_32 = vdupq_n_s32(0); + + for (int n = 0; n < num; ++n) { + const signed char* din_batch = din + n * ch_in * size_in_channel; + int* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + int* dout_ptr = dout_batch + c * size_out_channel; + + const signed char* din_ch_ptr = din_batch + c * size_in_channel; + + int bias_val = flag_bias ? bias[c] : 0; + + const signed char* wei_ptr = weights + c * w_stride; +#ifdef __aarch64__ + int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); + int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); + int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); + + int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); + + int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); +#endif + + int* doutr0 = nullptr; + int* doutr1 = nullptr; + + const signed char* dr0 = din_ch_ptr; + const signed char* dr1 = dr0 + w_in; + const signed char* dr2 = dr1 + w_in; + const signed char* dr3 = dr2 + w_in; + + const signed char* din_ptr0 = nullptr; + const signed char* din_ptr1 = nullptr; + const signed char* din_ptr2 = nullptr; + const signed char* din_ptr3 = nullptr; + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + + doutr0 = dout_ptr; + doutr1 = doutr0 + w_out; + unsigned int* rst_mask = rmask; + unsigned char* val_mask = vmask; + + int out_buf1[8]; + int out_buf2[8]; + + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + din_ptr3 = dr2; + dr0 = dr1; + dr1 = dr2; + dr2 = dr3; + dr3 = dr2 + w_in; + } else { + dr0 = dr2; + dr1 = dr3; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + } + //! process bottom pad + if (i + 3 > h_in) { + switch (i + 3 - h_in) { + case 3: + din_ptr1 = zero_ptr; + case 2: + din_ptr2 = zero_ptr; + case 1: + din_ptr3 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 2 > h_out) { + doutr1 = write_ptr; + } +#ifdef __aarch64__ + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "PRFM PLDL1KEEP, [%[din_ptr3]] \n" + "movi v21.4s, #0x0\n" /* out0 = 0 */ + // left + "ld1 {v4.8b}, [%[vmask]] \n" + "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v1.8b}, [%[din_ptr1]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v2.8b}, [%[din_ptr2]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v3.8b}, [%[din_ptr3]], #8 \n" /* load + a00-a015 + to + q0*/ + + "bif v0.8b, v21.8b, v4.8b \n" + "bif v1.8b, v21.8b, v4.8b \n" + "bif v2.8b, v21.8b, v4.8b \n" + "bif v3.8b, v21.8b, v4.8b \n" + + "ext v6.8b, v21.8b, v0.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v7.8b, v0.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "ld1 {v10.4s}, [%[vbias]] \n" + "ld1 {v11.4s}, [%[vbias]] \n" + + // r0 + "smull v18.8h, %[v1].8b, v0.8b \n" /* outr00 = 01234567 * w01 + */ + + "ext v8.8b, v21.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v9.8b, v1.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "smlal v18.8h, %[v0].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v12.4s}, [%[vbias]] \n" + "ld1 {v13.4s}, [%[vbias]] \n" + + "smlal v18.8h, %[v2].8b, v7.8b \n" /* outr00 = 01234567 * w00 + */ + + "ext v6.8b, v21.8b, v2.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v7.8b, v2.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + // r1 + "smull v19.8h, %[v1].8b, v1.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v4].8b, v1.8b \n" /* outr00 = 01234567 * w00 + */ + + // "ld1 {v14.4s}, [%[rmask]], #16 \n" + // "ld1 {v15.4s}, [%[rmask]] \n" + + "smlal v19.8h, %[v0].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v3].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + // "ld1 {v16.4s}, [%[ptr_out0]], #16 \n" + // "ld1 {v17.4s}, [%[ptr_out1]], #16 \n" + + "smlal v19.8h, %[v2].8b, v9.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v5].8b, v9.8b \n" /* outr00 = 01234567 * w00 + */ + + "ext v8.8b, v21.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v9.8b, v3.8b, v21.8B, #1 \n" // vext_s8(vinr0, vinr0_1, + // 1); 12345678 + + // "ld1 {v0.4s}, [%[ptr_out0]] \n" + // "ld1 {v1.4s}, [%[ptr_out1]] \n" + + // r2 + "smlal v19.8h, %[v4].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v7].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + + // "sub %[ptr_out0], %[ptr_out0], #16 \n" + // "sub %[ptr_out1], %[ptr_out1], #16 \n" + + "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "smlal v19.8h, %[v5].8b, v7.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v8].8b, v7.8b \n" /* outr00 = 01234567 * w00 + */ + + // r3 + "smlal v19.8h, %[v7].8b, v3.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smlal v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + "smax v10.4s, v10.4s, v21.4s \n" /* relu */ + "smax v11.4s, v11.4s, v21.4s \n" /* relu */ + + // "bif v10.16b, v16.16b, v14.16b \n" + // "bif v11.16b, v0.16b, v15.16b \n" + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v8].8b, v9.8b \n" /* outr00 = 01234567 * w00 + */ + + "stp q10, q11, [%[ptr_out0]] \n" /* store q10, q11 -> ptr_out */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smax v12.4s, v12.4s, v21.4s \n" /* relu */ + "smax v13.4s, v13.4s, v21.4s \n" /* relu */ + + // "bif v12.16b, v17.16b, v14.16b \n" + // "bif v13.16b, v1.16b, v15.16b \n" + + "stp q12, q13, [%[ptr_out1]] \n" /* store q10, q11 -> ptr_out + */ + + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [rmask] "+r"(rst_mask) + : [v0] "w"(wr00), + [v1] "w"(wr01), + [v2] "w"(wr02), + [v3] "w"(wr10), + [vbias] "r"(vbias), + [v4] "w"(wr11), + [v5] "w"(wr12), + [v6] "w"(wr20), + [v7] "w"(wr21), + [v8] "w"(wr22), + [vmask] "r"(vmask), + [ptr_out0] "r"(out_buf1), + [ptr_out1] "r"(out_buf2) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + // store weights + asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" + : + : [wei_ptr] "r"(wei_ptr) + : "memory"); + asm volatile( + // left + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + "pld [%[din_ptr3]] @ preload data\n" + "vld1.8 {d28}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + "vld1.8 {d12}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + "vld1.8 {d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" + + "vmov.u32 d11, #0 @ zero\n" + // out0 + "vdup.32 q8, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q9, %[bias] @ and \n" // q9 = + // vbias + + "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d13, d11, d28 @ bit select, deal with right pad\n" + "vld1.8 {d14}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + "vld1.8 {d15}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + // out1 + "vdup.32 q10, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q11, %[bias] @ and \n" // q9 = + // vbias + + // r0 + "vmull.s8 q12, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d11, #1 @ ext \n" // d11 = 12345678 + + "vdup.s8 d5, d0[3] @ d5 = w10, w10, w00, w00\n" + "vdup.s8 d6, d0[4] @ d6 = w11, w11, w01, w01\n" + + "vmlal.s8 q12, d30, d2 @ out0 += din0 * w00 \n" // q12 += d10 * w00 + + "vdup.s8 d7, d0[5] @ d7 = w12, w12\n" + "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d15, d11, d28 @ bit select, deal with right pad\n" + + "vmlal.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 + + // r1 + "vext.8 d30, d11, d13, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d13, d11, #1 @ ext \n" // d11 = 12345678 + "vmull.s8 q13, d13, d3 @ out1 = din1 * w01 \n" // q13 = d12 * w01 + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + "vmull.s8 q12, d13, d6 @ out0 = din1 * w11 \n" // q12 = d12 * w11 + + "vdup.s8 d8, d0[6] @ d8 = w20, w00, w00, w00\n" + "vdup.s8 d9, d0[7] @ d9 = w21, w01, w01, w01\n" + + "vmlal.s8 q13, d30, d2 @ out1 += din1 * w00 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d5 @ out0 += din1 * w10 \n" // q12 += d10 * w00 + + "vdup.s8 d10, d1[0] @ d10 = w22, w02, w02, w02\n" + // "vld1.32 {d28-d29}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 + // 6 7 8 9\n" "vld1.32 {d12-d13}, [%[dout_ptr1]] @ load din00= 0 + // 1 2 3 4 5 6 7 8 9\n" + + "vmlal.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 + + // r2 + "vext.8 d30, d11, d14, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d11, #1 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d14, d6 @ out1 = din2 * w11 \n" // q13 = d12 * w01 + "vmull.s8 q12, d14, d9 @ out1 = din2 * w21 \n" // q13 = d12 * w01 + + // "sub %[dout_ptr1], #16 @ sub \n" + "vmlal.s8 q13, d30, d5 @ out1 += din2 * w10 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d8 @ out0 += din2 * w20 \n" // q12 += d10 * w00 + + // "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 + // 8 9\n" "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 + // 5 6 7 8 9\n" + + "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 + + // r3 + "vext.8 d30, d11, d15, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d15, d11, #1 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d15, d9 @ out1 = din3 * w21 \n" // q13 = d12 * w01 + + "vmov.u32 q0, #0 @ zero\n" + + // "vld1.32 {d6-d7}, [%[dout_ptr2]]! @ load din00= 0 1 2 3 4 5 6 + // 7 8 9\n" "vld1.32 {d14-d15}, [%[dout_ptr2]] @ load din00= 0 1 + // 2 3 4 5 6 7 8 9\n" + + "vmlal.s8 q13, d30, d8 @ out1 += din3 * w20 \n" // q13 += d10 * w00 + + "vmax.s32 q8, q8, q0 @ max \n" + "vmax.s32 q9, q9, q0 @ max \n" + + "vmlal.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 + + // "sub %[dout_ptr2], #16 @ sub \n" + // "vbif q8, q14, q1 @ bit select, deal with right + // pad\n" "vbif q9, q6, q2 @ bit select, deal + // with right pad\n" + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vst1.32 {d16-d19}, [%[dout_ptr1]] @ store\n" + // "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" + + "vmax.s32 q10, q10, q0 @ max \n" + "vmax.s32 q11, q11, q0 @ max \n" + + // "vbif q10, q3, q1 @ bit select, deal with right + // pad\n" "vbif q11, q7, q2 @ bit select, deal + // with right pad\n" + + "vst1.32 {d20-d23}, [%[dout_ptr2]] @ store\n" + // "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [bias] "+r"(bias_val), + [rs_mask] "+r"(rst_mask) + : [mask] "r"(vmask), + [dout_ptr1] "r"(out_buf1), + [dout_ptr2] "r"(out_buf2) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + for (int w = 0; w < w_out; ++w) { + *doutr0++ = out_buf1[w]; + *doutr1++ = out_buf2[w]; + } + dout_ptr += 2 * w_out; + } + } + } +} + +// 1 line w_in > 16 +void conv_depthwise_3x3s2p1_bias_relu_int7(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + // printf("3x3s2 mult height \n"); + //! pad is done implicit + //! for 4x6 convolution window + const unsigned char right_pad_idx[16] = { + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // printf("conv3x3_dw start \n"); + signed char* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(signed char)); + int* write_ptr = + reinterpret_cast(ctx->workspace_data()) + w_out; + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_w = (w_in + 15) >> 4; + int cnt_col = tile_w - 2; + + unsigned int size_pad_right = (unsigned int)(w_in - 15 - (cnt_col << 4)); + if (size_pad_right == 17) { + size_pad_right = 0; + cnt_col++; + } + + uint8x8_t vmask_rp1 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); + uint8x8_t vmask_rp2 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); + unsigned int rst_remain = (unsigned int)(w_out - ((cnt_col + 1) << 3)); + uint32x4_t vmask_result1 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); + uint32x4_t vmask_result2 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); + + int8x8_t vzero = vdup_n_s8(0); + int32x4_t vzero_32 = vdupq_n_s32(0); + + uint8x16_t vmask_rp = + vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); + unsigned char vmask[16]; + vst1q_u8(vmask, vmask_rp); + + unsigned int rmask[8]; + vst1q_u32(rmask, vmask_result1); + vst1q_u32(rmask + 4, vmask_result2); + + for (int n = 0; n < num; ++n) { + const signed char* din_batch = din + n * ch_in * size_in_channel; + int* dout_batch = dout + n * ch_in * size_out_channel; + +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + int* dout_ptr = dout_batch + c * size_out_channel; + + const signed char* din_ch_ptr = din_batch + c * size_in_channel; + + int bias_val = flag_bias ? bias[c] : 0; + + const signed char* wei_ptr = weights + c * w_stride; +#ifdef __aarch64__ + int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); + int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); + int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); + + int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); + + int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); +#endif + + int* doutr0 = nullptr; + + const signed char* dr0 = din_ch_ptr; + const signed char* dr1 = dr0 + w_in; + const signed char* dr2 = dr1 + w_in; + + const signed char* din_ptr0 = nullptr; + const signed char* din_ptr1 = nullptr; + const signed char* din_ptr2 = nullptr; + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + + doutr0 = dout_ptr; + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + dr0 = dr1; + dr1 = dr2; + dr2 = dr1 + w_in; + } else { + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + } + //! process bottom pad + if (i + 2 > h_in) { + switch (i + 2 - h_in) { + case 2: + din_ptr1 = zero_ptr; + case 1: + din_ptr2 = zero_ptr; + default: + break; + } + } + int cnt = cnt_col; +#ifdef __aarch64__ + unsigned char* val_mask = vmask; + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "movi v10.4s, #0x0\n" + // left + "ld2 {v0.8b - v1.8b}, [%[din_ptr0]] \n" /*load a00-a015 + to q0*/ + "ld2 {v2.8b - v3.8b}, [%[din_ptr1]] \n" /* load a00-a015 + to q0*/ + "ld2 {v4.8b - v5.8b}, [%[din_ptr2]] \n" /*load a00-a015 + to q0*/ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias*/ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "ext v6.8b, v10.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + "ext v7.8b, v10.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + "ext v8.8b, v10.8b, v5.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + + // r0 + "smull v14.8h, %[v1].8b, v0.8b \n" /* outr00 = 02468 * w01 */ + "smull v15.8h, %[v2].8b, v1.8b\n" /* outr00 += 13579 * w02 */ + "smull v16.8h, %[v0].8b, v6.8b\n" /* outr00 += 013579 * w00 */ + + "add %[din_ptr0], %[din_ptr0], #15 \n" + "add %[din_ptr1], %[din_ptr1], #15 \n" + "add %[din_ptr2], %[din_ptr2], #15 \n" + + // r1 + "smlal v14.8h, %[v4].8b, v2.8b \n" /* outr00 = 02468 * w01 */ + "smlal v15.8h, %[v5].8b, v3.8b\n" /* outr00 += 13579 * w02 */ + "smlal v16.8h, %[v3].8b, v7.8b\n" /* outr00 += 013579 * w00 */ + + // r2 + "smlal v14.8h, %[v7].8b, v4.8b \n" /* outr00 = 02468 * w01 */ + "smlal v15.8h, %[v8].8b, v5.8b\n" /* outr00 += 13579 * w02 */ + "smlal v16.8h, %[v6].8b, v8.8b\n" /* outr00 += 013579 * w00 */ + + "ld2 {v0.8b - v1.8b}, [%[din_ptr0]], #16 \n" /*load + a00-a015 + to q0*/ + "ld2 {v2.8b - v3.8b}, [%[din_ptr1]], #16 \n" /* load + a00-a015 + to q0*/ + "ld2 {v4.8b - v5.8b}, [%[din_ptr2]], #16 \n" /*load + a00-a015 + to q0*/ + + "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ + + "smax v12.4s, v12.4s, v10.4s \n" /*relu*/ + "smax v13.4s, v13.4s, v10.4s \n" /*relu*/ + + "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "cmp %[cnt], #1 \n" + "blt 3f \n" + // mid + "1: \n" + "ld1 {v6.8b}, [%[din_ptr0]] \n" /*load a00-a015 to q0*/ + "ld1 {v7.8b}, [%[din_ptr1]] \n" /*load a00-a015 to q0*/ + "ld1 {v8.8b}, [%[din_ptr2]] \n" /*load a00-a015 to q0*/ + + "ext v9.8b, v0.8b, v6.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 246810 */ + "ext v11.8b, v2.8b, v7.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 246810 */ + "ext v14.8b, v4.8b, v8.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 246810 */ + + // r0 + "smull v6.8h, %[v0].8b, v0.8b \n" /* outr00 = 02468 * w00 */ + "smull v7.8h, %[v1].8b, v1.8b\n" /* outr00 += 13579 * w01 */ + "smull v8.8h, %[v2].8b, v9.8b\n" /* outr00 += 246810 * w02 */ + + // r1 + "smlal v6.8h, %[v3].8b, v2.8b \n" /* outr00 = 02468 * w00 */ + "smlal v7.8h, %[v4].8b, v3.8b\n" /* outr00 += 13579 * w01 */ + "smlal v8.8h, %[v5].8b, v11.8b\n" /* outr00 += 246810 * w02 */ + + // r2 + "smlal v6.8h, %[v6].8b, v4.8b \n" /* outr00 = 02468 * w00 */ + "smlal v7.8h, %[v7].8b, v5.8b\n" /* outr00 += 13579 * w01 */ + "smlal v8.8h, %[v8].8b, v14.8b\n" /* outr00 += 246810 * w02 */ + + "ld2 {v0.8b - v1.8b}, [%[din_ptr0]], #16 \n" /*load + a00-a015 + to q0*/ + "ld2 {v2.8b - v3.8b}, [%[din_ptr1]], #16 \n" /* load + a00-a015 + to q0*/ + "ld2 {v4.8b - v5.8b}, [%[din_ptr2]], #16 \n" /*load + a00-a015 + to q0*/ + + "saddw v12.4s, v12.4s, v6.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v6.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v7.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v7.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v8.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v8.8h \n" /* v11 += outr00.high*/ + + "smax v12.4s, v12.4s, v10.4s \n" /*relu*/ + "smax v13.4s, v13.4s, v10.4s \n" /*relu*/ + + "subs %[cnt], %[cnt], #1 \n" + + "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "bne 1b \n" + // right + "3: \n" + "ld1 {v14.8b}, [%[vmask]], #8 \n" + "ld1 {v15.8b}, [%[vmask]] \n" + + "bif v0.8b, v10.8b, v14.8b \n" + "bif v1.8b, v10.8b, v15.8b \n" + "bif v2.8b, v10.8b, v14.8b \n" + "bif v3.8b, v10.8b, v15.8b \n" + "bif v4.8b, v10.8b, v14.8b \n" + "bif v5.8b, v10.8b, v15.8b \n" + + "ext v6.8b, v0.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 2468.. */ + "ext v7.8b, v2.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 2468..*/ + "ext v8.8b, v4.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 2468.. */ + + // r0 + "smull v14.8h, %[v0].8b, v0.8b \n" /* outr00 = 02468 * w00 */ + "smull v15.8h, %[v1].8b, v1.8b\n" /* outr00 += 13579 * w01 */ + "smull v16.8h, %[v2].8b, v6.8b\n" /* outr00 += 246810 * w02 */ + + // r1 + "smlal v14.8h, %[v3].8b, v2.8b \n" /* outr00 = 02468 * w00 */ + "smlal v15.8h, %[v4].8b, v3.8b\n" /* outr00 += 13579 * w01 */ + "smlal v16.8h, %[v5].8b, v7.8b\n" /* outr00 += 246810 * w02 */ + + // r2 + "smlal v14.8h, %[v6].8b, v4.8b \n" /* outr00 = 02468 * w00 */ + "smlal v15.8h, %[v7].8b, v5.8b\n" /* outr00 += 13579 * w01 */ + "smlal v16.8h, %[v8].8b, v8.8b\n" /* outr00 += 246810 * w02 */ + + "ldp q0, q1, [%[ptr_out0]] \n" /* dup v10, bias */ + "ldp q9, q11, [%[rst_mask]] \n" /* dup v10, bias */ + + "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ + + "smax v12.4s, v12.4s, v10.4s \n" /*relu*/ + "smax v13.4s, v13.4s, v10.4s \n" /*relu*/ + + "bif v12.16b, v0.16b, v9.16b \n" + "bif v13.16b, v1.16b, v11.16b \n" + + "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [ptr_out0] "+r"(doutr0), + [vmask] "+r"(val_mask) + : [v0] "w"(wr00), + [v1] "w"(wr01), + [v2] "w"(wr02), + [v3] "w"(wr10), + [bias_val] "r"(vbias), + [v4] "w"(wr11), + [v5] "w"(wr12), + [v6] "w"(wr20), + [v7] "w"(wr21), + [v8] "w"(wr22), + [rst_mask] "r"(rmask) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + unsigned int* rst_mask = rmask; + // prefetch input + // store weights + asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" + : + : [wei_ptr] "r"(wei_ptr) + : "memory"); + asm volatile( + // left + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" + "vld2.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 + "vld2.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 + "vld2.8 {d16-d17}, [%[din_ptr2]] @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 + "vmov.u32 d11, #0 @ zero\n" + + "vdup.s8 d5, d0[3] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d6, d0[4] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d7, d0[5] @ d4 = w02, w02, w02, w02\n" + + "vext.8 d18, d11, d13, #7 @ ext \n" // d16 = -1 1 3 5 + "vext.8 d19, d11, d15, #7 @ ext \n" // d17 = -1 1 3 5 + "vext.8 d20, d11, d17, #7 @ ext \n" // d18 = -1 1 3 5 + + // r0 + "vmull.s8 q13, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vmull.s8 q14, d13, d4 @ out1 = din0 * w02 \n" // q12 = d12 * w02 + "vmull.s8 q15, d18, d2 @ out2 = din0 * w00 \n" // q12 = d12 * w02 + + "vdup.s8 d8, d0[6] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d9, d0[7] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d10, d1[0] @ d4 = w02, w02, w02, w02\n" + + // r1 + "vmlal.s8 q13, d14, d6 @ out0 += din1 * w11 \n" // q12 = d12 * w11 + "vmlal.s8 q14, d15, d7 @ out1 += din1 * w12 \n" // q12 = d12 * w11 + "vmlal.s8 q15, d19, d5 @ out2 += din1 * w10 \n" // q12 = d12 * w11 + + // out0 + "vdup.32 q11, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q12, %[bias] @ and \n" // q9 = + // vbias + + // r2 + "vmlal.s8 q13, d16, d9 @ out0 += din1 * w21 \n" // q12 = d12 * w11 + "vmlal.s8 q14, d17, d10 @ out1 += din1 * w22 \n" // q12 = d12 * w11 + "vmlal.s8 q15, d20, d8 @ out2 += din1 * w20 \n" // q12 = d12 * w11 + + "add %[din_ptr0], #15 @add \n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vmov.u32 q8, #0 @ max \n" // max + "add %[din_ptr1], #15 @add \n" + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + "add %[din_ptr2], #15 @add \n" + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + + "vmax.s32 q11, q11, q8 @ max\n" + "vmax.s32 q12, q12, q8 @ max\n" + + "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" + "cmp %[cnt], #1 \n" + "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" + "blt 1f \n" + + // mid + "2: \n" + "vld2.8 {d12-d13}, [%[din_ptr0]]! @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 + "vld2.8 {d14-d15}, [%[din_ptr1]]! @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 + "vld2.8 {d16-d17}, [%[din_ptr2]]! @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 + + "vld1.8 {d21}, [%[din_ptr0]] @ load din00= 16 17\n" // d10 = 0 2 + // 4 6 + "vld1.8 {d22}, [%[din_ptr1]] @ load din00= 16 17\n" // d12 = 0 2 + // 4 6 + "vld1.8 {d23}, [%[din_ptr2]] @ load din00= 16 17\n" // d14 = 0 2 + // 4 6 + + "vext.8 d18, d12, d21, #1 @ ext din00 = 2 4 6 8\n" // d16 = 2 + // 4 6 8 + "vext.8 d19, d14, d22, #1 @ ext \n" // d17 = 2 4 6 8 + "vext.8 d20, d16, d23, #1 @ ext \n" // d18 = 2 4 6 8 + + // r0 + "vmull.s8 q13, d12, d2 @ out0 = din0 * w00 \n" // q12 = 0 2 4 6 + "vmull.s8 q14, d13, d3 @ out1 = din0 * w01 \n" // q12 = 1 3 5 7 + "vmull.s8 q15, d18, d4 @ out2 = din0 * w02 \n" // q12 = 2 4 6 8 + + // out0 + "vdup.32 q11, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q12, %[bias] @ and \n" // q9 = + // vbias + + // r1 + "vmlal.s8 q13, d14, d5 @ out0 += din1 * w10 \n" // q12 = 0 2 4 6 + "vmlal.s8 q14, d15, d6 @ out1 += din1 * w11 \n" // q12 = 1 3 5 7 + "vmlal.s8 q15, d19, d7 @ out2 += din1 * w12 \n" // q12 = 2 4 6 8 + + // r2 + "vmlal.s8 q13, d16, d8 @ out0 += din1 * w20 \n" // q12 = 0 2 4 6 + "vmlal.s8 q14, d17, d9 @ out1 += din1 * w21 \n" // q12 = 1 3 5 7 + "vmlal.s8 q15, d20, d10 @ out2 += din1 * w22 \n" // q12 = 2 4 6 8 + + // "add %[din_ptr0], #16 @add \n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + // "add %[din_ptr1], #16 @add \n" + "vmov.u32 q8, #0 @ mov \n" + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + // "add %[din_ptr2], #16 @add \n" + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + + "vmax.s32 q11, q11, q8 @ max\n" + "vmax.s32 q12, q12, q8 @ max\n" + + "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" + + "subs %[cnt], #1 \n" + "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" + "bne 2b \n" + // right + "1: \n" + "cmp %[size_pad_right], #1 \n" + "blt 3f \n" + "vld2.8 {d12-d13}, [%[din_ptr0]]! @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 + "vld2.8 {d14-d15}, [%[din_ptr1]]! @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 + "vld2.8 {d16-d17}, [%[din_ptr2]]! @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 + "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + + // out0 + "vdup.32 q11, %[bias] @ and \n" // q8 = vbias + "vdup.32 q12, %[bias] @ and \n" // q9 = vbias + + "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" + + "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" + + "vbif.8 d16, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d17, d11, d29 @ bit select, deal with right pad\n" + + "vext.8 d18, d12, d11, #1 @ ext din00 = 2 4 6 8\n" // d16 = -1 + // 1 3 5 + "vext.8 d19, d14, d11, #1 @ ext \n" // d17 = -1 1 3 5 + "vext.8 d20, d16, d11, #1 @ ext \n" // d18 = -1 1 3 5 + + // r0 + "vmull.s8 q13, d12, d2 @ out0 = din0 * w00 \n" // q12 = 0 2 4 6 + "vmull.s8 q14, d13, d3 @ out1 = din0 * w01 \n" // q12 = 1 3 5 7 + "vmull.s8 q15, d18, d4 @ out2 = din0 * w02 \n" // q12 = 2 4 6 8 + + // r1 + "vmlal.s8 q13, d14, d5 @ out0 += din1 * w11 \n" // q12 = 0 2 4 6 + "vmlal.s8 q14, d15, d6 @ out1 += din1 * w12 \n" // q12 = 1 3 5 7 + "vmlal.s8 q15, d19, d7 @ out2 += din1 * w10 \n" // q12 = 2 4 6 8 + + "vld1.32 {d12-d13}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vld1.32 {d14-d15}, [%[dout_ptr1]] @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + + // r2 + "vmlal.s8 q13, d16, d8 @ out0 += din1 * w11 \n" // q12 = 0 2 4 6 + "vmlal.s8 q14, d17, d9 @ out1 += din1 * w12 \n" // q12 = 1 3 5 7 + "vmlal.s8 q15, d20, d10 @ out2 += din1 * w10 \n" // q12 = 2 4 6 8 + + "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 5 6 7 8 " + "9\n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "sub %[dout_ptr1], #16 @ sub \n" + "vmov.u32 q8, #0 @mov \n" + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmax.s32 q11, q11, q8 @ max\n" + "vmax.s32 q12, q12, q8 @ max\n" + + "vbif q11, q6, q1 @ bit select, deal with right pad\n" + "vbif q12, q7, q2 @ bit select, deal with right pad\n" + + "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" + "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" + "3: \n" + + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [dout_ptr1] "+r"(doutr0), + [cnt] "+r"(cnt), + [bias] "+r"(bias_val), + [rs_mask] "+r"(rst_mask) + : [mask] "r"(vmask), [size_pad_right] "r"(size_pad_right) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + dout_ptr += w_out; + } + } + } +} +// w_in <= 16 +void conv_depthwise_3x3s2p1_bias_s_relu_int7(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + // printf("3x3s2 mult height \n"); + //! pad is done implicit + // const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + //! for 4x6 convolution window + const unsigned char right_pad_idx[16] = { + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // printf("conv3x3_dw start \n"); + signed char* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(signed char)); + int* write_ptr = + reinterpret_cast(ctx->workspace_data()) + w_out; + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + unsigned int size_pad_right = (unsigned int)(w_in); + + uint8x8_t vmask_rp1 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); + uint8x8_t vmask_rp2 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); + unsigned int rst_remain = (unsigned int)w_out; + uint32x4_t vmask_result1 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); + uint32x4_t vmask_result2 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); + + uint8x16_t vmask_rp = + vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); + unsigned char vmask[16]; + vst1q_u8(vmask, vmask_rp); + + unsigned int rmask[8]; + vst1q_u32(rmask, vmask_result1); + vst1q_u32(rmask + 4, vmask_result2); + int8x8_t vzero = vdup_n_s8(0); + int32x4_t vzero_32 = vdupq_n_s32(0); + + for (int n = 0; n < num; ++n) { + const signed char* din_batch = din + n * ch_in * size_in_channel; + int* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + int* dout_ptr = dout_batch + c * size_out_channel; + + const signed char* din_ch_ptr = din_batch + c * size_in_channel; + + int bias_val = flag_bias ? bias[c] : 0; + + const signed char* wei_ptr = weights + c * w_stride; + +#ifdef __aarch64__ + int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); + int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); + int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); + + int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); + + int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); +#endif + + int* doutr0 = nullptr; + + const signed char* dr0 = din_ch_ptr; + const signed char* dr1 = dr0 + w_in; + const signed char* dr2 = dr1 + w_in; + + const signed char* din_ptr0 = nullptr; + const signed char* din_ptr1 = nullptr; + const signed char* din_ptr2 = nullptr; + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + + doutr0 = dout_ptr; + + int out_buf1[8]; + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + dr0 = dr1; + dr1 = dr2; + dr2 = dr1 + w_in; + } else { + dr0 = dr2; + dr1 = dr2 + w_in; + dr2 = dr1 + w_in; + } + //! process bottom pad + if (i + 2 > h_in) { + switch (i + 2 - h_in) { + case 2: + din_ptr1 = zero_ptr; + case 1: + din_ptr2 = zero_ptr; + default: + break; + } + } +#ifdef __aarch64__ + unsigned int* rst_mask = rmask; + unsigned char* val_mask = vmask; + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "movi v16.4s, #0x0\n" + // left + "ld1 {v10.8b}, [%[vmask]], #8 \n" + "ld1 {v11.8b}, [%[vmask]] \n" + "ld2 {v0.8b - v1.8b}, [%[din_ptr0]] \n" /*load a00-a015 + to q0*/ + "ld2 {v2.8b - v3.8b}, [%[din_ptr1]] \n" /* load a00-a015 + to q0*/ + "ld2 {v4.8b - v5.8b}, [%[din_ptr2]] \n" /*load a00-a015 + to q0*/ + + "bif v0.8b, v16.8b, v10.8b \n" + "bif v1.8b, v16.8b, v11.8b \n" + "bif v2.8b, v16.8b, v10.8b \n" + "bif v3.8b, v16.8b, v11.8b \n" + "bif v4.8b, v16.8b, v10.8b \n" + "bif v5.8b, v16.8b, v11.8b \n" + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias*/ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "ext v6.8b, v16.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + "ext v7.8b, v16.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + "ext v8.8b, v16.8b, v5.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + + // r0 + "smull v17.8h, %[v1].8b, v0.8b \n" /* outr00 = 02468 * w01 */ + "smull v18.8h, %[v2].8b, v1.8b\n" /* outr00 += 13579 * w02 */ + "smull v19.8h, %[v0].8b, v6.8b\n" /* outr00 += 013579 * w00 */ + + // "ldp q0, q1, [%[ptr_out0]] \n" /* dup v10, + // bias */ "ldp q10, q11, [%[rst_mask]] \n" /* + // dup v10, bias */ + + // r1 + "smlal v17.8h, %[v4].8b, v2.8b \n" /* outr00 = 02468 * w01 */ + "smlal v18.8h, %[v5].8b, v3.8b\n" /* outr00 += 13579 * w02 */ + "smlal v19.8h, %[v3].8b, v7.8b\n" /* outr00 += 013579 * w00 */ + + // r2 + "smlal v17.8h, %[v7].8b, v4.8b \n" /* outr00 = 02468 * w01 */ + "smlal v18.8h, %[v8].8b, v5.8b\n" /* outr00 += 13579 * w02 */ + "smlal v19.8h, %[v6].8b, v8.8b\n" /* outr00 += 013579 * w00 */ + + "saddw v12.4s, v12.4s, v17.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v17.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v18.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smax v12.4s, v12.4s, v16.4s \n" /*relu*/ + "smax v13.4s, v13.4s, v16.4s \n" /*relu*/ + + // "bif v12.16b, v0.16b, v10.16b \n" + // "bif v13.16b, v1.16b, v11.16b \n" + + "stp q12, q13, [%[ptr_out0]] \n" /* store q10, q11 -> ptr_out + */ + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [vmask] "+r"(val_mask) + : [v0] "w"(wr00), + [v1] "w"(wr01), + [v2] "w"(wr02), + [v3] "w"(wr10), + [bias_val] "r"(vbias), + [v4] "w"(wr11), + [v5] "w"(wr12), + [v6] "w"(wr20), + [v7] "w"(wr21), + [v8] "w"(wr22), + [rst_mask] "r"(rmask), + [ptr_out0] "r"(out_buf1) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + unsigned int* rst_mask = rmask; + // prefetch input + // store weights + asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" + : + : [wei_ptr] "r"(wei_ptr) + : "memory"); + asm volatile( + // left + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" + "vld2.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 + "vld2.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 + "vld2.8 {d16-d17}, [%[din_ptr2]] @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 + "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vmov.u32 d11, #0 @ zero\n" + + "vdup.s8 d5, d0[3] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d6, d0[4] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d7, d0[5] @ d4 = w02, w02, w02, w02\n" + + "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" + + "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" + + "vbif.8 d16, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d17, d11, d29 @ bit select, deal with right pad\n" + + "vext.8 d18, d11, d13, #7 @ ext \n" // d16 = -1 1 3 5 + "vext.8 d19, d11, d15, #7 @ ext \n" // d17 = -1 1 3 5 + "vext.8 d20, d11, d17, #7 @ ext \n" // d18 = -1 1 3 5 + + // "pld [%[dout_ptr1]] @ preload data\n" + + // r0 + "vmull.s8 q13, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vmull.s8 q14, d13, d4 @ out1 = din0 * w02 \n" // q12 = d12 * w02 + "vmull.s8 q15, d18, d2 @ out2 = din0 * w00 \n" // q12 = d12 * w02 + + "vdup.s8 d8, d0[6] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d9, d0[7] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d10, d1[0] @ d4 = w02, w02, w02, w02\n" + + // r1 + "vmlal.s8 q13, d14, d6 @ out0 += din1 * w11 \n" // q12 = d12 * w11 + "vmlal.s8 q14, d15, d7 @ out1 += din1 * w12 \n" // q12 = d12 * w11 + "vmlal.s8 q15, d19, d5 @ out2 += din1 * w10 \n" // q12 = d12 * w11 + + // "vld1.32 {d12-d13}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 + // 6 7 8 9\n" "vld1.32 {d14-d15}, [%[dout_ptr1]] @ load din00= 0 + // 1 2 3 4 5 6 7 8 9\n" + + // out0 + "vdup.32 q11, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q12, %[bias] @ and \n" // q9 = + // vbias + + // r2 + "vmlal.s8 q13, d16, d9 @ out0 += din1 * w21 \n" // q12 = d12 * w11 + "vmlal.s8 q14, d17, d10 @ out1 += din1 * w22 \n" // q12 = d12 * w11 + "vmlal.s8 q15, d20, d8 @ out2 += din1 * w20 \n" // q12 = d12 * w11 + + // "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 + // 8 9\n" "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 + // 5 6 7 8 9\n" + + // "sub %[dout_ptr1], #16 @ sub \n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vmov.u32 q8, #0 @ mov \n" + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmax.s32 q11, q11, q8 @ max\n" + "vmax.s32 q12, q12, q8 @ max\n" + + // "vbif q11, q6, q1 @ bit select, deal with right pad\n" + // "vbif q12, q7, q2 @ bit select, deal with right pad\n" + + "vst1.32 {d22-d25}, [%[dout_ptr1]] @ store\n" + // "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [bias] "+r"(bias_val), + [rs_mask] "+r"(rst_mask) + : [mask] "r"(vmask), + [size_pad_right] "r"(size_pad_right), + [dout_ptr1] "r"(out_buf1) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + for (int w = 0; w < w_out; ++w) { + *doutr0++ = out_buf1[w]; + } + dout_ptr += w_out; + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_depthwise_3x3_int8.cc b/lite/arm/math/conv_depthwise_3x3_int8.cc new file mode 100644 index 00000000000..f8b14359a61 --- /dev/null +++ b/lite/arm/math/conv_depthwise_3x3_int8.cc @@ -0,0 +1,5832 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void conv_depthwise_3x3s1p1_bias_int8(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +//! for input width <= 8 +void conv_depthwise_3x3s1p1_bias_s_int8(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p1_bias_int8(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +//! for input width <= 8 +void conv_depthwise_3x3s2p1_bias_s_int8(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s1p1_bias_relu_int8(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +//! for input width <= 4 +void conv_depthwise_3x3s1p1_bias_s_relu_int8(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p1_bias_relu_int8(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +//! for input width <= 4 +void conv_depthwise_3x3s2p1_bias_s_relu_int8(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3_int8(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + ARMContext* ctx, + PrecisionType out_type, + const float* scale) { + int w_in = win; + int h_in = hin; + int ch_in = chin; + + int w_out = wout; + int h_out = hout; + int ch_out = chout; + int stride_h = param.strides[0]; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + // if (param.activation_param.has_active){ + // if (param.activation_param.active == Active_relu || + // fabs(param.activation_param.negative_slope) > 1e-6f){ + // flag_relu = true; + // } + // } + //! only support stride = 1 or 2 + if (stride_h == 1) { + if (flag_relu) { + if (w_in > 8) { + conv_depthwise_3x3s1p1_bias_relu_int8(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s1p1_bias_s_relu_int8(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } else { + if (w_in > 8) { + conv_depthwise_3x3s1p1_bias_int8(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s1p1_bias_s_int8(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + } else { //! stride = 2 + if (flag_relu) { + if (w_in > 16) { + conv_depthwise_3x3s2p1_bias_relu_int8(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p1_bias_s_relu_int8(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } else { + if (w_in > 16) { + conv_depthwise_3x3s2p1_bias_int8(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p1_bias_s_int8(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + } +} +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, + * width > 4 + */ + +// 4line w_in > 8 +void conv_depthwise_3x3s1p1_bias_int8(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + // printf("3x3s1 mult height \n"); + //! pad is done implicit + const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + const unsigned char right_pad_idx[16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // printf("conv3x3_dw start \n"); + signed char* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(signed char)); + int* write_ptr = + reinterpret_cast(ctx->workspace_data()) + w_in; + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_w = (w_in + 7) >> 3; + int tile_h = (h_out + 1) >> 1; + int cnt_col = tile_w - 2; + + unsigned int size_pad_right = (unsigned int)(w_in - 7 - (cnt_col << 3)); + + int size_pad_bottom = h_out % 2; + + uint8x8_t vmask_rp1 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); + uint8x8_t vmask_rp2 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); + + uint8x16_t vmask_rp = + vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); + // uint8x8_t vmask_rp2 = vcgt_u8(vdup_n_u8(size_pad_right), + // vld1_u8(right_pad_idx + 8)); + unsigned char vmask[16]; + vst1q_u8(vmask, vmask_rp); + + unsigned int rst_remain = (unsigned int)(w_out - ((cnt_col + 1) << 3)); + uint32x4_t vmask_result1 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); + uint32x4_t vmask_result2 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); + + unsigned int rmask[8]; + vst1q_u32(rmask, vmask_result1); + vst1q_u32(rmask + 4, vmask_result2); + + int8x8_t vzero = vdup_n_s8(0); + int32x4_t vzero_32 = vdupq_n_s32(0); + + for (int n = 0; n < num; ++n) { + const signed char* din_batch = din + n * ch_in * size_in_channel; + int* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + int* dout_ptr = dout_batch + c * size_out_channel; + + const signed char* din_ch_ptr = din_batch + c * size_in_channel; + + int bias_val = flag_bias ? bias[c] : 0; + + const signed char* wei_ptr = weights + c * w_stride; + +#ifdef __aarch64__ + int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + + int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); + int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); + int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); + + int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); + + int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); +#endif + int* doutr0 = nullptr; + int* doutr1 = nullptr; + + const signed char* dr0 = din_ch_ptr; + const signed char* dr1 = dr0 + w_in; + const signed char* dr2 = dr1 + w_in; + const signed char* dr3 = dr2 + w_in; + + const signed char* din_ptr0 = nullptr; + const signed char* din_ptr1 = nullptr; + const signed char* din_ptr2 = nullptr; + const signed char* din_ptr3 = nullptr; + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + + doutr0 = dout_ptr; + doutr1 = doutr0 + w_out; + unsigned int* rst_mask = rmask; + unsigned char* val_mask = vmask; + + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + din_ptr3 = dr2; + dr0 = dr1; + dr1 = dr2; + dr2 = dr3; + dr3 = dr2 + w_in; + } else { + dr0 = dr2; + dr1 = dr3; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + } + //! process bottom pad + if (i + 3 > h_in) { + switch (i + 3 - h_in) { + case 3: + din_ptr1 = zero_ptr; + case 2: + din_ptr2 = zero_ptr; + case 1: + din_ptr3 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 2 > h_out) { + doutr1 = write_ptr; + } + int cnt = cnt_col; +#ifdef __aarch64__ + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "PRFM PLDL1KEEP, [%[din_ptr3]] \n" + "movi v21.4s, #0x0\n" /* out0 = 0 */ + // left + "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load + a00-a015 to + q0*/ + "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load + a00-a015 to + q0*/ + + "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + // r0 + "smull v18.8h, %[v1].8b, v0.8b \n" /* outr00 = 01234567 * w01 + */ + + "ext v4.8b, v21.8b, v0.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v5.8b, v0.8b, v1.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load + a00-a015 + to + q0*/ + + "smlal v18.8h, %[v0].8b, v4.8b\n" /* outr00 += 00123456 * w00 */ + + "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load + a00-a015 + to q0*/ + "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load + a00-a015 + to q0*/ + + "sub %[din_ptr0], %[din_ptr0], #1 \n" + "sub %[din_ptr1], %[din_ptr1], #1 \n" + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v2].8b, v5.8b\n" /* outr00 += 12345678 * w02 */ + + "ext v4.8b, v21.8b, v2.8b, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v5.8b, v2.8b, v3.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + // r1 + "sub %[din_ptr2], %[din_ptr2], #1 \n" + "sub %[din_ptr3], %[din_ptr3], #1 \n" + + "smull v19.8h, %[v1].8b, v2.8b \n" /* outr10 += 01234567 * w11 + */ + "smlal v18.8h, %[v4].8b, v2.8b \n" /* outr00 += 01234567 * w11 + */ + + "ext v14.8b, v21.8b, v6.8b, #7 \n" /* vext_s8(vzero, vinr0, + 7); 00123456 */ + "ext v15.8b, v6.8b, v7.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "smlal v19.8h, %[v0].8b, v4.8b \n" /* outr00 += 01234567 * w11 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v3].8b, v4.8b \n" /* outr00 += 001234567 * w10 + */ + + "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load + a00-a015 + to + q0*/ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v2].8b, v5.8b \n" /* outr00 += 01234567 * w11 + */ + "smlal v18.8h, %[v5].8b, v5.8b \n" /* outr00 += 12345678 * w12 + */ + + // r2 + "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load + a00-a015 to + q0*/ + "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load + a00-a015 to + q0*/ + + "smlal v19.8h, %[v4].8b, v6.8b \n" /* outr10 += 01234567 * w11 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v7].8b, v6.8b \n" /* outr00 += 01234567 * w11 + */ + + "ext v4.8b, v21.8b, v8.8b, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v5.8b, v8.8b, v9.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v3].8b, v14.8b \n" /* outr10 += 01234567 * w11 + */ + "smlal v18.8h, %[v6].8b, v14.8b \n" /* outr00 += 01234567 * w11 + */ + + "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load + a00-a015 + to + q0*/ + + "smlal v19.8h, %[v5].8b, v15.8b \n" /* outr10 += 01234567 * w11 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v8].8b, v15.8b \n" /* outr00 += 01234567 * w11 + */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + // r3 + "smull v19.8h, %[v7].8b, v8.8b \n" /* outr00 += 01234567 * w11 + */ + + "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load + a00-a015 + to + q0*/ + + "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load + a00-a015 to + q0*/ + "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load + a00-a015 to + q0*/ + + "smlal v19.8h, %[v6].8b, v4.8b \n" /* outr00 += 01234567 * + w11 */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v8].8b, v5.8b \n" /* outr00 += 01234567 * + w11 */ + + "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "cmp %[cnt], #1 \n" + "blt 3f \n" + // mid + "1: \n" + "ext v4.8b, v0.8B, v1.8b, #1 \n" /*12345678 */ + "ext v5.8b, v0.8b, v1.8B, #2 \n" /*23456789 */ + + // r0 + "smull v18.8h, %[v0].8b, v0.8b \n" /* outr00 = 01234567 * w00 + */ + + "ext v14.8b, v2.8B, v3.8b, #1 \n" /*12345678 */ + "ext v15.8b, v2.8b, v3.8B, #2 \n" /*23456789 */ + + "smlal v18.8h, %[v1].8b, v4.8b\n" /* outr00 += 12345678 * w01 */ + + "ext v16.8b, v6.8B, v7.8b, #1 \n" /*12345678 */ + "ext v17.8b, v6.8b, v7.8B, #2 \n" /*23456789 */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v2].8b, v5.8b\n" /* outr00 += 23456789 * w02 */ + + // r1 + "ext v4.8b, v8.8B, v9.8b, #1 \n" /*12345678 */ + "ext v5.8b, v8.8b, v9.8B, #2 \n" /*23456789 */ + + "smull v19.8h, %[v0].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v3].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load + a00-a015 + to + q0*/ + + "smlal v19.8h, %[v1].8b, v14.8b\n" /* outr00 += 12345678 * w01 */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v4].8b, v14.8b\n" /* outr00 += 12345678 * w01 */ + + "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load + a00-a015 + to q0*/ + "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load + a00-a015 + to q0*/ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v2].8b, v15.8b\n" /* outr00 += 23456789 * w02 */ + "smlal v18.8h, %[v5].8b, v15.8b\n" /* outr00 += 12345678 * w01 */ + + // r2 + "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v4].8b, v16.8b\n" /* outr00 += 12345678 * w01 */ + "smlal v18.8h, %[v7].8b, v16.8b\n" /* outr00 += 12345678 * w01 */ + + "smlal v19.8h, %[v5].8b, v17.8b\n" /* outr00 += 23456789 * w02 */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v8].8b, v17.8b\n" /* outr00 += 12345678 * w01 */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + // r3 + "smull v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load + a00-a015 + to + q0*/ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smlal v19.8h, %[v7].8b, v4.8b\n" /* outr00 += 12345678 * w01 */ + + "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load + a00-a015 + to q0*/ + "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load + a00-a015 + to q0*/ + + "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v8].8b, v5.8b\n" /* outr00 += 23456789 * w02 */ + + "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "subs %[cnt], %[cnt], #1 \n" + + "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "bne 1b \n" + // right + "3: \n" + "ld1 {v14.8b}, [%[vmask]], #8 \n" + "ld1 {v15.8b}, [%[vmask]] \n" + + "bif v0.8b, v21.8b, v14.8b \n" + "bif v1.8b, v21.8b, v15.8b \n" + "bif v2.8b, v21.8b, v14.8b \n" + "bif v3.8b, v21.8b, v15.8b \n" + + "ext v4.8b, v0.8b, v1.8b, #1 \n" + "ext v5.8b, v0.8b, v1.8b, #2 \n" + + // r0 + "smull v18.8h, %[v0].8b, v0.8b \n" /* outr00 = 01234567 * w00 + */ + + "ext v16.8b, v2.8b, v3.8b, #1 \n" + "ext v17.8b, v2.8b, v3.8b, #2 \n" + + "bif v6.8b, v21.8b, v14.8b \n" + "bif v7.8b, v21.8b, v15.8b \n" + + "smlal v18.8h, %[v1].8b, v4.8b \n" /* outr00 = 01234567 * w00 + */ + + "bif v8.8b, v21.8b, v14.8b \n" + "bif v9.8b, v21.8b, v15.8b \n" + + "ext v20.8b, v6.8b, v7.8b, #1 \n" + "ext v22.8b, v6.8b, v7.8b, #2 \n" + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v2].8b, v5.8b \n" /* outr00 = 01234567 * w00 + */ + + // r1 + "ext v4.8b, v8.8b, v9.8b, #1 \n" + "ext v5.8b, v8.8b, v9.8b, #2 \n" + + "smull v19.8h, %[v0].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v3].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v14.4s}, [%[rmask]], #16 \n" + "ld1 {v15.4s}, [%[rmask]] \n" + + "smlal v19.8h, %[v1].8b, v16.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v4].8b, v16.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v0.4s}, [%[ptr_out0]], #16 \n" + "ld1 {v2.4s}, [%[ptr_out1]], #16 \n" + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v2].8b, v17.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v5].8b, v17.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v1.4s}, [%[ptr_out0]] \n" + "ld1 {v3.4s}, [%[ptr_out1]] \n" + + // r2 + "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "sub %[ptr_out0], %[ptr_out0], #16 \n" + "sub %[ptr_out1], %[ptr_out1], #16 \n" + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v4].8b, v20.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v7].8b, v20.8b \n" /* outr00 = 01234567 * w00 + */ + + "smlal v19.8h, %[v5].8b, v22.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v8].8b, v22.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + // r3 + "smull v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smlal v19.8h, %[v7].8b, v4.8b \n" /* outr00 = 01234567 * w00 + */ + + "bif v10.16b, v0.16b, v14.16b \n" + "bif v11.16b, v1.16b, v15.16b \n" + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v8].8b, v5.8b \n" /* outr00 = 01234567 * w00 + */ + + "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "bif v12.16b, v2.16b, v14.16b \n" + "bif v13.16b, v3.16b, v15.16b \n" + + "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> + ptr_out */ + + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [ptr_out0] "+r"(doutr0), + [ptr_out1] "+r"(doutr1), + [vmask] "+r"(val_mask), + [rmask] "+r"(rst_mask) + : [v0] "w"(wr00), + [v1] "w"(wr01), + [v2] "w"(wr02), + [v3] "w"(wr10), + [bias_val] "r"(vbias), + [v4] "w"(wr11), + [v5] "w"(wr12), + [v6] "w"(wr20), + [v7] "w"(wr21), + [v8] "w"(wr22) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + // store weights + asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" + : + : [wei_ptr] "r"(wei_ptr) + : "memory"); + asm volatile( + // left + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + "pld [%[din_ptr3]] @ preload data\n" + "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" + "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vmov.u32 d11, #0 @ zero\n" + // out0 + "vdup.32 q8, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q9, %[bias] @ and \n" // q9 = + // vbias + // out1 + "vdup.32 q10, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q11, %[bias] @ and \n" // q9 = + // vbias + + // r0 + "vmull.s8 q12, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 + + "vld1.8 {d12-d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.8 {d14-d15}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vdup.s8 d5, d0[3] @ d5 = w10, w10, w00, w00\n" + "vdup.s8 d6, d0[4] @ d6 = w11, w11, w01, w01\n" + + "vmlal.s8 q12, d30, d2 @ out0 += din0 * w00 \n" // q12 += d10 * w00 + + "vdup.s8 d7, d0[5] @ d7 = w12, w12\n" + "add %[din_ptr0], #7 @add \n" + "add %[din_ptr1], #7 @add \n" + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 + + // r1 + "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 + "vmull.s8 q13, d12, d3 @ out1 = din1 * w01 \n" // q13 = d12 * w01 + + "vmlal.s8 q12, d12, d6 @ out0 = din1 * w11 \n" // q12 = d12 * w11 + + "vld1.8 {d12-d13}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vdup.s8 d8, d0[6] @ d8 = w20, w00, w00, w00\n" + "vdup.s8 d9, d0[7] @ d9 = w21, w01, w01, w01\n" + "vdup.s8 d10, d1[0] @ d10 = w22, w02, w02, w02\n" + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d30, d2 @ out1 += din1 * w00 \n" // q12 += d10 * w00 + "vmull.s8 q12, d30, d5 @ out0 += din1 * w10 \n" // q12 += d10 * w00 + + "add %[din_ptr2], #7 @add \n" + "add %[din_ptr3], #7 @add \n" + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 + + // r2 + "vext.8 d30, d11, d14, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d15, #1 @ ext \n" // d11 = 12345678 + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d14, d6 @ out1 = din2 * w11 \n" // q13 = d12 * w01 + "vmull.s8 q12, d14, d9 @ out1 = din2 * w21 \n" // q13 = d12 * w01 + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d30, d5 @ out1 += din2 * w10 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d8 @ out0 += din2 * w20 \n" // q12 += d10 * w00 + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 + "vmull.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 + + // r3 + "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d12, d9 @ out1 = din3 * w21 \n" // q13 = d12 * w01 + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + + "vmlal.s8 q13, d30, d8 @ out1 += din3 * w20 \n" // q13 += d10 * w00 + "pld [%[din_ptr2]] @ preload data\n" + "pld [%[din_ptr3]] @ preload data\n" + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" + + "vmull.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 + + "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" + "cmp %[cnt], #1 \n" + "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" + "blt 1f \n" + + // mid + "2: \n" + "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + // out0 + "vdup.32 q8, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q9, %[bias] @ and \n" // q9 = + // vbias + // out1 + "vdup.32 q10, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q11, %[bias] @ and \n" // q9 = + // vbias + + // r0 + "vmull.s8 q12, d12, d2 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 12345678 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 23456789 + + "vld1.8 {d12-d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.8 {d14-d15}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + + "vmlal.s8 q12, d30, d3 @ out0 += din0 * w00 \n" // q12 += d10 * w00 + + "add %[din_ptr0], #8 @add \n" + "add %[din_ptr1], #8 @add \n" + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 + + // r1 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 + "vmull.s8 q13, d12, d2 @ out1 = din1 * w01 \n" // q13 = d12 * w01 + + "vmlal.s8 q12, d12, d5 @ out0 = din1 * w11 \n" // q12 = d12 * w11 + + "vld1.8 {d12-d13}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + + "vmlal.s8 q13, d30, d3 @ out1 += din1 * w00 \n" // q12 += d10 * w00 + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q12, d30, d6 @ out0 += din1 * w10 \n" // q12 += d10 * w00 + + "add %[din_ptr2], #8 @add \n" + "add %[din_ptr3], #8 @add \n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 + + // r2 + "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d14, d5 @ out1 = din2 * w11 \n" // q13 = d12 * w01 + "vmull.s8 q12, d14, d8 @ out1 = din2 * w21 \n" // q13 = d12 * w01 + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d30, d6 @ out1 += din2 * w10 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d9 @ out0 += din2 * w20 \n" // q12 += d10 * w00 + + "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 + + // r3 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d12, d8 @ out1 = din3 * w21 \n" // q13 = d12 * w01 + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + + "vmlal.s8 q13, d30, d9 @ out1 += din3 * w20 \n" // q13 += d10 * w00 + "pld [%[din_ptr2]] @ preload data\n" + "pld [%[din_ptr3]] @ preload data\n" + + "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 + + "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" + "subs %[cnt], #1 \n" + "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" + "bne 2b \n" + // right + "1: \n" + "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + // out0 + "vdup.32 q8, %[bias] @ and \n" // q8 = vbias + "vdup.32 q9, %[bias] @ and \n" // q9 = vbias + // out1 + "vdup.32 q10, %[bias] @ and \n" // q8 = vbias + "vdup.32 q11, %[bias] @ and \n" // q9 = vbias + + "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" + "vld1.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + + // r0 + "vmull.s8 q12, d12, d2 @ out0 = din0 * w00 \n" // q12 = d12 * w01 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 12345678 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 23456789 + + "vld1.8 {d12-d13}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" + + "vmlal.s8 q12, d30, d3 @ out0 += din0 * w01 \n" // q12 += d10 * w00 + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 + + // r1 + "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 + + "vmull.s8 q13, d14, d2 @ out1 = din1 * w00 \n" // q13 = d12 * w01 + + "vmlal.s8 q12, d14, d5 @ out0 = din1 * w10 \n" // q12 = d12 * w11 + + "vld1.8 {d14-d15}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vbif.8 d12, d11, d28 @ bit select, deal with " + "right pad\n" + "vbif.8 d13, d11, d29 @ bit select, deal with " + "right pad\n" + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d30, d3 @ out1 += din1 * w01 \n" // q12 += d10 * w00 + "vmull.s8 q12, d30, d6 @ out0 += din1 * w11 \n" // q12 += d10 * w00 + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 + + // r2 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d12, d5 @ out1 = din2 * w10 \n" // q13 = d12 * w01 + "vmull.s8 q12, d12, d8 @ out1 = din2 * w20 \n" // q13 = d12 * w01 + + "vbif.8 d14, d11, d28 @ bit select, deal with " + "right pad\n" + "vbif.8 d15, d11, d29 @ bit select, deal with " + "right pad\n" + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d30, d6 @ out1 += din2 * w10 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d9 @ out0 += din2 * w20 \n" // q12 += d10 * w00 + + "vld1.32 {d28-d29}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vld1.32 {d12-d13}, [%[dout_ptr1]] @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 5 6 7 8 " + "9\n" + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 + "vmull.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 + + // r3 + "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d14, d8 @ out1 = din3 * w20 \n" // q13 = d12 * w01 + "sub %[dout_ptr1], #16 @ sub \n" + "vld1.32 {d14-d15}, [%[dout_ptr2]]! @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vld1.32 {d24-d25}, [%[dout_ptr2]] @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + + "vmlal.s8 q13, d30, d9 @ out1 += din3 * w21 \n" // q13 += d10 * w00 + "vbif q8, q14, q1 @ bit select, deal with right " + "pad\n" + "vbif q9, q6, q2 @ bit select, deal with right " + "pad\n" + "sub %[dout_ptr2], #16 @ sub \n" + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 + + "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" + "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vbif q10, q7, q1 @ bit select, deal with right pad\n" + "vbif q11, q12, q2 @ bit select, deal with right pad\n" + + "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" + "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" + + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [cnt] "+r"(cnt), + [bias] "+r"(bias_val), + [rs_mask] "+r"(rst_mask) + : [mask] "r"(vmask) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + dout_ptr += 2 * w_out; + } + } + } +} + +// w_in <= 8 +void conv_depthwise_3x3s1p1_bias_s_int8(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + // printf("3x3s1 mult height \n"); + const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + //! for 4x6 convolution window + const unsigned char right_pad_idx[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // printf("conv3x3_dw start \n"); + signed char* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(signed char)); + int* write_ptr = + reinterpret_cast(ctx->workspace_data()) + w_in; + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_h = (h_out + 1) >> 1; + + unsigned int size_pad_right = (unsigned int)(w_in); + + uint8x8_t vmask_rp = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); + // uint8x8_t vmask_rp2 = vcgt_u8(vdup_n_u8(size_pad_right), + // vld1_u8(right_pad_idx + 8)); + unsigned char vmask[8]; + vst1_u8(vmask, vmask_rp); + + unsigned int rst_remain = (unsigned int)w_out; + uint32x4_t vmask_result1 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); + uint32x4_t vmask_result2 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); + + unsigned int rmask[8]; + vst1q_u32(rmask, vmask_result1); + vst1q_u32(rmask + 4, vmask_result2); + + int8x8_t vzero = vdup_n_s8(0); + int32x4_t vzero_32 = vdupq_n_s32(0); + + for (int n = 0; n < num; ++n) { + const signed char* din_batch = din + n * ch_in * size_in_channel; + int* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + int* dout_ptr = dout_batch + c * size_out_channel; + + const signed char* din_ch_ptr = din_batch + c * size_in_channel; + + int bias_val = flag_bias ? bias[c] : 0; + + const signed char* wei_ptr = weights + c * w_stride; +#ifdef __aarch64__ + int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); + int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); + int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); + + int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); + + int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); +#endif + int* doutr0 = nullptr; + int* doutr1 = nullptr; + + const signed char* dr0 = din_ch_ptr; + const signed char* dr1 = dr0 + w_in; + const signed char* dr2 = dr1 + w_in; + const signed char* dr3 = dr2 + w_in; + + const signed char* din_ptr0 = nullptr; + const signed char* din_ptr1 = nullptr; + const signed char* din_ptr2 = nullptr; + const signed char* din_ptr3 = nullptr; + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + + doutr0 = dout_ptr; + doutr1 = doutr0 + w_out; + unsigned int* rst_mask = rmask; + + int out_buf1[8]; + int out_buf2[8]; + int trash_buf[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + din_ptr3 = dr2; + dr0 = dr1; + dr1 = dr2; + dr2 = dr3; + dr3 = dr2 + w_in; + } else { + dr0 = dr2; + dr1 = dr3; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + } + //! process bottom pad + if (i + 3 > h_in) { + switch (i + 3 - h_in) { + case 3: + din_ptr1 = zero_ptr; + case 2: + din_ptr2 = zero_ptr; + case 1: + din_ptr3 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 2 > h_out) { + doutr1 = trash_buf; + } +#ifdef __aarch64__ + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "PRFM PLDL1KEEP, [%[din_ptr3]] \n" + "movi v21.4s, #0x0\n" /* out0 = 0 */ + // left + "ld1 {v4.8b}, [%[vmask]] \n" + "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v1.8b}, [%[din_ptr1]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v2.8b}, [%[din_ptr2]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v3.8b}, [%[din_ptr3]], #8 \n" /* load + a00-a015 + to + q0*/ + + "bif v0.8b, v21.8b, v4.8b \n" + "bif v1.8b, v21.8b, v4.8b \n" + "bif v2.8b, v21.8b, v4.8b \n" + "bif v3.8b, v21.8b, v4.8b \n" + + "ext v6.8b, v21.8b, v0.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v7.8b, v0.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "ld1 {v10.4s}, [%[vbias]] \n" + "ld1 {v11.4s}, [%[vbias]] \n" + + // r0 + "smull v18.8h, %[v1].8b, v0.8b \n" /* outr00 = 01234567 * w01 + */ + + "ext v8.8b, v21.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v9.8b, v1.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "smlal v18.8h, %[v0].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v12.4s}, [%[vbias]] \n" + "ld1 {v13.4s}, [%[vbias]] \n" + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v2].8b, v7.8b \n" /* outr00 = 01234567 * w00 + */ + + "ext v6.8b, v21.8b, v2.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v7.8b, v2.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + // r1 + "smull v19.8h, %[v1].8b, v1.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v4].8b, v1.8b \n" /* outr00 = 01234567 * w00 + */ + + // "ld1 {v14.4s}, [%[rmask]], #16 \n" + // "ld1 {v15.4s}, [%[rmask]] \n" + + "smlal v19.8h, %[v0].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v3].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + // "ld1 {v16.4s}, [%[ptr_out0]], #16 \n" + // "ld1 {v17.4s}, [%[ptr_out1]], #16 \n" + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v2].8b, v9.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v5].8b, v9.8b \n" /* outr00 = 01234567 * w00 + */ + + "ext v8.8b, v21.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v9.8b, v3.8b, v21.8B, #1 \n" // vext_s8(vinr0, vinr0_1, + // 1); 12345678 + + // "ld1 {v0.4s}, [%[ptr_out0]] \n" + // "ld1 {v1.4s}, [%[ptr_out1]] \n" + + // r2 + "smlal v19.8h, %[v4].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v7].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + + // "sub %[ptr_out0], %[ptr_out0], #16 \n" + // "sub %[ptr_out1], %[ptr_out1], #16 \n" + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "smlal v19.8h, %[v5].8b, v7.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v8].8b, v7.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + // r3 + "smull v19.8h, %[v7].8b, v3.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smlal v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + // "bif v10.16b, v16.16b, v14.16b \n" + // "bif v11.16b, v0.16b, v15.16b \n" + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v8].8b, v9.8b \n" /* outr00 = 01234567 * w00 + */ + + "stp q10, q11, [%[ptr_out0]] \n" /* store q10, q11 -> ptr_out */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + // "bif v12.16b, v17.16b, v14.16b \n" + // "bif v13.16b, v1.16b, v15.16b \n" + + "stp q12, q13, [%[ptr_out1]] \n" /* store q10, q11 -> ptr_out */ + + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [rmask] "+r"(rst_mask) + : [v0] "w"(wr00), + [v1] "w"(wr01), + [v2] "w"(wr02), + [v3] "w"(wr10), + [vbias] "r"(vbias), + [v4] "w"(wr11), + [v5] "w"(wr12), + [v6] "w"(wr20), + [v7] "w"(wr21), + [v8] "w"(wr22), + [vmask] "r"(vmask), + [ptr_out0] "r"(out_buf1), + [ptr_out1] "r"(out_buf2) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + // store weights + asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" + : + : [wei_ptr] "r"(wei_ptr) + : "memory"); + asm volatile( + // left + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + "pld [%[din_ptr3]] @ preload data\n" + "vld1.8 {d28}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + "vld1.8 {d12}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + "vld1.8 {d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" + + "vmov.u32 d11, #0 @ zero\n" + // out0 + "vdup.32 q8, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q9, %[bias] @ and \n" // q9 = + // vbias + + "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d13, d11, d28 @ bit select, deal with right pad\n" + "vld1.8 {d14}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + "vld1.8 {d15}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + // out1 + "vdup.32 q10, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q11, %[bias] @ and \n" // q9 = + // vbias + + // r0 + "vmull.s8 q12, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d11, #1 @ ext \n" // d11 = 12345678 + + "vdup.s8 d5, d0[3] @ d5 = w10, w10, w00, w00\n" + "vdup.s8 d6, d0[4] @ d6 = w11, w11, w01, w01\n" + + "vmlal.s8 q12, d30, d2 @ out0 += din0 * w00 \n" // q12 += d10 * w00 + + "vdup.s8 d7, d0[5] @ d7 = w12, w12\n" + "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d15, d11, d28 @ bit select, deal with right pad\n" + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 + + // r1 + "vext.8 d30, d11, d13, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d13, d11, #1 @ ext \n" // d11 = 12345678 + "vmull.s8 q13, d13, d3 @ out1 = din1 * w01 \n" // q13 = d12 * w01 + + "vmlal.s8 q12, d13, d6 @ out0 = din1 * w11 \n" // q12 = d12 * w11 + + "vdup.s8 d8, d0[6] @ d8 = w20, w00, w00, w00\n" + "vdup.s8 d9, d0[7] @ d9 = w21, w01, w01, w01\n" + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d30, d2 @ out1 += din1 * w00 \n" // q12 += d10 * w00 + "vmull.s8 q12, d30, d5 @ out0 += din1 * w10 \n" // q12 += d10 * w00 + + "vdup.s8 d10, d1[0] @ d10 = w22, w02, w02, w02\n" + // "vld1.32 {d28-d29}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 + // 6 7 8 9\n" "vld1.32 {d12-d13}, [%[dout_ptr1]] @ load din00= 0 + // 1 2 3 4 5 6 7 8 9\n" + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 + + // r2 + "vext.8 d30, d11, d14, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d11, #1 @ ext \n" // d11 = 12345678 + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d14, d6 @ out1 = din2 * w11 \n" // q13 = d12 * w01 + "vmull.s8 q12, d14, d9 @ out1 = din2 * w21 \n" // q13 = d12 * w01 + + // "sub %[dout_ptr1], #16 @ sub \n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d30, d5 @ out1 += din2 * w10 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d8 @ out0 += din2 * w20 \n" // q12 += d10 * w00 + + // "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 + // 8 9\n" "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 + // 5 6 7 8 9\n" + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 + "vmull.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 + + // r3 + "vext.8 d30, d11, d15, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d15, d11, #1 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d15, d9 @ out1 = din3 * w21 \n" // q13 = d12 * w01 + + // "vld1.32 {d6-d7}, [%[dout_ptr2]]! @ load din00= 0 1 2 3 4 5 6 + // 7 8 9\n" "vld1.32 {d14-d15}, [%[dout_ptr2]] @ load din00= 0 1 + // 2 3 4 5 6 7 8 9\n" + + "vmlal.s8 q13, d30, d8 @ out1 += din3 * w20 \n" // q13 += d10 * w00 + + // "vbif q8, q14, q1 @ bit select, deal with right + // pad\n" "vbif q9, q6, q2 @ bit select, deal + // with right pad\n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 + + // "sub %[dout_ptr2], #16 @ sub \n" + + "vst1.32 {d16-d19}, [%[dout_ptr1]] @ store\n" + // "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + // "vbif q10, q3, q1 @ bit select, deal with right + // pad\n" "vbif q11, q7, q2 @ bit select, deal + // with right pad\n" + + "vst1.32 {d20-d23}, [%[dout_ptr2]] @ store\n" + // "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [bias] "+r"(bias_val), + [rs_mask] "+r"(rst_mask) + : [mask] "r"(vmask), + [dout_ptr1] "r"(out_buf1), + [dout_ptr2] "r"(out_buf2) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + for (int w = 0; w < w_out; ++w) { + *doutr0++ = out_buf1[w]; + *doutr1++ = out_buf2[w]; + } + dout_ptr += 2 * w_out; + } + } + } +} + +// 4line w_in > 16 +void conv_depthwise_3x3s2p1_bias_int8(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + // printf("3x3s2 mult height \n"); + //! pad is done implicit + const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + //! for 4x6 convolution window + const unsigned char right_pad_idx[16] = { + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // printf("conv3x3_dw start \n"); + signed char* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(signed char)); + int* write_ptr = + reinterpret_cast(ctx->workspace_data()) + w_out; + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_w = (w_in + 15) >> 4; + int cnt_col = tile_w - 2; + + unsigned int size_pad_right = (unsigned int)(w_in - 15 - (cnt_col << 4)); + if (size_pad_right == 17) { + size_pad_right = 0; + cnt_col++; + } + + uint8x8_t vmask_rp1 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); + uint8x8_t vmask_rp2 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); + unsigned int rst_remain = (unsigned int)(w_out - ((cnt_col + 1) << 3)); + uint32x4_t vmask_result1 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); + uint32x4_t vmask_result2 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); + + uint8x16_t vmask_rp = + vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); + unsigned char vmask[16]; + vst1q_u8(vmask, vmask_rp); + + unsigned int rmask[8]; + vst1q_u32(rmask, vmask_result1); + vst1q_u32(rmask + 4, vmask_result2); + + int8x8_t vzero = vdup_n_s8(0); + // printf("cnt_col: %d, rst_remain: %d, size_pad_right: %d\n", cnt_col, + // rst_remain, size_pad_right); + for (int n = 0; n < num; ++n) { + const signed char* din_batch = din + n * ch_in * size_in_channel; + int* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + int* dout_ptr = dout_batch + c * size_out_channel; + + const signed char* din_ch_ptr = din_batch + c * size_in_channel; + + int bias_val = flag_bias ? bias[c] : 0; + + const signed char* wei_ptr = weights + c * w_stride; +#ifdef __aarch64__ + int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); + int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); + int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); + + int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); + + int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); +#endif + + int* doutr0 = nullptr; + + const signed char* dr0 = din_ch_ptr; + const signed char* dr1 = dr0 + w_in; + const signed char* dr2 = dr1 + w_in; + + const signed char* din_ptr0 = nullptr; + const signed char* din_ptr1 = nullptr; + const signed char* din_ptr2 = nullptr; + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + + doutr0 = dout_ptr; + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + dr0 = dr1; + dr1 = dr2; + dr2 = dr1 + w_in; + } else { + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + } + //! process bottom pad + if (i + 2 > h_in) { + switch (i + 2 - h_in) { + case 2: + din_ptr1 = zero_ptr; + case 1: + din_ptr2 = zero_ptr; + default: + break; + } + } +#ifdef __aarch64__ + int cnt = cnt_col; + unsigned char* val_mask = vmask; + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "movi v10.4s, #0x0\n" + // left + "ld2 {v0.8b - v1.8b}, [%[din_ptr0]] \n" /*load a00-a015 + to q0*/ + "ld2 {v2.8b - v3.8b}, [%[din_ptr1]] \n" /* load a00-a015 + to q0*/ + "ld2 {v4.8b - v5.8b}, [%[din_ptr2]] \n" /*load a00-a015 + to q0*/ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias*/ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "ext v6.8b, v10.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + "ext v7.8b, v10.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + "ext v8.8b, v10.8b, v5.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + + // r0 + "smull v14.8h, %[v1].8b, v0.8b \n" /* outr00 = 02468 * w01 */ + "smull v15.8h, %[v2].8b, v1.8b\n" /* outr00 += 13579 * w02 */ + "smull v16.8h, %[v0].8b, v6.8b\n" /* outr00 += 013579 * w00 */ + + "add %[din_ptr0], %[din_ptr0], #15 \n" + "add %[din_ptr1], %[din_ptr1], #15 \n" + "add %[din_ptr2], %[din_ptr2], #15 \n" + + // r1 + "smlal v14.8h, %[v4].8b, v2.8b \n" /* outr00 = 02468 * w01 */ + "smlal v15.8h, %[v5].8b, v3.8b\n" /* outr00 += 13579 * w02 */ + "smlal v16.8h, %[v3].8b, v7.8b\n" /* outr00 += 013579 * w00 */ + + "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ + + // r2 + "smull v14.8h, %[v7].8b, v4.8b \n" /* outr00 = 02468 * w01 */ + "smull v15.8h, %[v8].8b, v5.8b\n" /* outr00 += 13579 * w02 */ + "smull v16.8h, %[v6].8b, v8.8b\n" /* outr00 += 013579 * w00 */ + + "ld2 {v0.8b - v1.8b}, [%[din_ptr0]], #16 \n" /*load + a00-a015 + to q0*/ + "ld2 {v2.8b - v3.8b}, [%[din_ptr1]], #16 \n" /* load + a00-a015 + to q0*/ + "ld2 {v4.8b - v5.8b}, [%[din_ptr2]], #16 \n" /*load + a00-a015 + to q0*/ + + "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ + + "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "cmp %[cnt], #1 \n" + "blt 3f \n" + // mid + "1: \n" + "ld1 {v6.8b}, [%[din_ptr0]] \n" /*load a00-a015 to q0*/ + "ld1 {v7.8b}, [%[din_ptr1]] \n" /*load a00-a015 to q0*/ + "ld1 {v8.8b}, [%[din_ptr2]] \n" /*load a00-a015 to q0*/ + + "ext v9.8b, v0.8b, v6.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 246810 */ + "ext v11.8b, v2.8b, v7.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 246810 */ + "ext v14.8b, v4.8b, v8.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 246810 */ + + // r0 + "smull v6.8h, %[v0].8b, v0.8b \n" /* outr00 = 02468 * w00 */ + "smull v7.8h, %[v1].8b, v1.8b\n" /* outr00 += 13579 * w01 */ + "smull v8.8h, %[v2].8b, v9.8b\n" /* outr00 += 246810 * w02 */ + + // r1 + "smlal v6.8h, %[v3].8b, v2.8b \n" /* outr00 = 02468 * w00 */ + "smlal v7.8h, %[v4].8b, v3.8b\n" /* outr00 += 13579 * w01 */ + "smlal v8.8h, %[v5].8b, v11.8b\n" /* outr00 += 246810 * w02 */ + + "saddw v12.4s, v12.4s, v6.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v6.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v7.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v7.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v8.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v8.8h \n" /* v11 += outr00.high*/ + + // r2 + "smull v6.8h, %[v6].8b, v4.8b \n" /* outr00 = 02468 * w00 */ + "smull v7.8h, %[v7].8b, v5.8b\n" /* outr00 += 13579 * w01 */ + "smull v8.8h, %[v8].8b, v14.8b\n" /* outr00 += 246810 * w02 */ + + "ld2 {v0.8b - v1.8b}, [%[din_ptr0]], #16 \n" /*load + a00-a015 + to q0*/ + "ld2 {v2.8b - v3.8b}, [%[din_ptr1]], #16 \n" /* load + a00-a015 + to q0*/ + "ld2 {v4.8b - v5.8b}, [%[din_ptr2]], #16 \n" /*load + a00-a015 + to q0*/ + + "saddw v12.4s, v12.4s, v6.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v6.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v7.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v7.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v8.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v8.8h \n" /* v11 += outr00.high*/ + + "subs %[cnt], %[cnt], #1 \n" + + "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "bne 1b \n" + // right + "3: \n" + "ld1 {v14.8b}, [%[vmask]], #8 \n" + "ld1 {v15.8b}, [%[vmask]] \n" + + "bif v0.8b, v10.8b, v14.8b \n" + "bif v1.8b, v10.8b, v15.8b \n" + "bif v2.8b, v10.8b, v14.8b \n" + "bif v3.8b, v10.8b, v15.8b \n" + "bif v4.8b, v10.8b, v14.8b \n" + "bif v5.8b, v10.8b, v15.8b \n" + + "ext v6.8b, v0.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 2468.. */ + "ext v7.8b, v2.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 2468..*/ + "ext v8.8b, v4.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 2468.. */ + + // r0 + "smull v14.8h, %[v0].8b, v0.8b \n" /* outr00 = 02468 * w00 */ + "smull v15.8h, %[v1].8b, v1.8b\n" /* outr00 += 13579 * w01 */ + "smull v16.8h, %[v2].8b, v6.8b\n" /* outr00 += 246810 * w02 */ + + // r1 + "smlal v14.8h, %[v3].8b, v2.8b \n" /* outr00 = 02468 * w00 */ + "smlal v15.8h, %[v4].8b, v3.8b\n" /* outr00 += 13579 * w01 */ + "smlal v16.8h, %[v5].8b, v7.8b\n" /* outr00 += 246810 * w02 */ + + "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ + + // r2 + "smull v14.8h, %[v6].8b, v4.8b \n" /* outr00 = 02468 * w00 */ + "smull v15.8h, %[v7].8b, v5.8b\n" /* outr00 += 13579 * w01 */ + "smull v16.8h, %[v8].8b, v8.8b\n" /* outr00 += 246810 * w02 */ + + "ldp q0, q1, [%[ptr_out0]] \n" /* dup v10, bias */ + "ldp q9, q11, [%[rst_mask]] \n" /* dup v10, bias */ + + "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ + + "bif v12.16b, v0.16b, v9.16b \n" + "bif v13.16b, v1.16b, v11.16b \n" + + "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [ptr_out0] "+r"(doutr0), + [vmask] "+r"(val_mask) + : [v0] "w"(wr00), + [v1] "w"(wr01), + [v2] "w"(wr02), + [v3] "w"(wr10), + [bias_val] "r"(vbias), + [v4] "w"(wr11), + [v5] "w"(wr12), + [v6] "w"(wr20), + [v7] "w"(wr21), + [v8] "w"(wr22), + [rst_mask] "r"(rmask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + unsigned int* rst_mask = rmask; + int cnt = cnt_col; + // prefetch input + // store weights + asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" + : + : [wei_ptr] "r"(wei_ptr) + : "memory"); + asm volatile( + // left + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" + "vld2.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 + "vld2.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 + "vld2.8 {d16-d17}, [%[din_ptr2]] @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 + "vmov.u32 d11, #0 @ zero\n" + + "vdup.s8 d5, d0[3] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d6, d0[4] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d7, d0[5] @ d4 = w02, w02, w02, w02\n" + + "vext.8 d18, d11, d13, #7 @ ext \n" // d16 = -1 1 3 5 + "vext.8 d19, d11, d15, #7 @ ext \n" // d17 = -1 1 3 5 + "vext.8 d20, d11, d17, #7 @ ext \n" // d18 = -1 1 3 5 + + // r0 + "vmull.s8 q13, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vmull.s8 q14, d13, d4 @ out1 = din0 * w02 \n" // q12 = d12 * w02 + "vmull.s8 q15, d18, d2 @ out2 = din0 * w00 \n" // q12 = d12 * w02 + + "vdup.s8 d8, d0[6] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d9, d0[7] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d10, d1[0] @ d4 = w02, w02, w02, w02\n" + + // out0 + "vdup.32 q11, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q12, %[bias] @ and \n" // q9 = + // vbias + + // r1 + "vmlal.s8 q13, d14, d6 @ out0 += din1 * w11 \n" // q12 = d12 * w11 + "vmlal.s8 q14, d15, d7 @ out1 += din1 * w12 \n" // q12 = d12 * w11 + "vmlal.s8 q15, d19, d5 @ out2 += din1 * w10 \n" // q12 = d12 * w11 + + "add %[din_ptr0], #15 @add \n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "add %[din_ptr1], #15 @add \n" + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + "add %[din_ptr2], #15 @add \n" + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + // r2 + "vmull.s8 q13, d16, d9 @ out0 += din1 * w21 \n" // q12 = d12 * w11 + "vmull.s8 q14, d17, d10 @ out1 += din1 * w22 \n" // q12 = d12 * w11 + "vmull.s8 q15, d20, d8 @ out2 += din1 * w20 \n" // q12 = d12 * w11 + + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" + "cmp %[cnt], #1 \n" + "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" + "blt 1f \n" + + // mid + "2: \n" + "vld2.8 {d12-d13}, [%[din_ptr0]]! @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 + "vld2.8 {d14-d15}, [%[din_ptr1]]! @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 + "vld2.8 {d16-d17}, [%[din_ptr2]]! @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 + + "vld1.8 {d21}, [%[din_ptr0]] @ load din00= 16 17\n" // d10 = 0 2 + // 4 6 + "vld1.8 {d22}, [%[din_ptr1]] @ load din00= 16 17\n" // d12 = 0 2 + // 4 6 + "vld1.8 {d23}, [%[din_ptr2]] @ load din00= 16 17\n" // d14 = 0 2 + // 4 6 + + "vext.8 d18, d12, d21, #1 @ ext din00 = 2 4 6 8\n" // d16 = 2 + // 4 6 8 + "vext.8 d19, d14, d22, #1 @ ext \n" // d17 = 2 4 6 8 + "vext.8 d20, d16, d23, #1 @ ext \n" // d18 = 2 4 6 8 + + // r0 + "vmull.s8 q13, d12, d2 @ out0 = din0 * w00 \n" // q12 = 0 2 4 6 + "vmull.s8 q14, d13, d3 @ out1 = din0 * w01 \n" // q12 = 1 3 5 7 + "vmull.s8 q15, d18, d4 @ out2 = din0 * w02 \n" // q12 = 2 4 6 8 + + // out0 + "vdup.32 q11, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q12, %[bias] @ and \n" // q9 = + // vbias + + // r1 + "vmlal.s8 q13, d14, d5 @ out0 += din1 * w10 \n" // q12 = 0 2 4 6 + "vmlal.s8 q14, d15, d6 @ out1 += din1 * w11 \n" // q12 = 1 3 5 7 + "vmlal.s8 q15, d19, d7 @ out2 += din1 * w12 \n" // q12 = 2 4 6 8 + + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + // r2 + "vmull.s8 q13, d16, d8 @ out0 += din1 * w20 \n" // q12 = 0 2 4 6 + "vmull.s8 q14, d17, d9 @ out1 += din1 * w21 \n" // q12 = 1 3 5 7 + "vmull.s8 q15, d20, d10 @ out2 += din1 * w22 \n" // q12 = 2 4 6 8 + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" + + "subs %[cnt], #1 \n" + "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" + "bne 2b \n" + // right + "1: \n" + "cmp %[size_pad_right], #1 \n" + "blt 3f \n" + "vld2.8 {d12-d13}, [%[din_ptr0]]! @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 + "vld2.8 {d14-d15}, [%[din_ptr1]]! @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 + "vld2.8 {d16-d17}, [%[din_ptr2]]! @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 + "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + + // out0 + "vdup.32 q11, %[bias] @ and \n" // q8 = vbias + "vdup.32 q12, %[bias] @ and \n" // q9 = vbias + + "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" + + "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" + + "vbif.8 d16, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d17, d11, d29 @ bit select, deal with right pad\n" + + "vext.8 d18, d12, d11, #1 @ ext din00 = 2 4 6 8\n" // d16 = -1 + // 1 3 5 + "vext.8 d19, d14, d11, #1 @ ext \n" // d17 = -1 1 3 5 + "vext.8 d20, d16, d11, #1 @ ext \n" // d18 = -1 1 3 5 + + // r0 + "vmull.s8 q13, d12, d2 @ out0 = din0 * w00 \n" // q12 = 0 2 4 6 + "vmull.s8 q14, d13, d3 @ out1 = din0 * w01 \n" // q12 = 1 3 5 7 + "vmull.s8 q15, d18, d4 @ out2 = din0 * w02 \n" // q12 = 2 4 6 8 + + // r1 + "vmlal.s8 q13, d14, d5 @ out0 += din1 * w11 \n" // q12 = 0 2 4 6 + "vmlal.s8 q14, d15, d6 @ out1 += din1 * w12 \n" // q12 = 1 3 5 7 + "vmlal.s8 q15, d19, d7 @ out2 += din1 * w10 \n" // q12 = 2 4 6 8 + + "vld1.32 {d12-d13}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vld1.32 {d14-d15}, [%[dout_ptr1]] @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "sub %[dout_ptr1], #16 @ sub \n" + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + // r2 + "vmull.s8 q13, d16, d8 @ out0 += din1 * w11 \n" // q12 = 0 2 4 6 + "vmull.s8 q14, d17, d9 @ out1 += din1 * w12 \n" // q12 = 1 3 5 7 + "vmull.s8 q15, d20, d10 @ out2 += din1 * w10 \n" // q12 = 2 4 6 8 + + "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 5 6 7 8 " + "9\n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vbif q11, q6, q1 @ bit select, deal with right pad\n" + "vbif q12, q7, q2 @ bit select, deal with right pad\n" + + "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" + "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" + "3: \n" + + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [dout_ptr1] "+r"(doutr0), + [cnt] "+r"(cnt), + [bias] "+r"(bias_val), + [rs_mask] "+r"(rst_mask) + : [mask] "r"(vmask), [size_pad_right] "r"(size_pad_right) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + dout_ptr += w_out; + } + } + } +} +// w_in <= 16 +void conv_depthwise_3x3s2p1_bias_s_int8(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + // printf("3x3s2 mult height \n"); + //! pad is done implicit + // const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + //! for 4x6 convolution window + const unsigned char right_pad_idx[16] = { + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // printf("conv3x3_dw start \n"); + signed char* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(signed char)); + int* write_ptr = + reinterpret_cast(ctx->workspace_data()) + w_out; + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + unsigned int size_pad_right = (unsigned int)(w_in); + + uint8x8_t vmask_rp1 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); + uint8x8_t vmask_rp2 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); + unsigned int rst_remain = (unsigned int)w_out; + uint32x4_t vmask_result1 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); + uint32x4_t vmask_result2 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); + + uint8x16_t vmask_rp = + vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); + unsigned char vmask[16]; + vst1q_u8(vmask, vmask_rp); + + unsigned int rmask[8]; + vst1q_u32(rmask, vmask_result1); + vst1q_u32(rmask + 4, vmask_result2); + + int8x8_t vzero = vdup_n_s8(0); + for (int n = 0; n < num; ++n) { + const signed char* din_batch = din + n * ch_in * size_in_channel; + int* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + int* dout_ptr = dout_batch + c * size_out_channel; + + const signed char* din_ch_ptr = din_batch + c * size_in_channel; + + int bias_val = flag_bias ? bias[c] : 0; + + const signed char* wei_ptr = weights + c * w_stride; +#ifdef __aarch64__ + int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + + int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); + int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); + int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); + + int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); + + int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); +#endif + int* doutr0 = nullptr; + + const signed char* dr0 = din_ch_ptr; + const signed char* dr1 = dr0 + w_in; + const signed char* dr2 = dr1 + w_in; + + const signed char* din_ptr0 = nullptr; + const signed char* din_ptr1 = nullptr; + const signed char* din_ptr2 = nullptr; + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + + doutr0 = dout_ptr; + int out_buf1[8]; + + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + dr0 = dr1; + dr1 = dr2; + dr2 = dr1 + w_in; + } else { + dr0 = dr2; + dr1 = dr2 + w_in; + dr2 = dr1 + w_in; + } + //! process bottom pad + if (i + 2 > h_in) { + switch (i + 2 - h_in) { + case 2: + din_ptr1 = zero_ptr; + case 1: + din_ptr2 = zero_ptr; + default: + break; + } + } +#ifdef __aarch64__ + unsigned int* rst_mask = rmask; + unsigned char* val_mask = vmask; + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "movi v16.4s, #0x0\n" + // left + "ld1 {v10.8b}, [%[vmask]], #8 \n" + "ld1 {v11.8b}, [%[vmask]] \n" + "ld2 {v0.8b - v1.8b}, [%[din_ptr0]] \n" /*load a00-a015 + to q0*/ + "ld2 {v2.8b - v3.8b}, [%[din_ptr1]] \n" /* load a00-a015 + to q0*/ + "ld2 {v4.8b - v5.8b}, [%[din_ptr2]] \n" /*load a00-a015 + to q0*/ + + "bif v0.8b, v16.8b, v10.8b \n" + "bif v1.8b, v16.8b, v11.8b \n" + "bif v2.8b, v16.8b, v10.8b \n" + "bif v3.8b, v16.8b, v11.8b \n" + "bif v4.8b, v16.8b, v10.8b \n" + "bif v5.8b, v16.8b, v11.8b \n" + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias*/ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "ext v6.8b, v16.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + "ext v7.8b, v16.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + "ext v8.8b, v16.8b, v5.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + + // r0 + "smull v17.8h, %[v1].8b, v0.8b \n" /* outr00 = 02468 * w01 */ + "smull v18.8h, %[v2].8b, v1.8b\n" /* outr00 += 13579 * w02 */ + "smull v19.8h, %[v0].8b, v6.8b\n" /* outr00 += 013579 * w00 */ + + // "ldp q0, q1, [%[ptr_out0]] \n" /* dup v10, + // bias */ "ldp q10, q11, [%[rst_mask]] \n" /* + // dup v10, bias */ + + // r1 + "smlal v17.8h, %[v4].8b, v2.8b \n" /* outr00 = 02468 * w01 */ + "smlal v18.8h, %[v5].8b, v3.8b\n" /* outr00 += 13579 * w02 */ + "smlal v19.8h, %[v3].8b, v7.8b\n" /* outr00 += 013579 * w00 */ + + "saddw v12.4s, v12.4s, v17.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v17.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v18.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + // r2 + "smull v17.8h, %[v7].8b, v4.8b \n" /* outr00 = 02468 * w01 */ + "smull v18.8h, %[v8].8b, v5.8b\n" /* outr00 += 13579 * w02 */ + "smull v19.8h, %[v6].8b, v8.8b\n" /* outr00 += 013579 * w00 */ + + "saddw v12.4s, v12.4s, v17.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v17.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v18.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + // "bif v12.16b, v0.16b, v10.16b \n" + // "bif v13.16b, v1.16b, v11.16b \n" + + "stp q12, q13, [%[ptr_out0]] \n" /* store q10, q11 -> ptr_out + */ + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [vmask] "+r"(val_mask) + : [v0] "w"(wr00), + [v1] "w"(wr01), + [v2] "w"(wr02), + [v3] "w"(wr10), + [bias_val] "r"(vbias), + [v4] "w"(wr11), + [v5] "w"(wr12), + [v6] "w"(wr20), + [v7] "w"(wr21), + [v8] "w"(wr22), + [rst_mask] "r"(rmask), + [ptr_out0] "r"(out_buf1) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + unsigned int* rst_mask = rmask; + // prefetch input + // store weights + asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" + : + : [wei_ptr] "r"(wei_ptr) + : "memory"); + asm volatile( + // left + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" + "vld2.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 + "vld2.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 + "vld2.8 {d16-d17}, [%[din_ptr2]] @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 + "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vmov.u32 d11, #0 @ zero\n" + + "vdup.s8 d5, d0[3] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d6, d0[4] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d7, d0[5] @ d4 = w02, w02, w02, w02\n" + + "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" + + "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" + + "vbif.8 d16, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d17, d11, d29 @ bit select, deal with right pad\n" + + "vext.8 d18, d11, d13, #7 @ ext \n" // d16 = -1 1 3 5 + "vext.8 d19, d11, d15, #7 @ ext \n" // d17 = -1 1 3 5 + "vext.8 d20, d11, d17, #7 @ ext \n" // d18 = -1 1 3 5 + + // "pld [%[dout_ptr1]] @ preload data\n" + + // r0 + "vmull.s8 q13, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vmull.s8 q14, d13, d4 @ out1 = din0 * w02 \n" // q12 = d12 * w02 + "vmull.s8 q15, d18, d2 @ out2 = din0 * w00 \n" // q12 = d12 * w02 + + "vdup.s8 d8, d0[6] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d9, d0[7] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d10, d1[0] @ d4 = w02, w02, w02, w02\n" + + // out0 + "vdup.32 q11, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q12, %[bias] @ and \n" // q9 = + // vbias + + // r1 + "vmlal.s8 q13, d14, d6 @ out0 += din1 * w11 \n" // q12 = d12 * w11 + "vmlal.s8 q14, d15, d7 @ out1 += din1 * w12 \n" // q12 = d12 * w11 + "vmlal.s8 q15, d19, d5 @ out2 += din1 * w10 \n" // q12 = d12 * w11 + + // "vld1.32 {d12-d13}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 + // 6 7 8 9\n" "vld1.32 {d14-d15}, [%[dout_ptr1]] @ load din00= 0 + // 1 2 3 4 5 6 7 8 9\n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + // r2 + "vmull.s8 q13, d16, d9 @ out0 += din1 * w21 \n" // q12 = d12 * w11 + "vmull.s8 q14, d17, d10 @ out1 += din1 * w22 \n" // q12 = d12 * w11 + "vmull.s8 q15, d20, d8 @ out2 += din1 * w20 \n" // q12 = d12 * w11 + + // "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 + // 8 9\n" "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 + // 5 6 7 8 9\n" + + // "sub %[dout_ptr1], #16 @ sub \n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + // "vbif q11, q6, q1 @ bit select, deal with right pad\n" + // "vbif q12, q7, q2 @ bit select, deal with right pad\n" + + "vst1.32 {d22-d25}, [%[dout_ptr1]] @ store\n" + // "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [bias] "+r"(bias_val), + [rs_mask] "+r"(rst_mask) + : [mask] "r"(vmask), + [size_pad_right] "r"(size_pad_right), + [dout_ptr1] "r"(out_buf1) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + for (int w = 0; w < w_out; ++w) { + *doutr0++ = out_buf1[w]; + } + dout_ptr += w_out; + } + } + } +} + +// relu +void conv_depthwise_3x3s1p1_bias_relu_int8(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + // printf("3x3s1 mult height \n"); + //! pad is done implicit + const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + //! for 4x6 convolution window + const unsigned char right_pad_idx[16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // printf("conv3x3_dw start \n"); + signed char* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(signed char)); + int* write_ptr = + reinterpret_cast(ctx->workspace_data()) + w_in; + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_w = (w_in + 7) >> 3; + int tile_h = (h_out + 1) >> 1; + int cnt_col = tile_w - 2; + + unsigned int size_pad_right = (unsigned int)(w_in - 7 - (cnt_col << 3)); + + int size_pad_bottom = h_out % 2; + + uint8x8_t vmask_rp1 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); + uint8x8_t vmask_rp2 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); + unsigned int rst_remain = (unsigned int)(w_out - ((cnt_col + 1) << 3)); + uint32x4_t vmask_result1 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); + uint32x4_t vmask_result2 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); + + int8x8_t vzero = vdup_n_s8(0); + int32x4_t vzero_32 = vdupq_n_s32(0); + + uint8x16_t vmask_rp = + vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); + // uint8x8_t vmask_rp2 = vcgt_u8(vdup_n_u8(size_pad_right), + // vld1_u8(right_pad_idx + 8)); + unsigned char vmask[16]; + vst1q_u8(vmask, vmask_rp); + + unsigned int rmask[8]; + vst1q_u32(rmask, vmask_result1); + vst1q_u32(rmask + 4, vmask_result2); + + for (int n = 0; n < num; ++n) { + const signed char* din_batch = din + n * ch_in * size_in_channel; + int* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + int* dout_ptr = dout_batch + c * size_out_channel; + + const signed char* din_ch_ptr = din_batch + c * size_in_channel; + + int bias_val = flag_bias ? bias[c] : 0; + + const signed char* wei_ptr = weights + c * w_stride; +#ifdef __aarch64__ + int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); + int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); + int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); + + int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); + + int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); +#endif + + int* doutr0 = nullptr; + int* doutr1 = nullptr; + + const signed char* dr0 = din_ch_ptr; + const signed char* dr1 = dr0 + w_in; + const signed char* dr2 = dr1 + w_in; + const signed char* dr3 = dr2 + w_in; + + const signed char* din_ptr0 = nullptr; + const signed char* din_ptr1 = nullptr; + const signed char* din_ptr2 = nullptr; + const signed char* din_ptr3 = nullptr; + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + + doutr0 = dout_ptr; + doutr1 = doutr0 + w_out; + unsigned int* rst_mask = rmask; + unsigned char* val_mask = vmask; + + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + din_ptr3 = dr2; + dr0 = dr1; + dr1 = dr2; + dr2 = dr3; + dr3 = dr2 + w_in; + } else { + dr0 = dr2; + dr1 = dr3; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + } + //! process bottom pad + if (i + 3 > h_in) { + switch (i + 3 - h_in) { + case 3: + din_ptr1 = zero_ptr; + case 2: + din_ptr2 = zero_ptr; + case 1: + din_ptr3 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 2 > h_out) { + doutr1 = write_ptr; + } + int cnt = cnt_col; +#ifdef __aarch64__ + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "PRFM PLDL1KEEP, [%[din_ptr3]] \n" + "movi v21.4s, #0x0\n" /* out0 = 0 */ + // left + "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load + a00-a015 to + q0*/ + "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load + a00-a015 to + q0*/ + + "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + // r0 + "smull v18.8h, %[v1].8b, v0.8b \n" /* outr00 = 01234567 * w01 + */ + + "ext v4.8b, v21.8b, v0.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v5.8b, v0.8b, v1.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load + a00-a015 + to + q0*/ + + "smlal v18.8h, %[v0].8b, v4.8b\n" /* outr00 += 00123456 * w00 */ + + "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load + a00-a015 + to q0*/ + "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load + a00-a015 + to q0*/ + + "sub %[din_ptr0], %[din_ptr0], #1 \n" + "sub %[din_ptr1], %[din_ptr1], #1 \n" + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v2].8b, v5.8b\n" /* outr00 += 12345678 * w02 */ + + "ext v4.8b, v21.8b, v2.8b, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v5.8b, v2.8b, v3.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + // r1 + "sub %[din_ptr2], %[din_ptr2], #1 \n" + "sub %[din_ptr3], %[din_ptr3], #1 \n" + + "smull v19.8h, %[v1].8b, v2.8b \n" /* outr10 += 01234567 * w11 + */ + "smlal v18.8h, %[v4].8b, v2.8b \n" /* outr00 += 01234567 * w11 + */ + + "ext v14.8b, v21.8b, v6.8b, #7 \n" /* vext_s8(vzero, vinr0, + 7); 00123456 */ + "ext v15.8b, v6.8b, v7.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "smlal v19.8h, %[v0].8b, v4.8b \n" /* outr00 += 01234567 * w11 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + "smull v18.8h, %[v3].8b, v4.8b \n" /* outr00 += 001234567 * w10 + */ + + "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load + a00-a015 + to + q0*/ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v2].8b, v5.8b \n" /* outr00 += 01234567 * w11 + */ + "smlal v18.8h, %[v5].8b, v5.8b \n" /* outr00 += 12345678 * w12 + */ + + // r2 + "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load + a00-a015 to + q0*/ + "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load + a00-a015 to + q0*/ + + "smlal v19.8h, %[v4].8b, v6.8b \n" /* outr10 += 01234567 * w11 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + "smull v18.8h, %[v7].8b, v6.8b \n" /* outr00 += 01234567 * w11 + */ + + "ext v4.8b, v21.8b, v8.8b, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v5.8b, v8.8b, v9.8b, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v3].8b, v14.8b \n" /* outr10 += 01234567 * w11 + */ + "smlal v18.8h, %[v6].8b, v14.8b \n" /* outr00 += 01234567 * w11 + */ + + "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load + a00-a015 + to + q0*/ + + "smlal v19.8h, %[v5].8b, v15.8b \n" /* outr10 += 01234567 * w11 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v8].8b, v15.8b \n" /* outr00 += 01234567 * w11 + */ + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + // r3 + "smull v19.8h, %[v7].8b, v8.8b \n" /* outr00 += 01234567 * w11 + */ + + "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load + a00-a015 + to + q0*/ + + "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load + a00-a015 to + q0*/ + "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load + a00-a015 to + q0*/ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smlal v19.8h, %[v6].8b, v4.8b \n" /* outr00 += 01234567 * + w11 */ + + "smax v10.4s, v10.4s, v21.4s \n" /* relu*/ + "smax v11.4s, v11.4s, v21.4s \n" /* relu*/ + + "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v8].8b, v5.8b \n" /* outr00 += 01234567 * + w11 */ + + "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smax v12.4s, v12.4s, v21.4s \n" /* relu*/ + "smax v13.4s, v13.4s, v21.4s \n" /* relu*/ + + "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "cmp %[cnt], #1 \n" + "blt 3f \n" + // mid + "1: \n" + "ext v4.8b, v0.8B, v1.8b, #1 \n" /*12345678 */ + "ext v5.8b, v0.8b, v1.8B, #2 \n" /*23456789 */ + + // r0 + "smull v18.8h, %[v0].8b, v0.8b \n" /* outr00 = 01234567 * w00 + */ + + "ext v14.8b, v2.8B, v3.8b, #1 \n" /*12345678 */ + "ext v15.8b, v2.8b, v3.8B, #2 \n" /*23456789 */ + + "smlal v18.8h, %[v1].8b, v4.8b\n" /* outr00 += 12345678 * w01 */ + + "ext v16.8b, v6.8B, v7.8b, #1 \n" /*12345678 */ + "ext v17.8b, v6.8b, v7.8B, #2 \n" /*23456789 */ + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v2].8b, v5.8b\n" /* outr00 += 23456789 * w02 */ + + // r1 + "ext v4.8b, v8.8B, v9.8b, #1 \n" /*12345678 */ + "ext v5.8b, v8.8b, v9.8B, #2 \n" /*23456789 */ + + "smull v19.8h, %[v0].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v3].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v2.8b}, [%[din_ptr1]], #8 \n" /* load + a00-a015 + to + q0*/ + + "smlal v19.8h, %[v1].8b, v14.8b\n" /* outr00 += 12345678 * w01 */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v4].8b, v14.8b\n" /* outr00 += 12345678 * w01 */ + + "ld1 {v1.8b}, [%[din_ptr0]] \n" /* load + a00-a015 + to q0*/ + "ld1 {v3.8b}, [%[din_ptr1]] \n" /* load + a00-a015 + to q0*/ + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v2].8b, v15.8b\n" /* outr00 += 23456789 * w02 */ + "smlal v18.8h, %[v5].8b, v15.8b\n" /* outr00 += 12345678 * w01 */ + + // r2 + "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v4].8b, v16.8b\n" /* outr00 += 12345678 * w01 */ + "smlal v18.8h, %[v7].8b, v16.8b\n" /* outr00 += 12345678 * w01 */ + + "smlal v19.8h, %[v5].8b, v17.8b\n" /* outr00 += 23456789 * w02 */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v8].8b, v17.8b\n" /* outr00 += 12345678 * w01 */ + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + // r3 + "smull v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v6.8b}, [%[din_ptr2]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v8.8b}, [%[din_ptr3]], #8 \n" /* load + a00-a015 + to + q0*/ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smlal v19.8h, %[v7].8b, v4.8b\n" /* outr00 += 12345678 * w01 */ + + "ld1 {v7.8b}, [%[din_ptr2]] \n" /* load + a00-a015 + to q0*/ + "ld1 {v9.8b}, [%[din_ptr3]] \n" /* load + a00-a015 + to q0*/ + + "smax v10.4s, v10.4s, v21.4s \n" /* relu*/ + "smax v11.4s, v11.4s, v21.4s \n" /* relu*/ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v8].8b, v5.8b\n" /* outr00 += 23456789 * w02 */ + + "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v10.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v11.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "subs %[cnt], %[cnt], #1 \n" + + "smax v12.4s, v12.4s, v21.4s \n" /* relu*/ + "smax v13.4s, v13.4s, v21.4s \n" /* relu*/ + + "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "bne 1b \n" + // right + "3: \n" + "ld1 {v14.8b}, [%[vmask]], #8 \n" + "ld1 {v15.8b}, [%[vmask]] \n" + + "bif v0.8b, v21.8b, v14.8b \n" + "bif v1.8b, v21.8b, v15.8b \n" + "bif v2.8b, v21.8b, v14.8b \n" + "bif v3.8b, v21.8b, v15.8b \n" + + "ext v4.8b, v0.8b, v1.8b, #1 \n" + "ext v5.8b, v0.8b, v1.8b, #2 \n" + + // r0 + "smull v18.8h, %[v0].8b, v0.8b \n" /* outr00 = 01234567 * w00 + */ + + "ext v16.8b, v2.8b, v3.8b, #1 \n" + "ext v17.8b, v2.8b, v3.8b, #2 \n" + + "bif v6.8b, v21.8b, v14.8b \n" + "bif v7.8b, v21.8b, v15.8b \n" + + "smlal v18.8h, %[v1].8b, v4.8b \n" /* outr00 = 01234567 * w00 + */ + + "bif v8.8b, v21.8b, v14.8b \n" + "bif v9.8b, v21.8b, v15.8b \n" + + "ext v20.8b, v6.8b, v7.8b, #1 \n" + "ext v22.8b, v6.8b, v7.8b, #2 \n" + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v2].8b, v5.8b \n" /* outr00 = 01234567 * w00 + */ + + // r1 + "ext v4.8b, v8.8b, v9.8b, #1 \n" + "ext v5.8b, v8.8b, v9.8b, #2 \n" + + "smull v19.8h, %[v0].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v3].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v14.4s}, [%[rmask]], #16 \n" + "ld1 {v15.4s}, [%[rmask]] \n" + + "smlal v19.8h, %[v1].8b, v16.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + "smull v18.8h, %[v4].8b, v16.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v0.4s}, [%[ptr_out0]], #16 \n" + "ld1 {v2.4s}, [%[ptr_out1]], #16 \n" + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v2].8b, v17.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v5].8b, v17.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v1.4s}, [%[ptr_out0]] \n" + "ld1 {v3.4s}, [%[ptr_out1]] \n" + + // r2 + "smlal v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + "smull v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "sub %[ptr_out0], %[ptr_out0], #16 \n" + "sub %[ptr_out1], %[ptr_out1], #16 \n" + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v4].8b, v20.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v7].8b, v20.8b \n" /* outr00 = 01234567 * w00 + */ + + "smlal v19.8h, %[v5].8b, v22.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v8].8b, v22.8b \n" /* outr00 = 01234567 * w00 + */ + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + // r3 + "smull v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smlal v19.8h, %[v7].8b, v4.8b \n" /* outr00 = 01234567 * w00 + */ + + "smax v10.4s, v10.4s, v21.4s \n" /* relu*/ + "smax v11.4s, v11.4s, v21.4s \n" /* relu*/ + + "bif v10.16b, v0.16b, v14.16b \n" + "bif v11.16b, v1.16b, v15.16b \n" + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v8].8b, v5.8b \n" /* outr00 = 01234567 * w00 + */ + + "stp q10, q11, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smax v12.4s, v12.4s, v21.4s \n" /* relu*/ + "smax v13.4s, v13.4s, v21.4s \n" /* relu*/ + + "bif v12.16b, v2.16b, v14.16b \n" + "bif v13.16b, v3.16b, v15.16b \n" + + "stp q12, q13, [%[ptr_out1]], #32 \n" /* store q10, q11 -> + ptr_out */ + + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [ptr_out0] "+r"(doutr0), + [ptr_out1] "+r"(doutr1), + [vmask] "+r"(val_mask), + [rmask] "+r"(rst_mask) + : [v0] "w"(wr00), + [v1] "w"(wr01), + [v2] "w"(wr02), + [v3] "w"(wr10), + [bias_val] "r"(vbias), + [v4] "w"(wr11), + [v5] "w"(wr12), + [v6] "w"(wr20), + [v7] "w"(wr21), + [v8] "w"(wr22) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + // store weights + asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" + : + : [wei_ptr] "r"(wei_ptr) + : "memory"); + asm volatile( + // left + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + "pld [%[din_ptr3]] @ preload data\n" + "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" + "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vmov.u32 d11, #0 @ zero\n" + // out0 + "vdup.32 q8, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q9, %[bias] @ and \n" // q9 = + // vbias + // out1 + "vdup.32 q10, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q11, %[bias] @ and \n" // q9 = + // vbias + + // r0 + "vmull.s8 q12, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 + + "vld1.8 {d12-d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.8 {d14-d15}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vdup.s8 d5, d0[3] @ d5 = w10, w10, w00, w00\n" + "vdup.s8 d6, d0[4] @ d6 = w11, w11, w01, w01\n" + + "vmlal.s8 q12, d30, d2 @ out0 += din0 * w00 \n" // q12 += d10 * w00 + + "vdup.s8 d7, d0[5] @ d7 = w12, w12\n" + "add %[din_ptr0], #7 @add \n" + "add %[din_ptr1], #7 @add \n" + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 + + // r1 + "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 + "vmull.s8 q13, d12, d3 @ out1 = din1 * w01 \n" // q13 = d12 * w01 + + "vmlal.s8 q12, d12, d6 @ out0 = din1 * w11 \n" // q12 = d12 * w11 + + "vld1.8 {d12-d13}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vdup.s8 d8, d0[6] @ d8 = w20, w00, w00, w00\n" + "vdup.s8 d9, d0[7] @ d9 = w21, w01, w01, w01\n" + "vdup.s8 d10, d1[0] @ d10 = w22, w02, w02, w02\n" + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d30, d2 @ out1 += din1 * w00 \n" // q12 += d10 * w00 + "vmull.s8 q12, d30, d5 @ out0 += din1 * w10 \n" // q12 += d10 * w00 + + "add %[din_ptr2], #7 @add \n" + "add %[din_ptr3], #7 @add \n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 + + // r2 + "vext.8 d30, d11, d14, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d15, #1 @ ext \n" // d11 = 12345678 + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d14, d6 @ out1 = din2 * w11 \n" // q13 = d12 * w01 + "vmull.s8 q12, d14, d9 @ out1 = din2 * w21 \n" // q13 = d12 * w01 + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d30, d5 @ out1 += din2 * w10 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d8 @ out0 += din2 * w20 \n" // q12 += d10 * w00 + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 + "vmull.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 + + // r3 + "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #1 @ ext \n" // d11 = 12345678 + "vmov.u32 q0, #0 @ mov \n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d12, d9 @ out1 = din3 * w21 \n" // q13 = d12 * w01 + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "vmax.s32 q8, q8, q0 @ max \n" + "vmax.s32 q9, q9, q0 @ max \n" + + "vmlal.s8 q13, d30, d8 @ out1 += din3 * w20 \n" // q13 += d10 * w00 + "pld [%[din_ptr2]] @ preload data\n" + "pld [%[din_ptr3]] @ preload data\n" + + "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 + + "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmax.s32 q10, q10, q0 @ max \n" + "vmax.s32 q11, q11, q0 @ max \n" + + "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" + "cmp %[cnt], #1 \n" + "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" + "blt 1f \n" + + // mid + "2: \n" + "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + // out0 + "vdup.32 q8, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q9, %[bias] @ and \n" // q9 = + // vbias + // out1 + "vdup.32 q10, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q11, %[bias] @ and \n" // q9 = + // vbias + + // r0 + "vmull.s8 q12, d12, d2 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 12345678 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 23456789 + + "vld1.8 {d12-d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.8 {d14-d15}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + + "vmlal.s8 q12, d30, d3 @ out0 += din0 * w00 \n" // q12 += d10 * w00 + + "add %[din_ptr0], #8 @add \n" + "add %[din_ptr1], #8 @add \n" + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 + + // r1 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 + "vmull.s8 q13, d12, d2 @ out1 = din1 * w01 \n" // q13 = d12 * w01 + + "vmlal.s8 q12, d12, d5 @ out0 = din1 * w11 \n" // q12 = d12 * w11 + + "vld1.8 {d12-d13}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + + "vmlal.s8 q13, d30, d3 @ out1 += din1 * w00 \n" // q12 += d10 * w00 + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q12, d30, d6 @ out0 += din1 * w10 \n" // q12 += d10 * w00 + + "add %[din_ptr2], #8 @add \n" + "add %[din_ptr3], #8 @add \n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 + + // r2 + "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d14, d5 @ out1 = din2 * w11 \n" // q13 = d12 * w01 + "vmull.s8 q12, d14, d8 @ out1 = din2 * w21 \n" // q13 = d12 * w01 + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d30, d6 @ out1 += din2 * w10 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d9 @ out0 += din2 * w20 \n" // q12 += d10 * w00 + + "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 + + // r3 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d12, d8 @ out1 = din3 * w21 \n" // q13 = d12 * w01 + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "vmax.s32 q8, q8, q0 @ max \n" + "vmax.s32 q9, q9, q0 @ max \n" + + "vmlal.s8 q13, d30, d9 @ out1 += din3 * w20 \n" // q13 += d10 * w00 + "pld [%[din_ptr2]] @ preload data\n" + "pld [%[din_ptr3]] @ preload data\n" + + "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 + + "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmax.s32 q10, q10, q0 @ max \n" + "vmax.s32 q11, q11, q0 @ max \n" + + "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" + "subs %[cnt], #1 \n" + "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" + "bne 2b \n" + // right + "1: \n" + "vld1.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + // out0 + "vdup.32 q8, %[bias] @ and \n" // q8 = vbias + "vdup.32 q9, %[bias] @ and \n" // q9 = vbias + // out1 + "vdup.32 q10, %[bias] @ and \n" // q8 = vbias + "vdup.32 q11, %[bias] @ and \n" // q9 = vbias + + "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" + "vld1.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + + // r0 + "vmull.s8 q12, d12, d2 @ out0 = din0 * w00 \n" // q12 = d12 * w01 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 12345678 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 23456789 + + "vld1.8 {d12-d13}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" + + "vmlal.s8 q12, d30, d3 @ out0 += din0 * w01 \n" // q12 += d10 * w00 + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 + + // r1 + "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 + + "vmull.s8 q13, d14, d2 @ out1 = din1 * w00 \n" // q13 = d12 * w01 + + "vmlal.s8 q12, d14, d5 @ out0 = din1 * w10 \n" // q12 = d12 * w11 + + "vld1.8 {d14-d15}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vbif.8 d12, d11, d28 @ bit select, deal with " + "right pad\n" + "vbif.8 d13, d11, d29 @ bit select, deal with " + "right pad\n" + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d30, d3 @ out1 += din1 * w01 \n" // q12 += d10 * w00 + "vmull.s8 q12, d30, d6 @ out0 += din1 * w11 \n" // q12 += d10 * w00 + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 + + // r2 + "vext.8 d30, d12, d13, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d13, #2 @ ext \n" // d11 = 12345678 + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d12, d5 @ out1 = din2 * w10 \n" // q13 = d12 * w01 + "vmull.s8 q12, d12, d8 @ out1 = din2 * w20 \n" // q13 = d12 * w01 + + "vbif.8 d14, d11, d28 @ bit select, deal with " + "right pad\n" + "vbif.8 d15, d11, d29 @ bit select, deal with " + "right pad\n" + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d30, d6 @ out1 += din2 * w10 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d9 @ out0 += din2 * w20 \n" // q12 += d10 * w00 + + "vld1.32 {d28-d29}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vld1.32 {d12-d13}, [%[dout_ptr1]] @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 5 6 7 8 " + "9\n" + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 + "vmull.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 + + // r3 + "vext.8 d30, d14, d15, #1 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d15, #2 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d14, d8 @ out1 = din3 * w20 \n" // q13 = d12 * w01 + "vld1.32 {d14-d15}, [%[dout_ptr2]]! @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vld1.32 {d24-d25}, [%[dout_ptr2]] @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vmax.s32 q8, q8, q0 @ max \n" + "vmax.s32 q9, q9, q0 @ max \n" + + "vmlal.s8 q13, d30, d9 @ out1 += din3 * w21 \n" // q13 += d10 * w00 + "vbif q8, q14, q1 @ bit select, deal with right " + "pad\n" + "vbif q9, q6, q2 @ bit select, deal with right " + "pad\n" + "sub %[dout_ptr1], #16 @ sub \n" + "sub %[dout_ptr2], #16 @ sub \n" + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 + + "vst1.32 {d16-d17}, [%[dout_ptr1]]! @ store\n" + "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmax.s32 q10, q10, q0 @ max \n" + "vmax.s32 q11, q11, q0 @ max \n" + + "vbif q10, q7, q1 @ bit select, deal with right pad\n" + "vbif q11, q12, q2 @ bit select, deal with right pad\n" + + "vst1.32 {d20-d21}, [%[dout_ptr2]]! @ store\n" + "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" + + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [cnt] "+r"(cnt), + [bias] "+r"(bias_val), + [rs_mask] "+r"(rst_mask) + : [mask] "r"(vmask) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + dout_ptr += 2 * w_out; + } + } + } +} +// w_in <= 8 +void conv_depthwise_3x3s1p1_bias_s_relu_int8(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + //! pad is done implicit + const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + //! for 4x6 convolution window + const unsigned char right_pad_idx[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // printf("conv3x3_dw start \n"); + signed char* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(signed char)); + int* write_ptr = + reinterpret_cast(ctx->workspace_data()) + w_in; + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_h = (h_out + 3) >> 2; + + unsigned int size_pad_right = (unsigned int)(w_in); + + int size_pad_bottom = h_out % 4; + + uint8x8_t vmask_rp = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); + unsigned int rst_remain = (unsigned int)w_out; + uint32x4_t vmask_result1 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); + uint32x4_t vmask_result2 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); + + unsigned char vmask[8]; + vst1_u8(vmask, vmask_rp); + + unsigned int rmask[8]; + vst1q_u32(rmask, vmask_result1); + vst1q_u32(rmask + 4, vmask_result2); + + int8x8_t vzero = vdup_n_s8(0); + int32x4_t vzero_32 = vdupq_n_s32(0); + + for (int n = 0; n < num; ++n) { + const signed char* din_batch = din + n * ch_in * size_in_channel; + int* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + int* dout_ptr = dout_batch + c * size_out_channel; + + const signed char* din_ch_ptr = din_batch + c * size_in_channel; + + int bias_val = flag_bias ? bias[c] : 0; + + const signed char* wei_ptr = weights + c * w_stride; +#ifdef __aarch64__ + int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); + int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); + int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); + + int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); + + int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); +#endif + + int* doutr0 = nullptr; + int* doutr1 = nullptr; + + const signed char* dr0 = din_ch_ptr; + const signed char* dr1 = dr0 + w_in; + const signed char* dr2 = dr1 + w_in; + const signed char* dr3 = dr2 + w_in; + + const signed char* din_ptr0 = nullptr; + const signed char* din_ptr1 = nullptr; + const signed char* din_ptr2 = nullptr; + const signed char* din_ptr3 = nullptr; + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + + doutr0 = dout_ptr; + doutr1 = doutr0 + w_out; + int out_buf1[8]; + int out_buf2[8]; + int trash_buf[8]; + + unsigned int* rst_mask = rmask; + unsigned char* val_mask = vmask; + + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + din_ptr3 = dr2; + dr0 = dr1; + dr1 = dr2; + dr2 = dr3; + dr3 = dr2 + w_in; + } else { + dr0 = dr2; + dr1 = dr3; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + } + //! process bottom pad + if (i + 3 > h_in) { + switch (i + 3 - h_in) { + case 3: + din_ptr1 = zero_ptr; + case 2: + din_ptr2 = zero_ptr; + case 1: + din_ptr3 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 2 > h_out) { + doutr1 = trash_buf; + } +#ifdef __aarch64__ + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "PRFM PLDL1KEEP, [%[din_ptr3]] \n" + "movi v21.4s, #0x0\n" /* out0 = 0 */ + // left + "ld1 {v4.8b}, [%[vmask]] \n" + "ld1 {v0.8b}, [%[din_ptr0]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v1.8b}, [%[din_ptr1]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v2.8b}, [%[din_ptr2]], #8 \n" /* load + a00-a015 + to + q0*/ + "ld1 {v3.8b}, [%[din_ptr3]], #8 \n" /* load + a00-a015 + to + q0*/ + + "bif v0.8b, v21.8b, v4.8b \n" + "bif v1.8b, v21.8b, v4.8b \n" + "bif v2.8b, v21.8b, v4.8b \n" + "bif v3.8b, v21.8b, v4.8b \n" + + "ext v6.8b, v21.8b, v0.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v7.8b, v0.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "ld1 {v10.4s}, [%[vbias]] \n" + "ld1 {v11.4s}, [%[vbias]] \n" + + // r0 + "smull v18.8h, %[v1].8b, v0.8b \n" /* outr00 = 01234567 * w01 + */ + + "ext v8.8b, v21.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v9.8b, v1.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + "smlal v18.8h, %[v0].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "ld1 {v12.4s}, [%[vbias]] \n" + "ld1 {v13.4s}, [%[vbias]] \n" + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v2].8b, v7.8b \n" /* outr00 = 01234567 * w00 + */ + + "ext v6.8b, v21.8b, v2.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v7.8b, v2.8b, v21.8B, #1 \n" /* vext_s8(vinr0, vinr0_1, + 1); 12345678 */ + + // r1 + "smull v19.8h, %[v1].8b, v1.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v4].8b, v1.8b \n" /* outr00 = 01234567 * w00 + */ + + // "ld1 {v14.4s}, [%[rmask]], #16 \n" + // "ld1 {v15.4s}, [%[rmask]] \n" + + "smlal v19.8h, %[v0].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v3].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + // "ld1 {v16.4s}, [%[ptr_out0]], #16 \n" + // "ld1 {v17.4s}, [%[ptr_out1]], #16 \n" + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v2].8b, v9.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v5].8b, v9.8b \n" /* outr00 = 01234567 * w00 + */ + + "ext v8.8b, v21.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 00123456 */ + "ext v9.8b, v3.8b, v21.8B, #1 \n" // vext_s8(vinr0, vinr0_1, + // 1); 12345678 + + // "ld1 {v0.4s}, [%[ptr_out0]] \n" + // "ld1 {v1.4s}, [%[ptr_out1]] \n" + + // r2 + "smlal v19.8h, %[v4].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v7].8b, v2.8b \n" /* outr00 = 01234567 * w00 + */ + + // "sub %[ptr_out0], %[ptr_out0], #16 \n" + // "sub %[ptr_out1], %[ptr_out1], #16 \n" + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v3].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + "smlal v18.8h, %[v6].8b, v6.8b \n" /* outr00 = 01234567 * w00 + */ + + "smlal v19.8h, %[v5].8b, v7.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smull v18.8h, %[v8].8b, v7.8b \n" /* outr00 = 01234567 * w00 + */ + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + // r3 + "smull v19.8h, %[v7].8b, v3.8b \n" /* outr00 = 01234567 * w00 + */ + + "saddw v10.4s, v10.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v11.4s, v11.4s, v18.8h \n" /* v11 += outr00.high*/ + + "smlal v19.8h, %[v6].8b, v8.8b \n" /* outr00 = 01234567 * w00 + */ + + "smax v10.4s, v10.4s, v21.4s \n" /* relu */ + "smax v11.4s, v11.4s, v21.4s \n" /* relu */ + + // "bif v10.16b, v16.16b, v14.16b \n" + // "bif v11.16b, v0.16b, v15.16b \n" + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smull v19.8h, %[v8].8b, v9.8b \n" /* outr00 = 01234567 * w00 + */ + + "stp q10, q11, [%[ptr_out0]] \n" /* store q10, q11 -> ptr_out */ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smax v12.4s, v12.4s, v21.4s \n" /* relu */ + "smax v13.4s, v13.4s, v21.4s \n" /* relu */ + + // "bif v12.16b, v17.16b, v14.16b \n" + // "bif v13.16b, v1.16b, v15.16b \n" + + "stp q12, q13, [%[ptr_out1]] \n" /* store q10, q11 -> ptr_out */ + + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [rmask] "+r"(rst_mask) + : [v0] "w"(wr00), + [v1] "w"(wr01), + [v2] "w"(wr02), + [v3] "w"(wr10), + [vbias] "r"(vbias), + [v4] "w"(wr11), + [v5] "w"(wr12), + [v6] "w"(wr20), + [v7] "w"(wr21), + [v8] "w"(wr22), + [vmask] "r"(vmask), + [ptr_out0] "r"(out_buf1), + [ptr_out1] "r"(out_buf2) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + // store weights + asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" + : + : [wei_ptr] "r"(wei_ptr) + : "memory"); + asm volatile( + // left + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + "pld [%[din_ptr3]] @ preload data\n" + "vld1.8 {d28}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + "vld1.8 {d12}, [%[din_ptr0]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + "vld1.8 {d13}, [%[din_ptr1]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" + + "vmov.u32 d11, #0 @ zero\n" + // out0 + "vdup.32 q8, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q9, %[bias] @ and \n" // q9 = + // vbias + + "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d13, d11, d28 @ bit select, deal with right pad\n" + "vld1.8 {d14}, [%[din_ptr2]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + "vld1.8 {d15}, [%[din_ptr3]] @ load din00= 0 1 2 3 4 5 6 7 8 9\n" + // out1 + "vdup.32 q10, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q11, %[bias] @ and \n" // q9 = + // vbias + + // r0 + "vmull.s8 q12, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vext.8 d30, d11, d12, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d12, d11, #1 @ ext \n" // d11 = 12345678 + + "vdup.s8 d5, d0[3] @ d5 = w10, w10, w00, w00\n" + "vdup.s8 d6, d0[4] @ d6 = w11, w11, w01, w01\n" + + "vmlal.s8 q12, d30, d2 @ out0 += din0 * w00 \n" // q12 += d10 * w00 + + "vdup.s8 d7, d0[5] @ d7 = w12, w12\n" + "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d15, d11, d28 @ bit select, deal with right pad\n" + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q12, d31, d4 @ out0 += din0 * w02 \n" // q12 += d11 * w02 + + // r1 + "vext.8 d30, d11, d13, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d13, d11, #1 @ ext \n" // d11 = 12345678 + "vmull.s8 q13, d13, d3 @ out1 = din1 * w01 \n" // q13 = d12 * w01 + + "vmlal.s8 q12, d13, d6 @ out0 = din1 * w11 \n" // q12 = d12 * w11 + + "vdup.s8 d8, d0[6] @ d8 = w20, w00, w00, w00\n" + "vdup.s8 d9, d0[7] @ d9 = w21, w01, w01, w01\n" + + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d30, d2 @ out1 += din1 * w00 \n" // q12 += d10 * w00 + "vmull.s8 q12, d30, d5 @ out0 += din1 * w10 \n" // q12 += d10 * w00 + + "vdup.s8 d10, d1[0] @ d10 = w22, w02, w02, w02\n" + // "vld1.32 {d28-d29}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 + // 6 7 8 9\n" "vld1.32 {d12-d13}, [%[dout_ptr1]] @ load din00= 0 + // 1 2 3 4 5 6 7 8 9\n" + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d31, d4 @ out1 += din1 * w02 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d31, d7 @ out0 += din1 * w12 \n" // q12 += d10 * w00 + + // r2 + "vext.8 d30, d11, d14, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d14, d11, #1 @ ext \n" // d11 = 12345678 + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d14, d6 @ out1 = din2 * w11 \n" // q13 = d12 * w01 + "vmull.s8 q12, d14, d9 @ out1 = din2 * w21 \n" // q13 = d12 * w01 + + // "sub %[dout_ptr1], #16 @ sub \n" + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d30, d5 @ out1 += din2 * w10 \n" // q12 += d10 * w00 + "vmlal.s8 q12, d30, d8 @ out0 += din2 * w20 \n" // q12 += d10 * w00 + + // "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 + // 8 9\n" "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 + // 5 6 7 8 9\n" + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmlal.s8 q13, d31, d7 @ out1 += din2 * w12 \n" // q12 += d10 * w00 + "vmull.s8 q12, d31, d10 @ out0 += din2 * w22 \n" // q12 += d10 * w00 + + // r3 + "vext.8 d30, d11, d15, #7 @ ext \n" // d10 = 00123456 + "vext.8 d31, d15, d11, #1 @ ext \n" // d11 = 12345678 + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vaddw.s16 q8, q8, d24 @addw \n" // out0 += + // vget_low_s16(out00) + "vaddw.s16 q9, q9, d25 @addw \n" // out0_1 += + // vget_high_s16(out00) + + "vmull.s8 q13, d15, d9 @ out1 = din3 * w21 \n" // q13 = d12 * w01 + + "vmov.u32 q0, #0 @ zero\n" + + // "vld1.32 {d6-d7}, [%[dout_ptr2]]! @ load din00= 0 1 2 3 4 5 6 + // 7 8 9\n" "vld1.32 {d14-d15}, [%[dout_ptr2]] @ load din00= 0 1 + // 2 3 4 5 6 7 8 9\n" + + "vmlal.s8 q13, d30, d8 @ out1 += din3 * w20 \n" // q13 += d10 * w00 + + "vmax.s32 q8, q8, q0 @ max \n" + "vmax.s32 q9, q9, q0 @ max \n" + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmull.s8 q13, d31, d10 @ out1 += din3 * w22 \n" // q12 += d10 * w00 + + // "sub %[dout_ptr2], #16 @ sub \n" + // "vbif q8, q14, q1 @ bit select, deal with right + // pad\n" "vbif q9, q6, q2 @ bit select, deal + // with right pad\n" + + "vaddw.s16 q10, q10, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q11, q11, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vst1.32 {d16-d19}, [%[dout_ptr1]] @ store\n" + // "vst1.32 {d18-d19}, [%[dout_ptr1]]! @ store\n" + + "vmax.s32 q10, q10, q0 @ max \n" + "vmax.s32 q11, q11, q0 @ max \n" + + // "vbif q10, q3, q1 @ bit select, deal with right + // pad\n" "vbif q11, q7, q2 @ bit select, deal + // with right pad\n" + + "vst1.32 {d20-d23}, [%[dout_ptr2]] @ store\n" + // "vst1.32 {d22-d23}, [%[dout_ptr2]]! @ store\n" + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [bias] "+r"(bias_val), + [rs_mask] "+r"(rst_mask) + : [mask] "r"(vmask), + [dout_ptr1] "r"(out_buf1), + [dout_ptr2] "r"(out_buf2) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + for (int w = 0; w < w_out; ++w) { + *doutr0++ = out_buf1[w]; + *doutr1++ = out_buf2[w]; + } + dout_ptr += 2 * w_out; + } + } + } +} + +// 1 line w_in > 16 +void conv_depthwise_3x3s2p1_bias_relu_int8(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + // printf("3x3s2 mult height \n"); + //! pad is done implicit + //! for 4x6 convolution window + const unsigned char right_pad_idx[16] = { + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // printf("conv3x3_dw start \n"); + signed char* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(signed char)); + int* write_ptr = + reinterpret_cast(ctx->workspace_data()) + w_out; + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_w = (w_in + 15) >> 4; + int cnt_col = tile_w - 2; + + unsigned int size_pad_right = (unsigned int)(w_in - 15 - (cnt_col << 4)); + if (size_pad_right == 17) { + size_pad_right = 0; + cnt_col++; + } + + uint8x8_t vmask_rp1 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); + uint8x8_t vmask_rp2 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); + unsigned int rst_remain = (unsigned int)(w_out - ((cnt_col + 1) << 3)); + uint32x4_t vmask_result1 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); + uint32x4_t vmask_result2 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); + + int8x8_t vzero = vdup_n_s8(0); + int32x4_t vzero_32 = vdupq_n_s32(0); + + uint8x16_t vmask_rp = + vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); + unsigned char vmask[16]; + vst1q_u8(vmask, vmask_rp); + + unsigned int rmask[8]; + vst1q_u32(rmask, vmask_result1); + vst1q_u32(rmask + 4, vmask_result2); + + for (int n = 0; n < num; ++n) { + const signed char* din_batch = din + n * ch_in * size_in_channel; + int* dout_batch = dout + n * ch_in * size_out_channel; + +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + int* dout_ptr = dout_batch + c * size_out_channel; + + const signed char* din_ch_ptr = din_batch + c * size_in_channel; + + int bias_val = flag_bias ? bias[c] : 0; + + const signed char* wei_ptr = weights + c * w_stride; +#ifdef __aarch64__ + int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); + int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); + int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); + + int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); + + int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); +#endif + + int* doutr0 = nullptr; + + const signed char* dr0 = din_ch_ptr; + const signed char* dr1 = dr0 + w_in; + const signed char* dr2 = dr1 + w_in; + + const signed char* din_ptr0 = nullptr; + const signed char* din_ptr1 = nullptr; + const signed char* din_ptr2 = nullptr; + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + + doutr0 = dout_ptr; + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + dr0 = dr1; + dr1 = dr2; + dr2 = dr1 + w_in; + } else { + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + } + //! process bottom pad + if (i + 2 > h_in) { + switch (i + 2 - h_in) { + case 2: + din_ptr1 = zero_ptr; + case 1: + din_ptr2 = zero_ptr; + default: + break; + } + } + int cnt = cnt_col; +#ifdef __aarch64__ + unsigned char* val_mask = vmask; + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "movi v10.4s, #0x0\n" + // left + "ld2 {v0.8b - v1.8b}, [%[din_ptr0]] \n" /*load a00-a015 + to q0*/ + "ld2 {v2.8b - v3.8b}, [%[din_ptr1]] \n" /* load a00-a015 + to q0*/ + "ld2 {v4.8b - v5.8b}, [%[din_ptr2]] \n" /*load a00-a015 + to q0*/ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias*/ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "ext v6.8b, v10.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + "ext v7.8b, v10.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + "ext v8.8b, v10.8b, v5.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + + // r0 + "smull v14.8h, %[v1].8b, v0.8b \n" /* outr00 = 02468 * w01 */ + "smull v15.8h, %[v2].8b, v1.8b\n" /* outr00 += 13579 * w02 */ + "smull v16.8h, %[v0].8b, v6.8b\n" /* outr00 += 013579 * w00 */ + + "add %[din_ptr0], %[din_ptr0], #15 \n" + "add %[din_ptr1], %[din_ptr1], #15 \n" + "add %[din_ptr2], %[din_ptr2], #15 \n" + + // r1 + "smlal v14.8h, %[v4].8b, v2.8b \n" /* outr00 = 02468 * w01 */ + "smlal v15.8h, %[v5].8b, v3.8b\n" /* outr00 += 13579 * w02 */ + "smlal v16.8h, %[v3].8b, v7.8b\n" /* outr00 += 013579 * w00 */ + + "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ + + // r2 + "smull v14.8h, %[v7].8b, v4.8b \n" /* outr00 = 02468 * w01 */ + "smull v15.8h, %[v8].8b, v5.8b\n" /* outr00 += 13579 * w02 */ + "smull v16.8h, %[v6].8b, v8.8b\n" /* outr00 += 013579 * w00 */ + + "ld2 {v0.8b - v1.8b}, [%[din_ptr0]], #16 \n" /*load + a00-a015 + to q0*/ + "ld2 {v2.8b - v3.8b}, [%[din_ptr1]], #16 \n" /* load + a00-a015 + to q0*/ + "ld2 {v4.8b - v5.8b}, [%[din_ptr2]], #16 \n" /*load + a00-a015 + to q0*/ + + "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ + + "smax v12.4s, v12.4s, v10.4s \n" /*relu*/ + "smax v13.4s, v13.4s, v10.4s \n" /*relu*/ + + "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "cmp %[cnt], #1 \n" + "blt 3f \n" + // mid + "1: \n" + "ld1 {v6.8b}, [%[din_ptr0]] \n" /*load a00-a015 to q0*/ + "ld1 {v7.8b}, [%[din_ptr1]] \n" /*load a00-a015 to q0*/ + "ld1 {v8.8b}, [%[din_ptr2]] \n" /*load a00-a015 to q0*/ + + "ext v9.8b, v0.8b, v6.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 246810 */ + "ext v11.8b, v2.8b, v7.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 246810 */ + "ext v14.8b, v4.8b, v8.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 246810 */ + + // r0 + "smull v6.8h, %[v0].8b, v0.8b \n" /* outr00 = 02468 * w00 */ + "smull v7.8h, %[v1].8b, v1.8b\n" /* outr00 += 13579 * w01 */ + "smull v8.8h, %[v2].8b, v9.8b\n" /* outr00 += 246810 * w02 */ + + // r1 + "smlal v6.8h, %[v3].8b, v2.8b \n" /* outr00 = 02468 * w00 */ + "smlal v7.8h, %[v4].8b, v3.8b\n" /* outr00 += 13579 * w01 */ + "smlal v8.8h, %[v5].8b, v11.8b\n" /* outr00 += 246810 * w02 */ + + "saddw v12.4s, v12.4s, v6.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v6.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v7.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v7.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v8.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v8.8h \n" /* v11 += outr00.high*/ + + // r2 + "smull v6.8h, %[v6].8b, v4.8b \n" /* outr00 = 02468 * w00 */ + "smull v7.8h, %[v7].8b, v5.8b\n" /* outr00 += 13579 * w01 */ + "smull v8.8h, %[v8].8b, v14.8b\n" /* outr00 += 246810 * w02 */ + + "ld2 {v0.8b - v1.8b}, [%[din_ptr0]], #16 \n" /*load + a00-a015 + to q0*/ + "ld2 {v2.8b - v3.8b}, [%[din_ptr1]], #16 \n" /* load + a00-a015 + to q0*/ + "ld2 {v4.8b - v5.8b}, [%[din_ptr2]], #16 \n" /*load + a00-a015 + to q0*/ + + "saddw v12.4s, v12.4s, v6.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v6.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v7.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v7.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v8.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v8.8h \n" /* v11 += outr00.high*/ + + "smax v12.4s, v12.4s, v10.4s \n" /*relu*/ + "smax v13.4s, v13.4s, v10.4s \n" /*relu*/ + + "subs %[cnt], %[cnt], #1 \n" + + "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + "bne 1b \n" + // right + "3: \n" + "ld1 {v14.8b}, [%[vmask]], #8 \n" + "ld1 {v15.8b}, [%[vmask]] \n" + + "bif v0.8b, v10.8b, v14.8b \n" + "bif v1.8b, v10.8b, v15.8b \n" + "bif v2.8b, v10.8b, v14.8b \n" + "bif v3.8b, v10.8b, v15.8b \n" + "bif v4.8b, v10.8b, v14.8b \n" + "bif v5.8b, v10.8b, v15.8b \n" + + "ext v6.8b, v0.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 2468.. */ + "ext v7.8b, v2.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 2468..*/ + "ext v8.8b, v4.8b, v10.8B, #1 \n" /* vext_s8(vzero, vinr0, 7); + 2468.. */ + + // r0 + "smull v14.8h, %[v0].8b, v0.8b \n" /* outr00 = 02468 * w00 */ + "smull v15.8h, %[v1].8b, v1.8b\n" /* outr00 += 13579 * w01 */ + "smull v16.8h, %[v2].8b, v6.8b\n" /* outr00 += 246810 * w02 */ + + // r1 + "smlal v14.8h, %[v3].8b, v2.8b \n" /* outr00 = 02468 * w00 */ + "smlal v15.8h, %[v4].8b, v3.8b\n" /* outr00 += 13579 * w01 */ + "smlal v16.8h, %[v5].8b, v7.8b\n" /* outr00 += 246810 * w02 */ + + "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ + + // r2 + "smull v14.8h, %[v6].8b, v4.8b \n" /* outr00 = 02468 * w00 */ + "smull v15.8h, %[v7].8b, v5.8b\n" /* outr00 += 13579 * w01 */ + "smull v16.8h, %[v8].8b, v8.8b\n" /* outr00 += 246810 * w02 */ + + "ldp q0, q1, [%[ptr_out0]] \n" /* dup v10, bias */ + "ldp q9, q11, [%[rst_mask]] \n" /* dup v10, bias */ + + "saddw v12.4s, v12.4s, v14.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v14.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v15.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v15.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v16.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v16.8h \n" /* v11 += outr00.high*/ + + "smax v12.4s, v12.4s, v10.4s \n" /*relu*/ + "smax v13.4s, v13.4s, v10.4s \n" /*relu*/ + + "bif v12.16b, v0.16b, v9.16b \n" + "bif v13.16b, v1.16b, v11.16b \n" + + "stp q12, q13, [%[ptr_out0]], #32 \n" /* store q10, q11 -> + ptr_out */ + + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [ptr_out0] "+r"(doutr0), + [vmask] "+r"(val_mask) + : [v0] "w"(wr00), + [v1] "w"(wr01), + [v2] "w"(wr02), + [v3] "w"(wr10), + [bias_val] "r"(vbias), + [v4] "w"(wr11), + [v5] "w"(wr12), + [v6] "w"(wr20), + [v7] "w"(wr21), + [v8] "w"(wr22), + [rst_mask] "r"(rmask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + unsigned int* rst_mask = rmask; + // prefetch input + // store weights + asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" + : + : [wei_ptr] "r"(wei_ptr) + : "memory"); + asm volatile( + // left + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" + "vld2.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 + "vld2.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 + "vld2.8 {d16-d17}, [%[din_ptr2]] @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 + "vmov.u32 d11, #0 @ zero\n" + + "vdup.s8 d5, d0[3] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d6, d0[4] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d7, d0[5] @ d4 = w02, w02, w02, w02\n" + + "vext.8 d18, d11, d13, #7 @ ext \n" // d16 = -1 1 3 5 + "vext.8 d19, d11, d15, #7 @ ext \n" // d17 = -1 1 3 5 + "vext.8 d20, d11, d17, #7 @ ext \n" // d18 = -1 1 3 5 + + // r0 + "vmull.s8 q13, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vmull.s8 q14, d13, d4 @ out1 = din0 * w02 \n" // q12 = d12 * w02 + "vmull.s8 q15, d18, d2 @ out2 = din0 * w00 \n" // q12 = d12 * w02 + + "vdup.s8 d8, d0[6] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d9, d0[7] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d10, d1[0] @ d4 = w02, w02, w02, w02\n" + + // out0 + "vdup.32 q11, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q12, %[bias] @ and \n" // q9 = + // vbias + + // r1 + "vmlal.s8 q13, d14, d6 @ out0 += din1 * w11 \n" // q12 = d12 * w11 + "vmlal.s8 q14, d15, d7 @ out1 += din1 * w12 \n" // q12 = d12 * w11 + "vmlal.s8 q15, d19, d5 @ out2 += din1 * w10 \n" // q12 = d12 * w11 + + "add %[din_ptr0], #15 @add \n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "add %[din_ptr1], #15 @add \n" + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "add %[din_ptr2], #15 @add \n" + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + // r2 + "vmull.s8 q13, d16, d9 @ out0 += din1 * w21 \n" // q12 = d12 * w11 + "vmull.s8 q14, d17, d10 @ out1 += din1 * w22 \n" // q12 = d12 * w11 + "vmull.s8 q15, d20, d8 @ out2 += din1 * w20 \n" // q12 = d12 * w11 + + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmov.u32 q8, #0 @ max \n" // max + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmax.s32 q11, q11, q8 @ max\n" + "vmax.s32 q12, q12, q8 @ max\n" + + "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" + "cmp %[cnt], #1 \n" + "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" + "blt 1f \n" + + // mid + "2: \n" + "vld2.8 {d12-d13}, [%[din_ptr0]]! @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 + "vld2.8 {d14-d15}, [%[din_ptr1]]! @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 + "vld2.8 {d16-d17}, [%[din_ptr2]]! @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 + + "vld1.8 {d21}, [%[din_ptr0]] @ load din00= 16 17\n" // d10 = 0 2 + // 4 6 + "vld1.8 {d22}, [%[din_ptr1]] @ load din00= 16 17\n" // d12 = 0 2 + // 4 6 + "vld1.8 {d23}, [%[din_ptr2]] @ load din00= 16 17\n" // d14 = 0 2 + // 4 6 + + "vext.8 d18, d12, d21, #1 @ ext din00 = 2 4 6 8\n" // d16 = 2 + // 4 6 8 + "vext.8 d19, d14, d22, #1 @ ext \n" // d17 = 2 4 6 8 + "vext.8 d20, d16, d23, #1 @ ext \n" // d18 = 2 4 6 8 + + // r0 + "vmull.s8 q13, d12, d2 @ out0 = din0 * w00 \n" // q12 = 0 2 4 6 + "vmull.s8 q14, d13, d3 @ out1 = din0 * w01 \n" // q12 = 1 3 5 7 + "vmull.s8 q15, d18, d4 @ out2 = din0 * w02 \n" // q12 = 2 4 6 8 + + // out0 + "vdup.32 q11, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q12, %[bias] @ and \n" // q9 = + // vbias + + // r1 + "vmlal.s8 q13, d14, d5 @ out0 += din1 * w10 \n" // q12 = 0 2 4 6 + "vmlal.s8 q14, d15, d6 @ out1 += din1 * w11 \n" // q12 = 1 3 5 7 + "vmlal.s8 q15, d19, d7 @ out2 += din1 * w12 \n" // q12 = 2 4 6 8 + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + // r2 + "vmull.s8 q13, d16, d8 @ out0 += din1 * w20 \n" // q12 = 0 2 4 6 + "vmull.s8 q14, d17, d9 @ out1 += din1 * w21 \n" // q12 = 1 3 5 7 + "vmull.s8 q15, d20, d10 @ out2 += din1 * w22 \n" // q12 = 2 4 6 8 + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vmov.u32 q8, #0 @ mov \n" + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + + "vmax.s32 q11, q11, q8 @ max\n" + "vmax.s32 q12, q12, q8 @ max\n" + + "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" + + "subs %[cnt], #1 \n" + "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" + "bne 2b \n" + // right + "1: \n" + "cmp %[size_pad_right], #1 \n" + "blt 3f \n" + "vld2.8 {d12-d13}, [%[din_ptr0]]! @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 + "vld2.8 {d14-d15}, [%[din_ptr1]]! @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 + "vld2.8 {d16-d17}, [%[din_ptr2]]! @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 + "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + + // out0 + "vdup.32 q11, %[bias] @ and \n" // q8 = vbias + "vdup.32 q12, %[bias] @ and \n" // q9 = vbias + + "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" + + "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" + + "vbif.8 d16, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d17, d11, d29 @ bit select, deal with right pad\n" + + "vext.8 d18, d12, d11, #1 @ ext din00 = 2 4 6 8\n" // d16 = -1 + // 1 3 5 + "vext.8 d19, d14, d11, #1 @ ext \n" // d17 = -1 1 3 5 + "vext.8 d20, d16, d11, #1 @ ext \n" // d18 = -1 1 3 5 + + // r0 + "vmull.s8 q13, d12, d2 @ out0 = din0 * w00 \n" // q12 = 0 2 4 6 + "vmull.s8 q14, d13, d3 @ out1 = din0 * w01 \n" // q12 = 1 3 5 7 + "vmull.s8 q15, d18, d4 @ out2 = din0 * w02 \n" // q12 = 2 4 6 8 + + // r1 + "vmlal.s8 q13, d14, d5 @ out0 += din1 * w11 \n" // q12 = 0 2 4 6 + "vmlal.s8 q14, d15, d6 @ out1 += din1 * w12 \n" // q12 = 1 3 5 7 + "vmlal.s8 q15, d19, d7 @ out2 += din1 * w10 \n" // q12 = 2 4 6 8 + + "vld1.32 {d12-d13}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + "vld1.32 {d14-d15}, [%[dout_ptr1]] @ load din00= 0 1 2 3 4 5 6 " + "7 8 9\n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + // r2 + "vmull.s8 q13, d16, d8 @ out0 += din1 * w11 \n" // q12 = 0 2 4 6 + "vmull.s8 q14, d17, d9 @ out1 += din1 * w12 \n" // q12 = 1 3 5 7 + "vmull.s8 q15, d20, d10 @ out2 += din1 * w10 \n" // q12 = 2 4 6 8 + + "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 5 6 7 8 " + "9\n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "sub %[dout_ptr1], #16 @ sub \n" + "vmov.u32 q8, #0 @mov \n" + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmax.s32 q11, q11, q8 @ max\n" + "vmax.s32 q12, q12, q8 @ max\n" + + "vbif q11, q6, q1 @ bit select, deal with right pad\n" + "vbif q12, q7, q2 @ bit select, deal with right pad\n" + + "vst1.32 {d22-d23}, [%[dout_ptr1]]! @ store\n" + "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" + "3: \n" + + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [dout_ptr1] "+r"(doutr0), + [cnt] "+r"(cnt), + [bias] "+r"(bias_val), + [rs_mask] "+r"(rst_mask) + : [mask] "r"(vmask), [size_pad_right] "r"(size_pad_right) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + dout_ptr += w_out; + } + } + } +} +// w_in <= 16 +void conv_depthwise_3x3s2p1_bias_s_relu_int8(int* dout, + const signed char* din, + const signed char* weights, + const int* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + // printf("3x3s2 mult height \n"); + //! pad is done implicit + // const char zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + //! for 4x6 convolution window + const unsigned char right_pad_idx[16] = { + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const unsigned int right_pad_rst[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // printf("conv3x3_dw start \n"); + signed char* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(signed char)); + int* write_ptr = + reinterpret_cast(ctx->workspace_data()) + w_out; + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + unsigned int size_pad_right = (unsigned int)(w_in); + + uint8x8_t vmask_rp1 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx)); + uint8x8_t vmask_rp2 = + vcgt_u8(vdup_n_u8(size_pad_right), vld1_u8(right_pad_idx + 8)); + unsigned int rst_remain = (unsigned int)w_out; + uint32x4_t vmask_result1 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst)); + uint32x4_t vmask_result2 = + vcgtq_u32(vdupq_n_u32(rst_remain), vld1q_u32(right_pad_rst + 4)); + + uint8x16_t vmask_rp = + vcgtq_u8(vdupq_n_u8(size_pad_right), vld1q_u8(right_pad_idx)); + unsigned char vmask[16]; + vst1q_u8(vmask, vmask_rp); + + unsigned int rmask[8]; + vst1q_u32(rmask, vmask_result1); + vst1q_u32(rmask + 4, vmask_result2); + int8x8_t vzero = vdup_n_s8(0); + int32x4_t vzero_32 = vdupq_n_s32(0); + + for (int n = 0; n < num; ++n) { + const signed char* din_batch = din + n * ch_in * size_in_channel; + int* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int c = 0; c < ch_in; c++) { + int* dout_ptr = dout_batch + c * size_out_channel; + + const signed char* din_ch_ptr = din_batch + c * size_in_channel; + + int bias_val = flag_bias ? bias[c] : 0; + + const signed char* wei_ptr = weights + c * w_stride; + +#ifdef __aarch64__ + int vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + int8x8_t wr00 = vdup_n_s8(wei_ptr[0]); + int8x8_t wr10 = vdup_n_s8(wei_ptr[3]); + int8x8_t wr20 = vdup_n_s8(wei_ptr[6]); + + int8x8_t wr01 = vdup_n_s8(wei_ptr[1]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[4]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[7]); + + int8x8_t wr02 = vdup_n_s8(wei_ptr[2]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[5]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[8]); +#endif + + int* doutr0 = nullptr; + + const signed char* dr0 = din_ch_ptr; + const signed char* dr1 = dr0 + w_in; + const signed char* dr2 = dr1 + w_in; + + const signed char* din_ptr0 = nullptr; + const signed char* din_ptr1 = nullptr; + const signed char* din_ptr2 = nullptr; + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + + doutr0 = dout_ptr; + + int out_buf1[8]; + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + dr0 = dr1; + dr1 = dr2; + dr2 = dr1 + w_in; + } else { + dr0 = dr2; + dr1 = dr2 + w_in; + dr2 = dr1 + w_in; + } + //! process bottom pad + if (i + 2 > h_in) { + switch (i + 2 - h_in) { + case 2: + din_ptr1 = zero_ptr; + case 1: + din_ptr2 = zero_ptr; + default: + break; + } + } +#ifdef __aarch64__ + unsigned int* rst_mask = rmask; + unsigned char* val_mask = vmask; + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "movi v16.4s, #0x0\n" + // left + "ld1 {v10.8b}, [%[vmask]], #8 \n" + "ld1 {v11.8b}, [%[vmask]] \n" + "ld2 {v0.8b - v1.8b}, [%[din_ptr0]] \n" /*load a00-a015 + to q0*/ + "ld2 {v2.8b - v3.8b}, [%[din_ptr1]] \n" /* load a00-a015 + to q0*/ + "ld2 {v4.8b - v5.8b}, [%[din_ptr2]] \n" /*load a00-a015 + to q0*/ + + "bif v0.8b, v16.8b, v10.8b \n" + "bif v1.8b, v16.8b, v11.8b \n" + "bif v2.8b, v16.8b, v10.8b \n" + "bif v3.8b, v16.8b, v11.8b \n" + "bif v4.8b, v16.8b, v10.8b \n" + "bif v5.8b, v16.8b, v11.8b \n" + + "ld1 {v12.4s}, [%[bias_val]] \n" /* dup v10, bias*/ + "ld1 {v13.4s}, [%[bias_val]] \n" /* dup v10, bias */ + + "ext v6.8b, v16.8b, v1.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + "ext v7.8b, v16.8b, v3.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + "ext v8.8b, v16.8b, v5.8B, #7 \n" /* vext_s8(vzero, vinr0, 7); + 013579 */ + + // r0 + "smull v17.8h, %[v1].8b, v0.8b \n" /* outr00 = 02468 * w01 */ + "smull v18.8h, %[v2].8b, v1.8b\n" /* outr00 += 13579 * w02 */ + "smull v19.8h, %[v0].8b, v6.8b\n" /* outr00 += 013579 * w00 */ + + // "ldp q0, q1, [%[ptr_out0]] \n" /* dup v10, + // bias */ "ldp q10, q11, [%[rst_mask]] \n" /* + // dup v10, bias */ + + // r1 + "smlal v17.8h, %[v4].8b, v2.8b \n" /* outr00 = 02468 * w01 */ + "smlal v18.8h, %[v5].8b, v3.8b\n" /* outr00 += 13579 * w02 */ + "smlal v19.8h, %[v3].8b, v7.8b\n" /* outr00 += 013579 * w00 */ + + "saddw v12.4s, v12.4s, v17.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v17.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v18.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + // r2 + "smull v17.8h, %[v7].8b, v4.8b \n" /* outr00 = 02468 * w01 */ + "smull v18.8h, %[v8].8b, v5.8b\n" /* outr00 += 13579 * w02 */ + "smull v19.8h, %[v6].8b, v8.8b\n" /* outr00 += 013579 * w00 */ + + "saddw v12.4s, v12.4s, v17.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v17.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v18.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v18.8h \n" /* v11 += outr00.high*/ + + "saddw v12.4s, v12.4s, v19.4h \n" /* v10 += outr00.low*/ + "saddw2 v13.4s, v13.4s, v19.8h \n" /* v11 += outr00.high*/ + + "smax v12.4s, v12.4s, v16.4s \n" /*relu*/ + "smax v13.4s, v13.4s, v16.4s \n" /*relu*/ + + // "bif v12.16b, v0.16b, v10.16b \n" + // "bif v13.16b, v1.16b, v11.16b \n" + + "stp q12, q13, [%[ptr_out0]] \n" /* store q10, q11 -> ptr_out + */ + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [vmask] "+r"(val_mask) + : [v0] "w"(wr00), + [v1] "w"(wr01), + [v2] "w"(wr02), + [v3] "w"(wr10), + [bias_val] "r"(vbias), + [v4] "w"(wr11), + [v5] "w"(wr12), + [v6] "w"(wr20), + [v7] "w"(wr21), + [v8] "w"(wr22), + [rst_mask] "r"(rmask), + [ptr_out0] "r"(out_buf1) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); + +#else + unsigned int* rst_mask = rmask; + // prefetch input + // store weights + asm volatile("vld1.8 {d0-d1}, [%[wei_ptr]] \n" + : + : [wei_ptr] "r"(wei_ptr) + : "memory"); + asm volatile( + // left + "pld [%[din_ptr0]] @ preload data\n" + "pld [%[din_ptr1]] @ preload data\n" + "pld [%[din_ptr2]] @ preload data\n" + "vdup.s8 d2, d0[0] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d3, d0[1] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d4, d0[2] @ d4 = w02, w02, w02, w02\n" + "vld2.8 {d12-d13}, [%[din_ptr0]] @ load din00= 0 2 4 6 8\n" // d10 = 0 2 4 6 + "vld2.8 {d14-d15}, [%[din_ptr1]] @ load din00= 0 2 4 6 8\n" // d12 = 0 2 4 6 + "vld2.8 {d16-d17}, [%[din_ptr2]] @ load din00= 0 2 4 6 8\n" // d14 = 0 2 4 6 + "vld1.8 {d28-d29}, [%[mask]] @ load din00= 0 1 2 3 4 5 6 7 " + "8 9\n" + "vmov.u32 d11, #0 @ zero\n" + + "vdup.s8 d5, d0[3] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d6, d0[4] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d7, d0[5] @ d4 = w02, w02, w02, w02\n" + + "vbif.8 d12, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d13, d11, d29 @ bit select, deal with right pad\n" + + "vbif.8 d14, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d15, d11, d29 @ bit select, deal with right pad\n" + + "vbif.8 d16, d11, d28 @ bit select, deal with right pad\n" + "vbif.8 d17, d11, d29 @ bit select, deal with right pad\n" + + "vext.8 d18, d11, d13, #7 @ ext \n" // d16 = -1 1 3 5 + "vext.8 d19, d11, d15, #7 @ ext \n" // d17 = -1 1 3 5 + "vext.8 d20, d11, d17, #7 @ ext \n" // d18 = -1 1 3 5 + + // "pld [%[dout_ptr1]] @ preload data\n" + + // r0 + "vmull.s8 q13, d12, d3 @ out0 = din0 * w01 \n" // q12 = d12 * w01 + "vmull.s8 q14, d13, d4 @ out1 = din0 * w02 \n" // q12 = d12 * w02 + "vmull.s8 q15, d18, d2 @ out2 = din0 * w00 \n" // q12 = d12 * w02 + + "vdup.s8 d8, d0[6] @ d2 = w00, w00, w00, w00\n" + "vdup.s8 d9, d0[7] @ d3 = w01, w01, w01, w01\n" + "vdup.s8 d10, d1[0] @ d4 = w02, w02, w02, w02\n" + + // out0 + "vdup.32 q11, %[bias] @ and \n" // q8 = + // vbias + "vdup.32 q12, %[bias] @ and \n" // q9 = + // vbias + + // r1 + "vmlal.s8 q13, d14, d6 @ out0 += din1 * w11 \n" // q12 = d12 * w11 + "vmlal.s8 q14, d15, d7 @ out1 += din1 * w12 \n" // q12 = d12 * w11 + "vmlal.s8 q15, d19, d5 @ out2 += din1 * w10 \n" // q12 = d12 * w11 + + // "vld1.32 {d12-d13}, [%[dout_ptr1]]! @ load din00= 0 1 2 3 4 5 + // 6 7 8 9\n" "vld1.32 {d14-d15}, [%[dout_ptr1]] @ load din00= 0 + // 1 2 3 4 5 6 7 8 9\n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + // r2 + "vmull.s8 q13, d16, d9 @ out0 += din1 * w21 \n" // q12 = d12 * w11 + "vmull.s8 q14, d17, d10 @ out1 += din1 * w22 \n" // q12 = d12 * w11 + "vmull.s8 q15, d20, d8 @ out2 += din1 * w20 \n" // q12 = d12 * w11 + + // "vld1.32 {d2-d3}, [%[rs_mask]]! @ load din00= 0 1 2 3 4 5 6 7 + // 8 9\n" "vld1.32 {d4-d5}, [%[rs_mask]] @ load din00= 0 1 2 3 4 + // 5 6 7 8 9\n" + + // "sub %[dout_ptr1], #16 @ sub \n" + + "vaddw.s16 q11, q11, d26 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d27 @addw \n" // out1_1 += + // vget_high_s16(out10) + "vmov.u32 q8, #0 @ mov \n" + + "vaddw.s16 q11, q11, d28 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d29 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vaddw.s16 q11, q11, d30 @addw \n" // out1 += + // vget_low_s16(out10) + "vaddw.s16 q12, q12, d31 @addw \n" // out1_1 += + // vget_high_s16(out10) + + "vmax.s32 q11, q11, q8 @ max\n" + "vmax.s32 q12, q12, q8 @ max\n" + + // "vbif q11, q6, q1 @ bit select, deal with right pad\n" + // "vbif q12, q7, q2 @ bit select, deal with right pad\n" + + "vst1.32 {d22-d25}, [%[dout_ptr1]] @ store\n" + // "vst1.32 {d24-d25}, [%[dout_ptr1]]! @ store\n" + : [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [bias] "+r"(bias_val), + [rs_mask] "+r"(rst_mask) + : [mask] "r"(vmask), + [size_pad_right] "r"(size_pad_right), + [dout_ptr1] "r"(out_buf1) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + for (int w = 0; w < w_out; ++w) { + *doutr0++ = out_buf1[w]; + } + dout_ptr += w_out; + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_depthwise_3x3p0.cc b/lite/arm/math/conv_depthwise_3x3p0.cc new file mode 100644 index 00000000000..9eb45514d23 --- /dev/null +++ b/lite/arm/math/conv_depthwise_3x3p0.cc @@ -0,0 +1,4178 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/conv_depthwise.h" +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void conv_depthwise_3x3s1p0_bias(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +//! for input width <= 4 +void conv_depthwise_3x3s1p0_bias_s(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p0_bias(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +//! for input width <= 4 +void conv_depthwise_3x3s2p0_bias_s(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s1p0_bias_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +//! for input width <= 4 +void conv_depthwise_3x3s1p0_bias_s_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p0_bias_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +//! for input width <= 4 +void conv_depthwise_3x3s2p0_bias_s_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3p0(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int stride, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + if (stride == 1) { + if (flag_relu) { + if (w_in > 5) { + conv_depthwise_3x3s1p0_bias_relu(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s1p0_bias_s_relu(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } else { + if (w_in > 5) { + conv_depthwise_3x3s1p0_bias(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s1p0_bias_s(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + } else { //! stride = 2 + if (flag_relu) { + if (w_in > 8) { + conv_depthwise_3x3s2p0_bias_relu(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p0_bias_s_relu(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } else { + if (w_in > 8) { + conv_depthwise_3x3s2p0_bias(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p0_bias_s(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + } +} +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, + * width > 4 + */ +// 4line +void conv_depthwise_3x3s1p0_bias(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + //! pad is done implicit + const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + //! for 4x6 convolution window + const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_w = w_out >> 2; + int remain = w_out % 4; + + unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); + const int remian_idx[4] = {0, 1, 2, 3}; + + uint32x4_t vmask_rp1 = + vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); + uint32x4_t vmask_rp2 = + vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); + uint32x4_t vmask_result = + vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); + + unsigned int vmask[8]; + vst1q_u32(vmask, vmask_rp1); + vst1q_u32(vmask + 4, vmask_rp2); + + unsigned int rmask[4]; + vst1q_u32(rmask, vmask_result); + + float32x4_t vzero = vdupq_n_f32(0.f); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for +#ifdef __aarch64__ + for (int c = 0; c < ch_in; c++) { + float* dout_ptr = dout_batch + c * size_out_channel; + + const float* din_ch_ptr = din_batch + c * size_in_channel; + + float bias_val = flag_bias ? bias[c] : 0.f; + float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + + const float* wei_ptr = weights + c * w_stride; + + float32x4_t wr0 = vld1q_f32(wei_ptr); + float32x4_t wr1 = vld1q_f32(wei_ptr + 3); + float32x4_t wr2 = vld1q_f32(wei_ptr + 6); + // wr0 = vsetq_lane_f32(0.f, wr0, 3); + // wr1 = vsetq_lane_f32(0.f, wr1, 3); + // wr2 = vsetq_lane_f32(0.f, wr2, 3); + + float* doutr0 = dout_ptr; + float* doutr1 = doutr0 + w_out; + float* doutr2 = doutr1 + w_out; + float* doutr3 = doutr2 + w_out; + + const float* dr0 = din_ch_ptr; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + const float* dr5 = dr4 + w_in; + + const float* din_ptr0 = dr0; + const float* din_ptr1 = dr1; + const float* din_ptr2 = dr2; + const float* din_ptr3 = dr3; + const float* din_ptr4 = dr4; + const float* din_ptr5 = dr5; + + for (int i = 0; i < h_out; i += 4) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + din_ptr4 = dr4; + din_ptr5 = dr5; + + doutr0 = dout_ptr; + doutr1 = doutr0 + w_out; + doutr2 = doutr1 + w_out; + doutr3 = doutr2 + w_out; + + dr0 = dr4; + dr1 = dr5; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + dr5 = dr4 + w_in; + + //! process bottom pad + if (i + 5 >= h_in) { + switch (i + 5 - h_in) { + case 5: + din_ptr1 = zero_ptr; + case 4: + din_ptr2 = zero_ptr; + case 3: + din_ptr3 = zero_ptr; + case 2: + din_ptr4 = zero_ptr; + case 1: + din_ptr5 = zero_ptr; + case 0: + din_ptr5 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 4 > h_out) { + switch (i + 4 - h_out) { + case 3: + doutr1 = write_ptr; + case 2: + doutr2 = write_ptr; + case 1: + doutr3 = write_ptr; + default: + break; + } + } + + int cnt = tile_w; + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "PRFM PLDL1KEEP, [%[din_ptr3]] \n" + "PRFM PLDL1KEEP, [%[din_ptr4]] \n" + "PRFM PLDL1KEEP, [%[din_ptr5]] \n" + "movi v21.4s, #0x0\n" /* out0 = 0 */ + + "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ + + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ + + // mid + // "cmp %[cnt], #1 \n" + // "blt 5f \n" + "4: \n" + // r0 + "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ + + // r1 + "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ + + // r2 + "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "st1 {v12.4s}, [%[doutr0]], #16 \n" + + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ + + // r4 + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "st1 {v13.4s}, [%[doutr1]], #16 \n" + + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ + + // r5 + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "st1 {v14.4s}, [%[doutr2]], #16 \n" + + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ + + "subs %[cnt], %[cnt], #1 \n" + + "st1 {v15.4s}, [%[doutr3]], #16 \n" + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "bne 4b \n" + + // right + "5: \n" + "cmp %[remain], #1 \n" + "blt 0f \n" + "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" + "ld1 {v22.4s}, [%[doutr0]] \n" + "ld1 {v23.4s}, [%[doutr1]] \n" + "ld1 {v24.4s}, [%[doutr2]] \n" + "ld1 {v25.4s}, [%[doutr3]] \n" + + "bif v0.16b, %[vzero].16b, v18.16b \n" + "bif v1.16b, %[vzero].16b, v19.16b \n" + "bif v2.16b, %[vzero].16b, v18.16b \n" + "bif v3.16b, %[vzero].16b, v19.16b \n" + + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + + // r0 + "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "bif v4.16b, %[vzero].16b, v18.16b \n" + "bif v5.16b, %[vzero].16b, v19.16b \n" + "bif v6.16b, %[vzero].16b, v18.16b \n" + "bif v7.16b, %[vzero].16b, v19.16b \n" + + "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "bif v8.16b, %[vzero].16b, v18.16b \n" + "bif v9.16b, %[vzero].16b, v19.16b \n" + "bif v10.16b, %[vzero].16b, v18.16b \n" + "bif v11.16b, %[vzero].16b, v19.16b \n" + + "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ + "ld1 {v18.4s}, [%[rmask]] \n" + + // r1 + "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ + + // r2 + "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "bif v12.16b, v22.16b, v18.16b \n" + + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "st1 {v12.4s}, [%[doutr0]], #16 \n" + + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "bif v13.16b, v23.16b, v18.16b \n" + + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "st1 {v13.4s}, [%[doutr1]], #16 \n" + + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "bif v14.16b, v24.16b, v18.16b \n" + + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "st1 {v14.4s}, [%[doutr2]], #16 \n" + + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "bif v15.16b, v25.16b, v18.16b \n" + + "st1 {v15.4s}, [%[doutr3]], #16 \n" + // end + "0: \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + dout_ptr = dout_ptr + 4 * w_out; + } + } +#else + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + float bias_val = flag_bias ? bias[i] : 0.f; + + float* dout_channel = dout_batch + i * size_out_channel; + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + + const float* din0_ptr = nullptr; + const float* din1_ptr = nullptr; + const float* din2_ptr = nullptr; + const float* din3_ptr = nullptr; + + float* doutr0 = nullptr; + float* doutr1 = nullptr; + + float* ptr_zero = const_cast(zero); + + for (int i = 0; i < h_out; i += 2) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + + doutr0 = dout_channel; + doutr1 = dout_channel + w_out; + + dr0 = dr2; + dr1 = dr3; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + //! process bottom pad + if (i + 3 >= h_in) { + switch (i + 3 - h_in) { + case 3: + din1_ptr = zero_ptr; + case 2: + din2_ptr = zero_ptr; + case 1: + din3_ptr = zero_ptr; + case 0: + din3_ptr = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 2 > h_out) { + doutr1 = write_ptr; + } + int cnt = tile_w; + unsigned int* rmask_ptr = rmask; + unsigned int* vmask_ptr = vmask; + asm volatile( + "pld [%[din0_ptr]] @ preload data\n" + "pld [%[din1_ptr]] @ preload data\n" + "pld [%[din2_ptr]] @ preload data\n" + "pld [%[din3_ptr]] @ preload data\n" + + "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" + "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r1\n" + "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r2\n" + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r3\n" + "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" + "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" + "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" + + "vdup.32 q4, %[bias_val] @ and \n" // q4 + // = + // vbias + "vdup.32 q5, %[bias_val] @ and \n" // q5 + // = + // vbias + + "vext.32 q6, q8, q9, #1 @ 1234\n" + "vext.32 q7, q8, q9, #2 @ 2345\n" + // mid + "1: @ right pad entry\n" + // r0 + "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" + + "pld [%[din0_ptr]] @ preload data\n" + "pld [%[din1_ptr]] @ preload data\n" + "pld [%[din2_ptr]] @ preload data\n" + "pld [%[din3_ptr]] @ preload data\n" + + "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" + + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" + + "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" + + "vext.32 q6, q10, q11, #1 @ 1234\n" + "vext.32 q7, q10, q11, #2 @ 2345\n" + + // r1 + "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" + + "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" + + "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q12, q13, #1 @ 1234\n" + "vext.32 q7, q12, q13, #2 @ 2345\n" + + // r2 + "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" + + "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" + + "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q14, q15, #1 @ 1234\n" + "vext.32 q7, q14, q15, #2 @ 2345\n" + + // r3 + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" + + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" + + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" + "vdup.32 q4, %[bias_val] @ and \n" // q4 + // = + // vbias + + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" + + "vext.32 q6, q8, q9, #1 @ 1234\n" + "vext.32 q7, q8, q9, #2 @ 2345\n" + + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " + "pointer\n" + + "subs %[cnt], #1 @ loop count minus 1\n" + + "vdup.32 q5, %[bias_val] @ and \n" // q4 + // = + // vbias + + "bne 1b @ jump to main loop start " + "point\n" + + // right + "3: @ right pad entry\n" + "cmp %[remain], #1 @ check whether has " + "mid cols\n" + "blt 0f @ jump to main loop start " + "point\n" + "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" + "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" + + "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" + "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" + + "vbif d16, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d17, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d18, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vbif d20, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d21, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d22, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vext.32 q6, q8, q9, #1 @ 1234\n" + "vext.32 q7, q8, q9, #2 @ 2345\n" + + // r0 + "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" + + "vbif d24, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d25, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d26, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + + "vbif d28, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d29, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d30, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" + + "vext.32 q6, q10, q11, #1 @ 1234\n" + "vext.32 q7, q10, q11, #2 @ 2345\n" + + // r1 + "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" + "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" + + "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" + "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" + + "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q12, q13, #1 @ 1234\n" + "vext.32 q7, q12, q13, #2 @ 2345\n" + + // r2 + "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q14, q15, #1 @ 1234\n" + "vext.32 q7, q14, q15, #2 @ 2345\n" + + // r3 + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" + + "vbif d8, d16, d19 @ bit select, deal with right pad\n" + "vbif d9, d17, d23 @ bit select, deal with right pad\n" + + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" + + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" + + "vbif d10, d20, d19 @ bit select, deal with right " + "pad\n" + "vbif d11, d21, d23 @ bit select, deal with right " + "pad\n" + + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " + "pointer\n" + "0: \n" + + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + dout_channel += 2 * w_out; + } //! end of processing mid rows + } +#endif + } +} + +/** + * \brief depthwise convolution kernel 3x3, stride 2 + */ +// w_in > 7 +void conv_depthwise_3x3s2p0_bias(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + + int tile_w = w_out >> 2; + int cnt_remain = w_out % 4; + + unsigned int size_right_remain = (unsigned int)(w_in - (tile_w << 3)); + + uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + uint32x4_t wmask = + vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + unsigned int dmask[12]; + + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + vst1q_u32(dmask + 8, wmask); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float32x4_t vzero = vdupq_n_f32(0.f); + + float32x4_t wbias; + float bias_c = 0.f; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + bias_c = bias[i]; + } else { + wbias = vdupq_n_f32(0.f); + } + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + const float* din3_ptr = dr3; + const float* din4_ptr = dr4; + + float* doutr0 = dout_channel; + float* doutr0_ptr = nullptr; + float* doutr1_ptr = nullptr; + +#ifdef __aarch64__ + for (int i = 0; i < h_out; i += 2) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + din4_ptr = dr4; + + doutr0_ptr = doutr0; + doutr1_ptr = doutr0 + w_out; + + dr0 = dr4; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + + //! process bottom pad + if (i + 4 >= h_in) { + switch (i + 4 - h_in) { + case 4: + din1_ptr = zero_ptr; + case 3: + din2_ptr = zero_ptr; + case 2: + din3_ptr = zero_ptr; + case 1: + din4_ptr = zero_ptr; + case 0: + din4_ptr = zero_ptr; + default: + break; + } + } + //! process output pad + if (i + 2 > h_out) { + doutr1_ptr = write_ptr; + } + int cnt = tile_w; + asm volatile( + // top + // Load up 12 elements (3 vectors) from each of 8 sources. + "0: \n" + "prfm pldl1keep, [%[inptr0]] \n" + "prfm pldl1keep, [%[inptr1]] \n" + "prfm pldl1keep, [%[inptr2]] \n" + "prfm pldl1keep, [%[inptr3]] \n" + "prfm pldl1keep, [%[inptr4]] \n" + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} + // v1={1,3,5,7} + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" + + "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias + "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias + + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + // mid + "2: \n" + // r0 + "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 + "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 + "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v2.16b, v18.16b, #4 \n" // v10 = {2,4,6,8} + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} + // v1={1,3,5,7} + + // r1 + "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 + "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 + "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v4.16b, v19.16b, #4 \n" // v10 = {2,4,6,8} + + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" + + // r2 + "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 + "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 + + "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 + "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 + + "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 + "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v6.16b, v20.16b, #4 \n" // v10 = {2,4,6,8} + + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" + + // r3 + "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 + "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 + "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v8.16b, v21.16b, #4 \n" // v10 = {2,4,6,8} + + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" + + "fadd v16.4s, v16.4s, v11.4s \n" + "fadd v16.4s, v16.4s, v12.4s \n" + + // r4 + "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 + "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 + "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 + + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "st1 {v16.4s}, [%[outptr0]], #16 \n" + + "fadd v17.4s, v17.4s, v13.4s \n" + + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + + "fadd v17.4s, v17.4s, v14.4s \n" + + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias + "subs %[cnt], %[cnt], #1 \n" + + "st1 {v17.4s}, [%[outptr1]], #16 \n" + + "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias + + "bne 2b \n" + + // right + "1: \n" + "cmp %[remain], #1 \n" + "blt 4f \n" + "3: \n" + "bif v0.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v1.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + "bif v2.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v3.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + "bif v4.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v5.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + "ext v10.16b, v0.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + + "bif v6.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v7.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + // r0 + "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 + "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 + "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v2.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + "bif v8.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v9.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + // r1 + "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 + "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 + "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v4.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + + // r2 + "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 + "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 + + "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 + "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 + + "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 + "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v6.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + + // r3 + "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 + "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 + "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v8.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + "ld1 {v0.4s}, [%[outptr0]] \n" + + "fadd v16.4s, v16.4s, v11.4s \n" + "fadd v16.4s, v16.4s, v12.4s \n" + "ld1 {v1.4s}, [%[outptr1]] \n" + + // r4 + "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 + "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 + "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 + + "bif v16.16b, v0.16b, %[wmask].16b \n" // pipei + + "fadd v17.4s, v17.4s, v13.4s \n" + + "st1 {v16.4s}, [%[outptr0]], #16 \n" + + "fadd v17.4s, v17.4s, v14.4s \n" + + "bif v17.16b, v1.16b, %[wmask].16b \n" // pipei + + "st1 {v17.4s}, [%[outptr1]], #16 \n" + "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + doutr0 = doutr0 + 2 * w_out; + } +#else + for (int i = 0; i < h_out; i++) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + + doutr0_ptr = doutr0; + + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + + //! process bottom pad + if (i + 2 > h_in) { + switch (i + 2 - h_in) { + case 2: + din1_ptr = zero_ptr; + case 1: + din2_ptr = zero_ptr; + default: + break; + } + } + int cnt = tile_w; + unsigned int* mask_ptr = dmask; + asm volatile( + // Load up 12 elements (3 vectors) from each of 8 sources. + "0: \n" + "vmov.u32 q9, #0 \n" + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" + "pld [%[din0_ptr]] @ preload data\n" + "pld [%[din1_ptr]] @ preload data\n" + "pld [%[din2_ptr]] @ preload data\n" + + "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} + + "vdup.32 q3, %[bias] @ and \n" // q10 = + // vbias + // mid + "2: \n" + "vext.32 q6, q10, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + "vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" // q2={8,10,12,14} + + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " + "out0\n" // q0 * w00 + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " + "out0\n" // q6 * w02 + + "vext.32 q7, q12, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + "vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" // q2={8,10,12,14} + + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} + + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " + "out0\n" // q0 * w00 + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " + "out0\n" // q6 * w02 + + "vext.32 q6, q14, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v0={0,2,4,6} v1={1,3,5,7} + + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " + "out0\n" // q0 * w00 + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " + "out0\n" // q6 * w02 + + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // v4={0,2,4,6} v5={1,3,5,7} + + "vadd.f32 q3, q3, q4 @ add \n" + "vadd.f32 q3, q3, q5 @ add \n" + + "subs %[cnt], #1 \n" + + "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} + + "vst1.32 {d6-d7}, [%[outptr]]! \n" + + "vdup.32 q3, %[bias] @ and \n" // q10 = + // vbias + "bne 2b \n" + + // right + "1: \n" + "cmp %[remain], #1 \n" + "blt 3f \n" + + "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" + + "vbif q10, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q11, q9, q7 @ bit select, deal " + "with right pad\n" + "vbif q12, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q13, q9, q7 @ bit select, deal " + "with right pad\n" + "vbif q14, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q15, q9, q7 @ bit select, deal " + "with right pad\n" + + "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " + "out0\n" // q0 * w00 + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " + "out0\n" // q6 * w02 + + "vext.32 q6, q14, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + "vld1.f32 {d20-d21}, [%[outptr]] @ load output\n" + + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " + "out0\n" // q0 * w00 + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " + "out0\n" // q6 * w02 + + "vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n" + + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " + "out0\n" // q0 * w00 + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " + "out0\n" // q6 * w02 + + "vadd.f32 q3, q3, q4 @ add \n" + "vadd.f32 q3, q3, q5 @ add \n" + + "vbif.f32 q3, q10, q11 @ write mask\n" + + "vst1.32 {d6-d7}, [%[outptr]]! \n" + "3: \n" + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + + doutr0 = doutr0 + w_out; + } +#endif + } + } +} + +// 4line +void conv_depthwise_3x3s1p0_bias_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + //! pad is done implicit + const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + //! for 4x6 convolution window + const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_w = w_out >> 2; + int remain = w_out % 4; + + unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); + const int remian_idx[4] = {0, 1, 2, 3}; + + uint32x4_t vmask_rp1 = + vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); + uint32x4_t vmask_rp2 = + vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); + uint32x4_t vmask_result = + vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); + + unsigned int vmask[8]; + vst1q_u32(vmask, vmask_rp1); + vst1q_u32(vmask + 4, vmask_rp2); + + unsigned int rmask[4]; + vst1q_u32(rmask, vmask_result); + + float32x4_t vzero = vdupq_n_f32(0.f); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for +#ifdef __aarch64__ + for (int c = 0; c < ch_in; c++) { + float* dout_ptr = dout_batch + c * size_out_channel; + + const float* din_ch_ptr = din_batch + c * size_in_channel; + + float bias_val = flag_bias ? bias[c] : 0.f; + float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + + const float* wei_ptr = weights + c * w_stride; + + float32x4_t wr0 = vld1q_f32(wei_ptr); + float32x4_t wr1 = vld1q_f32(wei_ptr + 3); + float32x4_t wr2 = vld1q_f32(wei_ptr + 6); + // wr0 = vsetq_lane_f32(0.f, wr0, 3); + // wr1 = vsetq_lane_f32(0.f, wr1, 3); + // wr2 = vsetq_lane_f32(0.f, wr2, 3); + + float* doutr0 = dout_ptr; + float* doutr1 = doutr0 + w_out; + float* doutr2 = doutr1 + w_out; + float* doutr3 = doutr2 + w_out; + + const float* dr0 = din_ch_ptr; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + const float* dr5 = dr4 + w_in; + + const float* din_ptr0 = dr0; + const float* din_ptr1 = dr1; + const float* din_ptr2 = dr2; + const float* din_ptr3 = dr3; + const float* din_ptr4 = dr4; + const float* din_ptr5 = dr5; + + for (int i = 0; i < h_out; i += 4) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + din_ptr4 = dr4; + din_ptr5 = dr5; + + doutr0 = dout_ptr; + doutr1 = doutr0 + w_out; + doutr2 = doutr1 + w_out; + doutr3 = doutr2 + w_out; + + dr0 = dr4; + dr1 = dr5; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + dr5 = dr4 + w_in; + + //! process bottom pad + if (i + 5 >= h_in) { + switch (i + 5 - h_in) { + case 5: + din_ptr1 = zero_ptr; + case 4: + din_ptr2 = zero_ptr; + case 3: + din_ptr3 = zero_ptr; + case 2: + din_ptr4 = zero_ptr; + case 1: + din_ptr5 = zero_ptr; + case 0: + din_ptr5 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 4 > h_out) { + switch (i + 4 - h_out) { + case 3: + doutr1 = write_ptr; + case 2: + doutr2 = write_ptr; + case 1: + doutr3 = write_ptr; + default: + break; + } + } + + int cnt = tile_w; + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "PRFM PLDL1KEEP, [%[din_ptr3]] \n" + "PRFM PLDL1KEEP, [%[din_ptr4]] \n" + "PRFM PLDL1KEEP, [%[din_ptr5]] \n" + "movi v21.4s, #0x0\n" /* out0 = 0 */ + + "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ + + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ + + // mid + "4: \n" + // r0 + "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ + + // r1 + "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ + + // r2 + "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "fmax v12.4s, v12.4s, %[vzero].4s \n" /* relu */ + + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "st1 {v12.4s}, [%[doutr0]], #16 \n" + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + // r4 + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "fmax v13.4s, v13.4s, %[vzero].4s \n" /* relu */ + + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "st1 {v13.4s}, [%[doutr1]], #16 \n" + + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + // r5 + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "fmax v14.4s, v14.4s, %[vzero].4s \n" /* relu */ + + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "st1 {v14.4s}, [%[doutr2]], #16 \n" + + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + "fmax v15.4s, v15.4s, %[vzero].4s \n" /* relu */ + + "subs %[cnt], %[cnt], #1 \n" + + "st1 {v15.4s}, [%[doutr3]], #16 \n" + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "bne 4b \n" + + // right + "5: \n" + "cmp %[remain], #1 \n" + "blt 0f \n" + "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" + "ld1 {v22.4s}, [%[doutr0]] \n" + "ld1 {v23.4s}, [%[doutr1]] \n" + "ld1 {v24.4s}, [%[doutr2]] \n" + "ld1 {v25.4s}, [%[doutr3]] \n" + + "bif v0.16b, %[vzero].16b, v18.16b \n" + "bif v1.16b, %[vzero].16b, v19.16b \n" + "bif v2.16b, %[vzero].16b, v18.16b \n" + "bif v3.16b, %[vzero].16b, v19.16b \n" + + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + + // r0 + "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "bif v4.16b, %[vzero].16b, v18.16b \n" + "bif v5.16b, %[vzero].16b, v19.16b \n" + "bif v6.16b, %[vzero].16b, v18.16b \n" + "bif v7.16b, %[vzero].16b, v19.16b \n" + + "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "bif v8.16b, %[vzero].16b, v18.16b \n" + "bif v9.16b, %[vzero].16b, v19.16b \n" + "bif v10.16b, %[vzero].16b, v18.16b \n" + "bif v11.16b, %[vzero].16b, v19.16b \n" + + "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ + "ld1 {v18.4s}, [%[rmask]] \n" + + // r1 + "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ + + // r2 + "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "fmax v12.4s, v12.4s, %[vzero].4s \n" /* relu */ + + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "bif v12.16b, v22.16b, v18.16b \n" + + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ + "st1 {v12.4s}, [%[doutr0]], #16 \n" + + // r3 + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "fmax v13.4s, v13.4s, %[vzero].4s \n" /* relu */ + + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "bif v13.16b, v23.16b, v18.16b \n" + + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "st1 {v13.4s}, [%[doutr1]], #16 \n" + + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "fmax v14.4s, v14.4s, %[vzero].4s \n" /* relu */ + + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "bif v14.16b, v24.16b, v18.16b \n" + + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "st1 {v14.4s}, [%[doutr2]], #16 \n" + + "fmax v15.4s, v15.4s, %[vzero].4s \n" /* relu */ + + "bif v15.16b, v25.16b, v18.16b \n" + + "st1 {v15.4s}, [%[doutr3]], #16 \n" + // end + "0: \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + dout_ptr = dout_ptr + 4 * w_out; + } + } +#else + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + float bias_val = flag_bias ? bias[i] : 0.f; + + float* dout_channel = dout_batch + i * size_out_channel; + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + + const float* din0_ptr = nullptr; + const float* din1_ptr = nullptr; + const float* din2_ptr = nullptr; + const float* din3_ptr = nullptr; + + float* doutr0 = nullptr; + float* doutr1 = nullptr; + + float* ptr_zero = const_cast(zero); + + for (int i = 0; i < h_out; i += 2) { + //! process top pad pad_h = 1 + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + + doutr0 = dout_channel; + doutr1 = dout_channel + w_out; + + dr0 = dr2; + dr1 = dr3; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + //! process bottom pad + if (i + 3 >= h_in) { + switch (i + 3 - h_in) { + case 3: + din1_ptr = zero_ptr; + case 2: + din2_ptr = zero_ptr; + case 1: + din3_ptr = zero_ptr; + case 0: + din3_ptr = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 2 > h_out) { + doutr1 = write_ptr; + } + int cnt = tile_w; + unsigned int* rmask_ptr = rmask; + unsigned int* vmask_ptr = vmask; + asm volatile( + "pld [%[din0_ptr]] @ preload data\n" + "pld [%[din1_ptr]] @ preload data\n" + "pld [%[din2_ptr]] @ preload data\n" + "pld [%[din3_ptr]] @ preload data\n" + + "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" + "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r1\n" + "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r2\n" + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r3\n" + "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" + "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" + "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" + + "vdup.32 q4, %[bias_val] @ and \n" // q4 + // = + // vbias + "vdup.32 q5, %[bias_val] @ and \n" // q5 + // = + // vbias + + "vext.32 q6, q8, q9, #1 @ 1234\n" + "vext.32 q7, q8, q9, #2 @ 2345\n" + + // mid + "1: @ right pad entry\n" + // r0 + "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" + + "pld [%[din0_ptr]] @ preload data\n" + "pld [%[din1_ptr]] @ preload data\n" + "pld [%[din2_ptr]] @ preload data\n" + "pld [%[din3_ptr]] @ preload data\n" + + "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" + + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" + + "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" + + "vext.32 q6, q10, q11, #1 @ 1234\n" + "vext.32 q7, q10, q11, #2 @ 2345\n" + + // r1 + "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" + + "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" + + "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q12, q13, #1 @ 1234\n" + "vext.32 q7, q12, q13, #2 @ 2345\n" + + // r2 + "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" + + "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" + + "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q14, q15, #1 @ 1234\n" + "vext.32 q7, q14, q15, #2 @ 2345\n" + + // r3 + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" + + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" + "vmax.f32 q4, q4, %q[vzero] @ relu \n" + + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" + + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" + + "vext.32 q6, q8, q9, #1 @ 1234\n" + "vext.32 q7, q8, q9, #2 @ 2345\n" + "vmax.f32 q5, q5, %q[vzero] @ relu \n" + + "vdup.32 q4, %[bias_val] @ and \n" // q4 + // = + // vbias + + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " + "pointer\n" + + "subs %[cnt], #1 @ loop count minus 1\n" + + "vdup.32 q5, %[bias_val] @ and \n" // q4 + // = + // vbias + + "bne 1b @ jump to main loop start " + "point\n" + + // right + "3: @ right pad entry\n" + "cmp %[remain], #1 @ check whether has " + "mid cols\n" + "blt 0f @ jump to main loop start " + "point\n" + "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" + "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" + + "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" + "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" + + "vbif d16, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d17, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d18, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vbif d20, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d21, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d22, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vext.32 q6, q8, q9, #1 @ 1234\n" + "vext.32 q7, q8, q9, #2 @ 2345\n" + + // r0 + "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" + + "vbif d24, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d25, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d26, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + + "vbif d28, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d29, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d30, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" + + "vext.32 q6, q10, q11, #1 @ 1234\n" + "vext.32 q7, q10, q11, #2 @ 2345\n" + + // r1 + "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" + "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" + + "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" + "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" + + "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q12, q13, #1 @ 1234\n" + "vext.32 q7, q12, q13, #2 @ 2345\n" + + // r2 + "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q14, q15, #1 @ 1234\n" + "vext.32 q7, q14, q15, #2 @ 2345\n" + + // r3 + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" + + "vmax.f32 q4, q4, %q[vzero] @ relu \n" + + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vbif d8, d16, d19 @ bit select, deal with right pad\n" + "vbif d9, d17, d23 @ bit select, deal with right pad\n" + + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" + + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" + + "vmax.f32 q5, q5, %q[vzero] @ relu \n" + + "vbif d10, d20, d19 @ bit select, deal with right " + "pad\n" + "vbif d11, d21, d23 @ bit select, deal with right " + "pad\n" + + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " + "pointer\n" + "0: \n" + + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + dout_channel += 2 * w_out; + } //! end of processing mid rows + } +#endif + } +} +/** + * \brief depthwise convolution kernel 3x3, stride 2, with reulu + */ +// w_in > 7 +void conv_depthwise_3x3s2p0_bias_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + + int tile_w = w_out >> 2; + int cnt_remain = w_out % 4; + + unsigned int size_right_remain = (unsigned int)(w_in - (tile_w << 3)); + + uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + uint32x4_t wmask = + vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + unsigned int dmask[12]; + + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + vst1q_u32(dmask + 8, wmask); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float32x4_t vzero = vdupq_n_f32(0.f); + + float32x4_t wbias; + float bias_c = 0.f; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + bias_c = bias[i]; + } else { + wbias = vdupq_n_f32(0.f); + } + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + const float* din3_ptr = dr3; + const float* din4_ptr = dr4; + + float* doutr0 = dout_channel; + float* doutr0_ptr = nullptr; + float* doutr1_ptr = nullptr; + +#ifdef __aarch64__ + for (int i = 0; i < h_out; i += 2) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + din4_ptr = dr4; + + doutr0_ptr = doutr0; + doutr1_ptr = doutr0 + w_out; + + dr0 = dr4; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + + //! process bottom pad + if (i + 4 >= h_in) { + switch (i + 4 - h_in) { + case 4: + din1_ptr = zero_ptr; + case 3: + din2_ptr = zero_ptr; + case 2: + din3_ptr = zero_ptr; + case 1: + din4_ptr = zero_ptr; + case 0: + din4_ptr = zero_ptr; + default: + break; + } + } + //! process output pad + if (i + 2 > h_out) { + doutr1_ptr = write_ptr; + } + int cnt = tile_w; + asm volatile( + // top + // Load up 12 elements (3 vectors) from each of 8 sources. + "0: \n" + "prfm pldl1keep, [%[inptr0]] \n" + "prfm pldl1keep, [%[inptr1]] \n" + "prfm pldl1keep, [%[inptr2]] \n" + "prfm pldl1keep, [%[inptr3]] \n" + "prfm pldl1keep, [%[inptr4]] \n" + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} + // v1={1,3,5,7} + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" + + "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias + "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias + + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + // mid + "2: \n" + // r0 + "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 + "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 + "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v2.16b, v18.16b, #4 \n" // v10 = {2,4,6,8} + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} + // v1={1,3,5,7} + + // r1 + "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 + "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 + "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v4.16b, v19.16b, #4 \n" // v10 = {2,4,6,8} + + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" + + // r2 + "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 + "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 + + "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 + "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 + + "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 + "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v6.16b, v20.16b, #4 \n" // v10 = {2,4,6,8} + + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" + + // r3 + "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 + "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 + "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v8.16b, v21.16b, #4 \n" // v10 = {2,4,6,8} + + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" + + "fadd v16.4s, v16.4s, v11.4s \n" + "fadd v16.4s, v16.4s, v12.4s \n" + + // r4 + "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 + "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 + "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 + + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ + + "fadd v17.4s, v17.4s, v13.4s \n" + + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + "st1 {v16.4s}, [%[outptr0]], #16 \n" + + "fadd v17.4s, v17.4s, v14.4s \n" + + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ + + "subs %[cnt], %[cnt], #1 \n" + + "st1 {v17.4s}, [%[outptr1]], #16 \n" + + "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias + + "bne 2b \n" + + // right + "1: \n" + "cmp %[remain], #1 \n" + "blt 4f \n" + "3: \n" + "bif v0.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v1.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + "bif v2.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v3.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + "bif v4.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v5.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + "ext v10.16b, v0.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + + "bif v6.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v7.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + // r0 + "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 + "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 + "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v2.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + "bif v8.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v9.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + // r1 + "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 + "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 + "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v4.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + + // r2 + "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 + "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 + + "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 + "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 + + "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 + "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v6.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + + // r3 + "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 + "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 + "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v8.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + "ld1 {v0.4s}, [%[outptr0]] \n" + + "fadd v16.4s, v16.4s, v11.4s \n" + "fadd v16.4s, v16.4s, v12.4s \n" + "ld1 {v1.4s}, [%[outptr1]] \n" + + // r4 + "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 + "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 + "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 + + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ + + "fadd v17.4s, v17.4s, v13.4s \n" + + "bif v16.16b, v0.16b, %[wmask].16b \n" // pipei + + "fadd v17.4s, v17.4s, v14.4s \n" + + "st1 {v16.4s}, [%[outptr0]], #16 \n" + + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ + + "bif v17.16b, v1.16b, %[wmask].16b \n" // pipei + + "st1 {v17.4s}, [%[outptr1]], #16 \n" + "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + doutr0 = doutr0 + 2 * w_out; + } +#else + for (int i = 0; i < h_out; i++) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + + doutr0_ptr = doutr0; + + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + + //! process bottom pad + if (i + 2 > h_in) { + switch (i + 2 - h_in) { + case 2: + din1_ptr = zero_ptr; + case 1: + din2_ptr = zero_ptr; + default: + break; + } + } + int cnt = tile_w; + unsigned int* mask_ptr = dmask; + asm volatile( + // Load up 12 elements (3 vectors) from each of 8 sources. + "0: \n" + "vmov.u32 q9, #0 \n" + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" + "pld [%[din0_ptr]] @ preload data\n" + "pld [%[din1_ptr]] @ preload data\n" + "pld [%[din2_ptr]] @ preload data\n" + + "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} + + "vdup.32 q3, %[bias] @ and \n" // q10 = + // vbias + // mid + "2: \n" + "vext.32 q6, q10, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + "vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" // q2={8,10,12,14} + + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " + "out0\n" // q0 * w00 + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " + "out0\n" // q6 * w02 + + "vext.32 q7, q12, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + "vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" // q2={8,10,12,14} + + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} + + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " + "out0\n" // q0 * w00 + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " + "out0\n" // q6 * w02 + + "vext.32 q6, q14, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v0={0,2,4,6} v1={1,3,5,7} + + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " + "out0\n" // q0 * w00 + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " + "out0\n" // q6 * w02 + + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // v4={0,2,4,6} v5={1,3,5,7} + + "vadd.f32 q3, q3, q4 @ add \n" + "vadd.f32 q3, q3, q5 @ add \n" + + "subs %[cnt], #1 \n" + "vmax.f32 q3, q3, q9 @ relu \n" + + "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} + + "vst1.32 {d6-d7}, [%[outptr]]! \n" + + "vdup.32 q3, %[bias] @ and \n" // q10 = + // vbias + "bne 2b \n" + + // right + "1: \n" + "cmp %[remain], #1 \n" + "blt 3f \n" + + "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" + + "vbif q10, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q11, q9, q7 @ bit select, deal " + "with right pad\n" + "vbif q12, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q13, q9, q7 @ bit select, deal " + "with right pad\n" + "vbif q14, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q15, q9, q7 @ bit select, deal " + "with right pad\n" + + "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " + "out0\n" // q0 * w00 + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " + "out0\n" // q6 * w02 + + "vext.32 q6, q14, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + "vld1.f32 {d20-d21}, [%[outptr]] @ load output\n" + + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " + "out0\n" // q0 * w00 + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " + "out0\n" // q6 * w02 + + "vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n" + + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " + "out0\n" // q0 * w00 + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " + "out0\n" // q6 * w02 + + "vadd.f32 q3, q3, q4 @ add \n" + "vadd.f32 q3, q3, q5 @ add \n" + + "vmax.f32 q3, q3, q9 @ relu \n" + + "vbif.f32 q3, q10, q11 @ write mask\n" + + "vst1.32 {d6-d7}, [%[outptr]]! \n" + "3: \n" + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + + doutr0 = doutr0 + w_out; + } +#endif + } + } +} +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, + * width <= 4 + */ +void conv_depthwise_3x3s1p0_bias_s(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + //! 3x3s1 convolution, implemented by direct algorithm + //! pad is done implicit + //! for 4x6 convolution window + const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; + const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + + float32x4_t vzero = vdupq_n_f32(0.f); + uint32x4_t vmask_rp1 = + vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); + uint32x4_t vmask_rp2 = + vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); + + unsigned int vmask[8]; + vst1q_u32(vmask, vmask_rp1); + vst1q_u32(vmask + 4, vmask_rp2); + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + float* dout_channel = dout_batch + i * size_out_channel; + const float* din_channel = din_batch + i * size_in_channel; + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } + + float out_buf1[4]; + float out_buf2[4]; + float trash_buf[4]; + + float* doutr0 = dout_channel; + float* doutr1 = dout_channel + w_out; + + for (int j = 0; j < h_out; j += 2) { + const float* dr0 = din_channel + j * w_in; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + + doutr0 = dout_channel + j * w_out; + doutr1 = doutr0 + w_out; + + if (j + 3 >= h_in) { + switch (j + 3 - h_in) { + case 3: + dr1 = zero_ptr; + case 2: + dr2 = zero_ptr; + case 1: + dr3 = zero_ptr; + doutr1 = trash_buf; + case 0: + dr3 = zero_ptr; + doutr1 = trash_buf; + default: + break; + } + } +#ifdef __aarch64__ + asm volatile( + "prfm pldl1keep, [%[din0]]\n" + "prfm pldl1keep, [%[din1]]\n" + "prfm pldl1keep, [%[din2]]\n" + "prfm pldl1keep, [%[din3]]\n" + + "ld1 {v0.4s, v1.4s}, [%[din0]]\n" + "ld1 {v2.4s, v3.4s}, [%[din1]]\n" + "ld1 {v4.4s, v5.4s}, [%[din2]]\n" + "ld1 {v6.4s, v7.4s}, [%[din3]]\n" + + "bif v0.16b, %[zero].16b, %[mask1].16b\n" // d0_1234 + "bif v1.16b, %[zero].16b, %[mask2].16b\n" // d0_1234 + + "bif v2.16b, %[zero].16b, %[mask1].16b\n" // d1_1234 + "bif v3.16b, %[zero].16b, %[mask2].16b\n" // d1_1234 + + "bif v4.16b, %[zero].16b, %[mask1].16b\n" // d2_1234 + "bif v5.16b, %[zero].16b, %[mask2].16b\n" // d2_1234 + + "bif v6.16b, %[zero].16b, %[mask1].16b\n" // d3_1234 + "bif v7.16b, %[zero].16b, %[mask2].16b\n" // d3_1234 + + "ext v8.16b, v0.16b, v1.16b, #4\n" // d1_2345 + "ext v9.16b, v0.16b, v1.16b, #8\n" // d1_3450 + + "and v12.16b, %[vbias].16b, %[vbias].16b \n" // v12 = vbias + "and v13.16b, %[vbias].16b, %[vbias].16b \n" // v13 = vbias + + // r0 + "fmul v10.4s, v0.4s, %[wr0].s[0]\n" // d0_1234 * w0[0] + "fmul v11.4s, v8.4s, %[wr0].s[1]\n" // d1_2345 * w0[1] + "fmla v12.4s, v9.4s, %[wr0].s[2]\n" // d0_3456 * w0[2] + + "ext v8.16b, v2.16b, v3.16b, #4\n" // d1_2345 + "ext v9.16b, v2.16b, v3.16b, #8\n" // d1_3450 + + // r1 + "fmul v14.4s, v2.4s, %[wr0].s[0]\n" // d0_1234 * w0[0] + "fmla v10.4s, v2.4s, %[wr1].s[0]\n" // d0_1234 * w0[0] + + "fmul v15.4s, v8.4s, %[wr0].s[1]\n" // d1_2345 * w0[1] + "fmla v11.4s, v8.4s, %[wr1].s[1]\n" // d1_2345 * w0[1] + + "fmla v13.4s, v9.4s, %[wr0].s[2]\n" // d0_3456 * w0[2] + "fmla v12.4s, v9.4s, %[wr1].s[2]\n" // d0_3456 * w0[2] + + "ext v8.16b, v4.16b, v5.16b, #4\n" // d1_2345 + "ext v9.16b, v4.16b, v5.16b, #8\n" // d1_3450 + + // r2 + "fmla v14.4s, v4.4s, %[wr1].s[0]\n" // d0_1234 * w0[0] + "fmla v10.4s, v4.4s, %[wr2].s[0]\n" // d0_1234 * w0[0] + + "fmla v15.4s, v8.4s, %[wr1].s[1]\n" // d1_2345 * w0[1] + "fmla v11.4s, v8.4s, %[wr2].s[1]\n" // d1_2345 * w0[1] + + "fmla v13.4s, v9.4s, %[wr1].s[2]\n" // d0_3456 * w0[2] + "fmla v12.4s, v9.4s, %[wr2].s[2]\n" // d0_3456 * w0[2] + + "ext v8.16b, v6.16b, v7.16b, #4\n" // d1_2345 + "ext v9.16b, v6.16b, v7.16b, #8\n" // d1_3450 + + // r3 + "fmla v14.4s, v6.4s, %[wr2].s[0]\n" // d0_1234 * w0[0] + + "fmla v15.4s, v8.4s, %[wr2].s[1]\n" // d1_2345 * w0[1] + + "fadd v12.4s, v12.4s, v10.4s\n" + + "fmla v13.4s, v9.4s, %[wr2].s[2]\n" // d0_3456 * w0[2] + + "fadd v12.4s, v12.4s, v11.4s\n" // out1 + "fadd v13.4s, v13.4s, v14.4s\n" // out2 + "fadd v13.4s, v13.4s, v15.4s\n" // out2 + + "prfm pldl1keep, [%[out1]]\n" + "prfm pldl1keep, [%[out2]]\n" + + "st1 {v12.4s}, [%[out1]]\n" + "st1 {v13.4s}, [%[out2]]\n" + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [zero] "w"(vzero), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); +#else + unsigned int* vmask_ptr = vmask; + float bias_val = flag_bias ? bias[i] : 0.f; + asm volatile( + "pld [%[din0]]\n" + "pld [%[din1]]\n" + "pld [%[din2]]\n" + "pld [%[din3]]\n" + + "vld1.32 {d16-d18}, [%[din0]] @ load din r0\n" + "vld1.32 {d20-d22}, [%[din1]] @ load din r1\n" + "vld1.32 {d24-d26}, [%[din2]] @ load din r2\n" + "vld1.32 {d28-d30}, [%[din3]] @ load din r3\n" + + "vdup.32 q4, %[bias_val] @ and \n" // q4 + // = + // vbias + "vdup.32 q5, %[bias_val] @ and \n" // q5 + // = + // vbias + + "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" + "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" + + "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" + + "vbif d16, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d20, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + + "vbif d17, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d21, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + + "vbif d18, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + "vbif d22, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vext.32 q6, q8, q9, #1 @ 1234\n" + "vext.32 q7, q8, q9, #2 @ 2345\n" + + // r0 + "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" + + "vbif d24, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d25, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d26, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + + "vbif d28, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d29, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d30, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" + + "vext.32 q6, q10, q11, #1 @ 1234\n" + "vext.32 q7, q10, q11, #2 @ 2345\n" + + // r1 + "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vmul.f32 q8, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + "vmul.f32 q10, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + + "vmul.f32 q9, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" + "vmul.f32 q11, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q12, q13, #1 @ 1234\n" + "vext.32 q7, q12, q13, #2 @ 2345\n" + + // r2 + "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vmla.f32 q8, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q10, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vmla.f32 q9, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q11, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q14, q15, #1 @ 1234\n" + "vext.32 q7, q14, q15, #2 @ 2345\n" + + // r3 + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" + + "vmla.f32 q8, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + "vadd.f32 q4, q4, q10 @ q4 += q10 \n" + + "pld [%[out1]]\n" + "pld [%[out2]]\n" + + "vmla.f32 q9, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" + "vadd.f32 q4, q4, q11 @ q4 += q10 \n" + + "vadd.f32 q5, q5, q8 @ q4 += q10 \n" + "vadd.f32 q5, q5, q9 @ q4 += q10 \n" + + "vst1.32 {d8-d9}, [%[out1]] @ store result, add pointer\n" + "vst1.32 {d10-d11}, [%[out2]] @ store result, add pointer\n" + + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [bias_val] "r"(bias_val), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif // __aarch64__ + for (int w = 0; w < w_out; ++w) { + *doutr0++ = out_buf1[w]; + *doutr1++ = out_buf2[w]; + } + } // end of processing heights + } // end of processing channels + } // end of processing batchs +} +/** + * \brief depthwise convolution kernel 3x3, stride 2, width <= 4 + */ + +void conv_depthwise_3x3s2p0_bias_s(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + float zeros[8] = {0.0f}; + + uint32x4_t vmask_rp1 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + unsigned int dmask[8]; + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float bias_c = 0.f; + + if (flag_bias) { + bias_c = bias[i]; + } + float32x4_t vbias = vdupq_n_f32(bias_c); + float out_buf[4]; + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + for (int j = 0; j < h_out; ++j) { + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + + unsigned int* mask_ptr = dmask; +#ifdef __aarch64__ + asm volatile( + // Load up 12 elements (3 vectors) from each of 8 sources. + "movi v9.4s, #0 \n" + "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" + + "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" // v10={0,2,4,6} + // v11={1,3,5,7} + "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" // v13={0,2,4,6} + // v12={1,3,5,7} + "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" // v14={0,2,4,6} + // v15={1,3,5,7} + "and v4.16b, %[bias].16b, %[bias].16b \n" // v10 = vbias + + "bif v10.16b, v9.16b, v6.16b \n" + "bif v11.16b, v9.16b, v7.16b \n" + "bif v12.16b, v9.16b, v6.16b \n" + "bif v13.16b, v9.16b, v7.16b \n" + "bif v14.16b, v9.16b, v6.16b \n" + "bif v15.16b, v9.16b, v7.16b \n" + + "ext v6.16b, v10.16b, v9.16b, #4 \n" // v6 = + // {2,4,6,8} + "ext v7.16b, v12.16b, v9.16b, #4 \n" // v6 = + // {2,4,6,8} + "ext v8.16b, v14.16b, v9.16b, #4 \n" // v6 = + // {2,4,6,8} + + "fmla v4.4s, v10.4s, %[wr0].s[0] \n" // 0246 * w00 + "fmul v5.4s, v11.4s, %[wr0].s[1] \n" // 1357 * w01 + "fmul v16.4s, v6.4s, %[wr0].s[2] \n" // 2468 * w02 + + "fmla v4.4s, v12.4s, %[wr1].s[0] \n" // v12 * w11 + "fmla v5.4s, v13.4s, %[wr1].s[1] \n" // v13 * w12 + "fmla v16.4s, v7.4s, %[wr1].s[2] \n" // v7 * w10 + + "fmla v4.4s, v14.4s, %[wr2].s[0] \n" // v14 * w20 + "fmla v5.4s, v15.4s, %[wr2].s[1] \n" // v15 * w21 + "fmla v16.4s, v8.4s, %[wr2].s[2] \n" // v8 * w22 + + "fadd v4.4s, v4.4s, v5.4s \n" + "fadd v4.4s, v4.4s, v16.4s \n" + + // "fadd v4.4s, v4.4s, %[bias].4s \n" + "st1 {v4.4s}, [%[out]] \n" + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [out] "r"(out_buf) + : "cc", + "memory", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); + +#else + asm volatile( + // Load up 12 elements (3 vectors) from each of 8 sources. + "vmov.u32 q9, #0 \n" + "vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask\n" + "vdup.32 q3, %[bias] @ and \n" // q3 = + // vbias + + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // q10={0,2,4,6} q11={1,3,5,7} + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // q13={0,2,4,6} q12={1,3,5,7} + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // q14={0,2,4,6} q15={1,3,5,7} + + "vbif q10, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q11, q9, q7 @ bit select, deal " + "with right pad\n" + "vbif q12, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q13, q9, q7 @ bit select, deal " + "with right pad\n" + "vbif q14, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q15, q9, q7 @ bit select, deal " + "with right pad\n" + + "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,0} + "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q7 = {2,4,6,0} + "vext.32 q8, q14, q9, #1 @ shift left 1 \n" // q8 = {2,4,6,0} + + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " + "out0\n" // {0,2,4,6} + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " + "out0\n" // {1,3,5,7} + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " + "out0\n" // {2,4,6,0} + + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " + "out0\n" // q12 * w11 + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " + "out0\n" // q13 * w12 + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " + "out0\n" // q7 * w10 + + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " + "out0\n" // q14 * w20 + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " + "out0\n" // q15 * w21 + "vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, " + "out0\n" // q8 * w22 + + "vadd.f32 q3, q3, q4 @ add \n" + "vadd.f32 q3, q3, q5 @ add \n" + + "vst1.32 {d6-d7}, [%[out]] \n" + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf), + [mask_ptr] "r"(dmask) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif // __aarch64__ + for (int w = 0; w < w_out; ++w) { + *dout_channel++ = out_buf[w]; + } + } + } + } +} +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, + * width <= 4 + */ +void conv_depthwise_3x3s1p0_bias_s_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + //! 3x3s1 convolution, implemented by direct algorithm + //! pad is done implicit + //! for 4x6 convolution window + const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; + const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + + float32x4_t vzero = vdupq_n_f32(0.f); + uint32x4_t vmask_rp1 = + vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); + uint32x4_t vmask_rp2 = + vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); + + unsigned int vmask[8]; + vst1q_u32(vmask, vmask_rp1); + vst1q_u32(vmask + 4, vmask_rp2); + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + float* dout_channel = dout_batch + i * size_out_channel; + const float* din_channel = din_batch + i * size_in_channel; + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } + + float out_buf1[4]; + float out_buf2[4]; + float trash_buf[4]; + + float* doutr0 = dout_channel; + float* doutr1 = dout_channel + w_out; + + for (int j = 0; j < h_out; j += 2) { + const float* dr0 = din_channel + j * w_in; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + + doutr0 = dout_channel + j * w_out; + doutr1 = doutr0 + w_out; + + if (j + 3 >= h_in) { + switch (j + 3 - h_in) { + case 3: + dr1 = zero_ptr; + case 2: + dr2 = zero_ptr; + case 1: + dr3 = zero_ptr; + doutr1 = trash_buf; + case 0: + dr3 = zero_ptr; + doutr1 = trash_buf; + default: + break; + } + } +#ifdef __aarch64__ + asm volatile( + "prfm pldl1keep, [%[din0]]\n" + "prfm pldl1keep, [%[din1]]\n" + "prfm pldl1keep, [%[din2]]\n" + "prfm pldl1keep, [%[din3]]\n" + + "ld1 {v0.4s, v1.4s}, [%[din0]]\n" + "ld1 {v2.4s, v3.4s}, [%[din1]]\n" + "ld1 {v4.4s, v5.4s}, [%[din2]]\n" + "ld1 {v6.4s, v7.4s}, [%[din3]]\n" + + "bif v0.16b, %[zero].16b, %[mask1].16b\n" // d0_1234 + "bif v1.16b, %[zero].16b, %[mask2].16b\n" // d0_1234 + + "bif v2.16b, %[zero].16b, %[mask1].16b\n" // d1_1234 + "bif v3.16b, %[zero].16b, %[mask2].16b\n" // d1_1234 + + "bif v4.16b, %[zero].16b, %[mask1].16b\n" // d2_1234 + "bif v5.16b, %[zero].16b, %[mask2].16b\n" // d2_1234 + + "bif v6.16b, %[zero].16b, %[mask1].16b\n" // d3_1234 + "bif v7.16b, %[zero].16b, %[mask2].16b\n" // d3_1234 + + "ext v8.16b, v0.16b, v1.16b, #4\n" // d1_2345 + "ext v9.16b, v0.16b, v1.16b, #8\n" // d1_3450 + + "and v12.16b, %[vbias].16b, %[vbias].16b \n" // v12 = vbias + "and v13.16b, %[vbias].16b, %[vbias].16b \n" // v13 = vbias + + // r0 + "fmul v10.4s, v0.4s, %[wr0].s[0]\n" // d0_1234 * w0[0] + "fmul v11.4s, v8.4s, %[wr0].s[1]\n" // d1_2345 * w0[1] + "fmla v12.4s, v9.4s, %[wr0].s[2]\n" // d0_3456 * w0[2] + + "ext v8.16b, v2.16b, v3.16b, #4\n" // d1_2345 + "ext v9.16b, v2.16b, v3.16b, #8\n" // d1_3450 + + // r1 + "fmul v14.4s, v2.4s, %[wr0].s[0]\n" // d0_1234 * w0[0] + "fmla v10.4s, v2.4s, %[wr1].s[0]\n" // d0_1234 * w0[0] + + "fmul v15.4s, v8.4s, %[wr0].s[1]\n" // d1_2345 * w0[1] + "fmla v11.4s, v8.4s, %[wr1].s[1]\n" // d1_2345 * w0[1] + + "fmla v13.4s, v9.4s, %[wr0].s[2]\n" // d0_3456 * w0[2] + "fmla v12.4s, v9.4s, %[wr1].s[2]\n" // d0_3456 * w0[2] + + "ext v8.16b, v4.16b, v5.16b, #4\n" // d1_2345 + "ext v9.16b, v4.16b, v5.16b, #8\n" // d1_3450 + + // r2 + "fmla v14.4s, v4.4s, %[wr1].s[0]\n" // d0_1234 * w0[0] + "fmla v10.4s, v4.4s, %[wr2].s[0]\n" // d0_1234 * w0[0] + + "fmla v15.4s, v8.4s, %[wr1].s[1]\n" // d1_2345 * w0[1] + "fmla v11.4s, v8.4s, %[wr2].s[1]\n" // d1_2345 * w0[1] + + "fmla v13.4s, v9.4s, %[wr1].s[2]\n" // d0_3456 * w0[2] + "fmla v12.4s, v9.4s, %[wr2].s[2]\n" // d0_3456 * w0[2] + + "ext v8.16b, v6.16b, v7.16b, #4\n" // d1_2345 + "ext v9.16b, v6.16b, v7.16b, #8\n" // d1_3450 + + // r3 + "fmla v14.4s, v6.4s, %[wr2].s[0]\n" // d0_1234 * w0[0] + + "fmla v15.4s, v8.4s, %[wr2].s[1]\n" // d1_2345 * w0[1] + + "fadd v12.4s, v12.4s, v10.4s\n" + + "fmla v13.4s, v9.4s, %[wr2].s[2]\n" // d0_3456 * w0[2] + + "fadd v12.4s, v12.4s, v11.4s\n" // out1 + "fadd v13.4s, v13.4s, v14.4s\n" // out2 + "fadd v13.4s, v13.4s, v15.4s\n" // out2 + + "prfm pldl1keep, [%[out1]]\n" + "prfm pldl1keep, [%[out2]]\n" + "fmax v12.4s, v12.4s, %[zero].4s \n" + "fmax v13.4s, v13.4s, %[zero].4s \n" + + "st1 {v12.4s}, [%[out1]]\n" + "st1 {v13.4s}, [%[out2]]\n" + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [zero] "w"(vzero), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); +#else + unsigned int* vmask_ptr = vmask; + float bias_val = flag_bias ? bias[i] : 0.f; + asm volatile( + "pld [%[din0]]\n" + "pld [%[din1]]\n" + "pld [%[din2]]\n" + "pld [%[din3]]\n" + + "vld1.32 {d16-d18}, [%[din0]] @ load din r0\n" + "vld1.32 {d20-d22}, [%[din1]] @ load din r1\n" + "vld1.32 {d24-d26}, [%[din2]] @ load din r2\n" + "vld1.32 {d28-d30}, [%[din3]] @ load din r3\n" + + "vdup.32 q4, %[bias_val] @ and \n" // q4 + // = + // vbias + "vdup.32 q5, %[bias_val] @ and \n" // q5 + // = + // vbias + + "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" + "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" + + "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" + + "vbif d16, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d20, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + + "vbif d17, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d21, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + + "vbif d18, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + "vbif d22, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vext.32 q6, q8, q9, #1 @ 1234\n" + "vext.32 q7, q8, q9, #2 @ 2345\n" + + // r0 + "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" + + "vbif d24, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d25, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d26, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + + "vbif d28, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d29, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d30, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" + + "vext.32 q6, q10, q11, #1 @ 1234\n" + "vext.32 q7, q10, q11, #2 @ 2345\n" + + // r1 + "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vmul.f32 q8, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + "vmul.f32 q10, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + + "vmul.f32 q9, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" + "vmul.f32 q11, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q12, q13, #1 @ 1234\n" + "vext.32 q7, q12, q13, #2 @ 2345\n" + + // r2 + "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vmla.f32 q8, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q10, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vmla.f32 q9, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q11, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q14, q15, #1 @ 1234\n" + "vext.32 q7, q14, q15, #2 @ 2345\n" + + // r3 + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" + + "vmla.f32 q8, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + "vadd.f32 q4, q4, q10 @ q4 += q10 \n" + + "pld [%[out1]]\n" + "pld [%[out2]]\n" + + "vmla.f32 q9, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" + "vadd.f32 q4, q4, q11 @ q4 += q10 \n" + + "vadd.f32 q5, q5, q8 @ q4 += q10 \n" + "vadd.f32 q5, q5, q9 @ q4 += q10 \n" + "vmax.f32 q4, q4, %q[vzero] @ relu \n" + "vmax.f32 q5, q5, %q[vzero] @ relu \n" + + "vst1.32 {d8-d9}, [%[out1]] @ store result, add pointer\n" + "vst1.32 {d10-d11}, [%[out2]] @ store result, add pointer\n" + + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [bias_val] "r"(bias_val), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif // __aarch64__ + for (int w = 0; w < w_out; ++w) { + *doutr0++ = out_buf1[w]; + *doutr1++ = out_buf2[w]; + } + // doutr0 = doutr1; + // doutr1 += w_out; + } // end of processing heights + } // end of processing channels + } // end of processing batchs +} + +/** + * \brief depthwise convolution kernel 3x3, stride 2, width <= 7 + */ +void conv_depthwise_3x3s2p0_bias_s_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + float zeros[8] = {0.0f}; + + uint32x4_t vmask_rp1 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + unsigned int dmask[8]; + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float bias_c = 0.f; + + if (flag_bias) { + bias_c = bias[i]; + } + float32x4_t vbias = vdupq_n_f32(bias_c); + float out_buf[4]; + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + for (int j = 0; j < h_out; ++j) { + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + + unsigned int* mask_ptr = dmask; +#ifdef __aarch64__ + asm volatile( + // Load up 12 elements (3 vectors) from each of 8 sources. + "movi v9.4s, #0 \n" + "ld1 {v6.4s, v7.4s}, [%[mask_ptr]] \n" + + "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" // v10={0,2,4,6} + // v11={1,3,5,7} + "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" // v13={0,2,4,6} + // v12={1,3,5,7} + "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" // v14={0,2,4,6} + // v15={1,3,5,7} + "and v4.16b, %[bias].16b, %[bias].16b \n" // v10 = vbias + + "bif v10.16b, v9.16b, v6.16b \n" + "bif v11.16b, v9.16b, v7.16b \n" + "bif v12.16b, v9.16b, v6.16b \n" + "bif v13.16b, v9.16b, v7.16b \n" + "bif v14.16b, v9.16b, v6.16b \n" + "bif v15.16b, v9.16b, v7.16b \n" + + "ext v6.16b, v10.16b, v9.16b, #4 \n" // v6 = + // {2,4,6,8} + "ext v7.16b, v12.16b, v9.16b, #4 \n" // v6 = + // {2,4,6,8} + "ext v8.16b, v14.16b, v9.16b, #4 \n" // v6 = + // {2,4,6,8} + + "fmla v4.4s, v10.4s, %[wr0].s[0] \n" // 0246 * w00 + "fmul v5.4s, v11.4s, %[wr0].s[1] \n" // 1357 * w01 + "fmul v16.4s, v6.4s, %[wr0].s[2] \n" // 2468 * w02 + + "fmla v4.4s, v12.4s, %[wr1].s[0] \n" // v12 * w11 + "fmla v5.4s, v13.4s, %[wr1].s[1] \n" // v13 * w12 + "fmla v16.4s, v7.4s, %[wr1].s[2] \n" // v7 * w10 + + "fmla v4.4s, v14.4s, %[wr2].s[0] \n" // v14 * w20 + "fmla v5.4s, v15.4s, %[wr2].s[1] \n" // v15 * w21 + "fmla v16.4s, v8.4s, %[wr2].s[2] \n" // v8 * w22 + + "fadd v4.4s, v4.4s, v5.4s \n" + "fadd v4.4s, v4.4s, v16.4s \n" + "fmax v4.4s, v4.4s, v9.4s \n" + + // "fadd v4.4s, v4.4s, %[bias].4s \n" + "st1 {v4.4s}, [%[out]] \n" + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [out] "r"(out_buf), + [mask_ptr] "r"(mask_ptr) + : "cc", + "memory", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); + +#else + asm volatile( + // Load up 12 elements (3 vectors) from each of 8 sources. + "vmov.u32 q9, #0 \n" + "vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask\n" + "vdup.32 q3, %[bias] @ and \n" // q3 = + // vbias + + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // q10={0,2,4,6} q11={1,3,5,7} + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // q13={0,2,4,6} q12={1,3,5,7} + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // q14={0,2,4,6} q15={1,3,5,7} + + "vbif q10, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q11, q9, q7 @ bit select, deal " + "with right pad\n" + "vbif q12, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q13, q9, q7 @ bit select, deal " + "with right pad\n" + "vbif q14, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q15, q9, q7 @ bit select, deal " + "with right pad\n" + + "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,0} + "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q7 = {2,4,6,0} + "vext.32 q8, q14, q9, #1 @ shift left 1 \n" // q8 = {2,4,6,0} + + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " + "out0\n" // {0,2,4,6} + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " + "out0\n" // {1,3,5,7} + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " + "out0\n" // {2,4,6,0} + + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " + "out0\n" // q12 * w11 + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " + "out0\n" // q13 * w12 + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " + "out0\n" // q7 * w10 + + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " + "out0\n" // q14 * w20 + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " + "out0\n" // q15 * w21 + "vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, " + "out0\n" // q8 * w22 + + "vadd.f32 q3, q3, q4 @ add \n" + "vadd.f32 q3, q3, q5 @ add \n" + + "vmax.f32 q3, q3, q9 @ relu \n" + + "vst1.32 {d6-d7}, [%[out]] \n" + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf), + [mask_ptr] "r"(mask_ptr) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif // __aarch64__ + for (int w = 0; w < w_out; ++w) { + *dout_channel++ = out_buf[w]; + } + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_depthwise_3x3p1.cc b/lite/arm/math/conv_depthwise_3x3p1.cc new file mode 100644 index 00000000000..86b2075bad7 --- /dev/null +++ b/lite/arm/math/conv_depthwise_3x3p1.cc @@ -0,0 +1,4850 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/conv_depthwise.h" +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void conv_depthwise_3x3s1p1_bias(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +//! for input width <= 4 +void conv_depthwise_3x3s1p1_bias_s(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p1_bias(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +//! for input width <= 4 +void conv_depthwise_3x3s2p1_bias_s(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s1p1_bias_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +//! for input width <= 4 +void conv_depthwise_3x3s1p1_bias_s_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p1_bias_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +//! for input width <= 4 +void conv_depthwise_3x3s2p1_bias_s_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3p1(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int stride, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + if (stride == 1) { + if (flag_relu) { + if (w_in > 4) { + conv_depthwise_3x3s1p1_bias_relu(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s1p1_bias_s_relu(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } else { + if (w_in > 4) { + conv_depthwise_3x3s1p1_bias(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s1p1_bias_s(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + } else { //! stride = 2 + if (flag_relu) { + if (w_in > 7) { + conv_depthwise_3x3s2p1_bias_relu(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p1_bias_s_relu(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } else { + if (w_in > 7) { + conv_depthwise_3x3s2p1_bias(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p1_bias_s(dout, + din, + weights, + bias, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + } +} +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, + * width > 4 + */ +// 4line +void conv_depthwise_3x3s1p1_bias(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + //! pad is done implicit + const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + //! for 4x6 convolution window + const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + // printf("conv3x3_dw start \n"); + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_w = (w_in + 3) >> 2; + int cnt_col = tile_w - 2; + + unsigned int size_pad_right = (unsigned int)(1 + (tile_w << 2) - w_in); + + uint32x4_t vmask_rp1 = + vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); + uint32x4_t vmask_rp2 = + vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); + uint32x4_t vmask_result = + vcgtq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); + + unsigned int vmask[8]; + vst1q_u32(vmask, vmask_rp1); + vst1q_u32(vmask + 4, vmask_rp2); + + unsigned int rmask[4]; + vst1q_u32(rmask, vmask_result); + + float32x4_t vzero = vdupq_n_f32(0.f); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for +#ifdef __aarch64__ + for (int c = 0; c < ch_in; c++) { + float* dout_ptr = dout_batch + c * size_out_channel; + + const float* din_ch_ptr = din_batch + c * size_in_channel; + + float bias_val = flag_bias ? bias[c] : 0.f; + float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + + const float* wei_ptr = weights + c * w_stride; + + float32x4_t wr0 = vld1q_f32(wei_ptr); + float32x4_t wr1 = vld1q_f32(wei_ptr + 3); + float32x4_t wr2 = vld1q_f32(wei_ptr + 6); + + float* doutr0 = dout_ptr; + float* doutr1 = doutr0 + w_out; + float* doutr2 = doutr1 + w_out; + float* doutr3 = doutr2 + w_out; + + const float* dr0 = din_ch_ptr; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + const float* dr5 = dr4 + w_in; + + const float* din_ptr0 = dr0; + const float* din_ptr1 = dr1; + const float* din_ptr2 = dr2; + const float* din_ptr3 = dr3; + const float* din_ptr4 = dr4; + const float* din_ptr5 = dr5; + + for (int i = 0; i < h_in; i += 4) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + din_ptr4 = dr4; + din_ptr5 = dr5; + + doutr0 = dout_ptr; + doutr1 = doutr0 + w_out; + doutr2 = doutr1 + w_out; + doutr3 = doutr2 + w_out; + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + din_ptr3 = dr2; + din_ptr4 = dr3; + din_ptr5 = dr4; + dr0 = dr3; + dr1 = dr4; + dr2 = dr5; + } else { + dr0 = dr4; + dr1 = dr5; + dr2 = dr1 + w_in; + } + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + dr5 = dr4 + w_in; + + //! process bottom pad + if (i + 5 > h_in) { + switch (i + 5 - h_in) { + case 5: + din_ptr1 = zero_ptr; + case 4: + din_ptr2 = zero_ptr; + case 3: + din_ptr3 = zero_ptr; + case 2: + din_ptr4 = zero_ptr; + case 1: + din_ptr5 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 4 > h_out) { + switch (i + 4 - h_out) { + case 3: + doutr1 = write_ptr; + case 2: + doutr2 = write_ptr; + case 1: + doutr3 = write_ptr; + default: + break; + } + } + + int cnt = cnt_col; + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "PRFM PLDL1KEEP, [%[din_ptr3]] \n" + "PRFM PLDL1KEEP, [%[din_ptr4]] \n" + "PRFM PLDL1KEEP, [%[din_ptr5]] \n" + "movi v21.4s, #0x0\n" /* out0 = 0 */ + + "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ + + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "ext v16.16b, %[vzero].16b, v0.16b, #12 \n" /* v16 = 00123*/ + "ext v17.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + + // left + // r0 + "fmla v12.4s, v0.4s, %[w0].s[1]\n" /* outr00 += din0_0123 * + w0[1]*/ + + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "sub %[din_ptr0], %[din_ptr0], #4 \n" /* din_ptr0-- */ + "sub %[din_ptr1], %[din_ptr1], #4 \n" /* din_ptr0-- */ + + "fmla v12.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din0_0012 * + w0[0]*/ + + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + "sub %[din_ptr2], %[din_ptr2], #4 \n" /* din_ptr0-- */ + "sub %[din_ptr3], %[din_ptr3], #4 \n" /* din_ptr0-- */ + + "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_1234 * + w0[2]*/ + + "ext v16.16b, %[vzero].16b, v2.16b, #12 \n" /* v16 = 00123*/ + "ext v17.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234 */ + + // r1 + "fmla v13.4s , v2.4s, %[w0].s[1]\n" /* outr00 += din1_0123 * + w0[1]*/ + "fmla v12.4s , v2.4s, %[w1].s[1]\n" /* outr00 += din1_0123 * + w1[1]*/ + "sub %[din_ptr4], %[din_ptr4], #4 \n" /* din_ptr0-- */ + "sub %[din_ptr5], %[din_ptr5], #4 \n" /* din_ptr0-- */ + + "fmla v13.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din1_0123 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din1_0123 * + w1[1]*/ + + "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * + w0[1]*/ + "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * + w1[1]*/ + + "ext v16.16b, %[vzero].16b, v4.16b, #12 \n" /* v16 = 00123*/ + "ext v17.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234 */ + + // r2 + "fmla v14.4s , v4.4s, %[w0].s[1]\n" /* outr00 += din2_0123 * + w0[1]*/ + "fmla v13.4s , v4.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * + w1[1]*/ + "fmla v12.4s , v4.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * + w2[1]*/ + + "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v14.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * + w1[1]*/ + + "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * + w0[1]*/ + "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * + w0[1]*/ + "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * + w1[1]*/ + + "ext v16.16b, %[vzero].16b, v6.16b, #12 \n" /* v16 = 00123*/ + "ext v17.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234 */ + + // r3 + "fmla v15.4s , v6.4s, %[w0].s[1]\n" /*outr00 += din2_0123 * + w0[1]*/ + "fmla v14.4s , v6.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * + w1[1]*/ + "fmla v13.4s , v6.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * + w2[1]*/ + + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v15.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * + w1[1]*/ + + "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * + w0[1]*/ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * + w0[1]*/ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * + w1[1]*/ + + "ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ + "ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */ + + // r4 + "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * + w1[1]*/ + "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * + w2[1]*/ + + "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * + w1[1]*/ + + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * + w0[1]*/ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * + w1[1]*/ + + "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ + "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ + + // r5 + "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * + w1[1]*/ + + "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * + w0[1]*/ + + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * + w0[1]*/ + + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ + + "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ + "cmp %[cnt], #1 \n" + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "blt 3f \n" + // mid + "1: \n" + // r0 + "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ + + // r1 + "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ + + // r2 + "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "st1 {v12.4s}, [%[doutr0]], #16 \n" + + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "st1 {v13.4s}, [%[doutr1]], #16 \n" + + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "st1 {v14.4s}, [%[doutr2]], #16 \n" + + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ + + "subs %[cnt], %[cnt], #1 \n" + + "st1 {v15.4s}, [%[doutr3]], #16 \n" + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "bne 1b \n" + + // right + "3: \n" + "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" + "ld1 {v22.4s}, [%[doutr0]] \n" + "ld1 {v23.4s}, [%[doutr1]] \n" + "ld1 {v24.4s}, [%[doutr2]] \n" + "ld1 {v25.4s}, [%[doutr3]] \n" + + "bif v0.16b, %[vzero].16b, v18.16b \n" + "bif v1.16b, %[vzero].16b, v19.16b \n" + "bif v2.16b, %[vzero].16b, v18.16b \n" + "bif v3.16b, %[vzero].16b, v19.16b \n" + + "bif v4.16b, %[vzero].16b, v18.16b \n" + "bif v5.16b, %[vzero].16b, v19.16b \n" + "bif v6.16b, %[vzero].16b, v18.16b \n" + "bif v7.16b, %[vzero].16b, v19.16b \n" + + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ + + // r0 + "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "bif v8.16b, %[vzero].16b, v18.16b \n" + "bif v9.16b, %[vzero].16b, v19.16b \n" + "bif v10.16b, %[vzero].16b, v18.16b \n" + "bif v11.16b, %[vzero].16b, v19.16b \n" + + "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v18.4s}, [%[rmask]] \n" + + "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ + + // r1 + "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ + + // r2 + "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "bif v12.16b, v22.16b, v18.16b \n" + + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "st1 {v12.4s}, [%[doutr0]], #16 \n" + + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "bif v13.16b, v23.16b, v18.16b \n" + + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "st1 {v13.4s}, [%[doutr1]], #16 \n" + + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "bif v14.16b, v24.16b, v18.16b \n" + + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "st1 {v14.4s}, [%[doutr2]], #16 \n" + + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "bif v15.16b, v25.16b, v18.16b \n" + + "st1 {v15.4s}, [%[doutr3]], #16 \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + dout_ptr = dout_ptr + 4 * w_out; + } + } +#else + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + float bias_val = flag_bias ? bias[i] : 0.f; + + float* dout_channel = dout_batch + i * size_out_channel; + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + + const float* din0_ptr = nullptr; + const float* din1_ptr = nullptr; + const float* din2_ptr = nullptr; + const float* din3_ptr = nullptr; + + float* doutr0 = nullptr; + float* doutr1 = nullptr; + + float* ptr_zero = const_cast(zero); + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + + doutr0 = dout_channel; + doutr1 = dout_channel + w_out; + // unsigned int* rst_mask = rmask; + + if (i == 0) { + din0_ptr = zero_ptr; + din1_ptr = dr0; + din2_ptr = dr1; + din3_ptr = dr2; + dr0 = dr1; + dr1 = dr2; + dr2 = dr3; + dr3 = dr2 + w_in; + } else { + dr0 = dr2; + dr1 = dr3; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + } + //! process bottom pad + if (i + 3 > h_in) { + switch (i + 3 - h_in) { + case 3: + din1_ptr = zero_ptr; + case 2: + din2_ptr = zero_ptr; + case 1: + din3_ptr = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 2 > h_out) { + doutr1 = write_ptr; + } + int cnt = cnt_col; + unsigned int* rmask_ptr = rmask; + unsigned int* vmask_ptr = vmask; + asm volatile( + "pld [%[din0_ptr]] @ preload data\n" + "pld [%[din1_ptr]] @ preload data\n" + "pld [%[din2_ptr]] @ preload data\n" + "pld [%[din3_ptr]] @ preload data\n" + + "vld1.32 {d16-d18}, [%[din0_ptr]]! @ load din r0\n" + "vld1.32 {d20-d22}, [%[din1_ptr]]! @ load din r1\n" + "vld1.32 {d24-d26}, [%[din2_ptr]]! @ load din r2\n" + "vld1.32 {d28-d30}, [%[din3_ptr]]! @ load din r3\n" + + "vdup.32 q4, %[bias_val] @ and \n" // q4 + // = + // vbias + "vdup.32 q5, %[bias_val] @ and \n" // q5 + // = + // vbias + + "vext.32 q6, %q[vzero], q8, #3 @ 0012\n" + "vext.32 q7, q8, q9, #1 @ 1234\n" + + // left + // r0 + "vmla.f32 q4, q8, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + + "sub %[din0_ptr], #12 @ 1pad + 2 float data overlap\n" + "sub %[din1_ptr], #12 @ 1pad + 2 float data overlap\n" + "sub %[din2_ptr], #12 @ 1pad + 2 float data overlap\n" + "sub %[din3_ptr], #12 @ 1pad + 2 float data overlap\n" + + "vmla.f32 q4, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" + + "pld [%[din0_ptr]] @ preload data\n" + "pld [%[din1_ptr]] @ preload data\n" + "pld [%[din2_ptr]] @ preload data\n" + "pld [%[din3_ptr]] @ preload data\n" + + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" + + "vext.32 q6, %q[vzero], q10, #3 @ 0012\n" + "vext.32 q7, q10, q11, #1 @ 1234\n" + + // r1 + "vmla.f32 q5, q10, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q10, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" + "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" + + "vmla.f32 q5, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" + + "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" + "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" + + "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" + "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" + + "vext.32 q6, %q[vzero], q12, #3 @ 0012\n" + "vext.32 q7, q12, q13, #1 @ 1234\n" + + // r2 + "vmla.f32 q5, q12, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q12, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" + + "vmla.f32 q5, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" + + "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" + + "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" + "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" + + "vext.32 q6, %q[vzero], q14, #3 @ 0012\n" + "vext.32 q7, q14, q15, #1 @ 1234\n" + + // r3 + "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" + + "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" + + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" + "vdup.32 q4, %[bias_val] @ and \n" // q4 + // = + // vbias + + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" + + "vext.32 q6, q8, q9, #1 @ 1234\n" + "vext.32 q7, q8, q9, #2 @ 2345\n" + "cmp %[cnt], #1 @ check whether has " + "mid cols\n" + + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " + "pointer\n" + + "vdup.32 q5, %[bias_val] @ and \n" // q5 + // = + // vbias + "blt 3f @ jump to main loop start " + "point\n" + + // mid + "1: @ right pad entry\n" + // r0 + "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" + + "pld [%[din0_ptr]] @ preload data\n" + "pld [%[din1_ptr]] @ preload data\n" + "pld [%[din2_ptr]] @ preload data\n" + "pld [%[din3_ptr]] @ preload data\n" + + "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" + + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" + + "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" + + "vext.32 q6, q10, q11, #1 @ 1234\n" + "vext.32 q7, q10, q11, #2 @ 2345\n" + + // r1 + "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" + + "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" + + "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q12, q13, #1 @ 1234\n" + "vext.32 q7, q12, q13, #2 @ 2345\n" + + // r2 + "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" + + "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" + + "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q14, q15, #1 @ 1234\n" + "vext.32 q7, q14, q15, #2 @ 2345\n" + + // r3 + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" + + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" + + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" + "vdup.32 q4, %[bias_val] @ and \n" // q4 + // = + // vbias + + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" + + "vext.32 q6, q8, q9, #1 @ 1234\n" + "vext.32 q7, q8, q9, #2 @ 2345\n" + + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " + "pointer\n" + + "subs %[cnt], #1 @ loop count minus 1\n" + + "vdup.32 q5, %[bias_val] @ and \n" // q4 + // = + // vbias + + "bne 1b @ jump to main loop start " + "point\n" + + // right + "3: @ right pad entry\n" + "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" + "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" + + "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" + "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" + + "vbif d16, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d17, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d18, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vbif d20, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d21, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d22, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vext.32 q6, q8, q9, #1 @ 1234\n" + "vext.32 q7, q8, q9, #2 @ 2345\n" + + // r0 + "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" + + "vbif d24, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d25, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d26, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + + "vbif d28, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d29, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d30, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" + + "vext.32 q6, q10, q11, #1 @ 1234\n" + "vext.32 q7, q10, q11, #2 @ 2345\n" + + // r1 + "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" + "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" + + "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" + "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" + + "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q12, q13, #1 @ 1234\n" + "vext.32 q7, q12, q13, #2 @ 2345\n" + + // r2 + "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q14, q15, #1 @ 1234\n" + "vext.32 q7, q14, q15, #2 @ 2345\n" + + // r3 + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" + + "vbif d8, d16, d19 @ bit select, deal with right pad\n" + "vbif d9, d17, d23 @ bit select, deal with right pad\n" + + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" + + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" + + "vbif d10, d20, d19 @ bit select, deal with right " + "pad\n" + "vbif d11, d21, d23 @ bit select, deal with right " + "pad\n" + + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " + "pointer\n" + + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + dout_channel += 2 * w_out; + } //! end of processing mid rows + } +#endif + } +} + +/** + * \brief depthwise convolution kernel 3x3, stride 2 + */ +// w_in > 7 +void conv_depthwise_3x3s2p1_bias(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + int size_pad_bottom = h_out * 2 - h_in; + + int cnt_col = (w_out >> 2) - 2; + int size_right_remain = w_in - (7 + cnt_col * 8); + if (size_right_remain >= 9) { + cnt_col++; + size_right_remain -= 8; + } + int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); // + + int size_right_pad = w_out * 2 - w_in; + + uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + uint32x4_t wmask = + vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + unsigned int dmask[12]; + + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + vst1q_u32(dmask + 8, wmask); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float32x4_t vzero = vdupq_n_f32(0.f); + + float32x4_t wbias; + float bias_c = 0.f; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + bias_c = bias[i]; + } else { + wbias = vdupq_n_f32(0.f); + } + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + const float* din3_ptr = dr3; + const float* din4_ptr = dr4; + + float* doutr0 = dout_channel; + float* doutr0_ptr = nullptr; + float* doutr1_ptr = nullptr; + +#ifdef __aarch64__ + for (int i = 0; i < h_in; i += 4) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + din4_ptr = dr4; + + doutr0_ptr = doutr0; + doutr1_ptr = doutr0 + w_out; + + if (i == 0) { + din0_ptr = zero_ptr; + din1_ptr = dr0; + din2_ptr = dr1; + din3_ptr = dr2; + din4_ptr = dr3; + dr0 = dr3; + dr1 = dr4; + } else { + dr0 = dr4; + dr1 = dr0 + w_in; + } + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + + //! process bottom pad + if (i + 4 > h_in) { + switch (i + 4 - h_in) { + case 4: + din1_ptr = zero_ptr; + case 3: + din2_ptr = zero_ptr; + case 2: + din3_ptr = zero_ptr; + case 1: + din4_ptr = zero_ptr; + default: + break; + } + } + //! process output pad + if (i / 2 + 2 > h_out) { + doutr1_ptr = write_ptr; + } + int cnt = cnt_col; + asm volatile( + // top + // Load up 12 elements (3 vectors) from each of 8 sources. + "0: \n" + "prfm pldl1keep, [%[inptr0]] \n" + "prfm pldl1keep, [%[inptr1]] \n" + "prfm pldl1keep, [%[inptr2]] \n" + "prfm pldl1keep, [%[inptr3]] \n" + "prfm pldl1keep, [%[inptr4]] \n" + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} + // v1={1,3,5,7} + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" + + "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias + "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias + + "ext v10.16b, %[vzero].16b, v1.16b, #12 \n" // v10 = {0,1,3,5} + + // r0 + "fmul v11.4s, v0.4s, %[w0].s[1] \n" // {0,2,4,6} * w01 + "fmul v12.4s, v1.4s, %[w0].s[2] \n" // {1,3,5,7} * w02 + "fmla v16.4s, v10.4s, %[w0].s[0] \n" // {0,1,3,5} * w00 + + "ext v10.16b, %[vzero].16b, v3.16b, #12 \n" // v10 = {0,1,3,5} + + "sub %[inptr0], %[inptr0], #4 \n" + "sub %[inptr1], %[inptr1], #4 \n" + + // r1 + "fmla v11.4s, v2.4s, %[w1].s[1] \n" // {0,2,4,6} * w01 + "fmla v12.4s, v3.4s, %[w1].s[2] \n" // {1,3,5,7} * w02 + "fmla v16.4s, v10.4s, %[w1].s[0] \n" // {0,1,3,5} * w00 + + "ext v10.16b, %[vzero].16b, v5.16b, #12 \n" // v10 = {0,1,3,5} + + "sub %[inptr2], %[inptr2], #4 \n" + "sub %[inptr3], %[inptr3], #4 \n" + + // r2 + "fmul v13.4s, v4.4s, %[w0].s[1] \n" // {0,2,4,6} * w01 + "fmla v11.4s, v4.4s, %[w2].s[1] \n" // {0,2,4,6} * w01 + + "fmul v14.4s, v5.4s, %[w0].s[2] \n" // {1,3,5,7} * w02 + "fmla v12.4s, v5.4s, %[w2].s[2] \n" // {1,3,5,7} * w02 + + "fmla v17.4s, v10.4s, %[w0].s[0] \n" // {0,1,3,5} * w00 + "fmla v16.4s, v10.4s, %[w2].s[0] \n" // {0,1,3,5} * w00 + + "ext v10.16b, %[vzero].16b, v7.16b, #12 \n" // v10 = {0,1,3,5} + + "sub %[inptr4], %[inptr4], #4 \n" + + // r3 + "fmla v13.4s, v6.4s, %[w1].s[1] \n" // {0,2,4,6} * w01 + "fmla v14.4s, v7.4s, %[w1].s[2] \n" // {1,3,5,7} * w02 + "fmla v17.4s, v10.4s, %[w1].s[0] \n" // {0,1,3,5} * w00 + + "ext v10.16b, %[vzero].16b, v9.16b, #12 \n" // v10 = {0,1,3,5} + "fadd v16.4s, v16.4s, v11.4s \n" + "fadd v16.4s, v16.4s, v12.4s \n" + + // r4 + "fmla v13.4s, v8.4s, %[w2].s[1] \n" // {0,2,4,6} * w01 + "fmla v14.4s, v9.4s, %[w2].s[2] \n" // {1,3,5,7} * w02 + "fmla v17.4s, v10.4s, %[w2].s[0] \n" // {0,1,3,5} * w00 + + "st1 {v16.4s}, [%[outptr0]], #16 \n" + + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} + // v1={1,3,5,7} + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" + + "fadd v17.4s, v17.4s, v13.4s \n" + + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" + "ld1 {v15.4s}, [%[inptr0]] \n" + "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias + + "fadd v17.4s, v17.4s, v14.4s \n" + + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + + "st1 {v17.4s}, [%[outptr1]], #16 \n" + + "cmp %[cnt], #1 \n" + + "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias + + "blt 1f \n" + // mid + "2: \n" + // r0 + "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 + "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 + "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v2.16b, v18.16b, #4 \n" // v10 = {2,4,6,8} + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} + // v1={1,3,5,7} + + // r1 + "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 + "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 + "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v4.16b, v19.16b, #4 \n" // v10 = {2,4,6,8} + + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" + + // r2 + "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 + "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 + + "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 + "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 + + "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 + "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v6.16b, v20.16b, #4 \n" // v10 = {2,4,6,8} + + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" + + // r3 + "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 + "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 + "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v8.16b, v21.16b, #4 \n" // v10 = {2,4,6,8} + + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" + + "fadd v16.4s, v16.4s, v11.4s \n" + "fadd v16.4s, v16.4s, v12.4s \n" + + // r4 + "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 + "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 + "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 + + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "st1 {v16.4s}, [%[outptr0]], #16 \n" + + "fadd v17.4s, v17.4s, v13.4s \n" + + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + + "fadd v17.4s, v17.4s, v14.4s \n" + + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias + "subs %[cnt], %[cnt], #1 \n" + + "st1 {v17.4s}, [%[outptr1]], #16 \n" + + "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias + + "bne 2b \n" + + // right + "1: \n" + "cmp %[remain], #1 \n" + "blt 4f \n" + "3: \n" + "bif v0.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v1.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + "bif v2.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v3.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + "bif v4.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v5.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + "ext v10.16b, v0.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + + "bif v6.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v7.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + // r0 + "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 + "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 + "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v2.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + "bif v8.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v9.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + // r1 + "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 + "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 + "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v4.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + + // r2 + "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 + "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 + + "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 + "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 + + "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 + "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v6.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + + // r3 + "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 + "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 + "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v8.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + "ld1 {v0.4s}, [%[outptr0]] \n" + + "fadd v16.4s, v16.4s, v11.4s \n" + "fadd v16.4s, v16.4s, v12.4s \n" + "ld1 {v1.4s}, [%[outptr1]] \n" + + // r4 + "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 + "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 + "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 + + "bif v16.16b, v0.16b, %[wmask].16b \n" // pipei + + "fadd v17.4s, v17.4s, v13.4s \n" + + "st1 {v16.4s}, [%[outptr0]], #16 \n" + + "fadd v17.4s, v17.4s, v14.4s \n" + + "bif v17.16b, v1.16b, %[wmask].16b \n" // pipei + + "st1 {v17.4s}, [%[outptr1]], #16 \n" + "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + doutr0 = doutr0 + 2 * w_out; + } +#else + for (int i = 0; i < h_in; i += 2) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + + doutr0_ptr = doutr0; + + if (i == 0) { + din0_ptr = zero_ptr; + din1_ptr = dr0; + din2_ptr = dr1; + dr0 = dr1; + dr1 = dr2; + dr2 = dr1 + w_in; + } else { + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + } + + //! process bottom pad + if (i + 2 > h_in) { + switch (i + 2 - h_in) { + case 2: + din1_ptr = zero_ptr; + case 1: + din2_ptr = zero_ptr; + default: + break; + } + } + int cnt = cnt_col; + unsigned int* mask_ptr = dmask; + asm volatile( + // top + // Load up 12 elements (3 vectors) from each of 8 sources. + "0: \n" + "vmov.u32 q9, #0 \n" + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" // v11={0,2,4,6} v12={1,3,5,7}, q10, q11 + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v11={0,2,4,6} v12={1,3,5,7}, q12, q13 + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" // v13={0,2,4,6} v14={1,3,5,7}, q14, q15 + "pld [%[din0_ptr]] @ preload data\n" + "pld [%[din1_ptr]] @ preload data\n" + "pld [%[din2_ptr]] @ preload data\n" + + "vdup.32 q3, %[bias] @ and \n" // q10 = + // vbias + + "vext.32 q6, q9, q11, #3 @ shift right 1 " + "data\n" // q2 = {0,1,3,5} + "vext.32 q7, q9, q13, #3 @ shift right 1 " + "data\n" // q6 = {0,1,3,5} + "vext.32 q8, q9, q15, #3 @ shift right 1 " + "data\n" // q6 = {0,1,3,5} + + "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 1, " + "out0\n" // q11 * w01 + "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 1, " + "out0\n" // q12 * w02 + "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 1, " + "out0\n" // q6 * w00 + + "sub %[din0_ptr], #4 @ inpitr0 - 1\n" + "sub %[din1_ptr], #4 @ inpitr1 - 1\n" + "sub %[din2_ptr], #4 @ inpitr2 - 1\n" + + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} + + "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, " + "out0\n" // q11 * w01 + "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, " + "out0\n" // q12 * w02 + "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, " + "out0\n" // q6 * w00 + + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v4={0,2,4,6} v5={1,3,5,7} + + "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 1, " + "out1\n" // q0 * w01 + "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, " + "out1\n" // q1 * w02 + "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, " + "out1\n" // q2 * w00 + + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" // v4={0,2,4,6} v5={1,3,5,7} + + "vadd.f32 q3, q3, q4 @ add \n" + "vadd.f32 q3, q3, q5 @ add \n" + + "vst1.32 {d6-d7}, [%[outptr]]! \n" + "cmp %[cnt], #1 \n" + "blt 1f \n" + // mid + "2: \n" + "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} + "vdup.32 q3, %[bias] @ and \n" // q10 = + // vbias + "vext.32 q6, q10, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + "vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" // q2={8,10,12,14} + + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " + "out0\n" // q0 * w00 + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " + "out0\n" // q6 * w02 + + "vext.32 q7, q12, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + "vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" // q2={8,10,12,14} + + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} + + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " + "out0\n" // q0 * w00 + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " + "out0\n" // q6 * w02 + + "vext.32 q6, q14, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v0={0,2,4,6} v1={1,3,5,7} + + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " + "out0\n" // q0 * w00 + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " + "out0\n" // q6 * w02 + + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // v4={0,2,4,6} v5={1,3,5,7} + + "vadd.f32 q3, q3, q4 @ add \n" + "vadd.f32 q3, q3, q5 @ add \n" + + "subs %[cnt], #1 \n" + + "vst1.32 {d6-d7}, [%[outptr]]! \n" + "bne 2b \n" + + // right + "1: \n" + "cmp %[remain], #1 \n" + "blt 3f \n" + + "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" + "vdup.32 q3, %[bias] @ and \n" // q10 = + // vbias + + "vbif q10, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q11, q9, q7 @ bit select, deal " + "with right pad\n" + "vbif q12, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q13, q9, q7 @ bit select, deal " + "with right pad\n" + "vbif q14, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q15, q9, q7 @ bit select, deal " + "with right pad\n" + + "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " + "out0\n" // q0 * w00 + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " + "out0\n" // q6 * w02 + + "vext.32 q6, q14, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + "vld1.f32 {d20-d21}, [%[outptr]] @ load output\n" + + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " + "out0\n" // q0 * w00 + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " + "out0\n" // q6 * w02 + + "vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n" + + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " + "out0\n" // q0 * w00 + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " + "out0\n" // q6 * w02 + + "vadd.f32 q3, q3, q4 @ add \n" + "vadd.f32 q3, q3, q5 @ add \n" + + "vbif.f32 q3, q10, q11 @ write mask\n" + + "vst1.32 {d6-d7}, [%[outptr]]! \n" + "3: \n" + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + + doutr0 = doutr0 + w_out; + } +#endif + } + } +} + +// 4line +void conv_depthwise_3x3s1p1_bias_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + //! pad is done implicit + const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + //! for 4x6 convolution window + const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; + + // printf("conv3x3_dw start \n"); + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + int w_stride = 9; + + int tile_w = (w_in + 3) >> 2; + int tile_h = (h_in + 3) >> 2; + int cnt_col = tile_w - 2; + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + unsigned int size_pad_right = (unsigned int)(1 + (tile_w << 2) - w_in); + int size_pad_bottom = (unsigned int)(1 + (tile_h << 2) - h_in); + + uint32x4_t vmask_rp1 = + vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); + uint32x4_t vmask_rp2 = + vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); + uint32x4_t vmask_result = + vcgtq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); + + unsigned int vmask[8]; + vst1q_u32(vmask, vmask_rp1); + vst1q_u32(vmask + 4, vmask_rp2); + + unsigned int rmask[4]; + vst1q_u32(rmask, vmask_result); + + float32x4_t vzero = vdupq_n_f32(0.f); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for +#ifdef __aarch64__ + for (int c = 0; c < ch_in; c++) { + float* dout_ptr = dout_batch + c * size_out_channel; + + const float* din_ch_ptr = din_batch + c * size_in_channel; + + float bias_val = flag_bias ? bias[c] : 0.f; + float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + + const float* wei_ptr = weights + c * w_stride; + + float32x4_t wr0 = vld1q_f32(wei_ptr); + float32x4_t wr1 = vld1q_f32(wei_ptr + 3); + float32x4_t wr2 = vld1q_f32(wei_ptr + 6); + + float* doutr0 = dout_ptr; + float* doutr1 = doutr0 + w_out; + float* doutr2 = doutr1 + w_out; + float* doutr3 = doutr2 + w_out; + + const float* dr0 = din_ch_ptr; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + const float* dr5 = dr4 + w_in; + + const float* din_ptr0 = dr0; + const float* din_ptr1 = dr1; + const float* din_ptr2 = dr2; + const float* din_ptr3 = dr3; + const float* din_ptr4 = dr4; + const float* din_ptr5 = dr5; + + for (int i = 0; i < h_in; i += 4) { + //! process top pad pad_h = 1 + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + din_ptr4 = dr4; + din_ptr5 = dr5; + + doutr0 = dout_ptr; + doutr1 = doutr0 + w_out; + doutr2 = doutr1 + w_out; + doutr3 = doutr2 + w_out; + if (i == 0) { + din_ptr0 = zero_ptr; + din_ptr1 = dr0; + din_ptr2 = dr1; + din_ptr3 = dr2; + din_ptr4 = dr3; + din_ptr5 = dr4; + dr0 = dr3; + dr1 = dr4; + dr2 = dr5; + } else { + dr0 = dr4; + dr1 = dr5; + dr2 = dr1 + w_in; + } + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + dr5 = dr4 + w_in; + + //! process bottom pad + if (i + 5 > h_in) { + switch (i + 5 - h_in) { + case 5: + din_ptr1 = zero_ptr; + case 4: + din_ptr2 = zero_ptr; + case 3: + din_ptr3 = zero_ptr; + case 2: + din_ptr4 = zero_ptr; + case 1: + din_ptr5 = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 4 > h_out) { + switch (i + 4 - h_out) { + case 3: + doutr1 = write_ptr; + case 2: + doutr2 = write_ptr; + case 1: + doutr3 = write_ptr; + default: + break; + } + } + + int cnt = cnt_col; + asm volatile( + "PRFM PLDL1KEEP, [%[din_ptr0]] \n" + "PRFM PLDL1KEEP, [%[din_ptr1]] \n" + "PRFM PLDL1KEEP, [%[din_ptr2]] \n" + "PRFM PLDL1KEEP, [%[din_ptr3]] \n" + "PRFM PLDL1KEEP, [%[din_ptr4]] \n" + "PRFM PLDL1KEEP, [%[din_ptr5]] \n" + "movi v21.4s, #0x0\n" /* out0 = 0 */ + + "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ + + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "ext v16.16b, %[vzero].16b, v0.16b, #12 \n" /* v16 = 00123*/ + "ext v17.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + + // left + // r0 + "fmla v12.4s, v0.4s, %[w0].s[1]\n" /* outr00 += din0_0123 * + w0[1]*/ + + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "sub %[din_ptr0], %[din_ptr0], #4 \n" /* din_ptr0-- */ + "sub %[din_ptr1], %[din_ptr1], #4 \n" /* din_ptr0-- */ + + "fmla v12.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din0_0012 * + w0[0]*/ + + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + "sub %[din_ptr2], %[din_ptr2], #4 \n" /* din_ptr0-- */ + "sub %[din_ptr3], %[din_ptr3], #4 \n" /* din_ptr0-- */ + + "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_1234 * + w0[2]*/ + + "ext v16.16b, %[vzero].16b, v2.16b, #12 \n" /* v16 = 00123*/ + "ext v17.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234 */ + + // r1 + "fmla v13.4s , v2.4s, %[w0].s[1]\n" /* outr00 += din1_0123 * + w0[1]*/ + "fmla v12.4s , v2.4s, %[w1].s[1]\n" /* outr00 += din1_0123 * + w1[1]*/ + "sub %[din_ptr4], %[din_ptr4], #4 \n" /* din_ptr0-- */ + "sub %[din_ptr5], %[din_ptr5], #4 \n" /* din_ptr0-- */ + + "fmla v13.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din1_0123 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din1_0123 * + w1[1]*/ + + "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * + w0[1]*/ + "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * + w1[1]*/ + + "ext v16.16b, %[vzero].16b, v4.16b, #12 \n" /* v16 = 00123*/ + "ext v17.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234 */ + + // r2 + "fmla v14.4s , v4.4s, %[w0].s[1]\n" /* outr00 += din2_0123 * + w0[1]*/ + "fmla v13.4s , v4.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * + w1[1]*/ + "fmla v12.4s , v4.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * + w2[1]*/ + + "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v14.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * + w1[1]*/ + + "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * + w0[1]*/ + "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * + w0[1]*/ + "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * + w1[1]*/ + + "ext v16.16b, %[vzero].16b, v6.16b, #12 \n" /* v16 = 00123*/ + "ext v17.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234 */ + + // r3 + "fmla v15.4s , v6.4s, %[w0].s[1]\n" /*outr00 += din2_0123 * + w0[1]*/ + "fmla v14.4s , v6.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * + w1[1]*/ + "fmla v13.4s , v6.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * + w2[1]*/ + + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v15.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * + w1[1]*/ + + "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * + w0[1]*/ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * + w0[1]*/ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * + w1[1]*/ + + "ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ + "ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */ + + // r4 + "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * + w1[1]*/ + "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * + w2[1]*/ + + "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ + "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ + + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * + w1[1]*/ + + "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ + + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * + w0[1]*/ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * + w1[1]*/ + + "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ + "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + // r5 + "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * + w1[1]*/ + + "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ + + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * + w0[1]*/ + + "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ + + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * + w0[1]*/ + + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ + + "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ + + "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ + "cmp %[cnt], #1 \n" + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "blt 3f \n" + // mid + "1: \n" + // r0 + "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ + + // r1 + "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ + + // r2 + "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ + + "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ + + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "st1 {v12.4s}, [%[doutr0]], #16 \n" + + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ + + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "st1 {v13.4s}, [%[doutr1]], #16 \n" + + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ + + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "st1 {v14.4s}, [%[doutr2]], #16 \n" + + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ + + "subs %[cnt], %[cnt], #1 \n" + + "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ + + "st1 {v15.4s}, [%[doutr3]], #16 \n" + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ + + "bne 1b \n" + + // right + "3: \n" + "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" + "ld1 {v22.4s}, [%[doutr0]] \n" + "ld1 {v23.4s}, [%[doutr1]] \n" + "ld1 {v24.4s}, [%[doutr2]] \n" + "ld1 {v25.4s}, [%[doutr3]] \n" + + "bif v0.16b, %[vzero].16b, v18.16b \n" + "bif v1.16b, %[vzero].16b, v19.16b \n" + "bif v2.16b, %[vzero].16b, v18.16b \n" + "bif v3.16b, %[vzero].16b, v19.16b \n" + + "bif v4.16b, %[vzero].16b, v18.16b \n" + "bif v5.16b, %[vzero].16b, v19.16b \n" + "bif v6.16b, %[vzero].16b, v18.16b \n" + "bif v7.16b, %[vzero].16b, v19.16b \n" + + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ + + // r0 + "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "bif v8.16b, %[vzero].16b, v18.16b \n" + "bif v9.16b, %[vzero].16b, v19.16b \n" + "bif v10.16b, %[vzero].16b, v18.16b \n" + "bif v11.16b, %[vzero].16b, v19.16b \n" + + "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "ld1 {v18.4s}, [%[rmask]] \n" + + "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ + + // r1 + "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ + + // r2 + "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ + + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "bif v12.16b, v22.16b, v18.16b \n" + + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ + + // r3 + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "st1 {v12.4s}, [%[doutr0]], #16 \n" + "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ + + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "bif v13.16b, v23.16b, v18.16b \n" + + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ + + "st1 {v13.4s}, [%[doutr1]], #16 \n" + + // r3 + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * + w0[0]*/ + + "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ + + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * + w0[1]*/ + + "bif v14.16b, v24.16b, v18.16b \n" + + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * + w0[2]*/ + + "st1 {v14.4s}, [%[doutr2]], #16 \n" + + "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ + + "bif v15.16b, v25.16b, v18.16b \n" + + "st1 {v15.4s}, [%[doutr3]], #16 \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + dout_ptr = dout_ptr + 4 * w_out; + } + } +#else + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + float bias_val = flag_bias ? bias[i] : 0.f; + + float* dout_channel = dout_batch + i * size_out_channel; + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + + const float* din0_ptr = nullptr; + const float* din1_ptr = nullptr; + const float* din2_ptr = nullptr; + const float* din3_ptr = nullptr; + + float* doutr0 = nullptr; + float* doutr1 = nullptr; + + float* ptr_zero = const_cast(zero); + + for (int i = 0; i < h_in; i += 2) { + //! process top pad pad_h = 1 + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + + doutr0 = dout_channel; + doutr1 = dout_channel + w_out; + // unsigned int* rst_mask = rmask; + + if (i == 0) { + din0_ptr = zero_ptr; + din1_ptr = dr0; + din2_ptr = dr1; + din3_ptr = dr2; + dr0 = dr1; + dr1 = dr2; + dr2 = dr3; + dr3 = dr2 + w_in; + } else { + dr0 = dr2; + dr1 = dr3; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + } + //! process bottom pad + if (i + 3 > h_in) { + switch (i + 3 - h_in) { + case 3: + din1_ptr = zero_ptr; + case 2: + din2_ptr = zero_ptr; + case 1: + din3_ptr = zero_ptr; + default: + break; + } + } + //! process bottom remain + if (i + 2 > h_out) { + doutr1 = write_ptr; + } + int cnt = cnt_col; + unsigned int* rmask_ptr = rmask; + unsigned int* vmask_ptr = vmask; + asm volatile( + "pld [%[din0_ptr]] @ preload data\n" + "pld [%[din1_ptr]] @ preload data\n" + "pld [%[din2_ptr]] @ preload data\n" + "pld [%[din3_ptr]] @ preload data\n" + + "vld1.32 {d16-d18}, [%[din0_ptr]]! @ load din r0\n" + "vld1.32 {d20-d22}, [%[din1_ptr]]! @ load din r1\n" + "vld1.32 {d24-d26}, [%[din2_ptr]]! @ load din r2\n" + "vld1.32 {d28-d30}, [%[din3_ptr]]! @ load din r3\n" + + "vdup.32 q4, %[bias_val] @ and \n" // q4 + // = + // vbias + "vdup.32 q5, %[bias_val] @ and \n" // q5 + // = + // vbias + + "vext.32 q6, %q[vzero], q8, #3 @ 0012\n" + "vext.32 q7, q8, q9, #1 @ 1234\n" + + // left + // r0 + "vmla.f32 q4, q8, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + + "sub %[din0_ptr], #12 @ 1pad + 2 float data overlap\n" + "sub %[din1_ptr], #12 @ 1pad + 2 float data overlap\n" + "sub %[din2_ptr], #12 @ 1pad + 2 float data overlap\n" + "sub %[din3_ptr], #12 @ 1pad + 2 float data overlap\n" + + "vmla.f32 q4, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" + + "pld [%[din0_ptr]] @ preload data\n" + "pld [%[din1_ptr]] @ preload data\n" + "pld [%[din2_ptr]] @ preload data\n" + "pld [%[din3_ptr]] @ preload data\n" + + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" + + "vext.32 q6, %q[vzero], q10, #3 @ 0012\n" + "vext.32 q7, q10, q11, #1 @ 1234\n" + + // r1 + "vmla.f32 q5, q10, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q10, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" + "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" + + "vmla.f32 q5, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" + + "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" + "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" + + "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" + "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" + + "vext.32 q6, %q[vzero], q12, #3 @ 0012\n" + "vext.32 q7, q12, q13, #1 @ 1234\n" + + // r2 + "vmla.f32 q5, q12, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q12, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" + + "vmla.f32 q5, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" + + "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" + + "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" + "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" + + "vext.32 q6, %q[vzero], q14, #3 @ 0012\n" + "vext.32 q7, q14, q15, #1 @ 1234\n" + + // r3 + "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" + "vmax.f32 q4, q4, %q[vzero] @ relu \n" + + "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" + + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" + + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" + + "vext.32 q6, q8, q9, #1 @ 1234\n" + "vext.32 q7, q8, q9, #2 @ 2345\n" + "vdup.32 q4, %[bias_val] @ and \n" // q4 + // = + // vbias + + "vmax.f32 q5, q5, %q[vzero] @ relu \n" + + "cmp %[cnt], #1 @ check whether has " + "mid cols\n" + + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " + "pointer\n" + + "vdup.32 q5, %[bias_val] @ and \n" // q5 + // = + // vbias + "blt 3f @ jump to main loop start " + "point\n" + + // mid + "1: @ right pad entry\n" + // r0 + "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" + + "pld [%[din0_ptr]] @ preload data\n" + "pld [%[din1_ptr]] @ preload data\n" + "pld [%[din2_ptr]] @ preload data\n" + "pld [%[din3_ptr]] @ preload data\n" + + "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" + + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" + + "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" + + "vext.32 q6, q10, q11, #1 @ 1234\n" + "vext.32 q7, q10, q11, #2 @ 2345\n" + + // r1 + "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" + + "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" + + "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q12, q13, #1 @ 1234\n" + "vext.32 q7, q12, q13, #2 @ 2345\n" + + // r2 + "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" + + "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" + + "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q14, q15, #1 @ 1234\n" + "vext.32 q7, q14, q15, #2 @ 2345\n" + + // r3 + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" + + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" + "vmax.f32 q4, q4, %q[vzero] @ relu \n" + + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" + + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" + + "vext.32 q6, q8, q9, #1 @ 1234\n" + "vext.32 q7, q8, q9, #2 @ 2345\n" + "vdup.32 q4, %[bias_val] @ and \n" // q4 + // = + // vbias + + "vmax.f32 q5, q5, %q[vzero] @ relu \n" + + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " + "pointer\n" + + "subs %[cnt], #1 @ loop count minus 1\n" + + "vdup.32 q5, %[bias_val] @ and \n" // q4 + // = + // vbias + + "bne 1b @ jump to main loop start " + "point\n" + + // right + "3: @ right pad entry\n" + "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" + "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" + + "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" + "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" + + "vbif d16, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d17, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d18, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vbif d20, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d21, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d22, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vext.32 q6, q8, q9, #1 @ 1234\n" + "vext.32 q7, q8, q9, #2 @ 2345\n" + + // r0 + "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" + + "vbif d24, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d25, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d26, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + + "vbif d28, %e[vzero], d19 @ bit select, deal with " + "right pad\n" + "vbif d29, %e[vzero], d23 @ bit select, deal with " + "right pad\n" + "vbif d30, %e[vzero], d27 @ bit select, deal with " + "right pad\n" + + "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" + + "vext.32 q6, q10, q11, #1 @ 1234\n" + "vext.32 q7, q10, q11, #2 @ 2345\n" + + // r1 + "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" + "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" + + "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + + "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" + "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" + + "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q12, q13, #1 @ 1234\n" + "vext.32 q7, q12, q13, #2 @ 2345\n" + + // r2 + "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" + "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" + "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" + + "vext.32 q6, q14, q15, #1 @ 1234\n" + "vext.32 q7, q14, q15, #2 @ 2345\n" + + // r3 + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" + + "vmax.f32 q4, q4, %q[vzero] @ relu \n" + + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" + + "vbif d8, d16, d19 @ bit select, deal with right pad\n" + "vbif d9, d17, d23 @ bit select, deal with right pad\n" + + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" + + "vmax.f32 q5, q5, %q[vzero] @ relu \n" + + "vbif d10, d20, d19 @ bit select, deal with right " + "pad\n" + "vbif d11, d21, d23 @ bit select, deal with right " + "pad\n" + + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " + "pointer\n" + + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + dout_channel += 2 * w_out; + } //! end of processing mid rows + } +#endif + } +} +/** + * \brief depthwise convolution kernel 3x3, stride 2, with reulu + */ +// w_in > 7 +void conv_depthwise_3x3s2p1_bias_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + int size_pad_bottom = h_out * 2 - h_in; + + int cnt_col = (w_out >> 2) - 2; + int size_right_remain = w_in - (7 + cnt_col * 8); + if (size_right_remain >= 9) { + cnt_col++; + size_right_remain -= 8; + } + int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); // + + int size_right_pad = w_out * 2 - w_in; + + uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + uint32x4_t wmask = + vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + unsigned int dmask[12]; + + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + vst1q_u32(dmask + 8, wmask); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float32x4_t vzero = vdupq_n_f32(0.f); + + float32x4_t wbias; + float bias_c = 0.f; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + bias_c = bias[i]; + } else { + wbias = vdupq_n_f32(0.f); + } + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + const float* din3_ptr = dr3; + const float* din4_ptr = dr4; + + float* doutr0 = dout_channel; + float* doutr0_ptr = nullptr; + float* doutr1_ptr = nullptr; + +#ifdef __aarch64__ + for (int i = 0; i < h_in; i += 4) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + din4_ptr = dr4; + + doutr0_ptr = doutr0; + doutr1_ptr = doutr0 + w_out; + + if (i == 0) { + din0_ptr = zero_ptr; + din1_ptr = dr0; + din2_ptr = dr1; + din3_ptr = dr2; + din4_ptr = dr3; + dr0 = dr3; + dr1 = dr4; + } else { + dr0 = dr4; + dr1 = dr0 + w_in; + } + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + + //! process bottom pad + if (i + 4 > h_in) { + switch (i + 4 - h_in) { + case 4: + din1_ptr = zero_ptr; + case 3: + din2_ptr = zero_ptr; + case 2: + din3_ptr = zero_ptr; + case 1: + din4_ptr = zero_ptr; + default: + break; + } + } + //! process output pad + if (i / 2 + 2 > h_out) { + doutr1_ptr = write_ptr; + } + int cnt = cnt_col; + asm volatile( + // top + // Load up 12 elements (3 vectors) from each of 8 sources. + "0: \n" + "prfm pldl1keep, [%[inptr0]] \n" + "prfm pldl1keep, [%[inptr1]] \n" + "prfm pldl1keep, [%[inptr2]] \n" + "prfm pldl1keep, [%[inptr3]] \n" + "prfm pldl1keep, [%[inptr4]] \n" + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} + // v1={1,3,5,7} + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" + + "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias + "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias + + "ext v10.16b, %[vzero].16b, v1.16b, #12 \n" // v10 = {0,1,3,5} + + // r0 + "fmul v11.4s, v0.4s, %[w0].s[1] \n" // {0,2,4,6} * w01 + "fmul v12.4s, v1.4s, %[w0].s[2] \n" // {1,3,5,7} * w02 + "fmla v16.4s, v10.4s, %[w0].s[0] \n" // {0,1,3,5} * w00 + + "ext v10.16b, %[vzero].16b, v3.16b, #12 \n" // v10 = {0,1,3,5} + + "sub %[inptr0], %[inptr0], #4 \n" + "sub %[inptr1], %[inptr1], #4 \n" + + // r1 + "fmla v11.4s, v2.4s, %[w1].s[1] \n" // {0,2,4,6} * w01 + "fmla v12.4s, v3.4s, %[w1].s[2] \n" // {1,3,5,7} * w02 + "fmla v16.4s, v10.4s, %[w1].s[0] \n" // {0,1,3,5} * w00 + + "ext v10.16b, %[vzero].16b, v5.16b, #12 \n" // v10 = {0,1,3,5} + + "sub %[inptr2], %[inptr2], #4 \n" + "sub %[inptr3], %[inptr3], #4 \n" + + // r2 + "fmul v13.4s, v4.4s, %[w0].s[1] \n" // {0,2,4,6} * w01 + "fmla v11.4s, v4.4s, %[w2].s[1] \n" // {0,2,4,6} * w01 + + "fmul v14.4s, v5.4s, %[w0].s[2] \n" // {1,3,5,7} * w02 + "fmla v12.4s, v5.4s, %[w2].s[2] \n" // {1,3,5,7} * w02 + + "fmla v17.4s, v10.4s, %[w0].s[0] \n" // {0,1,3,5} * w00 + "fmla v16.4s, v10.4s, %[w2].s[0] \n" // {0,1,3,5} * w00 + + "ext v10.16b, %[vzero].16b, v7.16b, #12 \n" // v10 = {0,1,3,5} + + "sub %[inptr4], %[inptr4], #4 \n" + + // r3 + "fmla v13.4s, v6.4s, %[w1].s[1] \n" // {0,2,4,6} * w01 + "fmla v14.4s, v7.4s, %[w1].s[2] \n" // {1,3,5,7} * w02 + "fmla v17.4s, v10.4s, %[w1].s[0] \n" // {0,1,3,5} * w00 + + "ext v10.16b, %[vzero].16b, v9.16b, #12 \n" // v10 = {0,1,3,5} + "fadd v16.4s, v16.4s, v11.4s \n" + "fadd v16.4s, v16.4s, v12.4s \n" + + // r4 + "fmla v13.4s, v8.4s, %[w2].s[1] \n" // {0,2,4,6} * w01 + "fmla v14.4s, v9.4s, %[w2].s[2] \n" // {1,3,5,7} * w02 + "fmla v17.4s, v10.4s, %[w2].s[0] \n" // {0,1,3,5} * w00 + + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ + + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} + // v1={1,3,5,7} + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" + + "fadd v17.4s, v17.4s, v13.4s \n" + + "st1 {v16.4s}, [%[outptr0]], #16 \n" + + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" + "ld1 {v15.4s}, [%[inptr0]] \n" + + "fadd v17.4s, v17.4s, v14.4s \n" + + "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias + + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ + + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + + "st1 {v17.4s}, [%[outptr1]], #16 \n" + + "cmp %[cnt], #1 \n" + + "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias + + "blt 1f \n" + // mid + "2: \n" + // r0 + "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 + "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 + "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v2.16b, v18.16b, #4 \n" // v10 = {2,4,6,8} + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} + // v1={1,3,5,7} + + // r1 + "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 + "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 + "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v4.16b, v19.16b, #4 \n" // v10 = {2,4,6,8} + + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" + + // r2 + "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 + "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 + + "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 + "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 + + "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 + "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v6.16b, v20.16b, #4 \n" // v10 = {2,4,6,8} + + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" + + // r3 + "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 + "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 + "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v8.16b, v21.16b, #4 \n" // v10 = {2,4,6,8} + + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" + + "fadd v16.4s, v16.4s, v11.4s \n" + "fadd v16.4s, v16.4s, v12.4s \n" + + // r4 + "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 + "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 + "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 + + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ + + "fadd v17.4s, v17.4s, v13.4s \n" + + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + + "st1 {v16.4s}, [%[outptr0]], #16 \n" + + "fadd v17.4s, v17.4s, v14.4s \n" + + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias + "subs %[cnt], %[cnt], #1 \n" + + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ + + "st1 {v17.4s}, [%[outptr1]], #16 \n" + + "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias + + "bne 2b \n" + + // right + "1: \n" + "cmp %[remain], #1 \n" + "blt 4f \n" + "3: \n" + "bif v0.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v1.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + "bif v2.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v3.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + "bif v4.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v5.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + "ext v10.16b, v0.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + + "bif v6.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v7.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + // r0 + "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 + "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 + "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v2.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + "bif v8.16b, %[vzero].16b, %[mask1].16b \n" // pipei + "bif v9.16b, %[vzero].16b, %[mask2].16b \n" // pipei + + // r1 + "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 + "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 + "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v4.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + + // r2 + "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 + "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 + + "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 + "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 + + "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 + "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v6.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + + // r3 + "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 + "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 + "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 + + "ext v10.16b, v8.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} + "ld1 {v0.4s}, [%[outptr0]] \n" + + "fadd v16.4s, v16.4s, v11.4s \n" + "fadd v16.4s, v16.4s, v12.4s \n" + "ld1 {v1.4s}, [%[outptr1]] \n" + + // r4 + "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 + "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 + "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 + + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ + + "fadd v17.4s, v17.4s, v13.4s \n" + + "bif v16.16b, v0.16b, %[wmask].16b \n" // pipei + + "fadd v17.4s, v17.4s, v14.4s \n" + + "st1 {v16.4s}, [%[outptr0]], #16 \n" + + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ + + "bif v17.16b, v1.16b, %[wmask].16b \n" // pipei + + "st1 {v17.4s}, [%[outptr1]], #16 \n" + "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + doutr0 = doutr0 + 2 * w_out; + } +#else + + for (int i = 0; i < h_in; i += 2) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + + doutr0_ptr = doutr0; + + if (i == 0) { + din0_ptr = zero_ptr; + din1_ptr = dr0; + din2_ptr = dr1; + dr0 = dr1; + dr1 = dr2; + dr2 = dr1 + w_in; + } else { + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + } + + //! process bottom pad + if (i + 2 > h_in) { + switch (i + 2 - h_in) { + case 2: + din1_ptr = zero_ptr; + case 1: + din2_ptr = zero_ptr; + default: + break; + } + } + int cnt = cnt_col; + + unsigned int* mask_ptr = dmask; + asm volatile( + // top + // Load up 12 elements (3 vectors) from each of 8 sources. + "0: \n" + "vmov.u32 q9, #0 \n" + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" // v11={0,2,4,6} v12={1,3,5,7}, q10, q11 + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v11={0,2,4,6} v12={1,3,5,7}, q12, q13 + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" // v13={0,2,4,6} v14={1,3,5,7}, q14, q15 + "pld [%[din0_ptr]] @ preload data\n" + "pld [%[din1_ptr]] @ preload data\n" + "pld [%[din2_ptr]] @ preload data\n" + + "vdup.32 q3, %[bias] @ and \n" // q10 = + // vbias + + "vext.32 q6, q9, q11, #3 @ shift right 1 " + "data\n" // q2 = {0,1,3,5} + "vext.32 q7, q9, q13, #3 @ shift right 1 " + "data\n" // q6 = {0,1,3,5} + "vext.32 q8, q9, q15, #3 @ shift right 1 " + "data\n" // q6 = {0,1,3,5} + + "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 1, " + "out0\n" // q11 * w01 + "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 1, " + "out0\n" // q12 * w02 + "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 1, " + "out0\n" // q6 * w00 + + "sub %[din0_ptr], #4 @ inpitr0 - 1\n" + "sub %[din1_ptr], #4 @ inpitr1 - 1\n" + "sub %[din2_ptr], #4 @ inpitr2 - 1\n" + + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} + + "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, " + "out0\n" // q11 * w01 + "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, " + "out0\n" // q12 * w02 + "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, " + "out0\n" // q6 * w00 + + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v4={0,2,4,6} v5={1,3,5,7} + + "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 1, " + "out1\n" // q0 * w01 + "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, " + "out1\n" // q1 * w02 + "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, " + "out1\n" // q2 * w00 + + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" // v4={0,2,4,6} v5={1,3,5,7} + + "vadd.f32 q3, q3, q4 @ add \n" + "vadd.f32 q3, q3, q5 @ add \n" + + "vmax.f32 q3, q3, q9 @ relu \n" + + "vst1.32 {d6-d7}, [%[outptr]]! \n" + "cmp %[cnt], #1 \n" + "blt 1f \n" + // mid + "2: \n" + "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} + "vdup.32 q3, %[bias] @ and \n" // q10 = + // vbias + "vext.32 q6, q10, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + "vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" // q2={8,10,12,14} + + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " + "out0\n" // q0 * w00 + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " + "out0\n" // q6 * w02 + + "vext.32 q7, q12, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + "vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" // q2={8,10,12,14} + + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} + + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " + "out0\n" // q0 * w00 + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " + "out0\n" // q6 * w02 + + "vext.32 q6, q14, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v0={0,2,4,6} v1={1,3,5,7} + + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " + "out0\n" // q0 * w00 + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " + "out0\n" // q6 * w02 + + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // v4={0,2,4,6} v5={1,3,5,7} + + "vadd.f32 q3, q3, q4 @ add \n" + "vadd.f32 q3, q3, q5 @ add \n" + + "vmax.f32 q3, q3, q9 @ relu \n" + + "subs %[cnt], #1 \n" + + "vst1.32 {d6-d7}, [%[outptr]]! \n" + "bne 2b \n" + + // right + "1: \n" + "cmp %[remain], #1 \n" + "blt 3f \n" + + "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" + "vdup.32 q3, %[bias] @ and \n" // q10 = + // vbias + + "vbif q10, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q11, q9, q7 @ bit select, deal " + "with right pad\n" + "vbif q12, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q13, q9, q7 @ bit select, deal " + "with right pad\n" + "vbif q14, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q15, q9, q7 @ bit select, deal " + "with right pad\n" + + "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " + "out0\n" // q0 * w00 + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " + "out0\n" // q6 * w02 + + "vext.32 q6, q14, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} + "vld1.f32 {d20-d21}, [%[outptr]] @ load output\n" + + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " + "out0\n" // q0 * w00 + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " + "out0\n" // q6 * w02 + + "vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n" + + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " + "out0\n" // q0 * w00 + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " + "out0\n" // q1 * w01 + "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " + "out0\n" // q6 * w02 + + "vadd.f32 q3, q3, q4 @ add \n" + "vadd.f32 q3, q3, q5 @ add \n" + + "vmax.f32 q3, q3, q9 @ relu \n" + + "vbif.f32 q3, q10, q11 @ write mask\n" + + "vst1.32 {d6-d7}, [%[outptr]]! \n" + "3: \n" + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + + doutr0 = doutr0 + w_out; + } +#endif + } + } +} +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, + * width <= 4 + */ +void conv_depthwise_3x3s1p1_bias_s(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + //! 3x3s1 convolution, implemented by direct algorithm + //! pad is done implicit + //! for 4x6 convolution window + const int right_pad_idx[4] = {3, 2, 1, 0}; + const float zero[4] = {0.f, 0.f, 0.f, 0.f}; + + float32x4_t vzero = vdupq_n_f32(0.f); + uint32x4_t vmask_rp = + vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in)); + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + float* dout_channel = dout_batch + i * size_out_channel; + const float* din_channel = din_batch + i * size_in_channel; + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } + + int hs = -1; + int he = 3; + + float out_buf1[4]; + float out_buf2[4]; + float trash_buf[4]; + + int h_cnt = (h_out + 1) >> 1; + float* doutr0 = dout_channel; + float* doutr1 = dout_channel + w_out; + + for (int j = 0; j < h_cnt; ++j) { + const float* dr0 = din_channel + hs * w_in; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + + if (hs == -1) { + dr0 = zero; + } + + switch (he - h_in) { + case 2: + dr2 = zero; + doutr1 = trash_buf; + case 1: + dr3 = zero; + default: + break; + } +#ifdef __aarch64__ + asm volatile( + "prfm pldl1keep, [%[din0]]\n" + "prfm pldl1keep, [%[din1]]\n" + "prfm pldl1keep, [%[din2]]\n" + "prfm pldl1keep, [%[din3]]\n" + + "ld1 {v0.4s}, [%[din0]], #16\n" + "ld1 {v1.4s}, [%[din1]], #16\n" + "ld1 {v2.4s}, [%[din2]], #16\n" + "ld1 {v3.4s}, [%[din3]], #16\n" + + "bif v0.16b, %[zero].16b, %[mask].16b\n" // d0_1234 + "bif v1.16b, %[zero].16b, %[mask].16b\n" // d1_1234 + "bif v2.16b, %[zero].16b, %[mask].16b\n" // d2_1234 + "bif v3.16b, %[zero].16b, %[mask].16b\n" // d3_1234 + + "ext v4.16b, %[zero].16b, v0.16b, #12\n" // d0_0123 + "ext v5.16b, %[zero].16b, v1.16b, #12\n" // d1_0123 + "ext v6.16b, %[zero].16b, v2.16b, #12\n" // d2_0123 + "ext v7.16b, %[zero].16b, v3.16b, #12\n" // d3_0123 + + "ext v8.16b, v0.16b, %[zero].16b, #4\n" // d0_2340 + "ext v9.16b, v1.16b, %[zero].16b, #4\n" // d1_2340 + "ext v10.16b, v2.16b, %[zero].16b, #4\n" // d2_2340 + "ext v11.16b, v3.16b, %[zero].16b, #4\n" // d3_2340 + + "fmul v12.4s, v0.4s, %[wr0].s[1]\n" + "fmul v13.4s, v1.4s, %[wr0].s[1]\n" + + "fmul v14.4s, v1.4s, %[wr1].s[1]\n" + "fmul v15.4s, v2.4s, %[wr1].s[1]\n" + + "fmul v16.4s, v2.4s, %[wr2].s[1]\n" + "fmul v17.4s, v3.4s, %[wr2].s[1]\n" + + "fmla v12.4s, v4.4s, %[wr0].s[0]\n" + "fmla v13.4s, v5.4s, %[wr0].s[0]\n" + + "fmla v14.4s, v5.4s, %[wr1].s[0]\n" + "fmla v15.4s, v6.4s, %[wr1].s[0]\n" + + "fmla v16.4s, v6.4s, %[wr2].s[0]\n" + "fmla v17.4s, v7.4s, %[wr2].s[0]\n" + + "fmla v12.4s, v8.4s, %[wr0].s[2]\n" + "fmla v13.4s, v9.4s, %[wr0].s[2]\n" + + "fmla v14.4s, v9.4s, %[wr1].s[2]\n" + "fmla v15.4s, v10.4s, %[wr1].s[2]\n" + + "fmla v16.4s, v10.4s, %[wr2].s[2]\n" + "fmla v17.4s, v11.4s, %[wr2].s[2]\n" + + "fadd v12.4s, v12.4s, v14.4s\n" + "fadd v12.4s, v12.4s, v16.4s\n" + + "fadd v13.4s, v13.4s, v15.4s\n" // out1 + "fadd v13.4s, v13.4s, v17.4s\n" // out2 + + "fadd v12.4s, v12.4s, %[bias].4s\n" // out1 add bias + "fadd v13.4s, v13.4s, %[bias].4s\n" // out2 add bias + + "prfm pldl1keep, [%[out1]]\n" + "prfm pldl1keep, [%[out2]]\n" + + "st1 {v12.4s}, [%[out1]]\n" + "st1 {v13.4s}, [%[out2]]\n" + + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [zero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17"); +#else + asm volatile( + "pld [%[din0]]\n" + "pld [%[din1]]\n" + "pld [%[din2]]\n" + "pld [%[din3]]\n" + + "vld1.32 {d12-d13}, [%[din0]]!\n" + "vld1.32 {d14-d15}, [%[din1]]!\n" + "vld1.32 {d16-d17}, [%[din2]]!\n" + "vld1.32 {d18-d19}, [%[din3]]!\n" + + "vbif q6, %q[zero], %q[mask]\n" // d0_1234 + "vbif q7, %q[zero], %q[mask]\n" // d1_1234 + "vbif q8, %q[zero], %q[mask]\n" // d2_1234 + "vbif q9, %q[zero], %q[mask]\n" // d3_1234 + + "vmul.f32 q14, q6, %e[wr0][1]\n" + "vmul.f32 q15, q7, %e[wr0][1]\n" + + "vmla.f32 q14, q7, %e[wr1][1]\n" + "vmla.f32 q15, q8, %e[wr1][1]\n" + + "vmla.f32 q14, q8, %e[wr2][1]\n" + "vmla.f32 q15, q9, %e[wr2][1]\n" + + "vext.32 q10, %q[zero], q6, #3\n" // d0_0123 + "vext.32 q11, %q[zero], q7, #3\n" // d1_0123 + "vext.32 q12, %q[zero], q8, #3\n" // d2_0123 + "vext.32 q13, %q[zero], q9, #3\n" // d3_0123 + + "vmla.f32 q14, q10, %e[wr0][0]\n" + "vmla.f32 q15, q11, %e[wr0][0]\n" + + "vmla.f32 q14, q11, %e[wr1][0]\n" + "vmla.f32 q15, q12, %e[wr1][0]\n" + + "vmla.f32 q14, q12, %e[wr2][0]\n" + "vmla.f32 q15, q13, %e[wr2][0]\n" + + "vext.32 q10, q6, %q[zero], #1\n" // d0_2340 + "vext.32 q11, q7, %q[zero], #1\n" // d1_2340 + "vext.32 q12, q8, %q[zero], #1\n" // d2_2340 + "vext.32 q13, q9, %q[zero], #1\n" // d3_2340 + + "vmla.f32 q14, q10, %f[wr0][0]\n" + "vmla.f32 q15, q11, %f[wr0][0]\n" + + "vmla.f32 q14, q11, %f[wr1][0]\n" + "vmla.f32 q15, q12, %f[wr1][0]\n" + + "vmla.f32 q14, q12, %f[wr2][0]\n" // out1 + "vmla.f32 q15, q13, %f[wr2][0]\n" // out2 + + "vadd.f32 q14, q14, %q[bias]\n" // out1 add bias + "vadd.f32 q15, q15, %q[bias]\n" // out2 add bias + + "pld [%[out1]]\n" + "pld [%[out2]]\n" + + "vst1.32 {d28-d29}, [%[out1]]\n" + "vst1.32 {d30-d31}, [%[out2]]\n" + + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [zero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif // __aarch64__ + for (int w = 0; w < w_out; ++w) { + *doutr0++ = out_buf1[w]; + *doutr1++ = out_buf2[w]; + } + doutr0 = doutr1; + doutr1 += w_out; + hs += 2; + he += 2; + } // end of processing heights + } // end of processing channels + } // end of processing batchs +} +/** + * \brief depthwise convolution kernel 3x3, stride 2, width <= 4 + */ + +void conv_depthwise_3x3s2p1_bias_s(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + float zeros[8] = {0.0f}; + + uint32x4_t vmask_rp1 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + unsigned int dmask[8]; + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float bias_c = 0.f; + + if (flag_bias) { + bias_c = bias[i]; + } + float32x4_t vbias = vdupq_n_f32(bias_c); + int hs = -1; + int he = 2; + float out_buf[4]; + for (int j = 0; j < h_out; ++j) { + const float* dr0 = din_channel + hs * w_in; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + if (hs == -1) { + dr0 = zeros; + } + if (he > h_in) { + dr2 = zeros; + } + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + + unsigned int* mask_ptr = dmask; +#ifdef __aarch64__ + asm volatile( + // Load up 12 elements (3 vectors) from each of 8 sources. + "movi v9.4s, #0 \n" + "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" + + "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" // v10={0,2,4,6} + // v11={1,3,5,7} + "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" // v13={0,2,4,6} + // v12={1,3,5,7} + "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" // v14={0,2,4,6} + // v15={1,3,5,7} + + "bif v10.16b, v9.16b, v6.16b \n" + "bif v11.16b, v9.16b, v7.16b \n" + "bif v12.16b, v9.16b, v6.16b \n" + "bif v13.16b, v9.16b, v7.16b \n" + "bif v14.16b, v9.16b, v6.16b \n" + "bif v15.16b, v9.16b, v7.16b \n" + + "ext v6.16b, v9.16b, v11.16b, #12 \n" // v6 = + // {0,1,3,5} + "ext v7.16b, v9.16b, v13.16b, #12 \n" // v7 = + // {0,1,3,5} + "ext v8.16b, v9.16b, v15.16b, #12 \n" // v8 = + // {0,1,3,5} + + "fmul v4.4s, v10.4s, %[wr0].s[1] \n" // v10 * w01 + "fmul v5.4s, v11.4s, %[wr0].s[2] \n" // v11 * w02 + "fmul v6.4s, v6.4s, %[wr0].s[0] \n" // v6 * w00 + + "fmla v4.4s, v12.4s, %[wr1].s[1] \n" // v12 * w11 + "fmla v5.4s, v13.4s, %[wr1].s[2] \n" // v13 * w12 + "fmla v6.4s, v7.4s, %[wr1].s[0] \n" // v7 * w10 + + "fmla v4.4s, v14.4s, %[wr2].s[1] \n" // v14 * w20 + "fmla v5.4s, v15.4s, %[wr2].s[2] \n" // v15 * w21 + "fmla v6.4s, v8.4s, %[wr2].s[0] \n" // v8 * w22 + + "fadd v4.4s, v4.4s, v5.4s \n" + "fadd v4.4s, v4.4s, v6.4s \n" + + "fadd v4.4s, v4.4s, %[bias].4s \n" + + "st1 {v4.4s}, [%[out]] \n" + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [out] "r"(out_buf) + : "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + +#else + asm volatile( + // Load up 12 elements (3 vectors) from each of 8 sources. + "vmov.u32 q9, #0 \n" + "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" + "vdup.32 q3, %[bias] @ and \n" // q3 = + // vbias + + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // q10={0,2,4,6} q11={1,3,5,7} + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // q13={0,2,4,6} q12={1,3,5,7} + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // q14={0,2,4,6} q15={1,3,5,7} + + "vbif q10, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q11, q9, q7 @ bit select, deal " + "with right pad\n" + "vbif q12, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q13, q9, q7 @ bit select, deal " + "with right pad\n" + "vbif q14, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q15, q9, q7 @ bit select, deal " + "with right pad\n" + + "vext.32 q6, q9, q11, #3 @ shift left 1 \n" // q6 = {0,1,3,5} + "vext.32 q7, q9, q13, #3 @ shift left 1 \n" // q7 = {0,1,3,5} + "vext.32 q8, q9, q15, #3 @ shift left 1 \n" // q8 = {0,1,3,5} + + "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 0, " + "out0\n" // q10 * w01 + "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 0, " + "out0\n" // q11 * w02 + "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 0, " + "out0\n" // q6 * w00 + + "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, " + "out0\n" // q12 * w11 + "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, " + "out0\n" // q13 * w12 + "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, " + "out0\n" // q7 * w10 + + "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, " + "out0\n" // q14 * w20 + "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, " + "out0\n" // q15 * w21 + "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, " + "out0\n" // q8 * w22 + + "vadd.f32 q3, q3, q4 @ add \n" + "vadd.f32 q3, q3, q5 @ add \n" + + "vst1.32 {d6-d7}, [%[out]] \n" + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif // __aarch64__ + for (int w = 0; w < w_out; ++w) { + *dout_channel++ = out_buf[w]; + } + hs += 2; + he += 2; + } + } + } +} +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, + * width <= 4 + */ +void conv_depthwise_3x3s1p1_bias_s_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + //! 3x3s1 convolution, implemented by direct algorithm + //! pad is done implicit + //! for 4x6 convolution window + const int right_pad_idx[4] = {3, 2, 1, 0}; + const float zero[4] = {0.f, 0.f, 0.f, 0.f}; + + float32x4_t vzero = vdupq_n_f32(0.f); + uint32x4_t vmask_rp = + vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in)); + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + float* dout_channel = dout_batch + i * size_out_channel; + const float* din_channel = din_batch + i * size_in_channel; + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } + + int hs = -1; + int he = 3; + + float out_buf1[4]; + float out_buf2[4]; + float trash_buf[4]; + + int h_cnt = (h_out + 1) >> 1; + float* doutr0 = dout_channel; + float* doutr1 = dout_channel + w_out; + + for (int j = 0; j < h_cnt; ++j) { + const float* dr0 = din_channel + hs * w_in; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + + if (hs == -1) { + dr0 = zero; + } + + switch (he - h_in) { + case 2: + dr2 = zero; + doutr1 = trash_buf; + case 1: + dr3 = zero; + default: + break; + } +#ifdef __aarch64__ + asm volatile( + "prfm pldl1keep, [%[din0]]\n" + "prfm pldl1keep, [%[din1]]\n" + "prfm pldl1keep, [%[din2]]\n" + "prfm pldl1keep, [%[din3]]\n" + + "ld1 {v0.4s}, [%[din0]], #16\n" + "ld1 {v1.4s}, [%[din1]], #16\n" + "ld1 {v2.4s}, [%[din2]], #16\n" + "ld1 {v3.4s}, [%[din3]], #16\n" + + "bif v0.16b, %[zero].16b, %[mask].16b\n" // d0_1234 + "bif v1.16b, %[zero].16b, %[mask].16b\n" // d1_1234 + "bif v2.16b, %[zero].16b, %[mask].16b\n" // d2_1234 + "bif v3.16b, %[zero].16b, %[mask].16b\n" // d3_1234 + + "ext v4.16b, %[zero].16b, v0.16b, #12\n" // d0_0123 + "ext v5.16b, %[zero].16b, v1.16b, #12\n" // d1_0123 + "ext v6.16b, %[zero].16b, v2.16b, #12\n" // d2_0123 + "ext v7.16b, %[zero].16b, v3.16b, #12\n" // d3_0123 + + "ext v8.16b, v0.16b, %[zero].16b, #4\n" // d0_2340 + "ext v9.16b, v1.16b, %[zero].16b, #4\n" // d1_2340 + "ext v10.16b, v2.16b, %[zero].16b, #4\n" // d2_2340 + "ext v11.16b, v3.16b, %[zero].16b, #4\n" // d3_2340 + + "fmul v12.4s, v0.4s, %[wr0].s[1]\n" + "fmul v13.4s, v1.4s, %[wr0].s[1]\n" + + "fmul v14.4s, v1.4s, %[wr1].s[1]\n" + "fmul v15.4s, v2.4s, %[wr1].s[1]\n" + + "fmul v16.4s, v2.4s, %[wr2].s[1]\n" + "fmul v17.4s, v3.4s, %[wr2].s[1]\n" + + "fmla v12.4s, v4.4s, %[wr0].s[0]\n" + "fmla v13.4s, v5.4s, %[wr0].s[0]\n" + + "fmla v14.4s, v5.4s, %[wr1].s[0]\n" + "fmla v15.4s, v6.4s, %[wr1].s[0]\n" + + "fmla v16.4s, v6.4s, %[wr2].s[0]\n" + "fmla v17.4s, v7.4s, %[wr2].s[0]\n" + + "fmla v12.4s, v8.4s, %[wr0].s[2]\n" + "fmla v13.4s, v9.4s, %[wr0].s[2]\n" + + "fmla v14.4s, v9.4s, %[wr1].s[2]\n" + "fmla v15.4s, v10.4s, %[wr1].s[2]\n" + + "fmla v16.4s, v10.4s, %[wr2].s[2]\n" + "fmla v17.4s, v11.4s, %[wr2].s[2]\n" + + "fadd v12.4s, v12.4s, v14.4s\n" + "fadd v12.4s, v12.4s, v16.4s\n" + + "fadd v13.4s, v13.4s, v15.4s\n" // out1 + "fadd v13.4s, v13.4s, v17.4s\n" // out2 + + "fadd v12.4s, v12.4s, %[bias].4s\n" // out1 add bias + "fadd v13.4s, v13.4s, %[bias].4s\n" // out2 add bias + + "prfm pldl1keep, [%[out1]]\n" + "prfm pldl1keep, [%[out2]]\n" + + "fmax v12.4s, v12.4s, %[zero].4s\n" // out1 -> relu + "fmax v13.4s, v13.4s, %[zero].4s\n" // out2 -> relu + + "st1 {v12.4s}, [%[out1]]\n" + "st1 {v13.4s}, [%[out2]]\n" + + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [zero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17"); +#else + asm volatile( + "pld [%[din0]]\n" + "pld [%[din1]]\n" + "pld [%[din2]]\n" + "pld [%[din3]]\n" + + "vld1.32 {d12-d13}, [%[din0]]!\n" + "vld1.32 {d14-d15}, [%[din1]]!\n" + "vld1.32 {d16-d17}, [%[din2]]!\n" + "vld1.32 {d18-d19}, [%[din3]]!\n" + + "vbif q6, %q[zero], %q[mask]\n" // d0_1234 + "vbif q7, %q[zero], %q[mask]\n" // d1_1234 + "vbif q8, %q[zero], %q[mask]\n" // d2_1234 + "vbif q9, %q[zero], %q[mask]\n" // d3_1234 + + "vmul.f32 q14, q6, %e[wr0][1]\n" + "vmul.f32 q15, q7, %e[wr0][1]\n" + + "vmla.f32 q14, q7, %e[wr1][1]\n" + "vmla.f32 q15, q8, %e[wr1][1]\n" + + "vmla.f32 q14, q8, %e[wr2][1]\n" + "vmla.f32 q15, q9, %e[wr2][1]\n" + + "vext.32 q10, %q[zero], q6, #3\n" // d0_0123 + "vext.32 q11, %q[zero], q7, #3\n" // d1_0123 + "vext.32 q12, %q[zero], q8, #3\n" // d2_0123 + "vext.32 q13, %q[zero], q9, #3\n" // d3_0123 + + "vmla.f32 q14, q10, %e[wr0][0]\n" + "vmla.f32 q15, q11, %e[wr0][0]\n" + + "vmla.f32 q14, q11, %e[wr1][0]\n" + "vmla.f32 q15, q12, %e[wr1][0]\n" + + "vmla.f32 q14, q12, %e[wr2][0]\n" + "vmla.f32 q15, q13, %e[wr2][0]\n" + + "vext.32 q10, q6, %q[zero], #1\n" // d0_2340 + "vext.32 q11, q7, %q[zero], #1\n" // d1_2340 + "vext.32 q12, q8, %q[zero], #1\n" // d2_2340 + "vext.32 q13, q9, %q[zero], #1\n" // d3_2340 + + "vmla.f32 q14, q10, %f[wr0][0]\n" + "vmla.f32 q15, q11, %f[wr0][0]\n" + + "vmla.f32 q14, q11, %f[wr1][0]\n" + "vmla.f32 q15, q12, %f[wr1][0]\n" + + "vmla.f32 q14, q12, %f[wr2][0]\n" // out1 + "vmla.f32 q15, q13, %f[wr2][0]\n" // out2 + + "vadd.f32 q14, q14, %q[bias]\n" // out1 add bias + "vadd.f32 q15, q15, %q[bias]\n" // out2 add bias + + "pld [%[out1]]\n" + "pld [%[out2]]\n" + + "vmax.f32 q14, q14, %q[zero]\n" // out1 -> relu + "vmax.f32 q15, q15, %q[zero]\n" // out2 -> relu + + "vst1.32 {d28-d29}, [%[out1]]\n" + "vst1.32 {d30-d31}, [%[out2]]\n" + + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [zero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif // __aarch64__ + for (int w = 0; w < w_out; ++w) { + *doutr0++ = out_buf1[w]; + *doutr1++ = out_buf2[w]; + } + doutr0 = doutr1; + doutr1 += w_out; + hs += 2; + he += 2; + } // end of processing heights + } // end of processing channels + } // end of processing batchs +} + +/** + * \brief depthwise convolution kernel 3x3, stride 2, width <= 7 + */ +void conv_depthwise_3x3s2p1_bias_s_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + float zeros[8] = {0.0f}; + + uint32x4_t vmask_rp1 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + unsigned int dmask[8]; + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float bias_c = 0.f; + + if (flag_bias) { + bias_c = bias[i]; + } + float32x4_t vbias = vdupq_n_f32(bias_c); + int hs = -1; + int he = 2; + float out_buf[4]; + for (int j = 0; j < h_out; ++j) { + const float* dr0 = din_channel + hs * w_in; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + if (hs == -1) { + dr0 = zeros; + } + if (he > h_in) { + dr2 = zeros; + } + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + + unsigned int* mask_ptr = dmask; +#ifdef __aarch64__ + asm volatile( + // Load up 12 elements (3 vectors) from each of 8 sources. + "movi v9.4s, #0 \n" + "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" + + "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" // v10={0,2,4,6} + // v11={1,3,5,7} + "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" // v13={0,2,4,6} + // v12={1,3,5,7} + "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" // v14={0,2,4,6} + // v15={1,3,5,7} + + "bif v10.16b, v9.16b, v6.16b \n" + "bif v11.16b, v9.16b, v7.16b \n" + "bif v12.16b, v9.16b, v6.16b \n" + "bif v13.16b, v9.16b, v7.16b \n" + "bif v14.16b, v9.16b, v6.16b \n" + "bif v15.16b, v9.16b, v7.16b \n" + + "ext v6.16b, v9.16b, v11.16b, #12 \n" // v6 = + // {0,1,3,5} + "ext v7.16b, v9.16b, v13.16b, #12 \n" // v7 = + // {0,1,3,5} + "ext v8.16b, v9.16b, v15.16b, #12 \n" // v8 = + // {0,1,3,5} + + "fmul v4.4s, v10.4s, %[wr0].s[1] \n" // v10 * w01 + "fmul v5.4s, v11.4s, %[wr0].s[2] \n" // v11 * w02 + "fmul v6.4s, v6.4s, %[wr0].s[0] \n" // v6 * w00 + + "fmla v4.4s, v12.4s, %[wr1].s[1] \n" // v12 * w11 + "fmla v5.4s, v13.4s, %[wr1].s[2] \n" // v13 * w12 + "fmla v6.4s, v7.4s, %[wr1].s[0] \n" // v7 * w10 + + "fmla v4.4s, v14.4s, %[wr2].s[1] \n" // v14 * w20 + "fmla v5.4s, v15.4s, %[wr2].s[2] \n" // v15 * w21 + "fmla v6.4s, v8.4s, %[wr2].s[0] \n" // v8 * w22 + + "fadd v4.4s, v4.4s, v5.4s \n" + "fadd v4.4s, v4.4s, v6.4s \n" + + "fadd v4.4s, v4.4s, %[bias].4s \n" // out add bias + "fmax v4.4s, v4.4s, v9.4s \n" + + "st1 {v4.4s}, [%[out]] \n" + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [out] "r"(out_buf) + : "cc", + "memory", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + +#else + asm volatile( + // Load up 12 elements (3 vectors) from each of 8 sources. + "vmov.u32 q9, #0 \n" + "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" + "vdup.32 q3, %[bias] @ and \n" // q3 = + // vbias + + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // q10={0,2,4,6} q11={1,3,5,7} + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // q13={0,2,4,6} q12={1,3,5,7} + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // q14={0,2,4,6} q15={1,3,5,7} + + "vbif q10, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q11, q9, q7 @ bit select, deal " + "with right pad\n" + "vbif q12, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q13, q9, q7 @ bit select, deal " + "with right pad\n" + "vbif q14, q9, q6 @ bit select, deal " + "with right pad\n" + "vbif q15, q9, q7 @ bit select, deal " + "with right pad\n" + + "vext.32 q6, q9, q11, #3 @ shift left 1 \n" // q6 = {0,1,3,5} + "vext.32 q7, q9, q13, #3 @ shift left 1 \n" // q7 = {0,1,3,5} + "vext.32 q8, q9, q15, #3 @ shift left 1 \n" // q8 = {0,1,3,5} + + "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 0, " + "out0\n" // q10 * w01 + "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 0, " + "out0\n" // q11 * w02 + "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 0, " + "out0\n" // q6 * w00 + + "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, " + "out0\n" // q12 * w11 + "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, " + "out0\n" // q13 * w12 + "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, " + "out0\n" // q7 * w10 + + "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, " + "out0\n" // q14 * w20 + "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, " + "out0\n" // q15 * w21 + "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, " + "out0\n" // q8 * w22 + + "vadd.f32 q3, q3, q4 @ add \n" + "vadd.f32 q3, q3, q5 @ add \n" + + "vmax.f32 q3, q3, q9 @ relu\n" + + "vst1.32 {d6-d7}, [%[out]] \n" + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif // __aarch64__ + for (int w = 0; w < w_out; ++w) { + *dout_channel++ = out_buf[w]; + } + hs += 2; + he += 2; + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_depthwise_5x5s1.cc b/lite/arm/math/conv_depthwise_5x5s1.cc new file mode 100644 index 00000000000..ac0fa08c8a2 --- /dev/null +++ b/lite/arm/math/conv_depthwise_5x5s1.cc @@ -0,0 +1,9615 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/conv_depthwise.h" +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +//! weights layout +//! *-----------------------*-----* +//! w0 <-- | W0 W1 W2 W3 | W4 | +//! *-----------------------* | +//! w1 <-- | W5 W6 W7 W8 | W9 | +//! *-----------------------* | --> w5 +//! w2 <-- | W10 W11 W12 W13 | W14 | +//! *-----------------------* | +//! w3 <-- | W15 W16 W17 W18 | W19 | +//! *-----------------------*-----* +//! w4 <-- | W20 W21 W22 W23 | W24 | --> w6[0] +//! *-----------------------*-----* + +void conv_depthwise_5x5s1_impl(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); + +void conv_depthwise_5x5s1_small_impl(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); + +void conv_depthwise_5x5s1_relu_impl(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); + +void conv_depthwise_5x5s1_small_relu_impl(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); + +static float* prepad_input( + const float* input, int num, int ch_in, int h_in, int w_in, int pad) { + int h_new = h_in + 2 * pad; + int w_new = w_in + 2 * pad; + float* new_input = + static_cast(malloc(h_new * w_new * ch_in * num * sizeof(float))); + float* new_input_ptr = new_input; + for (int c = 0; c < num * ch_in; ++c) { + memset(new_input_ptr, 0x00, w_new * pad * sizeof(float)); + new_input_ptr += w_new * pad; + for (int i = 0; i < h_in; ++i) { + memset(new_input_ptr, 0x00, pad * sizeof(float)); + new_input_ptr += pad; + memcpy(new_input_ptr, input, w_in * sizeof(float)); + new_input_ptr += w_in; + input += w_in; + memset(new_input_ptr, 0x00, pad * sizeof(float)); + new_input_ptr += pad; + } + memset(new_input_ptr, 0x00, w_new * pad * sizeof(float)); + new_input_ptr += w_new * pad; + } + return new_input; +} + +#ifdef __aarch64__ + +//! kernel for one out without extracting data mid +//! deal with four lines out +void compute_one_out_without_extract(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + float32x4_t w5, + float32x4_t w6, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! din0 - din7: 5 v20, v21 + //! dout0 - dout3: v16-v19 + asm volatile( + "ld1 {v8.4s}, [%[din0]], #16 \n" + "ld1 {v9.4s}, [%[din1]], #16 \n" + "ld1 {v10.4s}, [%[din2]], #16 \n" + "ld1 {v11.4s}, [%[din3]], #16 \n" + "ld1 {v12.4s}, [%[din4]], #16 \n" + "ld1 {v13.4s}, [%[din5]], #16 \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]], #16 \n" + "ld1 {v15.4s}, [%[din7]], #16 \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + "ld1 {v20.s}[0], [%[din0]] \n" + "ld1 {v21.s}[0], [%[din4]] \n" + "ld1 {v20.s}[1], [%[din1]] \n" + "ld1 {v21.s}[1], [%[din5]] \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + "ld1 {v20.s}[2], [%[din2]] \n" + "ld1 {v21.s}[2], [%[din6]] \n" + "ld1 {v20.s}[3], [%[din3]] \n" + "ld1 {v21.s}[3], [%[din7]] \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // ext + "ext v22.16b, v20.16b, v21.16b, #4 \n" // 1 2 3 4 + "ext v23.16b, v20.16b, v21.16b, #8 \n" // 2 3 4 5 + "ext v24.16b, v20.16b, v21.16b, #12 \n" // 3 4 5 6 + + // in col5 + "fmla v16.4s, %[w5].4s, v20.4s \n" + "fmla v17.4s, %[w5].4s, v22.4s \n" + "fmla v18.4s, %[w5].4s, v23.4s \n" + "fmla v19.4s, %[w5].4s, v24.4s \n" + + "ld1 {v31.4s}, [%[bias]] \n" + + // add to out register v25 + "faddp v25.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v25.4s, v25.4s, v26.4s \n" + + // in[24] * w6[0] + "fmla v25.4s, v21.4s, %[w6].s[0]\n" + "fadd v25.4s, v25.4s, v31.4s \n" + + // write output + "st1 {v25.s}[0], [%[dout0]] \n" + "st1 {v25.s}[1], [%[dout1]] \n" + "st1 {v25.s}[2], [%[dout2]] \n" + "st1 {v25.s}[3], [%[dout3]] \n" + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7) + : [dout0] "r"(dout0), + [dout1] "r"(dout1), + [dout2] "r"(dout2), + [dout3] "r"(dout3), + [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [bias] "r"(bias) + : "memory", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v31"); +} + +//! kernel for one out without extracting data mid +//! deal with four lines out +void compute_one_out_without_extract_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + float32x4_t w5, + float32x4_t w6, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! din0 - din7: 5 v20, v21 + //! dout0 - dout3: v16-v19 + asm volatile( + "ld1 {v8.4s}, [%[din0]], #16 \n" + "ld1 {v9.4s}, [%[din1]], #16 \n" + "ld1 {v10.4s}, [%[din2]], #16 \n" + "ld1 {v11.4s}, [%[din3]], #16 \n" + "ld1 {v12.4s}, [%[din4]], #16 \n" + "ld1 {v13.4s}, [%[din5]], #16 \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]], #16 \n" + "ld1 {v15.4s}, [%[din7]], #16 \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + "ld1 {v20.s}[0], [%[din0]] \n" + "ld1 {v21.s}[0], [%[din4]] \n" + "ld1 {v20.s}[1], [%[din1]] \n" + "ld1 {v21.s}[1], [%[din5]] \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + "ld1 {v20.s}[2], [%[din2]] \n" + "ld1 {v21.s}[2], [%[din6]] \n" + "ld1 {v20.s}[3], [%[din3]] \n" + "ld1 {v21.s}[3], [%[din7]] \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // ext + "ext v22.16b, v20.16b, v21.16b, #4 \n" // 1 2 3 4 + "ext v23.16b, v20.16b, v21.16b, #8 \n" // 2 3 4 5 + "ext v24.16b, v20.16b, v21.16b, #12 \n" // 3 4 5 6 + + // in col5 + "fmla v16.4s, %[w5].4s, v20.4s \n" + "fmla v17.4s, %[w5].4s, v22.4s \n" + "fmla v18.4s, %[w5].4s, v23.4s \n" + "fmla v19.4s, %[w5].4s, v24.4s \n" + + "ld1 {v31.4s}, [%[bias]] \n" + "movi v30.4s, #0 \n" + + // add to out register v25 + "faddp v25.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v25.4s, v25.4s, v26.4s \n" + + // in[24] * w6[0] + "fmla v25.4s, v21.4s, %[w6].s[0] \n" + "fadd v25.4s, v25.4s, v31.4s \n" + "fmax v25.4s, v25.4s, v30.4s \n" + + // write output + "st1 {v25.s}[0], [%[dout0]] \n" + "st1 {v25.s}[1], [%[dout1]] \n" + "st1 {v25.s}[2], [%[dout2]] \n" + "st1 {v25.s}[3], [%[dout3]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7) + : [dout0] "r"(dout0), + [dout1] "r"(dout1), + [dout2] "r"(dout2), + [dout3] "r"(dout3), + [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [bias] "r"(bias) + : "memory", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v30", + "v31"); +} + +//! kernel for one out with extracting data pre +//! deal with four lines out +//! need extra load weights +void compute_one_out_extract_pre(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + const float* weights, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! dout0 - dout3: v16-v19 + //! weights: v0-v4 + asm volatile( + // load weights + "add %[wh], %[wh], #4 \n" + "ldr q0, [%[wh]], #20 \n" + "ldr q1, [%[wh]], #20 \n" + "ldr q2, [%[wh]], #20 \n" + "ldr q3, [%[wh]], #20 \n" + "ldr q4, [%[wh]], #20 \n" + + "ld1 {v31.4s}, [%[bias]] \n" + "ld1 {v8.4s}, [%[din0]], #16 \n" + "ld1 {v9.4s}, [%[din1]], #16 \n" + "ld1 {v10.4s}, [%[din2]], #16 \n" + "ld1 {v11.4s}, [%[din3]], #16 \n" + "ld1 {v12.4s}, [%[din4]], #16 \n" + "ld1 {v13.4s}, [%[din5]], #16 \n" + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]], #16 \n" + "ld1 {v15.4s}, [%[din7]], #16 \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v25 + "faddp v25.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v25.4s, v25.4s, v26.4s \n" + "fadd v25.4s, v25.4s, v31.4s \n" + + // write output + "st1 {v25.s}[0], [%[dout0]] \n" + "st1 {v25.s}[1], [%[dout1]] \n" + "st1 {v25.s}[2], [%[dout2]] \n" + "st1 {v25.s}[3], [%[dout3]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7), + [wh] "+r"(weights) + : [dout0] "r"(dout0), + [dout1] "r"(dout1), + [dout2] "r"(dout2), + [dout3] "r"(dout3), + [bias] "r"(bias) + : "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v25", + "v26", + "v31"); +} + +//! kernel for one out with extracting data pre +//! deal with four lines out +//! need extra load weights +void compute_one_out_extract_pre_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + const float* weights, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! dout0 - dout3: v16-v19 + //! weights: v0-v4 + asm volatile( + // load weights + "add %[wh], %[wh], #4 \n" + "ldr q0, [%[wh]], #20 \n" + "ldr q1, [%[wh]], #20 \n" + "ldr q2, [%[wh]], #20 \n" + "ldr q3, [%[wh]], #20 \n" + "ldr q4, [%[wh]], #20 \n" + + "ld1 {v8.4s}, [%[din0]], #16 \n" + "ld1 {v9.4s}, [%[din1]], #16 \n" + "ld1 {v10.4s}, [%[din2]], #16 \n" + "ld1 {v11.4s}, [%[din3]], #16 \n" + "ld1 {v12.4s}, [%[din4]], #16 \n" + "ld1 {v13.4s}, [%[din5]], #16 \n" + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]], #16 \n" + "ld1 {v15.4s}, [%[din7]], #16 \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + "ld1 {v31.4s}, [%[bias]] \n" + "movi v30.4s, #0 \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v25 + "faddp v25.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v25.4s, v25.4s, v26.4s \n" + "fadd v25.4s, v25.4s, v31.4s \n" + "fmax v25.4s, v25.4s, v30.4s \n" + + // write output + "st1 {v25.s}[0], [%[dout0]] \n" + "st1 {v25.s}[1], [%[dout1]] \n" + "st1 {v25.s}[2], [%[dout2]] \n" + "st1 {v25.s}[3], [%[dout3]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7), + [wh] "+r"(weights) + : [dout0] "r"(dout0), + [dout1] "r"(dout1), + [dout2] "r"(dout2), + [dout3] "r"(dout3), + [bias] "r"(bias) + : "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v25", + "v26", + "v30", + "v31"); +} + +//! kernel for one out with extracting data post +//! deal with four lines out +void compute_one_out_extract_post(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! dout0 - dout3: v16-v19 + asm volatile( + "ld1 {v31.4s}, [%[bias]] \n" + "ld1 {v8.4s}, [%[din0]], #16 \n" + "ld1 {v9.4s}, [%[din1]], #16 \n" + "ld1 {v10.4s}, [%[din2]], #16 \n" + "ld1 {v11.4s}, [%[din3]], #16 \n" + "ld1 {v12.4s}, [%[din4]], #16 \n" + "ld1 {v13.4s}, [%[din5]], #16 \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]], #16 \n" + "ld1 {v15.4s}, [%[din7]], #16 \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v25 + "faddp v25.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v25.4s, v25.4s, v26.4s \n" + "fadd v25.4s, v25.4s, v31.4s \n" + + // write output + "st1 {v25.s}[0], [%[dout0]] \n" + "st1 {v25.s}[1], [%[dout1]] \n" + "st1 {v25.s}[2], [%[dout2]] \n" + "st1 {v25.s}[3], [%[dout3]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7) + : [dout0] "r"(dout0), + [dout1] "r"(dout1), + [dout2] "r"(dout2), + [dout3] "r"(dout3), + [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [bias] "r"(bias) + : "memory", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v25", + "v26", + "v31"); +} + +//! kernel for one out with extracting data post +//! deal with four lines out +void compute_one_out_extract_post_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! dout0 - dout3: v16-v19 + asm volatile( + "ld1 {v8.4s}, [%[din0]], #16 \n" + "ld1 {v9.4s}, [%[din1]], #16 \n" + "ld1 {v10.4s}, [%[din2]], #16 \n" + "ld1 {v11.4s}, [%[din3]], #16 \n" + "ld1 {v12.4s}, [%[din4]], #16 \n" + "ld1 {v13.4s}, [%[din5]], #16 \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]], #16 \n" + "ld1 {v15.4s}, [%[din7]], #16 \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + "ld1 {v31.4s}, [%[bias]] \n" + "movi v30.4s, #0 \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v25 + "faddp v25.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v25.4s, v25.4s, v26.4s \n" + "fadd v25.4s, v25.4s, v31.4s \n" + "fmax v25.4s, v25.4s, v30.4s \n" + + // write output + "st1 {v25.s}[0], [%[dout0]] \n" + "st1 {v25.s}[1], [%[dout1]] \n" + "st1 {v25.s}[2], [%[dout2]] \n" + "st1 {v25.s}[3], [%[dout3]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7) + : [dout0] "r"(dout0), + [dout1] "r"(dout1), + [dout2] "r"(dout2), + [dout3] "r"(dout3), + [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [bias] "r"(bias) + : "memory", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v25", + "v26", + "v30", + "v31"); +} + +//! kernel for two out with extracting data pre +//! deal with four lines out +//! need extra load weights +void compute_two_out_extract_pre(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + const float* weights, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! dout0 - dout3: v16-v19 + //! weights: v0-v4 + asm volatile( + // load weights + "movi v31.4s, #0 \n" + "add %[wh], %[wh], #4 \n" + "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 + "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 + "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 + "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 + "ldr q4, [%[wh]], #20 \n" // 21, 22, 23, 24 + + // load inputs + "ld1 {v20.4s}, [%[bias]] \n" + "ld1 {v8.4s}, [%[din0]], #16 \n" + "ld1 {v9.4s}, [%[din1]], #16 \n" + "ld1 {v10.4s}, [%[din2]], #16 \n" + "ld1 {v11.4s}, [%[din3]], #16 \n" + "ld1 {v12.4s}, [%[din4]], #16 \n" + "ld1 {v13.4s}, [%[din5]], #16 \n" + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]], #16 \n" + "ld1 {v15.4s}, [%[din7]], #16 \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v5 + "faddp v5.4s, v16.4s, v17.4s \n" + "faddp v6.4s, v18.4s, v19.4s \n" + "faddp v5.4s, v5.4s, v6.4s \n" + + // ext weights + "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 + "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 + "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 + "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 + "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v7 + "faddp v7.4s, v16.4s, v17.4s \n" + "faddp v8.4s, v18.4s, v19.4s \n" + "faddp v7.4s, v7.4s, v8.4s \n" + + // zip + "zip1 v6.4s, v7.4s, v5.4s \n" + "zip2 v8.4s, v7.4s, v5.4s \n" + "fadd v6.4s, v6.4s, v20.4s \n" + "fadd v8.4s, v8.4s, v20.4s \n" + "ext v7.16b, v6.16b, v31.16b, #8 \n" + "ext v9.16b, v8.16b, v31.16b, #8 \n" + + // write output + "str d6, [%[dout0]] \n" + "str d7, [%[dout1]] \n" + "str d8, [%[dout2]] \n" + "str d9, [%[dout3]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7), + [wh] "+r"(weights) + : [dout0] "r"(dout0), + [dout1] "r"(dout1), + [dout2] "r"(dout2), + [dout3] "r"(dout3), + [bias] "r"(bias) + : "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v31"); +} + +//! kernel for two out with extracting data pre +//! deal with four lines out +//! need extra load weights +void compute_two_out_extract_pre_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + const float* weights, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! dout0 - dout3: v16-v19 + //! weights: v0-v4 + asm volatile( + // load weights + "movi v31.4s, #0 \n" + "add %[wh], %[wh], #4 \n" + "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 + "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 + "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 + "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 + "ldr q4, [%[wh]], #20 \n" // 21, 22, 23, 24 + + // load inputs + "ld1 {v20.4s}, [%[bias]] \n" + "ld1 {v8.4s}, [%[din0]], #16 \n" + "ld1 {v9.4s}, [%[din1]], #16 \n" + "ld1 {v10.4s}, [%[din2]], #16 \n" + "ld1 {v11.4s}, [%[din3]], #16 \n" + "ld1 {v12.4s}, [%[din4]], #16 \n" + "ld1 {v13.4s}, [%[din5]], #16 \n" + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]], #16 \n" + "ld1 {v15.4s}, [%[din7]], #16 \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v5 + "faddp v5.4s, v16.4s, v17.4s \n" + "faddp v6.4s, v18.4s, v19.4s \n" + "faddp v5.4s, v5.4s, v6.4s \n" + + // ext weights + "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 + "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 + "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 + "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 + "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v7 + "faddp v7.4s, v16.4s, v17.4s \n" + "faddp v8.4s, v18.4s, v19.4s \n" + "faddp v7.4s, v7.4s, v8.4s \n" + + // zip + "zip1 v6.4s, v7.4s, v5.4s \n" + "zip2 v8.4s, v7.4s, v5.4s \n" + + // add bias + "fadd v6.4s, v6.4s, v20.4s \n" + "fadd v8.4s, v8.4s, v20.4s \n" + + // relu + "fmax v6.4s, v6.4s, v31.4s \n" + "fmax v8.4s, v8.4s, v31.4s \n" + + "ext v7.16b, v6.16b, v31.16b, #8 \n" + "ext v9.16b, v8.16b, v31.16b, #8 \n" + + // write output + "str d6, [%[dout0]] \n" + "str d7, [%[dout1]] \n" + "str d8, [%[dout2]] \n" + "str d9, [%[dout3]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7), + [wh] "+r"(weights) + : [dout0] "r"(dout0), + [dout1] "r"(dout1), + [dout2] "r"(dout2), + [dout3] "r"(dout3), + [bias] "r"(bias) + : "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v31"); +} + +//! kernel for two out with extracting data post +//! deal with four lines out +void compute_two_out_extract_post(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! dout0 - dout3: v16-v19 + asm volatile( + "movi v31.4s, #0 \n" + + // load inputs + "ld1 {v20.4s}, [%[bias]] \n" + "ld1 {v8.4s}, [%[din0]], #16 \n" + "ld1 {v9.4s}, [%[din1]], #16 \n" + "ld1 {v10.4s}, [%[din2]], #16 \n" + "ld1 {v11.4s}, [%[din3]], #16 \n" + "ld1 {v12.4s}, [%[din4]], #16 \n" + "ld1 {v13.4s}, [%[din5]], #16 \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]], #16 \n" + "ld1 {v15.4s}, [%[din7]], #16 \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v5 + "faddp v5.4s, v16.4s, v17.4s \n" + "faddp v6.4s, v18.4s, v19.4s \n" + "faddp v5.4s, v5.4s, v6.4s \n" + + // ext input + "ext v8.16b, v8.16b, v31.16b, #4 \n" + "ext v9.16b, v9.16b, v31.16b, #4 \n" + "ext v10.16b, v10.16b, v31.16b, #4 \n" + "ext v11.16b, v11.16b, v31.16b, #4 \n" + "ext v12.16b, v12.16b, v31.16b, #4 \n" + "ext v13.16b, v13.16b, v31.16b, #4 \n" + "ext v14.16b, v14.16b, v31.16b, #4 \n" + "ext v15.16b, v15.16b, v31.16b, #4 \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v7 + "faddp v7.4s, v16.4s, v17.4s \n" + "faddp v8.4s, v18.4s, v19.4s \n" + "faddp v7.4s, v7.4s, v8.4s \n" + + // zip + "zip1 v6.4s, v5.4s, v7.4s \n" + "zip2 v8.4s, v5.4s, v7.4s \n" + "fadd v6.4s, v6.4s, v20.4s \n" + "fadd v8.4s, v8.4s, v20.4s \n" + "ext v7.16b, v6.16b, v31.16b, #8 \n" + "ext v9.16b, v8.16b, v31.16b, #8 \n" + + // write output + "str d6, [%[dout0]] \n" + "str d7, [%[dout1]] \n" + "str d8, [%[dout2]] \n" + "str d9, [%[dout3]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7) + : [dout0] "r"(dout0), + [dout1] "r"(dout1), + [dout2] "r"(dout2), + [dout3] "r"(dout3), + [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [bias] "r"(bias) + : "memory", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v31"); +} + +//! kernel for two out with extracting data post +//! deal with four lines out +void compute_two_out_extract_post_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! dout0 - dout3: v16-v19 + asm volatile( + "movi v31.4s, #0 \n" + + // load inputs + "ld1 {v20.4s}, [%[bias]] \n" + "ld1 {v8.4s}, [%[din0]], #16 \n" + "ld1 {v9.4s}, [%[din1]], #16 \n" + "ld1 {v10.4s}, [%[din2]], #16 \n" + "ld1 {v11.4s}, [%[din3]], #16 \n" + "ld1 {v12.4s}, [%[din4]], #16 \n" + "ld1 {v13.4s}, [%[din5]], #16 \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]], #16 \n" + "ld1 {v15.4s}, [%[din7]], #16 \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v5 + "faddp v5.4s, v16.4s, v17.4s \n" + "faddp v6.4s, v18.4s, v19.4s \n" + "faddp v5.4s, v5.4s, v6.4s \n" + + // ext input + "ext v8.16b, v8.16b, v31.16b, #4 \n" + "ext v9.16b, v9.16b, v31.16b, #4 \n" + "ext v10.16b, v10.16b, v31.16b, #4 \n" + "ext v11.16b, v11.16b, v31.16b, #4 \n" + "ext v12.16b, v12.16b, v31.16b, #4 \n" + "ext v13.16b, v13.16b, v31.16b, #4 \n" + "ext v14.16b, v14.16b, v31.16b, #4 \n" + "ext v15.16b, v15.16b, v31.16b, #4 \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v7 + "faddp v7.4s, v16.4s, v17.4s \n" + "faddp v8.4s, v18.4s, v19.4s \n" + "faddp v7.4s, v7.4s, v8.4s \n" + + // zip + "zip1 v6.4s, v5.4s, v7.4s \n" + "zip2 v8.4s, v5.4s, v7.4s \n" + + // add bias + "fadd v6.4s, v6.4s, v20.4s \n" + "fadd v8.4s, v8.4s, v20.4s \n" + + // relu + "fmax v6.4s, v6.4s, v31.4s \n" + "fmax v8.4s, v8.4s, v31.4s \n" + "ext v7.16b, v6.16b, v31.16b, #8 \n" + "ext v9.16b, v8.16b, v31.16b, #8 \n" + + // write output + "str d6, [%[dout0]] \n" + "str d7, [%[dout1]] \n" + "str d8, [%[dout2]] \n" + "str d9, [%[dout3]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7) + : [dout0] "r"(dout0), + [dout1] "r"(dout1), + [dout2] "r"(dout2), + [dout3] "r"(dout3), + [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [bias] "r"(bias) + : "memory", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v31"); +} + +//! kernel for three out with extracting data pre +//! deal with four lines out +//! need extra load weights +void compute_three_out_extract_pre(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + const float* weights, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! dout0 - dout3: v16-v19 + //! weights: v0-v4 + asm volatile( + // load weights + "movi v31.4s, #0 \n" + "add %[wh], %[wh], #4 \n" + "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 + "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 + "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 + "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 + "ldr q4, [%[wh]], #20 \n" // 21, 22, 23, 24 + + // load inputs + "ld1 {v20.4s}, [%[bias]] \n" + "ld1 {v8.4s}, [%[din0]], #16 \n" + "ld1 {v9.4s}, [%[din1]], #16 \n" + "ld1 {v10.4s}, [%[din2]], #16 \n" + "ld1 {v11.4s}, [%[din3]], #16 \n" + "ld1 {v12.4s}, [%[din4]], #16 \n" + "ld1 {v13.4s}, [%[din5]], #16 \n" + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]], #16 \n" + "ld1 {v15.4s}, [%[din7]], #16 \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v5 + "faddp v5.4s, v16.4s, v17.4s \n" + "faddp v6.4s, v18.4s, v19.4s \n" + "faddp v5.4s, v5.4s, v6.4s \n" + + // ext weights + "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 + "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 + "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 + "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 + "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v7 + "faddp v7.4s, v16.4s, v17.4s \n" + "faddp v6.4s, v18.4s, v19.4s \n" + "faddp v7.4s, v7.4s, v6.4s \n" + + // ext weights + "ext v0.16b, v0.16b, v31.16b, #4 \n" // 3, 4 + "ext v1.16b, v1.16b, v31.16b, #4 \n" // 8, 9 + "ext v2.16b, v2.16b, v31.16b, #4 \n" // 13, 14 + "ext v3.16b, v3.16b, v31.16b, #4 \n" // 18, 19 + "ext v4.16b, v4.16b, v31.16b, #4 \n" // 23, 24 + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v25 + "faddp v25.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v25.4s, v25.4s, v26.4s \n" + "fadd v25.4s, v25.4s, v20.4s \n" + + // zip + "zip1 v6.4s, v7.4s, v5.4s \n" + "zip2 v8.4s, v7.4s, v5.4s \n" + "fadd v6.4s, v6.4s, v20.4s \n" + "fadd v8.4s, v8.4s, v20.4s \n" + "ext v7.16b, v6.16b, v31.16b, #8 \n" + "ext v9.16b, v8.16b, v31.16b, #8 \n" + + // write output + "st1 {v25.s}[0], [%[dout0]], #4 \n" + "st1 {v25.s}[1], [%[dout1]], #4 \n" + "st1 {v25.s}[2], [%[dout2]], #4 \n" + "st1 {v25.s}[3], [%[dout3]], #4 \n" + + "str d6, [%[dout0]] \n" + "str d7, [%[dout1]] \n" + "str d8, [%[dout2]] \n" + "str d9, [%[dout3]] \n" + + : [dout0] "+r"(dout0), + [dout1] "+r"(dout1), + [dout2] "+r"(dout2), + [dout3] "+r"(dout3), + [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7), + [wh] "+r"(weights) + : [bias] "r"(bias) + : "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v25", + "v26", + "v31"); +} + +//! kernel for three out with extracting data pre +//! deal with four lines out +//! need extra load weights +void compute_three_out_extract_pre_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + const float* weights, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! dout0 - dout3: v16-v19 + //! weights: v0-v4 + asm volatile( + // load weights + "movi v31.4s, #0 \n" + "add %[wh], %[wh], #4 \n" + "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 + "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 + "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 + "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 + "ldr q4, [%[wh]], #20 \n" // 21, 22, 23, 24 + + // load inputs + "ld1 {v20.4s}, [%[bias]] \n" + "ld1 {v8.4s}, [%[din0]], #16 \n" + "ld1 {v9.4s}, [%[din1]], #16 \n" + "ld1 {v10.4s}, [%[din2]], #16 \n" + "ld1 {v11.4s}, [%[din3]], #16 \n" + "ld1 {v12.4s}, [%[din4]], #16 \n" + "ld1 {v13.4s}, [%[din5]], #16 \n" + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]], #16 \n" + "ld1 {v15.4s}, [%[din7]], #16 \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v5 + "faddp v5.4s, v16.4s, v17.4s \n" + "faddp v6.4s, v18.4s, v19.4s \n" + "faddp v5.4s, v5.4s, v6.4s \n" + + // ext weights + "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 + "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 + "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 + "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 + "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v7 + "faddp v7.4s, v16.4s, v17.4s \n" + "faddp v6.4s, v18.4s, v19.4s \n" + "faddp v7.4s, v7.4s, v6.4s \n" + + // ext weights + "ext v0.16b, v0.16b, v31.16b, #4 \n" // 3, 4 + "ext v1.16b, v1.16b, v31.16b, #4 \n" // 8, 9 + "ext v2.16b, v2.16b, v31.16b, #4 \n" // 13, 14 + "ext v3.16b, v3.16b, v31.16b, #4 \n" // 18, 19 + "ext v4.16b, v4.16b, v31.16b, #4 \n" // 23, 24 + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v25 + "faddp v25.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v25.4s, v25.4s, v26.4s \n" + "fadd v25.4s, v25.4s, v20.4s \n" + "fmax v25.4s, v25.4s, v31.4s \n" + + // zip + "zip1 v6.4s, v7.4s, v5.4s \n" + "zip2 v8.4s, v7.4s, v5.4s \n" + + // add bias + "fadd v6.4s, v6.4s, v20.4s \n" + "fadd v8.4s, v8.4s, v20.4s \n" + + // relu + "fmax v6.4s, v6.4s, v31.4s \n" + "fmax v8.4s, v8.4s, v31.4s \n" + + "ext v7.16b, v6.16b, v31.16b, #8 \n" + "ext v9.16b, v8.16b, v31.16b, #8 \n" + + // write output + "st1 {v25.s}[0], [%[dout0]], #4 \n" + "st1 {v25.s}[1], [%[dout1]], #4 \n" + "st1 {v25.s}[2], [%[dout2]], #4 \n" + "st1 {v25.s}[3], [%[dout3]], #4 \n" + + "str d6, [%[dout0]] \n" + "str d7, [%[dout1]] \n" + "str d8, [%[dout2]] \n" + "str d9, [%[dout3]] \n" + + : [dout0] "+r"(dout0), + [dout1] "+r"(dout1), + [dout2] "+r"(dout2), + [dout3] "+r"(dout3), + [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7), + [wh] "+r"(weights) + : [bias] "r"(bias) + : "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v25", + "v26", + "v31"); +} + +//! kernel for three out with extracting data post +//! deal with four lines out +void compute_three_out_extract_post(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! dout0 - dout3: v6, v8, v25 + asm volatile( + "movi v31.4s, #0 \n" + // load inputs + "ld1 {v20.4s}, [%[bias]] \n" + "ld1 {v8.4s}, [%[din0]], #16 \n" + "ld1 {v9.4s}, [%[din1]], #16 \n" + "ld1 {v10.4s}, [%[din2]], #16 \n" + "ld1 {v11.4s}, [%[din3]], #16 \n" + "ld1 {v12.4s}, [%[din4]], #16 \n" + "ld1 {v13.4s}, [%[din5]], #16 \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]], #16 \n" + "ld1 {v15.4s}, [%[din7]], #16 \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v5 + "faddp v5.4s, v16.4s, v17.4s \n" + "faddp v6.4s, v18.4s, v19.4s \n" + "faddp v5.4s, v5.4s, v6.4s \n" + + // ext input + "ext v8.16b, v8.16b, v31.16b, #4 \n" + "ext v9.16b, v9.16b, v31.16b, #4 \n" + "ext v10.16b, v10.16b, v31.16b, #4 \n" + "ext v11.16b, v11.16b, v31.16b, #4 \n" + "ext v12.16b, v12.16b, v31.16b, #4 \n" + "ext v13.16b, v13.16b, v31.16b, #4 \n" + "ext v14.16b, v14.16b, v31.16b, #4 \n" + "ext v15.16b, v15.16b, v31.16b, #4 \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v7 + "faddp v7.4s, v16.4s, v17.4s \n" + "faddp v6.4s, v18.4s, v19.4s \n" + "faddp v7.4s, v7.4s, v6.4s \n" + + // ext input + "ext v8.16b, v8.16b, v31.16b, #4 \n" + "ext v9.16b, v9.16b, v31.16b, #4 \n" + "ext v10.16b, v10.16b, v31.16b, #4 \n" + "ext v11.16b, v11.16b, v31.16b, #4 \n" + "ext v12.16b, v12.16b, v31.16b, #4 \n" + "ext v13.16b, v13.16b, v31.16b, #4 \n" + "ext v14.16b, v14.16b, v31.16b, #4 \n" + "ext v15.16b, v15.16b, v31.16b, #4 \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v25 + "faddp v25.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v25.4s, v25.4s, v26.4s \n" + "fadd v25.4s, v25.4s, v20.4s \n" + + // zip + "zip1 v6.4s, v5.4s, v7.4s \n" + "zip2 v8.4s, v5.4s, v7.4s \n" + "fadd v6.4s, v6.4s, v20.4s \n" + "fadd v8.4s, v8.4s, v20.4s \n" + "ext v7.16b, v6.16b, v31.16b, #8 \n" + "ext v9.16b, v8.16b, v31.16b, #8 \n" + + // write output + "str d6, [%[dout0]], #8 \n" + "str d7, [%[dout1]], #8 \n" + "str d8, [%[dout2]], #8 \n" + "str d9, [%[dout3]], #8 \n" + + "st1 {v25.s}[0], [%[dout0]] \n" + "st1 {v25.s}[1], [%[dout1]] \n" + "st1 {v25.s}[2], [%[dout2]] \n" + "st1 {v25.s}[3], [%[dout3]] \n" + + : [dout0] "+r"(dout0), + [dout1] "+r"(dout1), + [dout2] "+r"(dout2), + [dout3] "+r"(dout3), + [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [bias] "r"(bias) + : "memory", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v25", + "v26", + "v31"); +} + +//! kernel for three out with extracting data post +//! deal with four lines out +void compute_three_out_extract_post_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! dout0 - dout3: v6, v8, v25 + asm volatile( + "movi v31.4s, #0 \n" + + // load inputs + "ld1 {v20.4s}, [%[bias]] \n" + "ld1 {v8.4s}, [%[din0]], #16 \n" + "ld1 {v9.4s}, [%[din1]], #16 \n" + "ld1 {v10.4s}, [%[din2]], #16 \n" + "ld1 {v11.4s}, [%[din3]], #16 \n" + "ld1 {v12.4s}, [%[din4]], #16 \n" + "ld1 {v13.4s}, [%[din5]], #16 \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]], #16 \n" + "ld1 {v15.4s}, [%[din7]], #16 \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v5 + "faddp v5.4s, v16.4s, v17.4s \n" + "faddp v6.4s, v18.4s, v19.4s \n" + "faddp v5.4s, v5.4s, v6.4s \n" + + // ext input + "ext v8.16b, v8.16b, v31.16b, #4 \n" + "ext v9.16b, v9.16b, v31.16b, #4 \n" + "ext v10.16b, v10.16b, v31.16b, #4 \n" + "ext v11.16b, v11.16b, v31.16b, #4 \n" + "ext v12.16b, v12.16b, v31.16b, #4 \n" + "ext v13.16b, v13.16b, v31.16b, #4 \n" + "ext v14.16b, v14.16b, v31.16b, #4 \n" + "ext v15.16b, v15.16b, v31.16b, #4 \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v7 + "faddp v7.4s, v16.4s, v17.4s \n" + "faddp v6.4s, v18.4s, v19.4s \n" + "faddp v7.4s, v7.4s, v6.4s \n" + + // ext input + "ext v8.16b, v8.16b, v31.16b, #4 \n" + "ext v9.16b, v9.16b, v31.16b, #4 \n" + "ext v10.16b, v10.16b, v31.16b, #4 \n" + "ext v11.16b, v11.16b, v31.16b, #4 \n" + "ext v12.16b, v12.16b, v31.16b, #4 \n" + "ext v13.16b, v13.16b, v31.16b, #4 \n" + "ext v14.16b, v14.16b, v31.16b, #4 \n" + "ext v15.16b, v15.16b, v31.16b, #4 \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v25 + "faddp v25.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v25.4s, v25.4s, v26.4s \n" + "fadd v25.4s, v25.4s, v20.4s \n" + "fmax v25.4s, v25.4s, v31.4s \n" + + // zip + "zip1 v6.4s, v5.4s, v7.4s \n" + "zip2 v8.4s, v5.4s, v7.4s \n" + + // add bias + "fadd v6.4s, v6.4s, v20.4s \n" + "fadd v8.4s, v8.4s, v20.4s \n" + + // relu + "fmax v6.4s, v6.4s, v31.4s \n" + "fmax v8.4s, v8.4s, v31.4s \n" + + "ext v7.16b, v6.16b, v31.16b, #8 \n" + "ext v9.16b, v8.16b, v31.16b, #8 \n" + + // write output + "str d6, [%[dout0]], #8 \n" + "str d7, [%[dout1]], #8 \n" + "str d8, [%[dout2]], #8 \n" + "str d9, [%[dout3]], #8 \n" + + "st1 {v25.s}[0], [%[dout0]] \n" + "st1 {v25.s}[1], [%[dout1]] \n" + "st1 {v25.s}[2], [%[dout2]] \n" + "st1 {v25.s}[3], [%[dout3]] \n" + + : [dout0] "+r"(dout0), + [dout1] "+r"(dout1), + [dout2] "+r"(dout2), + [dout3] "+r"(dout3), + [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [bias] "r"(bias) + : "memory", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v25", + "v26", + "v31"); +} + +//! kernel for four out with extracting data pre +//! deal with four lines out +//! need extra load weights +void compute_four_out_extract_pre(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + const float* weights, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! dout0 - dout3: v0-v3 + //! weights: v0-v4, v5, v6 + asm volatile( + // load weights + "movi v31.4s, #0 \n" + "mov x0, #20 \n" + "add %[wh], %[wh], #4 \n" + "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 + "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 + "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 + "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 + "ldr q4, [%[wh]] \n" // 21, 22, 23, 24 + "sub %[wh], %[wh], #68 \n" + + // load inputs + "ld1 {v8.4s}, [%[din0]] \n" + "ld1 {v9.4s}, [%[din1]] \n" + "ld1 {v10.4s}, [%[din2]] \n" + "ld1 {v11.4s}, [%[din3]] \n" + "ld1 {v12.4s}, [%[din4]] \n" + "ld1 {v13.4s}, [%[din5]] \n" + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]] \n" + "ld1 {v15.4s}, [%[din7]] \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v25 + "faddp v25.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v25.4s, v25.4s, v26.4s \n" + + // load weights col5 + "ld1 {v5.s}[0], [%[wh]], x0 \n" + "ld1 {v5.s}[1], [%[wh]], x0 \n" + "ld1 {v5.s}[2], [%[wh]], x0 \n" + "ld1 {v5.s}[3], [%[wh]], x0 \n" + "ld1 {v6.s}[0], [%[wh]] \n" + + // ext weights + "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 + "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 + "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 + "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 + "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v27 + "faddp v27.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v27.4s, v27.4s, v26.4s \n" + + // load in col5 + "ld1 {v20.s}[0], [%[din0]] \n" + "ld1 {v20.s}[1], [%[din1]] \n" + "ld1 {v20.s}[2], [%[din2]] \n" + "ld1 {v20.s}[3], [%[din3]] \n" + + // ext weights + "ext v0.16b, v0.16b, v31.16b, #4 \n" // 3, 4 + "ext v1.16b, v1.16b, v31.16b, #4 \n" // 8, 9 + "ext v2.16b, v2.16b, v31.16b, #4 \n" // 13, 14 + "ext v3.16b, v3.16b, v31.16b, #4 \n" // 18, 19 + "ext v4.16b, v4.16b, v31.16b, #4 \n" // 23, 24 + + "ld1 {v21.s}[0], [%[din4]] \n" + "ld1 {v21.s}[1], [%[din5]] \n" + "ld1 {v21.s}[2], [%[din6]] \n" + "ld1 {v21.s}[3], [%[din7]] \n" + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v26 + "faddp v26.4s, v16.4s, v17.4s \n" + "faddp v28.4s, v18.4s, v19.4s \n" + "faddp v26.4s, v26.4s, v28.4s \n" + + // ext input col5 + "ext v22.16b, v20.16b, v21.16b, #4 \n" + "ext v23.16b, v20.16b, v21.16b, #8 \n" + "ext v24.16b, v20.16b, v21.16b, #12 \n" + + // in col5 + "fmul v16.4s, v5.4s, v20.4s \n" + "fmul v17.4s, v5.4s, v22.4s \n" + "fmul v18.4s, v5.4s, v23.4s \n" + "fmul v19.4s, v5.4s, v24.4s \n" + + // add to out register v28 + "faddp v28.4s, v16.4s, v17.4s \n" + "faddp v29.4s, v18.4s, v19.4s \n" + "faddp v28.4s, v28.4s, v29.4s \n" + "fmla v28.4s, v21.4s, v6.s[0] \n" + + "ld1 {v8.4s}, [%[bias]] \n" + + // zip + "zip1 v0.4s, v28.4s, v26.4s \n" + "zip2 v2.4s, v28.4s, v26.4s \n" + "zip1 v4.4s, v27.4s, v25.4s \n" + "zip2 v6.4s, v27.4s, v25.4s \n" + + "fadd v0.4s, v0.4s, v8.4s \n" + "fadd v2.4s, v2.4s, v8.4s \n" + "fadd v4.4s, v4.4s, v8.4s \n" + "fadd v6.4s, v6.4s, v8.4s \n" + + "ext v1.16b, v0.16b, v31.16b, #8 \n" + "ext v3.16b, v2.16b, v31.16b, #8 \n" + "ext v5.16b, v4.16b, v31.16b, #8 \n" + "ext v7.16b, v6.16b, v31.16b, #8 \n" + + // write output + "str d0, [%[dout0]], #8 \n" + "str d1, [%[dout1]], #8 \n" + "str d2, [%[dout2]], #8 \n" + "str d3, [%[dout3]], #8 \n" + + "str d4, [%[dout0]] \n" + "str d5, [%[dout1]] \n" + "str d6, [%[dout2]] \n" + "str d7, [%[dout3]] \n" + + : [dout0] "+r"(dout0), + [dout1] "+r"(dout1), + [dout2] "+r"(dout2), + [dout3] "+r"(dout3), + [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7), + [wh] "+r"(weights) + : [bias] "r"(bias) + : "memory", + "x0", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v31"); +} + +//! kernel for four out with extracting data pre +//! deal with four lines out +//! need extra load weights +void compute_four_out_extract_pre_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + const float* weights, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! dout0 - dout3: v0-v3 + //! weights: v0-v4, v5, v6 + asm volatile( + // load weights + "movi v31.4s, #0 \n" + "mov x0, #20 \n" + "add %[wh], %[wh], #4 \n" + "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 + "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 + "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 + "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 + "ldr q4, [%[wh]] \n" // 21, 22, 23, 24 + "sub %[wh], %[wh], #68 \n" + + // load inputs + "ld1 {v8.4s}, [%[din0]] \n" + "ld1 {v9.4s}, [%[din1]] \n" + "ld1 {v10.4s}, [%[din2]] \n" + "ld1 {v11.4s}, [%[din3]] \n" + "ld1 {v12.4s}, [%[din4]] \n" + "ld1 {v13.4s}, [%[din5]] \n" + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]] \n" + "ld1 {v15.4s}, [%[din7]] \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v25 + "faddp v25.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v25.4s, v25.4s, v26.4s \n" + + // load weights col5 + "ld1 {v5.s}[0], [%[wh]], x0 \n" + "ld1 {v5.s}[1], [%[wh]], x0 \n" + "ld1 {v5.s}[2], [%[wh]], x0 \n" + "ld1 {v5.s}[3], [%[wh]], x0 \n" + "ld1 {v6.s}[0], [%[wh]] \n" + + // ext weights + "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 + "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 + "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 + "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 + "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v27 + "faddp v27.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v27.4s, v27.4s, v26.4s \n" + + // load in col5 + "ld1 {v20.s}[0], [%[din0]] \n" + "ld1 {v20.s}[1], [%[din1]] \n" + "ld1 {v20.s}[2], [%[din2]] \n" + "ld1 {v20.s}[3], [%[din3]] \n" + + // ext weights + "ext v0.16b, v0.16b, v31.16b, #4 \n" // 3, 4 + "ext v1.16b, v1.16b, v31.16b, #4 \n" // 8, 9 + "ext v2.16b, v2.16b, v31.16b, #4 \n" // 13, 14 + "ext v3.16b, v3.16b, v31.16b, #4 \n" // 18, 19 + "ext v4.16b, v4.16b, v31.16b, #4 \n" // 23, 24 + + "ld1 {v21.s}[0], [%[din4]] \n" + "ld1 {v21.s}[1], [%[din5]] \n" + "ld1 {v21.s}[2], [%[din6]] \n" + "ld1 {v21.s}[3], [%[din7]] \n" + + // in row0 + "fmul v16.4s, v0.4s, v8.4s \n" + "fmul v17.4s, v0.4s, v9.4s \n" + "fmul v18.4s, v0.4s, v10.4s \n" + "fmul v19.4s, v0.4s, v11.4s \n" + + // in row1 + "fmla v16.4s, v1.4s, v9.4s \n" + "fmla v17.4s, v1.4s, v10.4s \n" + "fmla v18.4s, v1.4s, v11.4s \n" + "fmla v19.4s, v1.4s, v12.4s \n" + + // in row2 + "fmla v16.4s, v2.4s, v10.4s \n" + "fmla v17.4s, v2.4s, v11.4s \n" + "fmla v18.4s, v2.4s, v12.4s \n" + "fmla v19.4s, v2.4s, v13.4s \n" + + // in row3 + "fmla v16.4s, v3.4s, v11.4s \n" + "fmla v17.4s, v3.4s, v12.4s \n" + "fmla v18.4s, v3.4s, v13.4s \n" + "fmla v19.4s, v3.4s, v14.4s \n" + + // in row4 + "fmla v16.4s, v4.4s, v12.4s \n" + "fmla v17.4s, v4.4s, v13.4s \n" + "fmla v18.4s, v4.4s, v14.4s \n" + "fmla v19.4s, v4.4s, v15.4s \n" + + // add to out register v26 + "faddp v26.4s, v16.4s, v17.4s \n" + "faddp v28.4s, v18.4s, v19.4s \n" + "faddp v26.4s, v26.4s, v28.4s \n" + + // ext input col5 + "ext v22.16b, v20.16b, v21.16b, #4 \n" + "ext v23.16b, v20.16b, v21.16b, #8 \n" + "ext v24.16b, v20.16b, v21.16b, #12 \n" + + // in col5 + "fmul v16.4s, v5.4s, v20.4s \n" + "fmul v17.4s, v5.4s, v22.4s \n" + "fmul v18.4s, v5.4s, v23.4s \n" + "fmul v19.4s, v5.4s, v24.4s \n" + + // add to out register v28 + "faddp v28.4s, v16.4s, v17.4s \n" + "faddp v29.4s, v18.4s, v19.4s \n" + "faddp v28.4s, v28.4s, v29.4s \n" + "fmla v28.4s, v21.4s, v6.s[0] \n" + + "ld1 {v8.4s}, [%[bias]] \n" + + // zip + "zip1 v0.4s, v28.4s, v26.4s \n" + "zip2 v2.4s, v28.4s, v26.4s \n" + "zip1 v4.4s, v27.4s, v25.4s \n" + "zip2 v6.4s, v27.4s, v25.4s \n" + + // add bias + "fadd v0.4s, v0.4s, v8.4s \n" + "fadd v2.4s, v2.4s, v8.4s \n" + "fadd v4.4s, v4.4s, v8.4s \n" + "fadd v6.4s, v6.4s, v8.4s \n" + + // relu + "fmax v0.4s, v0.4s, v31.4s \n" + "fmax v2.4s, v2.4s, v31.4s \n" + "fmax v4.4s, v4.4s, v31.4s \n" + "fmax v6.4s, v6.4s, v31.4s \n" + + "ext v1.16b, v0.16b, v31.16b, #8 \n" + "ext v3.16b, v2.16b, v31.16b, #8 \n" + "ext v5.16b, v4.16b, v31.16b, #8 \n" + "ext v7.16b, v6.16b, v31.16b, #8 \n" + + // write output + "str d0, [%[dout0]], #8 \n" + "str d1, [%[dout1]], #8 \n" + "str d2, [%[dout2]], #8 \n" + "str d3, [%[dout3]], #8 \n" + + "str d4, [%[dout0]] \n" + "str d5, [%[dout1]] \n" + "str d6, [%[dout2]] \n" + "str d7, [%[dout3]] \n" + + : [dout0] "+r"(dout0), + [dout1] "+r"(dout1), + [dout2] "+r"(dout2), + [dout3] "+r"(dout3), + [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7), + [wh] "+r"(weights) + : [bias] "r"(bias) + : "memory", + "x0", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v31"); +} + +//! kernel for four out with extracting data post +//! deal with four lines out +void compute_four_out_extract_post(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! dout0 - dout3: v0-v3 + const int64_t s_12 = 12; + const float* doutl[4] = {dout0, dout1, dout2, dout3}; + void* doutl_ptr = reinterpret_cast(doutl); + asm volatile( + "movi v31.4s, #0 \n" + "ldp x0, x1, [%[doutl]], #16 \n" + "ldp x2, x3, [%[doutl]] \n" + + // load inputs + "ld1 {v8.4s}, [%[din0]], %[s_12] \n" + "ld1 {v9.4s}, [%[din1]], %[s_12] \n" + "ld1 {v10.4s}, [%[din2]], %[s_12] \n" + "ld1 {v11.4s}, [%[din3]], %[s_12] \n" + "ld1 {v12.4s}, [%[din4]], %[s_12] \n" + "ld1 {v13.4s}, [%[din5]], %[s_12] \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]], %[s_12] \n" + "ld1 {v15.4s}, [%[din7]], %[s_12] \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v25 + "faddp v25.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v25.4s, v25.4s, v26.4s \n" + + // load input col5 + "ld1 {v20.s}[0], [%[din0]] \n" + "ld1 {v20.s}[1], [%[din1]] \n" + "ld1 {v20.s}[2], [%[din2]] \n" + "ld1 {v20.s}[3], [%[din3]] \n" + + // ext input + "ext v8.16b, v8.16b, v31.16b, #4 \n" + "ext v9.16b, v9.16b, v31.16b, #4 \n" + "ext v10.16b, v10.16b, v31.16b, #4 \n" + "ext v11.16b, v11.16b, v31.16b, #4 \n" + "ext v12.16b, v12.16b, v31.16b, #4 \n" + "ext v13.16b, v13.16b, v31.16b, #4 \n" + "ext v14.16b, v14.16b, v31.16b, #4 \n" + "ext v15.16b, v15.16b, v31.16b, #4 \n" + + // load input col5 + "ld1 {v21.s}[0], [%[din4]] \n" + "ld1 {v21.s}[1], [%[din5]] \n" + "ld1 {v21.s}[2], [%[din6]] \n" + "ld1 {v21.s}[3], [%[din7]] \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v27 + "faddp v27.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v27.4s, v27.4s, v26.4s \n" + + // ext input + "ext v8.16b, v8.16b, v31.16b, #4 \n" + "ext v9.16b, v9.16b, v31.16b, #4 \n" + "ext v10.16b, v10.16b, v31.16b, #4 \n" + "ext v11.16b, v11.16b, v31.16b, #4 \n" + "ext v12.16b, v12.16b, v31.16b, #4 \n" + "ext v13.16b, v13.16b, v31.16b, #4 \n" + "ext v14.16b, v14.16b, v31.16b, #4 \n" + "ext v15.16b, v15.16b, v31.16b, #4 \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v26 + "faddp v26.4s, v16.4s, v17.4s \n" + "faddp v28.4s, v18.4s, v19.4s \n" + "faddp v26.4s, v26.4s, v28.4s \n" + + // ext input col5 + "ext v8.16b, v20.16b, v21.16b, #4 \n" + "ext v9.16b, v20.16b, v21.16b, #8 \n" + "ext v10.16b, v20.16b, v21.16b, #12 \n" + + // ext weights col0 + "ins v5.s[0], %[w0].s[0] \n" + "ins v5.s[1], %[w1].s[0] \n" + "ins v5.s[2], %[w2].s[0] \n" + "ins v5.s[3], %[w3].s[0] \n" + + // in col5 + "fmul v16.4s, v5.4s, v20.4s \n" + "fmul v17.4s, v5.4s, v8.4s \n" + "fmul v18.4s, v5.4s, v9.4s \n" + "fmul v19.4s, v5.4s, v10.4s \n" + + // add to out register v28 + "faddp v28.4s, v16.4s, v17.4s \n" + "faddp v29.4s, v18.4s, v19.4s \n" + "faddp v28.4s, v28.4s, v29.4s \n" + "fmla v28.4s, v21.4s, %[w4].s[0] \n" + + "ld1 {v8.4s}, [%[bias]] \n" + + // zip + "zip1 v0.4s, v25.4s, v27.4s \n" + "zip2 v2.4s, v25.4s, v27.4s \n" + "zip1 v4.4s, v26.4s, v28.4s \n" + "zip2 v6.4s, v26.4s, v28.4s \n" + + "fadd v0.4s, v0.4s, v8.4s \n" + "fadd v2.4s, v2.4s, v8.4s \n" + "fadd v4.4s, v4.4s, v8.4s \n" + "fadd v6.4s, v6.4s, v8.4s \n" + + "ext v1.16b, v0.16b, v31.16b, #8 \n" + "ext v3.16b, v2.16b, v31.16b, #8 \n" + "ext v5.16b, v4.16b, v31.16b, #8 \n" + "ext v7.16b, v6.16b, v31.16b, #8 \n" + + // write output + "str d0, [x0], #8 \n" + "str d1, [x1], #8 \n" + "str d2, [x2], #8 \n" + "str d3, [x3], #8 \n" + + "str d4, [x0] \n" + "str d5, [x1] \n" + "str d6, [x2] \n" + "str d7, [x3] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7), + [doutl] "+r"(doutl_ptr) + : [s_12] "r"(s_12), + [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [bias] "r"(bias) + : "memory", + "x0", + "x1", + "x2", + "x3", + "v0", + "v1", + "v2", + "v3", + "v5", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v25", + "v26", + "v27", + "v28", + "v29", + "v31"); +} + +//! kernel for four out with extracting data post +//! deal with four lines out +void compute_four_out_extract_post_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + const float* din6, + const float* din7, + float* dout0, + float* dout1, + float* dout2, + float* dout3, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + const float* bias) { + //! din0 - din7: 0-4 v8-v15 + //! dout0 - dout3: v0-v3 + const int64_t s_12 = 12; + const float* doutl[4] = {dout0, dout1, dout2, dout3}; + void* doutl_ptr = reinterpret_cast(doutl); + asm volatile( + "movi v31.4s, #0 \n" + "ldp x0, x1, [%[doutl]], #16 \n" + "ldp x2, x3, [%[doutl]] \n" + + // load inputs + "ld1 {v8.4s}, [%[din0]], %[s_12] \n" + "ld1 {v9.4s}, [%[din1]], %[s_12] \n" + "ld1 {v10.4s}, [%[din2]], %[s_12] \n" + "ld1 {v11.4s}, [%[din3]], %[s_12] \n" + "ld1 {v12.4s}, [%[din4]], %[s_12] \n" + "ld1 {v13.4s}, [%[din5]], %[s_12] \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + "ld1 {v14.4s}, [%[din6]], %[s_12] \n" + "ld1 {v15.4s}, [%[din7]], %[s_12] \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v25 + "faddp v25.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v25.4s, v25.4s, v26.4s \n" + + // load input col5 + "ld1 {v20.s}[0], [%[din0]] \n" + "ld1 {v20.s}[1], [%[din1]] \n" + "ld1 {v20.s}[2], [%[din2]] \n" + "ld1 {v20.s}[3], [%[din3]] \n" + + // ext input + "ext v8.16b, v8.16b, v31.16b, #4 \n" + "ext v9.16b, v9.16b, v31.16b, #4 \n" + "ext v10.16b, v10.16b, v31.16b, #4 \n" + "ext v11.16b, v11.16b, v31.16b, #4 \n" + "ext v12.16b, v12.16b, v31.16b, #4 \n" + "ext v13.16b, v13.16b, v31.16b, #4 \n" + "ext v14.16b, v14.16b, v31.16b, #4 \n" + "ext v15.16b, v15.16b, v31.16b, #4 \n" + + // load input col5 + "ld1 {v21.s}[0], [%[din4]] \n" + "ld1 {v21.s}[1], [%[din5]] \n" + "ld1 {v21.s}[2], [%[din6]] \n" + "ld1 {v21.s}[3], [%[din7]] \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v27 + "faddp v27.4s, v16.4s, v17.4s \n" + "faddp v26.4s, v18.4s, v19.4s \n" + "faddp v27.4s, v27.4s, v26.4s \n" + + // ext input + "ext v8.16b, v8.16b, v31.16b, #4 \n" + "ext v9.16b, v9.16b, v31.16b, #4 \n" + "ext v10.16b, v10.16b, v31.16b, #4 \n" + "ext v11.16b, v11.16b, v31.16b, #4 \n" + "ext v12.16b, v12.16b, v31.16b, #4 \n" + "ext v13.16b, v13.16b, v31.16b, #4 \n" + "ext v14.16b, v14.16b, v31.16b, #4 \n" + "ext v15.16b, v15.16b, v31.16b, #4 \n" + + // in row0 + "fmul v16.4s, %[w0].4s, v8.4s \n" + "fmul v17.4s, %[w0].4s, v9.4s \n" + "fmul v18.4s, %[w0].4s, v10.4s \n" + "fmul v19.4s, %[w0].4s, v11.4s \n" + + // in row1 + "fmla v16.4s, %[w1].4s, v9.4s \n" + "fmla v17.4s, %[w1].4s, v10.4s \n" + "fmla v18.4s, %[w1].4s, v11.4s \n" + "fmla v19.4s, %[w1].4s, v12.4s \n" + + // in row2 + "fmla v16.4s, %[w2].4s, v10.4s \n" + "fmla v17.4s, %[w2].4s, v11.4s \n" + "fmla v18.4s, %[w2].4s, v12.4s \n" + "fmla v19.4s, %[w2].4s, v13.4s \n" + + // in row3 + "fmla v16.4s, %[w3].4s, v11.4s \n" + "fmla v17.4s, %[w3].4s, v12.4s \n" + "fmla v18.4s, %[w3].4s, v13.4s \n" + "fmla v19.4s, %[w3].4s, v14.4s \n" + + // in row4 + "fmla v16.4s, %[w4].4s, v12.4s \n" + "fmla v17.4s, %[w4].4s, v13.4s \n" + "fmla v18.4s, %[w4].4s, v14.4s \n" + "fmla v19.4s, %[w4].4s, v15.4s \n" + + // add to out register v26 + "faddp v26.4s, v16.4s, v17.4s \n" + "faddp v28.4s, v18.4s, v19.4s \n" + "faddp v26.4s, v26.4s, v28.4s \n" + + // ext input col5 + "ext v8.16b, v20.16b, v21.16b, #4 \n" + "ext v9.16b, v20.16b, v21.16b, #8 \n" + "ext v10.16b, v20.16b, v21.16b, #12 \n" + + // ext weights col0 + "ins v5.s[0], %[w0].s[0] \n" + "ins v5.s[1], %[w1].s[0] \n" + "ins v5.s[2], %[w2].s[0] \n" + "ins v5.s[3], %[w3].s[0] \n" + + // in col5 + "fmul v16.4s, v5.4s, v20.4s \n" + "fmul v17.4s, v5.4s, v8.4s \n" + "fmul v18.4s, v5.4s, v9.4s \n" + "fmul v19.4s, v5.4s, v10.4s \n" + + // add to out register v28 + "faddp v28.4s, v16.4s, v17.4s \n" + "faddp v29.4s, v18.4s, v19.4s \n" + "faddp v28.4s, v28.4s, v29.4s \n" + "fmla v28.4s, v21.4s, %[w4].s[0] \n" + + "ld1 {v8.4s}, [%[bias]] \n" + + // zip + "zip1 v0.4s, v25.4s, v27.4s \n" + "zip2 v2.4s, v25.4s, v27.4s \n" + "zip1 v4.4s, v26.4s, v28.4s \n" + "zip2 v6.4s, v26.4s, v28.4s \n" + + // add bias + "fadd v0.4s, v0.4s, v8.4s \n" + "fadd v2.4s, v2.4s, v8.4s \n" + "fadd v4.4s, v4.4s, v8.4s \n" + "fadd v6.4s, v6.4s, v8.4s \n" + + // relu + "fmax v0.4s, v0.4s, v31.4s \n" + "fmax v2.4s, v2.4s, v31.4s \n" + "fmax v4.4s, v4.4s, v31.4s \n" + "fmax v6.4s, v6.4s, v31.4s \n" + + "ext v1.16b, v0.16b, v31.16b, #8 \n" + "ext v3.16b, v2.16b, v31.16b, #8 \n" + "ext v5.16b, v4.16b, v31.16b, #8 \n" + "ext v7.16b, v6.16b, v31.16b, #8 \n" + + // write output + "str d0, [x0], #8 \n" + "str d1, [x1], #8 \n" + "str d2, [x2], #8 \n" + "str d3, [x3], #8 \n" + + "str d4, [x0] \n" + "str d5, [x1] \n" + "str d6, [x2] \n" + "str d7, [x3] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [din6] "+r"(din6), + [din7] "+r"(din7), + [doutl] "+r"(doutl_ptr) + : [s_12] "r"(s_12), + [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [bias] "r"(bias) + : "memory", + "x0", + "x1", + "x2", + "x3", + "v0", + "v1", + "v2", + "v3", + "v5", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v25", + "v26", + "v27", + "v28", + "v29", + "v31"); +} + +void conv_depthwise_5x5s1_impl(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + int pad_new = pad > 4 ? 4 : pad; + int pad_0 = pad - pad_new; + int h_out_new = h_out - 2 * pad_0; + int mid_out = w_out - 2 * pad; + int mid_cnt = mid_out >> 2; + int mid_remain = mid_out - (mid_cnt << 2); + int pad_cnt = pad_0 >> 2; + int pad_remain = pad_0 - (pad_cnt << 2); + int bias_cnt = (w_out * pad_0) >> 2; + int bias_remain = (w_out * pad_0) - (bias_cnt << 2); + int in_spatial_size = w_in * h_in; + int out_spatial_size = w_out * h_out; + int weights_saptial_size = 25; + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * in_spatial_size * ch_in; + float* dout_batch = dout + n * out_spatial_size * ch_out; +#pragma omp parallel for + for (int c = 0; c < ch_in; ++c) { + const float* din_ch = din_batch + c * in_spatial_size; + float* dout_ch = dout_batch + c * out_spatial_size; + float bias_c = flag_bias ? bias[c] : 0.f; + float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; + float32x4_t vbias_c = vdupq_n_f32(bias_c); + if (flag_bias) { + //! deal with h_out pad_0 line with bias + for (int i = 0; i < bias_cnt; ++i) { + vst1q_f32(dout_ch, vbias_c); + dout_ch += 4; + } + for (int i = 0; i < bias_remain; ++i) { + *dout_ch++ = bias_c; + } + } else { + //! deal with h_out pad_0 line without bias + for (int i = 0; i < pad_0; ++i) { + memset(dout_ch, 0x00, w_out * sizeof(float)); + dout_ch += w_out; + } + } + const float* din_list[8]; + const float* dinl[8]; + //! set din ptr with zero buffer + for (int i = 0; i < pad_new; ++i) { + din_list[i] = zero_ptr; + } + //! set din ptr with input data + for (int i = pad_new; i < 8; ++i) { + din_list[i] = din_ch; + din_ch += w_in; + } + + //! every h loop, deal with 4 line output + float* dout0 = dout_ch; + float* dout1 = dout0 + w_out; + float* dout2 = dout1 + w_out; + float* dout3 = dout2 + w_out; + + //! load weights to neon register + const float* weights_c = weights + c * weights_saptial_size; + + float32x4_t w5; + float32x4_t w6; + float32x4_t w0 = vld1q_f32(weights_c); + float32x4_t w1 = vld1q_f32(weights_c + 5); + float32x4_t w2 = vld1q_f32(weights_c + 10); + float32x4_t w3 = vld1q_f32(weights_c + 15); + float32x4_t w4 = vld1q_f32(weights_c + 20); + w5 = vsetq_lane_f32(weights_c[4], w5, 0); + w5 = vsetq_lane_f32(weights_c[9], w5, 1); + w5 = vsetq_lane_f32(weights_c[14], w5, 2); + w5 = vsetq_lane_f32(weights_c[19], w5, 3); + w6 = vsetq_lane_f32(weights_c[24], w6, 0); + + //! h loop + for (int h = 0; h < h_out_new; h += 4) { + //! (h - pad_new) + 7 > h_in - 1 + if (h + 8 - pad_new > h_in) { + switch (h + 8 - pad_new - h_in) { + case 7: + din_list[1] = zero_ptr; + case 6: + din_list[2] = zero_ptr; + case 5: + din_list[3] = zero_ptr; + case 4: + din_list[4] = zero_ptr; + case 3: + din_list[5] = zero_ptr; + case 2: + din_list[6] = zero_ptr; + case 1: + din_list[7] = zero_ptr; + default: + break; + } + } + if (h + 4 > h_out_new) { + switch (h + 4 - h_out_new) { + case 3: + dout1 = write_ptr; + case 2: + dout2 = write_ptr; + case 1: + dout3 = write_ptr; + default: + break; + } + } + + //! every h loop, deal with 8 line input + dinl[0] = din_list[0]; + dinl[1] = din_list[1]; + dinl[2] = din_list[2]; + dinl[3] = din_list[3]; + dinl[4] = din_list[4]; + dinl[5] = din_list[5]; + dinl[6] = din_list[6]; + dinl[7] = din_list[7]; + + const float* weights_ptr = weights_c; + float* dout_ptr0 = dout0; + float* dout_ptr1 = dout1; + float* dout_ptr2 = dout2; + float* dout_ptr3 = dout3; + if (flag_bias) { + //! deal with w_out pad_0 column pre with bias + for (int i = 0; i < pad_cnt; i++) { + vst1q_f32(dout_ptr0, vbias_c); + vst1q_f32(dout_ptr1, vbias_c); + vst1q_f32(dout_ptr2, vbias_c); + vst1q_f32(dout_ptr3, vbias_c); + dout_ptr0 += 4; + dout_ptr1 += 4; + dout_ptr2 += 4; + dout_ptr3 += 4; + } + for (int i = 0; i < pad_remain; ++i) { + *dout_ptr0++ = bias_c; + *dout_ptr1++ = bias_c; + *dout_ptr2++ = bias_c; + *dout_ptr3++ = bias_c; + } + } else { + //! deal with w_out pad_0 column pre without bias + memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); + dout_ptr0 += pad_0; + dout_ptr1 += pad_0; + dout_ptr2 += pad_0; + dout_ptr3 += pad_0; + } + //! deal with w_out pad_new column pre + switch (pad_new) { + case 4: + compute_four_out_extract_pre(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + weights_ptr, + vbias); + dout_ptr0 += 4; + dout_ptr1 += 4; + dout_ptr2 += 4; + dout_ptr3 += 4; + break; + case 3: + compute_three_out_extract_pre(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + weights_ptr, + vbias); + dout_ptr0 += 3; + dout_ptr1 += 3; + dout_ptr2 += 3; + dout_ptr3 += 3; + break; + case 2: + compute_two_out_extract_pre(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + weights_ptr, + vbias); + dout_ptr0 += 2; + dout_ptr1 += 2; + dout_ptr2 += 2; + dout_ptr3 += 2; + break; + case 1: + compute_one_out_extract_pre(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + weights_ptr, + vbias); + dout_ptr0 += 1; + dout_ptr1 += 1; + dout_ptr2 += 1; + dout_ptr3 += 1; + break; + } + //! mid loop + if (mid_cnt > 0) { + void* dinl_ptr = reinterpret_cast(dinl); + int mid_loop = mid_cnt; + asm volatile( + //! din: v7-v14 + //! dout: v15-v18 + "mov x0, #0 \n" + "mov x1, #4 \n" + "ldp x2, x3, [%[dinl]], #16 \n" + "ldp x4, x5, [%[dinl]], #16 \n" + "ldp x6, x7, [%[dinl]], #16 \n" + "ldp x8, x9, [%[dinl]], #16 \n" + + "ld1 {v7.4s} , [x2], x1 \n" + "ld1 {v8.4s} , [x3], x1 \n" + "ld1 {v9.4s} , [x4], x1 \n" + "ld1 {v10.4s}, [x5], x1 \n" + "ld1 {v11.4s}, [x6], x1 \n" + "ld1 {v12.4s}, [x7], x1 \n" + "ld1 {v13.4s}, [x8], x1 \n" + "ld1 {v14.4s}, [x9], x1 \n" + + //! load bias + "ld1 {v19.4s}, [%[bias]] \n" + + "1: \n" + //! add bias to output + "mov v15.16b, v19.16b \n" + "mov v16.16b, v19.16b \n" + "mov v17.16b, v19.16b \n" + "mov v18.16b, v19.16b \n" + + //! loop cnt is even, prefetch 64 Byte to l1 cache + "cmp x0, #1 \n" + "bne 2f \n" + "mov x0, #0 \n" + "prfm pldl1keep, [x2] \n" + "prfm pldl1keep, [x3] \n" + "prfm pldl1keep, [x4] \n" + "prfm pldl1keep, [x5] \n" + "prfm pldl1keep, [x6] \n" + "prfm pldl1keep, [x7] \n" + "prfm pldl1keep, [x8] \n" + "prfm pldl1keep, [x9] \n" + + "2: \n" + // weights col 0 + "fmla v15.4s, v7.4s , %[w0].s[0] \n" + "fmla v16.4s, v8.4s , %[w0].s[0] \n" + "fmla v17.4s, v9.4s , %[w0].s[0] \n" + "fmla v18.4s, v10.4s, %[w0].s[0] \n" + + "fmla v15.4s, v8.4s , %[w1].s[0] \n" + "fmla v16.4s, v9.4s , %[w1].s[0] \n" + "fmla v17.4s, v10.4s, %[w1].s[0] \n" + "fmla v18.4s, v11.4s, %[w1].s[0] \n" + + "ld1 {v7.4s}, [x2], x1 \n" + "ld1 {v8.4s}, [x3], x1 \n" + + "fmla v15.4s, v9.4s , %[w2].s[0] \n" + "fmla v16.4s, v10.4s, %[w2].s[0] \n" + "fmla v17.4s, v11.4s, %[w2].s[0] \n" + "fmla v18.4s, v12.4s, %[w2].s[0] \n" + + "fmla v15.4s, v10.4s, %[w3].s[0] \n" + "fmla v16.4s, v11.4s, %[w3].s[0] \n" + "fmla v17.4s, v12.4s, %[w3].s[0] \n" + "fmla v18.4s, v13.4s, %[w3].s[0] \n" + + "ld1 {v9.4s} , [x4], x1 \n" + "ld1 {v10.4s}, [x5], x1 \n" + + "fmla v15.4s, v11.4s, %[w4].s[0] \n" + "fmla v16.4s, v12.4s, %[w4].s[0] \n" + "fmla v17.4s, v13.4s, %[w4].s[0] \n" + "fmla v18.4s, v14.4s, %[w4].s[0] \n" + + "ld1 {v11.4s}, [x6], x1 \n" + "ld1 {v12.4s}, [x7], x1 \n" + + // weights col 1 + "fmla v15.4s, v7.4s , %[w0].s[1] \n" + "fmla v16.4s, v8.4s , %[w0].s[1] \n" + "fmla v17.4s, v9.4s , %[w0].s[1] \n" + "fmla v18.4s, v10.4s, %[w0].s[1] \n" + + "ld1 {v13.4s}, [x8], x1 \n" + "ld1 {v14.4s}, [x9], x1 \n" + + "fmla v15.4s, v8.4s , %[w1].s[1] \n" + "fmla v16.4s, v9.4s , %[w1].s[1] \n" + "fmla v17.4s, v10.4s, %[w1].s[1] \n" + "fmla v18.4s, v11.4s, %[w1].s[1] \n" + + "ld1 {v7.4s}, [x2], x1 \n" + "ld1 {v8.4s}, [x3], x1 \n" + + "fmla v15.4s, v9.4s , %[w2].s[1] \n" + "fmla v16.4s, v10.4s, %[w2].s[1] \n" + "fmla v17.4s, v11.4s, %[w2].s[1] \n" + "fmla v18.4s, v12.4s, %[w2].s[1] \n" + + "fmla v15.4s, v10.4s, %[w3].s[1] \n" + "fmla v16.4s, v11.4s, %[w3].s[1] \n" + "fmla v17.4s, v12.4s, %[w3].s[1] \n" + "fmla v18.4s, v13.4s, %[w3].s[1] \n" + + "ld1 {v9.4s} , [x4], x1 \n" + "ld1 {v10.4s}, [x5], x1 \n" + + "fmla v15.4s, v11.4s, %[w4].s[1] \n" + "fmla v16.4s, v12.4s, %[w4].s[1] \n" + "fmla v17.4s, v13.4s, %[w4].s[1] \n" + "fmla v18.4s, v14.4s, %[w4].s[1] \n" + + "ld1 {v11.4s}, [x6], x1 \n" + "ld1 {v12.4s}, [x7], x1 \n" + + // weights col 2 + "fmla v15.4s, v7.4s , %[w0].s[2] \n" + "fmla v16.4s, v8.4s , %[w0].s[2] \n" + "fmla v17.4s, v9.4s , %[w0].s[2] \n" + "fmla v18.4s, v10.4s, %[w0].s[2] \n" + + "ld1 {v13.4s}, [x8], x1 \n" + "ld1 {v14.4s}, [x9], x1 \n" + + "fmla v15.4s, v8.4s , %[w1].s[2] \n" + "fmla v16.4s, v9.4s , %[w1].s[2] \n" + "fmla v17.4s, v10.4s, %[w1].s[2] \n" + "fmla v18.4s, v11.4s, %[w1].s[2] \n" + + "ld1 {v7.4s}, [x2], x1 \n" + "ld1 {v8.4s}, [x3], x1 \n" + + "fmla v15.4s, v9.4s , %[w2].s[2] \n" + "fmla v16.4s, v10.4s, %[w2].s[2] \n" + "fmla v17.4s, v11.4s, %[w2].s[2] \n" + "fmla v18.4s, v12.4s, %[w2].s[2] \n" + + "fmla v15.4s, v10.4s, %[w3].s[2] \n" + "fmla v16.4s, v11.4s, %[w3].s[2] \n" + "fmla v17.4s, v12.4s, %[w3].s[2] \n" + "fmla v18.4s, v13.4s, %[w3].s[2] \n" + + "ld1 {v9.4s} , [x4], x1 \n" + "ld1 {v10.4s}, [x5], x1 \n" + + "fmla v15.4s, v11.4s, %[w4].s[2] \n" + "fmla v16.4s, v12.4s, %[w4].s[2] \n" + "fmla v17.4s, v13.4s, %[w4].s[2] \n" + "fmla v18.4s, v14.4s, %[w4].s[2] \n" + + "ld1 {v11.4s}, [x6], x1 \n" + "ld1 {v12.4s}, [x7], x1 \n" + + // weights col 3 + "fmla v15.4s, v7.4s , %[w0].s[3] \n" + "fmla v16.4s, v8.4s , %[w0].s[3] \n" + "fmla v17.4s, v9.4s , %[w0].s[3] \n" + "fmla v18.4s, v10.4s, %[w0].s[3] \n" + + "ld1 {v13.4s}, [x8], x1 \n" + "ld1 {v14.4s}, [x9], x1 \n" + + "fmla v15.4s, v8.4s , %[w1].s[3] \n" + "fmla v16.4s, v9.4s , %[w1].s[3] \n" + "fmla v17.4s, v10.4s, %[w1].s[3] \n" + "fmla v18.4s, v11.4s, %[w1].s[3] \n" + + "ld1 {v7.4s}, [x2], x1 \n" + "ld1 {v8.4s}, [x3], x1 \n" + + "fmla v15.4s, v9.4s , %[w2].s[3] \n" + "fmla v16.4s, v10.4s, %[w2].s[3] \n" + "fmla v17.4s, v11.4s, %[w2].s[3] \n" + "fmla v18.4s, v12.4s, %[w2].s[3] \n" + + "fmla v15.4s, v10.4s, %[w3].s[3] \n" + "fmla v16.4s, v11.4s, %[w3].s[3] \n" + "fmla v17.4s, v12.4s, %[w3].s[3] \n" + "fmla v18.4s, v13.4s, %[w3].s[3] \n" + + "ld1 {v9.4s} , [x4], x1 \n" + "ld1 {v10.4s}, [x5], x1 \n" + + "fmla v15.4s, v11.4s, %[w4].s[3] \n" + "fmla v16.4s, v12.4s, %[w4].s[3] \n" + "fmla v17.4s, v13.4s, %[w4].s[3] \n" + "fmla v18.4s, v14.4s, %[w4].s[3] \n" + + "ld1 {v11.4s}, [x6], x1 \n" + "ld1 {v12.4s}, [x7], x1 \n" + + // weights col 4 + "fmla v15.4s, v7.4s, %[w5].s[0] \n" + "fmla v16.4s, v8.4s, %[w5].s[0] \n" + "fmla v17.4s, v9.4s, %[w5].s[0] \n" + "fmla v18.4s, v10.4s, %[w5].s[0] \n" + + "ld1 {v13.4s}, [x8], x1 \n" + "ld1 {v14.4s}, [x9], x1 \n" + + "fmla v15.4s, v8.4s, %[w5].s[1] \n" + "fmla v16.4s, v9.4s, %[w5].s[1] \n" + "fmla v17.4s, v10.4s, %[w5].s[1] \n" + "fmla v18.4s, v11.4s, %[w5].s[1] \n" + + "fmla v15.4s, v9.4s , %[w5].s[2] \n" + "fmla v16.4s, v10.4s, %[w5].s[2] \n" + "fmla v17.4s, v11.4s, %[w5].s[2] \n" + "fmla v18.4s, v12.4s, %[w5].s[2] \n" + + "fmla v15.4s, v10.4s, %[w5].s[3] \n" + "fmla v16.4s, v11.4s, %[w5].s[3] \n" + "fmla v17.4s, v12.4s, %[w5].s[3] \n" + "fmla v18.4s, v13.4s, %[w5].s[3] \n" + + "fmla v15.4s, v11.4s, %[w6].s[0] \n" + "fmla v16.4s, v12.4s, %[w6].s[0] \n" + "fmla v17.4s, v13.4s, %[w6].s[0] \n" + "fmla v18.4s, v14.4s, %[w6].s[0] \n" + + "st1 {v15.4s}, [%[dout0]], #16 \n" + "st1 {v16.4s}, [%[dout1]], #16 \n" + "st1 {v17.4s}, [%[dout2]], #16 \n" + "st1 {v18.4s}, [%[dout3]], #16 \n" + + "subs %w[cnt], %w[cnt], #1 \n" + "add x0, x0, #1 \n" + "bne 1b \n" + + : [dout0] "+r"(dout_ptr0), + [dout1] "+r"(dout_ptr1), + [dout2] "+r"(dout_ptr2), + [dout3] "+r"(dout_ptr3), + [cnt] "+r"(mid_loop), + [dinl] "+r"(dinl_ptr) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [bias] "r"(vbias) + : "cc", + "memory", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7", + "x8", + "x9", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19"); + } + dinl[0] += 4 * mid_cnt; + dinl[1] += 4 * mid_cnt; + dinl[2] += 4 * mid_cnt; + dinl[3] += 4 * mid_cnt; + dinl[4] += 4 * mid_cnt; + dinl[5] += 4 * mid_cnt; + dinl[6] += 4 * mid_cnt; + dinl[7] += 4 * mid_cnt; + //! deal with mid remain + for (int i = 0; i < mid_remain; ++i) { + compute_one_out_without_extract(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + w0, + w1, + w2, + w3, + w4, + w5, + w6, + vbias); + dinl[0]++; + dinl[1]++; + dinl[2]++; + dinl[3]++; + dinl[4]++; + dinl[5]++; + dinl[6]++; + dinl[7]++; + + dout_ptr0++; + dout_ptr1++; + dout_ptr2++; + dout_ptr3++; + } + //! deal with w_out pad_new column post + switch (pad_new) { + case 4: + compute_four_out_extract_post(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + w0, + w1, + w2, + w3, + w4, + vbias); + dout_ptr0 += 4; + dout_ptr1 += 4; + dout_ptr2 += 4; + dout_ptr3 += 4; + break; + case 3: + compute_three_out_extract_post(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + w0, + w1, + w2, + w3, + w4, + vbias); + dout_ptr0 += 3; + dout_ptr1 += 3; + dout_ptr2 += 3; + dout_ptr3 += 3; + break; + case 2: + compute_two_out_extract_post(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + w0, + w1, + w2, + w3, + w4, + vbias); + dout_ptr0 += 2; + dout_ptr1 += 2; + dout_ptr2 += 2; + dout_ptr3 += 2; + break; + case 1: + compute_one_out_extract_post(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + w0, + w1, + w2, + w3, + w4, + vbias); + dout_ptr0 += 1; + dout_ptr1 += 1; + dout_ptr2 += 1; + dout_ptr3 += 1; + break; + } + + if (flag_bias) { + //! deal with w_out pad_0 column post with bias + memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); + memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); + memcpy(dout_ptr2, dout2, pad_0 * sizeof(float)); + memcpy(dout_ptr3, dout3, pad_0 * sizeof(float)); + } else { + //! deal with w_out pad_0 column post without bias + memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); + } + + din_list[0] = din_list[4]; + din_list[1] = din_list[5]; + din_list[2] = din_list[6]; + din_list[3] = din_list[7]; + din_list[4] = din_list[3] + w_in; + din_list[5] = din_list[4] + w_in; + din_list[6] = din_list[5] + w_in; + din_list[7] = din_list[6] + w_in; + + dout0 = dout3 + w_out; + dout1 = dout0 + w_out; + dout2 = dout1 + w_out; + dout3 = dout2 + w_out; + } + float* dout_pad_end = dout_ch + h_out_new * w_out; + if (flag_bias) { + //! deal with h_out pad_0 line with bias + memcpy(reinterpret_cast(dout_pad_end), + dout_ch - pad_0 * w_out, + pad_0 * w_out * sizeof(float)); + } else { + //! deal with h_out pad_0 line without bias + memset(reinterpret_cast(dout_pad_end), + 0x00, + pad_0 * w_out * sizeof(float)); + } + } + } +} + +void conv_depthwise_5x5s1_relu_impl(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + int pad_new = pad > 4 ? 4 : pad; + int pad_0 = pad - pad_new; + int h_out_new = h_out - 2 * pad_0; + int mid_out = w_out - 2 * pad; + int mid_cnt = mid_out >> 2; + int mid_remain = mid_out - (mid_cnt << 2); + int pad_cnt = pad_0 >> 2; + int pad_remain = pad_0 - (pad_cnt << 2); + int bias_cnt = (w_out * pad_0) >> 2; + int bias_remain = (w_out * pad_0) - (bias_cnt << 2); + int in_spatial_size = w_in * h_in; + int out_spatial_size = w_out * h_out; + int weights_saptial_size = 25; + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * in_spatial_size * ch_in; + float* dout_batch = dout + n * out_spatial_size * ch_out; +#pragma omp parallel for + for (int c = 0; c < ch_in; ++c) { + const float* din_ch = din_batch + c * in_spatial_size; + float* dout_ch = dout_batch + c * out_spatial_size; + float bias_c = flag_bias ? bias[c] : 0.f; + float bias_relu = bias_c > 0.f ? bias_c : 0.f; + float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; + float32x4_t vbias_c = vdupq_n_f32(bias_relu); + if (flag_bias) { + //! deal with h_out pad_0 line with bias + for (int i = 0; i < bias_cnt; ++i) { + vst1q_f32(dout_ch, vbias_c); + dout_ch += 4; + } + for (int i = 0; i < bias_remain; ++i) { + *dout_ch++ = bias_relu; + } + } else { + //! deal with h_out pad_0 line without bias + for (int i = 0; i < pad_0; ++i) { + memset(dout_ch, 0x00, w_out * sizeof(float)); + dout_ch += w_out; + } + } + const float* din_list[8]; + const float* dinl[8]; + //! set din ptr with zero buffer + for (int i = 0; i < pad_new; ++i) { + din_list[i] = zero_ptr; + } + //! set din ptr with input data + for (int i = pad_new; i < 8; ++i) { + din_list[i] = din_ch; + din_ch += w_in; + } + + //! every h loop, deal with 4 line output + float* dout0 = dout_ch; + float* dout1 = dout0 + w_out; + float* dout2 = dout1 + w_out; + float* dout3 = dout2 + w_out; + + //! load weights to neon register + const float* weights_c = weights + c * weights_saptial_size; + + float32x4_t w5; + float32x4_t w6; + float32x4_t w0 = vld1q_f32(weights_c); + float32x4_t w1 = vld1q_f32(weights_c + 5); + float32x4_t w2 = vld1q_f32(weights_c + 10); + float32x4_t w3 = vld1q_f32(weights_c + 15); + float32x4_t w4 = vld1q_f32(weights_c + 20); + w5 = vsetq_lane_f32(weights_c[4], w5, 0); + w5 = vsetq_lane_f32(weights_c[9], w5, 1); + w5 = vsetq_lane_f32(weights_c[14], w5, 2); + w5 = vsetq_lane_f32(weights_c[19], w5, 3); + w6 = vsetq_lane_f32(weights_c[24], w6, 0); + + //! h loop + for (int h = 0; h < h_out_new; h += 4) { + //! (h - pad_new) + 7 > h_in - 1 + if (h + 8 - pad_new > h_in) { + switch (h + 8 - pad_new - h_in) { + case 7: + din_list[1] = zero_ptr; + case 6: + din_list[2] = zero_ptr; + case 5: + din_list[3] = zero_ptr; + case 4: + din_list[4] = zero_ptr; + case 3: + din_list[5] = zero_ptr; + case 2: + din_list[6] = zero_ptr; + case 1: + din_list[7] = zero_ptr; + default: + break; + } + } + if (h + 4 > h_out_new) { + switch (h + 4 - h_out_new) { + case 3: + dout1 = write_ptr; + case 2: + dout2 = write_ptr; + case 1: + dout3 = write_ptr; + default: + break; + } + } + + //! every h loop, deal with 8 line input + dinl[0] = din_list[0]; + dinl[1] = din_list[1]; + dinl[2] = din_list[2]; + dinl[3] = din_list[3]; + dinl[4] = din_list[4]; + dinl[5] = din_list[5]; + dinl[6] = din_list[6]; + dinl[7] = din_list[7]; + + const float* weights_ptr = weights_c; + float* dout_ptr0 = dout0; + float* dout_ptr1 = dout1; + float* dout_ptr2 = dout2; + float* dout_ptr3 = dout3; + if (flag_bias) { + //! deal with w_out pad_0 column pre with bias + for (int i = 0; i < pad_cnt; i++) { + vst1q_f32(dout_ptr0, vbias_c); + vst1q_f32(dout_ptr1, vbias_c); + vst1q_f32(dout_ptr2, vbias_c); + vst1q_f32(dout_ptr3, vbias_c); + dout_ptr0 += 4; + dout_ptr1 += 4; + dout_ptr2 += 4; + dout_ptr3 += 4; + } + for (int i = 0; i < pad_remain; ++i) { + *dout_ptr0++ = bias_relu; + *dout_ptr1++ = bias_relu; + *dout_ptr2++ = bias_relu; + *dout_ptr3++ = bias_relu; + } + } else { + //! deal with w_out pad_0 column pre without bias + memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); + dout_ptr0 += pad_0; + dout_ptr1 += pad_0; + dout_ptr2 += pad_0; + dout_ptr3 += pad_0; + } + //! deal with w_out pad_new column pre + switch (pad_new) { + case 4: + compute_four_out_extract_pre_relu(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + weights_ptr, + vbias); + dout_ptr0 += 4; + dout_ptr1 += 4; + dout_ptr2 += 4; + dout_ptr3 += 4; + break; + case 3: + compute_three_out_extract_pre_relu(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + weights_ptr, + vbias); + dout_ptr0 += 3; + dout_ptr1 += 3; + dout_ptr2 += 3; + dout_ptr3 += 3; + break; + case 2: + compute_two_out_extract_pre_relu(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + weights_ptr, + vbias); + dout_ptr0 += 2; + dout_ptr1 += 2; + dout_ptr2 += 2; + dout_ptr3 += 2; + break; + case 1: + compute_one_out_extract_pre_relu(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + weights_ptr, + vbias); + dout_ptr0 += 1; + dout_ptr1 += 1; + dout_ptr2 += 1; + dout_ptr3 += 1; + break; + } + //! mid loop + if (mid_cnt > 0) { + void* dinl_ptr = reinterpret_cast(dinl); + int mid_loop = mid_cnt; + asm volatile( + //! din: v7-v14 + //! dout: v15-v18 + "mov x0, #0 \n" + "mov x1, #4 \n" + "movi v31.4s, #0 \n" + "ldp x2, x3, [%[dinl]], #16 \n" + "ldp x4, x5, [%[dinl]], #16 \n" + "ldp x6, x7, [%[dinl]], #16 \n" + "ldp x8, x9, [%[dinl]], #16 \n" + + "ld1 {v7.4s} , [x2], x1 \n" + "ld1 {v8.4s} , [x3], x1 \n" + "ld1 {v9.4s} , [x4], x1 \n" + "ld1 {v10.4s}, [x5], x1 \n" + "ld1 {v11.4s}, [x6], x1 \n" + "ld1 {v12.4s}, [x7], x1 \n" + "ld1 {v13.4s}, [x8], x1 \n" + "ld1 {v14.4s}, [x9], x1 \n" + + //! load bias + "ld1 {v19.4s}, [%[bias]] \n" + + "1: \n" + //! add bias to output + "mov v15.16b, v19.16b \n" + "mov v16.16b, v19.16b \n" + "mov v17.16b, v19.16b \n" + "mov v18.16b, v19.16b \n" + + //! loop cnt is even, prefetch 64 Byte to l1 cache + "cmp x0, #1 \n" + "bne 2f \n" + "mov x0, #0 \n" + "prfm pldl1keep, [x2] \n" + "prfm pldl1keep, [x3] \n" + "prfm pldl1keep, [x4] \n" + "prfm pldl1keep, [x5] \n" + "prfm pldl1keep, [x6] \n" + "prfm pldl1keep, [x7] \n" + "prfm pldl1keep, [x8] \n" + "prfm pldl1keep, [x9] \n" + + "2: \n" + // weights col 0 + "fmla v15.4s, v7.4s , %[w0].s[0] \n" + "fmla v16.4s, v8.4s , %[w0].s[0] \n" + "fmla v17.4s, v9.4s , %[w0].s[0] \n" + "fmla v18.4s, v10.4s, %[w0].s[0] \n" + + "fmla v15.4s, v8.4s , %[w1].s[0] \n" + "fmla v16.4s, v9.4s , %[w1].s[0] \n" + "fmla v17.4s, v10.4s, %[w1].s[0] \n" + "fmla v18.4s, v11.4s, %[w1].s[0] \n" + + "ld1 {v7.4s}, [x2], x1 \n" + "ld1 {v8.4s}, [x3], x1 \n" + + "fmla v15.4s, v9.4s , %[w2].s[0] \n" + "fmla v16.4s, v10.4s, %[w2].s[0] \n" + "fmla v17.4s, v11.4s, %[w2].s[0] \n" + "fmla v18.4s, v12.4s, %[w2].s[0] \n" + + "fmla v15.4s, v10.4s, %[w3].s[0] \n" + "fmla v16.4s, v11.4s, %[w3].s[0] \n" + "fmla v17.4s, v12.4s, %[w3].s[0] \n" + "fmla v18.4s, v13.4s, %[w3].s[0] \n" + + "ld1 {v9.4s} , [x4], x1 \n" + "ld1 {v10.4s}, [x5], x1 \n" + + "fmla v15.4s, v11.4s, %[w4].s[0] \n" + "fmla v16.4s, v12.4s, %[w4].s[0] \n" + "fmla v17.4s, v13.4s, %[w4].s[0] \n" + "fmla v18.4s, v14.4s, %[w4].s[0] \n" + + "ld1 {v11.4s}, [x6], x1 \n" + "ld1 {v12.4s}, [x7], x1 \n" + + // weights col 1 + "fmla v15.4s, v7.4s , %[w0].s[1] \n" + "fmla v16.4s, v8.4s , %[w0].s[1] \n" + "fmla v17.4s, v9.4s , %[w0].s[1] \n" + "fmla v18.4s, v10.4s, %[w0].s[1] \n" + + "ld1 {v13.4s}, [x8], x1 \n" + "ld1 {v14.4s}, [x9], x1 \n" + + "fmla v15.4s, v8.4s , %[w1].s[1] \n" + "fmla v16.4s, v9.4s , %[w1].s[1] \n" + "fmla v17.4s, v10.4s, %[w1].s[1] \n" + "fmla v18.4s, v11.4s, %[w1].s[1] \n" + + "ld1 {v7.4s}, [x2], x1 \n" + "ld1 {v8.4s}, [x3], x1 \n" + + "fmla v15.4s, v9.4s , %[w2].s[1] \n" + "fmla v16.4s, v10.4s, %[w2].s[1] \n" + "fmla v17.4s, v11.4s, %[w2].s[1] \n" + "fmla v18.4s, v12.4s, %[w2].s[1] \n" + + "fmla v15.4s, v10.4s, %[w3].s[1] \n" + "fmla v16.4s, v11.4s, %[w3].s[1] \n" + "fmla v17.4s, v12.4s, %[w3].s[1] \n" + "fmla v18.4s, v13.4s, %[w3].s[1] \n" + + "ld1 {v9.4s} , [x4], x1 \n" + "ld1 {v10.4s}, [x5], x1 \n" + + "fmla v15.4s, v11.4s, %[w4].s[1] \n" + "fmla v16.4s, v12.4s, %[w4].s[1] \n" + "fmla v17.4s, v13.4s, %[w4].s[1] \n" + "fmla v18.4s, v14.4s, %[w4].s[1] \n" + + "ld1 {v11.4s}, [x6], x1 \n" + "ld1 {v12.4s}, [x7], x1 \n" + + // weights col 2 + "fmla v15.4s, v7.4s , %[w0].s[2] \n" + "fmla v16.4s, v8.4s , %[w0].s[2] \n" + "fmla v17.4s, v9.4s , %[w0].s[2] \n" + "fmla v18.4s, v10.4s, %[w0].s[2] \n" + + "ld1 {v13.4s}, [x8], x1 \n" + "ld1 {v14.4s}, [x9], x1 \n" + + "fmla v15.4s, v8.4s , %[w1].s[2] \n" + "fmla v16.4s, v9.4s , %[w1].s[2] \n" + "fmla v17.4s, v10.4s, %[w1].s[2] \n" + "fmla v18.4s, v11.4s, %[w1].s[2] \n" + + "ld1 {v7.4s}, [x2], x1 \n" + "ld1 {v8.4s}, [x3], x1 \n" + + "fmla v15.4s, v9.4s , %[w2].s[2] \n" + "fmla v16.4s, v10.4s, %[w2].s[2] \n" + "fmla v17.4s, v11.4s, %[w2].s[2] \n" + "fmla v18.4s, v12.4s, %[w2].s[2] \n" + + "fmla v15.4s, v10.4s, %[w3].s[2] \n" + "fmla v16.4s, v11.4s, %[w3].s[2] \n" + "fmla v17.4s, v12.4s, %[w3].s[2] \n" + "fmla v18.4s, v13.4s, %[w3].s[2] \n" + + "ld1 {v9.4s} , [x4], x1 \n" + "ld1 {v10.4s}, [x5], x1 \n" + + "fmla v15.4s, v11.4s, %[w4].s[2] \n" + "fmla v16.4s, v12.4s, %[w4].s[2] \n" + "fmla v17.4s, v13.4s, %[w4].s[2] \n" + "fmla v18.4s, v14.4s, %[w4].s[2] \n" + + "ld1 {v11.4s}, [x6], x1 \n" + "ld1 {v12.4s}, [x7], x1 \n" + + // weights col 3 + "fmla v15.4s, v7.4s , %[w0].s[3] \n" + "fmla v16.4s, v8.4s , %[w0].s[3] \n" + "fmla v17.4s, v9.4s , %[w0].s[3] \n" + "fmla v18.4s, v10.4s, %[w0].s[3] \n" + + "ld1 {v13.4s}, [x8], x1 \n" + "ld1 {v14.4s}, [x9], x1 \n" + + "fmla v15.4s, v8.4s , %[w1].s[3] \n" + "fmla v16.4s, v9.4s , %[w1].s[3] \n" + "fmla v17.4s, v10.4s, %[w1].s[3] \n" + "fmla v18.4s, v11.4s, %[w1].s[3] \n" + + "ld1 {v7.4s}, [x2], x1 \n" + "ld1 {v8.4s}, [x3], x1 \n" + + "fmla v15.4s, v9.4s , %[w2].s[3] \n" + "fmla v16.4s, v10.4s, %[w2].s[3] \n" + "fmla v17.4s, v11.4s, %[w2].s[3] \n" + "fmla v18.4s, v12.4s, %[w2].s[3] \n" + + "fmla v15.4s, v10.4s, %[w3].s[3] \n" + "fmla v16.4s, v11.4s, %[w3].s[3] \n" + "fmla v17.4s, v12.4s, %[w3].s[3] \n" + "fmla v18.4s, v13.4s, %[w3].s[3] \n" + + "ld1 {v9.4s} , [x4], x1 \n" + "ld1 {v10.4s}, [x5], x1 \n" + + "fmla v15.4s, v11.4s, %[w4].s[3] \n" + "fmla v16.4s, v12.4s, %[w4].s[3] \n" + "fmla v17.4s, v13.4s, %[w4].s[3] \n" + "fmla v18.4s, v14.4s, %[w4].s[3] \n" + + "ld1 {v11.4s}, [x6], x1 \n" + "ld1 {v12.4s}, [x7], x1 \n" + + // weights col 4 + "fmla v15.4s, v7.4s, %[w5].s[0] \n" + "fmla v16.4s, v8.4s, %[w5].s[0] \n" + "fmla v17.4s, v9.4s, %[w5].s[0] \n" + "fmla v18.4s, v10.4s, %[w5].s[0] \n" + + "ld1 {v13.4s}, [x8], x1 \n" + "ld1 {v14.4s}, [x9], x1 \n" + + "fmla v15.4s, v8.4s, %[w5].s[1] \n" + "fmla v16.4s, v9.4s, %[w5].s[1] \n" + "fmla v17.4s, v10.4s, %[w5].s[1] \n" + "fmla v18.4s, v11.4s, %[w5].s[1] \n" + + "fmla v15.4s, v9.4s , %[w5].s[2] \n" + "fmla v16.4s, v10.4s, %[w5].s[2] \n" + "fmla v17.4s, v11.4s, %[w5].s[2] \n" + "fmla v18.4s, v12.4s, %[w5].s[2] \n" + + "fmla v15.4s, v10.4s, %[w5].s[3] \n" + "fmla v16.4s, v11.4s, %[w5].s[3] \n" + "fmla v17.4s, v12.4s, %[w5].s[3] \n" + "fmla v18.4s, v13.4s, %[w5].s[3] \n" + + "fmla v15.4s, v11.4s, %[w6].s[0] \n" + "fmla v16.4s, v12.4s, %[w6].s[0] \n" + "fmla v17.4s, v13.4s, %[w6].s[0] \n" + "fmla v18.4s, v14.4s, %[w6].s[0] \n" + + "fmax v15.4s, v15.4s, v31.4s \n" + "fmax v16.4s, v16.4s, v31.4s \n" + "fmax v17.4s, v17.4s, v31.4s \n" + "fmax v18.4s, v18.4s, v31.4s \n" + + "st1 {v15.4s}, [%[dout0]], #16 \n" + "st1 {v16.4s}, [%[dout1]], #16 \n" + "st1 {v17.4s}, [%[dout2]], #16 \n" + "st1 {v18.4s}, [%[dout3]], #16 \n" + + "subs %w[cnt], %w[cnt], #1 \n" + "add x0, x0, #1 \n" + "bne 1b \n" + + : [dout0] "+r"(dout_ptr0), + [dout1] "+r"(dout_ptr1), + [dout2] "+r"(dout_ptr2), + [dout3] "+r"(dout_ptr3), + [cnt] "+r"(mid_loop), + [dinl] "+r"(dinl_ptr) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [bias] "r"(vbias) + : "cc", + "memory", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7", + "x8", + "x9", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v31"); + } + dinl[0] += 4 * mid_cnt; + dinl[1] += 4 * mid_cnt; + dinl[2] += 4 * mid_cnt; + dinl[3] += 4 * mid_cnt; + dinl[4] += 4 * mid_cnt; + dinl[5] += 4 * mid_cnt; + dinl[6] += 4 * mid_cnt; + dinl[7] += 4 * mid_cnt; + //! deal with mid remain + for (int i = 0; i < mid_remain; ++i) { + compute_one_out_without_extract_relu(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + w0, + w1, + w2, + w3, + w4, + w5, + w6, + vbias); + dinl[0]++; + dinl[1]++; + dinl[2]++; + dinl[3]++; + dinl[4]++; + dinl[5]++; + dinl[6]++; + dinl[7]++; + + dout_ptr0++; + dout_ptr1++; + dout_ptr2++; + dout_ptr3++; + } + //! deal with w_out pad_new column post + switch (pad_new) { + case 4: + compute_four_out_extract_post_relu(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + w0, + w1, + w2, + w3, + w4, + vbias); + dout_ptr0 += 4; + dout_ptr1 += 4; + dout_ptr2 += 4; + dout_ptr3 += 4; + break; + case 3: + compute_three_out_extract_post_relu(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + w0, + w1, + w2, + w3, + w4, + vbias); + dout_ptr0 += 3; + dout_ptr1 += 3; + dout_ptr2 += 3; + dout_ptr3 += 3; + break; + case 2: + compute_two_out_extract_post_relu(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + w0, + w1, + w2, + w3, + w4, + vbias); + dout_ptr0 += 2; + dout_ptr1 += 2; + dout_ptr2 += 2; + dout_ptr3 += 2; + break; + case 1: + compute_one_out_extract_post_relu(dinl[0], + dinl[1], + dinl[2], + dinl[3], + dinl[4], + dinl[5], + dinl[6], + dinl[7], + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + w0, + w1, + w2, + w3, + w4, + vbias); + dout_ptr0 += 1; + dout_ptr1 += 1; + dout_ptr2 += 1; + dout_ptr3 += 1; + break; + } + + if (flag_bias) { + //! deal with w_out pad_0 column post with bias + memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); + memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); + memcpy(dout_ptr2, dout2, pad_0 * sizeof(float)); + memcpy(dout_ptr3, dout3, pad_0 * sizeof(float)); + } else { + //! deal with w_out pad_0 column post without bias + memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); + } + + din_list[0] = din_list[4]; + din_list[1] = din_list[5]; + din_list[2] = din_list[6]; + din_list[3] = din_list[7]; + din_list[4] = din_list[3] + w_in; + din_list[5] = din_list[4] + w_in; + din_list[6] = din_list[5] + w_in; + din_list[7] = din_list[6] + w_in; + + dout0 = dout3 + w_out; + dout1 = dout0 + w_out; + dout2 = dout1 + w_out; + dout3 = dout2 + w_out; + } + float* dout_pad_end = dout_ch + h_out_new * w_out; + if (flag_bias) { + //! deal with h_out pad_0 line with bias + memcpy(reinterpret_cast(dout_pad_end), + dout_ch - pad_0 * w_out, + pad_0 * w_out * sizeof(float)); + } else { + //! deal with h_out pad_0 line without bias + memset(reinterpret_cast(dout_pad_end), + 0x00, + pad_0 * w_out * sizeof(float)); + } + } + } +} + +void conv_depthwise_5x5s1_small_impl(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + int pad_new = pad > 4 ? 4 : pad; + int pad_0 = pad - pad_new; + int h_in_new = h_in + 2 * pad_new; + int w_in_new = w_in + 2 * pad_new; + int h_out_new = h_out - 2 * pad_0; + int w_out_new = w_out - 2 * pad_0; + float zero_ptr[w_in_new + w_out]; + memset(zero_ptr, 0, w_in_new * sizeof(float)); + float* write_ptr = zero_ptr + w_in_new; + int pad_cnt = pad_0 >> 2; + int pad_remain = pad_0 - (pad_cnt << 2); + int bias_cnt = (w_out * pad_0) >> 2; + int bias_remain = (w_out * pad_0) - (bias_cnt << 2); + int in_spatial_size = w_in_new * h_in_new; + int out_spatial_size = w_out * h_out; + int weights_saptial_size = 25; + + float* din_new = prepad_input(din, num, ch_in, h_in, w_in, pad_new); + for (int n = 0; n < num; ++n) { + const float* din_batch = din_new + n * in_spatial_size * ch_in; + float* dout_batch = dout + n * out_spatial_size * ch_out; +#pragma omp parallel for + for (int c = 0; c < ch_in; ++c) { + const float* din_ch = din_batch + c * in_spatial_size; + float* dout_ch = dout_batch + c * out_spatial_size; + float bias_c = flag_bias ? bias[c] : 0.f; + float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; + float32x4_t vbias_c = vdupq_n_f32(bias_c); + if (flag_bias) { + //! deal with h_out pad_0 line with bias + for (int i = 0; i < bias_cnt; ++i) { + vst1q_f32(dout_ch, vbias_c); + dout_ch += 4; + } + for (int i = 0; i < bias_remain; ++i) { + *dout_ch++ = bias_c; + } + } else { + //! deal with h_out pad_0 line without bias + for (int i = 0; i < pad_0; ++i) { + memset(dout_ch, 0x00, w_out * sizeof(float)); + dout_ch += w_out; + } + } + //! every h loop, deal with 8 line input + const float* din0 = din_ch; + const float* din1 = din0 + w_in_new; + const float* din2 = din1 + w_in_new; + const float* din3 = din2 + w_in_new; + const float* din4 = din3 + w_in_new; + const float* din5 = din4 + w_in_new; + const float* din6 = din5 + w_in_new; + const float* din7 = din6 + w_in_new; + //! every h loop, deal with 4 line output + float* dout0 = dout_ch; + float* dout1 = dout0 + w_out; + float* dout2 = dout1 + w_out; + float* dout3 = dout2 + w_out; + + //! load weights to neon register + const float* weights_c = weights + c * weights_saptial_size; + + float32x4_t w5; + float32x4_t w6; + float32x4_t w0 = vld1q_f32(weights_c); + float32x4_t w1 = vld1q_f32(weights_c + 5); + float32x4_t w2 = vld1q_f32(weights_c + 10); + float32x4_t w3 = vld1q_f32(weights_c + 15); + float32x4_t w4 = vld1q_f32(weights_c + 20); + w5 = vsetq_lane_f32(weights_c[4], w5, 0); + w5 = vsetq_lane_f32(weights_c[9], w5, 1); + w5 = vsetq_lane_f32(weights_c[14], w5, 2); + w5 = vsetq_lane_f32(weights_c[19], w5, 3); + w6 = vsetq_lane_f32(weights_c[24], w6, 0); + //! h loop + for (int h = 0; h < h_out_new; h += 4) { + //! (h - pad_new) + 7 > h_in - 1 + if (h + 8 > h_in_new) { + switch (h + 8 - h_in_new) { + case 7: + din1 = zero_ptr; + case 6: + din2 = zero_ptr; + case 5: + din3 = zero_ptr; + case 4: + din4 = zero_ptr; + case 3: + din5 = zero_ptr; + case 2: + din6 = zero_ptr; + case 1: + din7 = zero_ptr; + default: + break; + } + } + if (h + 4 > h_out_new) { + switch (h + 4 - h_out_new) { + case 3: + dout1 = write_ptr; + case 2: + dout2 = write_ptr; + case 1: + dout3 = write_ptr; + default: + break; + } + } + const float* din_ptr0 = din0; + const float* din_ptr1 = din1; + const float* din_ptr2 = din2; + const float* din_ptr3 = din3; + const float* din_ptr4 = din4; + const float* din_ptr5 = din5; + const float* din_ptr6 = din6; + const float* din_ptr7 = din7; + + const float* weights_ptr = weights_c; + float* dout_ptr0 = dout0; + float* dout_ptr1 = dout1; + float* dout_ptr2 = dout2; + float* dout_ptr3 = dout3; + + if (flag_bias) { + //! deal with w_out pad_0 column pre with bias + for (int i = 0; i < pad_cnt; i++) { + vst1q_f32(dout_ptr0, vbias_c); + vst1q_f32(dout_ptr1, vbias_c); + vst1q_f32(dout_ptr2, vbias_c); + vst1q_f32(dout_ptr3, vbias_c); + dout_ptr0 += 4; + dout_ptr1 += 4; + dout_ptr2 += 4; + dout_ptr3 += 4; + } + for (int i = 0; i < pad_remain; ++i) { + *dout_ptr0++ = bias_c; + *dout_ptr1++ = bias_c; + *dout_ptr2++ = bias_c; + *dout_ptr3++ = bias_c; + } + } else { + //! deal with w_out pad_0 column pre without bias + memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); + dout_ptr0 += pad_0; + dout_ptr1 += pad_0; + dout_ptr2 += pad_0; + dout_ptr3 += pad_0; + } + //! mid loop + for (int i = 0; i < w_out_new; ++i) { + compute_one_out_without_extract(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + din_ptr6, + din_ptr7, + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + w0, + w1, + w2, + w3, + w4, + w5, + w6, + vbias); + din_ptr0++; + din_ptr1++; + din_ptr2++; + din_ptr3++; + din_ptr4++; + din_ptr5++; + din_ptr6++; + din_ptr7++; + + dout_ptr0++; + dout_ptr1++; + dout_ptr2++; + dout_ptr3++; + } + if (flag_bias) { + //! deal with w_out pad_0 column post with bias + memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); + memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); + memcpy(dout_ptr2, dout2, pad_0 * sizeof(float)); + memcpy(dout_ptr3, dout3, pad_0 * sizeof(float)); + } else { + //! deal with w_out pad_0 column post without bias + memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); + } + + din0 = din4; + din1 = din5; + din2 = din6; + din3 = din7; + din4 = din3 + w_in_new; + din5 = din4 + w_in_new; + din6 = din5 + w_in_new; + din7 = din6 + w_in_new; + + dout0 = dout3 + w_out; + dout1 = dout0 + w_out; + dout2 = dout1 + w_out; + dout3 = dout2 + w_out; + } + float* dout_pad_end = dout_ch + h_out_new * w_out; + if (flag_bias) { + //! deal with h_out pad_0 line with bias + memcpy(reinterpret_cast(dout_pad_end), + dout_ch - pad_0 * w_out, + pad_0 * w_out * sizeof(float)); + } else { + //! deal with h_out pad_0 line without bias + memset(reinterpret_cast(dout_pad_end), + 0x00, + pad_0 * w_out * sizeof(float)); + } + } + } + free(din_new); +} + +void conv_depthwise_5x5s1_small_relu_impl(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + int pad_new = pad > 4 ? 4 : pad; + int pad_0 = pad - pad_new; + int h_in_new = h_in + 2 * pad_new; + int w_in_new = w_in + 2 * pad_new; + float zero_ptr[w_in_new + w_out]; + memset(zero_ptr, 0, w_in_new * sizeof(float)); + float* write_ptr = zero_ptr + w_in_new; + int h_out_new = h_out - 2 * pad_0; + int w_out_new = w_out - 2 * pad_0; + int pad_cnt = pad_0 >> 2; + int pad_remain = pad_0 - (pad_cnt << 2); + int bias_cnt = (w_out * pad_0) >> 2; + int bias_remain = (w_out * pad_0) - (bias_cnt << 2); + int in_spatial_size = w_in_new * h_in_new; + int out_spatial_size = w_out * h_out; + int weights_saptial_size = 25; + + float* din_new = prepad_input(din, num, ch_in, h_in, w_in, pad_new); + for (int n = 0; n < num; ++n) { + const float* din_batch = din_new + n * in_spatial_size * ch_in; + float* dout_batch = dout + n * out_spatial_size * ch_out; +#pragma omp parallel for + for (int c = 0; c < ch_in; ++c) { + const float* din_ch = din_batch + c * in_spatial_size; + float* dout_ch = dout_batch + c * out_spatial_size; + float bias_c = flag_bias ? bias[c] : 0.f; + float bias_relu = bias_c > 0.f ? bias_c : 0.f; + float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; + float32x4_t vbias_c = vdupq_n_f32(bias_relu); + if (flag_bias) { + //! deal with h_out pad_0 line with bias + for (int i = 0; i < bias_cnt; ++i) { + vst1q_f32(dout_ch, vbias_c); + dout_ch += 4; + } + for (int i = 0; i < bias_remain; ++i) { + *dout_ch++ = bias_relu; + } + } else { + //! deal with h_out pad_0 line without bias + for (int i = 0; i < pad_0; ++i) { + memset(dout_ch, 0x00, w_out * sizeof(float)); + dout_ch += w_out; + } + } + + //! every h loop, deal with 8 line input + const float* din0 = din_ch; + const float* din1 = din0 + w_in_new; + const float* din2 = din1 + w_in_new; + const float* din3 = din2 + w_in_new; + const float* din4 = din3 + w_in_new; + const float* din5 = din4 + w_in_new; + const float* din6 = din5 + w_in_new; + const float* din7 = din6 + w_in_new; + //! every h loop, deal with 4 line output + float* dout0 = dout_ch; + float* dout1 = dout0 + w_out; + float* dout2 = dout1 + w_out; + float* dout3 = dout2 + w_out; + + //! load weights to neon register + const float* weights_c = weights + c * weights_saptial_size; + + float32x4_t w5; + float32x4_t w6; + float32x4_t w0 = vld1q_f32(weights_c); + float32x4_t w1 = vld1q_f32(weights_c + 5); + float32x4_t w2 = vld1q_f32(weights_c + 10); + float32x4_t w3 = vld1q_f32(weights_c + 15); + float32x4_t w4 = vld1q_f32(weights_c + 20); + w5 = vsetq_lane_f32(weights_c[4], w5, 0); + w5 = vsetq_lane_f32(weights_c[9], w5, 1); + w5 = vsetq_lane_f32(weights_c[14], w5, 2); + w5 = vsetq_lane_f32(weights_c[19], w5, 3); + w6 = vsetq_lane_f32(weights_c[24], w6, 0); + + //! h loop + for (int h = 0; h < h_out_new; h += 4) { + //! (h - pad_new) + 7 > h_in - 1 + if (h + 8 > h_in_new) { + switch (h + 8 - h_in_new) { + case 7: + din1 = zero_ptr; + case 6: + din2 = zero_ptr; + case 5: + din3 = zero_ptr; + case 4: + din4 = zero_ptr; + case 3: + din5 = zero_ptr; + case 2: + din6 = zero_ptr; + case 1: + din7 = zero_ptr; + default: + break; + } + } + if (h + 4 > h_out_new) { + switch (h + 4 - h_out_new) { + case 3: + dout1 = write_ptr; + case 2: + dout2 = write_ptr; + case 1: + dout3 = write_ptr; + default: + break; + } + } + const float* din_ptr0 = din0; + const float* din_ptr1 = din1; + const float* din_ptr2 = din2; + const float* din_ptr3 = din3; + const float* din_ptr4 = din4; + const float* din_ptr5 = din5; + const float* din_ptr6 = din6; + const float* din_ptr7 = din7; + + float* dout_ptr0 = dout0; + float* dout_ptr1 = dout1; + float* dout_ptr2 = dout2; + float* dout_ptr3 = dout3; + + if (flag_bias) { + //! deal with w_out pad_0 column pre with bias + for (int i = 0; i < pad_cnt; i++) { + vst1q_f32(dout_ptr0, vbias_c); + vst1q_f32(dout_ptr1, vbias_c); + vst1q_f32(dout_ptr2, vbias_c); + vst1q_f32(dout_ptr3, vbias_c); + dout_ptr0 += 4; + dout_ptr1 += 4; + dout_ptr2 += 4; + dout_ptr3 += 4; + } + for (int i = 0; i < pad_remain; ++i) { + *dout_ptr0++ = bias_relu; + *dout_ptr1++ = bias_relu; + *dout_ptr2++ = bias_relu; + *dout_ptr3++ = bias_relu; + } + } else { + //! deal with w_out pad_0 column pre without bias + memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); + dout_ptr0 += pad_0; + dout_ptr1 += pad_0; + dout_ptr2 += pad_0; + dout_ptr3 += pad_0; + } + + //! mid loop + for (int i = 0; i < w_out_new; ++i) { + compute_one_out_without_extract_relu(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + din_ptr6, + din_ptr7, + dout_ptr0, + dout_ptr1, + dout_ptr2, + dout_ptr3, + w0, + w1, + w2, + w3, + w4, + w5, + w6, + vbias); + din_ptr0++; + din_ptr1++; + din_ptr2++; + din_ptr3++; + din_ptr4++; + din_ptr5++; + din_ptr6++; + din_ptr7++; + + dout_ptr0++; + dout_ptr1++; + dout_ptr2++; + dout_ptr3++; + } + + if (flag_bias) { + //! deal with w_out pad_0 column post with bias + memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); + memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); + memcpy(dout_ptr2, dout2, pad_0 * sizeof(float)); + memcpy(dout_ptr3, dout3, pad_0 * sizeof(float)); + } else { + //! deal with w_out pad_0 column post without bias + memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); + } + + din0 = din4; + din1 = din5; + din2 = din6; + din3 = din7; + din4 = din3 + w_in_new; + din5 = din4 + w_in_new; + din6 = din5 + w_in_new; + din7 = din6 + w_in_new; + + dout0 = dout3 + w_out; + dout1 = dout0 + w_out; + dout2 = dout1 + w_out; + dout3 = dout2 + w_out; + } + float* dout_pad_end = dout_ch + h_out_new * w_out; + if (flag_bias) { + //! deal with h_out pad_0 line with bias + memcpy(reinterpret_cast(dout_pad_end), + dout_ch - pad_0 * w_out, + pad_0 * w_out * sizeof(float)); + } else { + //! deal with h_out pad_0 line without bias + memset(reinterpret_cast(dout_pad_end), + 0x00, + pad_0 * w_out * sizeof(float)); + } + } + } + free(din_new); +} + +#else + +//! kernel for one out without extracting data mid +//! deal with two lines out +void compute_one_out_without_extract(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vld1.32 {d4-d5}, [%[din0]]! \n" + "vld1.32 {d6-d7}, [%[din1]]! \n" + "vld1.32 {d8-d9}, [%[din2]]! \n" + "vld1.32 {d10-d11}, [%[din3]]! \n" + "vld1.32 {d12-d13}, [%[din4]]! \n" + "vld1.32 {d14-d15}, [%[din5]]! \n" + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + "vld1.32 {d6[0]}, [%[din0]] \n" + "vld1.32 {d6[1]}, [%[din1]] \n" + "vld1.32 {d7[0]}, [%[din2]] \n" + "vld1.32 {d7[1]}, [%[din3]] \n" + + // weights r2 + "vmla.f32 q9, q0, q4 \n" + "vmla.f32 q10, q0, q5 \n" + + "vld1.32 {d8[0]}, [%[din4]] \n" + "vld1.32 {d8[1]}, [%[din5]] \n" + + "vld1.32 {d0-d1}, [%[wh]] \n" + + // weights r3 + "vmla.f32 q9, q1, q5 \n" + "vmla.f32 q10, q1, q6 \n" + + // weights col4 + "sub %[wh], #64 \n" + "vld1.32 {d4[0]}, [%[wh]], r0 \n" + "vld1.32 {d4[1]}, [%[wh]], r0 \n" + "vld1.32 {d5[0]}, [%[wh]], r0 \n" + "vld1.32 {d5[1]}, [%[wh]], r0 \n" + + // weights r4 + "vmla.f32 q9, q0, q6 \n" + "vmla.f32 q10, q0, q7 \n" + + "vext.32 q5, q3, q4, #1 \n" + + "vmla.f32 q9, q2, q3 \n" + "vmla.f32 q10, q2, q5 \n" + + "vld1.32 {d4[0]}, [%[wh]] \n" + "vld1.32 {d6}, [%[bias]] \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d18, d18, d19 \n" + + "vmla.f32 d18, d8, d4[0] \n" + + // add bias + "vadd.f32 d18, d18, d6 \n" + + "vst1.32 {d18[0]}, [%[dout0]] \n" + "vst1.32 {d18[1]}, [%[dout1]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [wh] "+r"(weights) + : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) + : "memory", + "r0", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11"); +} + +//! kernel for one out without extracting data mid +//! deal with two lines out +void compute_one_out_without_extract_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "vmov.i32 q15, #0x0 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vld1.32 {d4-d5}, [%[din0]]! \n" + "vld1.32 {d6-d7}, [%[din1]]! \n" + "vld1.32 {d8-d9}, [%[din2]]! \n" + "vld1.32 {d10-d11}, [%[din3]]! \n" + "vld1.32 {d12-d13}, [%[din4]]! \n" + "vld1.32 {d14-d15}, [%[din5]]! \n" + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + "vld1.32 {d6[0]}, [%[din0]] \n" + "vld1.32 {d6[1]}, [%[din1]] \n" + "vld1.32 {d7[0]}, [%[din2]] \n" + "vld1.32 {d7[1]}, [%[din3]] \n" + + // weights r2 + "vmla.f32 q9, q0, q4 \n" + "vmla.f32 q10, q0, q5 \n" + + "vld1.32 {d8[0]}, [%[din4]] \n" + "vld1.32 {d8[1]}, [%[din5]] \n" + + "vld1.32 {d0-d1}, [%[wh]] \n" + + // weights r3 + "vmla.f32 q9, q1, q5 \n" + "vmla.f32 q10, q1, q6 \n" + + // weights col4 + "sub %[wh], #64 \n" + "vld1.32 {d4[0]}, [%[wh]], r0 \n" + "vld1.32 {d4[1]}, [%[wh]], r0 \n" + "vld1.32 {d5[0]}, [%[wh]], r0 \n" + "vld1.32 {d5[1]}, [%[wh]], r0 \n" + + // weights r4 + "vmla.f32 q9, q0, q6 \n" + "vmla.f32 q10, q0, q7 \n" + + "vext.32 q5, q3, q4, #1 \n" + + "vmla.f32 q9, q2, q3 \n" + "vmla.f32 q10, q2, q5 \n" + + "vld1.32 {d4[0]}, [%[wh]] \n" + "vld1.32 {d6}, [%[bias]] \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d18, d18, d19 \n" + + "vmla.f32 d18, d8, d4[0] \n" + + // add bias + "vadd.f32 d18, d18, d6 \n" + + // relu + "vmax.f32 d18, d18, d30 \n" + + "vst1.32 {d18[0]}, [%[dout0]] \n" + "vst1.32 {d18[1]}, [%[dout1]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [wh] "+r"(weights) + : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) + : "memory", + "r0", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q15"); +} + +//! kernel for one out without extracting data pre +//! deal with two lines out +void compute_one_out_extract_pre(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "add %[wh], #4 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vld1.32 {d4-d5}, [%[din0]]! \n" + "vld1.32 {d6-d7}, [%[din1]]! \n" + "vld1.32 {d8-d9}, [%[din2]]! \n" + "vld1.32 {d10-d11}, [%[din3]]! \n" + "vld1.32 {d12-d13}, [%[din4]]! \n" + "vld1.32 {d14-d15}, [%[din5]]! \n" + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + // weights r2 + "vmla.f32 q9, q0, q4 \n" + "vmla.f32 q10, q0, q5 \n" + + "vld1.32 {d0-d1}, [%[wh]] \n" + + // weights r3 + "vmla.f32 q9, q1, q5 \n" + "vmla.f32 q10, q1, q6 \n" + + // weights r4 + "vmla.f32 q9, q0, q6 \n" + "vmla.f32 q10, q0, q7 \n" + + // load bias + "vld1.32 {d0}, [%[bias]] \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d18, d18, d19 \n" + + // add bias + "vadd.f32 d18, d18, d0 \n" + + "vst1.32 {d18[0]}, [%[dout0]] \n" + "vst1.32 {d18[1]}, [%[dout1]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [wh] "+r"(weights) + : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) + : "memory", + "r0", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11"); +} + +//! kernel for one out without extracting data pre +//! deal with two lines out +void compute_one_out_extract_pre_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "add %[wh], #4 \n" + "vmov.i32 q15, #0x0 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vld1.32 {d4-d5}, [%[din0]]! \n" + "vld1.32 {d6-d7}, [%[din1]]! \n" + "vld1.32 {d8-d9}, [%[din2]]! \n" + "vld1.32 {d10-d11}, [%[din3]]! \n" + "vld1.32 {d12-d13}, [%[din4]]! \n" + "vld1.32 {d14-d15}, [%[din5]]! \n" + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + // weights r2 + "vmla.f32 q9, q0, q4 \n" + "vmla.f32 q10, q0, q5 \n" + + "vld1.32 {d0-d1}, [%[wh]] \n" + + // weights r3 + "vmla.f32 q9, q1, q5 \n" + "vmla.f32 q10, q1, q6 \n" + + // weights r4 + "vmla.f32 q9, q0, q6 \n" + "vmla.f32 q10, q0, q7 \n" + + // load bias + "vld1.32 {d0}, [%[bias]] \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d18, d18, d19 \n" + + // add bias + "vadd.f32 d18, d18, d0 \n" + + // relu + "vmax.f32 d18, d18, d30 \n" + "vst1.32 {d18[0]}, [%[dout0]] \n" + "vst1.32 {d18[1]}, [%[dout1]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [wh] "+r"(weights) + : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) + : "memory", + "r0", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q15"); +} + +//! kernel for one out with extracting data post +//! deal with two lines out +void compute_one_out_extract_post(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vld1.32 {d4-d5}, [%[din0]]! \n" + "vld1.32 {d6-d7}, [%[din1]]! \n" + "vld1.32 {d8-d9}, [%[din2]]! \n" + "vld1.32 {d10-d11}, [%[din3]]! \n" + "vld1.32 {d12-d13}, [%[din4]]! \n" + "vld1.32 {d14-d15}, [%[din5]]! \n" + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + // weights r2 + "vmla.f32 q9, q0, q4 \n" + "vmla.f32 q10, q0, q5 \n" + + "vld1.32 {d0-d1}, [%[wh]] \n" + + // weights r3 + "vmla.f32 q9, q1, q5 \n" + "vmla.f32 q10, q1, q6 \n" + + // weights r4 + "vmla.f32 q9, q0, q6 \n" + "vmla.f32 q10, q0, q7 \n" + + "vld1.32 {d0}, [%[bias]] \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d18, d18, d19 \n" + + // add bias + "vadd.f32 d18, d18, d0 \n" + + "vst1.32 {d18[0]}, [%[dout0]] \n" + "vst1.32 {d18[1]}, [%[dout1]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [wh] "+r"(weights) + : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) + : "memory", + "r0", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11"); +} + +//! kernel for one out with extracting data post +//! deal with two lines out +void compute_one_out_extract_post_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "vmov.i32 q15, #0x0 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vld1.32 {d4-d5}, [%[din0]]! \n" + "vld1.32 {d6-d7}, [%[din1]]! \n" + "vld1.32 {d8-d9}, [%[din2]]! \n" + "vld1.32 {d10-d11}, [%[din3]]! \n" + "vld1.32 {d12-d13}, [%[din4]]! \n" + "vld1.32 {d14-d15}, [%[din5]]! \n" + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + // weights r2 + "vmla.f32 q9, q0, q4 \n" + "vmla.f32 q10, q0, q5 \n" + + "vld1.32 {d0-d1}, [%[wh]] \n" + + // weights r3 + "vmla.f32 q9, q1, q5 \n" + "vmla.f32 q10, q1, q6 \n" + + // weights r4 + "vmla.f32 q9, q0, q6 \n" + "vmla.f32 q10, q0, q7 \n" + + "vld1.32 {d0}, [%[bias]] \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d18, d18, d19 \n" + + // add bias + "vadd.f32 d18, d18, d0 \n" + + // relu + "vmax.f32 d18, d18, d30 \n" + + "vst1.32 {d18[0]}, [%[dout0]] \n" + "vst1.32 {d18[1]}, [%[dout1]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [wh] "+r"(weights) + : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) + : "memory", + "r0", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q15"); +} + +//! kernel for two out with extracting data pre +//! deal with two lines out +void compute_two_out_extract_pre(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "mov r1, #0 \n" + "add %[wh], #8 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vmov.32 d1[1], r1 \n" + "vmov.32 d3[1], r1 \n" + + "vld1.32 {d4-d5}, [%[din0]]! \n" + "vld1.32 {d6-d7}, [%[din1]]! \n" + "vld1.32 {d8-d9}, [%[din2]]! \n" + "vld1.32 {d10-d11}, [%[din3]]! \n" + "vld1.32 {d12-d13}, [%[din4]]! \n" + "vld1.32 {d14-d15}, [%[din5]]! \n" + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + "vmov.32 d25[1], r1 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + "vmov.32 d27[1], r1 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + "vld1.32 {d28-d29}, [%[wh]]\n" + "vmov.32 d29[1], r1 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + + "sub %[wh], #84 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + + "vpadd.f32 d22, d18, d19 \n" + "vpadd.f32 d23, d20, d21 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + "vld1.32 {d28-d29}, [%[wh]]\n" + + "vpadd.f32 d22, d22, d23 \n" + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + "vld1.32 {d30-d31}, [%[bias]] \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d23, d18, d19 \n" + + // trn out neon register + "vtrn.32 d22, d23 \n" + + // add bias + "vadd.f32 q11, q11, q15 \n" + + // store result + "vst1.32 {d22}, [%[dout0]] \n" + "vst1.32 {d23}, [%[dout1]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [wh] "+r"(weights) + : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) + : "memory", + "r0", + "r1", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +} + +//! kernel for two out with extracting data pre +//! deal with two lines out +void compute_two_out_extract_pre_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "mov r1, #0 \n" + "add %[wh], #8 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vmov.32 d1[1], r1 \n" + "vmov.32 d3[1], r1 \n" + + "vld1.32 {d4-d5}, [%[din0]]! \n" + "vld1.32 {d6-d7}, [%[din1]]! \n" + "vld1.32 {d8-d9}, [%[din2]]! \n" + "vld1.32 {d10-d11}, [%[din3]]! \n" + "vld1.32 {d12-d13}, [%[din4]]! \n" + "vld1.32 {d14-d15}, [%[din5]]! \n" + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + "vmov.32 d25[1], r1 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + "vmov.32 d27[1], r1 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + "vld1.32 {d28-d29}, [%[wh]]\n" + "vmov.32 d29[1], r1 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + + "sub %[wh], #84 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + + "vpadd.f32 d22, d18, d19 \n" + "vpadd.f32 d23, d20, d21 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + "vld1.32 {d28-d29}, [%[wh]]\n" + + "vpadd.f32 d22, d22, d23 \n" + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + "vld1.32 {d30-d31}, [%[bias]] \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d23, d18, d19 \n" + "vmov.i32 q9, #0x0 \n" + + // trn out neon register + "vtrn.32 d22, d23 \n" + + // add bias + "vadd.f32 q11, q11, q15 \n" + + // relu + "vmax.f32 q11, q11, q9 \n" + // store result + "vst1.32 {d22}, [%[dout0]] \n" + "vst1.32 {d23}, [%[dout1]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [wh] "+r"(weights) + : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) + : "memory", + "r0", + "r1", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +} + +//! kernel for two out with extracting data post +//! deal with two lines out +void compute_two_out_extract_post(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vld1.32 {d4-d5}, [%[din0]]! \n" + "vld1.32 {d6-d7}, [%[din1]]! \n" + "vld1.32 {d8-d9}, [%[din2]]! \n" + "vld1.32 {d10-d11}, [%[din3]]! \n" + "vld1.32 {d12-d13}, [%[din4]]! \n" + "vld1.32 {d14-d15}, [%[din5]]! \n" + + //! out zero + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + "vld1.32 {d28-d29}, [%[wh]]\n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + + "vpadd.f32 d22, d18, d19 \n" + "vpadd.f32 d23, d20, d21 \n" + "vpadd.f32 d22, d22, d23 \n" + + "vmov.f32 q15, #0.0 \n" + "vext.32 q2, q2, q15, #1 \n" + "vext.32 q3, q3, q15, #1 \n" + "vext.32 q4, q4, q15, #1 \n" + "vext.32 q5, q5, q15, #1 \n" + "vext.32 q6, q6, q15, #1 \n" + "vext.32 q7, q7, q15, #1 \n" + "vext.32 q8, q8, q15, #1 \n" + + //! out one + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + "vld1.32 {d30-d31}, [%[bias]] \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d23, d18, d19 \n" + + // trn out neon register + "vtrn.32 d22, d23 \n" + + // add bias + "vadd.f32 q11, q11, q15 \n" + + // store result + "vst1.32 {d22}, [%[dout0]] \n" + "vst1.32 {d23}, [%[dout1]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [wh] "+r"(weights) + : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) + : "memory", + "r0", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +} + +//! kernel for two out with extracting data post +//! deal with two lines out +void compute_two_out_extract_post_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vld1.32 {d4-d5}, [%[din0]]! \n" + "vld1.32 {d6-d7}, [%[din1]]! \n" + "vld1.32 {d8-d9}, [%[din2]]! \n" + "vld1.32 {d10-d11}, [%[din3]]! \n" + "vld1.32 {d12-d13}, [%[din4]]! \n" + "vld1.32 {d14-d15}, [%[din5]]! \n" + + //! out zero + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + "vld1.32 {d28-d29}, [%[wh]]\n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + + "vpadd.f32 d22, d18, d19 \n" + "vpadd.f32 d23, d20, d21 \n" + "vpadd.f32 d22, d22, d23 \n" + + "vmov.f32 q15, #0.0 \n" + "vext.32 q2, q2, q15, #1 \n" + "vext.32 q3, q3, q15, #1 \n" + "vext.32 q4, q4, q15, #1 \n" + "vext.32 q5, q5, q15, #1 \n" + "vext.32 q6, q6, q15, #1 \n" + "vext.32 q7, q7, q15, #1 \n" + "vext.32 q8, q8, q15, #1 \n" + + //! out one + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + "vld1.32 {d30-d31}, [%[bias]] \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d23, d18, d19 \n" + "vmov.i32 q9, #0x0 \n" + + // trn out neon register + "vtrn.32 d22, d23 \n" + + // add bias + "vadd.f32 q11, q11, q15 \n" + + // relu + "vmax.f32 q11, q11, q9 \n" + + // store result + "vst1.32 {d22}, [%[dout0]] \n" + "vst1.32 {d23}, [%[dout1]] \n" + + : [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [wh] "+r"(weights) + : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) + : "memory", + "r0", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +} + +//! kernel for three out with extracting data pre +//! deal with two lines out +void compute_three_out_extract_pre(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "add %[wh], #12 \n" + "vld1.32 {d0}, [%[wh]], r0 \n" + "vld1.32 {d2}, [%[wh]], r0 \n" + + "vld1.32 {d4-d5}, [%[din0]] \n" + "vld1.32 {d6-d7}, [%[din1]] \n" + "vld1.32 {d8-d9}, [%[din2]] \n" + "vld1.32 {d10-d11}, [%[din3]] \n" + "vld1.32 {d12-d13}, [%[din4]] \n" + "vld1.32 {d14-d15}, [%[din5]] \n" + + //! out zero + // weights r0 + "vmul.f32 d18, d0, d4 \n" + "vmul.f32 d20, d0, d6 \n" + + "vld1.32 {d24}, [%[wh]], r0 \n" + + // weights r1 + "vmla.f32 d18, d2, d6 \n" + "vmla.f32 d20, d2, d8 \n" + + "vld1.32 {d26}, [%[wh]], r0 \n" + + // weights r2 + "vmla.f32 d18, d24, d8 \n" + "vmla.f32 d20, d24, d10 \n" + + "vld1.32 {d28}, [%[wh]] \n" + + // weights r3 + "vmla.f32 d18, d26, d10 \n" + "vmla.f32 d20, d26, d12 \n" + + // load bias + "vld1.32 {d30-d31}, [%[bias]] \n" + + // weights r4 + "vmla.f32 d18, d28, d12 \n" + "vmla.f32 d20, d28, d14 \n" + "vpadd.f32 d22, d18, d20 \n" + + //! out one + "mov r1, #0 \n" + "sub %[wh], #84 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vmov.32 d1[1], r1 \n" + "vmov.32 d3[1], r1 \n" + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + "vmov.32 d25[1], r1 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + "vmov.32 d27[1], r1 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + "vld1.32 {d28-d29}, [%[wh]]\n" + "vmov.32 d29[1], r1 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + + "sub %[wh], #84 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + "vld1.32 {d28-d29}, [%[wh]]\n" + + "vpadd.f32 d23, d18, d19 \n" + + // trn out neon register + "vtrn.32 d22, d23 \n" + + // add bias + "vadd.f32 q11, q11, q15 \n" + + // store result + "vst1.32 {d22}, [%[dout0]]! \n" + "vst1.32 {d23}, [%[dout1]]! \n" + + //! out two + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d18, d18, d19 \n" + + // add bias + "vadd.f32 d18, d18, d30 \n" + + // store result + "vst1.32 {d18[0]}, [%[dout0]] \n" + "vst1.32 {d18[1]}, [%[dout1]] \n" + + : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) + : [din0] "r"(din0), + [din1] "r"(din1), + [din2] "r"(din2), + [din3] "r"(din3), + [din4] "r"(din4), + [din5] "r"(din5), + [bias] "r"(bias) + : "memory", + "r0", + "r1", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +} + +//! kernel for three out with extracting data pre +//! deal with two lines out +void compute_three_out_extract_pre_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "add %[wh], #12 \n" + "vld1.32 {d0}, [%[wh]], r0 \n" + "vld1.32 {d2}, [%[wh]], r0 \n" + + "vld1.32 {d4-d5}, [%[din0]] \n" + "vld1.32 {d6-d7}, [%[din1]] \n" + "vld1.32 {d8-d9}, [%[din2]] \n" + "vld1.32 {d10-d11}, [%[din3]] \n" + "vld1.32 {d12-d13}, [%[din4]] \n" + "vld1.32 {d14-d15}, [%[din5]] \n" + + //! out zero + // weights r0 + "vmul.f32 d18, d0, d4 \n" + "vmul.f32 d20, d0, d6 \n" + + "vld1.32 {d24}, [%[wh]], r0 \n" + + // weights r1 + "vmla.f32 d18, d2, d6 \n" + "vmla.f32 d20, d2, d8 \n" + + "vld1.32 {d26}, [%[wh]], r0 \n" + + // weights r2 + "vmla.f32 d18, d24, d8 \n" + "vmla.f32 d20, d24, d10 \n" + + "vld1.32 {d28}, [%[wh]] \n" + + // weights r3 + "vmla.f32 d18, d26, d10 \n" + "vmla.f32 d20, d26, d12 \n" + + // load bias + "vld1.32 {d30-d31}, [%[bias]] \n" + + // weights r4 + "vmla.f32 d18, d28, d12 \n" + "vmla.f32 d20, d28, d14 \n" + "vpadd.f32 d22, d18, d20 \n" + + //! out one + "mov r1, #0 \n" + "sub %[wh], #84 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vmov.32 d1[1], r1 \n" + "vmov.32 d3[1], r1 \n" + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + "vmov.32 d25[1], r1 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + "vmov.32 d27[1], r1 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + "vld1.32 {d28-d29}, [%[wh]]\n" + "vmov.32 d29[1], r1 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + + "sub %[wh], #84 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + "vld1.32 {d28-d29}, [%[wh]]\n" + + "vpadd.f32 d23, d18, d19 \n" + "vmov.i32 q8, #0x0 \n" + + // trn out neon register + "vtrn.32 d22, d23 \n" + + // add bias + "vadd.f32 q11, q11, q15 \n" + + // relu + "vmax.f32 q11, q11, q8 \n" + + // store result + "vst1.32 {d22}, [%[dout0]]! \n" + "vst1.32 {d23}, [%[dout1]]! \n" + + //! out two + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d18, d18, d19 \n" + + // add bias + "vadd.f32 d18, d18, d30 \n" + + // relu + "vmax.f32 d18, d18, d16 \n" + + // store result + "vst1.32 {d18[0]}, [%[dout0]] \n" + "vst1.32 {d18[1]}, [%[dout1]] \n" + + : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) + : [din0] "r"(din0), + [din1] "r"(din1), + [din2] "r"(din2), + [din3] "r"(din3), + [din4] "r"(din4), + [din5] "r"(din5), + [bias] "r"(bias) + : "memory", + "r0", + "r1", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +} + +//! kernel for three out with extracting data post +//! deal with two lines out +void compute_three_out_extract_post(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vld1.32 {d4-d5}, [%[din0]] \n" + "vld1.32 {d6-d7}, [%[din1]] \n" + "vld1.32 {d8-d9}, [%[din2]] \n" + "vld1.32 {d10-d11}, [%[din3]] \n" + "vld1.32 {d12-d13}, [%[din4]] \n" + "vld1.32 {d14-d15}, [%[din5]] \n" + + //! out zero && two + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + "vmul.f32 d16, d0, d5 \n" + "vmul.f32 d17, d0, d7 \n" + + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + "vmla.f32 d16, d2, d7 \n" + "vmla.f32 d17, d2, d9 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + "vmla.f32 d16, d24, d9 \n" + "vmla.f32 d17, d24, d11 \n" + + "vld1.32 {d28-d29}, [%[wh]]\n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + "vmla.f32 d16, d26, d11 \n" + "vmla.f32 d17, d26, d13 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + "vmla.f32 d16, d28, d13 \n" + "vmla.f32 d17, d28, d15 \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d16, d16, d17 \n" + "vpadd.f32 d22, d18, d19 \n" + + "vmov.f32 q15, #0.0 \n" + "vext.32 q2, q2, q15, #1 \n" + "vext.32 q3, q3, q15, #1 \n" + "vext.32 q4, q4, q15, #1 \n" + "vext.32 q5, q5, q15, #1 \n" + "vext.32 q6, q6, q15, #1 \n" + "vext.32 q7, q7, q15, #1 \n" + + //! out one + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + + // load bias + "vld1.32 {d30-d31}, [%[bias]] \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d23, d18, d19 \n" + "vmov.i32 q9, #0x0 \n" + + // trn out neon register + "vtrn.32 d22, d23 \n" + + // add bias + "vadd.f32 q11, q11, q15 \n" + "vadd.f32 d16, d16, d30 \n" + + "vst1.32 {d22}, [%[dout0]]! \n" + "vst1.32 {d23}, [%[dout1]]! \n" + "vst1.32 {d16[0]}, [%[dout0]]! \n" + "vst1.32 {d16[1]}, [%[dout1]]! \n" + + : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) + : [din0] "r"(din0), + [din1] "r"(din1), + [din2] "r"(din2), + [din3] "r"(din3), + [din4] "r"(din4), + [din5] "r"(din5), + [bias] "r"(bias) + : "memory", + "r0", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +} + +//! kernel for three out with extracting data post +//! deal with two lines out +void compute_three_out_extract_post_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vld1.32 {d4-d5}, [%[din0]] \n" + "vld1.32 {d6-d7}, [%[din1]] \n" + "vld1.32 {d8-d9}, [%[din2]] \n" + "vld1.32 {d10-d11}, [%[din3]] \n" + "vld1.32 {d12-d13}, [%[din4]] \n" + "vld1.32 {d14-d15}, [%[din5]] \n" + + //! out zero && two + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + "vmul.f32 d16, d0, d5 \n" + "vmul.f32 d17, d0, d7 \n" + + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + "vmla.f32 d16, d2, d7 \n" + "vmla.f32 d17, d2, d9 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + "vmla.f32 d16, d24, d9 \n" + "vmla.f32 d17, d24, d11 \n" + + "vld1.32 {d28-d29}, [%[wh]]\n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + "vmla.f32 d16, d26, d11 \n" + "vmla.f32 d17, d26, d13 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + "vmla.f32 d16, d28, d13 \n" + "vmla.f32 d17, d28, d15 \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d16, d16, d17 \n" + "vpadd.f32 d22, d18, d19 \n" + + "vmov.f32 q15, #0.0 \n" + "vext.32 q2, q2, q15, #1 \n" + "vext.32 q3, q3, q15, #1 \n" + "vext.32 q4, q4, q15, #1 \n" + "vext.32 q5, q5, q15, #1 \n" + "vext.32 q6, q6, q15, #1 \n" + "vext.32 q7, q7, q15, #1 \n" + + //! out one + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + + // load bias + "vld1.32 {d30-d31}, [%[bias]] \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d23, d18, d19 \n" + "vmov.i32 q9, #0x0 \n" + + // trn out neon register + "vtrn.32 d22, d23 \n" + + // add bias + "vadd.f32 q11, q11, q15 \n" + "vadd.f32 d16, d16, d30 \n" + + // relu + "vmax.f32 q11, q11, q9 \n" + "vmax.f32 d16, d16, d18 \n" + + "vst1.32 {d22}, [%[dout0]]! \n" + "vst1.32 {d23}, [%[dout1]]! \n" + "vst1.32 {d16[0]}, [%[dout0]]! \n" + "vst1.32 {d16[1]}, [%[dout1]]! \n" + + : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) + : [din0] "r"(din0), + [din1] "r"(din1), + [din2] "r"(din2), + [din3] "r"(din3), + [din4] "r"(din4), + [din5] "r"(din5), + [bias] "r"(bias) + : "memory", + "r0", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +} + +//! kernel for four out with extracting data pre +//! deal with two lines out +void compute_four_out_extract_pre(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "add %[wh], #16 \n" + + //! out zero + // load input + "vld1.32 {d4[0]}, [%[din0]] \n" + "vld1.32 {d4[1]}, [%[din1]] \n" + "vld1.32 {d5[0]}, [%[din2]] \n" + "vld1.32 {d5[1]}, [%[din3]] \n" + "vld1.32 {d6[0]}, [%[din4]] \n" + "vld1.32 {d6[1]}, [%[din5]] \n" + + "vext.32 q4, q2, q3, #1 \n" + + // load weights + "vld1.32 d0[0], [%[wh]], r0 \n" + "vld1.32 d0[1], [%[wh]], r0 \n" + "vld1.32 d1[0], [%[wh]], r0 \n" + "vld1.32 d1[1], [%[wh]], r0 \n" + "vld1.32 d2[0], [%[wh]]\n" + + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q4 \n" + + "vld1.32 {d30-d31}, [%[bias]] \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d22, d18, d19 \n" + + "vmla.f32 d22, d6, d2[0] \n" + + "sub %[wh], #84 \n" + "vld1.32 {d0}, [%[wh]], r0 \n" + "vld1.32 {d2}, [%[wh]], r0 \n" + + "vld1.32 {d4-d5}, [%[din0]] \n" + "vld1.32 {d6-d7}, [%[din1]] \n" + "vld1.32 {d8-d9}, [%[din2]] \n" + "vld1.32 {d10-d11}, [%[din3]] \n" + "vld1.32 {d12-d13}, [%[din4]] \n" + "vld1.32 {d14-d15}, [%[din5]] \n" + + //! out one + // weights r0 + "vmul.f32 d18, d0, d4 \n" + "vmul.f32 d20, d0, d6 \n" + + "vld1.32 {d24}, [%[wh]], r0 \n" + + // weights r1 + "vmla.f32 d18, d2, d6 \n" + "vmla.f32 d20, d2, d8 \n" + + "vld1.32 {d26}, [%[wh]], r0 \n" + + // weights r2 + "vmla.f32 d18, d24, d8 \n" + "vmla.f32 d20, d24, d10 \n" + + "vld1.32 {d28}, [%[wh]] \n" + + // weights r3 + "vmla.f32 d18, d26, d10 \n" + "vmla.f32 d20, d26, d12 \n" + + // weights r4 + "vmla.f32 d18, d28, d12 \n" + "vmla.f32 d20, d28, d14 \n" + + "vpadd.f32 d23, d18, d20 \n" + + // trn out neon register + "vtrn.32 d22, d23 \n" + + // add bias + "vadd.f32 q11, q11, q15 \n" + + // store result + "vst1.32 {d22}, [%[dout0]]! \n" + "vst1.32 {d23}, [%[dout1]]! \n" + + //! out two + "mov r1, #0 \n" + "sub %[wh], #84 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vmov.32 d1[1], r1 \n" + "vmov.32 d3[1], r1 \n" + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + "vmov.32 d25[1], r1 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + "vmov.32 d27[1], r1 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + "vld1.32 {d28-d29}, [%[wh]]\n" + "vmov.32 d29[1], r1 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + + "sub %[wh], #84 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + "vld1.32 {d28-d29}, [%[wh]]\n" + + "vpadd.f32 d22, d18, d19 \n" + + //! out three + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d23, d18, d19 \n" + + // trn out neon register + "vtrn.32 d22, d23 \n" + + // add bias + "vadd.f32 q11, q11, q15 \n" + + // store result + "vst1.32 {d22}, [%[dout0]] \n" + "vst1.32 {d23}, [%[dout1]] \n" + + : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) + : [din0] "r"(din0), + [din1] "r"(din1), + [din2] "r"(din2), + [din3] "r"(din3), + [din4] "r"(din4), + [din5] "r"(din5), + [bias] "r"(bias) + : "memory", + "r0", + "r1", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +} + +//! kernel for four out with extracting data pre +//! deal with two lines out +void compute_four_out_extract_pre_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "add %[wh], #16 \n" + + //! out zero + // load input + "vld1.32 {d4[0]}, [%[din0]] \n" + "vld1.32 {d4[1]}, [%[din1]] \n" + "vld1.32 {d5[0]}, [%[din2]] \n" + "vld1.32 {d5[1]}, [%[din3]] \n" + "vld1.32 {d6[0]}, [%[din4]] \n" + "vld1.32 {d6[1]}, [%[din5]] \n" + + "vext.32 q4, q2, q3, #1 \n" + + // load weights + "vld1.32 d0[0], [%[wh]], r0 \n" + "vld1.32 d0[1], [%[wh]], r0 \n" + "vld1.32 d1[0], [%[wh]], r0 \n" + "vld1.32 d1[1], [%[wh]], r0 \n" + "vld1.32 d2[0], [%[wh]]\n" + + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q4 \n" + + "vld1.32 {d30-d31}, [%[bias]] \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d22, d18, d19 \n" + + "vmla.f32 d22, d6, d2[0] \n" + + "sub %[wh], #84 \n" + "vld1.32 {d0}, [%[wh]], r0 \n" + "vld1.32 {d2}, [%[wh]], r0 \n" + + "vld1.32 {d4-d5}, [%[din0]] \n" + "vld1.32 {d6-d7}, [%[din1]] \n" + "vld1.32 {d8-d9}, [%[din2]] \n" + "vld1.32 {d10-d11}, [%[din3]] \n" + "vld1.32 {d12-d13}, [%[din4]] \n" + "vld1.32 {d14-d15}, [%[din5]] \n" + + //! out one + // weights r0 + "vmul.f32 d18, d0, d4 \n" + "vmul.f32 d20, d0, d6 \n" + + "vld1.32 {d24}, [%[wh]], r0 \n" + + // weights r1 + "vmla.f32 d18, d2, d6 \n" + "vmla.f32 d20, d2, d8 \n" + + "vld1.32 {d26}, [%[wh]], r0 \n" + + // weights r2 + "vmla.f32 d18, d24, d8 \n" + "vmla.f32 d20, d24, d10 \n" + + "vld1.32 {d28}, [%[wh]] \n" + + // weights r3 + "vmla.f32 d18, d26, d10 \n" + "vmla.f32 d20, d26, d12 \n" + + // weights r4 + "vmla.f32 d18, d28, d12 \n" + "vmla.f32 d20, d28, d14 \n" + + "vpadd.f32 d23, d18, d20 \n" + "vmov.i32 q8, #0x0 \n" + + // trn out neon register + "vtrn.32 d22, d23 \n" + + // add bias + "vadd.f32 q11, q11, q15 \n" + + // relu + "vmax.f32 q11, q11, q8 \n" + + // store result + "vst1.32 {d22}, [%[dout0]]! \n" + "vst1.32 {d23}, [%[dout1]]! \n" + + //! out two + "mov r1, #0 \n" + "sub %[wh], #84 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vmov.32 d1[1], r1 \n" + "vmov.32 d3[1], r1 \n" + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + "vmov.32 d25[1], r1 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + "vmov.32 d27[1], r1 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + "vld1.32 {d28-d29}, [%[wh]]\n" + "vmov.32 d29[1], r1 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + + "sub %[wh], #84 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + "vld1.32 {d28-d29}, [%[wh]]\n" + + "vpadd.f32 d22, d18, d19 \n" + + //! out three + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d23, d18, d19 \n" + + // trn out neon register + "vtrn.32 d22, d23 \n" + + // add bias + "vadd.f32 q11, q11, q15 \n" + + // relu + "vmax.f32 q11, q11, q8 \n" + + // store result + "vst1.32 {d22}, [%[dout0]] \n" + "vst1.32 {d23}, [%[dout1]] \n" + + : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) + : [din0] "r"(din0), + [din1] "r"(din1), + [din2] "r"(din2), + [din3] "r"(din3), + [din4] "r"(din4), + [din5] "r"(din5), + [bias] "r"(bias) + : "memory", + "r0", + "r1", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +} + +//! kernel for three out with extracting data post +//! deal with two lines out +void compute_four_out_extract_post(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "mov r1, #12 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vld1.32 {d4-d5}, [%[din0]], r1 \n" + "vld1.32 {d6-d7}, [%[din1]], r1 \n" + "vld1.32 {d8-d9}, [%[din2]], r1 \n" + "vld1.32 {d10-d11}, [%[din3]], r1 \n" + "vld1.32 {d12-d13}, [%[din4]], r1 \n" + "vld1.32 {d14-d15}, [%[din5]], r1 \n" + + //! out zero && two + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + "vmul.f32 d16, d0, d5 \n" + "vmul.f32 d17, d0, d7 \n" + + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + "vmla.f32 d16, d2, d7 \n" + "vmla.f32 d17, d2, d9 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + "vmla.f32 d16, d24, d9 \n" + "vmla.f32 d17, d24, d11 \n" + + "vld1.32 {d28-d29}, [%[wh]] \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + "vmla.f32 d16, d26, d11 \n" + "vmla.f32 d17, d26, d13 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + "vmla.f32 d16, d28, d13 \n" + "vmla.f32 d17, d28, d15 \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d16, d16, d17 \n" + "vpadd.f32 d22, d18, d19 \n" + + //! out one + "vmov.f32 q15, #0.0 \n" + "vext.32 q2, q2, q15, #1 \n" + "vext.32 q3, q3, q15, #1 \n" + "vext.32 q4, q4, q15, #1 \n" + "vext.32 q5, q5, q15, #1 \n" + "vext.32 q6, q6, q15, #1 \n" + "vext.32 q7, q7, q15, #1 \n" + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + + "vld1.32 {d30-d31}, [%[bias]] \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d23, d18, d19 \n" + + // trn out neon register + "vtrn.32 d22, d23 \n" + + // add bias + "vadd.f32 q11, q11, q15 \n" + + // store result + "vst1.32 {d22}, [%[dout0]]! \n" + "vst1.32 {d23}, [%[dout1]]! \n" + + //! out three + "sub %[wh], #80 \n" + "vld1.32 {d4[0]}, [%[din0]] \n" + "vld1.32 {d4[1]}, [%[din1]] \n" + "vld1.32 {d5[0]}, [%[din2]] \n" + "vld1.32 {d5[1]}, [%[din3]] \n" + "vld1.32 {d6[0]}, [%[din4]] \n" + "vld1.32 {d6[1]}, [%[din5]] \n" + + "vext.32 q4, q2, q3, #1 \n" + + "vld1.32 {d0[0]}, [%[wh]], r0 \n" + "vld1.32 {d0[1]}, [%[wh]], r0 \n" + "vld1.32 {d1[0]}, [%[wh]], r0 \n" + "vld1.32 {d1[1]}, [%[wh]], r0 \n" + "vld1.32 {d2[0]}, [%[wh]] \n" + + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q4 \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d20, d20, d21 \n" + "vpadd.f32 d17, d18, d20 \n" + + "vmla.f32 d17, d6, d2[0] \n" + + // trn out neon register + "vtrn.32 d16, d17 \n" + + // add bias + "vadd.f32 q8, q8, q15 \n" + + // store result + "vst1.32 {d16}, [%[dout0]] \n" + "vst1.32 {d17}, [%[dout1]] \n" + + : [dout0] "+r"(dout0), + [dout1] "+r"(dout1), + [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [wh] "+r"(weights) + : [bias] "r"(bias) + : "memory", + "r0", + "r1", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +} + +//! kernel for three out with extracting data post +//! deal with two lines out +void compute_four_out_extract_post_relu(const float* din0, + const float* din1, + const float* din2, + const float* din3, + const float* din4, + const float* din5, + float* dout0, + float* dout1, + const float* weights, + const float* bias) { + asm volatile( + "mov r0, #20 \n" + "mov r1, #12 \n" + "vld1.32 {d0-d1}, [%[wh]], r0 \n" + "vld1.32 {d2-d3}, [%[wh]], r0 \n" + + "vld1.32 {d4-d5}, [%[din0]], r1 \n" + "vld1.32 {d6-d7}, [%[din1]], r1 \n" + "vld1.32 {d8-d9}, [%[din2]], r1 \n" + "vld1.32 {d10-d11}, [%[din3]], r1 \n" + "vld1.32 {d12-d13}, [%[din4]], r1 \n" + "vld1.32 {d14-d15}, [%[din5]], r1 \n" + + //! out zero && two + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + "vmul.f32 d16, d0, d5 \n" + "vmul.f32 d17, d0, d7 \n" + + "vld1.32 {d24-d25}, [%[wh]], r0 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + "vmla.f32 d16, d2, d7 \n" + "vmla.f32 d17, d2, d9 \n" + + "vld1.32 {d26-d27}, [%[wh]], r0 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + "vmla.f32 d16, d24, d9 \n" + "vmla.f32 d17, d24, d11 \n" + + "vld1.32 {d28-d29}, [%[wh]] \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + "vmla.f32 d16, d26, d11 \n" + "vmla.f32 d17, d26, d13 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + "vmla.f32 d16, d28, d13 \n" + "vmla.f32 d17, d28, d15 \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d16, d16, d17 \n" + "vpadd.f32 d22, d18, d19 \n" + + //! out one + "vmov.f32 q15, #0.0 \n" + "vext.32 q2, q2, q15, #1 \n" + "vext.32 q3, q3, q15, #1 \n" + "vext.32 q4, q4, q15, #1 \n" + "vext.32 q5, q5, q15, #1 \n" + "vext.32 q6, q6, q15, #1 \n" + "vext.32 q7, q7, q15, #1 \n" + + // weights r0 + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q3 \n" + + // weights r1 + "vmla.f32 q9, q1, q3 \n" + "vmla.f32 q10, q1, q4 \n" + + // weights r2 + "vmla.f32 q9, q12, q4 \n" + "vmla.f32 q10, q12, q5 \n" + + // weights r3 + "vmla.f32 q9, q13, q5 \n" + "vmla.f32 q10, q13, q6 \n" + + // weights r4 + "vmla.f32 q9, q14, q6 \n" + "vmla.f32 q10, q14, q7 \n" + + "vld1.32 {d30-d31}, [%[bias]] \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d19, d20, d21 \n" + "vpadd.f32 d23, d18, d19 \n" + "vmov.i32 q5, #0x0 \n" + + // trn out neon register + "vtrn.32 d22, d23 \n" + + // add bias + "vadd.f32 q11, q11, q15 \n" + + // relu + "vmax.f32 q11, q11, q5 \n" + + // store result + "vst1.32 {d22}, [%[dout0]]! \n" + "vst1.32 {d23}, [%[dout1]]! \n" + + //! out three + "sub %[wh], #80 \n" + "vld1.32 {d4[0]}, [%[din0]] \n" + "vld1.32 {d4[1]}, [%[din1]] \n" + "vld1.32 {d5[0]}, [%[din2]] \n" + "vld1.32 {d5[1]}, [%[din3]] \n" + "vld1.32 {d6[0]}, [%[din4]] \n" + "vld1.32 {d6[1]}, [%[din5]] \n" + + "vext.32 q4, q2, q3, #1 \n" + + "vld1.32 {d0[0]}, [%[wh]], r0 \n" + "vld1.32 {d0[1]}, [%[wh]], r0 \n" + "vld1.32 {d1[0]}, [%[wh]], r0 \n" + "vld1.32 {d1[1]}, [%[wh]], r0 \n" + "vld1.32 {d2[0]}, [%[wh]] \n" + + "vmul.f32 q9, q0, q2 \n" + "vmul.f32 q10, q0, q4 \n" + + "vpadd.f32 d18, d18, d19 \n" + "vpadd.f32 d20, d20, d21 \n" + "vpadd.f32 d17, d18, d20 \n" + + "vmla.f32 d17, d6, d2[0] \n" + + // trn out neon register + "vtrn.32 d16, d17 \n" + + // add bias + "vadd.f32 q8, q8, q15 \n" + + // relu + "vmax.f32 q8, q8, q5 \n" + + // store result + "vst1.32 {d16}, [%[dout0]] \n" + "vst1.32 {d17}, [%[dout1]] \n" + + : [dout0] "+r"(dout0), + [dout1] "+r"(dout1), + [din0] "+r"(din0), + [din1] "+r"(din1), + [din2] "+r"(din2), + [din3] "+r"(din3), + [din4] "+r"(din4), + [din5] "+r"(din5), + [wh] "+r"(weights) + : [bias] "r"(bias) + : "memory", + "r0", + "r1", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +} + +void conv_depthwise_5x5s1_impl(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + int pad_new = pad > 4 ? 4 : pad; + int pad_0 = pad - pad_new; + int h_out_new = h_out - 2 * pad_0; + int mid_out = w_out - 2 * pad; + int mid_cnt = mid_out >> 2; + int mid_remain = mid_out - (mid_cnt << 2); + int pad_cnt = pad_0 >> 2; + int pad_remain = pad_0 - (pad_cnt << 2); + int bias_cnt = (w_out * pad_0) >> 2; + int bias_remain = (w_out * pad_0) - (bias_cnt << 2); + int in_spatial_size = w_in * h_in; + int out_spatial_size = w_out * h_out; + int weights_saptial_size = 25; + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * in_spatial_size * ch_in; + float* dout_batch = dout + n * out_spatial_size * ch_out; +#pragma omp parallel for + for (int c = 0; c < ch_in; ++c) { + const float* din_ch = din_batch + c * in_spatial_size; + float* dout_ch = dout_batch + c * out_spatial_size; + float bias_c = flag_bias ? bias[c] : 0.f; + float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; + float32x4_t vbias_c = vdupq_n_f32(bias_c); + if (flag_bias) { + //! deal with h_out pad_0 line with bias + for (int i = 0; i < bias_cnt; ++i) { + vst1q_f32(dout_ch, vbias_c); + dout_ch += 4; + } + for (int i = 0; i < bias_remain; ++i) { + *dout_ch++ = bias_c; + } + } else { + //! deal with h_out pad_0 line without bias + for (int i = 0; i < pad_0; ++i) { + memset(dout_ch, 0x00, w_out * sizeof(float)); + dout_ch += w_out; + } + } + const float* din_list[6]; + //! set din ptr with zero buffer + for (int i = 0; i < pad_new; ++i) { + din_list[i] = zero_ptr; + } + //! set din ptr with input data + for (int i = pad_new; i < 6; ++i) { + din_list[i] = din_ch; + din_ch += w_in; + } + //! every h loop, deal with 6 line input + const float* din0 = din_list[0]; + const float* din1 = din_list[1]; + const float* din2 = din_list[2]; + const float* din3 = din_list[3]; + const float* din4 = din_list[4]; + const float* din5 = din_list[5]; + + //! every h loop, deal with 2 line output + float* dout0 = dout_ch; + float* dout1 = dout0 + w_out; + + //! load weights to neon register + const float* weights_c = weights + c * weights_saptial_size; + + //! h loop + for (int h = 0; h < h_out_new; h += 2) { + //! (h - pad_new) + 7 > h_in - 1 + if (h + 6 - pad_new > h_in) { + switch (h + 6 - pad_new - h_in) { + case 5: + din1 = zero_ptr; + case 4: + din2 = zero_ptr; + case 3: + din3 = zero_ptr; + case 2: + din4 = zero_ptr; + case 1: + din5 = zero_ptr; + default: + break; + } + } + if (h + 2 > h_out_new) { + dout1 = write_ptr; + } + const float* din_ptr0 = din0; + const float* din_ptr1 = din1; + const float* din_ptr2 = din2; + const float* din_ptr3 = din3; + const float* din_ptr4 = din4; + const float* din_ptr5 = din5; + + float* dout_ptr0 = dout0; + float* dout_ptr1 = dout1; + if (flag_bias) { + //! deal with w_out pad_0 column pre with bias + for (int i = 0; i < pad_cnt; i++) { + vst1q_f32(dout_ptr0, vbias_c); + vst1q_f32(dout_ptr1, vbias_c); + dout_ptr0 += 4; + dout_ptr1 += 4; + } + for (int i = 0; i < pad_remain; ++i) { + *dout_ptr0++ = bias_c; + *dout_ptr1++ = bias_c; + } + } else { + //! deal with w_out pad_0 column pre without bias + memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); + dout_ptr0 += pad_0; + dout_ptr1 += pad_0; + } + + //! deal with w_out pad_new column pre + switch (pad_new) { + case 4: + compute_four_out_extract_pre(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + dout_ptr0 += 4; + dout_ptr1 += 4; + break; + case 3: + compute_three_out_extract_pre(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + dout_ptr0 += 3; + dout_ptr1 += 3; + break; + case 2: + compute_two_out_extract_pre(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + dout_ptr0 += 2; + dout_ptr1 += 2; + break; + case 1: + compute_one_out_extract_pre(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + dout_ptr0 += 1; + dout_ptr1 += 1; + break; + } + + //! mid loop + if (mid_cnt > 0) { + int mid_loop = mid_cnt; + const float* weights_ptr = weights_c; + asm volatile( + //! din: q7-q12 + //! dout: q13, q14 + "mov r1, #20 \n" + //! load weights + "vld1.32 {d0-d1}, [%[wh]], r1 \n" + "vld1.32 {d2-d3}, [%[wh]], r1 \n" + "vld1.32 {d4-d5}, [%[wh]], r1 \n" + "vld1.32 {d6-d7}, [%[wh]], r1 \n" + "vld1.32 {d8-d9}, [%[wh]] \n" + + "sub %[wh], #64 \n" + "vld1.32 {d10[0]}, [%[wh]], r1 \n" + "vld1.32 {d10[1]}, [%[wh]], r1 \n" + "vld1.32 {d11[0]}, [%[wh]], r1 \n" + "vld1.32 {d11[1]}, [%[wh]], r1 \n" + "vld1.32 {d12[0]}, [%[wh]] \n" + + //! load input + "mov r1, #4 \n" + "vld1.32 {d14-d15}, [%[din0]], r1 \n" + "vld1.32 {d16-d17}, [%[din1]], r1 \n" + "vld1.32 {d18-d19}, [%[din2]], r1 \n" + "vld1.32 {d20-d21}, [%[din3]], r1 \n" + "vld1.32 {d22-d23}, [%[din4]], r1 \n" + "vld1.32 {d24-d25}, [%[din5]], r1 \n" + + //! load bias + "vld1.32 {d30-d31}, [%[bias]] \n" + + "1: \n" + //! add bias to output + "vmov.32 q13, q15 \n" + "vmov.32 q14, q15 \n" + + "pld [%[din0]] \n" + "pld [%[din1]] \n" + "pld [%[din2]] \n" + "pld [%[din3]] \n" + "pld [%[din4]] \n" + "pld [%[din5]] \n" + + // weights col 0 + "vmla.f32 q13, q7, d0[0] \n" + "vmla.f32 q14, q8, d0[0] \n" + + "vmla.f32 q13, q8, d2[0] \n" + "vmla.f32 q14, q9, d2[0] \n" + + "vld1.32 {d14-d15}, [%[din0]], r1 \n" + "vld1.32 {d16-d17}, [%[din1]], r1 \n" + + "vmla.f32 q13, q9, d4[0] \n" + "vmla.f32 q14, q10, d4[0] \n" + + "vmla.f32 q13, q10, d6[0] \n" + "vmla.f32 q14, q11, d6[0] \n" + + "vld1.32 {d18-d19}, [%[din2]], r1 \n" + "vld1.32 {d20-d21}, [%[din3]], r1 \n" + + "vmla.f32 q13, q11, d8[0] \n" + "vmla.f32 q14, q12, d8[0] \n" + + "vld1.32 {d22-d23}, [%[din4]], r1 \n" + "vld1.32 {d24-d25}, [%[din5]], r1 \n" + + // weights col 1 + "vmla.f32 q13, q7, d0[1] \n" + "vmla.f32 q14, q8, d0[1] \n" + + "vmla.f32 q13, q8, d2[1] \n" + "vmla.f32 q14, q9, d2[1] \n" + + "vld1.32 {d14-d15}, [%[din0]], r1 \n" + "vld1.32 {d16-d17}, [%[din1]], r1 \n" + + "vmla.f32 q13, q9, d4[1] \n" + "vmla.f32 q14, q10, d4[1] \n" + + "vmla.f32 q13, q10, d6[1] \n" + "vmla.f32 q14, q11, d6[1] \n" + + "vld1.32 {d18-d19}, [%[din2]], r1 \n" + "vld1.32 {d20-d21}, [%[din3]], r1 \n" + + "vmla.f32 q13, q11, d8[1] \n" + "vmla.f32 q14, q12, d8[1] \n" + + "vld1.32 {d22-d23}, [%[din4]], r1 \n" + "vld1.32 {d24-d25}, [%[din5]], r1 \n" + + // weights col 2 + "vmla.f32 q13, q7, d1[0] \n" + "vmla.f32 q14, q8, d1[0] \n" + + "vmla.f32 q13, q8, d3[0] \n" + "vmla.f32 q14, q9, d3[0] \n" + + "vld1.32 {d14-d15}, [%[din0]], r1 \n" + "vld1.32 {d16-d17}, [%[din1]], r1 \n" + + "vmla.f32 q13, q9, d5[0] \n" + "vmla.f32 q14, q10, d5[0] \n" + + "vmla.f32 q13, q10, d7[0] \n" + "vmla.f32 q14, q11, d7[0] \n" + + "vld1.32 {d18-d19}, [%[din2]], r1 \n" + "vld1.32 {d20-d21}, [%[din3]], r1 \n" + + "vmla.f32 q13, q11, d9[0] \n" + "vmla.f32 q14, q12, d9[0] \n" + + "vld1.32 {d22-d23}, [%[din4]], r1 \n" + "vld1.32 {d24-d25}, [%[din5]], r1 \n" + + // weights col 3 + "vmla.f32 q13, q7, d1[1] \n" + "vmla.f32 q14, q8, d1[1] \n" + + "vmla.f32 q13, q8, d3[1] \n" + "vmla.f32 q14, q9, d3[1] \n" + + "vld1.32 {d14-d15}, [%[din0]], r1 \n" + "vld1.32 {d16-d17}, [%[din1]], r1 \n" + + "vmla.f32 q13, q9, d5[1] \n" + "vmla.f32 q14, q10, d5[1] \n" + + "vmla.f32 q13, q10, d7[1] \n" + "vmla.f32 q14, q11, d7[1] \n" + + "vld1.32 {d18-d19}, [%[din2]], r1 \n" + "vld1.32 {d20-d21}, [%[din3]], r1 \n" + + "vmla.f32 q13, q11, d9[1] \n" + "vmla.f32 q14, q12, d9[1] \n" + + "vld1.32 {d22-d23}, [%[din4]], r1 \n" + "vld1.32 {d24-d25}, [%[din5]], r1 \n" + + // weights col 4 + "vmla.f32 q13, q7, d10[0] \n" + "vmla.f32 q14, q8, d10[0] \n" + + "vmla.f32 q13, q8, d10[1] \n" + "vmla.f32 q14, q9, d10[1] \n" + + "vmla.f32 q13, q9, d11[0] \n" + "vmla.f32 q14, q10, d11[0] \n" + + "vmla.f32 q13, q10, d11[1] \n" + "vmla.f32 q14, q11, d11[1] \n" + + "vmla.f32 q13, q11, d12[0] \n" + "vmla.f32 q14, q12, d12[0] \n" + + // store reslult + "vst1.32 {d26-d27}, [%[out0]]! \n" + "vst1.32 {d28-d29}, [%[out1]]! \n" + + "subs %[cnt], #1 \n" + "bne 1b \n" + + "sub %[din0], r1 \n" + "sub %[din1], r1 \n" + "sub %[din2], r1 \n" + "sub %[din3], r1 \n" + "sub %[din4], r1 \n" + "sub %[din5], r1 \n" + + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3), + [din4] "+r"(din_ptr4), + [din5] "+r"(din_ptr5), + [out0] "+r"(dout_ptr0), + [out1] "+r"(dout_ptr1), + [wh] "+r"(weights_ptr), + [cnt] "+r"(mid_loop) + : [bias] "r"(vbias) + : "cc", + "memory", + "r1", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } + //! deal with mid remain + for (int i = 0; i < mid_remain; ++i) { + compute_one_out_without_extract(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + din_ptr0++; + din_ptr1++; + din_ptr2++; + din_ptr3++; + din_ptr4++; + din_ptr5++; + + dout_ptr0++; + dout_ptr1++; + } + //! deal with w_out pad_new column post + switch (pad_new) { + case 4: + compute_four_out_extract_post(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + dout_ptr0 += 4; + dout_ptr1 += 4; + break; + case 3: + compute_three_out_extract_post(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + dout_ptr0 += 3; + dout_ptr1 += 3; + break; + case 2: + compute_two_out_extract_post(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + dout_ptr0 += 2; + dout_ptr1 += 2; + break; + case 1: + compute_one_out_extract_post(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + dout_ptr0 += 1; + dout_ptr1 += 1; + break; + } + + if (flag_bias) { + //! deal with w_out pad_0 column post with bias + memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); + memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); + } else { + //! deal with w_out pad_0 column post without bias + memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); + } + + din0 = din2; + din1 = din3; + din2 = din4; + din3 = din5; + din4 = din3 + w_in; + din5 = din4 + w_in; + + dout0 = dout1 + w_out; + dout1 = dout0 + w_out; + } + float* dout_pad_end = dout_ch + h_out_new * w_out; + if (flag_bias) { + //! deal with h_out pad_0 line with bias + memcpy(reinterpret_cast(dout_pad_end), + dout_ch - pad_0 * w_out, + pad_0 * w_out * sizeof(float)); + } else { + //! deal with h_out pad_0 line without bias + memset(reinterpret_cast(dout_pad_end), + 0x00, + pad_0 * w_out * sizeof(float)); + } + } + } +} + +void conv_depthwise_5x5s1_relu_impl(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + int pad_new = pad > 4 ? 4 : pad; + int pad_0 = pad - pad_new; + int h_out_new = h_out - 2 * pad_0; + int mid_out = w_out - 2 * pad; + int mid_cnt = mid_out >> 2; + int mid_remain = mid_out - (mid_cnt << 2); + int pad_cnt = pad_0 >> 2; + int pad_remain = pad_0 - (pad_cnt << 2); + int bias_cnt = (w_out * pad_0) >> 2; + int bias_remain = (w_out * pad_0) - (bias_cnt << 2); + int in_spatial_size = w_in * h_in; + int out_spatial_size = w_out * h_out; + int weights_saptial_size = 25; + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * in_spatial_size * ch_in; + float* dout_batch = dout + n * out_spatial_size * ch_out; +#pragma omp parallel for + for (int c = 0; c < ch_in; ++c) { + const float* din_ch = din_batch + c * in_spatial_size; + float* dout_ch = dout_batch + c * out_spatial_size; + float bias_c = flag_bias ? bias[c] : 0.f; + float bias_relu = bias_c > 0.f ? bias_c : 0.f; + float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; + float32x4_t vbias_c = vdupq_n_f32(bias_relu); + if (flag_bias) { + //! deal with h_out pad_0 line with bias + for (int i = 0; i < bias_cnt; ++i) { + vst1q_f32(dout_ch, vbias_c); + dout_ch += 4; + } + for (int i = 0; i < bias_remain; ++i) { + *dout_ch++ = bias_relu; + } + } else { + //! deal with h_out pad_0 line without bias + for (int i = 0; i < pad_0; ++i) { + memset(dout_ch, 0x00, w_out * sizeof(float)); + dout_ch += w_out; + } + } + const float* din_list[6]; + //! set din ptr with zero buffer + for (int i = 0; i < pad_new; ++i) { + din_list[i] = zero_ptr; + } + //! set din ptr with input data + for (int i = pad_new; i < 6; ++i) { + din_list[i] = din_ch; + din_ch += w_in; + } + //! every h loop, deal with 6 line input + const float* din0 = din_list[0]; + const float* din1 = din_list[1]; + const float* din2 = din_list[2]; + const float* din3 = din_list[3]; + const float* din4 = din_list[4]; + const float* din5 = din_list[5]; + + //! every h loop, deal with 2 line output + float* dout0 = dout_ch; + float* dout1 = dout0 + w_out; + + //! load weights to neon register + const float* weights_c = weights + c * weights_saptial_size; + + //! h loop + for (int h = 0; h < h_out_new; h += 2) { + //! (h - pad_new) + 7 > h_in - 1 + if (h + 6 - pad_new > h_in) { + switch (h + 6 - pad_new - h_in) { + case 5: + din1 = zero_ptr; + case 4: + din2 = zero_ptr; + case 3: + din3 = zero_ptr; + case 2: + din4 = zero_ptr; + case 1: + din5 = zero_ptr; + default: + break; + } + } + if (h + 2 > h_out_new) { + dout1 = write_ptr; + } + const float* din_ptr0 = din0; + const float* din_ptr1 = din1; + const float* din_ptr2 = din2; + const float* din_ptr3 = din3; + const float* din_ptr4 = din4; + const float* din_ptr5 = din5; + + float* dout_ptr0 = dout0; + float* dout_ptr1 = dout1; + if (flag_bias) { + //! deal with w_out pad_0 column pre with bias + for (int i = 0; i < pad_cnt; i++) { + vst1q_f32(dout_ptr0, vbias_c); + vst1q_f32(dout_ptr1, vbias_c); + dout_ptr0 += 4; + dout_ptr1 += 4; + } + for (int i = 0; i < pad_remain; ++i) { + *dout_ptr0++ = bias_relu; + *dout_ptr1++ = bias_relu; + } + } else { + //! deal with w_out pad_0 column pre without bias + memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); + dout_ptr0 += pad_0; + dout_ptr1 += pad_0; + } + + //! deal with w_out pad_new column pre + switch (pad_new) { + case 4: + compute_four_out_extract_pre_relu(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + dout_ptr0 += 4; + dout_ptr1 += 4; + break; + case 3: + compute_three_out_extract_pre_relu(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + dout_ptr0 += 3; + dout_ptr1 += 3; + break; + case 2: + compute_two_out_extract_pre_relu(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + dout_ptr0 += 2; + dout_ptr1 += 2; + break; + case 1: + compute_one_out_extract_pre_relu(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + dout_ptr0 += 1; + dout_ptr1 += 1; + break; + } + + //! mid loop + if (mid_cnt > 0) { + int mid_loop = mid_cnt; + const float* weights_ptr = weights_c; + asm volatile( + //! din: q7-q12 + //! dout: q13, q14 + "mov r1, #20 \n" + "vmov.i32 q15, #0x0 \n" + //! load weights + "vld1.32 {d0-d1}, [%[wh]], r1 \n" + "vld1.32 {d2-d3}, [%[wh]], r1 \n" + "vld1.32 {d4-d5}, [%[wh]], r1 \n" + "vld1.32 {d6-d7}, [%[wh]], r1 \n" + "vld1.32 {d8-d9}, [%[wh]] \n" + + "sub %[wh], #64 \n" + "vld1.32 {d10[0]}, [%[wh]], r1 \n" + "vld1.32 {d10[1]}, [%[wh]], r1 \n" + "vld1.32 {d11[0]}, [%[wh]], r1 \n" + "vld1.32 {d11[1]}, [%[wh]], r1 \n" + "vld1.32 {d12[0]}, [%[wh]] \n" + + //! load input + "mov r1, #4 \n" + "vld1.32 {d14-d15}, [%[din0]], r1 \n" + "vld1.32 {d16-d17}, [%[din1]], r1 \n" + "vld1.32 {d18-d19}, [%[din2]], r1 \n" + "vld1.32 {d20-d21}, [%[din3]], r1 \n" + "vld1.32 {d22-d23}, [%[din4]], r1 \n" + "vld1.32 {d24-d25}, [%[din5]], r1 \n" + + "1: \n" + + //! load bias to output + "vld1.32 {d26-d27}, [%[bias]] \n" + "vld1.32 {d28-d29}, [%[bias]] \n" + + "pld [%[din0]] \n" + "pld [%[din1]] \n" + "pld [%[din2]] \n" + "pld [%[din3]] \n" + "pld [%[din4]] \n" + "pld [%[din5]] \n" + + // weights col 0 + "vmla.f32 q13, q7, d0[0] \n" + "vmla.f32 q14, q8, d0[0] \n" + + "vmla.f32 q13, q8, d2[0] \n" + "vmla.f32 q14, q9, d2[0] \n" + + "vld1.32 {d14-d15}, [%[din0]], r1 \n" + "vld1.32 {d16-d17}, [%[din1]], r1 \n" + + "vmla.f32 q13, q9, d4[0] \n" + "vmla.f32 q14, q10, d4[0] \n" + + "vmla.f32 q13, q10, d6[0] \n" + "vmla.f32 q14, q11, d6[0] \n" + + "vld1.32 {d18-d19}, [%[din2]], r1 \n" + "vld1.32 {d20-d21}, [%[din3]], r1 \n" + + "vmla.f32 q13, q11, d8[0] \n" + "vmla.f32 q14, q12, d8[0] \n" + + "vld1.32 {d22-d23}, [%[din4]], r1 \n" + "vld1.32 {d24-d25}, [%[din5]], r1 \n" + + // weights col 1 + "vmla.f32 q13, q7, d0[1] \n" + "vmla.f32 q14, q8, d0[1] \n" + + "vmla.f32 q13, q8, d2[1] \n" + "vmla.f32 q14, q9, d2[1] \n" + + "vld1.32 {d14-d15}, [%[din0]], r1 \n" + "vld1.32 {d16-d17}, [%[din1]], r1 \n" + + "vmla.f32 q13, q9, d4[1] \n" + "vmla.f32 q14, q10, d4[1] \n" + + "vmla.f32 q13, q10, d6[1] \n" + "vmla.f32 q14, q11, d6[1] \n" + + "vld1.32 {d18-d19}, [%[din2]], r1 \n" + "vld1.32 {d20-d21}, [%[din3]], r1 \n" + + "vmla.f32 q13, q11, d8[1] \n" + "vmla.f32 q14, q12, d8[1] \n" + + "vld1.32 {d22-d23}, [%[din4]], r1 \n" + "vld1.32 {d24-d25}, [%[din5]], r1 \n" + + // weights col 2 + "vmla.f32 q13, q7, d1[0] \n" + "vmla.f32 q14, q8, d1[0] \n" + + "vmla.f32 q13, q8, d3[0] \n" + "vmla.f32 q14, q9, d3[0] \n" + + "vld1.32 {d14-d15}, [%[din0]], r1 \n" + "vld1.32 {d16-d17}, [%[din1]], r1 \n" + + "vmla.f32 q13, q9, d5[0] \n" + "vmla.f32 q14, q10, d5[0] \n" + + "vmla.f32 q13, q10, d7[0] \n" + "vmla.f32 q14, q11, d7[0] \n" + + "vld1.32 {d18-d19}, [%[din2]], r1 \n" + "vld1.32 {d20-d21}, [%[din3]], r1 \n" + + "vmla.f32 q13, q11, d9[0] \n" + "vmla.f32 q14, q12, d9[0] \n" + + "vld1.32 {d22-d23}, [%[din4]], r1 \n" + "vld1.32 {d24-d25}, [%[din5]], r1 \n" + + // weights col 3 + "vmla.f32 q13, q7, d1[1] \n" + "vmla.f32 q14, q8, d1[1] \n" + + "vmla.f32 q13, q8, d3[1] \n" + "vmla.f32 q14, q9, d3[1] \n" + + "vld1.32 {d14-d15}, [%[din0]], r1 \n" + "vld1.32 {d16-d17}, [%[din1]], r1 \n" + + "vmla.f32 q13, q9, d5[1] \n" + "vmla.f32 q14, q10, d5[1] \n" + + "vmla.f32 q13, q10, d7[1] \n" + "vmla.f32 q14, q11, d7[1] \n" + + "vld1.32 {d18-d19}, [%[din2]], r1 \n" + "vld1.32 {d20-d21}, [%[din3]], r1 \n" + + "vmla.f32 q13, q11, d9[1] \n" + "vmla.f32 q14, q12, d9[1] \n" + + "vld1.32 {d22-d23}, [%[din4]], r1 \n" + "vld1.32 {d24-d25}, [%[din5]], r1 \n" + + // weights col 4 + "vmla.f32 q13, q7, d10[0] \n" + "vmla.f32 q14, q8, d10[0] \n" + + "vmla.f32 q13, q8, d10[1] \n" + "vmla.f32 q14, q9, d10[1] \n" + + "vmla.f32 q13, q9, d11[0] \n" + "vmla.f32 q14, q10, d11[0] \n" + + "vmla.f32 q13, q10, d11[1] \n" + "vmla.f32 q14, q11, d11[1] \n" + + "vmla.f32 q13, q11, d12[0] \n" + "vmla.f32 q14, q12, d12[0] \n" + + // relu + "vmax.f32 q13, q13, q15 \n" + "vmax.f32 q14, q14, q15 \n" + + // store result + "vst1.32 {d26-d27}, [%[out0]]! \n" + "vst1.32 {d28-d29}, [%[out1]]! \n" + + "subs %[cnt], #1 \n" + "bne 1b \n" + + "sub %[din0], r1 \n" + "sub %[din1], r1 \n" + "sub %[din2], r1 \n" + "sub %[din3], r1 \n" + "sub %[din4], r1 \n" + "sub %[din5], r1 \n" + + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3), + [din4] "+r"(din_ptr4), + [din5] "+r"(din_ptr5), + [out0] "+r"(dout_ptr0), + [out1] "+r"(dout_ptr1), + [wh] "+r"(weights_ptr), + [cnt] "+r"(mid_loop) + : [bias] "r"(vbias) + : "cc", + "memory", + "r1", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } + //! deal with mid remain + for (int i = 0; i < mid_remain; ++i) { + compute_one_out_without_extract_relu(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + din_ptr0++; + din_ptr1++; + din_ptr2++; + din_ptr3++; + din_ptr4++; + din_ptr5++; + + dout_ptr0++; + dout_ptr1++; + } + //! deal with w_out pad_new column post + switch (pad_new) { + case 4: + compute_four_out_extract_post_relu(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + dout_ptr0 += 4; + dout_ptr1 += 4; + break; + case 3: + compute_three_out_extract_post_relu(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + dout_ptr0 += 3; + dout_ptr1 += 3; + break; + case 2: + compute_two_out_extract_post_relu(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + dout_ptr0 += 2; + dout_ptr1 += 2; + break; + case 1: + compute_one_out_extract_post_relu(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + dout_ptr0 += 1; + dout_ptr1 += 1; + break; + } + + if (flag_bias) { + //! deal with w_out pad_0 column post with bias + memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); + memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); + } else { + //! deal with w_out pad_0 column post without bias + memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); + } + + din0 = din2; + din1 = din3; + din2 = din4; + din3 = din5; + din4 = din3 + w_in; + din5 = din4 + w_in; + + dout0 = dout1 + w_out; + dout1 = dout0 + w_out; + } + float* dout_pad_end = dout_ch + h_out_new * w_out; + if (flag_bias) { + //! deal with h_out pad_0 line with bias + memcpy(reinterpret_cast(dout_pad_end), + dout_ch - pad_0 * w_out, + pad_0 * w_out * sizeof(float)); + } else { + //! deal with h_out pad_0 line without bias + memset(reinterpret_cast(dout_pad_end), + 0x00, + pad_0 * w_out * sizeof(float)); + } + } + } +} + +void conv_depthwise_5x5s1_small_impl(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + int pad_new = pad > 4 ? 4 : pad; + int pad_0 = pad - pad_new; + int h_in_new = h_in + 2 * pad_new; + int w_in_new = w_in + 2 * pad_new; + int h_out_new = h_out - 2 * pad_0; + int w_out_new = w_out - 2 * pad_0; + float zero_ptr[w_in_new + w_out]; + memset(zero_ptr, 0, w_in_new * sizeof(float)); + float* write_ptr = zero_ptr + w_in_new; + int pad_cnt = pad_0 >> 2; + int pad_remain = pad_0 - (pad_cnt << 2); + int bias_cnt = (w_out * pad_0) >> 2; + int bias_remain = (w_out * pad_0) - (bias_cnt << 2); + int in_spatial_size = w_in_new * h_in_new; + int out_spatial_size = w_out * h_out; + int weights_saptial_size = 25; + + float* din_new = prepad_input(din, num, ch_in, h_in, w_in, pad_new); + for (int n = 0; n < num; ++n) { + const float* din_batch = din_new + n * in_spatial_size * ch_in; + float* dout_batch = dout + n * out_spatial_size * ch_out; +#pragma omp parallel for + for (int c = 0; c < ch_in; ++c) { + const float* din_ch = din_batch + c * in_spatial_size; + float* dout_ch = dout_batch + c * out_spatial_size; + float bias_c = flag_bias ? bias[c] : 0.f; + float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; + float32x4_t vbias_c = vdupq_n_f32(bias_c); + if (flag_bias) { + //! deal with h_out pad_0 line with bias + for (int i = 0; i < bias_cnt; ++i) { + vst1q_f32(dout_ch, vbias_c); + dout_ch += 4; + } + for (int i = 0; i < bias_remain; ++i) { + *dout_ch++ = bias_c; + } + } else { + //! deal with h_out pad_0 line without bias + for (int i = 0; i < pad_0; ++i) { + memset(dout_ch, 0x00, w_out * sizeof(float)); + dout_ch += w_out; + } + } + //! every h loop, deal with 6 line input + const float* din0 = din_ch; + const float* din1 = din0 + w_in_new; + const float* din2 = din1 + w_in_new; + const float* din3 = din2 + w_in_new; + const float* din4 = din3 + w_in_new; + const float* din5 = din4 + w_in_new; + //! every h loop, deal with 2 line output + float* dout0 = dout_ch; + float* dout1 = dout0 + w_out; + + const float* weights_c = weights + c * weights_saptial_size; + + //! h loop + for (int h = 0; h < h_out_new; h += 2) { + //! (h - pad_new) + 6 > h_in - 1 + if (h + 6 > h_in_new) { + switch (h + 6 - h_in_new) { + case 5: + din1 = zero_ptr; + case 4: + din2 = zero_ptr; + case 3: + din3 = zero_ptr; + case 2: + din4 = zero_ptr; + case 1: + din5 = zero_ptr; + default: + break; + } + } + if (h + 2 > h_out_new) { + dout1 = write_ptr; + } + const float* din_ptr0 = din0; + const float* din_ptr1 = din1; + const float* din_ptr2 = din2; + const float* din_ptr3 = din3; + const float* din_ptr4 = din4; + const float* din_ptr5 = din5; + + float* dout_ptr0 = dout0; + float* dout_ptr1 = dout1; + + if (flag_bias) { + //! deal with w_out pad_0 column pre with bias + for (int i = 0; i < pad_cnt; i++) { + vst1q_f32(dout_ptr0, vbias_c); + vst1q_f32(dout_ptr1, vbias_c); + dout_ptr0 += 4; + dout_ptr1 += 4; + } + for (int i = 0; i < pad_remain; ++i) { + *dout_ptr0++ = bias_c; + *dout_ptr1++ = bias_c; + } + } else { + //! deal with w_out pad_0 column pre without bias + memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); + dout_ptr0 += pad_0; + dout_ptr1 += pad_0; + } + //! mid loop + for (int i = 0; i < w_out_new; ++i) { + compute_one_out_without_extract(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + din_ptr0++; + din_ptr1++; + din_ptr2++; + din_ptr3++; + din_ptr4++; + din_ptr5++; + + dout_ptr0++; + dout_ptr1++; + } + if (flag_bias) { + //! deal with w_out pad_0 column post with bias + memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); + memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); + } else { + //! deal with w_out pad_0 column post without bias + memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); + } + + din0 = din2; + din1 = din3; + din2 = din4; + din3 = din5; + din4 = din3 + w_in_new; + din5 = din4 + w_in_new; + + dout0 = dout1 + w_out; + dout1 = dout0 + w_out; + } + float* dout_pad_end = dout_ch + h_out_new * w_out; + if (flag_bias) { + //! deal with h_out pad_0 line with bias + memcpy(reinterpret_cast(dout_pad_end), + dout_ch - pad_0 * w_out, + pad_0 * w_out * sizeof(float)); + } else { + //! deal with h_out pad_0 line without bias + memset(reinterpret_cast(dout_pad_end), + 0x00, + pad_0 * w_out * sizeof(float)); + } + } + } + free(din_new); +} + +void conv_depthwise_5x5s1_small_relu_impl(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + int pad_new = pad > 4 ? 4 : pad; + int pad_0 = pad - pad_new; + int h_in_new = h_in + 2 * pad_new; + int w_in_new = w_in + 2 * pad_new; + int h_out_new = h_out - 2 * pad_0; + int w_out_new = w_out - 2 * pad_0; + float zero_ptr[w_in_new + w_out]; + memset(zero_ptr, 0, w_in_new * sizeof(float)); + float* write_ptr = zero_ptr + w_in_new; + int pad_cnt = pad_0 >> 2; + int pad_remain = pad_0 - (pad_cnt << 2); + int bias_cnt = (w_out * pad_0) >> 2; + int bias_remain = (w_out * pad_0) - (bias_cnt << 2); + int in_spatial_size = w_in_new * h_in_new; + int out_spatial_size = w_out * h_out; + int weights_saptial_size = 25; + + float* din_new = prepad_input(din, num, ch_in, h_in, w_in, pad_new); + for (int n = 0; n < num; ++n) { + const float* din_batch = din_new + n * in_spatial_size * ch_in; + float* dout_batch = dout + n * out_spatial_size * ch_out; +#pragma omp parallel for + for (int c = 0; c < ch_in; ++c) { + const float* din_ch = din_batch + c * in_spatial_size; + float* dout_ch = dout_batch + c * out_spatial_size; + float bias_c = flag_bias ? bias[c] : 0.f; + float bias_relu = bias_c > 0.f ? bias_c : 0.f; + float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; + float32x4_t vbias_c = vdupq_n_f32(bias_relu); + if (flag_bias) { + //! deal with h_out pad_0 line with bias + for (int i = 0; i < bias_cnt; ++i) { + vst1q_f32(dout_ch, vbias_c); + dout_ch += 4; + } + for (int i = 0; i < bias_remain; ++i) { + *dout_ch++ = bias_relu; + } + } else { + //! deal with h_out pad_0 line without bias + for (int i = 0; i < pad_0; ++i) { + memset(dout_ch, 0x00, w_out * sizeof(float)); + dout_ch += w_out; + } + } + //! every h loop, deal with 6 line input + const float* din0 = din_ch; + const float* din1 = din0 + w_in_new; + const float* din2 = din1 + w_in_new; + const float* din3 = din2 + w_in_new; + const float* din4 = din3 + w_in_new; + const float* din5 = din4 + w_in_new; + //! every h loop, deal with 2 line output + float* dout0 = dout_ch; + float* dout1 = dout0 + w_out; + + const float* weights_c = weights + c * weights_saptial_size; + + //! h loop + for (int h = 0; h < h_out_new; h += 2) { + //! (h - pad_new) + 6 > h_in - 1 + if (h + 6 > h_in_new) { + switch (h + 6 - h_in_new) { + case 5: + din1 = zero_ptr; + case 4: + din2 = zero_ptr; + case 3: + din3 = zero_ptr; + case 2: + din4 = zero_ptr; + case 1: + din5 = zero_ptr; + default: + break; + } + } + if (h + 2 > h_out_new) { + dout1 = write_ptr; + } + const float* din_ptr0 = din0; + const float* din_ptr1 = din1; + const float* din_ptr2 = din2; + const float* din_ptr3 = din3; + const float* din_ptr4 = din4; + const float* din_ptr5 = din5; + + const float* weights_ptr = weights_c; + float* dout_ptr0 = dout0; + float* dout_ptr1 = dout1; + + if (flag_bias) { + //! deal with w_out pad_0 column pre with bias + for (int i = 0; i < pad_cnt; i++) { + vst1q_f32(dout_ptr0, vbias_c); + vst1q_f32(dout_ptr1, vbias_c); + dout_ptr0 += 4; + dout_ptr1 += 4; + } + for (int i = 0; i < pad_remain; ++i) { + *dout_ptr0++ = bias_relu; + *dout_ptr1++ = bias_relu; + } + } else { + //! deal with w_out pad_0 column pre without bias + memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); + dout_ptr0 += pad_0; + dout_ptr1 += pad_0; + } + //! mid loop + for (int i = 0; i < w_out_new; ++i) { + compute_one_out_without_extract_relu(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + dout_ptr0, + dout_ptr1, + weights_c, + vbias); + din_ptr0++; + din_ptr1++; + din_ptr2++; + din_ptr3++; + din_ptr4++; + din_ptr5++; + + dout_ptr0++; + dout_ptr1++; + } + if (flag_bias) { + //! deal with w_out pad_0 column post with bias + memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); + memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); + } else { + //! deal with w_out pad_0 column post without bias + memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); + memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); + } + + din0 = din2; + din1 = din3; + din2 = din4; + din3 = din5; + din4 = din3 + w_in_new; + din5 = din4 + w_in_new; + + dout0 = dout1 + w_out; + dout1 = dout0 + w_out; + } + float* dout_pad_end = dout_ch + h_out_new * w_out; + if (flag_bias) { + //! deal with h_out pad_0 line with bias + memcpy(reinterpret_cast(dout_pad_end), + dout_ch - pad_0 * w_out, + pad_0 * w_out * sizeof(float)); + } else { + //! deal with h_out pad_0 line without bias + memset(reinterpret_cast(dout_pad_end), + 0x00, + pad_0 * w_out * sizeof(float)); + } + } + } + free(din_new); +} +#endif // __aarch64__ + +void conv_depthwise_5x5s1(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + if (win < 4) { + if (flag_relu) { + conv_depthwise_5x5s1_small_relu_impl(din, + dout, + num, + chout, + hout, + wout, + chin, + hin, + win, + weights, + bias, + pad, + flag_bias, + flag_relu, + ctx); + } else { + conv_depthwise_5x5s1_small_impl(din, + dout, + num, + chout, + hout, + wout, + chin, + hin, + win, + weights, + bias, + pad, + flag_bias, + flag_relu, + ctx); + } + } else { + if (flag_relu) { + conv_depthwise_5x5s1_relu_impl(din, + dout, + num, + chout, + hout, + wout, + chin, + hin, + win, + weights, + bias, + pad, + flag_bias, + flag_relu, + ctx); + } else { + conv_depthwise_5x5s1_impl(din, + dout, + num, + chout, + hout, + wout, + chin, + hin, + win, + weights, + bias, + pad, + flag_bias, + flag_relu, + ctx); + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_depthwise_5x5s1_int8.cc b/lite/arm/math/conv_depthwise_5x5s1_int8.cc new file mode 100644 index 00000000000..47563c542c7 --- /dev/null +++ b/lite/arm/math/conv_depthwise_5x5s1_int8.cc @@ -0,0 +1,618 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/arm/math/conv_block_utils.h" +#include "lite/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void conv_depthwise_5x5s1_int8(int32_t* dout, + const int8_t* din, + const int8_t* weights, + const int* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int chin, + const int hin, + const int win, + const int hout, + const int wout, + ARMContext* ctx, + PrecisionType out_type, + const float* scale); + +void conv_depthwise_5x5_int8(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + ARMContext* ctx, + PrecisionType out_type, + const float* scale) { + int stride_h = param.strides[0]; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + // if (param.activation_param.has_active){ + // if (param.activation_param.active == Active_relu || + // fabs(param.activation_param.negative_slope) > 1e-6f){ + // flag_relu = true; + // } + // } + if (stride_h == 1) { +#ifdef __aarch64__ + conv_depthwise_5x5s1_int8(dout, + din, + weights, + bias, + flag_bias, + flag_relu, + num, + chin, + hin, + win, + hout, + wout, + ctx, + out_type, + scale); +#else + + LOG(FATAL) << "5x5 dw conv armv7 has not impl"; +#endif + } +} + +/** + * \brief depthwise convolution, kernel size 5x5, stride 1, pad 1, with bias, + * width > 4 + */ +// 2 line +#ifdef __aarch64__ + +template +inline void prefetch(const Dtype* din) { +#ifdef __aarch64__ + asm volatile("PRFM PLDL1KEEP, [%[din]] \n" : : [din] "r"(din) : "memory"); +#else + asm volatile("pld [%[din]] \n" : : [din] "r"(din) : "memory"); +#endif +} + +void conv_depthwise_5x5s1_int8( + int32_t* dout, + const int8_t* din, + const int8_t* weights, + const int32_t* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int chin, + const int hin, + const int win, + const int hout, + const int wout, + ARMContext* ctx, + PrecisionType od_type, + float const* scales) { /// scale_size = channel-out + + // printf("5*5 multiply\n"); + int size_in_channel = win * hin; + int size_out_channel = wout * hout; + int w_stride = 5 * 5; + + static int const stride_w = 1; + int const stride_h = stride_w; + int const chout = chin; + int const pad_w = 2; + int const pad_h = pad_w; + + int const wout_round = ((wout + 7) / 8) * 8; + int const win_round = wout_round * stride_w + 5 - 1; + int const hout_round = ((hout + 2) / 3) * 3; + int const hin_round = hout_round * stride_h + 5 - 1; + int const tile_h = hout_round / 3; + int const tile_w = wout_round / 8; + + int const pre_in_size = hin_round * win_round; + int const pre_out_size = hout_round * wout_round; + int const pre_io_size = pre_in_size + pre_out_size * sizeof(int); + + int const hs = -pad_h; + int const he = hs + hin_round; + int const ws = -pad_w; + int const we = ws + win_round; + + // signed char* tmp_work_space = new signed char [1024*5]; + signed char* tmp_work_space = ctx->workspace_data(); + signed char* ptr_zero = tmp_work_space; + int* ptr_write = reinterpret_cast(ptr_zero + win_round); + signed char* pre_data = + reinterpret_cast(ptr_write + wout_round); + + memset(ptr_zero, 0, win_round * sizeof(signed char)); + + for (int n = 0; n < num; ++n) { + signed char const* din_batch = din + n * chin * size_in_channel; + int* dout_batch = dout + n * chout * size_out_channel; + + // #pragma omp parallel for + for (int c = 0; c < chout; c++) { +#ifdef ARM_WITH_OMP + int const thno = omp_get_thread_num(); +#else + int const thno = 0; +#endif + signed char const* din_channel = din_batch + c * size_in_channel; + signed char* pre_din = pre_data + thno * pre_io_size; + int* pre_out = reinterpret_cast(pre_din + pre_in_size); + int* dout_ptr = pre_out; + + prepack_input_nxw(din_channel, + pre_din, + c, + c + 1, + hs, + he, + ws, + we, + 1, + win, + hin, + ptr_zero); + + signed char const* wei_ptr = weights + c * w_stride; + int bias_val = flag_bias ? bias[c] : 0.f; + + int8x8_t wr00 = vdup_n_s8(wei_ptr[0 * 5 + 0]); + int8x8_t wr01 = vdup_n_s8(wei_ptr[0 * 5 + 1]); + int8x8_t wr02 = vdup_n_s8(wei_ptr[0 * 5 + 2]); + int8x8_t wr03 = vdup_n_s8(wei_ptr[0 * 5 + 3]); + int8x8_t wr04 = vdup_n_s8(wei_ptr[0 * 5 + 4]); + + int8x8_t wr10 = vdup_n_s8(wei_ptr[1 * 5 + 0]); + int8x8_t wr11 = vdup_n_s8(wei_ptr[1 * 5 + 1]); + int8x8_t wr12 = vdup_n_s8(wei_ptr[1 * 5 + 2]); + int8x8_t wr13 = vdup_n_s8(wei_ptr[1 * 5 + 3]); + int8x8_t wr14 = vdup_n_s8(wei_ptr[1 * 5 + 4]); + + int8x8_t wr20 = vdup_n_s8(wei_ptr[2 * 5 + 0]); + int8x8_t wr21 = vdup_n_s8(wei_ptr[2 * 5 + 1]); + int8x8_t wr22 = vdup_n_s8(wei_ptr[2 * 5 + 2]); + int8x8_t wr23 = vdup_n_s8(wei_ptr[2 * 5 + 3]); + int8x8_t wr24 = vdup_n_s8(wei_ptr[2 * 5 + 4]); + + int8x8_t wr30 = vdup_n_s8(wei_ptr[3 * 5 + 0]); + int8x8_t wr31 = vdup_n_s8(wei_ptr[3 * 5 + 1]); + int8x8_t wr32 = vdup_n_s8(wei_ptr[3 * 5 + 2]); + int8x8_t wr33 = vdup_n_s8(wei_ptr[3 * 5 + 3]); + int8x8_t wr34 = vdup_n_s8(wei_ptr[3 * 5 + 4]); + + int8x8_t wr40 = vdup_n_s8(wei_ptr[4 * 5 + 0]); + int8x8_t wr41 = vdup_n_s8(wei_ptr[4 * 5 + 1]); + int8x8_t wr42 = vdup_n_s8(wei_ptr[4 * 5 + 2]); + int8x8_t wr43 = vdup_n_s8(wei_ptr[4 * 5 + 3]); + int8x8_t wr44 = vdup_n_s8(wei_ptr[4 * 5 + 4]); + + int* doutr0 = nullptr; + int* doutr1 = nullptr; + int* doutr2 = nullptr; + + signed char const* dr0 = pre_din; + signed char const* dr1 = dr0 + win_round; + signed char const* dr2 = dr1 + win_round; + signed char const* dr3 = dr2 + win_round; + signed char const* dr4 = dr3 + win_round; + signed char const* dr5 = dr4 + win_round; + signed char const* dr6 = dr5 + win_round; + + signed char const* din_ptr0 = nullptr; + signed char const* din_ptr1 = nullptr; + signed char const* din_ptr2 = nullptr; + signed char const* din_ptr3 = nullptr; + signed char const* din_ptr4 = nullptr; + signed char const* din_ptr5 = nullptr; + signed char const* din_ptr6 = nullptr; + + for (int h = 0; h < tile_h; h++) { + // printf("c:%d h:%d\n", c, h); + doutr0 = dout_ptr; + doutr1 = doutr0 + wout_round; + doutr2 = doutr1 + wout_round; + + din_ptr0 = dr0; + din_ptr1 = dr1; + din_ptr2 = dr2; + din_ptr3 = dr3; + din_ptr4 = dr4; + din_ptr5 = dr5; + din_ptr6 = dr6; + + prefetch(doutr0); + prefetch(doutr1); + prefetch(doutr2); + prefetch(din_ptr0); + prefetch(din_ptr1); + prefetch(din_ptr2); + prefetch(din_ptr3); + prefetch(din_ptr4); + prefetch(din_ptr5); + prefetch(din_ptr6); + + for (int j = 0; j < tile_w; ++j) { + // printf("j:%d\n", j); + int32x4_t voutr00 = vdupq_n_s32(bias_val); + int32x4_t voutr01 = vdupq_n_s32(bias_val); + int32x4_t voutr10 = vdupq_n_s32(bias_val); + int32x4_t voutr11 = vdupq_n_s32(bias_val); + int32x4_t voutr20 = vdupq_n_s32(bias_val); + int32x4_t voutr21 = vdupq_n_s32(bias_val); + + // din data + int8x8_t vinr00 = vld1_s8(din_ptr0 + 0); + int8x8_t vinr01 = vld1_s8(din_ptr0 + 8); + int8x8_t vinr10 = vld1_s8(din_ptr1 + 0); + int8x8_t vinr11 = vld1_s8(din_ptr1 + 8); + int8x8_t vinr20 = vld1_s8(din_ptr2 + 0); + int8x8_t vinr21 = vld1_s8(din_ptr2 + 8); + int8x8_t vinr30 = vld1_s8(din_ptr3 + 0); + int8x8_t vinr31 = vld1_s8(din_ptr3 + 8); + int8x8_t vinr40 = vld1_s8(din_ptr4 + 0); + int8x8_t vinr41 = vld1_s8(din_ptr4 + 8); + int8x8_t vinr50 = vld1_s8(din_ptr5 + 0); + int8x8_t vinr51 = vld1_s8(din_ptr5 + 8); + int8x8_t vinr60 = vld1_s8(din_ptr6 + 0); + int8x8_t vinr61 = vld1_s8(din_ptr6 + 8); + + /// the first row + // r0 + int8x8_t vtmp1 = vext_s8(vinr00, vinr01, 1); // 12345678 + int8x8_t vtmp2 = vext_s8(vinr00, vinr01, 2); // 2345678 + int8x8_t vtmp3 = vext_s8(vinr00, vinr01, 3); // 345678 + int8x8_t vtmp4 = vext_s8(vinr00, vinr01, 4); // 45678 + + int16x8_t tvoutr0 = vmull_s8(vinr00, wr00); + tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr01); + voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); + voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); + tvoutr0 = vmull_s8(vtmp2, wr02); + tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr03); + voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); + voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); + tvoutr0 = vmull_s8(vtmp4, wr04); + voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); + voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); + + // r1 + vtmp1 = vext_s8(vinr10, vinr11, 1); // 12345678 + vtmp2 = vext_s8(vinr10, vinr11, 2); // 2345678 + vtmp3 = vext_s8(vinr10, vinr11, 3); // 345678 + vtmp4 = vext_s8(vinr10, vinr11, 4); // 45678 + + tvoutr0 = vmull_s8(vinr10, wr10); + tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr11); + voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); + voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); + tvoutr0 = vmull_s8(vtmp2, wr12); + tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr13); + voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); + voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); + tvoutr0 = vmull_s8(vtmp4, wr14); + voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); + voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); + + int16x8_t tvoutr1 = vmull_s8(vinr10, wr00); + tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr01); + voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); + voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); + tvoutr1 = vmull_s8(vtmp2, wr02); + tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr03); + voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); + voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); + tvoutr1 = vmull_s8(vtmp4, wr04); + voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); + voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); + + // r2 + vtmp1 = vext_s8(vinr20, vinr21, 1); // 12345678 + vtmp2 = vext_s8(vinr20, vinr21, 2); // 2345678 + vtmp3 = vext_s8(vinr20, vinr21, 3); // 345678 + vtmp4 = vext_s8(vinr20, vinr21, 4); // 45678 + + tvoutr0 = vmull_s8(vinr20, wr20); + tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr21); + voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); + voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); + tvoutr0 = vmull_s8(vtmp2, wr22); + tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr23); + voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); + voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); + tvoutr0 = vmull_s8(vtmp4, wr24); + voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); + voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); + + tvoutr1 = vmull_s8(vinr20, wr10); + tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr11); + voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); + voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); + tvoutr1 = vmull_s8(vtmp2, wr12); + tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr13); + voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); + voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); + tvoutr1 = vmull_s8(vtmp4, wr14); + voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); + voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); + + int16x8_t tvoutr2 = vmull_s8(vinr20, wr00); + tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr01); + voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); + voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); + tvoutr2 = vmull_s8(vtmp2, wr02); + tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr03); + voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); + voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); + tvoutr2 = vmull_s8(vtmp4, wr04); + voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); + voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); + + // r3 + vtmp1 = vext_s8(vinr30, vinr31, 1); // 12345678 + vtmp2 = vext_s8(vinr30, vinr31, 2); // 2345678 + vtmp3 = vext_s8(vinr30, vinr31, 3); // 345678 + vtmp4 = vext_s8(vinr30, vinr31, 4); // 45678 + + tvoutr0 = vmull_s8(vinr30, wr30); + tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr31); + voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); + voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); + tvoutr0 = vmull_s8(vtmp2, wr32); + tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr33); + voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); + voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); + tvoutr0 = vmull_s8(vtmp4, wr34); + voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); + voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); + + tvoutr1 = vmull_s8(vinr30, wr20); + tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr21); + voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); + voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); + tvoutr1 = vmull_s8(vtmp2, wr22); + tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr23); + voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); + voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); + tvoutr1 = vmull_s8(vtmp4, wr24); + voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); + voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); + + tvoutr2 = vmull_s8(vinr30, wr10); + tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr11); + voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); + voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); + tvoutr2 = vmull_s8(vtmp2, wr12); + tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr13); + voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); + voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); + tvoutr2 = vmull_s8(vtmp4, wr14); + voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); + voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); + + // r4 + vtmp1 = vext_s8(vinr40, vinr41, 1); // 12345678 + vtmp2 = vext_s8(vinr40, vinr41, 2); // 2345678 + vtmp3 = vext_s8(vinr40, vinr41, 3); // 345678 + vtmp4 = vext_s8(vinr40, vinr41, 4); // 45678 + + tvoutr0 = vmull_s8(vinr40, wr40); + tvoutr0 = vmlal_s8(tvoutr0, vtmp1, wr41); + voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); + voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); + tvoutr0 = vmull_s8(vtmp2, wr42); + tvoutr0 = vmlal_s8(tvoutr0, vtmp3, wr43); + voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); + voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); + tvoutr0 = vmull_s8(vtmp4, wr44); + voutr00 = vaddw_s16(voutr00, vget_low_s16(tvoutr0)); + voutr01 = vaddw_s16(voutr01, vget_high_s16(tvoutr0)); + + tvoutr1 = vmull_s8(vinr40, wr30); + tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr31); + voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); + voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); + tvoutr1 = vmull_s8(vtmp2, wr32); + tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr33); + voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); + voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); + tvoutr1 = vmull_s8(vtmp4, wr34); + voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); + voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); + + tvoutr2 = vmull_s8(vinr40, wr20); + tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr21); + voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); + voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); + tvoutr2 = vmull_s8(vtmp2, wr22); + tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr23); + voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); + voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); + tvoutr2 = vmull_s8(vtmp4, wr24); + voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); + voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); + + // r5 + vtmp1 = vext_s8(vinr50, vinr51, 1); // 12345678 + vtmp2 = vext_s8(vinr50, vinr51, 2); // 2345678 + vtmp3 = vext_s8(vinr50, vinr51, 3); // 345678 + vtmp4 = vext_s8(vinr50, vinr51, 4); // 45678 + + tvoutr1 = vmull_s8(vinr50, wr40); + tvoutr1 = vmlal_s8(tvoutr1, vtmp1, wr41); + voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); + voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); + tvoutr1 = vmull_s8(vtmp2, wr42); + tvoutr1 = vmlal_s8(tvoutr1, vtmp3, wr43); + voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); + voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); + tvoutr1 = vmull_s8(vtmp4, wr44); + voutr10 = vaddw_s16(voutr10, vget_low_s16(tvoutr1)); + voutr11 = vaddw_s16(voutr11, vget_high_s16(tvoutr1)); + + tvoutr2 = vmull_s8(vinr50, wr30); + tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr31); + voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); + voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); + tvoutr2 = vmull_s8(vtmp2, wr32); + tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr33); + voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); + voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); + tvoutr2 = vmull_s8(vtmp4, wr34); + voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); + voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); + + // r6 + vtmp1 = vext_s8(vinr60, vinr61, 1); // 12345678 + vtmp2 = vext_s8(vinr60, vinr61, 2); // 2345678 + vtmp3 = vext_s8(vinr60, vinr61, 3); // 345678 + vtmp4 = vext_s8(vinr60, vinr61, 4); // 45678 + + tvoutr2 = vmull_s8(vinr60, wr40); + tvoutr2 = vmlal_s8(tvoutr2, vtmp1, wr41); + voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); + voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); + tvoutr2 = vmull_s8(vtmp2, wr42); + tvoutr2 = vmlal_s8(tvoutr2, vtmp3, wr43); + voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); + voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); + tvoutr2 = vmull_s8(vtmp4, wr44); + voutr20 = vaddw_s16(voutr20, vget_low_s16(tvoutr2)); + voutr21 = vaddw_s16(voutr21, vget_high_s16(tvoutr2)); + + /// data shift 8 bytes + din_ptr0 += 8; + din_ptr1 += 8; + din_ptr2 += 8; + din_ptr3 += 8; + din_ptr4 += 8; + din_ptr5 += 8; + din_ptr6 += 8; + + /// store + vst1q_s32(doutr0, voutr00); + vst1q_s32(doutr1, voutr10); + vst1q_s32(doutr2, voutr20); + doutr0 += 4; + doutr1 += 4; + doutr2 += 4; + vst1q_s32(doutr0, voutr01); + vst1q_s32(doutr1, voutr11); + vst1q_s32(doutr2, voutr21); + doutr0 += 4; + doutr1 += 4; + doutr2 += 4; + } /// end of tile_w + + dr0 = dr3; + dr1 = dr4; + dr2 = dr5; + dr3 = dr6; + dr4 = dr3 + win_round; + dr5 = dr4 + win_round; + dr6 = dr5 + win_round; + + dout_ptr = dout_ptr + 3 * wout_round; + } /// end of tile_h + + if (scales == 0) { + write_to_output_numc(pre_out, + dout_batch, + 1, + hout_round, + c, + c + 1, + 0, + hout, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + ptr_write); + } else if (od_type == PRECISION(kFloat)) { + write2_to_output_numc(pre_out, + reinterpret_cast(dout_batch), + 1, + hout_round, + c, + c + 1, + 0, + hout, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + reinterpret_cast(ptr_write), + scales); + } else if (od_type == PRECISION(kInt8)) { + write2_to_output_numc(pre_out, + reinterpret_cast(dout_batch), + 1, + hout_round, + c, + c + 1, + 0, + hout, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + reinterpret_cast(ptr_write), + scales); + } + // else if (od_type == AK_INT32) { + // write2_to_output_numc(pre_out, (int*)dout_batch, 1, hout_round, c, + // c+1, + // 0, hout, 0, wout_round, chout, hout, wout, flag_relu, + // (int*)ptr_write, scales); + // } + } /// end of chout + } /// end of batch num +} + +#endif // __aarch64__ + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_depthwise_5x5s2.cc b/lite/arm/math/conv_depthwise_5x5s2.cc new file mode 100644 index 00000000000..17ac1d87e50 --- /dev/null +++ b/lite/arm/math/conv_depthwise_5x5s2.cc @@ -0,0 +1,3746 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/conv_depthwise.h" +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void conv_depthwise_5x5s2p2(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); + +void conv_depthwise_5x5s2p2_relu(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); + +void conv_depthwise_5x5s2p2_s(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); + +void conv_depthwise_5x5s2p2_relu_s(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); + +void conv_depthwise_5x5s2(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + if (pad == 2) { + if (win >= 9) { + if (flag_relu) { + conv_depthwise_5x5s2p2_relu(din, + dout, + num, + chout, + hout, + wout, + chin, + hin, + win, + weights, + bias, + flag_bias, + flag_relu, + ctx); + } else { + conv_depthwise_5x5s2p2(din, + dout, + num, + chout, + hout, + wout, + chin, + hin, + win, + weights, + bias, + flag_bias, + flag_relu, + ctx); + } + } else { + if (flag_relu) { + conv_depthwise_5x5s2p2_relu_s(din, + dout, + num, + chout, + hout, + wout, + chin, + hin, + win, + weights, + bias, + flag_bias, + flag_relu, + ctx); + } else { + conv_depthwise_5x5s2p2_s(din, + dout, + num, + chout, + hout, + wout, + chin, + hin, + win, + weights, + bias, + flag_bias, + flag_relu, + ctx); + } + } + } +} + +#ifdef __aarch64__ + +//! larger depthwise, win >= 9; +void conv_depthwise_5x5s2p2(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + CHECK_GE(w_in, 9) << "only support win >= 9"; + int w_out_round = (w_out + 3) / 4 * 4; + int cnt = (w_out_round - 4) / 4; + int mid_cnt = cnt - 1; + int right_start = cnt * 2 * 4 - 2; + int mask_cnt = 12 - (w_in - right_start); + int mask[12]; + memset(mask, 0xff, 12 * sizeof(int)); + for (int i = 0; i < mask_cnt; ++i) { + mask[11 - i] = 0; + } + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + int in_spatial_size = w_in * h_in; + int out_spatial_size = w_out * h_out; + int weights_saptial_size = 25; + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * in_spatial_size * ch_in; + float* dout_batch = dout + n * out_spatial_size * ch_out; +#pragma omp parallel for + for (int c = 0; c < ch_in; ++c) { + const float* din_ch = din_batch + c * in_spatial_size; + float* dout_ch = dout_batch + c * out_spatial_size; + const float* din0 = zero_ptr; + const float* din1 = zero_ptr; + const float* din2 = din_ch; + const float* din3 = din2 + w_in; + const float* din4 = din3 + w_in; + const float* din5 = din4 + w_in; + const float* din6 = din5 + w_in; + + float out_buf0[4]; + float out_buf1[4]; + float* dout0 = dout_ch; + float* dout1 = dout0 + w_out; + + const float* weights_c = weights + c * weights_saptial_size; + for (int h = 0; h < h_out; h += 2) { + //! (h * 2 - 2) + 6 > h_in - 1 + if (h * 2 + 5 > h_in) { + switch (h * 2 + 5 - h_in) { + case 6: + din1 = zero_ptr; + case 5: + din2 = zero_ptr; + case 4: + din3 = zero_ptr; + case 3: + din4 = zero_ptr; + case 2: + din5 = zero_ptr; + case 1: + din6 = zero_ptr; + default: + break; + } + } + if (h + 2 > h_out) { + switch (h + 2 - h_out) { + case 1: + dout1 = write_ptr; + default: + break; + } + } + const float* din_ptr0 = din0; + const float* din_ptr1 = din1; + const float* din_ptr2 = din2; + const float* din_ptr3 = din3; + const float* din_ptr4 = din4; + const float* din_ptr5 = din5; + const float* din_ptr6 = din6; + + const float* weights_ptr = weights_c; + float* dout_ptr0 = dout0; + float* dout_ptr1 = dout1; + + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[c]; + } + float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; + int* mask_ptr = mask; + int loop = mid_cnt; + const int s_8 = 8; + const int s_16 = 16; + + //! in r0, r1/r4, r2/r5, r3/r6: x 0 2 4 -- v8 v13 v18 v23 + //! in r0, r1/r4, r2/r5, r3/r6: x 1 3 5 -- v9 v14 v19 v24 + //! in r0, r1/r4, r2/r5, r3/r6: 0 2 4 6 -- v6 v11 v16 v21 + //! in r0, r1/r4, r2/r5, r3/r6: 1 3 5 7 -- v7 v12 v17 v22 + //! in r0, r1/r4, r2/r5, r3/r6: 2 4 6 8 -- v10 v15 v20 v25 + //! out r0, r1 -- v26, v27 + asm volatile( + "movi v31.4s, #0x0\n" + "prfm pldl1keep, [%[din_ptr0]] \n" + "prfm pldl1keep, [%[din_ptr1]] \n" + "prfm pldl1keep, [%[din_ptr2]] \n" + "prfm pldl1keep, [%[din_ptr3]] \n" + "prfm pldl1keep, [%[din_ptr4]] \n" + "prfm pldl1keep, [%[din_ptr5]] \n" + "prfm pldl1keep, [%[din_ptr6]] \n" + "prfm pldl1keep, [%[weights]] \n" + "prfm pldl1keep, [%[mask]] \n" + // left + "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], #32 \n" // r0 v6: 0 + // 2 4 6, + // v7: 1 3 + // 5 7 + "ext v8.16b, v31.16b, v6.16b, #12 \n" // r0 v8: x + // 0 2 4 + "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], #32 \n" // r1 v11: + // 0 2 4 6, + // v12: 1 3 + // 5 7 + "ext v9.16b, v31.16b, v7.16b, #12 \n" // r0 v9: x + // 1 3 5 + "ld1 {v0.4s, v1.4s}, [%[weights]], #32 \n" // load + // weights + // 0-7 + "ext v10.16b, v6.16b, v31.16b, #4 \n" + "ld1 {v10.s}[3], [%[din_ptr0]] \n" // r0 v10: + // 2 4 6 8 + "sub %[din_ptr0], %[din_ptr0], #8 \n" + "ext v13.16b, v31.16b, v11.16b, #12 \n" // r1 v13: + // x 0 2 4 + "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], #32 \n" // r2 v16: + // 0 2 4 6, + // v17: 1 3 + // 5 7 + "ext v14.16b, v31.16b, v12.16b, #12 \n" // r1 v14: + // x 1 3 5 + "ld1 {v2.4s, v3.4s}, [%[weights]], #32 \n" // load + // weights + // 8-15 + "ext v15.16b, v11.16b, v31.16b, #4 \n" + "ld1 {v15.s}[3], [%[din_ptr1]] \n" // r1 v15: + // 2 4 6 + "sub %[din_ptr1], %[din_ptr1], #8 \n" + "ext v18.16b, v31.16b, v16.16b, #12 \n" // r2 v18: + // x 0 2 4 + "ld1 {v4.4s, v5.4s}, [%[weights]], #32 \n" // load + // weights + // 16-23 + "ext v19.16b, v31.16b, v17.16b, #12 \n" // r2 v19: + // x 1 3 5 + "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], #32 \n" // r3 v21: + // 0 2 4 6, + // v22: 1 3 + // 5 7 + "ext v20.16b, v16.16b, v31.16b, #4 \n" + "ld1 {v20.s}[3], [%[din_ptr2]] \n" // r2 v20: + // 2 4 6 8 + "sub %[din_ptr2], %[din_ptr2], #8 \n" + "ext v23.16b, v31.16b, v21.16b, #12 \n" // r3 v23: + // x 0 2 4 + "ld1 {v30.4s}, [%[weights]] \n" // load + // weights + // 24 + "ext v24.16b, v31.16b, v22.16b, #12 \n" // r3 v24: + // x 1 3 5 + "ld1 {v26.4s}, [%[vbias]] \n" // load + // bias to + // out_r0 + "ext v25.16b, v21.16b, v31.16b, #4 \n" + "ld1 {v25.s}[3], [%[din_ptr3]] \n" // r2 v25: + // 2 4 6 8 + "sub %[din_ptr3], %[din_ptr3], #8 \n" + "mov v27.16b, v26.16b \n" // load + // bias to + // out_r1 + "mov v28.16b, v31.16b \n" // load + // zero to + // out_r0 + "mov v29.16b, v31.16b \n" // load + // zero to + // out_r1 + + "fmla v26.4s, v8.4s, v0.s[0] \n" // out r0: + // w0 + "fmla v28.4s, v9.4s, v0.s[1] \n" // out r0: + // w1 + "fmla v26.4s, v6.4s, v0.s[2] \n" // out r0: + // w2 + "fmla v28.4s, v7.4s, v0.s[3] \n" // out r0: + // w3 + + "ld2 {v8.4s, v9.4s}, [%[din_ptr0]], %[s_8] \n" // next r0 + // v8: 0 2 + // 4 6, v9: + // 1 3 5 7 + + "fmla v26.4s, v10.4s, v1.s[0] \n" // out r0: + // w4 + "fmla v28.4s, v13.4s, v1.s[1] \n" // out r0: + // w5 + "fmla v26.4s, v14.4s, v1.s[2] \n" // out r0: + // w6 + "fmla v28.4s, v11.4s, v1.s[3] \n" // out r0: + // w7 + + "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], %[s_8] \n" // next r0 + // v6: 2 4 + // 6 8, v7: + // 3 5 7 9 + + "fmla v26.4s, v12.4s, v2.s[0] \n" // out r0: + // w8 + "fmla v28.4s, v15.4s, v2.s[1] \n" // out r0: + // w9 + "fmla v26.4s, v18.4s, v2.s[2] \n" // out r0: + // w10 + "fmla v28.4s, v19.4s, v2.s[3] \n" // out r0: + // w11 + + "ld2 {v10.4s, v11.4s}, [%[din_ptr0]], %[s_16] \n" // next r0 + // v10: 4 6 + // 8 10, + // v11: + // trash + // register + + "fmla v26.4s, v16.4s, v3.s[0] \n" // out r0: + // w12 + "fmla v28.4s, v17.4s, v3.s[1] \n" // out r0: + // w13 + "fmla v26.4s, v20.4s, v3.s[2] \n" // out r0: + // w14 + "fmla v28.4s, v23.4s, v3.s[3] \n" // out r0: + // w15 + "prfm pldl1keep, [%[din_ptr0]] \n" + + "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], #32 \n" // r4 v11: + // 0 2 4 6, + // v12: 1 3 + // 5 7 + + "fmla v26.4s, v24.4s, v4.s[0] \n" // out r0: + // w16 + "fmla v28.4s, v21.4s, v4.s[1] \n" // out r0: + // w17 + + "ext v13.16b, v31.16b, v11.16b, #12 \n" // r4 v13: + // x 0 2 4 + "ext v14.16b, v31.16b, v12.16b, #12 \n" // r4 v14: + // x 1 3 5 + "ext v15.16b, v11.16b, v31.16b, #4 \n" + + "fmla v26.4s, v22.4s, v4.s[2] \n" // out r0: + // w18 + "fmla v28.4s, v25.4s, v4.s[3] \n" // out r0: + // w19 + + "ld1 {v15.s}[3], [%[din_ptr4]] \n" // r4 v15: + // 2 4 6 + + "fmla v27.4s, v18.4s, v0.s[0] \n" // out r1: + // w0 + "fmla v29.4s, v19.4s, v0.s[1] \n" // out r1: + // w1 + + "sub %[din_ptr4], %[din_ptr4], #8 \n" + + "fmla v27.4s, v16.4s, v0.s[2] \n" // out r1: + // w2 + "fmla v29.4s, v17.4s, v0.s[3] \n" // out r1: + // w3 + "fmla v27.4s, v20.4s, v1.s[0] \n" // out r1: + // w4 + "fmla v29.4s, v23.4s, v1.s[1] \n" // out r1: + // w5 + + "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], #32 \n" // r5 v16: + // 0 2 4 6, + // v17: 1 3 + // 5 7 + + "fmla v27.4s, v24.4s, v1.s[2] \n" // out r1: + // w6 + "fmla v29.4s, v21.4s, v1.s[3] \n" // out r1: + // w7 + + "ext v18.16b, v31.16b, v16.16b, #12 \n" // r5 v18: + // x 0 2 4 + "ext v19.16b, v31.16b, v17.16b, #12 \n" // r5 v19: + // x 1 3 5 + "ext v20.16b, v16.16b, v31.16b, #4 \n" + + "fmla v27.4s, v22.4s, v2.s[0] \n" // out r1: + // w8 + "fmla v29.4s, v25.4s, v2.s[1] \n" // out r1: + // w9 + + "ld1 {v20.s}[3], [%[din_ptr5]] \n" // r5 v20: + // 2 4 6 + "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], #32 \n" // r6 v21: + // 0 2 4 6, + // v22: 1 3 + // 5 7 + + "ext v23.16b, v31.16b, v21.16b, #12 \n" // r6 v23: + // x 0 2 4 + "ext v24.16b, v31.16b, v22.16b, #12 \n" // r6 v24: + // x 1 3 5 + "ext v25.16b, v21.16b, v31.16b, #4 \n" + "sub %[din_ptr5], %[din_ptr5], #8 \n" + + "fmla v26.4s, v11.4s, v5.s[2] \n" // out r0: + // w22 + "fmla v28.4s, v12.4s, v5.s[3] \n" // out r0: + // w23 + + "ld1 {v25.s}[3], [%[din_ptr6]] \n" // r6 v25: + // 2 4 6 + + "fmla v26.4s, v13.4s, v5.s[0] \n" // out r0: + // w20 + "fmla v28.4s, v14.4s, v5.s[1] \n" // out r0: + // w21 + + "sub %[din_ptr6], %[din_ptr6], #8 \n" + + "fmla v26.4s, v15.4s, v30.s[0] \n" // out r0: + // w24 + "fmla v27.4s, v13.4s, v2.s[2] \n" // out r1: + // w10 + + "fadd v26.4s, v26.4s, v28.4s \n" + "fmla v29.4s, v14.4s, v2.s[3] \n" // out r1: + // w11 + + "ld2 {v13.4s, v14.4s}, [%[din_ptr1]], %[s_8] \n" // next r1 + // v13: 0 2 + // 4 6, + // v14: 1 3 + // 5 7 + "fmla v27.4s, v11.4s, v3.s[0] \n" // out r1: + // w12 + "fmla v29.4s, v12.4s, v3.s[1] \n" // out r1: + // w13 + + "st1 {v26.4s}, [%[dout_ptr0]], %[s_16] \n" // store + // output + // r0 + "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], %[s_8] \n" // next r1 + // v11: 2 4 + // 6 8, + // v12: 3 5 + // 7 9 + + "fmla v27.4s, v15.4s, v3.s[2] \n" // out r1: + // w14 + "fmla v29.4s, v16.4s, v4.s[1] \n" // out r1: + // w17 + "fmla v27.4s, v18.4s, v3.s[3] \n" // out r1: + // w15 + "fmla v29.4s, v19.4s, v4.s[0] \n" // out r1: + // w16 + + "ld2 {v15.4s, v16.4s}, [%[din_ptr1]], %[s_16] \n" // next r1 + // v15: 4 6 + // 8 10, + // v16: + // trash + // register + + "fmla v27.4s, v17.4s, v4.s[2] \n" // out r1: + // w18 + "fmla v29.4s, v20.4s, v4.s[3] \n" // out r1: + // w19 + + "ld2 {v18.4s, v19.4s}, [%[din_ptr2]], %[s_8] \n" // next r2 + // v18: 0 2 + // 4 6, + // v19: 1 3 + // 5 7 + "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], %[s_8] \n" // next r2 + // v16: 2 4 + // 6 8, + // v11: 3 5 + // 7 9 + + "fmla v27.4s, v23.4s, v5.s[0] \n" // out r1: + // w20 + "fmla v29.4s, v21.4s, v5.s[2] \n" // out r1: + // w22 + "fmla v27.4s, v24.4s, v5.s[1] \n" // out r1: + // w21 + "fmla v29.4s, v22.4s, v5.s[3] \n" // out r1: + // w23 + + "ld2 {v20.4s, v21.4s}, [%[din_ptr2]], %[s_16] \n" // next r2 + // v20: 4 6 + // 8 10, + // v21: + // trash + // register + "ld2 {v23.4s, v24.4s}, [%[din_ptr3]], %[s_8] \n" // next r3 + // v23: 0 2 + // 4 6, + // v24: 1 3 + // 5 7 + + "fmla v27.4s, v25.4s, v30.s[0] \n" // out r1: + // w24 + + "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], %[s_8] \n" // next r3 + // v21: 2 4 + // 6 8, + // v22: 3 5 + // 7 9 + "ld2 {v25.4s, v26.4s}, [%[din_ptr3]], %[s_16] \n" // next r3 + // v25: 4 6 + // 8 10, + // v26: + // trash + // register + + "fadd v27.4s, v27.4s, v29.4s \n" + "cmp %w[mid_cnt], #1 \n" + + "prfm pldl1keep, [%[din_ptr1]] \n" + "prfm pldl1keep, [%[din_ptr2]] \n" + "prfm pldl1keep, [%[din_ptr3]] \n" + + "st1 {v27.4s}, [%[dout_ptr1]], #16 \n" + "blt 2f \n" + + // mid loop + "1: \n" + "ld1 {v26.4s}, [%[vbias]] \n" + "mov v27.16b, v26.16b \n" + "mov v28.16b, v31.16b \n" + "mov v29.16b, v31.16b \n" + + // out_r0 r0-r3 + "fmla v26.4s, v8.4s, v0.s[0] \n" + "fmla v28.4s, v9.4s, v0.s[1] \n" + "fmla v26.4s, v6.4s, v0.s[2] \n" + "fmla v28.4s, v7.4s, v0.s[3] \n" + + "ld2 {v8.4s, v9.4s}, [%[din_ptr0]], %[s_8] \n" + + "fmla v26.4s, v10.4s, v1.s[0] \n" + "fmla v28.4s, v11.4s, v1.s[3] \n" + + "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], %[s_8] \n" + + "fmla v26.4s, v14.4s, v1.s[2] \n" + "fmla v28.4s, v13.4s, v1.s[1] \n" + + "ld2 {v10.4s, v11.4s}, [%[din_ptr0]], %[s_16] \n" + "prfm pldl1keep, [%[din_ptr0]] \n" + + "fmla v26.4s, v12.4s, v2.s[0] \n" + "fmla v28.4s, v15.4s, v2.s[1] \n" + + "ld2 {v13.4s, v14.4s}, [%[din_ptr4]], %[s_8] \n" + + "fmla v26.4s, v16.4s, v3.s[0] \n" + "fmla v27.4s, v16.4s, v0.s[2] \n" + + "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], %[s_8] \n" + + "fmla v28.4s, v19.4s, v2.s[3] \n" + "fmla v29.4s, v19.4s, v0.s[1] \n" + + "ld2 {v15.4s, v16.4s}, [%[din_ptr4]], %[s_16] \n" + "prfm pldl1keep, [%[din_ptr4]] \n" + + "fmla v26.4s, v18.4s, v2.s[2] \n" + "fmla v27.4s, v18.4s, v0.s[0] \n" + + "fmla v28.4s, v17.4s, v3.s[1] \n" + "fmla v29.4s, v17.4s, v0.s[3] \n" + + "ld2 {v18.4s, v19.4s}, [%[din_ptr5]], %[s_8] \n" + + "fmla v26.4s, v20.4s, v3.s[2] \n" + "fmla v27.4s, v20.4s, v1.s[0] \n" + + "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], %[s_8] \n" + + "fmla v29.4s, v21.4s, v1.s[3] \n" + "fmla v28.4s, v21.4s, v4.s[1] \n" + "fmla v28.4s, v23.4s, v3.s[3] \n" + "fmla v29.4s, v23.4s, v1.s[1] \n" + + "ld2 {v20.4s, v21.4s}, [%[din_ptr5]], %[s_16] \n" + "prfm pldl1keep, [%[din_ptr5]] \n" + + "fmla v26.4s, v24.4s, v4.s[0] \n" + "fmla v27.4s, v24.4s, v1.s[2] \n" + + "ld2 {v23.4s, v24.4s}, [%[din_ptr6]], %[s_8] \n" + + "fmla v27.4s, v22.4s, v2.s[0] \n" + "fmla v26.4s, v22.4s, v4.s[2] \n" + + "fmla v28.4s, v25.4s, v4.s[3] \n" + "fmla v29.4s, v25.4s, v2.s[1] \n" + + "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], %[s_8] \n" + "fadd v28.4s, v26.4s, v28.4s \n" + + "ld2 {v25.4s, v26.4s}, [%[din_ptr6]], %[s_16] \n" + "mov v26.16b, v31.16b \n" + "prfm pldl1keep, [%[din_ptr6]] \n" + + "fmla v26.4s, v13.4s, v5.s[0] \n" + "fmla v28.4s, v14.4s, v5.s[1] \n" + "fmla v27.4s, v13.4s, v2.s[2] \n" + "fmla v29.4s, v14.4s, v2.s[3] \n" + + "ld2 {v13.4s, v14.4s}, [%[din_ptr1]], %[s_8] \n" + + "fmla v26.4s, v11.4s, v5.s[2] \n" + "fmla v28.4s, v12.4s, v5.s[3] \n" + "fmla v27.4s, v11.4s, v3.s[0] \n" + "fmla v29.4s, v12.4s, v3.s[1] \n" + + "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], %[s_8] \n" + + "fmla v26.4s, v15.4s, v30.s[0] \n" + "fmla v27.4s, v15.4s, v3.s[2] \n" + "fmla v29.4s, v16.4s, v4.s[1] \n" + "fmla v27.4s, v17.4s, v4.s[2] \n" + + "ld2 {v15.4s, v16.4s}, [%[din_ptr1]], %[s_16] \n" + "prfm pldl1keep, [%[din_ptr1]] \n" + + "fmla v29.4s, v18.4s, v3.s[3] \n" + "fmla v27.4s, v19.4s, v4.s[0] \n" + + "ld2 {v18.4s, v19.4s}, [%[din_ptr2]], %[s_8] \n" + + "fmla v29.4s, v20.4s, v4.s[3] \n" + + "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], %[s_8] \n" + + "fmla v27.4s, v23.4s, v5.s[0] \n" + "fmla v27.4s, v21.4s, v5.s[2] \n" + + "ld2 {v20.4s, v21.4s}, [%[din_ptr2]], %[s_16] \n" + + "fmla v29.4s, v24.4s, v5.s[1] \n" + + "ld2 {v23.4s, v24.4s}, [%[din_ptr3]], %[s_8] \n" + "prfm pldl1keep, [%[din_ptr2]] \n" + + "fmla v29.4s, v22.4s, v5.s[3] \n" + + "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], %[s_8] \n" + + "fmla v27.4s, v25.4s, v30.s[0] \n" + + "fadd v26.4s, v26.4s, v28.4s \n" + + "prfm pldl1keep, [%[din_ptr3]] \n" + + "fadd v27.4s, v27.4s, v29.4s \n" + + "st1 {v26.4s}, [%[dout_ptr0]], #16 \n" + "st1 {v27.4s}, [%[dout_ptr1]], #16 \n" + + "ld2 {v25.4s, v26.4s}, [%[din_ptr3]], %[s_16] \n" + "subs %w[mid_cnt], %w[mid_cnt], #1 \n" + "bne 1b \n" + + "2: \n" + "ld2 {v26.4s, v27.4s}, [%[mask]], %[s_8] \n" + "ld2 {v28.4s, v29.4s}, [%[mask]], %[s_8] \n" + "bif v8.16b, v31.16b, v26.16b \n" + "bif v9.16b, v31.16b, v27.16b \n" + "bif v6.16b, v31.16b, v28.16b \n" + "bif v7.16b, v31.16b, v29.16b \n" + + "bif v13.16b, v31.16b, v26.16b \n" + "bif v14.16b, v31.16b, v27.16b \n" + "bif v11.16b, v31.16b, v28.16b \n" + "bif v12.16b, v31.16b, v29.16b \n" + + "bif v18.16b, v31.16b, v26.16b \n" + "bif v19.16b, v31.16b, v27.16b \n" + "bif v16.16b, v31.16b, v28.16b \n" + "bif v17.16b, v31.16b, v29.16b \n" + + "bif v23.16b, v31.16b, v26.16b \n" + "bif v24.16b, v31.16b, v27.16b \n" + "bif v21.16b, v31.16b, v28.16b \n" + "bif v22.16b, v31.16b, v29.16b \n" + + "ld2 {v28.4s, v29.4s}, [%[mask]] \n" + "ld1 {v26.4s}, [%[vbias]] \n" + "mov v29.16b, v31.16b \n" + + "bif v10.16b, v31.16b, v28.16b \n" + "bif v15.16b, v31.16b, v28.16b \n" + + "mov v27.16b, v26.16b \n" + + "bif v20.16b, v31.16b, v28.16b \n" + "bif v25.16b, v31.16b, v28.16b \n" + "mov v28.16b, v31.16b \n" + + "fmla v26.4s, v8.4s, v0.s[0] \n" + "fmla v28.4s, v9.4s, v0.s[1] \n" + "fmla v26.4s, v6.4s, v0.s[2] \n" + "fmla v28.4s, v7.4s, v0.s[3] \n" + + "fmla v26.4s, v10.4s, v1.s[0] \n" + "fmla v28.4s, v13.4s, v1.s[1] \n" + "fmla v26.4s, v14.4s, v1.s[2] \n" + "fmla v28.4s, v11.4s, v1.s[3] \n" + + "sub %[mask], %[mask], #16 \n" + "ld2 {v6.4s, v7.4s}, [%[mask]], %[s_8] \n" + "ld2 {v8.4s, v9.4s}, [%[mask]], %[s_8] \n" + "ld2 {v10.4s, v11.4s}, [%[mask]] \n" + + "fmla v26.4s, v12.4s, v2.s[0] \n" + "fmla v28.4s, v15.4s, v2.s[1] \n" + + "ld2 {v13.4s, v14.4s}, [%[din_ptr4]], %[s_8] \n" + + "fmla v26.4s, v16.4s, v3.s[0] \n" + "fmla v28.4s, v17.4s, v3.s[1] \n" + + "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], %[s_8] \n" + + "fmla v27.4s, v16.4s, v0.s[2] \n" + "fmla v29.4s, v17.4s, v0.s[3] \n" + + "ld2 {v15.4s, v16.4s}, [%[din_ptr4]] \n" + + "fmla v26.4s, v18.4s, v2.s[2] \n" + "fmla v28.4s, v19.4s, v2.s[3] \n" + "fmla v27.4s, v18.4s, v0.s[0] \n" + "fmla v29.4s, v19.4s, v0.s[1] \n" + + "bif v13.16b, v31.16b, v6.16b \n" + "bif v14.16b, v31.16b, v7.16b \n" + "bif v11.16b, v31.16b, v8.16b \n" + "bif v12.16b, v31.16b, v9.16b \n" + "bif v15.16b, v31.16b, v10.16b \n" + + "ld2 {v18.4s, v19.4s}, [%[din_ptr5]], %[s_8] \n" + + "fmla v26.4s, v20.4s, v3.s[2] \n" + "fmla v27.4s, v20.4s, v1.s[0] \n" + + "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], %[s_8] \n" + + "fmla v29.4s, v21.4s, v1.s[3] \n" + "fmla v28.4s, v21.4s, v4.s[1] \n" + + "ld2 {v20.4s, v21.4s}, [%[din_ptr5]] \n" + + "fmla v28.4s, v23.4s, v3.s[3] \n" + "fmla v29.4s, v23.4s, v1.s[1] \n" + "fmla v27.4s, v24.4s, v1.s[2] \n" + "fmla v26.4s, v24.4s, v4.s[0] \n" + + "bif v18.16b, v31.16b, v6.16b \n" + "bif v19.16b, v31.16b, v7.16b \n" + "bif v16.16b, v31.16b, v8.16b \n" + "bif v17.16b, v31.16b, v9.16b \n" + "bif v20.16b, v31.16b, v10.16b \n" + + "ld2 {v23.4s, v24.4s}, [%[din_ptr6]], %[s_8] \n" + + "fmla v27.4s, v22.4s, v2.s[0] \n" + "fmla v26.4s, v22.4s, v4.s[2] \n" + + "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], %[s_8] \n" + + "fmla v28.4s, v25.4s, v4.s[3] \n" + "fmla v29.4s, v25.4s, v2.s[1] \n" + "fadd v28.4s, v28.4s, v26.4s \n" + + "ld2 {v25.4s, v26.4s}, [%[din_ptr6]] \n" + "mov v26.16b, v31.16b \n" + + "bif v23.16b, v31.16b, v6.16b \n" + "bif v24.16b, v31.16b, v7.16b \n" + "bif v21.16b, v31.16b, v8.16b \n" + "bif v22.16b, v31.16b, v9.16b \n" + "bif v25.16b, v31.16b, v10.16b \n" + + "fmla v26.4s, v13.4s, v5.s[0] \n" + "fmla v28.4s, v14.4s, v5.s[1] \n" + "fmla v26.4s, v11.4s, v5.s[2] \n" + "fmla v28.4s, v12.4s, v5.s[3] \n" + "fmla v26.4s, v15.4s, v30.s[0] \n" + + "fmla v27.4s, v13.4s, v2.s[2] \n" + "fmla v29.4s, v14.4s, v2.s[3] \n" + "fmla v27.4s, v11.4s, v3.s[0] \n" + "fmla v29.4s, v12.4s, v3.s[1] \n" + + "fadd v26.4s, v26.4s, v28.4s \n" + "fmla v27.4s, v15.4s, v3.s[2] \n" + "fmla v29.4s, v18.4s, v3.s[3] \n" + "fmla v27.4s, v19.4s, v4.s[0] \n" + "fmla v29.4s, v16.4s, v4.s[1] \n" + + "st1 {v26.4s}, [%[out_buf0]] \n" + "fmla v27.4s, v17.4s, v4.s[2] \n" + "fmla v29.4s, v20.4s, v4.s[3] \n" + "fmla v27.4s, v23.4s, v5.s[0] \n" + "fmla v29.4s, v24.4s, v5.s[1] \n" + + "fmla v27.4s, v21.4s, v5.s[2] \n" + "fmla v29.4s, v22.4s, v5.s[3] \n" + "fmla v27.4s, v25.4s, v30.s[0] \n" + "fadd v27.4s, v27.4s, v29.4s \n" + + "st1 {v27.4s}, [%[out_buf1]] \n" + + : [dout_ptr0] "+r"(dout_ptr0), + [dout_ptr1] "+r"(dout_ptr1), + [mid_cnt] "+r"(loop), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [din_ptr6] "+r"(din_ptr6), + [mask] "+r"(mask_ptr), + [weights] "+r"(weights_ptr) + : [vbias] "r"(vbias), + [out_buf0] "r"(out_buf0), + [out_buf1] "r"(out_buf1), + [s_8] "r"(s_8), + [s_16] "r"(s_16) + : "memory", + "cc", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31"); + + int remain_cnt = w_out - (mid_cnt + 1) * 4; + for (int i = 0; i < remain_cnt; ++i) { + dout_ptr0[i] = out_buf0[i]; + dout_ptr1[i] = out_buf1[i]; + } + din0 = din4; + din1 = din5; + din2 = din6; + din3 = din6 + w_in; + din4 = din3 + w_in; + din5 = din4 + w_in; + din6 = din5 + w_in; + dout0 = dout1 + w_out; + dout1 = dout0 + w_out; + } + } + } +} + +//! larger depthwise, win >= 9; +void conv_depthwise_5x5s2p2_relu(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + CHECK_GE(w_in, 9) << "only support win >= 9"; + int w_out_round = (w_out + 3) / 4 * 4; + int cnt = (w_out_round - 4) / 4; + int mid_cnt = cnt - 1; + int right_start = cnt * 2 * 4 - 2; + int mask_cnt = 12 - (w_in - right_start); + int mask[12]; + memset(mask, 0xff, 12 * sizeof(int)); + for (int i = 0; i < mask_cnt; ++i) { + mask[11 - i] = 0; + } + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + int in_spatial_size = w_in * h_in; + int out_spatial_size = w_out * h_out; + int weights_saptial_size = 25; + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * in_spatial_size * ch_in; + float* dout_batch = dout + n * out_spatial_size * ch_out; + +#pragma omp parallel for + for (int c = 0; c < ch_in; ++c) { + const float* din_ch = din_batch + c * in_spatial_size; + float* dout_ch = dout_batch + c * out_spatial_size; + const float* din0 = zero_ptr; + const float* din1 = zero_ptr; + const float* din2 = din_ch; + const float* din3 = din2 + w_in; + const float* din4 = din3 + w_in; + const float* din5 = din4 + w_in; + const float* din6 = din5 + w_in; + + float out_buf0[4]; + float out_buf1[4]; + float* dout0 = dout_ch; + float* dout1 = dout0 + w_out; + + const float* weights_c = weights + c * weights_saptial_size; + for (int h = 0; h < h_out; h += 2) { + //! (h * 2 - 2) + 6 > h_in - 1 + if (h * 2 + 5 > h_in) { + switch (h * 2 + 5 - h_in) { + case 6: + din1 = zero_ptr; + case 5: + din2 = zero_ptr; + case 4: + din3 = zero_ptr; + case 3: + din4 = zero_ptr; + case 2: + din5 = zero_ptr; + case 1: + din6 = zero_ptr; + default: + break; + } + } + if (h + 2 > h_out) { + switch (h + 2 - h_out) { + case 1: + dout1 = write_ptr; + default: + break; + } + } + const float* din_ptr0 = din0; + const float* din_ptr1 = din1; + const float* din_ptr2 = din2; + const float* din_ptr3 = din3; + const float* din_ptr4 = din4; + const float* din_ptr5 = din5; + const float* din_ptr6 = din6; + + const float* weights_ptr = weights_c; + float* dout_ptr0 = dout0; + float* dout_ptr1 = dout1; + + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[c]; + } + float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; + int* mask_ptr = mask; + int loop = mid_cnt; + const int s_8 = 8; + const int s_16 = 16; + + //! in r0, r1/r4, r2/r5, r3/r6: x 0 2 4 -- v8 v13 v18 v23 + //! in r0, r1/r4, r2/r5, r3/r6: x 1 3 5 -- v9 v14 v19 v24 + //! in r0, r1/r4, r2/r5, r3/r6: 0 2 4 6 -- v6 v11 v16 v21 + //! in r0, r1/r4, r2/r5, r3/r6: 1 3 5 7 -- v7 v12 v17 v22 + //! in r0, r1/r4, r2/r5, r3/r6: 2 4 6 8 -- v10 v15 v20 v25 + //! out r0, r1 -- v26, v27 + asm volatile( + "movi v31.4s, #0x0\n" + "prfm pldl1keep, [%[din_ptr0]] \n" + "prfm pldl1keep, [%[din_ptr1]] \n" + "prfm pldl1keep, [%[din_ptr2]] \n" + "prfm pldl1keep, [%[din_ptr3]] \n" + "prfm pldl1keep, [%[din_ptr4]] \n" + "prfm pldl1keep, [%[din_ptr5]] \n" + "prfm pldl1keep, [%[din_ptr6]] \n" + "prfm pldl1keep, [%[weights]] \n" + "prfm pldl1keep, [%[mask]] \n" + // left + "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], #32 \n" // r0 v6: 0 + // 2 4 6, + // v7: 1 3 + // 5 7 + "ext v8.16b, v31.16b, v6.16b, #12 \n" // r0 v8: x + // 0 2 4 + "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], #32 \n" // r1 v11: + // 0 2 4 6, + // v12: 1 3 + // 5 7 + "ext v9.16b, v31.16b, v7.16b, #12 \n" // r0 v9: x + // 1 3 5 + "ld1 {v0.4s, v1.4s}, [%[weights]], #32 \n" // load + // weights + // 0-7 + "ext v10.16b, v6.16b, v31.16b, #4 \n" + "ld1 {v10.s}[3], [%[din_ptr0]] \n" // r0 v10: + // 2 4 6 8 + "sub %[din_ptr0], %[din_ptr0], #8 \n" + "ext v13.16b, v31.16b, v11.16b, #12 \n" // r1 v13: + // x 0 2 4 + "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], #32 \n" // r2 v16: + // 0 2 4 6, + // v17: 1 3 + // 5 7 + "ext v14.16b, v31.16b, v12.16b, #12 \n" // r1 v14: + // x 1 3 5 + "ld1 {v2.4s, v3.4s}, [%[weights]], #32 \n" // load + // weights + // 8-15 + "ext v15.16b, v11.16b, v31.16b, #4 \n" + "ld1 {v15.s}[3], [%[din_ptr1]] \n" // r1 v15: + // 2 4 6 + "sub %[din_ptr1], %[din_ptr1], #8 \n" + "ext v18.16b, v31.16b, v16.16b, #12 \n" // r2 v18: + // x 0 2 4 + "ld1 {v4.4s, v5.4s}, [%[weights]], #32 \n" // load + // weights + // 16-23 + "ext v19.16b, v31.16b, v17.16b, #12 \n" // r2 v19: + // x 1 3 5 + "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], #32 \n" // r3 v21: + // 0 2 4 6, + // v22: 1 3 + // 5 7 + "ext v20.16b, v16.16b, v31.16b, #4 \n" + "ld1 {v20.s}[3], [%[din_ptr2]] \n" // r2 v20: + // 2 4 6 8 + "sub %[din_ptr2], %[din_ptr2], #8 \n" + "ext v23.16b, v31.16b, v21.16b, #12 \n" // r3 v23: + // x 0 2 4 + "ld1 {v30.4s}, [%[weights]] \n" // load + // weights + // 24 + "ext v24.16b, v31.16b, v22.16b, #12 \n" // r3 v24: + // x 1 3 5 + "ld1 {v26.4s}, [%[vbias]] \n" // load + // bias to + // out_r0 + "ext v25.16b, v21.16b, v31.16b, #4 \n" + "ld1 {v25.s}[3], [%[din_ptr3]] \n" // r2 v25: + // 2 4 6 8 + "sub %[din_ptr3], %[din_ptr3], #8 \n" + "mov v27.16b, v26.16b \n" // load + // bias to + // out_r1 + "mov v28.16b, v31.16b \n" // load + // zero to + // out_r0 + "mov v29.16b, v31.16b \n" // load + // zero to + // out_r1 + + "fmla v26.4s, v8.4s, v0.s[0] \n" // out r0: + // w0 + "fmla v28.4s, v9.4s, v0.s[1] \n" // out r0: + // w1 + "fmla v26.4s, v6.4s, v0.s[2] \n" // out r0: + // w2 + "fmla v28.4s, v7.4s, v0.s[3] \n" // out r0: + // w3 + + "ld2 {v8.4s, v9.4s}, [%[din_ptr0]], %[s_8] \n" // next r0 + // v8: 0 2 + // 4 6, v9: + // 1 3 5 7 + + "fmla v26.4s, v10.4s, v1.s[0] \n" // out r0: + // w4 + "fmla v28.4s, v13.4s, v1.s[1] \n" // out r0: + // w5 + "fmla v26.4s, v14.4s, v1.s[2] \n" // out r0: + // w6 + "fmla v28.4s, v11.4s, v1.s[3] \n" // out r0: + // w7 + + "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], %[s_8] \n" // next r0 + // v6: 2 4 + // 6 8, v7: + // 3 5 7 9 + + "fmla v26.4s, v12.4s, v2.s[0] \n" // out r0: + // w8 + "fmla v28.4s, v15.4s, v2.s[1] \n" // out r0: + // w9 + "fmla v26.4s, v18.4s, v2.s[2] \n" // out r0: + // w10 + "fmla v28.4s, v19.4s, v2.s[3] \n" // out r0: + // w11 + + "ld2 {v10.4s, v11.4s}, [%[din_ptr0]], %[s_16] \n" // next r0 + // v10: 4 6 + // 8 10, + // v11: + // trash + // register + + "fmla v26.4s, v16.4s, v3.s[0] \n" // out r0: + // w12 + "fmla v28.4s, v17.4s, v3.s[1] \n" // out r0: + // w13 + "fmla v26.4s, v20.4s, v3.s[2] \n" // out r0: + // w14 + "fmla v28.4s, v23.4s, v3.s[3] \n" // out r0: + // w15 + "prfm pldl1keep, [%[din_ptr0]] \n" + + "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], #32 \n" // r4 v11: + // 0 2 4 6, + // v12: 1 3 + // 5 7 + + "fmla v26.4s, v24.4s, v4.s[0] \n" // out r0: + // w16 + "fmla v28.4s, v21.4s, v4.s[1] \n" // out r0: + // w17 + + "ext v13.16b, v31.16b, v11.16b, #12 \n" // r4 v13: + // x 0 2 4 + "ext v14.16b, v31.16b, v12.16b, #12 \n" // r4 v14: + // x 1 3 5 + "ext v15.16b, v11.16b, v31.16b, #4 \n" + + "fmla v26.4s, v22.4s, v4.s[2] \n" // out r0: + // w18 + "fmla v28.4s, v25.4s, v4.s[3] \n" // out r0: + // w19 + + "ld1 {v15.s}[3], [%[din_ptr4]] \n" // r4 v15: + // 2 4 6 + + "fmla v27.4s, v18.4s, v0.s[0] \n" // out r1: + // w0 + "fmla v29.4s, v19.4s, v0.s[1] \n" // out r1: + // w1 + + "sub %[din_ptr4], %[din_ptr4], #8 \n" + + "fmla v27.4s, v16.4s, v0.s[2] \n" // out r1: + // w2 + "fmla v29.4s, v17.4s, v0.s[3] \n" // out r1: + // w3 + "fmla v27.4s, v20.4s, v1.s[0] \n" // out r1: + // w4 + "fmla v29.4s, v23.4s, v1.s[1] \n" // out r1: + // w5 + + "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], #32 \n" // r5 v16: + // 0 2 4 6, + // v17: 1 3 + // 5 7 + + "fmla v27.4s, v24.4s, v1.s[2] \n" // out r1: + // w6 + "fmla v29.4s, v21.4s, v1.s[3] \n" // out r1: + // w7 + + "ext v18.16b, v31.16b, v16.16b, #12 \n" // r5 v18: + // x 0 2 4 + "ext v19.16b, v31.16b, v17.16b, #12 \n" // r5 v19: + // x 1 3 5 + "ext v20.16b, v16.16b, v31.16b, #4 \n" + + "fmla v27.4s, v22.4s, v2.s[0] \n" // out r1: + // w8 + "fmla v29.4s, v25.4s, v2.s[1] \n" // out r1: + // w9 + + "ld1 {v20.s}[3], [%[din_ptr5]] \n" // r5 v20: + // 2 4 6 + "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], #32 \n" // r6 v21: + // 0 2 4 6, + // v22: 1 3 + // 5 7 + + "ext v23.16b, v31.16b, v21.16b, #12 \n" // r6 v23: + // x 0 2 4 + "ext v24.16b, v31.16b, v22.16b, #12 \n" // r6 v24: + // x 1 3 5 + "ext v25.16b, v21.16b, v31.16b, #4 \n" + "sub %[din_ptr5], %[din_ptr5], #8 \n" + + "fmla v26.4s, v11.4s, v5.s[2] \n" // out r0: + // w22 + "fmla v28.4s, v12.4s, v5.s[3] \n" // out r0: + // w23 + + "ld1 {v25.s}[3], [%[din_ptr6]] \n" // r6 v25: + // 2 4 6 + + "fmla v26.4s, v13.4s, v5.s[0] \n" // out r0: + // w20 + "fmla v28.4s, v14.4s, v5.s[1] \n" // out r0: + // w21 + + "sub %[din_ptr6], %[din_ptr6], #8 \n" + + "fmla v26.4s, v15.4s, v30.s[0] \n" // out r0: + // w24 + "fmla v27.4s, v13.4s, v2.s[2] \n" // out r1: + // w10 + + "fadd v26.4s, v26.4s, v28.4s \n" + "fmla v29.4s, v14.4s, v2.s[3] \n" // out r1: + // w11 + "fmax v26.4s, v26.4s, v31.4s \n" + + "ld2 {v13.4s, v14.4s}, [%[din_ptr1]], %[s_8] \n" // next r1 + // v13: 0 2 + // 4 6, + // v14: 1 3 + // 5 7 + "fmla v27.4s, v11.4s, v3.s[0] \n" // out r1: + // w12 + "fmla v29.4s, v12.4s, v3.s[1] \n" // out r1: + // w13 + + "st1 {v26.4s}, [%[dout_ptr0]], %[s_16] \n" // store + // output + // r0 + "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], %[s_8] \n" // next r1 + // v11: 2 4 + // 6 8, + // v12: 3 5 + // 7 9 + + "fmla v27.4s, v15.4s, v3.s[2] \n" // out r1: + // w14 + "fmla v29.4s, v16.4s, v4.s[1] \n" // out r1: + // w17 + "fmla v27.4s, v18.4s, v3.s[3] \n" // out r1: + // w15 + "fmla v29.4s, v19.4s, v4.s[0] \n" // out r1: + // w16 + + "ld2 {v15.4s, v16.4s}, [%[din_ptr1]], %[s_16] \n" // next r1 + // v15: 4 6 + // 8 10, + // v16: + // trash + // register + + "fmla v27.4s, v17.4s, v4.s[2] \n" // out r1: + // w18 + "fmla v29.4s, v20.4s, v4.s[3] \n" // out r1: + // w19 + + "ld2 {v18.4s, v19.4s}, [%[din_ptr2]], %[s_8] \n" // next r2 + // v18: 0 2 + // 4 6, + // v19: 1 3 + // 5 7 + "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], %[s_8] \n" // next r2 + // v16: 2 4 + // 6 8, + // v11: 3 5 + // 7 9 + + "fmla v27.4s, v23.4s, v5.s[0] \n" // out r1: + // w20 + "fmla v29.4s, v21.4s, v5.s[2] \n" // out r1: + // w22 + "fmla v27.4s, v24.4s, v5.s[1] \n" // out r1: + // w21 + "fmla v29.4s, v22.4s, v5.s[3] \n" // out r1: + // w23 + + "ld2 {v20.4s, v21.4s}, [%[din_ptr2]], %[s_16] \n" // next r2 + // v20: 4 6 + // 8 10, + // v21: + // trash + // register + "ld2 {v23.4s, v24.4s}, [%[din_ptr3]], %[s_8] \n" // next r3 + // v23: 0 2 + // 4 6, + // v24: 1 3 + // 5 7 + + "fmla v27.4s, v25.4s, v30.s[0] \n" // out r1: + // w24 + + "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], %[s_8] \n" // next r3 + // v21: 2 4 + // 6 8, + // v22: 3 5 + // 7 9 + "ld2 {v25.4s, v26.4s}, [%[din_ptr3]], %[s_16] \n" // next r3 + // v25: 4 6 + // 8 10, + // v26: + // trash + // register + + "fadd v27.4s, v27.4s, v29.4s \n" + "fmax v27.4s, v27.4s, v31.4s \n" + "cmp %w[mid_cnt], #1 \n" + "prfm pldl1keep, [%[din_ptr1]] \n" + "prfm pldl1keep, [%[din_ptr2]] \n" + "prfm pldl1keep, [%[din_ptr3]] \n" + "st1 {v27.4s}, [%[dout_ptr1]], #16 \n" + "blt 2f \n" + + // mid loop + "1: \n" + "ld1 {v26.4s}, [%[vbias]] \n" + "mov v27.16b, v26.16b \n" + "mov v28.16b, v31.16b \n" + "mov v29.16b, v31.16b \n" + + // out_r0 r0-r3 + "fmla v26.4s, v8.4s, v0.s[0] \n" + "fmla v28.4s, v9.4s, v0.s[1] \n" + "fmla v26.4s, v6.4s, v0.s[2] \n" + "fmla v28.4s, v7.4s, v0.s[3] \n" + + "ld2 {v8.4s, v9.4s}, [%[din_ptr0]], %[s_8] \n" + + "fmla v26.4s, v10.4s, v1.s[0] \n" + "fmla v28.4s, v11.4s, v1.s[3] \n" + + "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], %[s_8] \n" + + "fmla v26.4s, v14.4s, v1.s[2] \n" + "fmla v28.4s, v13.4s, v1.s[1] \n" + + "ld2 {v10.4s, v11.4s}, [%[din_ptr0]], %[s_16] \n" + "prfm pldl1keep, [%[din_ptr0]] \n" + + "fmla v26.4s, v12.4s, v2.s[0] \n" + "fmla v28.4s, v15.4s, v2.s[1] \n" + + "ld2 {v13.4s, v14.4s}, [%[din_ptr4]], %[s_8] \n" + + "fmla v26.4s, v16.4s, v3.s[0] \n" + "fmla v27.4s, v16.4s, v0.s[2] \n" + + "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], %[s_8] \n" + + "fmla v28.4s, v19.4s, v2.s[3] \n" + "fmla v29.4s, v19.4s, v0.s[1] \n" + + "ld2 {v15.4s, v16.4s}, [%[din_ptr4]], %[s_16] \n" + "prfm pldl1keep, [%[din_ptr4]] \n" + + "fmla v26.4s, v18.4s, v2.s[2] \n" + "fmla v27.4s, v18.4s, v0.s[0] \n" + + "fmla v28.4s, v17.4s, v3.s[1] \n" + "fmla v29.4s, v17.4s, v0.s[3] \n" + + "ld2 {v18.4s, v19.4s}, [%[din_ptr5]], %[s_8] \n" + + "fmla v26.4s, v20.4s, v3.s[2] \n" + "fmla v27.4s, v20.4s, v1.s[0] \n" + + "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], %[s_8] \n" + + "fmla v29.4s, v21.4s, v1.s[3] \n" + "fmla v28.4s, v21.4s, v4.s[1] \n" + "fmla v28.4s, v23.4s, v3.s[3] \n" + "fmla v29.4s, v23.4s, v1.s[1] \n" + + "ld2 {v20.4s, v21.4s}, [%[din_ptr5]], %[s_16] \n" + "prfm pldl1keep, [%[din_ptr5]] \n" + + "fmla v26.4s, v24.4s, v4.s[0] \n" + "fmla v27.4s, v24.4s, v1.s[2] \n" + + "ld2 {v23.4s, v24.4s}, [%[din_ptr6]], %[s_8] \n" + + "fmla v27.4s, v22.4s, v2.s[0] \n" + "fmla v26.4s, v22.4s, v4.s[2] \n" + + "fmla v28.4s, v25.4s, v4.s[3] \n" + "fmla v29.4s, v25.4s, v2.s[1] \n" + + "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], %[s_8] \n" + "fadd v28.4s, v26.4s, v28.4s \n" + + "ld2 {v25.4s, v26.4s}, [%[din_ptr6]], %[s_16] \n" + "mov v26.16b, v31.16b \n" + "prfm pldl1keep, [%[din_ptr6]] \n" + + "fmla v26.4s, v13.4s, v5.s[0] \n" + "fmla v28.4s, v14.4s, v5.s[1] \n" + "fmla v27.4s, v13.4s, v2.s[2] \n" + "fmla v29.4s, v14.4s, v2.s[3] \n" + + "ld2 {v13.4s, v14.4s}, [%[din_ptr1]], %[s_8] \n" + + "fmla v26.4s, v11.4s, v5.s[2] \n" + "fmla v28.4s, v12.4s, v5.s[3] \n" + "fmla v27.4s, v11.4s, v3.s[0] \n" + "fmla v29.4s, v12.4s, v3.s[1] \n" + + "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], %[s_8] \n" + + "fmla v26.4s, v15.4s, v30.s[0] \n" + "fmla v27.4s, v15.4s, v3.s[2] \n" + "fmla v29.4s, v16.4s, v4.s[1] \n" + "fmla v27.4s, v17.4s, v4.s[2] \n" + + "ld2 {v15.4s, v16.4s}, [%[din_ptr1]], %[s_16] \n" + "prfm pldl1keep, [%[din_ptr1]] \n" + + "fmla v29.4s, v18.4s, v3.s[3] \n" + "fmla v27.4s, v19.4s, v4.s[0] \n" + + "ld2 {v18.4s, v19.4s}, [%[din_ptr2]], %[s_8] \n" + + "fmla v29.4s, v20.4s, v4.s[3] \n" + + "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], %[s_8] \n" + + "fmla v27.4s, v23.4s, v5.s[0] \n" + "fmla v27.4s, v21.4s, v5.s[2] \n" + + "ld2 {v20.4s, v21.4s}, [%[din_ptr2]], %[s_16] \n" + + "fmla v29.4s, v24.4s, v5.s[1] \n" + + "ld2 {v23.4s, v24.4s}, [%[din_ptr3]], %[s_8] \n" + "prfm pldl1keep, [%[din_ptr2]] \n" + + "fmla v29.4s, v22.4s, v5.s[3] \n" + + "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], %[s_8] \n" + + "fmla v27.4s, v25.4s, v30.s[0] \n" + + "fadd v26.4s, v26.4s, v28.4s \n" + "fadd v27.4s, v27.4s, v29.4s \n" + "fmax v26.4s, v26.4s, v31.4s \n" + "fmax v27.4s, v27.4s, v31.4s \n" + + "prfm pldl1keep, [%[din_ptr3]] \n" + "st1 {v26.4s}, [%[dout_ptr0]], #16 \n" + "st1 {v27.4s}, [%[dout_ptr1]], #16 \n" + + "ld2 {v25.4s, v26.4s}, [%[din_ptr3]], %[s_16] \n" + "subs %w[mid_cnt], %w[mid_cnt], #1 \n" + "bne 1b \n" + + "2: \n" + "ld2 {v26.4s, v27.4s}, [%[mask]], %[s_8] \n" + "ld2 {v28.4s, v29.4s}, [%[mask]], %[s_8] \n" + "bif v8.16b, v31.16b, v26.16b \n" + "bif v9.16b, v31.16b, v27.16b \n" + "bif v6.16b, v31.16b, v28.16b \n" + "bif v7.16b, v31.16b, v29.16b \n" + + "bif v13.16b, v31.16b, v26.16b \n" + "bif v14.16b, v31.16b, v27.16b \n" + "bif v11.16b, v31.16b, v28.16b \n" + "bif v12.16b, v31.16b, v29.16b \n" + + "bif v18.16b, v31.16b, v26.16b \n" + "bif v19.16b, v31.16b, v27.16b \n" + "bif v16.16b, v31.16b, v28.16b \n" + "bif v17.16b, v31.16b, v29.16b \n" + + "bif v23.16b, v31.16b, v26.16b \n" + "bif v24.16b, v31.16b, v27.16b \n" + "bif v21.16b, v31.16b, v28.16b \n" + "bif v22.16b, v31.16b, v29.16b \n" + + "ld2 {v28.4s, v29.4s}, [%[mask]] \n" + "ld1 {v26.4s}, [%[vbias]] \n" + "mov v29.16b, v31.16b \n" + + "bif v10.16b, v31.16b, v28.16b \n" + "bif v15.16b, v31.16b, v28.16b \n" + + "mov v27.16b, v26.16b \n" + + "bif v20.16b, v31.16b, v28.16b \n" + "bif v25.16b, v31.16b, v28.16b \n" + "mov v28.16b, v31.16b \n" + + "fmla v26.4s, v8.4s, v0.s[0] \n" + "fmla v28.4s, v9.4s, v0.s[1] \n" + "fmla v26.4s, v6.4s, v0.s[2] \n" + "fmla v28.4s, v7.4s, v0.s[3] \n" + + "fmla v26.4s, v10.4s, v1.s[0] \n" + "fmla v28.4s, v13.4s, v1.s[1] \n" + "fmla v26.4s, v14.4s, v1.s[2] \n" + "fmla v28.4s, v11.4s, v1.s[3] \n" + + "sub %[mask], %[mask], #16 \n" + "ld2 {v6.4s, v7.4s}, [%[mask]], %[s_8] \n" + "ld2 {v8.4s, v9.4s}, [%[mask]], %[s_8] \n" + "ld2 {v10.4s, v11.4s}, [%[mask]] \n" + + "fmla v26.4s, v12.4s, v2.s[0] \n" + "fmla v28.4s, v15.4s, v2.s[1] \n" + + "ld2 {v13.4s, v14.4s}, [%[din_ptr4]], %[s_8] \n" + + "fmla v26.4s, v16.4s, v3.s[0] \n" + "fmla v28.4s, v17.4s, v3.s[1] \n" + + "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], %[s_8] \n" + + "fmla v27.4s, v16.4s, v0.s[2] \n" + "fmla v29.4s, v17.4s, v0.s[3] \n" + + "ld2 {v15.4s, v16.4s}, [%[din_ptr4]] \n" + + "fmla v26.4s, v18.4s, v2.s[2] \n" + "fmla v28.4s, v19.4s, v2.s[3] \n" + "fmla v27.4s, v18.4s, v0.s[0] \n" + "fmla v29.4s, v19.4s, v0.s[1] \n" + + "bif v13.16b, v31.16b, v6.16b \n" + "bif v14.16b, v31.16b, v7.16b \n" + "bif v11.16b, v31.16b, v8.16b \n" + "bif v12.16b, v31.16b, v9.16b \n" + "bif v15.16b, v31.16b, v10.16b \n" + + "ld2 {v18.4s, v19.4s}, [%[din_ptr5]], %[s_8] \n" + + "fmla v26.4s, v20.4s, v3.s[2] \n" + "fmla v27.4s, v20.4s, v1.s[0] \n" + + "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], %[s_8] \n" + + "fmla v29.4s, v21.4s, v1.s[3] \n" + "fmla v28.4s, v21.4s, v4.s[1] \n" + + "ld2 {v20.4s, v21.4s}, [%[din_ptr5]] \n" + + "fmla v28.4s, v23.4s, v3.s[3] \n" + "fmla v29.4s, v23.4s, v1.s[1] \n" + "fmla v27.4s, v24.4s, v1.s[2] \n" + "fmla v26.4s, v24.4s, v4.s[0] \n" + + "bif v18.16b, v31.16b, v6.16b \n" + "bif v19.16b, v31.16b, v7.16b \n" + "bif v16.16b, v31.16b, v8.16b \n" + "bif v17.16b, v31.16b, v9.16b \n" + "bif v20.16b, v31.16b, v10.16b \n" + + "ld2 {v23.4s, v24.4s}, [%[din_ptr6]], %[s_8] \n" + + "fmla v27.4s, v22.4s, v2.s[0] \n" + "fmla v26.4s, v22.4s, v4.s[2] \n" + + "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], %[s_8] \n" + + "fmla v28.4s, v25.4s, v4.s[3] \n" + "fmla v29.4s, v25.4s, v2.s[1] \n" + "fadd v28.4s, v28.4s, v26.4s \n" + + "ld2 {v25.4s, v26.4s}, [%[din_ptr6]] \n" + "mov v26.16b, v31.16b \n" + + "bif v23.16b, v31.16b, v6.16b \n" + "bif v24.16b, v31.16b, v7.16b \n" + "bif v21.16b, v31.16b, v8.16b \n" + "bif v22.16b, v31.16b, v9.16b \n" + "bif v25.16b, v31.16b, v10.16b \n" + + "fmla v26.4s, v13.4s, v5.s[0] \n" + "fmla v28.4s, v14.4s, v5.s[1] \n" + "fmla v26.4s, v11.4s, v5.s[2] \n" + "fmla v28.4s, v12.4s, v5.s[3] \n" + "fmla v26.4s, v15.4s, v30.s[0] \n" + + "fmla v27.4s, v13.4s, v2.s[2] \n" + "fmla v29.4s, v14.4s, v2.s[3] \n" + "fmla v27.4s, v11.4s, v3.s[0] \n" + "fmla v29.4s, v12.4s, v3.s[1] \n" + + "fadd v26.4s, v26.4s, v28.4s \n" + "fmla v27.4s, v15.4s, v3.s[2] \n" + "fmla v29.4s, v18.4s, v3.s[3] \n" + "fmla v27.4s, v19.4s, v4.s[0] \n" + "fmla v29.4s, v16.4s, v4.s[1] \n" + + "fmax v26.4s, v26.4s, v31.4s \n" + "fmla v27.4s, v17.4s, v4.s[2] \n" + "fmla v29.4s, v20.4s, v4.s[3] \n" + "fmla v27.4s, v23.4s, v5.s[0] \n" + "fmla v29.4s, v24.4s, v5.s[1] \n" + + "st1 {v26.4s}, [%[out_buf0]] \n" + "fmla v27.4s, v21.4s, v5.s[2] \n" + "fmla v29.4s, v22.4s, v5.s[3] \n" + "fmla v27.4s, v25.4s, v30.s[0] \n" + "fadd v27.4s, v27.4s, v29.4s \n" + + "fmax v27.4s, v27.4s, v31.4s \n" + "st1 {v27.4s}, [%[out_buf1]] \n" + + : [dout_ptr0] "+r"(dout_ptr0), + [dout_ptr1] "+r"(dout_ptr1), + [mid_cnt] "+r"(loop), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [din_ptr6] "+r"(din_ptr6), + [mask] "+r"(mask_ptr), + [weights] "+r"(weights_ptr) + : [vbias] "r"(vbias), + [out_buf0] "r"(out_buf0), + [out_buf1] "r"(out_buf1), + [s_8] "r"(s_8), + [s_16] "r"(s_16) + : "memory", + "cc", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31"); + + int remain_cnt = w_out - (mid_cnt + 1) * 4; + for (int i = 0; i < remain_cnt; ++i) { + dout_ptr0[i] = out_buf0[i]; + dout_ptr1[i] = out_buf1[i]; + } + din0 = din4; + din1 = din5; + din2 = din6; + din3 = din6 + w_in; + din4 = din3 + w_in; + din5 = din4 + w_in; + din6 = din5 + w_in; + dout0 = dout1 + w_out; + dout1 = dout0 + w_out; + } + } + } +} + +//! small depthwise, win < 9; +void conv_depthwise_5x5s2p2_s(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + CHECK_LT(w_in, 9) << "only support win < 9"; + int w_out_round = (w_out + 3) / 4 * 4; + int mask_cnt = 12 - w_in - 2; + int mask[12]; + memset(mask, 0xff, 12 * sizeof(int)); + for (int i = 0; i < mask_cnt; ++i) { + mask[11 - i] = 0; + } + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + int in_spatial_size = w_in * h_in; + int out_spatial_size = w_out * h_out; + int weights_saptial_size = 25; + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * in_spatial_size * ch_in; + float* dout_batch = dout + n * out_spatial_size * ch_out; +#pragma omp parallel for + for (int c = 0; c < ch_in; ++c) { + const float* din_ch = din_batch + c * in_spatial_size; + float* dout_ch = dout_batch + c * out_spatial_size; + const float* din0 = zero_ptr; + const float* din1 = zero_ptr; + const float* din2 = din_ch; + const float* din3 = din2 + w_in; + const float* din4 = din3 + w_in; + + float out_buf0[4]; + float out_buf1[4]; + float* dout0 = dout_ch; + float* dout1 = dout0 + w_out; + + const float* weights_c = weights + c * weights_saptial_size; + for (int h = 0; h < h_out; h += 1) { + //! (h * 2 - 2) + 4 > h_in - 1 + if (h * 2 + 3 > h_in) { + switch (h * 2 + 3 - h_in) { + case 4: + din1 = zero_ptr; + case 3: + din2 = zero_ptr; + case 2: + din3 = zero_ptr; + case 1: + din4 = zero_ptr; + default: + break; + } + } + + const float* din_ptr0 = din0; + const float* din_ptr1 = din1; + const float* din_ptr2 = din2; + const float* din_ptr3 = din3; + const float* din_ptr4 = din4; + + const float* weights_ptr = weights_c; + float* dout_ptr0 = dout0; + + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[c]; + } + float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; + int* mask_ptr = mask; + const int s_8 = 8; + //! in r0/r4, r1, r2, r3: x 0 2 4 -- v8 v13 v18 v23 v28 + //! in r0/r4, r1, r2, r3: x 1 3 5 -- v9 v14 v19 v24 v29 + //! in r0/r4, r1, r2, r3: 0 2 4 6 -- v6 v11 v16 v21 v26 + //! in r0/r4, r1, r2, r3: 1 3 5 7 -- v7 v12 v17 v22 v27 + //! in r0/r4, r1, r2, r3: 2 4 6 8 -- v10 v15 v20 v25 v30 + //! out r0 -- v4 + asm volatile( + "movi v31.4s, #0x0\n" + "prfm pldl1keep, [%[din_ptr0]] \n" + "prfm pldl1keep, [%[din_ptr1]] \n" + "prfm pldl1keep, [%[din_ptr2]] \n" + "prfm pldl1keep, [%[din_ptr3]] \n" + "prfm pldl1keep, [%[din_ptr4]] \n" + "prfm pldl1keep, [%[weights]] \n" + "prfm pldl1keep, [%[mask]] \n" + + //! load mask + "ld2 {v0.4s, v1.4s}, [%[mask]], %[s_8] \n" + "ld2 {v2.4s, v3.4s}, [%[mask]], %[s_8] \n" + "ld2 {v4.4s, v5.4s}, [%[mask]] \n" + + //! load and extract input + "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], #32 \n" + "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], #32 \n" + "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], #32 \n" + "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], #32 \n" + "ld2 {v26.4s, v27.4s}, [%[din_ptr4]], #32 \n" + + "ext v8.16b, v31.16b, v6.16b, #12 \n" + "ext v9.16b, v31.16b, v7.16b, #12 \n" + "ext v13.16b, v31.16b, v11.16b, #12 \n" + "ext v14.16b, v31.16b, v12.16b, #12 \n" + + "ext v18.16b, v31.16b, v16.16b, #12 \n" + "ext v19.16b, v31.16b, v17.16b, #12 \n" + "ext v23.16b, v31.16b, v21.16b, #12 \n" + "ext v24.16b, v31.16b, v22.16b, #12 \n" + "ext v28.16b, v31.16b, v26.16b, #12 \n" + "ext v29.16b, v31.16b, v27.16b, #12 \n" + + "ext v10.16b, v6.16b, v31.16b, #4 \n" + "ext v15.16b, v11.16b, v31.16b, #4 \n" + "ext v20.16b, v16.16b, v31.16b, #4 \n" + "ext v25.16b, v21.16b, v31.16b, #4 \n" + "ext v30.16b, v26.16b, v31.16b, #4 \n" + + "bif v8.16b, v31.16b, v0.16b \n" + "bif v9.16b, v31.16b, v1.16b \n" + "bif v6.16b, v31.16b, v2.16b \n" + "bif v7.16b, v31.16b, v3.16b \n" + + "bif v13.16b, v31.16b, v0.16b \n" + "bif v14.16b, v31.16b, v1.16b \n" + "bif v11.16b, v31.16b, v2.16b \n" + "bif v12.16b, v31.16b, v3.16b \n" + + "bif v18.16b, v31.16b, v0.16b \n" + "bif v19.16b, v31.16b, v1.16b \n" + "bif v16.16b, v31.16b, v2.16b \n" + "bif v17.16b, v31.16b, v3.16b \n" + + "ld1 {v10.s}[3], [%[din_ptr0]] \n" + "ld1 {v15.s}[3], [%[din_ptr1]] \n" + "ld1 {v20.s}[3], [%[din_ptr2]] \n" + "ld1 {v25.s}[3], [%[din_ptr3]] \n" + "ld1 {v30.s}[3], [%[din_ptr4]] \n" + + "bif v23.16b, v31.16b, v0.16b \n" + "bif v24.16b, v31.16b, v1.16b \n" + "bif v21.16b, v31.16b, v2.16b \n" + "bif v22.16b, v31.16b, v3.16b \n" + + "bif v28.16b, v31.16b, v0.16b \n" + "bif v29.16b, v31.16b, v1.16b \n" + "bif v26.16b, v31.16b, v2.16b \n" + "bif v27.16b, v31.16b, v3.16b \n" + + "bif v10.16b, v31.16b, v4.16b \n" + "bif v15.16b, v31.16b, v4.16b \n" + "bif v20.16b, v31.16b, v4.16b \n" + "bif v25.16b, v31.16b, v4.16b \n" + "bif v30.16b, v31.16b, v4.16b \n" + + "ld1 {v4.4s}, [%[vbias]] \n" + "mov v5.16b, v31.16b \n" + + "ld1 {v0.4s, v1.4s}, [%[weights]], #32 \n" // load weights 0-7 + "ld1 {v2.4s, v3.4s}, [%[weights]], #32 \n" // load weights 8-15 + + //! compute + "fmla v4.4s, v8.4s, v0.s[0] \n" // out r0: w0 + "fmla v5.4s, v9.4s, v0.s[1] \n" // out r0: w1 + "fmla v4.4s, v6.4s, v0.s[2] \n" // out r0: w2 + "fmla v5.4s, v7.4s, v0.s[3] \n" // out r0: w3 + + "fmla v4.4s, v10.4s, v1.s[0] \n" // out r0: w4 + "fmla v5.4s, v13.4s, v1.s[1] \n" // out r0: w5 + "fmla v4.4s, v14.4s, v1.s[2] \n" // out r0: w6 + "fmla v5.4s, v11.4s, v1.s[3] \n" // out r0: w7 + + "ld1 {v6.4s, v7.4s}, [%[weights]], #32 \n" // load weights 16-23 + "ld1 {v8.s}[0], [%[weights]] \n" // load weights 24 + + "fmla v4.4s, v12.4s, v2.s[0] \n" // out r0: w8 + "fmla v5.4s, v15.4s, v2.s[1] \n" // out r0: w9 + "fmla v4.4s, v18.4s, v2.s[2] \n" // out r0: w10 + "fmla v5.4s, v19.4s, v2.s[3] \n" // out r0: w11 + + "fmla v4.4s, v16.4s, v3.s[0] \n" // out r0: w12 + "fmla v5.4s, v17.4s, v3.s[1] \n" // out r0: w13 + "fmla v4.4s, v20.4s, v3.s[2] \n" // out r0: w14 + "fmla v5.4s, v23.4s, v3.s[3] \n" // out r0: w15 + + "fmla v4.4s, v24.4s, v6.s[0] \n" // out r0: w16 + "fmla v5.4s, v21.4s, v6.s[1] \n" // out r0: w17 + "fmla v4.4s, v22.4s, v6.s[2] \n" // out r0: w18 + "fmla v5.4s, v25.4s, v6.s[3] \n" // out r0: w19 + + "fmla v4.4s, v28.4s, v7.s[0] \n" // out r0: w20 + "fmla v5.4s, v29.4s, v7.s[1] \n" // out r0: w21 + "fmla v4.4s, v26.4s, v7.s[2] \n" // out r0: w22 + "fmla v5.4s, v27.4s, v7.s[3] \n" // out r0: w23 + "fmla v4.4s, v30.4s, v8.s[0] \n" // out r0: w24 + + "fadd v4.4s, v4.4s, v5.4s \n" // add out to v4 + "st1 {v4.4s}, [%[out_buf0]] \n" + + : [dout_ptr0] "+r"(dout_ptr0), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [mask] "+r"(mask_ptr), + [weights] "+r"(weights_ptr) + : [vbias] "r"(vbias), + [out_buf0] "r"(out_buf0), + [out_buf1] "r"(out_buf1), + [s_8] "r"(s_8) + : "memory", + "cc", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31"); + for (int i = 0; i < w_out; ++i) { + dout_ptr0[i] = out_buf0[i]; + } + din0 = din2; + din1 = din3; + din2 = din4; + din3 = din2 + w_in; + din4 = din3 + w_in; + dout0 += w_out; + } + } + } +} + +//! small depthwise, win < 9; +void conv_depthwise_5x5s2p2_relu_s(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + CHECK_LT(w_in, 9) << "only support win < 9"; + int w_out_round = (w_out + 3) / 4 * 4; + int mask_cnt = 12 - w_in - 2; + int mask[12]; + memset(mask, 0xff, 12 * sizeof(int)); + for (int i = 0; i < mask_cnt; ++i) { + mask[11 - i] = 0; + } + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + int in_spatial_size = w_in * h_in; + int out_spatial_size = w_out * h_out; + int weights_saptial_size = 25; + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * in_spatial_size * ch_in; + float* dout_batch = dout + n * out_spatial_size * ch_out; +#pragma omp parallel for + for (int c = 0; c < ch_in; ++c) { + const float* din_ch = din_batch + c * in_spatial_size; + float* dout_ch = dout_batch + c * out_spatial_size; + const float* din0 = zero_ptr; + const float* din1 = zero_ptr; + const float* din2 = din_ch; + const float* din3 = din2 + w_in; + const float* din4 = din3 + w_in; + + float out_buf0[4]; + float out_buf1[4]; + float* dout0 = dout_ch; + float* dout1 = dout0 + w_out; + + const float* weights_c = weights + c * weights_saptial_size; + for (int h = 0; h < h_out; h += 1) { + //! (h * 2 - 2) + 4 > h_in - 1 + if (h * 2 + 3 > h_in) { + switch (h * 2 + 3 - h_in) { + case 4: + din1 = zero_ptr; + case 3: + din2 = zero_ptr; + case 2: + din3 = zero_ptr; + case 1: + din4 = zero_ptr; + default: + break; + } + } + const float* din_ptr0 = din0; + const float* din_ptr1 = din1; + const float* din_ptr2 = din2; + const float* din_ptr3 = din3; + const float* din_ptr4 = din4; + + const float* weights_ptr = weights_c; + float* dout_ptr0 = dout0; + + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[c]; + } + float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; + int* mask_ptr = mask; + const int s_8 = 8; + //! in r0/r4, r1, r2, r3: x 0 2 4 -- v8 v13 v18 v23 v28 + //! in r0/r4, r1, r2, r3: x 1 3 5 -- v9 v14 v19 v24 v29 + //! in r0/r4, r1, r2, r3: 0 2 4 6 -- v6 v11 v16 v21 v26 + //! in r0/r4, r1, r2, r3: 1 3 5 7 -- v7 v12 v17 v22 v27 + //! in r0/r4, r1, r2, r3: 2 4 6 8 -- v10 v15 v20 v25 v30 + //! out r0 -- v4 + asm volatile( + "movi v31.4s, #0x0\n" + "prfm pldl1keep, [%[din_ptr0]] \n" + "prfm pldl1keep, [%[din_ptr1]] \n" + "prfm pldl1keep, [%[din_ptr2]] \n" + "prfm pldl1keep, [%[din_ptr3]] \n" + "prfm pldl1keep, [%[din_ptr4]] \n" + "prfm pldl1keep, [%[weights]] \n" + "prfm pldl1keep, [%[mask]] \n" + + //! load mask + "ld2 {v0.4s, v1.4s}, [%[mask]], %[s_8] \n" + "ld2 {v2.4s, v3.4s}, [%[mask]], %[s_8] \n" + "ld2 {v4.4s, v5.4s}, [%[mask]] \n" + + //! load and extract input + "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], #32 \n" + "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], #32 \n" + "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], #32 \n" + "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], #32 \n" + "ld2 {v26.4s, v27.4s}, [%[din_ptr4]], #32 \n" + + "ext v8.16b, v31.16b, v6.16b, #12 \n" + "ext v9.16b, v31.16b, v7.16b, #12 \n" + "ext v13.16b, v31.16b, v11.16b, #12 \n" + "ext v14.16b, v31.16b, v12.16b, #12 \n" + + "ext v18.16b, v31.16b, v16.16b, #12 \n" + "ext v19.16b, v31.16b, v17.16b, #12 \n" + "ext v23.16b, v31.16b, v21.16b, #12 \n" + "ext v24.16b, v31.16b, v22.16b, #12 \n" + "ext v28.16b, v31.16b, v26.16b, #12 \n" + "ext v29.16b, v31.16b, v27.16b, #12 \n" + + "ext v10.16b, v6.16b, v31.16b, #4 \n" + "ext v15.16b, v11.16b, v31.16b, #4 \n" + "ext v20.16b, v16.16b, v31.16b, #4 \n" + "ext v25.16b, v21.16b, v31.16b, #4 \n" + "ext v30.16b, v26.16b, v31.16b, #4 \n" + + "bif v8.16b, v31.16b, v0.16b \n" + "bif v9.16b, v31.16b, v1.16b \n" + "bif v6.16b, v31.16b, v2.16b \n" + "bif v7.16b, v31.16b, v3.16b \n" + + "bif v13.16b, v31.16b, v0.16b \n" + "bif v14.16b, v31.16b, v1.16b \n" + "bif v11.16b, v31.16b, v2.16b \n" + "bif v12.16b, v31.16b, v3.16b \n" + + "bif v18.16b, v31.16b, v0.16b \n" + "bif v19.16b, v31.16b, v1.16b \n" + "bif v16.16b, v31.16b, v2.16b \n" + "bif v17.16b, v31.16b, v3.16b \n" + + "ld1 {v10.s}[3], [%[din_ptr0]] \n" + "ld1 {v15.s}[3], [%[din_ptr1]] \n" + "ld1 {v20.s}[3], [%[din_ptr2]] \n" + "ld1 {v25.s}[3], [%[din_ptr3]] \n" + "ld1 {v30.s}[3], [%[din_ptr4]] \n" + + "bif v23.16b, v31.16b, v0.16b \n" + "bif v24.16b, v31.16b, v1.16b \n" + "bif v21.16b, v31.16b, v2.16b \n" + "bif v22.16b, v31.16b, v3.16b \n" + + "bif v28.16b, v31.16b, v0.16b \n" + "bif v29.16b, v31.16b, v1.16b \n" + "bif v26.16b, v31.16b, v2.16b \n" + "bif v27.16b, v31.16b, v3.16b \n" + + "bif v10.16b, v31.16b, v4.16b \n" + "bif v15.16b, v31.16b, v4.16b \n" + "bif v20.16b, v31.16b, v4.16b \n" + "bif v25.16b, v31.16b, v4.16b \n" + "bif v30.16b, v31.16b, v4.16b \n" + + "ld1 {v4.4s}, [%[vbias]] \n" + "mov v5.16b, v31.16b \n" + + "ld1 {v0.4s, v1.4s}, [%[weights]], #32 \n" // load weights 0-7 + "ld1 {v2.4s, v3.4s}, [%[weights]], #32 \n" // load weights 8-15 + + //! compute + "fmla v4.4s, v8.4s, v0.s[0] \n" // out r0: w0 + "fmla v5.4s, v9.4s, v0.s[1] \n" // out r0: w1 + "fmla v4.4s, v6.4s, v0.s[2] \n" // out r0: w2 + "fmla v5.4s, v7.4s, v0.s[3] \n" // out r0: w3 + + "fmla v4.4s, v10.4s, v1.s[0] \n" // out r0: w4 + "fmla v5.4s, v13.4s, v1.s[1] \n" // out r0: w5 + "fmla v4.4s, v14.4s, v1.s[2] \n" // out r0: w6 + "fmla v5.4s, v11.4s, v1.s[3] \n" // out r0: w7 + + "ld1 {v6.4s, v7.4s}, [%[weights]], #32 \n" // load weights 16-23 + "ld1 {v8.s}[0], [%[weights]] \n" // load weights 24 + + "fmla v4.4s, v12.4s, v2.s[0] \n" // out r0: w8 + "fmla v5.4s, v15.4s, v2.s[1] \n" // out r0: w9 + "fmla v4.4s, v18.4s, v2.s[2] \n" // out r0: w10 + "fmla v5.4s, v19.4s, v2.s[3] \n" // out r0: w11 + + "fmla v4.4s, v16.4s, v3.s[0] \n" // out r0: w12 + "fmla v5.4s, v17.4s, v3.s[1] \n" // out r0: w13 + "fmla v4.4s, v20.4s, v3.s[2] \n" // out r0: w14 + "fmla v5.4s, v23.4s, v3.s[3] \n" // out r0: w15 + + "fmla v4.4s, v24.4s, v6.s[0] \n" // out r0: w16 + "fmla v5.4s, v21.4s, v6.s[1] \n" // out r0: w17 + "fmla v4.4s, v22.4s, v6.s[2] \n" // out r0: w18 + "fmla v5.4s, v25.4s, v6.s[3] \n" // out r0: w19 + + "fmla v4.4s, v28.4s, v7.s[0] \n" // out r0: w20 + "fmla v5.4s, v29.4s, v7.s[1] \n" // out r0: w21 + "fmla v4.4s, v26.4s, v7.s[2] \n" // out r0: w22 + "fmla v5.4s, v27.4s, v7.s[3] \n" // out r0: w23 + "fmla v4.4s, v30.4s, v8.s[0] \n" // out r0: w24 + + "fadd v4.4s, v4.4s, v5.4s \n" // add out to v4 + "fmax v4.4s, v4.4s, v31.4s \n" + "st1 {v4.4s}, [%[out_buf0]] \n" + + : [dout_ptr0] "+r"(dout_ptr0), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [mask] "+r"(mask_ptr), + [weights] "+r"(weights_ptr) + : [vbias] "r"(vbias), + [out_buf0] "r"(out_buf0), + [out_buf1] "r"(out_buf1), + [s_8] "r"(s_8) + : "memory", + "cc", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31"); + for (int i = 0; i < w_out; ++i) { + dout_ptr0[i] = out_buf0[i]; + } + din0 = din2; + din1 = din3; + din2 = din4; + din3 = din2 + w_in; + din4 = din3 + w_in; + dout0 += w_out; + } + } + } +} + +#else + +//! larger depthwise, win >= 9; +void conv_depthwise_5x5s2p2(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + // printf("invoke 5x5s2p2 armv7\n"); + CHECK_GE(w_in, 9) << "only support win >= 9"; + int w_out_round = (w_out + 3) / 4 * 4; + int cnt = (w_out_round - 4) / 4; + int mid_cnt = cnt - 1; + int right_start = cnt * 2 * 4 - 2; + int mask_cnt = 12 - (w_in - right_start); + int mask[12]; + memset(mask, 0xff, 12 * sizeof(int)); + for (int i = 0; i < mask_cnt; ++i) { + mask[11 - i] = 0; + } + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + int in_spatial_size = w_in * h_in; + int out_spatial_size = w_out * h_out; + int weights_saptial_size = 25; + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * in_spatial_size * ch_in; + float* dout_batch = dout + n * out_spatial_size * ch_out; +#pragma omp parallel for + for (int c = 0; c < ch_in; ++c) { + const float* din_ch = din_batch + c * in_spatial_size; + float* dout_ch = dout_batch + c * out_spatial_size; + const float* din0 = zero_ptr; + const float* din1 = zero_ptr; + const float* din2 = din_ch; + const float* din3 = din2 + w_in; + const float* din4 = din3 + w_in; + + float out_buf0[4]; + float* dout0 = dout_ch; + + const float* weights_c = weights + c * weights_saptial_size; + float32x4_t w0 = vld1q_f32(weights_c); + float32x4_t w1 = vld1q_f32(weights_c + 4); + float32x4_t w2 = vld1q_f32(weights_c + 8); + float32x4_t w3 = vld1q_f32(weights_c + 12); + float32x4_t w4 = vld1q_f32(weights_c + 16); + float32x4_t w5 = vld1q_f32(weights_c + 20); + for (int h = 0; h < h_out; h += 1) { + //! (h * 2 - 2) + 4 > h_in - 1 + if (h * 2 + 3 > h_in) { + switch (h * 2 + 3 - h_in) { + case 4: + din1 = zero_ptr; + case 3: + din2 = zero_ptr; + case 2: + din3 = zero_ptr; + case 1: + din4 = zero_ptr; + default: + break; + } + } + const float* din_ptr0 = din0; + const float* din_ptr1 = din1; + const float* din_ptr2 = din2; + const float* din_ptr3 = din3; + const float* din_ptr4 = din4; + + const float* weights_ptr = weights_c + 24; + float* dout_ptr0 = dout0; + + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[c]; + } + float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; + int* mask_ptr = mask; + int loop = mid_cnt; + const int s_8 = 8; + const int s_16 = 16; + + asm volatile( + "vmov.i32 q15, #0x0 \n" + "pld [%[din_ptr0]] \n" + "pld [%[din_ptr1]] \n" + "pld [%[din_ptr2]] \n" + "pld [%[din_ptr3]] \n" + "pld [%[din_ptr4]] \n" + "pld [%[mask]] \n" + + // left + "vld2.32 {d16-d19}, [%[din_ptr0]]! \n" + "vld1.32 {d26-d29}, [%[vbias]] \n" + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + "vmov.32 q14, q15 \n" + + // r0 + "vmla.f32 q13, q8, %f[w0][0] \n" + "vmla.f32 q14, q9, %f[w0][1] \n" + + "vld1.32 {d21[1]}, [%[din_ptr0]] \n" + "vld2.32 {d16-d19}, [%[din_ptr1]]! \n" + "sub %[din_ptr0], #8 \n" + + "vmla.f32 q13, q6, %e[w0][0] \n" + "vmla.f32 q14, q7, %e[w0][1] \n" + "vmla.f32 q13, q10, %e[w1][0] \n" + + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + + // r1 + "vmla.f32 q13, q8, %f[w1][1] \n" + "vmla.f32 q14, q9, %e[w2][0] \n" + + "vld1.32 {d21[1]}, [%[din_ptr1]] \n" + "vld2.32 {d16-d19}, [%[din_ptr2]]! \n" + "sub %[din_ptr1], #8 \n" + + "vmla.f32 q13, q6, %e[w1][1] \n" + "vmla.f32 q14, q7, %f[w1][0] \n" + "vmla.f32 q13, q10, %e[w2][1] \n" + + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + + // r2 + "vmla.f32 q13, q8, %e[w3][0] \n" + "vmla.f32 q14, q9, %e[w3][1] \n" + + "vld1.32 {d21[1]}, [%[din_ptr2]] \n" + "vld2.32 {d16-d19}, [%[din_ptr3]]! \n" + "sub %[din_ptr2], #8 \n" + + "vmla.f32 q13, q6, %f[w2][0] \n" + "vmla.f32 q14, q7, %f[w2][1] \n" + "vmla.f32 q13, q10, %f[w3][0] \n" + + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + + // r3 + "vmla.f32 q13, q8, %e[w4][1] \n" + "vmla.f32 q14, q9, %f[w4][0] \n" + + "vld1.32 {d21[1]}, [%[din_ptr3]] \n" + "vld2.32 {d16-d19}, [%[din_ptr4]]! \n" + "sub %[din_ptr3], #8 \n" + + "vmla.f32 q13, q6, %f[w3][1] \n" + "vmla.f32 q14, q7, %e[w4][0] \n" + "vmla.f32 q13, q10, %f[w4][1] \n" + + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + + // r4 + "vmla.f32 q13, q6, %e[w5][0] \n" + "vmla.f32 q14, q7, %e[w5][1] \n" + + "vld1.32 {d21[1]}, [%[din_ptr4]] \n" + "vld2.32 {d12-d15}, [%[din_ptr0]], %[s_8] \n" + "sub %[din_ptr4], #8 \n" + + "vmla.f32 q13, q8, %f[w5][0] \n" + "vmla.f32 q14, q9, %f[w5][1] \n" + + "vld2.32 {d16-d19}, [%[din_ptr0]], %[s_8] \n" + + "vmov.32 q12, %q[w0] \n" + "vld1.32 {%e[w0][0]}, [%[weights]] \n" + "vmla.f32 q13, q10, %e[w0][0] \n" + "vadd.f32 q13, q13, q14 \n" + "vmov.32 %q[w0], q12 \n" + "cmp %[mid_cnt], #1 \n" + "vld2.32 {d20-d23}, [%[din_ptr0]], %[s_16] \n" + "vst1.32 {d26-d27}, [%[dout_ptr0]]! \n" + "pld [%[din_ptr0]] \n" + "blt 2f \n" + + // mid + "1: \n" + "vld1.32 {d26-d27}, [%[vbias]] \n" + "vmov.32 q14, q15 \n" + + // r0 + "vmla.f32 q13, q6, %e[w0][0] \n" + "vmla.f32 q14, q7, %e[w0][1] \n" + + "vld2.32 {d12-d15}, [%[din_ptr1]], %[s_8] \n" + + "vmla.f32 q13, q8, %f[w0][0] \n" + "vmla.f32 q14, q9, %f[w0][1] \n" + + "vld2.32 {d16-d19}, [%[din_ptr1]], %[s_8] \n" + + "vmla.f32 q13, q10, %e[w1][0] \n" + + "vld2.32 {d20-d23}, [%[din_ptr1]], %[s_16] \n" + + // r1 + "vmla.f32 q13, q6, %e[w1][1] \n" + "vmla.f32 q14, q7, %f[w1][0] \n" + "pld [%[din_ptr1]] \n" + + "vld2.32 {d12-d15}, [%[din_ptr2]], %[s_8] \n" + + "vmla.f32 q13, q8, %f[w1][1] \n" + "vmla.f32 q14, q9, %e[w2][0] \n" + + "vld2.32 {d16-d19}, [%[din_ptr2]], %[s_8] \n" + + "vmla.f32 q13, q10, %e[w2][1] \n" + + "vld2.32 {d20-d23}, [%[din_ptr2]], %[s_16] \n" + + // r2 + "vmla.f32 q13, q6, %f[w2][0] \n" + "vmla.f32 q14, q7, %f[w2][1] \n" + "pld [%[din_ptr2]] \n" + + "vld2.32 {d12-d15}, [%[din_ptr3]], %[s_8] \n" + + "vmla.f32 q13, q8, %e[w3][0] \n" + "vmla.f32 q14, q9, %e[w3][1] \n" + + "vld2.32 {d16-d19}, [%[din_ptr3]], %[s_8] \n" + + "vmla.f32 q13, q10, %f[w3][0] \n" + + "vld2.32 {d20-d23}, [%[din_ptr3]], %[s_16] \n" + + // r3 + "vmla.f32 q13, q6, %f[w3][1] \n" + "vmla.f32 q14, q7, %e[w4][0] \n" + "pld [%[din_ptr3]] \n" + + "vld2.32 {d12-d15}, [%[din_ptr4]], %[s_8] \n" + + "vmla.f32 q13, q8, %e[w4][1] \n" + "vmla.f32 q14, q9, %f[w4][0] \n" + + "vld2.32 {d16-d19}, [%[din_ptr4]], %[s_8] \n" + + "vmla.f32 q13, q10, %f[w4][1] \n" + + "vld2.32 {d20-d23}, [%[din_ptr4]], %[s_16] \n" + + // r4 + "vmla.f32 q13, q6, %e[w5][0] \n" + "vmla.f32 q14, q7, %e[w5][1] \n" + "pld [%[din_ptr4]] \n" + + "vld2.32 {d12-d15}, [%[din_ptr0]], %[s_8] \n" + "vld1.32 {%e[w0][0]}, [%[weights]] \n" + + "vmla.f32 q13, q8, %f[w5][0] \n" + "vmla.f32 q14, q9, %f[w5][1] \n" + + "vld2.32 {d16-d19}, [%[din_ptr0]], %[s_8] \n" + + "vmla.f32 q13, q10, %e[w0][0] \n" + + "vld2.32 {d20-d23}, [%[din_ptr0]], %[s_16] \n" + + "vmov.32 %q[w0], q12 \n" + "vadd.f32 q13, q13, q14 \n" + "subs %[mid_cnt], #1 \n" + "vst1.32 {d26-d27}, [%[dout_ptr0]]! \n" + "bne 1b \n" + + "2: \n" + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vld1.32 {d26-d27}, [%[vbias]] \n" + "vmov.32 q14, q15 \n" + + // r0 + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q13, q6, %e[w0][0] \n" + "vmla.f32 q14, q7, %e[w0][1] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vld2.32 {d12-d15}, [%[din_ptr1]], %[s_8] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q13, q8, %f[w0][0] \n" + "vmla.f32 q14, q9, %f[w0][1] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "sub %[mask], #16 \n" + "vld2.32 {d16-d19}, [%[din_ptr1]], %[s_8] \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q13, q10, %e[w1][0] \n" + + // r1 + "vld2.32 {d20-d23}, [%[din_ptr1]] \n" + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q13, q6, %e[w1][1] \n" + "vmla.f32 q14, q7, %f[w1][0] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vld2.32 {d12-d15}, [%[din_ptr2]], %[s_8] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q13, q8, %f[w1][1] \n" + "vmla.f32 q14, q9, %e[w2][0] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "sub %[mask], #16 \n" + "vld2.32 {d16-d19}, [%[din_ptr2]], %[s_8] \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q13, q10, %e[w2][1] \n" + + // r2 + "vld2.32 {d20-d23}, [%[din_ptr2]] \n" + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q13, q6, %f[w2][0] \n" + "vmla.f32 q14, q7, %f[w2][1] \n" + + "vld2.32 {d12-d15}, [%[din_ptr3]], %[s_8] \n" + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q13, q8, %e[w3][0] \n" + "vmla.f32 q14, q9, %e[w3][1] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "sub %[mask], #16 \n" + "vld2.32 {d16-d19}, [%[din_ptr3]], %[s_8] \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q13, q10, %f[w3][0] \n" + + // r3 + "vld2.32 {d20-d23}, [%[din_ptr3]] \n" + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q13, q6, %f[w3][1] \n" + "vmla.f32 q14, q7, %e[w4][0] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vld2.32 {d12-d15}, [%[din_ptr4]], %[s_8] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q13, q8, %e[w4][1] \n" + "vmla.f32 q14, q9, %f[w4][0] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "sub %[mask], #16 \n" + "vld2.32 {d16-d19}, [%[din_ptr4]], %[s_8] \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q13, q10, %f[w4][1] \n" + + // r4 + "vld2.32 {d20-d23}, [%[din_ptr4]] \n" + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q13, q6, %e[w5][0] \n" + "vmla.f32 q14, q7, %e[w5][1] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vld1.32 {d12[0]}, [%[weights]] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q13, q8, %f[w5][0] \n" + "vmla.f32 q14, q9, %f[w5][1] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q13, q10, d12[0] \n" + + "vadd.f32 q13, q13, q14 \n" + "vst1.32 {d26-d27}, [%[out_buf0]] \n" + + : [dout_ptr0] "+r"(dout_ptr0), + [mid_cnt] "+r"(loop), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [mask] "+r"(mask_ptr), + [weights] "+r"(weights_ptr) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [vbias] "r"(vbias), + [out_buf0] "r"(out_buf0), + [s_8] "r"(s_8), + [s_16] "r"(s_16) + : "memory", + "cc", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + + int remain_cnt = w_out - (mid_cnt + 1) * 4; + for (int i = 0; i < remain_cnt; ++i) { + dout_ptr0[i] = out_buf0[i]; + } + + din0 = din2; + din1 = din3; + din2 = din4; + din3 = din2 + w_in; + din4 = din3 + w_in; + dout0 += w_out; + } + } + } +} + +//! larger depthwise, win >= 9; +void conv_depthwise_5x5s2p2_relu(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + // printf("invoke 5x5s2p2 armv7\n"); + CHECK_GE(w_in, 9) << "only support win >= 9"; + int w_out_round = (w_out + 3) / 4 * 4; + int cnt = (w_out_round - 4) / 4; + int mid_cnt = cnt - 1; + int right_start = cnt * 2 * 4 - 2; + int mask_cnt = 12 - (w_in - right_start); + int mask[12]; + memset(mask, 0xff, 12 * sizeof(int)); + for (int i = 0; i < mask_cnt; ++i) { + mask[11 - i] = 0; + } + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + int in_spatial_size = w_in * h_in; + int out_spatial_size = w_out * h_out; + int weights_saptial_size = 25; + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * in_spatial_size * ch_in; + float* dout_batch = dout + n * out_spatial_size * ch_out; +#pragma omp parallel for + for (int c = 0; c < ch_in; ++c) { + const float* din_ch = din_batch + c * in_spatial_size; + float* dout_ch = dout_batch + c * out_spatial_size; + const float* din0 = zero_ptr; + const float* din1 = zero_ptr; + const float* din2 = din_ch; + const float* din3 = din2 + w_in; + const float* din4 = din3 + w_in; + + float out_buf0[4]; + float* dout0 = dout_ch; + + const float* weights_c = weights + c * weights_saptial_size; + float32x4_t w0 = vld1q_f32(weights_c); + float32x4_t w1 = vld1q_f32(weights_c + 4); + float32x4_t w2 = vld1q_f32(weights_c + 8); + float32x4_t w3 = vld1q_f32(weights_c + 12); + float32x4_t w4 = vld1q_f32(weights_c + 16); + float32x4_t w5 = vld1q_f32(weights_c + 20); + for (int h = 0; h < h_out; h += 1) { + //! (h * 2 - 2) + 4 > h_in - 1 + if (h * 2 + 3 > h_in) { + switch (h * 2 + 3 - h_in) { + case 4: + din1 = zero_ptr; + case 3: + din2 = zero_ptr; + case 2: + din3 = zero_ptr; + case 1: + din4 = zero_ptr; + default: + break; + } + } + const float* din_ptr0 = din0; + const float* din_ptr1 = din1; + const float* din_ptr2 = din2; + const float* din_ptr3 = din3; + const float* din_ptr4 = din4; + + const float* weights_ptr = weights_c + 24; + float* dout_ptr0 = dout0; + + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[c]; + } + float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; + int* mask_ptr = mask; + int loop = mid_cnt; + const int s_8 = 8; + const int s_16 = 16; + + asm volatile( + "vmov.i32 q15, #0x0 \n" + "pld [%[din_ptr0]] \n" + "pld [%[din_ptr1]] \n" + "pld [%[din_ptr2]] \n" + "pld [%[din_ptr3]] \n" + "pld [%[din_ptr4]] \n" + "pld [%[mask]] \n" + + // left + "vld2.32 {d16-d19}, [%[din_ptr0]]! \n" + "vld1.32 {d26-d29}, [%[vbias]] \n" + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + "vmov.32 q14, q15 \n" + + // r0 + "vmla.f32 q13, q8, %f[w0][0] \n" + "vmla.f32 q14, q9, %f[w0][1] \n" + + "vld1.32 {d21[1]}, [%[din_ptr0]] \n" + "vld2.32 {d16-d19}, [%[din_ptr1]]! \n" + "sub %[din_ptr0], #8 \n" + + "vmla.f32 q13, q6, %e[w0][0] \n" + "vmla.f32 q14, q7, %e[w0][1] \n" + "vmla.f32 q13, q10, %e[w1][0] \n" + + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + + // r1 + "vmla.f32 q13, q8, %f[w1][1] \n" + "vmla.f32 q14, q9, %e[w2][0] \n" + + "vld1.32 {d21[1]}, [%[din_ptr1]] \n" + "vld2.32 {d16-d19}, [%[din_ptr2]]! \n" + "sub %[din_ptr1], #8 \n" + + "vmla.f32 q13, q6, %e[w1][1] \n" + "vmla.f32 q14, q7, %f[w1][0] \n" + "vmla.f32 q13, q10, %e[w2][1] \n" + + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + + // r2 + "vmla.f32 q13, q8, %e[w3][0] \n" + "vmla.f32 q14, q9, %e[w3][1] \n" + + "vld1.32 {d21[1]}, [%[din_ptr2]] \n" + "vld2.32 {d16-d19}, [%[din_ptr3]]! \n" + "sub %[din_ptr2], #8 \n" + + "vmla.f32 q13, q6, %f[w2][0] \n" + "vmla.f32 q14, q7, %f[w2][1] \n" + "vmla.f32 q13, q10, %f[w3][0] \n" + + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + + // r3 + "vmla.f32 q13, q8, %e[w4][1] \n" + "vmla.f32 q14, q9, %f[w4][0] \n" + + "vld1.32 {d21[1]}, [%[din_ptr3]] \n" + "vld2.32 {d16-d19}, [%[din_ptr4]]! \n" + "sub %[din_ptr3], #8 \n" + + "vmla.f32 q13, q6, %f[w3][1] \n" + "vmla.f32 q14, q7, %e[w4][0] \n" + "vmla.f32 q13, q10, %f[w4][1] \n" + + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + + // r4 + "vmla.f32 q13, q6, %e[w5][0] \n" + "vmla.f32 q14, q7, %e[w5][1] \n" + + "vld1.32 {d21[1]}, [%[din_ptr4]] \n" + "vld2.32 {d12-d15}, [%[din_ptr0]], %[s_8] \n" + "sub %[din_ptr4], #8 \n" + + "vmla.f32 q13, q8, %f[w5][0] \n" + "vmla.f32 q14, q9, %f[w5][1] \n" + + "vld2.32 {d16-d19}, [%[din_ptr0]], %[s_8] \n" + + "vmov.32 q12, %q[w0] \n" + "vld1.32 {%e[w0][0]}, [%[weights]] \n" + "vmla.f32 q13, q10, %e[w0][0] \n" + "vadd.f32 q13, q13, q14 \n" + "vmov.f32 %q[w0], q12 \n" + "vmax.f32 q13, q13, q15 \n" + "cmp %[mid_cnt], #1 \n" + "vld2.32 {d20-d23}, [%[din_ptr0]], %[s_16] \n" + "vst1.32 {d26-d27}, [%[dout_ptr0]]! \n" + "pld [%[din_ptr0]] \n" + "blt 2f \n" + + // mid + "1: \n" + "vld1.32 {d26-d27}, [%[vbias]] \n" + "vmov.32 q14, q15 \n" + + // r0 + "vmla.f32 q13, q6, %e[w0][0] \n" + "vmla.f32 q14, q7, %e[w0][1] \n" + + "vld2.32 {d12-d15}, [%[din_ptr1]], %[s_8] \n" + + "vmla.f32 q13, q8, %f[w0][0] \n" + "vmla.f32 q14, q9, %f[w0][1] \n" + + "vld2.32 {d16-d19}, [%[din_ptr1]], %[s_8] \n" + + "vmla.f32 q13, q10, %e[w1][0] \n" + + "vld2.32 {d20-d23}, [%[din_ptr1]], %[s_16] \n" + + // r1 + "vmla.f32 q13, q6, %e[w1][1] \n" + "vmla.f32 q14, q7, %f[w1][0] \n" + "pld [%[din_ptr1]] \n" + + "vld2.32 {d12-d15}, [%[din_ptr2]], %[s_8] \n" + + "vmla.f32 q13, q8, %f[w1][1] \n" + "vmla.f32 q14, q9, %e[w2][0] \n" + + "vld2.32 {d16-d19}, [%[din_ptr2]], %[s_8] \n" + + "vmla.f32 q13, q10, %e[w2][1] \n" + + "vld2.32 {d20-d23}, [%[din_ptr2]], %[s_16] \n" + + // r2 + "vmla.f32 q13, q6, %f[w2][0] \n" + "vmla.f32 q14, q7, %f[w2][1] \n" + "pld [%[din_ptr2]] \n" + + "vld2.32 {d12-d15}, [%[din_ptr3]], %[s_8] \n" + + "vmla.f32 q13, q8, %e[w3][0] \n" + "vmla.f32 q14, q9, %e[w3][1] \n" + + "vld2.32 {d16-d19}, [%[din_ptr3]], %[s_8] \n" + + "vmla.f32 q13, q10, %f[w3][0] \n" + + "vld2.32 {d20-d23}, [%[din_ptr3]], %[s_16] \n" + + // r3 + "vmla.f32 q13, q6, %f[w3][1] \n" + "vmla.f32 q14, q7, %e[w4][0] \n" + "pld [%[din_ptr3]] \n" + + "vld2.32 {d12-d15}, [%[din_ptr4]], %[s_8] \n" + + "vmla.f32 q13, q8, %e[w4][1] \n" + "vmla.f32 q14, q9, %f[w4][0] \n" + + "vld2.32 {d16-d19}, [%[din_ptr4]], %[s_8] \n" + + "vmla.f32 q13, q10, %f[w4][1] \n" + + "vld2.32 {d20-d23}, [%[din_ptr4]], %[s_16] \n" + + // r4 + "vmla.f32 q13, q6, %e[w5][0] \n" + "vmla.f32 q14, q7, %e[w5][1] \n" + "pld [%[din_ptr4]] \n" + + "vld2.32 {d12-d15}, [%[din_ptr0]], %[s_8] \n" + "vld1.32 {%e[w0][0]}, [%[weights]] \n" + + "vmla.f32 q13, q8, %f[w5][0] \n" + "vmla.f32 q14, q9, %f[w5][1] \n" + + "vld2.32 {d16-d19}, [%[din_ptr0]], %[s_8] \n" + + "vmla.f32 q13, q10, %e[w0][0] \n" + + "vld2.32 {d20-d23}, [%[din_ptr0]], %[s_16] \n" + + "vmov.32 %q[w0], q12 \n" + "vadd.f32 q13, q13, q14 \n" + "vmax.f32 q13, q13, q15 \n" + "subs %[mid_cnt], #1 \n" + "vst1.32 {d26-d27}, [%[dout_ptr0]]! \n" + "bne 1b \n" + + "2: \n" + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vld1.32 {d26-d27}, [%[vbias]] \n" + "vmov.32 q14, q15 \n" + + // r0 + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q13, q6, %e[w0][0] \n" + "vmla.f32 q14, q7, %e[w0][1] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vld2.32 {d12-d15}, [%[din_ptr1]], %[s_8] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q13, q8, %f[w0][0] \n" + "vmla.f32 q14, q9, %f[w0][1] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "sub %[mask], #16 \n" + "vld2.32 {d16-d19}, [%[din_ptr1]], %[s_8] \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q13, q10, %e[w1][0] \n" + + // r1 + "vld2.32 {d20-d23}, [%[din_ptr1]] \n" + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q13, q6, %e[w1][1] \n" + "vmla.f32 q14, q7, %f[w1][0] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vld2.32 {d12-d15}, [%[din_ptr2]], %[s_8] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q13, q8, %f[w1][1] \n" + "vmla.f32 q14, q9, %e[w2][0] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "sub %[mask], #16 \n" + "vld2.32 {d16-d19}, [%[din_ptr2]], %[s_8] \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q13, q10, %e[w2][1] \n" + + // r2 + "vld2.32 {d20-d23}, [%[din_ptr2]] \n" + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q13, q6, %f[w2][0] \n" + "vmla.f32 q14, q7, %f[w2][1] \n" + + "vld2.32 {d12-d15}, [%[din_ptr3]], %[s_8] \n" + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q13, q8, %e[w3][0] \n" + "vmla.f32 q14, q9, %e[w3][1] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "sub %[mask], #16 \n" + "vld2.32 {d16-d19}, [%[din_ptr3]], %[s_8] \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q13, q10, %f[w3][0] \n" + + // r3 + "vld2.32 {d20-d23}, [%[din_ptr3]] \n" + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q13, q6, %f[w3][1] \n" + "vmla.f32 q14, q7, %e[w4][0] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vld2.32 {d12-d15}, [%[din_ptr4]], %[s_8] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q13, q8, %e[w4][1] \n" + "vmla.f32 q14, q9, %f[w4][0] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "sub %[mask], #16 \n" + "vld2.32 {d16-d19}, [%[din_ptr4]], %[s_8] \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q13, q10, %f[w4][1] \n" + + // r4 + "vld2.32 {d20-d23}, [%[din_ptr4]] \n" + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q13, q6, %e[w5][0] \n" + "vmla.f32 q14, q7, %e[w5][1] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vld1.32 {d12[0]}, [%[weights]] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q13, q8, %f[w5][0] \n" + "vmla.f32 q14, q9, %f[w5][1] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q13, q10, d12[0] \n" + + "vadd.f32 q13, q13, q14 \n" + "vmax.f32 q13, q13, q15 \n" + "vst1.32 {d26-d27}, [%[out_buf0]] \n" + + : [dout_ptr0] "+r"(dout_ptr0), + [mid_cnt] "+r"(loop), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [mask] "+r"(mask_ptr), + [weights] "+r"(weights_ptr) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [vbias] "r"(vbias), + [out_buf0] "r"(out_buf0), + [s_8] "r"(s_8), + [s_16] "r"(s_16) + : "memory", + "cc", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + + int remain_cnt = w_out - (mid_cnt + 1) * 4; + for (int i = 0; i < remain_cnt; ++i) { + dout_ptr0[i] = out_buf0[i]; + } + + din0 = din2; + din1 = din3; + din2 = din4; + din3 = din2 + w_in; + din4 = din3 + w_in; + dout0 += w_out; + } + } + } +} + +//! small depthwise, win < 9; +void conv_depthwise_5x5s2p2_s(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + CHECK_LT(w_in, 9) << "only support win < 9"; + int w_out_round = (w_out + 3) / 4 * 4; + int mask_cnt = 12 - w_in - 2; + int mask[12]; + memset(mask, 0xff, 12 * sizeof(int)); + for (int i = 0; i < mask_cnt; ++i) { + mask[11 - i] = 0; + } + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + int in_spatial_size = w_in * h_in; + int out_spatial_size = w_out * h_out; + int weights_saptial_size = 25; + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * in_spatial_size * ch_in; + float* dout_batch = dout + n * out_spatial_size * ch_out; +#pragma omp parallel for + for (int c = 0; c < ch_in; ++c) { + const float* din_ch = din_batch + c * in_spatial_size; + float* dout_ch = dout_batch + c * out_spatial_size; + const float* din0 = zero_ptr; + const float* din1 = zero_ptr; + const float* din2 = din_ch; + const float* din3 = din2 + w_in; + const float* din4 = din3 + w_in; + + float out_buf0[4]; + float out_buf1[4]; + float* dout0 = dout_ch; + float* dout1 = dout0 + w_out; + + const float* weights_c = weights + c * weights_saptial_size; + float32x4_t w0 = vld1q_f32(weights_c); + float32x4_t w1 = vld1q_f32(weights_c + 4); + float32x4_t w2 = vld1q_f32(weights_c + 8); + float32x4_t w3 = vld1q_f32(weights_c + 12); + float32x4_t w4 = vld1q_f32(weights_c + 16); + float32x4_t w5 = vld1q_f32(weights_c + 20); + for (int h = 0; h < h_out; h += 1) { + //! (h * 2 - 2) + 4 > h_in - 1 + if (h * 2 + 3 > h_in) { + switch (h * 2 + 3 - h_in) { + case 4: + din1 = zero_ptr; + case 3: + din2 = zero_ptr; + case 2: + din3 = zero_ptr; + case 1: + din4 = zero_ptr; + default: + break; + } + } + const float* din_ptr0 = din0; + const float* din_ptr1 = din1; + const float* din_ptr2 = din2; + const float* din_ptr3 = din3; + const float* din_ptr4 = din4; + + const float* weights_ptr = weights_c + 24; + float* dout_ptr0 = dout0; + + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[c]; + } + float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; + int* mask_ptr = mask; + const int s_8 = 8; + + asm volatile( + "vmov.i32 q15, #0x0 \n" + "pld [%[din_ptr0]] \n" + "pld [%[din_ptr1]] \n" + "pld [%[din_ptr2]] \n" + "pld [%[din_ptr3]] \n" + "pld [%[din_ptr4]] \n" + "vld1.32 {d26-d27}, [%[vbias]] \n" + "vmov.32 q14, q15 \n" + "vld2.32 {d16-d19}, [%[din_ptr0]]! \n" + + // r0 + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + "vld1.32 {d21[1]}, [%[din_ptr0]] \n" + + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q13, q6, %e[w0][0] \n" + "vmla.f32 q14, q7, %e[w0][1] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q13, q8, %f[w0][0] \n" + "vmla.f32 q14, q9, %f[w0][1] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "vld2.32 {d16-d19}, [%[din_ptr1]]! \n" + "sub %[mask], #16 \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q13, q10, %e[w1][0] \n" + + // r1 + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + "vld1.32 {d21[1]}, [%[din_ptr1]] \n" + + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q14, q6, %e[w1][1] \n" + "vmla.f32 q13, q7, %f[w1][0] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q14, q8, %f[w1][1] \n" + "vmla.f32 q13, q9, %e[w2][0] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "vld2.32 {d16-d19}, [%[din_ptr2]]! \n" + "sub %[mask], #16 \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q14, q10, %e[w2][1] \n" + + // r2 + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + "vld1.32 {d21[1]}, [%[din_ptr2]] \n" + + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q13, q6, %f[w2][0] \n" + "vmla.f32 q14, q7, %f[w2][1] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q13, q8, %e[w3][0] \n" + "vmla.f32 q14, q9, %e[w3][1] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "vld2.32 {d16-d19}, [%[din_ptr3]]! \n" + "sub %[mask], #16 \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q13, q10, %f[w3][0] \n" + + // r3 + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + "vld1.32 {d21[1]}, [%[din_ptr3]] \n" + + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q14, q6, %f[w3][1] \n" + "vmla.f32 q13, q7, %e[w4][0] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q14, q8, %e[w4][1] \n" + "vmla.f32 q13, q9, %f[w4][0] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "vld2.32 {d16-d19}, [%[din_ptr4]]! \n" + "sub %[mask], #16 \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q14, q10, %f[w4][1] \n" + + // r4 + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + "vld1.32 {d21[1]}, [%[din_ptr4]] \n" + + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q13, q6, %e[w5][0] \n" + "vmla.f32 q14, q7, %e[w5][1] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vld1.32 {d12[0]}, [%[weights]] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q13, q8, %f[w5][0] \n" + "vmla.f32 q14, q9, %f[w5][1] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q13, q10, d12[0] \n" + + "vadd.f32 q13, q13, q14 \n" + "vst1.32 {d26-d27}, [%[out_buf0]] \n" + + : [dout_ptr0] "+r"(dout_ptr0), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [mask] "+r"(mask_ptr), + [weights] "+r"(weights_ptr) + : [vbias] "r"(vbias), + [out_buf0] "r"(out_buf0), + [s_8] "r"(s_8), + [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5) + : "memory", + "cc", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + for (int i = 0; i < w_out; ++i) { + dout_ptr0[i] = out_buf0[i]; + } + din0 = din2; + din1 = din3; + din2 = din4; + din3 = din2 + w_in; + din4 = din3 + w_in; + dout0 += w_out; + } + } + } +} + +//! small depthwise, win < 9; +void conv_depthwise_5x5s2p2_relu_s(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + ARMContext* ctx) { + CHECK_LT(w_in, 9) << "only support win < 9\n"; + int w_out_round = (w_out + 3) / 4 * 4; + int mask_cnt = 12 - w_in - 2; + int mask[12]; + memset(mask, 0xff, 12 * sizeof(int)); + for (int i = 0; i < mask_cnt; ++i) { + mask[11 - i] = 0; + } + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + int in_spatial_size = w_in * h_in; + int out_spatial_size = w_out * h_out; + int weights_saptial_size = 25; + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * in_spatial_size * ch_in; + float* dout_batch = dout + n * out_spatial_size * ch_out; +#pragma omp parallel for + for (int c = 0; c < ch_in; ++c) { + const float* din_ch = din_batch + c * in_spatial_size; + float* dout_ch = dout_batch + c * out_spatial_size; + const float* din0 = zero_ptr; + const float* din1 = zero_ptr; + const float* din2 = din_ch; + const float* din3 = din2 + w_in; + const float* din4 = din3 + w_in; + + float out_buf0[4]; + float out_buf1[4]; + float* dout0 = dout_ch; + float* dout1 = dout0 + w_out; + + const float* weights_c = weights + c * weights_saptial_size; + float32x4_t w0 = vld1q_f32(weights_c); + float32x4_t w1 = vld1q_f32(weights_c + 4); + float32x4_t w2 = vld1q_f32(weights_c + 8); + float32x4_t w3 = vld1q_f32(weights_c + 12); + float32x4_t w4 = vld1q_f32(weights_c + 16); + float32x4_t w5 = vld1q_f32(weights_c + 20); + for (int h = 0; h < h_out; h += 1) { + //! (h * 2 - 2) + 4 > h_in - 1 + if (h * 2 + 3 > h_in) { + switch (h * 2 + 3 - h_in) { + case 4: + din1 = zero_ptr; + case 3: + din2 = zero_ptr; + case 2: + din3 = zero_ptr; + case 1: + din4 = zero_ptr; + default: + break; + } + } + const float* din_ptr0 = din0; + const float* din_ptr1 = din1; + const float* din_ptr2 = din2; + const float* din_ptr3 = din3; + const float* din_ptr4 = din4; + + const float* weights_ptr = weights_c + 24; + float* dout_ptr0 = dout0; + + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[c]; + } + float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; + int* mask_ptr = mask; + const int s_8 = 8; + + asm volatile( + "vmov.i32 q15, #0x0 \n" + "pld [%[din_ptr0]] \n" + "pld [%[din_ptr1]] \n" + "pld [%[din_ptr2]] \n" + "pld [%[din_ptr3]] \n" + "pld [%[din_ptr4]] \n" + "vld1.32 {d26-d27}, [%[vbias]] \n" + "vmov.32 q14, q15 \n" + "vld2.32 {d16-d19}, [%[din_ptr0]]! \n" + + // r0 + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + "vld1.32 {d21[1]}, [%[din_ptr0]] \n" + + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q13, q6, %e[w0][0] \n" + "vmla.f32 q14, q7, %e[w0][1] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q13, q8, %f[w0][0] \n" + "vmla.f32 q14, q9, %f[w0][1] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "vld2.32 {d16-d19}, [%[din_ptr1]]! \n" + "sub %[mask], #16 \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q13, q10, %e[w1][0] \n" + + // r1 + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + "vld1.32 {d21[1]}, [%[din_ptr1]] \n" + + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q14, q6, %e[w1][1] \n" + "vmla.f32 q13, q7, %f[w1][0] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q14, q8, %f[w1][1] \n" + "vmla.f32 q13, q9, %e[w2][0] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "vld2.32 {d16-d19}, [%[din_ptr2]]! \n" + "sub %[mask], #16 \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q14, q10, %e[w2][1] \n" + + // r2 + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + "vld1.32 {d21[1]}, [%[din_ptr2]] \n" + + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q13, q6, %f[w2][0] \n" + "vmla.f32 q14, q7, %f[w2][1] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q13, q8, %e[w3][0] \n" + "vmla.f32 q14, q9, %e[w3][1] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "vld2.32 {d16-d19}, [%[din_ptr3]]! \n" + "sub %[mask], #16 \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q13, q10, %f[w3][0] \n" + + // r3 + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + "vld1.32 {d21[1]}, [%[din_ptr3]] \n" + + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q14, q6, %f[w3][1] \n" + "vmla.f32 q13, q7, %e[w4][0] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q14, q8, %e[w4][1] \n" + "vmla.f32 q13, q9, %f[w4][0] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "vld2.32 {d16-d19}, [%[din_ptr4]]! \n" + "sub %[mask], #16 \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q14, q10, %f[w4][1] \n" + + // r4 + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vext.32 q6, q15, q8, #3 \n" + "vext.32 q7, q15, q9, #3 \n" + "vext.32 q10, q8, q15, #1 \n" + "vld1.32 {d21[1]}, [%[din_ptr4]] \n" + + "vbif.32 q6, q15, q11 \n" + "vbif.32 q7, q15, q12 \n" + "vmla.f32 q13, q6, %e[w5][0] \n" + "vmla.f32 q14, q7, %e[w5][1] \n" + + "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" + "vld1.32 {d12[0]}, [%[weights]] \n" + "vbif.32 q8, q15, q11 \n" + "vbif.32 q9, q15, q12 \n" + "vmla.f32 q13, q8, %f[w5][0] \n" + "vmla.f32 q14, q9, %f[w5][1] \n" + + "vld2.32 {d22-d25}, [%[mask]] \n" + "vbif.32 q10, q15, q11 \n" + "vmla.f32 q13, q10, d12[0] \n" + + "vadd.f32 q13, q13, q14 \n" + "vmax.f32 q13, q13, q15 \n" + "vst1.32 {d26-d27}, [%[out_buf0]] \n" + + : [dout_ptr0] "+r"(dout_ptr0), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [mask] "+r"(mask_ptr), + [weights] "+r"(weights_ptr) + : [vbias] "r"(vbias), + [out_buf0] "r"(out_buf0), + [s_8] "r"(s_8), + [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5) + : "memory", + "cc", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + for (int i = 0; i < w_out; ++i) { + dout_ptr0[i] = out_buf0[i]; + } + din0 = din2; + din1 = din3; + din2 = din4; + din3 = din2 + w_in; + din4 = din3 + w_in; + dout0 += w_out; + } + } + } +} +#endif // __aarch64__ + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_direct.cc b/lite/arm/math/conv_direct.cc new file mode 100644 index 00000000000..5e61aa4367e --- /dev/null +++ b/lite/arm/math/conv_direct.cc @@ -0,0 +1,242 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/conv_direct.h" +#include "lite/arm/math/conv_block_utils.h" +#include "lite/arm/math/conv_impl.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +bool DirectConv::create(const operators::ConvParam& param, + ARMContext* ctx) { + this->ctx_ = ctx; + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ic = x_dims[1]; + int ow = o_dims[3]; + int oc = o_dims[1]; + int kw = w_dims[3]; + int sw = param.strides[1]; + // select dw conv kernel + const auto* w_data = param.filter->data(); + if (kw == 3 && sw == 1) { + VLOG(5) << "invoke 3x3s1 direct conv"; + impl_ = conv_3x3s1_direct_fp32; + + constexpr int cblock = 4; + int cround = (oc + cblock - 1) / cblock * cblock; + weights_trans_.Resize({cround, ic, kw, kw}); + float* transed_w_data = weights_trans_.mutable_data(); + + conv_trans_weights_numc(w_data, transed_w_data, oc, ic, cblock, kw * kw); + is_weights_transed_ = true; + } else if (kw == 3 && sw == 2) { + VLOG(5) << "invoke 3x3s2 direct conv"; + impl_ = conv_3x3s2_direct_fp32; + + constexpr int cblock = 4; + int cround = (oc + cblock - 1) / cblock * cblock; + weights_trans_.Resize({cround, ic, kw, kw}); + float* transed_w_data = weights_trans_.mutable_data(); + conv_trans_weights_numc(w_data, transed_w_data, oc, ic, cblock, kw * kw); + is_weights_transed_ = true; + } else { + LOG(ERROR) << "this type direct conv not impl"; + return false; + } + return true; +} + +template <> +bool DirectConv::init(const operators::ConvParam& param, + Context* ctx) { + this->ctx_ = ctx; + return create(param, ctx); +} + +template <> +bool DirectConv::run(const operators::ConvParam& param) { + // start timer + const auto* i_data = param.x->data(); + const auto* w_data = param.filter->data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + auto* o_data = param.output->mutable_data(); + + if (is_weights_transed_ == true) { + w_data = weights_trans_.data(); + } + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + + impl_(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + this->ctx_); + + // timer end + return true; +} + +template +bool DirectConvInt8::create(const operators::ConvParam& param, + ARMContext* ctx) { + this->ctx_ = ctx; + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ic = x_dims[1]; + int ow = o_dims[3]; + int oc = o_dims[1]; + int kw = w_dims[3]; + int sw = param.strides[1]; + // select dw conv kernel + w_scale_ = param.weight_scale; + //! update weights scale + const auto* w_data = param.filter->data(); + if (Ptype_out == PRECISION(kInt8) || Ptype_out == PRECISION(kFloat)) { + CHECK_EQ(this->w_scale_.size(), oc) << "weights scale size must be chout"; + float input_scale = param.input_scale; + for (auto& w_s : w_scale_) { + w_s *= input_scale; + if (Ptype_out == PRECISION(kInt8)) { + w_s /= param.output_scale; + } + } + } + if (kw == 3 && sw == 1) { + VLOG(5) << "invoke 3x3s1 direct conv"; + impl_int8_ = conv_3x3s1_direct_int8; + + constexpr int cblock = 4; + int inpad = 4; + int cround = (oc + cblock - 1) / cblock * cblock; + weights_trans_.Resize({cround, ic, kw, kw}); + int8_t* transed_w_data = weights_trans_.mutable_data(); + conv_trans_weights_numc(w_data, transed_w_data, oc, ic, cblock, kw * kw); + + int wout_round = ((ow + 3) / 4) * 4; + int win_round = wout_round * sw + inpad; + int row_out = 2; + int row_in = 4; + int tmp_size_out = wout_round * row_out * cblock; + int in_len = win_round * ic; + int tmp_size_in = row_in * in_len; + ctx_->ExtendWorkspace(ctx_->threads() * tmp_size_out + + (tmp_size_in + 3) / 4 * 4 + wout_round + win_round); + is_weights_transed_ = true; + + } else if (kw == 3 && sw == 2) { + VLOG(5) << "invoke 3x3s2 direct conv"; + impl_int8_ = conv_3x3s2_direct_int8; + + // constexpr int cblock = 4; + int cblock = conv_3x3s2_direct_int8_c_num(); + int cround = (oc + cblock - 1) / cblock * cblock; + weights_trans_.Resize({cround, ic, kw, kw}); + int8_t* transed_w_data = weights_trans_.mutable_data(); + conv_trans_weights_numc(w_data, transed_w_data, oc, ic, cblock, kw * kw); + is_weights_transed_ = true; + + } else { + LOG(ERROR) << "this type direct conv not impl"; + return false; + } + return true; +} + +template +bool DirectConvInt8::init(const operators::ConvParam& param, + Context* ctx) { + this->ctx_ = ctx; + return create(param, ctx); +} + +template +bool DirectConvInt8::run(const operators::ConvParam& param) { + // start timer + const auto* i_data = param.x->data(); + const auto* w_data = param.filter->data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + auto* o_data = param.output->mutable_data(); + if (is_weights_transed_ == true) { + w_data = weights_trans_.data(); + } + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + + impl_int8_(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + this->ctx_, + Ptype_out, + w_scale_.data()); + + // Modified from int32 for debug convenience + if (Ptype_out == PRECISION(kInt8)) param.output->mutable_data(); + return true; +} + +template class DirectConvInt8; +template class DirectConvInt8; +template class DirectConvInt8; + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_direct.h b/lite/arm/math/conv_direct.h new file mode 100644 index 00000000000..b2883cd7e99 --- /dev/null +++ b/lite/arm/math/conv_direct.h @@ -0,0 +1,107 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/core/target_wrapper.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +class DirectConv : public ImplBase { + public: + typedef void (*conv_direct_impl)(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + Context* ctx); + + DirectConv() = default; + ~DirectConv() {} + + virtual bool init(const operators::ConvParam& param, + Context* ctx); + + virtual bool create(const operators::ConvParam& param, + Context* ctx); + + virtual bool run(const operators::ConvParam& param); + + protected: + bool is_weights_transed_{false}; + Tensor weights_trans_; + Tensor _tmp_out; + + private: + conv_direct_impl impl_{nullptr}; +}; + +template +class DirectConvInt8 + : public ImplBase { + public: + typedef void (*conv_direct_int8_impl)(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + Context* ctx, + PrecisionType out_type, + const float* scale); + + DirectConvInt8() = default; + ~DirectConvInt8() {} + + virtual bool init(const operators::ConvParam& param, + Context* ctx); + + virtual bool create(const operators::ConvParam& param, + Context* ctx); + + virtual bool run(const operators::ConvParam& param); + + private: + bool is_weights_transed_{false}; + Tensor weights_trans_; + Tensor _tmp_out; + conv_direct_int8_impl impl_int8_{nullptr}; + std::vector w_scale_; +}; + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_direct_3x3s1.cc b/lite/arm/math/conv_direct_3x3s1.cc new file mode 100644 index 00000000000..d1973705ecd --- /dev/null +++ b/lite/arm/math/conv_direct_3x3s1.cc @@ -0,0 +1,1067 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/arm/math/conv_block_utils.h" +#include "lite/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void conv_3x3s1_direct_fp32(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + const int threads = ctx->threads(); + int l2_size = ctx->llc_size() / sizeof(float); + + const int pad_h = param.paddings[0]; + const int pad_w = param.paddings[1]; + const int hout_c_block = 4; + const int hout_r_kernel = 2; + const int wout_block = 4; + const int wout_round = ((ow + wout_block - 1) / wout_block) * wout_block; + const int win_round = wout_round + 2; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + // if (param.activation_param.has_active) { + // if (param.activation_param.active == Active_relu && + // fabs(param.activation_param.negative_slope) < 1e-6f) { + // flag_relu = true; + // } + // } + int hout_r_block = (l2_size - 2 * win_round * ic) / + (win_round * ic + hout_c_block * wout_round * threads); + hout_r_block = hout_r_block > oh ? oh : hout_r_block; + hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel; + hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + + const int hin_r_block = hout_r_block + 2; + + float* tmp_work_space = ctx->workspace_data(); + float ptr_zero[win_round]; // NOLINT + memset(ptr_zero, 0, sizeof(float) * win_round); + float ptr_write[wout_round]; // NOLINT + + int in_len = win_round * ic; + int pre_in_size = hin_r_block * in_len; + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + float* pre_din = tmp_work_space; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + int w_stride = ic * 9; // kernel_w * kernel_h; + int w_stride_chin = hout_c_block * 9; // kernel_w * kernel_h * + + int ws = -pad_w; + int we = ws + win_round; + int w_loop = wout_round / 4; + + int c_remain = oc - (oc / hout_c_block) * hout_c_block; + int c_round_down = (oc / hout_c_block) * hout_c_block; + + int out_row_stride = hout_c_block * wout_round; + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; + for (int h = 0; h < oh; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > oh) { + h_kernel = oh - h; + } + int hs = h - pad_h; + int he = hs + h_kernel + 2; + prepack_input_nxw( + din_batch, pre_din, 0, ic, hs, he, ws, we, ic, win, ih, ptr_zero); +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < oc - (hout_c_block - 1); c += hout_c_block) { +#ifdef ARM_WITH_OMP + float* pre_out = + pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; +#else + float* pre_out = pre_din + pre_in_size; +#endif + const float* block_inr0 = pre_din; + const float* block_inr1 = block_inr0 + in_len; + const float* block_inr2 = block_inr1 + in_len; + const float* block_inr3 = block_inr2 + in_len; + + const float* weight_c = weights + c * w_stride; + const float* bias_ptr = ptr_zero; + if (flag_bias) { + bias_ptr = bias + c; + } + fill_packed_biasc4( + pre_out, bias_ptr, wout_round * hout_c_block * h_kernel); + + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + const float* wc0 = weight_c; + + const float* inr0 = block_inr0; + const float* inr1 = block_inr1; + const float* inr2 = block_inr2; + const float* inr3 = block_inr3; + + float* pre_out0 = pre_out + hk * out_row_stride; + float* pre_out1 = pre_out0 + out_row_stride; +#ifdef __aarch64__ + for (int i = 0; i < ic; ++i) { + float* ptr_out0 = pre_out0; + float* ptr_out1 = pre_out1; + + float32x4_t w0 = vld1q_f32(wc0); // w0, v23 + float32x4_t w1 = vld1q_f32(wc0 + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(wc0 + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(wc0 + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(wc0 + 16); // w4, v27 + float32x4_t w5 = vld1q_f32(wc0 + 20); // w5, v28 + float32x4_t w6 = vld1q_f32(wc0 + 24); // w6, v29 + float32x4_t w7 = vld1q_f32(wc0 + 28); // w7, v30 + float32x4_t w8 = vld1q_f32(wc0 + 32); // w8, v31 + + const float* r0 = inr0; + const float* r1 = inr1; + const float* r2 = inr2; + const float* r3 = inr3; + + int cnt = w_loop; + asm volatile( + "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, + outr01*/ + "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ + "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ + "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr10, outr11*/ + "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ + "ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/ + "2: \n" /* main loop*/ + /* r0, r1, mul w0, get out r0, r1 */ + "fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/ + "fmla v16.4s , %[w0].4s, v0.s[1]\n" /* outr01 = w0 * r0[1]*/ + "fmla v17.4s , %[w0].4s, v0.s[2]\n" /* outr02 = w0 * r0[2]*/ + "fmla v18.4s , %[w0].4s, v0.s[3]\n" /* outr03 = w0 * r0[3]*/ + "fmla v19.4s , %[w0].4s, v2.s[0]\n" /* outr10 = w0 * r1[0]*/ + "fmla v20.4s , %[w0].4s, v2.s[1]\n" /* outr11 = w0 * r1[1]*/ + "fmla v21.4s , %[w0].4s, v2.s[2]\n" /* outr12 = w0 * r1[2]*/ + "fmla v22.4s , %[w0].4s, v2.s[3]\n" /* outr13 = w0 * r1[3]*/ + + /* r0, r1, mul w1, get out r0, r1 */ + "fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/ + "fmla v16.4s , %[w1].4s, v0.s[2]\n" /* outr01 = w1 * r0[2]*/ + "fmla v17.4s , %[w1].4s, v0.s[3]\n" /* outr02 = w1 * r0[3]*/ + "fmla v18.4s , %[w1].4s, v1.s[0]\n" /* outr03 = w1 * r0[4]*/ + "fmla v19.4s , %[w1].4s, v2.s[1]\n" /* outr10 = w1 * r1[1]*/ + "fmla v20.4s , %[w1].4s, v2.s[2]\n" /* outr11 = w1 * r1[2]*/ + "fmla v21.4s , %[w1].4s, v2.s[3]\n" /* outr12 = w1 * r1[3]*/ + "fmla v22.4s , %[w1].4s, v3.s[0]\n" /* outr13 = w1 * r1[4]*/ + + "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ + + /* r0, r1, mul w2, get out r0, r1 */ + "fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/ + "fmla v16.4s , %[w2].4s, v0.s[3]\n" /* outr01 = w2 * r0[3]*/ + "fmla v17.4s , %[w2].4s, v1.s[0]\n" /* outr02 = w2 * r0[0]*/ + "fmla v18.4s , %[w2].4s, v1.s[1]\n" /* outr03 = w2 * r0[1]*/ + "fmla v19.4s , %[w2].4s, v2.s[2]\n" /* outr10 = w2 * r1[2]*/ + "fmla v20.4s , %[w2].4s, v2.s[3]\n" /* outr11 = w2 * r1[3]*/ + "fmla v21.4s , %[w2].4s, v3.s[0]\n" /* outr12 = w2 * r1[0]*/ + "fmla v22.4s , %[w2].4s, v3.s[1]\n" /* outr13 = w2 * r1[1]*/ + + /* r1, r2, mul w3, get out r0, r1 */ + "fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/ + "fmla v16.4s , %[w3].4s, v2.s[1]\n" /* outr01 = w3 * r1[1]*/ + "fmla v17.4s , %[w3].4s, v2.s[2]\n" /* outr02 = w3 * r1[2]*/ + "fmla v18.4s , %[w3].4s, v2.s[3]\n" /* outr03 = w3 * r1[3]*/ + "fmla v19.4s , %[w3].4s, v4.s[0]\n" /* outr10 = w3 * r2[0]*/ + "fmla v20.4s , %[w3].4s, v4.s[1]\n" /* outr11 = w3 * r2[1]*/ + "fmla v21.4s , %[w3].4s, v4.s[2]\n" /* outr12 = w3 * r2[2]*/ + "fmla v22.4s , %[w3].4s, v4.s[3]\n" /* outr13 = w3 * r2[3]*/ + + "ldp q0, q1, [%[r0]], #16 \n" /* load next input r0*/ + + /* r1, r2, mul w4, get out r0, r1 */ + "fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/ + "fmla v16.4s , %[w4].4s, v2.s[2]\n" /* outr01 = w4 * r1[2]*/ + "fmla v17.4s , %[w4].4s, v2.s[3]\n" /* outr02 = w4 * r1[3]*/ + "fmla v18.4s , %[w4].4s, v3.s[0]\n" /* outr03 = w4 * r1[4]*/ + "fmla v19.4s , %[w4].4s, v4.s[1]\n" /* outr10 = w4 * r2[1]*/ + "fmla v20.4s , %[w4].4s, v4.s[2]\n" /* outr11 = w4 * r2[2]*/ + "fmla v21.4s , %[w4].4s, v4.s[3]\n" /* outr12 = w4 * r2[3]*/ + "fmla v22.4s , %[w4].4s, v5.s[0]\n" /* outr13 = w4 * r2[4]*/ + + "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ + + /* r1, r2, mul w5, get out r0, r1 */ + "fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/ + "fmla v16.4s , %[w5].4s, v2.s[3]\n" /* outr01 = w5 * r1[3]*/ + "fmla v17.4s , %[w5].4s, v3.s[0]\n" /* outr02 = w5 * r1[0]*/ + "fmla v18.4s , %[w5].4s, v3.s[1]\n" /* outr03 = w5 * r1[1]*/ + "fmla v19.4s , %[w5].4s, v4.s[2]\n" /* outr10 = w5 * r2[2]*/ + "fmla v20.4s , %[w5].4s, v4.s[3]\n" /* outr11 = w5 * r2[3]*/ + "fmla v21.4s , %[w5].4s, v5.s[0]\n" /* outr12 = w5 * r2[0]*/ + "fmla v22.4s , %[w5].4s, v5.s[1]\n" /* outr13 = w5 * r2[1]*/ + + /* r2, r3, mul w6, get out r0, r1 */ + "fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/ + "fmla v16.4s , %[w6].4s, v4.s[1]\n" /* outr01 = w6 * r2[1]*/ + "fmla v17.4s , %[w6].4s, v4.s[2]\n" /* outr02 = w6 * r2[2]*/ + "fmla v18.4s , %[w6].4s, v4.s[3]\n" /* outr03 = w6 * r2[3]*/ + "fmla v19.4s , %[w6].4s, v6.s[0]\n" /* outr10 = w6 * r3[0]*/ + "fmla v20.4s , %[w6].4s, v6.s[1]\n" /* outr11 = w6 * r3[1]*/ + "fmla v21.4s , %[w6].4s, v6.s[2]\n" /* outr12 = w6 * r3[2]*/ + "fmla v22.4s , %[w6].4s, v6.s[3]\n" /* outr13 = w6 * r3[3]*/ + + "ldp q2, q3, [%[r1]], #16 \n" /* load next input r1*/ + + /* r2, r3, mul w7, get out r0, r1 */ + "fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/ + "fmla v16.4s , %[w7].4s, v4.s[2]\n" /* outr01 = w7 * r2[2]*/ + "fmla v17.4s , %[w7].4s, v4.s[3]\n" /* outr02 = w7 * r2[3]*/ + "fmla v18.4s , %[w7].4s, v5.s[0]\n" /* outr03 = w7 * r2[4]*/ + "fmla v19.4s , %[w7].4s, v6.s[1]\n" /* outr10 = w7 * r3[1]*/ + "fmla v20.4s , %[w7].4s, v6.s[2]\n" /* outr11 = w7 * r3[2]*/ + "fmla v21.4s , %[w7].4s, v6.s[3]\n" /* outr12 = w7 * r3[3]*/ + "fmla v22.4s , %[w7].4s, v7.s[0]\n" /* outr13 = w7 * r3[4]*/ + + "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ + + /* r2, r3, mul w8, get out r0, r1 */ + "fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/ + "fmla v16.4s , %[w8].4s, v4.s[3]\n" /* outr01 = w8 * r2[3]*/ + "fmla v17.4s , %[w8].4s, v5.s[0]\n" /* outr02 = w8 * r2[0]*/ + "fmla v18.4s , %[w8].4s, v5.s[1]\n" /* outr03 = w8 * r2[1]*/ + + "stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/ + "fmla v19.4s , %[w8].4s, v6.s[2]\n" /* outr10 = w8 * r3[2]*/ + "stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/ + "fmla v20.4s , %[w8].4s, v6.s[3]\n" /* outr11 = w8 * r3[3]*/ + "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/ + "fmla v21.4s , %[w8].4s, v7.s[0]\n" /* outr12 = w8 * r3[0]*/ + "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ + "fmla v22.4s , %[w8].4s, v7.s[1]\n" /* outr13 = w8 * r3[1]*/ + "stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/ + "stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/ + "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ + "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/ + "bne 2b \n" /* jump to main loop*/ + + : [cnt] "+r"(cnt), + [r0] "+r"(r0), + [r1] "+r"(r1), + [r2] "+r"(r2), + [r3] "+r"(r3), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); + + wc0 += 9 * hout_c_block; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + } +#else // not __aarch64__ + for (int i = 0; i < ic; ++i) { + const float* wc0 = weight_c + i * w_stride_chin; + + float* ptr_out0 = pre_out0; + float* ptr_out1 = pre_out1; + + const float* r0 = inr0; + const float* r1 = inr1; + const float* r2 = inr2; + const float* r3 = inr3; + + int cnt = w_loop; + asm volatile( + "vld1.32 {d16-d19}, [%[ptr_out0]]! @ " + "load outr0, w0, w1, c0~c3\n" + "vld1.32 {d20-d23}, [%[ptr_out0]] @ load " + "outr0, w2, w3, c0~c3\n" + + /* load weights */ + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " + "w1, to q5, q6\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " + "to q7\n" + + /* load r0, r1 */ + "vld1.32 {d0-d1}, [%[r0]]! @ load r0, " + "4 float\n" + "vld1.32 {d2}, [%[r0]] @ load r0, " + "2 float\n" + + "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " + "- 32, to start address\n" + + /* main loop */ + "0: @ main " + "loop\n" + /* mul r0 with w0, w1, w2, get out r0 */ + "vld1.32 {d24-d27}, [%[ptr_out1]]! @ load " + "outr1, w0, w1, c0~c3\n" + "vmla.f32 q8, q5, d0[0] @ w0 * " + "inr00\n" + "vld1.32 {d28-d31}, [%[ptr_out1]] @ load " + "outr1, w2, w3, c0~c3\n" + "vmla.f32 q9, q5, d0[1] @ w0 * " + "inr01\n" + "vmla.f32 q10, q5, d1[0] @ w0 * " + "inr02\n" + "vmla.f32 q11, q5, d1[1] @ w0 * " + "inr03\n" + "vld1.32 {d3-d4}, [%[r1]]! @ load r1, " + "4 float\n" + "vmla.f32 q8, q6, d0[1] @ w1 * " + "inr01\n" + "vmla.f32 q9, q6, d1[0] @ w1 * " + "inr02\n" + "vmla.f32 q10, q6, d1[1] @ w1 * " + "inr03\n" + "vmla.f32 q11, q6, d2[0] @ w1 * " + "inr04\n" + "vld1.32 {d5}, [%[r1]] @ load r0, " + "2 float\n" + "vmla.f32 q8, q7, d1[0] @ w2 * " + "inr02\n" + "vmla.f32 q9, q7, d1[1] @ w2 * " + "inr03\n" + "vmla.f32 q10, q7, d2[0] @ w2 * " + "inr04\n" + "vmla.f32 q11, q7, d2[1] @ w2 * " + "inr05\n" + + "sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 " + "- 32, to start address\n" + + /* mul r1 with w0, w1, w2, get out r1 */ + "vmla.f32 q12, q5, d3[0] @ w0 * " + "inr10\n" + "vmla.f32 q13, q5, d3[1] @ w0 * " + "inr11\n" + "vmla.f32 q14, q5, d4[0] @ w0 * " + "inr12\n" + "vmla.f32 q15, q5, d4[1] @ w0 * " + "inr13\n" + "vmla.f32 q12, q6, d3[1] @ w1 * " + "inr11\n" + "vmla.f32 q13, q6, d4[0] @ w1 * " + "inr12\n" + "vmla.f32 q14, q6, d4[1] @ w1 * " + "inr13\n" + "vmla.f32 q15, q6, d5[0] @ w1 * " + "inr14\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w3, " + "w4, to q5, q6\n" + "vmla.f32 q12, q7, d4[0] @ w2 * " + "inr12\n" + "vmla.f32 q13, q7, d4[1] @ w2 * " + "inr13\n" + "vmla.f32 q14, q7, d5[0] @ w2 * " + "inr14\n" + "vmla.f32 q15, q7, d5[1] @ w2 * " + "inr15\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w5, " + "to q7\n" + + /* mul r1 with w3, w4, w5, get out r0 */ + "vmla.f32 q8, q5, d3[0] @ w3 * " + "inr10\n" + "vmla.f32 q9, q5, d3[1] @ w3 * " + "inr11\n" + "vmla.f32 q10, q5, d4[0] @ w3 * " + "inr12\n" + "vmla.f32 q11, q5, d4[1] @ w3 * " + "inr13\n" + "vld1.32 {d0-d1}, [%[r2]]! @ load r2, " + "4 float\n" + "vmla.f32 q8, q6, d3[1] @ w4 * " + "inr11\n" + "vmla.f32 q9, q6, d4[0] @ w4 * " + "inr12\n" + "vmla.f32 q10, q6, d4[1] @ w4 * " + "inr13\n" + "vmla.f32 q11, q6, d5[0] @ w4 * " + "inr14\n" + "vld1.32 {d2}, [%[r2]] @ load r2, " + "2 float\n" + "vmla.f32 q8, q7, d4[0] @ w5 * " + "inr12\n" + "vmla.f32 q9, q7, d4[1] @ w5 * " + "inr13\n" + "vmla.f32 q10, q7, d5[0] @ w5 * " + "inr14\n" + "vmla.f32 q11, q7, d5[1] @ w5 * " + "inr15\n" + + /* mul r2 with w3, w4, w5, get out r1 */ + "vmla.f32 q12, q5, d0[0] @ w3 * " + "inr20\n" + "vmla.f32 q13, q5, d0[1] @ w3 * " + "inr21\n" + "vmla.f32 q14, q5, d1[0] @ w3 * " + "inr22\n" + "vmla.f32 q15, q5, d1[1] @ w3 * " + "inr23\n" + "vmla.f32 q12, q6, d0[1] @ w4 * " + "inr21\n" + "vmla.f32 q13, q6, d1[0] @ w4 * " + "inr22\n" + "vmla.f32 q14, q6, d1[1] @ w4 * " + "inr23\n" + "vmla.f32 q15, q6, d2[0] @ w4 * " + "inr24\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w6, " + "w7, to q5, q6\n" + "vmla.f32 q12, q7, d1[0] @ w5 * " + "inr22\n" + "vmla.f32 q13, q7, d1[1] @ w5 * " + "inr23\n" + "vmla.f32 q14, q7, d2[0] @ w5 * " + "inr24\n" + "vmla.f32 q15, q7, d2[1] @ w5 * " + "inr25\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w8, " + "to q7\n" + + "sub %[wc0], %[wc0], #144 @ wc0 - " + "144 to start address\n" + + /* mul r2 with w6, w7, w8, get out r0 */ + "vmla.f32 q8, q5, d0[0] @ w6 * " + "inr20\n" + "vmla.f32 q9, q5, d0[1] @ w6 * " + "inr21\n" + "vld1.32 {d3-d4}, [%[r3]]! @ load r3, " + "4 float\n" + "vmla.f32 q10, q5, d1[0] @ w6 * " + "inr22\n" + "vmla.f32 q11, q5, d1[1] @ w6 * " + "inr23\n" + "vmla.f32 q8, q6, d0[1] @ w7 * " + "inr21\n" + "vmla.f32 q9, q6, d1[0] @ w7 * " + "inr22\n" + "vld1.32 {d5}, [%[r3]] @ load r3, " + "2 float\n" + "vmla.f32 q10, q6, d1[1] @ w7 * " + "inr23\n" + "vmla.f32 q11, q6, d2[0] @ w7 * " + "inr24\n" + "vmla.f32 q8, q7, d1[0] @ w8 * " + "inr22\n" + "vmla.f32 q9, q7, d1[1] @ w8 * " + "inr23\n" + "vld1.32 {d0-d1}, [%[r0]]! @ load r0, " + "4 float\n" + "vmla.f32 q10, q7, d2[0] @ w8 * " + "inr24\n" + "vmla.f32 q11, q7, d2[1] @ w8 * " + "inr25\n" + "vld1.32 {d2}, [%[r0]] @ load r0, " + "2 float\n" + + /* mul r3 with w6, w7, w8, get out r1 */ + "vmla.f32 q12, q5, d3[0] @ w6 * " + "inr20\n" + "vmla.f32 q13, q5, d3[1] @ w6 * " + "inr21\n" + "vst1.32 {d16-d19}, [%[ptr_out0]]! @ save " + "r00, r01, c0~c3\n" + "vmla.f32 q14, q5, d4[0] @ w6 * " + "inr22\n" + "vmla.f32 q15, q5, d4[1] @ w6 * " + "inr23\n" + "vst1.32 {d20-d23}, [%[ptr_out0]]! @ save " + "r02, r03, c0~c3\n" + "vmla.f32 q12, q6, d3[1] @ w7 * " + "inr21\n" + "vmla.f32 q13, q6, d4[0] @ w7 * " + "inr22\n" + "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load " + "outr0, w0, w1, c0~c3\n" + "vmla.f32 q14, q6, d4[1] @ w7 * " + "inr23\n" + "vmla.f32 q15, q6, d5[0] @ w7 * " + "inr24\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " + "w1, to q5, q6\n" + "vmla.f32 q12, q7, d4[0] @ w8 * " + "inr22\n" + "vmla.f32 q13, q7, d4[1] @ w8 * " + "inr23\n" + "vld1.32 {d20-d23}, [%[ptr_out0]] @ load " + "outr0, w2, w3, c0~c3\n" + "vmla.f32 q14, q7, d5[0] @ w8 * " + "inr24\n" + "vmla.f32 q15, q7, d5[1] @ w8 * " + "inr25\n" + + "vst1.32 {d24-d27}, [%[ptr_out1]]! @ save " + "r10, r11, c0~c3\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save " + "r12, r13, c0~c3\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " + "to q7\n" + + "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " + "- 32, to start address\n" + + "subs %[cnt], #1 @ loop " + "count--\n" + "bne 0b @ jump to " + "main loop\n" + + : [cnt] "+r"(cnt), + [r0] "+r"(r0), + [r1] "+r"(r1), + [r2] "+r"(r2), + [r3] "+r"(r3), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1), + [wc0] "+r"(wc0) + : + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + } +#endif // __aarch64__ + block_inr0 = block_inr2; + block_inr1 = block_inr3; + block_inr2 = block_inr1 + in_len; + block_inr3 = block_inr2 + in_len; + } + write_to_output_c4_fp32(pre_out, + dout_batch, + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, + oc, + oh, + ow, + flag_relu, + ptr_write); + } + const float* weight_remain_ptr = weights + c_round_down * w_stride; +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < c_remain; ++c) { +#ifdef ARM_WITH_OMP + float* pre_out = + pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; +#else + float* pre_out = pre_din + pre_in_size; +#endif + + int c_idx = c_round_down + c; + + int h_kernel = hout_r_block; + if (h + hout_r_block > oh) { + h_kernel = oh - h; + } + + const float* block_inr0 = pre_din; + const float* block_inr1 = block_inr0 + in_len; + const float* block_inr2 = block_inr1 + in_len; + const float* block_inr3 = block_inr2 + in_len; + + const float* bias_ptr = ptr_zero; + if (flag_bias) { + bias_ptr = bias + c_idx; + } + fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel); + + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + const float* wc0 = weight_remain_ptr; + + const float* inr0 = block_inr0; + const float* inr1 = block_inr1; + const float* inr2 = block_inr2; + const float* inr3 = block_inr3; + + float* pre_out0 = pre_out + hk * wout_round; + float* pre_out1 = pre_out0 + wout_round; +#ifdef __aarch64__ + for (int i = 0; i < ic; ++i) { + float* ptr_out0 = pre_out0; + float* ptr_out1 = pre_out1; + + float32x4_t w0 = vdupq_n_f32(wc0[c]); // w0, v23 + float32x4_t w1 = vdupq_n_f32(wc0[4 + c]); // w1, v24 + float32x4_t w2 = vdupq_n_f32(wc0[8 + c]); // w2, v25 + float32x4_t w3 = vdupq_n_f32(wc0[12 + c]); // w3, v26 + float32x4_t w4 = vdupq_n_f32(wc0[16 + c]); // w4, v27 + float32x4_t w5 = vdupq_n_f32(wc0[20 + c]); // w5, v28 + float32x4_t w6 = vdupq_n_f32(wc0[24 + c]); // w6, v29 + float32x4_t w7 = vdupq_n_f32(wc0[28 + c]); // w7, v30 + float32x4_t w8 = vdupq_n_f32(wc0[32 + c]); // w8, v31 + + const float* r0 = inr0; + const float* r1 = inr1; + const float* r2 = inr2; + const float* r3 = inr3; + + int cnt = w_loop; + asm volatile( + "ldr q21, [%[ptr_out0]] \n" /* load outr0, + w0~w3*/ + "ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/ + "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ + "ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/ + "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ + "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ + "2: \n" /* main loop*/ + + "fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0*/ + "fmla v22.4s , %[w0].4s, v2.4s \n" /* outr1 = w0 * r1*/ + + "ext v8.16b, v0.16b, v1.16b, #4 \n" /* shift r0 left 1*/ + "ext v10.16b, v2.16b, v3.16b, #4 \n" /* shift r1 left 1*/ + "ext v9.16b, v0.16b, v1.16b, #8 \n" /* shift r0 left 2*/ + "ext v11.16b, v2.16b, v3.16b, #8 \n" /* shift r1 left 2*/ + + "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ + + "fmla v21.4s , %[w1].4s, v8.4s \n" /* outr0 = w1 * r1*/ + "fmla v22.4s , %[w1].4s, v10.4s \n" /* outr1 = w1 * r2*/ + + "fmla v21.4s , %[w2].4s, v9.4s \n" /* outr0 = w2 * r1*/ + "fmla v22.4s , %[w2].4s, v11.4s \n" /* outr1 = w2 * r2*/ + + "fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1*/ + "fmla v22.4s , %[w3].4s, v4.4s \n" /* outr1 = w3 * r2*/ + + "ext v12.16b, v4.16b, v5.16b, #4\n" /* shift r2 left 1*/ + "ext v14.16b, v6.16b, v7.16b, #4\n" /* shift r3 left 1*/ + "ext v13.16b, v4.16b, v5.16b, #8\n" /* shift r2 left 2*/ + "ext v15.16b, v6.16b, v7.16b, #8\n" /* shift r3 left 2*/ + + "fmla v21.4s , %[w4].4s, v10.4s \n" /* outr0 = w4 * r1*/ + "fmla v22.4s , %[w4].4s, v12.4s \n" /* outr1 = w4 * r2*/ + + "fmla v21.4s , %[w5].4s, v11.4s \n" /* outr0 = w5 * r1*/ + "fmla v22.4s , %[w5].4s, v13.4s \n" /* outr1 = w5 * r2*/ + + "ldp q2, q3, [%[r1]], #16 \n" /* load input r0*/ + + "fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2*/ + "fmla v22.4s , %[w6].4s, v6.4s \n" /* outr1 = w6 * r3*/ + + "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ + + "fmla v21.4s , %[w7].4s, v12.4s \n" /* outr0 = w7 * r1*/ + "fmla v22.4s , %[w7].4s, v14.4s \n" /* outr1 = w7 * r2*/ + + "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ + + "fmla v21.4s , %[w8].4s, v13.4s \n" /* outr0 = w8 * r1*/ + "fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r2*/ + + "str q21, [%[ptr_out0]], #16 \n" /*write output r0*/ + "str q22, [%[ptr_out1]], #16 \n" /*write output r1*/ + + "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ + + "ldr q21, [%[ptr_out0]] \n" /* load outr0, w0~w3*/ + "ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/ + + "bne 2b \n" /* jump to main loop*/ + + : [cnt] "+r"(cnt), + [r0] "+r"(r0), + [r1] "+r"(r1), + [r2] "+r"(r2), + [r3] "+r"(r3), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v21", + "v22"); + + wc0 += 9 * hout_c_block; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + } +#else // not __aarch64__ + for (int i = 0; i < ic; ++i) { + float* ptr_out0 = pre_out0; + float* ptr_out1 = pre_out1; + + //! get valid weights of current output channel + float w_tmp[10] = {wc0[c], + wc0[c + 4], + wc0[c + 8], + wc0[c + 12], + wc0[c + 16], + wc0[c + 20], + wc0[c + 24], + wc0[c + 28], + wc0[c + 32], + 0.f}; + float32x4_t w0 = vld1q_f32(w_tmp); // w0, w1, w2, q0 + float32x4_t w1 = vld1q_f32(w_tmp + 3); // w3, w4, w5, q1 + float32x4_t w2 = vld1q_f32(w_tmp + 6); // w6, w7, w8, q2 + + const float* r0 = inr0; + const float* r1 = inr1; + const float* r2 = inr2; + const float* r3 = inr3; + int cnt = w_loop / 2; + if (cnt > 0) { + asm volatile( + "vld1.32 {d24-d27}, [%[ptr_out0]] @ " + "load or00, or01\n" + "vld1.32 {d6-d9}, [%[r0]]! @ load r0, 8 " + "float\n" + "vld1.32 {d10}, [%[r0]] @ load r0, 2 " + "float\n" + /* main loop */ + "0: @ main loop\n" + /* r0 * w0, w1, w2, get out r0*/ + "vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, " + "or11\n" + "vext.32 q8, q3, q4, #1 @ r0, shift " + "left 1, get 1, 2, 3, 4\n" + "vext.32 q9, q4, q5, #1 @ r0, shift " + "left 1, get 5, 6, 7, 8\n" + "vmla.f32 q12, q3, %e[w0][0] @ w00 * r0, " + "0, 1, 2, 3\n" + "vmla.f32 q13, q4, %e[w0][0] @ w00 * r0, " + "4, 5, 6, 7\n" + "vext.32 q10, q3, q4, #2 @ r0, shift " + "left 2, get 2, 3, 4, 5\n" + "vext.32 q11, q4, q5, #2 @ r0, shift " + "left 2, get 6, 7, 8, 9\n" + "vmla.f32 q12, q8, %e[w0][1] @ w01 * r0, " + "1, 2, 3, 4\n" + "vmla.f32 q13, q9, %e[w0][1] @ w01 * r0, " + "5, 6, 7, 8\n" + "vld1.32 {d6-d9}, [%[r1]]! @ load r1, 8 " + "float\n" + "vmla.f32 q12, q10, %f[w0][0] @ w02 * r0, " + "2, 3, 4, 5\n" + "vmla.f32 q13, q11, %f[w0][0] @ w02 * r0, " + "6, 7, 8, 9\n" + "vld1.32 {d10}, [%[r1]] @ load r1, 2 " + "float\n" + + /* r1 * w3, w4, w5, get out r0*/ + /* r1 * w0, w1, w2, get out r1*/ + "vmla.f32 q12, q3, %e[w1][0] @ w10 * r1, " + "0, 1, 2, 3\n" + "vmla.f32 q13, q4, %e[w1][0] @ w10 * r1, " + "4, 5, 6, 7\n" + "vext.32 q8, q3, q4, #1 @ r1, shift " + "left 1, get 1, 2, 3, 4\n" + "vext.32 q9, q4, q5, #1 @ r1, shift " + "left 1, get 5, 6, 7, 8\n" + "vmla.f32 q14, q3, %e[w0][0] @ w00 * r1, " + "0, 1, 2, 3\n" + "vmla.f32 q15, q4, %e[w0][0] @ w00 * r1, " + "4, 5, 6, 7\n" + "vext.32 q10, q3, q4, #2 @ r1, shift " + "left 2, get 2, 3, 4, 5\n" + "vext.32 q11, q4, q5, #2 @ r1, shift " + "left 2, get 6, 7, 8, 9\n" + "vmla.f32 q12, q8, %e[w1][1] @ w11 * r1, " + "1, 2, 3, 4\n" + "vmla.f32 q13, q9, %e[w1][1] @ w11 * r1, " + "5, 6, 7, 8\n" + "vmla.f32 q14, q8, %e[w0][1] @ w01 * r1, " + "1, 2, 3, 4\n" + "vmla.f32 q15, q9, %e[w0][1] @ w01 * r1, " + "5, 6, 7, 8\n" + "vld1.32 {d6-d9}, [%[r2]]! @ load r2, 8 " + "float\n" + "vmla.f32 q12, q10, %f[w1][0] @ w12 * r1, " + "2, 3, 4, 5\n" + "vmla.f32 q13, q11, %f[w1][0] @ w12 * r1, " + "6, 7, 8, 9\n" + "vmla.f32 q14, q10, %f[w0][0] @ w02 * r1, " + "2, 3, 4, 5\n" + "vmla.f32 q15, q11, %f[w0][0] @ w02 * r1, " + "6, 7, 8, 9\n" + "vld1.32 {d10}, [%[r2]] @ load r2, 2 " + "float\n" + + /* r2 * w6, w7, w8, get out r0*/ + /* r2 * w3, w4, w5, get out r1*/ + "vmla.f32 q12, q3, %e[w2][0] @ w20 * r2, " + "0, 1, 2, 3\n" + "vmla.f32 q13, q4, %e[w2][0] @ w20 * r2, " + "4, 5, 6, 7\n" + "vext.32 q8, q3, q4, #1 @ r2, shift " + "left 1, get 1, 2, 3, 4\n" + "vext.32 q9, q4, q5, #1 @ r2, shift " + "left 1, get 5, 6, 7, 8\n" + "vmla.f32 q14, q3, %e[w1][0] @ w10 * r2, " + "0, 1, 2, 3\n" + "vmla.f32 q15, q4, %e[w1][0] @ w10 * r2, " + "4, 5, 6, 7\n" + "vext.32 q10, q3, q4, #2 @ r2, shift " + "left 2, get 2, 3, 4, 5\n" + "vext.32 q11, q4, q5, #2 @ r2, shift " + "left 2, get 6, 7, 8, 9\n" + "vmla.f32 q12, q8, %e[w2][1] @ w21 * r2, " + "1, 2, 3, 4\n" + "vmla.f32 q13, q9, %e[w2][1] @ w21 * r2, " + "5, 6, 7, 8\n" + "vmla.f32 q14, q8, %e[w1][1] @ w11 * r2, " + "1, 2, 3, 4\n" + "vmla.f32 q15, q9, %e[w1][1] @ w11 * r2, " + "5, 6, 7, 8\n" + "vld1.32 {d6-d9}, [%[r3]]! @ load r3, 8 " + "float\n" + "vmla.f32 q12, q10, %f[w2][0] @ w22 * r2, " + "2, 3, 4, 5\n" + "vmla.f32 q13, q11, %f[w2][0] @ w22 * r2, " + "6, 7, 8, 9\n" + "vmla.f32 q14, q10, %f[w1][0] @ w12 * r2, " + "2, 3, 4, 5\n" + "vmla.f32 q15, q11, %f[w1][0] @ w12 * r2, " + "6, 7, 8, 9\n" + "vld1.32 {d10}, [%[r3]] @ load r3, 2 " + "float\n" + + /* r3 * w6, w7, w8, get out r1*/ + "vext.32 q8, q3, q4, #1 @ r3, shift " + "left 1, get 1, 2, 3, 4\n" + "vext.32 q9, q4, q5, #1 @ r3, shift " + "left 1, get 5, 6, 7, 8\n" + "vmla.f32 q14, q3, %e[w2][0] @ w20 * r3, " + "0, 1, 2, 3\n" + "vmla.f32 q15, q4, %e[w2][0] @ w20 * r3, " + "4, 5, 6, 7\n" + "vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or00, " + "or01\n" + "vext.32 q10, q3, q4, #2 @ r3, shift " + "left 2, get 2, 3, 4, 5\n" + "vext.32 q11, q4, q5, #2 @ r3, shift " + "left 2, get 6, 7, 8, 9\n" + "vmla.f32 q14, q8, %e[w2][1] @ w21 * r3, " + "0, 1, 2, 3\n" + "vmla.f32 q15, q9, %e[w2][1] @ w21 * r3, " + "4, 5, 6, 7\n" + "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, " + "or01\n" + "vld1.32 {d6-d9}, [%[r0]]! @ load r3, 8 " + "float\n" + "vmla.f32 q14, q10, %f[w2][0] @ w22 * r3, " + "2, 3, 4, 5\n" + "vmla.f32 q15, q11, %f[w2][0] @ w22 * r3, " + "6, 7, 8, 9\n" + "vld1.32 {d10}, [%[r0]] @ load r0, 2 " + "float\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or10, " + "or11\n" + + "subs %[cnt], #1 @loop count " + "-1\n" + "bne 0b @ jump to " + "main loop\n" + + : [cnt] "+r"(cnt), + [r0] "+r"(r0), + [r1] "+r"(r1), + [r2] "+r"(r2), + [r3] "+r"(r3), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + r0 -= 8; + } + //! deal with remain ow + if (w_loop & 1) { + ptr_out0[0] += + r0[0] * w_tmp[0] + r0[1] * w_tmp[1] + r0[2] * w_tmp[2] + + r1[0] * w_tmp[3] + r1[1] * w_tmp[4] + r1[2] * w_tmp[5] + + r2[0] * w_tmp[6] + r2[1] * w_tmp[7] + r2[2] * w_tmp[8]; + + ptr_out0[1] += + r0[1] * w_tmp[0] + r0[2] * w_tmp[1] + r0[3] * w_tmp[2] + + r1[1] * w_tmp[3] + r1[2] * w_tmp[4] + r1[3] * w_tmp[5] + + r2[1] * w_tmp[6] + r2[2] * w_tmp[7] + r2[3] * w_tmp[8]; + + ptr_out0[2] += + r0[2] * w_tmp[0] + r0[3] * w_tmp[1] + r0[4] * w_tmp[2] + + r1[2] * w_tmp[3] + r1[3] * w_tmp[4] + r1[4] * w_tmp[5] + + r2[2] * w_tmp[6] + r2[3] * w_tmp[7] + r2[4] * w_tmp[8]; + + ptr_out0[3] += + r0[3] * w_tmp[0] + r0[4] * w_tmp[1] + r0[5] * w_tmp[2] + + r1[3] * w_tmp[3] + r1[4] * w_tmp[4] + r1[5] * w_tmp[5] + + r2[3] * w_tmp[6] + r2[4] * w_tmp[7] + r2[5] * w_tmp[8]; + + ptr_out1[0] += + r1[0] * w_tmp[0] + r1[1] * w_tmp[1] + r1[2] * w_tmp[2] + + r2[0] * w_tmp[3] + r2[1] * w_tmp[4] + r2[2] * w_tmp[5] + + r3[0] * w_tmp[6] + r3[1] * w_tmp[7] + r3[2] * w_tmp[8]; + + ptr_out1[1] += + r1[1] * w_tmp[0] + r1[2] * w_tmp[1] + r1[3] * w_tmp[2] + + r2[1] * w_tmp[3] + r2[2] * w_tmp[4] + r2[3] * w_tmp[5] + + r3[1] * w_tmp[6] + r3[2] * w_tmp[7] + r3[3] * w_tmp[8]; + + ptr_out1[2] += + r1[2] * w_tmp[0] + r1[3] * w_tmp[1] + r1[4] * w_tmp[2] + + r2[2] * w_tmp[3] + r2[3] * w_tmp[4] + r2[4] * w_tmp[5] + + r3[2] * w_tmp[6] + r3[3] * w_tmp[7] + r3[4] * w_tmp[8]; + + ptr_out1[3] += + r1[3] * w_tmp[0] + r1[4] * w_tmp[1] + r1[5] * w_tmp[2] + + r2[3] * w_tmp[3] + r2[4] * w_tmp[4] + r2[5] * w_tmp[5] + + r3[3] * w_tmp[6] + r3[4] * w_tmp[7] + r3[5] * w_tmp[8]; + } + + wc0 += 36; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + } +#endif // __aarch64__ + block_inr0 = block_inr2; + block_inr1 = block_inr3; + block_inr2 = block_inr1 + in_len; + block_inr3 = block_inr2 + in_len; + } + write_to_output_c1_fp32(pre_out, + dout_batch, + c_idx, + c_idx + 1, + h, + h + h_kernel, + 0, + wout_round, + oc, + oh, + ow, + flag_relu, + ptr_write); + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_direct_3x3s2.cc b/lite/arm/math/conv_direct_3x3s2.cc new file mode 100644 index 00000000000..b048f61877d --- /dev/null +++ b/lite/arm/math/conv_direct_3x3s2.cc @@ -0,0 +1,1209 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/conv_block_utils.h" +#include "lite/arm/math/conv_impl.h" +#include "lite/core/context.h" +#ifdef ARM_WITH_OMP +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void conv_3x3s2_direct_fp32(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + //! 3x3s2 convolution, implemented by direct algorithm + //! prepack input to tmp buffer + //! write output to tmp buffer + const int threads = ctx->threads(); + int l2_size = ctx->llc_size() / sizeof(float); + const int pad_w = param.paddings[1]; + const int pad_h = param.paddings[0]; + const int hout_c_block = 4; + const int hout_r_kernel = 2; + const int wout_block = 4; + const int wout_round = ((ow + wout_block - 1) / wout_block) * wout_block; + const int win_round = wout_round * 2 /*stride_w*/ + 1; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + // if (param.activation_param.has_active) { + // if (param.activation_param.active == Active_relu && + // fabs(param.activation_param.negative_slope) < 1e-6f) { + // flag_relu = true; + // } + // } + //! get h block + //! win_round * ic * hin_r_block + wout_round * hout_c_block * hout_r_block + //! * threads = l2_size + //! win_round = 2 * wout_round + 1 + //! hin_r_block = 2 * hout_r_block + 1 + int hout_r_block = + (l2_size - 2 * wout_round * ic - ic) / + ((4 * wout_round + 2) * ic + wout_round * hout_c_block * threads); + hout_r_block = hout_r_block > oh ? oh : hout_r_block; + hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel; + hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + + const int hin_r_block = hout_r_block * 2 /*stride_h*/ + 1; + + float* tmp_work_space = ctx->workspace_data(); + float ptr_zero[win_round]; // NOLINT + memset(ptr_zero, 0, sizeof(float) * win_round); + float ptr_write[wout_round]; // NOLINT + + int in_len = win_round * ic; + int pre_in_size = hin_r_block * in_len; + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + //! l2_cache start + float* pre_din = tmp_work_space; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + int w_stride = ic * 9; /*kernel_w * kernel_h*/ + int w_stride_chin = hout_c_block * 9; // kernel_w * kernel_h * + + int ws = -pad_w; + int we = ws + win_round; + int w_loop = wout_round / 4; + + int c_remain = oc - (oc / hout_c_block) * hout_c_block; + int c_round_down = (oc / hout_c_block) * hout_c_block; + + int out_row_stride = hout_c_block * wout_round; + + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; + for (int h = 0; h < oh; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > oh) { + h_kernel = oh - h; + } + + int hs = h * 2 /*stride_h*/ - pad_h; + int he = hs + h_kernel * 2 /*stride_h*/ + 1; + + prepack_input_nxw( + din_batch, pre_din, 0, ic, hs, he, ws, we, ic, win, ih, ptr_zero); + + const float* cblock_inr0 = pre_din; + const float* cblock_inr1 = cblock_inr0 + in_len; + const float* cblock_inr2 = cblock_inr1 + in_len; + const float* cblock_inr3 = cblock_inr2 + in_len; + const float* cblock_inr4 = cblock_inr3 + in_len; + +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < c_round_down; c += hout_c_block) { +#ifdef ARM_WITH_OMP + float* pre_out = + pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; +#else + float* pre_out = pre_din + pre_in_size; +#endif + const float* block_inr0 = cblock_inr0; + const float* block_inr1 = cblock_inr1; + const float* block_inr2 = cblock_inr2; + const float* block_inr3 = cblock_inr3; + const float* block_inr4 = cblock_inr4; + + const float* weight_c = weights + c * w_stride; + const float* bias_ptr = ptr_zero; + if (flag_bias) { + bias_ptr = bias + c; + } + fill_packed_biasc4( + pre_out, bias_ptr, wout_round * hout_c_block * h_kernel); + + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + const float* wc0 = weight_c; + + const float* inr0 = block_inr0; + const float* inr1 = block_inr1; + const float* inr2 = block_inr2; + const float* inr3 = block_inr3; + const float* inr4 = block_inr4; + + float* pre_out0 = pre_out + hk * out_row_stride; + float* pre_out1 = pre_out0 + out_row_stride; +#ifdef __aarch64__ + for (int i = 0; i < ic; ++i) { + float* ptr_out0 = pre_out0; + float* ptr_out1 = pre_out1; + + float32x4_t w0 = vld1q_f32(wc0); // w0, v23 + float32x4_t w1 = vld1q_f32(wc0 + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(wc0 + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(wc0 + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(wc0 + 16); // w4, v27 + float32x4_t w5 = vld1q_f32(wc0 + 20); // w5, v28 + float32x4_t w6 = vld1q_f32(wc0 + 24); // w6, v29 + float32x4_t w7 = vld1q_f32(wc0 + 28); // w7, v30 + float32x4_t w8 = vld1q_f32(wc0 + 32); // w8, v31 + + const float* r0 = inr0; + const float* r1 = inr1; + const float* r2 = inr2; + const float* r3 = inr3; + const float* r4 = inr4; + + int cnt = w_loop; + asm volatile( + "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, + outr01*/ + "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ + + "ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/ + "ldr d10, [%[r0]] \n" /* load input r0, 9th + element*/ + "ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/ + "ldr d12, [%[r2]] \n" /* load input r2, 9th + element*/ + "2: \n" /* main loop*/ + /* r0, r2, mul w0, get out r0, r1 */ + "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ + "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/ + "fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/ + "fmla v16.4s , %[w0].4s, v0.s[2]\n" /* outr01 = w0 * r0[2]*/ + "fmla v17.4s , %[w0].4s, v1.s[0]\n" /* outr02 = w0 * r0[4]*/ + "fmla v18.4s , %[w0].4s, v1.s[2]\n" /* outr03 = w0 * r0[6]*/ + "fmla v19.4s , %[w0].4s, v4.s[0]\n" /* outr10 = w0 * r2[0]*/ + "fmla v20.4s , %[w0].4s, v4.s[2]\n" /* outr11 = w0 * r2[2]*/ + "fmla v21.4s , %[w0].4s, v5.s[0]\n" /* outr12 = w0 * r2[4]*/ + "fmla v22.4s , %[w0].4s, v5.s[2]\n" /* outr13 = w0 * r2[6]*/ + + "ldp q2, q3, [%[r1]], #32 \n" /* load input r1*/ + + /* r2 mul w6, get out r0*/ + "fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/ + "fmla v16.4s , %[w6].4s, v4.s[2]\n" /* outr01 = w6 * r2[2]*/ + "fmla v17.4s , %[w6].4s, v5.s[0]\n" /* outr02 = w6 * r2[4]*/ + "fmla v18.4s , %[w6].4s, v5.s[2]\n" /* outr03 = w6 * r2[6]*/ + + "ldr d11, [%[r1]] \n" /* load input r1, 9th + element*/ + + /* r0, r2, mul w1, get out r0, r1 */ + "fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/ + "fmla v16.4s , %[w1].4s, v0.s[3]\n" /* outr01 = w1 * r0[3]*/ + "fmla v17.4s , %[w1].4s, v1.s[1]\n" /* outr02 = w1 * r0[5]*/ + "fmla v18.4s , %[w1].4s, v1.s[3]\n" /* outr03 = w1 * r0[7]*/ + "fmla v19.4s , %[w1].4s, v4.s[1]\n" /* outr10 = w1 * r2[1]*/ + "fmla v20.4s , %[w1].4s, v4.s[3]\n" /* outr11 = w1 * r2[3]*/ + "fmla v21.4s , %[w1].4s, v5.s[1]\n" /* outr12 = w1 * r2[5]*/ + "fmla v22.4s , %[w1].4s, v5.s[3]\n" /* outr13 = w1 * r2[7]*/ + + "ldp q6, q7, [%[r3]], #32 \n" /* load input r3*/ + + /* r2 mul w7, get out r0 */ + "fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/ + "fmla v16.4s , %[w7].4s, v4.s[3]\n" /* outr01 = w7 * r2[3]*/ + "fmla v17.4s , %[w7].4s, v5.s[1]\n" /* outr02 = w7 * r2[5]*/ + "fmla v18.4s , %[w7].4s, v5.s[3]\n" /* outr03 = w7 * r2[7]*/ + + "ldr d13, [%[r3]] \n" /* load input r3, 9th + element*/ + + /* r0, r2, mul w2, get out r0, r1 */ + "fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/ + "fmla v16.4s , %[w2].4s, v1.s[0]\n" /* outr01 = w2 * r0[4]*/ + "fmla v17.4s , %[w2].4s, v1.s[2]\n" /* outr02 = w2 * r0[6]*/ + "fmla v18.4s , %[w2].4s, v10.s[0]\n" /* outr03 = w2 * + r0[8]*/ + "fmla v19.4s , %[w2].4s, v4.s[2]\n" /* outr10 = w2 * r2[2]*/ + "fmla v20.4s , %[w2].4s, v5.s[0]\n" /* outr11 = w2 * r2[4]*/ + "fmla v21.4s , %[w2].4s, v5.s[2]\n" /* outr12 = w2 * r2[6]*/ + "fmla v22.4s , %[w2].4s, v12.s[0]\n" /* outr13 = w2 * + r2[8]*/ + + "ldp q8, q9, [%[r4]], #32 \n" /* load input r4*/ + + /* r2, mul w8, get out r0 */ + "fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/ + "fmla v16.4s , %[w8].4s, v5.s[0]\n" /* outr01 = w8 * r2[4]*/ + "fmla v17.4s , %[w8].4s, v5.s[2]\n" /* outr02 = w8 * r2[6]*/ + "fmla v18.4s , %[w8].4s, v12.s[0]\n" /* outr03 = w8 * + r2[8]*/ + + "ldr d14, [%[r4]] \n" /* load input r4, 9th + element*/ + + /* r1, r3, mul w3, get out r0, r1 */ + "fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/ + "fmla v16.4s , %[w3].4s, v2.s[2]\n" /* outr01 = w3 * r1[2]*/ + "fmla v17.4s , %[w3].4s, v3.s[0]\n" /* outr02 = w3 * r1[4]*/ + "fmla v18.4s , %[w3].4s, v3.s[2]\n" /* outr03 = w3 * r1[6]*/ + "fmla v19.4s , %[w3].4s, v6.s[0]\n" /* outr10 = w3 * r3[0]*/ + "fmla v20.4s , %[w3].4s, v6.s[2]\n" /* outr11 = w3 * r3[2]*/ + "fmla v21.4s , %[w3].4s, v7.s[0]\n" /* outr12 = w3 * r3[4]*/ + "fmla v22.4s , %[w3].4s, v7.s[2]\n" /* outr13 = w3 * r3[6]*/ + + "ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/ + + /* r1, r3, mul w4, get out r0, r1 */ + "fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/ + "fmla v16.4s , %[w4].4s, v2.s[3]\n" /* outr01 = w4 * r1[3]*/ + "fmla v17.4s , %[w4].4s, v3.s[1]\n" /* outr02 = w4 * r1[5]*/ + "fmla v18.4s , %[w4].4s, v3.s[3]\n" /* outr03 = w4 * r1[7]*/ + "fmla v19.4s , %[w4].4s, v6.s[1]\n" /* outr10 = w4 * r3[1]*/ + "fmla v20.4s , %[w4].4s, v6.s[3]\n" /* outr11 = w4 * r3[3]*/ + "fmla v21.4s , %[w4].4s, v7.s[1]\n" /* outr12 = w4 * r3[5]*/ + "fmla v22.4s , %[w4].4s, v7.s[3]\n" /* outr13 = w4 * r3[7]*/ + + "ldr d10, [%[r0]] \n" /* load input r0, 9th + element*/ + + /* r1, r3, mul w5, get out r0, r1 */ + "fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/ + "fmla v16.4s , %[w5].4s, v3.s[0]\n" /* outr01 = w5 * r1[4]*/ + "fmla v17.4s , %[w5].4s, v3.s[2]\n" /* outr02 = w5 * r1[6]*/ + "fmla v18.4s , %[w5].4s, v11.s[0]\n" /* outr03 = w5 * + r1[8]*/ + + "ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/ + "stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/ + + "fmla v19.4s , %[w5].4s, v6.s[2]\n" /* outr10 = w5 * r3[2]*/ + "fmla v20.4s , %[w5].4s, v7.s[0]\n" /* outr11 = w5 * r3[4]*/ + "fmla v21.4s , %[w5].4s, v7.s[2]\n" /* outr12 = w5 * r3[6]*/ + "fmla v22.4s , %[w5].4s, v13.s[0]\n" /* outr13 = w5 * + r3[8]*/ + + "ldr d12, [%[r2]] \n" /* load input r2, 9th + element*/ + "stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/ + + /* r4, mul w6, get out r1 */ + "fmla v19.4s , %[w6].4s, v8.s[0]\n" /* outr10 = w6 * r4[0]*/ + "fmla v20.4s , %[w6].4s, v8.s[2]\n" /* outr11 = w6 * r4[2]*/ + "fmla v21.4s , %[w6].4s, v9.s[0]\n" /* outr12 = w6 * r4[4]*/ + "fmla v22.4s , %[w6].4s, v9.s[2]\n" /* outr13 = w6 * r4[6]*/ + + "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/ + + /* r4, mul w7, get out r1 */ + "fmla v19.4s , %[w7].4s, v8.s[1]\n" /* outr10 = w7 * r4[1]*/ + "fmla v20.4s , %[w7].4s, v8.s[3]\n" /* outr11 = w7 * r4[3]*/ + "fmla v21.4s , %[w7].4s, v9.s[1]\n" /* outr12 = w7 * r4[5]*/ + "fmla v22.4s , %[w7].4s, v9.s[3]\n" /* outr13 = w7 * r4[7]*/ + + "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ + + /* r4, mul w8, get out r1 */ + "fmla v19.4s , %[w8].4s, v8.s[2]\n" /* outr10 = w8 * r4[2]*/ + "fmla v20.4s , %[w8].4s, v9.s[0]\n" /* outr11 = w8 * r4[4]*/ + "fmla v21.4s , %[w8].4s, v9.s[2]\n" /* outr12 = w8 * r4[6]*/ + "fmla v22.4s , %[w8].4s, v14.s[0]\n" /* outr13 = w8 * + r4[8]*/ + + "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ + + "stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/ + "stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/ + + "bne 2b \n" /* jump to main loop*/ + + : [cnt] "+r"(cnt), + [r0] "+r"(r0), + [r1] "+r"(r1), + [r2] "+r"(r2), + [r3] "+r"(r3), + [r4] "+r"(r4), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); + + wc0 += 9 * hout_c_block; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + inr4 += win_round; + } +#else // not __aarch64__ + for (int i = 0; i < ic; ++i) { + const float* wc0 = weight_c + i * w_stride_chin; + + float* ptr_out0 = pre_out0; + float* ptr_out1 = pre_out1; + + const float* r0 = inr0; + const float* r1 = inr1; + const float* r2 = inr2; + const float* r3 = inr3; + const float* r4 = inr4; + + int cnt = w_loop; + asm volatile( + "vld1.32 {d16-d19}, [%[ptr_out0]]! @ " + "load outr0, w0, w1, c0~c3\n" + "vld1.32 {d20-d23}, [%[ptr_out0]] @ load " + "outr0, w2, w3, c0~c3\n" + + /* load weights */ + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " + "w1, to q5, q6\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " + "to q7\n" + + /* load r0, r2 */ + "vld1.32 {d0-d3}, [%[r0]]! @ load r0, " + "8 float\n" + "vld1.32 {d8}, [%[r0]] @ load r0, " + "9th float\n" + + "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " + "- 32, to start address\n" + + /* main loop */ + "0: @ main " + "loop\n" + /* mul r0, with w0, w1, w2 */ + "vld1.32 {d24-d27}, [%[ptr_out1]]! @ load " + "outr1, w0, w1, c0~c3\n" + "vmla.f32 q8, q5, d0[0] @ w0 * " + "inr00\n" + "vld1.32 {d28-d31}, [%[ptr_out1]] @ load " + "outr1, w2, w3, c0~c3\n" + "vmla.f32 q9, q5, d1[0] @ w0 * " + "inr02\n" + "vmla.f32 q10, q5, d2[0] @ w0 * " + "inr04\n" + "vmla.f32 q11, q5, d3[0] @ w0 * " + "inr06\n" + "vld1.32 {d4-d7}, [%[r2]]! @ load r2, " + "8 float\n" + "vmla.f32 q8, q6, d0[1] @ w1 * " + "inr01\n" + "vmla.f32 q9, q6, d1[1] @ w1 * " + "inr03\n" + "vmla.f32 q10, q6, d2[1] @ w1 * " + "inr05\n" + "vmla.f32 q11, q6, d3[1] @ w1 * " + "inr07\n" + "vld1.32 {d9}, [%[r2]] @ load r2, " + "9th float\n" + "vmla.f32 q8, q7, d1[0] @ w2 * " + "inr02\n" + "vmla.f32 q9, q7, d2[0] @ w2 * " + "inr04\n" + "vmla.f32 q10, q7, d3[0] @ w2 * " + "inr06\n" + "vmla.f32 q11, q7, d8[0] @ w2 * " + "inr08\n" + + "sub %[r2], %[r2], #32 @ r2 - 32, " + "load r2 twice\n" + + /* mul r2, with w0, w1, w2 */ + "vld1.32 {d0-d3}, [%[r1]]! @ load r1, " + "8 float\n" + "vmla.f32 q12, q5, d4[0] @ w0 * " + "inr20\n" + "vmla.f32 q13, q5, d5[0] @ w0 * " + "inr22\n" + "vmla.f32 q14, q5, d6[0] @ w0 * " + "inr24\n" + "vmla.f32 q15, q5, d7[0] @ w0 * " + "inr26\n" + "vld1.32 {d8}, [%[r1]] @ load r1, " + "9th float\n" + "vmla.f32 q12, q6, d4[1] @ w1 * " + "inr21\n" + "vmla.f32 q13, q6, d5[1] @ w1 * " + "inr23\n" + "vmla.f32 q14, q6, d6[1] @ w1 * " + "inr25\n" + "vmla.f32 q15, q6, d7[1] @ w1 * " + "inr27\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w3, " + "w4, to q5, q6\n" + "vmla.f32 q12, q7, d5[0] @ w2 * " + "inr22\n" + "vmla.f32 q13, q7, d6[0] @ w2 * " + "inr24\n" + "vmla.f32 q14, q7, d7[0] @ w2 * " + "inr26\n" + "vmla.f32 q15, q7, d9[0] @ w2 * " + "inr28\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w5, " + "to q7\n" + + /* mul r1, with w3, w4, w5 */ + "vmla.f32 q8, q5, d0[0] @ w3 * " + "inr10\n" + "vmla.f32 q9, q5, d1[0] @ w3 * " + "inr12\n" + "vmla.f32 q10, q5, d2[0] @ w3 * " + "inr14\n" + "vmla.f32 q11, q5, d3[0] @ w3 * " + "inr16\n" + "vld1.32 {d4-d7}, [%[r3]]! @ load r3, " + "8 float\n" + "vmla.f32 q8, q6, d0[1] @ w4 * " + "inr11\n" + "vmla.f32 q9, q6, d1[1] @ w4 * " + "inr13\n" + "vmla.f32 q10, q6, d2[1] @ w4 * " + "inr15\n" + "vmla.f32 q11, q6, d3[1] @ w4 * " + "inr17\n" + "vld1.32 {d9}, [%[r3]] @ load r3, " + "9th float\n" + "vmla.f32 q8, q7, d1[0] @ w5 * " + "inr12\n" + "vmla.f32 q9, q7, d2[0] @ w5 * " + "inr14\n" + "vmla.f32 q10, q7, d3[0] @ w5 * " + "inr16\n" + "vmla.f32 q11, q7, d8[0] @ w5 * " + "inr18\n" + + "sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 " + "- 32, to start address\n" + + /* mul r3, with w3, w4, w5 */ + "vld1.32 {d0-d3}, [%[r2]]! @ load r2, " + "8 float\n" + "vmla.f32 q12, q5, d4[0] @ w3 * " + "inr30\n" + "vmla.f32 q13, q5, d5[0] @ w3 * " + "inr32\n" + "vmla.f32 q14, q5, d6[0] @ w3 * " + "inr34\n" + "vmla.f32 q15, q5, d7[0] @ w3 * " + "inr36\n" + "vld1.32 {d8}, [%[r2]] @ load r2, " + "9th float\n" + "vmla.f32 q12, q6, d4[1] @ w4 * " + "inr31\n" + "vmla.f32 q13, q6, d5[1] @ w4 * " + "inr33\n" + "vmla.f32 q14, q6, d6[1] @ w4 * " + "inr35\n" + "vmla.f32 q15, q6, d7[1] @ w4 * " + "inr37\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w6, " + "w7, to q5, q6\n" + "vmla.f32 q12, q7, d5[0] @ w5 * " + "inr32\n" + "vmla.f32 q13, q7, d6[0] @ w5 * " + "inr34\n" + "vmla.f32 q14, q7, d7[0] @ w5 * " + "inr36\n" + "vmla.f32 q15, q7, d9[0] @ w5 * " + "inr38\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w8, " + "to q7\n" + + /* mul r2, with w6, w7, w8 */ + "vmla.f32 q8, q5, d0[0] @ w6 * " + "inr20\n" + "vmla.f32 q9, q5, d1[0] @ w6 * " + "inr22\n" + "vmla.f32 q10, q5, d2[0] @ w6 * " + "inr24\n" + "vmla.f32 q11, q5, d3[0] @ w6 * " + "inr26\n" + "vld1.32 {d4-d7}, [%[r4]]! @ load r4, " + "8 float\n" + "vmla.f32 q8, q6, d0[1] @ w7 * " + "inr21\n" + "vmla.f32 q9, q6, d1[1] @ w7 * " + "inr23\n" + "vmla.f32 q10, q6, d2[1] @ w7 * " + "inr25\n" + "vmla.f32 q11, q6, d3[1] @ w7 * " + "inr27\n" + "vld1.32 {d9}, [%[r4]] @ load r4, " + "9th float\n" + "vmla.f32 q8, q7, d1[0] @ w8 * " + "inr22\n" + "vmla.f32 q9, q7, d2[0] @ w8 * " + "inr24\n" + "vmla.f32 q10, q7, d3[0] @ w8 * " + "inr26\n" + "vmla.f32 q11, q7, d8[0] @ w8 * " + "inr28\n" + + "sub %[wc0], %[wc0], #144 @ wc0 - " + "144 to start address\n" + + /* mul r4, with w6, w7, w8 */ + "vld1.32 {d0-d3}, [%[r0]]! @ load r0, " + "8 float\n" + "vmla.f32 q12, q5, d4[0] @ w3 * " + "inr40\n" + "vst1.32 {d16-d19}, [%[ptr_out0]]! @ save " + "r00, r01, c0~c3\n" + "vmla.f32 q13, q5, d5[0] @ w3 * " + "inr42\n" + "vst1.32 {d20-d23}, [%[ptr_out0]]! @ save " + "r02, r03, c0~c3\n" + "vmla.f32 q14, q5, d6[0] @ w3 * " + "inr44\n" + "vmla.f32 q15, q5, d7[0] @ w3 * " + "inr46\n" + "vld1.32 {d8}, [%[r0]] @ load " + "r0, 9th float\n" + "vmla.f32 q12, q6, d4[1] @ w4 * " + "inr41\n" + "vmla.f32 q13, q6, d5[1] @ w4 * " + "inr43\n" + "vmla.f32 q14, q6, d6[1] @ w4 * " + "inr45\n" + "vmla.f32 q15, q6, d7[1] @ w4 * " + "inr47\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " + "w1, to q5, q6\n" + "vmla.f32 q12, q7, d5[0] @ w5 * " + "inr42\n" + "vmla.f32 q13, q7, d6[0] @ w5 * " + "inr44\n" + "vmla.f32 q14, q7, d7[0] @ w5 * " + "inr46\n" + "vmla.f32 q15, q7, d9[0] @ w5 * " + "inr48\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " + "to q7\n" + + "vst1.32 {d24-d27}, [%[ptr_out1]]! @ save " + "r10, r11, c0~c3\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save " + "r12, r13, c0~c3\n" + + "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load " + "outr0, w0, w1, c0~c3\n" + "vld1.32 {d20-d23}, [%[ptr_out0]] @ load " + "outr0, w2, w3, c0~c3\n" + + "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " + "- 32, to start address\n" + + "subs %[cnt], #1 @ loop " + "count--\n" + "bne 0b @ jump to " + "main loop\n" + + : [cnt] "+r"(cnt), + [r0] "+r"(r0), + [r1] "+r"(r1), + [r2] "+r"(r2), + [r3] "+r"(r3), + [r4] "+r"(r4), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1), + [wc0] "+r"(wc0) + : + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + inr4 += win_round; + } +#endif // __aarch64__ + block_inr0 = block_inr4; + block_inr1 = block_inr0 + in_len; + block_inr2 = block_inr1 + in_len; + block_inr3 = block_inr2 + in_len; + block_inr4 = block_inr3 + in_len; + } + + write_to_output_c4_fp32(pre_out, + dout_batch, + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, + oc, + oh, + ow, + flag_relu, + ptr_write); + } + +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < c_remain; ++c) { +#ifdef ARM_WITH_OMP + float* pre_out = + pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; +#else + float* pre_out = pre_din + pre_in_size; +#endif + + const float* block_inr0 = cblock_inr0; + const float* block_inr1 = cblock_inr1; + const float* block_inr2 = cblock_inr2; + const float* block_inr3 = cblock_inr3; + const float* block_inr4 = cblock_inr4; + + //! get weights ptr of remained + const float* weight_c = weights + c_round_down * w_stride; + + //! fill bias to one channel + const float* bias_ptr = ptr_zero; + if (flag_bias) { + bias_ptr = bias + c_round_down + c; + } + fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel); + + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + const float* wc0 = weight_c; + + const float* inr0 = block_inr0; + const float* inr1 = block_inr1; + const float* inr2 = block_inr2; + const float* inr3 = block_inr3; + const float* inr4 = block_inr4; + + float* pre_out0 = pre_out + hk * wout_round; + float* pre_out1 = pre_out0 + wout_round; +#ifdef __aarch64__ + for (int i = 0; i < ic; ++i) { + float* ptr_out0 = pre_out0; + float* ptr_out1 = pre_out1; + + //! get valid weights of current output channel + float32x4_t w0 = vdupq_n_f32(wc0[c]); // w0, v23 + float32x4_t w1 = vdupq_n_f32(wc0[c + 4]); // w1, v24 + float32x4_t w2 = vdupq_n_f32(wc0[c + 8]); // w2, v25 + float32x4_t w3 = vdupq_n_f32(wc0[c + 12]); // w3, v26 + float32x4_t w4 = vdupq_n_f32(wc0[c + 16]); // w4, v27 + float32x4_t w5 = vdupq_n_f32(wc0[c + 20]); // w5, v28 + float32x4_t w6 = vdupq_n_f32(wc0[c + 24]); // w6, v29 + float32x4_t w7 = vdupq_n_f32(wc0[c + 28]); // w7, v30 + float32x4_t w8 = vdupq_n_f32(wc0[c + 32]); // w8, v31 + + const float* r0 = inr0; + const float* r1 = inr1; + const float* r2 = inr2; + const float* r3 = inr3; + const float* r4 = inr4; + + int cnt = w_loop; + asm volatile( + "ldr q21, [%[ptr_out0]] \n" /* load outr00, + outr01, + outr02, + outr03*/ + + "ld2 {v0.4s, v1.4s}, [%[r0]], #32 \n" /* load input r0*/ + "ldr d10, [%[r0]] \n" /* load input r0, 9th + element*/ + "ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/ + "ldr d12, [%[r2]] \n" /* load input r2, 9th + element*/ + "2: \n" /* main loop*/ + /* r0, r2, mul w0, get out r0, r1 */ + "ldr q22, [%[ptr_out1]] \n" /* load outr10, outr11, + outr12, outr13*/ + + "fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0[0, 2, + 4, 6]*/ + "fmla v22.4s , %[w0].4s, v4.4s \n" /* outr1 = w0 * r2[0, 2, + 4, 6]*/ + + "ld2 {v2.4s, v3.4s}, [%[r1]], #32 \n" /* load input r1*/ + + /* r2 mul w6, get out r0*/ + "fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2[0, 2, + 4, 6]*/ + "ldr d11, [%[r1]] \n" /* load input r1, 9th + element*/ + + /* shift left 1 */ + "ext v15.16b, v0.16b, v10.16b, #4\n" /* shift left r0 1*/ + "ext v16.16b, v4.16b, v12.16b, #4\n" /* shift left r2 1*/ + + /* r0, r2, mul w1, get out r0, r1 */ + "fmla v21.4s , %[w1].4s, v1.4s \n" /* outr0 = w1 * r0[1, 3, + 5, 7]*/ + "fmla v22.4s , %[w1].4s, v5.4s \n" /* outr1 = w1 * r2[1, 3, + 5, 7]*/ + + "ld2 {v6.4s, v7.4s}, [%[r3]], #32 \n" /* load input r3*/ + + /* r2 mul w7, get out r0 */ + "fmla v21.4s , %[w7].4s, v5.4s \n" /* outr00 = w7 * r2[1, + 3, 5, 7]*/ + + "ldr d13, [%[r3]] \n" /* load input r3, 9th + element*/ + + /* r0, r2, mul w2, get out r0, r1 */ + "fmla v21.4s , %[w2].4s, v15.4s \n" /* outr0 = w2 * r0[2, 4, + 6, 8]*/ + "fmla v22.4s , %[w2].4s, v16.4s \n" /* outr1 = w2 * r2[2, 4, + 6, 8]*/ + + "ld2 {v8.4s, v9.4s}, [%[r4]], #32 \n" /* load input r4*/ + + /* r2, mul w8, get out r0 */ + "fmla v21.4s , %[w8].4s, v16.4s \n" /* outr00 = w8 * r2[2, + 4, 6, 8]*/ + + "ldr d14, [%[r4]] \n" /* load input r4, 9th + element*/ + + /* r1, r3, mul w3, get out r0, r1 */ + "fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1[0, 2, + 4, 6]*/ + "fmla v22.4s , %[w3].4s, v6.4s \n" /* outr1 = w3 * r3[0, 2, + 4, 6]*/ + + /* shift left 1 */ + "ext v15.16b, v2.16b, v11.16b, #4\n" /* shift left r1 1*/ + "ext v16.16b, v6.16b, v13.16b, #4\n" /* shift left r3 1*/ + + "ld2 {v0.4s, v1.4s}, [%[r0]], #32 \n" /* load input r0*/ + + /* r1, r3, mul w4, get out r0, r1 */ + "fmla v21.4s , %[w4].4s, v3.4s \n" /* outr0 = w4 * r1[1, 3, + 5, 7]*/ + "fmla v22.4s , %[w4].4s, v7.4s \n" /* outr1 = w4 * r3[1, 3, + 5, 7]*/ + + "ldr d10, [%[r0]] \n" /* load input r0, 9th + element*/ + + /* r1, r3, mul w5, get out r0, r1 */ + "fmla v21.4s , %[w5].4s, v15.4s \n" /* outr0 = w5 * r1[2]*/ + "fmla v22.4s , %[w5].4s, v16.4s \n" /* outr1 = w5 * r1[4]*/ + + "ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/ + "ldr d12, [%[r2]] \n" /* load input r2, 9th + element*/ + "str q21, [%[ptr_out0]], #16 \n" /* save outr00, outr01*/ + + /* r4, mul w6, get out r1 */ + "fmla v22.4s , %[w6].4s, v8.4s \n" /* outr1 = w6 * r4[0, 2, + 4, 6]*/ + + "ext v15.16b, v8.16b, v14.16b, #4\n" /* shift left r1 1*/ + "ldr q21, [%[ptr_out0]] \n" /* load outr0*/ + + /* r4, mul w7, get out r1 */ + "fmla v22.4s , %[w7].4s, v9.4s \n" /* outr1 = w7 * r4[1, 3, + 5, 7]*/ + + /* r4, mul w8, get out r1 */ + "fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r4[2, 4, + 6, 8]*/ + + "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ + "str q22, [%[ptr_out1]], #16 \n" /* save outr1*/ + "bne 2b \n" /* jump to main loop*/ + + : [cnt] "+r"(cnt), + [r0] "+r"(r0), + [r1] "+r"(r1), + [r2] "+r"(r2), + [r3] "+r"(r3), + [r4] "+r"(r4), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v21", + "v22"); + + wc0 += 36; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + inr4 += win_round; + } +#else // not __aarch64__ + for (int i = 0; i < ic; ++i) { + float* ptr_out0 = pre_out0; + float* ptr_out1 = pre_out1; + + //! get valid weights of current output channel + float w_tmp[12] = {wc0[c], + wc0[c + 4], + wc0[c + 8], + 0.f, + wc0[c + 12], + wc0[c + 16], + wc0[c + 20], + 0.f, + wc0[c + 24], + wc0[c + 28], + wc0[c + 32], + 0.f}; + float32x4_t w0 = vld1q_f32(w_tmp); // w0, w1, w2, q0 + float32x4_t w1 = vld1q_f32(w_tmp + 4); // w3, w4, w5, q1 + float32x4_t w2 = vld1q_f32(w_tmp + 8); // w6, w7, w8, q2 + + const float* r0 = inr0; + const float* r1 = inr1; + const float* r2 = inr2; + const float* r3 = inr3; + const float* r4 = inr4; + + int cnt = w_loop / 2; + if (cnt > 0) { + asm volatile( + /* main loop */ + "0: @ " + "main loop\n" + "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, " + "or01\n" + "vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, " + "or11\n" + "vld2.32 {d6-d9}, [%[r2]]! @ load r2, 8 " + "float, interleave\n" + "vld2.32 {d10-d13}, [%[r2]]! @ load r2, 8 " + "float, interleave\n" + "vld1.32 {d22}, [%[r2]] @ load 16th " + "float\n" + + /* r2 * w2, r2 * w0, get or0, or1 */ + "vmla.f32 q12, q4, %e[w2][1] @ w21 * r2, " + "1, 3, 5, 7\n" + "vmla.f32 q13, q6, %e[w2][1] @ w21 * r2, " + "9, 11, 13, 15\n" + "vld2.32 {d14-d17}, [%[r0]]! @ load r0, 8 " + "float, interleave\n" + "vmla.f32 q14, q4, %e[w0][1] @ w01 * r2, " + "1, 3, 5, 7\n" + "vmla.f32 q15, q6, %e[w0][1] @ w01 * r2, " + "9, 11, 13, 15\n" + + "vext.32 q4, q3, q5, #1 @ r2, shift " + "left 1, get 2, 4, 6, 8\n" + "vext.32 q6, q5, q11, #1 @ r2, shift " + "left 1, get 10, 12, 14, 16\n" + + "vmla.f32 q12, q3, %e[w2][0] @ w20 * r2, " + "0, 2, 4, 6\n" + "vmla.f32 q13, q5, %e[w2][0] @ w20 * r2, " + "8, 10, 12, 14\n" + "vld2.32 {d18-d21}, [%[r0]]! @ load r0, 8 " + "float, interleave\n" + "vmla.f32 q14, q3, %e[w0][0] @ w00 * r2, " + "0, 2, 4, 6\n" + "vmla.f32 q15, q5, %e[w0][0] @ w00 * r2, " + "8, 10, 12, 14\n" + + "vld1.32 {d22}, [%[r0]] @ load 16th " + "float\n" + + "vmla.f32 q12, q4, %f[w2][0] @ w22 * r2, " + "2, 4, 6, 8\n" + "vmla.f32 q14, q4, %f[w0][0] @ w02 * r2, " + "2, 4, 6, 8\n" + "vld2.32 {d6-d9}, [%[r3]]! @ load r3, 8 " + "float, interleave\n" + "vmla.f32 q13, q6, %f[w2][0] @ w22 * r2, " + "10, 12, 14, 16\n" + "vmla.f32 q15, q6, %f[w0][0] @ w02 * r2, " + "10, 12, 14, 16\n" + "vld2.32 {d10-d13}, [%[r3]]! @ load r3, 8 " + "float, interleave\n" + + /* r0 * w0, get or0, r3 * w1, get or1*/ + "vmla.f32 q12, q8, %e[w0][1] @ w01 * r0, " + "1, 3, 5, 7\n" + "vmla.f32 q13, q10, %e[w0][1] @ w01 * r0, " + "9, 11, 13, 15\n" + "vext.32 q8, q7, q9, #1 @ r0, shift " + "left 1, get 2, 4, 6, 8\n" + "vext.32 q10, q9, q11, #1 @ r0, shift " + "left 1, get 10, 12, 14, 16\n" + "vld1.32 {d22}, [%[r3]] @ load 16th " + "float\n" + "vmla.f32 q14, q4, %e[w1][1] @ w11 * r3, " + "1, 3, 5, 7\n" + "vmla.f32 q15, q6, %e[w1][1] @ w11 * r3, " + "9, 11, 13, 15\n" + + "vmla.f32 q12, q7, %e[w0][0] @ w00 * r0, " + "0, 2, 4, 6\n" + "vmla.f32 q13, q9, %e[w0][0] @ w00 * r0, " + "8, 10, 12, 14\n" + "vext.32 q4, q3, q5, #1 @ r3, shift " + "left 1, get 2, 4, 6, 8\n" + "vext.32 q6, q5, q11, #1 @ r3, shift " + "left 1, get 10, 12, 14, 16\n" + "vmla.f32 q14, q3, %e[w1][0] @ w10 * r3, " + "0, 2, 4, 6\n" + "vmla.f32 q15, q5, %e[w1][0] @ w10 * r3, " + "8, 10, 12, 14\n" + + "vmla.f32 q12, q8, %f[w0][0] @ w02 * r0, " + "2, 4, 6, 8\n" + "vld2.32 {d14-d17}, [%[r1]]! @ load r1, 8 " + "float, interleave\n" + "vmla.f32 q13, q10,%f[w0][0] @ w02 * r0, " + "10, 12, 14, 16\n" + "vld2.32 {d18-d21}, [%[r1]]! @ load r1, 8 " + "float, interleave\n" + "vmla.f32 q14, q4, %f[w1][0] @ w12 * r3, " + "2, 4, 6, 8\n" + "vld2.32 {d6-d9}, [%[r4]]! @ load r4, 8 " + "float, interleave\n" + "vmla.f32 q15, q6, %f[w1][0] @ w12 * r3, " + "10, 12, 14, 16\n" + "vld2.32 {d10-d13}, [%[r4]]! @ load r4, 8 " + "float, interleave\n" + + "vld1.32 {d22}, [%[r1]] @ load 16th " + "float\n" + + /* r1 * w1, get or0, r4 * w2, get or1 */ + "vmla.f32 q12, q8, %e[w1][1] @ w11 * r1, " + "1, 3, 5, 7\n" + "vmla.f32 q13, q10, %e[w1][1] @ w11 * r1, " + "9, 11, 13, 15\n" + "vext.32 q8, q7, q9, #1 @ r1, shift " + "left 1, get 2, 4, 6, 8\n" + "vext.32 q10, q9, q11, #1 @ r1, shift " + "left 1, get 10, 12, 14, 16\n" + "vmla.f32 q14, q4, %e[w2][1] @ w21 * r4, " + "1, 3, 5, 7\n" + "vmla.f32 q15, q6, %e[w2][1] @ w21 * r4, " + "9, 11, 13, 15\n" + "vld1.32 {d22}, [%[r4]] @ load 16th " + "float\n" + + "vmla.f32 q12, q7, %e[w1][0] @ w10 * r1, " + "0, 2, 4, 6\n" + "vmla.f32 q13, q9, %e[w1][0] @ w10 * r1, " + "8, 10, 12, 14\n" + "vext.32 q4, q3, q5, #1 @ r1, shift " + "left 1, get 2, 4, 6, 8\n" + "vext.32 q6, q5, q11, #1 @ r1, shift " + "left 1, get 10, 12, 14, 16\n" + "vmla.f32 q14, q3, %e[w2][0] @ w20 * r4, " + "0, 2, 4, 6\n" + "vmla.f32 q15, q5, %e[w2][0] @ w20 * r4, " + "8, 10, 12, 14\n" + + "vmla.f32 q12, q8, %f[w1][0] @ w12 * r1, " + "2, 4, 6, 8\n" + "vmla.f32 q13, q10, %f[w1][0] @ w12 * r1, " + "10, 12, 14, 16\n" + "vmla.f32 q14, q4, %f[w2][0] @ w22 * r4, " + "2, 4, 6, 8\n" + "vmla.f32 q15, q6, %f[w2][0] @ w22 * r4, " + "10, 12, 14, 16\n" + + "vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or0\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or0\n" + + "subs %[cnt], #1 @loop count " + "-1\n" + "bne 0b @ jump to " + "main loop\n" + + : [cnt] "+r"(cnt), + [r0] "+r"(r0), + [r1] "+r"(r1), + [r2] "+r"(r2), + [r3] "+r"(r3), + [r4] "+r"(r4), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } + //! deal with remain ow + if (w_loop & 1) { + ptr_out0[0] += + r0[0] * w_tmp[0] + r0[1] * w_tmp[1] + r0[2] * w_tmp[2] + + r1[0] * w_tmp[4] + r1[1] * w_tmp[5] + r1[2] * w_tmp[6] + + r2[0] * w_tmp[8] + r2[1] * w_tmp[9] + r2[2] * w_tmp[10]; + + ptr_out0[1] += + r0[2] * w_tmp[0] + r0[3] * w_tmp[1] + r0[4] * w_tmp[2] + + r1[2] * w_tmp[4] + r1[3] * w_tmp[5] + r1[4] * w_tmp[6] + + r2[2] * w_tmp[8] + r2[3] * w_tmp[9] + r2[4] * w_tmp[10]; + + ptr_out0[2] += + r0[4] * w_tmp[0] + r0[5] * w_tmp[1] + r0[6] * w_tmp[2] + + r1[4] * w_tmp[4] + r1[5] * w_tmp[5] + r1[6] * w_tmp[6] + + r2[4] * w_tmp[8] + r2[5] * w_tmp[9] + r2[6] * w_tmp[10]; + + ptr_out0[3] += + r0[6] * w_tmp[0] + r0[7] * w_tmp[1] + r0[8] * w_tmp[2] + + r1[6] * w_tmp[4] + r1[7] * w_tmp[5] + r1[8] * w_tmp[6] + + r2[6] * w_tmp[8] + r2[7] * w_tmp[9] + r2[8] * w_tmp[10]; + + ptr_out1[0] += + r2[0] * w_tmp[0] + r2[1] * w_tmp[1] + r2[2] * w_tmp[2] + + r3[0] * w_tmp[4] + r3[1] * w_tmp[5] + r3[2] * w_tmp[6] + + r4[0] * w_tmp[8] + r4[1] * w_tmp[9] + r4[2] * w_tmp[10]; + + ptr_out1[1] += + r2[2] * w_tmp[0] + r2[3] * w_tmp[1] + r2[4] * w_tmp[2] + + r3[2] * w_tmp[4] + r3[3] * w_tmp[5] + r3[4] * w_tmp[6] + + r4[2] * w_tmp[8] + r4[3] * w_tmp[9] + r4[4] * w_tmp[10]; + + ptr_out1[2] += + r2[4] * w_tmp[0] + r2[5] * w_tmp[1] + r2[6] * w_tmp[2] + + r3[4] * w_tmp[4] + r3[5] * w_tmp[5] + r3[6] * w_tmp[6] + + r4[4] * w_tmp[8] + r4[5] * w_tmp[9] + r4[6] * w_tmp[10]; + + ptr_out1[3] += + r2[6] * w_tmp[0] + r2[7] * w_tmp[1] + r2[8] * w_tmp[2] + + r3[6] * w_tmp[4] + r3[7] * w_tmp[5] + r3[8] * w_tmp[6] + + r4[6] * w_tmp[8] + r4[7] * w_tmp[9] + r4[8] * w_tmp[10]; + } + + wc0 += 36; + inr0 += win_round; + inr1 += win_round; + inr2 += win_round; + inr3 += win_round; + inr4 += win_round; + } +#endif // __aarch64__ + block_inr0 = block_inr4; + block_inr1 = block_inr0 + in_len; + block_inr2 = block_inr1 + in_len; + block_inr3 = block_inr2 + in_len; + block_inr4 = block_inr3 + in_len; + } + write_to_output_c1_fp32(pre_out, + dout_batch, + c + c_round_down, + c + c_round_down + 1, + h, + h + h_kernel, + 0, + wout_round, + oc, + oh, + ow, + flag_relu, + ptr_write); + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_gemmlike.cc b/lite/arm/math/conv_gemmlike.cc new file mode 100644 index 00000000000..f78c366d0f9 --- /dev/null +++ b/lite/arm/math/conv_gemmlike.cc @@ -0,0 +1,285 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/conv_gemmlike.h" +#include +#include "lite/arm/math/gemm_prepacked_int8.h" +#include "lite/arm/math/packed_sgemm.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +/********************* Gemmlike Conv Precision Is Float ***********************/ +template <> +bool GemmLikeConv::create(const operators::ConvParam& param, + ARMContext* ctx) { + this->ctx_ = ctx; + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int ow = o_dims[3]; + int oh = o_dims[2]; + int oc = o_dims[1]; + int kw = w_dims[3]; + int kh = w_dims[2]; + int sw = param.strides[1]; + int sh = param.strides[0]; + int pw = param.paddings[1]; + int ph = param.paddings[0]; + int dw = param.dilations[1]; + int dh = param.dilations[0]; + + int m = oc / param.groups; + int k = ic * kh * kw / param.groups; + int n = oh * ow; + bool kps_equal = (pw == ph) && (sw == sh) && (kw == kh); + bool ks_equal = (sw == sh) && (kw == kh); + //! select conv gemmlike kernel + if (kw == 1 && sw == 1 && pw == 0 && kps_equal) { + //! 1x1s1p0 gemmlike conv + impl_ = conv1x1s1_gemm; + } else { + //! otherwise case + if (kw == 3 && sw == 1 && n > 1 && ks_equal) { + idx_data_.Resize({1, 1, 1, n * kh * kw}); + int* idx_out = idx_data_.mutable_data(); + for (int i = 0; i < oh; ++i) { + for (int j = 0; j < ow; ++j) { + compute_offset(idx_out, i, j, kh, kw, ih, iw, ph, pw, dh, dw); + idx_out += kh * kw; + } + } + } + //! im2col gemmlike conv + impl_ = conv_im2col_gemm; + this->ctx_->ExtendWorkspace(k * n * sizeof(float)); + } + + if (n > 1) { + int hblock = get_hblock(this->ctx_->arch()); + int m_roundup = hblock * ((m + hblock - 1) / hblock); + int group_size_round_up = ((m_roundup * k + 15) / 16) * 16; + float* w_trans_ptr = nullptr; + weights_trans_.Resize({1, 1, 1, group_size_round_up * param.groups}); + w_trans_ptr = weights_trans_.mutable_data(); + const auto* w_data = param.filter->data(); + for (int g = 0; g < param.groups; ++g) { + const float* weights_group = w_data + g * m * k; + float* weights_trans_ptr = w_trans_ptr + g * group_size_round_up; + prepackA(weights_trans_ptr, + weights_group, + 1.f, + k, + 0, + m, + 0, + k, + false, + this->ctx_); + } + is_weights_transed_ = true; + } + return true; +} + +template <> +bool GemmLikeConv::init(const operators::ConvParam& param, + ARMContext* ctx) { + this->ctx_ = ctx; + return create(param, ctx); +} + +template <> +bool GemmLikeConv::run(const operators::ConvParam& param) { + // start timer + const auto* i_data = param.x->data(); + const auto* w_data = param.filter->data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + auto* o_data = param.output->mutable_data(); + const int* idx_data = idx_data_.mutable_data(); + + if (is_weights_transed_) { + w_data = weights_trans_.data(); + } + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + + impl_(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + this->ctx_, + idx_data); + + // timer end + return true; +} + +/********************* Gemmlike Conv Precision Is Int8 ************************/ +template +bool GemmLikeConvInt8::create(const operators::ConvParam& param, + ARMContext* ctx) { + this->ctx_ = ctx; + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int ow = o_dims[3]; + int oh = o_dims[2]; + int oc = o_dims[1]; + int kw = w_dims[3]; + int kh = w_dims[2]; + int sw = param.strides[1]; + int sh = param.strides[0]; + int pw = param.paddings[1]; + int ph = param.paddings[0]; + int dw = param.dilations[1]; + int dh = param.dilations[0]; + + int m = oc / param.groups; + int k = ic * kh * kw / param.groups; + int n = oh * ow; + w_scale_ = param.weight_scale; + //! update weights scale + if (Ptype_out == PRECISION(kInt8) || Ptype_out == PRECISION(kFloat)) { + CHECK_EQ(this->w_scale_.size(), oc) << "weights scale size must be chout"; + float input_scale = param.input_scale; + for (auto& w_s : w_scale_) { + w_s *= input_scale; + if (Ptype_out == PRECISION(kInt8)) { + w_s /= param.output_scale; + } + } + } + + bool kps_equal = (pw == ph) && (sw == sh) && (kw == kh); + bool ks_equal = (sw == sh) && (kw == kh); + //! select conv gemmlike kernel + if (kw == 1 && sw == 1 && pw == 0 && kps_equal) { + //! 1x1s1p0 gemmlike conv + impl_int8_ = conv1x1s1_gemm_int8; + } else { + //! otherwise case + if (kw == 3 && sw == 1 && n > 1 && ks_equal) { + idx_data_.Resize({1, 1, 1, n * kh * kw}); + int* idx_out = idx_data_.mutable_data(); + for (int i = 0; i < oh; ++i) { + for (int j = 0; j < ow; ++j) { + compute_offset(idx_out, i, j, kh, kw, ih, iw, ph, pw, dh, dw); + idx_out += kh * kw; + } + } + } + //! im2col gemmlike conv + impl_int8_ = conv_im2col_gemm_int8; + this->ctx_->ExtendWorkspace(k * n); + } + + if (n > 1) { + prepackA_int8(&this->weights_trans_, + *param.filter, + m, + k, + param.groups, + false, + this->ctx_); + this->is_weights_transed_ = true; + } + return true; +} + +template +bool GemmLikeConvInt8::init(const operators::ConvParam& param, + ARMContext* ctx) { + this->ctx_ = ctx; + return create(param, ctx); +} + +template +bool GemmLikeConvInt8::run(const operators::ConvParam& param) { + const auto* i_data = param.x->data(); + const auto* w_data = param.filter->data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + auto* o_data = param.output->mutable_data(); + const int32_t* idx_data = idx_data_.mutable_data(); + + if (this->is_weights_transed_ == true) { + w_data = this->weights_trans_.template data(); + } + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + + impl_int8_(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + this->ctx_, + Ptype_out, + this->w_scale_.data(), + idx_data); + + return true; +} + +template class GemmLikeConvInt8; +template class GemmLikeConvInt8; +template class GemmLikeConvInt8; + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_gemmlike.h b/lite/arm/math/conv_gemmlike.h new file mode 100644 index 00000000000..872af2b7cbf --- /dev/null +++ b/lite/arm/math/conv_gemmlike.h @@ -0,0 +1,108 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/core/target_wrapper.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +class GemmLikeConv + : public ImplBase { + public: + typedef void (*conv_im2col_gemm_impl)(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const int* idx_ptr); + + GemmLikeConv() = default; + ~GemmLikeConv() {} + + virtual bool init(const operators::ConvParam& param, ARMContext* ctx) { + LOG(FATAL) << "GemmLikeConv::init() not implemented."; + } + + virtual bool create(const operators::ConvParam& param, ARMContext* ctx) { + LOG(FATAL) << "GemmLikeConv::create() not implemented."; + } + + virtual bool run(const operators::ConvParam& param) { + LOG(FATAL) << "GemmLikeConv::run() not implemented."; + } + + protected: + bool is_weights_transed_{false}; + Tensor idx_data_; + Tensor weights_trans_; + + private: + conv_im2col_gemm_impl impl_{nullptr}; +}; + +template +class GemmLikeConvInt8 : public GemmLikeConv { + public: + typedef void (*conv_im2col_gemm_int8_impl)(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + ARMContext* ctx, + PrecisionType out_type, + const float* scale, + const int* idx_ptr); + + GemmLikeConvInt8() = default; + ~GemmLikeConvInt8() {} + + virtual bool init(const operators::ConvParam& param, ARMContext* ctx); + + virtual bool create(const operators::ConvParam& param, ARMContext* ctx); + + virtual bool run(const operators::ConvParam& param); + + private: + conv_im2col_gemm_int8_impl impl_int8_{nullptr}; + std::vector w_scale_; +}; + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_impl.cc b/lite/arm/math/conv_impl.cc new file mode 100644 index 00000000000..7b4ac255680 --- /dev/null +++ b/lite/arm/math/conv_impl.cc @@ -0,0 +1,900 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #include "saber/funcs/impl/arm/neon/impl/conv_arm_depthwise.h" +// #include "saber/funcs/impl/arm/neon/impl/conv_arm_impl.h" +// #include "saber/funcs/impl/arm/neon/impl/gemm_prepacked_int8.h" +// #include "saber/funcs/impl/arm/neon/impl/gemv_arm_int8.h" +// #include "saber/funcs/impl/arm/neon/impl/sgemv_arm.h" + +#include "lite/arm/math/conv_impl.h" +#include +#include "lite/arm/math/gemm_prepacked_int8.h" +#include "lite/arm/math/gemv_arm_int8.h" +#include "lite/arm/math/packed_sgemm.h" +#include "lite/arm/math/sgemv.h" +#include "lite/core/context.h" +#include "lite/core/target_wrapper.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +/** + * \brief neon implementation to add bias + * @param tensor + * @param bias + * @param channel + * @param channel_size + */ +void fill_bias(float* tensor, + const float* bias, + int channel, + int channel_size) { + if (tensor == nullptr) { + return; + } + float* data = tensor; + + for (int j = 0; j < channel; ++j) { + float32x4_t vdata = vdupq_n_f32(bias[j]); + int i = 0; + for (; i < channel_size - 3; i += 4) { + vst1q_f32(data + i, vdata); + } + for (; i < channel_size; i++) { + data[i] = bias[j]; + } + data += channel_size; + } +} + +void fill_bias_int8(int* tensor, + const int* bias, + int channel, + int channel_size) { + if (tensor == nullptr) { + return; + } + int* data = tensor; + for (int j = 0; j < channel; ++j) { + int32x4_t vdata = vdupq_n_s32(bias[j]); + int i = 0; + for (; i < channel_size - 3; i += 4) { + vst1q_s32(data + i, vdata); + } + for (; i < channel_size; i++) { + data[i] = bias[j]; + } + data += channel_size; + } +} + +/** + * \brief inline funcs used in im2col + * @param a + * @param b + * @return + */ +inline bool is_a_ge_zero_and_a_lt_b(int a, int b) { + return static_cast(a) < static_cast(b); +} + +/** + * \brief normal im2col function for gemm conv + * @tparam dtype + * @param data_im + * @param channels + * @param height + * @param width + * @param kernel_size + * @param pad + * @param stride + * @param data_col + */ +template +void im2col(const Dtype* data_im, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + Dtype* data_col) { + const int output_h = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int output_w = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + const int channel_size = height * width; + for (int channel = channels; channel--; data_im += channel_size) { + for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_row = -pad_h + kernel_row * dilation_h; + for (int output_rows = output_h; output_rows; output_rows--) { + if (!is_a_ge_zero_and_a_lt_b(input_row, height)) { + for (int output_cols = output_w; output_cols; output_cols--) { + *(data_col++) = 0; + } + } else { + int input_col = -pad_w + kernel_col * dilation_w; + for (int output_col = output_w; output_col; output_col--) { + if (is_a_ge_zero_and_a_lt_b(input_col, width)) { + *(data_col++) = data_im[input_row * width + input_col]; + } else { + *(data_col++) = 0; + } + input_col += stride_w; + } + } + input_row += stride_h; + } + } + } + } +} +void compute_offset(int* idx_out, + int h, + int w, + int kernel_h, + int kernel_w, + int height, + int width, + int pad_h, + int pad_w, + int dilation_h, + int dilation_w) { + int idx_h[kernel_h]; // NOLINT + int idx_w[kernel_w]; // NOLINT + for (int i = 0; i < kernel_h; ++i) { + idx_h[i] = h - pad_h + i * dilation_h; + } + for (int i = 0; i < kernel_w; ++i) { + idx_w[i] = w - pad_w + i * dilation_w; + } + for (int k_h = 0; k_h < kernel_h; ++k_h) { + for (int k_w = 0; k_w < kernel_w; ++k_w) { + idx_out[k_h * kernel_w + k_w] = + (idx_h[k_h] >= 0 && idx_w[k_w] >= 0 && idx_h[k_h] < height && + idx_w[k_w] < width) + ? idx_h[k_h] * width + idx_w[k_w] + : -1; + } + } +} +template +void im2col3x3(const Dtype* data_im, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + Dtype* data_col, + const int* idx) { + const int output_h = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int output_w = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + int kernel_stride = kernel_h * kernel_w; + int in_channel_stride = height * width; + const int* idx_out = idx; + Dtype* data_col_ptr = data_col; + + bool flag_continue = false; + if (dilation_h == 1 && dilation_w == 1) { + flag_continue = true; + } + + for (int o = 0; o < output_h * output_w; o += 1) { + const Dtype* data_im_ptr = data_im; + + // int* idx_out_d = idx_out; + + int idx_out_d0 = idx_out[0]; + int idx_out_d1 = idx_out[1]; + int idx_out_d2 = idx_out[2]; + int idx_out_d3 = idx_out[3]; + int idx_out_d4 = idx_out[4]; + int idx_out_d5 = idx_out[5]; + int idx_out_d6 = idx_out[6]; + int idx_out_d7 = idx_out[7]; + int idx_out_d8 = idx_out[8]; + + for (int i = 0; i < channels; i += 1) { + if (idx_out_d0 >= 0 && idx_out_d2 >= 0 && idx_out_d6 >= 0 && + idx_out_d8 >= 0) { + if (flag_continue) { + memcpy( + data_col_ptr, data_im_ptr + idx_out_d0, kernel_w * sizeof(Dtype)); + memcpy(data_col_ptr + kernel_w, + data_im_ptr + idx_out_d3, + kernel_w * sizeof(Dtype)); + memcpy(data_col_ptr + kernel_w + kernel_w, + data_im_ptr + idx_out_d6, + kernel_w * sizeof(Dtype)); + } else { + data_col_ptr[0] = data_im_ptr[idx_out_d0]; + data_col_ptr[1] = data_im_ptr[idx_out_d1]; + data_col_ptr[2] = data_im_ptr[idx_out_d2]; + data_col_ptr[3] = data_im_ptr[idx_out_d3]; + data_col_ptr[4] = data_im_ptr[idx_out_d4]; + data_col_ptr[5] = data_im_ptr[idx_out_d5]; + data_col_ptr[6] = data_im_ptr[idx_out_d6]; + data_col_ptr[7] = data_im_ptr[idx_out_d7]; + data_col_ptr[8] = data_im_ptr[idx_out_d8]; + } + } else { + data_col_ptr[0] = (idx_out_d0 < 0) ? 0 : data_im_ptr[idx_out_d0]; + data_col_ptr[1] = (idx_out_d1 < 0) ? 0 : data_im_ptr[idx_out_d1]; + data_col_ptr[2] = (idx_out_d2 < 0) ? 0 : data_im_ptr[idx_out_d2]; + data_col_ptr[3] = (idx_out_d3 < 0) ? 0 : data_im_ptr[idx_out_d3]; + data_col_ptr[4] = (idx_out_d4 < 0) ? 0 : data_im_ptr[idx_out_d4]; + data_col_ptr[5] = (idx_out_d5 < 0) ? 0 : data_im_ptr[idx_out_d5]; + data_col_ptr[6] = (idx_out_d6 < 0) ? 0 : data_im_ptr[idx_out_d6]; + data_col_ptr[7] = (idx_out_d7 < 0) ? 0 : data_im_ptr[idx_out_d7]; + data_col_ptr[8] = (idx_out_d8 < 0) ? 0 : data_im_ptr[idx_out_d8]; + } + data_im_ptr += height * width; + data_col_ptr += kernel_stride; + } + // data_col_ptr += channels * kernel_stride; + // idx_out += kernel_stride * 2; + idx_out += kernel_stride; + } +} + +/** + * \brief convolution function for kernel size 1x1, stride size 1, gemm + * implementation + */ +void conv1x1s1_gemm(const float* i_data, + float* o_data, + int num, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const int* idx_ptr) { + int channel_size_out = ow * oh; + int channel_size_in = win * ih; + + const int group = param.groups; + const int m = oc / group; + const int n = oh * ow; + const int k = ic / group; + + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + // if (param.activation_param.has_active) { + // if (param.activation_param.active == Active_relu && + // fabs(param.activation_param.negative_slope) < 1e-6f) { + // flag_relu = true; + // } + // } + int hblock = get_hblock(ctx->arch()); + int m_roundup = hblock * ((m + hblock - 1) / hblock); + int weights_size_per_group = m * k; + if (n > 1) { + weights_size_per_group = ((m_roundup * k + 15) / 16) * 16; + } + + // int weights_size_per_group = m_roundup * k;//oc * ic / (group * + // group); + //! use gemv when the output channel size = 1 + for (int b = 0; b < num; ++b) { + // dC + for (int g = 0; g < group; ++g) { + float* dout_group = + static_cast(o_data) + (b * oc + g * m) * channel_size_out; + const float* din_group = static_cast(i_data) + + (b * ic + g * k) * channel_size_in; + const float* weights_group = + static_cast(weights) + g * weights_size_per_group; + const float* bias_group = static_cast(bias) + g * m; + + if (n == 1) { + sgemv(weights_group, + din_group, + dout_group, + false, + m, + k, + flag_bias, + bias_group, + flag_relu); + } else { + sgemm_prepack(false, + m, + n, + k, + weights_group, + din_group, + n, + 0.f, + dout_group, + n, + bias_group, + flag_bias, + flag_relu, + ctx); + } + } + } +} + +void conv1x1s1_gemm_int8(const int8_t* i_data, + int32_t* o_data, + int num, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + ARMContext* ctx, + PrecisionType out_type, + const float* scale, + const int32_t* idx_ptr) { + int group = param.groups; + int channel_size_out = ow * oh; + int channel_size_in = win * ih; + const int m = oc / group; + const int n = oh * ow; + const int k = ic / group; + int hblock = get_hblock_int8(ctx); + int k_roundup = ROUNDUP(k, KBLOCK_INT8); + int m_roundup = ROUNDUP(m, hblock); + int weights_size_per_group = m * k; + if (n > 1) { + weights_size_per_group = ((m_roundup * k_roundup + 15) / 16) * 16; + } + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + //! use gemv when the output channel size = 1 + for (int b = 0; b < num; ++b) { + // dC + for (int g = 0; g < group; ++g) { + signed char* dout_group = + reinterpret_cast(o_data) + + (b * oc + g * m) * channel_size_out * PrecisionTypeLength(out_type); + const int8_t* din_group = i_data + (b * ic + g * k) * channel_size_in; + const int8_t* weights_group = weights + g * weights_size_per_group; + const int* bias_group = bias + g * m; + const float* scale_group = scale + g * m; + if (n == 1) { + if (out_type == PRECISION(kFloat)) { + gemv_int8(weights_group, + din_group, + reinterpret_cast(dout_group), + false, + m, + k, + scale_group, + flag_bias, + bias_group, + flag_relu); + } else if (out_type == PRECISION(kInt8)) { // int8 + gemv_int8(weights_group, + din_group, + dout_group, + false, + m, + k, + scale_group, + flag_bias, + bias_group, + flag_relu); + } else { + gemv_int8(weights_group, + din_group, + reinterpret_cast(dout_group), + false, + m, + k, + scale_group, + flag_bias, + bias_group, + flag_relu); + } + } else { + if (out_type == PRECISION(kFloat)) { + gemm_prepack_int8(weights_group, + din_group, + bias_group, + reinterpret_cast(dout_group), + m, + n, + k, + flag_bias, + flag_relu, + false, + scale_group, + ctx); + } else if (out_type == PRECISION(kInt8)) { // int8 + gemm_prepack_int8(weights_group, + din_group, + bias_group, + dout_group, + m, + n, + k, + flag_bias, + flag_relu, + false, + scale_group, + ctx); + } else { + gemm_prepack_int8(weights_group, + din_group, + bias_group, + reinterpret_cast(dout_group), + m, + n, + k, + flag_bias, + flag_relu, + false, + scale_group, + ctx); + } + } + } + } +} + +/** + * \brief convolution function for kernel size 3x3, stride size 2, gemm + * implementation + */ +void conv_im2col_gemm(const float* i_data, + float* o_data, + int num, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx, + const int* idx_ptr) { + const int group = param.groups; + auto filter_dims = param.filter->dims(); + const int kernel_h = filter_dims[2]; + const int kernel_w = filter_dims[3]; // nchw + const int m = oc / group; + const int n = oh * ow; + const int k = ic * kernel_h * kernel_w / group; + const int chin_per_group = ic / group; + int channel_size_out = ow * oh; + int channel_size_in = win * ih; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + // if (param.activation_param.has_active) { + // if (param.activation_param.active == Active_relu && + // fabs(param.activation_param.negative_slope) < 1e-6f) { + // flag_relu = true; + // } + // } + int hblock = get_hblock(ctx->arch()); + int m_roundup = hblock * ((m + hblock - 1) / hblock); + int weights_size_per_group = m * k; + if (n > 1) { + weights_size_per_group = ((m_roundup * k + 15) / 16) * 16; + } + + bool flag_im2col2 = (kernel_h == 3 && kernel_w == 3 && + param.strides[0] == 1 && param.strides[1] == 1 && n > 1); + + float* tmp_work_space = + ctx->workspace_data() + ctx->llc_size() / sizeof(float); + + //! use gemv when the output channel size = 1 + for (int b = 0; b < num; ++b) { + // dC + for (int g = 0; g < group; ++g) { + float* dout_group = o_data + (b * oc + g * m) * channel_size_out; + const float* din_group = + i_data + (b * ic + g * chin_per_group) * channel_size_in; + const float* weights_group = weights + g * weights_size_per_group; + const float* bias_group = bias + g * m; + float* dB = tmp_work_space; + + if (flag_im2col2) { + im2col3x3(din_group, + chin_per_group, + ih, + win, + kernel_h, + kernel_w, + param.paddings[0], + param.paddings[1], + param.strides[0], + param.strides[1], + param.dilations[0], + param.dilations[1], + dB, + idx_ptr); + } else { + im2col(din_group, + chin_per_group, + ih, + win, + kernel_h, + kernel_w, + param.paddings[0], + param.paddings[1], + param.strides[0], + param.strides[1], + param.dilations[0], + param.dilations[1], + dB); + } + if (n == 1) { + sgemv(weights_group, + dB, + dout_group, + false, + m, + k, + flag_bias, + bias_group, + flag_relu); + } else { + int ldb = n; + if (flag_im2col2) { + ldb = k; + } + sgemm_prepack(flag_im2col2, + m, + n, + k, + weights_group, + dB, + ldb, + 0.f, + dout_group, + n, + bias_group, + flag_bias, + flag_relu, + ctx); + } + } + } +} + +void conv_im2col_gemm_int8(const int8_t* i_data, + int32_t* o_data, + int num, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + ARMContext* ctx, + PrecisionType out_type, + const float* scale, + const int32_t* idx_ptr) { + int group = param.groups; + auto filter_dims = param.filter->dims(); + int kernel_h = filter_dims[2]; + int kernel_w = filter_dims[3]; + int stride_h = param.strides[0]; + int stride_w = param.strides[1]; + int dila_h = param.dilations[0]; + int dila_w = param.dilations[1]; + int pad_h = param.paddings[0]; + int pad_w = param.paddings[1]; + const int m = oc / group; + const int n = oh * ow; + const int k = ic * kernel_h * kernel_w / group; + const int chin_per_group = ic / group; + int channel_size_out = ow * oh; + int channel_size_in = win * ih; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + + int hblock = get_hblock_int8(ctx); + int k_roundup = ROUNDUP(k, KBLOCK_INT8); + int m_roundup = ROUNDUP(m, hblock); + int weights_size_per_group = m * k; + if (n > 1) { + weights_size_per_group = ((m_roundup * k_roundup + 15) / 16) * 16; + } + + bool flag_im2col2 = (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && + stride_w == 1 && n > 1); + + int8_t* tmp_work_space = + ctx->workspace_data() + ctx->llc_size() / sizeof(int8_t); + + //! use gemv when the output channel size = 1 + for (int b = 0; b < num; ++b) { + // dC + for (int g = 0; g < group; ++g) { + signed char* dout_group = + reinterpret_cast(o_data) + + (b * oc + g * m) * channel_size_out * PrecisionTypeLength(out_type); + const int8_t* din_group = static_cast(i_data) + + (b * ic + g * chin_per_group) * channel_size_in; + const int8_t* weights_group = + static_cast(weights) + g * weights_size_per_group; + const int* bias_group = static_cast(bias) + g * m; + int8_t* dB = tmp_work_space; + const float* scale_group = scale + g * m; + + if (flag_im2col2) { + im2col3x3(din_group, + chin_per_group, + ih, + win, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dila_h, + dila_w, + dB, + idx_ptr); + + } else { + im2col(din_group, + chin_per_group, + ih, + win, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dila_h, + dila_w, + dB); + } + if (n == 1) { + if (out_type == PRECISION(kFloat)) { + gemv_int8(weights_group, + dB, + reinterpret_cast(dout_group), + false, + m, + k, + scale_group, + flag_bias, + bias_group, + flag_relu); + } else if (out_type == PRECISION(kInt8)) { // int8 + gemv_int8(weights_group, + dB, + dout_group, + false, + m, + k, + scale_group, + flag_bias, + bias_group, + flag_relu); + } else { + gemv_int8(weights_group, + dB, + reinterpret_cast(dout_group), + false, + m, + k, + scale_group, + flag_bias, + bias_group, + flag_relu); + } + } else { + if (out_type == PRECISION(kFloat)) { + gemm_prepack_int8(weights_group, + dB, + bias_group, + reinterpret_cast(dout_group), + m, + n, + k, + flag_bias, + flag_relu, + flag_im2col2, + scale_group, + ctx); + } else if (out_type == PRECISION(kInt8)) { // int8 + gemm_prepack_int8(weights_group, + dB, + bias_group, + dout_group, + m, + n, + k, + flag_bias, + flag_relu, + flag_im2col2, + scale_group, + ctx); + } else { + gemm_prepack_int8(weights_group, + dB, + bias_group, + reinterpret_cast(dout_group), + m, + n, + k, + flag_bias, + flag_relu, + flag_im2col2, + scale_group, + ctx); + } + } + } + } +} + +void conv_depthwise_3x3(const float* i_data, + float* o_data, + int num, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + int pad = param.paddings[1]; + int stride = param.strides[1]; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + // if (param.activation_param.has_active) { + // if (param.activation_param.active == Active_relu && + // fabs(param.activation_param.negative_slope) < 1e-6f) { + // flag_relu = true; + // } + // } + if (pad == 1) { + conv_depthwise_3x3p1(i_data, + o_data, + num, + oc, + oh, + ow, + ic, + ih, + win, + weights, + bias, + stride, + flag_bias, + flag_relu, + ctx); + } else if (pad == 0 && ih > 2) { + conv_depthwise_3x3p0(i_data, + o_data, + num, + oc, + oh, + ow, + ic, + ih, + win, + weights, + bias, + stride, + flag_bias, + flag_relu, + ctx); + } else { + LOG(FATAL) << "unsupport this type 3x3 dw conv"; + } +} + +void conv_depthwise_5x5(const float* i_data, + float* o_data, + int num, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + int pad = param.paddings[1]; + int stride = param.strides[1]; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + // if (param.activation_param.has_active && + // fabs(param.activation_param.negative_slope) < 1e-6f) { + // if (param.activation_param.active == Active_relu) { + // flag_relu = true; + // } + // } + if (pad == 2 && stride == 2) { + conv_depthwise_5x5s2(i_data, + o_data, + num, + oc, + oh, + ow, + ic, + ih, + win, + weights, + bias, + pad, + flag_bias, + flag_relu, + ctx); + } else if (stride == 1) { + conv_depthwise_5x5s1(i_data, + o_data, + num, + oc, + oh, + ow, + ic, + ih, + win, + weights, + bias, + pad, + flag_bias, + flag_relu, + ctx); + } else { + LOG(FATAL) << "unsupport this type 5x5 dw conv"; + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_impl.h b/lite/arm/math/conv_impl.h new file mode 100644 index 00000000000..38d799bb4c9 --- /dev/null +++ b/lite/arm/math/conv_impl.h @@ -0,0 +1,423 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/core/context.h" +#include "lite/core/target_wrapper.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +// TODO(TJ): move to somewhere else common +template +class ImplBase { + public: + ImplBase() {} + virtual ~ImplBase() {} + + virtual bool create(const Param& param, Context* ctx) { return false; } + + virtual bool init(const Param& param, Context* ctx) { return false; } + + virtual bool run(const Param& param) { return false; } + // void set_op_name(const char* name){_op_name = name;} + // const char* get_op_name() { return _op_name.c_str();} + + protected: + Param* param_; + Context* ctx_; +}; + +void conv_3x3s1_direct_fp32(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + Context* ctx); + +void conv_3x3s1_direct_int8(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + Context* ctx, + PrecisionType out_type, + const float* scale); + +void conv_3x3s1_direct_int7(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + Context* ctx, + PrecisionType out_type, + const float* scale); + +void conv_3x3s2_direct_fp32(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + Context* ctx); + +int conv_3x3s2_direct_int8_c_num(); + +void conv_3x3s2_direct_int8(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + Context* ctx, + PrecisionType out_type, + const float* scale); + +void conv_1x5s1_direct(const void* din, + void* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const void* weights, + const void* bias, + int group, + int kernel_w, + int kernel_h, + int stride_w, + int stride_h, + int dila_w, + int dila_h, + int pad_w, + int pad_h, + bool flag_bias, + bool flag_relu, + Context& ctx, + void* work_space, + const void* idx_ptr); + +void conv_5x1s1_direct(const void* din, + void* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const void* weights, + const void* bias, + int group, + int kernel_w, + int kernel_h, + int stride_w, + int stride_h, + int dila_w, + int dila_h, + int pad_w, + int pad_h, + bool flag_bias, + bool flag_relu, + Context& ctx, + void* work_space, + const void* idx_ptr); + +void conv1x1s1_gemm(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + Context* ctx, + const int* idx_ptr); + +void conv1x1s1_gemm_int8(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + Context* ctx, + PrecisionType out_type, + const float* scale, + const int32_t* idx_ptr); + +void conv_im2col_gemm(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + Context* ctx, + const int* idx_ptr); + +void conv_im2col_gemm_int8(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + Context* ctx, + PrecisionType out_type, + const float* scale, + const int32_t* idx_ptr); + +/** + * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias + */ + +void conv_depthwise_3x3p0(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int stride, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); + +void conv_depthwise_3x3p1(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int stride, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); + +void conv_depthwise_5x5s1(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); + +void conv_depthwise_5x5s2(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); + +void conv_depthwise_3x3(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + Context* ctx); + +void conv_depthwise_3x3_int8(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + Context* ctx, + PrecisionType out_type, + const float* scale); + +void conv_depthwise_3x3_int7(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + Context* ctx, + PrecisionType out_type, + const float* scale); + +void conv_depthwise_5x5(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + Context* ctx); + +void conv_depthwise_5x5_int8(const int8_t* din, + int32_t* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int8_t* weights, + const int32_t* bias, + const operators::ConvParam& param, + Context* ctx, + PrecisionType out_type, + const float* scale); + +void conv_winograd3x3(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + Context* ctx); + +void winograd_transform_weights( + void* dout, const void* din, int ch_out, int ch_in, void* work_space); + +void compute_offset(int* idx_out, + int h, + int w, + int kernel_h, + int kernel_w, + int height, + int width, + int pad_h, + int pad_w, + int dilation_h, + int dilation_w); + +void fill_bias(float* tensor, const float* bias, int channel, int channel_size); + +void fill_bias_int8(int* tensor, + const int* bias, + int channel, + int channel_size); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_winograd.cc b/lite/arm/math/conv_winograd.cc new file mode 100644 index 00000000000..ac2f2aeab46 --- /dev/null +++ b/lite/arm/math/conv_winograd.cc @@ -0,0 +1,141 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/conv_winograd.h" +#include +#include "lite/arm/math/conv_impl.h" +#include "lite/arm/math/packed_sgemm.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +bool WinogradConv::create(const operators::ConvParam& param, + ARMContext* ctx) { + this->ctx_ = ctx; + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ic = x_dims[1]; + int ow = o_dims[3]; + int oh = o_dims[2]; + int oc = o_dims[1]; + int kw = w_dims[3]; + int sw = param.strides[1]; + if (kw == 3) { + is_weights_transed_ = true; + int tile_w = (ow + 5) / 6; + int tile_h = (oh + 5) / 6; + int size_tile = tile_h * tile_w; + int size_trans_channel = 8 * 8 * size_tile; + int max_ch = ic > oc ? ic : oc; + + const int m_wino = oc; + const int n_wino = size_tile; + int hblock = get_hblock(this->ctx_->arch()); + int m_round = hblock * ((m_wino + hblock - 1) / hblock); + weights_trans_.Resize({1, 1, 1, 8 * 8 * m_round * ic}); + this->ctx_->ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) * + sizeof(float)); + auto weights_wino = + static_cast(malloc(sizeof(float) * 8 * 8 * oc * ic)); + void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic); + if (weights_wino && trans_tmp_ptr) { + winograd_transform_weights( + weights_wino, param.filter->data(), oc, ic, trans_tmp_ptr); + auto weights_trans = weights_trans_.mutable_data(); + for (int i = 0; i < 64; ++i) { + float* packed_weights = weights_trans + i * m_round * ic; + const float* weights_wino_ptr = weights_wino + i * oc * ic; + prepackA(packed_weights, + weights_wino_ptr, + 1.f, + ic, + 0, + m_wino, + 0, + ic, + false, + this->ctx_); + } + impl_ = conv_winograd3x3; + free(trans_tmp_ptr); + free(weights_wino); + return true; + } + free(trans_tmp_ptr); + free(weights_wino); + } else { + LOG(ERROR) << "this type winograd conv not impl"; + } + return false; +} + +template <> +bool WinogradConv::init(const operators::ConvParam& param, + Context* ctx) { + this->ctx_ = ctx; + return create(param, ctx); +} + +template <> +bool WinogradConv::run(const operators::ConvParam& param) { + // start timer + const auto* i_data = param.x->data(); + const auto* w_data = param.filter->data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + auto* o_data = param.output->mutable_data(); + + if (is_weights_transed_) { + w_data = weights_trans_.data(); + } + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + + impl_(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + this->ctx_); + + // timer end + return true; +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_winograd.h b/lite/arm/math/conv_winograd.h new file mode 100644 index 00000000000..6533727a939 --- /dev/null +++ b/lite/arm/math/conv_winograd.h @@ -0,0 +1,65 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/arm/math/conv_impl.h" +#include "lite/core/context.h" +#include "lite/core/target_wrapper.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +class WinogradConv + : public ImplBase { + public: + typedef void (*conv_winograd_impl)(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + Context* ctx); + + WinogradConv() = default; + ~WinogradConv() {} + + virtual bool init(const operators::ConvParam& param, + Context* ctx); + + virtual bool create(const operators::ConvParam& param, + Context* ctx); + + virtual bool run(const operators::ConvParam& param); + + private: + conv_winograd_impl impl_{nullptr}; + bool is_weights_transed_{false}; + Tensor weights_trans_; +}; + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/conv_winograd_3x3.cc b/lite/arm/math/conv_winograd_3x3.cc new file mode 100644 index 00000000000..30b029b42a2 --- /dev/null +++ b/lite/arm/math/conv_winograd_3x3.cc @@ -0,0 +1,479 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/conv_impl.h" +#include "lite/arm/math/packed_sgemm.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void transpose(float* data_out, const float* data_in, int w_in, int h_in); +void transform_input_f6x6(float* dout, const float* din); +void transform_output_f6x6(float* output, const float* din, float bias); +void conv_winograd3x3(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + int threads = ctx->threads(); + + const int pad_h = param.paddings[0]; + const int pad_w = param.paddings[1]; + int size_in_channel = win * hin; + int size_out_channel = wout * hout; + bool flag_relu = param.fuse_relu; + bool flag_bias = param.bias != nullptr; + + //! transform input + int tile_w = (wout + 5) / 6; + int tile_h = (hout + 5) / 6; + int size_tile = tile_h * tile_w; + int size_trans_channel = 8 * 8 * size_tile; + int max_ch = chin > chout ? chin : chout; + + int m = chout; + int n = size_tile; + int k = chin; + + float* tmp_work_space = + ctx->workspace_data() + ctx->llc_size() / sizeof(float); + + //! tmp data buffer for input transform + float* tmp_data1 = tmp_work_space; + //! tmp data buffer for dot mul + float* tmp_data2 = tmp_data1 + size_trans_channel * max_ch; + + for (int i = 0; i < num; ++i) { + const float* din_batch = din + i * chin * size_in_channel; + float* dout_batch = dout + i * chout * size_out_channel; + +//! transform input Bt * data * B +#pragma omp parallel for num_threads(threads) + for (int j = 0; j < chin; ++j) { + const float* din_channel = din_batch + j * size_in_channel; + float* data_trans_channel = tmp_data1 + j * size_trans_channel; + + for (int h = 0; h < tile_h; h++) { + for (int w = 0; w < tile_w; w++) { + //! prepare data 8x8 + //! row 8 + float data_in_tmp[8][8] = {0.f}; + // memset(data_in_tmp[0], 0, sizeof(float) * 64); + for (int j = 0; j < 8; ++j) { + int start_row = h * 6 + j - pad_h; + if (start_row >= 0 && start_row < hin) { + for (int k = 0; k < 8; ++k) { + int start_col = w * 6 + k - pad_w; + if (start_col >= 0 && start_col < win) { + data_in_tmp[j][k] = din_channel[start_row * win + start_col]; + } + } + } + } + transform_input_f6x6(data_trans_channel, data_in_tmp[0]); + data_trans_channel += 64; + } + } + } + //! end of transform input + + //////////////////////////////////////////////////////////////////////////////// + //! dot mul + //! transpose input, convert from ch_in * tile_h * tile_w * 64 to + //! 64 * ch_in * tile_h * tile_w + int hblock = get_hblock(ctx->arch()); + int m_round = hblock * ((chout + hblock - 1) / hblock); + int stride_a = m_round * chin; + int stride_b = chin * size_tile; + int stride_c = chout * size_tile; + transpose(tmp_data2, tmp_data1, 64, stride_b); + + //! gemm + // #pragma omp parallel for + for (int l = 0; l < 64; ++l) { + const float* ptr_a = weights + l * stride_a; + const float* ptr_b = tmp_data2 + l * stride_b; + float* ptr_c = tmp_data1 + l * stride_c; + sgemm_prepack(false, + chout, + size_tile, + chin, + ptr_a, + ptr_b, + size_tile, + 0.f, + ptr_c, + size_tile, + nullptr, + false, + false, + ctx); + } + + //! transpose output, convert from 64 * ch_out * tile_h * tile_w to + //! ch_out * tile_h * tile_w * 64 + transpose(tmp_data2, tmp_data1, stride_c, 64); +//! end of dot mul + +/////////////////////////////////////////////////////////////////////////////// +//! transform output +#pragma omp parallel for + for (int i = 0; i < chout; ++i) { + float bias_value = flag_bias ? bias[i] : 0.f; + float* dout_tmp = tmp_data2 + i * size_trans_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + for (int h = 0; h < tile_h; ++h) { + for (int w = 0; w < tile_w; ++w) { + float out_tmp[6][6]; + + transform_output_f6x6(out_tmp[0], dout_tmp, bias_value); + dout_tmp += 64; + + for (int j = 0; j < 6; ++j) { + int end_row = h * 6 + j; + if (end_row < hout) { + for (int k = 0; k < 6; ++k) { + int end_col = w * 6 + k; + if (end_col < wout) { + if (flag_relu) { + dout_channel[end_row * wout + end_col] = + out_tmp[j][k] > 0.f ? out_tmp[j][k] : 0.f; + } else { + dout_channel[end_row * wout + end_col] = out_tmp[j][k]; + } + } + } + } + } + } + } + } + //! end of transform output + } +} + +/** + * \brief transpose with arm neon optimization + * @param data_out + * @param data_in + * @param w_in + * @param h_in + */ +void transpose(float* data_out, const float* data_in, int w_in, int h_in) { + int nw = w_in >> 2; + int nh = h_in >> 2; + int size_in = w_in * h_in; + + float* ptr_out = data_out; + const float* ptr_in = data_in; +#pragma omp parallel for + for (int h = 0; h < nh; h++) { + const float* ptr_din_row = ptr_in + h * 4 * w_in; + for (int w = 0; w < nw; w++) { + float* data_out_ptr = ptr_out + w * 4 * h_in + h * 4; + const float* din0 = ptr_din_row; + const float* din1 = din0 + w_in; + const float* din2 = din1 + w_in; + const float* din3 = din2 + w_in; + + float* dout0 = data_out_ptr; + float* dout1 = dout0 + h_in; + float* dout2 = dout1 + h_in; + float* dout3 = dout2 + h_in; +#ifdef __aarch64__ + asm("ldr q0, [%[in0]] \n" /*load input 0*/ + "ldr q1, [%[in1]] \n" + "ldr q2, [%[in2]] \n" + "ldr q3, [%[in3]] \n" + "trn1 v4.4s, v0.4s, v1.4s \n" + "trn2 v5.4s, v0.4s, v1.4s \n" + "trn1 v6.4s, v2.4s, v3.4s \n" + "trn2 v7.4s, v2.4s, v3.4s \n" + "trn1 v8.2d, v4.2d, v6.2d \n" + "trn1 v9.2d, v5.2d, v7.2d \n" + "trn2 v10.2d, v4.2d, v6.2d \n" + "trn2 v11.2d, v5.2d, v7.2d \n" + "str q8, [%[out0]] \n" + "str q9, [%[out1]] \n" + "str q10, [%[out2]] \n" + "str q11, [%[out3]] \n" + : + : [out0] "r"(dout0), + [out1] "r"(dout1), + [out2] "r"(dout2), + [out3] "r"(dout3), + [in0] "r"(din0), + [in1] "r"(din1), + [in2] "r"(din2), + [in3] "r"(din3) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11"); +#else + asm("vld1.32 {d0, d1}, [%[in0]] \n" + "vld1.32 {d2, d3}, [%[in1]] \n" + "vld1.32 {d4, d5}, [%[in2]] \n" + "vld1.32 {d6, d7}, [%[in3]] \n" + "vtrn.32 q0, q1 \n" + "vtrn.32 q2, q3 \n" + "vswp d1, d4 \n" + "vswp d3, d6 \n" + "vst1.32 {d0, d1}, [%[out0]] \n" + "vst1.32 {d2, d3}, [%[out1]] \n" + "vst1.32 {d4, d5}, [%[out2]] \n" + "vst1.32 {d6, d7}, [%[out3]] \n" + : + : [out0] "r"(dout0), + [out1] "r"(dout1), + [out2] "r"(dout2), + [out3] "r"(dout3), + [in0] "r"(din0), + [in1] "r"(din1), + [in2] "r"(din2), + [in3] "r"(din3) + : "q0", "q1", "q2", "q3"); +#endif + ptr_din_row += 4; + } + } + // remian + for (int h = 0; h < h_in; h++) { + for (int w = nw * 4; w < w_in; w++) { + const float* data_in_ptr = ptr_in + h * w_in + w; + float* data_out_ptr = ptr_out + w * h_in + h; + *data_out_ptr = *data_in_ptr; + } + } + for (int w = 0; w < w_in; w++) { + for (int h = nh * 4; h < h_in; h++) { + const float* data_in_ptr = ptr_in + h * w_in + w; + float* data_out_ptr = ptr_out + w * h_in + h; + *data_out_ptr = *data_in_ptr; + } + } +} + +/** + * \brief winograd transform conv3x3 weights, f63 + * this is done in op initialization or creation, only do once + * dout = G * g * GT, where G is the transform coeff, g is the input weights + * @param dout + * @param din + * @param ch_out + * @param ch_in + * @param work_space + */ +void winograd_transform_weights( + void* dout, const void* din, int ch_out, int ch_in, void* work_space) { + const float coeff[8][3] = {{1.0f, 0.0f, 0.0f}, + {-2.0f / 9, -2.0f / 9, -2.0f / 9}, + {-2.0f / 9, 2.0f / 9, -2.0f / 9}, + {1.0f / 90, 1.0f / 45, 2.0f / 45}, + {1.0f / 90, -1.0f / 45, 2.0f / 45}, + {32.0f / 45, 16.0f / 45, 8.0f / 45}, + {32.0f / 45, -16.0f / 45, 8.0f / 45}, + {0.0f, 0.0f, 1.0f}}; + + float* ptr_out = static_cast(work_space); + + for (int i = 0; i < ch_out; i++) { + for (int j = 0; j < ch_in; j++) { + const float* kernel0 = + static_cast(din) + (i * ch_in + j) * 9; + float* ptr_channel = ptr_out + (i * ch_in + j) * 64; + + //! transform kernel, transposed + const float* k0 = kernel0; + const float* k1 = kernel0 + 3; + const float* k2 = kernel0 + 6; + + //! h + float tmp[8][3]; + for (int i = 0; i < 8; i++) { + tmp[i][0] = + k0[0] * coeff[i][0] + k0[1] * coeff[i][1] + k0[2] * coeff[i][2]; + tmp[i][1] = + k1[0] * coeff[i][0] + k1[1] * coeff[i][1] + k1[2] * coeff[i][2]; + tmp[i][2] = + k2[0] * coeff[i][0] + k2[1] * coeff[i][1] + k2[2] * coeff[i][2]; + } + + //! v + for (int j = 0; j < 8; j++) { + float* tmpp = &tmp[j][0]; + for (int i = 0; i < 8; i++) { + ptr_channel[j * 8 + i] = tmpp[0] * coeff[i][0] + + tmpp[1] * coeff[i][1] + + tmpp[2] * coeff[i][2]; + } + } + } + } + transpose(static_cast(dout), ptr_out, 64, ch_out * ch_in); +} + +/** + * \brief winograd conv, transform input, f6x3 + * dout = BT * d * B, whrer B is the transform + * BT = 1 0 -21/4 0 21/4 0 -1 0 + * 0 1 1 -17/4 -17/4 1 1 0 + * 0 -1 1 17/4 -17/4 -1 1 0 + * 0 1/2 1/4 -5/2 -5/4 2 1 0 + * 0 -1/2 1/4 5/2 -5/4 -2 1 0 + * 0 2 4 -5/2 -5 1/2 1 0 + * 0 -2 4 5/2 -5 -1/2 1 0 + * 0 -1 0 21/4 0 -21/4 0 1 + * @param dout + * @param din + */ +void transform_input_f6x6(float* dout, const float* din) { + float tmp[8][8]; + //! BT * d + for (int m = 0; m < 8; m++) { + tmp[0][m] = din[0] - din[6] + (din[4] - din[2]) * 5.25f; + tmp[7][m] = din[7] - din[1] + (din[3] - din[5]) * 5.25f; + + float tmp12a = din[2] + din[6] - din[4] * 4.25f; + float tmp12b = din[1] + din[5] - din[3] * 4.25f; + + tmp[1][m] = tmp12a + tmp12b; + tmp[2][m] = tmp12a - tmp12b; + + float tmp34a = din[6] + din[2] * 0.25f - din[4] * 1.25f; + float tmp34b = din[1] * 0.5f - din[3] * 2.5f + din[5] * 2.f; + + tmp[3][m] = tmp34a + tmp34b; + tmp[4][m] = tmp34a - tmp34b; + + float tmp56a = din[6] + (din[2] - din[4] * 1.25f) * 4.f; + float tmp56b = din[1] * 2.f - din[3] * 2.5f + din[5] * 0.5f; + + tmp[5][m] = tmp56a + tmp56b; + tmp[6][m] = tmp56a - tmp56b; + + din += 8; + } + + for (int m = 0; m < 8; m++) { + const float* tmp0 = tmp[m]; + + dout[0] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25f; + dout[7] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25f; + + float tmp12a = tmp0[2] + tmp0[6] - tmp0[4] * 4.25f; + float tmp12b = tmp0[1] + tmp0[5] - tmp0[3] * 4.25f; + + dout[1] = tmp12a + tmp12b; + dout[2] = tmp12a - tmp12b; + + float tmp34a = tmp0[6] + tmp0[2] * 0.25f - tmp0[4] * 1.25f; + float tmp34b = tmp0[1] * 0.5f - tmp0[3] * 2.5f + tmp0[5] * 2.f; + + dout[3] = tmp34a + tmp34b; + dout[4] = tmp34a - tmp34b; + + float tmp56a = tmp0[6] + (tmp0[2] - tmp0[4] * 1.25f) * 4.f; + float tmp56b = tmp0[1] * 2.f - tmp0[3] * 2.5f + tmp0[5] * 0.5f; + + dout[5] = tmp56a + tmp56b; + dout[6] = tmp56a - tmp56b; + + dout += 8; + } +} + +/** + * \brief winograd conv, transform output, f63 + * out = AT * din * A + * AT = 1 1 1 1 1 1 1 0 + * 0 1 -1 2 -2 1/2 -1/2 0 + * 0 1 1 4 4 1/4 1/4 0 + * 0 1 -1 8 -8 1/8 -1/8 0 + * 0 1 1 16 16 1/16 1/16 0 + * 0 1 -1 32 -32 1/32 -1/32 1 + * @param output + * @param din + * @param bias + */ +void transform_output_f6x6(float* output, const float* din, float bias) { + float tmp[6][8]; + for (int m = 0; m < 8; m++) { + float tmp024a = din[1] + din[2]; + float tmp135a = din[1] - din[2]; + + float tmp024b = din[3] + din[4]; + float tmp135b = din[3] - din[4]; + + float tmp024c = din[5] + din[6]; + float tmp135c = din[5] - din[6]; + + tmp[0][m] = din[0] + tmp024a + tmp024b + tmp024c; + tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 0.25f; + tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c * 0.0625f; + + tmp[1][m] = tmp135a + tmp135b * 2 + tmp135c * 0.5f; + tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 0.125f; + tmp[5][m] = din[7] + tmp135a + tmp135b * 32 + tmp135c * 0.03125f; + + din += 8; + } + + for (int m = 0; m < 6; m++) { + const float* tmp0 = tmp[m]; + + float tmp024a = tmp0[1] + tmp0[2]; + float tmp135a = tmp0[1] - tmp0[2]; + + float tmp024b = tmp0[3] + tmp0[4]; + float tmp135b = tmp0[3] - tmp0[4]; + + float tmp024c = tmp0[5] + tmp0[6]; + float tmp135c = tmp0[5] - tmp0[6]; + + output[0] = bias + tmp0[0] + tmp024a + tmp024b + tmp024c; + output[2] = bias + tmp024a + tmp024b * 4 + tmp024c * 0.25f; + output[4] = bias + tmp024a + tmp024b * 16 + tmp024c * 0.0625f; + + output[1] = bias + tmp135a + tmp135b * 2 + tmp135c * 0.5f; + output[3] = bias + tmp135a + tmp135b * 8 + tmp135c * 0.125f; + output[5] = bias + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c * 0.03125f; + + output += 6; + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/decode_bboxes.cc b/lite/arm/math/decode_bboxes.cc new file mode 100644 index 00000000000..2c6c7f186e6 --- /dev/null +++ b/lite/arm/math/decode_bboxes.cc @@ -0,0 +1,651 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/decode_bboxes.h" +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void decode_bbox_corner_variance_kernel(const int batch_num, + const T* loc_data, + const T* prior_data, + const T* variance, + const int num_priors, + const bool share_location, + const int num_loc_classes, + const int background_label_id, + T* bbox_data); + +template +void decode_bbox_corner_no_variance_kernel(const int batch_num, + const T* loc_data, + const T* prior_data, + const T* variance, + const int num_priors, + const bool share_location, + const int num_loc_classes, + const int background_label_id, + T* bbox_data); + +template +void decode_bbox_center_variance_kernel(const int batch_num, + const T* loc_data, + const T* prior_data, + const T* variance, + const int num_priors, + const bool share_location, + const int num_loc_classes, + const int background_label_id, + T* bbox_data); + +template +void decode_bbox_center_no_variance_kernel(const int batch_num, + const float* loc_data, + const float* prior_data, + const float* variance, + const int num_priors, + const bool share_location, + const int num_loc_classes, + const int background_label_id, + float* bbox_data); + +template +void decode_bbox_corner_size_variance_kernel(const int batch_num, + const T* loc_data, + const T* prior_data, + const T* variance, + const int num_priors, + const bool share_location, + const int num_loc_classes, + const int background_label_id, + T* bbox_data); + +template +void decode_bbox_corner_size_no_variance_kernel(const int batch_num, + const T* loc_data, + const T* prior_data, + const T* variance, + const int num_priors, + const bool share_location, + const int num_loc_classes, + const int background_label_id, + T* bbox_data); + +template <> +void decode_bbox_corner_variance_kernel(const int batch_num, + const float* loc_data, + const float* prior_data, + const float* variance, + const int num_priors, + const bool share_location, + const int num_loc_classes, + const int background_label_id, + float* bbox_data) { + if (!share_location) { + CHECK_EQ(share_location, true) + << "ERROR: decode boxes without share_location is unimplemented\n"; + return; + } + + int cnt = num_priors / 4; + int len_batch = num_priors * 4; + + for (int n = 0; n < batch_num; ++n) { + const float* ptr_loc_batch = loc_data + n * len_batch; + float* ptr_bbox_batch = bbox_data + n * len_batch; +#pragma omp parallel for + for (int i = 0; i < cnt; ++i) { + int idx = i * 16; + const float* ptr_loc = ptr_loc_batch + idx; + const float* ptr_prior = prior_data + idx; + float* ptr_bbox = ptr_bbox_batch + idx; + + float32x4_t vloc1 = vld1q_f32(ptr_loc); + float32x4_t vloc2 = vld1q_f32(ptr_loc + 4); + float32x4_t vloc3 = vld1q_f32(ptr_loc + 8); + float32x4_t vloc4 = vld1q_f32(ptr_loc + 12); + + float32x4_t vprior1 = vld1q_f32(ptr_prior); + float32x4_t vprior2 = vld1q_f32(ptr_prior + 4); + float32x4_t vprior3 = vld1q_f32(ptr_prior + 8); + float32x4_t vprior4 = vld1q_f32(ptr_prior + 12); + + vst1q_f32(ptr_bbox, vaddq_f32(vloc1, vprior1)); + vst1q_f32(ptr_bbox + 4, vaddq_f32(vloc2, vprior2)); + vst1q_f32(ptr_bbox + 8, vaddq_f32(vloc3, vprior3)); + vst1q_f32(ptr_bbox + 12, vaddq_f32(vloc4, vprior4)); + } +#pragma omp parallel for + for (int i = cnt * 4; i < num_priors; i++) { + int idx = i * 4; + float32x4_t vloc = vld1q_f32(ptr_loc_batch + idx); + float32x4_t vprior = vld1q_f32(prior_data + idx); + vst1q_f32(ptr_bbox_batch + idx, vaddq_f32(vloc, vprior)); + } + } +} + +template <> +void decode_bbox_corner_no_variance_kernel(const int batch_num, + const float* loc_data, + const float* prior_data, + const float* variance, + const int num_priors, + const bool share_location, + const int num_loc_classes, + const int background_label_id, + float* bbox_data) { + if (!share_location) { + CHECK_EQ(share_location, true) + << "ERROR: decode boxes without share_location is unimplemented\n"; + return; + } + + int cnt = num_priors / 4; + int len_batch = num_priors * 4; + + for (int n = 0; n < batch_num; ++n) { + const float* ptr_loc_batch = loc_data + n * len_batch; + float* ptr_bbox_batch = bbox_data + n * len_batch; + +#pragma omp parallel for + for (int i = 0; i < cnt; ++i) { + int idx = i * 16; + const float* ptr_loc = ptr_loc_batch + idx; + const float* ptr_prior = prior_data + idx; + const float* ptr_var = variance + idx; + float* ptr_bbox = ptr_bbox_batch + idx; + + float32x4_t vloc1 = vld1q_f32(ptr_loc); + float32x4_t vprior1 = vld1q_f32(ptr_prior); + float32x4_t vvar1 = vld1q_f32(ptr_var); + float32x4_t vout1 = vmulq_f32(vloc1, vvar1); + + float32x4_t vloc2 = vld1q_f32(ptr_loc + 4); + float32x4_t vprior2 = vld1q_f32(ptr_prior + 4); + float32x4_t vvar2 = vld1q_f32(ptr_var + 4); + float32x4_t vout2 = vmulq_f32(vloc2, vvar2); + + float32x4_t vloc3 = vld1q_f32(ptr_loc + 8); + float32x4_t vprior3 = vld1q_f32(ptr_prior + 8); + float32x4_t vvar3 = vld1q_f32(ptr_var + 8); + float32x4_t vout3 = vmulq_f32(vloc3, vvar3); + + float32x4_t vloc4 = vld1q_f32(ptr_loc + 12); + float32x4_t vprior4 = vld1q_f32(ptr_prior + 12); + float32x4_t vvar4 = vld1q_f32(ptr_var + 12); + float32x4_t vout4 = vmulq_f32(vloc4, vvar4); + + vst1q_f32(ptr_bbox, vaddq_f32(vout1, vprior1)); + vst1q_f32(ptr_bbox + 4, vaddq_f32(vout2, vprior2)); + vst1q_f32(ptr_bbox + 8, vaddq_f32(vout3, vprior3)); + vst1q_f32(ptr_bbox + 12, vaddq_f32(vout4, vprior4)); + } + + for (int i = cnt * 4; i < num_priors; i++) { + int idx = i * 4; + float32x4_t vloc = vld1q_f32(ptr_loc_batch + idx); + float32x4_t vprior = vld1q_f32(prior_data + idx); + float32x4_t vvar = vld1q_f32(variance + idx); + float32x4_t vout = vmulq_f32(vloc, vvar); + vst1q_f32(ptr_bbox_batch + idx, vaddq_f32(vout, vprior)); + } + } +} + +template <> +void decode_bbox_center_variance_kernel(const int batch_num, + const float* loc_data, + const float* prior_data, + const float* variance, + const int num_priors, + const bool share_location, + const int num_loc_classes, + const int background_label_id, + float* bbox_data) { + if (!share_location) { + CHECK_EQ(share_location, true) + << "ERROR: decode boxes without share_location is unimplemented\n"; + return; + } + + int cnt = num_priors / 4; + //! vprior 0: xmin, 1: ymin, 2: xmax, 3: ymax + //! vloc 0: xmin, 1: ymin, 2: xmax, 3: ymax + //! vvar + float32x4_t vhalf = vdupq_n_f32(0.5f); + + int len_batch = num_priors * 4; + + for (int n = 0; n < batch_num; ++n) { + const float* ptr_loc_batch = loc_data + n * len_batch; + float* ptr_bbox_batch = bbox_data + n * len_batch; + +#pragma omp parallel for + for (int i = 0; i < cnt; ++i) { + int idx = i * 16; + const float* ptr_loc = ptr_loc_batch + idx; + const float* ptr_prior = prior_data + idx; + float* ptr_bbox = ptr_bbox_batch + idx; + + float32x4x4_t vprior = vld4q_f32(ptr_prior); + float32x4x4_t vloc = vld4q_f32(ptr_loc); + float32x4_t vprior_width = vsubq_f32(vprior.val[2], vprior.val[0]); + float32x4_t vprior_height = vsubq_f32(vprior.val[3], vprior.val[1]); + float32x4_t vprior_cx = + vmulq_f32(vaddq_f32(vprior.val[0], vprior.val[2]), vhalf); + float32x4_t vprior_cy = + vmulq_f32(vaddq_f32(vprior.val[1], vprior.val[3]), vhalf); + + float32x4_t vdec_bbx_cx = + vaddq_f32(vmulq_f32(vloc.val[0], vprior_width), vprior_cx); + float32x4_t vdec_bbx_cy = + vaddq_f32(vmulq_f32(vloc.val[1], vprior_height), vprior_cy); + float32x4_t vdec_bbx_w = exp_ps(vloc.val[2]); + float32x4_t vdec_bbx_h = exp_ps(vloc.val[3]); + vprior_width = vmulq_f32(vprior_width, vhalf); + vprior_height = vmulq_f32(vprior_height, vhalf); + vdec_bbx_w = vmulq_f32(vdec_bbx_w, vprior_width); + vdec_bbx_h = vmulq_f32(vdec_bbx_h, vprior_height); + + vloc.val[0] = vsubq_f32(vdec_bbx_cx, vdec_bbx_w); + vloc.val[1] = vsubq_f32(vdec_bbx_cy, vdec_bbx_h); + vloc.val[2] = vaddq_f32(vdec_bbx_cx, vdec_bbx_w); + vloc.val[3] = vaddq_f32(vdec_bbx_cy, vdec_bbx_h); + + vst4q_f32(ptr_bbox, vloc); + } +#pragma omp parallel for + for (int i = cnt * 4; i < num_priors; i++) { + int idx = i * 4; + float p_xmin = prior_data[idx]; + float p_ymin = prior_data[idx + 1]; + float p_xmax = prior_data[idx + 2]; + float p_ymax = prior_data[idx + 3]; + float prior_width = p_xmax - p_xmin; + float prior_height = p_ymax - p_ymin; + float prior_center_x = (p_xmin + p_xmax) / 2.f; + float prior_center_y = (p_ymin + p_ymax) / 2.f; + + float xmin = ptr_loc_batch[idx]; + float ymin = ptr_loc_batch[idx + 1]; + float xmax = ptr_loc_batch[idx + 2]; + float ymax = ptr_loc_batch[idx + 3]; + + //! variance is encoded in target, we simply need to retore the offset + //! predictions. + float decode_bbox_center_x = xmin * prior_width + prior_center_x; + float decode_bbox_center_y = ymin * prior_height + prior_center_y; + float decode_bbox_width = expf(xmax) * prior_width; + float decode_bbox_height = expf(ymax) * prior_height; + + ptr_bbox_batch[idx] = decode_bbox_center_x - decode_bbox_width / 2.f; + ptr_bbox_batch[idx + 1] = decode_bbox_center_y - decode_bbox_height / 2.f; + ptr_bbox_batch[idx + 2] = decode_bbox_center_x + decode_bbox_width / 2.f; + ptr_bbox_batch[idx + 3] = decode_bbox_center_y + decode_bbox_height / 2.f; + } + } +} + +template <> +void decode_bbox_center_no_variance_kernel(const int batch_num, + const float* loc_data, + const float* prior_data, + const float* variance, + const int num_priors, + const bool share_location, + const int num_loc_classes, + const int background_label_id, + float* bbox_data) { + if (!share_location) { + CHECK_EQ(share_location, true) + << "ERROR: decode boxes without share_location is unimplemented\n"; + return; + } + + int cnt = num_priors / 4; + //! vprior 0: xmin, 1: ymin, 2: xmax, 3: ymax + //! vloc 0: xmin, 1: ymin, 2: xmax, 3: ymax + //! vvar + float32x4_t vhalf = vdupq_n_f32(0.5f); + + int len_batch = num_priors * 4; + + for (int n = 0; n < batch_num; ++n) { + const float* ptr_loc_batch = loc_data + n * len_batch; + float* ptr_bbox_batch = bbox_data + n * len_batch; + +#pragma omp parallel for + for (int i = 0; i < cnt; ++i) { + int idx = i * 16; + + const float* ptr_loc = ptr_loc_batch + idx; + const float* ptr_prior = prior_data + idx; + const float* ptr_var = variance + idx; + float* ptr_bbox = ptr_bbox_batch + idx; + + float32x4x4_t vprior = vld4q_f32(ptr_prior); + float32x4x4_t vloc = vld4q_f32(ptr_loc); + float32x4x4_t vvar = vld4q_f32(ptr_var); + float32x4_t vprior_width = vsubq_f32(vprior.val[2], vprior.val[0]); + float32x4_t vprior_height = vsubq_f32(vprior.val[3], vprior.val[1]); + float32x4_t vprior_cx = + vmulq_f32(vaddq_f32(vprior.val[0], vprior.val[2]), vhalf); + float32x4_t vprior_cy = + vmulq_f32(vaddq_f32(vprior.val[1], vprior.val[3]), vhalf); + + vloc.val[0] = vmulq_f32(vloc.val[0], vvar.val[0]); + vloc.val[1] = vmulq_f32(vloc.val[1], vvar.val[1]); + vloc.val[2] = vmulq_f32(vloc.val[2], vvar.val[2]); + vloc.val[3] = vmulq_f32(vloc.val[3], vvar.val[3]); + + float32x4_t vdec_bbx_cx = + vaddq_f32(vmulq_f32(vloc.val[0], vprior_width), vprior_cx); + float32x4_t vdec_bbx_cy = + vaddq_f32(vmulq_f32(vloc.val[1], vprior_height), vprior_cy); + float32x4_t vdec_bbx_w = exp_ps(vloc.val[2]); + float32x4_t vdec_bbx_h = exp_ps(vloc.val[3]); + vprior_width = vmulq_f32(vprior_width, vhalf); + vprior_height = vmulq_f32(vprior_height, vhalf); + vdec_bbx_w = vmulq_f32(vdec_bbx_w, vprior_width); + vdec_bbx_h = vmulq_f32(vdec_bbx_h, vprior_height); + + vloc.val[0] = vsubq_f32(vdec_bbx_cx, vdec_bbx_w); + vloc.val[1] = vsubq_f32(vdec_bbx_cy, vdec_bbx_h); + vloc.val[2] = vaddq_f32(vdec_bbx_cx, vdec_bbx_w); + vloc.val[3] = vaddq_f32(vdec_bbx_cy, vdec_bbx_h); + + vst4q_f32(ptr_bbox, vloc); + } + +#pragma omp parallel for + for (int i = cnt * 4; i < num_priors; i++) { + int idx = i * 4; + float p_xmin = prior_data[idx]; + float p_ymin = prior_data[idx + 1]; + float p_xmax = prior_data[idx + 2]; + float p_ymax = prior_data[idx + 3]; + float prior_width = p_xmax - p_xmin; + float prior_height = p_ymax - p_ymin; + float prior_center_x = (p_xmin + p_xmax) / 2.f; + float prior_center_y = (p_ymin + p_ymax) / 2.f; + + float xmin = ptr_loc_batch[idx]; + float ymin = ptr_loc_batch[idx + 1]; + float xmax = ptr_loc_batch[idx + 2]; + float ymax = ptr_loc_batch[idx + 3]; + + //! variance is encoded in target, we simply need to retore the offset + //! predictions. + float decode_bbox_center_x = + variance[idx] * xmin * prior_width + prior_center_x; + float decode_bbox_center_y = + variance[idx + 1] * ymin * prior_height + prior_center_y; + float decode_bbox_width = expf(variance[idx + 2] * xmax) * prior_width; + float decode_bbox_height = expf(variance[idx + 3] * ymax) * prior_height; + + ptr_bbox_batch[idx] = decode_bbox_center_x - decode_bbox_width / 2.f; + ptr_bbox_batch[idx + 1] = decode_bbox_center_y - decode_bbox_height / 2.f; + ptr_bbox_batch[idx + 2] = decode_bbox_center_x + decode_bbox_width / 2.f; + ptr_bbox_batch[idx + 3] = decode_bbox_center_y + decode_bbox_height / 2.f; + } + } +} + +template <> +void decode_bbox_corner_size_variance_kernel( + const int batch_num, + const float* loc_data, + const float* prior_data, + const float* variance, + const int num_priors, + const bool share_location, + const int num_loc_classes, + const int background_label_id, + float* bbox_data) { + if (!share_location) { + CHECK_EQ(share_location, true) + << "ERROR: decode boxes without share_location is unimplemented\n"; + return; + } + + int cnt = num_priors / 4; + //! vprior 0: xmin, 1: ymin, 2: xmax, 3: ymax + //! bbx + + int len_batch = num_priors * 4; + + for (int n = 0; n < batch_num; ++n) { + const float* ptr_loc_batch = loc_data + n * len_batch; + float* ptr_bbox_batch = bbox_data + n * len_batch; + +#pragma omp parallel for + for (int i = 0; i < cnt; ++i) { + int idx = i * 16; + + const float* ptr_loc = ptr_loc_batch + idx; + const float* ptr_prior = prior_data + idx; + const float* ptr_var = variance + idx; + float* ptr_bbox = ptr_bbox_batch + idx; + + float32x4x4_t vprior = vld4q_f32(ptr_prior); + float32x4x4_t vloc = vld4q_f32(ptr_loc); + + float32x4_t vprior_width = vsubq_f32(vprior.val[2], vprior.val[0]); + float32x4_t vprior_height = vsubq_f32(vprior.val[3], vprior.val[1]); + + float32x4x4_t vbbx; + vbbx.val[0] = vmulq_f32(vloc.val[0], vprior_width); + vbbx.val[1] = vmulq_f32(vloc.val[1], vprior_height); + vbbx.val[2] = vmulq_f32(vloc.val[2], vprior_width); + vbbx.val[3] = vmulq_f32(vloc.val[3], vprior_height); + + vbbx.val[0] = vaddq_f32(vprior.val[0], vbbx.val[0]); + vbbx.val[1] = vaddq_f32(vprior.val[1], vbbx.val[1]); + vbbx.val[2] = vaddq_f32(vprior.val[2], vbbx.val[2]); + vbbx.val[3] = vaddq_f32(vprior.val[3], vbbx.val[3]); + + vst4q_f32(ptr_bbox, vbbx); + } + +#pragma omp parallel for + for (int i = cnt * 4; i < num_priors; i++) { + int idx = i * 4; + float p_xmin = prior_data[idx]; + float p_ymin = prior_data[idx + 1]; + float p_xmax = prior_data[idx + 2]; + float p_ymax = prior_data[idx + 3]; + float prior_width = p_xmax - p_xmin; + float prior_height = p_ymax - p_ymin; + + ptr_bbox_batch[idx] = p_xmin + ptr_loc_batch[idx] * prior_width; + ptr_bbox_batch[idx + 1] = p_ymin + ptr_loc_batch[idx + 1] * prior_height; + ptr_bbox_batch[idx + 2] = p_xmax + ptr_loc_batch[idx + 2] * prior_width; + ptr_bbox_batch[idx + 3] = p_ymax + ptr_loc_batch[idx + 3] * prior_height; + } + } +} + +template <> +void decode_bbox_corner_size_no_variance_kernel( + const int batch_num, + const float* loc_data, + const float* prior_data, + const float* variance, + const int num_priors, + const bool share_location, + const int num_loc_classes, + const int background_label_id, + float* bbox_data) { + if (!share_location) { + CHECK_EQ(share_location, true) + << "ERROR: decode boxes without share_location is unimplemented\n"; + return; + } + + int cnt = num_priors / 4; + //! vprior 0: xmin, 1: ymin, 2: xmax, 3: ymax + //! bbx + + int len_batch = num_priors * 4; + + for (int n = 0; n < batch_num; ++n) { + const float* ptr_loc_batch = loc_data + n * len_batch; + float* ptr_bbox_batch = bbox_data + n * len_batch; + +#pragma omp parallel for + for (int i = 0; i < cnt; ++i) { + int idx = i * 16; + + const float* ptr_loc = ptr_loc_batch + idx; + const float* ptr_prior = prior_data + idx; + const float* ptr_var = variance + idx; + float* ptr_bbox = ptr_bbox_batch + idx; + + float32x4x4_t vprior = vld4q_f32(ptr_prior); + float32x4x4_t vloc = vld4q_f32(ptr_loc); + + float32x4_t vprior_width = vsubq_f32(vprior.val[2], vprior.val[0]); + float32x4_t vprior_height = vsubq_f32(vprior.val[3], vprior.val[1]); + + float32x4x4_t vbbx; + vbbx.val[0] = vmulq_f32(vloc.val[0], vprior_width); + vbbx.val[1] = vmulq_f32(vloc.val[1], vprior_height); + vbbx.val[2] = vmulq_f32(vloc.val[2], vprior_width); + vbbx.val[3] = vmulq_f32(vloc.val[3], vprior_height); + + vloc = vld4q_f32(ptr_var); + vbbx.val[0] = vmulq_f32(vbbx.val[0], vloc.val[0]); + vbbx.val[1] = vmulq_f32(vbbx.val[1], vloc.val[1]); + vbbx.val[2] = vmulq_f32(vbbx.val[2], vloc.val[2]); + vbbx.val[3] = vmulq_f32(vbbx.val[3], vloc.val[3]); + + vbbx.val[0] = vaddq_f32(vprior.val[0], vbbx.val[0]); + vbbx.val[1] = vaddq_f32(vprior.val[1], vbbx.val[1]); + vbbx.val[2] = vaddq_f32(vprior.val[2], vbbx.val[2]); + vbbx.val[3] = vaddq_f32(vprior.val[3], vbbx.val[3]); + + vst4q_f32(ptr_bbox, vbbx); + } +#pragma omp parallel for + for (int i = cnt * 4; i < num_priors; i++) { + int idx = i * 4; + float p_xmin = prior_data[idx]; + float p_ymin = prior_data[idx + 1]; + float p_xmax = prior_data[idx + 2]; + float p_ymax = prior_data[idx + 3]; + float prior_width = p_xmax - p_xmin; + float prior_height = p_ymax - p_ymin; + + ptr_bbox_batch[idx] = + p_xmin + ptr_loc_batch[idx] * variance[idx] * prior_width; + ptr_bbox_batch[idx + 1] = + p_ymin + ptr_loc_batch[idx + 1] * variance[idx + 1] * prior_height; + ptr_bbox_batch[idx + 2] = + p_xmax + ptr_loc_batch[idx + 2] * variance[idx + 2] * prior_width; + ptr_bbox_batch[idx + 3] = + p_ymax + ptr_loc_batch[idx + 3] * variance[idx + 3] * prior_height; + } + } +} + +template <> +void decode_bboxes(const int batch_num, + const float* loc_data, + const float* prior_data, + const std::string code_type, + const bool variance_encoded_in_target, + const int num_priors, + const bool share_location, + const int num_loc_classes, + const int background_label_id, + float* bbox_data) { + const float* variance_data = prior_data + 4 * num_priors; + if (code_type == "corner") { + if (variance_encoded_in_target) { + decode_bbox_corner_variance_kernel(batch_num, + loc_data, + prior_data, + variance_data, + num_priors, + share_location, + num_loc_classes, + background_label_id, + bbox_data); + } else { + decode_bbox_corner_no_variance_kernel(batch_num, + loc_data, + prior_data, + variance_data, + num_priors, + share_location, + num_loc_classes, + background_label_id, + bbox_data); + } + } else if (code_type == "center_size") { + if (variance_encoded_in_target) { + decode_bbox_center_variance_kernel(batch_num, + loc_data, + prior_data, + variance_data, + num_priors, + share_location, + num_loc_classes, + background_label_id, + bbox_data); + } else { + decode_bbox_center_no_variance_kernel(batch_num, + loc_data, + prior_data, + variance_data, + num_priors, + share_location, + num_loc_classes, + background_label_id, + bbox_data); + } + } else if (code_type == "corner_size") { + if (variance_encoded_in_target) { + decode_bbox_corner_size_variance_kernel(batch_num, + loc_data, + prior_data, + variance_data, + num_priors, + share_location, + num_loc_classes, + background_label_id, + bbox_data); + } else { + decode_bbox_corner_size_no_variance_kernel(batch_num, + loc_data, + prior_data, + variance_data, + num_priors, + share_location, + num_loc_classes, + background_label_id, + bbox_data); + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/decode_bboxes.h b/lite/arm/math/decode_bboxes.h new file mode 100644 index 00000000000..f18bfe64200 --- /dev/null +++ b/lite/arm/math/decode_bboxes.h @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void decode_bboxes(const int batch_num, + const T* loc_data, + const T* prior_data, + const std::string code_type, + const bool variance_encoded_in_target, + const int num_priors, + const bool share_location, + const int num_loc_classes, + const int background_label_id, + T* bbox_data); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/dot_toolchain_support.h b/lite/arm/math/dot_toolchain_support.h new file mode 100644 index 00000000000..8342ffee199 --- /dev/null +++ b/lite/arm/math/dot_toolchain_support.h @@ -0,0 +1,196 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// This file is modified according to +// https://github.com/ARM-software/ComputeLibrary +// * Copyright (c) 2017-2018 ARM Limited. +// * +// * SPDX-License-Identifier: MIT +// * +// * Permission is hereby granted, free of charge, to any person obtaining a +// copy +// * of this software and associated documentation files (the "Software"), to +// * deal in the Software without restriction, including without limitation the +// * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +// * sell copies of the Software, and to permit persons to whom the Software is +// * furnished to do so, subject to the following conditions: +// * +// * The above copyright notice and this permission notice shall be included in +// all +// * copies or substantial portions of the Software. +// * +// * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, +// * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE +// * SOFTWARE. + +#pragma once + +#define _DECLARE_SDOT_ELEMENT \ + ".altmacro\n" \ + ".macro sdot opd:req, opn:req, opm:req\n" \ + "local vd, vn, vm, h, l\n" \ + ".irp " \ + "reg,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25," \ + "26,27,28,29,30,31\n" \ + ".ifeqs \"\\opd\",\"v\\reg\\.4s\"\n" \ + ".set vd,\\reg\n" \ + ".endif\n" \ + ".ifeqs \"\\opn\",\"v\\reg\\.16b\"\n" \ + ".set vn,\\reg\n" \ + ".endif\n" \ + ".irp idx,0,1,2,3\n" \ + ".ifeqs \"\\opm\",\"v\\reg\\.4b[\\idx\\]\"\n" \ + ".set vm,\\reg\n" \ + ".set h,\\idx / 2\n" \ + ".set l,\\idx %% 2\n" \ + ".endif\n" \ + ".endr\n" \ + ".endr\n" \ + ".ifndef vd\n" \ + ".error \"Bad operand \\opd\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef vn\n" \ + ".error \"Bad operand \\opn\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef vm\n" \ + ".error \"Bad operand \\opm\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef h\n" \ + ".error \"Bad operand \\opm\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef l\n" \ + ".error \"Bad operand \\opm\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".int 0x4f80e000 | vd | (vn << 5) | (vm << 16) | (l << 21) | (h << 11)\n" \ + ".endm\n" + +#define _DECLARE_SDOT_VECTOR \ + ".altmacro\n" \ + ".macro sdot opd:req, opn:req, opm:req\n" \ + "local vd, vn, vm\n" \ + ".irp " \ + "reg,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25," \ + "26,27,28,29,30,31\n" \ + ".ifeqs \"\\opd\",\"v\\reg\\.4s\"\n" \ + ".set vd,\\reg\n" \ + ".endif\n" \ + ".ifeqs \"\\opn\",\"v\\reg\\.16b\"\n" \ + ".set vn,\\reg\n" \ + ".endif\n" \ + ".ifeqs \"\\opm\",\"v\\reg\\.16b\"\n" \ + ".set vm,\\reg\n" \ + ".endif\n" \ + ".endr\n" \ + ".endr\n" \ + ".ifndef vd\n" \ + ".error \"Bad operand \\opd\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef vn\n" \ + ".error \"Bad operand \\opn\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef vm\n" \ + ".error \"Bad operand \\opm\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".int 0x4e809400 | vd | (vn << 5) | (vm << 16)\n" \ + ".endm\n" + +#define _DECLARE_SDOT_VECTOR_2s \ + ".altmacro\n" \ + ".macro sdot opd:req, opn:req, opm:req\n" \ + "local vd, vn, vm\n" \ + ".irp " \ + "reg,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25," \ + "26,27,28,29,30,31\n" \ + ".ifeqs \"\\opd\",\"v\\reg\\.2s\"\n" \ + ".set vd,\\reg\n" \ + ".endif\n" \ + ".ifeqs \"\\opn\",\"v\\reg\\.8b\"\n" \ + ".set vn,\\reg\n" \ + ".endif\n" \ + ".ifeqs \"\\opm\",\"v\\reg\\.8b\"\n" \ + ".set vm,\\reg\n" \ + ".endif\n" \ + ".endr\n" \ + ".endr\n" \ + ".ifndef vd\n" \ + ".error \"Bad operand \\opd\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef vn\n" \ + ".error \"Bad operand \\opn\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef vm\n" \ + ".error \"Bad operand \\opm\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".int 0x0e809400 | vd | (vn << 5) | (vm << 16)\n" \ + ".endm\n" + +#define _DECLARE_SDOT_ELEMENT_2s \ + ".altmacro\n" \ + ".macro sdot opd:req, opn:req, opm:req\n" \ + "local vd, vn, vm, h, l\n" \ + ".irp " \ + "reg,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25," \ + "26,27,28,29,30,31\n" \ + ".ifeqs \"\\opd\",\"v\\reg\\.2s\"\n" \ + ".set vd,\\reg\n" \ + ".endif\n" \ + ".ifeqs \"\\opn\",\"v\\reg\\.8b\"\n" \ + ".set vn,\\reg\n" \ + ".endif\n" \ + ".irp idx,0,1,2,3\n" \ + ".ifeqs \"\\opm\",\"v\\reg\\.4b[\\idx\\]\"\n" \ + ".set vm,\\reg\n" \ + ".set h,\\idx / 2\n" \ + ".set l,\\idx %% 2\n" \ + ".endif\n" \ + ".endr\n" \ + ".endr\n" \ + ".ifndef vd\n" \ + ".error \"Bad operand \\opd\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef vn\n" \ + ".error \"Bad operand \\opn\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef vm\n" \ + ".error \"Bad operand \\opm\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef h\n" \ + ".error \"Bad operand \\opm\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef l\n" \ + ".error \"Bad operand \\opm\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".int 0x0f80e000 | vd | (vn << 5) | (vm << 16) | (l << 21) | (h << 11)\n" \ + ".endm\n" diff --git a/lite/arm/math/dropout.cc b/lite/arm/math/dropout.cc new file mode 100644 index 00000000000..1944dbb882e --- /dev/null +++ b/lite/arm/math/dropout.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/dropout.h" +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void dropout_down(const float* din, float* dout, int num, float prob) { + const float scale = 1.0f - prob; + int cnt = num >> 4; + int remain = num % 16; + float32x4_t vscale = vdupq_n_f32(scale); +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* din_ptr = din + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + float32x4_t vmul0 = vmulq_f32(din0, vscale); + float32x4_t vmul1 = vmulq_f32(din1, vscale); + float32x4_t vmul2 = vmulq_f32(din2, vscale); + float32x4_t vmul3 = vmulq_f32(din3, vscale); + + vst1q_f32(dout_ptr, vmul0); + vst1q_f32(dout_ptr + 4, vmul1); + vst1q_f32(dout_ptr + 8, vmul2); + vst1q_f32(dout_ptr + 12, vmul3); + } + if (remain > 0) { + const float* din_ptr = din + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + *dout_ptr = *din_ptr * scale; + dout_ptr++; + din_ptr++; + } + } +} + +template <> +void dropout_up(const float* din, float* dout, int num) { + int cnt = num >> 4; + int remain = num % 16; +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* din_ptr = din + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + } + if (remain > 0) { + const float* din_ptr = din + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + *dout_ptr = *din_ptr; + dout_ptr++; + din_ptr++; + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/dropout.h b/lite/arm/math/dropout.h new file mode 100644 index 00000000000..df2be016de9 --- /dev/null +++ b/lite/arm/math/dropout.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void dropout_down(const T* din, T* dout, int num, float prob); + +template +void dropout_up(const T* din, T* dout, int num); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/elementwise.cc b/lite/arm/math/elementwise.cc new file mode 100644 index 00000000000..19155c264b5 --- /dev/null +++ b/lite/arm/math/elementwise.cc @@ -0,0 +1,758 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/elementwise.h" +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void elementwise_add(const float* dinx, + const float* diny, + float* dout, + int num) { + int cnt = num >> 4; + int remain = num % 16; +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* dinx_ptr = dinx + (i << 4); + const float* diny_ptr = diny + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t dinx0 = vld1q_f32(dinx_ptr); + float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4); + float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8); + float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12); + + float32x4_t diny0 = vld1q_f32(diny_ptr); + float32x4_t diny1 = vld1q_f32(diny_ptr + 4); + float32x4_t diny2 = vld1q_f32(diny_ptr + 8); + float32x4_t diny3 = vld1q_f32(diny_ptr + 12); + + dinx0 = vaddq_f32(dinx0, diny0); + dinx1 = vaddq_f32(dinx1, diny1); + dinx2 = vaddq_f32(dinx2, diny2); + dinx3 = vaddq_f32(dinx3, diny3); + + vst1q_f32(dout_ptr, dinx0); + vst1q_f32(dout_ptr + 4, dinx1); + vst1q_f32(dout_ptr + 8, dinx2); + vst1q_f32(dout_ptr + 12, dinx3); + } + if (remain > 0) { + const float* dinx_ptr = dinx + (cnt << 4); + const float* diny_ptr = diny + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + *dout_ptr = *dinx_ptr + *diny_ptr; + dout_ptr++; + dinx_ptr++; + diny_ptr++; + } + } +} + +template <> +void elementwise_add_relu(const float* dinx, + const float* diny, + float* dout, + int num) { + int cnt = num >> 4; + int remain = num % 16; + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* dinx_ptr = dinx + (i << 4); + const float* diny_ptr = diny + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t dinx0 = vld1q_f32(dinx_ptr); + float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4); + float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8); + float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12); + + float32x4_t diny0 = vld1q_f32(diny_ptr); + float32x4_t diny1 = vld1q_f32(diny_ptr + 4); + float32x4_t diny2 = vld1q_f32(diny_ptr + 8); + float32x4_t diny3 = vld1q_f32(diny_ptr + 12); + + dinx0 = vaddq_f32(dinx0, diny0); + dinx1 = vaddq_f32(dinx1, diny1); + dinx2 = vaddq_f32(dinx2, diny2); + dinx3 = vaddq_f32(dinx3, diny3); + + // relu + dinx0 = vmaxq_f32(dinx0, vzero); + dinx1 = vmaxq_f32(dinx1, vzero); + dinx2 = vmaxq_f32(dinx2, vzero); + dinx3 = vmaxq_f32(dinx3, vzero); + + vst1q_f32(dout_ptr, dinx0); + vst1q_f32(dout_ptr + 4, dinx1); + vst1q_f32(dout_ptr + 8, dinx2); + vst1q_f32(dout_ptr + 12, dinx3); + } + if (remain > 0) { + const float* dinx_ptr = dinx + (cnt << 4); + const float* diny_ptr = diny + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + float tmp = *dinx_ptr + *diny_ptr; + *dout_ptr = tmp > 0.f ? tmp : 0.f; + dout_ptr++; + dinx_ptr++; + diny_ptr++; + } + } +} + +template <> +void elementwise_add_broadcast(const float* dinx, + const float* diny, + float* dout, + int batch, + int channels, + int num) { +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const float* din_ptr = dinx + offset; + const float diny_data = diny[j]; + float* dout_ptr = dout + offset; + + int cnt = num >> 4; + int remain = num % 16; + float32x4_t rb = vdupq_n_f32(diny_data); + for (int k = 0; k < cnt; ++k) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + din0 = vaddq_f32(din0, rb); + din1 = vaddq_f32(din1, rb); + din2 = vaddq_f32(din2, rb); + din3 = vaddq_f32(din3, rb); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + din_ptr += 16; + dout_ptr += 16; + } + if (remain >= 8) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + din0 = vaddq_f32(din0, rb); + din1 = vaddq_f32(din1, rb); + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + din_ptr += 8; + dout_ptr += 8; + remain -= 8; + } + if (remain >= 4) { + float32x4_t din0 = vld1q_f32(din_ptr); + din0 = vaddq_f32(din0, rb); + vst1q_f32(dout_ptr, din0); + din_ptr += 4; + dout_ptr += 4; + remain -= 4; + } + if (remain > 0) { + for (int p = 0; p < remain; p++) { + *dout_ptr = *din_ptr + diny_data; + dout_ptr++; + din_ptr++; + } + } + } + } +} + +template <> +void elementwise_add_relu_broadcast(const float* dinx, + const float* diny, + float* dout, + int batch, + int channels, + int num) { + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const float* din_ptr = dinx + offset; + const float diny_data = diny[j]; + float* dout_ptr = dout + offset; + + int cnt = num >> 4; + int remain = num % 16; + float32x4_t rb = vdupq_n_f32(diny_data); + for (int k = 0; k < cnt; ++k) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + din0 = vaddq_f32(din0, rb); + din1 = vaddq_f32(din1, rb); + din2 = vaddq_f32(din2, rb); + din3 = vaddq_f32(din3, rb); + + // relu + din0 = vmaxq_f32(din0, vzero); + din1 = vmaxq_f32(din1, vzero); + din2 = vmaxq_f32(din2, vzero); + din3 = vmaxq_f32(din3, vzero); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + din_ptr += 16; + dout_ptr += 16; + } + if (remain >= 8) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + din0 = vaddq_f32(din0, rb); + din1 = vaddq_f32(din1, rb); + // relu + din0 = vmaxq_f32(din0, vzero); + din1 = vmaxq_f32(din1, vzero); + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + din_ptr += 8; + dout_ptr += 8; + remain -= 8; + } + if (remain >= 4) { + float32x4_t din0 = vld1q_f32(din_ptr); + din0 = vaddq_f32(din0, rb); + // relu + din0 = vmaxq_f32(din0, vzero); + vst1q_f32(dout_ptr, din0); + din_ptr += 4; + dout_ptr += 4; + remain -= 4; + } + if (remain > 0) { + for (int p = 0; p < remain; p++) { + float tmp = *din_ptr + diny_data; + *dout_ptr = tmp > 0.f ? tmp : 0.f; + dout_ptr++; + din_ptr++; + } + } + } + } +} + +template <> +void elementwise_mul(const float* dinx, + const float* diny, + float* dout, + int num) { + int cnt = num >> 4; + int remain = num % 16; +#pragma omp parallel for + for (int i = 0; i < cnt; ++i) { + const float* dinx_ptr = dinx + (i << 4); + const float* diny_ptr = diny + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t dinx0 = vld1q_f32(dinx_ptr); + float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4); + float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8); + float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12); + + float32x4_t diny0 = vld1q_f32(diny_ptr); + float32x4_t diny1 = vld1q_f32(diny_ptr + 4); + float32x4_t diny2 = vld1q_f32(diny_ptr + 8); + float32x4_t diny3 = vld1q_f32(diny_ptr + 12); + + dinx0 = vmulq_f32(dinx0, diny0); + dinx1 = vmulq_f32(dinx1, diny1); + dinx2 = vmulq_f32(dinx2, diny2); + dinx3 = vmulq_f32(dinx3, diny3); + + vst1q_f32(dout_ptr, dinx0); + vst1q_f32(dout_ptr + 4, dinx1); + vst1q_f32(dout_ptr + 8, dinx2); + vst1q_f32(dout_ptr + 12, dinx3); + } + if (remain > 0) { + const float* dinx_ptr = dinx + (cnt << 4); + const float* diny_ptr = diny + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + *dout_ptr = *dinx_ptr * *diny_ptr; + dout_ptr++; + dinx_ptr++; + diny_ptr++; + } + } +} + +template <> +void elementwise_mul_relu(const float* dinx, + const float* diny, + float* dout, + int num) { + int cnt = num >> 4; + int remain = num % 16; + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for + for (int i = 0; i < cnt; ++i) { + const float* dinx_ptr = dinx + (i << 4); + const float* diny_ptr = diny + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t dinx0 = vld1q_f32(dinx_ptr); + float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4); + float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8); + float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12); + + float32x4_t diny0 = vld1q_f32(diny_ptr); + float32x4_t diny1 = vld1q_f32(diny_ptr + 4); + float32x4_t diny2 = vld1q_f32(diny_ptr + 8); + float32x4_t diny3 = vld1q_f32(diny_ptr + 12); + + dinx0 = vmulq_f32(dinx0, diny0); + dinx1 = vmulq_f32(dinx1, diny1); + dinx2 = vmulq_f32(dinx2, diny2); + dinx3 = vmulq_f32(dinx3, diny3); + + // relu + dinx0 = vmaxq_f32(dinx0, vzero); + dinx1 = vmaxq_f32(dinx1, vzero); + dinx2 = vmaxq_f32(dinx2, vzero); + dinx3 = vmaxq_f32(dinx3, vzero); + + vst1q_f32(dout_ptr, dinx0); + vst1q_f32(dout_ptr + 4, dinx1); + vst1q_f32(dout_ptr + 8, dinx2); + vst1q_f32(dout_ptr + 12, dinx3); + } + if (remain > 0) { + const float* dinx_ptr = dinx + (cnt << 4); + const float* diny_ptr = diny + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + float tmp = *dinx_ptr * *diny_ptr; + *dout_ptr = tmp > 0.f ? tmp : 0.f; + dout_ptr++; + dinx_ptr++; + diny_ptr++; + } + } +} + +template <> +void elementwise_mul_broadcast(const float* dinx, + const float* diny, + float* dout, + int batch, + int channels, + int num) { +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const float* din_ptr = dinx + offset; + const float diny_data = diny[j]; + float* dout_ptr = dout + offset; + + int cnt = num >> 4; + int remain = num % 16; + float32x4_t rb = vdupq_n_f32(diny_data); + for (int k = 0; k < cnt; ++k) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + din0 = vmulq_f32(din0, rb); + din1 = vmulq_f32(din1, rb); + din2 = vmulq_f32(din2, rb); + din3 = vmulq_f32(din3, rb); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + + din_ptr += 16; + dout_ptr += 16; + } + if (remain >= 8) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + din0 = vmulq_f32(din0, rb); + din1 = vmulq_f32(din1, rb); + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + din_ptr += 8; + dout_ptr += 8; + remain -= 8; + } + if (remain >= 4) { + float32x4_t din0 = vld1q_f32(din_ptr); + din0 = vmulq_f32(din0, rb); + vst1q_f32(dout_ptr, din0); + din_ptr += 4; + dout_ptr += 4; + remain -= 4; + } + if (remain > 0) { + for (int p = 0; p < remain; ++p) { + *dout_ptr = *din_ptr * diny_data; + dout_ptr++; + din_ptr++; + } + } + } + } +} + +template <> +void elementwise_mul_relu_broadcast(const float* dinx, + const float* diny, + float* dout, + int batch, + int channels, + int num) { + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const float* din_ptr = dinx + offset; + const float diny_data = diny[j]; + float* dout_ptr = dout + offset; + + int cnt = num >> 4; + int remain = num % 16; + float32x4_t rb = vdupq_n_f32(diny_data); + for (int k = 0; k < cnt; ++k) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + din0 = vmulq_f32(din0, rb); + din1 = vmulq_f32(din1, rb); + din2 = vmulq_f32(din2, rb); + din3 = vmulq_f32(din3, rb); + + // relu + din0 = vmaxq_f32(din0, vzero); + din1 = vmaxq_f32(din1, vzero); + din2 = vmaxq_f32(din2, vzero); + din3 = vmaxq_f32(din3, vzero); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + din_ptr += 16; + dout_ptr += 16; + } + if (remain >= 8) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + din0 = vmulq_f32(din0, rb); + din1 = vmulq_f32(din1, rb); + // relu + din0 = vmaxq_f32(din0, vzero); + din1 = vmaxq_f32(din1, vzero); + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + din_ptr += 8; + dout_ptr += 8; + remain -= 8; + } + if (remain >= 4) { + float32x4_t din0 = vld1q_f32(din_ptr); + din0 = vmulq_f32(din0, rb); + // relu + din0 = vmaxq_f32(din0, vzero); + vst1q_f32(dout_ptr, din0); + din_ptr += 4; + dout_ptr += 4; + remain -= 4; + } + if (remain > 0) { + for (int p = 0; p < remain; ++p) { + float tmp = *din_ptr * diny_data; + *dout_ptr = tmp > 0.f ? tmp : 0.f; + dout_ptr++; + din_ptr++; + } + } + } + } +} + +template <> +void elementwise_max(const float* dinx, + const float* diny, + float* dout, + int num) { + int cnt = num >> 4; + int remain = num % 16; +#pragma omp parallel for + for (int i = 0; i < cnt; ++i) { + const float* dinx_ptr = dinx + (i << 4); + const float* diny_ptr = diny + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t dinx0 = vld1q_f32(dinx_ptr); + float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4); + float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8); + float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12); + + float32x4_t diny0 = vld1q_f32(diny_ptr); + float32x4_t diny1 = vld1q_f32(diny_ptr + 4); + float32x4_t diny2 = vld1q_f32(diny_ptr + 8); + float32x4_t diny3 = vld1q_f32(diny_ptr + 12); + + dinx0 = vmaxq_f32(dinx0, diny0); + dinx1 = vmaxq_f32(dinx1, diny1); + dinx2 = vmaxq_f32(dinx2, diny2); + dinx3 = vmaxq_f32(dinx3, diny3); + + vst1q_f32(dout_ptr, dinx0); + vst1q_f32(dout_ptr + 4, dinx1); + vst1q_f32(dout_ptr + 8, dinx2); + vst1q_f32(dout_ptr + 12, dinx3); + } + if (remain > 0) { + const float* dinx_ptr = dinx + (cnt << 4); + const float* diny_ptr = diny + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; ++i) { + *(dout_ptr++) = std::max(*(dinx_ptr++), *(diny_ptr++)); + } + } +} + +template <> +void elementwise_max_relu(const float* dinx, + const float* diny, + float* dout, + int num) { + int cnt = num >> 4; + int remain = num % 16; + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for + for (int i = 0; i < cnt; ++i) { + const float* dinx_ptr = dinx + (i << 4); + const float* diny_ptr = diny + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t dinx0 = vld1q_f32(dinx_ptr); + float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4); + float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8); + float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12); + + float32x4_t diny0 = vld1q_f32(diny_ptr); + float32x4_t diny1 = vld1q_f32(diny_ptr + 4); + float32x4_t diny2 = vld1q_f32(diny_ptr + 8); + float32x4_t diny3 = vld1q_f32(diny_ptr + 12); + + dinx0 = vmaxq_f32(dinx0, diny0); + dinx1 = vmaxq_f32(dinx1, diny1); + dinx2 = vmaxq_f32(dinx2, diny2); + dinx3 = vmaxq_f32(dinx3, diny3); + + // relu + dinx0 = vmaxq_f32(dinx0, vzero); + dinx1 = vmaxq_f32(dinx1, vzero); + dinx2 = vmaxq_f32(dinx2, vzero); + dinx3 = vmaxq_f32(dinx3, vzero); + + vst1q_f32(dout_ptr, dinx0); + vst1q_f32(dout_ptr + 4, dinx1); + vst1q_f32(dout_ptr + 8, dinx2); + vst1q_f32(dout_ptr + 12, dinx3); + } + if (remain > 0) { + const float* dinx_ptr = dinx + (cnt << 4); + const float* diny_ptr = diny + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; ++i) { + float tmp = std::max(*(dinx_ptr++), *(diny_ptr++)); + *(dout_ptr++) = tmp > 0.f ? tmp : 0.f; + } + } +} + +template <> +void elementwise_max_broadcast(const float* dinx, + const float* diny, + float* dout, + int batch, + int channels, + int num) { +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const float* din_ptr = dinx + offset; + const float diny_data = diny[j]; + float* dout_ptr = dout + offset; + + int cnt = num >> 4; + int remain = num % 16; + float32x4_t rb = vdupq_n_f32(diny_data); + for (int k = 0; k < cnt; ++k) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + din0 = vmaxq_f32(din0, rb); + din1 = vmaxq_f32(din1, rb); + din2 = vmaxq_f32(din2, rb); + din3 = vmaxq_f32(din3, rb); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + + din_ptr += 16; + dout_ptr += 16; + } + if (remain >= 8) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + din0 = vmaxq_f32(din0, rb); + din1 = vmaxq_f32(din1, rb); + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + din_ptr += 8; + dout_ptr += 8; + remain -= 8; + } + if (remain >= 4) { + float32x4_t din0 = vld1q_f32(din_ptr); + din0 = vmaxq_f32(din0, rb); + vst1q_f32(dout_ptr, din0); + din_ptr += 4; + dout_ptr += 4; + remain -= 4; + } + if (remain > 0) { + for (int p = 0; p < remain; ++p) { + *dout_ptr = std::max(*din_ptr, diny_data); + dout_ptr++; + din_ptr++; + } + } + } + } +} + +template <> +void elementwise_max_relu_broadcast(const float* dinx, + const float* diny, + float* dout, + int batch, + int channels, + int num) { + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const float* din_ptr = dinx + offset; + const float diny_data = diny[j]; + float* dout_ptr = dout + offset; + + int cnt = num >> 4; + int remain = num % 16; + float32x4_t rb = vdupq_n_f32(diny_data); + for (int k = 0; k < cnt; ++k) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + din0 = vmaxq_f32(din0, rb); + din1 = vmaxq_f32(din1, rb); + din2 = vmaxq_f32(din2, rb); + din3 = vmaxq_f32(din3, rb); + + // relu + din0 = vmaxq_f32(din0, vzero); + din1 = vmaxq_f32(din1, vzero); + din2 = vmaxq_f32(din2, vzero); + din3 = vmaxq_f32(din3, vzero); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + din_ptr += 16; + dout_ptr += 16; + } + if (remain >= 8) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + din0 = vmaxq_f32(din0, rb); + din1 = vmaxq_f32(din1, rb); + // relu + din0 = vmaxq_f32(din0, vzero); + din1 = vmaxq_f32(din1, vzero); + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + din_ptr += 8; + dout_ptr += 8; + remain -= 8; + } + if (remain >= 4) { + float32x4_t din0 = vld1q_f32(din_ptr); + din0 = vmaxq_f32(din0, rb); + // relu + din0 = vmaxq_f32(din0, vzero); + vst1q_f32(dout_ptr, din0); + din_ptr += 4; + dout_ptr += 4; + remain -= 4; + } + if (remain > 0) { + for (int p = 0; p < remain; ++p) { + float tmp = std::max(*din_ptr, diny_data); + *dout_ptr = tmp > 0.f ? tmp : 0.f; + dout_ptr++; + din_ptr++; + } + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/elementwise.h b/lite/arm/math/elementwise.h new file mode 100644 index 00000000000..e4772fb919e --- /dev/null +++ b/lite/arm/math/elementwise.h @@ -0,0 +1,67 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void elementwise_add(const T* dinx, const T* diny, T* dout, int num); + +template +void elementwise_add_relu(const T* dinx, const T* diny, T* dout, int num); + +template +void elementwise_add_broadcast( + const T* dinx, const T* diny, T* dout, int batch, int channels, int num); + +template +void elementwise_add_relu_broadcast( + const T* dinx, const T* diny, T* dout, int batch, int channels, int num); + +template +void elementwise_mul(const T* dinx, const T* diny, T* dout, int num); + +template +void elementwise_mul_relu(const T* dinx, const T* diny, T* dout, int num); + +template +void elementwise_mul_broadcast( + const T* dinx, const T* diny, T* dout, int batch, int channels, int num); + +template +void elementwise_mul_relu_broadcast( + const T* dinx, const T* diny, T* dout, int batch, int channels, int num); + +template +void elementwise_max(const T* dinx, const T* diny, T* dout, int num); + +template +void elementwise_max_relu(const T* dinx, const T* diny, T* dout, int num); + +template +void elementwise_max_broadcast( + const T* dinx, const T* diny, T* dout, int batch, int channels, int num); + +template +void elementwise_max_relu_broadcast( + const T* dinx, const T* diny, T* dout, int batch, int channels, int num); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/fill_bias_relu.cc b/lite/arm/math/fill_bias_relu.cc new file mode 100644 index 00000000000..b4cf2d876a2 --- /dev/null +++ b/lite/arm/math/fill_bias_relu.cc @@ -0,0 +1,122 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/fill_bias_relu.h" +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void fill_bias_relu(float* tensor, + const float* bias, + int channel, + int channel_size, + bool flag_bias, + bool flag_relu) { + float* data = tensor; + if (flag_relu) { + for (int j = 0; j < channel; ++j) { + float bias_data = flag_bias ? bias[j] : 0.f; + float32x4_t vbias = vdupq_n_f32(bias_data); + float32x4_t vzero = vdupq_n_f32(0.f); + int i = 0; + for (; i < channel_size - 3; i += 4) { + float32x4_t vdata = vld1q_f32(&data[i]); + vdata = vaddq_f32(vdata, vbias); + float32x4_t vmax = vmaxq_f32(vdata, vzero); + vst1q_f32(data + i, vmax); + } + for (; i < channel_size; i++) { + data[i] += bias_data; + data[i] = data[i] > 0 ? data[i] : 0.f; + } + data += channel_size; + } + } else { + for (int j = 0; j < channel; ++j) { + float bias_data = flag_bias ? bias[j] : 0.f; + float32x4_t vbias = vdupq_n_f32(bias_data); + int i = 0; + for (; i < channel_size - 3; i += 4) { + float32x4_t vdata = vld1q_f32(&data[i]); + vdata = vaddq_f32(vdata, vbias); + vst1q_f32(data + i, vdata); + } + for (; i < channel_size; i++) { + data[i] += bias_data; + } + data += channel_size; + } + } +} + +template <> +void fill_bias_relu(int* tensor, + const int* bias, + int channel, + int channel_size, + bool flag_bias, + bool flag_relu) { + int* data = tensor; + if (flag_relu) { + for (int j = 0; j < channel; ++j) { + int bias_data = flag_bias ? bias[j] : 0; + int32x4_t vbias = vdupq_n_s32(bias_data); + int32x4_t vzero = vdupq_n_s32(0); + int i = 0; + for (; i < channel_size - 7; i += 8) { + int32x4_t vdata1 = vld1q_s32(data + i); + int32x4_t vdata2 = vld1q_s32(data + i + 4); + vdata1 = vaddq_s32(vdata1, vbias); + vdata2 = vaddq_s32(vdata2, vbias); + int32x4_t vmax1 = vmaxq_s32(vdata1, vzero); + int32x4_t vmax2 = vmaxq_s32(vdata2, vzero); + vst1q_s32(data + i, vmax1); + vst1q_s32(data + i + 4, vmax2); + } + for (; i < channel_size; i++) { + data[i] += bias_data; + data[i] = data[i] > 0 ? data[i] : 0; + } + data += channel_size; + } + } else { + for (int j = 0; j < channel; ++j) { + int bias_data = flag_bias ? bias[j] : 0; + int32x4_t vbias = vdupq_n_s32(bias_data); + int i = 0; + for (; i < channel_size - 7; i += 8) { + int32x4_t vdata1 = vld1q_s32(data + i); + int32x4_t vdata2 = vld1q_s32(data + i + 4); + vdata1 = vaddq_s32(vdata1, vbias); + vdata2 = vaddq_s32(vdata2, vbias); + vst1q_s32(data + i, vdata1); + vst1q_s32(data + i + 4, vdata2); + } + for (; i < channel_size; i++) { + data[i] += bias_data; + } + data += channel_size; + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/fill_bias_relu.h b/lite/arm/math/fill_bias_relu.h new file mode 100644 index 00000000000..254d6d43be8 --- /dev/null +++ b/lite/arm/math/fill_bias_relu.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +/** + * * \brief neon implementation to add bias and relu + * * @param tensor + * * @param bias + * * @param channel + * * @param channel_size + * + */ +template +void fill_bias_relu(Dtype* tensor, + const Dtype* bias, + int channel, + int channel_size, + bool flag_bias, + bool flag_relu); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/funcs.cc b/lite/arm/math/funcs.cc new file mode 100644 index 00000000000..edc5fe7fdcc --- /dev/null +++ b/lite/arm/math/funcs.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/funcs.h" +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void fill_bias_fc(float *out, const float *bias, int num, int channel) { + int cnt = channel >> 4; + int remain = channel & 15; + + for (int j = 0; j < num; ++j) { + const float *ptr_bias = bias; + float *ptr_out = out + j * channel; + + float32x4_t vout1; + float32x4_t vout2; + float32x4_t vout3; + float32x4_t vout4; + + for (int i = 0; i < cnt; ++i) { + float32x4_t vin1 = vld1q_f32(ptr_out); + float32x4_t vb1 = vld1q_f32(ptr_bias); + + float32x4_t vin2 = vld1q_f32(ptr_out + 4); + float32x4_t vb2 = vld1q_f32(ptr_bias + 4); + + float32x4_t vin3 = vld1q_f32(ptr_out + 8); + float32x4_t vb3 = vld1q_f32(ptr_bias + 8); + + float32x4_t vin4 = vld1q_f32(ptr_out + 12); + float32x4_t vb4 = vld1q_f32(ptr_bias + 12); + + vout1 = vaddq_f32(vin1, vb1); + vout2 = vaddq_f32(vin2, vb2); + vout3 = vaddq_f32(vin3, vb3); + vout4 = vaddq_f32(vin4, vb4); + + vst1q_f32(ptr_out, vout1); + vst1q_f32(ptr_out + 4, vout2); + vst1q_f32(ptr_out + 8, vout3); + vst1q_f32(ptr_out + 12, vout4); + + ptr_out += 16; + ptr_bias += 16; + } +#if 0 + if (cnt > 0) { + asm( + "1: \n" + "vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n" + "vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n" + "vadd.f32 q2, q0, q1 @ add bias\n" + "vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n" + "subs %[cnt], #1 @ loop count -1\n" + "bne 1b @ jump to main loop\n" + :[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \ + [cnt] "+r"(cnt) + : + :"q0", "q1", "q2" + ); + } +#endif + for (int i = 0; i < remain; ++i) { + *(ptr_out++) += *(ptr_bias++); + } + } +} + +template <> +void fill_bias_fc(int *out, const int *bias, int num, int channel) { + int cnt = channel >> 4; + int remain = channel & 15; + + for (int j = 0; j < num; ++j) { + const int *ptr_bias = bias; + int *ptr_out = out + j * channel; + + int32x4_t vout1; + int32x4_t vout2; + int32x4_t vout3; + int32x4_t vout4; + + for (int i = 0; i < cnt; ++i) { + int32x4_t vin1 = vld1q_s32(ptr_out); + int32x4_t vb1 = vld1q_s32(ptr_bias); + + int32x4_t vin2 = vld1q_s32(ptr_out + 4); + int32x4_t vb2 = vld1q_s32(ptr_bias + 4); + + int32x4_t vin3 = vld1q_s32(ptr_out + 8); + int32x4_t vb3 = vld1q_s32(ptr_bias + 8); + + int32x4_t vin4 = vld1q_s32(ptr_out + 12); + int32x4_t vb4 = vld1q_s32(ptr_bias + 12); + + vout1 = vaddq_s32(vin1, vb1); + vout2 = vaddq_s32(vin2, vb2); + vout3 = vaddq_s32(vin3, vb3); + vout4 = vaddq_s32(vin4, vb4); + + vst1q_s32(ptr_out, vout1); + vst1q_s32(ptr_out + 4, vout2); + vst1q_s32(ptr_out + 8, vout3); + vst1q_s32(ptr_out + 12, vout4); + + ptr_out += 16; + ptr_bias += 16; + } + +#if 0 + if (cnt > 0) { + asm( + "1: \n" + "vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n" + "vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n" + "vadd.s32 q2, q0, q1 @ add bias\n" + "vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n" + "subs %[cnt], #1 @ loop count -1\n" + "bne 1b @ jump to main loop\n" + :[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \ + [cnt] "+r"(cnt) + : + :"q0", "q1", "q2" + ); + } +#endif + for (int i = 0; i < remain; ++i) { + *(ptr_out++) += *(ptr_bias++); + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/funcs.h b/lite/arm/math/funcs.h new file mode 100644 index 00000000000..2cfb95e3666 --- /dev/null +++ b/lite/arm/math/funcs.h @@ -0,0 +1,424 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "lite/arm/math/activation.h" +#include "lite/arm/math/argmax.h" +#include "lite/arm/math/axpy.h" +#include "lite/arm/math/beam_search.h" +#include "lite/arm/math/box_coder.h" +#include "lite/arm/math/col_im_transform.h" +#include "lite/arm/math/concat.h" +#include "lite/arm/math/conv_depthwise.h" +#include "lite/arm/math/conv_direct.h" +#include "lite/arm/math/conv_gemmlike.h" +#include "lite/arm/math/conv_winograd.h" +#include "lite/arm/math/decode_bboxes.h" +#include "lite/arm/math/dropout.h" +#include "lite/arm/math/elementwise.h" +#include "lite/arm/math/fill_bias_relu.h" +#include "lite/arm/math/im2sequence.h" +#include "lite/arm/math/increment.h" +#include "lite/arm/math/interpolate.h" +#include "lite/arm/math/lrn.h" +#include "lite/arm/math/multiclass_nms.h" +#include "lite/arm/math/negative.h" +#include "lite/arm/math/norm.h" +#include "lite/arm/math/packed_sgemm.h" +#include "lite/arm/math/pad2d.h" +#include "lite/arm/math/pooling.h" +#include "lite/arm/math/power.h" +#include "lite/arm/math/prior_box.h" +#include "lite/arm/math/reduce_max.h" +#include "lite/arm/math/scale.h" +#include "lite/arm/math/sequence_expand.h" +#include "lite/arm/math/sequence_pool.h" +#include "lite/arm/math/sequence_softmax.h" +#include "lite/arm/math/sgemm.h" +#include "lite/arm/math/sgemv.h" +#include "lite/arm/math/shuffle_channel.h" +#include "lite/arm/math/slice.h" +#include "lite/arm/math/softmax.h" +#include "lite/arm/math/split.h" +#include "lite/arm/math/topk.h" +#include "lite/arm/math/yolo_box.h" +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +#define c_inv_mant_mask ~0x7f800000u +#define c_cephes_SQRTHF 0.707106781186547524 +#define c_cephes_log_p0 7.0376836292E-2 +#define c_cephes_log_p1 -1.1514610310E-1 +#define c_cephes_log_p2 1.1676998740E-1 +#define c_cephes_log_p3 -1.2420140846E-1 +#define c_cephes_log_p4 +1.4249322787E-1 +#define c_cephes_log_p5 -1.6668057665E-1 +#define c_cephes_log_p6 +2.0000714765E-1 +#define c_cephes_log_p7 -2.4999993993E-1 +#define c_cephes_log_p8 +3.3333331174E-1 +#define c_cephes_log_q1 -2.12194440e-4 +#define c_cephes_log_q2 0.693359375 + +// natural logarithm computed for 4 simultaneous float +// return NaN for x <= 0 +inline float32x4_t log_ps(float32x4_t x) { + float32x4_t one = vdupq_n_f32(1); + + x = vmaxq_f32(x, vdupq_n_f32(0)); // force flush to zero on denormal values + uint32x4_t invalid_mask = vcleq_f32(x, vdupq_n_f32(0)); + + int32x4_t ux = vreinterpretq_s32_f32(x); + + int32x4_t emm0 = vshrq_n_s32(ux, 23); + + // keep only the fractional part + ux = vandq_s32(ux, vdupq_n_s32(c_inv_mant_mask)); + ux = vorrq_s32(ux, vreinterpretq_s32_f32(vdupq_n_f32(0.5f))); + x = vreinterpretq_f32_s32(ux); + + emm0 = vsubq_s32(emm0, vdupq_n_s32(0x7f)); + float32x4_t e = vcvtq_f32_s32(emm0); + + e = vaddq_f32(e, one); + + // part2: + // if( x < SQRTHF ) { + // e -= 1; + // x = x + x - 1.0; + // } else { + // x = x - 1.0; + // } + // + uint32x4_t mask = vcltq_f32(x, vdupq_n_f32(c_cephes_SQRTHF)); + float32x4_t tmp = + vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask)); + x = vsubq_f32(x, one); + e = vsubq_f32( + e, vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(one), mask))); + x = vaddq_f32(x, tmp); + + float32x4_t z = vmulq_f32(x, x); + + float32x4_t y = vdupq_n_f32(c_cephes_log_p0); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p1)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p2)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p3)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p4)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p5)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p6)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p7)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p8)); + y = vmulq_f32(y, x); + + y = vmulq_f32(y, z); + + tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q1)); + y = vaddq_f32(y, tmp); + + tmp = vmulq_f32(z, vdupq_n_f32(0.5f)); + y = vsubq_f32(y, tmp); + + tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q2)); + x = vaddq_f32(x, y); + x = vaddq_f32(x, tmp); + x = vreinterpretq_f32_u32(vorrq_u32( + vreinterpretq_u32_f32(x), invalid_mask)); // negative arg will be NAN + return x; +} + +#define c_exp_hi 88.3762626647949f +#define c_exp_lo -88.3762626647949f + +#define c_cephes_LOG2EF 1.44269504088896341 +#define c_cephes_exp_C1 0.693359375 +#define c_cephes_exp_C2 -2.12194440e-4 + +#define c_cephes_exp_p0 1.9875691500E-4 +#define c_cephes_exp_p1 1.3981999507E-3 +#define c_cephes_exp_p2 8.3334519073E-3 +#define c_cephes_exp_p3 4.1665795894E-2 +#define c_cephes_exp_p4 1.6666665459E-1 +#define c_cephes_exp_p5 5.0000001201E-1 + +// exp() computed for 4 float at once +inline float32x4_t exp_ps(float32x4_t x) { + float32x4_t tmp, fx; + + float32x4_t one = vdupq_n_f32(1); + x = vminq_f32(x, vdupq_n_f32(c_exp_hi)); + x = vmaxq_f32(x, vdupq_n_f32(c_exp_lo)); + + // express exp(x) as exp(g + n*log(2)) + fx = vmlaq_f32(vdupq_n_f32(0.5f), x, vdupq_n_f32(c_cephes_LOG2EF)); + + // perform a floorf + tmp = vcvtq_f32_s32(vcvtq_s32_f32(fx)); + + // if greater, substract 1 + uint32x4_t mask = vcgtq_f32(tmp, fx); + mask = vandq_u32(mask, vreinterpretq_u32_f32(one)); + + fx = vsubq_f32(tmp, vreinterpretq_f32_u32(mask)); + + tmp = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C1)); + float32x4_t z = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C2)); + x = vsubq_f32(x, tmp); + x = vsubq_f32(x, z); + + static const float cephes_exp_p[6] = {c_cephes_exp_p0, + c_cephes_exp_p1, + c_cephes_exp_p2, + c_cephes_exp_p3, + c_cephes_exp_p4, + c_cephes_exp_p5}; + float32x4_t y = vld1q_dup_f32(cephes_exp_p + 0); + float32x4_t c1 = vld1q_dup_f32(cephes_exp_p + 1); + float32x4_t c2 = vld1q_dup_f32(cephes_exp_p + 2); + float32x4_t c3 = vld1q_dup_f32(cephes_exp_p + 3); + float32x4_t c4 = vld1q_dup_f32(cephes_exp_p + 4); + float32x4_t c5 = vld1q_dup_f32(cephes_exp_p + 5); + + y = vmulq_f32(y, x); + z = vmulq_f32(x, x); + + y = vaddq_f32(y, c1); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c2); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c3); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c4); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c5); + + y = vmulq_f32(y, z); + y = vaddq_f32(y, x); + y = vaddq_f32(y, one); + + // build 2^n + int32x4_t mm; + mm = vcvtq_s32_f32(fx); + mm = vaddq_s32(mm, vdupq_n_s32(0x7f)); + mm = vshlq_n_s32(mm, 23); + float32x4_t pow2n = vreinterpretq_f32_s32(mm); + + y = vmulq_f32(y, pow2n); + return y; +} + +#define c_minus_cephes_DP1 -0.78515625 +#define c_minus_cephes_DP2 -2.4187564849853515625e-4 +#define c_minus_cephes_DP3 -3.77489497744594108e-8 +#define c_sincof_p0 -1.9515295891E-4 +#define c_sincof_p1 8.3321608736E-3 +#define c_sincof_p2 -1.6666654611E-1 +#define c_coscof_p0 2.443315711809948E-005 +#define c_coscof_p1 -1.388731625493765E-003 +#define c_coscof_p2 4.166664568298827E-002 +#define c_cephes_FOPI 1.27323954473516 // 4 / M_PI + +// evaluation of 4 sines & cosines at once. +// +// The code is the exact rewriting of the cephes sinf function. +// Precision is excellent as long as x < 8192 (I did not bother to +// take into account the special handling they have for greater values +// -- it does not return garbage for arguments over 8192, though, but +// the extra precision is missing). +// +// Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the +// surprising but correct result. +// +// Note also that when you compute sin(x), cos(x) is available at +// almost no extra price so both sin_ps and cos_ps make use of +// sincos_ps.. +// +inline void sincos_ps(float32x4_t x, float32x4_t *ysin, float32x4_t *ycos) { + // any x + float32x4_t xmm1, xmm2, xmm3, y; + + uint32x4_t emm2; + + uint32x4_t sign_mask_sin, sign_mask_cos; + sign_mask_sin = vcltq_f32(x, vdupq_n_f32(0)); + x = vabsq_f32(x); + + // scale by 4/Pi + y = vmulq_f32(x, vdupq_n_f32(c_cephes_FOPI)); + + // store the integer part of y in mm0 + emm2 = vcvtq_u32_f32(y); + // j=(j+1) & (~1) (see the cephes sources) + emm2 = vaddq_u32(emm2, vdupq_n_u32(1)); + emm2 = vandq_u32(emm2, vdupq_n_u32(~1)); + y = vcvtq_f32_u32(emm2); + + // get the polynom selection mask + // there is one polynom for 0 <= x <= Pi/4 + // and another one for Pi/4 +void fill_bias_fc(T *tensor, const T *bias, int num, int channel); + +template +inline float32x4_t vactive_f32(const float32x4_t &x) { + return x; +} + +template <> +inline float32x4_t vactive_f32( + const float32x4_t &x) { + float32x4_t __zero = vdupq_n_f32(0.f); + return vmaxq_f32(x, __zero); +} + +template <> +inline float32x4_t vactive_f32( + const float32x4_t &x) { + float32x4_t __zero = vdupq_n_f32(0.f); + float32x4_t __six = vdupq_n_f32(6.f); + return vminq_f32(vmaxq_f32(x, __zero), __six); +} + +template <> +inline float32x4_t vactive_f32( + const float32x4_t &x) { + float32x4_t __one = vdupq_n_f32(1.f); + float32x4_t __x = vnegq_f32(x); + __x = exp_ps(__x); + __x = vaddq_f32(__x, __one); + float32x4_t __out = vrecpeq_f32(__x); + return vmulq_f32(vrecpsq_f32(__x, __out), __out); +} + +template <> +inline float32x4_t vactive_f32( + const float32x4_t &x) { + float32x4_t __one = vdupq_n_f32(1.f); + float32x4_t __x = vmulq_n_f32(x, -2.f); + __x = exp_ps(__x); + __x = vaddq_f32(__x, __one); + float32x4_t __out = vrecpeq_f32(__x); + __out = vmulq_f32(vrecpsq_f32(__x, __out), __out); + __out = vmulq_n_f32(__out, 2.f); + return vsubq_f32(__out, __one); +} + +template +inline float active_f32(const float &x) { + return x; +} + +template <> +inline float active_f32(const float &x) { + return std::max(x, 0.f); +} + +template <> +inline float active_f32(const float &x) { + return std::min(std::max(x, 0.f), 6.f); +} + +template <> +inline float active_f32(const float &x) { + return 1.f / (1.f + exp(-x)); +} + +template <> +inline float active_f32(const float &x) { + return 2.f / (1.f + exp(-2.f * x)) - 1.f; +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/gemm_prepacked_int8.cc b/lite/arm/math/gemm_prepacked_int8.cc new file mode 100644 index 00000000000..ded4e97c016 --- /dev/null +++ b/lite/arm/math/gemm_prepacked_int8.cc @@ -0,0 +1,3942 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/gemm_prepacked_int8.h" +#include +#include "lite/arm/math/dot_toolchain_support.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void prepackA_m4k2x2_int8(int8_t* out, + const int8_t* in, + int ldin, + int m0, + int mmax, + int k0, + int kmax); + +void prepackA_m4k2x2_trans_int8(int8_t* out, + const int8_t* in, + int ldin, + int m0, + int mmax, + int k0, + int kmax); + +void packb_int8(int8_t* out, + const int8_t* in, + int ldin, + int k0, + int kmax, + int n0, + int nmax, + const int8_t* zerobuf); + +void packb_trans_int8(int8_t* out, + const int8_t* in, + int ldin, + int k0, + int kmax, + int n0, + int nmax, + const int8_t* zerobuf); + +#ifdef WITH_ARM_DOTPROD +void prepackA_m8k4_int8(int8_t* out, + const int8_t* in, + int ldin, + int m0, + int mmax, + int k0, + int kmax); + +void prepackA_m8k4_trans_int8(int8_t* out, + const int8_t* in, + int ldin, + int m0, + int mmax, + int k0, + int kmax); + +void packb_sdot_int8(int8_t* out, + const int8_t* in, + int ldin, + int k0, + int kmax, + int n0, + int nmax); + +void packb_sdot_trans_int8(int8_t* out, + const int8_t* in, + int ldin, + int k0, + int kmax, + int n0, + int nmax); +#endif + +void prepackA_int8(void* out, + const void* in, + int ldin, + int m0, + int mmax, + int k0, + int kmax, + bool is_trans, + ARMContext* ctx) { +#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) + if (is_trans) { + if (ctx->has_dot()) { + prepackA_m8k4_trans_int8(static_cast(out), + static_cast(in), + ldin, + m0, + mmax, + k0, + kmax); + } else { + prepackA_m4k2x2_trans_int8(static_cast(out), + static_cast(in), + ldin, + m0, + mmax, + k0, + kmax); + } + } else { + if (ctx->has_dot()) { + prepackA_m8k4_int8(static_cast(out), + static_cast(in), + ldin, + m0, + mmax, + k0, + kmax); + } else { + prepackA_m4k2x2_int8(static_cast(out), + static_cast(in), + ldin, + m0, + mmax, + k0, + kmax); + } + } +#else + if (is_trans) { + prepackA_m4k2x2_trans_int8(static_cast(out), + static_cast(in), + ldin, + m0, + mmax, + k0, + kmax); + } else { + prepackA_m4k2x2_int8(static_cast(out), + static_cast(in), + ldin, + m0, + mmax, + k0, + kmax); + } +#endif +} + +void prepackA_int8(TensorLite* tout, + const TensorLite& tin, + int m, + int k, + int group, + bool is_trans, + ARMContext* ctx) { + int hblock = get_hblock_int8(ctx); + int m_roundup = ROUNDUP(m, hblock); + // round up to 128 bits + int kup = ROUNDUP(k, KBLOCK_INT8); + int group_size_round_up = ((m_roundup * kup + 15) / 16) * 16; + + if (tout->numel() < group_size_round_up * group) { + tout->Resize({1, 1, 1, group_size_round_up * group}); + } + int lda = k; + if (is_trans) { + lda = m; + } + for (int g = 0; g < group; ++g) { + const char* weights_group = tin.data() + g * m * k; + char* weights_trans_ptr = + tout->mutable_data() + g * group_size_round_up; + prepackA_int8( + weights_trans_ptr, weights_group, lda, 0, m, 0, k, is_trans, ctx); + } +} + +template +inline void gemm_int8_kernel(const int8_t* a_ptr, + const int8_t*& b_ptr, // NOLINT + const int32_t* bias, + Dtype*& c_ptr0, // NOLINT + Dtype*& c_ptr1, // NOLINT + Dtype*& c_ptr2, // NOLINT + Dtype*& c_ptr3, // NOLINT + const float* scale, + bool is_relu, + int k, + int rem); +#ifdef __aarch64__ +#define GEMM_INT8_KERNEL \ + "ld1 {v0.16b}, [%[a_ptr]],#16\n" /* load a to q0, q1 */ \ + "ld1 {v4.16b, v5.16b}, [%[b_ptr]],#32\n" /* load b to q4, q5 */ \ + "ld1 {v6.16b, v7.16b}, [%[b_ptr]],#32\n" /* load b to q6, q7 */ \ + "ldr q8, [%[bias]]\n" /* load bias */ \ + "ext v9.16b, v8.16b, v8.16b, #4\n" /* shift left 1s */ \ + "ext v10.16b, v8.16b, v8.16b, #8\n" /* shift left 2s */ \ + "ext v11.16b, v8.16b, v8.16b, #12\n" /* shift left 3s */ \ + "and v16.16b, v8.16b, v8.16b\n" /* set bias0 to out00 */ \ + "and v17.16b, v9.16b, v9.16b\n" /* set bias0 to out01 */ \ + "prfm pldl1keep, [%[a_ptr], #64]\n" /* preload a*/ \ + "and v18.16b, v10.16b, v10.16b\n" /* set bias0 to out02 */ \ + "and v19.16b, v11.16b, v11.16b\n" /* set bias0 to out03 */ \ + "prfm pldl1keep, [%[b_ptr], #64]\n" /* preload b*/ \ + "and v20.16b, v8.16b, v8.16b\n" /* set bias0 to out10 */ \ + "and v21.16b, v9.16b, v9.16b\n" /* set bias0 to out11 */ \ + "prfm pldl1keep, [%[a_ptr], #128]\n" /* preload a*/ \ + "and v22.16b, v10.16b, v10.16b\n" /* set bias0 to out12 */ \ + "and v23.16b, v11.16b, v11.16b\n" /* set bias0 to out13 */ \ + "prfm pldl1keep, [%[b_ptr], #128]\n" /* preload b*/ \ + "and v24.16b, v8.16b, v8.16b\n" /* set bias0 to out20 */ \ + "and v25.16b, v9.16b, v9.16b\n" /* set bias0 to out21 */ \ + "prfm pldl1keep, [%[a_ptr], #192]\n" /* preload a*/ \ + "and v26.16b, v10.16b, v10.16b\n" /* set bias0 to out22 */ \ + "and v27.16b, v11.16b, v11.16b\n" /* set bias0 to out23 */ \ + "prfm pldl1keep, [%[b_ptr], #192]\n" /* preload b*/ \ + "and v28.16b, v8.16b, v8.16b\n" /* set bias0 to out30 */ \ + "and v29.16b, v9.16b, v9.16b\n" /* set bias0 to out31 */ \ + "prfm pldl1keep, [%[b_ptr], #256]\n" /* preload b*/ \ + "and v30.16b, v10.16b, v10.16b\n" /* set bias0 to out32 */ \ + "and v31.16b, v11.16b, v11.16b\n" /* set bias0 to out33 */ \ + "ext v1.16b, v0.16b, v0.16b, #2\n" /* shift left 2bytes */ \ + "ins v1.h[3], v0.h[0]\n" /* insert element */ \ + "ins v1.h[7], v0.h[4]\n" /* insert element */ \ + "rev64 v2.4s, v0.4s\n" /* get low: 22,33,00,11; hi: 66,77,44,55 */ \ + "rev64 v3.4s, v1.4s\n" /* get low: 33,00,11,22; hi: 77,44,55,66 */ \ + "prfm pldl1keep, [%[b_ptr], #320]\n" /* preload a*/ \ + "prfm pldl1keep, [%[b_ptr], #384]\n" /* preload b*/ \ + "cbz %w[k], 3f\n" /* if k = 0, jump to remains */ /* 1st b0, b1 */ \ + "smull v8.8h, v0.8b, v4.8b\n" /* a0 * b0 = c00 */ \ + "smull v12.8h, v0.8b, v5.8b\n" /* a0 * b1 = c01 */ \ + "smull v9.8h, v1.8b, v4.8b\n" /* a1 * b0 = c10 */ \ + "smull v13.8h, v1.8b, v5.8b\n" /* a1 * b1 = c11 */ \ + "smull v10.8h, v2.8b, v4.8b\n" /* a2 * b0 = c20 */ \ + "smull v14.8h, v2.8b, v5.8b\n" /* a2 * b1 = c21 */ \ + "smull v11.8h, v3.8b, v4.8b\n" /* a3 * b0 = c30 */ \ + "smull v15.8h, v3.8b, v5.8b\n" /* a3 * b1 = c31 */ \ + "subs %w[k], %w[k], #1\n" /* loop count -1 */ /* 2nd b0, b1 */ \ + "smlal2 v8.8h, v0.16b, v4.16b\n" /* a0 * b0 = c00 */ \ + "smlal2 v12.8h, v0.16b, v5.16b\n" /* a0 * b1 = c01 */ \ + "smlal2 v9.8h, v1.16b, v4.16b\n" /* a1 * b0 = c10 */ \ + "smlal2 v13.8h, v1.16b, v5.16b\n" /* a1 * b1 = c11 */ \ + "smlal2 v10.8h, v2.16b, v4.16b\n" /* a2 * b0 = c20 */ \ + "smlal2 v14.8h, v2.16b, v5.16b\n" /* a2 * b1 = c21 */ \ + "smlal2 v11.8h, v3.16b, v4.16b\n" /* a3 * b0 = c30 */ \ + "smlal2 v15.8h, v3.16b, v5.16b\n" /* a3 * b1 = c31 */ \ + "beq 8f\n" /* skip main loop */ /* main loop*/ \ + "0:\n" /* main loop */ \ + "ld1 {v4.16b, v5.16b}, [%[b_ptr]],#32\n" /* load b to q4, q5 */ \ + "sadalp v16.4s, v8.8h\n" /* pairwise accumulate to int32, out00 */ \ + "smull v8.8h, v0.8b, v6.8b\n" /* a0 * b2 = c02 */ \ + "sadalp v20.4s, v12.8h\n" /* pairwise accumulate to int32, out01 */ \ + "smull v12.8h, v0.8b, v7.8b\n" /* a0 * b3 = c03 */ \ + "sadalp v17.4s, v9.8h\n" /* pairwise accumulate to int32, out10 */ \ + "smull v9.8h, v1.8b, v6.8b\n" /* a1 * b2 = c12 */ \ + "sadalp v21.4s, v13.8h\n" /* pairwise accumulate to int32, out11 */ \ + "smull v13.8h, v1.8b, v7.8b\n" /* a1 * b3 = c13 */ \ + "sadalp v18.4s, v10.8h\n" /* pairwise accumulate to int32, out20 */ \ + "smull v10.8h, v2.8b, v6.8b\n" /* a2 * b2 = c22 */ \ + "sadalp v22.4s, v14.8h\n" /* pairwise accumulate to int32, out21 */ \ + "smull v14.8h, v2.8b, v7.8b\n" /* a2 * b3 = c23 */ \ + "sadalp v19.4s, v11.8h\n" /* pairwise accumulate to int32, out30 */ \ + "smlal2 v8.8h, v0.16b, v6.16b\n" /* a0 * b2 = c02 */ \ + "smlal2 v12.8h, v0.16b, v7.16b\n" /* a0 * b3 = c03 */ \ + "ld1 {v0.16b}, [%[a_ptr]],#16\n" /* load a to q0, q1 */ \ + "smull v11.8h, v3.8b, v6.8b\n" /* a3 * b2 = c32 */ \ + "sadalp v23.4s, v15.8h\n" /* pairwise accumulate to int32, out31 */ \ + "smull v15.8h, v3.8b, v7.8b\n" /* a3 * b3 = c33 */ /* 2nd b2, b3 */ \ + "smlal2 v9.8h, v1.16b, v6.16b\n" /* a1 * b2 = c12 */ \ + "smlal2 v13.8h, v1.16b, v7.16b\n" /* a1 * b3 = c13 */ \ + "smlal2 v10.8h, v2.16b, v6.16b\n" /* a2 * b2 = c22 */ \ + "ext v1.16b, v0.16b, v0.16b, #2\n" /* shift left 2bytes*/ \ + "ins v1.h[3], v0.h[0]\n" /* insert element */ \ + "ins v1.h[7], v0.h[4]\n" /* insert element */ \ + "smlal2 v14.8h, v2.16b, v7.16b\n" /* a2 * b3 = c23 */ \ + "smlal2 v11.8h, v3.16b, v6.16b\n" /* a3 * b2 = c32 */ \ + "smlal2 v15.8h, v3.16b, v7.16b\n" /* a3 * b3 = c33 */ /* pre-process a */ \ + "rev64 v2.4s, v0.4s\n" /* get low: 22,33,00,11; hi: 66,77,44,55 */ \ + "rev64 v3.4s, v1.4s\n" /* get low: 33,00,11,22; hi: 77,44,55,66 */ \ + "ld1 {v6.16b, v7.16b}, [%[b_ptr]],#32\n" /* load b to q6, q7 */ \ + "sadalp v24.4s, v8.8h\n" /* pairwise accumulate to int32, out02 */ \ + "smull v8.8h, v0.8b, v4.8b\n" /* a0 * b0 = c00 */ \ + "sadalp v28.4s, v12.8h\n" /* pairwise accumulate to int32, out03 */ \ + "smull v12.8h, v0.8b, v5.8b\n" /* a0 * b1 = c01 */ \ + "sadalp v25.4s, v9.8h\n" /* pairwise accumulate to int32, out12 */ \ + "smull v9.8h, v1.8b, v4.8b\n" /* a1 * b0 = c00 */ \ + "sadalp v29.4s, v13.8h\n" /* pairwise accumulate to int32, out13 */ \ + "smull v13.8h, v1.8b, v5.8b\n" /* a1 * b1 = c01 */ \ + "sadalp v26.4s, v10.8h\n" /* pairwise accumulate to int32, out22 */ \ + "smull v10.8h, v2.8b, v4.8b\n" /* a2 * b0 = c00 */ \ + "sadalp v30.4s, v14.8h\n" /* pairwise accumulate to int32, out23 */ \ + "smull v14.8h, v2.8b, v5.8b\n" /* a2 * b1 = c01 */ \ + "sadalp v27.4s, v11.8h\n" /* pairwise accumulate to int32, out32 */ \ + "smull v11.8h, v3.8b, v4.8b\n" /* a3 * b0 = c00 */ \ + "sadalp v31.4s, v15.8h\n" /* pairwise accumulate to int32, out33 */ \ + "smull v15.8h, v3.8b, v5.8b\n" /* a3 * b1 = c01 */ \ + "subs %w[k], %w[k], #1\n" /* loop count -1 */ /* 2nd b0, b1 */ \ + "smlal2 v8.8h, v0.16b, v4.16b\n" /* a0 * b0 = c00 */ \ + "smlal2 v12.8h, v0.16b, v5.16b\n" /* a0 * b1 = c01 */ \ + "smlal2 v9.8h, v1.16b, v4.16b\n" /* a1 * b0 = c10 */ \ + "smlal2 v13.8h, v1.16b, v5.16b\n" /* a1 * b1 = c11 */ \ + "smlal2 v10.8h, v2.16b, v4.16b\n" /* a2 * b0 = c20 */ \ + "smlal2 v14.8h, v2.16b, v5.16b\n" /* a2 * b1 = c21 */ \ + "smlal2 v11.8h, v3.16b, v4.16b\n" /* a3 * b0 = c30 */ \ + "smlal2 v15.8h, v3.16b, v5.16b\n" /* a3 * b1 = c31 */ \ + "bgt 0b\n" /* jump to main loop */ \ + "8:\n" /* finish main loop */ /* 1st b2, b3 */ \ + "sadalp v16.4s, v8.8h\n" /* pairwise accumulate to int32, out00 */ \ + "smull v8.8h, v0.8b, v6.8b\n" /* a0 * b0 = c02 */ \ + "sadalp v20.4s, v12.8h\n" /* pairwise accumulate to int32, out01 */ \ + "smull v12.8h, v0.8b, v7.8b\n" /* a0 * b1 = c03 */ \ + "sadalp v17.4s, v9.8h\n" /* pairwise accumulate to int32, out10 */ \ + "smull v9.8h, v1.8b, v6.8b\n" /* a1 * b0 = c12 */ \ + "sadalp v21.4s, v13.8h\n" /* pairwise accumulate to int32, out11 */ \ + "smull v13.8h, v1.8b, v7.8b\n" /* a1 * b1 = c13 */ \ + "sadalp v18.4s, v10.8h\n" /* pairwise accumulate to int32, out20 */ \ + "smull v10.8h, v2.8b, v6.8b\n" /* a2 * b0 = c22 */ \ + "sadalp v22.4s, v14.8h\n" /* pairwise accumulate to int32, out21 */ \ + "smull v14.8h, v2.8b, v7.8b\n" /* a2 * b1 = c23 */ \ + "sadalp v19.4s, v11.8h\n" /* pairwise accumulate to int32, out30 */ \ + "smull v11.8h, v3.8b, v6.8b\n" /* a3 * b0 = c32 */ \ + "sadalp v23.4s, v15.8h\n" /* pairwise accumulate to int32, out31 */ \ + "smull v15.8h, v3.8b, v7.8b\n" /* a3 * b1 = c33 */ /* 2nd b2, b3 */ \ + "smlal2 v8.8h, v0.16b, v6.16b\n" /* a0 * b0 = c02 */ \ + "smlal2 v12.8h, v0.16b, v7.16b\n" /* a0 * b1 = c03 */ \ + "smlal2 v9.8h, v1.16b, v6.16b\n" /* a1 * b0 = c12 */ \ + "smlal2 v13.8h, v1.16b, v7.16b\n" /* a1 * b1 = c23 */ \ + "smlal2 v10.8h, v2.16b, v6.16b\n" /* a2 * b0 = c13 */ \ + "smlal2 v14.8h, v2.16b, v7.16b\n" /* a2 * b1 = c32 */ \ + "smlal2 v11.8h, v3.16b, v6.16b\n" /* a3 * b0 = c22 */ \ + "smlal2 v15.8h, v3.16b, v7.16b\n" /* a3 * b1 = c33 */ \ + "cbz %w[rem], 5f\n" /* skip remain */ \ + "ld1 {v0.8b}, [%[a_ptr]]\n" /* load a to q0, final */ \ + "ld1 {v4.16b, v5.16b}, [%[b_ptr]],#32\n" /* load b to q4, q5 */ \ + "ld1 {v6.16b, v7.16b}, [%[b_ptr]],#32\n" /* load b to q6, q7 */ \ + "5:\n" /* no remain */ \ + "sadalp v24.4s, v8.8h\n" /* pairwise accumulate to int32, out02 */ \ + "sadalp v28.4s, v12.8h\n" /* pairwise accumulate to int32, out03 */ \ + "sadalp v25.4s, v9.8h\n" /* pairwise accumulate to int32, out12 */ \ + "sadalp v29.4s, v13.8h\n" /* pairwise accumulate to int32, out13 */ \ + "sadalp v26.4s, v10.8h\n" /* pairwise accumulate to int32, out22 */ \ + "sadalp v30.4s, v14.8h\n" /* pairwise accumulate to int32, out23 */ \ + "sadalp v27.4s, v11.8h\n" /* pairwise accumulate to int32, out32 */ \ + "sadalp v31.4s, v15.8h\n" /* pairwise accumulate to int32, out33 */ \ + "3: \n" /* process remains */ \ + "cbz %w[rem], 7f\n" /* skip remain */ /* process remain k */ \ + "4: \n" /* remain = 1, 2 */ \ + "ext v1.8b, v0.8b, v0.8b, #2\n" /* shift left 2bytes */ \ + "ext v2.8b, v0.8b, v0.8b, #4\n" /* shift left 4bytes */ \ + "ext v3.8b, v0.8b, v0.8b, #6\n" /* shift left 6bytes */ /* 1st b0, b1 */ \ + "smull v8.8h, v0.8b, v4.8b\n" /* a0 * b0 = c00 */ \ + "smull v12.8h, v0.8b, v5.8b\n" /* a0 * b1 = c01 */ \ + "smull v9.8h, v1.8b, v4.8b\n" /* a1 * b0 = c10 */ \ + "smull v13.8h, v1.8b, v5.8b\n" /* a1 * b1 = c11 */ \ + "smull v10.8h, v2.8b, v4.8b\n" /* a2 * b0 = c20 */ \ + "smull v14.8h, v2.8b, v5.8b\n" /* a2 * b1 = c21 */ \ + "smull v11.8h, v3.8b, v4.8b\n" /* a3 * b0 = c30 */ \ + "smull v15.8h, v3.8b, v5.8b\n" /* a3 * b1 = c31 */ /* 1st b2, b3 */ \ + "sadalp v16.4s, v8.8h\n" /* pairwise accumulate to int32, out00 */ \ + "smull v8.8h, v0.8b, v6.8b\n" /* a0 * b0 = c02 */ \ + "sadalp v20.4s, v12.8h\n" /* pairwise accumulate to int32, out01 */ \ + "smull v12.8h, v0.8b, v7.8b\n" /* a0 * b1 = c03 */ \ + "sadalp v17.4s, v9.8h\n" /* pairwise accumulate to int32, out10 */ \ + "smull v9.8h, v1.8b, v6.8b\n" /* a1 * b0 = c12 */ \ + "sadalp v21.4s, v13.8h\n" /* pairwise accumulate to int32, out11 */ \ + "smull v13.8h, v1.8b, v7.8b\n" /* a1 * b1 = c13 */ \ + "sadalp v18.4s, v10.8h\n" /* pairwise accumulate to int32, out20 */ \ + "smull v10.8h, v2.8b, v6.8b\n" /* a2 * b0 = c22 */ \ + "sadalp v22.4s, v14.8h\n" /* pairwise accumulate to int32, out21 */ \ + "smull v14.8h, v2.8b, v7.8b\n" /* a2 * b1 = c23 */ \ + "sadalp v19.4s, v11.8h\n" /* pairwise accumulate to int32, out30 */ \ + "smull v11.8h, v3.8b, v6.8b\n" /* a3 * b0 = c32 */ \ + "sadalp v23.4s, v15.8h\n" /* pairwise accumulate to int32, out31 */ \ + "smull v15.8h, v3.8b, v7.8b\n" /* a3 * b1 = c33 */ \ + "sadalp v24.4s, v8.8h\n" /* pairwise accumulate to int32, out02 */ \ + "sadalp v28.4s, v12.8h\n" /* pairwise accumulate to int32, out03 */ \ + "sadalp v25.4s, v9.8h\n" /* pairwise accumulate to int32, out12 */ \ + "sadalp v29.4s, v13.8h\n" /* pairwise accumulate to int32, out13 */ \ + "sadalp v26.4s, v10.8h\n" /* pairwise accumulate to int32, out22 */ \ + "sadalp v30.4s, v14.8h\n" /* pairwise accumulate to int32, out23 */ \ + "sadalp v27.4s, v11.8h\n" /* pairwise accumulate to int32, out32 */ \ + "sadalp v31.4s, v15.8h\n" /* pairwise accumulate to int32, out33 */ \ + "7: \n" /* do relu */ /* do relu */ \ + "cbz %w[is_relu], 9f\n" /* not relu, jump to unpack */ \ + "movi v0.4s, #0\n" /* for relu */ \ + "smax v16.4s, v16.4s, v0.4s\n" /* relu */ \ + "smax v17.4s, v17.4s, v0.4s\n" /* relu */ \ + "smax v18.4s, v18.4s, v0.4s\n" /* relu */ \ + "smax v19.4s, v19.4s, v0.4s\n" /* relu */ \ + "smax v20.4s, v20.4s, v0.4s\n" /* relu */ \ + "smax v21.4s, v21.4s, v0.4s\n" /* relu */ \ + "smax v22.4s, v22.4s, v0.4s\n" /* relu */ \ + "smax v23.4s, v23.4s, v0.4s\n" /* relu */ \ + "smax v24.4s, v24.4s, v0.4s\n" /* relu */ \ + "smax v25.4s, v25.4s, v0.4s\n" /* relu */ \ + "smax v26.4s, v26.4s, v0.4s\n" /* relu */ \ + "smax v27.4s, v27.4s, v0.4s\n" /* relu */ \ + "smax v28.4s, v28.4s, v0.4s\n" /* relu */ \ + "smax v29.4s, v29.4s, v0.4s\n" /* relu */ \ + "smax v30.4s, v30.4s, v0.4s\n" /* relu */ \ + "smax v31.4s, v31.4s, v0.4s\n" /* relu */ /* unpack the result */ \ + "9:\n" /* unpack */ /* trans 1 */ \ + "trn1 v0.4s, v16.4s, v17.4s\n" /* get a0,b0, a2,b2 */ \ + "trn2 v1.4s, v16.4s, v17.4s\n" /* get a1,b1, a3,b3 */ \ + "trn1 v2.4s, v18.4s, v19.4s\n" /* get c0,d0, c2,c2 */ \ + "trn2 v3.4s, v18.4s, v19.4s\n" /* get c1,d1, c3,d3 */ \ + "trn1 v4.4s, v20.4s, v21.4s\n" \ + "trn2 v5.4s, v20.4s, v21.4s\n" \ + "trn1 v6.4s, v22.4s, v23.4s\n" \ + "trn2 v7.4s, v22.4s, v23.4s\n" \ + "trn1 v8.4s, v24.4s, v25.4s\n" \ + "trn2 v9.4s, v24.4s, v25.4s\n" \ + "trn1 v10.4s, v26.4s, v27.4s\n" \ + "trn2 v11.4s, v26.4s, v27.4s\n" \ + "trn1 v12.4s, v28.4s, v29.4s\n" \ + "trn2 v13.4s, v28.4s, v29.4s\n" \ + "trn1 v14.4s, v30.4s, v31.4s\n" \ + "trn2 v15.4s, v30.4s, v31.4s\n" /* trans 2 */ \ + "trn1 v16.2d, v0.2d, v2.2d\n" /* get a0,b0, c0,d0 */ \ + "trn2 v18.2d, v0.2d, v2.2d\n" /* get a2,b2, c2,d2 */ \ + "trn1 v17.2d, v1.2d, v3.2d\n" /* get a1,b1, c1,d1 */ \ + "trn2 v19.2d, v1.2d, v3.2d\n" /* get a3,b3, c3,d3 */ \ + "trn1 v20.2d, v4.2d, v6.2d\n" \ + "trn2 v22.2d, v4.2d, v6.2d\n" \ + "trn1 v21.2d, v5.2d, v7.2d\n" \ + "trn2 v23.2d, v5.2d, v7.2d\n" \ + "trn1 v24.2d, v8.2d, v10.2d\n" \ + "trn2 v26.2d, v8.2d, v10.2d\n" \ + "trn1 v25.2d, v9.2d, v11.2d\n" \ + "trn2 v27.2d, v9.2d, v11.2d\n" \ + "trn1 v28.2d, v12.2d, v14.2d\n" \ + "trn2 v30.2d, v12.2d, v14.2d\n" \ + "trn1 v29.2d, v13.2d, v15.2d\n" \ + "trn2 v31.2d, v13.2d, v15.2d\n" /* shift */ \ + "ext v17.16b, v17.16b, v17.16b, #12\n" /* circular shift left 1 */ \ + "ext v18.16b, v18.16b, v18.16b, #8\n" /* circular shift left 2 */ \ + "ext v19.16b, v19.16b, v19.16b, #4\n" /* circular shift left 3 */ \ + "ext v21.16b, v21.16b, v21.16b, #12\n" /* circular shift left 1 */ \ + "ext v22.16b, v22.16b, v22.16b, #8\n" /* circular shift left 2 */ \ + "ext v23.16b, v23.16b, v23.16b, #4\n" /* circular shift left 3 */ \ + "ext v25.16b, v25.16b, v25.16b, #12\n" /* circular shift left 1 */ \ + "ext v26.16b, v26.16b, v26.16b, #8\n" /* circular shift left 2 */ \ + "ext v27.16b, v27.16b, v27.16b, #4\n" /* circular shift left 3 */ \ + "ext v29.16b, v29.16b, v29.16b, #12\n" /* circular shift left 1 */ \ + "ext v30.16b, v30.16b, v30.16b, #8\n" /* circular shift left 2 */ \ + "ext v31.16b, v31.16b, v31.16b, #4\n" /* circular shift left 3 */ \ + "trn1 v0.4s, v16.4s, v17.4s\n" /* get a0,b0, a2,b2 */ \ + "trn2 v1.4s, v16.4s, v17.4s\n" /* get a1,b1, a3,b3 */ \ + "trn1 v2.4s, v18.4s, v19.4s\n" /* get c0,d0, c2,c2 */ \ + "trn2 v3.4s, v18.4s, v19.4s\n" /* get c1,d1, c3,d3 */ \ + "trn1 v4.4s, v20.4s, v21.4s\n" \ + "trn2 v5.4s, v20.4s, v21.4s\n" \ + "trn1 v6.4s, v22.4s, v23.4s\n" \ + "trn2 v7.4s, v22.4s, v23.4s\n" \ + "trn1 v8.4s, v24.4s, v25.4s\n" \ + "trn2 v9.4s, v24.4s, v25.4s\n" \ + "trn1 v10.4s, v26.4s, v27.4s\n" \ + "trn2 v11.4s, v26.4s, v27.4s\n" \ + "trn1 v12.4s, v28.4s, v29.4s\n" \ + "trn2 v13.4s, v28.4s, v29.4s\n" \ + "trn1 v14.4s, v30.4s, v31.4s\n" \ + "trn2 v15.4s, v30.4s, v31.4s\n" /* trans 2 */ \ + "trn1 v16.2d, v0.2d, v2.2d\n" /* get a0,b0, c0,d0 */ \ + "trn2 v24.2d, v0.2d, v2.2d\n" /* get a2,b2, c2,d2 */ \ + "trn1 v20.2d, v1.2d, v3.2d\n" /* get a1,b1, c1,d1 */ \ + "trn2 v28.2d, v1.2d, v3.2d\n" /* get a3,b3, c3,d3 */ \ + "trn1 v17.2d, v4.2d, v6.2d\n" \ + "trn2 v25.2d, v4.2d, v6.2d\n" \ + "trn1 v21.2d, v5.2d, v7.2d\n" \ + "trn2 v29.2d, v5.2d, v7.2d\n" \ + "trn1 v18.2d, v8.2d, v10.2d\n" \ + "trn2 v26.2d, v8.2d, v10.2d\n" \ + "trn1 v22.2d, v9.2d, v11.2d\n" \ + "trn2 v30.2d, v9.2d, v11.2d\n" \ + "trn1 v19.2d, v12.2d, v14.2d\n" \ + "trn2 v27.2d, v12.2d, v14.2d\n" \ + "trn1 v23.2d, v13.2d, v15.2d\n" \ + "trn2 v31.2d, v13.2d, v15.2d\n" + +// clang-format off +#define GEMM_INT8_INT32_OUT \ + /* store */ \ + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[c_ptr0]], #64\n" \ + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[c_ptr1]], #64\n" \ + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[c_ptr2]], #64\n" \ + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[c_ptr3]], #64\n" +// clang-format on + +#define GEMM_INT8_FP32_OUT \ + /* store */ \ + "ldr q15, [%[scale]]\n" /* load scale */ \ + "scvtf v0.4s , v16.4s\n" /* 00, convert to fp32 */ \ + "scvtf v1.4s , v17.4s\n" /* 01, convert to fp32 */ \ + "scvtf v2.4s , v18.4s\n" /* 02, convert to fp32 */ \ + "scvtf v3.4s , v19.4s\n" /* 03, convert to fp32 */ \ + "scvtf v4.4s , v20.4s\n" /* 10, convert to fp32 */ \ + "scvtf v5.4s , v21.4s\n" /* 11, convert to fp32 */ \ + "scvtf v6.4s , v22.4s\n" /* 12, convert to fp32 */ \ + "scvtf v7.4s , v23.4s\n" /* 13, convert to fp32 */ \ + "fmul v16.4s, v0.4s, v15.s[0]\n" /* 00, mul scale to get final result */ \ + "fmul v17.4s, v1.4s, v15.s[0]\n" /* 01, mul scale to get final result */ \ + "fmul v18.4s, v2.4s, v15.s[0]\n" /* 02, mul scale to get final result */ \ + "fmul v19.4s, v3.4s, v15.s[0]\n" /* 03, mul scale to get final result */ \ + "fmul v20.4s, v4.4s, v15.s[1]\n" /* 10, mul scale to get final result */ \ + "fmul v21.4s, v5.4s, v15.s[1]\n" /* 11, mul scale to get final result */ \ + "fmul v22.4s, v6.4s, v15.s[1]\n" /* 12, mul scale to get final result */ \ + "fmul v23.4s, v7.4s, v15.s[1]\n" /* 13, mul scale to get final result */ \ + "scvtf v0.4s , v24.4s\n" /* 20, convert to fp32 */ \ + "scvtf v1.4s , v25.4s\n" /* 21, convert to fp32 */ \ + "stp q16, q17, [%[c_ptr0]], #32\n" /* write r0, 0,1 */ \ + "scvtf v2.4s , v26.4s\n" /* 22, convert to fp32 */ \ + "scvtf v3.4s , v27.4s\n" /* 23, convert to fp32 */ \ + "stp q18, q19, [%[c_ptr0]], #32\n" /* write r0, 2,3 */ \ + "scvtf v4.4s , v28.4s\n" /* 30, convert to fp32 */ \ + "scvtf v5.4s , v29.4s\n" /* 31, convert to fp32 */ \ + "stp q20, q21, [%[c_ptr1]], #32\n" /* write r1, 0,1 */ \ + "scvtf v6.4s , v30.4s\n" /* 32, convert to fp32 */ \ + "scvtf v7.4s , v31.4s\n" /* 33, convert to fp32 */ \ + "stp q22, q23, [%[c_ptr1]], #32\n" /* write r1, 2,3 */ \ + "fmul v24.4s, v0.4s, v15.s[2]\n" /* 20, mul scale to get final result */ \ + "fmul v25.4s, v1.4s, v15.s[2]\n" /* 21, mul scale to get final result */ \ + "fmul v26.4s, v2.4s, v15.s[2]\n" /* 22, mul scale to get final result */ \ + "fmul v27.4s, v3.4s, v15.s[2]\n" /* 23, mul scale to get final result */ \ + "fmul v28.4s, v4.4s, v15.s[3]\n" /* 30, mul scale to get final result */ \ + "fmul v29.4s, v5.4s, v15.s[3]\n" /* 31, mul scale to get final result */ \ + "stp q24, q25, [%[c_ptr2]], #32\n" /* write r2, 2,3 */ \ + "fmul v30.4s, v6.4s, v15.s[3]\n" /* 32, mul scale to get final result */ \ + "stp q26, q27, [%[c_ptr2]], #32\n" /* write r2, 2,3 */ \ + "fmul v31.4s, v7.4s, v15.s[3]\n" /* 33, mul scale to get final result */ \ + "stp q28, q29, [%[c_ptr3]], #32\n" /* write r3, 2,3 */ \ + "stp q30, q31, [%[c_ptr3]], #32\n" /* write r3, 2,3 */ + +#define GEMM_INT8_INT8_OUT \ + /* store */ \ + "ldr q15, [%[scale]]\n" /* load scale */ \ + "scvtf v0.4s , v16.4s\n" /* 00, convert to fp32 */ \ + "scvtf v1.4s , v17.4s\n" /* 01, convert to fp32 */ \ + "scvtf v2.4s , v18.4s\n" /* 02, convert to fp32 */ \ + "scvtf v3.4s , v19.4s\n" /* 03, convert to fp32 */ \ + "scvtf v4.4s , v20.4s\n" /* 10, convert to fp32 */ \ + "scvtf v5.4s , v21.4s\n" /* 11, convert to fp32 */ \ + "scvtf v6.4s , v22.4s\n" /* 12, convert to fp32 */ \ + "scvtf v7.4s , v23.4s\n" /* 13, convert to fp32 */ \ + "fmul v16.4s, v0.4s, v15.s[0]\n" /* 00, mul scale to get final result */ \ + "fmul v17.4s, v1.4s, v15.s[0]\n" /* 01, mul scale to get final result */ \ + "fmul v18.4s, v2.4s, v15.s[0]\n" /* 02, mul scale to get final result */ \ + "fmul v19.4s, v3.4s, v15.s[0]\n" /* 03, mul scale to get final result */ \ + "fmul v20.4s, v4.4s, v15.s[1]\n" /* 20, mul scale to get final result */ \ + "fmul v21.4s, v5.4s, v15.s[1]\n" /* 21, mul scale to get final result */ \ + "fmul v22.4s, v6.4s, v15.s[1]\n" /* 22, mul scale to get final result */ \ + "fmul v23.4s, v7.4s, v15.s[1]\n" /* 23, mul scale to get final result */ \ + "scvtf v0.4s , v24.4s\n" /* 20, convert to fp32 */ \ + "scvtf v1.4s , v25.4s\n" /* 21, convert to fp32 */ \ + "scvtf v2.4s , v26.4s\n" /* 22, convert to fp32 */ \ + "scvtf v3.4s , v27.4s\n" /* 23, convert to fp32 */ \ + "scvtf v4.4s , v28.4s\n" /* 30, convert to fp32 */ \ + "scvtf v5.4s , v29.4s\n" /* 31, convert to fp32 */ \ + "scvtf v6.4s , v30.4s\n" /* 32, convert to fp32 */ \ + "scvtf v7.4s , v31.4s\n" /* 33, convert to fp32 */ \ + "fmul v24.4s, v0.4s, v15.s[2]\n" /* 20, mul scale to get final result */ \ + "fmul v25.4s, v1.4s, v15.s[2]\n" /* 21, mul scale to get final result */ \ + "fmul v26.4s, v2.4s, v15.s[2]\n" /* 22, mul scale to get final result */ \ + "fmul v27.4s, v3.4s, v15.s[2]\n" /* 23, mul scale to get final result */ \ + "fmul v28.4s, v4.4s, v15.s[3]\n" /* 30, mul scale to get final result */ \ + "fmul v29.4s, v5.4s, v15.s[3]\n" /* 31, mul scale to get final result */ \ + "fmul v30.4s, v6.4s, v15.s[3]\n" /* 32, mul scale to get final result */ \ + "fmul v31.4s, v7.4s, v15.s[3]\n" /* 33, mul scale to get final result */ \ + "fcvtas v0.4s, v16.4s\n" /* 00, cvt to int */ \ + "fcvtas v1.4s, v17.4s\n" /* 01, cvt to int */ \ + "fcvtas v2.4s, v18.4s\n" /* 02, cvt to int */ \ + "fcvtas v3.4s, v19.4s\n" /* 03, cvt to int */ \ + "fcvtas v4.4s, v20.4s\n" /* 10, cvt to int */ \ + "fcvtas v5.4s, v21.4s\n" /* 11, cvt to int */ \ + "fcvtas v6.4s, v22.4s\n" /* 12, cvt to int */ \ + "fcvtas v7.4s, v23.4s\n" /* 13, cvt to int */ \ + "sqxtn v16.4h, v0.4s\n" /* 00, cvt int32 to int16 */ \ + "fcvtas v8.4s, v24.4s\n" /* 20, cvt to int */ \ + "sqxtn2 v16.8h, v1.4s\n" /* 01, cvt int32 to int16 */ \ + "fcvtas v9.4s, v25.4s\n" /* 21, cvt to int */ \ + "sqxtn v17.4h, v2.4s\n" /* 02, cvt int32 to int16 */ \ + "fcvtas v10.4s, v26.4s\n" /* 22, cvt to int */ \ + "sqxtn2 v17.8h, v3.4s\n" /* 03, cvt int32 to int16 */ \ + "fcvtas v11.4s, v27.4s\n" /* 23, cvt to int */ \ + "sqxtn v18.4h, v4.4s\n" /* 10, cvt int32 to int16 */ \ + "fcvtas v12.4s, v28.4s\n" /* 30, cvt to int */ \ + "sqxtn2 v18.8h, v5.4s\n" /* 11, cvt int32 to int16 */ \ + "fcvtas v13.4s, v29.4s\n" /* 31, cvt to int */ \ + "sqxtn v19.4h, v6.4s\n" /* 12, cvt int32 to int16 */ \ + "fcvtas v14.4s, v30.4s\n" /* 32, cvt to int */ \ + "sqxtn2 v19.8h, v7.4s\n" /* 13, cvt int32 to int16 */ \ + "fcvtas v15.4s, v31.4s\n" /* 33, cvt to int */ \ + "sqxtn v0.8b, v16.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn2 v0.16b, v17.8h\n" /* 02, 03, cvt int16 to int8 */ \ + "sqxtn v1.8b, v18.8h\n" /* 10, 11, cvt int16 to int8 */ \ + "sqxtn2 v1.16b, v19.8h\n" /* 12, 13, cvt int16 to int8 */ \ + "sqxtn v20.4h, v8.4s\n" /* 20, cvt int32 to int16 */ \ + "sqxtn2 v20.8h, v9.4s\n" /* 21, cvt int32 to int16 */ \ + "sqxtn v21.4h, v10.4s\n" /* 22, cvt int32 to int16 */ \ + "sqxtn2 v21.8h, v11.4s\n" /* 23, cvt int32 to int16 */ \ + "sqxtn v22.4h, v12.4s\n" /* 30, cvt int32 to int16 */ \ + "sqxtn2 v22.8h, v13.4s\n" /* 31, cvt int32 to int16 */ \ + "sqxtn v23.4h, v14.4s\n" /* 32, cvt int32 to int16 */ \ + "sqxtn2 v23.8h, v15.4s\n" /* 33, cvt int32 to int16 */ \ + "sqxtn v2.8b, v20.8h\n" /* 20, 21, cvt int16 to int8 */ \ + "sqxtn2 v2.16b, v21.8h\n" /* 22, 23, cvt int16 to int8 */ \ + "sqxtn v3.8b, v22.8h\n" /* 30, 31, cvt int16 to int8 */ \ + "sqxtn2 v3.16b, v23.8h\n" /* 32, 33, cvt int16 to int8 */ \ + "str q0, [%[c_ptr0]], #16\n" /* write r0 */ \ + "str q1, [%[c_ptr1]], #16\n" /* write r1 */ \ + "str q2, [%[c_ptr2]], #16\n" /* write r2 */ \ + "str q3, [%[c_ptr3]], #16\n" /* write r3 */ + +template <> +inline void gemm_int8_kernel(const int8_t* a_ptr, + const int8_t*& b_ptr, // NOLINT + const int32_t* bias, + int32_t*& c_ptr0, // NOLINT + int32_t*& c_ptr1, // NOLINT + int32_t*& c_ptr2, // NOLINT + int32_t*& c_ptr3, // NOLINT + const float* scale, // NOLINT + bool is_relu, // NOLINT + int k, + int rem) { + asm volatile(GEMM_INT8_KERNEL GEMM_INT8_INT32_OUT + : [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), + [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), + [k] "+r"(k) + : [is_relu] "r"(is_relu), [bias] "r"(bias), [rem] "r"(rem) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31", + "cc"); +} +template <> +inline void gemm_int8_kernel(const int8_t* a_ptr, + const int8_t*& b_ptr, // NOLINT + const int32_t* bias, + float*& c_ptr0, // NOLINT + float*& c_ptr1, // NOLINT + float*& c_ptr2, // NOLINT + float*& c_ptr3, // NOLINT + const float* scale, + bool is_relu, + int k, + int rem) { + asm volatile(GEMM_INT8_KERNEL GEMM_INT8_FP32_OUT + : [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), + [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), + [k] "+r"(k) + : [is_relu] "r"(is_relu), + [bias] "r"(bias), + [rem] "r"(rem), + [scale] "r"(scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31", + "cc"); +} + +template <> +inline void gemm_int8_kernel(const int8_t* a_ptr, + const int8_t*& b_ptr, // NOLINT + const int32_t* bias, + int8_t*& c_ptr0, // NOLINT + int8_t*& c_ptr1, // NOLINT + int8_t*& c_ptr2, // NOLINT + int8_t*& c_ptr3, // NOLINT + const float* scale, + bool is_relu, + int k, + int rem) { + asm volatile(GEMM_INT8_KERNEL GEMM_INT8_INT8_OUT + : [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), + [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), + [k] "+r"(k) + : [is_relu] "r"(is_relu), + [bias] "r"(bias), + [rem] "r"(rem), + [scale] "r"(scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31", + "cc"); +} + +#ifdef WITH_ARM_DOTPROD +template +inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, + const int8_t*& b_ptr, // NOLINT + const int32_t* bias, + Dtype*& c_ptr0, // NOLINT + Dtype*& c_ptr1, // NOLINT + Dtype*& c_ptr2, // NOLINT + Dtype*& c_ptr3, // NOLINT + Dtype*& c_ptr4, // NOLINT + Dtype*& c_ptr5, // NOLINT + Dtype*& c_ptr6, // NOLINT + Dtype*& c_ptr7, // NOLINT + const float32_t* scale, + bool is_relu, + int k, + int rem); + +#define GEMM_SDOT_INT8_KERNEL \ + "ldp q2, q3, [%[bias_ptr]]\n" /* load bias to q2, q3*/ \ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00,a01 to q0, q1*/ \ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ \ + "dup v8.4s, v2.s[0]\n" /* out0 = 0 */ \ + "dup v9.4s, v2.s[0]\n" /* out1 = 0*/ \ + "dup v10.4s, v2.s[0]\n" /* out2 = 0*/ \ + "dup v11.4s, v2.s[1]\n" /* out3 = 0*/ \ + "dup v12.4s, v2.s[1]\n" /* out4 = 0*/ \ + "prfm pldl1keep, [%[b_ptr], #64]\n" /* preload b*/ \ + "dup v13.4s, v2.s[1]\n" /* out5 = 0*/ \ + "prfm pldl1keep, [%[a_ptr], #64]\n" /* preload a*/ \ + "dup v14.4s, v2.s[2]\n" /* out6 = 0*/ \ + "prfm pldl1keep, [%[b_ptr], #128]\n" /* preload b*/ \ + "dup v15.4s, v2.s[2]\n" /* out7 = 0*/ \ + "prfm pldl1keep, [%[a_ptr], #128]\n" /* preload a*/ \ + "dup v16.4s, v2.s[2]\n" /* out8 = 0*/ \ + "prfm pldl1keep, [%[b_ptr], #192]\n" /* preload b*/ \ + "dup v17.4s, v2.s[3]\n" /* out9 = 0*/ \ + "prfm pldl1keep, [%[b_ptr], #256]\n" /* preload b*/ \ + "dup v18.4s, v2.s[3]\n" /* out10 = 0*/ \ + "prfm pldl1keep, [%[a_ptr], #192]\n" /* preload a*/ \ + "dup v19.4s, v2.s[3]\n" /* out11 = 0*/ \ + "prfm pldl1keep, [%[b_ptr], #320]\n" /* preload b*/ \ + "dup v20.4s, v3.s[0]\n" /* out12 = 0*/ \ + "prfm pldl1keep, [%[a_ptr], #256]\n" /* preload a*/ \ + "dup v21.4s, v3.s[0]\n" /* out13 = 0*/ \ + "prfm pldl1keep, [%[b_ptr], #384]\n" /* preload b*/ \ + "dup v22.4s, v3.s[0]\n" /* out14 = 0*/ \ + "dup v23.4s, v3.s[1]\n" /* out15 = 0*/ \ + "dup v24.4s, v3.s[1]\n" /* out16 = 0*/ \ + "dup v25.4s, v3.s[1]\n" /* out17 = 0*/ \ + "dup v26.4s, v3.s[2]\n" /* out18 = 0*/ \ + "dup v27.4s, v3.s[2]\n" /* out19 = 0*/ \ + "dup v28.4s, v3.s[2]\n" /* out20 = 0*/ \ + "dup v29.4s, v3.s[3]\n" /* out21 = 0*/ \ + "dup v30.4s, v3.s[3]\n" /* out22 = 0*/ \ + "dup v31.4s, v3.s[3]\n" /* out23 = 0*/ \ + "cbz %w[k], 2f\n" /* check loop count > 0 */ \ + "1:\n" /* main loop */ \ + "sdot v8.4s , v4.16b, v0.4b[0]\n" /* out0 = b0 * a00[0], b0 = q4 */ \ + "sdot v11.4s , v4.16b, v0.4b[1]\n" /* out1 = b0 * a00[1], b0 = q4 */ \ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7 */ \ + "sdot v14.4s, v4.16b, v0.4b[2]\n" /* out2 = b0 * a00[2], b0 = q4 */ \ + "sdot v17.4s, v4.16b, v0.4b[3]\n" /* out3 = b0 * a00[3], b0 = q4 */ \ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4 */ \ + "sdot v20.4s, v4.16b, v1.4b[0]\n" /* out4 = b0 * a01[0], b0 = q4 */ \ + "sdot v23.4s, v4.16b, v1.4b[1]\n" /* out5 = b0 * a01[1], b0 = q4 */ \ + "sdot v26.4s, v4.16b, v1.4b[2]\n" /* out6 = b0 * a01[2], b0 = q4 */ \ + "sdot v29.4s, v4.16b, v1.4b[3]\n" /* out7 = b0 * a01[3], b0 = q4 */ \ + "sdot v9.4s, v5.16b, v0.4b[0]\n" /* out8 = b1 * a00[0], b1 = q5 */ \ + "sdot v12.4s, v5.16b, v0.4b[1]\n" /* out9 = b1 * a00[1], b1 = q5 */ \ + "sdot v15.4s, v5.16b, v0.4b[2]\n" /* out10 = b1 * a00[2], b1 = q5*/ \ + "sdot v18.4s, v5.16b, v0.4b[3]\n" /* out11 = b1 * a00[3], b1 = q5*/ \ + "sdot v21.4s, v5.16b, v1.4b[0]\n" /* out12 = b1 * a01[0], b1 = q5*/ \ + "sdot v24.4s, v5.16b, v1.4b[1]\n" /* out13 = b1 * a01[1], b1 = q5*/ \ + "sdot v27.4s, v5.16b, v1.4b[2]\n" /* out14 = b1 * a01[2], b1 = q5*/ \ + "sdot v30.4s, v5.16b, v1.4b[3]\n" /* out15 = b1 * a01[3], b1 = q5*/ \ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b1, b2 to q4, q5 */ \ + "sdot v10.4s, v6.16b, v0.4b[0]\n" /* out16 = b2 * a00[0], b2 = q6*/ \ + "sdot v13.4s, v6.16b, v0.4b[1]\n" /* out17 = b2 * a00[1], b2 = q6*/ \ + "prfm pldl1keep, [%[b_ptr], #384]\n" \ + "sdot v16.4s, v6.16b, v0.4b[2]\n" /* out18 = b2 * a00[2], b2 = q6*/ \ + "sdot v19.4s, v6.16b, v0.4b[3]\n" /* out19 = b2 * a00[3], b2 = q6*/ \ + "sdot v22.4s, v6.16b, v1.4b[0]\n" /* out20 = b2 * a00[0], b2 = q6*/ \ + "sdot v25.4s, v6.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q6*/ \ + "sdot v28.4s, v6.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q6*/ \ + "sdot v31.4s, v6.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q6*/ \ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1 */ \ + "sdot v8.4s , v7.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q7 */ \ + "sdot v11.4s , v7.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q7 */ \ + "sdot v14.4s, v7.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q7 */ \ + "prfm pldl1keep, [%[a_ptr], #256]\n" \ + "sdot v17.4s, v7.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q7 */ \ + "sdot v20.4s, v7.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q7 */ \ + "sdot v23.4s, v7.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q7 */ \ + "sdot v26.4s, v7.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q7 */ \ + "sdot v29.4s, v7.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q7 */ \ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7 */ \ + "sdot v9.4s, v4.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q4 */ \ + "sdot v12.4s, v4.16b, v2.4b[1]\n" /* out9 = b0 * a10[1], b1 = q4 */ \ + "sdot v15.4s, v4.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q4*/ \ + "sdot v18.4s, v4.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q4*/ \ + "sdot v21.4s, v4.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q4*/ \ + "sdot v24.4s, v4.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q4*/ \ + "sdot v27.4s, v4.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q4*/ \ + "sdot v30.4s, v4.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q4*/ \ + "sdot v10.4s, v5.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q5*/ \ + "sdot v13.4s, v5.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q5*/ \ + "sdot v16.4s, v5.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q5*/ \ + "sdot v19.4s, v5.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q5*/ \ + "sdot v22.4s, v5.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q5*/ \ + "sdot v25.4s, v5.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q5*/ \ + "sdot v28.4s, v5.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q5*/ \ + "sdot v31.4s, v5.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q5*/ \ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5 */ \ + "sdot v8.4s , v6.16b, v0.4b[0]\n" /* out0 = b0 * a00[0], b0 = q6 */ \ + "sdot v11.4s , v6.16b, v0.4b[1]\n" /* out1 = b0 * a00[1], b0 = q6 */ \ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/ \ + "sdot v14.4s, v6.16b, v0.4b[2]\n" /* out2 = b0 * a00[2], b0 = q6*/ \ + "sdot v17.4s, v6.16b, v0.4b[3]\n" /* out3 = b0 * a00[3], b0 = q6*/ \ + "sdot v20.4s, v6.16b, v1.4b[0]\n" /* out4 = b0 * a01[0], b0 = q6*/ \ + "sdot v23.4s, v6.16b, v1.4b[1]\n" /* out5 = b0 * a01[1], b0 = q6*/ \ + "sdot v26.4s, v6.16b, v1.4b[2]\n" /* out6 = b0 * a01[2], b0 = q6*/ \ + "sdot v29.4s, v6.16b, v1.4b[3]\n" /* out7 = b0 * a01[3], b0 = q6*/ \ + "sdot v9.4s, v7.16b, v0.4b[0]\n" /* out8 = b1 * a00[0], b1 = q7*/ \ + "sdot v12.4s, v7.16b, v0.4b[1]\n" /* out9 = b1 * a00[1], b1 = q7*/ \ + "prfm pldl1keep, [%[b_ptr], #384]\n" \ + "sdot v15.4s, v7.16b, v0.4b[2]\n" /* out10 = b1 * a00[2], b1 = q7*/ \ + "sdot v18.4s, v7.16b, v0.4b[3]\n" /* out11 = b1 * a00[3], b1 = q7*/ \ + "sdot v21.4s, v7.16b, v1.4b[0]\n" /* out12 = b1 * a01[0], b1 = q7*/ \ + "sdot v24.4s, v7.16b, v1.4b[1]\n" /* out13 = b1 * a01[1], b1 = q7*/ \ + "sdot v27.4s, v7.16b, v1.4b[2]\n" /* out14 = b1 * a01[2], b1 = q7*/ \ + "sdot v30.4s, v7.16b, v1.4b[3]\n" /* out15 = b1 * a01[3], b1 = q7*/ \ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b1, b2 to q6, q7*/ \ + "sdot v10.4s, v4.16b, v0.4b[0]\n" /* out16 = b2 * a00[0], b2 = q4*/ \ + "sdot v13.4s, v4.16b, v0.4b[1]\n" /* out17 = b2 * a00[1], b2 = q4*/ \ + "sdot v16.4s, v4.16b, v0.4b[2]\n" /* out18 = b2 * a00[2], b2 = q4*/ \ + "sdot v19.4s, v4.16b, v0.4b[3]\n" /* out19 = b2 * a00[3], b2 = q4*/ \ + "sdot v22.4s, v4.16b, v1.4b[0]\n" /* out20 = b2 * a00[0], b2 = q4*/ \ + "sdot v25.4s, v4.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q4*/ \ + "sdot v28.4s, v4.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q4*/ \ + "sdot v31.4s, v4.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q4*/ \ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 */ /* unrool 3*/ \ + "sdot v8.4s , v5.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \ + "sdot v11.4s , v5.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \ + "sdot v14.4s, v5.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \ + "sdot v17.4s, v5.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ \ + "sdot v20.4s, v5.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ \ + "sdot v23.4s, v5.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ \ + "sdot v26.4s, v5.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ \ + "sdot v29.4s, v5.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ \ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ \ + "sdot v9.4s, v6.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ \ + "sdot v12.4s, v6.16b, v2.4b[1]\n" /* out9 = b0 * a10[1], b1 = q6*/ \ + "prfm pldl1keep, [%[a_ptr], #256]\n" \ + "sdot v15.4s, v6.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q6*/ \ + "sdot v18.4s, v6.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q6*/ \ + "sdot v21.4s, v6.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q6*/ \ + "sdot v24.4s, v6.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q6*/ \ + "sdot v27.4s, v6.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q6*/ \ + "prfm pldl1keep, [%[b_ptr], #384]\n" \ + "sdot v30.4s, v6.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q6*/ \ + "sdot v10.4s, v7.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q7*/ \ + "sdot v13.4s, v7.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q7*/ \ + "sdot v16.4s, v7.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q7*/ \ + "sdot v19.4s, v7.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q7*/ \ + "sdot v22.4s, v7.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q7*/ \ + "sdot v25.4s, v7.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q7*/ \ + "subs %w[k], %w[k], #1\n" /* loop count - 1*/ \ + "sdot v28.4s, v7.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \ + "sdot v31.4s, v7.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \ + "bne 1b\n" \ + "2:\n" /* process tail*/ \ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ \ + "beq 3f\n" \ + "sdot v8.4s , v4.16b, v0.4b[0]\n" /* out0 = b0 * a00[0], b0 = q4*/ \ + "sdot v11.4s , v4.16b, v0.4b[1]\n" /* out1 = b0 * a00[1], b0 = q4*/ \ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7*/ \ + "sdot v14.4s, v4.16b, v0.4b[2]\n" /* out2 = b0 * a00[2], b0 = q4*/ \ + "sdot v17.4s, v4.16b, v0.4b[3]\n" /* out3 = b0 * a00[3], b0 = q4*/ \ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q2, q3*/ \ + "sdot v20.4s, v4.16b, v1.4b[0]\n" /* out4 = b0 * a01[0], b0 = q4*/ \ + "sdot v23.4s, v4.16b, v1.4b[1]\n" /* out5 = b0 * a01[1], b0 = q4*/ \ + "sdot v26.4s, v4.16b, v1.4b[2]\n" /* out6 = b0 * a01[2], b0 = q4*/ \ + "sdot v29.4s, v4.16b, v1.4b[3]\n" /* out7 = b0 * a01[3], b0 = q4*/ \ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ \ + "sdot v9.4s, v5.16b, v0.4b[0]\n" /* out8 = b1 * a00[0], b1 = q5*/ \ + "sdot v12.4s, v5.16b, v0.4b[1]\n" /* out9 = b1 * a00[1], b1 = q5*/ \ + "sdot v15.4s, v5.16b, v0.4b[2]\n" /* out10 = b1 * a00[2], b1 = q5*/ \ + "sdot v18.4s, v5.16b, v0.4b[3]\n" /* out11 = b1 * a00[3], b1 = q5*/ \ + "sdot v21.4s, v5.16b, v1.4b[0]\n" /* out12 = b1 * a01[0], b1 = q5*/ \ + "sdot v24.4s, v5.16b, v1.4b[1]\n" /* out13 = b1 * a01[1], b1 = q5*/ \ + "sdot v27.4s, v5.16b, v1.4b[2]\n" /* out14 = b1 * a01[2], b1 = q5*/ \ + "sdot v30.4s, v5.16b, v1.4b[3]\n" /* out15 = b1 * a01[3], b1 = q5*/ \ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b1, b2 to q4, q5*/ \ + "sdot v10.4s, v6.16b, v0.4b[0]\n" /* out16 = b2 * a00[0], b2 = q6*/ \ + "sdot v13.4s, v6.16b, v0.4b[1]\n" /* out17 = b2 * a00[1], b2 = q6*/ \ + "sdot v16.4s, v6.16b, v0.4b[2]\n" /* out18 = b2 * a00[2], b2 = q6*/ \ + "sdot v19.4s, v6.16b, v0.4b[3]\n" /* out19 = b2 * a00[3], b2 = q6*/ \ + "sdot v22.4s, v6.16b, v1.4b[0]\n" /* out20 = b2 * a00[0], b2 = q6*/ \ + "sdot v25.4s, v6.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q6*/ \ + "sdot v28.4s, v6.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q6*/ \ + "sdot v31.4s, v6.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q6*/ \ + "beq 4f\n" /*jump to tail = 2*/ /* unrool 1, tail > 2*/ \ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/ \ + "sdot v8.4s , v7.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q7*/ \ + "sdot v11.4s , v7.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q7*/ \ + "sdot v14.4s, v7.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q7*/ \ + "sdot v17.4s, v7.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q7*/ \ + "sdot v20.4s, v7.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q7*/ \ + "sdot v23.4s, v7.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q7*/ \ + "sdot v26.4s, v7.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q7*/ \ + "sdot v29.4s, v7.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q7*/ \ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7*/ \ + "sdot v9.4s, v4.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q4*/ \ + "sdot v12.4s, v4.16b, v2.4b[1]\n" /* out9 = b0 * a10[1], b1 = q4*/ \ + "sdot v15.4s, v4.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q4*/ \ + "sdot v18.4s, v4.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q4*/ \ + "sdot v21.4s, v4.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q4*/ \ + "sdot v24.4s, v4.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q4*/ \ + "sdot v27.4s, v4.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q4*/ \ + "sdot v30.4s, v4.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q4*/ \ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ \ + "sdot v10.4s, v5.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q5*/ \ + "sdot v13.4s, v5.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q5*/ \ + "sdot v16.4s, v5.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q5*/ \ + "sdot v19.4s, v5.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q5*/ \ + "sdot v22.4s, v5.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q5*/ \ + "sdot v25.4s, v5.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q5*/ \ + "sdot v28.4s, v5.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q5*/ \ + "sdot v31.4s, v5.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q5*/ \ + "beq 5f\n" /*jump to tail = 3*/ /* unrool 2, tail = 4*/ \ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5*/ \ + "sdot v8.4s , v6.16b, v0.4b[0]\n" /* out0 = b0 * a00[0], b0 = q6*/ \ + "sdot v11.4s , v6.16b, v0.4b[1]\n" /* out1 = b0 * a00[1], b0 = q6*/ \ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/ \ + "sdot v14.4s, v6.16b, v0.4b[2]\n" /* out2 = b0 * a00[2], b0 = q6*/ \ + "sdot v17.4s, v6.16b, v0.4b[3]\n" /* out3 = b0 * a00[3], b0 = q6*/ \ + "sdot v20.4s, v6.16b, v1.4b[0]\n" /* out4 = b0 * a01[0], b0 = q6*/ \ + "sdot v23.4s, v6.16b, v1.4b[1]\n" /* out5 = b0 * a01[1], b0 = q6*/ \ + "sdot v26.4s, v6.16b, v1.4b[2]\n" /* out6 = b0 * a01[2], b0 = q6*/ \ + "sdot v29.4s, v6.16b, v1.4b[3]\n" /* out7 = b0 * a01[3], b0 = q6*/ \ + "sdot v9.4s, v7.16b, v0.4b[0]\n" /* out8 = b1 * a00[0], b1 = q7*/ \ + "sdot v12.4s, v7.16b, v0.4b[1]\n" /* out9 = b1 * a00[1], b1 = q7*/ \ + "sdot v15.4s, v7.16b, v0.4b[2]\n" /* out10 = b1 * a00[2], b1 = q7*/ \ + "sdot v18.4s, v7.16b, v0.4b[3]\n" /* out11 = b1 * a00[3], b1 = q7*/ \ + "sdot v21.4s, v7.16b, v1.4b[0]\n" /* out12 = b1 * a01[0], b1 = q7*/ \ + "sdot v24.4s, v7.16b, v1.4b[1]\n" /* out13 = b1 * a01[1], b1 = q7*/ \ + "sdot v27.4s, v7.16b, v1.4b[2]\n" /* out14 = b1 * a01[2], b1 = q7*/ \ + "sdot v30.4s, v7.16b, v1.4b[3]\n" /* out15 = b1 * a01[3], b1 = q7*/ \ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b1, b2 to q6, q7*/ \ + "sdot v10.4s, v4.16b, v0.4b[0]\n" /* out16 = b2 * a00[0], b2 = q4*/ \ + "sdot v13.4s, v4.16b, v0.4b[1]\n" /* out17 = b2 * a00[1], b2 = q4*/ \ + "sdot v16.4s, v4.16b, v0.4b[2]\n" /* out18 = b2 * a00[2], b2 = q4*/ \ + "sdot v19.4s, v4.16b, v0.4b[3]\n" /* out19 = b2 * a00[3], b2 = q4*/ \ + "sdot v22.4s, v4.16b, v1.4b[0]\n" /* out20 = b2 * a00[0], b2 = q4*/ \ + "sdot v25.4s, v4.16b, v1.4b[1]\n" /* out21 = b2 * a00[1], b2 = q4*/ \ + "sdot v28.4s, v4.16b, v1.4b[2]\n" /* out22 = b2 * a00[2], b2 = q4*/ \ + "sdot v31.4s, v4.16b, v1.4b[3]\n" /* out23 = b2 * a00[3], b2 = q4*/ \ + "sdot v8.4s , v5.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \ + "sdot v11.4s , v5.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \ + "sdot v14.4s, v5.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \ + "sdot v17.4s, v5.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ \ + "sdot v20.4s, v5.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ \ + "sdot v23.4s, v5.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ \ + "sdot v26.4s, v5.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ \ + "sdot v29.4s, v5.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ \ + "sdot v9.4s, v6.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ \ + "sdot v12.4s, v6.16b, v2.4b[1]\n" /* out9 = b1 * a10[1], b1 = q6*/ \ + "sdot v15.4s, v6.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q6*/ \ + "sdot v18.4s, v6.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q6*/ \ + "sdot v21.4s, v6.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q6*/ \ + "sdot v24.4s, v6.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q6*/ \ + "sdot v27.4s, v6.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q6*/ \ + "sdot v30.4s, v6.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q6*/ \ + "sdot v10.4s, v7.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q7*/ \ + "sdot v13.4s, v7.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q7*/ \ + "sdot v16.4s, v7.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q7*/ \ + "sdot v19.4s, v7.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q7*/ \ + "sdot v22.4s, v7.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q7*/ \ + "sdot v25.4s, v7.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q7*/ \ + "sdot v28.4s, v7.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \ + "sdot v31.4s, v7.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \ + "b 11f\n" /* tails==1 final tail*/ \ + "3: \n" /* tail=1*/ \ + "ldr q6, [%[b_ptr]], #16\n" /* load b2 to q6*/ \ + "sdot v8.4s , v4.16b, v0.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \ + "sdot v11.4s , v4.16b, v0.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \ + "sdot v14.4s, v4.16b, v0.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \ + "sdot v17.4s, v4.16b, v0.4b[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ \ + "sdot v20.4s, v4.16b, v1.4b[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ \ + "sdot v23.4s, v4.16b, v1.4b[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ \ + "sdot v26.4s, v4.16b, v1.4b[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ \ + "sdot v29.4s, v4.16b, v1.4b[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ \ + "sdot v9.4s, v5.16b, v0.4b[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ \ + "sdot v12.4s, v5.16b, v0.4b[1]\n" /* out9 = b1 * a10[1], b1 = q6*/ \ + "sdot v15.4s, v5.16b, v0.4b[2]\n" /* out10 = b1 * a10[2], b1 = q6*/ \ + "sdot v18.4s, v5.16b, v0.4b[3]\n" /* out11 = b1 * a10[3], b1 = q6*/ \ + "sdot v21.4s, v5.16b, v1.4b[0]\n" /* out12 = b1 * a10[0], b1 = q6*/ \ + "sdot v24.4s, v5.16b, v1.4b[1]\n" /* out13 = b1 * a10[1], b1 = q6*/ \ + "sdot v27.4s, v5.16b, v1.4b[2]\n" /* out14 = b1 * a10[2], b1 = q6*/ \ + "sdot v30.4s, v5.16b, v1.4b[3]\n" /* out15 = b1 * a10[3], b1 = q6*/ \ + "sdot v10.4s, v6.16b, v0.4b[0]\n" /* out16 = b2 * a10[0], b2 = q7*/ \ + "sdot v13.4s, v6.16b, v0.4b[1]\n" /* out17 = b2 * a10[0], b2 = q7*/ \ + "sdot v16.4s, v6.16b, v0.4b[2]\n" /* out18 = b2 * a10[0], b2 = q7*/ \ + "sdot v19.4s, v6.16b, v0.4b[3]\n" /* out19 = b2 * a10[0], b2 = q7*/ \ + "sdot v22.4s, v6.16b, v1.4b[0]\n" /* out20 = b2 * a10[0], b2 = q7*/ \ + "sdot v25.4s, v6.16b, v1.4b[1]\n" /* out21 = b2 * a10[0], b2 = q7*/ \ + "sdot v28.4s, v6.16b, v1.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \ + "sdot v31.4s, v6.16b, v1.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \ + "b 11f\n" /* tails==2 final tail*/ \ + "4:\n" /* tail = 2*/ \ + "sdot v8.4s , v7.16b, v2.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \ + "sdot v11.4s , v7.16b, v2.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \ + "sdot v14.4s, v7.16b, v2.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \ + "sdot v17.4s, v7.16b, v2.4b[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ \ + "sdot v20.4s, v7.16b, v3.4b[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ \ + "sdot v23.4s, v7.16b, v3.4b[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ \ + "sdot v26.4s, v7.16b, v3.4b[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ \ + "sdot v29.4s, v7.16b, v3.4b[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ \ + "sdot v9.4s, v4.16b, v2.4b[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ \ + "sdot v12.4s, v4.16b, v2.4b[1]\n" /* out9 = b1 * a10[1], b1 = q6*/ \ + "sdot v15.4s, v4.16b, v2.4b[2]\n" /* out10 = b1 * a10[2], b1 = q6*/ \ + "sdot v18.4s, v4.16b, v2.4b[3]\n" /* out11 = b1 * a10[3], b1 = q6*/ \ + "sdot v21.4s, v4.16b, v3.4b[0]\n" /* out12 = b1 * a10[0], b1 = q6*/ \ + "sdot v24.4s, v4.16b, v3.4b[1]\n" /* out13 = b1 * a10[1], b1 = q6*/ \ + "sdot v27.4s, v4.16b, v3.4b[2]\n" /* out14 = b1 * a10[2], b1 = q6*/ \ + "sdot v30.4s, v4.16b, v3.4b[3]\n" /* out15 = b1 * a10[3], b1 = q6*/ \ + "sdot v10.4s, v5.16b, v2.4b[0]\n" /* out16 = b2 * a10[0], b2 = q7*/ \ + "sdot v13.4s, v5.16b, v2.4b[1]\n" /* out17 = b2 * a10[0], b2 = q7*/ \ + "sdot v16.4s, v5.16b, v2.4b[2]\n" /* out18 = b2 * a10[0], b2 = q7*/ \ + "sdot v19.4s, v5.16b, v2.4b[3]\n" /* out19 = b2 * a10[0], b2 = q7*/ \ + "sdot v22.4s, v5.16b, v3.4b[0]\n" /* out20 = b2 * a10[0], b2 = q7*/ \ + "sdot v25.4s, v5.16b, v3.4b[1]\n" /* out21 = b2 * a10[0], b2 = q7*/ \ + "sdot v28.4s, v5.16b, v3.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \ + "sdot v31.4s, v5.16b, v3.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \ + "b 11f\n" /* tails==3 final tail*/ \ + "5:\n" /* tail = 3*/ \ + "ldr q4, [%[b_ptr]], #16\n" /* load b2, b0 to q4*/ \ + "sdot v8.4s , v6.16b, v0.4b[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ \ + "sdot v11.4s , v6.16b, v0.4b[1]\n" /* out1 = b0 * a10[1], b0 = q5*/ \ + "sdot v14.4s, v6.16b, v0.4b[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ \ + "sdot v17.4s, v6.16b, v0.4b[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ \ + "sdot v20.4s, v6.16b, v1.4b[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ \ + "sdot v23.4s, v6.16b, v1.4b[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ \ + "sdot v26.4s, v6.16b, v1.4b[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ \ + "sdot v29.4s, v6.16b, v1.4b[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ \ + "sdot v9.4s, v7.16b, v0.4b[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ \ + "sdot v12.4s, v7.16b, v0.4b[1]\n" /* out9 = b1 * a10[1], b1 = q6*/ \ + "sdot v15.4s, v7.16b, v0.4b[2]\n" /* out10 = b1 * a10[2], b1 = q6*/ \ + "sdot v18.4s, v7.16b, v0.4b[3]\n" /* out11 = b1 * a10[3], b1 = q6*/ \ + "sdot v21.4s, v7.16b, v1.4b[0]\n" /* out12 = b1 * a10[0], b1 = q6*/ \ + "sdot v24.4s, v7.16b, v1.4b[1]\n" /* out13 = b1 * a10[1], b1 = q6*/ \ + "sdot v27.4s, v7.16b, v1.4b[2]\n" /* out14 = b1 * a10[2], b1 = q6*/ \ + "sdot v30.4s, v7.16b, v1.4b[3]\n" /* out15 = b1 * a10[3], b1 = q6*/ \ + "sdot v10.4s, v4.16b, v0.4b[0]\n" /* out16 = b2 * a10[0], b2 = q7*/ \ + "sdot v13.4s, v4.16b, v0.4b[1]\n" /* out17 = b2 * a10[0], b2 = q7*/ \ + "sdot v16.4s, v4.16b, v0.4b[2]\n" /* out18 = b2 * a10[0], b2 = q7*/ \ + "sdot v19.4s, v4.16b, v0.4b[3]\n" /* out19 = b2 * a10[0], b2 = q7*/ \ + "sdot v22.4s, v4.16b, v1.4b[0]\n" /* out20 = b2 * a10[0], b2 = q7*/ \ + "sdot v25.4s, v4.16b, v1.4b[1]\n" /* out21 = b2 * a10[0], b2 = q7*/ \ + "sdot v28.4s, v4.16b, v1.4b[2]\n" /* out22 = b2 * a10[0], b2 = q7*/ \ + "sdot v31.4s, v4.16b, v1.4b[3]\n" /* out23 = b2 * a10[0], b2 = q7*/ \ + "11: \n" /* check if relu */ \ + "cbz %w[relu], 12f\n" /* skip relu */ \ + "movi v2.4s, #0\n" /* for relu*/ \ + "smax v8.4s, v8.4s, v2.4s\n" /* relu*/ \ + "smax v9.4s, v9.4s, v2.4s\n" /* relu*/ \ + "smax v10.4s, v10.4s, v2.4s\n" /* relu*/ \ + "smax v11.4s, v11.4s, v2.4s\n" /* relu*/ \ + "smax v12.4s, v12.4s, v2.4s\n" /* relu*/ \ + "smax v13.4s, v13.4s, v2.4s\n" /* relu*/ \ + "smax v14.4s, v14.4s, v2.4s\n" /* relu*/ \ + "smax v15.4s, v15.4s, v2.4s\n" /* relu*/ \ + "smax v16.4s,v16.4s,v2.4s\n" /* relu*/ \ + "smax v17.4s,v17.4s,v2.4s\n" /* relu*/ \ + "smax v18.4s, v18.4s, v2.4s\n" /* relu*/ \ + "smax v19.4s, v19.4s, v2.4s\n" /* relu*/ \ + "smax v20.4s, v20.4s, v2.4s\n" /* relu*/ \ + "smax v21.4s, v21.4s, v2.4s\n" /* relu*/ \ + "smax v22.4s, v22.4s, v2.4s\n" /* relu*/ \ + "smax v23.4s, v23.4s, v2.4s\n" /* relu*/ \ + "smax v24.4s, v24.4s, v2.4s\n" /* relu*/ \ + "smax v25.4s, v25.4s, v2.4s\n" /* relu*/ \ + "smax v26.4s, v26.4s, v2.4s\n" /* relu*/ \ + "smax v27.4s, v27.4s, v2.4s\n" /* relu*/ \ + "smax v28.4s, v28.4s, v2.4s\n" /* relu*/ \ + "smax v29.4s, v29.4s, v2.4s\n" /* relu*/ \ + "smax v30.4s, v30.4s, v2.4s\n" /* relu*/ \ + "smax v31.4s, v31.4s, v2.4s\n" /* relu*/ \ + "12: \n" + +#define GEMM_SDOT_INT32_OUT \ + "st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */ \ + "st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */ \ + "st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */ \ + "st1 {v17.4s, v18.4s, v19.4s},[%[c_ptr3]], #48\n" /* store r3 */ \ + "st1 {v20.4s, v21.4s, v22.4s},[%[c_ptr4]], #48\n" /* store r4 */ \ + "st1 {v23.4s, v24.4s, v25.4s},[%[c_ptr5]], #48\n" /* store r5 */ \ + "st1 {v26.4s, v27.4s, v28.4s},[%[c_ptr6]], #48\n" /* store r6 */ \ + "st1 {v29.4s, v30.4s, v31.4s},[%[c_ptr7]], #48\n" /* store r7 */ + +#define GEMM_SDOT_FP32_OUT \ + "ldp q0, q1, [%[scale]]\n" /* load scale */ \ + "scvtf v2.4s , v8.4s\n" /* 00, convert to fp32 */ \ + "scvtf v3.4s , v9.4s\n" /* 01, convert to fp32 */ \ + "scvtf v4.4s , v10.4s\n" /* 02, convert to fp32 */ \ + "scvtf v5.4s , v11.4s\n" /* 03, convert to fp32 */ \ + "scvtf v6.4s , v12.4s\n" /* 00, convert to fp32 */ \ + "scvtf v7.4s , v13.4s\n" /* 00, convert to fp32 */ \ + "fmul v8.4s, v2.4s, v0.s[0]\n" /* 00, mul scale to get final */ \ + "fmul v9.4s, v3.4s, v0.s[0]\n" /* 00, mul scale to get final */ \ + "fmul v10.4s, v4.4s, v0.s[0]\n" /* 00, mul scale to get final */ \ + "fmul v11.4s, v5.4s, v0.s[1]\n" /* 00, mul scale to get final */ \ + "fmul v12.4s, v6.4s, v0.s[1]\n" /* 00, mul scale to get final */ \ + "fmul v13.4s, v7.4s, v0.s[1]\n" /* 00, mul scale to get final */ \ + "scvtf v2.4s , v14.4s\n" /* 00, convert to fp32 */ \ + "scvtf v3.4s , v15.4s\n" /* 01, convert to fp32 */ \ + "scvtf v4.4s , v16.4s\n" /* 02, convert to fp32 */ \ + "scvtf v5.4s , v17.4s\n" /* 03, convert to fp32 */ \ + "scvtf v6.4s , v18.4s\n" /* 00, convert to fp32 */ \ + "scvtf v7.4s , v19.4s\n" /* 00, convert to fp32 */ \ + "st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */ \ + "st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */ \ + "fmul v14.4s, v2.4s, v0.s[2]\n" /* 00, mul scale to get final */ \ + "fmul v15.4s, v3.4s, v0.s[2]\n" /* 00, mul scale to get final */ \ + "fmul v16.4s, v4.4s, v0.s[2]\n" /* 00, mul scale to get final */ \ + "fmul v17.4s, v5.4s, v0.s[3]\n" /* 00, mul scale to get final */ \ + "fmul v18.4s, v6.4s, v0.s[3]\n" /* 00, mul scale to get final */ \ + "fmul v19.4s, v7.4s, v0.s[3]\n" /* 00, mul scale to get final */ \ + "scvtf v2.4s , v20.4s\n" /* 00, convert to fp32 */ \ + "scvtf v3.4s , v21.4s\n" /* 01, convert to fp32 */ \ + "scvtf v4.4s , v22.4s\n" /* 02, convert to fp32 */ \ + "scvtf v5.4s , v23.4s\n" /* 03, convert to fp32 */ \ + "scvtf v6.4s , v24.4s\n" /* 00, convert to fp32 */ \ + "scvtf v7.4s , v25.4s\n" /* 00, convert to fp32 */ \ + "st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */ \ + "st1 {v17.4s, v18.4s, v19.4s},[%[c_ptr3]], #48\n" /* store r3 */ \ + "fmul v20.4s, v2.4s, v1.s[0]\n" /* 00, mul scale to get final */ \ + "fmul v21.4s, v3.4s, v1.s[0]\n" /* 00, mul scale to get final */ \ + "fmul v22.4s, v4.4s, v1.s[0]\n" /* 00, mul scale to get final */ \ + "fmul v23.4s, v5.4s, v1.s[1]\n" /* 00, mul scale to get final */ \ + "fmul v24.4s, v6.4s, v1.s[1]\n" /* 00, mul scale to get final */ \ + "fmul v25.4s, v7.4s, v1.s[1]\n" /* 00, mul scale to get final */ \ + "scvtf v2.4s , v26.4s\n" /* 00, convert to fp32 */ \ + "scvtf v3.4s , v27.4s\n" /* 01, convert to fp32 */ \ + "scvtf v4.4s , v28.4s\n" /* 02, convert to fp32 */ \ + "scvtf v5.4s , v29.4s\n" /* 03, convert to fp32 */ \ + "scvtf v6.4s , v30.4s\n" /* 00, convert to fp32 */ \ + "scvtf v7.4s , v31.4s\n" /* 00, convert to fp32 */ \ + "st1 {v20.4s, v21.4s, v22.4s},[%[c_ptr4]], #48\n" /* store r4 */ \ + "st1 {v23.4s, v24.4s, v25.4s},[%[c_ptr5]], #48\n" /* store r5 */ \ + "fmul v26.4s, v2.4s, v1.s[2]\n" /* 00, mul scale to get final */ \ + "fmul v27.4s, v3.4s, v1.s[2]\n" /* 00, mul scale to get final */ \ + "fmul v28.4s, v4.4s, v1.s[2]\n" /* 00, mul scale to get final */ \ + "fmul v29.4s, v5.4s, v1.s[3]\n" /* 00, mul scale to get final */ \ + "fmul v30.4s, v6.4s, v1.s[3]\n" /* 00, mul scale to get final */ \ + "fmul v31.4s, v7.4s, v1.s[3]\n" /* 00, mul scale to get final */ \ + "st1 {v26.4s, v27.4s, v28.4s},[%[c_ptr6]], #48\n" /* store r6 */ \ + "st1 {v29.4s, v30.4s, v31.4s},[%[c_ptr7]], #48\n" /* store r7 */ + +#define GEMM_SDOT_INT8_OUT \ + "ldp q0, q1, [%[scale]]\n" /* load scale */ \ + "scvtf v2.4s , v8.4s\n" /* 00, convert to fp32 */ \ + "scvtf v3.4s , v9.4s\n" /* 01, convert to fp32 */ \ + "scvtf v4.4s , v10.4s\n" /* 02, convert to fp32 */ \ + "scvtf v5.4s , v11.4s\n" /* 03, convert to fp32 */ \ + "scvtf v6.4s , v12.4s\n" /* 00, convert to fp32 */ \ + "scvtf v7.4s , v13.4s\n" /* 00, convert to fp32 */ \ + "fmul v8.4s, v2.4s, v0.s[0]\n" /* 00, mul scale to get final*/ \ + "fmul v9.4s, v3.4s, v0.s[0]\n" /* 00, mul scale to get final*/ \ + "fmul v10.4s, v4.4s, v0.s[0]\n" /* 00, mul scale to get final*/ \ + "fmul v11.4s, v5.4s, v0.s[1]\n" /* 00, mul scale to get final*/ \ + "fmul v12.4s, v6.4s, v0.s[1]\n" /* 00, mul scale to get final*/ \ + "fmul v13.4s, v7.4s, v0.s[1]\n" /* 00, mul scale to get final*/ \ + "scvtf v2.4s , v14.4s\n" /* 00, convert to fp32 */ \ + "scvtf v3.4s , v15.4s\n" /* 01, convert to fp32 */ \ + "scvtf v4.4s , v16.4s\n" /* 02, convert to fp32 */ \ + "scvtf v5.4s , v17.4s\n" /* 03, convert to fp32 */ \ + "scvtf v6.4s , v18.4s\n" /* 00, convert to fp32 */ \ + "scvtf v7.4s , v19.4s\n" /* 00, convert to fp32 */ \ + "fmul v14.4s, v2.4s, v0.s[2]\n" /* 00, mul scale to get final*/ \ + "fmul v15.4s, v3.4s, v0.s[2]\n" /* 00, mul scale to get final*/ \ + "fmul v16.4s, v4.4s, v0.s[2]\n" /* 00, mul scale to get final*/ \ + "fmul v17.4s, v5.4s, v0.s[3]\n" /* 00, mul scale to get final*/ \ + "fmul v18.4s, v6.4s, v0.s[3]\n" /* 00, mul scale to get final*/ \ + "fmul v19.4s, v7.4s, v0.s[3]\n" /* 00, mul scale to get final*/ \ + "scvtf v2.4s , v20.4s\n" /* 00, convert to fp32 */ \ + "scvtf v3.4s , v21.4s\n" /* 01, convert to fp32 */ \ + "scvtf v4.4s , v22.4s\n" /* 02, convert to fp32 */ \ + "scvtf v5.4s , v23.4s\n" /* 03, convert to fp32 */ \ + "scvtf v6.4s , v24.4s\n" /* 00, convert to fp32 */ \ + "scvtf v7.4s , v25.4s\n" /* 00, convert to fp32 */ \ + "fmul v20.4s, v2.4s, v1.s[0]\n" /* 00, mul scale to get final*/ \ + "fmul v21.4s, v3.4s, v1.s[0]\n" /* 00, mul scale to get final*/ \ + "fmul v22.4s, v4.4s, v1.s[0]\n" /* 00, mul scale to get final*/ \ + "fmul v23.4s, v5.4s, v1.s[1]\n" /* 00, mul scale to get final*/ \ + "fmul v24.4s, v6.4s, v1.s[1]\n" /* 00, mul scale to get final*/ \ + "fmul v25.4s, v7.4s, v1.s[1]\n" /* 00, mul scale to get final*/ \ + "scvtf v2.4s , v26.4s\n" /* 00, convert to fp32 */ \ + "scvtf v3.4s , v27.4s\n" /* 01, convert to fp32 */ \ + "scvtf v4.4s , v28.4s\n" /* 02, convert to fp32 */ \ + "scvtf v5.4s , v29.4s\n" /* 03, convert to fp32 */ \ + "scvtf v6.4s , v30.4s\n" /* 00, convert to fp32 */ \ + "scvtf v7.4s , v31.4s\n" /* 00, convert to fp32 */ \ + "fmul v26.4s, v2.4s, v1.s[2]\n" /* 00, mul scale to get final*/ \ + "fmul v27.4s, v3.4s, v1.s[2]\n" /* 00, mul scale to get final*/ \ + "fmul v28.4s, v4.4s, v1.s[2]\n" /* 00, mul scale to get final*/ \ + "fmul v29.4s, v5.4s, v1.s[3]\n" /* 00, mul scale to get final*/ \ + "fmul v30.4s, v6.4s, v1.s[3]\n" /* 00, mul scale to get final*/ \ + "fmul v31.4s, v7.4s, v1.s[3]\n" /* 00, mul scale to get final*/ \ + "fcvtas v0.4s, v8.4s\n" /* 00, cvt to int */ \ + "fcvtas v1.4s, v9.4s\n" /* 00, cvt to int */ \ + "fcvtas v2.4s, v10.4s\n" /* 00, cvt to int */ \ + "fcvtas v3.4s, v11.4s\n" /* 00, cvt to int */ \ + "fcvtas v4.4s, v12.4s\n" /* 00, cvt to int */ \ + "fcvtas v5.4s, v13.4s\n" /* 00, cvt to int */ \ + "sqxtn v8.4h, v0.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn2 v8.8h, v1.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn v9.4h, v2.4s\n" /* 00, cvt int32 to int16 */ \ + "fcvtas v0.4s, v14.4s\n" /* 00, cvt to int */ \ + "fcvtas v1.4s, v15.4s\n" /* 00, cvt to int */ \ + "fcvtas v2.4s, v16.4s\n" /* 00, cvt to int */ \ + "sqxtn v11.4h, v3.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn2 v11.8h, v4.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn v12.4h, v5.4s\n" /* 00, cvt int32 to int16 */ \ + "fcvtas v3.4s, v17.4s\n" /* 00, cvt to int */ \ + "fcvtas v4.4s, v18.4s\n" /* 00, cvt to int */ \ + "fcvtas v5.4s, v19.4s\n" /* 00, cvt to int */ \ + "sqxtn v14.4h, v0.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn2 v14.8h, v1.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn v15.4h, v2.4s\n" /* 00, cvt int32 to int16 */ \ + "fcvtas v0.4s, v20.4s\n" /* 00, cvt to int */ \ + "fcvtas v1.4s, v21.4s\n" /* 00, cvt to int */ \ + "fcvtas v2.4s, v22.4s\n" /* 00, cvt to int */ \ + "sqxtn v17.4h, v3.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn2 v17.8h, v4.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn v18.4h, v5.4s\n" /* 00, cvt int32 to int16 */ \ + "fcvtas v3.4s, v23.4s\n" /* 00, cvt to int */ \ + "fcvtas v4.4s, v24.4s\n" /* 00, cvt to int */ \ + "fcvtas v5.4s, v25.4s\n" /* 00, cvt to int */ \ + "sqxtn v20.4h, v0.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn2 v20.8h, v1.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn v21.4h, v2.4s\n" /* 00, cvt int32 to int16 */ \ + "fcvtas v0.4s, v26.4s\n" /* 00, cvt to int */ \ + "fcvtas v1.4s, v27.4s\n" /* 00, cvt to int */ \ + "fcvtas v2.4s, v28.4s\n" /* 00, cvt to int */ \ + "sqxtn v23.4h, v3.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn2 v23.8h, v4.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn v24.4h, v5.4s\n" /* 00, cvt int32 to int16 */ \ + "fcvtas v3.4s, v29.4s\n" /* 00, cvt to int */ \ + "fcvtas v4.4s, v30.4s\n" /* 00, cvt to int */ \ + "fcvtas v5.4s, v31.4s\n" /* 00, cvt to int */ \ + "sqxtn v26.4h, v0.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn2 v26.8h, v1.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn v27.4h, v2.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn v29.4h, v3.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn2 v29.8h, v4.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn v30.4h, v5.4s\n" /* 00, cvt int32 to int16 */ \ + "sqxtn v4.8b, v8.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn v0.8b, v9.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn v5.8b, v11.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn v1.8b, v12.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn v6.8b, v14.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn v2.8b, v15.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn v7.8b, v17.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn v3.8b, v18.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn v16.8b, v20.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn v15.8b, v21.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn v20.8b, v23.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn v17.8b, v24.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn v24.8b, v26.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn v18.8b, v27.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn v28.8b, v29.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "sqxtn v19.8b, v30.8h\n" /* 00, 01, cvt int16 to int8 */ \ + "st1 {v4.8b},[%[c_ptr0]], #8\n" /* store r0 */ \ + "st1 {v5.8b},[%[c_ptr1]], #8\n" /* store r0 */ \ + "st1 {v6.8b},[%[c_ptr2]], #8\n" /* store r0 */ \ + "st1 {v7.8b},[%[c_ptr3]], #8\n" /* store r0 */ \ + "st1 {v16.8b},[%[c_ptr4]], #8\n" /* store r0 */ \ + "st1 {v20.8b},[%[c_ptr5]], #8\n" /* store r0 */ \ + "st1 {v24.8b},[%[c_ptr6]], #8\n" /* store r0 */ \ + "st1 {v28.8b},[%[c_ptr7]], #8\n" /* store r0 */ \ + "str s0,[%[c_ptr0]], #4\n" /* store r0 */ \ + "str s1,[%[c_ptr1]], #4\n" /* store r0 */ \ + "str s2,[%[c_ptr2]], #4\n" /* store r0 */ \ + "str s3,[%[c_ptr3]], #4\n" /* store r0 */ \ + "str s15,[%[c_ptr4]], #4\n" /* store r0 */ \ + "str s17,[%[c_ptr5]], #4\n" /* store r0 */ \ + "str s18,[%[c_ptr6]], #4\n" /* store r0 */ \ + "str s19,[%[c_ptr7]], #4\n" /* store r0 */ + +template <> +inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, + const int8_t*& b_ptr, // NOLINT + const int32_t* bias, + int32_t*& c_ptr0, // NOLINT + int32_t*& c_ptr1, // NOLINT + int32_t*& c_ptr2, // NOLINT + int32_t*& c_ptr3, // NOLINT + int32_t*& c_ptr4, // NOLINT + int32_t*& c_ptr5, // NOLINT + int32_t*& c_ptr6, // NOLINT + int32_t*& c_ptr7, // NOLINT + const float32_t* scale, + bool is_relu, + int k, + int tail) { + asm volatile(_DECLARE_SDOT_ELEMENT GEMM_SDOT_INT8_KERNEL GEMM_SDOT_INT32_OUT + : [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), + [k] "+r"(k), + [tail] "+r"(tail), + [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), + [c_ptr4] "+r"(c_ptr4), + [c_ptr5] "+r"(c_ptr5), + [c_ptr6] "+r"(c_ptr6), + [c_ptr7] "+r"(c_ptr7) + : [bias_ptr] "r"(bias), [scale] "r"(scale), [relu] "r"(is_relu) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31"); +} +template <> +inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, + const int8_t*& b_ptr, // NOLINT + const int32_t* bias, + float32_t*& c_ptr0, // NOLINT + float32_t*& c_ptr1, // NOLINT + float32_t*& c_ptr2, // NOLINT + float32_t*& c_ptr3, // NOLINT + float32_t*& c_ptr4, // NOLINT + float32_t*& c_ptr5, // NOLINT + float32_t*& c_ptr6, // NOLINT + float32_t*& c_ptr7, // NOLINT + const float32_t* scale, + bool is_relu, + int k, + int tail) { + asm volatile(GEMM_SDOT_INT8_KERNEL GEMM_SDOT_FP32_OUT + : [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), + [k] "+r"(k), + [tail] "+r"(tail), + [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), + [c_ptr4] "+r"(c_ptr4), + [c_ptr5] "+r"(c_ptr5), + [c_ptr6] "+r"(c_ptr6), + [c_ptr7] "+r"(c_ptr7) + : [bias_ptr] "r"(bias), [scale] "r"(scale), [relu] "r"(is_relu) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31"); +} +template <> +inline void sgemm_sdot_int8_kernel(const int8_t* a_ptr, + const int8_t*& b_ptr, // NOLINT + const int32_t* bias, + int8_t*& c_ptr0, // NOLINT + int8_t*& c_ptr1, // NOLINT + int8_t*& c_ptr2, // NOLINT + int8_t*& c_ptr3, // NOLINT + int8_t*& c_ptr4, // NOLINT + int8_t*& c_ptr5, // NOLINT + int8_t*& c_ptr6, // NOLINT + int8_t*& c_ptr7, // NOLINT + const float32_t* scale, + bool is_relu, + int k, + int tail) { + asm volatile(GEMM_SDOT_INT8_KERNEL GEMM_SDOT_INT8_OUT + : [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), + [k] "+r"(k), + [tail] "+r"(tail), + [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), + [c_ptr4] "+r"(c_ptr4), + [c_ptr5] "+r"(c_ptr5), + [c_ptr6] "+r"(c_ptr6), + [c_ptr7] "+r"(c_ptr7) + : [bias_ptr] "r"(bias), [scale] "r"(scale), [relu] "r"(is_relu) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31"); +} +#endif + +#else // armv7 +// clang-format off +#define GEMM_INT8_KERNEL \ + "vld1.8 {d0-d1}, [%[a_ptr]: 128]!\n" /* load 4x2x2 int8, A, k2x2 */ \ + "vld1.8 {d4-d7}, [%[b_ptr]: 128]!\n" /* load 8x2x2 int8, B, k2x2 */ \ + "vld1.8 {d8-d9}, [%[bias]]\n" /* load int32x4 bias */ \ + "vext.8 q5, q4, q4, #4\n" /* bias shift 1 int32 */ \ + "vext.8 q6, q4, q4, #8\n" /* bias shift 2 int32 */ \ + "vext.8 q7, q4, q4, #12\n" /* bias shift 3 int32 */ \ + "pld [%[a_ptr]]\n" /* preload A */ \ + "vand q8, q4, q4\n" /* set bias to out00 */ \ + "vand q9, q4, q4\n" /* set bias to out01 */ \ + "pld [%[b_ptr]]\n" /* preload B */ \ + "vand q10, q5, q5\n" /* set bias to out10 */ \ + "vand q11, q5, q5\n" /* set bias to out11 */ \ + "pld [%[b_ptr], #64]\n" /* preload B */ \ + "vand q12, q6, q6\n" /* set bias to out20 */ \ + "vand q13, q6, q6\n" /* set bias to out21 */ \ + "pld [%[b_ptr], #128]\n" /* preload B */ \ + "vand q14, q7, q7\n" /* set bias to out30 */ \ + "vand q15, q7, q7\n" /* set bias to out31 */ \ + "pld [%[a_ptr], #64]\n" /* preload A */ \ + "vext.8 d2, d0, d0, #2\n" /* shift left circular by 2byte */ \ + "vext.8 d3, d1, d1, #2\n" /* shift left circular by 2byte */ \ + "pld [%[b_ptr], #192]\n" /* preload b */ \ + "pld [%[b_ptr], #256]\n" /* preload b */ \ + "pld [%[a_ptr], #128]\n" /* preload a */ \ + "cmp %[k], #0\n" /* check main loop count */ \ + "beq 3f\n" /* if k = 0, jump to remains */ /* 1st r0, r1 */ \ + "vmull.s8 q4, d0, d4\n" /* a0 * b0 = c00 */ \ + "vmull.s8 q5, d0, d5\n" /* a0 * b1 = c01 */ \ + "vmull.s8 q6, d2, d4\n" /* a1 * b0 = c10 */ \ + "vmull.s8 q7, d2, d5\n" /* a1 * b1 = c11 */ \ + "subs %[k], %[k], #1\n" /* loop count -1 */ /* 2nd r0, r1 */ \ + "vmlal.s8 q4, d1, d6\n" /* a0 * b0 = c00 */ \ + "vmlal.s8 q5, d1, d7\n" /* a0 * b1 = c01 */ \ + "vrev64.32 q0, q0\n" /* shift left circular by 4byte */ \ + "vmlal.s8 q6, d3, d6\n" /* a1 * b0 = c10 */ \ + "vmlal.s8 q7, d3, d7\n" /* a1 * b1 = c11 */ \ + "vrev64.32 q1, q1\n" /* shift left circular by 4byte */ \ + "beq 8f\n" /* skip main loop */ /* main loop*/ \ + "0:\n" /* main loop */ /* 1st r2, r3 */ \ + "vpadal.s16 q8, q4\n" /* pair add and accumulate to int32, c00 */ \ + "vmull.s8 q4, d0, d4\n" /* a2 * b0 = c20 */ \ + "vpadal.s16 q9, q5\n" /* pair add and accumulate to int32, c01 */ \ + "vmull.s8 q5, d0, d5\n" /* a2 * b1 = c21 */ \ + "vpadal.s16 q10,q6\n" /* pair add and accumulate to int32, c10 */ \ + "vmull.s8 q6, d2, d4\n" /* a3 * b0 = c30 */ \ + "vpadal.s16 q11,q7\n" /* pair add and accumulate to int32, c11 */ \ + "vmull.s8 q7, d2, d5\n" /* a3 * b1 = c31 */ \ + "vld1.8 {d4-d5}, [%[b_ptr]: 128]!\n" /* load 4x2x2 int8, B, k2x2 */ \ + "vmlal.s8 q4, d1, d6\n" /* a0 * b0 = c00 */ \ + "vmlal.s8 q5, d1, d7\n" /* a0 * b1 = c01 */ \ + "vld1.8 {d0-d1}, [%[a_ptr]: 128]!\n" /* load 4x2x2 int8, A, k2x2 */ \ + "vmlal.s8 q6, d3, d6\n" /* a1 * b0 = c10 */ \ + "vmlal.s8 q7, d3, d7\n" /* a1 * b1 = c11 */ \ + "vld1.8 {d6-d7}, [%[b_ptr]: 128]!\n" /* load 4x2x2 int8, B, k2x2 */ \ + "vext.8 d2, d0, d0, #2\n" /* shift left circular by 2byte */ \ + "vext.8 d3, d1, d1, #2\n" /* shift left circular by 2byte */ \ + "vpadal.s16 q12,q4\n" /* pair add and accumulate to int32, c20 */ \ + "vmull.s8 q4, d0, d4\n" /* a0 * b0 = c00 */ \ + "vpadal.s16 q13,q5\n" /* pair add and accumulate to int32, c21 */ \ + "vmull.s8 q5, d0, d5\n" /* a0 * b1 = c01 */ \ + "vpadal.s16 q14,q6\n" /* pair add and accumulate to int32, c30 */ \ + "vmull.s8 q6, d2, d4\n" /* a1 * b0 = c10 */ \ + "vpadal.s16 q15,q7\n" /* pair add and accumulate to int32, c31 */ \ + "vmull.s8 q7, d2, d5\n" /* a1 * b1 = c11 */ \ + "subs %[k], %[k], #1\n" /* loop count -1 */ /* 2nd r0, r1 */ \ + "vmlal.s8 q4, d1, d6\n" /* a0 * b0 = c00 */ \ + "vmlal.s8 q5, d1, d7\n" /* a0 * b1 = c01 */ \ + "vrev64.32 q0, q0\n" /* shift left circular by 2 */ \ + "vmlal.s8 q6, d3, d6\n" /* a1 * b0 = c10 */ \ + "vmlal.s8 q7, d3, d7\n" /* a1 * b1 = c11 */ \ + "vrev64.32 q1, q1\n" /* shift left circular by 2 */ \ + "bgt 0b\n" /* jump to main loop */ \ + "8:\n" /* end of main loop */ /* 1st r2, r3 */ \ + "vpadal.s16 q8, q4\n" /* pair add and accumulate to int32, c00 */ \ + "vmull.s8 q4, d0, d4\n" /* a2 * b0 = c20 */ \ + "vpadal.s16 q9, q5\n" /* pair add and accumulate to int32, c01 */ \ + "vmull.s8 q5, d0, d5\n" /* a2 * b1 = c21 */ \ + "vpadal.s16 q10,q6\n" /* pair add and accumulate to int32, c10 */ \ + "vmull.s8 q6, d2, d4\n" /* a3 * b0 = c30 */ \ + "vpadal.s16 q11,q7\n" /* pair add and accumulate to int32, c11 */ \ + "vmull.s8 q7, d2, d5\n" /* a3 * b1 = c31 */ /* 2nd r2, r3 */ \ + "vmlal.s8 q4, d1, d6\n" /* a0 * b0 = c20 */ \ + "vmlal.s8 q5, d1, d7\n" /* a0 * b1 = c21 */ \ + "vmlal.s8 q6, d3, d6\n" /* a1 * b0 = c30 */ \ + "vmlal.s8 q7, d3, d7\n" /* a1 * b1 = c31 */ \ + "cmp %[rem], #0\n" /* skip remain */ \ + "beq 5f\n" \ + "mov r0, #32\n" /* address offset */ \ + "vld1.8 {d0}, [%[a_ptr]]\n" /* load a to d0, final */ \ + "vld1.8 {d4-d5}, [%[b_ptr]], r0\n" /* load b to d4, d5 */ \ + "5:\n" /* skip rem */ \ + "vpadal.s16 q12, q4\n" /* pair add and accumulate to int32, c20 */ \ + "vpadal.s16 q13, q5\n" /* pair add and accumulate to int32, c21 */ \ + "vpadal.s16 q14, q6\n" /* pair add and accumulate to int32, c30 */ \ + "vpadal.s16 q15, q7\n" /* pair add and accumulate to int32, c31 */ \ + "3:\n" /* process remain k */ \ + "cmp %[rem], #0\n" /* skip remain */ \ + "beq 7f\n" /* process remain k */ \ + "vext.8 d1, d0, d0, #2\n" /* shift left 2bytes */ \ + "vext.8 d2, d0, d0, #4\n" /* shift left 4bytes */ \ + "vext.8 d3, d0, d0, #6\n" /* shift left 6bytes */ /* 1st r0, r1 */ \ + "vmull.s8 q4, d0, d4\n" /* a0 * b0 = c00 */ \ + "vmull.s8 q5, d0, d5\n" /* a0 * b1 = c01 */ \ + "vmull.s8 q6, d1, d4\n" /* a1 * b0 = c10 */ \ + "vmull.s8 q7, d1, d5\n" /* a1 * b1 = c11 */ /* 1st r2, r3 */ \ + "vpadal.s16 q8, q4\n" /* pair add and accumulate to int32, c00 */ \ + "vmull.s8 q4, d2, d4\n" /* a2 * b0 = c20 */ \ + "vpadal.s16 q9, q5\n" /* pair add and accumulate to int32, c01 */ \ + "vmull.s8 q5, d2, d5\n" /* a2 * b1 = c21 */ \ + "vpadal.s16 q10,q6\n" /* pair add and accumulate to int32, c10 */ \ + "vmull.s8 q6, d3, d4\n" /* a3 * b0 = c30 */ \ + "vpadal.s16 q11,q7\n" /* pair add and accumulate to int32, c11 */ \ + "vmull.s8 q7, d3, d5\n" /* a3 * b1 = c31 */ \ + "vpadal.s16 q12, q4\n" /* pair add and accumulate to int32, c20 */ \ + "vpadal.s16 q13, q5\n" /* pair add and accumulate to int32, c21 */ \ + "vpadal.s16 q14, q6\n" /* pair add and accumulate to int32, c30 */ \ + "vpadal.s16 q15, q7\n" /* pair add and accumulate to int32, c31 */ \ + "7: \n" /* do relu */ /* do relu */ \ + "cmp %[is_relu], #0\n" /* skip relu */ \ + "beq 9f\n" /* skip relu */ \ + "vmov.i32 q0, #0\n" /* for relu */ \ + "vmax.s32 q8, q8, q0\n" /* relu */ \ + "vmax.s32 q9, q9, q0\n" /* relu */ \ + "vmax.s32 q10,q10, q0\n" /* relu */ \ + "vmax.s32 q11,q11, q0\n" /* relu */ \ + "vmax.s32 q12,q12, q0\n" /* relu */ \ + "vmax.s32 q13,q13, q0\n" /* relu */ \ + "vmax.s32 q14,q14, q0\n" /* relu */ \ + "vmax.s32 q15,q15, q0\n" /* relu */ /* unpack the result */ \ + "9:\n" /* unpack */ /* trans 1 */ \ + "vtrn.32 q8, q10\n" /* get q8 */ \ + "vtrn.32 q12, q14\n" /* get q12 */ \ + "vtrn.32 q9, q11\n" /* get q9 */ \ + "vtrn.32 q13, q15\n" /* get q13*/ \ + "vswp d17, d24\n" /* get q8*/ \ + "vswp d21, d28\n" /* get q10 */ \ + "vswp d19, d26\n" /* get q9 */ \ + "vswp d23, d30\n" /* get q11 */ \ + "vext.8 q0, q10, q10, #12\n" /* circular shift left 1 q0 */ \ + "vext.8 q2, q12, q12, #8\n" /* circular shift left 2 q2 */ \ + "vext.8 q4, q14, q14, #4\n" /* circular shift left 3 q4 */ \ + "vext.8 q1, q11, q11, #12\n" /* circular shift left 1 q1 */ \ + "vext.8 q3, q13, q13, #8\n" /* circular shift left 2 q3 */ \ + "vext.8 q5, q15, q15, #4\n" /* circular shift left 3 q5 */ \ + "vtrn.32 q8, q0\n" /* get q8 */ \ + "vtrn.32 q2, q4\n" /* get q2 */ \ + "vtrn.32 q9, q1\n" /* get q9 */ \ + "vtrn.32 q3, q5\n" /* get q3 */ /* trans 2 */ \ + "vswp d17, d4\n" /* get q8 */ \ + "vswp d1, d8\n" /* get q0: a1*/ \ + "vswp d19, d6\n" /* get q9: */ \ + "vswp d3, d10\n" /* get q1: a3b3 */ + +// clang-format off + +#define GEMM_INT8_INT32_OUT \ + /* write output */ \ + "vst1.32 {d16-d19}, [%[c_ptr0]]!\n" /* write outr0 */ \ + "vst1.32 {d0-d3}, [%[c_ptr1]]!\n" /* write outr1 */ \ + "vst1.32 {d4-d7}, [%[c_ptr2]]!\n" /* write outr2 */ \ + "vst1.32 {d8-d11}, [%[c_ptr3]]!\n" /* write outr3 */ + +#define GEMM_INT8_FP32_OUT \ + /* write output */ \ + "vld1.32 {d12-d13}, [%[scale]]\n" /* load scale */ \ + "vcvt.f32.s32 q10, q8\n" /* r00, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q11, q9\n" /* r01, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q12, q0\n" /* r10, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q13, q1\n" /* r11, cvt int32 to fp32*/ \ + "vmul.f32 q8, q10, d12[0]\n" /* r00, mul scale to get final result */ \ + "vmul.f32 q9, q11, d12[0]\n" /* r01, mul scale to get final result */ \ + "vmul.f32 q0, q12, d12[1]\n" /* r10, mul scale to get final result */ \ + "vmul.f32 q1, q13, d12[1]\n" /* r11, mul scale to get final result */ \ + "vcvt.f32.s32 q10, q2\n" /* r20, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q11, q3\n" /* r21, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q12, q4\n" /* r30, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q13, q5\n" /* r31, cvt int32 to fp32*/ \ + "vst1.32 {d16-d19}, [%[c_ptr0]]!\n" /* write r0, float32x4 x2 */ \ + "vmul.f32 q2, q10, d13[0]\n" /* r20, mul scale to get final result */ \ + "vmul.f32 q3, q11, d13[0]\n" /* r21, mul scale to get final result */ \ + "vst1.32 {d0-d3}, [%[c_ptr1]]!\n" /* write r1, float32x4 x2 */ \ + "vmul.f32 q4, q12, d13[1]\n" /* r30, mul scale to get final result */ \ + "vmul.f32 q5, q13, d13[1]\n" /* r31, mul scale to get final result */ \ + "vst1.32 {d4-d7}, [%[c_ptr2]]!\n" /* write r2, float32x4 x2 */ \ + "vst1.32 {d8-d11}, [%[c_ptr3]]!\n" /* write r3, float32x4 x2 */ + +#define GEMM_INT8_INT8_OUT \ + /* write output */ \ + "vld1.32 {d12-d13}, [%[scale]]\n" /* load scale */ \ + "vmov.f32 q7, #-0.5\n" /* neg offset */ \ + "vcvt.f32.s32 q10, q8\n" /* r00, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q11, q9\n" /* r01, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q12, q0\n" /* r10, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q13, q1\n" /* r11, cvt int32 to fp32*/ \ + "vmov.f32 q8, #0.5\n" /* pos offset */ \ + "vmov.f32 q9, #0.5\n" /* pos offset */ \ + "vmov.f32 q0, #0.5\n" /* pos offset */ \ + "vmov.f32 q1, #0.5\n" /* pos offset */ \ + "vcgt.f32 q14, q10, #0\n" /* get pos mask */ \ + "vcgt.f32 q15, q11, #0\n" /* get pos mask */ \ + "vbif.f32 q8, q7, q14\n" /* get right offset */ \ + "vbif.f32 q9, q7, q15\n" /* get right offset */ \ + "vcgt.f32 q14, q12, #0\n" /* get pos mask */ \ + "vcgt.f32 q15, q13, #0\n" /* get pos mask */ \ + "vbif.f32 q0, q7, q14\n" /* get right offset */ \ + "vbif.f32 q1, q7, q15\n" /* get right offset */ \ + "vmla.f32 q8, q10, d12[0]\n" /* r00, mul scale to get final result */ \ + "vmla.f32 q9, q11, d12[0]\n" /* r01, mul scale to get final result */ \ + "vmla.f32 q0, q12, d12[1]\n" /* r10, mul scale to get final result */ \ + "vmla.f32 q1, q13, d12[1]\n" /* r11, mul scale to get final result */ \ + "vcvt.f32.s32 q10, q2\n" /* r20, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q11, q3\n" /* r21, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q12, q4\n" /* r30, cvt int32 to fp32*/ \ + "vcvt.f32.s32 q13, q5\n" /* r31, cvt int32 to fp32*/ \ + "vmov.f32 q2, #0.5\n" /* pos offset */ \ + "vmov.f32 q3, #0.5\n" /* pos offset */ \ + "vmov.f32 q4, #0.5\n" /* pos offset */ \ + "vmov.f32 q5, #0.5\n" /* pos offset */ \ + "vcgt.f32 q14, q10, #0\n" /* get pos mask */ \ + "vcgt.f32 q15, q11, #0\n" /* get pos mask */ \ + "vbif.f32 q2, q7, q14\n" /* get right offset */ \ + "vbif.f32 q3, q7, q15\n" /* get right offset */ \ + "vcgt.f32 q14, q12, #0\n" /* get pos mask */ \ + "vcgt.f32 q15, q13, #0\n" /* get pos mask */ \ + "vbif.f32 q4, q7, q14\n" /* get right offset */ \ + "vbif.f32 q5, q7, q15\n" /* get right offset */ \ + "vmla.f32 q2, q10, d13[0]\n" /* r20, mul scale to get final result */ \ + "vmla.f32 q3, q11, d13[0]\n" /* r21, mul scale to get final result */ \ + "vmla.f32 q4, q12, d13[1]\n" /* r30, mul scale to get final result */ \ + "vmla.f32 q5, q13, d13[1]\n" /* r31, mul scale to get final result */ \ + "vcvt.s32.f32 q6, q8\n" /* r00, fp32->int32 */ \ + "vcvt.s32.f32 q7, q9\n" /* r01, fp32->int32 */ \ + "vcvt.s32.f32 q10, q0\n" /* r10, fp32->int32 */ \ + "vcvt.s32.f32 q11, q1\n" /* r11, fp32->int32 */ \ + "vcvt.s32.f32 q12, q2\n" /* r20, fp32->int32 */ \ + "vcvt.s32.f32 q13, q3\n" /* r21, fp32->int32 */ \ + "vcvt.s32.f32 q14, q4\n" /* r30, fp32->int32 */ \ + "vcvt.s32.f32 q15, q5\n" /* r31, fp32->int32 */ \ + "vqmovn.s32 d0, q6\n" /* r00, int32 -> int16 */ \ + "vqmovn.s32 d1, q7\n" /* r01, int32 -> int16 */ \ + "vqmovn.s32 d2, q10\n" /* r10, int32 -> int16 */ \ + "vqmovn.s32 d3, q11\n" /* r11, int32 -> int16 */ \ + "vqmovn.s32 d4, q12\n" /* r00, int32 -> int16 */ \ + "vqmovn.s32 d5, q13\n" /* r01, int32 -> int16 */ \ + "vqmovn.s32 d6, q14\n" /* r10, int32 -> int16 */ \ + "vqmovn.s32 d7, q15\n" /* r11, int32 -> int16 */ \ + "vqmovn.s16 d8, q0\n" /* 0, int16 -> int8 */ \ + "vqmovn.s16 d9, q1\n" /* 1, int16 -> int8 */ \ + "vqmovn.s16 d10, q2\n" /* 2, int16 -> int8 */ \ + "vqmovn.s16 d11, q3\n" /* 3, int16 -> int8 */ \ + "vst1.32 {d8}, [%[c_ptr0]]!\n" /* write r0*/ \ + "vst1.32 {d9}, [%[c_ptr1]]!\n" /* write r1*/ \ + "vst1.32 {d10}, [%[c_ptr2]]!\n" /* write r2*/ \ + "vst1.32 {d11}, [%[c_ptr3]]!\n" /* write r3*/ + +template <> +inline void gemm_int8_kernel(const int8_t* a_ptr, const int8_t*& b_ptr, // NOLINT + const int32_t* bias, int32_t*& c_ptr0, // NOLINT + int32_t*& c_ptr1, int32_t*& c_ptr2, // NOLINT + int32_t*& c_ptr3, const float* scale, bool is_relu, // NOLINT + int k, int rem) { + asm volatile(GEMM_INT8_KERNEL GEMM_INT8_INT32_OUT + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [c_ptr0] "+r"(c_ptr0), [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), [c_ptr3] "+r"(c_ptr3), [k] "+r"(k) + : [is_relu] "r"(is_relu), [bias] "r"(bias), [rem] "r"(rem) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13", "q14", "q15", "r0", "cc"); +} + +template <> +inline void gemm_int8_kernel(const int8_t* a_ptr, const int8_t*& b_ptr, // NOLINT + const int32_t* bias, float*& c_ptr0, // NOLINT + float*& c_ptr1, float*& c_ptr2, float*& c_ptr3, // NOLINT + const float* scale, bool is_relu, int k, int rem) { + asm volatile(GEMM_INT8_KERNEL GEMM_INT8_FP32_OUT + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [c_ptr0] "+r"(c_ptr0), [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), [c_ptr3] "+r"(c_ptr3), [k] "+r"(k) + : [is_relu] "r"(is_relu), [bias] "r"(bias), [rem] "r"(rem), + [scale] "r"(scale) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13", "q14", "q15", "r0", "cc"); +} + +template <> +inline void gemm_int8_kernel(const int8_t* a_ptr, const int8_t*& b_ptr, // NOLINT + const int32_t* bias, int8_t*& c_ptr0, // NOLINT + int8_t*& c_ptr1, int8_t*& c_ptr2, int8_t*& c_ptr3, // NOLINT + const float* scale, bool is_relu, int k, int rem) { + asm volatile(GEMM_INT8_KERNEL GEMM_INT8_INT8_OUT + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [c_ptr0] "+r"(c_ptr0), [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), [c_ptr3] "+r"(c_ptr3), [k] "+r"(k) + : [is_relu] "r"(is_relu), [bias] "r"(bias), [rem] "r"(rem), + [scale] "r"(scale) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13", "q14", "q15", "r0", "cc"); +} +#endif //__aarch64__ // NOLINT + +// gemm wrapper +template +void gemm_prepack_oth_int8(const int8_t* A_packed, + const int8_t* B, + const int* bias, + Dtype* C, + int M, + int N, + int K, + bool is_bias, + bool is_relu, + bool is_transB, + const float* scale, + ARMContext* ctx) { + const int KUP = ROUNDUP(K, KBLOCK_INT8); + size_t llc_size = ctx->llc_size() / 4; + auto workspace = ctx->workspace_data(); + int threads = ctx->threads(); + int x_block = llc_size / (sizeof(int8_t) * (KUP + MBLOCK_INT8_OTH)); + x_block /= NBLOCK_INT8_OTH; + x_block *= NBLOCK_INT8_OTH; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + NBLOCK_INT8_OTH - 1) / NBLOCK_INT8_OTH; + x_block *= NBLOCK_INT8_OTH; + int k = K / KBLOCK_INT8; + int k_rem = K & (KBLOCK_INT8 - 1); + if (k_rem > KBLOCK_INT8 / 2) { + k_rem = 0; + k += 1; + } + int n_rem = N & (NBLOCK_INT8_OTH - 1); + + auto* b_tmp = static_cast(workspace); + + auto* zerobuf = static_cast(malloc(x_block * \ + (sizeof(int8_t) + sizeof(Dtype)))); + memset(zerobuf, 0, x_block * sizeof(int8_t)); + auto* trash_ptr = reinterpret_cast(zerobuf + \ + x_block * sizeof(int8_t)); + + //! apanel is pre_compute outside gemm + + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + bool flag_rem = false; + if (xmax >= N) { + xmax = N; + flag_rem = n_rem > 0; + } + int bblocks = (xmax - x0 + NBLOCK_INT8_OTH - 1) / NBLOCK_INT8_OTH; + //! load bpanel + int8_t* b_pannel = b_tmp; + if (is_transB) { + packb_trans_int8(b_pannel, B, K, 0, K, x0, xmax, zerobuf); + } else { + packb_int8(b_pannel, B, N, 0, K, x0, xmax, zerobuf); + } + +#pragma omp parallel for num_threads(threads) + for (unsigned int y = 0; y < M; y += MBLOCK_INT8_OTH) { + Dtype out0[NBLOCK_INT8_OTH] = {0}; + Dtype out1[NBLOCK_INT8_OTH] = {0}; + Dtype out2[NBLOCK_INT8_OTH] = {0}; + Dtype out3[NBLOCK_INT8_OTH] = {0}; + Dtype* c_ptr0 = C + y * N + x0; + Dtype* c_ptr1 = c_ptr0 + N; + Dtype* c_ptr2 = c_ptr1 + N; + Dtype* c_ptr3 = c_ptr2 + N; + Dtype* tmp0 = nullptr; + Dtype* tmp1 = nullptr; + Dtype* tmp2 = nullptr; + Dtype* tmp3 = nullptr; + float32_t scale_local[4]; + int32_t bias_local[4] = {0, 0, 0, 0}; + if (is_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + } + if (scale) { + scale_local[0] = scale[y]; + scale_local[1] = scale[y + 1]; + scale_local[2] = scale[y + 2]; + scale_local[3] = scale[y + 3]; + } + if (y + MBLOCK_INT8_OTH > M) { + switch (y + MBLOCK_INT8_OTH - M) { + case 3: + c_ptr1 = trash_ptr; + case 2: + c_ptr2 = trash_ptr; + case 1: + c_ptr3 = trash_ptr; + default: + break; + } + } + const int8_t* a_ptr_l = A_packed + y * KUP; + const int8_t* b_ptr = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if (flag_rem && (xb == bblocks - 1)) { + tmp0 = c_ptr0; + tmp1 = c_ptr1; + tmp2 = c_ptr2; + tmp3 = c_ptr3; + c_ptr0 = out0; + c_ptr1 = out1; + c_ptr2 = out2; + c_ptr3 = out3; + } + gemm_int8_kernel(a_ptr_l, b_ptr, bias_local, + c_ptr0, c_ptr1, c_ptr2, c_ptr3, + scale_local, is_relu, k, k_rem); + if (flag_rem && (xb == bblocks - 1)) { + for (int i = 0; i < n_rem; ++i) { + *(tmp0++) = out0[i]; + *(tmp1++) = out1[i]; + *(tmp2++) = out2[i]; + *(tmp3++) = out3[i]; + } + } + } + } + } + free(zerobuf); +} + +/***********************************************************************/ +// prepack A according to gemm kernel +// A block size: (<4x2>x1) x2, with unroll=2 can be described as below: +// origin A data: +// A_origin(no trans, m x k): +// r0: ==> a0, b0, c0, d0, e0, f0, g0, h0 +// r1: ==> a1, b1, c1, d1, e1, f1, g1, h1 +// r2: ==> a2, b2, c2, d2, e2, f2, g2, h2 +// r3: ==> a3, b3, c3, d3, e3, f3, g3, h3 +// packed A +// a0,b0, a1,b1, a2,b2, a3,b3; +// c0,d0, c1,d1, c2,d2, c3,d3; +// e0,f0, e1,f1, e2,f2, e3,f3; +// g0,h0, g1,h1, g2,h2, g3,h3; +/***********************************************************************/ +void prepackA_m4k2x2_int8(int8_t* out, const int8_t* in, const int ldin, + const int m0, const int mmax, const int k0, + const int kmax) { + int y_len = mmax - m0; + int x_len = kmax - k0; + int x_len_roundup = ROUNDUP(x_len, KBLOCK_INT8); + auto zerobuff = static_cast(malloc(x_len_roundup * sizeof(char))); + memset(zerobuff, 0, sizeof(char) * x_len_roundup); + + const int8_t* inptr = in + m0 * ldin + k0; + uint8_t remain = static_cast(x_len & (KBLOCK_INT8 - 1)); + +#pragma omp parallel for + for (int y = 0; y < y_len; y += MBLOCK_INT8_OTH) { + const int8_t* ptr0 = inptr + y * ldin; + const int8_t* ptr1 = ptr0 + ldin; + const int8_t* ptr2 = ptr1 + ldin; + const int8_t* ptr3 = ptr2 + ldin; + //! cope with row index exceed real size, set to zero buffer + if ((y + MBLOCK_INT8_OTH) > y_len) { + switch ((y + MBLOCK_INT8_OTH) - y_len) { + case 3: + ptr1 = zerobuff; + case 2: + ptr2 = zerobuff; + case 1: + ptr3 = zerobuff; + default: + break; + } + } + int8_t* ptr_out = out + y * x_len_roundup; + int i = 0; + for (; i < x_len + 1 - 2 * KBLOCK_INT8; i += 2 * KBLOCK_INT8) { +#ifdef __aarch64__ + asm volatile( + "ld1 {v0.8b}, [%[ptr0]], #8\n" /* load r0, 8 int8 */ + "ld1 {v1.8b}, [%[ptr1]], #8\n" /* load r1, 8 int8 */ + "ld1 {v2.8b}, [%[ptr2]], #8\n" /* load r2, 8 int8 */ + "ld1 {v3.8b}, [%[ptr3]], #8\n" /* load r3, 8 int8 */ + "trn1 v4.4h, v0.4h, v1.4h\n" /* get a0,b0, a2,b2 */ + "trn2 v5.4h, v0.4h, v1.4h\n" /* get a1,b1, a3,b3 */ + "trn1 v6.4h, v2.4h, v3.4h\n" /* get c0,d0, c2,d2 */ + "trn2 v7.4h, v2.4h, v3.4h\n" /* get c1,d1, c3,d3 */ + "trn1 v0.2s, v4.2s, v6.2s\n" /* get a0,b0, c0,d0 */ + "trn2 v2.2s, v4.2s, v6.2s\n" /* get a2,b2, c2,d2 */ + "trn1 v1.2s, v5.2s, v7.2s\n" /* get a1,b1, c1,d1 */ + "trn2 v3.2s, v5.2s, v7.2s\n" /* get a3,b3, c3,d3 */ + "st1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%[ptr_out]], #32\n" /* write + out*/ + : [ptr_out] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "cc", "memory"); +#else // armv7 + asm volatile( + "vld1.8 {d0}, [%[ptr0]]!\n" /* load r0, 8 int8, + a0,b0,c0,d0,e0,f0,g0,h0 */ + "vld1.8 {d1}, [%[ptr1]]!\n" /* load r1, 8 int8, + a1,b1,c1,d1,e1,f1,g1,h1 */ + "vld1.8 {d2}, [%[ptr2]]!\n" /* load r2, 8 int8, + a2,b2,c2,d2,e2,f2,g2,h2 */ + "vld1.8 {d3}, [%[ptr3]]!\n" /* load r3, 8 int8, + a3,b3,c3,d3,e3,f3,g3,h3 */ + "vtrn.16 d0, d1\n" /* trans, d0: a0,b0,a1,b1, e0,f0,e1,f1; d1: + c0,d0,c1,d1, g0,h0,g1,h1 */ + "vtrn.16 d2, d3\n" /* trans, d2: a2,b2,a3,b3, e2,f2,e3,f3; d3: + c2,d2,c3,d3, g2,h2,g3,h3 */ + "vtrn.32 d0, d2\n" /* trans, d0: a0,b0,a1,b1, a2,b2,a3,b3; d2: + e0,f0,e1,f1, e2,f2,e3,f3 */ + "vtrn.32 d1, d3\n" /* trans, d1: c0,d0,c1,d1, e2,f2,e3,f3; d3: + g0,h0,g1,h1, g2,h2,g3,h3 */ + "vst1.32 {d0-d3}, [%[outptr]]!\n" /* write to output ptr */ + : [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) + : + : "q0", "q1", "cc", "memory"); +#endif //__aarch64 // NOLINT + } + if (i + KBLOCK_INT8 <= x_len) { + ptr_out[0] = ptr0[0]; + ptr_out[1] = ptr0[1]; + ptr_out[2] = ptr1[0]; + ptr_out[3] = ptr1[1]; + ptr_out[4] = ptr2[0]; + ptr_out[5] = ptr2[1]; + ptr_out[6] = ptr3[0]; + ptr_out[7] = ptr3[1]; + // unroll + ptr_out[8] = ptr0[2]; + ptr_out[9] = ptr0[3]; + ptr_out[10] = ptr1[2]; + ptr_out[11] = ptr1[3]; + ptr_out[12] = ptr2[2]; + ptr_out[13] = ptr2[3]; + ptr_out[14] = ptr3[2]; + ptr_out[15] = ptr3[3]; + ptr_out += 16; + ptr0 += 4; + ptr1 += 4; + ptr2 += 4; + ptr3 += 4; + } + switch (remain) { + case 0: + break; + case 1: + ptr_out[0] = ptr0[0]; + ptr_out[1] = 0; + ptr_out[2] = ptr1[0]; + ptr_out[3] = 0; + ptr_out[4] = ptr2[0]; + ptr_out[5] = 0; + ptr_out[6] = ptr3[0]; + ptr_out[7] = 0; + // unroll + ptr_out[8] = 0; + ptr_out[9] = 0; + ptr_out[10] = 0; + ptr_out[11] = 0; + ptr_out[12] = 0; + ptr_out[13] = 0; + ptr_out[14] = 0; + ptr_out[15] = 0; + ptr_out += 16; + break; + case 2: + ptr_out[0] = ptr0[0]; + ptr_out[1] = ptr0[1]; + ptr_out[2] = ptr1[0]; + ptr_out[3] = ptr1[1]; + ptr_out[4] = ptr2[0]; + ptr_out[5] = ptr2[1]; + ptr_out[6] = ptr3[0]; + ptr_out[7] = ptr3[1]; + // unroll + ptr_out[8] = 0; + ptr_out[9] = 0; + ptr_out[10] = 0; + ptr_out[11] = 0; + ptr_out[12] = 0; + ptr_out[13] = 0; + ptr_out[14] = 0; + ptr_out[15] = 0; + ptr_out += 16; + break; + case 3: + ptr_out[0] = ptr0[0]; + ptr_out[1] = ptr0[1]; + ptr_out[2] = ptr1[0]; + ptr_out[3] = ptr1[1]; + ptr_out[4] = ptr2[0]; + ptr_out[5] = ptr2[1]; + ptr_out[6] = ptr3[0]; + ptr_out[7] = ptr3[1]; + // unroll + ptr_out[8] = ptr0[2]; + ptr_out[9] = 0; + ptr_out[10] = ptr1[2]; + ptr_out[11] = 0; + ptr_out[12] = ptr2[2]; + ptr_out[13] = 0; + ptr_out[14] = ptr3[2]; + ptr_out[15] = 0; + ptr_out += 16; + break; + default: + break; + } + } + free(zerobuff); +} + +/***************************************************************************/ +// prepack A according to gemm kernel +// A block size: <4x2>x2, unroll x4, can be described as below: +// origin A data: +// A_origin(no trans, k x m): +// r0: ==> a0, a1, a2, a3 .... a12, a13, a14, a15 +// r1: ==> b0, b1, b2, b3 .... b12, b13, b14, b15 +// r2: ==> c0, c1, c2, c3 .... c12, c13, c14, c15 +// r3: ==> d0, d1, d2, d3 .... d12, d13, d14, d15 +// packed A: +// a0,b0, a1,b1, a2,b2, a3,b3; +// c0,d0, c1,d1, c2,d2, c3,d3;----block0 +// a4,b4, a5,b5, a6,b6, a7,b7; +// c4,d4, c5,d5, c6,d6, c7,d7;----block1 +// a8,b8, a9,b9, a10,b10, a11,b11; +// c8,d8, c9,d9, c10,d10, c11,d11;----block2 +// a12,b12, a13,b13, a14,b14, a15,b15; +// c12,d12, c13,d13, c14,d14, c15,d15;----block3 +/***************************************************************************/ +void prepackA_m4k2x2_trans_int8(int8_t* out, const int8_t* in, const int ldin, + const int m0, const int mmax, const int k0, + const int kmax) { + int xlen = mmax - m0; + int ylen = kmax - k0; + int ylen_roundup = ROUNDUP(ylen, KBLOCK_INT8); + int xlen_roundup = ROUNDUP(xlen, MBLOCK_INT8_OTH); + + const int MUNROLL = 4; + int mcnt = xlen / (MUNROLL * MBLOCK_INT8_OTH); + int x_rem = xlen & (MUNROLL * MBLOCK_INT8_OTH - 1); + int m_rem = (x_rem + MBLOCK_INT8_OTH - 1) / MBLOCK_INT8_OTH; + + const uint8_t mask_buffer[16] = {0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15}; + int8x16_t vzero = vdupq_n_s8(0); + uint8x16_t vmask = vcltq_u8(vld1q_u8(mask_buffer), vdupq_n_u8(x_rem)); + + int stride_out = ylen_roundup * MBLOCK_INT8_OTH; + + int8_t* zerobuf = static_cast(malloc(xlen_roundup)); + memset(zerobuf, 0, xlen_roundup); + + const int8_t* inr = in + ldin * k0 + m0; +#pragma omp parallel for + for (int y = 0; y < ylen; y += KBLOCK_INT8) { + const int8_t* ptr0 = inr + y * ldin; + const int8_t* ptr1 = ptr0 + ldin; + const int8_t* ptr2 = ptr1 + ldin; + const int8_t* ptr3 = ptr2 + ldin; + int8_t* ptr_out = out + MBLOCK_INT8_OTH * y; + if (y + KBLOCK_INT8 > ylen) { + switch (y + KBLOCK_INT8 - ylen) { + case 3: + ptr1 = zerobuf; + case 2: + ptr2 = zerobuf; + case 1: + ptr3 = zerobuf; + default: + break; + } + } + int k = mcnt; + int rem = m_rem; +#ifdef __aarch64__ + asm volatile( + "ld1 {v0.16b}, [%[ptr0]], #16\n" /* load r0 */ + "ld1 {v1.16b}, [%[ptr1]], #16\n" /* load r1 */ + "ld1 {v2.16b}, [%[ptr2]], #16\n" /* load r2 */ + "ld1 {v3.16b}, [%[ptr3]], #16\n" /* load r3 */ + "cbz %w[k], 1f\n" /* jump to remain */ + "0:\n" /* main loop */ + /* trans 16b */ + "trn1 v4.16b, v0.16b, v1.16b\n" /* get a0,b0, a2,b2, a4,b4, a6,b6, + a8,b8, a10,b10, a12,b12, a14,b14 */ + "trn2 v5.16b, v0.16b, v1.16b\n" /* get a1,b1, a3,b3, a5,b5, a7,b7, + a9,b9, a11,b11, a13,b13, a15,b15 */ + "trn1 v6.16b, v2.16b, v3.16b\n" /* get c0,d0, c2,d2, c4,d4, c6,d6, + c8,d8, c10,d10, c12,d12, c14,d14 */ + "trn2 v7.16b, v2.16b, v3.16b\n" /* get c1,d1, c3,d3, c5,d5, c7,d7, + c9,d9, c11,d11, c13,d13, c15,d15 */ + "ld1 {v0.16b}, [%[ptr0]], #16\n" /* load r0 */ + "ld1 {v1.16b}, [%[ptr1]], #16\n" /* load r1 */ + "subs %w[k], %w[k], #1\n" /* loop cnt -1 */ + /* trans 8h */ + "trn1 v8.8h, v4.8h, v5.8h\n" /* get a0,b0, a1,b1, a4,b4, a5,b5, a8,b8, + a9,b9, a12,b12, a13,b13 */ + "trn2 v9.8h, v4.8h, v5.8h\n" /* get a2,b2, a3,b3, a6,b6, a7,b7, + a10,b10, a11,b11, a14,b14, a15,b15 */ + "trn1 v10.8h, v6.8h, v7.8h\n" /* get c0,d0, c1,d1, c4,d4, c5,d5, + c8,d8, c9,d9, c12,d12, c13,d13 */ + "trn2 v11.8h, v6.8h, v7.8h\n" /* get c2,d2, c3,d3, c6,d6, c7,d7, + c10,d10, c11,d11, c14,d14, c15,d15 */ + /* trans 4s */ + "ld1 {v2.16b}, [%[ptr2]], #16\n" /* load r2 */ + "trn1 v4.4s, v8.4s, v9.4s\n" /* get a0,b0, a1,b1, a2,b2, a3,b3, a8,b8, + a9,b9, a10,b10, a11,b11 */ + "trn2 v5.4s, v8.4s, v9.4s\n" /* get a4,b4, a5,b5, a6,b6, a7,b7, + a12,b12, a13,b13, a14,b14, a15,b15 */ + "trn1 v6.4s, v10.4s, v11.4s\n" /* get c0,d0, c1,d1, c2,d2, c3,d3, + c8,d8, c9,d9, c10,d10, c11,d11 */ + "trn2 v7.4s, v10.4s, v11.4s\n" /* get c4,d4, c5,d5, c6,d6, c7,d7, + c12,d12, c13,d13, c14,d14, c15,d15 + */ + /* trans 2d */ + "ld1 {v3.16b}, [%[ptr3]], #16\n" /* load r3 */ + "trn1 v8.2d, v4.2d, v6.2d\n" /* get a0,b0, a1,b1, a2,b2, a3,b3, c0,d0, + c1,d1, c2,d2, c3,d3 */ + "trn1 v9.2d, v5.2d, v7.2d\n" /* get a4,b4, a5,b5, a6,b6, a7,b7, c4,d4, + c5,d5, c6,d6, c7,d7 */ + "trn2 v10.2d, v4.2d, v6.2d\n" /* get a8,b8, a9,b9, a10,b10, a11,b11, + c8,d8, c9,d9, c10,d10, c11,d11 */ + "trn2 v11.2d, v5.2d, v7.2d\n" /* get a12,b12, a13,b13, a14,b14, + a15,b15, c12,d12, c13,d13, c14,d14, + c15,d15 */ + "st1 {v8.16b}, [%[ptr_out]], %[stride]\n" /* write block0, address + + stride */ + "st1 {v9.16b}, [%[ptr_out]], %[stride]\n" /* write block1, address + + stride */ + "st1 {v10.16b}, [%[ptr_out]], %[stride]\n" /* write block2, address + + stride */ + "st1 {v11.16b}, [%[ptr_out]], %[stride]\n" /* write block3, address + + stride */ + "bgt 0b\n" /* jump to main loop */ + "1:\n" /* process remain */ + "cbz %w[rem], 2f\n" /* skip to remain */ + /* bit select */ + "bif v0.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */ + "bif v1.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */ + "bif v2.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */ + "bif v3.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */ + /* trans 16b */ + "trn1 v4.16b, v0.16b, v1.16b\n" /* get a0,b0, a2,b2, a4,b4, a6,b6, + a8,b8, a10,b10, a12,b12, a14,b14 */ + "trn2 v5.16b, v0.16b, v1.16b\n" /* get a1,b1, a3,b3, a5,b5, a7,b7, + a9,b9, a11,b11, a13,b13, a15,b15 */ + "trn1 v6.16b, v2.16b, v3.16b\n" /* get c0,d0, c2,d2, c4,d4, c6,d6, + c8,d8, c10,d10, c12,d12, c14,d14 */ + "trn2 v7.16b, v2.16b, v3.16b\n" /* get c1,d1, c3,d3, c5,d5, c7,d7, + c9,d9, c11,d11, c13,d13, c15,d15 */ + /* trans 8h */ + "trn1 v8.8h, v4.8h, v5.8h\n" /* get a0,b0, a1,b1, a4,b4, a5,b5, a8,b8, + a9,b9, a12,b12, a13,b13 */ + "trn2 v9.8h, v4.8h, v5.8h\n" /* get a2,b2, a3,b3, a6,b6, a7,b7, + a10,b10, a11,b11, a14,b14, a15,b15 */ + "trn1 v10.8h, v6.8h, v7.8h\n" /* get c0,d0, c1,d1, c4,d4, c5,d5, + c8,d8, c9,d9, c12,d12, c13,d13 */ + "trn2 v11.8h, v6.8h, v7.8h\n" /* get c2,d2, c3,d3, c6,d6, c7,d7, + c10,d10, c11,d11, c14,d14, c15,d15 */ + /* trans 4s */ + "trn1 v4.4s, v8.4s, v9.4s\n" /* get a0,b0, a1,b1, a2,b2, a3,b3, a8,b8, + a9,b9, a10,b10, a11,b11 */ + "trn2 v5.4s, v8.4s, v9.4s\n" /* get a4,b4, a5,b5, a6,b6, a7,b7, + a12,b12, a13,b13, a14,b14, a15,b15 */ + "trn1 v6.4s, v10.4s, v11.4s\n" /* get c0,d0, c1,d1, c2,d2, c3,d3, + c8,d8, c9,d9, c10,d10, c11,d11 */ + "trn2 v7.4s, v10.4s, v11.4s\n" /* get c4,d4, c5,d5, c6,d6, c7,d7, + c12,d12, c13,d13, c14,d14, c15,d15 + */ + /* trans 2d */ + "trn1 v8.2d, v4.2d, v6.2d\n" /* get a0,b0, a1,b1, a2,b2, a3,b3, c0,d0, + c1,d1, c2,d2, c3,d3 */ + "trn1 v9.2d, v5.2d, v7.2d\n" /* get a4,b4, a5,b5, a6,b6, a7,b7, c4,d4, + c5,d5, c6,d6, c7,d7 */ + "trn2 v10.2d, v4.2d, v6.2d\n" /* get a8,b8, a9,b9, a10,b10, a11,b11, + c8,d8, c9,d9, c10,d10, c11,d11 */ + "trn2 v11.2d, v5.2d, v7.2d\n" /* get a12,b12, a13,b13, a14,b14, + a15,b15, c12,d12, c13,d13, c14,d14, + c15,d15 */ + /* check remain size */ + "subs %w[rem], %w[rem], #1\n" /* check remain num */ + "st1 {v8.16b}, [%[ptr_out]], %[stride]\n" /* write 0 */ + "beq 2f\n" /* remain = 1 */ + "subs %w[rem], %w[rem], #1\n" /* check remain num */ + "st1 {v9.16b}, [%[ptr_out]], %[stride]\n" /* write 1 */ + "beq 2f\n" /* remain = 2 */ + "subs %w[rem], %w[rem], #1\n" /* check remain num */ + "st1 {v10.16b}, [%[ptr_out]], %[stride]\n" /* write 2 */ + "beq 2f\n" /* remain = 3 */ + "st1 {v11.16b}, [%[ptr_out]]\n" /* write 3 */ + /* end */ + "2:\n" /* end */ + : [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), [ptr2] "+r"(ptr2), + [ptr3] "+r"(ptr3), [k] "+r"(k), [rem] "+r"(rem), + [ptr_out] "+r"(ptr_out) + : [mask] "w"(vmask), [vzero] "w"(vzero), [stride] "r"(stride_out) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "cc"); +#else // armv7 + asm volatile( + "vld1.8 {d0-d1}, [%[ptr0]]!\n" /* load r0 */ + "vld1.8 {d2-d3}, [%[ptr1]]!\n" /* load r1 */ + "vld1.8 {d4-d5}, [%[ptr2]]!\n" /* load r2 */ + "vld1.8 {d6-d7}, [%[ptr3]]!\n" /* load r3 */ + "cmp %[k], #0\n" /* check main loop */ + "beq 1f\n" /* jump to remain */ + "0:\n" /* main loop */ + /* trans 16b */ + "vtrn.8 q0, q1\n" /* get q0: a0,b0, a2,b2, a4,b4, a6,b6, a8,b8, a10,b10, + a12,b12, a14,b14; q1: a1,b1, a3,b3, a5,b5, a7,b7, + a9,b9, a11,b11, a13,b13, a15,b15 */ + "vtrn.8 q2, q3\n" /* get q2: c0,d0, c2,d2, c4,d4, c6,d6, c8,d8, c10,d10, + c12,d12, c14,d14; q3: c0,d0, c2,d2, c4,d4, c6,d6, + c8,d8, c10,d10, c12,d12, c14,d14 */ + "subs %[k], %[k], #1\n" /* loop cnt -1 */ + /* trans 8h */ + "vtrn.16 q0, q1\n" /* get q0: a0,b0, a1,b1, a4,b4, a5,b5, a8,b8, + a9,b9, a12,b12, a13,b13; q1: a2,b2, a3,b3, + a6,b6, a7,b7, a10,b10, a11,b11, a14,b14, + a15,b15 */ + "vtrn.16 q2, q3\n" /* get q2: c0,d0, c1,d1, c4,d4, c5,d5, c8,d8, + c9,d9, c12,d12, c13,d13; q3: c2,d2, c3,d3, + c6,d6, c7,d7, c10,d10, c11,d11, c14,d14, + c15,d15 */ + /* trans 4s */ + "vtrn.32 q0, q1\n" /* get q0: a0,b0, a1,b1, a2,b2, a3,b3, a8,b8, + a9,b9, a10,b10, a11,b11; q1: a4,b4, a5,b5, + a6,b6, a7,b7, a12,b12, a13,b13, a14,b14, + a15,b15 */ + "vtrn.32 q2, q3\n" /* get q2: c0,d0, c1,d1, c2,d2, c3,d3, c8,d8, + c9,d9, c10,d10, c11,d11; q3: c4,d4, c5,d5, + c6,d6, c7,d7, c12,d12, c13,d13, c14,d14, + c15,d15 */ + /* trans 2d */ + "vswp d1, d4\n" /* get q0: a0,b0, a1,b1, a2,b2, a3,b3, c0,d0, c1,d1, + c2,d2, c3,d3; q2: a8,b8, a9,b9, a10,b10, a11,b11, + c8,d8, c9,d9, c10,d10, c11,d11 */ + "vswp d3, d6\n" /* get q1: a4,b4, a5,b5, a6,b6, a7,b7, c4,d4, c5,d5, + c6,d6, c7,d7; q3: a12,b12, a13,b13, a14,b14, + a15,b15, c12,d12, c13,d13, c14,d14, c15,d15 */ + "vst1.8 {d0-d1}, [%[ptr_out]], %[stride]\n" /* write block0, address + + stride */ + "vst1.8 {d2-d3}, [%[ptr_out]], %[stride]\n" /* write block1, address + + stride */ + "vst1.8 {d4-d5}, [%[ptr_out]], %[stride]\n" /* write block2, address + + stride */ + "vst1.8 {d6-d7}, [%[ptr_out]], %[stride]\n" /* write block3, address + + stride */ + "vld1.8 {d0-d1}, [%[ptr0]]!\n" /* load r0 */ + "vld1.8 {d2-d3}, [%[ptr1]]!\n" /* load r1 */ + "vld1.8 {d4-d5}, [%[ptr2]]!\n" /* load r2 */ + "vld1.8 {d6-d7}, [%[ptr3]]!\n" /* load r3 */ + "bgt 0b\n" /* jump to main loop */ + "1:\n" /* process remain */ + "cmp %[rem], #0\n" /* check remain */ + "beq 2f\n" /* skip to remain */ + /* bit select */ + "vbif q0, %q[vzero], %q[mask]\n" /* pad 0 */ + "vbif q1, %q[vzero], %q[mask]\n" /* pad 0 */ + "vbif q2, %q[vzero], %q[mask]\n" /* pad 0 */ + "vbif q3, %q[vzero], %q[mask]\n" /* pad 0 */ + /* trans 16b */ + "vtrn.8 q0, q1\n" /* get q0: a0,b0, a2,b2, a4,b4, a6,b6, a8,b8, a10,b10, + a12,b12, a14,b14; q1: a1,b1, a3,b3, a5,b5, a7,b7, + a9,b9, a11,b11, a13,b13, a15,b15 */ + "vtrn.8 q2, q3\n" /* get q2: c0,d0, c2,d2, c4,d4, c6,d6, c8,d8, c10,d10, + c12,d12, c14,d14; q3: c0,d0, c2,d2, c4,d4, c6,d6, + c8,d8, c10,d10, c12,d12, c14,d14 */ + /* trans 8h */ + "vtrn.16 q0, q1\n" /* get q0: a0,b0, a1,b1, a4,b4, a5,b5, a8,b8, + a9,b9, a12,b12, a13,b13; q1: a2,b2, a3,b3, + a6,b6, a7,b7, a10,b10, a11,b11, a14,b14, + a15,b15 */ + "vtrn.16 q2, q3\n" /* get q2: c0,d0, c1,d1, c4,d4, c5,d5, c8,d8, + c9,d9, c12,d12, c13,d13; q3: c2,d2, c3,d3, + c6,d6, c7,d7, c10,d10, c11,d11, c14,d14, + c15,d15 */ + /* trans 4s */ + "vtrn.32 q0, q1\n" /* get q0: a0,b0, a1,b1, a2,b2, a3,b3, a8,b8, + a9,b9, a10,b10, a11,b11; q1: a4,b4, a5,b5, + a6,b6, a7,b7, a12,b12, a13,b13, a14,b14, + a15,b15 */ + "vtrn.32 q2, q3\n" /* get q2: c0,d0, c1,d1, c2,d2, c3,d3, c8,d8, + c9,d9, c10,d10, c11,d11; q3: c4,d4, c5,d5, + c6,d6, c7,d7, c12,d12, c13,d13, c14,d14, + c15,d15 */ + /* trans 2d */ + "vswp d1, d4\n" /* get q0: a0,b0, a1,b1, a2,b2, a3,b3, c0,d0, c1,d1, + c2,d2, c3,d3; q2: a8,b8, a9,b9, a10,b10, a11,b11, + c8,d8, c9,d9, c10,d10, c11,d11 */ + "vswp d3, d6\n" /* get q1: a4,b4, a5,b5, a6,b6, a7,b7, c4,d4, c5,d5, + c6,d6, c7,d7; q3: a12,b12, a13,b13, a14,b14, + a15,b15, c12,d12, c13,d13, c14,d14, c15,d15 */ + /* check remain size */ + "subs %[rem], %[rem], #1\n" /* check remain num */ + "vst1.8 {d0-d1}, [%[ptr_out]], %[stride]\n" /* write 0 */ + "beq 2f\n" /* remain = 1 */ + "subs %[rem], %[rem], #1\n" /* check remain num */ + "vst1.8 {d2-d3}, [%[ptr_out]], %[stride]\n" /* write 1 */ + "beq 2f\n" /* remain = 2 */ + "subs %[rem], %[rem], #1\n" /* check remain num */ + "vst1.8 {d4-d5}, [%[ptr_out]], %[stride]\n" /* write 2 */ + "beq 2f\n" /* remain = 3 */ + "vst1.8 {d6-d7}, [%[ptr_out]], %[stride]\n" /* write 3 */ + /* end */ + "2:\n" /* end */ + : [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), [ptr2] "+r"(ptr2), + [ptr3] "+r"(ptr3), [k] "+r"(k), [rem] "+r"(rem), + [ptr_out] "+r"(ptr_out) + : [mask] "w"(vmask), [vzero] "w"(vzero), [stride] "r"(stride_out) + : "q0", "q1", "q2", "q3", "cc"); +#endif //__aarch64__ // NOLINT + } + free(zerobuf); +} + +/**************************************************************************/ +// for armv8 +// prepack B according to gemm kernel +// B block size: (<4x2>x4) x2, can be described as below: +// origin B data: +// B_origin(no trans, k x n): +// r0: ==> a0, a1, a2, a3 .... a12, a13, a14, a15 +// r1: ==> b0, b1, b2, b3 .... b12, b13, b14, b15 +// r2: ==> c0, c1, c2, c3 .... c12, c13, c14, c15 +// r3: ==> d0, d1, d2, d3 .... d12, d13, d14, d15 +// packed B: +// a0,b0, a1,b1, a2,b2, a3,b3; +// c0,d0, c1,d1, c2,d2, c3,d3; +// . +// . +// . +// a12,b12, a13,b13, a14,b14, a15,b15; +// c12,d12, c13,d13, c14,d14, c15,d15; +// for armv7 +// prepack B according to gemm kernel +// B block size: (<4x2>x4) x2, can be described as below: +// origin B data: +// B_origin(no trans, k x n): +// r0: ==> a0, a1, a2, a3, a4, a5, a6, a7 +// r1: ==> b0, b1, b2, b3, b4, b5, b6, b7 +// r2: ==> c0, c1, c2, c3, c4, c5, c6, c7 +// r3: ==> d0, d1, d2, d3, d4, d5, d6, d7 +// packed B: +// a0,b0, a1,b1, a2,b2, a3,b3; +// a4,b4, a5,b5, a6,b6, a7,b7; +// c0,d0, c1,d1, c2,d2, c3,d3; +// c4,d4, c5,d5, c6,d6, c7,d7; +/***************************************************************************/ +void packb_int8(int8_t* out, const int8_t* in, const int ldin, const int k0, + const int kmax, const int n0, const int nmax, + const int8_t* zerobuf) { + const int8_t* inptr = in + k0 * ldin + n0; + const uint8_t mask_buffer[16] = {0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15}; + int x_len = nmax - n0; + int y_len = kmax - k0; + int kup = ROUNDUP(y_len, KBLOCK_INT8); + int kcnt = x_len / NBLOCK_INT8_OTH; + int rem = x_len & (NBLOCK_INT8_OTH - 1); + int stride_out = NBLOCK_INT8_OTH * kup; + + int8x16_t vzero = vdupq_n_s8(0); + uint8x16_t vmask = vcltq_u8(vld1q_u8(mask_buffer), vdupq_n_u8(rem)); +#pragma omp parallel for + for (int y = 0; y < y_len; y += KBLOCK_INT8) { + const int8_t* ptr0 = inptr + y * ldin; + const int8_t* ptr1 = ptr0 + ldin; + const int8_t* ptr2 = ptr1 + ldin; + const int8_t* ptr3 = ptr2 + ldin; + if (y + KBLOCK_INT8 > y_len) { + switch (y + KBLOCK_INT8 - y_len) { + case 3: + ptr1 = zerobuf; + case 2: + ptr2 = zerobuf; + case 1: + ptr3 = zerobuf; + default: + break; + } + } + int8_t* outptr_row_col = out + y * NBLOCK_INT8_OTH; + int k = kcnt; +#ifdef __aarch64__ + asm volatile( + "ld1 {v0.16b}, [%[ptr0]], #16\n" /* load r0 */ + "ld1 {v1.16b}, [%[ptr1]], #16\n" /* load r1 */ + "ld1 {v2.16b}, [%[ptr2]], #16\n" /* load r2 */ + "ld1 {v3.16b}, [%[ptr3]], #16\n" /* load r3 */ + "cbz %w[k], 1f\n" /* jump to remain */ + "0:\n" /* main loop */ + /* trans 16b */ + "trn1 v4.16b, v0.16b, v1.16b\n" /* get a0,b0, a2,b2, a4,b4, a6,b6, + a8,b8, a10,b10, a12,b12, a14,b14 */ + "trn2 v5.16b, v0.16b, v1.16b\n" /* get a1,b1, a3,b3, a5,b5, a7,b7, + a9,b9, a11,b11, a13,b13, a15,b15 */ + "trn1 v6.16b, v2.16b, v3.16b\n" /* get c0,d0, c2,d2, c4,d4, c6,d6, + c8,d8, c10,d10, c12,d12, c14,d14 */ + "trn2 v7.16b, v2.16b, v3.16b\n" /* get c1,d1, c3,d3, c5,d5, c7,d7, + c9,d9, c11,d11, c13,d13, c15,d15 */ + "ld1 {v0.16b}, [%[ptr0]], #16\n" /* load r0 */ + "ld1 {v1.16b}, [%[ptr1]], #16\n" /* load r1 */ + "subs %w[k], %w[k], #1\n" /* loop cnt -1 */ + /* trans 8h */ + "trn1 v8.8h, v4.8h, v5.8h\n" /* get a0,b0, a1,b1, a4,b4, a5,b5, a8,b8, + a9,b9, a12,b12, a13,b13 */ + "trn2 v9.8h, v4.8h, v5.8h\n" /* get a2,b2, a3,b3, a6,b6, a7,b7, + a10,b10, a11,b11, a14,b14, a15,b15 */ + "trn1 v10.8h, v6.8h, v7.8h\n" /* get c0,d0, c1,d1, c4,d4, c5,d5, + c8,d8, c9,d9, c12,d12, c13,d13 */ + "trn2 v11.8h, v6.8h, v7.8h\n" /* get c2,d2, c3,d3, c6,d6, c7,d7, + c10,d10, c11,d11, c14,d14, c15,d15 */ + /* trans 4s */ + "ld1 {v2.16b}, [%[ptr2]], #16\n" /* load r2 */ + "trn1 v4.4s, v8.4s, v9.4s\n" /* get a0,b0, a1,b1, a2,b2, a3,b3, a8,b8, + a9,b9, a10,b10, a11,b11 */ + "trn2 v5.4s, v8.4s, v9.4s\n" /* get a4,b4, a5,b5, a6,b6, a7,b7, + a12,b12, a13,b13, a14,b14, a15,b15 */ + "trn1 v6.4s, v10.4s, v11.4s\n" /* get c0,d0, c1,d1, c2,d2, c3,d3, + c8,d8, c9,d9, c10,d10, c11,d11 */ + "trn2 v7.4s, v10.4s, v11.4s\n" /* get c4,d4, c5,d5, c6,d6, c7,d7, + c12,d12, c13,d13, c14,d14, c15,d15 + */ + /* trans 2d */ + "ld1 {v3.16b}, [%[ptr3]], #16\n" /* load r3 */ + "trn1 v8.2d, v4.2d, v6.2d\n" /* get a0,b0, a1,b1, a2,b2, a3,b3, c0,d0, + c1,d1, c2,d2, c3,d3 */ + "trn2 v10.2d, v4.2d, v6.2d\n" /* get a8,b8, a9,b9, a10,b10, a11,b11, + c8,d8, c9,d9, c10,d10, c11,d11 */ + "trn1 v9.2d, v5.2d, v7.2d\n" /* get a4,b4, a5,b5, a6,b6, a7,b7, c4,d4, + c5,d5, c6,d6, c7,d7 */ + "trn2 v11.2d, v5.2d, v7.2d\n" /* get a12,b12, a13,b13, a14,b14, + a15,b15, c12,d12, c13,d13, c14,d14, + c15,d15 */ + "st1 {v8.16b, v9.16b, v10.16b, v11.16b}, [%[ptr_out]], %[stride]\n" + "bgt 0b\n" /* jump to main loop */ + "1:\n" /* process remain */ + "cbz %w[rem], 2f\n" /* jump to remain */ + /* bit select */ + "bif v0.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */ + "bif v1.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */ + "bif v2.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */ + "bif v3.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */ + /* trans 16b */ + "trn1 v4.16b, v0.16b, v1.16b\n" /* get a0,b0, a2,b2, a4,b4, a6,b6, + a8,b8, a10,b10, a12,b12, a14,b14 */ + "trn2 v5.16b, v0.16b, v1.16b\n" /* get a1,b1, a3,b3, a5,b5, a7,b7, + a9,b9, a11,b11, a13,b13, a15,b15 */ + "trn1 v6.16b, v2.16b, v3.16b\n" /* get c0,d0, c2,d2, c4,d4, c6,d6, + c8,d8, c10,d10, c12,d12, c14,d14 */ + "trn2 v7.16b, v2.16b, v3.16b\n" /* get c1,d1, c3,d3, c5,d5, c7,d7, + c9,d9, c11,d11, c13,d13, c15,d15 */ + /* trans 8h */ + "trn1 v8.8h, v4.8h, v5.8h\n" /* get a0,b0, a1,b1, a4,b4, a5,b5, a8,b8, + a9,b9, a12,b12, a13,b13 */ + "trn2 v9.8h, v4.8h, v5.8h\n" /* get a2,b2, a3,b3, a6,b6, a7,b7, + a10,b10, a11,b11, a14,b14, a15,b15 */ + "trn1 v10.8h, v6.8h, v7.8h\n" /* get c0,d0, c1,d1, c4,d4, c5,d5, + c8,d8, c9,d9, c12,d12, c13,d13 */ + "trn2 v11.8h, v6.8h, v7.8h\n" /* get c2,d2, c3,d3, c6,d6, c7,d7, + c10,d10, c11,d11, c14,d14, c15,d15 */ + /* trans 4s */ + "trn1 v4.4s, v8.4s, v9.4s\n" /* get a0,b0, a1,b1, a2,b2, a3,b3, a8,b8, + a9,b9, a10,b10, a11,b11 */ + "trn2 v5.4s, v8.4s, v9.4s\n" /* get a4,b4, a5,b5, a6,b6, a7,b7, + a12,b12, a13,b13, a14,b14, a15,b15 */ + "trn1 v6.4s, v10.4s, v11.4s\n" /* get c0,d0, c1,d1, c2,d2, c3,d3, + c8,d8, c9,d9, c10,d10, c11,d11 */ + "trn2 v7.4s, v10.4s, v11.4s\n" /* get c4,d4, c5,d5, c6,d6, c7,d7, + c12,d12, c13,d13, c14,d14, c15,d15 + */ + /* trans 2d */ + "trn1 v8.2d, v4.2d, v6.2d\n" /* get a0,b0, a1,b1, a2,b2, a3,b3, c0,d0, + c1,d1, c2,d2, c3,d3 */ + "trn2 v10.2d, v4.2d, v6.2d\n" /* get a8,b8, a9,b9, a10,b10, a11,b11, + c8,d8, c9,d9, c10,d10, c11,d11 */ + "trn1 v9.2d, v5.2d, v7.2d\n" /* get a4,b4, a5,b5, a6,b6, a7,b7, c4,d4, + c5,d5, c6,d6, c7,d7 */ + "trn2 v11.2d, v5.2d, v7.2d\n" /* get a12,b12, a13,b13, a14,b14, + a15,b15, c12,d12, c13,d13, c14,d14, + c15,d15 */ + "st1 {v8.16b, v9.16b, v10.16b, v11.16b}, [%[ptr_out]]\n" /* save to + memory + */ + /* end */ + "2:\n" /* end */ + : [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), [ptr2] "+r"(ptr2), + [ptr3] "+r"(ptr3), [k] "+r"(k), [ptr_out] "+r"(outptr_row_col) + : [rem] "r"(rem), [mask] "w"(vmask), [vzero] "w"(vzero), + [stride] "r"(stride_out) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "cc"); +#else // armv7 + asm volatile( + "vld1.8 {d0}, [%[ptr0]]!\n" /* load r0, a0,a1,a2,a3,a4,a5,a6,a7 */ + "vld1.8 {d1}, [%[ptr1]]!\n" /* load r1, b0,b1,b2,b3,b4,b5,b6,b7 */ + "vld1.8 {d2}, [%[ptr2]]!\n" /* load r2, c0,c1,c2,c3,c4,c5,c6,c7 */ + "vld1.8 {d3}, [%[ptr3]]!\n" /* load r3, d0,d1,d2,d3,d4,d5,d6,d7 */ + "cmp %[k], #0\n" /* check main loop count */ + "beq 1f\n" /* jump to remain */ + "0:\n" /* main loop */ + /* trans 8b */ + "vtrn.8 d0, d1\n" /* get d0: a0,b0, a2,b2, a4,b4, a6,b6; d1: a1,b1, + a3,b3, a5,b5, a7,b7 */ + "vtrn.8 d2, d3\n" /* get d2: c0,d0, c2,d2, c4,d4, c6,d6; d3: c1,d1, + c3,d3, c5,d5, c7,d7 */ + /* trans 4h */ + "vtrn.16 d0, d1\n" /* get d0: a0,b0, a1,b1, a4,b4, a5,b5; d1: a2,b2, + a3,b3, a6,b6, a7,b7 */ + "vtrn.16 d2, d3\n" /* get d2: c0,d0, c1,d1, c4,d4, c5,d5; d3: c2,d2, + c3,d3, c6,d6, c7,d7 */ + "subs %[k], %[k], #1\n" /* loop - 1 */ + /* trans 2s */ + "vtrn.32 d0, d1\n" /* get d0: a0,b0, a1,b1, a2,b2, a3,b3; d1: a4,b4, + a5,b5, a6,b6, a7,b7 */ + "vtrn.32 d2, d3\n" /* get d2: c0,d0, c1,d1, c2,d2, c3,d3; d3: c4,d4, + c5,d5, c6,d6, c7,d7 */ + "vst1.8 {d0-d3}, [%[ptr_out]], %[stride]\n" /* save to memory */ + "vld1.8 {d0}, [%[ptr0]]!\n" /* load r0, a0,a1,a2,a3,a4,a5,a6,a7 */ + "vld1.8 {d1}, [%[ptr1]]!\n" /* load r1, b0,b1,b2,b3,b4,b5,b6,b7 */ + "vld1.8 {d2}, [%[ptr2]]!\n" /* load r2, c0,c1,c2,c3,c4,c5,c6,c7 */ + "vld1.8 {d3}, [%[ptr3]]!\n" /* load r3, d0,d1,d2,d3,d4,d5,d6,d7 */ + "bgt 0b\n" /* jump to main loop */ + "1:\n" /* process remain */ + "cmp %[rem], #0\n" /* check remain size */ + "beq 2f\n" /* jump to end */ + /* bit select */ + "vbif d0, %e[vzero], %e[mask]\n" /* pad 0 */ + "vbif d1, %e[vzero], %e[mask]\n" /* pad 0 */ + "vbif d2, %e[vzero], %e[mask]\n" /* pad 0 */ + "vbif d3, %e[vzero], %e[mask]\n" /* pad 0 */ + /* trans 8b */ + "vtrn.8 d0, d1\n" /* get d0: a0,b0, a2,b2, a4,b4, a6,b6; d1: a1,b1, + a3,b3, a5,b5, a7,b7 */ + "vtrn.8 d2, d3\n" /* get d2: c0,d0, c2,d2, c4,d4, c6,d6; d3: c1,d1, + c3,d3, c5,d5, c7,d7 */ + /* trans 4h */ + "vtrn.16 d0, d1\n" /* get d0: a0,b0, a1,b1, a4,b4, a5,b5; d1: a2,b2, + a3,b3, a6,b6, a7,b7 */ + "vtrn.16 d2, d3\n" /* get d2: c0,d0, c1,d1, c4,d4, c5,d5; d3: c2,d2, + c3,d3, c6,d6, c7,d7 */ + /* trans 2s */ + "vtrn.32 d0, d1\n" /* get d0: a0,b0, a1,b1, a2,b2, a3,b3; d1: a4,b4, + a5,b5, a6,b6, a7,b7 */ + "vtrn.32 d2, d3\n" /* get d2: c0,d0, c1,d1, c2,d2, c3,d3; d3: c4,d4, + c5,d5, c6,d6, c7,d7 */ + "vst1.8 {d0-d3}, [%[ptr_out]]\n" /* save to memory */ + /* end */ + "2:\n" /* end */ + : [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), [ptr2] "+r"(ptr2), + [ptr3] "+r"(ptr3), [k] "+r"(k), [ptr_out] "+r"(outptr_row_col) + : [rem] "r"(rem), [mask] "w"(vmask), [vzero] "w"(vzero), + [stride] "r"(stride_out) + : "q0", "q1", "cc"); +#endif //__aarch64__ // NOLINT + } +} + +/************************************************************************/ +// prepack B according to gemm kernel +// origin B data: +// B_origin(transpose, n x k: +// k unroll 2, a0=k0,k1 +// r0: ==> a0, a1, a2, a3, a4, a5, a6, a7 +// r1: ==> b0, b1, b2, b3, b4, b5, b6, b7 +// r2: ==> c0, c1, c2, c3, c4, c5, c6, c7 +// r3: ==> d0, d1, d2, d3, d4, d5, d6, d7 +// r4: ==> e0, e1, e2, e3, e4, e5, e6, e7 +// r5: ==> f0, f1, f2, f3, f4, f5, f6, f7 +// r6: ==> g0, g1, g2, g3, g4, g5, g6, g7 +// r7: ==> h0, h1, h2, h3, h4, h5, h6, h7 +// for armv8: +// B block size: (<4x2>x4) x2, can be described as below: +// packed B: +// a0,b0, c0,d0, a1,b1, c1,d1; +// e0,f0, g0,h0, e1,f1, g1,h1;--block0, address+64 +// . +// . +// . +// a6,b6, c6,d6, a7,b7, c7,d7; +// e6,f6, g6,h6, e7,f7, g7,h7;--block3, address+64 +// for armv7: +// B block size: (<8x2>x1) x2, can be described as below: +// packed B: +// a0,b0, c0,d0, e0,f0, g0,h0; +// a1,b1, c1,d1, e1,f1, g1,h1;--block0, address+32 +// . +// . +// . +// a6,b6, c6,d6, e6,f6, g6,h6; +// a7,b7, c7,d7, e7,f7, g7,h7;--block3, address+32 +/*******************************************************************/ +void packb_trans_int8(int8_t* out, const int8_t* in, const int ldin, + const int k0, const int kmax, const int n0, + const int nmax, const int8_t* zerobuf) { + const int KUNROLL = 4; + const int NUNROLL = 8; + const int RATIO = NBLOCK_INT8_OTH / NUNROLL; + const int8_t* inptr = in + n0 * ldin + k0; + const uint8_t mask_buffer[16] = {0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15}; + int y_len = nmax - n0; + int x_len = kmax - k0; + int yup = ROUNDUP(y_len, NBLOCK_INT8_OTH); + const int kup = ROUNDUP(x_len, KBLOCK_INT8); + const int KSTRIDE = KBLOCK_INT8 * KUNROLL; + int kcnt = x_len / KSTRIDE; + int x_rem = (x_len & (KSTRIDE - 1)); + int k_rem = (x_rem + KBLOCK_INT8 - 1) / KBLOCK_INT8; + const int stride_inner = KBLOCK_INT8 * NUNROLL; + const int stride_outer = kup * NBLOCK_INT8_OTH; + const int ncnt = yup / NUNROLL; + + int8x16_t vzero = vdupq_n_s8(0); + uint8x16_t vmask = vcltq_u8(vld1q_u8(mask_buffer), vdupq_n_u8(x_rem)); + +#pragma omp parallel for + for (int y = 0; y < ncnt; y++) { + int idx = y * NUNROLL; + const int8_t* ptr0 = inptr + idx * ldin; + const int8_t* ptr1 = ptr0 + ldin; + const int8_t* ptr2 = ptr1 + ldin; + const int8_t* ptr3 = ptr2 + ldin; + const int8_t* ptr4 = ptr3 + ldin; + const int8_t* ptr5 = ptr4 + ldin; + const int8_t* ptr6 = ptr5 + ldin; + const int8_t* ptr7 = ptr6 + ldin; + // only for ratio = 0 or 1 + int8_t* ptr_out = + out + (y & (RATIO - 1)) * stride_inner + (y / RATIO) * stride_outer; + if (idx + NUNROLL > y_len) { + switch (idx + NUNROLL - y_len) { + case 8: + ptr0 = zerobuf; + case 7: + ptr1 = zerobuf; + case 6: + ptr2 = zerobuf; + case 5: + ptr3 = zerobuf; + case 4: + ptr4 = zerobuf; + case 3: + ptr5 = zerobuf; + case 2: + ptr6 = zerobuf; + case 1: + ptr7 = zerobuf; + default: + break; + } + } + int k = kcnt; + int rem = k_rem; +#ifdef __aarch64__ + asm volatile( + "cbz %w[k], 1f\n" /* skip main loop */ + /* main loop */ + "0:\n" /* main loop */ + "ld1 {v0.16b}, [%[ptr0]], #16\n" /* load n0, k0~k15 */ + "ld1 {v1.16b}, [%[ptr1]], #16\n" /* load n1, k0~k15 */ + "ld1 {v2.16b}, [%[ptr2]], #16\n" /* load n2, k0~k15 */ + "ld1 {v3.16b}, [%[ptr3]], #16\n" /* load n3, k0~k15 */ + "ld1 {v4.16b}, [%[ptr4]], #16\n" /* load n4, k0~k15 */ + "ld1 {v5.16b}, [%[ptr5]], #16\n" /* load n5, k0~k15 */ + "ld1 {v6.16b}, [%[ptr6]], #16\n" /* load n6, k0~k15 */ + "ld1 {v7.16b}, [%[ptr7]], #16\n" /* load n7, k0~k15 */ + /* trans, 8h */ + "trn1 v8.8h, v0.8h, v1.8h\n" /* trans, zip n0,n1 */ + "trn2 v9.8h, v0.8h, v1.8h\n" /* trans, zip n0,n1 */ + "trn1 v10.8h, v2.8h, v3.8h\n" /* trans, zip n2,n3 */ + "trn2 v11.8h, v2.8h, v3.8h\n" /* trans, zip n2,n3 */ + "trn1 v12.8h, v4.8h, v5.8h\n" /* trans, zip n4,n5 */ + "trn2 v13.8h, v4.8h, v5.8h\n" /* trans, zip n4,n5 */ + "trn1 v14.8h, v6.8h, v7.8h\n" /* trans, zip n6,n7 */ + "trn2 v15.8h, v6.8h, v7.8h\n" /* trans, zip n6,n7 */ + /* trans, 4s */ + "trn1 v16.4s, v8.4s, v10.4s\n" /* trans, block 0 */ + "trn2 v17.4s, v8.4s, v10.4s\n" /* trans, block 0 */ + "trn1 v18.4s, v9.4s, v11.4s\n" /* trans, block 0 */ + "trn2 v19.4s, v9.4s, v11.4s\n" /* trans, block 0 */ + "trn1 v20.4s, v12.4s, v14.4s\n" /* trans, block 1 */ + "trn2 v21.4s, v12.4s, v14.4s\n" /* trans, block 1 */ + "trn1 v22.4s, v13.4s, v15.4s\n" /* trans, block 1 */ + "trn2 v23.4s, v13.4s, v15.4s\n" /* trans, block 1 */ + "subs %w[k], %w[k], #1\n" /* loop count -1 */ + /* trans, 2d */ + "trn1 v8.2d, v16.2d, v18.2d\n" /* trans, block 0, out0 */ + "trn1 v9.2d, v20.2d, v22.2d\n" /* trans, block 1, out0 */ + "trn1 v10.2d, v17.2d, v19.2d\n" /* trans, block 0, out1 */ + "trn1 v11.2d, v21.2d, v23.2d\n" /* trans, block 1, out1 */ + "trn2 v12.2d, v16.2d, v18.2d\n" /* trans, block 0, out2 */ + "trn2 v13.2d, v20.2d, v22.2d\n" /* trans, block 1, out2 */ + "trn2 v14.2d, v17.2d, v19.2d\n" /* trans, block 0, out3 */ + "trn2 v15.2d, v21.2d, v23.2d\n" /* trans, block 1, out3 */ + /* store result */ + "stp q8, q9, [%[ptr_out]],#64\n" /* write 0 */ + "stp q10, q11, [%[ptr_out]],#64\n" /* write 1 */ + "stp q12, q13, [%[ptr_out]],#64\n" /* write 2 */ + "stp q14, q15, [%[ptr_out]],#64\n" /* write 3 */ + "bgt 0b\n" /* jump to main loop */ + /* process remain */ + "1:\n" /* process remains */ + "cbz %w[rem], 2f\n" /* no remain, jump to end */ + "ld1 {v0.16b}, [%[ptr0]]\n" /* load n0, k0~k15 */ + "ld1 {v1.16b}, [%[ptr1]]\n" /* load n1, k0~k15 */ + "ld1 {v2.16b}, [%[ptr2]]\n" /* load n2, k0~k15 */ + "ld1 {v3.16b}, [%[ptr3]]\n" /* load n3, k0~k15 */ + "ld1 {v4.16b}, [%[ptr4]]\n" /* load n4, k0~k15 */ + "ld1 {v5.16b}, [%[ptr5]]\n" /* load n5, k0~k15 */ + "ld1 {v6.16b}, [%[ptr6]]\n" /* load n6, k0~k15 */ + "ld1 {v7.16b}, [%[ptr7]]\n" /* load n7, k0~k15 */ + /* bit select */ + "bif v0.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */ + "bif v1.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */ + "bif v2.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */ + "bif v3.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */ + "bif v4.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */ + "bif v5.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */ + "bif v6.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */ + "bif v7.16b, %[vzero].16b, %[mask].16b\n" /* pad 0 */ + /* trans, 8h */ + "trn1 v8.8h, v0.8h, v1.8h\n" /* trans, zip n0,n1 */ + "trn2 v9.8h, v0.8h, v1.8h\n" /* trans, zip n0,n1 */ + "trn1 v10.8h, v2.8h, v3.8h\n" /* trans, zip n2,n3 */ + "trn2 v11.8h, v2.8h, v3.8h\n" /* trans, zip n2,n3 */ + "trn1 v12.8h, v4.8h, v5.8h\n" /* trans, zip n4,n5 */ + "trn2 v13.8h, v4.8h, v5.8h\n" /* trans, zip n4,n5 */ + "trn1 v14.8h, v6.8h, v7.8h\n" /* trans, zip n6,n7 */ + "trn2 v15.8h, v6.8h, v7.8h\n" /* trans, zip n6,n7 */ + /* trans, 4s */ + "trn1 v16.4s, v8.4s, v10.4s\n" /* trans, block 0 */ + "trn2 v17.4s, v8.4s, v10.4s\n" /* trans, block 0 */ + "trn1 v18.4s, v9.4s, v11.4s\n" /* trans, block 0 */ + "trn2 v19.4s, v9.4s, v11.4s\n" /* trans, block 0 */ + "trn1 v20.4s, v12.4s, v14.4s\n" /* trans, block 1 */ + "trn2 v21.4s, v12.4s, v14.4s\n" /* trans, block 1 */ + "trn1 v22.4s, v13.4s, v15.4s\n" /* trans, block 1 */ + "trn2 v23.4s, v13.4s, v15.4s\n" /* trans, block 1 */ + /* trans, 2d */ + "trn1 v8.2d, v16.2d, v18.2d\n" /* trans, block 0, out0 */ + "trn1 v9.2d, v20.2d, v22.2d\n" /* trans, block 1, out0 */ + "trn1 v10.2d, v17.2d, v19.2d\n" /* trans, block 0, out1 */ + "trn1 v11.2d, v21.2d, v23.2d\n" /* trans, block 1, out1 */ + "trn2 v12.2d, v16.2d, v18.2d\n" /* trans, block 0, out2 */ + "trn2 v13.2d, v20.2d, v22.2d\n" /* trans, block 1, out2 */ + "trn2 v14.2d, v17.2d, v19.2d\n" /* trans, block 0, out3 */ + "trn2 v15.2d, v21.2d, v23.2d\n" /* trans, block 1, out3 */ + /* check remain size */ + "subs %w[rem], %w[rem], #1\n" /* check remain num */ + "stp q8, q9, [%[ptr_out]],#64\n" /* write 0 */ + "beq 2f\n" /* remain = 1 */ + "subs %w[rem], %w[rem], #1\n" /* check remain num */ + "stp q10, q11, [%[ptr_out]],#64\n" /* write 1 */ + "beq 2f\n" /* remain = 2 */ + "subs %w[rem], %w[rem], #1\n" /* check remain num */ + "stp q12, q13, [%[ptr_out]],#64\n" /* write 2 */ + "beq 2f\n" /* remain = 3 */ + "stp q14, q15, [%[ptr_out]]\n" /* write 3 */ + /* end */ + "2:\n" /* end */ + : [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), [ptr2] "+r"(ptr2), + [ptr3] "+r"(ptr3), [ptr4] "+r"(ptr4), [ptr5] "+r"(ptr5), + [ptr6] "+r"(ptr6), [ptr7] "+r"(ptr7), [ptr_out] "+r"(ptr_out), + [k] "+r"(k), [rem] "+r"(rem) + : [mask] "w"(vmask), [vzero] "w"(vzero) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "cc"); +#else // armv7 + asm volatile( + "cmp %[k], #0\n" /* check main loop */ + "beq 1f\n" /* skip main loop */ + /* main loop */ + "0:\n" /* main loop */ + "vld1.8 {d0-d1}, [%[ptr0]]!\n" /* load n0, a0~a7 */ + "vld1.8 {d2-d3}, [%[ptr1]]!\n" /* load n1, b0~b7 */ + "vld1.8 {d4-d5}, [%[ptr2]]!\n" /* load n2, c0~c7 */ + "vld1.8 {d6-d7}, [%[ptr3]]!\n" /* load n3, d0~d7 */ + "vld1.8 {d8-d9}, [%[ptr4]]!\n" /* load n4, e0~e7 */ + "vld1.8 {d10-d11}, [%[ptr5]]!\n" /* load n5, f0~f7 */ + "vld1.8 {d12-d13}, [%[ptr6]]!\n" /* load n6, g0~g7 */ + "vld1.8 {d14-d15}, [%[ptr7]]!\n" /* load n7, h0~h7 */ + /* trans, 8h */ + "vtrn.16 q0, q1\n" /* trans, zip n0,n1, q0: a0b0,a2b2, a4b4,a6b6, q1: + a1b1,a3b3, a5b5,a7b7 */ + "vtrn.16 q2, q3\n" /* trans, zip n2,n3, q2: c0d0,c2d2, c4d4,c6d6, q3: + c1d1,c3d3, c5d5,c7d7 */ + "vtrn.16 q4, q5\n" /* trans, zip n4,n5, q4: e0f0,e2f2, e4f4,e6f6, q5: + e1f1,e3f3, e5f5,e7f7 */ + "vtrn.16 q6, q7\n" /* trans, zip n6,n7, q6: g0h0,g2h2, g4h4,g6h6, q7: + g1h1,g3h3, g5h5,g7h7 */ + /* trans, 4s */ + "vtrn.32 q0, q2\n" /* trans, q0: a0b0,c0d0, a4b4,c4d4, q2: a2b2,c2d2, + a6b6,c6d6 */ + "vtrn.32 q1, q3\n" /* trans, q1: a1b1,c1d1, a5b5,c5d5, q3: a3b3,c3d3, + a7b7,c7d7 */ + "vtrn.32 q4, q6\n" /* trans, q4: e0f0,g0h0, e4f4,g4h4, q6: e2f2,g2h2, + e6f6,g6h6 */ + "vtrn.32 q5, q7\n" /* trans, q5: e1f1,g1h1, e5f5,g5h5, q7: e3f3,g3h3, + e7f7,g7h7 */ + "subs %[k], %[k], #1\n" /* loop count -1 */ + /* trans, 2d */ + "vswp d1, d8\n" /* q0: a0b0,c0d0, e0f0,g0h0, q4: a4b4,c4d4, e4f4,g4h4 + */ + "vswp d3, d10\n" /* q1: a1b1,c1d1, e1f1,g1h1, q5: a5b5,c5d5, e5f5,g5h5 + */ + "vswp d5, d12\n" /* q2: a2b2,c2d2, e2f2,g2h2, q6: a6b6,c6d6, e6f6,g6h6 + */ + "vswp d7, d14\n" /* q3: a3b3,c3d3, e3f3,g3h3, q7: a7b7,c7d7, e7f7,g7h7 + */ + /* store result */ + "vst1.8 {d0-d3}, [%[ptr_out]]!\n" /* write 0 */ + "vst1.8 {d4-d7}, [%[ptr_out]]!\n" /* write 1 */ + "vst1.8 {d8-d11}, [%[ptr_out]]!\n" /* write 2 */ + "vst1.8 {d12-d15}, [%[ptr_out]]!\n" /* write 3 */ + "bgt 0b\n" /* jump to main loop */ + /* process remain */ + "1:\n" /* process remains */ + "cmp %[rem], #0\n" /* check remain */ + "beq 2f\n" /* no remain, jump to end */ + "vld1.8 {d0-d1}, [%[ptr0]]!\n" /* load n0, a0~a7 */ + "vld1.8 {d2-d3}, [%[ptr1]]!\n" /* load n1, b0~b7 */ + "vld1.8 {d4-d5}, [%[ptr2]]!\n" /* load n2, c0~c7 */ + "vld1.8 {d6-d7}, [%[ptr3]]!\n" /* load n3, d0~d7 */ + "vld1.8 {d8-d9}, [%[ptr4]]!\n" /* load n4, e0~e7 */ + "vld1.8 {d10-d11}, [%[ptr5]]!\n" /* load n5, f0~f7 */ + "vld1.8 {d12-d13}, [%[ptr6]]!\n" /* load n6, g0~g7 */ + "vld1.8 {d14-d15}, [%[ptr7]]!\n" /* load n7, h0~h7 */ + /* bit select */ + "vbif q0, %q[vzero], %q[mask]\n" /* pad 0 */ + "vbif q1, %q[vzero], %q[mask]\n" /* pad 0 */ + "vbif q2, %q[vzero], %q[mask]\n" /* pad 0 */ + "vbif q3, %q[vzero], %q[mask]\n" /* pad 0 */ + "vbif q4, %q[vzero], %q[mask]\n" /* pad 0 */ + "vbif q5, %q[vzero], %q[mask]\n" /* pad 0 */ + "vbif q6, %q[vzero], %q[mask]\n" /* pad 0 */ + "vbif q7, %q[vzero], %q[mask]\n" /* pad 0 */ + /* trans, 8h */ + "vtrn.16 q0, q1\n" /* trans, zip n0,n1, q0: a0b0,a2b2, a4b4,a6b6, q1: + a1b1,a3b3, a5b5,a7b7 */ + "vtrn.16 q2, q3\n" /* trans, zip n2,n3, q2: c0d0,c2d2, c4d4,c6d6, q3: + c1d1,c3d3, c5d5,c7d7 */ + "vtrn.16 q4, q5\n" /* trans, zip n4,n5, q4: e0f0,e2f2, e4f4,e6f6, q5: + e1f1,e3f3, e5f5,e7f7 */ + "vtrn.16 q6, q7\n" /* trans, zip n6,n7, q6: g0h0,g2h2, g4h4,g6h6, q7: + g1h1,g3h3, g5h5,g7h7 */ + /* trans, 4s */ + "vtrn.32 q0, q2\n" /* trans, q0: a0b0,c0d0, a4b4,c4d4, q2: a2b2,c2d2, + a6b6,c6d6 */ + "vtrn.32 q1, q3\n" /* trans, q1: a1b1,c1d1, a5b5,c5d5, q3: a3b3,c3d3, + a7b7,c7d7 */ + "vtrn.32 q4, q6\n" /* trans, q4: e0f0,g0h0, e4f4,g4h4, q6: e2f2,g2h2, + e6f6,g6h6 */ + "vtrn.32 q5, q7\n" /* trans, q5: e1f1,g1h1, e5f5,g5h5, q7: e3f3,g3h3, + e7f7,g7h7 */ + /* trans, 2d */ + "vswp d1, d8\n" /* q0: a0b0,c0d0, e0f0,g0h0, q4: a4b4,c4d4, e4f4,g4h4 + */ + "vswp d3, d10\n" /* q1: a1b1,c1d1, e1f1,g1h1, q5: a5b5,c5d5, e5f5,g5h5 + */ + "vswp d5, d12\n" /* q2: a2b2,c2d2, e2f2,g2h2, q6: a6b6,c6d6, e6f6,g6h6 + */ + "vswp d7, d14\n" /* q3: a3b3,c3d3, e3f3,g3h3, q7: a7b7,c7d7, e7f7,g7h7 + */ + /* check remain size */ + "subs %[rem], %[rem], #1\n" /* check remain num */ + "vst1.8 {d0-d3}, [%[ptr_out]]!\n" /* write 0 */ + "beq 2f\n" /* remain = 1 */ + "subs %[rem], %[rem], #1\n" /* check remain num */ + "vst1.8 {d4-d7}, [%[ptr_out]]!\n" /* write 1 */ + "beq 2f\n" /* remain = 2 */ + "subs %[rem], %[rem], #1\n" /* check remain num */ + "vst1.8 {d8-d11}, [%[ptr_out]]!\n" /* write 2 */ + "beq 2f\n" /* remain = 3 */ + "vst1.8 {d12-d15}, [%[ptr_out]]!\n" /* write 3 */ + /* end */ + "2:\n" /* end */ + : [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), [ptr2] "+r"(ptr2), + [ptr3] "+r"(ptr3), [ptr4] "+r"(ptr4), [ptr5] "+r"(ptr5), + [ptr6] "+r"(ptr6), [ptr7] "+r"(ptr7), [ptr_out] "+r"(ptr_out), + [k] "+r"(k), [rem] "+r"(rem) + : [mask] "w"(vmask), [vzero] "w"(vzero) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "cc"); +#endif //__aarch64__ // NOLINT + } +} + +#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) + +template +void gemm_prepack_sdot_int8(const int8_t* A_packed, + const int8_t* B, + const int* bias, + Dtype* C, + int M, + int N, + int K, + bool is_bias, + bool is_relu, + bool is_transB, + const float* scale, + ARMContext* ctx) { + size_t llc_size = ctx->llc_size() / 4; + auto workspace = ctx->workspace_data(); + //! MBLOCK_INT8_DOT * x (result) + MBLOCK_INT8_DOT * k (A) + x * k (B) = l2 + int x_block = (llc_size - (MBLOCK_INT8_DOT * K)) / \ + (sizeof(int8_t) * (K + MBLOCK_INT8_DOT)); + x_block /= NBLOCK_INT8_DOT; + x_block *= NBLOCK_INT8_DOT; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + NBLOCK_INT8_DOT - 1) / NBLOCK_INT8_DOT; + x_block *= NBLOCK_INT8_DOT; + x_block = x_block < NBLOCK_INT8_DOT ? NBLOCK_INT8_DOT : x_block; + + int kup = ROUNDUP(K, KBLOCK_INT8); + // unroll 2 loop + int tail_pre = ((kup / 4) & (KBLOCK_INT8 - 1)); + int k_pre = (((kup / 4) + KBLOCK_INT8 - 1) / KBLOCK_INT8) - 1; + + bool flag_p_remain = false; + int remain = 0; + + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + NBLOCK_INT8_DOT - 1) / NBLOCK_INT8_DOT; + remain = xmax - x0 - (bblocks - 1) * NBLOCK_INT8_DOT; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + auto b_pannel = static_cast(workspace); + if (!is_transB) { + // K * N + packb_sdot_int8(b_pannel, B, N, 0, K, x0, xmax); + } else { + // N X K + packb_sdot_trans_int8(b_pannel, B, K, 0, K, x0, xmax); + } +#pragma omp parallel for + for (unsigned int y = 0; y < M; y += MBLOCK_INT8_DOT) { + unsigned int ymax = y + MBLOCK_INT8_DOT; + if (ymax > M) { + ymax = M; + } + + int32_t bias_local[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + if (is_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + bias_local[4] = bias[y + 4]; + bias_local[5] = bias[y + 5]; + bias_local[6] = bias[y + 6]; + bias_local[7] = bias[y + 7]; + } + float32_t scale_local[8]; + if (scale) { + scale_local[0] = scale[y]; + scale_local[1] = scale[y + 1]; + scale_local[2] = scale[y + 2]; + scale_local[3] = scale[y + 3]; + scale_local[4] = scale[y + 4]; + scale_local[5] = scale[y + 5]; + scale_local[6] = scale[y + 6]; + scale_local[7] = scale[y + 7]; + } + + Dtype cout0[NBLOCK_INT8_DOT]; + Dtype cout1[NBLOCK_INT8_DOT]; + Dtype cout2[NBLOCK_INT8_DOT]; + Dtype cout3[NBLOCK_INT8_DOT]; + Dtype cout4[NBLOCK_INT8_DOT]; + Dtype cout5[NBLOCK_INT8_DOT]; + Dtype cout6[NBLOCK_INT8_DOT]; + Dtype cout7[NBLOCK_INT8_DOT]; + + Dtype *c_ptr0 = C + y * N + x0; + Dtype *c_ptr1 = c_ptr0 + N; + Dtype *c_ptr2 = c_ptr1 + N; + Dtype *c_ptr3 = c_ptr2 + N; + Dtype *c_ptr4 = c_ptr3 + N; + Dtype *c_ptr5 = c_ptr4 + N; + Dtype *c_ptr6 = c_ptr5 + N; + Dtype *c_ptr7 = c_ptr6 + N; + + Dtype *pout0 = c_ptr0; + Dtype *pout1 = c_ptr1; + Dtype *pout2 = c_ptr2; + Dtype *pout3 = c_ptr3; + Dtype *pout4 = c_ptr4; + Dtype *pout5 = c_ptr5; + Dtype *pout6 = c_ptr6; + Dtype *pout7 = c_ptr7; + + // const int8_t *a_ptr_l = A_packed + y * K; + const int8_t *a_ptr_l = A_packed + y * kup; + const int8_t *b_ptr = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 7) >= ymax) { + switch ((y + 7) - ymax) { + case 6: + c_ptr1 = cout1; + case 5: + c_ptr2 = cout2; + case 4: + c_ptr3 = cout3; + case 3: + c_ptr4 = cout4; + case 2: + c_ptr5 = cout5; + case 1: + c_ptr6 = cout6; + case 0: + c_ptr7 = cout7; + default: + break; + } + } + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + pout4 = c_ptr4; + pout5 = c_ptr5; + pout6 = c_ptr6; + pout7 = c_ptr7; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + c_ptr4 = cout4; + c_ptr5 = cout5; + c_ptr6 = cout6; + c_ptr7 = cout7; + } + const int8_t *a_ptr = a_ptr_l; + int tail = tail_pre; + int k = k_pre; + sgemm_sdot_int8_kernel(a_ptr, b_ptr, + bias_local, c_ptr0, c_ptr1, c_ptr2, c_ptr3, \ + c_ptr4, c_ptr5, c_ptr6, c_ptr7, scale_local, \ + is_relu, k, tail); + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + *pout4++ = cout4[i]; + *pout5++ = cout5[i]; + *pout6++ = cout6[i]; + *pout7++ = cout7[i]; + } + } + } + } + } +} + +void prepackA_m8k4_int8(int8_t* out, + const int8_t* in, + const int ldin, + const int m0, + const int mmax, + const int k0, + const int kmax) { + int x_len = (kmax - k0); + int8_t zerobuff[x_len]; //NOLINT + memset(zerobuff, 0, sizeof(int8_t) * x_len); + + int8_t *dout = out; + const int8_t *inptr = in; + int kup = ROUNDUP(x_len, KBLOCK_INT8); + int stride = kup * 8; + int remain = x_len % 4; +#pragma omp parallel for + for (int y = m0; y < mmax; y += 8) { + int8_t* outptr = dout + stride * (y - m0) / 8; + const int8_t * inptr_row[8]; + inptr_row[0] = inptr + y * ldin + k0; + for (int i = 1; i < 8; i++) { + inptr_row[i] = inptr_row[i - 1] + ldin; + } + //! cope with row index exceed real size, set to zero buffer + if ((y + 7) >= mmax) { + switch ((y + 7) - mmax) { + case 6: + inptr_row[1] = zerobuff; + case 5: + inptr_row[2] = zerobuff; + case 4: + inptr_row[3] = zerobuff; + case 3: + inptr_row[4] = zerobuff; + case 2: + inptr_row[5] = zerobuff; + case 1: + inptr_row[6] = zerobuff; + case 0: + inptr_row[7] = zerobuff; + default: + break; + } + } + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + "prfm pldl1keep, [%[ptr4]] \n" + "prfm pldl1keep, [%[ptr4], #64] \n" + "prfm pldl1keep, [%[ptr5]] \n" + "prfm pldl1keep, [%[ptr5], #64] \n" + "prfm pldl1keep, [%[ptr6]] \n" + "prfm pldl1keep, [%[ptr6], #64] \n" + "prfm pldl1keep, [%[ptr7]] \n" + "prfm pldl1keep, [%[ptr7], #64] \n" + : + :[ptr0] "r"(inptr_row[0]),[ptr1] "r"(inptr_row[1]),[ptr2] "r"(inptr_row[2]),[ptr3] "r"(inptr_row[3]),\ + [ptr4] "r"(inptr_row[4]),[ptr5] "r"(inptr_row[5]),[ptr6] "r"(inptr_row[6]),[ptr7] "r"(inptr_row[7]) + :"memory" + ); + + int x = x_len; + + for (; x > 7; x -= 8) { + asm volatile( + "ld1 {v0.8b}, [%[inptr0]], #8 \n" // v0=a0a1a2a3a4a5a6a7 + "ld1 {v1.8b}, [%[inptr1]], #8 \n" // v1=b0b1b2b3b4b5b6b7 + "ld1 {v2.8b}, [%[inptr2]], #8 \n" // v2=c0c1c2c3c4c5c6c7 + "ld1 {v3.8b}, [%[inptr3]], #8 \n" // v3=d0d1d2d3d4d5d6d7 + + "ld1 {v4.8b}, [%[inptr4]], #8 \n" // v0=e0e1a2a3a4a5a6a7 + "ld1 {v5.8b}, [%[inptr5]], #8 \n" // v1=f0f1b2b3b4b5b6b7 + "ld1 {v6.8b}, [%[inptr6]], #8 \n" // v2=g0g1c2c3c4c5c6c7 + "ld1 {v7.8b}, [%[inptr7]], #8 \n" // v3=h0h1d2d3d4d5d6d7 + + "trn1 v8.2s, v0.2s, v1.2s \n" // v0=a0a1a2a3b0b1b2b3 + "trn2 v9.2s, v0.2s, v1.2s \n" // v0=a4a5a6a7b4b5b6b7 + "trn1 v10.2s, v2.2s, v3.2s \n" // v0=c0c1c2c3d0d1d2d3 + "trn2 v11.2s, v2.2s, v3.2s \n" // v0=c4c5c6c7d4d5d6d7 + + "trn1 v12.2s, v4.2s, v5.2s \n" // v0=e0e1e2e3f0f1f2f3 + "trn2 v13.2s, v4.2s, v5.2s \n" // v0=e4e5e6e7f4f5f6f7 + "trn1 v14.2s, v6.2s, v7.2s \n" // v0=g0g1g2g3h0h1h2h3 + "trn2 v15.2s, v6.2s, v7.2s \n" // v0=g4g5g6g7h4h5h6h7 + + "st1 {v8.2s}, [%[outptr]], #8\n" + "st1 {v10.2s}, [%[outptr]], #8\n" + "st1 {v12.2s}, [%[outptr]], #8\n" + "st1 {v14.2s}, [%[outptr]], #8\n" + + "st1 {v9.2s}, [%[outptr]], #8\n" + "st1 {v11.2s}, [%[outptr]], #8\n" + "st1 {v13.2s}, [%[outptr]], #8\n" + "st1 {v15.2s}, [%[outptr]], #8\n" + + :[inptr0] "+r"(inptr_row[0]), [inptr1] "+r"(inptr_row[1]), + [inptr2] "+r"(inptr_row[2]), [inptr3] "+r"(inptr_row[3]), + [inptr4] "+r"(inptr_row[4]), [inptr5] "+r"(inptr_row[5]), + [inptr6] "+r"(inptr_row[6]), [inptr7] "+r"(inptr_row[7]), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "cc", "memory" + ); + } + if (x >= 4) { + asm volatile( + "mov x1, #4 \n" + "ld1 {v0.8b}, [%[inptr0]], x1 \n" // v0=a0a1a2a3a4a5a6a7 + "ld1 {v1.8b}, [%[inptr1]], x1 \n" // v1=b0b1b2b3b4b5b6b7 + "ld1 {v2.8b}, [%[inptr2]], x1 \n" // v2=c0c1c2c3c4c5c6c7 + "ld1 {v3.8b}, [%[inptr3]], x1 \n" // v3=d0d1d2d3d4d5d6d7 + + "ld1 {v4.8b}, [%[inptr4]], x1 \n" // v0=e0e1a2a3a4a5a6a7 + "ld1 {v5.8b}, [%[inptr5]], x1 \n" // v1=f0f1b2b3b4b5b6b7 + "ld1 {v6.8b}, [%[inptr6]], x1 \n" // v2=g0g1c2c3c4c5c6c7 + "ld1 {v7.8b}, [%[inptr7]], x1 \n" // v3=h0h1d2d3d4d5d6d7 + + "trn1 v8.2s, v0.2s, v1.2s \n" // v0=a0a1a2a3b0b1b2b3 + "trn1 v10.2s, v2.2s, v3.2s \n" // v0=c0c1c2c3d0d1d2d3 + + "trn1 v12.2s, v4.2s, v5.2s \n" // v0=e0e1e2e3f0f1f2f3 + "trn1 v14.2s, v6.2s, v7.2s \n" // v0=g0g1g2g3h0h1h2h3 + + "st1 {v8.2s}, [%[outptr]], #8\n" + "st1 {v10.2s}, [%[outptr]], #8\n" + + "st1 {v12.2s}, [%[outptr]], #8\n" + "st1 {v14.2s}, [%[outptr]], #8\n" + + :[inptr0] "+r"(inptr_row[0]), [inptr1] "+r"(inptr_row[1]), + [inptr2] "+r"(inptr_row[2]), [inptr3] "+r"(inptr_row[3]), + [inptr4] "+r"(inptr_row[4]), [inptr5] "+r"(inptr_row[5]), + [inptr6] "+r"(inptr_row[6]), [inptr7] "+r"(inptr_row[7]), + [outptr] "+r"(outptr) + : + : "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "cc", "memory" + ); + x -= 4; + } + if (x > 0) { + for (int i = 0; i < 8; i++) { + for (int j = x; j > 0; j--) { + *outptr++ = *inptr_row[i]++; + } + for (int j = 0; j < 4 - remain; j++) { + *outptr++ = 0; + } + } + } + } +} + +void prepackA_m8k4_trans_int8(int8_t* out, + const int8_t* in, + const int ldin, + const int m0, + const int mmax, + const int k0, + const int kmax) { + int8_t *outptr = out; + const int8_t *inptr = in + k0 * ldin + m0; + int x_len = mmax - m0; + int y_len = kmax - k0; + int right_remain = x_len % 8; + int kup = ROUNDUP(y_len, KBLOCK_INT8); + + int stride_out = 8 * kup; + int8_t zerobuff[x_len]; //NOLINT + memset(zerobuff, 0, sizeof(int8_t) * x_len); + printf("right_remain: %d \n", right_remain); + +#pragma omp parallel for + for (int y = 0; y < y_len; y += 4) { + const int8_t* inptr0 = inptr + y * ldin; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + if (y + 4 > y_len) { + switch (y + 4 - y_len) { + case 3: + inptr1 = zerobuff; + case 2: + inptr2 = zerobuff; + case 1: + inptr3 = zerobuff; + default: + break; + } + } + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + : + :[ptr0] "r"(inptr0),[ptr1] "r"(inptr1),[ptr2] "r"(inptr2), + [ptr3] "r"(inptr3) + :"memory" + ); + + int8_t *outptr_row = outptr + y * 8; + int x = 0; + for (; x < x_len - 7; x += 8) { + int8_t *out0 = outptr_row; + asm volatile ( + "ld1 {v0.8b}, [%[inptr0]], #8 \n" // v0 = a0a1a2a3a4a5a6a7 + "ld1 {v1.8b}, [%[inptr1]], #8 \n" // v0 = b0b1b2b3b4b5b6b7 + "ld1 {v2.8b}, [%[inptr2]], #8 \n" // v0 = c0c1c2c3c4c5c6c7 + "ld1 {v3.8b}, [%[inptr3]], #8 \n" // v0 = d0d1d2d3d4d5d6d7 + + "trn1 v4.8b, v0.8b, v1.8b \n" // v4 = a0b0a2b2a4b4a6b6 + "trn2 v5.8b, v0.8b, v1.8b \n" // v4 = a1b1a3b3a5b5a7b7 + "trn1 v6.8b, v2.8b, v3.8b \n" // v4 = c0d0c2d2a4b4a6b6 + "trn2 v7.8b, v2.8b, v3.8b \n" // v4 = c1d1c3d3a5b5a7b7 + + "trn1 v0.4h, v4.4h, v6.4h \n" // v4 = a0b0c0d0a4b4c4d4 + "trn2 v1.4h, v4.4h, v6.4h \n" // v4 = a2b2c2d2a6b6c6d6 + "trn1 v2.4h, v5.4h, v7.4h \n" // v4 = a1b1c1d1a5b5c5d5 + "trn2 v3.4h, v5.4h, v7.4h \n" // v4 = a3b3c3d3a7b7c7d7 + + "trn1 v4.2s, v0.2s, v2.2s \n" //v4 =a0b0c0d0a1b1c1d1 + "trn2 v5.2s, v0.2s, v2.2s \n" //v4 =a4b4c4d4a5b5c5d5 + "trn1 v6.2s, v1.2s, v3.2s \n" //v4 =a2b2c2d2a3b3c3d3 + "trn2 v7.2s, v1.2s, v3.2s \n" //v4 =a6b6c6d6a7b7c7d7 + + "st1 {v4.2s}, [%[outr]], #8\n" + "st1 {v6.2s}, [%[outr]], #8\n" + "st1 {v5.2s}, [%[outr]], #8\n" + "st1 {v7.2s}, [%[outr]], #8\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outr] "+r"(out0) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", + "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "cc", "memory" + ); + outptr_row += stride_out; + } + if (right_remain > 0) { + int8_t *out0 = outptr_row; + for (; x < x_len; x++) { + *out0++ = *inptr0++; + *out0++ = *inptr1++; + *out0++ = *inptr2++; + *out0++ = *inptr3++; + } + for (int i = 0; i < 8 - right_remain; i++) { + *out0++ = 0; + *out0++ = 0; + *out0++ = 0; + *out0++ = 0; + } + } + } +} + +void packb_sdot_int8(int8_t* out, + const int8_t* in, + const int ldin, + const int k0, + const int kmax, + const int n0, + const int nmax) { + int y_len = kmax - k0; + int x_len = nmax - n0; + int kup = ROUNDUP(y_len, KBLOCK_INT8); // 4k + int8_t zerobuff[x_len]; //NOLINT + memset(zerobuff, 0, sizeof(int8_t) * x_len); + int8_t *outptr = out; + const int8_t *inptr = in + k0 * ldin + n0; + + int stride_out = 12 * kup; + // int stride_y = 48; + int remain = x_len % 12; + + // data B is not transposed, transpose B to k * 12 +#pragma omp parallel for + for (int y = 0; y < y_len; y += 4) { + // cope with row index exceed real size, set to zero + const int8_t *inptr0 = inptr + y * ldin; + const int8_t *inptr1 = inptr0 + ldin; + const int8_t *inptr2 = inptr1 + ldin; + const int8_t *inptr3 = inptr2 + ldin; + if (y + 4 > y_len) { + switch (y + 4 - y_len) { + case 3: + inptr1 = zerobuff; + case 2: + inptr2 = zerobuff; + case 1: + inptr3 = zerobuff; + default: + break; + } + } + asm volatile( + "prfm pldl1keep, [%[inptr0]] \n" + "prfm pldl1keep, [%[inptr0], #64] \n" + "prfm pldl1keep, [%[inptr1]] \n" + "prfm pldl1keep, [%[inptr1], #64] \n" + "prfm pldl1keep, [%[inptr2]] \n" + "prfm pldl1keep, [%[inptr2], #64] \n" + "prfm pldl1keep, [%[inptr3]] \n" + "prfm pldl1keep, [%[inptr3], #64] \n" + : + :[inptr0] "r"(inptr0), [inptr1] "r"(inptr1), + [inptr2] "r"(inptr2), [inptr3] "r"(inptr3) + :"memory" + ); + int8_t* outptr_row = outptr + y * 12; + int x = 0; + for (; x < x_len - 11; x += 12) { + int8_t *out0 = outptr_row; + asm volatile ( + "mov x1, #4 \n" + "ld1 {v0.8b}, [%[inptr0]], #8 \n" // v0 = a0a1a2a3a4a5a6a7 + "ld1 {v1.8b}, [%[inptr1]], #8 \n" // v0 = b0b1b2b3b4b5b6b7 + "ld1 {v2.8b}, [%[inptr2]], #8 \n" // v0 = c0c1c2c3c4c5c6c7 + "ld1 {v3.8b}, [%[inptr3]], #8 \n" // v0 = d0d1d2d3d4d5d6d7 + + "ld1 {v8.8b}, [%[inptr0]] \n" // v0 = a8a9a10a11 + "ld1 {v9.8b}, [%[inptr1]] \n" // v0 = b8b9b10b11 + "ld1 {v10.8b}, [%[inptr2]] \n" // v0 = c8c9c10c11 + "ld1 {v11.8b}, [%[inptr3]] \n" // v0 = d8d9d10d11 + + "trn1 v4.8b, v0.8b, v1.8b \n" // v4 = a0b0a2b2a4b4a6b6 + "trn2 v5.8b, v0.8b, v1.8b \n" // v4 = a1b1a3b3a5b5a7b7 + "trn1 v6.8b, v2.8b, v3.8b \n" // v4 = c0d0c2d2a4b4a6b6 + "trn2 v7.8b, v2.8b, v3.8b \n" // v4 = c1d1c3d3a5b5a7b7 + + "trn1 v12.8b, v8.8b, v9.8b \n" // v4 = a8b8a10b10a4b4a6b6 + "trn2 v13.8b, v8.8b, v9.8b \n" // v4 = a9b9a11b11a5b5a7b7 + "trn1 v14.8b, v10.8b, v11.8b \n" // v4 = c8d8c10d10a4b4a6b6 + "trn2 v15.8b, v10.8b, v11.8b \n" // v4 = c9d9c11d11a5b5a7b7 + + "trn1 v0.4h, v4.4h, v6.4h \n" // v4 = a0b0c0d0a4b4c4d4 + "trn2 v1.4h, v4.4h, v6.4h \n" // v4 = a2b2c2d2a6b6c6d6 + "trn1 v2.4h, v5.4h, v7.4h \n" // v4 = a1b1c1d1a5b5c5d5 + "trn2 v3.4h, v5.4h, v7.4h \n" // v4 = a3b3c3d3a7b7c7d7 + + "trn1 v8.4h, v12.4h, v14.4h \n" // v4 = a8b8c8d8 + "trn2 v9.4h, v12.4h, v14.4h \n" // v4 = a10b10c10d10 + "trn1 v10.4h, v13.4h, v15.4h \n" // v4 = a9b9c9d9 + "trn2 v11.4h, v13.4h, v15.4h \n" // v4 = a11b11c11d11 + + "trn1 v4.2s, v0.2s, v2.2s \n" //v4 =a0b0c0d0a1b1c1d1 + "trn2 v5.2s, v0.2s, v2.2s \n" //v4 =a4b4c4d4a5b5c5d5 + "trn1 v6.2s, v1.2s, v3.2s \n" //v4 =a2b2c2d2a3b3c3d3 + "trn2 v7.2s, v1.2s, v3.2s \n" //v4 =a6b6c6d6a7b7c7d7 + + "trn1 v0.2s, v8.2s, v10.2s \n" //v4 =a8b8c8d8a9b9c9d9 + "trn1 v1.2s, v9.2s, v11.2s \n" //v4 =a10b10c10d10a11b11c11d11 + + "st1 {v4.2s}, [%[outr]], #8\n" + "st1 {v6.2s}, [%[outr]], #8\n" + "add %[inptr0], %[inptr0], #4\n" + "add %[inptr1], %[inptr1], #4\n" + "st1 {v5.2s}, [%[outr]], #8\n" + "st1 {v7.2s}, [%[outr]], #8\n" + "add %[inptr2], %[inptr2], #4\n" + "add %[inptr3], %[inptr3], #4\n" + "st1 {v0.2s}, [%[outr]], #8\n" + "st1 {v1.2s}, [%[outr]], #8\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [outr] "+r"(out0) + : + : "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "cc", "memory" + ); + outptr_row += stride_out; + } + int8_t* out0 = outptr_row; // outptr + stride_out + y * remain; + for (; x < x_len; x++) { + *out0++ = *inptr0++; + *out0++ = *inptr1++; + *out0++ = *inptr2++; + *out0++ = *inptr3++; + } + for (int i = 0; i < 12 - remain; i++) { + *out0++ = 0; + *out0++ = 0; + *out0++ = 0; + *out0++ = 0; + } + } +} + +void packb_sdot_trans_int8(int8_t* out, + const int8_t* in, + const int ldin, + const int k0, + const int kmax, + const int n0, + const int nmax) { + int8_t *outptr = out; + const int8_t *inptr = in + n0 * ldin + k0; + int y_len = nmax - n0; + int x_len = kmax - k0; + + int kup = ROUNDUP(x_len, KBLOCK_INT8); // 4 + + int8_t zerobuff[kup]; //NOLINT + memset(zerobuff, 0, sizeof(int8_t) * kup); + + int stride_y = 48; + int stride_out = kup; + + int remain = x_len % 8; + +#pragma omp parallel for + for (int y = 0; y < y_len; y += 12) { + const int8_t *inptr_row[12]; + inptr_row[0] = inptr + y * ldin; + for (int i = 1; i < 12; i++) { + inptr_row[i] = inptr_row[i - 1] + ldin; + } + if (y + 12 > y_len) { + for (int i = y + 12 - y_len; i > 0; i--) { + // inptr_row[12 - i] = zero_ptr[12 - i - 1]; + inptr_row[12 - i] = zerobuff; + } + } + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr4]] \n" + "prfm pldl1keep, [%[ptr5]] \n" + "prfm pldl1keep, [%[ptr6]] \n" + "prfm pldl1keep, [%[ptr7]] \n" + "prfm pldl1keep, [%[ptr8]] \n" + "prfm pldl1keep, [%[ptr9]] \n" + "prfm pldl1keep, [%[ptr10]] \n" + "prfm pldl1keep, [%[ptr11]] \n" + : + :[ptr0] "r"(inptr_row[0]), [ptr1] "r"(inptr_row[1]), + [ptr2] "r"(inptr_row[2]), [ptr3] "r"(inptr_row[3]), + [ptr4] "r"(inptr_row[4]), [ptr5] "r"(inptr_row[5]), + [ptr6] "r"(inptr_row[6]), [ptr7] "r"(inptr_row[7]), + [ptr8] "r"(inptr_row[8]), [ptr9] "r"(inptr_row[9]), + [ptr10] "r"(inptr_row[10]), [ptr11] "r"(inptr_row[11]) + :"memory" + ); + int right_remain = remain; + int8_t *outptr_row = outptr + y * stride_out; + for (int x = 0; x < x_len - 7; x += 8) { + int8_t *out0 = outptr_row; + int8_t *out1 = out0 + stride_y; + asm volatile( + "ld1 {v0.8b}, [%[inptr0]], #8 \n" // q0=A0A1A2A3A4A5A6A7 + "ld1 {v1.8b}, [%[inptr1]], #8 \n" // q0=B0b1b2b3A4A5A6A7 + "ld1 {v2.8b}, [%[inptr2]], #8 \n" // q0=c0c1c2c3A4A5A6A7 + "ld1 {v3.8b}, [%[inptr3]], #8 \n" // q0=d0d1d2d3A4A5A6A7 + + "ld1 {v4.8b}, [%[inptr4]], #8 \n" // q0=A0A1A2A3A4A5A6A7 + "ld1 {v5.8b}, [%[inptr5]], #8 \n" // q0=B0b1b2b3A4A5A6A7 + "ld1 {v6.8b}, [%[inptr6]], #8 \n" // q0=c0c1c2c3A4A5A6A7 + "ld1 {v7.8b}, [%[inptr7]], #8 \n" // q0=d0d1d2d3A4A5A6A7 + + "trn1 v8.2s, v0.2s, v1.2s \n" //v0=a0a1a2a3'b0b1b2b3 -00 01 + "trn2 v12.2s, v0.2s, v1.2s \n" //v0=a4a5a6a7'b4b5b6b7 - 10 11 + "trn1 v9.2s, v2.2s, v3.2s \n" //v0=c0c1a2a3'd0b1b2b3 -02 03 + "trn2 v13.2s, v2.2s, v3.2s \n" //v0=c4a5a6a7'c4b5b6b7 - 12 13 + + "ld1 {v0.8b}, [%[inptr8]], #8 \n" // q0=A0A1A2A3A4A5A6A7 + "ld1 {v1.8b}, [%[inptr9]], #8 \n" // q0=B0b1b2b3A4A5A6A7 + "ld1 {v2.8b}, [%[inptr10]], #8 \n" // q0=c0c1c2c3A4A5A6A7 + "ld1 {v3.8b}, [%[inptr11]], #8 \n" // q0=d0d1d2d3A4A5A6A7 + + "st1 {v8.8b}, [%[outptr_row0]], #8 \n" + "st1 {v12.8b}, [%[outptr_row1]], #8 \n" + "st1 {v9.8b}, [%[outptr_row0]], #8 \n" + "st1 {v13.8b}, [%[outptr_row1]], #8 \n" + + "trn1 v10.2s, v4.2s, v5.2s \n" //v0=a0b0a0b0'a4b4a4b4 -04 05 + "trn2 v14.2s, v4.2s, v5.2s \n" //v0=a2b2a2b2'a6b6a6b6 -14 15 + "trn1 v11.2s, v6.2s, v7.2s \n" //v0=a0b0a0b0'a4b4a4b4 -06 07 + "trn2 v15.2s, v6.2s, v7.2s \n" //v0=a2b2a2b2'a6b6a6b6 -16 17 + + "trn1 v4.2s, v0.2s, v1.2s \n" //v0=a0b0a0b0'a4b4a4b4 -08 09 + "trn2 v5.2s, v0.2s, v1.2s \n" //v0=a2b2a2b2'a6b6a6b6 -18 19 + "trn1 v6.2s, v2.2s, v3.2s \n" //v0=a0b0a0b0'a4b4a4b4 -010 011 + "trn2 v7.2s, v2.2s, v3.2s \n" //v0=a2b2a2b2'a6b6a6b6 -110 111 + + "st1 {v10.8b}, [%[outptr_row0]], #8 \n" + "st1 {v14.8b}, [%[outptr_row1]], #8 \n" + "st1 {v11.8b}, [%[outptr_row0]], #8 \n" + "st1 {v15.8b}, [%[outptr_row1]], #8 \n" + + "st1 {v4.8b}, [%[outptr_row0]], #8 \n" + "st1 {v5.8b}, [%[outptr_row1]], #8 \n" + "st1 {v6.8b}, [%[outptr_row0]], #8 \n" + "st1 {v7.8b}, [%[outptr_row1]], #8 \n" + : [inptr0] "+r"(inptr_row[0]), [inptr1] "+r"(inptr_row[1]), + [inptr2] "+r"(inptr_row[2]), [inptr3] "+r"(inptr_row[3]), + [inptr4] "+r"(inptr_row[4]), [inptr5] "+r"(inptr_row[5]), + [inptr6] "+r"(inptr_row[6]), [inptr7] "+r"(inptr_row[7]), + [inptr8] "+r"(inptr_row[8]), [inptr9] "+r"(inptr_row[9]), + [inptr10] "+r"(inptr_row[10]), [inptr11] "+r"(inptr_row[11]), + [outptr_row0] "+r"(out0), [outptr_row1] "+r"(out1) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "cc", "memory" + ); + outptr_row += 96; + } + int8_t *out0 = outptr_row; + if (right_remain >= 4) { + asm volatile( + "mov x1, #4 \n" + "ld1 {v0.8b}, [%[inptr0]], x1 \n" // q0=A0A1A2A3A4A5A6A7 + "ld1 {v1.8b}, [%[inptr1]], x1 \n" // q0=B0b1b2b3A4A5A6A7 + "ld1 {v2.8b}, [%[inptr2]], x1 \n" // q0=c0c1c2c3A4A5A6A7 + "ld1 {v3.8b}, [%[inptr3]], x1 \n" // q0=d0d1d2d3A4A5A6A7 + + "ld1 {v4.8b}, [%[inptr4]], x1 \n" // q0=A0A1A2A3A4A5A6A7 + "ld1 {v5.8b}, [%[inptr5]], x1 \n" // q0=B0b1b2b3A4A5A6A7 + "ld1 {v6.8b}, [%[inptr6]], x1 \n" // q0=c0c1c2c3A4A5A6A7 + "ld1 {v7.8b}, [%[inptr7]], x1 \n" // q0=d0d1d2d3A4A5A6A7 + + "trn1 v8.2s, v0.2s, v1.2s \n" //v0=a0a1a2a3'b0b1b2b3 -00 01 + "trn1 v9.2s, v2.2s, v3.2s \n" //v0=c0c1a2a3'd0b1b2b3 -02 03 + + "ld1 {v12.8b}, [%[inptr8]], x1 \n" // q0=A0A1A2A3A4A5A6A7 + "ld1 {v13.8b}, [%[inptr9]], x1 \n" // q0=B0b1b2b3A4A5A6A7 + "ld1 {v14.8b}, [%[inptr10]], x1 \n" // q0=c0c1c2c3A4A5A6A7 + "ld1 {v15.8b}, [%[inptr11]], x1 \n" // q0=d0d1d2d3A4A5A6A7 + + "trn1 v10.2s, v4.2s, v5.2s \n" //v0=a0b0a0b0'a4b4a4b4 -04 05 + "trn1 v11.2s, v6.2s, v7.2s \n" //v0=a0b0a0b0'a4b4a4b4 -06 07 + + "trn1 v4.2s, v12.2s, v13.2s \n" //v0=a0b0a0b0'a4b4a4b4 -08 09 + "trn1 v6.2s, v14.2s, v15.2s \n" //v0=a0b0a0b0'a4b4a4b4 -010 011 + + "st1 {v8.8b}, [%[outptr_row0]], #8 \n" + "st1 {v9.8b}, [%[outptr_row0]], #8 \n" + "st1 {v10.8b}, [%[outptr_row0]], #8 \n" + "st1 {v11.8b}, [%[outptr_row0]], #8 \n" + "st1 {v4.8b}, [%[outptr_row0]], #8 \n" + "st1 {v6.8b}, [%[outptr_row0]], #8 \n" + : [inptr0] "+r"(inptr_row[0]), [inptr1] "+r"(inptr_row[1]), + [inptr2] "+r"(inptr_row[2]), [inptr3] "+r"(inptr_row[3]), + [inptr4] "+r"(inptr_row[4]), [inptr5] "+r"(inptr_row[5]), + [inptr6] "+r"(inptr_row[6]), [inptr7] "+r"(inptr_row[7]), + [inptr8] "+r"(inptr_row[8]), [inptr9] "+r"(inptr_row[9]), + [inptr10] "+r"(inptr_row[10]), [inptr11] "+r"(inptr_row[11]), \ + [outptr_row0] "+r"(out0) + : + : "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", + "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "cc", "memory" + ); + right_remain -= 4; + } + if (right_remain > 0) { + for (int i = 0; i < 12; i++) { + for (int x = 0; x < right_remain; x++) { + *out0++ = *inptr_row[i]++; + } + for (int x = 0; x < 4 - right_remain; x++) { + *out0++ = 0; + } + } + } + } +} +#endif //dotprod //NOLINT + +template <> +void gemm_prepack_int8(const int8_t* A_packed, + const int8_t* B, + const int* bias, + float32_t* C, + int M, + int N, + int K, + bool is_bias, + bool is_relu, + bool is_transB, + const float* scale, + ARMContext* ctx) { +#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) + if (ctx->has_dot()) { + gemm_prepack_sdot_int8(A_packed, + B, bias, C, M, N, K, is_bias, is_relu, + is_transB, scale, ctx); + } else { + gemm_prepack_oth_int8(A_packed, B, + bias, C, M, N, K, is_bias, is_relu, + is_transB, scale, ctx); + } +#else + gemm_prepack_oth_int8(A_packed, B, + bias, C, M, N, K, is_bias, is_relu, + is_transB, scale, ctx); +#endif +} + +template <> +void gemm_prepack_int8(const int8_t* A_packed, + const int8_t* B, + const int* bias, + int8_t* C, + int M, + int N, + int K, + bool is_bias, + bool is_relu, + bool is_transB, + const float* scale, + ARMContext* ctx) { +#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) + if (ctx->has_dot()) { + gemm_prepack_sdot_int8(A_packed, B, bias, + C, M, N, K, is_bias, is_relu, + is_transB, scale, ctx); + } else { + gemm_prepack_oth_int8(A_packed, B, bias, + C, M, N, K, is_bias, is_relu, + is_transB, scale, ctx); + } +#else + gemm_prepack_oth_int8(A_packed, B, bias, + C, M, N, K, is_bias, is_relu, + is_transB, scale, ctx); +#endif +} + +template <> +void gemm_prepack_int8(const int8_t* A_packed, + const int8_t* B, + const int* bias, + int32_t* C, + int M, + int N, + int K, + bool is_bias, + bool is_relu, + bool is_transB, + const float* scale, + ARMContext* ctx) { +#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) + if (ctx->has_dot()) { + gemm_prepack_sdot_int8(A_packed, B, + bias, C, M, N, K, is_bias, is_relu, + is_transB, scale, ctx); + } else { + gemm_prepack_oth_int8(A_packed, B, + bias, C, M, N, K, is_bias, is_relu, + is_transB, scale, ctx); + } +#else + gemm_prepack_oth_int8(A_packed, B, bias, + C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); +#endif +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/gemm_prepacked_int8.h b/lite/arm/math/gemm_prepacked_int8.h new file mode 100644 index 00000000000..6d02fdabef2 --- /dev/null +++ b/lite/arm/math/gemm_prepacked_int8.h @@ -0,0 +1,94 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/context.h" +#include "lite/core/cpu_info.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +const int KBLOCK_INT8 = 4; +#ifdef __aarch64__ +// for int7/int8 gemm +// const int HBLOCK = 4; +// const int NBLOCK = 16; +const int MBLOCK_INT8_OTH = 4; +const int NBLOCK_INT8_OTH = 16; + +const int MBLOCK_INT8_DOT = 8; +const int NBLOCK_INT8_DOT = 12; + +inline int get_hblock_int8(const ARMContext* ctx) { +#ifdef WITH_ARM_DOTPROD + if (ctx->has_dot()) { + return MBLOCK_INT8_DOT; + } else { + return MBLOCK_INT8_OTH; + } +#else + return MBLOCK_INT8_OTH; +#endif +} +#else +// const int HBLOCK = 4; +// const int WBLOCK = 8; +const int MBLOCK_INT8_OTH = 4; +const int NBLOCK_INT8_OTH = 8; + +inline int get_hblock_int8(const ARMContext* ctx) { return 4; } +#endif // __aarch64__ + +void prepackA_int8(void* out, + const void* in, + int ldin, + int m0, + int mmax, + int k0, + int kmax, + bool is_trans, + ARMContext* ctx); + +void prepackA_int8(TensorLite* tout, + const TensorLite& tin, + int m, + int k, + int group, + bool is_trans, + ARMContext* ctx); + +template +void gemm_prepack_int8(const int8_t* A_packed, + const int8_t* B, + const int* bias, + dtype* C, + int M, + int N, + int K, + bool is_bias, + bool is_relu, + bool is_transB, + const float* scale, + ARMContext* ctx); + +#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b)) + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/gemv_arm_int8.cc b/lite/arm/math/gemv_arm_int8.cc new file mode 100644 index 00000000000..568ee0a9d9b --- /dev/null +++ b/lite/arm/math/gemv_arm_int8.cc @@ -0,0 +1,480 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/gemv_arm_int8.h" +#include +#include "lite/arm/math/saturate.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +inline void write_gemv_out(const int* in, dtype* out, const float* scale); + +template <> +inline void write_gemv_out(const int* in, int* out, const float* scale) { + out[0] = in[0]; +} + +template <> +inline void write_gemv_out(const int* in, float* out, const float* scale) { + out[0] = in[0] * scale[0]; +} + +template <> +inline void write_gemv_out(const int* in, + signed char* out, + const float* scale) { + out[0] = saturate_cast(roundf(in[0] * scale[0])); +} + +template +bool gemv_int8(const int8_t* A, + const int8_t* x, + dtype* y, + bool transA, + int M, + int N, + const float* scale, + bool is_bias, + const int* bias, + bool is_relu) { + if (transA) { + LOG(ERROR) << "ERROR: sgemv, transA is not supported now"; + return false; + } + dtype* data_out = y; + const int8_t* data_in = x; + const int8_t* weights_ptr = A; + int cnt = N >> 4; + int tail = N & 15; + int flag_bias = is_bias ? 1 : 0; + +#ifdef __aarch64__ + int out_cnt = M >> 3; +#pragma omp parallel for + for (int j = 0; j < out_cnt; j++) { + int out_idx = j * 8; + dtype* out_ptr = data_out + out_idx; + const float* scale_ptr = scale + out_idx; + int ptr_out[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + const int8_t* ptr_in = data_in; + const int8_t* ptr_w0 = weights_ptr + (N * out_idx); + const int8_t* ptr_w1 = ptr_w0 + N; + const int8_t* ptr_w2 = ptr_w1 + N; + const int8_t* ptr_w3 = ptr_w2 + N; + const int8_t* ptr_w4 = ptr_w3 + N; + const int8_t* ptr_w5 = ptr_w4 + N; + const int8_t* ptr_w6 = ptr_w5 + N; + const int8_t* ptr_w7 = ptr_w6 + N; + const int* bias_ptr = is_bias ? (bias + out_idx) : nullptr; + int cnt_loop = cnt; + asm volatile( + "prfm pldl1keep, [%[in]] \n" /* preload din */ + "prfm pldl1keep, [%[w0]] \n" /* preload w0 */ + "prfm pldl1keep, [%[w1]] \n" /* preload w1 */ + "prfm pldl1keep, [%[w2]] \n" /* preload w2 */ + "prfm pldl1keep, [%[w3]] \n" /* preload w3 */ + "prfm pldl1keep, [%[w4]] \n" /* preload w4 */ + "prfm pldl1keep, [%[w5]] \n" /* preload w5 */ + "prfm pldl1keep, [%[w6]] \n" /* preload w6 */ + "prfm pldl1keep, [%[w7]] \n" /* preload w7 */ + "movi v0.4s, #0 \n" /* set out0 to 0 */ + "movi v1.4s, #0 \n" /* set out1 to 0 */ + "movi v2.4s, #0 \n" /* set out2 to 0 */ + "movi v3.4s, #0 \n" /* set out3 to 0 */ + "movi v4.4s, #0 \n" /* set out4 to 0 */ + "movi v5.4s, #0 \n" /* set out5 to 0 */ + "movi v6.4s, #0 \n" /* set out6 to 0 */ + "movi v7.4s, #0 \n" /* set out7 to 0 */ + /* check main loop */ + "cmp %w[cnt], #1 \n" /* check whether has main loop */ + "blt 2f \n" /* jump to tail */ + /* main loop */ + "1: \n" /* main loop */ + "ldr q8, [%[in]], #16 \n" /* load input, 16 int8 */ + "ldr q9, [%[w0]], #16 \n" /* load w0, 16 int8 */ + "ldr q10, [%[w1]], #16 \n" /* load w1, 16 int8 */ + "ldr q11, [%[w2]], #16 \n" /* load w2, 16 int8 */ + "ldr q12, [%[w3]], #16 \n" /* load w3, 16 int8 */ + "ldr q13, [%[w4]], #16 \n" /* load w4, 16 int8 */ + "ldr q14, [%[w5]], #16 \n" /* load w5, 16 int8 */ + "ldr q15, [%[w6]], #16 \n" /* load w6, 16 int8 */ + "ldr q16, [%[w7]], #16 \n" /* load w7, 16 int8 */ + /* mul, lower 8 int8 * int8 = int16 */ + "smull v18.8h, v8.8b, v9.8b \n" /* mul in * w0, low, 8 int8 */ + "smull v19.8h, v8.8b, v10.8b\n" /* mul in * w1, low, 8 int8 */ + "smull v20.8h, v8.8b, v11.8b\n" /* mul in * w2, low, 8 int8 */ + "smull v21.8h, v8.8b, v12.8b\n" /* mul in * w3, low, 8 int8 */ + "smull v22.8h, v8.8b, v13.8b\n" /* mul in * w4, low, 8 int8 */ + "smull v23.8h, v8.8b, v14.8b\n" /* mul in * w5, low, 8 int8 */ + "smull v24.8h, v8.8b, v15.8b\n" /* mul in * w6, low, 8 int8 */ + "smull v25.8h, v8.8b, v16.8b\n" /* mul in * w7, low, 8 int8 */ + /* mul, higher 8 int8 * int8 + int16 = int16 */ + "smlal2 v18.8h,v8.16b,v9.16b \n" /* mul in * w0, high, 8 int8 */ + "smlal2 v19.8h,v8.16b,v10.16b\n" /* mul in * w1, high, 8 int8 */ + "smlal2 v20.8h,v8.16b,v11.16b\n" /* mul in * w2, high, 8 int8 */ + "smlal2 v21.8h,v8.16b,v12.16b\n" /* mul in * w2, high, 8 int8 */ + "smlal2 v22.8h,v8.16b,v13.16b\n" /* mul in * w2, high, 8 int8 */ + "smlal2 v23.8h,v8.16b,v14.16b\n" /* mul in * w2, high, 8 int8 */ + "smlal2 v24.8h,v8.16b,v15.16b\n" /* mul in * w2, high, 8 int8 */ + "smlal2 v25.8h,v8.16b,v16.16b\n" /* mul in * w2, high, 8 int8 */ + "subs %w[cnt], %w[cnt], #1 \n" /* sub main loop count */ + /* add int16 to int32 */ + "sadalp v0.4s, v18.8h \n" /* pair acc, 8 int16 -> 4 int32 */ + "sadalp v1.4s, v19.8h \n" /* pair acc, 8 int16 -> 4 int32 */ + "sadalp v2.4s, v20.8h \n" /* pair acc, 8 int16 -> 4 int32 */ + "sadalp v3.4s, v21.8h \n" /* pair acc, 8 int16 -> 4 int32 */ + "sadalp v4.4s, v22.8h \n" /* pair acc, 8 int16 -> 4 int32 */ + "sadalp v5.4s, v23.8h \n" /* pair acc, 8 int16 -> 4 int32 */ + "sadalp v6.4s, v24.8h \n" /* pair acc, 8 int16 -> 4 int32 */ + "sadalp v7.4s, v25.8h \n" /* pair acc, 8 int16 -> 4 int32 */ + "bne 1b \n" /* jump to main loop */ + /* pair add to final result */ + "2: \n" /* reduce to scale */ + "addp v8.4s , v0.4s , v1.4s \n" /* pair add to 4 int32*/ + "addp v9.4s , v2.4s , v3.4s \n" /* pair add to 4 int32*/ + "addp v10.4s, v4.4s , v5.4s \n" /* pair add to 4 int32*/ + "addp v11.4s, v6.4s , v7.4s \n" /* pair add to 4 int32*/ + + "addp v12.4s, v8.4s , v9.4s \n" /* pair add to 4 int32*/ + "addp v13.4s, v10.4s, v11.4s \n" /* pair add to 4 int32*/ + + "cmp %w[bias], #1 \n" /* check whether has bias */ + "blt 0f \n" /* jump to tail */ + "ldp q8, q9, [%[bias_ptr]]\n" /* load bias to q8, q9*/ + "add v12.4s, v12.4s, v8.4s \n" /* add bias */ + "add v13.4s, v13.4s, v9.4s \n" /* add bias */ + "0: \n" /* end of add bias */ + + /* write to output */ + "stp q12, q13, [%[out]] \n" /* save result */ + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [w1] "+r"(ptr_w1), + [w2] "+r"(ptr_w2), + [w3] "+r"(ptr_w3), + [w4] "+r"(ptr_w4), + [w5] "+r"(ptr_w5), + [w6] "+r"(ptr_w6), + [w7] "+r"(ptr_w7), + [cnt] "+r"(cnt_loop) + : [out] "r"(ptr_out), [bias_ptr] "r"(bias_ptr), [bias] "r"(flag_bias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + for (int i = 0; i < tail; ++i) { + ptr_out[0] += ptr_in[i] * ptr_w0[i]; + ptr_out[1] += ptr_in[i] * ptr_w1[i]; + ptr_out[2] += ptr_in[i] * ptr_w2[i]; + ptr_out[3] += ptr_in[i] * ptr_w3[i]; + ptr_out[4] += ptr_in[i] * ptr_w4[i]; + ptr_out[5] += ptr_in[i] * ptr_w5[i]; + ptr_out[6] += ptr_in[i] * ptr_w6[i]; + ptr_out[7] += ptr_in[i] * ptr_w7[i]; + } + if (is_relu) { + ptr_out[0] = ptr_out[0] > 0 ? ptr_out[0] : 0; + ptr_out[1] = ptr_out[1] > 0 ? ptr_out[1] : 0; + ptr_out[2] = ptr_out[2] > 0 ? ptr_out[2] : 0; + ptr_out[3] = ptr_out[3] > 0 ? ptr_out[3] : 0; + ptr_out[4] = ptr_out[4] > 0 ? ptr_out[4] : 0; + ptr_out[5] = ptr_out[5] > 0 ? ptr_out[5] : 0; + ptr_out[6] = ptr_out[6] > 0 ? ptr_out[6] : 0; + ptr_out[7] = ptr_out[7] > 0 ? ptr_out[7] : 0; + } + + write_gemv_out(ptr_out, out_ptr, scale_ptr); + write_gemv_out(ptr_out + 1, out_ptr + 1, scale_ptr + 1); + write_gemv_out(ptr_out + 2, out_ptr + 2, scale_ptr + 2); + write_gemv_out(ptr_out + 3, out_ptr + 3, scale_ptr + 3); + write_gemv_out(ptr_out + 4, out_ptr + 4, scale_ptr + 4); + write_gemv_out(ptr_out + 5, out_ptr + 5, scale_ptr + 5); + write_gemv_out(ptr_out + 6, out_ptr + 6, scale_ptr + 6); + write_gemv_out(ptr_out + 7, out_ptr + 7, scale_ptr + 7); + } + +//! deal with remains +#pragma omp parallel for + for (int j = out_cnt * 8; j < M; j++) { + // int *ptr_out = data_out + j; + dtype* out_ptr = data_out + j; + const float* scale_ptr = scale + j; + int ptr_out[1] = {0}; + const int8_t* ptr_in = data_in; + const int8_t* ptr_w0 = weights_ptr + (N * j); + int cnt_loop = cnt; + int bias0 = is_bias ? bias[j] : 0; + asm volatile( + "prfm pldl1keep, [%[in]] \n" /* preload din */ + "prfm pldl1keep, [%[w0]] \n" /* preload w0 */ + "movi v0.4s, #0 \n" /* set out0 to 0 */ + "fmov s0, %w[bias0] \n" /* set bias */ + /* check main loop */ + "cmp %w[cnt], #1 \n" /* check whether has main loop */ + "blt 2f \n" /* jump to tail */ + /* main loop */ + "1: \n" /* main loop */ + "ldr q8, [%[in]], #16 \n" /* load input, 16 int8 */ + "ldr q9, [%[w0]], #16 \n" /* load w0, 16 int8 */ + /* mul, lower 8 int8 * int8 = int16 */ + "smull v18.8h, v8.8b, v9.8b \n" /* mul in * w0, low, 8 int8 */ + "subs %w[cnt], %w[cnt], #1 \n" /* sub main loop count */ + /* mul, higher 8 int8 * int8 + int16 = int16 */ + "smlal2 v18.8h,v8.16b,v9.16b \n" /* mul in * w0, high, 8 int8 */ + /* add int16 to int32 */ + "sadalp v0.4s, v18.8h \n" /* pair acc, 8 int16 -> 4 int32 */ + "bne 1b \n" /* jump to main loop */ + /* pair add to final result */ + "2: \n" /* reduce to scale */ + "addv s8, v0.4s \n" /* reduction to out0 */ + /* write to output */ + "str s8, [%[out]] \n" /* save result */ + : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop) + : [out] "r"(ptr_out), [bias0] "r"(bias0) + : "cc", "memory", "v0", "v8", "v9", "v18"); + for (int i = 0; i < tail; ++i) { + ptr_out[0] += ptr_in[i] * ptr_w0[i]; + } + if (is_relu) { + ptr_out[0] = ptr_out[0] > 0 ? ptr_out[0] : 0; + } + write_gemv_out(ptr_out, out_ptr, scale_ptr); + } +#else //__aarch64__ // NOLINT + int out_cnt = M >> 2; +#pragma omp parallel for + for (int j = 0; j < out_cnt; j++) { + int out_idx = j * 4; + dtype* out_ptr = data_out + out_idx; + const float* scale_ptr = scale + out_idx; + int ptr_out[4] = {0, 0, 0, 0}; + const int8_t* ptr_in = data_in; + const int8_t* ptr_w0 = weights_ptr + (N * out_idx); + const int8_t* ptr_w1 = ptr_w0 + N; + const int8_t* ptr_w2 = ptr_w1 + N; + const int8_t* ptr_w3 = ptr_w2 + N; + int cnt_loop = cnt; + int bias0 = is_bias ? bias[out_idx] : 0; + int bias1 = is_bias ? bias[out_idx + 1] : 0; + int bias2 = is_bias ? bias[out_idx + 2] : 0; + int bias3 = is_bias ? bias[out_idx + 3] : 0; + asm volatile( + "pld [%[in]] @ preload cache line, input\n" + "pld [%[w0]] @ preload cache line, weights r0\n" + "pld [%[w1]] @ preload cache line, weights r1\n" + "pld [%[w2]] @ preload cache line, weights r2\n" + "pld [%[w3]] @ preload cache line, weights r3\n" + "vmov.u32 q0, #0 @ set q0 to 0\n" + "vmov.u32 q1, #0 @ set q1 to 0\n" + "vmov.u32 q2, #0 @ set q2 to 0\n" + "vmov.u32 q3, #0 @ set q3 to 0\n" + "vmov s0, %[bias0] @ set q0 to bias0\n" + "vmov s4, %[bias1] @ set q1 to bias1\n" + "vmov s8, %[bias2] @ set q2 to bias2\n" + "vmov s12,%[bias3] @ set q3 to bias3\n" + // "vld1.32 {d20-d21}, %[bias] @ load bias data" + "cmp %[cnt], #1 @ check whether has main loop\n" + "blt 2f @ jump to pair add\n" + /* main loop */ + "1: @ main loop\n" + "vld1.8 {d8-d9}, [%[in]]! @ load input, q4\n" + "vld1.8 {d12-d13}, [%[w0]]! @ load weights r0, q6\n" + "vld1.8 {d14-d15}, [%[w1]]! @ load weights r1, q7\n" + "vld1.8 {d16-d17}, [%[w2]]! @ load weights r2, q8\n" + "vld1.8 {d18-d19}, [%[w3]]! @ load weights r3, q9\n" + /* mul, int8 * int8 = int16 */ + "vmull.s8 q12, d8, d12 @ mul add\n" + "vmull.s8 q13, d8, d14 @ mul add\n" + "vmull.s8 q14, d8, d16 @ mul add\n" + "vmull.s8 q15, d8, d18 @ mul add\n" + /* mla, int8 * int8 + int16 = int16 */ + "vmlal.s8 q12, d9, d13 @ mul add\n" + "vmlal.s8 q13, d9, d15 @ mul add\n" + "vmlal.s8 q14, d9, d17 @ mul add\n" + "vmlal.s8 q15, d9, d19 @ mul add\n" + /* pacc, int16 + int32 = int32 */ + "vpadal.s16 q0, q12 @ pair acc\n" + "vpadal.s16 q1, q13 @ pair acc\n" + "vpadal.s16 q2, q14 @ pair acc\n" + "vpadal.s16 q3, q15 @ pair acc\n" + "subs %[cnt], #1 @ sub loop count \n" + /* check loop end */ + "bne 1b @ jump to main loop\n" + /* pair add to final result */ + "2: @ pair add \n" + "vpadd.s32 d8, d0, d1 @ pair add, first step\n" + "vpadd.s32 d9, d2, d3 @ pair add, first step\n" + "vpadd.s32 d10, d4, d5 @ pair add, first step\n" + "vpadd.s32 d11, d6, d7 @ pair add, first step\n" + "vpadd.s32 d0, d8, d9 @ pair add, second step\n" + "vpadd.s32 d1, d10, d11 @ pair add, second step\n" + /* write output */ + "vst1.32 {d0-d1}, [%[out]] @ save result\n" + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [w1] "+r"(ptr_w1), + [w2] "+r"(ptr_w2), + [w3] "+r"(ptr_w3), + [cnt] "+r"(cnt_loop) + : [bias0] "r"(bias0), + [bias1] "r"(bias1), + [bias2] "r"(bias2), + [bias3] "r"(bias3), + [out] "r"(ptr_out) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q12", + "q13", + "q14", + "q15"); + for (int i = 0; i < tail; ++i) { + ptr_out[0] += ptr_in[i] * ptr_w0[i]; + ptr_out[1] += ptr_in[i] * ptr_w1[i]; + ptr_out[2] += ptr_in[i] * ptr_w2[i]; + ptr_out[3] += ptr_in[i] * ptr_w3[i]; + } + if (is_relu) { + ptr_out[0] = ptr_out[0] > 0 ? ptr_out[0] : 0; + ptr_out[1] = ptr_out[1] > 0 ? ptr_out[1] : 0; + ptr_out[2] = ptr_out[2] > 0 ? ptr_out[2] : 0; + ptr_out[3] = ptr_out[3] > 0 ? ptr_out[3] : 0; + } + write_gemv_out(ptr_out, out_ptr, scale_ptr); + write_gemv_out(ptr_out + 1, out_ptr + 1, scale_ptr + 1); + write_gemv_out(ptr_out + 2, out_ptr + 2, scale_ptr + 2); + write_gemv_out(ptr_out + 3, out_ptr + 3, scale_ptr + 3); + } +//! deal with remains +#pragma omp parallel for + for (int j = out_cnt * 4; j < M; j++) { + dtype* out_ptr = data_out + j; + const float* scale_ptr = scale + j; + int ptr_out[1] = {0}; + const int8_t* ptr_in = data_in; + const int8_t* ptr_w0 = weights_ptr + (N * j); + int cnt_loop = cnt; + int bias0 = is_bias ? bias[j] : 0; + asm volatile( + "pld [%[in]] @ preload cache line, " + "input\n" + "pld [%[w0]] @ preload cache line, weights r0\n" + "vmov.u32 q0, #0 @ set q0 to 0\n" + "vmov s0, %[bias0] @ set q0 to bias0\n" + "cmp %[cnt], #1 @ check whether has main loop\n" + "blt 2f @ jump to tail\n" + /* main loop */ + "1: @ main loop\n" + "vld1.8 {d24-d25}, [%[in]]! @ load input, q12\n" + "vld1.8 {d28-d29}, [%[w0]]! @ load weights q14\n" + /* mull int8 * int8 = int16*/ + "vmull.s8 q1, d24, d28 @ mul add\n" + "vmlal.s8 q1, d25, d29 @ mul add\n" + "subs %[cnt] , #1 @ sub loop count \n" + /* pacc int16 + int32 = int32*/ + "vpadal.s16 q0, q1 @ pair acc\n" + "bne 1b @ jump to main loop\n" + /* pair add to final result */ + "2: @ end processing\n" + "vpadd.s32 d2, d0, d1 @ pair add, first step\n" + "vpadd.s32 d0, d2, d2 @ pair add, final step\n" + /* write output */ + "vst1.32 {d0[0]}, [%[out]] @ save result\n" + : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop) + : [bias0] "r"(bias0), [out] "r"(ptr_out) + : "cc", "memory", "q0", "q1", "q12", "q13"); + for (int i = 0; i < tail; ++i) { + ptr_out[0] += ptr_in[i] * ptr_w0[i]; + } + if (is_relu) { + ptr_out[0] = ptr_out[0] > 0 ? ptr_out[0] : 0; + } + write_gemv_out(ptr_out, out_ptr, scale_ptr); + } +#endif //__aarch64__ // NOLINT + return true; +} + +template bool gemv_int8(const int8_t* A, + const int8_t* x, + float* y, + bool transA, + int M, + int N, + const float* scale, + bool is_bias, + const int* bias, + bool is_relu); +template bool gemv_int8(const int8_t* A, + const int8_t* x, + int* y, + bool transA, + int M, + int N, + const float* scale, + bool is_bias, + const int* bias, + bool is_relu); +template bool gemv_int8(const int8_t* A, + const int8_t* x, + signed char* y, + bool transA, + int M, + int N, + const float* scale, + bool is_bias, + const int* bias, + bool is_relu); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/gemv_arm_int8.h b/lite/arm/math/gemv_arm_int8.h new file mode 100644 index 00000000000..50818c741a4 --- /dev/null +++ b/lite/arm/math/gemv_arm_int8.h @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/cpu_info.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +// fixme now only support transA = false +template +bool gemv_int8(const int8_t* A, + const int8_t* x, + dtype* y, + bool transA, + int M, + int N, + const float* scale, + bool is_bias = false, + const int* bias = nullptr, + bool is_relu = false); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/gru_utils.h b/lite/arm/math/gru_utils.h new file mode 100644 index 00000000000..47e5d62aa1d --- /dev/null +++ b/lite/arm/math/gru_utils.h @@ -0,0 +1,434 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/arm/math/sgemm.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +struct GRUMetaValue { + T* gate_weight; + T* state_weight; + T* gate_value; + T* reset_output_value; + T* output_value; + T* prev_out_value; +}; + +template +inline void gru_add_with_bias( + const Dtype* din, const Dtype* bias, Dtype* dout, int batch, int size); + +template <> +inline void gru_add_with_bias( + const float* din, const float* bias, float* dout, int batch, int size) { +#pragma omp parallel for + for (int i = 0; i < batch; ++i) { + int j = 0; + auto din_batch = din + i * size; + auto dout_batch = dout + i * size; + float32x4_t vb0 = vld1q_f32(bias); + float32x4_t vin0 = vld1q_f32(din_batch); + float32x4_t vout0; + float32x4_t vout1; + float32x4_t vin1; + float32x4_t vb1; + for (; j < size - 7; j += 8) { + vin1 = vld1q_f32(din_batch + j + 4); + vb1 = vld1q_f32(bias + j + 4); + vout0 = vaddq_f32(vb0, vin0); + vout1 = vaddq_f32(vb1, vin1); + vb0 = vld1q_f32(bias + j + 8); + vin0 = vld1q_f32(din_batch + j + 8); + vst1q_f32(dout_batch + j, vout0); + vst1q_f32(dout_batch + j + 4, vout1); + } + for (; j < size; ++j) { + dout_batch[j] = din_batch[j] + bias[j]; + } + } +} + +template +static void gru_unit_reset_act_impl(float* updata_gate, + int stride_update, + float* reset_gate, + int stride_reset, + const float* hidden_prev, + int stride_hidden_prev, + float* reset_hidden_prev, + int stride_reset_hidden_prev, + int frame_size, + int batch_size) { +#pragma omp parallel for + for (int b = 0; b < batch_size; ++b) { + float32x4_t vpre0 = vdupq_n_f32(0.f); + float32x4_t vpre1 = vdupq_n_f32(0.f); + float prev = 0.f; + int i = 0; + for (; i < frame_size - 7; i += 8) { + float32x4_t vu0 = vld1q_f32(updata_gate + i); + float32x4_t vu1 = vld1q_f32(updata_gate + i + 4); + float32x4_t vr0 = vld1q_f32(reset_gate + i); + float32x4_t vr1 = vld1q_f32(reset_gate + i + 4); + + float32x4_t vau0 = lite::arm::math::vactive_f32(vu0); + float32x4_t vau1 = lite::arm::math::vactive_f32(vu1); + + if (hidden_prev) { + vpre0 = vld1q_f32(hidden_prev + i); + vpre1 = vld1q_f32(hidden_prev + i + 4); + } + + float32x4_t var0 = lite::arm::math::vactive_f32(vr0); + float32x4_t var1 = lite::arm::math::vactive_f32(vr1); + + vst1q_f32(updata_gate + i, vau0); + vst1q_f32(updata_gate + i + 4, vau1); + + float32x4_t vres0 = vmulq_f32(vpre0, var0); + float32x4_t vres1 = vmulq_f32(vpre1, var1); + + vst1q_f32(reset_gate + i, var0); + vst1q_f32(reset_gate + i + 4, var1); + vst1q_f32(reset_hidden_prev + i, vres0); + vst1q_f32(reset_hidden_prev + i + 4, vres1); + } + + for (; i < frame_size; ++i) { + updata_gate[i] = lite::arm::math::active_f32(updata_gate[i]); + reset_gate[i] = lite::arm::math::active_f32(reset_gate[i]); + if (hidden_prev) { + prev = hidden_prev[i]; + } + reset_hidden_prev[i] = reset_gate[i] * prev; + } + + updata_gate += stride_update; + reset_gate += stride_reset; + if (hidden_prev) { + hidden_prev += stride_hidden_prev; + } + reset_hidden_prev += stride_reset_hidden_prev; + } +} + +template +static void gru_unit_out_act_impl(bool origin_mode, + float* updata_gate, + int stride_update, + float* cell_state, + int stride_cell_state, + const float* hidden_prev, + int stride_hidden_prev, + float* hidden, + int stride_hidden, + int frame_size, + int batch_size) { +#pragma omp parallel for + for (int b = 0; b < batch_size; ++b) { + float32x4_t vpre0 = vdupq_n_f32(0.f); + float32x4_t vpre1 = vdupq_n_f32(0.f); + float prev = 0.f; + int i = 0; + if (origin_mode) { + for (; i < frame_size - 7; i += 8) { + float32x4_t vc0 = vld1q_f32(cell_state + i); + float32x4_t vc1 = vld1q_f32(cell_state + i + 4); + float32x4_t vu0 = vld1q_f32(updata_gate + i); + float32x4_t vu1 = vld1q_f32(updata_gate + i + 4); + + float32x4_t vac0 = lite::arm::math::vactive_f32(vc0); + float32x4_t vac1 = lite::arm::math::vactive_f32(vc1); + if (hidden_prev) { + vpre0 = vld1q_f32(hidden_prev + i); + vpre1 = vld1q_f32(hidden_prev + i + 4); + } + + float32x4_t vh0 = vmlsq_f32(vac0, vu0, vac0); + float32x4_t vh1 = vmlsq_f32(vac1, vu1, vac1); + + vst1q_f32(cell_state + i, vac0); + vst1q_f32(cell_state + i + 4, vac1); + + vh0 = vmlaq_f32(vh0, vu0, vpre0); + vh1 = vmlaq_f32(vh1, vu1, vpre1); + + vst1q_f32(hidden + i, vh0); + vst1q_f32(hidden + i + 4, vh1); + } + + for (; i < frame_size; ++i) { + if (hidden_prev) { + prev = hidden_prev[i]; + } + cell_state[i] = lite::arm::math::active_f32(cell_state[i]); + hidden[i] = + cell_state[i] * (1.f - updata_gate[i]) + updata_gate[i] * prev; + } + } else { + for (; i < frame_size - 7; i += 8) { + float32x4_t vc0 = vld1q_f32(cell_state + i); + float32x4_t vc1 = vld1q_f32(cell_state + i + 4); + float32x4_t vu0 = vld1q_f32(updata_gate + i); + float32x4_t vu1 = vld1q_f32(updata_gate + i + 4); + + float32x4_t vac0 = lite::arm::math::vactive_f32(vc0); + float32x4_t vac1 = lite::arm::math::vactive_f32(vc1); + + if (hidden_prev) { + vpre0 = vld1q_f32(hidden_prev + i); + vpre1 = vld1q_f32(hidden_prev + i + 4); + } + + float32x4_t vh0 = vmlsq_f32(vpre0, vpre0, vu0); + float32x4_t vh1 = vmlsq_f32(vpre1, vpre1, vu1); + + vst1q_f32(cell_state + i, vac0); + vst1q_f32(cell_state + i + 4, vac1); + + vh0 = vmlaq_f32(vh0, vu0, vac0); + vh1 = vmlaq_f32(vh1, vu1, vac1); + + vst1q_f32(hidden + i, vh0); + vst1q_f32(hidden + i + 4, vh1); + } + + for (; i < frame_size; ++i) { + cell_state[i] = lite::arm::math::active_f32(cell_state[i]); + if (hidden_prev) { + prev = hidden_prev[i]; + } + hidden[i] = + prev * (1.f - updata_gate[i]) + updata_gate[i] * cell_state[i]; + } + } + updata_gate += stride_update; + cell_state += stride_cell_state; + if (hidden_prev) { + hidden_prev += stride_hidden_prev; + } + hidden += stride_hidden; + } +} + +inline void gru_unit_reset_act(lite_api::ActivationType act_type, + GRUMetaValue value, + int frame_size, + int batch_size) { + auto updata_gate = value.gate_value; + auto reset_gate = value.gate_value + frame_size; + auto hidden_prev = value.prev_out_value; + auto reset_hidden_prev = value.reset_output_value; + int stride_update = 3 * frame_size; + int stride_reset = 3 * frame_size; + int stride_hidden_prev = frame_size; + int stride_reset_hidden_prev = frame_size; + + switch (act_type) { + case lite_api::ActivationType::kIndentity: + gru_unit_reset_act_impl( + updata_gate, + stride_update, + reset_gate, + stride_reset, + hidden_prev, + stride_hidden_prev, + reset_hidden_prev, + stride_reset_hidden_prev, + frame_size, + batch_size); + break; + case lite_api::ActivationType::kTanh: + gru_unit_reset_act_impl( + updata_gate, + stride_update, + reset_gate, + stride_reset, + hidden_prev, + stride_hidden_prev, + reset_hidden_prev, + stride_reset_hidden_prev, + frame_size, + batch_size); + break; + case lite_api::ActivationType::kSigmoid: + gru_unit_reset_act_impl( + updata_gate, + stride_update, + reset_gate, + stride_reset, + hidden_prev, + stride_hidden_prev, + reset_hidden_prev, + stride_reset_hidden_prev, + frame_size, + batch_size); + break; + case lite_api::ActivationType::kRelu: + gru_unit_reset_act_impl( + updata_gate, + stride_update, + reset_gate, + stride_reset, + hidden_prev, + stride_hidden_prev, + reset_hidden_prev, + stride_reset_hidden_prev, + frame_size, + batch_size); + break; + default: + break; + } +} + +inline void gru_unit_out_act(lite_api::ActivationType act_type, + bool origin_mode, + GRUMetaValue value, + int frame_size, + int batch_size) { + auto updata_gate = value.gate_value; + auto cell_state = value.gate_value + 2 * frame_size; + auto hidden_prev = value.prev_out_value; + auto hidden = value.output_value; + + int stride_update = 3 * frame_size; + int stride_cell_state = 3 * frame_size; + int stride_hidden_prev = frame_size; + int stride_hidden = frame_size; + + switch (act_type) { + case lite_api::ActivationType::kIndentity: + gru_unit_out_act_impl( + origin_mode, + updata_gate, + stride_update, + cell_state, + stride_cell_state, + hidden_prev, + stride_hidden_prev, + hidden, + stride_hidden, + frame_size, + batch_size); + break; + case lite_api::ActivationType::kTanh: + gru_unit_out_act_impl(origin_mode, + updata_gate, + stride_update, + cell_state, + stride_cell_state, + hidden_prev, + stride_hidden_prev, + hidden, + stride_hidden, + frame_size, + batch_size); + break; + case lite_api::ActivationType::kSigmoid: + gru_unit_out_act_impl( + origin_mode, + updata_gate, + stride_update, + cell_state, + stride_cell_state, + hidden_prev, + stride_hidden_prev, + hidden, + stride_hidden, + frame_size, + batch_size); + break; + case lite_api::ActivationType::kRelu: + gru_unit_out_act_impl(origin_mode, + updata_gate, + stride_update, + cell_state, + stride_cell_state, + hidden_prev, + stride_hidden_prev, + hidden, + stride_hidden, + frame_size, + batch_size); + break; + default: + break; + } +} + +template +struct GRUUnitFunctor { + static void compute(GRUMetaValue value, + int frame_size, + int batch_size, + const lite_api::ActivationType active_node, + const lite_api::ActivationType active_gate, + bool origin_mode, + ARMContext* ctx) { + if (value.prev_out_value) { + sgemm(false, + false, + batch_size, + frame_size * 2, + frame_size, + 1.f, + value.prev_out_value, + frame_size, + value.gate_weight, + frame_size * 2, + 1.f, + value.gate_value, + frame_size * 3, + nullptr, + false, + false, + ctx); + } + gru_unit_reset_act(active_gate, value, frame_size, batch_size); + + if (value.prev_out_value) { + sgemm(false, + false, + batch_size, + frame_size, + frame_size, + 1.f, + value.reset_output_value, + frame_size, + value.state_weight, + frame_size, + 1.f, + value.gate_value + frame_size * 2, + frame_size * 3, + nullptr, + false, + false, + ctx); + } + + gru_unit_out_act(active_node, origin_mode, value, frame_size, batch_size); + } +}; + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/im2sequence.cc b/lite/arm/math/im2sequence.cc new file mode 100644 index 00000000000..046339ccd70 --- /dev/null +++ b/lite/arm/math/im2sequence.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/im2sequence.h" +#include +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void im2sequence(const float* input, + const int input_c, + const int input_h, + const int input_w, + const int kernel_h, + const int kernel_w, + const int pad_top, + const int pad_bottom, + const int pad_left, + const int pad_right, + const int stride_h, + const int stride_w, + const int out_h, + const int out_w, + float* out, + Context* ctx) { + int window_size = kernel_h * kernel_w; + int out_rows = out_h * out_w; + int out_cols = input_c * window_size; + int H_pad = input_h + pad_top + pad_bottom; + int W_pad = input_w + pad_left + pad_right; + for (int h_id = 0; h_id < out_h; h_id++) { + for (int w_id = 0; w_id < out_w; w_id++) { + // consider dilation. + int start_h = h_id * stride_h - pad_top; + int start_w = w_id * stride_w - pad_left; + for (int c_id = 0; c_id < input_c; c_id++) { + for (int k_h_id = 0; k_h_id < kernel_h; k_h_id++) { + int in_h_id = start_h + k_h_id; + bool exceed_flag = (in_h_id < 0) || (in_h_id >= H_pad); + int out_start_id = + (h_id * out_w + w_id) * out_cols + c_id * window_size; + for (int k_w_id = 0; k_w_id < kernel_w; k_w_id++) { + int in_w_id = start_w + k_w_id; + exceed_flag = exceed_flag || (in_w_id < 0) || (in_w_id >= W_pad); + int input_id = (c_id * input_h + in_h_id) * input_w + in_w_id; + int out_id = out_start_id + k_h_id * kernel_w + k_w_id; + out[out_id] = exceed_flag ? 0.f : input[input_id]; + } + } + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/im2sequence.h b/lite/arm/math/im2sequence.h new file mode 100644 index 00000000000..5fd06c26088 --- /dev/null +++ b/lite/arm/math/im2sequence.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/context.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +void im2sequence(const float* input, + const int input_c, + const int input_h, + const int input_w, + const int kernel_h, + const int kernel_w, + const int pad_top, + const int pad_bottom, + const int pad_left, + const int pad_right, + const int stride_h, + const int stride_w, + const int out_h, + const int out_w, + float* out, + Context* ctx); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/increment.cc b/lite/arm/math/increment.cc new file mode 100644 index 00000000000..d44cb38ac99 --- /dev/null +++ b/lite/arm/math/increment.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/increment.h" +#include +#include +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +void increment(const int* input, + const int n, + const float step, + int* out, + Context* ctx) { + for (int i = 0; i < n; i++) { + out[i] = input[i] + step; + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/increment.h b/lite/arm/math/increment.h new file mode 100644 index 00000000000..80aec628854 --- /dev/null +++ b/lite/arm/math/increment.h @@ -0,0 +1,33 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/context.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +void increment(const int* input, + const int n, + const float step, + int* out, + Context* ctx); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/interpolate.cc b/lite/arm/math/interpolate.cc new file mode 100644 index 00000000000..8a4a07d6008 --- /dev/null +++ b/lite/arm/math/interpolate.cc @@ -0,0 +1,514 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/interpolate.h" +#include +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void bilinear_interp(const float* src, + int w_in, + int h_in, + float* dst, + int w_out, + int h_out, + float scale_x, + float scale_y, + bool with_align) { + int* buf = new int[w_out + h_out + w_out * 2 + h_out * 2]; + + int* xofs = buf; + int* yofs = buf + w_out; + + float* alpha = reinterpret_cast(buf + w_out + h_out); + float* beta = reinterpret_cast(buf + w_out + h_out + w_out * 2); + + float fx = 0.0f; + float fy = 0.0f; + int sx = 0; + int sy = 0; + if (with_align) { + scale_x = static_cast(w_in - 1) / (w_out - 1); + scale_y = static_cast(h_in - 1) / (h_out - 1); + // calculate x axis coordinate + for (int dx = 0; dx < w_out; dx++) { + fx = dx * scale_x; + sx = static_cast(fx); + fx -= sx; + xofs[dx] = sx; + alpha[dx * 2] = 1.f - fx; + alpha[dx * 2 + 1] = fx; + } + // calculate y axis coordinate + for (int dy = 0; dy < h_out; dy++) { + fy = dy * scale_y; + sy = static_cast(fy); + fy -= sy; + yofs[dy] = sy; + beta[dy * 2] = 1.f - fy; + beta[dy * 2 + 1] = fy; + } + } else { + scale_x = static_cast(w_in / w_out); + scale_y = static_cast(h_in / h_out); + // calculate x axis coordinate + for (int dx = 0; dx < w_out; dx++) { + fx = scale_x * (dx + 0.5f) - 0.5f; + fx = fx < 0 ? 0.f : fx; + sx = static_cast(fx); + fx -= sx; + xofs[dx] = sx; + alpha[dx * 2] = 1.f - fx; + alpha[dx * 2 + 1] = fx; + } + // calculate y axis coordinate + for (int dy = 0; dy < h_out; dy++) { + fy = scale_y * (dy + 0.5f) - 0.5f; + fy = fy < 0 ? 0.f : fy; + sy = static_cast(fy); + fy -= sy; + yofs[dy] = sy; + beta[dy * 2] = 1.f - fy; + beta[dy * 2 + 1] = fy; + } + } + float* rowsbuf0 = new float[w_out]; + float* rowsbuf1 = new float[w_out]; + float* rows0 = rowsbuf0; + float* rows1 = rowsbuf1; + // output w , h boundary + int w_bound = w_out; + int h_bound = h_out; + if (with_align) { + w_bound = ceil((w_in - 1) / scale_x); + h_bound = ceil((h_in - 1) / scale_y); + } else { + w_bound = ceil((w_in - 0.5f) / scale_x - 0.5f); + h_bound = ceil((h_in - 0.5f) / scale_y - 0.5f); + } + // h_bound loop + for (int dy = 0; dy < h_bound; dy++) { + int sy = yofs[dy]; + + const float* s0 = src + sy * w_in; + const float* s1 = src + (sy + 1) * w_in; + + const float* alphap = alpha; + float* rows0p = rows0; + float* rows1p = rows1; + + int dx = 0; + // w_bound loop + for (; dx + 1 < w_bound; dx += 2) { + int sx = xofs[dx]; + int sxn = xofs[dx + 1]; + const float* s0p = s0 + sx; + const float* s1p = s1 + sx; + const float* s0np = s0 + sxn; + const float* s1np = s1 + sxn; + + float32x4_t _a = vld1q_f32(alphap); + float32x2_t _s0 = vld1_f32(s0p); + float32x2_t _s1 = vld1_f32(s1p); + float32x2_t _s0n = vld1_f32(s0np); + float32x2_t _s1n = vld1_f32(s1np); + + float32x4_t _s0s0n = vcombine_f32(_s0, _s0n); + float32x4_t _ms0 = vmulq_f32(_s0s0n, _a); + float32x4_t _s1s1n = vcombine_f32(_s1, _s1n); + float32x4_t _ms1 = vmulq_f32(_s1s1n, _a); + + float32x2_t _rows0 = vpadd_f32(vget_low_f32(_ms0), vget_high_f32(_ms0)); + vst1_f32(rows0p + dx, _rows0); + float32x2_t _rows1 = vpadd_f32(vget_low_f32(_ms1), vget_high_f32(_ms1)); + vst1_f32(rows1p + dx, _rows1); + + alphap += 4; + } + // w_bound remain loop + for (; dx < w_bound; dx++) { + int sx = xofs[dx]; + const float* s0p = s0 + sx; + const float* s1p = s1 + sx; + + float a0 = alphap[0]; + float a1 = alphap[1]; + rows0p[dx] = s0p[0] * a0 + s0p[1] * a1; + rows1p[dx] = s1p[0] * a0 + s1p[1] * a1; + + alphap += 2; + } + + const float buffer1[2] = {*(src + sy * w_in + w_in - 1), + *(src + sy * w_in + w_in - 1)}; + const float buffer2[2] = {*(src + (sy + 1) * w_in + w_in - 1), + *(src + (sy + 1) * w_in + w_in - 1)}; + // w_bound - w_out loop + for (; dx + 1 < w_out; dx += 2) { + const float* s0p = buffer1; + const float* s1p = buffer2; + const float* s0np = buffer1; + const float* s1np = buffer2; + + float32x4_t _a = vld1q_f32(alphap); + float32x2_t _s0 = vld1_f32(s0p); + float32x2_t _s1 = vld1_f32(s1p); + float32x2_t _s0n = vld1_f32(s0np); + float32x2_t _s1n = vld1_f32(s1np); + + float32x4_t _s0s0n = vcombine_f32(_s0, _s0n); + float32x4_t _ms0 = vmulq_f32(_s0s0n, _a); + float32x4_t _s1s1n = vcombine_f32(_s1, _s1n); + float32x4_t _ms1 = vmulq_f32(_s1s1n, _a); + + float32x2_t _rows0 = vpadd_f32(vget_low_f32(_ms0), vget_high_f32(_ms0)); + vst1_f32(rows0p + dx, _rows0); + float32x2_t _rows1 = vpadd_f32(vget_low_f32(_ms1), vget_high_f32(_ms1)); + vst1_f32(rows1p + dx, _rows1); + + alphap += 4; + } + // w_bound - w_out remain loop + for (; dx < w_out; dx++) { + const float* s0p = buffer1; + const float* s1p = buffer2; + + float a0 = alphap[0]; + float a1 = alphap[1]; + rows0p[dx] = s0p[0] * a0 + s0p[1] * a1; + rows1p[dx] = s1p[0] * a0 + s1p[1] * a1; + + alphap += 2; + } + + float b0 = beta[0]; + float b1 = beta[1]; + + float* dp = dst + dy * w_out; + + int nn = w_out >> 3; + int remain = w_out - (nn << 3); + +#ifdef __aarch64__ + float32x4_t _b0 = vdupq_n_f32(b0); + float32x4_t _b1 = vdupq_n_f32(b1); + // calculate and store results + for (; nn > 0; nn--) { + float32x4_t _rows0 = vld1q_f32(rows0p); + float32x4_t _d = vmulq_f32(_rows0, _b0); + float32x4_t _rows1 = vld1q_f32(rows1p); + _d = vmlaq_f32(_d, _rows1, _b1); + + float32x4_t _rows0n = vld1q_f32(rows0p + 4); + float32x4_t _rows1n = vld1q_f32(rows1p + 4); + + float32x4_t _dn = vmulq_f32(_rows0n, _b0); + vst1q_f32(dp, _d); + _dn = vmlaq_f32(_dn, _rows1n, _b1); + vst1q_f32(dp + 4, _dn); + + dp += 8; + rows0p += 8; + rows1p += 8; + } + +#else + if (nn > 0) { + asm volatile( + "vdup.32 q0, %[b0] @dup b0 to q1\n" + "vdup.32 q1, %[b1] @dup b1 to q0\n" + "1: \n" + "vld1.32 {d4-d5}, [%[rows0p]]! @loads rows0p to q2\n" + "vld1.32 {d6-d7}, [%[rows1p]]! @loads rows0p to q3\n" + "vmul.f32 q2, q2, q0 @mul\n" + "vmla.f32 q2, q3, q1 @mul add\n" + "vst1.32 {d4-d5}, [%[out]]! @store out to q2 \n" + "pld [%[rows0p]] @preload rows0p\n" + + "vld1.32 {d4-d5}, [%[rows0p]]! @loads rows0p to q2\n" + "vld1.32 {d6-d7}, [%[rows1p]]! @load rows1p to q3\n" + "vmul.f32 q2, q2, q0 @mul\n" + "vmla.f32 q2, q3, q1 @mul add\n" + "vst1.32 {d4-d5}, [%[out]]! @store out to q2 \n" + "pld [%[rows1p]] @preload rows1p\n" + "subs %[loopc], #1 @loop count minus #1\n" + "bne 1b @jump to 1\n" + : [rows0p] "+r"(rows0p), + [rows1p] "+r"(rows1p), + [out] "+r"(dp), + [loopc] "+r"(nn) + : [b0] "r"(b0), [b1] "r"(b1) + : "cc", "memory", "q0", "q1", "q2", "q3"); + } +#endif + // calculate and store remain resluts + for (; remain; --remain) { + *dp++ = *rows0p++ * b0 + *rows1p++ * b1; + } + beta += 2; + } + + // h_bound - h_out loop + for (int dy = h_bound; dy < h_out; dy++) { + int sy = h_in - 1; + const float* s0 = src + sy * w_in; + const float* s1 = s0; + const float* alphap = alpha; + float* rows0p = rows0; + float* rows1p = rows1; + + int dx = 0; + // w_bound loop + for (; dx + 1 < w_bound; dx += 2) { + int sx = xofs[dx]; + int sxn = xofs[dx + 1]; + const float* s0p = s0 + sx; + const float* s1p = s1 + sx; + const float* s0np = s0 + sxn; + const float* s1np = s1 + sxn; + + float32x4_t _a = vld1q_f32(alphap); + float32x2_t _s0 = vld1_f32(s0p); + float32x2_t _s1 = vld1_f32(s1p); + float32x2_t _s0n = vld1_f32(s0np); + float32x2_t _s1n = vld1_f32(s1np); + + float32x4_t _s0s0n = vcombine_f32(_s0, _s0n); + float32x4_t _ms0 = vmulq_f32(_s0s0n, _a); + float32x4_t _s1s1n = vcombine_f32(_s1, _s1n); + float32x4_t _ms1 = vmulq_f32(_s1s1n, _a); + + float32x2_t _rows0 = vpadd_f32(vget_low_f32(_ms0), vget_high_f32(_ms0)); + vst1_f32(rows0p + dx, _rows0); + float32x2_t _rows1 = vpadd_f32(vget_low_f32(_ms1), vget_high_f32(_ms1)); + vst1_f32(rows1p + dx, _rows1); + + alphap += 4; + } + // w_bound remain loop + for (; dx < w_bound; dx++) { + int sx = xofs[dx]; + const float* s0p = s0 + sx; + float a0 = alphap[0]; + float a1 = alphap[1]; + rows0p[dx] = s0p[0] * a0 + s0p[1] * a1; + rows1p[dx] = rows0p[dx]; + + alphap += 2; + } + + const float buffer1[2] = {*(src + sy * w_in + w_in - 1), + *(src + sy * w_in + w_in - 1)}; + // w_bound - w_out loop + for (; dx + 1 < w_out; dx += 2) { + const float* s0p = buffer1; + const float* s1p = buffer1; + const float* s0np = buffer1; + const float* s1np = buffer1; + + float32x4_t _a = vld1q_f32(alphap); + float32x2_t _s0 = vld1_f32(s0p); + float32x2_t _s1 = vld1_f32(s1p); + float32x2_t _s0n = vld1_f32(s0np); + float32x2_t _s1n = vld1_f32(s1np); + + float32x4_t _s0s0n = vcombine_f32(_s0, _s0n); + float32x4_t _ms0 = vmulq_f32(_s0s0n, _a); + float32x4_t _s1s1n = vcombine_f32(_s1, _s1n); + float32x4_t _ms1 = vmulq_f32(_s1s1n, _a); + + float32x2_t _rows0 = vpadd_f32(vget_low_f32(_ms0), vget_high_f32(_ms0)); + vst1_f32(rows0p + dx, _rows0); + float32x2_t _rows1 = vpadd_f32(vget_low_f32(_ms1), vget_high_f32(_ms1)); + vst1_f32(rows1p + dx, _rows1); + + alphap += 4; + } + // w_bound - wout remain loop + for (; dx < w_out; dx++) { + const float* s0p = buffer1; + float a0 = alphap[0]; + float a1 = alphap[1]; + rows0p[dx] = s0p[0] * a0 + s0p[1] * a1; + rows1p[dx] = rows0p[dx]; + alphap += 2; + } + + float b0 = beta[0]; + float b1 = beta[1]; + + float* dp = dst + dy * w_out; + + int nn = w_out >> 3; + int remain = w_out - (nn << 3); + +#ifdef __aarch64__ + float32x4_t _b0 = vdupq_n_f32(b0); + float32x4_t _b1 = vdupq_n_f32(b1); + // calculate and store results + for (; nn > 0; nn--) { + float32x4_t _rows0 = vld1q_f32(rows0p); + float32x4_t _d = vmulq_f32(_rows0, _b0); + float32x4_t _rows1 = vld1q_f32(rows1p); + _d = vmlaq_f32(_d, _rows1, _b1); + + float32x4_t _rows0n = vld1q_f32(rows0p + 4); + float32x4_t _rows1n = vld1q_f32(rows1p + 4); + + float32x4_t _dn = vmulq_f32(_rows0n, _b0); + vst1q_f32(dp, _d); + _dn = vmlaq_f32(_dn, _rows1n, _b1); + vst1q_f32(dp + 4, _dn); + + dp += 8; + rows0p += 8; + rows1p += 8; + } + +#else + if (nn > 0) { + asm volatile( + "vdup.32 q0, %[b0] @dup b0 to q1\n" + "vdup.32 q1, %[b1] @dup b1 to q0\n" + "1: \n" + "vld1.32 {d4-d5}, [%[rows0p]]! @loads rows0p to q2\n" + "vld1.32 {d6-d7}, [%[rows1p]]! @loads rows0p to q3\n" + "vmul.f32 q2, q2, q0 @mul\n" + "vmla.f32 q2, q3, q1 @mul add\n" + "vst1.32 {d4-d5}, [%[out]]! @store out to q2 \n" + "pld [%[rows0p]] @preload rows0p\n" + + "vld1.32 {d4-d5}, [%[rows0p]]! @loads rows0p to q2\n" + "vld1.32 {d6-d7}, [%[rows1p]]! @load rows1p to q3\n" + "vmul.f32 q2, q2, q0 @mul\n" + "vmla.f32 q2, q3, q1 @mul add\n" + "vst1.32 {d4-d5}, [%[out]]! @store out to q2 \n" + "pld [%[rows1p]] @preload rows1p\n" + "subs %[loopc], #1 @loop count minus #1\n" + "bne 1b @jump to 1\n" + : [rows0p] "+r"(rows0p), + [rows1p] "+r"(rows1p), + [out] "+r"(dp), + [loopc] "+r"(nn) + : [b0] "r"(b0), [b1] "r"(b1) + : "cc", "memory", "q0", "q1", "q2", "q3"); + } +#endif + // calculate and store remain results + for (; remain; --remain) { + *dp++ = *rows0p++ * b0 + *rows1p++ * b1; + } + + beta += 2; + } + delete[] buf; + delete[] rowsbuf0; + delete[] rowsbuf1; +} + +void nearest_interp(const float* src, + int w_in, + int h_in, + float* dst, + int w_out, + int h_out, + float scale_x, + float scale_y, + bool with_align) { + float scale_w_new = (with_align) + ? (static_cast(w_in - 1) / (w_out - 1)) + : (static_cast(w_in) / (w_out)); + float scale_h_new = (with_align) + ? (static_cast(h_in - 1) / (h_out - 1)) + : (static_cast(h_in) / (h_out)); + +#pragma omp parallel for collapse(2) schedule(static) + for (int h = 0; h < h_out; ++h) { + for (int w = 0; w < w_out; ++w) { + int near_x = static_cast(scale_w_new * w + 0.5); + int near_y = static_cast(scale_h_new * h + 0.5); + near_x = near_x < 0 ? 0 : near_x; + near_y = near_y < 0 ? 0 : near_y; + dst[h * w_out + w] = src[near_y * w_in + near_x]; + } + } +} + +void interpolate(lite::Tensor* X, + lite::Tensor* OutSize, + lite::Tensor* Out, + int out_height, + int out_width, + float height_scale, + float width_scale, + bool with_align, + std::string interpolate_type) { + if (out_width > 0 && out_height > 0) { + height_scale = static_cast(out_height / X->dims()[2]); + width_scale = static_cast(out_width / X->dims()[3]); + } + if (OutSize != nullptr) { + auto OutSize_data = OutSize->data(); + int h_out = OutSize_data[0]; // HW + int w_out = OutSize_data[1]; // HW + int num_cout = Out->dims()[0]; + int c_cout = Out->dims()[1]; + Out->Resize({num_cout, c_cout, h_out, w_out}); + } + + float* dout = Out->mutable_data(); + const float* din = X->data(); + int out_num = Out->dims()[0]; + int out_c = Out->dims()[1]; + int count = out_num * out_c; + int in_h = X->dims()[2]; + int in_w = X->dims()[3]; + int out_h = Out->dims()[2]; + int out_w = Out->dims()[3]; + int spatial_in = in_h * in_w; + int spatial_out = out_h * out_w; + for (int i = 0; i < count; ++i) { + if ("Bilinear" == interpolate_type) { + bilinear_interp(din + spatial_in * i, + in_w, + in_h, + dout + spatial_out * i, + out_w, + out_h, + 1.f / width_scale, + 1.f / height_scale, + with_align); + } else if ("Nearest" == interpolate_type) { + nearest_interp(din + spatial_in * i, + in_w, + in_h, + dout + spatial_out * i, + out_w, + out_h, + 1.f / width_scale, + 1.f / height_scale, + with_align); + } + } +} + +} /* namespace math */ +} /* namespace arm */ +} /* namespace lite */ +} /* namespace paddle */ diff --git a/lite/arm/math/interpolate.h b/lite/arm/math/interpolate.h new file mode 100644 index 00000000000..be250f6a5e7 --- /dev/null +++ b/lite/arm/math/interpolate.h @@ -0,0 +1,58 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void bilinear_interp(const float* src, + int w_in, + int h_in, + float* dst, + int w_out, + int h_out, + float scale_x, + float scale_y, + bool with_align); + +void nearest_interp(const float* src, + int w_in, + int h_in, + float* dst, + int w_out, + int h_out, + float scale_x, + float scale_y, + bool with_align); + +void interpolate(lite::Tensor* X, + lite::Tensor* OutSize, + lite::Tensor* Out, + int out_height, + int out_width, + float height_scale, + float width_scale, + bool with_align, + std::string interpolate_type); + +} /* namespace math */ +} /* namespace arm */ +} /* namespace lite */ +} /* namespace paddle */ diff --git a/lite/arm/math/lrn.cc b/lite/arm/math/lrn.cc new file mode 100644 index 00000000000..4297c221eab --- /dev/null +++ b/lite/arm/math/lrn.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/lrn.h" +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void compute_across_channels(const float* din, + float* dout, + int num, + int channel, + int h, + int w, + int local_size, + float alpha, + float beta, + float k) { + int channel_size = h * w; + int cnt = channel_size / 4; + int remain = channel_size % 4; + int pre_pad = (local_size - 1) / 2; + int post_pad = local_size - pre_pad - 1; + float32x4_t k_val = vdupq_n_f32(k); + float32x4_t alpha_val = vdupq_n_f32(alpha); + float32x4_t beta_val = vdupq_n_f32(-beta); + for (int n = 0; n < num; ++n) { + const float* din_ptr = din + n * channel * channel_size; + float* dout_ptr = dout + n * channel * channel_size; + for (int c = 0; c < channel; ++c) { + const float* din_ch_ptr = din_ptr + c * channel_size; + float* dout_ch_ptr = dout_ptr + c * channel_size; + int cs = (c - pre_pad) < 0 ? 0 : (c - pre_pad); + int ce = (c + post_pad) >= channel ? channel : (c + pre_pad + 1); + for (int i = 0; i < cnt; ++i) { + int idx = i * 4; + float32x4_t sum = vdupq_n_f32(0.f); + float32x4_t din = vld1q_f32(din_ch_ptr); + for (int k = cs; k < ce; ++k) { + float32x4_t v0 = vld1q_f32(&din_ptr[k * channel_size + idx]); + sum = vmlaq_f32(sum, v0, v0); + } + sum = vmulq_f32(sum, alpha_val); + sum = vaddq_f32(sum, k_val); + float32x4_t res0 = pow_ps(sum, beta_val); + float32x4_t res1 = vmulq_f32(din, res0); + vst1q_f32(dout_ch_ptr, res1); + dout_ch_ptr += 4; + din_ch_ptr += 4; + } + int idx = cnt * 4; + for (int i = 0; i < remain; ++i) { + float sum = 0.0; + for (int k = cs; k < ce; ++k) { + sum += + din_ptr[k * channel_size + idx] * din_ptr[k * channel_size + idx]; + } + sum = k + sum * alpha; + dout_ch_ptr[0] = din_ch_ptr[0] * pow(sum, -beta); + dout_ch_ptr++; + din_ch_ptr++; + idx++; + } + } + } +} + +template <> +void compute_within_channels(const float* din, + float* dout, + int num, + int channel, + int h, + int w, + int local_size, + float alpha, + float beta, + float k) { + LOG(ERROR) << "unsupported method!!"; + return; +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/lrn.h b/lite/arm/math/lrn.h new file mode 100644 index 00000000000..03551231893 --- /dev/null +++ b/lite/arm/math/lrn.h @@ -0,0 +1,49 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void compute_across_channels(const T* din, + T* dout, + int num, + int channel, + int h, + int w, + int local_size, + float alpha, + float beta, + float k); + +template +void compute_within_channels(const T* din, + T* dout, + int num, + int channel, + int h, + int w, + int local_size, + float alpha, + float beta, + float k); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/multiclass_nms.cc b/lite/arm/math/multiclass_nms.cc new file mode 100644 index 00000000000..3baeb2d8443 --- /dev/null +++ b/lite/arm/math/multiclass_nms.cc @@ -0,0 +1,299 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/multiclass_nms.h" +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +static bool sort_score_pair_descend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +template +void get_max_score_index(const dtype* scores, + int num, + float threshold, + int top_k, + std::vector>* score_index_vec) { + //! Generate index score pairs. + for (int i = 0; i < num; ++i) { + if (scores[i] > threshold) { + score_index_vec->push_back(std::make_pair(scores[i], i)); + } + } + + //! Sort the score pair according to the scores in descending order + std::stable_sort(score_index_vec->begin(), + score_index_vec->end(), + sort_score_pair_descend); + + //! Keep top_k scores if needed. + if (top_k > -1 && top_k < score_index_vec->size()) { + score_index_vec->resize(top_k); + } +} + +template +dtype bbox_size(const dtype* bbox, bool normalized = true) { + if (bbox[2] < bbox[0] || bbox[3] < bbox[1]) { + // If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0. + return dtype(0.); + } else { + const dtype width = bbox[2] - bbox[0]; + const dtype height = bbox[3] - bbox[1]; + + if (normalized) { + return width * height; + } else { + // If bbox is not within range [0, 1]. + return (width + 1) * (height + 1); + } + } +} + +template +dtype jaccard_overlap(const dtype* bbox1, const dtype* bbox2) { + if (bbox2[0] > bbox1[2] || bbox2[2] < bbox1[0] || bbox2[1] > bbox1[3] || + bbox2[3] < bbox1[1]) { + return dtype(0.); + } else { + const dtype inter_xmin = std::max(bbox1[0], bbox2[0]); + const dtype inter_ymin = std::max(bbox1[1], bbox2[1]); + const dtype inter_xmax = std::min(bbox1[2], bbox2[2]); + const dtype inter_ymax = std::min(bbox1[3], bbox2[3]); + + const dtype inter_width = inter_xmax - inter_xmin; + const dtype inter_height = inter_ymax - inter_ymin; + const dtype inter_size = inter_width * inter_height; + + const dtype bbox1_size = bbox_size(bbox1); + const dtype bbox2_size = bbox_size(bbox2); + + return inter_size / (bbox1_size + bbox2_size - inter_size); + } +} + +template +void apply_nms_fast(const dtype* bboxes, + const dtype* scores, + int num, + float score_threshold, + float nms_threshold, + float eta, + int top_k, + std::vector* indices) { + // Get top_k scores (with corresponding indices). + std::vector> score_index_vec; + get_max_score_index(scores, num, score_threshold, top_k, &score_index_vec); + + // Do nms. + float adaptive_threshold = nms_threshold; + indices->clear(); + + while (score_index_vec.size() != 0) { + const int idx = score_index_vec.front().second; + bool keep = true; + + for (int k = 0; k < indices->size(); ++k) { + if (keep) { + const int kept_idx = (*indices)[k]; + float overlap = + jaccard_overlap(bboxes + idx * 4, bboxes + kept_idx * 4); + keep = overlap <= adaptive_threshold; + } else { + break; + } + } + + if (keep) { + indices->push_back(idx); + } + + score_index_vec.erase(score_index_vec.begin()); + + if (keep && eta < 1 && adaptive_threshold > 0.5) { + adaptive_threshold *= eta; + } + } +} + +template +void multiclass_nms(const dtype* bbox_cpu_data, + const dtype* conf_cpu_data, + std::vector* result, + const std::vector& priors, + int class_num, + int background_id, + int keep_topk, + int nms_topk, + float conf_thresh, + float nms_thresh, + float nms_eta, + bool share_location) { + int num_kept = 0; + std::vector>> all_indices; + int64_t conf_offset = 0; + int64_t bbox_offset = 0; + for (int i = 0; i < priors.size(); ++i) { + std::map> indices; + int num_det = 0; + int num_priors = priors[i]; + + int conf_idx = class_num * conf_offset; + int bbox_idx = + share_location ? bbox_offset * 4 : bbox_offset * 4 * class_num; + + for (int c = 0; c < class_num; ++c) { + if (c == background_id) { + // Ignore background class + continue; + } + + const dtype* cur_conf_data = conf_cpu_data + conf_idx + c * num_priors; + const dtype* cur_bbox_data = bbox_cpu_data + bbox_idx; + + if (!share_location) { + cur_bbox_data += c * num_priors * 4; + } + + apply_nms_fast(cur_bbox_data, + cur_conf_data, + num_priors, + conf_thresh, + nms_thresh, + nms_eta, + nms_topk, + &(indices[c])); + num_det += indices[c].size(); + } + + if (keep_topk > -1 && num_det > keep_topk) { + std::vector>> score_index_pairs; + + for (auto it = indices.begin(); it != indices.end(); ++it) { + int label = it->first; + const std::vector& label_indices = it->second; + + for (int j = 0; j < label_indices.size(); ++j) { + int idx = label_indices[j]; + float score = conf_cpu_data[conf_idx + label * num_priors + idx]; + score_index_pairs.push_back( + std::make_pair(score, std::make_pair(label, idx))); + } + } + + // Keep top k results per image. + std::stable_sort(score_index_pairs.begin(), + score_index_pairs.end(), + sort_score_pair_descend>); + score_index_pairs.resize(keep_topk); + // Store the new indices. + std::map> new_indices; + + for (int j = 0; j < score_index_pairs.size(); ++j) { + int label = score_index_pairs[j].second.first; + int idx = score_index_pairs[j].second.second; + new_indices[label].push_back(idx); + } + + all_indices.push_back(new_indices); + num_kept += keep_topk; + } else { + all_indices.push_back(indices); + num_kept += num_det; + } + conf_offset += num_priors; + bbox_offset += num_priors; + } + + if (num_kept == 0) { + (*result).clear(); + return; + } else { + (*result).resize(num_kept * 7); + } + + int count = 0; + + conf_offset = 0; + bbox_offset = 0; + for (int i = 0; i < priors.size(); ++i) { + int num_priors = priors[i]; + int conf_idx = class_num * conf_offset; + int bbox_idx = + share_location ? bbox_offset * 4 : bbox_offset * 4 * class_num; + + for (auto it = all_indices[i].begin(); it != all_indices[i].end(); ++it) { + int label = it->first; + std::vector& indices = it->second; + const dtype* cur_conf_data = + conf_cpu_data + conf_idx + label * num_priors; + const dtype* cur_bbox_data = bbox_cpu_data + bbox_idx; + + if (!share_location) { + cur_bbox_data += label * num_priors * 4; + } + + for (int j = 0; j < indices.size(); ++j) { + int idx = indices[j]; + (*result)[count * 7] = i; + (*result)[count * 7 + 1] = label; + (*result)[count * 7 + 2] = cur_conf_data[idx]; + + for (int k = 0; k < 4; ++k) { + (*result)[count * 7 + 3 + k] = cur_bbox_data[idx * 4 + k]; + } + + ++count; + } + } + conf_offset += num_priors; + bbox_offset += num_priors; + } +} + +template float jaccard_overlap(const float* bbox1, const float* bbox2); + +template void apply_nms_fast(const float* bboxes, + const float* scores, + int num, + float score_threshold, + float nms_threshold, + float eta, + int top_k, + std::vector* indices); + +template void multiclass_nms(const float* bbox_cpu_data, + const float* conf_cpu_data, + std::vector* result, + const std::vector& priors, + int class_num, + int background_id, + int keep_topk, + int nms_topk, + float conf_thresh, + float nms_thresh, + float nms_eta, + bool share_location); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/multiclass_nms.h b/lite/arm/math/multiclass_nms.h new file mode 100644 index 00000000000..a5f39b64620 --- /dev/null +++ b/lite/arm/math/multiclass_nms.h @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void multiclass_nms(const dtype* bbox_cpu_data, + const dtype* conf_cpu_data, + std::vector* result, + const std::vector& priors, + int class_num, + int background_id, + int keep_topk, + int nms_topk, + float conf_thresh, + float nms_thresh, + float nms_eta, + bool share_location); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/negative.cc b/lite/arm/math/negative.cc new file mode 100644 index 00000000000..5da2b8e2855 --- /dev/null +++ b/lite/arm/math/negative.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/negative.h" +#include +#include +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void negative_func(const float* din, float* dout, int num) { + for (int i = 0; i < num; i++) { + dout[i] = -din[i]; + LOG(INFO) << "arm i:" << i; + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/negative.h b/lite/arm/math/negative.h new file mode 100644 index 00000000000..9a5648743da --- /dev/null +++ b/lite/arm/math/negative.h @@ -0,0 +1,33 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/operators/op_params.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void negative_func(const T* din, T* dout, int num); +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/norm.cc b/lite/arm/math/norm.cc new file mode 100644 index 00000000000..978a147c59e --- /dev/null +++ b/lite/arm/math/norm.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/norm.h" +#include +#include +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void norm(const float* input, + const int pre_n, + const int n, + const int post_n, + const float epsilon, + float* out, + Context* ctx) { + for (int i = 0; i < pre_n; i++) { + for (int k = 0; k < post_n; k++) { + float sum = epsilon; + const float* in_tmp = input + i * n * post_n + k; + for (int j = 0; j < n; j++) { + sum += in_tmp[j * post_n] * in_tmp[j * post_n]; + } + sum = std::sqrt(sum); + float* out_tmp = out + i * n * post_n + k; + for (int j = 0; j < n; j++) { + out_tmp[j * post_n] = in_tmp[j * post_n] / sum; + } + } + } + LOG(INFO) << "norm math finished"; +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/norm.h b/lite/arm/math/norm.h new file mode 100644 index 00000000000..503d2c5af48 --- /dev/null +++ b/lite/arm/math/norm.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/context.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +void norm(const float* input, + const int pre_n, + const int n, + const int post_n, + const float epsilon, + float* out, + Context* ctx); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/packed_sgemm.cc b/lite/arm/math/packed_sgemm.cc new file mode 100644 index 00000000000..bcfb0e2a9f7 --- /dev/null +++ b/lite/arm/math/packed_sgemm.cc @@ -0,0 +1,3454 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/packed_sgemm.h" +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +#ifdef __aarch64__ +void prepackA_8x12(float *out, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax); + +void prepackA_trans_8x12(float *out, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax); + +void sgemm_prepacked_8x12(bool is_transB, + int M, + int N, + int K, + const float *A_packed, + const float *B, + int ldb, + float beta, + float *C, + int ldc, + const float *bias, + bool has_bias, + bool has_relu, + ARMContext *ctx); +#else +// for kA72 +void prepackA_6x8(float *out, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax); + +void prepackA_trans_6x8(float *out, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax); +// for kA73 +void prepackA_4x8(float *out, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax); + +void prepackA_trans_4x8(float *out, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax); + +// for kA72, 6x8 +void sgemm_prepacked_6x8(bool is_transB, + int M, + int N, + int K, + const float *A_packed, + const float *B, + int ldb, + float beta, + float *C, + int ldc, + const float *bias, + bool has_bias, + bool has_relu, + ARMContext *ctx); +// for kA73, 4x8 +void sgemm_prepacked_4x8(bool is_transB, + int M, + int N, + int K, + const float *A_packed, + const float *B, + int ldb, + float beta, + float *C, + int ldc, + const float *bias, + bool has_bias, + bool has_relu, + ARMContext *ctx); +#endif // __aarch64__ + +/** + * \brief input data is not transpose + * for arm-v7a, transform data to block x k x 6 layout + * for arm-v8a, transform data to block x k x 8 layout + */ +void prepackA(float *out, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax, + bool is_trans, + ARMContext *ctx) { +#ifdef __aarch64__ + if (is_trans) { + prepackA_trans_8x12(out, in, alpha, ldin, m0, mmax, k0, kmax); + } else { + prepackA_8x12(out, in, alpha, ldin, m0, mmax, k0, kmax); + } +#else + if (ctx->arch() == kA73) { + if (is_trans) { + prepackA_trans_4x8(out, in, alpha, ldin, m0, mmax, k0, kmax); + } else { + prepackA_4x8(out, in, alpha, ldin, m0, mmax, k0, kmax); + } + } else { + if (is_trans) { + prepackA_trans_6x8(out, in, alpha, ldin, m0, mmax, k0, kmax); + } else { + prepackA_6x8(out, in, alpha, ldin, m0, mmax, k0, kmax); + } + } +#endif +} + +void prepackA(TensorLite *tout, + const TensorLite &tin, + float alpha, + int m, + int k, + int group, + bool is_trans, + ARMContext *ctx) { + int hblock = get_hblock(ctx->arch()); + int m_roundup = hblock * ((m + hblock - 1) / hblock); + int group_size_round_up = ((m_roundup * k + 15) / 16) * 16; + if (tout->numel() < group_size_round_up * group) { + tout->Resize({group_size_round_up * group}); + } + int lda = k; + if (is_trans) { + lda = m; + } + for (int g = 0; g < group; ++g) { + const float *weights_group = tin.data() + g * m * k; + float *weights_trans_ptr = + tout->mutable_data() + g * group_size_round_up; + prepackA(weights_trans_ptr, + weights_group, + alpha, + lda, + 0, + m, + 0, + k, + is_trans, + ctx); + } +} + +/// a: m*k b: k*n c: m*n +void sgemm_prepack(bool is_transB, + int M, + int N, + int K, + const float *A_packed, + const float *B, + int ldb, + float beta, + float *C, + int ldc, + const float *bias, + bool has_bias, + bool has_relu, + ARMContext *ctx) { +#ifdef __aarch64__ + sgemm_prepacked_8x12(is_transB, + M, + N, + K, + A_packed, + B, + ldb, + beta, + C, + ldc, + bias, + has_bias, + has_relu, + ctx); +#else // armv7 + if (ctx->arch() == kA73) { + sgemm_prepacked_4x8(is_transB, + M, + N, + K, + A_packed, + B, + ldb, + beta, + C, + ldc, + bias, + has_bias, + has_relu, + ctx); + } else { + sgemm_prepacked_6x8(is_transB, + M, + N, + K, + A_packed, + B, + ldb, + beta, + C, + ldc, + bias, + has_bias, + has_relu, + ctx); + } +#endif // arm64 +} + +#ifdef __aarch64__ +void prepackA_8x12(float *dout, + const float *inptr, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax) { + int x_len = kmax - k0; + int stride = x_len * 8; + float zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(float) * x_len); + bool has_alpha = fabsf(alpha - 1.f) > 1e-8f; + +#pragma omp parallel for + for (int y = m0; y < mmax; y += 8) { + float *outptr = dout + stride * (y - m0) / 8; + + const float *inptr0 = inptr + y * ldin + k0; + const float *inptr1 = inptr0 + ldin; + const float *inptr2 = inptr1 + ldin; + const float *inptr3 = inptr2 + ldin; + const float *inptr4 = inptr3 + ldin; + const float *inptr5 = inptr4 + ldin; + const float *inptr6 = inptr5 + ldin; + const float *inptr7 = inptr6 + ldin; + + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + "prfm pldl1keep, [%[ptr4]] \n" + "prfm pldl1keep, [%[ptr4], #64] \n" + "prfm pldl1keep, [%[ptr5]] \n" + "prfm pldl1keep, [%[ptr5], #64] \n" + "prfm pldl1keep, [%[ptr6]] \n" + "prfm pldl1keep, [%[ptr6], #64] \n" + "prfm pldl1keep, [%[ptr7]] \n" + "prfm pldl1keep, [%[ptr7], #64] \n" + : + : [ptr0] "r"(inptr0), + [ptr1] "r"(inptr1), + [ptr2] "r"(inptr2), + [ptr3] "r"(inptr3), + [ptr4] "r"(inptr4), + [ptr5] "r"(inptr5), + [ptr6] "r"(inptr6), + [ptr7] "r"(inptr7) + : "memory"); + + int x = x_len; + //! cope with row index exceed real size, set to zero buffer + if ((y + 7) >= mmax) { + switch ((y + 7) - mmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + default: + break; + } + } + for (; x > 7; x -= 8) { + asm volatile( + "cbz %w[has_alpha], 0f\n" /* check alpha == 1.f? */ + "dup v31.4s, %w[alpha]\n" /* alpha to vector */ + "ldp q0, q1, [%[inptr0]], #32\n" /* load r0, a0~a7 */ + "ldp q2, q3, [%[inptr1]], #32\n" /* load r1, b0~b7 */ + "fmul v0.4s, v31.4s, v0.4s\n" /* mul alpha */ + "fmul v1.4s, v31.4s, v1.4s\n" /* mul alpha */ + "ldp q4, q5, [%[inptr2]], #32\n" /* load r2, c0~c7 */ + "fmul v2.4s, v31.4s, v2.4s\n" /* mul alpha */ + "fmul v3.4s, v31.4s, v3.4s\n" /* mul alpha */ + "ldp q6, q7, [%[inptr3]], #32\n" /* load r3, d0~d7 */ + "fmul v4.4s, v31.4s, v4.4s\n" /* mul alpha */ + "fmul v5.4s, v31.4s, v5.4s\n" /* mul alpha */ + "ldp q8, q9, [%[inptr4]], #32\n" /* load r4, e0~e7 */ + "fmul v6.4s, v31.4s, v6.4s\n" /* mul alpha */ + "fmul v7.4s, v31.4s, v7.4s\n" /* mul alpha */ + "ldp q10, q11, [%[inptr5]], #32\n" /* load r5, f0~f7 */ + "fmul v8.4s, v31.4s, v8.4s\n" /* mul alpha */ + "fmul v9.4s, v31.4s, v9.4s\n" /* mul alpha */ + "ldp q12, q13, [%[inptr6]], #32\n" /* load r6, g0~g7 */ + "fmul v10.4s, v31.4s, v10.4s\n" /* mul alpha */ + "fmul v11.4s, v31.4s, v11.4s\n" /* mul alpha */ + "ldp q14, q15, [%[inptr7]], #32\n" /* load r7, h0~h7 */ + "fmul v12.4s, v31.4s, v12.4s\n" /* mul alpha */ + "fmul v13.4s, v31.4s, v13.4s\n" /* mul alpha */ + "fmul v14.4s, v31.4s, v14.4s\n" /* mul alpha */ + "fmul v15.4s, v31.4s, v15.4s\n" /* mul alpha */ + "b 1f\n" /* to main process */ + "0: \n" /* alpha == 1 */ + "ldp q0, q1, [%[inptr0]], #32\n" /* load r0, a0~a7 */ + "ldp q2, q3, [%[inptr1]], #32\n" /* load r1, b0~b7 */ + "ldp q4, q5, [%[inptr2]], #32\n" /* load r2, c0~c7 */ + "ldp q6, q7, [%[inptr3]], #32\n" /* load r3, d0~d7 */ + "ldp q8, q9, [%[inptr4]], #32\n" /* load r4, e0~e7 */ + "ldp q10, q11, [%[inptr5]], #32\n" /* load r5, f0~f7 */ + "ldp q12, q13, [%[inptr6]], #32\n" /* load r6, g0~g7 */ + "ldp q14, q15, [%[inptr7]], #32\n" /* load r7, h0~h7 */ + "1: \n" /* main process */ + "trn1 v16.4s, v0.4s, v2.4s\n" /* a0b0a2b2*/ + "trn2 v17.4s, v0.4s, v2.4s\n" /* a1b1a3b3*/ + "trn1 v18.4s, v1.4s, v3.4s\n" /* a4b4a6b6*/ + "trn2 v19.4s, v1.4s, v3.4s\n" /* a5b5a7b7*/ + + "trn1 v20.4s, v4.4s, v6.4s\n" /* c0d0c2d2*/ + "trn2 v21.4s, v4.4s, v6.4s\n" /* c1d1c3d3*/ + "trn1 v22.4s, v5.4s, v7.4s\n" /* c4d4c6d6*/ + "trn2 v23.4s, v5.4s, v7.4s\n" /* c5d5c7d7*/ + + "trn1 v24.4s, v8.4s, v10.4s\n" /* e0f0e2f2*/ + "trn2 v25.4s, v8.4s, v10.4s\n" /* e1f1e3f3*/ + "trn1 v26.4s, v9.4s, v11.4s\n" /* e4f4e6f6*/ + "trn2 v27.4s, v9.4s, v11.4s\n" /* e5f5e7f7*/ + + "trn1 v28.4s, v12.4s, v14.4s\n" /* g0h0g2h2*/ + "trn2 v29.4s, v12.4s, v14.4s\n" /* g1h1g3h3*/ + "trn1 v30.4s, v13.4s, v15.4s\n" /* g4h4g6h6*/ + "trn2 v31.4s, v13.4s, v15.4s\n" /* g5h5g7h7*/ + + "trn1 v0.2d, v16.2d, v20.2d\n" /* a0b0c0d0 */ + "trn1 v1.2d, v24.2d, v28.2d\n" /* e0f0g0h0 */ + "trn1 v2.2d, v17.2d, v21.2d\n" /* a1b1c1d1 */ + "trn1 v3.2d, v25.2d, v29.2d\n" /* e1b1c1d1 */ + + "trn2 v4.2d, v16.2d, v20.2d\n" /* a2b2c2d2 */ + "trn2 v5.2d, v24.2d, v28.2d\n" /* e2f2g2h2 */ + "stp q0, q1, [%[outptr]], #32\n" /* save q0, q1, a0~h0*/ + "trn2 v6.2d, v17.2d, v21.2d\n" /* a3b3c3d3 */ + "trn2 v7.2d, v25.2d, v29.2d\n" /* e3f3g3h3 */ + "stp q2, q3, [%[outptr]], #32\n" /* save q2, q3, a1~h1*/ + + "trn1 v8.2d, v18.2d, v22.2d\n" /* a4b4c4d4 */ + "trn1 v9.2d, v26.2d, v30.2d\n" /* e4f4g4h4 */ + "stp q4, q5, [%[outptr]], #32\n" /* save q4, q5, a2~h2*/ + "trn1 v10.2d, v19.2d, v23.2d\n" /* a5b5c5d5 */ + "trn1 v11.2d, v27.2d, v31.2d\n" /* e5f5g5h5 */ + "stp q6, q7, [%[outptr]], #32\n" /* save q6, q7, a3~h3*/ + + "trn2 v12.2d, v18.2d, v22.2d\n" /* a6b6c6d6 */ + "trn2 v13.2d, v26.2d, v30.2d\n" /* e6f6g6h6 */ + "stp q8, q9, [%[outptr]], #32\n" /* save q8, q9, a4~h4*/ + "trn2 v14.2d, v19.2d, v23.2d\n" /* a7b7c7d7 */ + "trn2 v15.2d, v27.2d, v31.2d\n" /* e7f7g7h7 */ + "stp q10, q11, [%[outptr]], #32\n" /* save q10, q11, a5~h5*/ + + "stp q12, q13, [%[outptr]], #32\n" /* save q12, q13, a6~h6*/ + "stp q14, q15, [%[outptr]], #32\n" /* save q14, q15, a7~h7*/ + : [inptr0] "+r"(inptr0), + [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), + [inptr4] "+r"(inptr4), + [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), + [inptr7] "+r"(inptr7), + [outptr] "+r"(outptr) + : [alpha] "r"(alpha), [has_alpha] "r"(has_alpha) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31", + "cc", + "memory"); + } + + for (; x > 0; x--) { + if (has_alpha) { + *outptr++ = *inptr0++ * alpha; + *outptr++ = *inptr1++ * alpha; + *outptr++ = *inptr2++ * alpha; + *outptr++ = *inptr3++ * alpha; + *outptr++ = *inptr4++ * alpha; + *outptr++ = *inptr5++ * alpha; + *outptr++ = *inptr6++ * alpha; + *outptr++ = *inptr7++ * alpha; + } else { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + } + } + } +} + +void prepackA_trans_8x12(float *outptr, + const float *in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax) { + auto inptr = in + k0 * ldin + m0; + uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + int x_len = mmax - m0; + int y_len = kmax - k0; + int right_remain = x_len - 8 * (x_len / 8); + int stride_out = 8 * y_len; + + float32x4_t vzero = vdupq_n_f32(0.f); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + uint32x4_t vmask2 = + vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain)); + + bool has_alpha = fabsf(alpha - 1.f) > 1e-8f; + float32x4_t valpha = vdupq_n_f32(alpha); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const float *ptr0 = inptr + y * ldin; + const float *ptr1 = ptr0 + ldin; + const float *ptr2 = ptr1 + ldin; + const float *ptr3 = ptr2 + ldin; + + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + : + : [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3) + : "memory"); + + float *outptr_row_col = outptr + y * 8; + int i = 0; + for (; i < x_len - 7; i += 8) { + float32x4_t vr00 = vld1q_f32(ptr0); + float32x4_t vr01 = vld1q_f32(ptr0 + 4); + float32x4_t vr10 = vld1q_f32(ptr1); + float32x4_t vr11 = vld1q_f32(ptr1 + 4); + float32x4_t vr20 = vld1q_f32(ptr2); + float32x4_t vr21 = vld1q_f32(ptr2 + 4); + float32x4_t vr30 = vld1q_f32(ptr3); + float32x4_t vr31 = vld1q_f32(ptr3 + 4); + if (has_alpha) { + vr00 = vmulq_f32(vr00, valpha); + vr01 = vmulq_f32(vr01, valpha); + vr10 = vmulq_f32(vr10, valpha); + vr11 = vmulq_f32(vr11, valpha); + vr20 = vmulq_f32(vr20, valpha); + vr21 = vmulq_f32(vr21, valpha); + vr30 = vmulq_f32(vr30, valpha); + vr31 = vmulq_f32(vr31, valpha); + } + + vst1q_f32(outptr_row_col, vr00); + vst1q_f32(outptr_row_col + 4, vr01); + vst1q_f32(outptr_row_col + 8, vr10); + vst1q_f32(outptr_row_col + 12, vr11); + vst1q_f32(outptr_row_col + 16, vr20); + vst1q_f32(outptr_row_col + 20, vr21); + vst1q_f32(outptr_row_col + 24, vr30); + vst1q_f32(outptr_row_col + 28, vr31); + + ptr0 += 8; + ptr1 += 8; + ptr2 += 8; + ptr3 += 8; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + float32x4_t vr00 = vld1q_f32(ptr0); + float32x4_t vr01 = vld1q_f32(ptr0 + 4); + float32x4_t vr10 = vld1q_f32(ptr1); + float32x4_t vr11 = vld1q_f32(ptr1 + 4); + float32x4_t vr20 = vld1q_f32(ptr2); + float32x4_t vr21 = vld1q_f32(ptr2 + 4); + float32x4_t vr30 = vld1q_f32(ptr3); + float32x4_t vr31 = vld1q_f32(ptr3 + 4); + + if (has_alpha) { + vr00 = vmulq_f32(vr00, valpha); + vr01 = vmulq_f32(vr01, valpha); + vr10 = vmulq_f32(vr10, valpha); + vr11 = vmulq_f32(vr11, valpha); + vr20 = vmulq_f32(vr20, valpha); + vr21 = vmulq_f32(vr21, valpha); + vr30 = vmulq_f32(vr30, valpha); + vr31 = vmulq_f32(vr31, valpha); + } + + float32x4_t vr00_1 = vbslq_f32(vmask1, vr00, vzero); + float32x4_t vr01_1 = vbslq_f32(vmask2, vr01, vzero); + float32x4_t vr10_1 = vbslq_f32(vmask1, vr10, vzero); + float32x4_t vr11_1 = vbslq_f32(vmask2, vr11, vzero); + float32x4_t vr20_1 = vbslq_f32(vmask1, vr20, vzero); + float32x4_t vr21_1 = vbslq_f32(vmask2, vr21, vzero); + float32x4_t vr30_1 = vbslq_f32(vmask1, vr30, vzero); + float32x4_t vr31_1 = vbslq_f32(vmask2, vr31, vzero); + + vst1q_f32(outptr_row_col, vr00_1); + vst1q_f32(outptr_row_col + 4, vr01_1); + vst1q_f32(outptr_row_col + 8, vr10_1); + vst1q_f32(outptr_row_col + 12, vr11_1); + vst1q_f32(outptr_row_col + 16, vr20_1); + vst1q_f32(outptr_row_col + 20, vr21_1); + vst1q_f32(outptr_row_col + 24, vr30_1); + vst1q_f32(outptr_row_col + 28, vr31_1); + } + } + +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const float *ptr0 = inptr + y * ldin; + float *outptr_row_col = outptr + y * 8; + int i = 0; + for (; i < x_len - 7; i += 8) { + float32x4_t vr0 = vld1q_f32(ptr0); + float32x4_t vr1 = vld1q_f32(ptr0 + 4); + if (has_alpha) { + vr0 = vmulq_f32(vr0, valpha); + vr1 = vmulq_f32(vr1, valpha); + } + vst1q_f32(outptr_row_col, vr0); + vst1q_f32(outptr_row_col + 4, vr1); + + ptr0 += 8; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + float32x4_t vr0 = vld1q_f32(ptr0); + float32x4_t vr1 = vld1q_f32(ptr0 + 4); + + if (has_alpha) { + vr0 = vmulq_f32(vr0, valpha); + vr1 = vmulq_f32(vr1, valpha); + } + + float32x4_t vr0_1 = vbslq_f32(vmask1, vr0, vzero); + float32x4_t vr1_1 = vbslq_f32(vmask2, vr1, vzero); + + vst1q_f32(outptr_row_col, vr0_1); + vst1q_f32(outptr_row_col + 4, vr1_1); + } + } +} + +#else // __aarch64__ +void prepackA_6x8(float* outptr, + const float* inptr, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax) { + int x_len = kmax - k0; + float zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(float) * x_len); + + bool has_alpha = fabsf(alpha - 1.f) > 1e-8f; + float32x4_t valpha = vdupq_n_f32(alpha); + + for (int y = m0; y < mmax; y += 6) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + const float* inptr4 = inptr3 + ldin; + const float* inptr5 = inptr4 + ldin; + + int x = x_len; + if ((y + 5) >= mmax) { + switch ((y + 5) - mmax) { + case 4: + inptr1 = zerobuff; + case 3: + inptr2 = zerobuff; + case 2: + inptr3 = zerobuff; + case 1: + inptr4 = zerobuff; + case 0: + inptr5 = zerobuff; + default: + break; + } + } + + for (; x > 7; x -= 8) { + asm volatile( + "vld4.32 {d0-d3}, [%[inptr0]]! @ zip load r0, " + "q0,q1=r00,r04,r01,r05,r02,r06,r03,r07\n" + "vld4.32 {d4-d7}, [%[inptr1]]! @ zip load r1, " + "q2,q3=r10,r14,r11,r15,r12,r16,r13,r17\n" + "vld4.32 {d8-d11}, [%[inptr2]]! @ zip load r2, " + "q4,q5=r20,r24,r21,r25,r22,r26,r23,r27\n" + "vld4.32 {d12-d15}, [%[inptr3]]! @ zip load r3, " + "q6,q7=r30,r34,r31,r35,r32,r36,r33,r37\n" + "vld4.32 {d16-d19}, [%[inptr4]]! @ zip load r4, " + "q8,q9=r40,r44,r41,r45,r42,r46,r43,r47\n" + "vld4.32 {d20-d23}, [%[inptr5]]! @ zip load r5, " + "q10,q11=r50,r54,r51,r55,r52,r56,r53,r57\n" + "cmp %[has_alpha], #0\n" + "beq 0f\n" /* check whether alpha == 1? */ + "vmul.f32 q0, q0, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q1, q1, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q2, q2, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q3, q3, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q4, q4, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q5, q5, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q6, q6, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q7, q7, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q8, q8, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q9, q9, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q10, q10, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q11, q11, %q[alpha]\n" /* mul alpha */ + "0: \n" + "vtrn.32 q0, q2 @ trans data: q0=r00,r10,r01,r11; " + "q2=r04,r14,r05,r15\n" + "vtrn.32 q4, q6 @ trans data: q4=r20,r30,r21,r31; " + "q6=r24,r34,r25,r35\n" + "vtrn.32 q8, q10 @ trans data: q8=r40,r50,r41,r51; " + "q10=r44,r54,r45,r55\n" + + "vswp d1, d8 @ swap d1, d8, q0=r00,r10,r20,r30; " + "q4=r01,r11,r21,r31\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write q0:r00,r10,r20,r30\n" + "vst1.32 {d16}, [%[outptr]]! @ write d16(q8,low),r40,r50\n" + "vst1.32 {d8-d9}, [%[outptr]]! @ write q4:r01,r11,r21,r31\n" + "vst1.32 {d17}, [%[outptr]]! @ write d16(q8,high),r41,r51\n" + + "vtrn.32 q1, q3 @ trans data: q1=r02,r12,r03,r13; " + "q3=r06,r16,r07,r17\n" + "vtrn.32 q5, q7 @ trans data: q5=r22,r32,r23,r33; " + "q7=r26,r36,r27,r37\n" + "vtrn.32 q9, q11 @ trans data: q9=r42,r52,r43,r53; " + "q11=r46,r56,r47,r57\n" + + "vswp d3, d10 @ swap d3, d10, " + "q1=r02,r12,r22,r32; q5=r03,r13,r23,r33\n" + "vst1.32 {d2-d3}, [%[outptr]]! @ write q1:r02,r12,r22,r32\n" + "vst1.32 {d18}, [%[outptr]]! @ write d18(q9,low),r42,r52\n" + "vst1.32 {d10-d11},[%[outptr]]! @ write q5:r03,r13,r23,r33\n" + "vst1.32 {d19}, [%[outptr]]! @ write d19(q9,high),r43,r53\n" + + "vswp d5, d12 @ swap d5, d12,q2=r04,r14,r24,r34; " + "q6=r05,r15,r25,r35\n" + "vst1.32 {d4-d5}, [%[outptr]]! @ write q2:r04,r14,r24,r34\n" + "vst1.32 {d20}, [%[outptr]]! @ write d20(q10,low),r44,r54\n" + "vst1.32 {d12-d13},[%[outptr]]! @ write q6:r05,r15,r25,r35\n" + "vst1.32 {d21}, [%[outptr]]! @ write d21(q10,high),r45,r55\n" + + "vswp d7, d14 @ swap d7, d14, " + "q3=r06,r16,r26,r36; q7=r07,r17,r27,r37\n" + "vst1.32 {d6-d7}, [%[outptr]]! @ write q3:r06,r16,r26,r36\n" + "vst1.32 {d22}, [%[outptr]]! @ write d22(q11,low),r46,r56\n" + "vst1.32 {d14-d15},[%[outptr]]! @ write q7:r07,r17,r27,r37\n" + "vst1.32 {d23}, [%[outptr]]! @ write d23(q11,high),r47,r57\n" + : [inptr0] "+r"(inptr0), + [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), + [inptr4] "+r"(inptr4), + [inptr5] "+r"(inptr5), + [outptr] "+r"(outptr) + : [has_alpha] "r"(has_alpha), [alpha] "w"(valpha) + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q15", + "cc", + "memory"); + } + + for (; x > 0; x--) { + if (has_alpha) { + *outptr++ = *inptr0++ * alpha; + *outptr++ = *inptr1++ * alpha; + *outptr++ = *inptr2++ * alpha; + *outptr++ = *inptr3++ * alpha; + *outptr++ = *inptr4++ * alpha; + *outptr++ = *inptr5++ * alpha; + } else { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + } + } + } +} + +void prepackA_trans_6x8(float* outptr, + const float* in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax) { + auto inptr = in + k0 * ldin + m0; + + bool has_alpha = fabsf(alpha - 1.f) > 1e-8f; + float32x4_t valpha = vdupq_n_f32(alpha); + + uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + int x_len = mmax - m0; + int y_len = kmax - k0; + int right_remain = x_len - 6 * (x_len / 6); + int right_pad = 6 - right_remain; + if (right_remain == 0) { + right_pad = 0; + } + + float* outptr_row = outptr; + int stride_out = 6 * y_len; + + float32x4_t vzero = vdupq_n_f32(0.f); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + uint32x4_t vmask2 = + vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const float* ptr0 = inptr + y * ldin; + const float* ptr1 = ptr0 + ldin; + const float* ptr2 = ptr1 + ldin; + const float* ptr3 = ptr2 + ldin; + + float* outptr_row_col = outptr_row + y * 6; + int i = 0; + for (; i < x_len - 5; i += 6) { + float* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d2}, [%[ptr0]]! @ load r0, 6 elements\n" + "vld1.32 {d4-d6}, [%[ptr1]]! @ load r1, 6 elements\n" + "vld1.32 {d8-d10}, [%[ptr2]]! @ load r2, 6 elements\n" + "vld1.32 {d12-d14}, [%[ptr3]]! @ load r3, 6 elements\n" + "cmp %[has_alpha], #0\n" + "beq 0f\n" /* check whether alpha == 1? */ + "vmul.f32 q0, q0, %q[alpha]\n" /* mul alpha */ + "vmul.f32 d2, d2, %e[alpha]\n" /* mul alpha */ + "vmul.f32 q2, q2, %q[alpha]\n" /* mul alpha */ + "vmul.f32 d6, d6, %e[alpha]\n" /* mul alpha */ + "vmul.f32 q4, q4, %q[alpha]\n" /* mul alpha */ + "vmul.f32 d10, d10, %e[alpha]\n" /* mul alpha */ + "vmul.f32 q6, q6, %q[alpha]\n" /* mul alpha */ + "vmul.f32 d14, d14, %e[alpha]\n" /* mul alpha */ + "0: \n" + "vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d6}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d8-d10}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d12-d14}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), + [ptr0] "+r"(ptr0), + [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), + [ptr3] "+r"(ptr3) + : [has_alpha] "r"(has_alpha), [alpha] "w"(valpha) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_pad > 0) { + float* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d2}, [%[ptr0]]! @ load r0, 6 elements\n" + "vld1.32 {d4-d6}, [%[ptr1]]! @ load r1, 6 elements\n" + "vld1.32 {d8-d10}, [%[ptr2]]! @ load r2, 8 elements\n" + "vld1.32 {d12-d14}, [%[ptr3]]! @ load r3, 8 elements\n" + "cmp %[has_alpha], #0\n" + "beq 0f\n" /* check whether alpha == 1? */ + "vmul.f32 q0, q0, %q[alpha]\n" /* mul alpha */ + "vmul.f32 d2, d2, %e[alpha]\n" /* mul alpha */ + "vmul.f32 q2, q2, %q[alpha]\n" /* mul alpha */ + "vmul.f32 d6, d6, %e[alpha]\n" /* mul alpha */ + "vmul.f32 q4, q4, %q[alpha]\n" /* mul alpha */ + "vmul.f32 d10, d10, %e[alpha]\n" /* mul alpha */ + "vmul.f32 q6, q6, %q[alpha]\n" /* mul alpha */ + "vmul.f32 d14, d14, %e[alpha]\n" /* mul alpha */ + "0: \n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif d2, %e[vzero], %e[vmask2] @ bit select, pad zero\n" + "vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif d6, %e[vzero], %e[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d6}, [%[outptr]]! @ write to output ptr\n" + "vbif q4, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif d10, %e[vzero], %e[vmask2] @ bit select, pad zero\n" + "vbif q6, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif d14, %e[vzero], %e[vmask2] @ bit select, pad zero\n" + "vst1.32 {d8-d10}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d12-d14}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), + [ptr0] "+r"(ptr0), + [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), + [ptr3] "+r"(ptr3) + : [vmask1] "w"(vmask1), + [vmask2] "w"(vmask2), + [vzero] "w"(vzero), + [has_alpha] "r"(has_alpha), + [alpha] "w"(valpha) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "cc", "memory"); + } + } + +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const float* ptr0 = inptr + y * ldin; + float* outptr_row_col = outptr_row + y * 6; + int i = 0; + for (; i < x_len - 5; i += 6) { + float* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d2}, [%[ptr0]]! @ load r0, 6 elements\n" + "cmp %[has_alpha], #0\n" + "beq 0f\n" /* check whether alpha == 1? */ + "vmul.f32 q0, q0, %q[alpha]\n" /* mul alpha */ + "vmul.f32 d2, d2, %e[alpha]\n" /* mul alpha */ + "0: \n" + "vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : [has_alpha] "r"(has_alpha), [alpha] "w"(valpha) + : "q0", "q1", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_pad > 0) { + float* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d2}, [%[ptr0]]! @ load r0, 6 elements\n" + "cmp %[has_alpha], #0\n" + "beq 0f\n" /* check whether alpha == 1? */ + "vmul.f32 q0, q0, %q[alpha]\n" /* mul alpha */ + "vmul.f32 d2, d2, %e[alpha]\n" /* mul alpha */ + "0: \n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif d2, %e[vzero], %e[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : [vmask1] "w"(vmask1), + [vmask2] "w"(vmask2), + [vzero] "w"(vzero), + [has_alpha] "r"(has_alpha), + [alpha] "w"(valpha) + : "q0", "q1", "cc", "memory"); + } + } +} + +void prepackA_4x8(float* outptr, + const float* inptr, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax) { + int x_len = kmax - k0; + float zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(float) * x_len); + + bool has_alpha = fabsf(alpha - 1.f) > 1e-8f; + float32x4_t valpha = vdupq_n_f32(alpha); + + for (int y = m0; y < mmax; y += 4) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + int x = x_len; + if ((y + 3) >= mmax) { + switch ((y + 3) - mmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + default: + break; + } + } + + for (; x > 7; x -= 8) { + asm volatile( + "vld4.32 {d0-d3}, [%[inptr0]]! @ zip load r0, " + "q0,q1=r00,r04,r01,r05,r02,r06,r03,r07\n" + "vld4.32 {d4-d7}, [%[inptr1]]! @ zip load r1, " + "q2,q3=r10,r14,r11,r15,r12,r16,r13,r17\n" + "vld4.32 {d8-d11}, [%[inptr2]]! @ zip load r2, " + "q4,q5=r20,r24,r21,r25,r22,r26,r23,r27\n" + "vld4.32 {d12-d15}, [%[inptr3]]! @ zip load r3, " + "q6,q7=r30,r34,r31,r35,r32,r36,r33,r37\n" + "cmp %[has_alpha], #0\n" + "beq 0f\n" /* check whether alpha == 1? */ + "vmul.f32 q0, q0, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q1, q1, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q2, q2, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q3, q3, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q4, q4, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q5, q5, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q6, q6, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q7, q7, %q[alpha]\n" /* mul alpha */ + "0: \n" + "vtrn.32 q0, q2 @ trans data: q0=r00,r10,r01,r11; " + "q2=r04,r14,r05,r15\n" + "vtrn.32 q4, q6 @ trans data: q4=r20,r30,r21,r31; " + "q6=r24,r34,r25,r35\n" + + "vswp d1, d8 @ swap d1, d8, q0=r00,r10,r20,r30; " + "q4=r01,r11,r21,r31\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write q0:r00,r10,r20,r30\n" + "vst1.32 {d8-d9}, [%[outptr]]! @ write q4:r01,r11,r21,r31\n" + + "vtrn.32 q1, q3 @ trans data: q1=r02,r12,r03,r13; " + "q3=r06,r16,r07,r17\n" + "vtrn.32 q5, q7 @ trans data: q5=r22,r32,r23,r33; " + "q7=r26,r36,r27,r37\n" + + "vswp d3, d10 @ swap d3, d10, " + "q1=r02,r12,r22,r32; q5=r03,r13,r23,r33\n" + "vst1.32 {d2-d3}, [%[outptr]]! @ write q1:r02,r12,r22,r32\n" + "vst1.32 {d10-d11},[%[outptr]]! @ write q5:r03,r13,r23,r33\n" + + "vswp d5, d12 @ swap d5, d12,q2=r04,r14,r24,r34; " + "q6=r05,r15,r25,r35\n" + "vst1.32 {d4-d5}, [%[outptr]]! @ write q2:r04,r14,r24,r34\n" + "vst1.32 {d12-d13},[%[outptr]]! @ write q6:r05,r15,r25,r35\n" + + "vswp d7, d14 @ swap d7, d14, " + "q3=r06,r16,r26,r36; q7=r07,r17,r27,r37\n" + "vst1.32 {d6-d7}, [%[outptr]]! @ write q3:r06,r16,r26,r36\n" + "vst1.32 {d14-d15},[%[outptr]]! @ write q7:r07,r17,r27,r37\n" + : [inptr0] "+r"(inptr0), + [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), + [outptr] "+r"(outptr) + : [has_alpha] "r"(has_alpha), [alpha] "w"(valpha) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "cc", "memory"); + } + + for (; x > 0; x--) { + if (has_alpha) { + *outptr++ = *inptr0++ * alpha; + *outptr++ = *inptr1++ * alpha; + *outptr++ = *inptr2++ * alpha; + *outptr++ = *inptr3++ * alpha; + } else { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + } + } + } +} + +void prepackA_trans_4x8(float* outptr, + const float* in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax) { + auto inptr = in + k0 * ldin + m0; + bool has_alpha = fabsf(alpha - 1.f) > 1e-8f; + float32x4_t valpha = vdupq_n_f32(alpha); + + uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + int x_len = mmax - m0; + int y_len = kmax - k0; + int right_remain = x_len - 4 * (x_len / 4); + int right_pad = 4 - right_remain; + if (right_remain == 0) { + right_pad = 0; + } + + int stride_out = 4 * y_len; + + float32x4_t vzero = vdupq_n_f32(0.f); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const float* ptr0 = inptr + y * ldin; + const float* ptr1 = ptr0 + ldin; + const float* ptr2 = ptr1 + ldin; + const float* ptr3 = ptr2 + ldin; + + float* outptr_row_col = outptr + y * 4; + int i = 0; + for (; i < x_len - 3; i += 4) { + float* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n" + "vld1.32 {d2-d3}, [%[ptr1]]! @ load r1, 4 elements\n" + "vld1.32 {d4-d5}, [%[ptr2]]! @ load r2, 4 elements\n" + "vld1.32 {d6-d7}, [%[ptr3]]! @ load r3, 4 elements\n" + "cmp %[has_alpha], #0\n" + "beq 0f\n" /* check whether alpha == 1? */ + "vmul.f32 q0, q0, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q1, q1, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q2, q2, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q3, q3, %q[alpha]\n" /* mul alpha */ + "0: \n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d2-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d5}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d6-d7}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), + [ptr0] "+r"(ptr0), + [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), + [ptr3] "+r"(ptr3) + : [has_alpha] "r"(has_alpha), [alpha] "w"(valpha) + : "q0", "q1", "q2", "q3", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_pad > 0) { + float* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n" + "vld1.32 {d2-d3}, [%[ptr1]]! @ load r1, 4 elements\n" + "vld1.32 {d4-d5}, [%[ptr2]]! @ load r2, 4 elements\n" + "vld1.32 {d6-d7}, [%[ptr3]]! @ load r3, 4 elements\n" + "cmp %[has_alpha], #0\n" + "beq 0f\n" /* check whether alpha == 1? */ + "vmul.f32 q0, q0, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q1, q1, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q2, q2, %q[alpha]\n" /* mul alpha */ + "vmul.f32 q3, q3, %q[alpha]\n" /* mul alpha */ + "0: \n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q1, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q3, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d2-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d5}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d6-d7}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), + [ptr0] "+r"(ptr0), + [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), + [ptr3] "+r"(ptr3) + : [vmask1] "w"(vmask1), + [vzero] "w"(vzero), + [has_alpha] "r"(has_alpha), + [alpha] "w"(valpha) + : "q0", "q1", "q2", "q3", "cc", "memory"); + } + } + +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const float* ptr0 = inptr + y * ldin; + float* outptr_row_col = outptr + y * 4; + int i = 0; + for (; i < x_len - 3; i += 4) { + float* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n" + "cmp %[has_alpha], #0\n" + "beq 0f\n" /* check whether alpha == 1? */ + "vmul.f32 q0, q0, %q[alpha]\n" /* mul alpha */ + "0: \n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : [has_alpha] "r"(has_alpha), [alpha] "w"(valpha) + : "q0", "q1", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_pad > 0) { + float* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n" + "cmp %[has_alpha], #0\n" + "beq 0f\n" /* check whether alpha == 1? */ + "vmul.f32 q0, q0, %q[alpha]\n" /* mul alpha */ + "0: \n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : [vmask1] "w"(vmask1), + [vzero] "w"(vzero), + [has_alpha] "r"(has_alpha), + [alpha] "w"(valpha) + : "q0", "q1", "cc", "memory"); + } + } +} + +#endif // __aarch64__ + +/** +* \brief input data is transpose +* for arm-v7a, transform data to block x k x 8 layout +* for arm-v8a, transform data to block x k x 12 layout +*/ +#ifdef __aarch64__ +void loadb( + float *out, const float *in, int ldin, int k0, int kmax, int n0, int nmax) { + auto outptr = reinterpret_cast(out); + auto inptr = reinterpret_cast(in) + k0 * ldin + n0; + uint32_t mask_buffer[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + int x_len = nmax - n0; + int y_len = kmax - k0; + int right_remain = x_len - 12 * (x_len / 12); + int right_pad = 12 - right_remain; + + uint32_t *outptr_row = outptr; + int stride_out = 12 * y_len; + + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + uint32x4_t vmask2 = + vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain)); + uint32x4_t vmask3 = + vcltq_u32(vld1q_u32(mask_buffer + 8), vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const uint32_t *ptr0 = inptr + y * ldin; + const uint32_t *ptr1 = ptr0 + ldin; + const uint32_t *ptr2 = ptr1 + ldin; + const uint32_t *ptr3 = ptr2 + ldin; + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + : + : [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3) + : "memory"); + + uint32_t *outptr_row_col = outptr_row + y * 12; + + int i = 0; + for (; i < x_len - 11; i += 12) { + uint32x4_t vr00 = vld1q_u32(ptr0); + uint32x4_t vr01 = vld1q_u32(ptr0 + 4); + uint32x4_t vr02 = vld1q_u32(ptr0 + 8); + + uint32x4_t vr10 = vld1q_u32(ptr1); + uint32x4_t vr11 = vld1q_u32(ptr1 + 4); + uint32x4_t vr12 = vld1q_u32(ptr1 + 8); + + vst1q_u32(outptr_row_col, vr00); + vst1q_u32(outptr_row_col + 4, vr01); + vst1q_u32(outptr_row_col + 8, vr02); + + uint32x4_t vr20 = vld1q_u32(ptr2); + uint32x4_t vr21 = vld1q_u32(ptr2 + 4); + uint32x4_t vr22 = vld1q_u32(ptr2 + 8); + + vst1q_u32(outptr_row_col + 12, vr10); + vst1q_u32(outptr_row_col + 16, vr11); + vst1q_u32(outptr_row_col + 20, vr12); + + uint32x4_t vr30 = vld1q_u32(ptr3); + uint32x4_t vr31 = vld1q_u32(ptr3 + 4); + uint32x4_t vr32 = vld1q_u32(ptr3 + 8); + + vst1q_u32(outptr_row_col + 24, vr20); + vst1q_u32(outptr_row_col + 28, vr21); + vst1q_u32(outptr_row_col + 32, vr22); + + vst1q_u32(outptr_row_col + 36, vr30); + vst1q_u32(outptr_row_col + 40, vr31); + vst1q_u32(outptr_row_col + 44, vr32); + + ptr0 += 12; + ptr1 += 12; + ptr2 += 12; + ptr3 += 12; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32x4_t vr00 = vld1q_u32(ptr0); + uint32x4_t vr01 = vld1q_u32(ptr0 + 4); + uint32x4_t vr02 = vld1q_u32(ptr0 + 8); + + uint32x4_t vr10 = vld1q_u32(ptr1); + uint32x4_t vr11 = vld1q_u32(ptr1 + 4); + uint32x4_t vr12 = vld1q_u32(ptr1 + 8); + + uint32x4_t vr00_1 = vbslq_u32(vmask1, vr00, vzero); + uint32x4_t vr01_1 = vbslq_u32(vmask2, vr01, vzero); + uint32x4_t vr02_1 = vbslq_u32(vmask3, vr02, vzero); + + uint32x4_t vr20 = vld1q_u32(ptr2); + uint32x4_t vr21 = vld1q_u32(ptr2 + 4); + uint32x4_t vr22 = vld1q_u32(ptr2 + 8); + + vst1q_u32(outptr_row_col, vr00_1); + vst1q_u32(outptr_row_col + 4, vr01_1); + vst1q_u32(outptr_row_col + 8, vr02_1); + + uint32x4_t vr10_1 = vbslq_u32(vmask1, vr10, vzero); + uint32x4_t vr11_1 = vbslq_u32(vmask2, vr11, vzero); + uint32x4_t vr12_1 = vbslq_u32(vmask3, vr12, vzero); + + uint32x4_t vr30 = vld1q_u32(ptr3); + uint32x4_t vr31 = vld1q_u32(ptr3 + 4); + uint32x4_t vr32 = vld1q_u32(ptr3 + 8); + + vst1q_u32(outptr_row_col + 12, vr10_1); + vst1q_u32(outptr_row_col + 16, vr11_1); + vst1q_u32(outptr_row_col + 20, vr12_1); + + uint32x4_t vr20_1 = vbslq_u32(vmask1, vr20, vzero); + uint32x4_t vr21_1 = vbslq_u32(vmask2, vr21, vzero); + uint32x4_t vr22_1 = vbslq_u32(vmask3, vr22, vzero); + + uint32x4_t vr30_1 = vbslq_u32(vmask1, vr30, vzero); + uint32x4_t vr31_1 = vbslq_u32(vmask2, vr31, vzero); + uint32x4_t vr32_1 = vbslq_u32(vmask3, vr32, vzero); + + vst1q_u32(outptr_row_col + 24, vr20_1); + vst1q_u32(outptr_row_col + 28, vr21_1); + vst1q_u32(outptr_row_col + 32, vr22_1); + + vst1q_u32(outptr_row_col + 36, vr30_1); + vst1q_u32(outptr_row_col + 40, vr31_1); + vst1q_u32(outptr_row_col + 44, vr32_1); + } + } + +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const uint32_t *ptr0 = inptr + y * ldin; + uint32_t *outptr_row_col = outptr_row + y * 12; + + int i = 0; + for (; i < x_len - 11; i += 12) { + uint32x4_t vr0 = vld1q_u32(ptr0); + uint32x4_t vr1 = vld1q_u32(ptr0 + 4); + uint32x4_t vr2 = vld1q_u32(ptr0 + 8); + vst1q_u32(outptr_row_col, vr0); + vst1q_u32(outptr_row_col + 4, vr1); + vst1q_u32(outptr_row_col + 8, vr2); + + ptr0 += 12; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32x4_t vr0 = vld1q_u32(ptr0); + uint32x4_t vr1 = vld1q_u32(ptr0 + 4); + uint32x4_t vr2 = vld1q_u32(ptr0 + 8); + + uint32x4_t vr0_1 = vbslq_u32(vmask1, vr0, vzero); + uint32x4_t vr1_1 = vbslq_u32(vmask2, vr1, vzero); + uint32x4_t vr2_1 = vbslq_u32(vmask3, vr2, vzero); + + vst1q_u32(outptr_row_col, vr0_1); + vst1q_u32(outptr_row_col + 4, vr1_1); + vst1q_u32(outptr_row_col + 8, vr2_1); + } + } +} + +void loadb_trans( + float *out, const float *in, int ldin, int k0, int kmax, int n0, int nmax) { + int x_len = kmax - k0; + uint32_t zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(uint32_t) * x_len); + auto outptr = reinterpret_cast(out); + auto inptr = reinterpret_cast(in); + + //! data B is not transposed, transpose B to k * 12 + for (int y = n0; y < nmax; y += 12) { + const uint32_t *inptr0 = inptr + y * ldin + k0; + const uint32_t *inptr1 = inptr0 + ldin; + const uint32_t *inptr2 = inptr1 + ldin; + const uint32_t *inptr3 = inptr2 + ldin; + const uint32_t *inptr4 = inptr3 + ldin; + const uint32_t *inptr5 = inptr4 + ldin; + const uint32_t *inptr6 = inptr5 + ldin; + const uint32_t *inptr7 = inptr6 + ldin; + const uint32_t *inptr8 = inptr7 + ldin; + const uint32_t *inptr9 = inptr8 + ldin; + const uint32_t *inptr10 = inptr9 + ldin; + const uint32_t *inptr11 = inptr10 + ldin; + + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + "prfm pldl1keep, [%[ptr4]] \n" + "prfm pldl1keep, [%[ptr4], #64] \n" + "prfm pldl1keep, [%[ptr5]] \n" + "prfm pldl1keep, [%[ptr5], #64] \n" + "prfm pldl1keep, [%[ptr6]] \n" + "prfm pldl1keep, [%[ptr6], #64] \n" + "prfm pldl1keep, [%[ptr7]] \n" + "prfm pldl1keep, [%[ptr7], #64] \n" + "prfm pldl1keep, [%[ptr8]] \n" + "prfm pldl1keep, [%[ptr8], #64] \n" + "prfm pldl1keep, [%[ptr9]] \n" + "prfm pldl1keep, [%[ptr9], #64] \n" + "prfm pldl1keep, [%[ptr10]] \n" + "prfm pldl1keep, [%[ptr10], #64] \n" + "prfm pldl1keep, [%[ptr11]] \n" + "prfm pldl1keep, [%[ptr11], #64] \n" + : + : [ptr0] "r"(inptr0), + [ptr1] "r"(inptr1), + [ptr2] "r"(inptr2), + [ptr3] "r"(inptr3), + [ptr4] "r"(inptr4), + [ptr5] "r"(inptr5), + [ptr6] "r"(inptr6), + [ptr7] "r"(inptr7), + [ptr8] "r"(inptr8), + [ptr9] "r"(inptr9), + [ptr10] "r"(inptr10), + [ptr11] "r"(inptr11) + : "memory"); + + int x = x_len; + + //! cope with row index exceed real size, set to zero buffer + if ((y + 11) >= nmax) { + switch ((y + 11) - nmax) { + case 10: + inptr1 = zerobuff; + case 9: + inptr2 = zerobuff; + case 8: + inptr3 = zerobuff; + case 7: + inptr4 = zerobuff; + case 6: + inptr5 = zerobuff; + case 5: + inptr6 = zerobuff; + case 4: + inptr7 = zerobuff; + case 3: + inptr8 = zerobuff; + case 2: + inptr9 = zerobuff; + case 1: + inptr10 = zerobuff; + case 0: + inptr11 = zerobuff; + default: + break; + } + } + for (; x > 7; x -= 8) { + asm volatile( + "ldp q0, q1, [%[inptr0]], #32\n" /* r0, a0~a7 */ + "ldp q2, q3, [%[inptr1]], #32\n" /* r1, b0~b7 */ + "ldp q4, q5, [%[inptr2]], #32\n" /* r2, c0~c7 */ + "ldp q6, q7, [%[inptr3]], #32\n" /* r3, d0~d7 */ + + "zip1 v16.4s, v0.4s, v4.4s\n" /* a0c0a1c1 */ + "zip1 v17.4s, v2.4s, v6.4s\n" /* b0d0b1d1 */ + "prfm pldl1keep, [%[inptr0], #128] \n" + + "ldp q8, q9, [%[inptr4]], #32\n" /* r4, e0~e7 */ + "ldp q10, q11, [%[inptr5]], #32\n" /* r5, f0~f7 */ + "ldp q12, q13, [%[inptr6]], #32\n" /* r6, g0~g7 */ + "ldp q14, q15, [%[inptr7]], #32\n" /* r7, h0~h7 */ + + "zip1 v18.4s, v8.4s, v12.4s\n" /* e0g0e1g1 */ + "zip1 v19.4s, v10.4s, v14.4s\n" /* f0h0f1h1 */ + "prfm pldl1keep, [%[inptr1], #128]\n" + "zip1 v20.4s, v16.4s, v17.4s\n" /* a0b0c0d0 */ + "zip1 v21.4s, v18.4s, v19.4s\n" /* e0f0g0h0 */ + "prfm pldl1keep, [%[inptr2], #128]\n" + "zip2 v22.4s, v16.4s, v17.4s\n" /* a1b1c1d1 */ + "zip2 v23.4s, v18.4s, v19.4s\n" /* e1f1g1h1 */ + + "ldp q24, q25, [%[inptr8]], #32\n" /* r8, i0~i7 */ + "ldp q26, q27, [%[inptr9]], #32\n" /* r9, j0~j7 */ + "ldp q28, q29, [%[inptr10]], #32\n" /* r10, k0~k7 */ + "ldp q30, q31, [%[inptr11]], #32\n" /* r11, l0~l7 */ + + "stp q20, q21, [%[outptr]], #32\n" /* save a0~h0 */ + "prfm pldl1keep, [%[inptr3], #128]\n" + + "zip1 v16.4s, v24.4s, v28.4s\n" /* i0k0i1k1 */ + "zip1 v17.4s, v26.4s, v30.4s\n" /* j0l0j1l1 */ + "prfm pldl1keep, [%[inptr4], #128]\n" + "zip1 v18.4s, v16.4s, v17.4s\n" /* i0j0k0l0 */ + "zip2 v19.4s, v16.4s, v17.4s\n" /* i1j1k1l1 */ + "prfm pldl1keep, [%[inptr5], #128]\n" + "zip2 v16.4s, v0.4s, v4.4s\n" /* a2c2a3c3 */ + "zip2 v17.4s, v2.4s, v6.4s\n" /* b2d2b3d3 */ + + "str q18, [%[outptr]], #16\n" /* save j0~l0 */ + "stp q22, q23, [%[outptr]], #32\n" /* save a1~h1 */ + "str q19, [%[outptr]], #16\n" /* save j1~l1 */ + + "zip2 v18.4s, v8.4s, v12.4s\n" /* e2g2e3g3 */ + "zip2 v19.4s, v10.4s, v14.4s\n" /* f2h2f3h3 */ + "prfm pldl1keep, [%[inptr6], #128]\n" + "zip1 v20.4s, v16.4s, v17.4s\n" /* a2b2c2d2 */ + "zip1 v21.4s, v18.4s, v19.4s\n" /* e2f2g2h2 */ + "prfm pldl1keep, [%[inptr7], #128]\n" + "zip2 v22.4s, v16.4s, v17.4s\n" /* a3b3c3d3 */ + "zip2 v23.4s, v18.4s, v19.4s\n" /* e3f3g3h3 */ + "prfm pldl1keep, [%[inptr8], #128]\n" + "zip2 v16.4s, v24.4s, v28.4s\n" /* i2k2i3k3 */ + "zip2 v17.4s, v26.4s, v30.4s\n" /* j2l2j3l3 */ + + "stp q20, q21, [%[outptr]], #32\n" /* save a2~h2 */ + + "zip1 v18.4s, v16.4s, v17.4s\n" /* i2j2k2l2 */ + "zip2 v19.4s, v16.4s, v17.4s\n" /* i3j3k3l3 */ + "prfm pldl1keep, [%[inptr9], #128]\n" + "zip1 v16.4s, v1.4s, v5.4s\n" /* a4c4a5c5 */ + "zip1 v17.4s, v3.4s, v7.4s\n" /* b4d4b5d5 */ + + "str q18, [%[outptr]], #16\n" /* save i2~l2 */ + "stp q22, q23, [%[outptr]], #32\n" /* save a3~h3 */ + "str q19, [%[outptr]], #16\n" /* save i3~l3 */ + + "zip1 v18.4s, v9.4s, v13.4s\n" /* e4g4e5g5 */ + "zip1 v19.4s, v11.4s, v15.4s\n" /* f4h4f5h5 */ + "prfm pldl1keep, [%[inptr10], #128]\n" + "zip1 v20.4s, v16.4s, v17.4s\n" /* a4b4c4d4 */ + "zip1 v21.4s, v18.4s, v19.4s\n" /* e4f4g4h4 */ + "prfm pldl1keep, [%[inptr11], #128]\n" + "zip2 v22.4s, v16.4s, v17.4s\n" /* a5b5c5d5 */ + "zip2 v23.4s, v18.4s, v19.4s\n" /* e5f5g5h5 */ + "zip1 v16.4s, v25.4s, v29.4s\n" /* i4k4i5k5 */ + "zip1 v17.4s, v27.4s, v31.4s\n" /* j4l4j5l5 */ + + "stp q20, q21, [%[outptr]], #32\n" /* save a4~h4 */ + + "zip1 v18.4s, v16.4s, v17.4s\n" /* i4j4k4l4 */ + "zip2 v19.4s, v16.4s, v17.4s\n" /* i5j5k5l5 */ + "zip2 v16.4s, v1.4s, v5.4s\n" /* a6c6a7c7 */ + "zip2 v17.4s, v3.4s, v7.4s\n" /* b6d6b7d7 */ + + "str q18, [%[outptr]], #16\n" /* save i4~l4 */ + "stp q22, q23, [%[outptr]], #32\n" /* save a5~h5 */ + "str q19, [%[outptr]], #16\n" /* save i5~l5 */ + + "zip2 v18.4s, v9.4s, v13.4s\n" /* e6g6e7g7 */ + "zip2 v19.4s, v11.4s, v15.4s\n" /* f6h6f7h7 */ + "zip1 v20.4s, v16.4s, v17.4s\n" /* a6b6c6d6 */ + "zip1 v21.4s, v18.4s, v19.4s\n" /* e6f6g6h6 */ + "zip2 v22.4s, v16.4s, v17.4s\n" /* a7b7c7d7 */ + "zip2 v23.4s, v18.4s, v19.4s\n" /* e7f7g7h7 */ + "zip2 v16.4s, v25.4s, v29.4s\n" /* i6k6i7k7 */ + "zip2 v17.4s, v27.4s, v31.4s\n" /* j6l6j7l7 */ + + "stp q20, q21, [%[outptr]], #32\n" /* save a6~h6 */ + + "zip1 v18.4s, v16.4s, v17.4s\n" /* i6j6k6l6 */ + "zip2 v19.4s, v16.4s, v17.4s\n" /* i7j7k7l7 */ + + "str q18, [%[outptr]], #16\n" /* save i6~l6 */ + "stp q22, q23, [%[outptr]], #32\n" /* save a7~h7 */ + "str q19, [%[outptr]], #16\n" /* save i7~l7 */ + : [inptr0] "+r"(inptr0), + [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), + [inptr4] "+r"(inptr4), + [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), + [inptr7] "+r"(inptr7), + [inptr8] "+r"(inptr8), + [inptr9] "+r"(inptr9), + [inptr10] "+r"(inptr10), + [inptr11] "+r"(inptr11), + [outptr] "+r"(outptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31", + "cc", + "memory"); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + *outptr++ = *inptr8++; + *outptr++ = *inptr9++; + *outptr++ = *inptr10++; + *outptr++ = *inptr11++; + } + } +} + +#else // __aarch64__ +void loadb( + float* out, const float* in, int ldin, int k0, int kmax, int n0, int nmax) { + auto outptr = reinterpret_cast(out); + auto inptr = reinterpret_cast(in) + k0 * ldin + n0; + uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + int x_len = nmax - n0; + int y_len = kmax - k0; + int right_remain = x_len - 8 * (x_len / 8); + int right_pad = 8 - right_remain; + + uint32_t* outptr_row = outptr; + int stride_out = 8 * y_len; + + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + uint32x4_t vmask2 = + vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const uint32_t* ptr0 = inptr + y * ldin; + const uint32_t* ptr1 = ptr0 + ldin; + const uint32_t* ptr2 = ptr1 + ldin; + const uint32_t* ptr3 = ptr2 + ldin; + uint32_t* outptr_row_col = outptr_row + y * 8; + int i = 0; + for (; i < x_len - 7; i += 8) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n" + "vld1.32 {d4-d7}, [%[ptr1]]! @ load r1, 8 elements\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n" + + "vld1.32 {d0-d3}, [%[ptr2]]! @ load r2, 8 elements\n" + "vld1.32 {d4-d7}, [%[ptr3]]! @ load r3, 8 elements\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), + [ptr0] "+r"(ptr0), + [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), + [ptr3] "+r"(ptr3) + : + : "q0", "q1", "q2", "q3", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n" + "vld1.32 {d4-d7}, [%[ptr1]]! @ load r1, 8 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + //"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q3, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n" + + "vld1.32 {d0-d3}, [%[ptr2]]! @ load r2, 8 elements\n" + "vld1.32 {d4-d7}, [%[ptr3]]! @ load r3, 8 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + //"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q3, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), + [ptr0] "+r"(ptr0), + [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), + [ptr3] "+r"(ptr3) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero) + : "q0", "q1", "q2", "q3", "cc", "memory"); + } + } +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const uint32_t* ptr0 = inptr + y * ldin; + uint32_t* outptr_row_col = outptr_row + y * 8; + int i = 0; + for (; i < x_len - 7; i += 8) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : + : "q0", "q1", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero) + : "q0", "q1", "cc", "memory"); + } + } +} + +void loadb_trans( + float* out, const float* in, int ldin, int k0, int kmax, int n0, int nmax) { + int x_len = kmax - k0; + uint32_t zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(uint32_t) * x_len); + + auto outptr = reinterpret_cast(out); + auto inptr = reinterpret_cast(in); + //! data B is not transposed, transpose B to k * 8 + for (int y = n0; y < nmax; y += 8) { + const uint32_t* inptr0 = inptr + y * ldin + k0; + const uint32_t* inptr1 = inptr0 + ldin; + const uint32_t* inptr2 = inptr1 + ldin; + const uint32_t* inptr3 = inptr2 + ldin; + const uint32_t* inptr4 = inptr3 + ldin; + const uint32_t* inptr5 = inptr4 + ldin; + const uint32_t* inptr6 = inptr5 + ldin; + const uint32_t* inptr7 = inptr6 + ldin; + + int x = x_len; + + //! cope with row index exceed real size, set to zero buffer + if ((y + 7) >= nmax) { + switch ((y + 7) - nmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + default: + break; + } + } + + for (; x > 7; x -= 8) { + //! zip load 8 elements (2 neon Q registers) from each of 8 rows + asm volatile( + "vld4.32 {d0-d3}, [%[inptr0]]! @ zip load r0, " + "q0,q1=r00,r04,r01,r05,r02,r06,r03,r07\n" + "vld4.32 {d4-d7}, [%[inptr1]]! @ zip load r1, " + "q2,q3=r10,r14,r11,r15,r12,r16,r13,r17\n" + "vtrn.32 q0, q2 @ trans data: q0=r00,r10,r01,r11; " + "q2=r04,r14,r05,r15\n" + "vst1.32 {d0}, [%[outptr]]! @ write d0(q0,low),r00,r10\n" + + "vld4.32 {d8-d11}, [%[inptr2]]! @ zip load r2, " + "q4,q5=r20,r24,r21,r25,r22,r26,r23,r27\n" + "vld4.32 {d12-d15}, [%[inptr3]]! @ zip load r3, " + "q6,q7=r30,r34,r31,r35,r32,r36,r33,r37\n" + "vtrn.32 q4, q6 @ trans data: q4=r20,r30,r21,r31; " + "q6=r24,r34,r25,r35\n" + "vst1.32 {d8}, [%[outptr]]! @ write d8(q4,low),r20,r30\n" + + "vld4.32 {d16-d19}, [%[inptr4]]! @ zip load r4, " + "q8,q9=r40,r44,r41,r45,r42,r46,r43,r47\n" + "vld4.32 {d20-d23}, [%[inptr5]]! @ zip load r5, " + "q10,q11=r50,r54,r51,r55,r52,r56,r53,r57\n" + "vtrn.32 q8, q10 @ trans data: q8=r40,r50,r41,r51; " + "q10=r44,r54,r45,r55\n" + "vst1.32 {d16}, [%[outptr]]! @ write d16(q8,low),r40,r50\n" + + "vld4.32 {d24-d27}, [%[inptr6]]! @ zip load r6, " + "q12,q13=r60,r64,r61,r65,r62,r66,r63,r67\n" + "vld4.32 {d28-d31}, [%[inptr7]]! @ zip load r7, " + "q14,q15=r70,r74,r71,r75,r72,r76,r73,r77\n" + "vtrn.32 q12, q14 @ trans data:q12=r60,r70,r61,r71; " + "q14=r64,r74,r65,r75\n" + "vst1.32 {d24}, [%[outptr]]! @ write d24(q8,low),r60,r70\n" + + //"pld [%[inptr0], #128] @ preload r0 data to cache, fill + // pipeline\n" + "vst1.32 {d1}, [%[outptr]]! @ write d1(q0,high),r01,r11\n" + "vst1.32 {d9}, [%[outptr]]! @ write d9(q4,high),r21,r31\n" + "vst1.32 {d17}, [%[outptr]]! @ write d17(q8,high),r41,r51\n" + "vst1.32 {d25}, [%[outptr]]! @ write d25(q12,high),r61,r71\n" + + "vtrn.32 q1, q3 @ trans data: q1=r02,r12,r03,r13; " + "q3=r06,r16,r07,r17\n" + "vst1.32 {d2}, [%[outptr]]! @ write d2(q1,low),r02,r12\n" + "vtrn.32 q5, q7 @ trans data: q5=r22,r32,r23,r33; " + "q7=r26,r36,r27,r37\n" + "vst1.32 {d10}, [%[outptr]]! @ write d10(q5,low),r22,r32\n" + "vtrn.32 q9, q11 @ trans data: q9=r42,r52,r43,r53; " + "q11=r46,r56,r47,r57\n" + "vst1.32 {d18}, [%[outptr]]! @ write d18(q9,low),r42,r52\n" + "vtrn.32 q13, q15 @ trans data:q13=r62,r72,r63,r73; " + "q15=r66,r76,r67,r77\n" + "vst1.32 {d26}, [%[outptr]]! @ write d18(q9,low),r62,r72\n" + + //"pld [%[inptr1], #128] @ preload r1 data to cache, fill + // pipeline\n" + "vst1.32 {d3}, [%[outptr]]! @ write d3(q1,high),r03,r13\n" + "vst1.32 {d11}, [%[outptr]]! @ write d11(q5,high),r23,r33\n" + "vst1.32 {d19}, [%[outptr]]! @ write d19(q9,high),r43,r53\n" + "vst1.32 {d27}, [%[outptr]]! @ write d27(q13,high),r63,r73\n" + + //"pld [%[inptr2], #128] @ preload r2 data to cache, fill + // pipeline\n" + "vst1.32 {d4}, [%[outptr]]! @ write d4(q2,low),r04,r14\n" + "vst1.32 {d12}, [%[outptr]]! @ write d12(q6,low),r24,r34\n" + "vst1.32 {d20}, [%[outptr]]! @ write d20(q10,low),r44,r54\n" + "vst1.32 {d28}, [%[outptr]]! @ write d28(q14,low),r64,r74\n" + + //"pld [%[inptr3], #128] @ preload r3 data to cache, fill + // pipeline\n" + "vst1.32 {d5}, [%[outptr]]! @ write d5(q2,high),r05,r15\n" + "vst1.32 {d13}, [%[outptr]]! @ write d13(q6,high),r25,r35\n" + "vst1.32 {d21}, [%[outptr]]! @ write d21(q10,high),r45,r55\n" + "vst1.32 {d29}, [%[outptr]]! @ write d29(q14,high),r65,r75\n" + + //"pld [%[inptr4], #128] @ preload r4 data to cache, fill + // pipeline\n" + "vst1.32 {d6}, [%[outptr]]! @ write d6(q3,low),r06,r16\n" + "vst1.32 {d14}, [%[outptr]]! @ write d14(q7,low),r26,r36\n" + "vst1.32 {d22}, [%[outptr]]! @ write d22(q11,low),r46,r56\n" + "vst1.32 {d30}, [%[outptr]]! @ write d30(q15,low),r66,r76\n" + + //"pld [%[inptr5], #128] @ preload r5 data to cache, fill + // pipeline\n" + "vst1.32 {d7}, [%[outptr]]! @ write d7(q3,high),r07,r17\n" + "vst1.32 {d15}, [%[outptr]]! @ write d15(q7,high),r27,r37\n" + "vst1.32 {d23}, [%[outptr]]! @ write d23(q11,high),r47,r57\n" + "vst1.32 {d31}, [%[outptr]]! @ write d31(q15,high),r67,r77\n" + : [inptr0] "+r"(inptr0), + [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), + [inptr4] "+r"(inptr4), + [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), + [inptr7] "+r"(inptr7), + [outptr] "+r"(outptr) + : + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "cc", + "memory"); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + } + } +} + +#endif // __aarch64__ + +#ifdef __aarch64__ +void sgemm_prepacked_8x12(bool is_transB, + int M, + int N, + int K, + const float *A_packed, + const float *B, + int ldb, + float beta, + float *C, + int ldc, + const float *bias, + bool has_bias, + bool has_relu, + ARMContext *ctx) { + size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024; + auto workspace = ctx->workspace_data(); + int threads = ctx->threads(); + //! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2 + int x_block = (l2_cache - (MBLOCK * K)) / (sizeof(float) * (K + MBLOCK)); + x_block /= NBLOCK; + x_block *= NBLOCK; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + NBLOCK - 1) / NBLOCK; + x_block *= NBLOCK; + x_block = x_block < NBLOCK ? NBLOCK : x_block; + + // unroll 2 loop + int tail_pre = (K & (KBLOCK - 1)); + int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1; + + bool flag_p_remain = false; + int remain = 0; + + int has_beta = fabsf(beta) > 1e-8f ? 1 : 0; + + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK; + remain = xmax - x0 - (bblocks - 1) * NBLOCK; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + float *b_pannel = workspace; + if (is_transB) { + loadb_trans(b_pannel, B, ldb, 0, K, x0, xmax); + } else { + loadb(b_pannel, B, ldb, 0, K, x0, xmax); + } +#pragma omp parallel for num_threads(threads) + for (unsigned int y = 0; y < M; y += MBLOCK) { + unsigned int ymax = y + MBLOCK; + if (ymax > M) { + ymax = M; + } + + float bias_local[8] = {0}; + if (has_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + bias_local[4] = bias[y + 4]; + bias_local[5] = bias[y + 5]; + bias_local[6] = bias[y + 6]; + bias_local[7] = bias[y + 7]; + } + + float cout0[NBLOCK]; + float cout1[NBLOCK]; + float cout2[NBLOCK]; + float cout3[NBLOCK]; + float cout4[NBLOCK]; + float cout5[NBLOCK]; + float cout6[NBLOCK]; + float cout7[NBLOCK]; + + float *c_ptr0 = C + y * ldc + x0; + float *c_ptr1 = c_ptr0 + ldc; + float *c_ptr2 = c_ptr1 + ldc; + float *c_ptr3 = c_ptr2 + ldc; + float *c_ptr4 = c_ptr3 + ldc; + float *c_ptr5 = c_ptr4 + ldc; + float *c_ptr6 = c_ptr5 + ldc; + float *c_ptr7 = c_ptr6 + ldc; + + float *pout0 = c_ptr0; + float *pout1 = c_ptr1; + float *pout2 = c_ptr2; + float *pout3 = c_ptr3; + float *pout4 = c_ptr4; + float *pout5 = c_ptr5; + float *pout6 = c_ptr6; + float *pout7 = c_ptr7; + + const float *a_ptr_l = A_packed + y * K; + const float *b_ptr = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 7) >= ymax) { + switch ((y + 7) - ymax) { + case 6: + c_ptr1 = cout1; + case 5: + c_ptr2 = cout2; + case 4: + c_ptr3 = cout3; + case 3: + c_ptr4 = cout4; + case 2: + c_ptr5 = cout5; + case 1: + c_ptr6 = cout6; + case 0: + c_ptr7 = cout7; + default: + break; + } + } + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + pout4 = c_ptr4; + pout5 = c_ptr5; + pout6 = c_ptr6; + pout7 = c_ptr7; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + c_ptr4 = cout4; + c_ptr5 = cout5; + c_ptr6 = cout6; + c_ptr7 = cout7; + if (has_beta) { + for (int i = 0; i < remain; ++i) { + cout0[i] = pout0[i]; + cout1[i] = pout1[i]; + cout2[i] = pout2[i]; + cout3[i] = pout3[i]; + cout4[i] = pout4[i]; + cout5[i] = pout5[i]; + cout6[i] = pout6[i]; + cout7[i] = pout7[i]; + } + } + } + const float *a_ptr = a_ptr_l; + int tail = tail_pre; + int k = k_pre; + + asm volatile( + "prfm pldl1keep, [%[a_ptr]]\n" /* preload a*/ + "ldp q2, q3, [%[bias_ptr]]\n" /* load bias to q2, q3*/ + "dup v8.4s, v2.s[0]\n" /* out0 = 0 */ + "dup v9.4s, v2.s[0]\n" /* out1 = 0*/ + "dup v10.4s, v2.s[0]\n" /* out2 = 0*/ + "prfm pldl1keep, [%[b_ptr]]\n" /* preload b*/ + "dup v11.4s, v2.s[1]\n" /* out3 = 0*/ + "dup v12.4s, v2.s[1]\n" /* out4 = 0*/ + "prfm pldl1keep, [%[b_ptr], #64]\n" /* preload b*/ + "dup v13.4s, v2.s[1]\n" /* out5 = 0*/ + "prfm pldl1keep, [%[a_ptr], #64]\n" /* preload a*/ + "dup v14.4s, v2.s[2]\n" /* out6 = 0*/ + "prfm pldl1keep, [%[b_ptr], #128]\n" /* preload b*/ + "dup v15.4s, v2.s[2]\n" /* out7 = 0*/ + "prfm pldl1keep, [%[a_ptr], #128]\n" /* preload a*/ + "dup v16.4s, v2.s[2]\n" /* out8 = 0*/ + "prfm pldl1keep, [%[b_ptr], #192]\n" /* preload b*/ + "dup v17.4s, v2.s[3]\n" /* out9 = 0*/ + "prfm pldl1keep, [%[b_ptr], #256]\n" /* preload b*/ + "dup v18.4s, v2.s[3]\n" /* out10 = 0*/ + "prfm pldl1keep, [%[a_ptr], #192]\n" /* preload a*/ + "dup v19.4s, v2.s[3]\n" /* out11 = 0*/ + "prfm pldl1keep, [%[b_ptr], #320]\n" /* preload b*/ + "dup v20.4s, v3.s[0]\n" /* out12 = 0*/ + "prfm pldl1keep, [%[a_ptr], #256]\n" /* preload a*/ + "dup v21.4s, v3.s[0]\n" /* out13 = 0*/ + "prfm pldl1keep, [%[b_ptr], #384]\n" /* preload b*/ + "dup v22.4s, v3.s[0]\n" /* out14 = 0*/ + "dup v23.4s, v3.s[1]\n" /* out15 = 0*/ + "dup v24.4s, v3.s[1]\n" /* out16 = 0*/ + "dup v25.4s, v3.s[1]\n" /* out17 = 0*/ + "dup v26.4s, v3.s[2]\n" /* out18 = 0*/ + "dup v27.4s, v3.s[2]\n" /* out19 = 0*/ + "dup v28.4s, v3.s[2]\n" /* out20 = 0*/ + "dup v29.4s, v3.s[3]\n" /* out21 = 0*/ + "dup v30.4s, v3.s[3]\n" /* out22 = 0*/ + "dup v31.4s, v3.s[3]\n" /* out23 = 0*/ + "cbz %w[has_beta], 0f\n" /* check beta == 0? */ + /* process beta */ + "dup v7.4s, %w[beta]\n" /* beta to vector */ + "ld1 {v0.4s, v1.4s, v2.4s}, [%[c_ptr0]]\n" /* load output r0 */ + "ld1 {v3.4s, v4.4s, v5.4s}, [%[c_ptr1]]\n" /* load output r1 */ + "fmla v8.4s, v0.4s, v7.4s\n" /* cr00 += beta * c_r00*/ + "fmla v9.4s, v1.4s, v7.4s\n" /* cr01 += beta * c_r01*/ + "fmla v10.4s, v2.4s, v7.4s\n" /* cr02 += beta * c_r02*/ + "ld1 {v0.4s, v1.4s, v2.4s}, [%[c_ptr2]]\n" /* load output r2*/ + "fmla v11.4s, v3.4s, v7.4s\n" /* cr10 += beta * c_r10*/ + "fmla v12.4s, v4.4s, v7.4s\n" /* cr11 += beta * c_r11*/ + "fmla v13.4s, v5.4s, v7.4s\n" /* cr12 += beta * c_r12*/ + "ld1 {v3.4s, v4.4s, v5.4s}, [%[c_ptr3]]\n" /* load output r3*/ + "fmla v14.4s, v0.4s, v7.4s\n" /* cr20 += beta * c_r20*/ + "fmla v15.4s, v1.4s, v7.4s\n" /* cr21 += beta * c_r21*/ + "fmla v16.4s, v2.4s, v7.4s\n" /* cr22 += beta * c_r22*/ + "ld1 {v0.4s, v1.4s, v2.4s}, [%[c_ptr4]]\n" /* load output r4*/ + "fmla v17.4s, v3.4s, v7.4s\n" /* cr30 += beta * c_r30*/ + "fmla v18.4s, v4.4s, v7.4s\n" /* cr31 += beta * c_r31*/ + "fmla v19.4s, v5.4s, v7.4s\n" /* cr32 += beta * c_r32*/ + "ld1 {v3.4s, v4.4s, v5.4s}, [%[c_ptr5]]\n" /* load output r5*/ + "fmla v20.4s, v0.4s, v7.4s\n" /* cr40 += beta * c_r40*/ + "fmla v21.4s, v1.4s, v7.4s\n" /* cr41 += beta * c_r41*/ + "fmla v22.4s, v2.4s, v7.4s\n" /* cr42 += beta * c_r42*/ + "ld1 {v0.4s, v1.4s, v2.4s}, [%[c_ptr6]]\n" /* load output r6*/ + "fmla v23.4s, v3.4s, v7.4s\n" /* cr50 += beta * c_r50*/ + "fmla v24.4s, v4.4s, v7.4s\n" /* cr51 += beta * c_r51*/ + "fmla v25.4s, v5.4s, v7.4s\n" /* cr52 += beta * c_r52*/ + "ld1 {v3.4s, v4.4s, v5.4s}, [%[c_ptr7]]\n" /* load output r7*/ + "fmla v26.4s, v0.4s, v7.4s\n" /* cr60 += beta * c_r60*/ + "fmla v27.4s, v1.4s, v7.4s\n" /* cr61 += beta * c_r61*/ + "fmla v28.4s, v2.4s, v7.4s\n" /* cr62 += beta * c_r62*/ + "fmla v29.4s, v3.4s, v7.4s\n" /* cr70 += beta * c_r70*/ + "fmla v30.4s, v4.4s, v7.4s\n" /* cr71 += beta * c_r71*/ + "fmla v31.4s, v5.4s, v7.4s\n" /* cr72 += beta * c_r72*/ + "0: \n" /* check loop count */ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00,a01 to q0, q1*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ + "cbz %w[k], 2f\n" /* check loop count > 0 */ + /* main loop */ + /* unrool 0*/ + "1:\n" /* main loop */ + "fmla v8.4s , v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =q4 */ + "fmla v11.4s , v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =q4 + */ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7 */ + "fmla v14.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =q4 */ + "fmla v17.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =q4 */ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4 */ + "fmla v20.4s, v4.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 =q4 */ + "fmla v23.4s, v4.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 =q4 */ + "fmla v26.4s, v4.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 =q4 */ + "fmla v29.4s, v4.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 =q4 */ + + "fmla v9.4s, v5.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 =q5 */ + "fmla v12.4s, v5.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 =q5 */ + "fmla v15.4s, v5.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 =q5*/ + "fmla v18.4s, v5.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 =q5*/ + "fmla v21.4s, v5.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 =q5*/ + "fmla v24.4s, v5.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 =q5*/ + "fmla v27.4s, v5.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 =q5*/ + "fmla v30.4s, v5.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 =q5*/ + + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b1, b2 to q4, q5 */ + + "fmla v10.4s, v6.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 =q6*/ + "fmla v13.4s, v6.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 =q6*/ + "prfm pldl1keep, [%[b_ptr], #384]\n" + "fmla v16.4s, v6.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 =q6*/ + "fmla v19.4s, v6.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 =q6*/ + "fmla v22.4s, v6.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 =q6*/ + "fmla v25.4s, v6.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 =q6*/ + "fmla v28.4s, v6.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 =q6*/ + "fmla v31.4s, v6.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 =q6*/ + + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1 */ + + /* unrool 1 */ + "fmla v8.4s , v7.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 =q7 */ + "fmla v11.4s , v7.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 =q7 */ + "fmla v14.4s, v7.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 =q7 */ + "prfm pldl1keep, [%[a_ptr], #256]\n" + "fmla v17.4s, v7.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 =q7 */ + "fmla v20.4s, v7.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 =q7 */ + "fmla v23.4s, v7.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q7*/ + "fmla v26.4s, v7.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 =q7 */ + "fmla v29.4s, v7.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 =q7 */ + + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7 */ + + "fmla v9.4s, v4.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 =q4 */ + "fmla v12.4s, v4.4s, v2.s[1]\n" /* out9 = b0 * a10[1], b1 =q4 */ + "fmla v15.4s, v4.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 =q4*/ + "fmla v18.4s, v4.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 =q4*/ + "fmla v21.4s, v4.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 =q4*/ + "fmla v24.4s, v4.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 =q4*/ + "fmla v27.4s, v4.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 =q4*/ + "fmla v30.4s, v4.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 =q4*/ + + "fmla v10.4s, v5.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 =q5*/ + "fmla v13.4s, v5.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 =q5*/ + "fmla v16.4s, v5.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 =q5*/ + "fmla v19.4s, v5.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 =q5*/ + "fmla v22.4s, v5.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 =q5*/ + "fmla v25.4s, v5.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 =q5*/ + "fmla v28.4s, v5.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 =q5*/ + "fmla v31.4s, v5.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 =q5*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5 */ + /* unrool 2*/ + "fmla v8.4s , v6.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =q6 */ + "fmla v11.4s , v6.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =q6 + */ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/ + "fmla v14.4s, v6.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =q6*/ + "fmla v17.4s, v6.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =q6*/ + "fmla v20.4s, v6.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 =q6*/ + "fmla v23.4s, v6.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 =q6*/ + "fmla v26.4s, v6.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 =q6*/ + "fmla v29.4s, v6.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 =q6*/ + "fmla v9.4s, v7.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 =q7*/ + "fmla v12.4s, v7.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 =q7*/ + "prfm pldl1keep, [%[b_ptr], #384]\n" + "fmla v15.4s, v7.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 =q7*/ + "fmla v18.4s, v7.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 =q7*/ + "fmla v21.4s, v7.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 =q7*/ + "fmla v24.4s, v7.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 =q7*/ + "fmla v27.4s, v7.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 =q7*/ + "fmla v30.4s, v7.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 =q7*/ + + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b1, b2 to q6, q7*/ + + "fmla v10.4s, v4.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 =q4*/ + "fmla v13.4s, v4.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 =q4*/ + "fmla v16.4s, v4.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 =q4*/ + "fmla v19.4s, v4.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 =q4*/ + "fmla v22.4s, v4.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 =q4*/ + "fmla v25.4s, v4.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 =q4*/ + "fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 =q4*/ + "fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 =q4*/ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/ + /* unrool 3*/ + "fmla v8.4s , v5.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 =q5*/ + "fmla v11.4s , v5.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 =q5*/ + "fmla v14.4s, v5.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 =q5*/ + "fmla v17.4s, v5.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 =q5*/ + "fmla v20.4s, v5.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 =q5*/ + "fmla v23.4s, v5.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 =q5*/ + "fmla v26.4s, v5.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 =q5*/ + "fmla v29.4s, v5.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 =q5*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ + "fmla v9.4s, v6.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 =q6*/ + "fmla v12.4s, v6.4s, v2.s[1]\n" /* out9 = b0 * a10[1], b1 =q6*/ + "prfm pldl1keep, [%[a_ptr], #256]\n" + "fmla v15.4s, v6.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 =q6*/ + "fmla v18.4s, v6.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 =q6*/ + "fmla v21.4s, v6.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 =q6*/ + "fmla v24.4s, v6.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 =q6*/ + "fmla v27.4s, v6.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 =q6*/ + "prfm pldl1keep, [%[b_ptr], #384]\n" + "fmla v30.4s, v6.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 =q6*/ + "fmla v10.4s, v7.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 =q7*/ + "fmla v13.4s, v7.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 =q7*/ + "fmla v16.4s, v7.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 =q7*/ + "fmla v19.4s, v7.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 =q7*/ + "fmla v22.4s, v7.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 =q7*/ + "fmla v25.4s, v7.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 =q7*/ + "subs %w[k], %w[k], #1\n" /* loop count - 1*/ + "fmla v28.4s, v7.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 =q7*/ + "fmla v31.4s, v7.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 =q7*/ + "bne 1b\n" + "2:\n" /* process tail*/ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "beq 3f\n" /*jump to tail = 1*/ + /* final unrool 0*/ + /* unrool 0, tail > 1*/ + "fmla v8.4s , v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =q4*/ + "fmla v11.4s , v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =q4*/ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7*/ + "fmla v14.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =q4*/ + "fmla v17.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =q4*/ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q2, q3*/ + "fmla v20.4s, v4.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 =q4*/ + "fmla v23.4s, v4.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 =q4*/ + "fmla v26.4s, v4.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 =q4*/ + "fmla v29.4s, v4.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 =q4*/ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "fmla v9.4s, v5.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 =q5*/ + "fmla v12.4s, v5.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 =q5*/ + "fmla v15.4s, v5.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 =q5*/ + "fmla v18.4s, v5.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 =q5*/ + "fmla v21.4s, v5.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 =q5*/ + "fmla v24.4s, v5.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 =q5*/ + "fmla v27.4s, v5.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 =q5*/ + "fmla v30.4s, v5.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 =q5*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b1, b2 to q4, q5*/ + "fmla v10.4s, v6.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 =q6*/ + "fmla v13.4s, v6.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 =q6*/ + "fmla v16.4s, v6.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 =q6*/ + "fmla v19.4s, v6.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 =q6*/ + "fmla v22.4s, v6.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 =q6*/ + "fmla v25.4s, v6.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 =q6*/ + "fmla v28.4s, v6.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 =q6*/ + "fmla v31.4s, v6.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 =q6*/ + "beq 4f\n" /*jump to tail = 2*/ + /* unrool 1, tail > 2*/ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/ + "fmla v8.4s , v7.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 =q7*/ + "fmla v11.4s , v7.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 =q7*/ + "fmla v14.4s, v7.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 =q7*/ + "fmla v17.4s, v7.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 =q7*/ + "fmla v20.4s, v7.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 =q7*/ + "fmla v23.4s, v7.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 =q7*/ + "fmla v26.4s, v7.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 =q7*/ + "fmla v29.4s, v7.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 =q7*/ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7*/ + "fmla v9.4s, v4.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 =q4*/ + "fmla v12.4s, v4.4s, v2.s[1]\n" /* out9 = b0 * a10[1], b1 =q4*/ + "fmla v15.4s, v4.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 =q4*/ + "fmla v18.4s, v4.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 =q4*/ + "fmla v21.4s, v4.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 =q4*/ + "fmla v24.4s, v4.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 =q4*/ + "fmla v27.4s, v4.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 =q4*/ + "fmla v30.4s, v4.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 =q4*/ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "fmla v10.4s, v5.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 =q5*/ + "fmla v13.4s, v5.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 =q5*/ + "fmla v16.4s, v5.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 =q5*/ + "fmla v19.4s, v5.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 =q5*/ + "fmla v22.4s, v5.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 =q5*/ + "fmla v25.4s, v5.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 =q5*/ + "fmla v28.4s, v5.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 =q5*/ + "fmla v31.4s, v5.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 =q5*/ + "beq 5f\n" /*jump to tail = 3*/ + /* unrool 2, tail = 4*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5*/ + "fmla v8.4s , v6.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 =q6*/ + "fmla v11.4s , v6.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 =q6*/ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/ + "fmla v14.4s, v6.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 =q6*/ + "fmla v17.4s, v6.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 =q6*/ + "fmla v20.4s, v6.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 =q6*/ + "fmla v23.4s, v6.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 =q6*/ + "fmla v26.4s, v6.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 =q6*/ + "fmla v29.4s, v6.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 =q6*/ + "fmla v9.4s, v7.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 =q7*/ + "fmla v12.4s, v7.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 =q7*/ + "fmla v15.4s, v7.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 =q7*/ + "fmla v18.4s, v7.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 =q7*/ + "fmla v21.4s, v7.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 =q7*/ + "fmla v24.4s, v7.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 =q7*/ + "fmla v27.4s, v7.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 =q7*/ + "fmla v30.4s, v7.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 =q7*/ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b1, b2 to q6, q7*/ + "fmla v10.4s, v4.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 =q4*/ + "fmla v13.4s, v4.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 =q4*/ + "fmla v16.4s, v4.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 =q4*/ + "fmla v19.4s, v4.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 =q4*/ + "fmla v22.4s, v4.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 =q4*/ + "fmla v25.4s, v4.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 =q4*/ + "fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 =q4*/ + "fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 =q4*/ + /* unrool 3, tail = 4*/ + "fmla v8.4s , v5.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 =q5*/ + "fmla v11.4s , v5.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 =q5*/ + "fmla v14.4s, v5.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 =q5*/ + "fmla v17.4s, v5.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 =q5*/ + "fmla v20.4s, v5.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 =q5*/ + "fmla v23.4s, v5.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 =q5*/ + "fmla v26.4s, v5.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 =q5*/ + "fmla v29.4s, v5.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 =q5*/ + "fmla v9.4s, v6.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 =q6*/ + "fmla v12.4s, v6.4s, v2.s[1]\n" /* out9 = b1 * a10[1], b1 =q6*/ + "fmla v15.4s, v6.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 =q6*/ + "fmla v18.4s, v6.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 =q6*/ + "fmla v21.4s, v6.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 =q6*/ + "fmla v24.4s, v6.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 =q6*/ + "fmla v27.4s, v6.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 =q6*/ + "fmla v30.4s, v6.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 =q6*/ + "fmla v10.4s, v7.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 =q7*/ + "fmla v13.4s, v7.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 =q7*/ + "fmla v16.4s, v7.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 =q7*/ + "fmla v19.4s, v7.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 =q7*/ + "fmla v22.4s, v7.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 =q7*/ + "fmla v25.4s, v7.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 =q7*/ + "fmla v28.4s, v7.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 =q7*/ + "fmla v31.4s, v7.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 =q7*/ + "b 11f\n" + /* tails==1 final tail*/ + "3: \n" /* tail=1*/ + "ldr q6, [%[b_ptr]], #16\n" /* load b2 to q6*/ + "fmla v8.4s , v4.4s, v0.s[0]\n" /* out0 = b0 * a10[0], b0 =q5*/ + "fmla v11.4s , v4.4s, v0.s[1]\n" /* out1 = b0 * a10[1], b0 =q5*/ + "fmla v14.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a10[2], b0 =q5*/ + "fmla v17.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a10[3], b0 =q5*/ + "fmla v20.4s, v4.4s, v1.s[0]\n" /* out4 = b0 * a11[0], b0 =q5*/ + "fmla v23.4s, v4.4s, v1.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v4.4s, v1.s[2]\n" /* out6 = b0 * a11[2], b0 =q5*/ + "fmla v29.4s, v4.4s, v1.s[3]\n" /* out7 = b0 * a11[3], b0 =q5*/ + "fmla v9.4s, v5.4s, v0.s[0]\n" /* out8 = b0 * a10[0], b1 =q6*/ + "fmla v12.4s, v5.4s, v0.s[1]\n" /* out9 = b1 * a10[1], b1 =q6*/ + "fmla v15.4s, v5.4s, v0.s[2]\n" /* out10 = b1 * a10[2], b1 =q6*/ + "fmla v18.4s, v5.4s, v0.s[3]\n" /* out11 = b1 * a10[3], b1 =q6*/ + "fmla v21.4s, v5.4s, v1.s[0]\n" /* out12 = b1 * a10[0], b1 =q6*/ + "fmla v24.4s, v5.4s, v1.s[1]\n" /* out13 = b1 * a10[1], b1 =q6*/ + "fmla v27.4s, v5.4s, v1.s[2]\n" /* out14 = b1 * a10[2], b1 =q6*/ + "fmla v30.4s, v5.4s, v1.s[3]\n" /* out15 = b1 * a10[3], b1 =q6*/ + "fmla v10.4s, v6.4s, v0.s[0]\n" /* out16 = b2 * a10[0], b2 =q7*/ + "fmla v13.4s, v6.4s, v0.s[1]\n" /* out17 = b2 * a10[0], b2 =q7*/ + "fmla v16.4s, v6.4s, v0.s[2]\n" /* out18 = b2 * a10[0], b2 =q7*/ + "fmla v19.4s, v6.4s, v0.s[3]\n" /* out19 = b2 * a10[0], b2 =q7*/ + "fmla v22.4s, v6.4s, v1.s[0]\n" /* out20 = b2 * a10[0], b2 =q7*/ + "fmla v25.4s, v6.4s, v1.s[1]\n" /* out21 = b2 * a10[0], b2 =q7*/ + "fmla v28.4s, v6.4s, v1.s[2]\n" /* out22 = b2 * a10[0], b2 =q7*/ + "fmla v31.4s, v6.4s, v1.s[3]\n" /* out23 = b2 * a10[0], b2 =q7*/ + "b 11f\n" + /* tails==2 final tail*/ + "4:\n" /* tail = 2*/ + "fmla v8.4s , v7.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 =q5*/ + "fmla v11.4s , v7.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 =q5*/ + "fmla v14.4s, v7.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 =q5*/ + "fmla v17.4s, v7.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 =q5*/ + "fmla v20.4s, v7.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 =q5*/ + "fmla v23.4s, v7.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v7.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 =q5*/ + "fmla v29.4s, v7.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 =q5*/ + "fmla v9.4s, v4.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 =q6*/ + "fmla v12.4s, v4.4s, v2.s[1]\n" /* out9 = b1 * a10[1], b1 =q6*/ + "fmla v15.4s, v4.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 =q6*/ + "fmla v18.4s, v4.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 =q6*/ + "fmla v21.4s, v4.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 =q6*/ + "fmla v24.4s, v4.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 =q6*/ + "fmla v27.4s, v4.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 =q6*/ + "fmla v30.4s, v4.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 =q6*/ + "fmla v10.4s, v5.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 =q7*/ + "fmla v13.4s, v5.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 =q7*/ + "fmla v16.4s, v5.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 =q7*/ + "fmla v19.4s, v5.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 =q7*/ + "fmla v22.4s, v5.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 =q7*/ + "fmla v25.4s, v5.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 =q7*/ + "fmla v28.4s, v5.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 =q7*/ + "fmla v31.4s, v5.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 =q7*/ + "b 11f\n" + /* tails==3 final tail*/ + "5:\n" /* tail = 3*/ + "ldr q4, [%[b_ptr]], #16\n" /* load b2, b0 to q4*/ + "fmla v8.4s , v6.4s, v0.s[0]\n" /* out0 = b0 * a10[0], b0 =q5*/ + "fmla v11.4s , v6.4s, v0.s[1]\n" /* out1 = b0 * a10[1], b0 =q5*/ + "fmla v14.4s, v6.4s, v0.s[2]\n" /* out2 = b0 * a10[2], b0 =q5*/ + "fmla v17.4s, v6.4s, v0.s[3]\n" /* out3 = b0 * a10[3], b0 =q5*/ + "fmla v20.4s, v6.4s, v1.s[0]\n" /* out4 = b0 * a11[0], b0 =q5*/ + "fmla v23.4s, v6.4s, v1.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v6.4s, v1.s[2]\n" /* out6 = b0 * a11[2], b0 =q5*/ + "fmla v29.4s, v6.4s, v1.s[3]\n" /* out7 = b0 * a11[3], b0 =q5*/ + "fmla v9.4s, v7.4s, v0.s[0]\n" /* out8 = b0 * a10[0], b1 =q6*/ + "fmla v12.4s, v7.4s, v0.s[1]\n" /* out9 = b1 * a10[1], b1 =q6*/ + "fmla v15.4s, v7.4s, v0.s[2]\n" /* out10 = b1 * a10[2], b1 =q6*/ + "fmla v18.4s, v7.4s, v0.s[3]\n" /* out11 = b1 * a10[3], b1 =q6*/ + "fmla v21.4s, v7.4s, v1.s[0]\n" /* out12 = b1 * a10[0], b1 =q6*/ + "fmla v24.4s, v7.4s, v1.s[1]\n" /* out13 = b1 * a10[1], b1 =q6*/ + "fmla v27.4s, v7.4s, v1.s[2]\n" /* out14 = b1 * a10[2], b1 =q6*/ + "fmla v30.4s, v7.4s, v1.s[3]\n" /* out15 = b1 * a10[3], b1 =q6*/ + "fmla v10.4s, v4.4s, v0.s[0]\n" /* out16 = b2 * a10[0], b2 =q7*/ + "fmla v13.4s, v4.4s, v0.s[1]\n" /* out17 = b2 * a10[0], b2 =q7*/ + "fmla v16.4s, v4.4s, v0.s[2]\n" /* out18 = b2 * a10[0], b2 =q7*/ + "fmla v19.4s, v4.4s, v0.s[3]\n" /* out19 = b2 * a10[0], b2 =q7*/ + "fmla v22.4s, v4.4s, v1.s[0]\n" /* out20 = b2 * a10[0], b2 =q7*/ + "fmla v25.4s, v4.4s, v1.s[1]\n" /* out21 = b2 * a10[0], b2 =q7*/ + "fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a10[0], b2 =q7*/ + "fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a10[0], b2 =q7*/ + "11: \n" /* check if relu */ + "cbz %w[relu], 12f\n" /* skip relu */ + "movi v2.4s, #0\n" /* for relu*/ + "fmax v8.4s, v8.4s, v2.4s\n" /* relu*/ + "fmax v9.4s, v9.4s, v2.4s\n" /* relu*/ + "fmax v10.4s, v10.4s, v2.4s\n" /* relu*/ + "fmax v11.4s, v11.4s, v2.4s\n" /* relu*/ + "fmax v12.4s, v12.4s, v2.4s\n" /* relu*/ + "fmax v13.4s, v13.4s, v2.4s\n" /* relu*/ + "fmax v14.4s, v14.4s, v2.4s\n" /* relu*/ + "fmax v15.4s, v15.4s, v2.4s\n" /* relu*/ + "fmax v16.4s,v16.4s,v2.4s\n" /* relu*/ + "fmax v17.4s,v17.4s,v2.4s\n" /* relu*/ + "fmax v18.4s, v18.4s, v2.4s\n" /* relu*/ + "fmax v19.4s, v19.4s, v2.4s\n" /* relu*/ + "fmax v20.4s, v20.4s, v2.4s\n" /* relu*/ + "fmax v21.4s, v21.4s, v2.4s\n" /* relu*/ + "fmax v22.4s, v22.4s, v2.4s\n" /* relu*/ + "fmax v23.4s, v23.4s, v2.4s\n" /* relu*/ + "fmax v24.4s,v24.4s,v2.4s\n" /* relu*/ + "fmax v25.4s,v25.4s,v2.4s\n" /* relu*/ + "fmax v26.4s, v26.4s, v2.4s\n" /* relu*/ + "fmax v27.4s, v27.4s, v2.4s\n" /* relu*/ + "fmax v28.4s, v28.4s, v2.4s\n" /* relu*/ + "fmax v29.4s, v29.4s, v2.4s\n" /* relu*/ + "fmax v30.4s, v30.4s, v2.4s\n" /* relu*/ + "fmax v31.4s, v31.4s, v2.4s\n" /* relu*/ + "12: \n" + "st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */ + "st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */ + "st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */ + "st1 {v17.4s, v18.4s, v19.4s},[%[c_ptr3]], #48\n" /* store r3 */ + "st1 {v20.4s, v21.4s, v22.4s},[%[c_ptr4]], #48\n" /* store r4 */ + "st1 {v23.4s, v24.4s, v25.4s},[%[c_ptr5]], #48\n" /* store r5 */ + "st1 {v26.4s, v27.4s, v28.4s},[%[c_ptr6]], #48\n" /* store r6 */ + "st1 {v29.4s, v30.4s, v31.4s},[%[c_ptr7]], #48\n" /* store r7 */ + + : [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), + [k] "+r"(k), + [tail] "+r"(tail), + [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), + [c_ptr4] "+r"(c_ptr4), + [c_ptr5] "+r"(c_ptr5), + [c_ptr6] "+r"(c_ptr6), + [c_ptr7] "+r"(c_ptr7) + : [bias_ptr] "r"(bias_local), + [relu] "r"(has_relu), + [has_beta] "r"(has_beta), + [beta] "r"(beta) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31"); + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + *pout4++ = cout4[i]; + *pout5++ = cout5[i]; + *pout6++ = cout6[i]; + *pout7++ = cout7[i]; + } + } + } + } + } +} +#else // __aarch64__ +/** + * \brief gemm with ablock = 6, bblock = 8, output 6x8 + * @param A + * @param B + * @param C + * @param M + * @param N + * @param K + * @param threads + * @param workspace + */ +void sgemm_prepacked_6x8(bool is_transB, + int M, + int N, + int K, + const float* A_packed, + const float* B, + int ldb, + float beta, + float* C, + int ldc, + const float* bias, + bool has_bias, + bool has_relu, + ARMContext* ctx) { + size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024; + auto* workspace = ctx->workspace_data(); + int threads = ctx->threads(); + //! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2 + int x_block = + (l2_cache - (MBLOCK_OTH * K)) / (sizeof(float) * (K + MBLOCK_OTH)); + x_block /= NBLOCK; + x_block *= NBLOCK; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + NBLOCK - 1) / NBLOCK; + x_block *= NBLOCK; + x_block = x_block < NBLOCK ? NBLOCK : x_block; + + int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1; + int tail_pre = (K & (KBLOCK - 1)); + if (tail_pre == 0) { + tail_pre = KBLOCK; + } + + bool flag_p_remain = false; + int remain = 0; + + int has_beta = fabsf(beta) > 1e-8f ? 1 : 0; + + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK; + remain = xmax - x0 - (bblocks - 1) * NBLOCK; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + auto b_pannel = static_cast(workspace); + if (is_transB) { + loadb_trans(b_pannel, B, ldb, 0, K, x0, xmax); + } else { + loadb(b_pannel, B, ldb, 0, K, x0, xmax); + } +#pragma omp parallel for num_threads(threads) + for (unsigned int y = 0; y < M; y += MBLOCK_OTH) { + unsigned int ymax = y + MBLOCK_OTH; + if (ymax > M) { + ymax = M; + } + float* c_ptr0 = C + y * ldc + x0; + float* c_ptr1 = c_ptr0 + ldc; + float* c_ptr2 = c_ptr1 + ldc; + float* c_ptr3 = c_ptr2 + ldc; + float* c_ptr4 = c_ptr3 + ldc; + float* c_ptr5 = c_ptr4 + ldc; + + float* pout0 = c_ptr0; + float* pout1 = c_ptr1; + float* pout2 = c_ptr2; + float* pout3 = c_ptr3; + float* pout4 = c_ptr4; + float* pout5 = c_ptr5; + + float bias_local[6] = {0}; + if (has_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + bias_local[4] = bias[y + 4]; + bias_local[5] = bias[y + 5]; + } + + float cout0[NBLOCK]; + float cout1[NBLOCK]; + float cout2[NBLOCK]; + float cout3[NBLOCK]; + float cout4[NBLOCK]; + float cout5[NBLOCK]; + + const float* a_ptr_l = A_packed + y * K; + const float* b_ptr = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 5) >= ymax) { + switch ((y + 5) - ymax) { + case 4: + c_ptr1 = cout1; + case 3: + c_ptr2 = cout2; + case 2: + c_ptr3 = cout3; + case 1: + c_ptr4 = cout4; + case 0: + c_ptr5 = cout5; + default: + break; + } + } + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + pout4 = c_ptr4; + pout5 = c_ptr5; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + c_ptr4 = cout4; + c_ptr5 = cout5; + if (has_beta) { + for (int i = 0; i < remain; ++i) { + cout0[i] = pout0[i]; + cout1[i] = pout1[i]; + cout2[i] = pout2[i]; + cout3[i] = pout3[i]; + cout4[i] = pout4[i]; + cout5[i] = pout5[i]; + } + } + } + const float* a_ptr = a_ptr_l; + int tails = tail_pre; + int k = k_pre; + asm volatile( + // sgemm 6x8 + "vld1.32 {d2-d4}, [%[bias_ptr]] @ load bias 6 elements\n" + "pld [%[a_ptr]] @ preload a\n" + "vdup.i32 q12,d4[0] @ out40=0\n" + "pld [%[b_ptr]] @ preload b\n" + "vdup.i32 q13,d4[0] @ out41=0\n" + "pld [%[a_ptr], #64] @ preload a\n" + "vdup.i32 q14,d4[1] @ out50=0\n" + "pld [%[b_ptr], #64] @ preload b\n" + "vdup.i32 q15,d4[1] @ out51=0\n" + "pld [%[a_ptr], #128] @ preload a\n" + "vdup.i32 q4, d2[0] @ out00=0\n" + "pld [%[b_ptr], #128] @ preload b\n" + "vdup.i32 q5, d2[0] @ out01=0\n" + "vdup.i32 q6, d2[1] @ out10=0\n" + "pld [%[a_ptr], #192] @ preload a\n" + "vdup.i32 q7, d2[1] @ out11=0\n" + "pld [%[b_ptr], #192] @ preload a\n" + "vdup.i32 q8, d3[0] @ out20=0\n" + "pld [%[a_ptr], #256] @ preload a\n" + "vdup.i32 q9, d3[0] @ out21=0\n" + "pld [%[b_ptr], #256] @ preload a\n" + "vdup.i32 q10,d3[1] @ out30=0\n" + "pld [%[b_ptr], #320] @ preload b\n" + "vdup.i32 q11,d3[1] @ out31=0\n" + "pld [%[b_ptr], #384] @ preload b\n" + "cmp %[has_beta], #0\n" + "beq 11f\n" /* check beta == 0? */ + /* process beta */ + "vdup.32 q3, %[beta]\n" /* beta to vector */ + "vld1.32 {d0-d3}, [%[c_ptr0]]\n" /* load output r0 */ + "vmla.f32 q4, q0, q3\n" /* cr00 += beta * c_r00 */ + "vmla.f32 q5, q1, q3\n" /* cr01 += beta * c_r01 */ + "vld1.32 {d0-d3}, [%[c_ptr1]]\n" /* load output r1 */ + "vmla.f32 q6, q0, q3\n" /* cr10 += beta * c_r10 */ + "vmla.f32 q7, q1, q3\n" /* cr11 += beta * c_r11 */ + "vld1.32 {d0-d3}, [%[c_ptr2]]\n" /* load output r2 */ + "vmla.f32 q8, q0, q3\n" /* cr20 += beta * c_r20 */ + "vmla.f32 q9, q1, q3\n" /* cr21 += beta * c_r21 */ + "vld1.32 {d0-d3}, [%[c_ptr3]]\n" /* load output r3 */ + "vmla.f32 q10, q0, q3\n" /* cr30 += beta * c_r30 */ + "vmla.f32 q11, q1, q3\n" /* cr31 += beta * c_r31 */ + "vld1.32 {d0-d3}, [%[c_ptr4]]\n" /* load output r4 */ + "vmla.f32 q12, q0, q3\n" /* cr40 += beta * c_r40 */ + "vmla.f32 q13, q1, q3\n" /* cr41 += beta * c_r41 */ + "vld1.32 {d0-d3}, [%[c_ptr5]]\n" /* load output r5 */ + "vmla.f32 q14, q0, q3\n" /* cr50 += beta * c_r50 */ + "vmla.f32 q15, q1, q3\n" /* cr51 += beta * c_r51 */ + "11: \n" /* check loop count */ + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a0~a3\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "cmp %[k], #0 @ check weather k is bigger than " + "0\n" + "beq 0f @ jump to tail\n" + "1: @ main loop for k\n" + /* Unroll 0*/ + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a4, a5, and next a0, " + "a1\n" + "vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a2~a5\n" + "vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + /* Unroll 1 */ + "vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1\n" + /*"pld [%[a_ptr], #64] @ preload a\n"*/ + "vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3\n" + /*"pld [%[b_ptr], #192]\n"*/ + "vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a0~a3\n" + "vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a4, a5, a0, a1\n" + /* Unroll 2 */ + "vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3\n" + /*"pld [%[a_ptr], #240] @ preload\n"*/ + "vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1\n" + /*"pld [%[b_ptr], #208]\n"*/ + "vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a2~a5\n" + "vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + /* Unroll 3 */ + "vmla.f32 q4, q2, d1[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d1[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d2[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d2[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d3[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d3[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d1[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d1[1] @ out7 += b2 * a1\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a0~a3\n" + "vmla.f32 q9, q3, d2[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d2[1] @ out9 += b2 * a3\n" + "subs %[k], %[k], #1 @ k--\n" + "vmla.f32 q13, q3, d3[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d3[1] @ out11 += b2 * a5\n" + "bne 1b @ jump to main loop\n" + "0: @ process tail\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "beq 3f @ jump to tail = 1\n" + /* Unroll 0*/ + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a4,5, a0, a1\n" + "vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a2~a5\n" + "vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "beq 4f @ jump to tail==2\n" + /* Unroll 1*/ + "vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a0~a3\n" + "vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "beq 5f @ jump to tail==3\n" + /* Unroll 2 */ + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a4,a5, a0,a1\n" + "vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a2~a5\n" + "vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + /* Unroll 3*/ + "vmla.f32 q4, q2, d1[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d1[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d2[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d2[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d3[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d3[1] @ out5 += b1 * a5\n" + "vmla.f32 q5, q3, d1[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d1[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d2[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d2[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d3[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d3[1] @ out11 += b2 * a5\n" + "b 2f\n" + /* tails==1 final tail*/ + "3: @ tail=1\n" + "vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0\n" + "vld1.32 {d2}, [%[a_ptr] :64]! @ load a4,a5\n" + "vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5\n" + "vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5\n" + "b 2f @ jump to end\n" + /* tails==2 final tail*/ + "4: @ tail == 2\n" + "vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5\n" + "vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5\n" + "b 2f @ jump to end\n" + /* tails==3 final tail*/ + "5: @ tail=3\n" + "vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0\n" + "vld1.32 {d0}, [%[a_ptr] :64]! @ load a4,a5\n" + "vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5\n" + "vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n" + "2: @ check relu\n" + "cmp %[relu], #0 @ check if has relu\n" + "ble 6f @ skip relu if relu <= 0\n" + "vmov.u32 q0, #0 @ for relu\n" + "vmax.f32 q4, q4, q0 @ for relu\n" + "vmax.f32 q5, q5, q0 @ for relu\n" + "vmax.f32 q6, q6, q0 @ for relu\n" + "vmax.f32 q7, q7, q0 @ for relu\n" + "vmax.f32 q8, q8, q0 @ for relu\n" + "vmax.f32 q9, q9, q0 @ for relu\n" + "vmax.f32 q10, q10, q0 @ for relu\n" + "vmax.f32 q11, q11, q0 @ for relu\n" + "vmax.f32 q12, q12, q0 @ for relu\n" + "vmax.f32 q13, q13, q0 @ for relu\n" + "vmax.f32 q14, q14, q0 @ for relu\n" + "vmax.f32 q15, q15, q0 @ for relu\n" + "6: @ store result\n" + "vst1.32 {d8-d11}, [%[c_ptr0]]! @ store r0\n" + "vst1.32 {d12-d15}, [%[c_ptr1]]! @ store r1\n" + "vst1.32 {d16-d19}, [%[c_ptr2]]! @ store r2\n" + "vst1.32 {d20-d23}, [%[c_ptr3]]! @ store r3\n" + "vst1.32 {d24-d27}, [%[c_ptr4]]! @ store r4\n" + "vst1.32 {d28-d31}, [%[c_ptr5]]! @ store r5\n" + : [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), + [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), + [c_ptr4] "+r"(c_ptr4), + [c_ptr5] "+r"(c_ptr5), + [k] "+r"(k), + [tails] "+r"(tails) + : [bias_ptr] "r"(bias_local), + [relu] "r"(has_relu), + [has_beta] "r"(has_beta), + [beta] "r"(beta) + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "cc", + "memory"); + + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + *pout4++ = cout4[i]; + *pout5++ = cout5[i]; + } + } + } + } + } +} + +void sgemm_prepacked_4x8(bool is_transB, + int M, + int N, + int K, + const float* A_packed, + const float* B, + int ldb, + float beta, + float* C, + int ldc, + const float* bias, + bool has_bias, + bool has_relu, + ARMContext* ctx) { + size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024; + auto* workspace = ctx->workspace_data(); + int threads = ctx->threads(); + //! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2 + int x_block = + (l2_cache - (MBLOCK_A73 * K)) / (sizeof(float) * (K + MBLOCK_A73)); + x_block /= NBLOCK; + x_block *= NBLOCK; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + NBLOCK - 1) / NBLOCK; + x_block *= NBLOCK; + x_block = x_block < NBLOCK ? NBLOCK : x_block; + + int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1; + int tail_pre = (K & (KBLOCK - 1)); + if (tail_pre == 0) { + tail_pre = KBLOCK; + } + + bool flag_p_remain = false; + int remain = 0; + + int has_beta = fabsf(beta) > 1e-8f ? 1 : 0; + + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK; + remain = xmax - x0 - (bblocks - 1) * NBLOCK; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + auto b_pannel = static_cast(workspace); + if (is_transB) { + loadb_trans(b_pannel, B, ldb, 0, K, x0, xmax); + } else { + loadb(b_pannel, B, ldb, 0, K, x0, xmax); + } +#pragma omp parallel for num_threads(threads) + for (unsigned int y = 0; y < M; y += MBLOCK_A73) { + unsigned int ymax = y + MBLOCK_A73; + if (ymax > M) { + ymax = M; + } + + float cout0[NBLOCK]; + float cout1[NBLOCK]; + float cout2[NBLOCK]; + float cout3[NBLOCK]; + + float bias_local[4] = {0}; + if (has_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + } + + float* c_ptr0 = C + y * ldc + x0; + float* c_ptr1 = c_ptr0 + ldc; + float* c_ptr2 = c_ptr1 + ldc; + float* c_ptr3 = c_ptr2 + ldc; + + float* pout0 = c_ptr0; + float* pout1 = c_ptr1; + float* pout2 = c_ptr2; + float* pout3 = c_ptr3; + + const float* a_ptr_l = A_packed + y * K; + const float* b_ptr = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + case 2: + c_ptr1 = cout1; + case 1: + c_ptr2 = cout1; + case 0: + c_ptr3 = cout1; + default: + break; + } + } + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + + if (has_beta) { + for (int i = 0; i < remain; ++i) { + cout0[i] = pout0[i]; + cout1[i] = pout1[i]; + cout2[i] = pout2[i]; + cout3[i] = pout3[i]; + } + } + } + const float* a_ptr = a_ptr_l; + int tails = tail_pre; + int k = k_pre; + asm volatile( + "vld1.32 {d4-d5}, [%[bias_ptr]] @ load bias\n" + "vdup.32 q8, d4[0] @ add bias to out00\n" + "pld [%[a_ptr]] @ preload a, 64byte\n" + "vdup.32 q9, d4[0] @ add bias to out01\n" + "pld [%[b_ptr]] @ preload b\n" + "vdup.32 q10, d4[1] @ add bias to out10\n" + "pld [%[a_ptr], #64] @ preload a\n" + "vdup.32 q11, d4[1] @ add bias to out11\n" + "vdup.32 q12, d5[0] @ add bias to out20\n" + "pld [%[b_ptr], #64] @ preload b\n" + "vdup.32 q13, d5[0] @ add bias to out21\n" + "pld [%[a_ptr], #128] @ preload a\n" + "vdup.32 q14, d5[1] @ add bias to out30\n" + "pld [%[b_ptr], #128] @ preload b\n" + "vdup.32 q15, d5[1] @ add bias to out31\n" + "pld [%[b_ptr], #192] @ preload b\n" + "cmp %[has_beta], #0\n" + "beq 11f\n" /* check beta == 0? */ + /* process beta */ + "vdup.32 q4, %[beta]\n" /* beta to vector */ + "vld1.32 {d0-d3}, [%[c_ptr0]]\n" /* load output r0 */ + "vld1.32 {d4-d7}, [%[c_ptr1]]\n" /* load output r1 */ + "vmla.f32 q8, q0, q4\n" /* cr00 += beta * c_r00 */ + "vmla.f32 q9, q1, q4\n" /* cr01 += beta * c_r01 */ + "vld1.32 {d0-d3}, [%[c_ptr2]]\n" /* load output r2 */ + "vmla.f32 q10, q2, q4\n" /* cr10 += beta * c_r10 */ + "vmla.f32 q11, q3, q4\n" /* cr11 += beta * c_r11 */ + "vld1.32 {d4-d7}, [%[c_ptr3]]\n" /* load output r3 */ + "vmla.f32 q12, q0, q4\n" /* cr20 += beta * c_r20 */ + "vmla.f32 q13, q1, q4\n" /* cr21 += beta * c_r21 */ + "vmla.f32 q14, q2, q4\n" /* cr30 += beta * c_r30 */ + "vmla.f32 q15, q3, q4\n" /* cr31 += beta * c_r31 */ + "11: \n" /* check loop count */ + "vld1.32 {d0-d3}, [%[a_ptr] :128]! @ load a0~a3\n" + "vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load b1\n" + "cmp %[k], #0 @ check weather k is bigger than " + "0\n" + "beq 0f @ jump to tail\n" + "1: @ main loop for k\n" + /* Unroll 0*/ + "vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1, b2\n" + "vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0\n" + "vld1.32 {d4-d7}, [%[a_ptr] :128]! @ load next 2xa0~a3\n" + "vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3\n" + "vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2\n" + /* Unroll 1 */ + "vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0\n" + "pld [%[b_ptr], #64] @ preload b\n" + "vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q11, q7, d2[1] @ out7 += b2 * a1\n" + "vmla.f32 q13, q7, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q15, q7, d3[1] @ out9 += b2 * a3\n" + "vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1,b2\n" + /* Unroll 2 */ + "vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0\n" + "vld1.32 {d0-d3}, [%[a_ptr] :128]! @ load next a0~a3\n" + "vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3\n" + "vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2\n" + /* Unroll 3 */ + "vmla.f32 q8, q6, d6[0] @ out0 += b1 * a0\n" + "pld [%[a_ptr], #64] @ preload a\n" + "vmla.f32 q10, q6, d6[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q6, d7[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d7[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d6[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q7, d6[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q7, d7[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q7, d7[1] @ out7 += b2 * a3\n" + "subs %[k], %[k], #1 @ k--\n" + "bne 1b @ jump to main loop\n" + "0: @ process tail\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "beq 3f @ jump to tail = 1\n" + /* Unroll 0*/ + "vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1, b2\n" + "vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3\n" + "beq 4f @ jump to tail==2\n" + /* Unroll 1 */ + "vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2\n" + "vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0\n" + "vld1.32 {d4-d7}, [%[a_ptr] :128]! @ load next 2xa0~a3\n" + "vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q11, q7, d2[1] @ out7 += b2 * a1\n" + "vmla.f32 q13, q7, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q15, q7, d3[1] @ out9 += b2 * a3\n" + "beq 5f @ jump to tail==3\n" + /* Unroll 2 */ + "vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1,b2\n" + "vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3\n" + /* Unroll 3 */ + "vmla.f32 q8, q6, d6[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q6, d6[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q6, d7[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d7[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d6[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q7, d6[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q7, d7[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q7, d7[1] @ out7 += b2 * a3\n" + "b 2f\n" + /* tails==1 final tail */ + "3: @ tail=1\n" + "vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3\n" + /*aptr - 16 */ + "sub %[a_ptr], %[a_ptr], #16 @ tail--\n" + "b 2f @ jump to end\n" + /* tails==2 final tail*/ + "4: @ tail == 2\n" + "vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d2[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q7, d2[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q7, d3[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q7, d3[1] @ out7 += b2 * a3\n" + "b 2f @ jump to end\n" + /* tails==3 final tail*/ + "5: @ tail=3\n" + "vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3\n" + /*aptr - 16*/ + "sub %[a_ptr], %[a_ptr], #16 @ tail--\n" + "2: @ check relu\n" + "cmp %[relu], #0 @ check if has relu\n" + "ble 6f @ skip relu if relu <= 0\n" + "vmov.u32 q0, #0 @ for relu\n" + "vmax.f32 q8, q8, q0 @ for relu\n" + "vmax.f32 q9, q9, q0 @ for relu\n" + "vmax.f32 q10, q10, q0 @ for relu\n" + "vmax.f32 q11, q11, q0 @ for relu\n" + "vmax.f32 q12, q12, q0 @ for relu\n" + "vmax.f32 q13, q13, q0 @ for relu\n" + "vmax.f32 q14, q14, q0 @ for relu\n" + "vmax.f32 q15, q15, q0 @ for relu\n" + "6: @ store result\n" + "vst1.32 {d16-d19}, [%[c_ptr0]]! @ store r0\n" + "vst1.32 {d20-d23}, [%[c_ptr1]]! @ store r1\n" + "vst1.32 {d24-d27}, [%[c_ptr2]]! @ store r2\n" + "vst1.32 {d28-d31}, [%[c_ptr3]]! @ store r3\n" + : [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), + [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), + [k] "+r"(k), + [tails] "+r"(tails) + : [bias_ptr] "r"(bias_local), + [relu] "r"(has_relu), + [has_beta] "r"(has_beta), + [beta] "r"(beta) + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "cc", + "memory"); + + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + } + } + } + } + } +} +#endif // __aarch64__ + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/packed_sgemm.h b/lite/arm/math/packed_sgemm.h new file mode 100644 index 00000000000..978ddc97601 --- /dev/null +++ b/lite/arm/math/packed_sgemm.h @@ -0,0 +1,84 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/context.h" +#include "lite/core/cpu_info.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +#ifdef __aarch64__ +constexpr int MBLOCK = 8; +constexpr int NBLOCK = 12; +constexpr int KBLOCK = 4; +inline int get_hblock(ARMArch arch) { return MBLOCK; } +#else +constexpr int MBLOCK_A73 = 4; +constexpr int MBLOCK_OTH = 6; +constexpr int NBLOCK = 8; +constexpr int KBLOCK = 4; +inline int get_hblock(ARMArch arch) { + if (arch == kA73) { + return MBLOCK_A73; + } else { + return MBLOCK_OTH; + } +} +#endif // __aarch64__ + +void prepackA(float* out, + const float* in, + float alpha, + int ldin, + int m0, + int mmax, + int k0, + int kmax, + bool is_trans, + ARMContext* ctx); + +void prepackA(TensorLite* tout, + const TensorLite& tin, + float alpha, + int m, + int k, + int group, + bool is_trans, + ARMContext* ctx); + +void sgemm_prepack(bool is_transB, + int M, + int N, + int K, + const float* A_packed, + const float* B, + int ldb, + float beta, + float* C, + int ldc, + const float* bias, + bool has_bias, + bool has_relu, + ARMContext* ctx); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/pad2d.cc b/lite/arm/math/pad2d.cc new file mode 100644 index 00000000000..66f91f37446 --- /dev/null +++ b/lite/arm/math/pad2d.cc @@ -0,0 +1,413 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/pad2d.h" +#include +#include +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void pad_constant(const float* din, + float* dout, + int n, + int c, + int h, + int w, + const int pad_top, + const int pad_bottom, + const int pad_left, + const int pad_right, + const float pad_value) { + int h_in = h - pad_top - pad_bottom; + int w_in = w - pad_left - pad_right; + int spatial_size_out = w * h; + int spatial_size_in = h_in * w_in; +#pragma omp parallel for + for (int s = 0; s < n * c; ++s) { + const float* din_s = din + s * spatial_size_in; + float* dout_s = dout + s * spatial_size_out; + int top_loop = (w * pad_top) >> 3; + int top_loop_remain = (w * pad_top) & 7; + float32x4_t vpad_value = vdupq_n_f32(pad_value); + // process top + for (int i = 0; i < top_loop; ++i) { + vst1q_f32(dout_s, vpad_value); + vst1q_f32(dout_s + 4, vpad_value); + dout_s += 8; + } + for (int i = 0; i < top_loop_remain; ++i) { + *dout_s++ = pad_value; + } + // process med + int left_loop = pad_left >> 2; + int left_loop_remain = pad_left & 3; + int med_loop = w_in >> 3; + int med_loop_remain = w_in & 7; + for (int i = 0; i < left_loop; ++i) { + vst1q_f32(dout_s, vpad_value); + dout_s += 4; + } + + for (int i = 0; i < left_loop_remain; ++i) { + *dout_s++ = pad_value; + } + + for (int i = 0; i < med_loop; ++i) { + float32x4_t val = vld1q_f32(din_s); + float32x4_t val1 = vld1q_f32(din_s + 4); + vst1q_f32(dout_s, val); + vst1q_f32(dout_s + 4, val1); + dout_s += 8; + din_s += 8; + } + for (int i = 0; i < med_loop_remain; ++i) { + float val = *din_s++; + *dout_s++ = val; + } + + int loop = (pad_right + pad_left) >> 2; + int loop_remain = (pad_right + pad_left) & 3; + for (int j = 0; j < h_in - 1; ++j) { + for (int i = 0; i < loop; ++i) { + vst1q_f32(dout_s, vpad_value); + dout_s += 4; + } + + for (int i = 0; i < loop_remain; ++i) { + *dout_s++ = pad_value; + } + + for (int i = 0; i < med_loop; ++i) { + float32x4_t val = vld1q_f32(din_s); + float32x4_t val1 = vld1q_f32(din_s + 4); + vst1q_f32(dout_s, val); + vst1q_f32(dout_s + 4, val1); + dout_s += 8; + din_s += 8; + } + + for (int i = 0; i < med_loop_remain; ++i) { + *dout_s++ = *din_s++; + } + } + int right_loop = pad_right >> 2; + int right_loop_remain = pad_right & 3; + + for (int i = 0; i < right_loop; ++i) { + vst1q_f32(dout_s, vpad_value); + dout_s += 4; + } + + for (int i = 0; i < right_loop_remain; ++i) { + *dout_s++ = pad_value; + } + // process bottom + int bottom_loop = (pad_bottom * w) >> 3; + int bottom_loop_remain = (pad_bottom * w) & 7; + for (int i = 0; i < bottom_loop; ++i) { + vst1q_f32(dout_s, vpad_value); + vst1q_f32(dout_s + 4, vpad_value); + dout_s += 8; + } + for (int i = 0; i < bottom_loop_remain; ++i) { + *dout_s++ = pad_value; + } + } +} + +void pad_edge(const float* din, + float* dout, + int n, + int c, + int h, + int w, + const int pad_top, + const int pad_bottom, + const int pad_left, + const int pad_right, + const float pad_value) { + int h_in = h - pad_top - pad_bottom; + int w_in = w - pad_left - pad_right; + int spatial_size_out = w * h; + int spatial_size_in = h_in * w_in; +#pragma omp parallel for + for (int s = 0; s < n * c; ++s) { + const float* din_s = din + s * spatial_size_in; + float* dout_s = dout + s * spatial_size_out; + + // process med + int left_loop = pad_left >> 2; + int right_loop = pad_right >> 2; + int med_loop = w_in >> 3; + int med_loop_remain = w_in & 7; + int left_loop_remain = pad_left & 3; + int right_loop_remain = pad_right & 3; + float* dout_med = dout_s + w * pad_top; + for (int j = 0; j < h_in; ++j) { + float edge_val = din_s[0]; + float32x4_t vedge = vdupq_n_f32(edge_val); + for (int i = 0; i < left_loop; ++i) { + vst1q_f32(dout_med, vedge); + dout_med += 4; + } + for (int i = 0; i < left_loop_remain; ++i) { + *dout_med++ = edge_val; + } + for (int i = 0; i < med_loop; ++i) { + float32x4_t val = vld1q_f32(din_s); + float32x4_t val1 = vld1q_f32(din_s + 4); + vst1q_f32(dout_med, val); + vst1q_f32(dout_med + 4, val1); + din_s += 8; + dout_med += 8; + } + for (int i = 0; i < med_loop_remain; ++i) { + *dout_med++ = *din_s++; + } + edge_val = din_s[-1]; + vedge = vdupq_n_f32(edge_val); + for (int i = 0; i < right_loop; ++i) { + vst1q_f32(dout_med, vedge); + dout_med += 4; + } + for (int i = 0; i < right_loop_remain; ++i) { + *dout_med++ = edge_val; + } + } + + // process bottom + float* dout_bottom = dout_med; + for (int i = 0; i < pad_bottom; ++i) { + memcpy(dout_bottom, dout_s + w * (pad_top + h_in - 1), w * sizeof(float)); + dout_bottom += w; + } + + // process top + float* dout_top = dout_s; + for (int i = 0; i < pad_top; ++i) { + memcpy(dout_top, dout_s + w * pad_top, w * sizeof(float)); + dout_top += w; + } + } +} + +void pad_reflect(const float* din, + float* dout, + int n, + int c, + int h, + int w, + const int pad_top, + const int pad_bottom, + const int pad_left, + const int pad_right, + const float pad_value) { + int h_in = h - pad_top - pad_bottom; + int w_in = w - pad_left - pad_right; + int spatial_size_out = w * h; + int spatial_size_in = h_in * w_in; +#pragma omp parallel for + for (int s = 0; s < n * c; ++s) { + const float* din_s = din + s * spatial_size_in; + float* dout_s = dout + s * spatial_size_out; + + // process med + int left_loop = pad_left >> 2; + int right_loop = pad_right >> 2; + int med_loop = w_in >> 3; + int med_loop_remain = w_in & 7; + int left_loop_remain = pad_left & 3; + int right_loop_remain = pad_right & 3; + float* dout_med = dout_s + w * pad_top; + for (int j = 0; j < h_in; ++j) { +#ifdef __aarch64__ + for (int i = 0; i < left_loop; ++i) { + float32x4_t val = vld1q_f32(din_s + left_loop_remain + + ((left_loop - i - 1) << 2) + 1); + val = vrev64q_f32(val); + float32x2_t low = vget_low_f32(val); + float32x2_t high = vget_high_f32(val); + float32x2_t tmp = low; + low = high; + high = tmp; + float32x4_t val1 = vcombine_f32(low, high); + vst1q_f32(dout_med, val1); + dout_med += 4; + } +#else + const float* din_s_ptr = + din_s + left_loop_remain + ((left_loop - 1) << 2) + 1; + int cnt = left_loop; + if (cnt > 0) { + asm volatile( + "1: \n" + "vld1.32 {d0-d1}, [%[din_s]] \n" + "subs %[cnt], #1 \n" + "sub %[din_s], #16 \n" + "vrev64.32 q1, q0 \n" + "vswp d2, d3 \n" + "vst1.32 {d2-d3}, [%[dout_med]]!\n" + "bne 1b \n" + : + [din_s] "+r"(din_s_ptr), [dout_med] "+r"(dout_med), [cnt] "+r"(cnt) + : + : "cc", "memory", "q0", "q1"); + } +#endif // __aarch64__ + for (int i = 0; i < left_loop_remain; ++i) { + *dout_med++ = *(din_s + left_loop_remain - i); + } + for (int i = 0; i < med_loop; ++i) { + float32x4_t val = vld1q_f32(din_s); + float32x4_t val1 = vld1q_f32(din_s + 4); + vst1q_f32(dout_med, val); + vst1q_f32(dout_med + 4, val1); + din_s += 8; + dout_med += 8; + } + for (int i = 0; i < med_loop_remain; ++i) { + *dout_med++ = *din_s++; + } +#ifdef __aarch64__ + for (int i = 0; i < right_loop; ++i) { + float32x4_t val = vld1q_f32(din_s - ((i + 1) << 2) - 1); + val = vrev64q_f32(val); + float32x2_t low = vget_low_f32(val); + float32x2_t high = vget_high_f32(val); + float32x2_t tmp = low; + low = high; + high = tmp; + float32x4_t val1 = vcombine_f32(low, high); + vst1q_f32(dout_med, val1); + dout_med += 4; + } +#else + din_s_ptr = din_s - 5; + cnt = right_loop; + if (cnt > 0) { + asm volatile( + "1: \n" + "vld1.32 {d0-d1}, [%[din_s]] \n" + "subs %[cnt], #1 \n" + "sub %[din_s], #16 \n" + "vrev64.32 q1, q0 \n" + "vswp d2, d3 \n" + "vst1.32 {d2-d3}, [%[dout_med]]!\n" + "bne 1b \n" + : + [din_s] "+r"(din_s_ptr), [dout_med] "+r"(dout_med), [cnt] "+r"(cnt) + : + : "cc", "memory", "q0", "q1"); + } +#endif // __aarch64__ + const float* remain = din_s - (right_loop << 2) - 2; + for (int i = 0; i < right_loop_remain; ++i) { + *dout_med++ = *remain--; + } + } + + // process bottom + float* dout_bottom = dout_med; + float* dout_bottom_reflect = dout_med - (w << 1); + for (int i = 0; i < pad_bottom; ++i) { + memcpy(dout_bottom, dout_bottom_reflect, w * sizeof(float)); + dout_bottom += w; + dout_bottom_reflect -= w; + } + + // process top + float* dout_top = dout_s; + float* dout_top_reflect = dout_s + w * (pad_top << 1); + for (int i = 0; i < pad_top; ++i) { + memcpy(dout_top, dout_top_reflect, w * sizeof(float)); + dout_top += w; + dout_top_reflect -= w; + } + } +} + +// void pad2d_func(const lite::Tensor *input,lite::Tensor *output) +void pad2d_func(const lite::Tensor* input, + lite::Tensor* output, + int _mode, + std::vector _pad_h, + std::vector _pad_w, + float _pad_value) { + float* dout = output->mutable_data(); // modified by zhiqiang + const float* din = input->data(); // modified by zhiqiang + + auto output_dims = output->dims(); + // nchw + int on = output_dims[0]; + int oc = output_dims[1]; + int oh = output_dims[2]; + int ow = output_dims[3]; + ///////////////////////////// + /* _mode是PadMode + typedef enum{ + PAD_CONSTANT = 0, + PAD_EDGE = 1, + PAD_REFLECT = 2, + } PadMode; */ + ///////////////////////// + if (_mode == 0) { + pad_constant(din, + dout, + on, + oc, + oh, + ow, + _pad_h[0], + _pad_h[1], + _pad_w[0], + _pad_w[1], + _pad_value); + } else if (_mode == 1) { + pad_edge(din, + dout, + on, + oc, + oh, + ow, + _pad_h[0], + _pad_h[1], + _pad_w[0], + _pad_w[1], + _pad_value); + } else if (_mode == 2) { + pad_reflect(din, + dout, + on, + oc, + oh, + ow, + _pad_h[0], + _pad_h[1], + _pad_w[0], + _pad_w[1], + _pad_value); + } else { + LOG(ERROR) << "ERROR: unknown pad mode " << _mode; + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/pad2d.h b/lite/arm/math/pad2d.h new file mode 100644 index 00000000000..08c5c8c1a24 --- /dev/null +++ b/lite/arm/math/pad2d.h @@ -0,0 +1,71 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/operators/op_params.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void pad_constant(const float* din, + float* dout, + int n, + int c, + int h, + int w, + const int pad_top, + const int pad_bottom, + const int pad_left, + const int pad_right, + const float pad_value); +void pad_edge(const float* din, + float* dout, + int n, + int c, + int h, + int w, + const int pad_top, + const int pad_bottom, + const int pad_left, + const int pad_right, + const float pad_value); +void pad_reflect(const float* din, + float* dout, + int n, + int c, + int h, + int w, + const int pad_top, + const int pad_bottom, + const int pad_left, + const int pad_right, + const float pad_value); +void pad2d_func(const lite::Tensor* input, + lite::Tensor* output, + int _mode, + std::vector _pad_h, + std::vector _pad_w, + float _pad_value); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/pooling.cc b/lite/arm/math/pooling.cc new file mode 100644 index 00000000000..d90da0edf05 --- /dev/null +++ b/lite/arm/math/pooling.cc @@ -0,0 +1,3173 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/pooling.h" +#include +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void pooling_basic(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + bool global_pooling, + bool exclusive, + bool adaptive, + bool ceil_mode, + bool use_quantizer, + const std::string& pooling_type) { + // no need to pad input tensor, border is zero pad inside this function + int kernel_h = ksize[0]; + int kernel_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + int size_channel_in = win * hin; + int size_channel_out = wout * hout; + if (global_pooling) { + if (pooling_type == "max") { // Pooling_max + for (int n = 0; n < num; ++n) { + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; ++c) { + const float* din_ch = din_batch + c * size_channel_in; // in address + float tmp1 = din_ch[0]; + for (int i = 0; i < size_channel_in; ++i) { + float tmp2 = din_ch[i]; + tmp1 = tmp1 > tmp2 ? tmp1 : tmp2; + } + dout_batch[c] = tmp1; + } + } + } else if (pooling_type == "avg") { + // Pooling_average_include_padding + // Pooling_average_exclude_padding + for (int n = 0; n < num; ++n) { + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; ++c) { + const float* din_ch = din_batch + c * size_channel_in; // in address + float sum = 0.f; + for (int i = 0; i < size_channel_in; ++i) { + sum += din_ch[i]; + } + dout_batch[c] = sum / size_channel_in; + } + } + } else { + LOG(FATAL) << "unsupported pooling type: " << pooling_type; + } + } else { + if (pooling_type == "max") { + // Pooling_max + for (int n = 0; n < num; ++n) { + float* dout_ch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* dout_row = dout_ch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + for (int i = 0; i < hout; i++) { + for (int j = 0; j < wout; j++) { + int hstart = i * stride_h - pad_h; + int wstart = j * stride_w - pad_w; + int hend = std::min(hstart + kernel_h, hin + pad_h); + int wend = std::min(wstart + kernel_w, win + pad_w); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, hin); + wend = std::min(wend, win); + int pool_size = (hend - hstart) * (wend - wstart); + if (pool_size == 0) continue; + float tmp1 = din_ch[hstart * win + wstart]; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + float tmp2 = din_ch[h * win + w]; + tmp1 = tmp1 > tmp2 ? tmp1 : tmp2; + } + } + dout_row[j] = tmp1; + } + dout_row += wout; + } + } + } + } else if (pooling_type == "avg") { + if (exclusive) { + // Pooling_average_exclude_padding + for (int n = 0; n < num; ++n) { + float* dout_ch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* dout_row = dout_ch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + for (int i = 0; i < hout; i++) { + for (int j = 0; j < wout; j++) { + int hstart = i * stride_h - pad_h; + int wstart = j * stride_w - pad_w; + int hend = std::min(hstart + kernel_h, hin + pad_h); + int wend = std::min(wstart + kernel_w, win + pad_w); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, hin); + wend = std::min(wend, win); + int pool_size = (hend - hstart) * (wend - wstart); + if (pool_size == 0) continue; + float sum = 0.f; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + sum += din_ch[h * win + w]; + } + } + dout_row[j] = sum / pool_size; + } + dout_row += wout; + } + } + } + } else { // Pooling_average_include_padding + for (int n = 0; n < num; ++n) { + float* dout_ch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* dout_row = dout_ch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + for (int i = 0; i < hout; i++) { + for (int j = 0; j < wout; j++) { + int hstart = i * stride_h - pad_h; + int wstart = j * stride_w - pad_w; + int hend = std::min(hstart + kernel_h, hin + pad_h); + int wend = std::min(wstart + kernel_w, win + pad_w); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, hin); + wend = std::min(wend, win); + int pool_size = (hend - hstart) * (wend - wstart); + if (pool_size == 0) continue; + float sum = 0.f; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + sum += din_ch[h * win + w]; + } + } + dout_row[j] = sum / (kernel_w * kernel_h); + } + dout_row += wout; + } + } + } + } + } else { + LOG(FATAL) << "unsupported pooling type: " << pooling_type; + } + } +} + +void pooling_global_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win) { + int size_channel_in = win * hin; + int cnt = size_channel_in / 8; + for (int n = 0; n < num; ++n) { + float* dout_batch = dout + n * chout; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; ++c) { + const float* din_ch = din_batch + c * size_channel_in; + int i = 0; + float minval = std::numeric_limits::lowest(); + float32x4_t vmax = vdupq_n_f32(minval); +#ifdef __aarch64__ + for (; i < cnt; i++) { + float32x4_t vdin1 = vld1q_f32(din_ch); + vmax = vmaxq_f32(vdin1, vmax); + float32x4_t vdin2 = vld1q_f32(din_ch + 4); + vmax = vmaxq_f32(vmax, vdin2); + din_ch += 8; + } +#else + int cnt_num = cnt; + if (cnt_num > 0) { + asm volatile( + "max_loop: @main loop\n" + "vld1.f32 {d0-d1}, [%[din_ch]]! @load q1,din_ch\n" + "vmax.f32 %q[vmax], %q[vmax], q0 @max vmax,vmax,din_ch\n" + "vld1.f32 {d2-d3}, [%[din_ch]]! @load 2nd 4 data\n" + "vmax.f32 %q[vmax], %q[vmax], q1 @compare 2nd 4 datas\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "bne max_loop @bne cnt_num\n" + : [din_ch] "+r"(din_ch), [cnt_num] "+r"(cnt_num), [vmax] "+w"(vmax) + : + : "cc", "memory", "q0", "q1"); + } +#endif // __aarch64__ + float32x2_t vmax_tmp = vmax_f32(vget_low_f32(vmax), vget_high_f32(vmax)); + float tmp1 = vget_lane_f32(vmax_tmp, 0); + float tmp2 = vget_lane_f32(vmax_tmp, 1); + float max_tmp = tmp1 > tmp2 ? tmp1 : tmp2; + for (i = cnt * 8; i < size_channel_in; ++i) { + /* code */ + max_tmp = max_tmp > din_ch[0] ? max_tmp : din_ch[0]; + din_ch++; + } + dout_batch[c] = max_tmp; + } + } +} + +void pooling_global_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win) { + int size_channel_in = win * hin; + int cnt = size_channel_in / 4; + for (int n = 0; n < num; ++n) { + float* dout_batch = dout + n * chout; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + const float* din_ch = din_batch + c * size_channel_in; // in address + int i = 0; + float32x4_t vsum = vdupq_n_f32(0.0f); +#ifdef __aarch64__ + for (; i < cnt; i++) { + vsum = vaddq_f32(vld1q_f32(din_ch), vsum); + din_ch += 4; + } +#else + int cnt_num = cnt; + if (cnt_num > 0) { + asm volatile( + "add_loop: @main loop\n" + "vld1.f32 {d0-d1}, [%[din_ch]]! @load q1,din_ch\n" + "vadd.f32 %q[vsum], %q[vsum], q0 @add vmax,vmax, din_ch\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "bne add_loop @bne num\n" + : [din_ch] "+r"(din_ch), [cnt_num] "+r"(cnt_num), [vsum] "+w"(vsum) + : + : "cc", "memory", "q0"); + } +#endif // __aarch64__ + float32x2_t vsum_tmp = vadd_f32(vget_low_f32(vsum), vget_high_f32(vsum)); + float sum = vget_lane_f32(vsum_tmp, 0) + vget_lane_f32(vsum_tmp, 1); + for (i = cnt * 4; i < size_channel_in; i++) { + sum += din_ch[0]; + din_ch++; + } + dout_batch[c] = sum / size_channel_in; + } + } +} + +void pooling2x2s2_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win) { + int kernel = 2; + int stride = 2; + int padding = 0; + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + + int w_needed = (wout << 1); + int h_needed = (hout << 1); + int w_limit = w_needed > win ? win : w_needed; + int h_limit = h_needed > hin ? hin : h_needed; + int w_even = (w_limit >> 1) << 1; + int h_even = (h_limit >> 1) << 1; + int w_unroll_size = (w_even >> 3) << 3; + // int w_unroll_remain = w_even - w_unroll_size; + int w_in_2 = win << 1; + for (int n = 0; n < num; ++n) { + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* dout_ch = dout_batch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + const float* r0 = din_ch; + const float* r1 = r0 + win; + int h = 0; + for (; h < h_even; h += 2) { + int w = 0; +#ifdef __aarch64__ + for (; w < w_unroll_size; w += 8) { + float32x4_t dr00 = vld1q_f32(&r0[w]); + float32x4_t dr01 = vld1q_f32(&r0[w + 4]); + float32x4_t dr10 = vld1q_f32(&r1[w]); + float32x4_t dr11 = vld1q_f32(&r1[w + 4]); + float32x4_t dmax1 = vmaxq_f32(dr00, dr10); + float32x4_t dmax2 = vmaxq_f32(dr01, dr11); +#ifdef __aarch64__ + float32x4_t dmax = vpmaxq_f32(dmax1, dmax2); +#else + float32x2_t dmaxl = + vpmax_f32(vget_low_f32(dmax1), vget_high_f32(dmax1)); + float32x2_t dmaxh = + vpmax_f32(vget_low_f32(dmax2), vget_high_f32(dmax2)); + float32x4_t dmax = vcombine_f32(dmaxl, dmaxh); +#endif + vst1q_f32(&dout_ch[w >> 1], dmax); + } +#else + float* dr_out = dout_ch; + const float* dr0 = r0; + const float* dr1 = r1; + int cnt_num = w_unroll_size >> 3; + if (cnt_num > 0) { + asm volatile( + "s2_max_loop: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0\n" + "vld1.f32 {d4-d7}, [%[dr1]]! @load q1,dr1\n" + "vmax.f32 q0, q0, q2 @max q0,q0,q2\n" + "vmax.f32 q1, q1, q3 @max q1,q1,q2\n" + "vpmax.f32 d4, d0, d1 @max d4,d0,d1\n" + "vpmax.f32 d5, d2, d3 @max d5,d2,d3\n" + "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "bne s2_max_loop @bne cnt_num\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : + : "cc", "memory", "q0", "q1", "q2", "q3"); + } + w = w_unroll_size; +#endif // __aarch64__ + for (; w < w_even; w += 2) { + dout_ch[w >> 1] = + std::max(std::max(r0[w], r0[w + 1]), std::max(r1[w], r1[w + 1])); + } + for (; w < w_limit; ++w) { // run 0 or 1 time + dout_ch[w >> 1] = std::max(r0[w], r1[w]); + } + r0 += w_in_2; // << 1; + r1 += w_in_2; // << 1; + dout_ch += wout; + } + // process remain row (odd, last row) + for (; h < h_limit; h++) { // run 0 or 1 time + int w = 0; +#ifdef __aarch64__ + for (; w < w_unroll_size; w += 8) { + float32x4_t dr00 = vld1q_f32(&r0[w]); + float32x4_t dr01 = vld1q_f32(&r0[w + 4]); +#ifdef __aarch64__ + float32x4_t dmax = vpmaxq_f32(dr00, dr01); +#else + float32x2_t dmaxl = + vpmax_f32(vget_low_f32(dr00), vget_high_f32(dr00)); + float32x2_t dmaxh = + vpmax_f32(vget_low_f32(dr01), vget_high_f32(dr01)); + float32x4_t dmax = vcombine_f32(dmaxl, dmaxh); +#endif + vst1q_f32(&dout_ch[w >> 1], dmax); + } +#else + float* dr_out = dout_ch; + const float* dr0 = r0; + int cnt_num = w_unroll_size >> 3; + if (cnt_num > 0) { + asm volatile( + "s2_max_loop1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0\n" + "vpmax.f32 d4, d0, d1 @max d4,d0,d1\n" + "vpmax.f32 d5, d2, d3 @max d5,d2,d3\n" + "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "bne s2_max_loop1 @bne cnt_num\n" + : [dr0] "+r"(dr0), [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num) + : + : "cc", "memory", "q0", "q1", "q2"); + } + w = w_unroll_size; +#endif // __aarch64__ + for (; w < w_even; w += 2) { + dout_ch[w >> 1] = std::max(r0[w], r0[w + 1]); + } + for (; w < w_limit; ++w) { // run 0 or 1 time + dout_ch[w >> 1] = r0[w]; + } + } + } + } +} + +void pooling2x2s2_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive) { + int kernel = 2; + int stride = 2; + int padding = 0; + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + + int w_needed = (wout << 1); + int h_needed = (hout << 1); + int w_limit = w_needed > win ? win : w_needed; + int h_limit = h_needed > hin ? hin : h_needed; + int w_even = (w_limit >> 1) << 1; + int h_even = (h_limit >> 1) << 1; + int w_unroll_size = (w_even >> 3) << 3; + // int w_unroll_remain = w_even - w_unroll_size; + int w_in_2 = win << 1; + const float coef = 1.f / 4.f; + const float coef_1 = exclusive ? 1.f : coef; + const float coef_2 = exclusive ? 1.f / 2.f : coef; + float32x4_t vcoef = vdupq_n_f32(coef); + float32x4_t vcoef_1 = vdupq_n_f32(coef_1); + float32x4_t vcoef_2 = vdupq_n_f32(coef_2); + for (int n = 0; n < num; ++n) { + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* dout_ch = dout_batch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + const float* r0 = din_ch; + const float* r1 = r0 + win; + int h = 0; + for (; h < h_even; h += 2) { + int w = 0; +#ifdef __aarch64__ + for (; w < w_unroll_size; w += 8) { + float32x4_t dr00 = vld1q_f32(&r0[w]); + float32x4_t dr01 = vld1q_f32(&r0[w + 4]); + float32x4_t dr10 = vld1q_f32(&r1[w]); + float32x4_t dr11 = vld1q_f32(&r1[w + 4]); + float32x4_t dsum1 = vaddq_f32(dr00, dr10); + float32x4_t dsum2 = vaddq_f32(dr01, dr11); +#ifdef __aarch64__ + float32x4_t dsum = vpaddq_f32(dsum1, dsum2); +#else + float32x2_t dsuml = + vpadd_f32(vget_low_f32(dsum1), vget_high_f32(dsum1)); + float32x2_t dsumh = + vpadd_f32(vget_low_f32(dsum2), vget_high_f32(dsum2)); + float32x4_t dsum = vcombine_f32(dsuml, dsumh); +#endif + float32x4_t res = vmulq_f32(dsum, vcoef); + vst1q_f32(&dout_ch[w >> 1], res); + } +#else + float* dr_out = dout_ch; + const float* dr0 = r0; + const float* dr1 = r1; + int cnt_num = w_unroll_size >> 3; + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0\n" + "vld1.f32 {d4-d7}, [%[dr1]]! @load q1,dr1\n" + "vadd.f32 q0, q0, q2 @add q0,q0,q2\n" + "vadd.f32 q1, q1, q3 @add q1,q1,q2\n" + "vpadd.f32 d4, d0, d1 @add d4,d0,d1\n" + "vpadd.f32 d5, d2, d3 @add d5,d2,d3\n" + "vmul.f32 q2, q2, %q[vcoef] @mul q2,q2,vcoef\n" + "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "bne 1b @bne cnt_num\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [vcoef] "+w"(vcoef), + [cnt_num] "+r"(cnt_num) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "w"(vcoef) + : "cc", "memory", "q0", "q1", "q2", "q3"); + } + w = w_unroll_size; +#endif // __aarch64__ + for (; w < w_even; w += 2) { + dout_ch[w >> 1] = (r0[w] + r0[w + 1] + r1[w] + r1[w + 1]) * coef; + } + for (; w < w_limit; ++w) { // run 0 or 1 time + dout_ch[w >> 1] = (r0[w] + r1[w]) * coef_2; + } + r0 += w_in_2; // << 1; + r1 += w_in_2; // << 1; + dout_ch += wout; + } + // process remain row (odd, last row) + for (; h < h_limit; h++) { // run 0 or 1 time + int w = 0; +#ifdef __aarch64__ + for (; w < w_unroll_size; w += 8) { + float32x4_t dr00 = vld1q_f32(&r0[w]); + float32x4_t dr01 = vld1q_f32(&r0[w + 4]); +#ifdef __aarch64__ + float32x4_t dsum = vpaddq_f32(dr00, dr01); +#else + float32x2_t dsuml = + vpadd_f32(vget_low_f32(dr00), vget_high_f32(dr00)); + float32x2_t dsumh = + vpadd_f32(vget_low_f32(dr01), vget_high_f32(dr01)); + float32x4_t dsum = vcombine_f32(dsuml, dsumh); +#endif + float32x4_t res = vmulq_f32(dsum, vcoef_2); + vst1q_f32(&dout_ch[w >> 1], res); + } +#else + float* dr_out = dout_ch; + const float* dr0 = r0; + int cnt_num = w_unroll_size >> 3; + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0\n" + "vpadd.f32 d4, d0, d1 @add d4,d0,d1\n" + "vpadd.f32 d5, d2, d3 @add d5,d2,d3\n" + "vmul.f32 q2, q2, %q[vcoef_2] @mul q2,q2,vcoef_2\n" + "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "bne 1b @bne cnt_num\n" + : [dr0] "+r"(dr0), + [dr_out] "+r"(dr_out), + [vcoef_2] "+w"(vcoef_2), + [cnt_num] "+r"(cnt_num) + : "r"(dr0), "r"(dr_out), "r"(cnt_num), "w"(vcoef_2) + : "cc", "memory", "q0", "q1", "q2"); + } + w = w_unroll_size; +#endif // __aarch64__ + for (; w < w_even; w += 2) { + dout_ch[w >> 1] = (r0[w] + r0[w + 1]) * coef_2; + } + for (; w < w_limit; ++w) { // run 0 or 1 time + dout_ch[w >> 1] = r0[w] * coef_1; + } + } + } + } +} + +void pooling3x3s1p1_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win) { + int kernel = 3; + int stride = 1; + int padding = 1; + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + + int w_unroll_size = ((win - 2) >> 2) << 2; + int w_unroll_remain = win - 2 - w_unroll_size; + const float minval = std::numeric_limits::lowest(); + for (int n = 0; n < num; ++n) { + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* dout_ch = dout_batch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + const float* r0 = din_ch; + const float* r1 = r0 + win; + const float* r2 = r1 + win; + int cnt_num = w_unroll_size >> 2; // w_unroll_size / 4 + float* dr_out = dout_ch; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + int w = 0; + int cnt = 1; + // left + dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); +// first row with zero pad +#ifdef __aarch64__ + for (; w < w_unroll_size; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_34_56 = + vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56); + float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0)); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3); + vst1q_f32(&dout_ch[cnt], vmax); + cnt += 4; + } + +#else + dr_out = dr_out + 1; + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" + "vmax.f32 q5, q0, q2 @max r0_1234,r1_1234\n" + "vmax.f32 d12, d2, d6 @max r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q5, q6, #1 @vext max_2345\n" + "vext.f32 q2, q5, q6, #2 @vext max_3456\n" + "vpmax.f32 d2, d10, d11 @pmax d4,max_1234,max_1234\n" + "vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345\n" + "vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456\n" + "vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45\n" + "vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56\n" + "sub %[dr0], #8 @sub w,8\n" + "sub %[dr1], #8 @sub w,8\n" + // swap + "vmov.f32 s0, s17 @mov\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s0 @mov\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s1_max_loop\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); + } + +#endif + // remain + w = w_unroll_size; + for (int j = 0; j < w_unroll_remain; j++) { + float tmp_max = std::max(r0[j + w], r1[j + w]); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1])); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2])); + dout_ch[j + w + 1] = tmp_max; + } + // right + float tmp = std::max(r0[win - 2], r1[win - 2]); + tmp = std::max(tmp, std::max(r0[win - 1], r1[win - 1])); + dout_ch[wout - 1] = tmp; + + // r0 = r1; + // r1 = r0 + w_in; + // r2 = r1 + w_in; + dout_ch += wout; + int h = 0; + for (; h < hin - 2; h += 1) { + // deal with left pad + float maxr0 = std::max(r0[0], r0[1]); + float maxr1 = std::max(r1[0], r1[1]); + float maxr2 = std::max(r2[0], r2[1]); + dout_ch[0] = std::max(std::max(maxr0, maxr1), maxr2); +#ifdef __aarch64__ + w = 0; + cnt = 1; + for (; w < w_unroll_size; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr2_1234 = vld1q_f32(&r2[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678); + + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_34_56 = + vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56); + float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0)); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3); + vst1q_f32(&dout_ch[cnt], vmax); + cnt += 4; + } +#else + dr_out = dout_ch + 1; + dr0 = r0; + dr1 = r1; + dr2 = r2; + cnt_num = w_unroll_size >> 2; + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7,dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d10}, [%[dr2]]! @load d4-d7, dr1\n" + "vmax.f32 q7, q0, q2 @max r0_1234,r1_1234\n" + "vmax.f32 d16, d2, d6 @max r0_5678,r1_5678\n" + "vmax.f32 q3, q7, q4 @max r0_1234,r1_1234\n" + "vmax.f32 d12, d16, d10 @max r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q3, q6, #1 @vext max_2345\n" + "vext.f32 q2, q3, q6, #2 @vext max_3456\n" + "vpmax.f32 d2, d6, d7 @pmax d4,max_1234,max_1234\n" + "vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345\n" + "vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456\n" + "vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45\n" + "vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56\n" + "sub %[dr0], #8 @sub w,8\n" + "sub %[dr1], #8 @sub w,8\n" + "sub %[dr2], #8 @sub w,8\n" + // swap + "vmov.f32 s0, s17 @mov\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s0 @mov\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s1_max_loop\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8"); + } +#endif + // remain + w = w_unroll_size; + for (int j = 0; j < w_unroll_remain; j++) { + float tmp_max = std::max(r0[j + w], r1[j + w]); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1])); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2])); + tmp_max = std::max(tmp_max, std::max(r2[j + w], r2[j + w + 1])); + tmp_max = std::max(tmp_max, r2[j + w + 2]); + dout_ch[j + w + 1] = tmp_max; + } + // right + tmp = std::max(r0[win - 2], r1[win - 2]); + tmp = std::max(tmp, std::max(r0[win - 1], r1[win - 1])); + tmp = std::max(tmp, std::max(r2[win - 2], r2[win - 1])); + dout_ch[wout - 1] = tmp; + + r0 = r1; + r1 = r2; + r2 = r1 + win; + dout_ch += wout; + } + + // the last two line + float maxr0 = std::max(r0[0], r0[1]); + float maxr1 = std::max(r1[0], r1[1]); + dout_ch[0] = std::max(maxr0, maxr1); +#ifdef __aarch64__ + w = 0; + cnt = 1; + for (; w < w_unroll_size; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_34_56 = + vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56); + float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0)); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3); + vst1q_f32(&dout_ch[cnt], vmax); + cnt += 4; + } +#else + dr_out = dout_ch + 1; + dr0 = r0; + dr1 = r1; + cnt_num = w_unroll_size >> 2; + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" + "vmax.f32 q5, q0, q2 @max r0_1234,r1_1234\n" + "vmax.f32 d12, d2, d6 @max r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q5, q6, #1 @vext max_2345\n" + "vext.f32 q2, q5, q6, #2 @vext max_3456\n" + "vpmax.f32 d2, d10, d11 @pmax d4,max_1234,max_1234\n" + "vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345\n" + "vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456\n" + "vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45\n" + "vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56\n" + "sub %[dr0], #8 @sub w,8\n" + "sub %[dr1], #8 @sub w,8\n" + // swap + "vmov.f32 s0, s17 @mov\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s0 @mov\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s1_max_loop\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); + } +#endif + // remian + w = w_unroll_size; + for (int j = 0; j < w_unroll_remain; j++) { + float tmp_max = std::max(r0[j + w], r1[j + w]); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1])); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2])); + dout_ch[j + w + 1] = tmp_max; + } + tmp = std::max(r0[win - 2], r1[win - 2]); + tmp = std::max(tmp, std::max(r0[win - 1], r1[win - 1])); + dout_ch[wout - 1] = tmp; + } + } +} + +void pooling3x3s1p1_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive) { + int kernel = 3; + int stride = 1; + int padding = 1; + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + + int w_unroll_size = ((win - 2) >> 2) << 2; + int w_unroll_remain = win - 2 - w_unroll_size; + const float coef = 1.f / 9.f; + const float coef_2 = exclusive ? 1.f / 2.f : coef; + const float coef_4 = exclusive ? 1.f / 4.f : coef; + const float coef_6 = exclusive ? 1.f / 6.f : coef; + float32x4_t vcoef = vdupq_n_f32(coef); + float32x4_t vcoef_2 = vdupq_n_f32(coef_2); + float32x4_t vcoef_4 = vdupq_n_f32(coef_4); + float32x4_t vcoef_6 = vdupq_n_f32(coef_6); + for (int n = 0; n < num; ++n) { + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* dout_ch = dout_batch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + const float* r0 = din_ch; + const float* r1 = r0 + win; + const float* r2 = r1 + win; + int cnt_num = w_unroll_size >> 2; // w_unroll_size / 4 + float* dr_out = dout_ch; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + int w = 0; + int cnt = 1; + // left + dout_ch[0] = (r0[0] + r0[1] + r1[0] + r1[1]) * coef_4; +// first row with zero pad +#ifdef __aarch64__ + for (; w < w_unroll_size; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345); + vsum = vaddq_f32(vsum, vsum_3456); + vsum = vmulq_f32(vsum, vcoef_6); + vst1q_f32(&dout_ch[cnt], vsum); + cnt += 4; + } +#else + dr_out = dr_out + 1; + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" + "vadd.f32 q5, q0, q2 @max r0_1234,r1_1234\n" + "vadd.f32 d12, d2, d6 @max r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q5, q6, #1 @vext max_2345\n" + "vext.f32 q2, q5, q6, #2 @vext max_3456\n" + "vadd.f32 q1, q5, q0 @add 1234+2345\n" + "vadd.f32 q1, q1, q2 @add + 3456\n" + "vmul.f32 q4, q1, %q[vcoef_6] @mul * 1/9.f\n" + "sub %[dr0], #8 @sub w,8\n" + "sub %[dr1], #8 @sub w,8\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s1_max_loop\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), + [vcoef_6] "+w"(vcoef_6) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); + } + +#endif + // remain + w = w_unroll_size; + for (int j = 0; j < w_unroll_remain; j++) { + float tmp_sum = r0[j + w] + r1[j + w]; + tmp_sum += (r0[j + w + 1] + r1[j + w + 1]); + tmp_sum += (r0[j + w + 2] + r1[j + w + 2]); + dout_ch[j + w + 1] = tmp_sum * coef_6; + } + // right + float tmp = r0[win - 2] + r1[win - 2]; + tmp += (r0[win - 1] + r1[win - 1]); + dout_ch[wout - 1] = tmp * coef_4; + + // r0 = r1; + // r1 = r0 + w_in; + // r2 = r1 + w_in; + dout_ch += wout; + int h = 0; + for (; h < hin - 2; h += 1) { + // deal with left pad + float maxr0 = r0[0] + r0[1]; + float maxr1 = r1[0] + r1[1]; + float maxr2 = r2[0] + r2[1]; + dout_ch[0] = (maxr0 + maxr1 + maxr2) * coef_6; +#ifdef __aarch64__ + w = 0; + cnt = 1; + for (; w < w_unroll_size; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr2_1234 = vld1q_f32(&r2[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + vsum_1234 = vaddq_f32(vsum_1234, vr2_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + vsum_5678 = vaddq_f32(vsum_5678, vr2_5678); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345); + vsum = vaddq_f32(vsum, vsum_3456); + vsum = vmulq_f32(vsum, vcoef); + vst1q_f32(&dout_ch[cnt], vsum); + cnt += 4; + } +#else + dr_out = dout_ch + 1; + dr0 = r0; + dr1 = r1; + dr2 = r2; + cnt_num = w_unroll_size >> 2; + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7,dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d10}, [%[dr2]]! @load d4-d7,dr1\n" + "vadd.f32 q7, q0, q2 @max r0_1234,r1_1234\n" + "vadd.f32 d16, d2, d6 @max r0_5678,r1_5678\n" + "vadd.f32 q3, q7, q4 @max r0_1234,r1_1234\n" + "vadd.f32 d12, d16, d10 @max r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q3, q6, #1 @vext max_2345\n" + "vext.f32 q2, q3, q6, #2 @vext max_3456\n" + "vadd.f32 q1, q3, q0 @add 1234+2345\n" + "vadd.f32 q1, q1, q2 @add+3456\n" + "vmul.f32 q4, q1, %q[vcoef] @mul*1/9.f\n" + "sub %[dr0], #8 @sub w,8\n" + "sub %[dr1], #8 @sub w,8\n" + "sub %[dr2], #8 @sub w,8\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s1_max_loop\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), + [vcoef] "+w"(vcoef) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8"); + } +#endif + // remain + w = w_unroll_size; + for (int j = 0; j < w_unroll_remain; j++) { + float tmp_sum = r0[j + w] + r1[j + w]; + tmp_sum += (r0[j + w + 1] + r1[j + w + 1]); + tmp_sum += (r0[j + w + 2] + r1[j + w + 2]); + tmp_sum += (r2[j + w + 1] + r2[j + w + 2]); + tmp_sum += r2[j + w]; + dout_ch[j + w + 1] = tmp_sum * coef; + } + // right + tmp = r0[win - 2] + r1[win - 2]; + tmp += (r0[win - 1] + r1[win - 1]); + tmp += (r2[win - 2] + r2[win - 1]); + dout_ch[wout - 1] = tmp * coef_6; + + r0 = r1; + r1 = r2; + r2 = r1 + win; + dout_ch += wout; + } + + // last line + float maxr0 = (r0[0] + r0[1]); + float maxr1 = (r1[0] + r1[1]); + dout_ch[0] = (maxr0 + maxr1) * coef_4; +#ifdef __aarch64__ + w = 0; + cnt = 1; + for (; w < w_unroll_size; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345); + vsum = vaddq_f32(vsum, vsum_3456); + vsum = vmulq_f32(vsum, vcoef_6); + vst1q_f32(&dout_ch[cnt], vsum); + cnt += 4; + } +#else + dr_out = dout_ch + 1; + dr0 = r0; + dr1 = r1; + cnt_num = w_unroll_size >> 2; + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" + "vadd.f32 q5, q0, q2 @max r0_1234,r1_1234\n" + "vadd.f32 d12, d2, d6 @max r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q5, q6, #1 @vext max_2345\n" + "vext.f32 q2, q5, q6, #2 @vext max_3456\n" + "vadd.f32 q1, q5, q0 @add 1234+2345\n" + "vadd.f32 q1, q1, q2 @add + 3456\n" + "vmul.f32 q4, q1, %q[vcoef_6] @mul * 1/9.f\n" + "sub %[dr0], #8 @sub w,8\n" + "sub %[dr1], #8 @sub w,8\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s1_max_loop\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), + [vcoef_6] "+w"(vcoef_6) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); + } +#endif + // remain + w = w_unroll_size; + for (int j = 0; j < w_unroll_remain; j++) { + float tmp_sum = r0[j + w] + r1[j + w]; + tmp_sum += (r0[j + w + 1] + r1[j + w + 1]); + tmp_sum += (r0[j + w + 2] + r1[j + w + 2]); + dout_ch[j + w + 1] = tmp_sum * coef_6; + } + // right + tmp = r0[win - 2] + r1[win - 2]; + tmp += (r0[win - 1] + r1[win - 1]); + dout_ch[wout - 1] = tmp * coef_4; + } + } +} + +void pooling3x3s2p1_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win) { + int kernel = 3; + int stride = 2; + int padding = 1; + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + + int w_needed = (wout << 1) + 1; + int h_needed = (hout << 1) + 1; + int w_limit = w_needed > win ? win : w_needed; + int h_limit = h_needed > hin ? hin : h_needed; + int w_even = (w_limit >> 1) << 1; + int h_even = (h_limit >> 1) << 1; + int w_unroll_size = ((w_even - 1) >> 3) << 3; + int w_unroll_remain = w_even - 1 - w_unroll_size; + int w_remain = w_needed - w_limit - padding; + int h_remain = h_needed - h_limit - padding; + int w_in_2 = win << 1; + float minval = std::numeric_limits::lowest(); + for (int n = 0; n < num; ++n) { + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* dout_ch = dout_batch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + const float* r0 = din_ch; + const float* r1 = r0 + win; + const float* r2 = r1 + win; + int cnt_num = w_unroll_size >> 3; + int cnt_num_remain = w_unroll_remain >> 1; + float* dr_out = dout_ch; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + int w = 1; + int cnt = 1; + dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); +// first row with zero pad +#if __aarch64__ + for (; w < w_unroll_size; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&dout_ch[cnt], vmax_123_345); + vst1_f32(&dout_ch[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + vmax2 = vpmax_f32(vmax2, vmax2); + dout_ch[cnt] = vget_lane_f32(vmax2, 0); + cnt++; + } +#else + dr0 = dr0 + 1; + dr1 = dr1 + 1; + dr_out = dr_out + 1; + // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << + // cnt_num_remain; + if (cnt_num > 0 || cnt_num_remain > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" + "vmax.f32 q6, q0, q3 @max r0_1234,r1_1234\n" + "vmax.f32 q7, q1, q4 @max r0_5678,r1_5678\n" + "vmax.f32 q8, q2, q5 @max r0_9101112,r1_9101112\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q7, q8, #1 @vext max_6789\n" + "vpmax.f32 d4, d12, d13 @pmax d4,vmax_1234,vmax_1234\n" + "vpmax.f32 d6, d14, d15 @pmax d6,vmax_5678,vmax_5678\n" + "vpmax.f32 d5, d0, d1 @pmax d5,vmax_2345,vmax_2345\n" + "vpmax.f32 d7, d2, d3 @pmax d7,vmax_6789,vmax_6789\n" + "vmax.f32 d8, d4, d5 @max d2,vmax_12_34,vmax_23_45\n" + "vmax.f32 d9, d6, d7 @max d2,vmax_56_78,vmax_67_89\n" + "sub %[dr0], #16 @add w,8\n" + "sub %[dr1], #16 @add w, 8\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "subs %[cnt_num], #1 @subs cnt_num, #1\n" + "bne 1b @bne s3_max_loop\n" + "3: @loop \n" + "cmp %[cnt_num_remain], #0 @cmp cnt_num,0\n" + "ble 4f @ble exit\n" + "2: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vmov.f32 s3,s2 @movs3,s2\n" + "vmov.f32 s7,s6 @movs7,s6\n" + "vmax.f32 q0, q0, q1 @max q0,q0,q1\n" + "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @subs cnt_num,#1\n" + "bne 2b @bne s3_max_loop_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), + [cnt_num_remain] "+r"(cnt_num_remain) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num_remain) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9"); + } +#endif + // int w = w_even - 1; + if (w_remain > 0) { + // deal with right pad + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp = std::max(tmp, std::max(r0[i], r1[i])); + } + dout_ch[w_even >> 1] = tmp; + // cnt ++; + } + + r0 = r1; + r1 = r0 + win; + r2 = r1 + win; + dout_ch += wout; + int h = 2; + for (; h < h_even; h += 2) { + // deal with left pad + float maxr0 = std::max(r0[0], r0[1]); + float maxr1 = std::max(r1[0], r1[1]); + float maxr2 = std::max(r2[0], r2[1]); + dout_ch[0] = std::max(std::max(maxr0, maxr1), maxr2); +#if __aarch64__ + w = 1; + cnt = 1; + for (; w < w_unroll_size; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vr2_1234 = vld1q_f32(&r2[w]); + float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); + float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + vmax_9101112 = vmaxq_f32(vmax_9101112, vr2_9101112); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&dout_ch[cnt], vmax_123_345); + vst1_f32(&dout_ch[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + float32x4_t vr2 = vld1q_f32(&r2[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + vr2 = vsetq_lane_f32(minval, vr2, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + vmax1 = vmaxq_f32(vmax1, vr2); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + float32x2_t vmax = vpmax_f32(vmax2, vmax2); + dout_ch[cnt] = vget_lane_f32(vmax, 0); + cnt++; + } +#else + dr_out = dout_ch + 1; + dr0 = (r0 + 1); + dr1 = (r1 + 1); + dr2 = (r2 + 1); + cnt_num = w_unroll_size >> 3; + cnt_num_remain = w_unroll_remain >> 1; + if (cnt_num > 0 || cnt_num_remain > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7,dr1\n" + "vmax.f32 q9, q0, q3 @max q0,q0,q2\n" + "vmax.f32 q10, q1, q4 @max q1,q1,q3\n" + "vmax.f32 q11, q2, q5 @max q1,q1,q3\n" + "vmax.f32 q0, q9, q6 @max q0,q0,q2 1234\n" + "vmax.f32 q3, q10, q7 @max q1,q1,q3 5678\n" + "vmax.f32 q1, q11, q8 @max q1,q1,q3 9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q4, q0, q3, #1 @vext 2345\n" + "vext.f32 q2, q3, q1, #1 @vext 6789\n" + "vpmax.f32 d10, d0, d1 @pmax d10,vmax_1234,vmax_1234\n" + "vpmax.f32 d12, d6, d7 @pmax d12,vmax_5678,vmax_5678\n" + "vpmax.f32 d11, d8, d9 @pmax d11,vmax_2345,vmax_2345\n" + "vpmax.f32 d13, d4, d5 @pmax d13,vmax_6789,vmax_6789\n" + "vmax.f32 d0, d10, d11 @pmax d0,vmax_12_34,vmax_23_45\n" + "vmax.f32 d1, d12, d13 @pmax d1,vmax_56_78,vmax_67_89\n" + "sub %[dr0], #16 @add w,8\n" + "sub %[dr1], #16 @add w,8\n" + "sub %[dr2], #16 @add w,8\n" + "vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "bne 1b @bne s3_max_loop_mid\n" + "3: @loop \n" + "cmp %[cnt_num_remain], #0 @cmp cnt_num,0\n" + "ble 4f @ble exit1\n" + "2: @mid loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1\n" + "vmov.f32 s3,s2 @movs3,s2\n" + "vmov.f32 s7,s6 @movs7,s6\n" + "vmov.f32 s11,s10 @movs11,s10\n" + "vmax.f32 q0, q0, q1 @max q0,q0,q1\n" + "vmax.f32 q0, q0, q2 @max q0,q0,q2\n" + "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0, d0,d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "sub %[dr2], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @subs cnt_num,#1\n" + "bne 2b @bne s3_max_loop_mid_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), + [cnt_num_remain] "+r"(cnt_num_remain) + : "r"(dr0), + "r"(dr1), + "r"(dr2), + "r"(dr_out), + "r"(cnt_num), + "r"(cnt_num_remain) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12"); + } +#endif + if (w_remain > 0) { + // deal with right pad + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { + tmp = std::max(tmp, std::max(r0[i], r1[i])); + tmp = std::max(tmp, r2[i]); + } + dout_ch[w_even >> 1] = tmp; + // cnt ++; + } + r0 = r2; + r1 = r0 + win; + r2 = r1 + win; + dout_ch += wout; + } + + if (h_remain > 0) { + // deal with bottom pad + // first row with zero pad + int hstart = (h >> 1) * stride - padding; + int hend = std::min(std::min(hstart + kernel, hin + padding), hin); + if (hstart == hend - 1) { // only one lline + dout_ch[0] = std::max(r0[0], r0[1]); +#if __aarch64__ + w = 1; + cnt = 1; + for (; w < w_unroll_size; w += 8) { + float32x4_t vmax_1234 = vld1q_f32(&r0[w]); + float32x4_t vmax_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vmax_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&dout_ch[cnt], vmax_123_345); + vst1_f32(&dout_ch[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + float32x2_t vmax = vpmax_f32(vget_low_f32(vr0), vget_high_f32(vr0)); + vmax = vpmax_f32(vmax, vmax); + dout_ch[cnt] = vget_lane_f32(vmax, 0); + cnt++; + } +#else + dr_out = dout_ch + 1; + dr0 = (r0 + 1); + cnt_num = w_unroll_size >> 3; + cnt_num_remain = w_unroll_remain >> 1; + // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << + // cnt_num_remain; + if (cnt_num > 0 || cnt_num_remain > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d3,dr0\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0\n" + "vext.f32 q4, q0, q1, #1 @vmax_2345\n" + "vext.f32 q5, q1, q2, #1 @vmax_6789\n" + "vpmax.f32 d12, d0, d1 @vmax_12_34\n" + "vpmax.f32 d14, d2, d3 @vmax_56_78\n" + "vpmax.f32 d13, d8, d9 @vmax_23_45\n" + "vpmax.f32 d15, d10, d11 @vmax_67_89\n" + "vmax.f32 d0, d12, d13 @12_34,23_45\n" + "vmax.f32 d1, d14, d15 @56_78,67_89\n" + "sub %[dr0], #16 @add w,6\n" + "vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "bne 1b @bne s3_max_loop_bot\n" + "3: @loop \n" + "cmp %[cnt_num_remain], #0 @cmp cnt_num,0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vmov.f32 s3,s2 @movs3, s2\n" + "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "sub %[dr0], #8 @add w,2\n" + "subs %[cnt_num_remain], #1 @subs cnt_num,#1\n" + "bne 2b @bne s3_max_loop_bot_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), + [cnt_num_remain] "+r"(cnt_num_remain) + : "r"(dr0), + "r"(dr1), + "r"(dr_out), + "r"(cnt_num), + "r"(cnt_num_remain) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8"); + } +#endif + if (w_remain > 0) { + // deal with right pad + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { + tmp = std::max(tmp, r0[i]); + } + dout_ch[w_even >> 1] = tmp; + } + } else { // two lines + dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); +#ifdef __aarch64__ + w = 1; + cnt = 1; + for (; w < w_unroll_size; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&dout_ch[cnt], vmax_123_345); + vst1_f32(&dout_ch[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + vmax2 = vpmax_f32(vmax2, vmax2); + dout_ch[cnt] = vget_lane_f32(vmax2, 0); + cnt++; + } +#else + dr_out = dout_ch + 1; + dr0 = (r0 + 1); + dr1 = (r1 + 1); + cnt_num = w_unroll_size >> 3; + cnt_num_remain = w_unroll_remain >> 1; + // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << + // cnt_num_remain; + if (cnt_num > 0 || cnt_num_remain > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" + "vmax.f32 q6, q0, q3 @max q0,q0,q2 1234\n" + "vmax.f32 q7, q1, q4 @max q1,q1,q3 5678\n" + "vmax.f32 q8, q2, q5 @max q1,q1,q3 9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext q0,2345\n" + "vext.f32 q1, q7, q8, #1 @vext q1,6789\n" + "vpmax.f32 d4, d12, d13 @pmax " + "d4,vmax_1234,vmax_1234\n" + "vpmax.f32 d6, d14, d15 @pmax " + "d6,vmax_5678,vmax_5678\n" + "vpmax.f32 d5, d0, d1 @pmax " + "d5,vmax_2345,vmax_2345\n" + "vpmax.f32 d7, d2, d3 @pmax " + "d7,vmax_6789,vmax_6789\n" + "vmax.f32 d8, d4, d5 @max " + "d2,vmax_12_34,vmax_23_45\n" + "vmax.f32 d9, d6, d7 @max " + "d2,vmax_56_78,vmax_67_89\n" + "sub %[dr0], #16 @add w,8\n" + "sub %[dr1], #16 @add w,8\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "bne 1b @bne s3_max_loop_bot\n" + "3: @loop \n" + "cmp %[cnt_num_remain], #0 @cmp cnt_num,0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vmov.f32 s3,s2 @movs3, s2\n" + "vmov.f32 s7,s6 @movs7, s6\n" + "vmax.f32 q0, q0, q1 @max q0,q0,q1\n" + "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @subs cnt_num,#1\n" + "bne 2b @bne s3_max_loop_bot_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), + [cnt_num_remain] "+r"(cnt_num_remain) + : "r"(dr0), + "r"(dr1), + "r"(dr_out), + "r"(cnt_num), + "r"(cnt_num_remain) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9"); + } +#endif + if (w_remain > 0) { + // deal with right pad + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp = std::max(tmp, std::max(r0[i], r1[i])); + } + dout_ch[w_even >> 1] = tmp; + } + } + } + } + } +} + +void pooling3x3s2p1_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive) { + int kernel = 3; + int stride = 2; + int padding = 1; + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + + int w_needed = (wout << 1) + 1; + int h_needed = (hout << 1) + 1; + int w_limit = w_needed > win ? win : w_needed; + int h_limit = h_needed > hin ? hin : h_needed; + int w_even = (w_limit >> 1) << 1; + int h_even = (h_limit >> 1) << 1; + int w_unroll_size = ((w_even - 1) >> 3) << 3; + int w_unroll_remain = w_even - 1 - w_unroll_size; + int w_remain = w_needed - w_limit - padding; + int h_remain = h_needed - h_limit - padding; + int w_in_2 = win << 1; + const float coef = 1.f / 9.f; + const float coef_1 = exclusive ? 1.f : coef; + const float coef_2 = exclusive ? 1.f / 2.f : coef; + const float coef_3 = exclusive ? 1.f / 3.f : coef; + const float coef_4 = exclusive ? 1.f / 4.f : coef; + const float coef_6 = exclusive ? 1.f / 6.f : coef; + float32x4_t vcoef = vdupq_n_f32(coef); + float32x4_t vcoef_1 = vdupq_n_f32(coef_1); + float32x4_t vcoef_2 = vdupq_n_f32(coef_2); + float32x4_t vcoef_3 = vdupq_n_f32(coef_3); + float32x4_t vcoef_4 = vdupq_n_f32(coef_4); + float32x4_t vcoef_6 = vdupq_n_f32(coef_6); + for (int n = 0; n < num; ++n) { + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* dout_ch = dout_batch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + const float* r0 = din_ch; + const float* r1 = r0 + win; + const float* r2 = r1 + win; + int cnt_num = w_unroll_size >> 3; + int cnt_num_remain = w_unroll_remain >> 1; + float* dr_out = dout_ch; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + int w = 1; + int cnt = 1; + float32x4_t vzero = vdupq_n_f32(0.f); + dout_ch[0] = (r0[0] + r0[1] + r1[0] + r1[1]) * coef_4; +// first row with zero pad +#ifdef __aarch64__ + for (; w < w_unroll_size; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); + float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); + float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); + vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); + float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); + vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef_6); + vst1q_f32(&dout_ch[cnt], vrst); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(0.f, vr0, 3); + vr1 = vsetq_lane_f32(0.f, vr1, 3); + float32x4_t vsum1 = vaddq_f32(vr0, vr1); + float32x2_t vsum2 = + vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); + vsum2 = vpadd_f32(vsum2, vsum2); + float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef_6)); + dout_ch[cnt] = vget_lane_f32(vrst, 0); + cnt++; + } +#else + dr0 = dr0 + 1; + dr1 = dr1 + 1; + dr_out = dr_out + 1; + // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << + // cnt_num_remain; + if (cnt_num > 0 || cnt_num_remain > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" + "vadd.f32 q6, q0, q3 @max r0_1234,r1_1234\n" + "vadd.f32 q7, q1, q4 @max r0_5678,r1_5678\n" + "vadd.f32 q8, q2, q5 @max r0_9101112,r1_9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234, 2345\n" + "vadd.f32 q5, q7, q1 @add 5678, 4567\n" + "vadd.f32 q4, q4, q2 @add 3456, sum1\n" + "vadd.f32 q5, q5, q3 @add 6789, sum2\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s21 @mov\n" + "vmov.f32 s19, s23 @mov\n" + "vmul.f32 q4, q4, %q[vcoef_6] @mul\n" + "sub %[dr0], #16 @add w,8\n" + "sub %[dr1], #16 @add w,8\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s3_max_loop\n" + "3: @loop\n" + "cmp %[cnt_num_remain], #0 @cnt_num_remain<=0\n" + "ble 4f @ble exit\n" + "2: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0,q0,q1\n" + "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0, d0,d0\n" + "vmul.f32 d0, d0, %e[vcoef_6] @mul\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @subs cnt_num,#1\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "bne 2b @bne s3_max_loop_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), + [cnt_num_remain] "+r"(cnt_num_remain), + [vcoef_6] "+w"(vcoef_6), + [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num_remain) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9"); + } +#endif + // int w = w_even - 1; + if (w_remain > 0) { + // deal with right pad + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp1 = 0.f; // std::numeric_limits::min(); + float tmp2 = exclusive ? 1.0f / (2.f * (wend - wstart)) : coef; + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp1 += (r0[i] + r1[i]); + } + dout_ch[w_even >> 1] = tmp1 * tmp2; + // cnt ++; + } + + r0 = r1; + r1 = r0 + win; + r2 = r1 + win; + dout_ch += wout; + int h = 2; + for (; h < h_even; h += 2) { + // deal with left pad + float sum0 = r0[0] + r0[1]; + float sum1 = r1[0] + r1[1]; + float sum2 = r2[0] + r2[1]; + dout_ch[0] = (sum0 + sum1 + sum2) * coef_6; +#ifdef __aarch64__ + w = 1; + cnt = 1; + for (; w < w_unroll_size; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vr2_1234 = vld1q_f32(&r2[w]); + float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); + float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]); + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112); + vsum_1234 = vaddq_f32(vsum_1234, vr2_1234); + vsum_5678 = vaddq_f32(vsum_5678, vr2_5678); + vsum_9101112 = vaddq_f32(vsum_9101112, vr2_9101112); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); + float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); + float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); + vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); + float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); + vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); + vst1q_f32(&dout_ch[cnt], vrst); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + float32x4_t vr2 = vld1q_f32(&r2[w]); + vr0 = vsetq_lane_f32(0.f, vr0, 3); + vr1 = vsetq_lane_f32(0.f, vr1, 3); + vr2 = vsetq_lane_f32(0.f, vr2, 3); + float32x4_t vsum1 = vaddq_f32(vr0, vr1); + vsum1 = vaddq_f32(vsum1, vr2); + float32x2_t vsum2 = + vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); + float32x2_t vsum = vpadd_f32(vsum2, vsum2); + dout_ch[cnt] = vget_lane_f32(vsum, 0) * coef; + cnt++; + } +#else + dr_out = dout_ch + 1; + dr0 = (r0 + 1); + dr1 = (r1 + 1); + dr2 = (r2 + 1); + cnt_num = w_unroll_size >> 3; + cnt_num_remain = w_unroll_remain >> 1; + if (cnt_num > 0 || cnt_num_remain > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7,dr1\n" + "vadd.f32 q9, q0, q3 @max q0,q0,q2\n" + "vadd.f32 q10, q1, q4 @max q1,q1,q3\n" + "vadd.f32 q11, q2, q5 @max q1,q1,q3\n" + "vadd.f32 q6, q9, q6 @max q0,q0,q2 1234\n" + "vadd.f32 q7, q10, q7 @max q1,q1,q3 5678\n" + "vadd.f32 q8, q11, q8 @max q1,q1,q3 9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234,2345\n" + "vadd.f32 q5, q7, q1 @add 5678,4567\n" + "vadd.f32 q4, q4, q2 @add 3456,sum1\n" + "vadd.f32 q5, q5, q3 @add 6789,sum2\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s21 @mov\n" + "vmov.f32 s19, s23 @mov\n" + "vmul.f32 q4, q4, %q[vcoef] @mul\n" + "sub %[dr0], #16 @add w,8\n" + "sub %[dr1], #16 @add w,8\n" + "sub %[dr2], #16 @add w, 8\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s3_max_loop_mid\n" + "3: @loop\n" + "cmp %[cnt_num_remain], #0 @cnt_num_remain<=0\n" + "ble 4f @ble exit1\n" + "2: @mid loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n" + "vext.f32 q2, %q[vzero], q2, #3 @ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0,q0,q1\n" + "vadd.f32 q0, q0, q2 @add q0,q0,q1\n" + "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n" + "vmul.f32 d0, d0, %e[vcoef] @mul\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "sub %[dr2], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "bne 2b @bne s3_max_loop_mid_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), + [cnt_num_remain] "+r"(cnt_num_remain), + [vcoef] "+w"(vcoef), + [vzero] "+w"(vzero) + : "r"(dr0), + "r"(dr1), + "r"(dr2), + "r"(dr_out), + "r"(cnt_num), + "r"(cnt_num_remain) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12"); + } +#endif + if (w_remain > 0) { + // deal with right pad + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp1 = 0.f; + float tmp2 = exclusive ? 1.0f / (3.f * (wend - wstart)) : coef; + for (int i = wstart; i < wend; i++) { + tmp1 += (r0[i] + r1[i] + r2[i]); + } + dout_ch[w_even >> 1] = tmp1 * tmp2; + // cnt ++; + } + r0 = r2; + r1 = r0 + win; + r2 = r1 + win; + dout_ch += wout; + } + + if (h_remain > 0) { + // deal with bottom pad + // first row with zero pad + int hstart = (h >> 1) * stride - padding; + int hend = std::min(std::min(hstart + kernel, hin + padding), hin); + if (hstart == hend - 1) { // only one line + dout_ch[0] = (r0[0] + r0[1]) * coef_2; +#ifdef __aarch64__ + w = 1; + cnt = 1; + for (; w < w_unroll_size; w += 8) { + float32x4_t vsum_1234 = vld1q_f32(&r0[w]); + float32x4_t vsum_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vsum_9101112 = vld1q_f32(&r0[w + 8]); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); + float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); + float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); + vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); + float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); + vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); + vsum_123_345 = vsetq_lane_f32( + vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1); + vsum_123_345 = vsetq_lane_f32( + vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); + vsum_123_345 = vsetq_lane_f32( + vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef_3); + vst1q_f32(&dout_ch[cnt], vrst); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + vr0 = vsetq_lane_f32(0.f, vr0, 3); + float32x2_t vsum = vpadd_f32(vget_low_f32(vr0), vget_high_f32(vr0)); + vsum = vpadd_f32(vsum, vsum); + dout_ch[cnt] = vget_lane_f32(vsum, 0) * coef_3; + cnt++; + } +#else + dr_out = dout_ch + 1; + dr0 = (r0 + 1); + cnt_num = w_unroll_size >> 3; + cnt_num_remain = w_unroll_remain >> 1; + if (cnt_num > 0 || cnt_num_remain > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d12-d15}, [%[dr0]]! @load d0-d3,dr0\n" + "vld1.f32 {d16-d17}, [%[dr0]]! @load d0-d3,dr0\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234,2345\n" + "vadd.f32 q5, q7, q1 @add 5678,4567\n" + "vadd.f32 q4, q4, q2 @add 3456,sum1\n" + "vadd.f32 q5, q5, q3 @add 6789,sum2\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s21 @mov\n" + "vmov.f32 s19, s23 @mov\n" + "vmul.f32 q4, q4, %q[vcoef_3] @mul\n" + "sub %[dr0], #16 @add w,6\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s3_max_loop_bot\n" + "3: @loop\n" + "cmp %[cnt_num_remain], #0 @cnt_num_remain<=0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" + "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n" + "vmul.f32 d0, d0, %e[vcoef_3] @mul\n" + "sub %[dr0], #8 @add w,2\n" + "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "bne 2b @bne s3_max_loop_bot_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), + [cnt_num_remain] "+r"(cnt_num_remain), + [vcoef_3] "+w"(vcoef_3), + [vzero] "+w"(vzero) + : "r"(dr0), + "r"(dr1), + "r"(dr_out), + "r"(cnt_num), + "r"(cnt_num_remain) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8"); + } +#endif + if (w_remain > 0) { + // deal with right pad + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp1 = 0.f; + float tmp2 = exclusive ? 1.0f / (1.f * (wend - wstart)) : coef; + for (int i = wstart; i < wend; i++) { + tmp1 += r0[i]; + } + dout_ch[w_even >> 1] = tmp1 * tmp2; + } + } else { // two lines + dout_ch[0] = (r0[0] + r0[1] + r1[0] + r1[1]) * coef_4; +#ifdef __aarch64__ + w = 1; + cnt = 1; + for (; w < w_unroll_size; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112); + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); + float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); + float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); + vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); + float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); + vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); + vsum_123_345 = vsetq_lane_f32( + vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1); + vsum_123_345 = vsetq_lane_f32( + vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); + vsum_123_345 = vsetq_lane_f32( + vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef_6); + vst1q_f32(&dout_ch[cnt], vrst); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(0.f, vr0, 3); + vr1 = vsetq_lane_f32(0.f, vr1, 3); + float32x4_t vsum1 = vaddq_f32(vr0, vr1); + float32x2_t vsum2 = + vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); + vsum2 = vpadd_f32(vsum2, vsum2); + float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef_6)); + dout_ch[cnt] = vget_lane_f32(vrst, 0); + cnt++; + } +#else + dr_out = dout_ch + 1; + dr0 = (r0 + 1); + dr1 = (r1 + 1); + cnt_num = w_unroll_size >> 3; + cnt_num_remain = w_unroll_remain >> 1; + if (cnt_num > 0 || cnt_num_remain > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" + "vadd.f32 q6, q0, q3 @add q0,q0,q2 1234\n" + "vadd.f32 q7, q1, q4 @add q1,q1,q3 5678\n" + "vadd.f32 q8, q2, q5 @add q1,q1,q3 9101112\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234,2345\n" + "vadd.f32 q5, q7, q1 @add 5678,4567\n" + "vadd.f32 q4, q4, q2 @add 3456,sum1\n" + "vadd.f32 q5, q5, q3 @add 6789,sum2\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s21 @mov\n" + "vmov.f32 s19, s23 @mov\n" + "vmul.f32 q4, q4, %q[vcoef_6] @mul\n" + "sub %[dr0], #16 @add w,8\n" + "sub %[dr1], #16 @add w,8\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" + "bne 1b @bne s3_max_loop_bot\n" + "3: @loop\n" + "cmp %[cnt_num_remain], #0 @cnt_num_remain<=0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0,q0,q1\n" + "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n" + "vmul.f32 d0, d0, %e[vcoef_6] @mul\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "bne 2b @bne s3_max_loop_bot_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), + [cnt_num_remain] "+r"(cnt_num_remain), + [vcoef_6] "+w"(vcoef_6), + [vzero] "+w"(vzero) + : "r"(dr0), + "r"(dr1), + "r"(dr_out), + "r"(cnt_num), + "r"(cnt_num_remain) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9"); + } +#endif + if (w_remain > 0) { + // deal with right pad + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp1 = 0.f; + float tmp2 = exclusive ? 1.0f / (2.f * (wend - wstart)) : coef; + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp1 += (r0[i] + r1[i]); + } + dout_ch[w_even >> 1] = tmp1 * tmp2; + } + } + } + } + } +} + +void pooling3x3s2p0_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win) { + int kernel = 3; + int stride = 2; + int padding = 0; + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + + int w_needed = (wout << 1) + 1; + int h_needed = (hout << 1) + 1; + int w_limit = w_needed > win ? win : w_needed; + int h_limit = h_needed > hin ? hin : h_needed; + int w_even = ((w_limit - 1) >> 1) << 1; + int h_even = ((h_limit - 1) >> 1) << 1; + int w_unroll_size = (w_even >> 3) << 3; + int w_unroll_remain = w_even - w_unroll_size; + int w_remain = w_needed - w_limit; + int h_remain = h_needed - h_limit; + int w_in_2 = win << 1; + float minval = std::numeric_limits::lowest(); + for (int n = 0; n < num; ++n) { + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* dout_ch = dout_batch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + const float* r0 = din_ch; + const float* r1 = r0 + win; + const float* r2 = r1 + win; + // w = w_in - 8; + float* dr_out = dout_ch; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + int w = 0; + int cnt = 0; + // dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], + // r1[1])); + // first row with zero pad + // r0 = r1; + // r1 = r0 + w_in; + // r2 = r1 + w_in; + // dout_channel += w_out; + int h = 0; + for (; h < h_even; h += 2) { + // deal with left pad + float maxr0 = std::max(r0[0], r0[1]); + float maxr1 = std::max(r1[0], r1[1]); + float maxr2 = std::max(r2[0], r2[1]); +// dout_ch[0] = std::max(std::max(maxr0, maxr1), maxr2); +#ifdef __aarch64__ + w = 0; + cnt = 0; + for (; w < w_unroll_size; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vr2_1234 = vld1q_f32(&r2[w]); + float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); + float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + vmax_9101112 = vmaxq_f32(vmax_9101112, vr2_9101112); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&dout_ch[cnt], vmax_123_345); + vst1_f32(&dout_ch[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + float32x4_t vr2 = vld1q_f32(&r2[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + vr2 = vsetq_lane_f32(minval, vr2, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + vmax1 = vmaxq_f32(vmax1, vr2); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + float32x2_t vmax = vpmax_f32(vmax2, vmax2); + dout_ch[cnt] = vget_lane_f32(vmax, 0); + cnt++; + } +#else + dr_out = dout_ch; // + 1; + dr0 = r0; // (r0 + 1); + dr1 = r1; // (r1 + 1); + dr2 = r2; // (r2 + 1); + int cnt_num = w_unroll_size >> 3; + int cnt_num_remain = w_unroll_remain >> 1; + if (cnt_num > 0 || cnt_num_remain > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1\n" + "vld1.f32 {d4}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d16}, [%[dr2]]! @load d4-d7,dr1\n" + "vmax.f32 q9, q0, q3 @max q0,q0,q2\n" + "vmax.f32 q10, q1, q4 @max q1,q1,q3\n" + "vmax.f32 d22, d4, d10 @max q1,q1,q3\n" + "vmax.f32 q0, q9, q6 @max q0,q0,q2 1234\n" + "vmax.f32 q3, q10, q7 @max q1,q1,q3 5678\n" + "vmax.f32 d2, d22, d16 @max q1,q1,q3 9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q4, q0, q3, #1 @vext 2345\n" + "vext.f32 q2, q3, q1, #1 @vext 6789\n" + "vpmax.f32 d10, d0, d1 @pmax " + "d10,vmax_1234,vmax_1234\n" + "vpmax.f32 d12, d6, d7 @pmax " + "d12,vmax_5678,vmax_5678\n" + "vpmax.f32 d11, d8, d9 @pmax " + "d11,vmax_2345,vmax_2345\n" + "vpmax.f32 d13, d4, d5 @pmax " + "d13,vmax_6789,vmax_6789\n" + "vmax.f32 d0, d10, d11 @pmax " + "d0,vmax_12_34,vmax_23_45\n" + "vmax.f32 d1, d12, d13 @pmax " + "d1,vmax_56_78,vmax_67_89\n" + "sub %[dr0], #8 @add w,8\n" + "sub %[dr1], #8 @add w,8\n" + "sub %[dr2], #8 @add w,8\n" + "vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "bne 1b @bne s3_max_loop_mid\n" + "3: @loop\n" + "cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0\n" + "ble 4f @ble exit1\n" + "2: @mid loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1\n" + "vmov.f32 s3,s2 @movs3,s2\n" + "vmov.f32 s7,s6 @movs7,s6\n" + "vmov.f32 s11,s10 @movs11,s10\n" + "vmax.f32 q0, q0, q1 @max q0,q0,q1\n" + "vmax.f32 q0, q0, q2 @max q0,q0,q2\n" + "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "sub %[dr2], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" + "bne 2b @bne s3_max_loop_mid_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), + [cnt_num_remain] "+r"(cnt_num_remain) + : "r"(dr0), + "r"(dr1), + "r"(dr2), + "r"(dr_out), + "r"(cnt_num), + "r"(cnt_num_remain) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12"); + } +#endif + if (w_remain > 0) { + // deal with right pad + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { + tmp = std::max(tmp, std::max(r0[i], r1[i])); + tmp = std::max(tmp, r2[i]); + } + dout_ch[w_even >> 1] = tmp; + // cnt ++; + } + r0 = r2; + r1 = r0 + win; + r2 = r1 + win; + dout_ch += wout; + } + + if (h_remain > 0) { +// deal with bottom pad +// first row with zero pad +// int hstart = (h >> 1) * stride_h - pad_h; +// int hend = std::min(std::min(hstart + kernel_h, hin + pad_h), hin); +// dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], +// r1[1])); +#ifdef __aarch64__ + w = 0; + cnt = 0; + for (; w < w_unroll_size; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&dout_ch[cnt], vmax_123_345); + vst1_f32(&dout_ch[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + vmax2 = vpmax_f32(vmax2, vmax2); + dout_ch[cnt] = vget_lane_f32(vmax2, 0); + cnt++; + } +#else + dr_out = dout_ch; // + 1; + dr0 = r0; // (r0 + 1); + dr1 = r1; // (r1 + 1); + int cnt_num = w_unroll_size >> 3; + int cnt_num_remain = w_unroll_remain >> 1; + if (cnt_num > 0 || cnt_num_remain > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d4}, [%[dr0]]! @load d0-d3,dr0\n" + "vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1\n" + "vmax.f32 q6, q0, q3 @max q0,q0,q2 1234\n" + "vmax.f32 q7, q1, q4 @max q1,q1,q3 5678\n" + "vmax.f32 d16, d4, d10 @max q1,q1,q3 9101112\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q6, q7, #1 @vext q0,2345\n" + "vext.f32 q1, q7, q8, #1 @vext q1,6789\n" + "vpmax.f32 d4, d12, d13 @pmax " + "d4,vmax_1234,vmax_1234\n" + "vpmax.f32 d6, d14, d15 @pmax " + "d6,vmax_5678,vmax_5678\n" + "vpmax.f32 d5, d0, d1 @pmax " + "d5,vmax_2345,vmax_2345\n" + "vpmax.f32 d7, d2, d3 @pmax " + "d7,vmax_6789,vmax_6789\n" + "vmax.f32 d8, d4, d5 @max " + "d2,vmax_12_34,vmax_23_45\n" + "vmax.f32 d9, d6, d7 @max " + "d2,vmax_56_78,vmax_67_89\n" + "sub %[dr0], #8 @add w,8\n" + "sub %[dr1], #8 @add w,8\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "bne 1b @bne s3_max_loop_bot\n" + "3: @loop \n" + "cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vmov.f32 s3,s2 @movs3,s2\n" + "vmov.f32 s7,s6 @movs7,s6\n" + "vmax.f32 q0, q0, q1 @max q0,q0,q1\n" + "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" + "bne 2b @bne s3_max_loop_bot_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), + [cnt_num_remain] "+r"(cnt_num_remain) + : "r"(dr0), + "r"(dr1), + "r"(dr_out), + "r"(cnt_num), + "r"(cnt_num_remain) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9"); + } +#endif + if (w_remain > 0) { + // deal with right pad + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp = std::max(tmp, std::max(r0[i], r1[i])); + } + dout_ch[w_even >> 1] = tmp; + } + } + } + } +} + +void pooling3x3s2p0_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive) { + int kernel = 3; + int stride = 2; + int padding = 0; + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + + int w_needed = (wout << 1) + 1; + int h_needed = (hout << 1) + 1; + int w_limit = w_needed > win ? win : w_needed; + int h_limit = h_needed > hin ? hin : h_needed; + int w_even = ((w_limit - 1) >> 1) << 1; + int h_even = ((h_limit - 1) >> 1) << 1; + int w_unroll_size = (w_even >> 3) << 3; + int w_unroll_remain = w_even - w_unroll_size; + int w_remain = w_needed - w_limit; + int h_remain = h_needed - h_limit; + int w_in_2 = win << 1; + const float coef = 1.f / 9.f; + const float coef_6 = exclusive ? 1.f / 6.f : coef; + float32x4_t vcoef = vdupq_n_f32(coef); + float32x4_t vcoef_6 = vdupq_n_f32(coef_6); + for (int n = 0; n < num; ++n) { + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* dout_ch = dout_batch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + const float* r0 = din_ch; + const float* r1 = r0 + win; + const float* r2 = r1 + win; + // w = w_in - 8; + float* dr_out = dout_ch; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + + float32x4_t vzero = vdupq_n_f32(0.f); + + int h = 0; + for (; h < h_even; h += 2) { +// LOG(INFO) << "h: " << h <<", dr0:" << r0 << ", dr1: " << r1 << +// ",dr2: " <> 3; + int cnt_num_remain = w_unroll_remain >> 1; + // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << + // cnt_num_remain; + if (cnt_num > 0 || cnt_num_remain > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, 0\n" + "ble 3f @ble exit\n" + "s3_ave_loop_mid_p0: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, dr2\n" + "vld1.f32 {d4}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d16}, [%[dr2]]! @load d4-d7, dr2\n" + "vadd.f32 q9, q0, q3 @max q0,q0,q2\n" + "vadd.f32 q10, q1, q4 @max q1,q1,q3\n" + "vadd.f32 d22, d4, d10 @max q1,q1,q3\n" + "vadd.f32 q6, q9, q6 @max q0,q0,q2 1234\n" + "vadd.f32 q7, q10, q7 @max q1,q1,q3 5678\n" + "vadd.f32 d16, d22, d16 @max q1,q1,q3 9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234, 2345\n" + "vadd.f32 q5, q7, q1 @add 5678, 4567\n" + "vadd.f32 q4, q4, q2 @add 3456, sum1\n" + "vadd.f32 q5, q5, q3 @add 6789, sum2\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s21 @mov\n" + "vmov.f32 s19, s23 @mov\n" + "vmul.f32 q4, q4, %q[vcoef] @mul\n" + "sub %[dr0], #8 @add w,8\n" + "sub %[dr1], #8 @add w,8\n" + "sub %[dr2], #8 @add w,8\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne s3_ave_loop_mid_p0 @bne s3_max_loop_mid\n" + "3: @loop\n" + "cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0\n" + "ble 4f @ble exit1\n" + "s3_ave_loop_mid_1_p0: @mid loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n" + "vext.f32 q2, %q[vzero], q2, #3 @ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0,q0,q1\n" + "vadd.f32 q0, q0, q2 @add q0,q0,q1\n" + "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n" + "vmul.f32 d0, d0, %e[vcoef] @mul\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "sub %[dr2], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "bne s3_ave_loop_mid_1_p0 @bne s3_max_loop_mid_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), + [cnt_num_remain] "+r"(cnt_num_remain), + [vcoef] "+w"(vcoef), + [vzero] "+w"(vzero) + : "r"(dr0), + "r"(dr1), + "r"(dr2), + "r"(dr_out), + "r"(cnt_num), + "r"(cnt_num_remain) + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12"); + } +#endif + if (w_remain > 0) { + // deal with right pad + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp1 = 0.f; + float tmp2 = exclusive ? 1.0f / (3.f * (wend - wstart)) : coef; + for (int i = wstart; i < wend; i++) { + tmp1 += (r0[i] + r1[i] + r2[i]); + } + dout_ch[w_even >> 1] = tmp1 * tmp2; + // cnt ++; + } + r0 = r2; + r1 = r0 + win; + r2 = r1 + win; + dout_ch += wout; + } + + if (h_remain > 0) { +// deal with bottom pad +// first row with zero pad +// int hstart = (h >> 1) * stride_h - pad_h; +// int hend = std::min(std::min(hstart + kernel_h, hin + padding_h), +// hin); data_out_channel[0] =(r0[0] + r0[1] + r0[2] + r1[0] + r1[1] + +// r1[2]) / 9.f; +#ifdef __aarch64__ + int w = 0; + int cnt = 0; + for (; w < w_unroll_size; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112); + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); + float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); + float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); + vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); + float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); + vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef_6); + vst1q_f32(&dout_ch[cnt], vrst); + cnt += 4; + } + for (; w < w_even; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(0.f, vr0, 3); + vr1 = vsetq_lane_f32(0.f, vr1, 3); + float32x4_t vsum1 = vaddq_f32(vr0, vr1); + float32x2_t vsum2 = + vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); + vsum2 = vpadd_f32(vsum2, vsum2); + float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef_6)); + dout_ch[cnt] = vget_lane_f32(vrst, 0); + cnt++; + } +#else + dr_out = dout_ch; // + 1; + dr0 = r0; // (r0 + 1); + dr1 = r1; // (r1 + 1); + int cnt_num = w_unroll_size >> 3; + int cnt_num_remain = w_unroll_remain >> 1; + // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << + // cnt_num_remain; + if (cnt_num > 0 || cnt_num_remain > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 2f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d4}, [%[dr0]]! @load d0-d3,dr0\n" + "vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1\n" + "vadd.f32 q6, q0, q3 @max q0,q0,q2 1234\n" + "vadd.f32 q7, q1, q4 @max q1,q1,q3 5678\n" + "vadd.f32 d16, d4, d10 @max q1,q1,q3 9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234,2345\n" + "vadd.f32 q5, q7, q1 @add 5678,4567\n" + "vadd.f32 q4, q4, q2 @add 3456,sum1\n" + "vadd.f32 q5, q5, q3 @add 6789,sum2\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s21 @mov\n" + "vmov.f32 s19, s23 @mov\n" + "vmul.f32 q4, q4, %q[vcoef_6] @mul\n" + "sub %[dr0], #8 @add w,8\n" + "sub %[dr1], #8 @add w,8\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s3_max_loop_bot\n" + "2: @loop\n" + "cmp %[cnt_num_remain], #0 @cmp cnt_num_remain, 0\n" + "ble 3f @ble exit\n" + "4: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0,q0,q1\n" + "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n" + "vmul.f32 d0, d0, %e[vcoef_6] @mul\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "bne 4b @bne s3_max_loop_bot_1\n" + "3: @exit\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), + [cnt_num_remain] "+r"(cnt_num_remain), + [vcoef_6] "+w"(vcoef_6), + [vzero] "+w"(vzero) + : "r"(dr0), + "r"(dr1), + "r"(dr_out), + "r"(cnt_num), + "r"(cnt_num_remain) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9"); + } + +#endif + if (w_remain > 0) { + // deal with right pad + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp1 = 0.f; + float tmp2 = exclusive ? 1.0f / (2.f * (wend - wstart)) : coef; + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp1 += (r0[i] + r1[i]); + } + dout_ch[w_even >> 1] = tmp1 * tmp2; + } + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/pooling.h b/lite/arm/math/pooling.h new file mode 100644 index 00000000000..8fc9e0c4e01 --- /dev/null +++ b/lite/arm/math/pooling.h @@ -0,0 +1,154 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +// !pooling fp32 Op +void pooling_basic(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + bool global_pooling, + bool exclusive, + bool adaptive, + bool ceil_mode, + bool use_quantizer, + const std::string& pooling_type); + +void pooling_global_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win); + +void pooling_global_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win); + +void pooling2x2s2_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win); + +void pooling2x2s2_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive); + +void pooling3x3s1p1_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win); + +void pooling3x3s1p1_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive); + +void pooling3x3s2p1_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win); + +void pooling3x3s2p1_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive); + +void pooling3x3s2p0_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win); + +void pooling3x3s2p0_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/power.cc b/lite/arm/math/power.cc new file mode 100644 index 00000000000..33a5a64ede1 --- /dev/null +++ b/lite/arm/math/power.cc @@ -0,0 +1,96 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/power.h" +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void power(const float* din, + float* dout, + const int num, + float scale_, + float shift_, + float power_) { + int cnt = num >> 4; + int remain = num % 16; + bool _do_power = true; + bool _do_scale = true; + bool _do_shift = true; + if (fabsf(power_ - 1.f) < 1e-6f) { + _do_power = false; + } + if (fabsf(scale_ - 1.f) < 1e-6f) { + _do_scale = false; + } + if (fabsf(shift_ - 0.f) < 1e-6f) { + _do_shift = false; + } + float* ptr_out = dout; + const float* ptr_in = din; + float32x4_t vscale = vdupq_n_f32(scale_); + float32x4_t vshift = vdupq_n_f32(shift_); + float32x4_t vpower = vdupq_n_f32(power_); +#pragma omp parallel for + for (int nums = 0; nums < cnt; ++nums) { + float32x4_t vr0 = vld1q_f32(ptr_in); + ptr_in += 4; + float32x4_t vr1 = vld1q_f32(ptr_in); + ptr_in += 4; + float32x4_t vr2 = vld1q_f32(ptr_in); + ptr_in += 4; + float32x4_t vr3 = vld1q_f32(ptr_in); + ptr_in += 4; + if (_do_scale) { + vr0 = vmulq_f32(vr0, vscale); + vr1 = vmulq_f32(vr1, vscale); + vr2 = vmulq_f32(vr2, vscale); + vr3 = vmulq_f32(vr3, vscale); + } + if (_do_shift) { + vr0 = vaddq_f32(vr0, vshift); + vr1 = vaddq_f32(vr1, vshift); + vr2 = vaddq_f32(vr2, vshift); + vr3 = vaddq_f32(vr3, vshift); + } + if (_do_power) { + vr0 = pow_ps(vr0, vpower); + vr1 = pow_ps(vr1, vpower); + vr2 = pow_ps(vr2, vpower); + vr3 = pow_ps(vr3, vpower); + } + vst1q_f32(ptr_out, vr0); + ptr_out += 4; + vst1q_f32(ptr_out, vr1); + ptr_out += 4; + vst1q_f32(ptr_out, vr2); + ptr_out += 4; + vst1q_f32(ptr_out, vr3); + ptr_out += 4; + } + for (int j = 0; j < remain; ++j) { + ptr_out[0] = std::pow((ptr_in[0] * scale_ + shift_), power_); + ptr_in++; + ptr_out++; + } +} + +} /* namespace math */ +} /* namespace arm */ +} /* namespace lite */ +} /* namespace paddle */ diff --git a/lite/arm/math/power.h b/lite/arm/math/power.h new file mode 100644 index 00000000000..7b9074918d2 --- /dev/null +++ b/lite/arm/math/power.h @@ -0,0 +1,33 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void power(const T* din, + T* dout, + const int num, + float scale_, + float shift_, + float power_); + +} /* namespace math */ +} /* namespace arm */ +} /* namespace lite */ +} /* namespace paddle */ diff --git a/lite/arm/math/prior_box.cc b/lite/arm/math/prior_box.cc new file mode 100644 index 00000000000..e6f455e72a2 --- /dev/null +++ b/lite/arm/math/prior_box.cc @@ -0,0 +1,364 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/prior_box.h" +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +const int MALLOC_ALIGN = 64; + +void* fast_malloc(size_t size) { + size_t offset = sizeof(void*) + MALLOC_ALIGN - 1; + char* p = static_cast(malloc(offset + size)); + + if (!p) { + return nullptr; + } + + void* r = reinterpret_cast(reinterpret_cast(p + offset) & + (~(MALLOC_ALIGN - 1))); + static_cast(r)[-1] = p; + memset(r, 0, size); + return r; +} + +void fast_free(void* ptr) { + if (ptr) { + free(static_cast(ptr)[-1]); + } +} + +void density_prior_box(const lite::Tensor* input, + const lite::Tensor* image, + lite::Tensor** boxes, + lite::Tensor** variances, + const std::vector& min_size_, + const std::vector& fixed_size_, + const std::vector& fixed_ratio_, + const std::vector& density_size_, + const std::vector& max_size_, + const std::vector& aspect_ratio_, + const std::vector& variance_, + int img_w_, + int img_h_, + float step_w_, + float step_h_, + float offset_, + int prior_num_, + bool is_flip_, + bool is_clip_, + const std::vector& order_) { + // compute output shape + int win1 = input->dims()[3]; + int hin1 = input->dims()[2]; + DDim shape_out({hin1, win1, prior_num_, 4}); + (*boxes)->Resize(shape_out); + (*variances)->Resize(shape_out); + + float* _cpu_data = (*boxes)->mutable_data(); + float* _variance_data = (*variances)->mutable_data(); + + const int width = win1; + const int height = hin1; + int img_width = img_w_; + int img_height = img_h_; + if (img_width == 0 || img_height == 0) { + img_width = image->dims()[3]; + img_height = image->dims()[2]; + } + + float step_w = step_w_; + float step_h = step_h_; + if (step_w == 0 || step_h == 0) { + step_w = static_cast(img_width) / width; + step_h = static_cast(img_height) / height; + } + + float offset = offset_; + int step_average = static_cast((step_w + step_h) * 0.5); // add + int channel_size = height * width * prior_num_ * 4; + int idx = 0; + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + float center_x = (w + offset) * step_w; + float center_y = (h + offset) * step_h; + float box_width; + float box_height; + if (fixed_size_.size() > 0) { + // add + for (int s = 0; s < fixed_size_.size(); ++s) { + int fixed_size = fixed_size_[s]; + int com_idx = 0; + box_width = fixed_size; + box_height = fixed_size; + + if (fixed_ratio_.size() > 0) { + for (int r = 0; r < fixed_ratio_.size(); ++r) { + float ar = fixed_ratio_[r]; + int density = density_size_[s]; + int shift = step_average / density; + float box_width_ratio = fixed_size_[s] * sqrt(ar); + float box_height_ratio = fixed_size_[s] / sqrt(ar); + + for (int p = 0; p < density; ++p) { + for (int c = 0; c < density; ++c) { + float center_x_temp = + center_x - step_average / 2.0f + shift / 2.f + c * shift; + float center_y_temp = + center_y - step_average / 2.0f + shift / 2.f + p * shift; + // xmin + _cpu_data[idx++] = + (center_x_temp - box_width_ratio / 2.f) / img_width >= 0 + ? (center_x_temp - box_width_ratio / 2.f) / img_width + : 0; + // ymin + _cpu_data[idx++] = + (center_y_temp - box_height_ratio / 2.f) / img_height >= 0 + ? (center_y_temp - box_height_ratio / 2.f) / + img_height + : 0; + // xmax + _cpu_data[idx++] = + (center_x_temp + box_width_ratio / 2.f) / img_width <= 1 + ? (center_x_temp + box_width_ratio / 2.f) / img_width + : 1; + // ymax + _cpu_data[idx++] = + (center_y_temp + box_height_ratio / 2.f) / img_height <= 1 + ? (center_y_temp + box_height_ratio / 2.f) / + img_height + : 1; + } + } + } + } else { + // this code for density anchor box + if (density_size_.size() > 0) { + CHECK_EQ(fixed_size_.size(), density_size_.size()) + << "fixed_size_ should be same with density_size_"; + int density = density_size_[s]; + int shift = fixed_size_[s] / density; + + for (int r = 0; r < density; ++r) { + for (int c = 0; c < density; ++c) { + float center_x_temp = + center_x - fixed_size / 2.f + shift / 2.f + c * shift; + float center_y_temp = + center_y - fixed_size / 2.f + shift / 2.f + r * shift; + // xmin + _cpu_data[idx++] = + (center_x_temp - box_width / 2.f) / img_width >= 0 + ? (center_x_temp - box_width / 2.f) / img_width + : 0; + // ymin + _cpu_data[idx++] = + (center_y_temp - box_height / 2.f) / img_height >= 0 + ? (center_y_temp - box_height / 2.f) / img_height + : 0; + // xmax + _cpu_data[idx++] = + (center_x_temp + box_width / 2.f) / img_width <= 1 + ? (center_x_temp + box_width / 2.f) / img_width + : 1; + // ymax + _cpu_data[idx++] = + (center_y_temp + box_height / 2.f) / img_height <= 1 + ? (center_y_temp + box_height / 2.f) / img_height + : 1; + } + } + } + + // rest of priors: will never come here!!! + for (int r = 0; r < aspect_ratio_.size(); ++r) { + float ar = aspect_ratio_[r]; + + if (fabs(ar - 1.) < 1e-6) { + continue; + } + + int density = density_size_[s]; + int shift = fixed_size_[s] / density; + float box_width_ratio = fixed_size_[s] * sqrt(ar); + float box_height_ratio = fixed_size_[s] / sqrt(ar); + + for (int p = 0; p < density; ++p) { + for (int c = 0; c < density; ++c) { + float center_x_temp = + center_x - fixed_size / 2.f + shift / 2.f + c * shift; + float center_y_temp = + center_y - fixed_size / 2.f + shift / 2.f + p * shift; + // xmin + _cpu_data[idx++] = + (center_x_temp - box_width_ratio / 2.f) / img_width >= 0 + ? (center_x_temp - box_width_ratio / 2.f) / img_width + : 0; + // ymin + _cpu_data[idx++] = + (center_y_temp - box_height_ratio / 2.f) / img_height >= 0 + ? (center_y_temp - box_height_ratio / 2.f) / + img_height + : 0; + // xmax + _cpu_data[idx++] = + (center_x_temp + box_width_ratio / 2.f) / img_width <= 1 + ? (center_x_temp + box_width_ratio / 2.f) / img_width + : 1; + // ymax + _cpu_data[idx++] = + (center_y_temp + box_height_ratio / 2.f) / img_height <= 1 + ? (center_y_temp + box_height_ratio / 2.f) / + img_height + : 1; + } + } + } + } + } + } else { + float* min_buf = + reinterpret_cast(fast_malloc(sizeof(float) * 4)); + float* max_buf = + reinterpret_cast(fast_malloc(sizeof(float) * 4)); + float* com_buf = reinterpret_cast( + fast_malloc(sizeof(float) * aspect_ratio_.size() * 4)); + + for (int s = 0; s < min_size_.size(); ++s) { + int min_idx = 0; + int max_idx = 0; + int com_idx = 0; + int min_size = min_size_[s]; + // first prior: aspect_ratio = 1, size = min_size + box_width = box_height = min_size; + //! xmin + min_buf[min_idx++] = (center_x - box_width / 2.f) / img_width; + //! ymin + min_buf[min_idx++] = (center_y - box_height / 2.f) / img_height; + //! xmax + min_buf[min_idx++] = (center_x + box_width / 2.f) / img_width; + //! ymax + min_buf[min_idx++] = (center_y + box_height / 2.f) / img_height; + + if (max_size_.size() > 0) { + int max_size = max_size_[s]; + //! second prior: aspect_ratio = 1, size = sqrt(min_size * max_size) + box_width = box_height = sqrtf(min_size * max_size); + //! xmin + max_buf[max_idx++] = (center_x - box_width / 2.f) / img_width; + //! ymin + max_buf[max_idx++] = (center_y - box_height / 2.f) / img_height; + //! xmax + max_buf[max_idx++] = (center_x + box_width / 2.f) / img_width; + //! ymax + max_buf[max_idx++] = (center_y + box_height / 2.f) / img_height; + } + + //! rest of priors + for (int r = 0; r < aspect_ratio_.size(); ++r) { + float ar = aspect_ratio_[r]; + if (fabs(ar - 1.) < 1e-6) { + continue; + } + box_width = min_size * sqrt(ar); + box_height = min_size / sqrt(ar); + //! xmin + com_buf[com_idx++] = (center_x - box_width / 2.f) / img_width; + //! ymin + com_buf[com_idx++] = (center_y - box_height / 2.f) / img_height; + //! xmax + com_buf[com_idx++] = (center_x + box_width / 2.f) / img_width; + //! ymax + com_buf[com_idx++] = (center_y + box_height / 2.f) / img_height; + } + memcpy(_cpu_data + idx, min_buf, sizeof(float) * min_idx); + idx += min_idx; + memcpy(_cpu_data + idx, com_buf, sizeof(float) * com_idx); + idx += com_idx; + memcpy(_cpu_data + idx, max_buf, sizeof(float) * max_idx); + idx += max_idx; + } + fast_free(min_buf); + fast_free(max_buf); + fast_free(com_buf); + } + } + } + //! clip the prior's coordinate such that it is within [0, 1] + if (is_clip_) { + for (int d = 0; d < channel_size; ++d) { + _cpu_data[d] = std::min(std::max(_cpu_data[d], 0.f), 1.f); + } + } + //! set the variance. + int count = 0; + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + for (int i = 0; i < prior_num_; ++i) { + for (int j = 0; j < 4; ++j) { + _variance_data[count] = variance_[j]; + ++count; + } + } + } + } +} + +void prior_box(const lite::Tensor* input, + const lite::Tensor* image, + lite::Tensor** boxes, + lite::Tensor** variances, + const std::vector& min_size, + const std::vector& max_size, + const std::vector& aspect_ratio, + const std::vector& variance, + int img_w, + int img_h, + float step_w, + float step_h, + float offset, + int prior_num, + bool is_flip, + bool is_clip, + const std::vector& order) { + density_prior_box(input, + image, + boxes, + variances, + min_size, + std::vector(), + std::vector(), + std::vector(), + max_size, + aspect_ratio, + variance, + img_w, + img_h, + step_w, + step_h, + offset, + prior_num, + is_flip, + is_clip, + order); +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/prior_box.h b/lite/arm/math/prior_box.h new file mode 100644 index 00000000000..59efb2ab002 --- /dev/null +++ b/lite/arm/math/prior_box.h @@ -0,0 +1,68 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void density_prior_box(const lite::Tensor* input, + const lite::Tensor* image, + lite::Tensor** boxes, + lite::Tensor** variances, + const std::vector& min_size_, + const std::vector& fixed_size_, + const std::vector& fixed_ratio_, + const std::vector& density_size_, + const std::vector& max_size_, + const std::vector& aspect_ratio_, + const std::vector& variance_, + int img_w_, + int img_h_, + float step_w_, + float step_h_, + float offset_, + int prior_num_, + bool is_flip_, + bool is_clip_, + const std::vector& order_); + +void prior_box(const lite::Tensor* input, + const lite::Tensor* image, + lite::Tensor** boxes, + lite::Tensor** variances, + const std::vector& min_size, + const std::vector& max_size, + const std::vector& aspect_ratio, + const std::vector& variance, + int img_w, + int img_h, + float step_w, + float step_h, + float offset, + int prior_num, + bool is_flip, + bool is_clip, + const std::vector& order); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/reduce_max.cc b/lite/arm/math/reduce_max.cc new file mode 100644 index 00000000000..7175a6709ba --- /dev/null +++ b/lite/arm/math/reduce_max.cc @@ -0,0 +1,207 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/arm/math/reduce_max.h" +#include "lite/arm/math/funcs.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void reduce_n(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int chw_size = channel_in * hw_size; + int data_index, src_index, src_index0; + for (int c = 0; c < channel_in; ++c) { + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + data_index = c * hw_size + h * width_in + w; + dst[data_index] = src[data_index]; + for (int n = 1; n < num_in; ++n) { + src_index = n * chw_size + data_index; + dst[data_index] = dst[data_index] > src[src_index] ? dst[data_index] + : src[src_index]; + } + } + } + } +} + +template <> +void reduce_c(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int chw_size = hw_size * channel_in; + int data_index, src_index0, src_index; + for (int n = 0; n < num_in; ++n) { + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + data_index = n * hw_size + h * width_in + w; + src_index0 = n * chw_size + h * width_in + w; + dst[data_index] = src[src_index0]; + for (int c = 1; c < channel_in; ++c) { + src_index = src_index0 + c * hw_size; + dst[data_index] = dst[data_index] > src[src_index] ? dst[data_index] + : src[src_index]; + } + } + } + } +} + +template <> +void reduce_h(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int cw_size = channel_in * width_in; + int chw_size = cw_size * height_in; + int hw_size = height_in * width_in; + int data_index, src_index, src_index0; + for (int n = 0; n < num_in; ++n) { + for (int c = 0; c < channel_in; ++c) { + for (int w = 0; w < width_in; ++w) { + data_index = n * cw_size + c * width_in + w; + src_index0 = n * chw_size + c * hw_size + w; + dst[data_index] = src[src_index0]; + for (int h = 1; h < height_in; ++h) { + src_index = src_index0 + h * width_in; + dst[data_index] = dst[data_index] > src[src_index] ? dst[data_index] + : src[src_index]; + } + } + } + } +} + +template <> +void reduce_w(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int ch_size = channel_in * height_in; + int hw_size = height_in * width_in; + int chw_size = ch_size * width_in; + int data_index = 0; + int src_index0 = 0; + int src_index = 0; + for (int n = 0; n < num_in; ++n) { + for (int c = 0; c < channel_in; ++c) { + for (int h = 0; h < height_in; ++h) { + data_index = n * ch_size + c * height_in + h; + src_index0 = n * chw_size + c * hw_size + h * width_in; + dst[data_index] = src[src_index0]; + for (int w = 1; w < width_in; ++w) { + src_index = src_index0 + w; + dst[data_index] = dst[data_index] > src[src_index] ? dst[data_index] + : src[src_index]; + } + } + } + } +} + +template <> +void reduce_all(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + float max = src[0]; + int src_index; + int n_id, c_id; + for (int n = 0; n < num_in; ++n) { + n_id = n * channel_in * height_in * width_in; + for (int c = 0; c < channel_in; ++c) { + c_id = c * height_in * width_in; + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + src_index = n_id + c_id + h * width_in + w; + max = src[src_index] > max ? src[src_index] : max; + } + } + } + } + dst[0] = max; +} + +template <> +void reduce_nc(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce n first. + DDimLite ddimA({1, channel_in, height_in, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + float* tmp_out = tensor_tmp.mutable_data(); + reduce_n(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_c(tmp_out, dst, 1, channel_in, height_in, width_in); +} + +template <> +void reduce_ch(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce c first + DDimLite ddimA({num_in, 1, height_in, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + float* tmp_out = tensor_tmp.mutable_data(); + reduce_c(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_h(tmp_out, dst, num_in, 1, height_in, width_in); +} + +template <> +void reduce_hw(const float* src, + float* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce h first + DDimLite ddimA({num_in, channel_in, 1, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + float* tmp_out = tensor_tmp.mutable_data(); + reduce_h(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_w(tmp_out, dst, num_in, channel_in, 1, width_in); +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/reduce_max.h b/lite/arm/math/reduce_max.h new file mode 100644 index 00000000000..dab96261822 --- /dev/null +++ b/lite/arm/math/reduce_max.h @@ -0,0 +1,89 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void reduce_n(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_c(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_h(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_w(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_nc(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_ch(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_hw(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +template +void reduce_all(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/saturate.h b/lite/arm/math/saturate.h new file mode 100644 index 00000000000..833f0f5c1c3 --- /dev/null +++ b/lite/arm/math/saturate.h @@ -0,0 +1,320 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +static inline _Tp saturate_cast(uint8_t v) { + return _Tp(v); +} +/** @overload */ +template +static inline _Tp saturate_cast(int8_t v) { + return _Tp(v); +} +/** @overload */ +template +static inline _Tp saturate_cast(uint16_t v) { + return _Tp(v); +} +/** @overload */ +template +static inline _Tp saturate_cast(int16_t v) { + return _Tp(v); +} +/** @overload */ +template +static inline _Tp saturate_cast(uint32_t v) { + return _Tp(v); +} +/** @overload */ +template +static inline _Tp saturate_cast(int32_t v) { + return _Tp(v); +} +/** @overload */ +template +static inline _Tp saturate_cast(float v) { + return _Tp(v); +} +/** @overload */ +template +static inline _Tp saturate_cast(double v) { + return _Tp(v); +} +/** @overload */ +template +static inline _Tp saturate_cast(int64_t v) { + return _Tp(v); +} +/** @overload */ +template +static inline _Tp saturate_cast(uint64_t v) { + return _Tp(v); +} + +template <> +inline uint8_t saturate_cast(int8_t v) { + return static_cast(std::max(static_cast(v), 0)); +} + +template <> +inline uint8_t saturate_cast(uint16_t v) { + return static_cast(std::min((unsigned)v, (unsigned)UCHAR_MAX)); +} + +template <> +inline uint8_t saturate_cast(int v) { + return static_cast( + ((unsigned)v <= UCHAR_MAX ? v : v > 0 ? UCHAR_MAX : 0)); +} + +template <> +inline uint8_t saturate_cast(int16_t v) { + return saturate_cast(static_cast(v)); +} + +template <> +inline uint8_t saturate_cast(unsigned v) { + return static_cast(std::min(v, (unsigned)UCHAR_MAX)); +} +template <> +inline uint8_t saturate_cast(float v) { + int iv = static_cast(roundf(v)); + return saturate_cast(iv); +} +template <> +inline uint8_t saturate_cast(double v) { + int iv = static_cast(round(v)); + return saturate_cast(iv); +} +template <> +inline uint8_t saturate_cast(int64_t v) { + return static_cast( + ((uint64_t)v <= (uint64_t)UCHAR_MAX ? v : v > 0 ? UCHAR_MAX : 0)); +} +template <> +inline uint8_t saturate_cast(uint64_t v) { + return static_cast(std::min(v, (uint64_t)UCHAR_MAX)); +} + +template <> +inline int8_t saturate_cast(uint8_t v) { + return static_cast(std::min(static_cast(v), SCHAR_MAX)); +} +template <> +inline int8_t saturate_cast(uint16_t v) { + return static_cast(std::min((unsigned)v, (unsigned)SCHAR_MAX)); +} +template <> +inline int8_t saturate_cast(int v) { + return static_cast(((unsigned)(v - SCHAR_MIN) <= (unsigned)UCHAR_MAX + ? v + : v > 0 ? SCHAR_MAX : SCHAR_MIN)); +} +template <> +inline int8_t saturate_cast(int16_t v) { + return saturate_cast(static_cast(v)); +} +template <> +inline int8_t saturate_cast(unsigned v) { + return static_cast(std::min(v, (unsigned)SCHAR_MAX)); +} +template <> +inline int8_t saturate_cast(float v) { + int iv = static_cast(roundf(v)); + return saturate_cast(iv); +} +template <> +inline int8_t saturate_cast(double v) { + int iv = static_cast(round(v)); + return saturate_cast(iv); +} +template <> +inline int8_t saturate_cast(int64_t v) { + return static_cast( + ((uint64_t)(static_cast(v) - SCHAR_MIN) <= (uint64_t)UCHAR_MAX + ? v + : v > 0 ? SCHAR_MAX : SCHAR_MIN)); +} +template <> +inline int8_t saturate_cast(uint64_t v) { + return static_cast(std::min(v, (uint64_t)SCHAR_MAX)); +} + +template <> +inline uint16_t saturate_cast(int8_t v) { + return static_cast(std::max(static_cast(v), 0)); +} + +template <> +inline uint16_t saturate_cast(int16_t v) { + return static_cast(std::max(static_cast(v), 0)); +} +template <> +inline uint16_t saturate_cast(int v) { + return static_cast( + (unsigned)v <= (unsigned)USHRT_MAX ? v : v > 0 ? USHRT_MAX : 0); +} +template <> +inline uint16_t saturate_cast(unsigned v) { + return static_cast(std::min(v, (unsigned)USHRT_MAX)); +} +template <> +inline uint16_t saturate_cast(float v) { + int iv = static_cast(roundf(v)); + return saturate_cast(iv); +} +template <> +inline uint16_t saturate_cast(double v) { + int iv = static_cast(round(v)); + return saturate_cast(iv); +} +template <> +inline uint16_t saturate_cast(int64_t v) { + return static_cast( + (uint64_t)v <= (uint64_t)USHRT_MAX ? v : v > 0 ? USHRT_MAX : 0); +} +template <> +inline uint16_t saturate_cast(uint64_t v) { + return static_cast(std::min(v, (uint64_t)USHRT_MAX)); +} + +template <> +inline int16_t saturate_cast(uint16_t v) { + return static_cast(std::min(static_cast(v), SHRT_MAX)); +} +template <> +inline int16_t saturate_cast(int v) { + return static_cast((unsigned)(v - SHRT_MIN) <= (unsigned)USHRT_MAX + ? v + : v > 0 ? SHRT_MAX : SHRT_MIN); +} +template <> +inline int16_t saturate_cast(unsigned v) { + return (int16_t)std::min(v, (unsigned)SHRT_MAX); +} +template <> +inline int16_t saturate_cast(float v) { + int iv = static_cast(roundf(v)); + return saturate_cast(iv); +} +template <> +inline int16_t saturate_cast(double v) { + int iv = static_cast(round(v)); + return saturate_cast(iv); +} +template <> +inline int16_t saturate_cast(int64_t v) { + return static_cast((uint64_t)((int64_t)v - SHRT_MIN) <= + (uint64_t)USHRT_MAX + ? v + : v > 0 ? SHRT_MAX : SHRT_MIN); +} +template <> +inline int16_t saturate_cast(uint64_t v) { + return static_cast(std::min(v, (uint64_t)SHRT_MAX)); +} + +template <> +inline int saturate_cast(unsigned v) { + return static_cast(std::min(v, (unsigned)INT_MAX)); +} +template <> +inline int saturate_cast(int64_t v) { + return static_cast((uint64_t)(v - INT_MIN) <= (uint64_t)UINT_MAX + ? v + : v > 0 ? INT_MAX : INT_MIN); +} +template <> +inline int saturate_cast(uint64_t v) { + return static_cast(std::min(v, (uint64_t)INT_MAX)); +} +template <> +inline int saturate_cast(float v) { + return static_cast(roundf(v)); +} +template <> +inline int saturate_cast(double v) { + return static_cast(round(v)); +} + +template <> +inline unsigned saturate_cast(int8_t v) { + return static_cast(std::max(v, static_cast(0))); +} +template <> +inline unsigned saturate_cast(int16_t v) { + return static_cast(std::max(v, (int16_t)0)); +} +template <> +inline unsigned saturate_cast(int v) { + return static_cast(std::max(v, static_cast(0))); +} +template <> +inline unsigned saturate_cast(int64_t v) { + return static_cast( + (uint64_t)v <= (uint64_t)UINT_MAX ? v : v > 0 ? UINT_MAX : 0); +} +template <> +inline unsigned saturate_cast(uint64_t v) { + return static_cast(std::min(v, (uint64_t)UINT_MAX)); +} +// we intentionally do not clip negative numbers, to make -1 become 0xffffffff +// etc. +template <> +inline unsigned saturate_cast(float v) { + return static_cast(roundf(v)); +} +template <> +inline unsigned saturate_cast(double v) { + return static_cast(round(v)); +} + +template <> +inline uint64_t saturate_cast(int8_t v) { + return static_cast(std::max(v, static_cast(0))); +} + +template <> +inline uint64_t saturate_cast(int16_t v) { + return static_cast(std::max(v, (int16_t)0)); +} +template <> +inline uint64_t saturate_cast(int v) { + return static_cast(std::max(v, static_cast(0))); +} +template <> +inline uint64_t saturate_cast(int64_t v) { + return static_cast(std::max(v, (int64_t)0)); +} + +template <> +inline int64_t saturate_cast(uint64_t v) { + return static_cast(std::min(v, (uint64_t)LLONG_MAX)); +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/scale.cc b/lite/arm/math/scale.cc new file mode 100644 index 00000000000..23036a7e1d1 --- /dev/null +++ b/lite/arm/math/scale.cc @@ -0,0 +1,177 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/scale.h" +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void scale( + const float* din, float* dout, int num, float scale, float bias) { + int cnt = num >> 4; + int remain = num % 16; + float32x4_t vscale = vdupq_n_f32(scale); + float32x4_t vbias = vdupq_n_f32(bias); +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* din_ptr = din + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + float32x4_t vsum1 = vmlaq_f32(vbias, din0, vscale); + float32x4_t vsum2 = vmlaq_f32(vbias, din1, vscale); + float32x4_t vsum3 = vmlaq_f32(vbias, din2, vscale); + float32x4_t vsum4 = vmlaq_f32(vbias, din3, vscale); + + vst1q_f32(dout_ptr, vsum1); + vst1q_f32(dout_ptr + 4, vsum2); + vst1q_f32(dout_ptr + 8, vsum3); + vst1q_f32(dout_ptr + 12, vsum4); + } + if (remain > 0) { + const float* din_ptr = din + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + *dout_ptr = *din_ptr * scale + bias; + dout_ptr++; + din_ptr++; + } + } +} + +template <> +void scale(const float* din, + float* dout, + int outer_dim, + int scale_dim, + int inner_dim, + const float* scale_data, + const float* bias_data) { + int cnt = inner_dim >> 4; + int remain = inner_dim % 16; + int size = inner_dim * scale_dim; + for (int n = 0; n < outer_dim; n++) { + const float* din_ptr_n = din + n * size; + float* dout_ptr_n = dout + n * size; +#pragma omp parallel for + for (int i = 0; i < scale_dim; i++) { + const float* din_ptr = din_ptr_n + i * inner_dim; + float* dout_ptr = dout_ptr_n + i * inner_dim; + float scale = scale_data[i]; + float32x4_t vscale = vdupq_n_f32(scale); + float bias = bias_data[i]; + float32x4_t vbias = vdupq_n_f32(bias); + for (int j = 0; j < cnt; j++) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + float32x4_t vsum1 = vmlaq_f32(vbias, din0, vscale); + float32x4_t vsum2 = vmlaq_f32(vbias, din1, vscale); + float32x4_t vsum3 = vmlaq_f32(vbias, din2, vscale); + float32x4_t vsum4 = vmlaq_f32(vbias, din3, vscale); + + din_ptr += 16; + vst1q_f32(dout_ptr, vsum1); + vst1q_f32(dout_ptr + 4, vsum2); + vst1q_f32(dout_ptr + 8, vsum3); + vst1q_f32(dout_ptr + 12, vsum4); + + dout_ptr += 16; + } + for (int j = 0; j < remain; j++) { + *dout_ptr = *din_ptr * scale + bias; + dout_ptr++; + din_ptr++; + } + } + } +} + +template <> +void scale(const float* din, + float* dout, + int outer_dim, + int scale_dim, + const float* scale_data, + const float* bias_data) { + int cnt = scale_dim >> 4; + int remain = scale_dim % 16; + for (int n = 0; n < outer_dim; n++) { + const float* din_ptr_n = din + n * scale_dim; + float* dout_ptr_n = dout + n * scale_dim; +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + int idx = i << 4; + const float* din_ptr = din_ptr_n + idx; + const float* scale_ptr = scale_data + idx; + const float* bias_ptr = bias_data + idx; + float* dout_ptr = dout_ptr_n + idx; + + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t vscale0 = vld1q_f32(scale_ptr); + float32x4_t vbias0 = vld1q_f32(bias_ptr); + + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t vscale1 = vld1q_f32(scale_ptr + 4); + float32x4_t vbias1 = vld1q_f32(bias_ptr + 4); + + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t vscale2 = vld1q_f32(scale_ptr + 8); + float32x4_t vbias2 = vld1q_f32(bias_ptr + 8); + + float32x4_t vsum1 = vmlaq_f32(vbias0, din0, vscale0); + float32x4_t vsum2 = vmlaq_f32(vbias1, din1, vscale1); + + float32x4_t din3 = vld1q_f32(din_ptr + 12); + float32x4_t vscale3 = vld1q_f32(scale_ptr + 12); + float32x4_t vbias3 = vld1q_f32(bias_ptr + 12); + + vst1q_f32(dout_ptr, vsum1); + vst1q_f32(dout_ptr + 4, vsum2); + + float32x4_t vsum3 = vmlaq_f32(vbias2, din2, vscale2); + float32x4_t vsum4 = vmlaq_f32(vbias3, din3, vscale3); + + vst1q_f32(dout_ptr + 8, vsum3); + vst1q_f32(dout_ptr + 12, vsum4); + } + int idx = cnt << 4; + const float* din_ptr = din_ptr_n + idx; + float* dout_ptr = dout_ptr_n + idx; + const float* scale_ptr = scale_data + idx; + const float* bias_ptr = bias_data + idx; + for (int j = 0; j < remain; j++) { + *dout_ptr = *din_ptr * (*scale_ptr) + (*bias_ptr); + dout_ptr++; + din_ptr++; + scale_ptr++; + bias_ptr++; + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/scale.h b/lite/arm/math/scale.h new file mode 100644 index 00000000000..a86528c9df1 --- /dev/null +++ b/lite/arm/math/scale.h @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void scale(const T* din, T* dout, int num, float scale, float bias); + +template +void scale(const T* din, + T* dout, + int outer_dim, + int scale_dim, + int inner_dim, + const float* scale_data, + const float* bias_data); + +template +void scale(const T* din, + T* dout, + int outer_dim, + int scale_dim, + const float* scale_data, + const float* bias_data); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/sequence2batch.h b/lite/arm/math/sequence2batch.h new file mode 100644 index 00000000000..d982ad66676 --- /dev/null +++ b/lite/arm/math/sequence2batch.h @@ -0,0 +1,210 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +class CopyMatrixRowsFunctor { + public: + // If is_src_index is true, + // copy the indexed rows of input src to the output dst. + // If is_src_index is false, + // copy the input src to the indexed rows of output dst. + // The indexed rows are based on the input index. + void operator()(const Tensor& src, + std::vector index_lod, + Tensor* dst, + bool is_src_index) { + auto index = index_lod.data(); + auto src_dims = src.dims(); + auto dst_dims = dst->dims(); + CHECK_EQ(src_dims.size(), 2UL) << "The src must be matrix with rank 2."; + CHECK_EQ(dst_dims.size(), 2UL) << "The dst must be matrix with rank 2."; + CHECK_EQ(src_dims[1], dst_dims[1]) + << "The width of src and dst must be same."; + auto height = dst_dims[0]; + auto width = dst_dims[1]; + auto* src_data = src.data(); + auto* dst_data = dst->mutable_data(); + const int sz = width * sizeof(T); + if (is_src_index) { + for (int i = 0; i < height; ++i) { + TargetCopy(TARGET(kARM), + dst_data + i * width, + src_data + index[i] * width, + sz); + } + } else { + for (int i = 0; i < height; ++i) { + TargetCopy(TARGET(kARM), + dst_data + index[i] * width, + src_data + i * width, + sz); + } + } + } +}; + +template +class LoDTensor2BatchFunctor { + // Calculate the length of each sequence and + // sort sequence index by the length. + // example: sequences = {s0, s1, s2} + // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 + // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} + // + struct SeqInfo { + SeqInfo(int start, int length, int seq_idx) + : start(start), length(length), seq_idx(seq_idx) {} + int start; + int length; + int seq_idx; + }; + + public: + void operator()(const Tensor& lod_tensor, + Tensor* batch, + bool is_cal_batch_lod, + bool is_reverse = false) const { + if (!is_cal_batch_lod) { + auto lods = batch->lod(); + CHECK_GT(lods.size(), 2UL) + << "The LoD of LoDTensor should inlcude at least 2-level " + "sequence information."; + CHECK_EQ(lods[1].size(), static_cast(lod_tensor.dims()[0])) + << "The LoD information should be consistent with the dims."; + CopyMatrixRowsFunctor to_batch; + to_batch(lod_tensor, lods[1], batch, true); + return; + } + + auto lods = lod_tensor.lod(); + CHECK_EQ(lods.size(), 1UL) << "Only support one level sequence now."; + + const auto& lod = lods[0]; + + std::vector seq_info; + for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { + int length = lod[seq_id + 1] - lod[seq_id]; + seq_info.emplace_back(lod[seq_id], length, seq_id); + } + + std::sort(seq_info.begin(), seq_info.end(), [](SeqInfo a, SeqInfo b) { + return a.length > b.length; + }); + + // Calculate the start position of each batch. + // example: sequences = {s0, s1, s2} + // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 + // max_seqlen = 5, + // batchIndex = {b0, b1, b2, b3, b4} + // b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1 + // batch_start_positions[6] = {0, 3, 6, 9, 11, 12} + // batch_start_positions[0] = len(b0) + // batch_start_positions[1] = len(b0) + len(b1) + // batch_start_positions[2] = len(b0) + len(b1) + len(b2) + // ... + // seq2batch_idx[12] = {4, 0, 9, + // 5, 1, 10, + // 6, 2, 11, + // 7, 3, + // 8} + // seq_order = {1, 0, 2}, the sort order. + // where 1 is the second sequence, + // 0 is the first sequence, + // 2 is the third sequence. + // The max_seqlen represents batch size after rearranging the + // input LodTensor. It is also the maximum length of input sequence. + + LoD batch_lods; + batch_lods.emplace_back(std::vector{0}); + batch_lods.emplace_back(std::vector{0}); + batch_lods.emplace_back(std::vector{0}); + + // batch_lods[0] is the start positions for batch LoDTensor + int max_seqlen = seq_info[0].length; + batch_lods[0].resize(static_cast(max_seqlen + 1)); + // batch_lods[1] is the raw index in the input LoDTensor + batch_lods[1].resize(static_cast(lod_tensor.dims()[0])); + // batch_lods[2] is the sort order for the input LoDTensor. + batch_lods[2].resize(seq_info.size()); + + auto batch_starts = batch_lods[0].data(); + auto seq2batch_idx = batch_lods[1].data(); + batch_starts[0] = 0; + for (int n = 0; n < max_seqlen; n++) { + auto batch_id = static_cast(batch_starts[n]); + for (size_t i = 0; i < seq_info.size(); ++i) { + int seq_len = seq_info[i].length; + int start = seq_info[i].start; + if (n < seq_len) { + seq2batch_idx[batch_id] = + is_reverse ? start + seq_len - 1 - n : start + n; + batch_id++; + } else { + break; + } + } + batch_starts[n + 1] = static_cast(batch_id); + } + auto seq_order = batch_lods[2].data(); + for (size_t i = 0; i < seq_info.size(); ++i) { + seq_order[i] = seq_info[i].seq_idx; + } + *(batch->mutable_lod()) = batch_lods; + + CopyMatrixRowsFunctor to_batch; + to_batch(lod_tensor, batch_lods[1], batch, true); + } +}; + +template +class Batch2LoDTensorFunctor { + public: + void operator()(const Tensor& batch, Tensor* lod_tensor) const { + auto in_lod = batch.lod(); + CHECK_GT(in_lod.size(), 2UL) + << "The LoD of LoDTensor should inlcude at least 2-level " + "sequence information."; + CHECK_EQ(in_lod[1].size(), static_cast(lod_tensor->dims()[0])) + << "The LoD information should be consistent with the dims."; + CopyMatrixRowsFunctor to_seq; + to_seq(batch, in_lod[1], lod_tensor, false); + } +}; + +template +inline void ReorderInitState(const Tensor& src, + const std::vector& index_lod, + Tensor* dst, + bool indexed_src) { + CopyMatrixRowsFunctor row_shuffle; + dst->Resize(src.dims()); + dst->mutable_data(); + row_shuffle(src, index_lod, dst, indexed_src); +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/sequence_expand.cc b/lite/arm/math/sequence_expand.cc new file mode 100644 index 00000000000..0048ad74e30 --- /dev/null +++ b/lite/arm/math/sequence_expand.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/sequence_expand.h" +#include +#include +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void SequenceExpandImpl(const float* x_data, + const LoD& x_lod, + int width, + const std::vector& ref_lod, + lite::Tensor* output) { + float* output_data = output->mutable_data(); + if (x_lod.size() == 0) { + for (int i = 0; i < ref_lod.size() - 1; i++) { + for (int j = ref_lod[i]; j < ref_lod[i + 1]; j++) { + memcpy( + output_data + j * width, x_data + i * width, sizeof(float) * width); + } + } + (output->mutable_lod())->push_back(ref_lod); + } else { + std::vector out_lod; + out_lod.push_back(0); + uint64_t out_offset = 0; + uint64_t len = 0; + for (int i = 0; i < ref_lod.size() - 1; i++) { + auto x_seq_len = x_lod[0][i + 1] - x_lod[0][i]; + for (int j = ref_lod[i]; j < ref_lod[i + 1]; j++) { + memcpy(output_data + out_offset * width, + x_data + len * width, + width * sizeof(float) * x_seq_len); + out_offset += x_seq_len; + out_lod.push_back(out_offset); + } + len += x_seq_len; + } + (output->mutable_lod())->push_back(out_lod); + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/sequence_expand.h b/lite/arm/math/sequence_expand.h new file mode 100644 index 00000000000..d3b19a4c626 --- /dev/null +++ b/lite/arm/math/sequence_expand.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/core/tensor.h" + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void SequenceExpandImpl(const T* x_data, + const LoD& x_lod, + int width, + const std::vector& ref_lod, + lite::Tensor* output); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/sequence_pool.cc b/lite/arm/math/sequence_pool.cc new file mode 100644 index 00000000000..a0372101041 --- /dev/null +++ b/lite/arm/math/sequence_pool.cc @@ -0,0 +1,224 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/sequence_pool.h" +#include +#include +#include +#include +#include "lite/arm/math/funcs.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void seq_pool_sum(const float* din, + float* dout, + const std::vector lod, + int64_t width) { + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + const float* din_ptr = din + lod[i] * width; + float* dout_ptr = dout + i * width; + int64_t height = static_cast(lod[i + 1] - lod[i]); + if (width == 1) { + float sum = 0.f; + for (int h = 0; h < height; ++h) { + sum += din_ptr[h]; + } + *dout_ptr = sum; + } else { + memcpy(dout_ptr, din_ptr, width * sizeof(float)); + din_ptr += width; + height = height - 1; + for (int h = 0; h < height; h++) { + for (int w = 0; w < width; ++w) { + dout_ptr[w] += din_ptr[w]; + } + din_ptr += width; + } + } + } +} + +template <> +void seq_pool_average(const float* din, + float* dout, + const std::vector lod, + int64_t width) { + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + const float* din_ptr = din + lod[i] * width; + float* dout_ptr = dout + i * width; + int64_t height = static_cast(lod[i + 1] - lod[i]); + if (height > 0) { + if (width == 1) { + float sum = 0.f; + for (int h = 0; h < height; ++h) { + sum += din_ptr[h]; + } + *dout_ptr = sum / height; + } else { + memcpy(dout_ptr, din_ptr, width * sizeof(float)); + din_ptr += width; + int remain_h = height - 1; + for (int h = 0; h < remain_h; h++) { + for (int w = 0; w < width; ++w) { + dout_ptr[w] += din_ptr[w]; + } + din_ptr += width; + } + for (int w = 0; w < width; ++w) { + dout_ptr[w] /= height; + } + } + } + } +} + +template <> +void seq_pool_sqrt(const float* din, + float* dout, + const std::vector lod, + int64_t width) { + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + const float* din_ptr = din + lod[i] * width; + float* dout_ptr = dout + i * width; + int64_t height = static_cast(lod[i + 1] - lod[i]); + if (height > 0) { + float sqrt_len = sqrtf(height); + if (width == 1) { + float sum = 0.f; + for (int h = 0; h < height; ++h) { + sum += din_ptr[h]; + } + *dout_ptr = sum / sqrt_len; + } else { + memcpy(dout_ptr, din_ptr, width * sizeof(float)); + din_ptr += width; + int remain_h = height - 1; + for (int h = 0; h < remain_h; h++) { + for (int w = 0; w < width; ++w) { + dout_ptr[w] += din_ptr[w]; + } + din_ptr += width; + } + for (int w = 0; w < width; ++w) { + dout_ptr[w] /= sqrt_len; + } + } + } + } +} + +template <> +void seq_pool_max(const float* din, + float* dout, + const std::vector lod, + int64_t width) { + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + const float* din_ptr = din + lod[i] * width; + float* dout_ptr = dout + i * width; + int64_t height = static_cast(lod[i + 1] - lod[i]); + if (height > 0) { + if (width == 1) { + float max = -std::numeric_limits::max(); + for (int h = 0; h < height; ++h) { + max = std::max(max, din_ptr[h]); + } + *dout_ptr = max; + } else { + memcpy(dout_ptr, din_ptr, width * sizeof(float)); + din_ptr += width; + int remain_h = height - 1; + for (int h = 0; h < remain_h; h++) { + for (int w = 0; w < width; w++) { + dout_ptr[w] = std::max(dout_ptr[w], din_ptr[w]); + } + din_ptr += width; + } + } + } + } +} + +template <> +void seq_pool_min(const float* din, + float* dout, + const std::vector lod, + int64_t width) { + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + const float* din_ptr = din + lod[i] * width; + float* dout_ptr = dout + i * width; + int64_t height = static_cast(lod[i + 1] - lod[i]); + if (height > 0) { + if (width == 1) { + float min = std::numeric_limits::max(); + for (int h = 0; h < height; ++h) { + min = std::min(min, din_ptr[h]); + } + *dout_ptr = min; + } else { + memcpy(dout_ptr, din_ptr, width * sizeof(float)); + din_ptr += width; + int remain_h = height - 1; + for (int h = 0; h < remain_h; h++) { + for (int w = 0; w < width; w++) { + dout_ptr[w] = std::min(dout_ptr[w], din_ptr[w]); + } + din_ptr += width; + } + } + } + } +} + +template <> +void seq_pool_first(const float* din, + float* dout, + const std::vector lod, + int64_t width) { + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + int64_t height = lod[i + 1] - lod[i]; + const float* din_ptr = din + width * lod[i]; + float* dout_ptr = dout + i * width; + if (height > 0) { + memcpy(dout_ptr, din_ptr, width * sizeof(float)); + } + } +} + +template <> +void seq_pool_last(const float* din, + float* dout, + const std::vector lod, + int64_t width) { + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + int64_t height = lod[i + 1] - lod[i]; + int64_t seq_len = static_cast(lod[i + 1] - lod[0]); + const float* din_ptr = din + width * seq_len; + float* dout_ptr = dout + i * width; + if (height > 0) { + memcpy(dout_ptr, din_ptr - width, width * sizeof(float)); + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/sequence_pool.h b/lite/arm/math/sequence_pool.h new file mode 100644 index 00000000000..6cbcd7d6d60 --- /dev/null +++ b/lite/arm/math/sequence_pool.h @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void seq_pool_sum(const T* din, + T* dout, + const std::vector lod, + int64_t width); + +template +void seq_pool_average(const T* din, + T* dout, + const std::vector lod, + int64_t width); + +template +void seq_pool_sqrt(const T* din, + T* dout, + const std::vector lod, + int64_t width); + +template +void seq_pool_max(const T* din, + T* dout, + const std::vector lod, + int64_t width); + +template +void seq_pool_min(const T* din, + T* dout, + const std::vector lod, + int64_t width); + +template +void seq_pool_first(const T* din, + T* dout, + const std::vector lod, + int64_t width); + +template +void seq_pool_last(const T* din, + T* dout, + const std::vector lod, + int64_t width); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/sequence_softmax.cc b/lite/arm/math/sequence_softmax.cc new file mode 100644 index 00000000000..d2be72639ee --- /dev/null +++ b/lite/arm/math/sequence_softmax.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/sequence_softmax.h" +#include +#include +#include +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +bool sequence_softmax(const float* input, + const std::vector& seq_offset, + float* out, + Context* ctx) { + int seq_num = seq_offset.size() - 1; + for (int i = 0; i < seq_num; i++) { + float seq_max = input[seq_offset[i]]; + float exp_sum = 0.f; + for (int j = seq_offset[i]; j < seq_offset[i + 1]; j++) { + seq_max = std::max(seq_max, input[j]); + } + for (int j = seq_offset[i]; j < seq_offset[i + 1]; j++) { + exp_sum += expf(input[j] - seq_max); + } + for (int j = seq_offset[i]; j < seq_offset[i + 1]; j++) { + out[j] = expf(input[j] - seq_max) / exp_sum; + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/sequence_softmax.h b/lite/arm/math/sequence_softmax.h new file mode 100644 index 00000000000..2923039b0c0 --- /dev/null +++ b/lite/arm/math/sequence_softmax.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/context.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +bool sequence_softmax(const float* input, + const std::vector& seq_offset, + float* out, + Context* ctx); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/sgemm.cc b/lite/arm/math/sgemm.cc new file mode 100644 index 00000000000..bea1ac633ae --- /dev/null +++ b/lite/arm/math/sgemm.cc @@ -0,0 +1,68 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/sgemm.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void sgemm(bool is_transA, + bool is_transB, + int M, + int N, + int K, + float alpha, + const float* A, + int lda, + const float* B, + int ldb, + float beta, + float* C, + int ldc, + const float* bias, + bool is_bias, + bool is_relu, + ARMContext* ctx) { + auto arch = ctx->arch(); + int hblock = get_hblock(arch); + int m_roundup = hblock * ((M + hblock - 1) / hblock); + + auto packed_A = static_cast( + TargetMalloc(TargetType::kARM, m_roundup * K * sizeof(float))); + + prepackA(packed_A, A, alpha, lda, 0, M, 0, K, is_transA, ctx); + + sgemm_prepack(is_transB, + M, + N, + K, + packed_A, + B, + ldb, + beta, + C, + ldc, + bias, + is_bias, + is_relu, + ctx); + TargetFree(TargetType::kARM, packed_A); +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/sgemm.h b/lite/arm/math/sgemm.h new file mode 100644 index 00000000000..63d4a8e5b60 --- /dev/null +++ b/lite/arm/math/sgemm.h @@ -0,0 +1,48 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/arm/math/packed_sgemm.h" +#include "lite/core/context.h" +#include "lite/core/cpu_info.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void sgemm(bool is_transA, + bool is_transB, + int M, + int N, + int K, + float alpha, + const float* A, + int lda, + const float* B, + int ldb, + float beta, + float* C, + int ldc, + const float* bias, + bool is_bias, + bool is_relu, + ARMContext* ctx); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/sgemv.cc b/lite/arm/math/sgemv.cc new file mode 100644 index 00000000000..d1449a88ae0 --- /dev/null +++ b/lite/arm/math/sgemv.cc @@ -0,0 +1,1054 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/sgemv.h" +#include +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void sgemv(const bool transA, + const int M, + const int N, + const float *A, + const float *x, + float *y); + +void sgemv_relu(const bool transA, + const int M, + const int N, + const float *A, + const float *x, + float *y); + +void sgemv_bias(const bool transA, + const int M, + const int N, + const float *A, + const float *x, + float *y, + const float *bias); + +void sgemv_bias_relu(const bool transA, + const int M, + const int N, + const float *A, + const float *x, + float *y, + const float *bias); + +bool sgemv(const float *A, + const float *x, + float *y, + bool transA, + int M, + int N, + bool is_bias, + const float *bias, + bool is_relu) { + if (transA) { + LOG(ERROR) << " sgemv, transA is not supported now"; + return false; + } + if (is_bias) { + //! with bias + if (is_relu) { + //! with relu + sgemv_bias_relu(transA, M, N, A, x, y, bias); + } else { + //! without relu + sgemv_bias(transA, M, N, A, x, y, bias); + } + } else { + //! without bias + if (is_relu) { + //! with relu + sgemv_relu(transA, M, N, A, x, y); + } else { + //! without relu + sgemv(transA, M, N, A, x, y); + } + } + return true; +} + +//! define compute kernel +#ifdef __aarch64__ +#define SGEMV_IN_8 \ + "prfm pldl1keep, [%[in]] \n" /* preload din */ \ + "prfm pldl1keep, [%[w0]] \n" /* preload w0 */ \ + "prfm pldl1keep, [%[w1]] \n" /* preload w1 */ \ + "prfm pldl1keep, [%[w2]] \n" /* preload w2 */ \ + "prfm pldl1keep, [%[w3]] \n" /* preload w3 */ \ + "prfm pldl1keep, [%[w4]] \n" /* preload w4 */ \ + "prfm pldl1keep, [%[w5]] \n" /* preload w5 */ \ + "prfm pldl1keep, [%[w6]] \n" /* preload w6 */ \ + "prfm pldl1keep, [%[w7]] \n" /* preload w7 */ \ + "movi v0.4s, #0 \n" /* set out0 to 0 */ \ + "movi v1.4s, #0 \n" /* set out1 to 0 */ \ + "movi v2.4s, #0 \n" /* set out2 to 0 */ \ + "movi v3.4s, #0 \n" /* set out3 to 0 */ \ + "movi v4.4s, #0 \n" /* set out4 to 0 */ \ + "movi v5.4s, #0 \n" /* set out5 to 0 */ \ + "movi v6.4s, #0 \n" /* set out6 to 0 */ \ + "movi v7.4s, #0 \n" /* set out7 to 0 */ + +#define SGEMV_IN_8_BIAS \ + "ldp q8, q9, [%[bias_ptr]]\n" /* load bias to q8, q9*/ \ + "prfm pldl1keep, [%[in]] \n" /* preload din */ \ + "prfm pldl1keep, [%[w0]] \n" /* preload w0 */ \ + "prfm pldl1keep, [%[w1]] \n" /* preload w1 */ \ + "prfm pldl1keep, [%[w2]] \n" /* preload w2 */ \ + "prfm pldl1keep, [%[w3]] \n" /* preload w3 */ \ + "prfm pldl1keep, [%[w4]] \n" /* preload w4 */ \ + "prfm pldl1keep, [%[w5]] \n" /* preload w5 */ \ + "prfm pldl1keep, [%[w6]] \n" /* preload w6 */ \ + "prfm pldl1keep, [%[w7]] \n" /* preload w7 */ \ + "movi v0.4s, #0 \n" /* set out0 to 0 */ \ + "movi v1.4s, #0 \n" /* set out1 to 0 */ \ + "movi v2.4s, #0 \n" /* set out2 to 0 */ \ + "movi v3.4s, #0 \n" /* set out3 to 0 */ \ + "movi v4.4s, #0 \n" /* set out4 to 0 */ \ + "movi v5.4s, #0 \n" /* set out5 to 0 */ \ + "movi v6.4s, #0 \n" /* set out6 to 0 */ \ + "movi v7.4s, #0 \n" /* set out7 to 0 */ \ + "ins v0.s[0], v8.s[0] \n" /* out0 = bias0 */ \ + "ins v1.s[0], v8.s[1] \n" /* out1 = bias1 */ \ + "ins v2.s[0], v8.s[2] \n" /* out2 = bias2 */ \ + "ins v3.s[0], v8.s[3] \n" /* out3 = bias3 */ \ + "ins v4.s[0], v9.s[0] \n" /* out4 = bias4 */ \ + "ins v5.s[0], v9.s[1] \n" /* out5 = bias5 */ \ + "ins v6.s[0], v9.s[2] \n" /* out6 = bias6 */ \ + "ins v7.s[0], v9.s[3] \n" /* out7 = bias7 */ + +#define SGEMV_IN_1 \ + "prfm pldl1keep, [%[in]] \n" /* preload din */ \ + "prfm pldl1keep, [%[w0]] \n" /* preload w0 */ \ + "movi v0.4s, #0 \n" /* set out0 to 0 */ \ + "movi v1.4s, #0 \n" /* set out0 to 0 */ + +#define SGEMV_IN_1_BIAS \ + "prfm pldl1keep, [%[in]] \n" /* preload din */ \ + "prfm pldl1keep, [%[w0]] \n" /* preload w0 */ \ + "movi v0.4s, #0 \n" /* set out0 to 0 */ \ + "movi v1.4s, #0 \n" /* set out0 to 0 */ \ + "fmov s0, %w[bias0] \n" /* set out0 = bias0 */ + +#define SGEMV_KERNEL_8 \ + /* check main loop */ \ + "cmp %w[cnt], #1 \n" /* check whether has main loop */ \ + "blt 2f \n" /* jump to tail */ /* main loop */ \ + "1: \n" /* main loop */ \ + "ldp q8, q9, [%[in]], #32 \n" /* load input 8 float */ \ + "ldp q10, q11, [%[w0]], #32 \n" /* load w0 8 float */ \ + "ldp q12, q13, [%[w1]], #32 \n" /* load w1 8 float */ \ + "ldp q14, q15, [%[w2]], #32 \n" /* load w2 8 float */ \ + "ldp q16, q17, [%[w3]], #32 \n" /* load w3 8 float */ \ + "ldp q18, q19, [%[w4]], #32 \n" /* load w4 8 float */ \ + "ldp q20, q21, [%[w5]], #32 \n" /* load w5 8 float */ \ + "ldp q22, q23, [%[w6]], #32 \n" /* load w6 8 float */ \ + "ldp q24, q25, [%[w7]], #32 \n" /* load w7 8 float */ \ + "fmla v0.4s, v8.4s, v10.4s \n" /* mul + add*/ \ + "fmla v1.4s, v8.4s, v12.4s \n" /* mul + add*/ \ + "fmla v2.4s, v8.4s, v14.4s \n" /* mul + add*/ \ + "fmla v3.4s, v8.4s, v16.4s \n" /* mul + add*/ \ + "fmla v4.4s, v8.4s, v18.4s \n" /* mul + add*/ \ + "fmla v5.4s, v8.4s, v20.4s \n" /* mul + add*/ \ + "fmla v6.4s, v8.4s, v22.4s \n" /* mul + add*/ \ + "fmla v7.4s, v8.4s, v24.4s \n" /* mul + add*/ \ + "subs %w[cnt], %w[cnt], #1 \n" /* sub main loop count */ \ + "fmla v0.4s, v9.4s, v11.4s \n" /* mul + add*/ \ + "fmla v1.4s, v9.4s, v13.4s \n" /* mul + add*/ \ + "fmla v2.4s, v9.4s, v15.4s \n" /* mul + add*/ \ + "fmla v3.4s, v9.4s, v17.4s \n" /* mul + add*/ \ + "fmla v4.4s, v9.4s, v19.4s \n" /* mul + add*/ \ + "fmla v5.4s, v9.4s, v21.4s \n" /* mul + add*/ \ + "fmla v6.4s, v9.4s, v23.4s \n" /* mul + add*/ \ + "fmla v7.4s, v9.4s, v25.4s \n" /* mul + add*/ \ + "bne 1b \n" /* jump to main loop */ /* pair add to final \ + result */ \ + "2: \n" /* reduce to scale */ \ + "faddp v16.4s, v0.4s, v0.4s\n" /* pair add to vector */ \ + "faddp s8, v16.2s \n" /* pair add to scale */ \ + "faddp v17.4s, v1.4s, v1.4s\n" /* pair add to vector */ \ + "faddp s9, v17.2s \n" /* pair add to scale */ \ + "faddp v18.4s, v2.4s, v2.4s\n" /* pair add to vector */ \ + "faddp s10, v18.2s \n" /* pair add to scale */ \ + "faddp v19.4s, v3.4s, v3.4s\n" /* pair add to vector */ \ + "faddp s11, v19.2s \n" /* pair add to scale */ \ + "faddp v20.4s, v4.4s, v4.4s\n" /* pair add to vector */ \ + "faddp s12, v20.2s \n" /* pair add to scale */ \ + "faddp v21.4s, v5.4s, v5.4s\n" /* pair add to vector */ \ + "faddp s13, v21.2s \n" /* pair add to scale */ \ + "faddp v22.4s, v6.4s, v6.4s\n" /* pair add to vector */ \ + "faddp s14, v22.2s \n" /* pair add to scale */ \ + "faddp v23.4s, v7.4s, v7.4s\n" /* pair add to vector */ \ + "faddp s15, v23.2s \n" /* pair add to scale */ \ + "cmp %w[tail], #1 \n" /* check whether has tail */ \ + "blt 4f \n" /* jump to end */ \ + "3: \n" /* tail loop */ \ + "ldr s16, [%[in]], #4 \n" /* load in, 1 float */ \ + "ldr s17, [%[w0]], #4 \n" /* load w0, 1 float */ \ + "ldr s18, [%[w1]], #4 \n" /* load w1, 1 float */ \ + "ldr s19, [%[w2]], #4 \n" /* load w2, 1 float */ \ + "ldr s20, [%[w3]], #4 \n" /* load w3, 1 float */ \ + "ldr s21, [%[w4]], #4 \n" /* load w4, 1 float */ \ + "ldr s22, [%[w5]], #4 \n" /* load w5, 1 float */ \ + "ldr s23, [%[w6]], #4 \n" /* load w6, 1 float */ \ + "ldr s24, [%[w7]], #4 \n" /* load w7, 1 float */ \ + "fmadd s8, s16, s17, s8 \n" /* mul + add */ \ + "fmadd s9, s16, s18, s9 \n" /* mul + add */ \ + "fmadd s10, s16, s19, s10 \n" /* mul + add */ \ + "fmadd s11, s16, s20, s11 \n" /* mul + add */ \ + "fmadd s12, s16, s21, s12 \n" /* mul + add */ \ + "fmadd s13, s16, s22, s13 \n" /* mul + add */ \ + "fmadd s14, s16, s23, s14 \n" /* mul + add */ \ + "fmadd s15, s16, s24, s15 \n" /* mul + add */ \ + "subs %w[tail], %w[tail], #1\n" /* sub tail loop count */ \ + "bne 3b \n" /* jump to tail loop */ + +#define SGEMV_KERNEL_1 \ + /* check main loop */ \ + "cmp %w[cnt], #1 \n" /* check whether has main loop */ \ + "blt 2f \n" /* jump to tail */ /* main loop */ \ + "1: \n" /* main loop */ \ + "ldp q8, q9, [%[in]], #32 \n" /* load input 8 float */ \ + "ldp q10, q11, [%[w0]], #32 \n" /* load w0 8 float */ \ + "fmla v0.4s, v8.4s, v10.4s \n" /* mul + add*/ \ + "subs %w[cnt], %w[cnt], #1 \n" /* sub main loop count */ \ + "fmla v1.4s, v9.4s, v11.4s \n" /* mul + add*/ \ + "bne 1b \n" /* jump to main loop */ /* pair add to final \ + result */ \ + "2: \n" /* reduce to scale */ \ + "fadd v9.4s, v0.4s, v1.4s \n" /* add 2 vector */ \ + "faddp v10.4s, v9.4s, v9.4s\n" /* pair add to vector */ \ + "faddp s8, v10.2s \n" /* pair add to scale */ /* check tails */ \ + "cmp %w[tail], #1 \n" /* check whether has tail */ \ + "blt 4f \n" /* jump to end */ \ + "3: \n" /* tail loop */ \ + "ldr s16, [%[in]], #4 \n" /* load in, 1 float */ \ + "ldr s17, [%[w0]], #4 \n" /* load w0, 1 float */ \ + "fmadd s8, s16, s17, s8 \n" /* mul + add */ \ + "subs %w[tail], %w[tail], #1\n" /* sub tail loop count */ \ + "bne 3b \n" /* jump to tail loop */ + +#define SGEMV_OUT_8 \ + /* end */ \ + "4: \n" /* end */ \ + "stp s8, s9, [%[out]] \n" /* save result */ \ + "stp s10, s11, [%[out], #8] \n" /* save result */ \ + "stp s12, s13, [%[out], #16]\n" /* save result */ \ + "stp s14, s15, [%[out], #24]\n" /* save result */ + +#define SGEMV_OUT_8_RELU \ + /* end */ \ + "4: \n" /* end */ \ + "movi d0, #0 \n" /* zero data for relu */ \ + "fmax s8, s8, s0 \n" /* relu */ \ + "fmax s9, s9, s0 \n" /* relu */ \ + "fmax s10, s10, s0 \n" /* relu */ \ + "fmax s11, s11, s0 \n" /* relu */ \ + "fmax s12, s12, s0 \n" /* relu */ \ + "fmax s13, s13, s0 \n" /* relu */ \ + "fmax s14, s14, s0 \n" /* relu */ \ + "fmax s15, s15, s0 \n" /* relu */ \ + "stp s8, s9, [%[out]] \n" /* save result */ \ + "stp s10, s11, [%[out], #8] \n" /* save result */ \ + "stp s12, s13, [%[out], #16]\n" /* save result */ \ + "stp s14, s15, [%[out], #24]\n" /* save result */ + +#define SGEMV_OUT_1 \ + /* end */ \ + "4: \n" /* end */ \ + "str s8, [%[out]] \n" /* save result */ + +#define SGEMV_OUT_1_RELU \ + /* end */ \ + "4: \n" /* end */ \ + "movi d0, #0 \n" /* zero data for relu */ \ + "fmax s8, s8, s0 \n" /* relu */ \ + "str s8, [%[out]] \n" /* save result */ + +#else //__aarch64__ + +#define SGEMV_IN_4 \ + "pld [%[in]] @ preload cache line, input\n" \ + "pld [%[w0]] @ preload cache line, weights r0\n" \ + "pld [%[w1]] @ preload cache line, weights r1\n" \ + "pld [%[w2]] @ preload cache line, weights r2\n" \ + "pld [%[w3]] @ preload cache line, weights r3\n" \ + "vmov.u32 q0, #0 @ set q0 to 0\n" \ + "vmov.u32 q1, #0 @ set q1 to 0\n" \ + "vmov.u32 q2, #0 @ set q2 to 0\n" \ + "vmov.u32 q3, #0 @ set q3 to 0\n" \ + "pld [%[w0], #64] @ preload cache line, weights r0\n" \ + "pld [%[w1], #64] @ preload cache line, weights r1\n" \ + "pld [%[w2], #64] @ preload cache line, weights r2\n" \ + "pld [%[w3], #64] @ preload cache line, weights r3\n" + +#define SGEMV_IN_4_BIAS \ + "pld [%[in]] @ preload cache line, input\n" \ + "pld [%[w0]] @ preload cache line, weights r0\n" \ + "pld [%[w1]] @ preload cache line, weights r1\n" \ + "pld [%[w2]] @ preload cache line, weights r2\n" \ + "pld [%[w3]] @ preload cache line, weights r3\n" \ + "vmov.u32 q0, #0 @ set q0 to 0\n" \ + "vmov.u32 q1, #0 @ set q1 to 0\n" \ + "vmov.u32 q2, #0 @ set q2 to 0\n" \ + "vmov.u32 q3, #0 @ set q3 to 0\n" \ + "vmov s0, %[bias0] @ set q0 to bias0\n" \ + "vmov s4, %[bias1] @ set q1 to bias1\n" \ + "vmov s8, %[bias2] @ set q2 to bias2\n" \ + "vmov s12,%[bias3] @ set q3 to bias3\n" \ + "pld [%[w0], #64] @ preload cache line, weights r0\n" \ + "pld [%[w1], #64] @ preload cache line, weights r1\n" \ + "pld [%[w2], #64] @ preload cache line, weights r2\n" \ + "pld [%[w3], #64] @ preload cache line, weights r3\n" + +#define SGEMV_IN_1 \ + "pld [%[in]] @ preload cache line, input\n" \ + "pld [%[w0]] @ preload cache line, weights r0\n" \ + "vmov.u32 q0, #0 @ set q0 to 0\n" + +#define SGEMV_IN_1_BIAS \ + "pld [%[in]] @ preload cache line, input\n" \ + "pld [%[w0]] @ preload cache line, weights r0\n" \ + "vmov.u32 q0, #0 @ set q0 to 0\n" \ + "vmov s0, %[bias0] @ set q0 to 0\n" + +#define SGEMV_KERNEL_4 \ + /* check main loop */ \ + "cmp %[cnt], #1 @ check whether has main loop\n" \ + "blt 2f @ jump to tail\n" \ + "1: @ main loop\n" \ + "vld1.32 {d8-d11}, [%[in]]! @ load input, q4, q5\n" \ + "vld1.32 {d12-d15}, [%[w0]]! @ load weights r0, q6,q7\n" \ + "vld1.32 {d16-d19}, [%[w1]]! @ load weights r1, q8,q9\n" \ + "vld1.32 {d20-d23}, [%[w2]]! @ load weights r2, q10,q11\n" \ + "vld1.32 {d24-d27}, [%[w3]]! @ load weights r3, q12,q13\n" \ + "vmla.f32 q0, q4, q6 @ mul add\n" \ + "vmla.f32 q1, q4, q8 @ mul add\n" \ + "vmla.f32 q2, q4, q10 @ mul add\n" \ + "vmla.f32 q3, q4, q12 @ mul add\n" \ + "subs %[cnt], #1 @ sub loop count \n" \ + "vmla.f32 q0, q5, q7 @ mul add\n" \ + "vmla.f32 q1, q5, q9 @ mul add\n" \ + "vmla.f32 q2, q5, q11 @ mul add\n" \ + "vmla.f32 q3, q5, q13 @ mul add\n" \ + "bne 1b @ jump to main loop\n" /* pair add to final \ + result */ \ + "2: @ pair add \n" \ + "vpadd.f32 d8, d0, d1 @ pair add, first step\n" \ + "vpadd.f32 d9, d2, d3 @ pair add, first step\n" \ + "vpadd.f32 d10, d4, d5 @ pair add, first step\n" \ + "vpadd.f32 d11, d6, d7 @ pair add, first step\n" \ + "vpadd.f32 d0, d8, d9 @ pair add, second step\n" \ + "vpadd.f32 d1, d10, d11 @ pair add, second step\n" /* check tails */ \ + "cmp %[tail], #1 @ check whether has tail\n" \ + "blt 4f @ jump to end\n" \ + "3: @ tail loop\n" \ + "vldm %[in]!, {s16} @ load 1 float\n" \ + "vldm %[w0]!, {s17} @ load 1 float\n" \ + "vldm %[w1]!, {s18} @ load 1 float\n" \ + "vldm %[w2]!, {s19} @ load 1 float\n" \ + "vldm %[w3]!, {s20} @ load 1 float\n" \ + "vmla.f32 s0, s16, s17 @ mul + add\n" \ + "vmla.f32 s1, s16, s18 @ mul + add\n" \ + "vmla.f32 s2, s16, s19 @ mul + add\n" \ + "vmla.f32 s3, s16, s20 @ mul + add\n" \ + "subs %[tail], #1 @ sub loop count \n" \ + "bne 3b @ jump to tail loop\n" + +#define SGEMV_KERNEL_1 \ + "cmp %[cnt], #1 @ check whether has main loop\n" \ + "blt 2f @ jump to tail\n" \ + "1: @ main loop\n" \ + "vld1.32 {d24-d27}, [%[in]]! @ load input, q12,q13\n" \ + "vld1.32 {d28-d31}, [%[w0]]! @ load weights r0, q14, q15\n" \ + "vmla.f32 q0, q12, q14 @ mul add\n" \ + "vmla.f32 q0, q13, q15 @ mul add\n" \ + "subs %[cnt] , #1 @ sub loop count \n" \ + "bne 1b @ jump to main loop\n" /* pair add to \ + final result \ + */ \ + "2: @ end processing\n" \ + "vpadd.f32 d2, d0, d1 @ pair add, first step\n" \ + "vpadd.f32 d0, d2, d2 @ pair add, final step\n" /* check tails \ + */ \ + "cmp %[tail], #1 @ check whether has mid cols\n" \ + "blt 4f @ jump to end\n" \ + "3: @ tail loop\n" \ + "vldm %[in]!, {s16} @ load 1 float\n" \ + "vldm %[w0]!, {s17} @ load 1 float\n" \ + "vmla.f32 s0, s16, s17 @ mul + add\n" \ + "subs %[tail], #1 @ sub loop count \n" \ + "bne 3b @ jump to tail loop\n" + +#define SGEMV_OUT_4 \ + /* end */ \ + "4: @ end\n" \ + "vst1.32 {d0-d1}, [%[out]] @ save result\n" + +#define SGEMV_OUT_4_RELU \ + /* end */ \ + "4: @ end\n" \ + "vmov.i32 q1, #0 @ zero for relu\n" \ + "vmax.f32 q0, q0, q1 @ relu\n" \ + "vst1.32 {d0-d1}, [%[out]] @ save result\n" + +#define SGEMV_OUT_1 \ + /* end */ \ + "4: @ end\n" \ + "vst1.32 {d0[0]}, [%[out]] @ save result\n" + +#define SGEMV_OUT_1_RELU \ + /* end */ \ + "4: @ end\n" \ + "vmov.i32 d1, #0 @ zero for relu\n" \ + "vmax.f32 d0, d0, d1 @ relu\n" \ + "vst1.32 {d0[0]}, [%[out]] @ save result\n" +#endif + +void sgemv(const bool transA, + const int M, + const int N, + const float *A, + const float *x, + float *y) { + float *data_out = y; + const float *data_in = x; + const float *weights_ptr = A; + + int cnt = N >> 3; + int tail = N & 7; + +#ifdef __aarch64__ + int out_cnt = M >> 3; + +#pragma omp parallel for + for (int j = 0; j < out_cnt; j++) { + int out_idx = j * 8; + float *ptr_out = data_out + out_idx; + const float *ptr_in = data_in; + const float *ptr_w0 = weights_ptr + (N * out_idx); + const float *ptr_w1 = ptr_w0 + N; + const float *ptr_w2 = ptr_w1 + N; + const float *ptr_w3 = ptr_w2 + N; + const float *ptr_w4 = ptr_w3 + N; + const float *ptr_w5 = ptr_w4 + N; + const float *ptr_w6 = ptr_w5 + N; + const float *ptr_w7 = ptr_w6 + N; + int cnt_loop = cnt; + int tail_loop = tail; + asm volatile(SGEMV_IN_8 SGEMV_KERNEL_8 SGEMV_OUT_8 + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [w1] "+r"(ptr_w1), + [w2] "+r"(ptr_w2), + [w3] "+r"(ptr_w3), + [w4] "+r"(ptr_w4), + [w5] "+r"(ptr_w5), + [w6] "+r"(ptr_w6), + [w7] "+r"(ptr_w7), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "cc", + "memory"); + } +//! deal with remains +#pragma omp parallel for + for (int j = out_cnt * 8; j < M; ++j) { + float *ptr_out = data_out + j; + const float *ptr_in = data_in; + const float *ptr_w0 = weights_ptr + (N * j); + int cnt_loop = cnt; + int tail_loop = tail; + float tmp[4]; + float tmp1[4]; + float tmp2[4]; + float tmp3[4]; + float tmp4[4]; + asm volatile( + SGEMV_IN_1 SGEMV_KERNEL_1 SGEMV_OUT_1 + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out), + [tmp] "r"(tmp), + [tmp1] "r"(tmp1), + [tmp2] "r"(tmp2), + [tmp3] "r"(tmp3), + [tmp4] "r"(tmp4) + : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory"); + } +#else //__aarch64__ + int out_cnt = M >> 2; +#pragma omp parallel for + for (int j = 0; j < out_cnt; j++) { + int out_idx = j * 4; + float *ptr_out = data_out + out_idx; + const float *ptr_in = data_in; + const float *ptr_w0 = weights_ptr + (N * out_idx); + const float *ptr_w1 = ptr_w0 + N; + const float *ptr_w2 = ptr_w1 + N; + const float *ptr_w3 = ptr_w2 + N; + + int cnt_loop = cnt; + int tail_loop = tail; + asm volatile(SGEMV_IN_4 SGEMV_KERNEL_4 SGEMV_OUT_4 + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [w1] "+r"(ptr_w1), + [w2] "+r"(ptr_w2), + [w3] "+r"(ptr_w3), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out) + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "cc", + "memory"); + } +//! deal with remains +#pragma omp parallel for + for (int j = out_cnt * 4; j < M; ++j) { + float *ptr_out = data_out + j; + const float *ptr_in = data_in; + const float *ptr_w0 = weights_ptr + (N * j); + int cnt_loop = cnt; + int tail_loop = tail; + asm volatile(SGEMV_IN_1 SGEMV_KERNEL_1 SGEMV_OUT_1 + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out) + : "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory"); + } +#endif //__aarch64__ +} + +void sgemv_relu(const bool transA, + const int M, + const int N, + const float *A, + const float *x, + float *y) { + float *data_out = y; + const float *data_in = x; + const float *weights_ptr = A; + + int cnt = N >> 3; + int tail = N & 7; + +#ifdef __aarch64__ + int out_cnt = M >> 3; +#pragma omp parallel for + for (int j = 0; j < out_cnt; j++) { + int out_idx = j * 8; + float *ptr_out = data_out + out_idx; + const float *ptr_in = data_in; + const float *ptr_w0 = weights_ptr + (N * out_idx); + const float *ptr_w1 = ptr_w0 + N; + const float *ptr_w2 = ptr_w1 + N; + const float *ptr_w3 = ptr_w2 + N; + const float *ptr_w4 = ptr_w3 + N; + const float *ptr_w5 = ptr_w4 + N; + const float *ptr_w6 = ptr_w5 + N; + const float *ptr_w7 = ptr_w6 + N; + int cnt_loop = cnt; + int tail_loop = tail; + asm volatile(SGEMV_IN_8 SGEMV_KERNEL_8 SGEMV_OUT_8_RELU + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [w1] "+r"(ptr_w1), + [w2] "+r"(ptr_w2), + [w3] "+r"(ptr_w3), + [w4] "+r"(ptr_w4), + [w5] "+r"(ptr_w5), + [w6] "+r"(ptr_w6), + [w7] "+r"(ptr_w7), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "cc", + "memory"); + } +//! deal with remains +#pragma omp parallel for + for (int j = out_cnt * 8; j < M; ++j) { + float *ptr_out = data_out + j; + const float *ptr_in = data_in; + const float *ptr_w0 = weights_ptr + (N * j); + int cnt_loop = cnt; + int tail_loop = tail; + asm volatile( + SGEMV_IN_1 SGEMV_KERNEL_1 SGEMV_OUT_1_RELU + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out) + : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory"); + } +#else //__aarch64__ + int out_cnt = M >> 2; +#pragma omp parallel for + for (int j = 0; j < out_cnt; j++) { + int out_idx = j * 4; + float *ptr_out = data_out + out_idx; + const float *ptr_in = data_in; + const float *ptr_w0 = weights_ptr + (N * out_idx); + const float *ptr_w1 = ptr_w0 + N; + const float *ptr_w2 = ptr_w1 + N; + const float *ptr_w3 = ptr_w2 + N; + + int cnt_loop = cnt; + int tail_loop = tail; + asm volatile(SGEMV_IN_4 SGEMV_KERNEL_4 SGEMV_OUT_4_RELU + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [w1] "+r"(ptr_w1), + [w2] "+r"(ptr_w2), + [w3] "+r"(ptr_w3), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out) + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "cc", + "memory"); + } +//! deal with remains +#pragma omp parallel for + for (int j = out_cnt * 4; j < M; ++j) { + float *ptr_out = data_out + j; + const float *ptr_in = data_in; + const float *ptr_w0 = weights_ptr + (N * j); + int cnt_loop = cnt; + int tail_loop = tail; + asm volatile(SGEMV_IN_1 SGEMV_KERNEL_1 SGEMV_OUT_1_RELU + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out) + : "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory"); + } +#endif //__aarch64__ +} + +void sgemv_bias(const bool transA, + const int M, + const int N, + const float *A, + const float *x, + float *y, + const float *bias) { + float *data_out = y; + const float *data_in = x; + const float *weights_ptr = A; + + int cnt = N >> 3; + int tail = N & 7; + +#ifdef __aarch64__ + int out_cnt = M >> 3; +#pragma omp parallel for + for (int j = 0; j < out_cnt; j++) { + int out_idx = j * 8; + float *ptr_out = data_out + out_idx; + const float *ptr_in = data_in; + const float *ptr_w0 = weights_ptr + (N * out_idx); + const float *ptr_w1 = ptr_w0 + N; + const float *ptr_w2 = ptr_w1 + N; + const float *ptr_w3 = ptr_w2 + N; + const float *ptr_w4 = ptr_w3 + N; + const float *ptr_w5 = ptr_w4 + N; + const float *ptr_w6 = ptr_w5 + N; + const float *ptr_w7 = ptr_w6 + N; + const float *bias_ptr = bias + out_idx; + int cnt_loop = cnt; + int tail_loop = tail; + asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8 + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [w1] "+r"(ptr_w1), + [w2] "+r"(ptr_w2), + [w3] "+r"(ptr_w3), + [w4] "+r"(ptr_w4), + [w5] "+r"(ptr_w5), + [w6] "+r"(ptr_w6), + [w7] "+r"(ptr_w7), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out), [bias_ptr] "r"(bias_ptr) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "cc", + "memory"); + } +//! deal with remains +#pragma omp parallel for + for (int j = out_cnt * 8; j < M; ++j) { + float *ptr_out = data_out + j; + const float *ptr_in = data_in; + const float *ptr_w0 = weights_ptr + (N * j); + int cnt_loop = cnt; + int tail_loop = tail; + float bias0 = bias[j]; + asm volatile( + SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1 + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out), [bias0] "r"(bias0) + : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory"); + } +#else //__aarch64__ + int out_cnt = M >> 2; +#pragma omp parallel for + for (int j = 0; j < out_cnt; j++) { + int out_idx = j * 4; + float *ptr_out = data_out + out_idx; + const float *ptr_in = data_in; + const float *ptr_w0 = weights_ptr + (N * out_idx); + const float *ptr_w1 = ptr_w0 + N; + const float *ptr_w2 = ptr_w1 + N; + const float *ptr_w3 = ptr_w2 + N; + float bias0 = bias[out_idx]; + float bias1 = bias[out_idx + 1]; + float bias2 = bias[out_idx + 2]; + float bias3 = bias[out_idx + 3]; + + int cnt_loop = cnt; + int tail_loop = tail; + asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4 + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [w1] "+r"(ptr_w1), + [w2] "+r"(ptr_w2), + [w3] "+r"(ptr_w3), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out), + [bias0] "r"(bias0), + [bias1] "r"(bias1), + [bias2] "r"(bias2), + [bias3] "r"(bias3) + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "cc", + "memory"); + } +//! deal with remains +#pragma omp parallel for + for (int j = out_cnt * 4; j < M; ++j) { + float *ptr_out = data_out + j; + const float *ptr_in = data_in; + const float *ptr_w0 = weights_ptr + (N * j); + int cnt_loop = cnt; + int tail_loop = tail; + float bias0 = bias[j]; + asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1 + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out), [bias0] "r"(bias0) + : "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory"); + } +#endif //__aarch64__ +} + +void sgemv_bias_relu(const bool transA, + const int M, + const int N, + const float *A, + const float *x, + float *y, + const float *bias) { + float *data_out = y; + const float *data_in = x; + const float *weights_ptr = A; + int cnt = N >> 3; + int tail = N & 7; +#ifdef __aarch64__ + int out_cnt = M >> 3; +#pragma omp parallel for + for (int j = 0; j < out_cnt; j++) { + int out_idx = j * 8; + float *ptr_out = data_out + out_idx; + const float *ptr_in = data_in; + const float *ptr_w0 = weights_ptr + (N * out_idx); + const float *ptr_w1 = ptr_w0 + N; + const float *ptr_w2 = ptr_w1 + N; + const float *ptr_w3 = ptr_w2 + N; + const float *ptr_w4 = ptr_w3 + N; + const float *ptr_w5 = ptr_w4 + N; + const float *ptr_w6 = ptr_w5 + N; + const float *ptr_w7 = ptr_w6 + N; + const float *bias_ptr = bias + out_idx; + int cnt_loop = cnt; + int tail_loop = tail; + asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8_RELU + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [w1] "+r"(ptr_w1), + [w2] "+r"(ptr_w2), + [w3] "+r"(ptr_w3), + [w4] "+r"(ptr_w4), + [w5] "+r"(ptr_w5), + [w6] "+r"(ptr_w6), + [w7] "+r"(ptr_w7), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out), [bias_ptr] "r"(bias_ptr) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "cc", + "memory"); + } +//! deal with remains +#pragma omp parallel for + for (int j = out_cnt * 8; j < M; ++j) { + float *ptr_out = data_out + j; + const float *ptr_in = data_in; + const float *ptr_w0 = weights_ptr + (N * j); + int cnt_loop = cnt; + int tail_loop = tail; + float bias0 = bias[j]; + asm volatile( + SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out), [bias0] "r"(bias0) + : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory"); + } +#else //__aarch64__ + int out_cnt = M >> 2; +#pragma omp parallel for + for (int j = 0; j < out_cnt; j++) { + int out_idx = j * 4; + float *ptr_out = data_out + out_idx; + const float *ptr_in = data_in; + const float *ptr_w0 = weights_ptr + (N * out_idx); + const float *ptr_w1 = ptr_w0 + N; + const float *ptr_w2 = ptr_w1 + N; + const float *ptr_w3 = ptr_w2 + N; + float bias0 = bias[out_idx]; + float bias1 = bias[out_idx + 1]; + float bias2 = bias[out_idx + 2]; + float bias3 = bias[out_idx + 3]; + + int cnt_loop = cnt; + int tail_loop = tail; + asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4_RELU + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [w1] "+r"(ptr_w1), + [w2] "+r"(ptr_w2), + [w3] "+r"(ptr_w3), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out), + [bias0] "r"(bias0), + [bias1] "r"(bias1), + [bias2] "r"(bias2), + [bias3] "r"(bias3) + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "cc", + "memory"); + } +//! deal with remains +#pragma omp parallel for + for (int j = out_cnt * 4; j < M; ++j) { + float *ptr_out = data_out + j; + const float *ptr_in = data_in; + const float *ptr_w0 = weights_ptr + (N * j); + int cnt_loop = cnt; + int tail_loop = tail; + float bias0 = bias[j]; + asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out), [bias0] "r"(bias0) + : "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory"); + } +#endif //__aarch64__ +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/sgemv.h b/lite/arm/math/sgemv.h new file mode 100644 index 00000000000..4d74006f932 --- /dev/null +++ b/lite/arm/math/sgemv.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +// TODO(xxx): fixme now only support transA = false +bool sgemv(const float* A, + const float* x, + float* y, + bool transA, + int M, + int N, + bool is_bias = false, + const float* bias = nullptr, + bool is_relu = false); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/shuffle_channel.cc b/lite/arm/math/shuffle_channel.cc new file mode 100644 index 00000000000..bae03e90214 --- /dev/null +++ b/lite/arm/math/shuffle_channel.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/shuffle_channel.h" +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void shuffle_kernel( + Dtype* output, const Dtype* input, int group_row, int group_col, int len) { + for (int i = 0; i < group_row; ++i) { + for (int j = 0; j < group_col; ++j) { + const Dtype* p_i = input + (i * group_col + j) * len; + Dtype* p_o = output + (j * group_row + i) * len; + memcpy(p_o, p_i, len * sizeof(Dtype)); + } + } +} + +template <> +void shuffle_channel(const float* inputs, + float* outputs, + int group, + int num, + int channel, + int height, + int width) { + int fea_size = channel * height * width; + int spatial_size = height * width; + int group_row = group; + int group_col = channel / group; + for (int i = 0; i < num; ++i) { + shuffle_kernel(outputs + i * fea_size, + inputs + i * fea_size, + group_row, + group_col, + spatial_size); + } +} + +template <> +void shuffle_channel(const char* inputs, + char* outputs, + int group, + int num, + int channel, + int height, + int width) { + int fea_size = channel * height * width; + int spatial_size = height * width; + int group_row = group; + int group_col = channel / group; + for (int i = 0; i < num; ++i) { + shuffle_kernel(outputs + i * fea_size, + inputs + i * fea_size, + group_row, + group_col, + spatial_size); + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/shuffle_channel.h b/lite/arm/math/shuffle_channel.h new file mode 100644 index 00000000000..d0c8b7b81eb --- /dev/null +++ b/lite/arm/math/shuffle_channel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void shuffle_channel(const T* inputs, + T* outputs, + int group, + int num, + int channel, + int height, + int width); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/slice.cc b/lite/arm/math/slice.cc new file mode 100644 index 00000000000..f9251181225 --- /dev/null +++ b/lite/arm/math/slice.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/slice.h" +#include +#include +#include +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void slice(const Dtype* input, + std::vector in_dims, + std::vector axes, + std::vector starts, + std::vector ends, + Dtype* out, + Context* ctx) { + auto out_dims = in_dims; + std::vector real_starts(in_dims.size(), 0); + std::vector real_ends(in_dims.size(), 0); + std::vector real_step(in_dims.size(), 0); + for (int i = 0; i < in_dims.size(); i++) { + real_ends[i] = in_dims[i]; + } + for (int i = 0; i < axes.size(); i++) { + int dim_value = in_dims[axes[i]]; + if (dim_value > 0) { + int start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i]; + int end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i]; + start = std::max(start, 0); + end = std::max(end, 0); + end = std::min(end, dim_value); + out_dims[axes[i]] = end - start; + real_starts[axes[i]] = start; + real_ends[axes[i]] = end; + } + } + const int LEN = in_dims.size() - 1; + int dst_step[LEN]; + for (int i = 0; i < in_dims.size(); ++i) { + dst_step[i] = 1; + } + int src_step[LEN]; + for (int i = 0; i < in_dims.size(); ++i) { + src_step[i] = 1; + } + int out_num = out_dims[in_dims.size() - 1]; + for (int i = in_dims.size() - 2; i >= 0; i--) { + dst_step[i] = out_dims[i] * dst_step[i + 1]; + src_step[i] = in_dims[i] * src_step[i + 1]; + out_num *= out_dims[i]; + } + + for (int dst_id = 0; dst_id < out_num; dst_id++) { + int src_id = 0; + for (int j = 0; j < out_dims.size(); j++) { + int cur_id = dst_id / dst_step[j]; + src_id += (cur_id + real_starts[j]) * src_step[j]; + } + out[dst_id] = input[src_id]; + } +} + +template void slice(const int* input, + std::vector dims, + std::vector axes, + std::vector starts, + std::vector ends, + int* out, + Context* ctx); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/slice.h b/lite/arm/math/slice.h new file mode 100644 index 00000000000..86172d28a7a --- /dev/null +++ b/lite/arm/math/slice.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/context.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void slice(const Dtype* input, + std::vector dims, + std::vector axes, + std::vector starts, + std::vector ends, + Dtype* out, + Context* ctx); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/softmax.cc b/lite/arm/math/softmax.cc new file mode 100644 index 00000000000..c37f66974e6 --- /dev/null +++ b/lite/arm/math/softmax.cc @@ -0,0 +1,616 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/softmax.h" +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void softmax_basic(const float* din, + float* dout, + const int axis_size, + const int inner_num, + const int outer_num) { + int compute_size = inner_num * outer_num; +#pragma omp parallel for + for (int i = 0; i < compute_size; ++i) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + float max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + float sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + float sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } +} + +template <> +void softmax_inner8_axis4(const float* din, + float* dout, + const int axis_size, + const int inner_num, + const int outer_num) { + int compute_size = inner_num * outer_num; + int cmp_cnt = compute_size >> 3; + int remain = compute_size % 8; + float32x4_t vone = vdupq_n_f32(1.0f); + +#pragma omp parallel for + for (int c = 0; c < cmp_cnt; ++c) { + int i = c * 8; + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + // get max axis_size == 4 + const float* din_ptr = din + real_index; + const float* din_ptr1 = din_ptr + inner_num; + const float* din_ptr2 = din_ptr1 + inner_num; + const float* din_ptr3 = din_ptr2 + inner_num; + float32x4_t vdata0 = vld1q_f32(din_ptr); + float32x4_t vdata1 = vld1q_f32(din_ptr1); + float32x4_t vdata2 = vld1q_f32(din_ptr2); + float32x4_t vdata3 = vld1q_f32(din_ptr3); + + float32x4_t vdata01 = vld1q_f32(din_ptr + 4); + float32x4_t vdata11 = vld1q_f32(din_ptr1 + 4); + float32x4_t vdata21 = vld1q_f32(din_ptr2 + 4); + float32x4_t vdata31 = vld1q_f32(din_ptr3 + 4); + + float* dout_ptr0 = dout + real_index; + float* dout_ptr1 = dout_ptr0 + inner_num; + float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1); + float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3); + float32x4_t vmax11 = vmaxq_f32(vdata01, vdata11); + float32x4_t vmax21 = vmaxq_f32(vdata21, vdata31); + float* dout_ptr2 = dout_ptr1 + inner_num; + float* dout_ptr3 = dout_ptr2 + inner_num; + float32x4_t vmax = vmaxq_f32(vmax1, vmax2); + float32x4_t vmax_1 = vmaxq_f32(vmax11, vmax21); + + // sub, exp and sum + float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax)); + float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax)); + float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax)); + float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax)); + + float32x4_t vsum01 = exp_ps(vsubq_f32(vdata01, vmax_1)); + float32x4_t vsum11 = exp_ps(vsubq_f32(vdata11, vmax_1)); + float32x4_t vsum21 = exp_ps(vsubq_f32(vdata21, vmax_1)); + float32x4_t vsum31 = exp_ps(vsubq_f32(vdata31, vmax_1)); + + float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1); + float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3); + float32x4_t vsum_11 = vaddq_f32(vsum01, vsum11); + float32x4_t vsum_21 = vaddq_f32(vsum21, vsum31); + + float32x4_t vsum = vaddq_f32(vsum_1, vsum_2); + float32x4_t vsum111 = vaddq_f32(vsum_11, vsum_21); + + float32x4_t vinf = div_ps(vone, vsum); + float32x4_t vinf1 = div_ps(vone, vsum111); + + vsum0 = vmulq_f32(vsum0, vinf); + vsum1 = vmulq_f32(vsum1, vinf); + vsum2 = vmulq_f32(vsum2, vinf); + vsum3 = vmulq_f32(vsum3, vinf); + + vsum01 = vmulq_f32(vsum01, vinf1); + vsum11 = vmulq_f32(vsum11, vinf1); + vsum21 = vmulq_f32(vsum21, vinf1); + vsum31 = vmulq_f32(vsum31, vinf1); + + vst1q_f32(dout_ptr0, vsum0); + vst1q_f32(dout_ptr1, vsum1); + vst1q_f32(dout_ptr2, vsum2); + vst1q_f32(dout_ptr3, vsum3); + + vst1q_f32(dout_ptr0 + 4, vsum01); + vst1q_f32(dout_ptr1 + 4, vsum11); + vst1q_f32(dout_ptr2 + 4, vsum21); + vst1q_f32(dout_ptr3 + 4, vsum31); + } + + int i = cmp_cnt * 8; + + if (remain > 4) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + // get max axis_size == 4 + const float* din_ptr = din + real_index; + const float* din_ptr1 = din_ptr + inner_num; + const float* din_ptr2 = din_ptr1 + inner_num; + const float* din_ptr3 = din_ptr2 + inner_num; + float32x4_t vdata0 = vld1q_f32(din_ptr); + float32x4_t vdata1 = vld1q_f32(din_ptr1); + float32x4_t vdata2 = vld1q_f32(din_ptr2); + float32x4_t vdata3 = vld1q_f32(din_ptr3); + + float* dout_ptr0 = dout + real_index; + float* dout_ptr1 = dout_ptr0 + inner_num; + float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1); + float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3); + float* dout_ptr2 = dout_ptr1 + inner_num; + float* dout_ptr3 = dout_ptr2 + inner_num; + float32x4_t vmax = vmaxq_f32(vmax1, vmax2); + + // sub, exp and sum + float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax)); + float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax)); + float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax)); + float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax)); + + float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1); + float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3); + + float32x4_t vsum = vaddq_f32(vsum_1, vsum_2); + + float32x4_t vone = vdupq_n_f32(1.0f); + float32x4_t vinf = div_ps(vone, vsum); + + vsum0 = vmulq_f32(vsum0, vinf); + vsum1 = vmulq_f32(vsum1, vinf); + vsum2 = vmulq_f32(vsum2, vinf); + vsum3 = vmulq_f32(vsum3, vinf); + + vst1q_f32(dout_ptr0, vsum0); + vst1q_f32(dout_ptr1, vsum1); + vst1q_f32(dout_ptr2, vsum2); + vst1q_f32(dout_ptr3, vsum3); + + i += 4; + } + for (; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + float max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + float sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + float sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } +} + +template <> +void softmax_inner4_axis4(const float* din, + float* dout, + const int axis_size, + const int inner_num, + const int outer_num) { + int compute_size = inner_num * outer_num; + int cmp_cnt = compute_size >> 2; + int remain = compute_size % 4; + float32x4_t vone = vdupq_n_f32(1.0f); + +#pragma omp parallel for + for (int c = 0; c < cmp_cnt; ++c) { + int i = c * 4; + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + // get max axis_size == 4 + const float* din_ptr = din + real_index; + const float* din_ptr1 = din_ptr + inner_num; + const float* din_ptr2 = din_ptr1 + inner_num; + const float* din_ptr3 = din_ptr2 + inner_num; + float32x4_t vdata0 = vld1q_f32(din_ptr); + float32x4_t vdata1 = vld1q_f32(din_ptr1); + float32x4_t vdata2 = vld1q_f32(din_ptr2); + float32x4_t vdata3 = vld1q_f32(din_ptr3); + + float* dout_ptr0 = dout + real_index; + float* dout_ptr1 = dout_ptr0 + inner_num; + float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1); + float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3); + float* dout_ptr2 = dout_ptr1 + inner_num; + float* dout_ptr3 = dout_ptr2 + inner_num; + float32x4_t vmax = vmaxq_f32(vmax1, vmax2); + + // sub, exp and sum + float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax)); + float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax)); + float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax)); + float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax)); + + float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1); + float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3); + + float32x4_t vsum = vaddq_f32(vsum_1, vsum_2); + + float32x4_t vinf = div_ps(vone, vsum); + + vsum0 = vmulq_f32(vsum0, vinf); + vsum1 = vmulq_f32(vsum1, vinf); + vsum2 = vmulq_f32(vsum2, vinf); + vsum3 = vmulq_f32(vsum3, vinf); + + vst1q_f32(dout_ptr0, vsum0); + vst1q_f32(dout_ptr1, vsum1); + vst1q_f32(dout_ptr2, vsum2); + vst1q_f32(dout_ptr3, vsum3); + } + + int i = cmp_cnt * 8; + for (; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + float max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + float sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + float sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } +} + +template <> +void softmax_inner8(const float* din, + float* dout, + const int axis_size, + const int inner_num, + const int outer_num) { + int compute_size = inner_num * outer_num; + int cmp_cnt = compute_size >> 3; +#pragma omp parallel for + for (int c = 0; c < cmp_cnt; ++c) { + int i = c * 8; + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + const float* din_ptr = din + real_index; + float32x4_t vmax = vld1q_f32(din_ptr); + float32x4_t vmax2 = vld1q_f32(din_ptr + 4); + // get max + for (int j = 1; j < axis_size; ++j) { + din_ptr += inner_num; + float32x4_t vdata = vld1q_f32(din_ptr); + float32x4_t vdata2 = vld1q_f32(din_ptr + 4); + vmax = vmaxq_f32(vmax, vdata); + vmax2 = vmaxq_f32(vmax2, vdata2); + } + + // sub, exp and sum + din_ptr = din + real_index; + float* dout_ptr = dout + real_index; + float32x4_t vdata = vld1q_f32(din_ptr); + float32x4_t vdata2 = vld1q_f32(din_ptr + 4); + float32x4_t vsum = exp_ps(vsubq_f32(vdata, vmax)); + float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax2)); + din_ptr += inner_num; + vst1q_f32(dout_ptr, vsum); + vst1q_f32(dout_ptr + 4, vsum2); + dout_ptr += inner_num; + for (int j = 1; j < axis_size; ++j) { + float32x4_t vdata0 = vld1q_f32(din_ptr); + float32x4_t vdata1 = vld1q_f32(din_ptr + 4); + vdata0 = exp_ps(vsubq_f32(vdata0, vmax)); + vdata1 = exp_ps(vsubq_f32(vdata1, vmax2)); + din_ptr += inner_num; + vsum = vaddq_f32(vsum, vdata0); + vsum2 = vaddq_f32(vsum2, vdata1); + vst1q_f32(dout_ptr, vdata0); + vst1q_f32(dout_ptr + 4, vdata1); + dout_ptr += inner_num; + } + + float32x4_t vone = vdupq_n_f32(1.0f); + float32x4_t vinf = div_ps(vone, vsum); + float32x4_t vinf2 = div_ps(vone, vsum2); + dout_ptr = dout + real_index; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + float32x4_t vdata0 = vld1q_f32(dout_ptr); + float32x4_t vdata1 = vld1q_f32(dout_ptr + 4); + vdata0 = vmulq_f32(vdata0, vinf); + vdata1 = vmulq_f32(vdata1, vinf2); + vst1q_f32(dout_ptr, vdata0); + vst1q_f32(dout_ptr + 4, vdata1); + dout_ptr += inner_num; + } + } + + for (int i = cmp_cnt * 8; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + float max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + float sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + float sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } +} + +template <> +void softmax_inner4(const float* din, + float* dout, + const int axis_size, + const int inner_num, + const int outer_num) { + int compute_size = inner_num * outer_num; + int cmp_cnt = compute_size >> 2; +#pragma omp parallel for + for (int c = 0; c < cmp_cnt; ++c) { + int i = c * 4; + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + // float max_data = din[real_index]; + const float* din_ptr = din + real_index; + float32x4_t vmax = vld1q_f32(din_ptr); + // get max + for (int j = 1; j < axis_size; ++j) { + din_ptr += inner_num; + float32x4_t vdata = vld1q_f32(din_ptr); + vmax = vmaxq_f32(vmax, vdata); + } + // sub, exp and sum + din_ptr = din + real_index; + float* dout_ptr = dout + real_index; + float32x4_t vdata = vld1q_f32(din_ptr); + float32x4_t vsum = exp_ps(vsubq_f32(vdata, vmax)); + din_ptr += inner_num; + vst1q_f32(dout_ptr, vsum); + dout_ptr += inner_num; + for (int j = 1; j < axis_size; ++j) { + // real_index += inner_num; + float32x4_t vdata0 = vld1q_f32(din_ptr); + vdata0 = exp_ps(vsubq_f32(vdata0, vmax)); + din_ptr += inner_num; + vsum = vaddq_f32(vsum, vdata0); + vst1q_f32(dout_ptr, vdata0); + dout_ptr += inner_num; + } + + float32x4_t vone = vdupq_n_f32(1.0f); + float32x4_t vinf = div_ps(vone, vsum); + dout_ptr = dout + real_index; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + float32x4_t vdata0 = vld1q_f32(dout_ptr); + vdata0 = vmulq_f32(vdata0, vinf); + vst1q_f32(dout_ptr, vdata0); + dout_ptr += inner_num; + } + } + + for (int i = cmp_cnt * 4; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + float max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + float sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + float sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } +} + +template <> +void softmax_inner1_large_axis(const float* din, + float* dout, + const int outer_size, + const int axis_size) { +#pragma omp parallel for + for (int i = 0; i < outer_size; ++i) { + const float* din_ptr = din + i * axis_size; + float* dout_ptr = dout + i * axis_size; + + const float* din_max_ptr = din_ptr; + int nn = axis_size >> 2; + + // get max + float32x4_t vmax = vld1q_f32(din_max_ptr); + din_max_ptr += 4; + int j = 1; + for (; j < nn; ++j) { + vmax = vmaxq_f32(vmax, vld1q_f32(din_max_ptr)); + din_max_ptr += 4; + } + float32x2_t vhmax = vmax_f32(vget_high_f32(vmax), vget_low_f32(vmax)); + float max_data = std::max(vget_lane_f32(vhmax, 0), vget_lane_f32(vhmax, 1)); + for (j = 4 * j; j < axis_size; ++j) { + max_data = std::max(max_data, din_max_ptr[0]); + din_max_ptr++; + } + + // sub, exp and sum + const float* din_sum_ptr = din_ptr; + float* dout_sum_ptr = dout_ptr; + vmax = vdupq_n_f32(max_data); + float32x4_t vsub_exp = exp_ps(vsubq_f32(vld1q_f32(din_sum_ptr), vmax)); + float32x4_t vsum = vsub_exp; + vst1q_f32(dout_sum_ptr, vsub_exp); + din_sum_ptr += 4; + dout_sum_ptr += 4; + + j = 1; + for (; j < nn; ++j) { + vsub_exp = exp_ps(vsubq_f32(vld1q_f32(din_sum_ptr), vmax)); + vst1q_f32(dout_sum_ptr, vsub_exp); + vsum = vaddq_f32(vsum, vsub_exp); + din_sum_ptr += 4; + dout_sum_ptr += 4; + } + float32x2_t vhsum = vadd_f32(vget_high_f32(vsum), vget_low_f32(vsum)); + float sum_data = vget_lane_f32(vhsum, 0) + vget_lane_f32(vhsum, 1); + + for (j = 4 * j; j < axis_size; ++j) { + dout_sum_ptr[0] = expf(din_sum_ptr[0] - max_data); + sum_data += dout_sum_ptr[0]; + din_sum_ptr++; + dout_sum_ptr++; + } + + float sum_inv = 1.f / sum_data; + float* dout_res_ptr = dout_ptr; + float32x4_t vinv = vdupq_n_f32(sum_inv); + // get softmax result + j = 0; + for (; j < nn; ++j) { + float32x4_t vout = vld1q_f32(dout_res_ptr); + float32x4_t vres = vmulq_f32(vout, vinv); + vst1q_f32(dout_res_ptr, vres); + dout_res_ptr += 4; + } + for (j = nn * 4; j < axis_size; ++j) { + dout_ptr[j] *= sum_inv; + } + } +} + +template <> +void softmax_inner1_small_axis(const float* din, + float* dout, + const int outer_size, + const int axis_size) { +#pragma omp parallel for + for (int i = 0; i < outer_size; ++i) { + const float* din_ptr = din + i * axis_size; + float* dout_ptr = dout + i * axis_size; + // get max + float max_data = din_ptr[0]; + for (int j = 1; j < axis_size; ++j) { + max_data = std::max(max_data, din_ptr[j]); + } + + // sub, exp and sum + float sum_data = 0.f; + for (int j = 0; j < axis_size; ++j) { + dout_ptr[j] = expf(din_ptr[j] - max_data); + sum_data += dout_ptr[j]; + } + + float sum_inv = 1.f / sum_data; + for (int j = 0; j < axis_size; ++j) { + dout_ptr[j] *= sum_inv; + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/softmax.h b/lite/arm/math/softmax.h new file mode 100644 index 00000000000..cc1957a73ec --- /dev/null +++ b/lite/arm/math/softmax.h @@ -0,0 +1,71 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void softmax_basic(const T* din, + T* dout, + const int axis_size, + const int inner_num, + const int outer_num); + +template +void softmax_inner8_axis4(const T* din, + T* dout, + const int axis_size, + const int inner_num, + const int outer_num); + +template +void softmax_inner4_axis4(const T* din, + T* dout, + const int axis_size, + const int inner_num, + const int outer_num); +template +void softmax_inner8(const T* din, + T* dout, + const int axis_size, + const int inner_num, + const int outer_num); + +template +void softmax_inner4(const T* din, + T* dout, + const int axis_size, + const int inner_num, + const int outer_num); + +template +void softmax_inner1_large_axis(const T* din, + T* dout, + const int outer_size, + const int axis_size); + +template +void softmax_inner1_small_axis(const T* din, + T* dout, + const int outer_size, + const int axis_size); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/split.cc b/lite/arm/math/split.cc new file mode 100644 index 00000000000..ae622fc3b69 --- /dev/null +++ b/lite/arm/math/split.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/split.h" +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void split_cpy(const float* din, float* dout, int num) { + int cnt = num >> 4; + int remain = num % 16; +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* din_ptr = din + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + } + if (remain > 0) { + const float* din_ptr = din + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + *dout_ptr = *din_ptr; + dout_ptr++; + din_ptr++; + } + } +} + +template <> +void split(const float* din, + const std::vector& dout, + const int axis, + const std::vector& in_strides) { + int input_offset = 0; + for (auto out : dout) { + auto out_dim = out->dims(); + std::vector out_strides(out_dim.size()); + out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; + for (int i = out_dim.size() - 2; i >= 0; --i) { + out_strides[i] = out_strides[i + 1] * out_dim[i]; + } + + float* out_data = out->mutable_data(); + int before = out_strides[0] / out_strides[axis]; + int in_after = in_strides[axis]; + int out_after = out_strides[axis]; + + for (int i = 0; i < before; ++i) { + split_cpy(din + input_offset + i * in_after, + out_data + i * out_after, + out_after); + } + input_offset += out_strides[axis]; + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/split.h b/lite/arm/math/split.h new file mode 100644 index 00000000000..2c6f392cc50 --- /dev/null +++ b/lite/arm/math/split.h @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void split_cpy(const T* din, T* dout, int num); + +template +void split(const T* din, + const std::vector& dout, + const int axis, + const std::vector& in_strides); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/topk.cc b/lite/arm/math/topk.cc new file mode 100644 index 00000000000..741bb2561a5 --- /dev/null +++ b/lite/arm/math/topk.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/topk.h" +#include +#include +#include "lite/arm/math/funcs.h" +namespace paddle { +namespace lite { +namespace arm { +namespace math { +bool comp_func(std::pair a, std::pair b) { + return (a.first > b.first); +} + +void topk(const float* in_data, + float* out_val, + int* out_ind, + int m, + int n, + int k, + Context* ctx) { + for (int i = 0; i < m; i++) { + const float* in_tmp = in_data + i * n; + float* out_val_tmp = out_val + i * k; + int* out_ind_tmp = out_ind + i * k; + std::vector> vec; + for (int j = 0; j < n; j++) { + vec.push_back(std::make_pair(in_tmp[j], j)); + } + std::partial_sort(vec.begin(), vec.begin() + k, vec.end(), comp_func); + for (int q = 0; q < k; q++) { + out_val_tmp[q] = vec[q].first; + out_ind_tmp[q] = vec[q].second; + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/topk.h b/lite/arm/math/topk.h new file mode 100644 index 00000000000..5bf472e1af4 --- /dev/null +++ b/lite/arm/math/topk.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/context.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void topk(const float* din, + float* out_val, + int* out_ind, + int m, + int n, + int k, + Context* ctx); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/type_trans.cc b/lite/arm/math/type_trans.cc new file mode 100644 index 00000000000..86be0099c30 --- /dev/null +++ b/lite/arm/math/type_trans.cc @@ -0,0 +1,919 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/type_trans.h" +#include +#include +#include +#include "lite/arm/math/saturate.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void int32_to_dtype(const int* din, + dtype* dout, + const float* scale, + int axis_size, + int64_t outer_size, + int64_t inner_size); + +void fp32_to_int8(const float* din, + int8_t* dout, + const float* scale, + int axis_size, + int64_t outer_size, + int64_t inner_size) { + int cnt = inner_size / 16; + int remain = inner_size & 15; + int64_t loop_size = outer_size * axis_size; + +#pragma omp parallel for + for (int j = 0; j < loop_size; ++j) { + float inv_scale = 1.f / scale[j % axis_size]; + float32x4_t vzero = vdupq_n_f32(0.f); + float32x4_t vscale = vdupq_n_f32(inv_scale); + float32x4_t vpoff = vdupq_n_f32(0.5f); + float32x4_t vnoff = vdupq_n_f32(-0.5f); + const float* din_c = din + j * inner_size; + signed char* dout_c = dout + j * inner_size; + if (cnt > 0) { + int cnt_loop = cnt; + const float* din_ptr = din_c; + signed char* dout_ptr = dout_c; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[in]], #32 \n" + "ldp q2, q3, [%[in]], #32 \n" + "0: \n" /* main loop */ + "fmul v4.4s, v0.4s, %[scale].4s \n" + "fmul v5.4s, v1.4s, %[scale].4s \n" + "fmul v6.4s, v2.4s, %[scale].4s \n" + "fmul v7.4s, v3.4s, %[scale].4s \n" + "ldp q0, q1, [%[in]], #32 \n" + "subs %[cnt], %[cnt], #1 \n" + "FCVTAS v8.4s, v4.4s \n" + "FCVTAS v9.4s, v5.4s \n" + "FCVTAS v10.4s, v6.4s \n" + "FCVTAS v11.4s, v7.4s \n" + "ldp q2, q3, [%[in]], #32 \n" + "sqxtn v4.4h, v8.4s \n" + "sqxtn2 v4.8h, v9.4s \n" + "sqxtn v5.4h, v10.4s \n" + "sqxtn2 v5.8h, v11.4s \n" + "sqxtn v8.8b, v4.8h \n" + "sqxtn2 v8.16b, v5.8h \n" + "str q8, [%[out]], #16 \n" + "bne 0b \n" + : [in] "+r"(din_ptr), [out] "+r"(dout_ptr), [cnt] "+r"(cnt_loop) + : [scale] "w"(vscale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" + "vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n" + "0: @ main loop\n" + "vand.i32 q4, %q[vpoff], %q[vpoff] @ set offset, 0.5\n" + "vand.i32 q5, q4, q4 @ set offset, 0.5\n" + "vand.i32 q6, q4, q4 @ set offset, 0.5\n" + "vand.i32 q7, q4, q4 @ set offset, 0.5\n" + "vcgt.f32 q8, q0, %q[vzero] @ get mask > 0, in0\n" + "vcgt.f32 q9, q1, %q[vzero] @ get mask > 0, in1\n" + "vcgt.f32 q10, q2, %q[vzero] @ get mask > 0, in2\n" + "vcgt.f32 q11, q3, %q[vzero] @ get mask > 0, in3\n" + "vbif.f32 q4, %q[vnoff], q8 @ get right offset\n" + "vbif.f32 q5, %q[vnoff], q9 @ get right offset\n" + "vbif.f32 q6, %q[vnoff], q10 @ get right offset\n" + "vbif.f32 q7, %q[vnoff], q11 @ get right offset\n" + "vmla.f32 q4, q0, %q[vscale] @ mul scale\n" + "vmla.f32 q5, q1, %q[vscale] @ mul scale\n" + "vmla.f32 q6, q2, %q[vscale] @ mul scale\n" + "vmla.f32 q7, q3, %q[vscale] @ mul scale\n" + "vcvt.s32.f32 q0, q4 @ cvt to int32\n" + "vcvt.s32.f32 q1, q5 @ cvt to int32\n" + "vcvt.s32.f32 q2, q6 @ cvt to int32\n" + "vcvt.s32.f32 q3, q7 @ cvt to int32\n" + "vqmovn.s32 d8, q0 @ cnt to int16\n" + "vqmovn.s32 d9, q1 @ cnt to int16\n" + "vqmovn.s32 d10, q2 @ cnt to int16\n" + "vqmovn.s32 d11, q3 @ cnt to int16\n" + "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" + "vqmovn.s16 d12, q4 @ cnt to int8\n" + "vqmovn.s16 d13, q5 @ cnt to int8\n" + "vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n" + "vst1.32 {d12-d13}, [%[dout]]! @ write to output\n" + "subs %[cnt], #1 @ loop count -1\n" + "bne 0b @ to main loop\n" + + : [dout] "+r"(dout_ptr), [din] "+r"(din_ptr), [cnt] "+r"(cnt_loop) + : [vscale] "w"(vscale), + [vpoff] "w"(vpoff), + [vnoff] "w"(vnoff), + [vzero] "w"(vzero) + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11"); +#endif + } + const float* din_r = din_c + 16 * cnt; + signed char* dout_r = dout_c + 16 * cnt; + for (int i = 0; i < remain; ++i) { + dout_r[i] = saturate_cast(roundf(inv_scale * din_r[i])); + } + } +} + +void fp32_to_int16(const float* din, + int16_t* dout, + const float* scale, + int axis_size, + int64_t outer_size, + int64_t inner_size) { + int cnt = inner_size / 8; + int remain = inner_size & 7; + int64_t loop_size = outer_size * axis_size; + +#pragma omp parallel for + for (int j = 0; j < loop_size; ++j) { + float inv_scale = 1.f / scale[j % axis_size]; + float32x4_t vzero = vdupq_n_f32(0.f); + float32x4_t vscale = vdupq_n_f32(inv_scale); + float32x4_t vpoff = vdupq_n_f32(0.5f); + float32x4_t vnoff = vdupq_n_f32(-0.5f); + const float* din_c = din + j * inner_size; + int16_t* dout_c = dout + j * inner_size; + if (cnt > 0) { + int cnt_loop = cnt; + const float* din_ptr = din_c; + int16_t* dout_ptr = dout_c; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[in]], #32 \n" + "0: \n" /* main loop */ + "fmul v4.4s, v0.4s, %[scale].4s \n" + "fmul v5.4s, v1.4s, %[scale].4s \n" + "ldp q0, q1, [%[in]], #32 \n" + "subs %[cnt], %[cnt], #1 \n" + "FCVTAS v8.4s, v4.4s \n" + "FCVTAS v9.4s, v5.4s \n" + "sqxtn v4.4h, v8.4s \n" + "sqxtn2 v4.8h, v9.4s \n" + "str q4, [%[out]], #16 \n" + "bne 0b \n" + : [in] "+r"(din_ptr), [out] "+r"(dout_ptr), [cnt] "+r"(cnt_loop) + : [scale] "w"(vscale) + : "v0", "v1", "v4", "v5", "v8", "v9"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" + "0: @ main loop\n" + "vand.i32 q4, %q[vpoff], %q[vpoff] @ set offset, 0.5\n" + "vand.i32 q5, q4, q4 @ set offset, 0.5\n" + "vand.i32 q6, q4, q4 @ set offset, 0.5\n" + "vand.i32 q7, q4, q4 @ set offset, 0.5\n" + "vcgt.f32 q8, q0, %q[vzero] @ get mask > 0, in0\n" + "vcgt.f32 q9, q1, %q[vzero] @ get mask > 0, in1\n" + "vbif.f32 q4, %q[vnoff], q8 @ get right offset\n" + "vbif.f32 q5, %q[vnoff], q9 @ get right offset\n" + "vmla.f32 q4, q0, %q[vscale] @ mul scale\n" + "vmla.f32 q5, q1, %q[vscale] @ mul scale\n" + "vcvt.s32.f32 q0, q4 @ cvt to int32\n" + "vcvt.s32.f32 q1, q5 @ cvt to int32\n" + "vqmovn.s32 d8, q0 @ cnt to int16\n" + "vqmovn.s32 d9, q1 @ cnt to int16\n" + "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" + "vst1.32 {d8-d9}, [%[dout]]! @ write to output\n" + "subs %[cnt], #1 @ loop count -1\n" + "bne 0b @ to main loop\n" + + : [dout] "+r"(dout_ptr), [din] "+r"(din_ptr), [cnt] "+r"(cnt_loop) + : [vscale] "w"(vscale), + [vpoff] "w"(vpoff), + [vnoff] "w"(vnoff), + [vzero] "w"(vzero) + : "q0", "q1", "q4", "q5", "q6", "q7", "q8", "q9"); +#endif + } + const float* din_r = din_c + 8 * cnt; + int16_t* dout_r = dout_c + 8 * cnt; + for (int i = 0; i < remain; ++i) { + dout_r[i] = saturate_cast(roundf(inv_scale * din_r[i])); + } + } +} + +void int8_to_fp32(const int8_t* in, + float* out, + const float* scale, + int axis_size, + int64_t outer_size, + int64_t inner_size) { + int cnt = inner_size / 16; + int remain = inner_size & 15; + int64_t loop_size = axis_size * outer_size; +#pragma omp parallel for + for (int64_t n = 0; n < loop_size; ++n) { + float in_scale = scale[n % axis_size]; + const signed char* din_c = in + n * inner_size; + float* dout_c = out + n * inner_size; + float32x4_t vscale = vdupq_n_f32(in_scale); + if (cnt > 0) { + int loop = cnt; + const signed char* din_ptr = din_c; + float* dout_ptr = dout_c; +#ifdef __aarch64__ + asm volatile( + "ldp d0, d1, [%[in]], #16 \n" /* load 16 int8*/ + "0: \n" /* main loop */ + "sshll v2.8h, v0.8b, #0 \n" /* trans to int16*/ + "sshll v3.8h, v1.8b, #0 \n" /* trans to int16*/ + + "sshll v4.4s, v2.4h, #0 \n" /* trans to int32*/ + "sshll2 v5.4s, v2.8h, #0 \n" /* trans to int32*/ + "sshll v6.4s, v3.4h, #0 \n" /* trans to int32*/ + "sshll2 v7.4s, v3.8h, #0 \n" /* trans to int32*/ + + "ldp d0, d1, [%[in]], #16 \n" /* load 16 int8*/ + + "scvtf v8.4s, v4.4s \n" /* trans to fp32*/ + "scvtf v9.4s, v5.4s \n" /* trans to fp32*/ + "scvtf v10.4s, v6.4s \n" /* trans to fp32*/ + "scvtf v11.4s, v7.4s \n" /* trans to fp32*/ + + "subs %[loop], %[loop], #1 \n" + + "fmul v4.4s, v8.4s, %[scale].4s \n" /* mul with scale*/ + "fmul v5.4s, v9.4s, %[scale].4s \n" /* mul with scale*/ + "fmul v6.4s, v10.4s, %[scale].4s \n" /* mul with scale*/ + "fmul v7.4s, v11.4s, %[scale].4s \n" /* mul with scale*/ + + "stp q4, q5, [%[out]], #32 \n" /* write to memory*/ + "stp q6, q7, [%[out]], #32 \n" /* write to memory*/ + + "bne 0b \n" + : [loop] "+r"(loop), [in] "+r"(din_ptr), [out] "+r"(dout_ptr) + : [scale] "w"(vscale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11"); +#else + asm volatile( + "vld1.32 {d0-d1}, [%[in]]! @ load 16 int8\n" + "0: @ main loop\n" + "vmovl.s8 q2, d0 @ trans to int16\n" + "vmovl.s8 q3, d1 @ trans to int16\n" + "vmovl.s16 q4, d4 @ trans to int32\n" + "vmovl.s16 q5, d5 @ trans to int32\n" + "vmovl.s16 q6, d6 @ trans to int32\n" + "vmovl.s16 q7, d7 @ trans to int32\n" + "vcvt.f32.s32 q0, q4 @ trans to fp32\n" + "vcvt.f32.s32 q1, q5 @ trans to fp32\n" + "vcvt.f32.s32 q2, q6 @ trans to fp32\n" + "vcvt.f32.s32 q3, q7 @ trans to fp32\n" + "vmul.f32 q4, q0, %q[scale] @ mul with scale\n" + "vmul.f32 q5, q1, %q[scale] @ mul with scale\n" + "vmul.f32 q6, q2, %q[scale] @ mul with scale\n" + "vmul.f32 q7, q3, %q[scale] @ mul with scale\n" + + "vld1.32 {d0-d1}, [%[in]]! @ load 16 int8\n" + + "subs %[loop], #1 \n" + + "vst1.f32 {d8-d11}, [%[out]]! @ write to memory\n" + "vst1.f32 {d12-d15}, [%[out]]! @ write to memory\n" + + "bne 0b \n" + : [loop] "+r"(loop), [in] "+r"(din_ptr), [out] "+r"(dout_ptr) + : [scale] "w"(vscale) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +#endif // __aarch64__ + } + const signed char* din_r = din_c + 16 * cnt; + float* dout_r = dout_c + 16 * cnt; + for (int i = 0; i < remain; ++i) { + dout_r[i] = in_scale * din_r[i]; + } + } +} + +void int16_to_fp32(const int16_t* in, + float* out, + const float* scale, + int axis_size, + int64_t outer_size, + int64_t inner_size) { + int cnt = inner_size / 16; + int remain = inner_size & 15; + int64_t loop_size = axis_size * outer_size; +#pragma omp parallel for + for (int64_t n = 0; n < loop_size; ++n) { + float in_scale = scale[n % axis_size]; + const int16_t* din_c = in + n * inner_size; + float* dout_c = out + n * inner_size; + float32x4_t vscale = vdupq_n_f32(in_scale); + if (cnt > 0) { + int loop = cnt; + const int16_t* din_ptr = din_c; + float* dout_ptr = dout_c; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[in]], #32 \n" /* load 16 int16*/ + "0: \n" /* main loop */ + "sshll v4.4s, v0.4h, #0 \n" /* trans to int32*/ + "sshll2 v5.4s, v0.8h, #0 \n" /* trans to int32*/ + "sshll v6.4s, v1.4h, #0 \n" /* trans to int32*/ + "sshll2 v7.4s, v1.8h, #0 \n" /* trans to int32*/ + + "ldp q0, q1, [%[in]], #32 \n" /* load 16 int16*/ + + "scvtf v8.4s, v4.4s \n" /* trans to fp32*/ + "scvtf v9.4s, v5.4s \n" /* trans to fp32*/ + "scvtf v10.4s, v6.4s \n" /* trans to fp32*/ + "scvtf v11.4s, v7.4s \n" /* trans to fp32*/ + + "subs %[loop], %[loop], #1 \n" + + "fmul v4.4s, v8.4s, %[scale].4s \n" /* mul with scale*/ + "fmul v5.4s, v9.4s, %[scale].4s \n" /* mul with scale*/ + "fmul v6.4s, v10.4s, %[scale].4s \n" /* mul with scale*/ + "fmul v7.4s, v11.4s, %[scale].4s \n" /* mul with scale*/ + + "stp q4, q5, [%[out]], #32 \n" /* write to memory*/ + "stp q6, q7, [%[out]], #32 \n" /* write to memory*/ + + "bne 0b \n" + : [loop] "+r"(loop), [in] "+r"(din_ptr), [out] "+r"(dout_ptr) + : [scale] "w"(vscale) + : "v0", "v1", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[in]]! @ load 16 int16\n" + "0: @ main loop\n" + "vmovl.s16 q4, d0 @ trans to int32\n" + "vmovl.s16 q5, d1 @ trans to int32\n" + "vmovl.s16 q6, d2 @ trans to int32\n" + "vmovl.s16 q7, d3 @ trans to int32\n" + "vcvt.f32.s32 q0, q4 @ trans to fp32\n" + "vcvt.f32.s32 q1, q5 @ trans to fp32\n" + "vcvt.f32.s32 q2, q6 @ trans to fp32\n" + "vcvt.f32.s32 q3, q7 @ trans to fp32\n" + "vmul.f32 q4, q0, %q[scale] @ mul with scale\n" + "vmul.f32 q5, q1, %q[scale] @ mul with scale\n" + "vmul.f32 q6, q2, %q[scale] @ mul with scale\n" + "vmul.f32 q7, q3, %q[scale] @ mul with scale\n" + + "vld1.32 {d0-d3}, [%[in]]! @ load 16 int8\n" + + "subs %[loop], #1 \n" + + "vst1.f32 {d8-d11}, [%[out]]! @ write to memory\n" + "vst1.f32 {d12-d15}, [%[out]]! @ write to memory\n" + + "bne 0b \n" + : [loop] "+r"(loop), [in] "+r"(din_ptr), [out] "+r"(dout_ptr) + : [scale] "w"(vscale) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +#endif // __aarch64__ + } + const int16_t* din_r = din_c + 16 * cnt; + float* dout_r = dout_c + 16 * cnt; + for (int i = 0; i < remain; ++i) { + dout_r[i] = in_scale * din_r[i]; + } + } +} + +void int32_to_fp32(const int* din, + float* dout, + const float* scale, + int axis_size, + int64_t outer_size, + int64_t inner_size) { + int cnt = inner_size / 16; + int remain = inner_size & 15; + int64_t loop_size = axis_size * outer_size; +#pragma omp parallel for + for (int64_t n = 0; n < loop_size; ++n) { + float in_scale = scale[n % axis_size]; + const int* din_c = din + n * inner_size; + float* dout_c = dout + n * inner_size; + float32x4_t vscale = vdupq_n_f32(in_scale); + if (cnt > 0) { + int loop = cnt; + const int* din_ptr = din_c; + float* dout_ptr = dout_c; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[in]], #32 \n" + "ldp q2, q3, [%[in]], #32 \n" + "0: \n" + "scvtf v4.4s, v0.4s \n" + "scvtf v5.4s, v1.4s \n" + "scvtf v6.4s, v2.4s \n" + "scvtf v7.4s, v3.4s \n" + "ldp q0, q1, [%[in]], #32 \n" + "fmul v8.4s, v4.4s, %[scale].4s \n" + "fmul v9.4s, v5.4s, %[scale].4s \n" + "fmul v10.4s, v6.4s, %[scale].4s \n" + "fmul v11.4s, v7.4s, %[scale].4s \n" + "ldp q2, q3, [%[in]], #32 \n" + "stp q8, q9, [%[out]], #32 \n" + "stp q10, q11, [%[out]], #32 \n" + "subs %[loop], %[loop], #1 \n" + "bne 0b \n" + : [loop] "+r"(loop), [in] "+r"(din_ptr), [out] "+r"(dout_ptr) + : [scale] "w"(vscale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11"); +#else + asm volatile( + "vld1.s32 {d0-d3}, [%[in]]! \n" + "vld1.s32 {d4-d7}, [%[in]]! \n" + "0: \n" + "vcvt.f32.s32 q4, q0 \n" + "vcvt.f32.s32 q5, q1 \n" + "vcvt.f32.s32 q6, q2 \n" + "vcvt.f32.s32 q7, q3 \n" + "vld1.s32 {d0-d3}, [%[in]]! \n" + "vmul.f32 q8, q4, %q[scale] \n" + "vmul.f32 q9, q5, %q[scale] \n" + "vmul.f32 q10, q6, %q[scale] \n" + "vmul.f32 q11, q7, %q[scale] \n" + "vld1.s32 {d4-d7}, [%[in]]! \n" + "subs %[loop], #1 \n" + "vst1.f32 {d16-d19}, [%[out]]! \n" + "vst1.f32 {d20-d23}, [%[out]]! \n" + "bne 0b \n" + : [loop] "+r"(loop), [in] "+r"(din_ptr), [out] "+r"(dout_ptr) + : [scale] "w"(vscale) + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11"); +#endif // __aarch64__ + } + const int* din_r = din_c + 16 * cnt; + float* dout_r = dout_c + 16 * cnt; + for (int i = 0; i < remain; ++i) { + dout_r[i] = in_scale * din_r[i]; + } + } +} + +void int32_to_int8(const int* din, + int8_t* dout, + const float* scale, + int axis_size, + int64_t outer_size, + int64_t inner_size) { + int cnt = inner_size / 16; + int remain = inner_size & 15; + int64_t loop_size = outer_size * axis_size; +#pragma omp parallel for + for (int64_t n = 0; n < loop_size; ++n) { + float in_scale = scale[n % axis_size]; + const int* din_c = din + n * inner_size; + int8_t* dout_c = dout + n * inner_size; + float32x4_t vscale = vdupq_n_f32(in_scale); + float32x4_t vzero = vdupq_n_f32(0.f); + float32x4_t vpoff = vdupq_n_f32(0.5f); + float32x4_t vnoff = vdupq_n_f32(-0.5f); + if (cnt > 0) { + int loop = cnt; + const int* din_ptr = din_c; + int8_t* dout_ptr = dout_c; +#ifdef __aarch64__ + asm volatile( + "0: \n" + "ld1 {v0.4s, v1.4s}, [%[in]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[in]], #32 \n" + + "scvtf v4.4s, v0.4s \n" + "scvtf v5.4s, v1.4s \n" + "scvtf v6.4s, v2.4s \n" + "scvtf v7.4s, v3.4s \n" + + "fmul v0.4s, v4.4s, %[scale].4s \n" + "fmul v1.4s, v5.4s, %[scale].4s \n" + "fmul v2.4s, v6.4s, %[scale].4s \n" + "fmul v3.4s, v7.4s, %[scale].4s \n" + + "fcvtas v4.4s, v0.4s \n" + "fcvtas v5.4s, v1.4s \n" + "fcvtas v6.4s, v2.4s \n" + "fcvtas v7.4s, v3.4s \n" + + "sqxtn v0.4h, v4.4s \n" + "sqxtn2 v0.8h, v5.4s \n" + "sqxtn v1.4h, v6.4s \n" + "sqxtn2 v1.8h, v7.4s \n" + + "sqxtn v2.8b, v0.8h \n" + "sqxtn2 v2.16b, v1.8h \n" + + "st1 {v2.16b}, [%[out]], #16 \n" + "subs %[loop], %[loop], #1 \n" + "bne 0b \n" + : [loop] "+r"(loop), [in] "+r"(din_ptr), [out] "+r"(dout_ptr) + : [scale] "w"(vscale) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" + "vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n" + "0: @ main loop\n" + "vcvt.f32.s32 q4, q0 @ cvt to float\n" + "vcvt.f32.s32 q5, q1 @ cvt to float\n" + "vcvt.f32.s32 q6, q2 @ cvt to float\n" + "vcvt.f32.s32 q7, q3 @ cvt to float\n" + "vand.i32 q0, %q[vpoff], %q[vpoff] @ set offset, 0.5\n" + "vand.i32 q1, q0, q0 @ set offset, 0.5\n" + "vand.i32 q2, q0, q0 @ set offset, 0.5\n" + "vand.i32 q3, q0, q0 @ set offset, 0.5\n" + "vcgt.f32 q8, q4, %q[vzero] @ get mask > 0, in0\n" + "vcgt.f32 q9, q5, %q[vzero] @ get mask > 0, in1\n" + "vcgt.f32 q10, q6, %q[vzero] @ get mask > 0, in2\n" + "vcgt.f32 q11, q7, %q[vzero] @ get mask > 0, in3\n" + "vbif.f32 q0, %q[vnoff], q8 @ get right offset\n" + "vbif.f32 q1, %q[vnoff], q9 @ get right offset\n" + "vbif.f32 q2, %q[vnoff], q10 @ get right offset\n" + "vbif.f32 q3, %q[vnoff], q11 @ get right offset\n" + "vmla.f32 q0, q4, %q[vscale] @ mul scale\n" + "vmla.f32 q1, q5, %q[vscale] @ mul scale\n" + "vmla.f32 q2, q6, %q[vscale] @ mul scale\n" + "vmla.f32 q3, q7, %q[vscale] @ mul scale\n" + "vcvt.s32.f32 q4, q0 @ cvt to int32\n" + "vcvt.s32.f32 q5, q1 @ cvt to int32\n" + "vcvt.s32.f32 q6, q2 @ cvt to int32\n" + "vcvt.s32.f32 q7, q3 @ cvt to int32\n" + "vqmovn.s32 d16, q4 @ cnt to int16\n" + "vqmovn.s32 d17, q5 @ cnt to int16\n" + "vqmovn.s32 d18, q6 @ cnt to int16\n" + "vqmovn.s32 d19, q7 @ cnt to int16\n" + "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" + "vqmovn.s16 d8, q8 @ cnt to int8\n" + "vqmovn.s16 d9, q9 @ cnt to int8\n" + "vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n" + "vst1.32 {d8-d9}, [%[dout]]! @ write to output\n" + "subs %[loop], #1 @ loop count -1\n" + "bne 0b @ to main loop\n" + : [loop] "+r"(loop), [din] "+r"(din_ptr), [dout] "+r"(dout_ptr) + : [vscale] "w"(vscale), + [vzero] "w"(vzero), + [vnoff] "w"(vnoff), + [vpoff] "w"(vpoff) + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11"); +#endif // __aarch64__ + } + const int* din_r = din_c + 16 * cnt; + int8_t* dout_r = dout_c + 16 * cnt; + for (int i = 0; i < remain; ++i) { + dout_r[i] = saturate_cast(roundf(in_scale * din_r[i])); + } + } +} + +/******************************************/ +/******** kernel implement *********/ +/******************************************/ +float compute_max_kernel(const float* din, int64_t size) { + float max_value = 0.f; + int cnt = size / 16; + int remain = size & 15; + float32x4_t vmax_val = vdupq_n_f32(0.f); + const float* ptr_in = din; + if (cnt > 0) { + int loop_cnt = cnt; +#ifdef __aarch64__ + asm volatile( + "ld1 {v0.4s, v1.4s}, [%[in]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[in]], #32 \n" + "0: \n" + "fabs v4.4s, v0.4s \n" + "fabs v5.4s, v1.4s \n" + "fabs v6.4s, v2.4s \n" + "fabs v7.4s, v3.4s \n" + "ld1 {v0.4s, v1.4s}, [%[in]], #32 \n" + "fmax v2.4s, v4.4s, v5.4s \n" + "fmax v3.4s, v6.4s, v7.4s \n" + "fmax v4.4s, v2.4s, v3.4s \n" + "ld1 {v2.4s, v3.4s}, [%[in]], #32 \n" + "fmax %[max_val].4s, v4.4s, %[max_val].4s \n" + "subs %[cnt], %[cnt], #1 \n" + "bne 0b \n" + : [in] "+r"(ptr_in), [cnt] "+r"(loop_cnt), [max_val] "+w"(vmax_val) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[in]]! @ load 8 float\n" + "vld1.32 {d4-d7}, [%[in]]! @ load 8 float\n" + "0: @ main loop\n" + "vabs.f32 q4, q0 @ abs \n" + "vabs.f32 q5, q1 @ abs \n" + "vabs.f32 q6, q2 @ abs \n" + "vabs.f32 q7, q3 @ abs \n" + "vld1.32 {d0-d3}, [%[in]]! @ load 8 float\n" + "vmax.f32 q2, q4, q5 @ max \n" + "vmax.f32 q3, q6, q7 @ max \n" + "vmax.f32 q4, q2, q3 @ max \n" + "vld1.32 {d4-d7}, [%[in]]! @ load 8 float\n" + "vmax.f32 %q[max_val], q4, %q[max_val] @ max \n" + "subs %[cnt], #1 @ loop count -1\n" + "bne 0b @ jump to main loop\n" + + : [in] "+r"(ptr_in), [cnt] "+r"(loop_cnt), [max_val] "+w"(vmax_val) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +#endif + float32x2_t vmax_p = + vpmax_f32(vget_high_f32(vmax_val), vget_low_f32(vmax_val)); + float max0 = vget_lane_f32(vmax_p, 0); + float max1 = vget_lane_f32(vmax_p, 1); + float max2 = max0 > max1 ? max0 : max1; + max_value = max_value > max2 ? max_value : max2; + } + ptr_in = din + 16 * cnt; + for (int i = 0; i < remain; ++i) { + float data = fabsf(*(ptr_in++)); + max_value = fmaxf(max_value, data); + } + return max_value; +} + +std::vector get_tensor_scale_n(const float* in_data, + int axis_size, + int64_t inner_size, + float scale_factor) { + std::vector scale_out(axis_size); +#pragma omp parallel for + for (int c = 0; c < axis_size; ++c) { // num + const float* ptr_in = in_data + c * inner_size; // channel*width*height + scale_out[c] = compute_max_kernel(ptr_in, inner_size) / scale_factor; + } + return scale_out; +} + +std::vector get_tensor_scale_chw(const float* in_data, + int axis_size, + int64_t outer_size, + int64_t inner_size, + float scale_factor) { + std::vector scale_out(axis_size); + int64_t inner_size_with_axis = axis_size * inner_size; +#pragma omp parallel for + for (int c = 0; c < axis_size; ++c) { + const float* din = in_data + c * inner_size; + float max_val = 0.f; + for (int j = 0; j < outer_size; ++j) { + const float* ptr_in = din + j * inner_size_with_axis; + max_val = fmaxf(compute_max_kernel(ptr_in, inner_size), max_val); + } + scale_out[c] = max_val / scale_factor; + } + return scale_out; +} + +void int32_to_int32(const int* din, + int* dout, + const float* scale, + int axis_size, + int64_t outer_size, + int64_t inner_size) { + int size_all = outer_size * axis_size * inner_size; + memmove(dout, din, size_all * sizeof(int)); +} + +template <> +void int32_to_dtype(const int* din, + float* dout, + const float* scale, + int axis_size, + int64_t outer_size, + int64_t inner_size) { + return int32_to_fp32(din, dout, scale, axis_size, outer_size, inner_size); +} + +template <> +void int32_to_dtype(const int* din, + signed char* dout, + const float* scale, + int axis_size, + int64_t outer_size, + int64_t inner_size) { + return int32_to_int8(din, dout, scale, axis_size, outer_size, inner_size); +} + +template <> +void int32_to_dtype(const int* din, + int* dout, + const float* scale, + int axis_size, + int64_t outer_size, + int64_t inner_size) { + return int32_to_int32(din, dout, scale, axis_size, outer_size, inner_size); +} + +bool trans_tensor_int32_to_int8(Tensor* tin, + Tensor* tout, + float input_scale, + float output_scale, + std::vector weights_scale, + int axis) { + tout->Resize(tin->dims()); + + // compute scale + std::vector scale(weights_scale.size()); + for (int i = 0; i < weights_scale.size(); ++i) { + scale[i] = input_scale * weights_scale[i] / output_scale; + } + + auto i_dims = tin->dims(); + int outer_size = i_dims.count(0, axis); + int axis_size = i_dims[axis]; + int inner_size = i_dims.count(axis + 1, i_dims.size()); + + const int* i_data = tin->data(); + int8_t* o_data = tout->mutable_data(); + int32_to_int8( + i_data, o_data, scale.data(), axis_size, outer_size, inner_size); + + return true; +} + +template <> +bool get_tensor_scale(const Tensor& tin, + std::vector* scale_out, + int axis, + float scale_factor) { + int axis_size = 1; + if (axis >= 0 && axis < tin.dims().size()) { + axis_size = tin.dims()[axis]; + } + int outer_size = 1; + if (axis >= 0) { + outer_size = tin.dims().count(0, axis); + } + int64_t inner_size = tin.dims().count(axis + 1, tin.dims().size()); + + const float* in_data = static_cast(tin.data()); + if (axis <= 0) { + *scale_out = + get_tensor_scale_n(in_data, axis_size, inner_size, scale_factor); + } else { + *scale_out = get_tensor_scale_chw( + in_data, axis_size, outer_size, inner_size, scale_factor); + } + return true; +} + +bool trans_tensor_int32_to_fp32(Tensor* tin, + Tensor* tout, + float input_scale, + std::vector weights_scale, + int axis) { + tout->Resize(tin->dims()); + + // compute scale + std::vector scale(weights_scale.size()); + for (int i = 0; i < weights_scale.size(); ++i) { + scale[i] = input_scale * weights_scale[i]; + } + + auto i_dims = tin->dims(); + int outer_size = i_dims.count(0, axis); + int axis_size = i_dims[axis]; + int inner_size = i_dims.count(axis + 1, i_dims.size()); + + const auto* i_data = tin->data(); + float* o_data = tout->mutable_data(); + //! convert to fp32 + int32_to_fp32( + i_data, o_data, scale.data(), axis_size, outer_size, inner_size); + return true; +} + +bool trans_tensor_fp32_to_int8(Tensor* tin, Tensor* tout, float input_scale) { + tout->Resize(tin->dims()); + + // compute scale + std::vector scale({input_scale}); + int inner_size = tin->dims().production(); + + const auto* i_data = tin->data(); + int8_t* o_data = tout->mutable_data(); + fp32_to_int8(i_data, o_data, scale.data(), 1, 1, inner_size); + return true; +} + +bool trans_fp32_bias_to_int32_basic(Tensor* tin, + Tensor* tout, + float in_scale, + std::vector vector_weight_scale) { + tout->Resize(tin->dims()); + + const float* i_data = tin->data(); + int* o_data = tout->mutable_data(); + for (int i = 0; i < tin->dims().production(); ++i) { + o_data[i] = + static_cast(roundf(i_data[i] / in_scale / vector_weight_scale[i])); + } + return true; +} + +template <> +bool trans_tensor_dtype( + Tensor* tin, + Tensor* tout, + float input_scale, + float output_scale, + std::vector weights_scale) { + return trans_tensor_int32_to_int8( + tin, tout, input_scale, output_scale, weights_scale, 1); +} + +template <> +bool trans_tensor_dtype( + Tensor* tin, + Tensor* tout, + float input_scale, + float output_scale, + std::vector weights_scale) { + return trans_tensor_int32_to_fp32(tin, tout, input_scale, weights_scale, 1); +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/type_trans.h b/lite/arm/math/type_trans.h new file mode 100644 index 00000000000..e07d798b101 --- /dev/null +++ b/lite/arm/math/type_trans.h @@ -0,0 +1,117 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/target_wrapper.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +bool trans_tensor_dtype(Tensor* tin, + Tensor* tout, + float input_scale, + float output_scale, + std::vector weights_scale) { + LOG(FATAL) << "trans_tensor_dtype has no impl"; + return false; +} + +template <> +bool trans_tensor_dtype( + Tensor* tin, + Tensor* tout, + float input_scale, + float output_scale, + std::vector weights_scale); + +template <> +bool trans_tensor_dtype( + Tensor* tin, + Tensor* tout, + float input_scale, + float output_scale, + std::vector weights_scale); + +template +bool get_tensor_scale(const Tensor& tin, + std::vector* scale_out, + int axis, + float scale_factor) { + return false; +} + +std::vector get_tensor_scale_n(const float* in_data, + int axis_size, + int64_t inner_size, + float scale_factor); + +bool trans_fp32_bias_to_int32_basic(Tensor* tin, + Tensor* tout, + float in_scale, + std::vector vector_weight_scale); + +bool trans_tensor_int32_to_int8(Tensor* tin, + Tensor* tout, + float input_scale, + float output_scale, + std::vector weights_scale, + int axis = 1); + +bool trans_tensor_int32_to_fp32(Tensor* tin, + Tensor* tout, + float input_scale, + std::vector weights_scale, + int axis = 1); + +bool trans_tensor_fp32_to_int8(Tensor* tin, Tensor* tout, float input_scale); + +template <> +bool get_tensor_scale(const Tensor& tin, + std::vector* scale_out, + int axis, + float scale_factor); + +template +void int32_to_dtype(const int* din, + dtype* dout, + const float* scale, + int axis_size, + int64_t outer_size, + int64_t inner_size); + +void fp32_to_int8(const float* din, + int8_t* dout, + const float* scale, + int axis_size, + int64_t outer_size, + int64_t inner_size); + +void int8_to_fp32(const int8_t* in, + float* out, + const float* scale, + int axis_size, + int64_t outer_size, + int64_t inner_size); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/yolo_box.cc b/lite/arm/math/yolo_box.cc new file mode 100644 index 00000000000..6dc21c29244 --- /dev/null +++ b/lite/arm/math/yolo_box.cc @@ -0,0 +1,168 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/arm/math/yolo_box.h" +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +namespace { +inline float sigmoid(float x) { return 1.f / (1.f + expf(-x)); } + +inline void get_yolo_box(float* box, + const float* x, + const int* anchors, + int i, + int j, + int an_idx, + int grid_size, + int input_size, + int index, + int stride, + int img_height, + int img_width) { + box[0] = (i + sigmoid(x[index])) * img_width / grid_size; + box[1] = (j + sigmoid(x[index + stride])) * img_height / grid_size; + box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width / + input_size; + box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * + img_height / input_size; +} + +inline int get_entry_index(int batch, + int an_idx, + int hw_idx, + int an_num, + int an_stride, + int stride, + int entry) { + return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; +} + +inline void calc_detection_box(float* boxes, + float* box, + const int box_idx, + const int img_height, + const int img_width) { + boxes[box_idx] = box[0] - box[2] / 2; + boxes[box_idx + 1] = box[1] - box[3] / 2; + boxes[box_idx + 2] = box[0] + box[2] / 2; + boxes[box_idx + 3] = box[1] + box[3] / 2; + + boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast(0); + boxes[box_idx + 1] = + boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast(0); + boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 + ? boxes[box_idx + 2] + : static_cast(img_width - 1); + boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 + ? boxes[box_idx + 3] + : static_cast(img_height - 1); +} + +inline void calc_label_score(float* scores, + const float* input, + const int label_idx, + const int score_idx, + const int class_num, + const float conf, + const int stride) { + for (int i = 0; i < class_num; i++) { + scores[score_idx + i] = conf * sigmoid(input[label_idx + i * stride]); + } +} +} // namespace + +void yolobox(lite::Tensor* X, + lite::Tensor* ImgSize, + lite::Tensor* Boxes, + lite::Tensor* Scores, + std::vector anchors, + int class_num, + float conf_thresh, + int downsample_ratio) { + const int n = X->dims()[0]; + const int h = X->dims()[2]; + const int w = X->dims()[3]; + const int b_num = Boxes->dims()[1]; + const int an_num = anchors.size() / 2; + int X_size = downsample_ratio * h; + + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + + auto anchors_data = anchors.data(); + + const float* X_data = X->data(); + float* ImgSize_data = ImgSize->mutable_data(); + + float* Boxes_data = Boxes->mutable_data(); + + float* Scores_data = Scores->mutable_data(); + + float box[4]; + for (int i = 0; i < n; i++) { + int img_height = static_cast(ImgSize_data[2 * i]); + int img_width = static_cast(ImgSize_data[2 * i + 1]); + + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int obj_idx = + get_entry_index(i, j, k * w + l, an_num, an_stride, stride, 4); + float conf = sigmoid(X_data[obj_idx]); + if (conf < conf_thresh) { + continue; + } + + int box_idx = + get_entry_index(i, j, k * w + l, an_num, an_stride, stride, 0); + get_yolo_box(box, + X_data, + anchors_data, + l, + k, + j, + h, + X_size, + box_idx, + stride, + img_height, + img_width); + box_idx = (i * b_num + j * stride + k * w + l) * 4; + calc_detection_box(Boxes_data, box, box_idx, img_height, img_width); + + int label_idx = + get_entry_index(i, j, k * w + l, an_num, an_stride, stride, 5); + int score_idx = (i * b_num + j * stride + k * w + l) * class_num; + calc_label_score(Scores_data, + X_data, + label_idx, + score_idx, + class_num, + conf, + stride); + } + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/arm/math/yolo_box.h b/lite/arm/math/yolo_box.h new file mode 100644 index 00000000000..e4543087003 --- /dev/null +++ b/lite/arm/math/yolo_box.h @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void yolobox(lite::Tensor* X, + lite::Tensor* ImgSize, + lite::Tensor* Boxes, + lite::Tensor* Scores, + std::vector anchors, + int class_num, + float conf_thresh, + int downsample_ratio); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/core/CMakeLists.txt b/lite/core/CMakeLists.txt new file mode 100644 index 00000000000..cc80637dd44 --- /dev/null +++ b/lite/core/CMakeLists.txt @@ -0,0 +1,75 @@ +if (WITH_TESTING) + lite_cc_library(lite_gtest_main SRCS lite_gtest_main.cc DEPS gtest gflags) +endif() +lite_cc_library(target_wrapper SRCS target_wrapper.cc + DEPS target_wrapper_host place + X86_DEPS target_wrapper_x86 + CUDA_DEPS target_wrapper_cuda + CL_DEPS cl_target_wrapper + FPGA_DEPS fpga_target_wrapper) + +lite_cc_library(memory SRCS memory.cc DEPS target_wrapper CL_DEPS cl_target_wrapper) + +set(tensor_extra_deps "") +if (LITE_WITH_FPGA) + set(tensor_extra_deps lite_tensor_fpga) +endif() +lite_cc_library(tensor SRCS tensor.cc DEPS memory ${tensor_extra_deps}) + + +if (NOT LITE_ON_TINY_PUBLISH) + proto_library(framework_proto SRCS framework.proto) +endif() + +if (LITE_WITH_X86) +lite_cc_library(variable SRCS variable.cc DEPS tensor) +lite_cc_library(types SRCS types.cc) +else() +lite_cc_library(variable SRCS variable.cc DEPS tensor) +lite_cc_library(types SRCS types.cc) +endif() +lite_cc_library(op_registry SRCS op_registry.cc DEPS kernel) +lite_cc_library(scope SRCS scope.cc DEPS tensor) +lite_cc_library(cpu_info SRCS cpu_info.cc DEPS tensor) + +if (LITE_WITH_ARM) +lite_cc_library(context SRCS context.cc DEPS tensor any cpu_info CL_DEPS cl_context gflags NPU_DEPS ${npu_ddk_libs}) +else() +lite_cc_library(context SRCS context.cc DEPS tensor any cpu_info eigen3 CL_DEPS cl_context gflags) +endif() +lite_cc_library(kernel SRCS kernel.cc DEPS context type_system target_wrapper any op_params tensor) +lite_cc_library(op SRCS op_lite.cc DEPS scope op_registry target_wrapper kernel + cpp_op_desc tensor) +lite_cc_library(type_system SRCS type_system.cc DEPS tensor target_wrapper) + +lite_cc_library(program SRCS program.cc + DEPS op kernel model_parser ${ops} ${cpp_wrapper} + PROFILE_DEPS basic_profiler) + +if (NOT LITE_ON_TINY_PUBLISH) + lite_cc_library(optimizer SRCS optimizer.cc DEPS mir_pass_manager model_parser program) + add_subdirectory(mir) + add_subdirectory(profile) + add_subdirectory(arena) +endif() + +# for mobile, unnecessary to compile the following testings. +if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + return() +endif() + +# lite_cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph +# scope op_registry proto_desc op +# ${ops} +# ${host_kernels} +# ) + +lite_cc_test(test_scope SRCS scope_test.cc DEPS scope) +lite_cc_test(test_kernel SRCS kernel_test.cc DEPS kernel target_wrapper any) +lite_cc_test(test_op SRCS op_lite_test.cc DEPS op) +lite_cc_test(test_tensor SRCS lite_tensor_test.cc DEPS tensor) +lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils) +#lite_cc_test(test_optimizer SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes optimizer fc_op) +lite_cc_test(test_types SRCS types_test.cc DEPS types) +lite_cc_test(test_memory SRCS memory_test.cc DEPS memory) +lite_cc_test(test_context SRCS context_test.cc DEPS context) diff --git a/lite/core/arena/CMakeLists.txt b/lite/core/arena/CMakeLists.txt new file mode 100644 index 00000000000..854d2f41725 --- /dev/null +++ b/lite/core/arena/CMakeLists.txt @@ -0,0 +1,10 @@ +# To make sure the test framework is only actived in TESTING mode. +if(NOT WITH_TESTING) + return() +endif() + +lite_cc_library(arena_framework SRCS framework.cc DEPS program) + +if(NOT LITE_WITH_OPENCL AND (LITE_WITH_X86 OR LITE_WITH_ARM)) + lite_cc_test(test_arena_framework SRCS framework_test.cc DEPS arena_framework ${x86_kernels} ${fpga_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) +endif() diff --git a/lite/core/arena/framework.cc b/lite/core/arena/framework.cc new file mode 100644 index 00000000000..c59c078787b --- /dev/null +++ b/lite/core/arena/framework.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/arena/framework.h" +#include "lite/core/context.h" + +namespace paddle { +namespace lite { +namespace arena { + +void TestCase::CreateInstruction() { + auto op = LiteOpRegistry::Global().Create(op_desc().Type()); + CHECK(op) << "no op for " << op_desc().Type(); + op->Attach(*op_desc_, inst_scope_); + auto kernels = op->CreateKernels({place_}); + // filter out the target kernel + CHECK(!kernels.empty()) << "No kernel found for place " + << place_.DebugString(); + auto it = std::remove_if( + kernels.begin(), kernels.end(), [&](std::unique_ptr& k) { + return k->alias() == alias_; + }); + CHECK(it != kernels.end()) << "failed to create the kernel in " + << place_.DebugString() + << " with alias: " << alias_; + // prepare context + (*it)->SetContext(std::move(ctx_)); + instruction_.reset(new Instruction(op, std::move(*it))); +} + +void TestCase::PrepareInputsForInstruction() { + for (auto& arg : op_desc().InputArgumentNames()) { + for (auto& var : op_desc().Input(arg)) { + std::string kernel_key = instruction_->kernel()->key_with_alias(); + const auto* param_type = ParamTypeRegistry::Global().RetrieveInArgument( + place_, kernel_key, arg); + + const auto* inst_type = Type::GetTensorTy(TARGET(kHost)); + CHECK(scope_->FindVar(var)); + const auto* shared_tensor = scope_->FindTensor((var)); + if (!TargetCompatibleTo(*inst_type, *param_type->type)) { + /// Create a tensor in the instruction's scope, alloc memory and then + /// copy data there. + auto* target_tensor = inst_scope_->NewTensor(var); + CHECK(!shared_tensor->dims().empty()) << "shared_tensor is empty yet"; + target_tensor->Resize(shared_tensor->dims()); + TargetCopy(param_type->type->target(), + target_tensor->mutable_data(param_type->type->target(), + shared_tensor->memory_size()), + shared_tensor->raw_data(), + shared_tensor->memory_size()); + } + } + } +} + +} // namespace arena +} // namespace lite +} // namespace paddle diff --git a/lite/core/arena/framework.h b/lite/core/arena/framework.h new file mode 100644 index 00000000000..48a8571a199 --- /dev/null +++ b/lite/core/arena/framework.h @@ -0,0 +1,258 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" +#include "lite/core/program.h" +#include "lite/core/scope.h" +#include "lite/core/types.h" +#include "lite/model_parser/cpp/op_desc.h" + +namespace paddle { +namespace lite { +namespace arena { + +/* + * Init data and prepare the op. + */ +class TestCase { + public: + explicit TestCase(const Place& place, const std::string& alias) + : place_(place), scope_(new Scope), alias_(alias) { + ctx_ = ContextScheduler::Global().NewContext(place_.target); + } + + void Prepare() { + PrepareScopes(); + PrepareData(); + op_desc_.reset(new cpp::OpDesc); + PrepareOpDesc(op_desc_.get()); + + PrepareOutputsForInstruction(); + CreateInstruction(); + PrepareInputsForInstruction(); + } + + /// Run the target instruction, that is run the test operator. + void RunInstruction() { instruction_->Run(); } + + KernelContext* context() { return ctx_.get(); } + + /// The baseline should be implemented, which acts similar to an operator, + /// that is take several tensors as input and output several tensors as + /// output. + virtual void RunBaseline(Scope* scope) = 0; + + /// Check the precision of the output tensors. It will compare the same tensor + /// in two scopes, one of the instruction execution, and the other for the + /// baseline. + template + bool CheckPrecision(const std::string& var_name, float abs_error); + + const cpp::OpDesc& op_desc() { return *op_desc_; } + + // Check whether the output tensor is consistent with the output definition in + // kernel registry. + void CheckKernelConsistWithDefinition() {} + + Scope& scope() { return *scope_; } + + Scope* baseline_scope() { return base_scope_; } + Scope* inst_scope() { return inst_scope_; } + + protected: + // Prepare inputs in scope() for Tester. + virtual void PrepareData() = 0; + + /// Prepare a tensor in host. The tensors will be created in scope_. + /// Need to specify the targets other than X86 or ARM. + template + void SetCommonTensor(const std::string& var_name, + const DDim& ddim, + const T* data, + const LoD& lod = {}) { + auto* tensor = scope_->NewTensor(var_name); + tensor->Resize(ddim); + auto* d = tensor->mutable_data(); + memcpy(d, data, ddim.production() * sizeof(T)); + + // set lod + if (!lod.empty()) *tensor->mutable_lod() = lod; + } + + // Prepare for the operator. + virtual void PrepareOpDesc(cpp::OpDesc* op_desc) = 0; + + public: + const Instruction& instruction() { return *instruction_; } + + private: + std::unique_ptr ctx_; + void CreateInstruction(); + + void PrepareScopes() { + inst_scope_ = &scope_->NewScope(); + base_scope_ = &scope_->NewScope(); + } + + // Check shape + // TODO(Superjomn) Move this method to utils or DDim? + bool ShapeEquals(const DDim& a, const DDim& b) { + if (a.size() != b.size()) return false; + for (int i = 0; i < a.size(); i++) { + if (a[i] != b[i]) return false; + } + return true; + } + + /// Copy the input tensors to target devices needed by the instruction. + void PrepareInputsForInstruction(); + + // Create output tensors and variables. + void PrepareOutputsForInstruction() { + for (auto x : op_desc().output_vars()) { + inst_scope_->NewTensor(x); + base_scope_->NewTensor(x); + } + } + + private: + std::shared_ptr scope_; + // The workspace for the Instruction. + Scope* inst_scope_{}; + // The workspace for the baseline implementation. + Scope* base_scope_{}; + std::unique_ptr op_desc_; + std::unique_ptr instruction_; + Place place_; + std::string alias_; +}; + +class Arena { + float abs_error_{}; + + public: + Arena(std::unique_ptr&& tester, + const Place& place, + float abs_error = 1e-5) + : tester_(std::move(tester)), place_(place), abs_error_(abs_error) { + tester_->Prepare(); + } + + bool TestPrecision() { + tester_->RunBaseline(tester_->baseline_scope()); + tester_->RunInstruction(); + + bool success = true; + for (auto& out : tester_->op_desc().OutputArgumentNames()) { + for (auto& var : tester_->op_desc().Output(out)) { + success = success && CompareTensor(out, var); + } + } + LOG(INFO) << "done"; + return success; + } + + void TestPerformance(int times = 100) { + auto timer = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < times; i++) { + tester_->RunInstruction(); + } + auto duration = std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - timer); + LOG(INFO) << "average duration: " << duration.count() << " ms"; + } + + private: + // input_name: X + bool CompareTensor(const std::string& arg_name, const std::string& var_name) { + // get tensor type. + const Type* type = + tester_->instruction().kernel()->GetOutputDeclType(arg_name); + + switch (type->precision()) { + case PRECISION(kFloat): + return tester_->CheckPrecision(var_name, abs_error_); + case PRECISION(kInt8): + return tester_->CheckPrecision(var_name, abs_error_); + case PRECISION(kInt32): + return tester_->CheckPrecision(var_name, abs_error_); + case PRECISION(kBool): + return tester_->CheckPrecision(var_name, abs_error_); + + default: + LOG(FATAL) << "not support type " << PrecisionToStr(type->precision()); + } + } + + private: + std::unique_ptr tester_; + Place place_; +}; + +template +bool TestCase::CheckPrecision(const std::string& var_name, float abs_error) { + auto a_tensor = inst_scope_->FindTensor(var_name); + auto b_tensor = base_scope_->FindTensor(var_name); + CHECK(a_tensor); + CHECK(b_tensor); + + CHECK(ShapeEquals(a_tensor->dims(), b_tensor->dims())); + + CHECK(a_tensor->lod() == b_tensor->lod()) << "lod not match"; + + // The baseline should output in host devices. + CHECK(b_tensor->target() == TARGET(kHost) || + b_tensor->target() == TARGET(kX86) || + b_tensor->target() == TARGET(kARM)); + + const T* a_data{}; + switch (a_tensor->target()) { + case TARGET(kX86): + case TARGET(kHost): + case TARGET(kARM): + a_data = static_cast(a_tensor->raw_data()); + break; + + default: + // Before compare, need to copy data from `target` device to host. + LOG(FATAL) << "Not supported"; + } + + CHECK(a_data); + + const T* b_data = static_cast(b_tensor->raw_data()); + + bool success = true; + for (int i = 0; i < a_tensor->dims().production(); i++) { + EXPECT_NEAR(a_data[i], b_data[i], abs_error); + if (fabsf(a_data[i] - b_data[i]) > abs_error) { + success = false; + } + } + return success; +} + +} // namespace arena +} // namespace lite +} // namespace paddle diff --git a/lite/core/arena/framework_test.cc b/lite/core/arena/framework_test.cc new file mode 100644 index 00000000000..411ab26a716 --- /dev/null +++ b/lite/core/arena/framework_test.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/arena/framework.h" +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" + +namespace paddle { +namespace lite { + +class ScaleComputeTester : public arena::TestCase { + // common attributes for this op. + std::string input_ = "x"; + std::string output_ = "out"; + float scale_ = 1.2f; + float bias_ = 0.f; + DDim dims_{{3, 2, 10}}; + + public: + explicit ScaleComputeTester(const Place& place, const std::string& alias) + : TestCase(place, alias) {} + + void RunBaseline(Scope* scope) override { + auto* out = scope->NewTensor(output_); + CHECK(out); + out->Resize(dims_); + auto* out_data = out->mutable_data(); + + auto* x = scope->FindTensor(input_); + const auto* x_data = x->data(); + + for (int i = 0; i < dims_.production(); i++) { + out_data[i] = x_data[i] * scale_ + bias_; + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("scale"); + op_desc->SetInput("X", {input_}); + op_desc->SetOutput("Out", {output_}); + op_desc->SetAttr("scale", scale_); + op_desc->SetAttr("bias", bias_); + op_desc->SetAttr("bias_after_scale", false); + } + + void PrepareData() override { + std::vector data(dims_.production()); + + for (int i = 0; i < dims_.production(); i++) { + data[i] = i * 1.1; + } + + SetCommonTensor(input_, dims_, data.data()); + } +}; + +TEST(scale, basic) { +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); +#endif +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); +#endif + std::unique_ptr tester(new ScaleComputeTester(place, "def")); + arena::Arena arena(std::move(tester), place); + + arena.TestPrecision(); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/context.cc b/lite/core/context.cc new file mode 100644 index 00000000000..948aac0c794 --- /dev/null +++ b/lite/core/context.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/context.h" + +#ifdef LITE_WITH_OPENCL +DEFINE_string(cl_path, "/data/local/tmp/opencl", "The OpenCL kernels path."); +#endif + +namespace paddle { +namespace lite {} // namespace lite +} // namespace paddle diff --git a/lite/core/context.h b/lite/core/context.h new file mode 100644 index 00000000000..f36744dc00f --- /dev/null +++ b/lite/core/context.h @@ -0,0 +1,363 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/utils/any.h" +#ifdef LITE_WITH_CUDA +#include "lite/cuda/blas.h" +#include "lite/cuda/cuda_utils.h" +#endif +#ifdef LITE_WITH_OPENCL +#include +#include +#include "lite/opencl/cl_context.h" +#include "lite/opencl/cl_runtime.h" +#endif +#ifdef LITE_WITH_NPU +#include "lite/npu/npu_helper.h" +#endif + +#include +#include +#include +#include +#include +#include +#include "lite/core/cpu_info.h" +#include "lite/core/target_wrapper.h" +#include "lite/core/tensor.h" +#include "lite/utils/all.h" + +#ifdef LITE_WITH_OPENCL +DECLARE_string(cl_path); +#endif + +namespace paddle { +namespace lite { + +template +class Context; + +using HostContext = Context; +using X86Context = Context; +using CUDAContext = Context; +using ARMContext = Context; +using NPUContext = Context; +using OpenCLContext = Context; +using FPGAContext = Context; + +template <> +class Context { + public: + // NOTE: InitOnce should only be used by ContextScheduler + void InitOnce() {} + + void CopySharedTo(HostContext* ctx) {} + + std::string name() const { return "HostContext"; } +}; + +#ifdef LITE_WITH_NPU +template <> +class Context { + public: + Context() {} + explicit Context(const NPUContext& ctx); + // NOTE: InitOnce should only be used by ContextScheduler + void InitOnce() {} + void CopySharedTo(NPUContext* ctx) {} + + NPUContext& operator=(const NPUContext& ctx) {} + std::string name() const { return "NPUContext"; } + hiai::AiModelMngerClient* client(const std::string& model_name) const { + return npu::DeviceInfo::Global().client(model_name); + } +}; +#endif + +#ifdef LITE_WITH_ARM +template <> +class Context { + public: + Context() {} + explicit Context(const ARMContext& ctx); + + ARMContext& operator=(const ARMContext& ctx) {} + + // NOTE: InitOnce should only be used by ContextScheduler + void InitOnce() { DeviceInfo::Init(); } + + void CopySharedTo(ARMContext* ctx) {} + + void SetRunMode(PowerMode mode, int threads) { + return DeviceInfo::Global().SetRunMode(mode, threads); + } + void SetCache(int l1size, int l2size, int l3size) { + return DeviceInfo::Global().SetCache(l1size, l2size, l3size); + } + void SetArch(ARMArch arch) { return DeviceInfo::Global().SetArch(arch); } + + PowerMode mode() const { return DeviceInfo::Global().mode(); } + int threads() const { return DeviceInfo::Global().threads(); } + ARMArch arch() const { return DeviceInfo::Global().arch(); } + int l1_cache_size() const { return DeviceInfo::Global().l1_cache_size(); } + int l2_cache_size() const { return DeviceInfo::Global().l2_cache_size(); } + int l3_cache_size() const { return DeviceInfo::Global().l3_cache_size(); } + int llc_size() const { return DeviceInfo::Global().llc_size(); } + bool has_dot() const { return DeviceInfo::Global().has_dot(); } + bool has_fp16() const { return DeviceInfo::Global().has_fp16(); } + + template + T* workspace_data() { + return DeviceInfo::Global().workspace_data(); + } + + bool ExtendWorkspace(size_t size) { + return DeviceInfo::Global().ExtendWorkspace(size); + } + + std::string name() const { return "ARMContext"; } +}; +#endif + +#ifdef LITE_WITH_FPGA +// TODO(tianxiaogang): add needed implementation to context +template <> +class Context { + public: + Context() {} + void InitOnce() {} + + FPGAContext& operator=(const FPGAContext& ctx) {} + + void CopySharedTo(FPGAContext* ctx) {} + + std::string name() const { return "FPGAContext"; } +}; +#endif + +#ifdef LITE_WITH_CUDA +// Only works with CUDA kernels. +template <> +class Context { + public: + // NOTE: InitOnce should only be used by ContextScheduler + void InitOnce() { + cublas_fp32_ = std::make_shared>(); + } + + void CopySharedTo(CUDAContext* ctx) { + CHECK(ctx); + CHECK(cublas_fp32_) << "cublas_fp32 should be set first"; + ctx->cublas_fp32_ = cublas_fp32_; + } + + const cudaStream_t exec_stream() { return exec_stream_; } + void SetExecStream(cudaStream_t stream) { exec_stream_ = stream; } + + const cudaStream_t io_stream() { return io_stream_; } + void SetIoStream(cudaStream_t stream) { io_stream_ = stream; } + + std::shared_ptr> cublas_fp32() { return cublas_fp32_; } + void SetCuBlasFP32(std::shared_ptr> cublas_fp32) { + cublas_fp32_ = cublas_fp32; + } + + const std::vector& input_events() { return input_events_; } + void SetInputEvents(const std::vector& input_events) { + input_events_.clear(); + input_events_.assign(input_events.begin(), input_events.end()); + } + + const std::vector& output_events() { return output_events_; } + void SetOutputEvents(const std::vector& output_events) { + output_events_.clear(); + output_events_.assign(output_events.begin(), output_events.end()); + } + + std::string name() const { return "CUDAContext"; } + + private: + // overall information + cudaStream_t exec_stream_; + cudaStream_t io_stream_; + + // not thread-safe, should allocate for each thread. + std::shared_ptr> cublas_fp32_; + + // kernel information + std::vector input_events_; + std::vector output_events_; +}; +#endif + +#ifdef LITE_WITH_X86 +template <> +class Context { + public: + Context() {} + + Context(Context&& ctx) {} + + // NOTE: InitOnce should only be used by ContextScheduler + void InitOnce() {} + + void CopySharedTo(X86Context* ctx) {} + + std::string name() const { return "X86Context"; } + + private: + // overall information + // + // kernel information +}; +#endif + +#ifdef LITE_WITH_OPENCL +template <> +class Context { + std::shared_ptr cl_context_; + using WaitListType = + std::unordered_map(nullptr)), + std::shared_ptr>; + std::shared_ptr cl_wait_list_; + + public: + CLContext* cl_context() { return cl_context_.get(); } + WaitListType* cl_wait_list() { return cl_wait_list_.get(); } + + void InitOnce() { + // Init cl runtime. + CHECK(CLRuntime::Global()->IsInitSuccess()) << "OpenCL runtime init failed"; + CLRuntime::Global()->set_cl_path(FLAGS_cl_path); + + cl_context_ = std::make_shared(); + cl_wait_list_ = std::make_shared(); + } + + void CopySharedTo(OpenCLContext* ctx) { + ctx->cl_context_ = cl_context_; + ctx->cl_wait_list_ = cl_wait_list_; + } +}; +#endif + +// Context for running a kernel. +// Holds the necessary resource and information. +class KernelContext { + public: + template + ContextT& As() { + if (!ctx_.valid()) { + ctx_.set(); + } + return *ctx_.get_mutable(); + } + + private: + Any ctx_; +}; + +// The ContextScheduler helps to assign different context for each kernel. +class ContextScheduler { + public: + static ContextScheduler& Global() { + static auto* x = new ContextScheduler; + return *x; + } + + std::unique_ptr NewContext(TargetType target) { + std::unique_ptr ctx(new KernelContext); + switch (target) { + case TARGET(kHost): + kernel_contexts_[TargetType::kHost].As().CopySharedTo( + &ctx->As()); + break; +#ifdef LITE_WITH_X86 + case TARGET(kX86): + kernel_contexts_[TargetType::kX86].As().CopySharedTo( + &ctx->As()); + break; +#endif +#ifdef LITE_WITH_CUDA + case TARGET(kCUDA): + kernel_contexts_[TargetType::kCUDA].As().CopySharedTo( + &ctx->As()); + break; +#endif +#ifdef LITE_WITH_ARM + case TARGET(kARM): + kernel_contexts_[TargetType::kARM].As().CopySharedTo( + &ctx->As()); + break; +#endif +#ifdef LITE_WITH_NPU + case TARGET(kNPU): + kernel_contexts_[TargetType::kNPU].As().CopySharedTo( + &ctx->As()); + break; +#endif +#ifdef LITE_WITH_OPENCL + case TARGET(kOpenCL): + kernel_contexts_[TargetType::kOpenCL].As().CopySharedTo( + &ctx->As()); + break; +#endif +#ifdef LITE_WITH_FPGA + case TARGET(kFPGA): + kernel_contexts_[TargetType::kFPGA].As().CopySharedTo( + &ctx->As()); + break; +#endif + default: + LOG(FATAL) << "unsupported target " << TargetToStr(target); + } + return ctx; + } + + private: + template + void InitContext() { + kernel_contexts_[Type].As().InitOnce(); + } + + ContextScheduler() { + InitContext(); +#ifdef LITE_WITH_X86 + InitContext(); +#endif +#ifdef LITE_WITH_CUDA + InitContext(); +#endif +#ifdef LITE_WITH_ARM + InitContext(); +#endif +#ifdef LITE_WITH_OPENCL + InitContext(); +#endif +#ifdef LITE_WITH_FPGA + InitContext(); +#endif +#ifdef LITE_WITH_NPU + InitContext(); +#endif + } + + private: + std::map kernel_contexts_; +}; + +} // namespace lite +} // namespace paddle diff --git a/lite/core/context_test.cc b/lite/core/context_test.cc new file mode 100644 index 00000000000..80b642bfad1 --- /dev/null +++ b/lite/core/context_test.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/context.h" +#include + +namespace paddle { +namespace lite { + +// #ifdef LITE_WITH_X86 +// TEST(ContextScheduler, NewContext) { +// auto ctx1_p = ContextScheduler::Global().NewContext(TargetType::kX86); +// auto ctx2_p = ContextScheduler::Global().NewContext(TargetType::kX86); +// ASSERT_FALSE(ctx1_p.get() == ctx2_p.get()); + +// auto& ctx1 = ctx1_p->As(); +// auto& ctx2 = ctx2_p->As(); + +// ASSERT_EQ(ctx1.name(), "X86Context"); +// ASSERT_EQ(ctx2.name(), "X86Context"); + +// ASSERT_FALSE(ctx1.x86_device_context() == nullptr || +// ctx2.x86_device_context() == nullptr); +// ASSERT_FALSE(ctx1.x86_execution_context() == nullptr || +// ctx2.x86_execution_context() == nullptr); + +// ASSERT_TRUE(ctx1.x86_device_context() != ctx2.x86_device_context()); +// ASSERT_TRUE(ctx1.x86_execution_context() != ctx2.x86_execution_context()); + +// using device_ctx_t = ::paddle::platform::CPUDeviceContext; +// using exec_ctx_t = ::paddle::framework::ExecutionContext; +// auto* device_ctx = new device_ctx_t; +// ctx1.SetX86DeviceContext(std::unique_ptr(device_ctx)); +// ctx1.SetX86ExecutionContext( +// std::unique_ptr(new exec_ctx_t(*device_ctx))); +// } +// #endif + +} // namespace lite +} // namespace paddle diff --git a/lite/core/cpu_info.cc b/lite/core/cpu_info.cc new file mode 100644 index 00000000000..e99adf8b2a0 --- /dev/null +++ b/lite/core/cpu_info.cc @@ -0,0 +1,1073 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef LITE_WITH_LINUX +#include +#include +#endif +#if __APPLE__ +#include "TargetConditionals.h" +#if LITE_WITH_IPHONE +#include +#include +#include +#endif // LITE_WITH_IPHONE +#endif // __APPLE__ + +#ifdef ARM_WITH_OMP +#include +#endif + +#include +#include +#include "lite/core/cpu_info.h" + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_ARM + +#ifdef TARGET_IOS +const int DEFAULT_L1_CACHE_SIZE = 64 * 1024; +const int DEFAULT_L2_CACHE_SIZE = 2048 * 1024; +const int DEFAULT_L3_CACHE_SIZE = 0; +#else +const int DEFAULT_L1_CACHE_SIZE = 32 * 1024; +const int DEFAULT_L2_CACHE_SIZE = 512 * 1024; +const int DEFAULT_L3_CACHE_SIZE = 0; +#endif + +int get_cpu_num() { +#ifdef LITE_WITH_LINUX + // get cpu count from /sys/devices/system/cpu/cpunum/uevent + int max_cpu_num = 20; + int cpu_num = 0; + for (int i = 0; i < max_cpu_num; ++i) { + char path[256]; + snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%d/uevent", i); + FILE* fp = fopen(path, "rb"); + if (!fp) { + break; + } + cpu_num++; + fclose(fp); + } + if (cpu_num < 1) { + cpu_num = 1; + } + return cpu_num; +#elif defined(TARGET_IOS) + int cpu_num = 0; + size_t len = sizeof(cpu_num); + sysctlbyname("hw.ncpu", &cpu_num, &len, NULL, 0); + if (cpu_num < 1) { + cpu_num = 1; + } + return cpu_num; +#else + return 1; +#endif +} + +size_t get_mem_size() { +#ifdef LITE_WITH_LINUX + // get cpu count from /proc/cpuinfo + FILE* fp = fopen("/proc/meminfo", "rb"); + if (!fp) { + return 1; + } + size_t memsize = 0; + char line[1024]; + while (!feof(fp)) { + char* s = fgets(line, 1024, fp); + if (!s) { + break; + } + sscanf(s, "MemTotal: %d kB", &memsize); + } + fclose(fp); + return memsize; +#elif defined(TARGET_IOS) + // to be implemented + printf("not implemented\n"); +#endif + return 0; +} + +void get_cpu_arch(std::vector* archs, const int cpu_num) { + archs->resize(cpu_num); + for (int i = 0; i < cpu_num; ++i) { + archs->at(i) = kARMArch_UNKOWN; + } +#ifdef LITE_WITH_LINUX + //! get CPU ARCH + FILE* fp = fopen("/proc/cpuinfo", "rb"); + if (!fp) { + return; + } + int cpu_idx = 0; + char line[1024]; + while (!feof(fp)) { + char* s = fgets(line, 1024, fp); + if (!s) { + break; + } + if (strstr(line, "part") != NULL) { + ARMArch arch_type = kARMArch_UNKOWN; + int arch_id = 0; + sscanf(s, "CPU part\t: %x", &arch_id); + switch (arch_id) { + case 0xd03: + arch_type = kA53; + break; + case 0xd05: + arch_type = kA55; + break; + case 0xd07: + arch_type = kA57; + break; + case 0xd08: + arch_type = kA72; + break; + case 0xd09: + arch_type = kA73; + break; + case 0xd0a: + arch_type = kA75; + break; + case 0xd40: + arch_type = kA76; + break; + case 0x804: + // 855 + arch_type = kA76; + break; + case 0x805: + // 855 + arch_type = kA55; + break; + case 0x802: + // 845 + arch_type = kA75; + break; + case 0x803: + // 845 + arch_type = kA55; + break; + case 0x801: + // 835 + arch_type = kA73; + break; + case 0x800: + // 835 + arch_type = kA73; + break; + case 0x205: + // 820 + arch_type = kA72; + break; + default: + LOG(ERROR) << "Unknow cpu arch: " << arch_id; + } + archs->at(cpu_idx) = arch_type; + cpu_idx++; + } + } + fclose(fp); + for (; cpu_idx > 0 && cpu_idx < cpu_num; ++cpu_idx) { + archs->at(cpu_idx) = archs->at(cpu_idx - 1); + } +#elif defined(TARGET_IOS) + for (int i = 0; i < cpu_num; ++i) { + archs->at(i) = APPLE; + } +#endif +} + +#ifdef LITE_WITH_LINUX + +std::string get_cpu_name() { + FILE* fp = fopen("/proc/cpuinfo", "rb"); + if (!fp) { + return ""; + } + char line[1024]; + while (!feof(fp)) { + char* s = fgets(line, 1024, fp); + if (!s) { + break; + } + if (strstr(line, "Hardware") != NULL) { + fclose(fp); + return std::string(line); + } + } + fclose(fp); + return ""; +} + +void get_cpu_max_min_freq(int cpu_id, int* max_freq, int* min_freq) { + *max_freq = 0; + *min_freq = 0; + // first try, for all possible cpu + char path[256]; + snprintf(path, + sizeof(path), + "/sys/devices/system/cpu/cpufreq/stats/cpu%d/time_in_state", + cpu_id); + FILE* fp = fopen(path, "rb"); + if (!fp) { + // second try, for online cpu + snprintf(path, + sizeof(path), + "/sys/devices/system/cpu/cpu%d/cpufreq/stats/time_in_state", + cpu_id); + fp = fopen(path, "rb"); + if (!fp) { + // third try, for online cpu + // get max_freq + snprintf(path, + sizeof(path), + "/sys/devices/system/cpu/cpu%d/cpufreq/cpuinfo_max_freq", + cpu_id); + fp = fopen(path, "rb"); + if (!fp) { + return; + } + fscanf(fp, "%d", max_freq); + fclose(fp); + // get min_freq + snprintf(path, + sizeof(path), + "/sys/devices/system/cpu/cpu%d/cpufreq/cpuinfo_min_freq", + cpu_id); + fp = fopen(path, "rb"); + if (!fp) { + return; + } + fscanf(fp, "%d", min_freq); + fclose(fp); + return; + } + } + *min_freq = std::numeric_limits::max(); + while (!feof(fp)) { + int freq = 0; + int nscan = fscanf(fp, "%d %*d", &freq); + if (nscan != 1) { + break; + } + if (freq > *max_freq) { + *max_freq = freq; + } + if (freq < *min_freq) { + *min_freq = freq; + } + } + fclose(fp); +} + +void sort_cpuid_by_max_freq(const std::vector& max_freqs, + std::vector* cpu_ids, + std::vector* cluster_ids) { + int cpu_num = max_freqs.size(); + if (cpu_num == 0) { + return; + } + cpu_ids->resize(cpu_num); + cluster_ids->resize(cpu_num); + for (int i = 0; i < cpu_num; i++) { + cpu_ids->at(i) = i; + } + // sort cpuid as big core first + // simple bubble sort + for (int i = 0; i < cpu_num; i++) { + for (int j = i + 1; j < cpu_num; j++) { + if (max_freqs[i] < max_freqs[j]) { + // swap + int tmp = cpu_ids->at(i); + cpu_ids->at(i) = cpu_ids->at(j); + cpu_ids->at(j) = tmp; + } + } + } + // SMP + int mid_max_freq = + (max_freqs[cpu_ids->at(0)] + max_freqs[cpu_ids->at(cpu_num - 1)]) / 2; + + for (int i = 0; i < cpu_num; i++) { + cpu_ids->at(i) = i; + if (max_freqs[i] >= mid_max_freq) { + cluster_ids->at(i) = 0; + } else { + cluster_ids->at(i) = 1; + } + } +} + +void get_cpu_cache_size(int cpu_id, + int* l1_cache_size, + int* l2_cache_size, + int* l3_cache_size) { + int max_cache_idx_num = 10; + *l1_cache_size = DEFAULT_L1_CACHE_SIZE; + *l2_cache_size = DEFAULT_L2_CACHE_SIZE; + *l3_cache_size = DEFAULT_L3_CACHE_SIZE; + for (int i = 0; i < max_cache_idx_num; i++) { + char path[256]; + snprintf(path, + sizeof(path), + "/sys/devices/system/cpu/cpu%d/cache/index%d/level", + cpu_id, + i); + FILE* fp = fopen(path, "rb"); + if (fp) { + int level = -1; + fscanf(fp, "%d", &level); + fclose(fp); + snprintf(path, + sizeof(path), + "/sys/devices/system/cpu/cpu%d/cache/index%d/size", + cpu_id, + i); + fp = fopen(path, "rb"); + if (fp) { + int size = -1; + fscanf(fp, "%d", &size); + fclose(fp); + if (size >= 0) { + if (level == 1) { + *l1_cache_size = size * 1024; + } else if (level == 2) { + *l2_cache_size = size * 1024; + } else if (level == 3) { + *l3_cache_size = size * 1024; + } + } + } + } + } +} + +bool check_cpu_online(const std::vector& cpu_ids) { + if (cpu_ids.size() == 0) { + return false; + } + char path[256]; + bool all_online = true; + for (int i = 0; i < cpu_ids.size(); ++i) { + snprintf( + path, sizeof(path), "/sys/devices/system/cpu/cpu%d/online", cpu_ids[i]); + FILE* fp = fopen(path, "rb"); + int is_online = 0; + if (fp) { + fscanf(fp, "%d", &is_online); + fclose(fp); + } else { + LOG(ERROR) << "Failed to query the online statue of CPU id:" + << cpu_ids[i]; + } + if (is_online == 0) { + all_online = false; + LOG(ERROR) << "CPU id:" << cpu_ids[i] << " is offine"; + } + } + return all_online; +} + +int set_sched_affinity(const std::vector& cpu_ids) { +// #define CPU_SETSIZE 1024 +// #define __NCPUBITS (8 * sizeof (unsigned long)) +// typedef struct +// { +// unsigned long __bits[CPU_SETSIZE / __NCPUBITS]; +// } cpu_set_t; + +// set affinity for thread +#ifdef __GLIBC__ + pid_t pid = syscall(SYS_gettid); +#else + pid_t pid = gettid(); +#endif + cpu_set_t mask; + CPU_ZERO(&mask); + for (int i = 0; i < cpu_ids.size(); ++i) { + CPU_SET(cpu_ids[i], &mask); + } + int syscallret = syscall(__NR_sched_setaffinity, pid, sizeof(mask), &mask); + if (syscallret) { + return -1; + } + return 0; +} + +bool bind_threads(const std::vector cpu_ids) { +#ifdef ARM_WITH_OMP + int thread_num = cpu_ids.size(); + omp_set_num_threads(thread_num); + std::vector ssarets; + for (int i = 0; i < thread_num; ++i) { + ssarets.push_back(0); + } +#pragma omp parallel for + for (int i = 0; i < thread_num; i++) { + ssarets[i] = set_sched_affinity(cpu_ids); + } + for (int i = 0; i < thread_num; i++) { + if (ssarets[i] != 0) { + LOG(ERROR) << "Set cpu affinity failed, core id: " << cpu_ids[i]; + return false; + } + } +#else // ARM_WITH_OMP + std::vector first_cpu_id; + first_cpu_id.push_back(cpu_ids[0]); + int ssaret = set_sched_affinity(first_cpu_id); + if (ssaret != 0) { + LOG(ERROR) << "Set cpu affinity failed, core id: " << cpu_ids[0]; + return false; + } +#endif // ARM_WITH_OMP + return true; +} + +#endif // LITE_WITH_LINUX + +void DeviceInfo::SetDotInfo(int argc, ...) { + va_list arg_ptr; + va_start(arg_ptr, argc); + dot_.resize(core_num_); + if (argc == 1) { + bool flag = va_arg(arg_ptr, int) > 0; + for (int i = 0; i < core_num_; ++i) { + dot_[i] = flag; + } + } else { + bool flag_big_core = va_arg(arg_ptr, int) > 0; + bool flag_little_core = va_arg(arg_ptr, int) > 0; + int big_core_num = big_core_ids_.size(); + int little_core_num = little_core_ids_.size(); + for (int i = 0; i < big_core_num; ++i) { + dot_[big_core_ids_[i]] = flag_big_core; + } + for (int i = 0; i < little_core_num; ++i) { + dot_[little_core_ids_[i]] = flag_little_core; + } + } + va_end(arg_ptr); +} + +void DeviceInfo::SetFP16Info(int argc, ...) { + va_list arg_ptr; + va_start(arg_ptr, argc); + fp16_.resize(core_num_); + if (argc == 1) { + bool flag = va_arg(arg_ptr, int) > 0; + for (int i = 0; i < core_num_; ++i) { + fp16_[i] = flag; + } + } else { + bool flag_big_core = va_arg(arg_ptr, int) > 0; + bool flag_little_core = va_arg(arg_ptr, int) > 0; + int big_core_num = big_core_ids_.size(); + int little_core_num = little_core_ids_.size(); + for (int i = 0; i < big_core_num; ++i) { + fp16_[big_core_ids_[i]] = flag_big_core; + } + for (int i = 0; i < little_core_num; ++i) { + fp16_[little_core_ids_[i]] = flag_little_core; + } + } + va_end(arg_ptr); +} + +void DeviceInfo::SetFP32Info(int argc, ...) { + va_list arg_ptr; + va_start(arg_ptr, argc); + fp32_.resize(core_num_); + if (argc == 1) { + bool flag = va_arg(arg_ptr, int) > 0; + for (int i = 0; i < core_num_; ++i) { + fp32_[i] = flag; + } + } else { + bool flag_big_core = va_arg(arg_ptr, int) > 0; + bool flag_little_core = va_arg(arg_ptr, int) > 0; + int big_core_num = big_core_ids_.size(); + int little_core_num = little_core_ids_.size(); + for (int i = 0; i < big_core_num; ++i) { + fp32_[big_core_ids_[i]] = flag_big_core; + } + for (int i = 0; i < little_core_num; ++i) { + fp32_[little_core_ids_[i]] = flag_little_core; + } + } + va_end(arg_ptr); +} + +// cache_id : 0 -> L1, 1 -> L2, 2 -> L3 +void DeviceInfo::SetCacheInfo(int cache_id, int argc, ...) { + va_list arg_ptr; + va_start(arg_ptr, argc); + std::vector* cache; + switch (cache_id) { + case 0: + cache = &L1_cache_; + break; + case 1: + cache = &L2_cache_; + break; + case 2: + cache = &L3_cache_; + break; + default: + break; + } + cache->resize(core_num_); + if (argc == 1) { + int cache_size = va_arg(arg_ptr, int); + for (int i = 0; i < core_num_; ++i) { + (*cache)[i] = cache_size; + } + } else { + int big_core_num = big_core_ids_.size(); + int little_core_num = little_core_ids_.size(); + int big_core_cache_size = va_arg(arg_ptr, int); + int little_core_cache_size = va_arg(arg_ptr, int); + for (int i = 0; i < big_core_num; ++i) { + (*cache)[big_core_ids_[i]] = big_core_cache_size; + } + for (int i = 0; i < little_core_num; ++i) { + (*cache)[little_core_ids_[i]] = little_core_cache_size; + } + } + va_end(arg_ptr); +} + +void DeviceInfo::SetArchInfo(int argc, ...) { + va_list arg_ptr; + va_start(arg_ptr, argc); + archs_.resize(core_num_); + if (argc == 1) { + ARMArch arch = (ARMArch)va_arg(arg_ptr, int); + for (int i = 0; i < core_num_; ++i) { + archs_[i] = arch; + } + } else { + ARMArch big_core_arch = (ARMArch)va_arg(arg_ptr, int); + ARMArch little_core_arch = (ARMArch)va_arg(arg_ptr, int); + int big_core_num = big_core_ids_.size(); + int little_core_num = little_core_ids_.size(); + for (int i = 0; i < big_core_num; ++i) { + archs_[big_core_ids_[i]] = big_core_arch; + } + for (int i = 0; i < little_core_num; ++i) { + archs_[little_core_ids_[i]] = little_core_arch; + } + } + va_end(arg_ptr); +} + +bool DeviceInfo::SetCPUInfoByName() { + /* Snapdragon */ + if (dev_name_.find("SM8150") != std::string::npos) { // 855 + core_num_ = 8; + core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + big_core_ids_ = {4, 5, 6, 7}; + little_core_ids_ = {0, 1, 2, 3}; + cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + SetArchInfo(2, kA76, kA55); + SetCacheInfo(0, 2, 64 * 1024, 32 * 1024); + SetCacheInfo(1, 2, 256 * 1024, 128 * 1024); + SetCacheInfo(2, 1, 2048 * 1024); + SetFP16Info(1, 1); + SetDotInfo(1, 1); + return true; + } else if (dev_name_.find("SDM845") != std::string::npos) { // 845 + core_num_ = 8; + core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + big_core_ids_ = {4, 5, 6, 7}; + little_core_ids_ = {0, 1, 2, 3}; + cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + SetArchInfo(2, kA75, kA55); + SetCacheInfo(0, 2, 64 * 1024, 32 * 1024); + SetCacheInfo(1, 2, 256 * 1024, 128 * 1024); + SetCacheInfo(2, 1, 2048 * 1024); + SetFP16Info(1, 1); + return true; + } else if (dev_name_.find("SDM710") != std::string::npos) { // 710 + core_num_ = 8; + core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + big_core_ids_ = {6, 7}; + little_core_ids_ = {0, 1, 2, 3, 4, 5}; + cluster_ids_ = {1, 1, 1, 1, 1, 1, 0, 0}; + SetArchInfo(2, kA75, kA55); + SetCacheInfo(0, 2, 64 * 1024, 32 * 1024); + SetCacheInfo(1, 2, 256 * 1024, 128 * 1024); + SetCacheInfo(2, 1, 1024 * 1024); + return true; + } else if (dev_name_.find("MSM8998") != std::string::npos) { // 835 + core_num_ = 8; + core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + big_core_ids_ = {4, 5, 6, 7}; + little_core_ids_ = {0, 1, 2, 3}; + cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + SetArchInfo(2, kA73, kA53); + SetCacheInfo(0, 2, 64 * 1024, 32 * 1024); + SetCacheInfo(1, + 2, + 1024 * 1024, + /*real cache size is 2M, while that will get bad performace + on conv3x3s1 or gemm, set to 1M or 512K*/ + 1024 * 1024); + return true; + } else if (dev_name_.find("MSM8996") != std::string::npos) { // 820 + core_num_ = 4; + core_ids_ = {0, 1, 2, 3}; + big_core_ids_ = {2, 3}; + little_core_ids_ = {0, 1}; + cluster_ids_ = {1, 1, 0, 0}; + SetArchInfo(1, kA72); + SetCacheInfo(0, 1, 24 * 1024); + SetCacheInfo(1, 2, 1024 * 1024, 512 * 1024); + return true; + } else if (dev_name_.find("SDM660") != std::string::npos || + dev_name_.find("SDM636") != std::string::npos) { // 660, 636 + core_num_ = 8; + core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + big_core_ids_ = {4, 5, 6, 7}; + little_core_ids_ = {0, 1, 2, 3}; + cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + SetArchInfo(1, kA73); + SetCacheInfo(0, 2, 64 * 1024, 32 * 1024); + SetCacheInfo(1, 1, 1024 * 1024); + return true; + } else if (dev_name_.find("MSM8976") != std::string::npos) { // 652,653 + core_num_ = 8; + core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + big_core_ids_ = {4, 5, 6, 7}; + little_core_ids_ = {0, 1, 2, 3}; + cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + SetArchInfo(2, kA72, kA53); + SetCacheInfo(0, 1, 32 * 1024); + SetCacheInfo(1, 2, 1024 * 1024, 512 * 1024); + return true; + } else if (dev_name_.find("MSM8953") != std::string::npos) { // 625 + core_num_ = 8; + core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + big_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + little_core_ids_ = {}; + cluster_ids_ = {0, 0, 0, 0, 0, 0, 0, 0}; + SetArchInfo(1, kA53); + SetCacheInfo(0, 1, 32 * 1024); + SetCacheInfo(1, 1, 1024 * 1024); + return true; + } else if (dev_name_.find("MSM8939") != std::string::npos) { // 615 + core_num_ = 8; + core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + big_core_ids_ = {0, 1, 2, 3}; + little_core_ids_ = {4, 5, 6, 7}; + cluster_ids_ = {0, 0, 0, 0, 1, 1, 1, 1}; + SetArchInfo(1, kA53); + SetCacheInfo(0, 1, 32 * 1024); + SetCacheInfo(1, 2, 512 * 1024, 256 * 1024); + return true; + /* MediaTek */ + } else if (dev_name_.find("MT6797") != + std::string::npos) { // X20/X23/X25/X27 + core_num_ = 10; + core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + big_core_ids_ = {8, 9}; + little_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cluster_ids_ = {1, 1, 1, 1, 1, 1, 1, 1, 0, 0}; + SetArchInfo(2, kA72, kA53); + SetCacheInfo(0, 1, 32 * 1024); + SetCacheInfo(1, 2, 1024 * 1024, 512 * 1024); + return true; + } else if (dev_name_.find("MT6799") != std::string::npos) { // X30 + core_num_ = 10; + core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + big_core_ids_ = {8, 9}; + little_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cluster_ids_ = {1, 1, 1, 1, 1, 1, 1, 1, 0, 0}; + SetArchInfo(2, kA73, kA53); + return true; + } else if (dev_name_.find("MT6795") != std::string::npos || + dev_name_.find("MT6762") != std::string::npos || + dev_name_.find("MT6755T") != std::string::npos || + dev_name_.find("MT6755S") != std::string::npos || + dev_name_.find("MT6753") != std::string::npos || + dev_name_.find("MT6752") != std::string::npos || + dev_name_.find("MT6750") != std::string::npos) { + // X10, P22, P15/P18, MT6753, MT6752/MT6752M, MT6750 + core_num_ = 8; + core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + big_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + little_core_ids_ = {}; + cluster_ids_ = {0, 0, 0, 0, 0, 0, 0, 0}; + SetArchInfo(1, kA53); + return true; + } else if (dev_name_.find("MT6758") != std::string::npos || + dev_name_.find("MT6757") != std::string::npos || + dev_name_.find("MT6763") != std::string::npos || + dev_name_.find("MT6755M") != std::string::npos || + dev_name_.find("MT6755") != + std::string::npos) { // P30, P20/P25, P23, P10 + core_num_ = 8; + core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + big_core_ids_ = {4, 5, 6, 7}; + little_core_ids_ = {0, 1, 2, 3}; + cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + SetArchInfo(1, kA53); + return true; + } else if (dev_name_.find("MT6771") != std::string::npos) { // P60 + core_num_ = 8; + core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + big_core_ids_ = {4, 5, 6, 7}; + little_core_ids_ = {0, 1, 2, 3}; + cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + SetArchInfo(2, kA73, kA53); + return true; + } else if (dev_name_.find("MT6765") != std::string::npos || + dev_name_.find("MT6739") != std::string::npos || + dev_name_.find("MT6738") != std::string::npos || + dev_name_.find("MT6737") != + std::string::npos) { // A22, MT6739, MT6738, MT6767 + core_num_ = 4; + core_ids_ = {0, 1, 2, 3}; + big_core_ids_ = {0, 1, 2, 3}; + little_core_ids_ = {}; + cluster_ids_ = {0, 0, 0, 0}; + SetArchInfo(1, kA53); + return true; + } else if (dev_name_.find("KIRIN980") != std::string::npos) { // Kirin 980 + core_num_ = 8; + core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + big_core_ids_ = {4, 5, 6, 7}; + little_core_ids_ = {0, 1, 2, 3}; + cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + SetArchInfo(2, kA76, kA55); + SetCacheInfo(0, 2, 64 * 1024, 32 * 1024); + SetCacheInfo(1, 2, 512 * 1024, 128 * 1024); + SetCacheInfo(2, 1, 4096 * 1024); + SetFP16Info(1, 1); + SetDotInfo(1, 1); + return true; + } + return false; +} + +void DeviceInfo::SetCPUInfoByProb() { +#ifdef LITE_WITH_LINUX + // get big.LITTLE cores by sorting CPU frequency + sort_cpuid_by_max_freq(max_freqs_, &core_ids_, &cluster_ids_); + big_core_ids_.clear(); + little_core_ids_.clear(); + for (int i = 0; i < cluster_ids_.size(); ++i) { + if (cluster_ids_[i] == 0) { + big_core_ids_.push_back(core_ids_[i]); + } else { + little_core_ids_.push_back(core_ids_[i]); + } + } + // get l1, l2, l3 cache size for each core + for (int i = 0; i < core_num_; i++) { + get_cpu_cache_size(i, &(L1_cache_[i]), &(L2_cache_[i]), &(L3_cache_[i])); + } +#endif // LITE_WITH_LINUX +} + +void DeviceInfo::RequestPowerFullMode(int thread_num) { + int big_core_size = big_core_ids_.size(); + int little_core_size = little_core_ids_.size(); + active_ids_.clear(); + for (int i = 0; i < thread_num; ++i) { + if (i < big_core_size) { + active_ids_.push_back(big_core_ids_[i]); + } else if (i < big_core_size + little_core_size) { + active_ids_.push_back(little_core_ids_[i - big_core_size]); + } + } + mode_ = LITE_POWER_FULL; +} + +void DeviceInfo::RequestPowerHighMode(int thread_num) { + int big_core_size = big_core_ids_.size(); + int little_core_size = little_core_ids_.size(); + active_ids_.clear(); + if (big_core_size > 0) { + mode_ = LITE_POWER_HIGH; + if (thread_num > big_core_size) { + LOG(ERROR) << "Request thread num: " << thread_num + << ", exceed the big cores size: " << big_core_size + << ", truncate thread num to " << big_core_size; + active_ids_ = big_core_ids_; + } else { + for (int i = 0; i < thread_num; ++i) { + active_ids_.push_back(big_core_ids_[i]); + } + } + } else { + mode_ = LITE_POWER_LOW; + LOG(ERROR) << "HIGH POWER MODE is not support, switch to little cores."; + if (thread_num > little_core_size) { + active_ids_ = little_core_ids_; + } else { + for (int i = 0; i < thread_num; ++i) { + active_ids_.push_back(little_core_ids_[i]); + } + } + } +} + +void DeviceInfo::RequestPowerLowMode(int thread_num) { + int big_core_size = big_core_ids_.size(); + int little_core_size = little_core_ids_.size(); + active_ids_.clear(); + if (little_core_size > 0) { + mode_ = LITE_POWER_LOW; + if (thread_num > little_core_size) { + LOG(WARNING) << "Request thread num: " << thread_num + << ", exceed the little cores size: " << little_core_size + << ", truncate thread num to " << little_core_size; + active_ids_ = little_core_ids_; + } else { + for (int i = 0; i < thread_num; i++) { + active_ids_.push_back(little_core_ids_[i]); + } + } + } else { + mode_ = LITE_POWER_HIGH; + LOG(WARNING) << "LOW POWER MODE is not support, switch to big cores"; + if (thread_num > big_core_size) { + active_ids_ = big_core_ids_; + } else { + for (int i = 0; i < thread_num; i++) { + active_ids_.push_back(big_core_ids_[i]); + } + } + } +} + +void DeviceInfo::RequestPowerNoBindMode(int thread_num) { + active_ids_.clear(); + if (thread_num > core_ids_.size()) { + active_ids_ = core_ids_; + } else { + active_ids_.resize(thread_num); + for (int i = 0; i < thread_num; ++i) { + if (i < big_core_ids_.size()) { + active_ids_[i] = big_core_ids_[i]; + } else { + active_ids_[i] = little_core_ids_[i - big_core_ids_.size()]; + } + } + } + mode_ = LITE_POWER_NO_BIND; +} + +void DeviceInfo::RequestPowerRandHighMode(int shift_num, int thread_num) { + int big_core_size = big_core_ids_.size(); + int little_core_size = little_core_ids_.size(); + active_ids_.clear(); + if (big_core_size > 0) { + mode_ = LITE_POWER_RAND_HIGH; + if (thread_num > big_core_size) { + LOG(WARNING) << "Request thread num: " << thread_num + << ", exceed the big cores size: " << big_core_size + << ", truncate thread num to " << big_core_size; + active_ids_ = big_core_ids_; + } else { + for (int i = 0; i < thread_num; ++i) { + active_ids_.push_back(big_core_ids_[(i + shift_num) % big_core_size]); + } + } + } else { + mode_ = LITE_POWER_LOW; + LOG(WARNING) << "HIGH POWER MODE is not support, switch to little cores."; + if (thread_num > little_core_size) { + active_ids_ = little_core_ids_; + } else { + for (int i = 0; i < thread_num; ++i) { + active_ids_.push_back(little_core_ids_[i]); + } + } + } +} + +void DeviceInfo::RequestPowerRandLowMode(int shift_num, int thread_num) { + int big_core_size = big_core_ids_.size(); + int little_core_size = little_core_ids_.size(); + active_ids_.clear(); + if (little_core_size > 0) { + mode_ = LITE_POWER_RAND_LOW; + if (thread_num > little_core_size) { + LOG(WARNING) << "Request thread num: " << thread_num + << ", exceed the little cores size: " << little_core_size + << ", truncate thread num to " << little_core_size; + active_ids_ = little_core_ids_; + } else { + for (int i = 0; i < thread_num; ++i) { + active_ids_.push_back( + little_core_ids_[(i + shift_num) % little_core_size]); + } + } + } else { + mode_ = LITE_POWER_HIGH; + LOG(WARNING) << "LOW POWER MODE is not support, switch to big cores."; + if (thread_num > big_core_size) { + active_ids_ = big_core_ids_; + } else { + for (int i = 0; i < thread_num; ++i) { + active_ids_.push_back(big_core_ids_[i]); + } + } + } +} + +int DeviceInfo::Setup() { + core_num_ = get_cpu_num(); + mem_size_ = get_mem_size(); + get_cpu_arch(&archs_, core_num_); + // set defalut CPU info + SetCacheInfo(0, 1, DEFAULT_L1_CACHE_SIZE); + SetCacheInfo(1, 1, DEFAULT_L2_CACHE_SIZE); + SetCacheInfo(2, 1, DEFAULT_L3_CACHE_SIZE); + SetFP32Info(1, 1); + SetFP16Info(1, 0); + SetDotInfo(1, 0); +#ifdef LITE_WITH_LINUX + // get max&min freq + max_freqs_.resize(core_num_); + min_freqs_.resize(core_num_); + for (int i = 0; i < core_num_; ++i) { + int max_freq, min_freq; + get_cpu_max_min_freq(i, &max_freq, &min_freq); + max_freqs_[i] = max_freq / 1000; + min_freqs_[i] = min_freq / 1000; + } + // get cache size and big.LITTLE core ids + dev_name_ = get_cpu_name(); + if (!SetCPUInfoByName()) { + SetCPUInfoByProb(); + } + // output info + LOG(INFO) << "ARM multiprocessors name: " << dev_name_; + LOG(INFO) << "ARM multiprocessors number: " << core_num_; + for (int i = 0; i < core_num_; ++i) { + LOG(INFO) << "ARM multiprocessors ID: " << core_ids_[i] + << ", max freq: " << max_freqs_[i] + << ", min freq: " << min_freqs_[i] + << ", cluster ID: " << cluster_ids_[core_ids_[i]] + << ", CPU ARCH: A" << archs_[i]; + } + LOG(INFO) << "L1 DataCache size is: "; + for (int i = 0; i < core_num_; ++i) { + LOG(INFO) << L1_cache_[i] / 1024 << " KB"; + } + LOG(INFO) << "L2 Cache size is: "; + for (int i = 0; i < core_num_; ++i) { + LOG(INFO) << L2_cache_[i] / 1024 << " KB"; + } + LOG(INFO) << "L3 Cache size is: "; + for (int i = 0; i < core_num_; ++i) { + LOG(INFO) << L3_cache_[i] / 1024 << " KB"; + } + LOG(INFO) << "Total memory: " << mem_size_ << "KB"; +#endif + // set default run mode + SetRunMode(LITE_POWER_NO_BIND, 1); // use single thread by default + return 0; +} + +void DeviceInfo::SetRunMode(PowerMode mode, int thread_num) { +#ifdef ARM_WITH_OMP + thread_num = std::min(thread_num, core_num_); +#else + thread_num = 1; // force thread_num to 1 if OpenMP is disabled +#endif +#ifdef LITE_WITH_LINUX + int big_core_size = big_core_ids_.size(); + int little_core_size = little_core_ids_.size(); + int big_little_core_size = big_core_size + little_core_size; + thread_num = std::min(thread_num, big_little_core_size); + count_++; + int shift_num = (count_ / 10) % big_core_size; + switch (mode) { + case LITE_POWER_FULL: + RequestPowerFullMode(thread_num); + break; + case LITE_POWER_HIGH: + RequestPowerHighMode(thread_num); + break; + case LITE_POWER_LOW: + RequestPowerLowMode(thread_num); + break; + case LITE_POWER_NO_BIND: + RequestPowerNoBindMode(thread_num); + break; + case LITE_POWER_RAND_HIGH: + RequestPowerRandHighMode(shift_num, thread_num); + break; + case LITE_POWER_RAND_LOW: + RequestPowerRandLowMode(shift_num, thread_num); + break; + default: + LOG(FATAL) << "Unsupported power mode: " << mode; + break; + } + if (active_ids_.empty()) { + active_ids_.push_back(0); + } +#ifdef ARM_WITH_OMP + omp_set_num_threads(active_ids_.size()); +#endif + if (mode_ != LITE_POWER_NO_BIND) { + if (check_cpu_online(active_ids_)) { + bind_threads(active_ids_); + } else { + LOG(WARNING) << "Some cores are offline, switch to NO BIND MODE"; + mode_ = LITE_POWER_NO_BIND; + } + } +#else // LITE_WITH_LINUX + // only LITE_POWER_NO_BIND is supported in other OS + RequestPowerNoBindMode(thread_num); +#ifdef ARM_WITH_OMP + omp_set_num_threads(active_ids_.size()); +#endif +#endif // LITE_WITH_LINUX + //! alloc memory for sgemm in this context + workspace_.Resize({llc_size()}); + workspace_.mutable_data(); + arch_ = archs_[active_ids_[0]]; +} + +void DeviceInfo::SetCache(int l1size, int l2size, int l3size) { + SetCacheInfo(0, 1, l1size); + SetCacheInfo(1, 1, l2size); + SetCacheInfo(2, 1, l3size); + workspace_.Resize({2 * (l1size + l2size)}); +} + +bool DeviceInfo::ExtendWorkspace(size_t size) { + workspace_.Resize({size + llc_size()}); + workspace_.mutable_data(); + return true; +} + +#endif // LITE_WITH_ARM + +} // namespace lite +} // namespace paddle diff --git a/lite/core/cpu_info.h b/lite/core/cpu_info.h new file mode 100644 index 00000000000..495f95943e9 --- /dev/null +++ b/lite/core/cpu_info.h @@ -0,0 +1,135 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/core/tensor.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_ARM + +typedef enum { + LITE_POWER_HIGH = 0, + LITE_POWER_LOW = 1, + LITE_POWER_FULL = 2, + LITE_POWER_NO_BIND = 3, + LITE_POWER_RAND_HIGH = 4, + LITE_POWER_RAND_LOW = 5 +} PowerMode; + +typedef enum { + kAPPLE = 0, + kA53 = 53, + kA55 = 55, + kA57 = 57, + kA72 = 72, + kA73 = 73, + kA75 = 75, + kA76 = 76, + kARMArch_UNKOWN = -1 +} ARMArch; + +class DeviceInfo { + public: + static DeviceInfo& Global() { + static auto* x = new DeviceInfo; + return *x; + } + + static int Init() { + static int ret = Global().Setup(); + return ret; + } + + int Setup(); + + void SetRunMode(PowerMode mode, int thread_num); + void SetCache(int l1size, int l2size, int l3size); + void SetArch(ARMArch arch) { arch_ = arch; } + + PowerMode mode() const { return mode_; } + int threads() const { return active_ids_.size(); } + ARMArch arch() const { return arch_; } + int l1_cache_size() const { return L1_cache_[active_ids_[0]]; } + int l2_cache_size() const { return L2_cache_[active_ids_[0]]; } + int l3_cache_size() const { return L3_cache_[active_ids_[0]]; } + int llc_size() const { + auto size = L3_cache_[active_ids_[0]] > 0 ? L3_cache_[active_ids_[0]] + : L2_cache_[active_ids_[0]]; + return size > 0 ? size : 512 * 1024; + } + bool has_dot() const { return dot_[active_ids_[0]]; } + bool has_fp16() const { return fp16_[active_ids_[0]]; } + + template + T* workspace_data() { + return reinterpret_cast(workspace_.mutable_data()); + } + bool ExtendWorkspace(size_t size); + + private: + int core_num_; + std::vector max_freqs_; + std::vector min_freqs_; + int mem_size_; + std::string dev_name_; + + std::vector L1_cache_; + std::vector L2_cache_; + std::vector L3_cache_; + std::vector core_ids_; + std::vector big_core_ids_; + std::vector little_core_ids_; + std::vector cluster_ids_; + std::vector archs_; + std::vector fp32_; + std::vector fp16_; + std::vector dot_; + + ARMArch arch_; + // LITE_POWER_HIGH stands for using big cores, + // LITE_POWER_LOW stands for using small core, + // LITE_POWER_FULL stands for using all cores + PowerMode mode_; + std::vector active_ids_; + TensorLite workspace_; + int64_t count_{0}; + + void SetDotInfo(int argc, ...); + void SetFP16Info(int argc, ...); + void SetFP32Info(int argc, ...); + void SetCacheInfo(int cache_id, int argc, ...); + void SetArchInfo(int argc, ...); + bool SetCPUInfoByName(); + void SetCPUInfoByProb(); + void RequestPowerFullMode(int thread_num); + void RequestPowerHighMode(int thread_num); + void RequestPowerLowMode(int thread_num); + void RequestPowerNoBindMode(int thread_num); + void RequestPowerRandHighMode(int shift_num, int thread_num); + void RequestPowerRandLowMode(int shift_num, int thread_num); + + DeviceInfo() = default; +}; + +#endif // LITE_WITH_ARM + +} // namespace lite +} // namespace paddle diff --git a/lite/core/framework.proto b/lite/core/framework.proto new file mode 100644 index 00000000000..6c60a041a19 --- /dev/null +++ b/lite/core/framework.proto @@ -0,0 +1,188 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +syntax = "proto2"; +// option optimize_for = LITE_RUNTIME; +package paddle.framework.proto; + +// Any incompatible changes to ProgramDesc and its dependencies should +// raise the version defined version.h. +// +// Serailization and Deserialization codes should be modified in a way +// that supports old versions following the version and compatibility policy. +message Version { optional int64 version = 1 [ default = 0 ]; } + +enum AttrType { + INT = 0; + FLOAT = 1; + STRING = 2; + INTS = 3; + FLOATS = 4; + STRINGS = 5; + BOOLEAN = 6; + BOOLEANS = 7; + BLOCK = 8; + LONG = 9; + BLOCKS = 10; + LONGS = 11; +} + +// OpDesc describes an instance of a C++ framework::OperatorBase +// derived class type. +message OpDesc { + + message Attr { + required string name = 1; + required AttrType type = 2; + optional int32 i = 3; + optional float f = 4; + optional string s = 5; + repeated int32 ints = 6; + repeated float floats = 7; + repeated string strings = 8; + optional bool b = 10; + repeated bool bools = 11; + optional int32 block_idx = 12; + optional int64 l = 13; + repeated int32 blocks_idx = 14; + repeated int64 longs = 15; + }; + + message Var { + required string parameter = 1; + repeated string arguments = 2; + }; + + required string type = 3; + repeated Var inputs = 1; + repeated Var outputs = 2; + repeated Attr attrs = 4; + optional bool is_target = 5 [ default = false ]; +}; + +// OpProto describes a C++ framework::OperatorBase derived class. +message OpProto { + + // VarProto describes the C++ type framework::Variable. + message Var { + required string name = 1; + required string comment = 2; + + optional bool duplicable = 3 [ default = false ]; + optional bool intermediate = 4 [ default = false ]; + optional bool dispensable = 5 [ default = false ]; + } + + // AttrProto describes the C++ type Attribute. + message Attr { + required string name = 1; + required AttrType type = 2; + required string comment = 3; + // If that attribute is generated, it means the Paddle third + // language binding has responsibility to fill that + // attribute. End-User should not set that attribute. + optional bool generated = 4 [ default = false ]; + } + + required string type = 1; + repeated Var inputs = 2; + repeated Var outputs = 3; + repeated Attr attrs = 4; + required string comment = 5; +} + +message VarType { + enum Type { + // Pod Types + BOOL = 0; + INT16 = 1; + INT32 = 2; + INT64 = 3; + FP16 = 4; + FP32 = 5; + FP64 = 6; + // Tensor is used in C++. + SIZE_T = 19; + UINT8 = 20; + INT8 = 21; + + // Other types that may need additional descriptions + LOD_TENSOR = 7; + SELECTED_ROWS = 8; + FEED_MINIBATCH = 9; + FETCH_LIST = 10; + STEP_SCOPES = 11; + LOD_RANK_TABLE = 12; + LOD_TENSOR_ARRAY = 13; + PLACE_LIST = 14; + READER = 15; + // Any runtime decided variable type is raw + // raw variables should manage their own allocations + // in operators like nccl_op + RAW = 17; + TUPLE = 18; + } + + required Type type = 1; + + message TensorDesc { + // Should only be PODType. Is enforced in C++ + required Type data_type = 1; + repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] + } + optional TensorDesc selected_rows = 2; + + message LoDTensorDesc { + required TensorDesc tensor = 1; + optional int32 lod_level = 2 [ default = 0 ]; + } + optional LoDTensorDesc lod_tensor = 3; + + message LoDTensorArrayDesc { + required TensorDesc tensor = 1; + optional int32 lod_level = 2 [ default = 0 ]; + } + optional LoDTensorArrayDesc tensor_array = 4; + + message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; } + optional ReaderDesc reader = 5; + + message Tuple { repeated Type element_type = 1; } + optional Tuple tuple = 7; +} + +message VarDesc { + required string name = 1; + required VarType type = 2; + optional bool persistable = 3 [ default = false ]; +} + +message BlockDesc { + required int32 idx = 1; + required int32 parent_idx = 2; + repeated VarDesc vars = 3; + repeated OpDesc ops = 4; + optional int32 forward_block_idx = 5 [ default = -1 ]; +} + +// Please refer to +// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md +// for more details. +// TODO(panyx0718): A model can have multiple programs. Need a +// way to distinguish them. Maybe ID or name? +message ProgramDesc { + repeated BlockDesc blocks = 1; + + optional Version version = 2; +} diff --git a/lite/core/kernel.cc b/lite/core/kernel.cc new file mode 100644 index 00000000000..7ec718cb388 --- /dev/null +++ b/lite/core/kernel.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/kernel.h" +#include +#include "lite/utils/string.h" + +namespace paddle { +namespace lite { + +std::string KernelBase::summary() const { + STL::stringstream ss; + ss << op_type() << ":" << TargetToStr(target()) << "/" + << PrecisionToStr(precision()) << "/" << DataLayoutToStr(layout()) << "(" + << alias() << ")"; + return ss.str(); +} + +const Type *KernelBase::GetInputDeclType(const std::string &arg_name) const { + CHECK(!op_type_.empty()) << "op_type should be set first"; + const auto *type = ParamTypeRegistry::Global().RetrieveInArgument( + place(), GenParamTypeKey(), arg_name); + CHECK(type) << "no type registered for kernel [" << op_type_ + << "] input argument [" << arg_name << "]" + << " with key " << GenParamTypeKey(); + return type->type; +} + +const Type *KernelBase::GetOutputDeclType(const std::string &arg_name) const { + CHECK(!op_type_.empty()) << "op_type should be set first"; + const auto *type = ParamTypeRegistry::Global().RetrieveOutArgument( + place(), GenParamTypeKey(), arg_name); + CHECK(type) << "no type registered for kernel [" << GenParamTypeKey() + << "] output argument [" << arg_name << "]"; + return type->type; +} + +std::string KernelBase::GenParamTypeKey() const { + STL::stringstream ss; + ss << op_type() << "/" << alias_; + return ss.str(); +} + +void KernelBase::ParseKernelType(const std::string &kernel_type, + std::string *op_type, + std::string *alias, + Place *place) { + auto parts = Split(kernel_type, "/"); + CHECK_EQ(parts.size(), 5); + *op_type = parts[0]; + *alias = parts[1]; + + std::string target, precision, layout; + + target = parts[2]; + precision = parts[3]; + layout = parts[4]; + + place->target = static_cast(std::atoi(target.c_str())); + place->precision = static_cast(std::atoi(precision.c_str())); + place->layout = static_cast(std::atoi(layout.c_str())); +} + +std::string KernelBase::SerializeKernelType(const std::string &op_type, + const std::string &alias, + const Place &place) { + STL::stringstream ss; + ss << op_type << "/"; + ss << alias << "/"; + // We serialize the place value not the string representation here for + // easier deserialization. + ss << static_cast(place.target) << "/"; + ss << static_cast(place.precision) << "/"; + ss << static_cast(place.layout); + return ss.str(); +} + +bool ParamTypeRegistry::KeyCmp::operator()( + const ParamTypeRegistry::key_t &a, + const ParamTypeRegistry::key_t &b) const { + return a.hash() < b.hash(); +} + +STL::ostream &operator<<(STL::ostream &os, + const ParamTypeRegistry::KernelIdTy &other) { + std::string io_s = other.io == ParamTypeRegistry::IO::kInput ? "in" : "out"; + os << other.kernel_type << ":" << other.arg_name << ":" << io_s << ":" + << other.place.DebugString(); + return os; +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/kernel.h b/lite/core/kernel.h new file mode 100644 index 00000000000..85244c5b14f --- /dev/null +++ b/lite/core/kernel.h @@ -0,0 +1,189 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "lite/arm/math/type_trans.h" +#include "lite/core/context.h" +#include "lite/core/target_wrapper.h" +#include "lite/core/type_system.h" +#include "lite/core/types.h" +#include "lite/core/workspace.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" +#include "lite/utils/replace_stl/stream.h" + +namespace paddle { +namespace lite { + +// An base with virtual functions to unify all the kernel implementation on +// different targets. +class KernelBase { + public: + // type_infer_handler is used to inference a output type by considering the + // input types in the type system. + using type_infer_handler_t = std::function& input_types, + const std::string& out_arg)>; + + protected: + /// Run some initialization before `Run`, it will invoke after `SetParam` and + /// `SetContext`, that is both the param_ and context_ are valid. + virtual void PrepareForRun() {} + + /// Run the kernel. Before Run, both the param_ and context_ should be valid. + virtual void Run() = 0; + + public: + void Launch() { + if (is_first_epoch_) { + PrepareForRun(); + is_first_epoch_ = false; + } + + // Reset the workspace to make every kernel in the same thread to share the + // temporary memory. + WorkSpace::Global_Host().AllocReset(); +#if defined(LITE_WITH_X86) + WorkSpace::Global_X86().AllocReset(); +#endif +#if defined(LITE_WITH_CUDA) + WorkSpace::Global_CUDA().AllocReset(); +#endif + Run(); + } + + void SetContext(std::unique_ptr&& ctx) { + ctx_ = std::move(ctx); + } + template + void SetParam(T param) { + param_.set(param); + } + template + P& Param() const { + return *param_.get_mutable

(); + } + + // This is used in the kernels that takes 'kAny' places and inference the + // output place. For `ScaleCompute` and `IoCopyCompute`, their input types are + // declared as 'kAny' in some Place field, and the output is also `kAny`, but + // when in real execution, when takes some non-kAny type as input, the + // output's kAny-fields can be determained. For example, when the + // `ScaleCompute` takes `TensorFp32NCHWTy` as input, its output should be also + // `TensorFp32NCHWTy`. This type inference rule is different for each kernel, + // so we make it a virtual method. + // One can custom this handler to make a specific type inference rule for a + // kernel, or leave the default to force the kernel use the system's + // type-inference rules. + virtual std::unique_ptr GetTypeInferHandler() { + return nullptr; + } + + void set_op_type(const std::string& type) { op_type_ = type; } + const std::string& op_type() const { return op_type_; } + + // Get input declaration Type. + const Type* GetInputDeclType(const std::string& arg_name) const; + + // Get output declaration Type. + const Type* GetOutputDeclType(const std::string& arg_name) const; + + void set_alias(const std::string& x) { alias_ = x; } + const std::string& alias() const { return alias_; } + + virtual Place place() const = 0; + virtual TargetType target() const = 0; + virtual PrecisionType precision() const = 0; + virtual DataLayoutType layout() const = 0; + const KernelContext* context() const { return ctx_.get(); } + KernelContext* mutable_context() { return ctx_.get(); } + virtual std::string name() const = 0; + + // Short human-readable document. + std::string summary() const; + // Long human-readable document. + virtual std::string doc() const { return ""; } + // Generate the key of the parameter type. + std::string GenParamTypeKey() const; + + // Used to serialize the kernel. + std::string SerializedKernelType() const { + return SerializeKernelType(op_type(), alias(), place()); + } + + static std::string SerializeKernelType(const std::string& op_type, + const std::string& alias, + const Place& place); + + static void ParseKernelType(const std::string& kernel_type, + std::string* op_type, + std::string* alias, + Place* place); + + std::string key_with_alias() const { return op_type() + "/" + alias(); } + + virtual ~KernelBase() = default; + void Torch() {} + + protected: + std::unique_ptr ctx_{nullptr}; + mutable operators::param_t param_; + // The corresponding op type. + std::string op_type_{}; + // The extra identity to help defficiate a specific kernel, op_type_ + alias_ + // is the unique ID for the kernel. + std::string alias_{}; + bool is_first_epoch_{true}; +}; + +// Light-weight kernel implementation. +// The OpKernel is designed to implement the specific algorithm on a target +// device. +// TODO(Superjomn) Consider to add a Platform type to differentiate CUDNN, +// MKLDNN, plain CUDA C implementations. +template +class KernelLite : public KernelBase { + public: + // Run the kernel. + virtual void Run() { CHECK(false) << "Not Implemented"; } + + TargetType target() const override { return Target; } + PrecisionType precision() const override { return Precision; } + DataLayoutType layout() const override { return DataLayout; } + Place place() const override { return Place{Target, Precision, DataLayout}; } + std::string name() const override; + + void Touch() {} + + KernelLite() = default; + virtual ~KernelLite() = default; +}; + +template +std::string KernelLite::name() const { + return op_type() + ":" + TargetToStr(Target) + "/" + + PrecisionToStr(Precision) + "/" + DataLayoutToStr(DataLayout); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/kernel_test.cc b/lite/core/kernel_test.cc new file mode 100644 index 00000000000..8ad8b477443 --- /dev/null +++ b/lite/core/kernel_test.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/kernel.h" +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace core { + +int test_code{-1}; +class SomeKernel : public KernelLite { + public: + void Run() override { + LOG(INFO) << "SomeKernel executed"; + LOG(INFO) << Param().in_num_col_dims; + test_code = Param().in_num_col_dims; + } + + TargetType target() const override { return TARGET(kHost); } + PrecisionType precision() const override { return PRECISION(kFloat); } +}; + +TEST(Kernel, test) { + SomeKernel kernel; + operators::FcParam param; + param.in_num_col_dims = 100; + kernel.SetParam(param); + kernel.Run(); + ASSERT_EQ(test_code, 100); +} + +TEST(Kernel, kernel_type) { + const std::string op_type = "fc"; + const std::string alias = "def"; + Place place(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); + auto kernel_type = KernelBase::SerializeKernelType(op_type, alias, place); + LOG(INFO) << "kernel_type: " << kernel_type; + ASSERT_EQ(kernel_type, "fc/def/1/1/1"); + + std::string op_type1, alias1; + Place place1; + KernelBase::ParseKernelType(kernel_type, &op_type1, &alias1, &place1); + ASSERT_EQ(op_type, op_type1); + ASSERT_EQ(alias, alias1); + ASSERT_EQ(place, place1); +} + +} // namespace core +} // namespace lite +} // namespace paddle diff --git a/lite/core/lite.map b/lite/core/lite.map new file mode 100644 index 00000000000..31adae42196 --- /dev/null +++ b/lite/core/lite.map @@ -0,0 +1,6 @@ +{ + global: + *paddle*; + local: + *; +}; diff --git a/lite/core/lite_gtest_main.cc b/lite/core/lite_gtest_main.cc new file mode 100644 index 00000000000..9784fc79945 --- /dev/null +++ b/lite/core/lite_gtest_main.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + google::ParseCommandLineFlags(&argc, &argv, false); + + return RUN_ALL_TESTS(); +} diff --git a/lite/core/lite_tensor_test.cc b/lite/core/lite_tensor_test.cc new file mode 100644 index 00000000000..d667a9f8852 --- /dev/null +++ b/lite/core/lite_tensor_test.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { + +TEST(tensor, test) { + TensorLite tensor; + DDimLite ddim({1, 8}); + tensor.Resize(ddim); + + for (int i = 0; i < 8; i++) { + tensor.mutable_data()[i] = i; + } +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/memory.cc b/lite/core/memory.cc new file mode 100644 index 00000000000..67a1f83bfdb --- /dev/null +++ b/lite/core/memory.cc @@ -0,0 +1,110 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/memory.h" + +namespace paddle { +namespace lite { + +void* TargetMalloc(TargetType target, size_t size) { + void* data{nullptr}; + switch (target) { + case TargetType::kHost: + case TargetType::kX86: + case TargetType::kARM: + data = TargetWrapper::Malloc(size); + break; +#ifdef LITE_WITH_CUDA + case TargetType::kCUDA: + data = + TargetWrapper::Malloc(size); + break; +#endif // LITE_WITH_CUDA +#ifdef LITE_WITH_OPENCL + case TargetType::kOpenCL: + data = TargetWrapperCL::Malloc(size); + break; +#endif // LITE_WITH_OPENCL +#ifdef LITE_WITH_FPGA + case TargetType::kFPGA: + data = TargetWrapper::Malloc(size); + break; +#endif // LITE_WITH_OPENCL + default: + LOG(FATAL) << "Unknown supported target " << TargetToStr(target); + } + return data; +} + +void TargetFree(TargetType target, void* data) { + switch (target) { + case TargetType::kHost: + case TargetType::kX86: + case TargetType::kARM: + TargetWrapper::Free(data); + break; + +#ifdef LITE_WITH_CUDA + case TargetType::kCUDA: + TargetWrapper::Free(data); + break; +#endif // LITE_WITH_CUDA +#ifdef LITE_WITH_OPENCL + case TargetType::kOpenCL: + TargetWrapperCL::Free(data); + break; +#endif // LITE_WITH_OPENCL +#ifdef LITE_WITH_FPGA + case TargetType::kFPGA: + TargetWrapper::Free(data); + break; +#endif // LITE_WITH_CUDA + default: + LOG(FATAL) << "Unknown type"; + } +} + +void TargetCopy(TargetType target, void* dst, const void* src, size_t size) { + switch (target) { + case TargetType::kHost: + case TargetType::kX86: + case TargetType::kARM: + TargetWrapper::MemcpySync( + dst, src, size, IoDirection::DtoD); + break; + +#ifdef LITE_WITH_CUDA + case TargetType::kCUDA: + TargetWrapper::MemcpySync( + dst, src, size, IoDirection::DtoD); + break; +#endif +#ifdef LITE_WITH_FPGA + case TargetType::kFPGA: + TargetWrapper::MemcpySync( + dst, src, size, IoDirection::DtoD); + break; +#endif +#ifdef LITE_WITH_OPENCL + case TargetType::kOpenCL: + TargetWrapperCL::MemcpySync(dst, src, size, IoDirection::DtoD); + break; +#endif // LITE_WITH_OPENCL + default: + LOG(FATAL) << "unsupported type"; + } +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/memory.h b/lite/core/memory.h new file mode 100644 index 00000000000..c35e5bec299 --- /dev/null +++ b/lite/core/memory.h @@ -0,0 +1,111 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/api/paddle_place.h" +#include "lite/core/target_wrapper.h" +#include "lite/utils/macros.h" + +#ifdef LITE_WITH_OPENCL +#include "lite/opencl/target_wrapper.h" +#endif // LITE_WITH_OPENCL + +namespace paddle { +namespace lite { + +// Malloc memory for a specific Target. All the targets should be an element in +// the `switch` here. +LITE_API void* TargetMalloc(TargetType target, size_t size); + +// Free memory for a specific Target. All the targets should be an element in +// the `switch` here. +void LITE_API TargetFree(TargetType target, void* data); + +// Copy a buffer from host to another target. +void TargetCopy(TargetType target, void* dst, const void* src, size_t size); + +template +void CopySync(void* dst, void* src, size_t size, IoDirection dir) { + switch (Target) { + case TARGET(kX86): + case TARGET(kHost): + case TARGET(kARM): + TargetWrapper::MemcpySync( + dst, src, size, IoDirection::HtoH); + break; +#ifdef LITE_WITH_CUDA + case TARGET(kCUDA): + TargetWrapperCuda::MemcpySync(dst, src, size, dir); + break; +#endif +#ifdef LITE_WITH_OPENCL + case TargetType::kOpenCL: + TargetWrapperCL::MemcpySync(dst, src, size, dir); + break; +#endif // LITE_WITH_OPENCL +#ifdef LITE_WITH_FPGA + case TARGET(kFPGA): + TargetWrapper::MemcpySync(dst, src, size, dir); + break; +#endif + } +} + +// Memory buffer manager. +class Buffer { + public: + Buffer() = default; + Buffer(TargetType target, size_t size) : space_(size), target_(target) {} + + void* data() const { return data_; } + TargetType target() const { return target_; } + size_t space() const { return space_; } + + void ResetLazy(TargetType target, size_t size) { + if (target != target_ || space_ < size) { + Free(); + data_ = TargetMalloc(target, size); + target_ = target; + space_ = size; + } + } + + void ResizeLazy(size_t size) { ResetLazy(target_, size); } + + void Free() { + if (space_ > 0) { + TargetFree(target_, data_); + } + target_ = TargetType::kHost; + space_ = 0; + } + + void CopyDataFrom(const Buffer& other, size_t nbytes) { + target_ = other.target_; + ResizeLazy(nbytes); + // TODO(Superjomn) support copy between different targets. + TargetCopy(target_, data_, other.data_, nbytes); + } + + ~Buffer() { Free(); } + + private: + // memory it actually malloced. + size_t space_{0}; + void* data_{nullptr}; + TargetType target_{TargetType::kHost}; +}; + +} // namespace lite +} // namespace paddle diff --git a/lite/core/memory_test.cc b/lite/core/memory_test.cc new file mode 100644 index 00000000000..eaee9c092cc --- /dev/null +++ b/lite/core/memory_test.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/memory.h" +#include + +namespace paddle { +namespace lite { + +TEST(memory, test) { + auto* buf = TargetMalloc(TARGET(kX86), 10); + ASSERT_TRUE(buf); + TargetFree(TARGET(kX86), buf); + +#ifdef LITE_WITH_CUDA + auto* buf_cuda = TargetMalloc(TARGET(kCUDA), 10); + ASSERT_TRUE(buf_cuda); + TargetFree(Target(kCUDA), buf_cuda); +#endif +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt new file mode 100644 index 00000000000..827c2cf9a41 --- /dev/null +++ b/lite/core/mir/CMakeLists.txt @@ -0,0 +1,106 @@ +lite_cc_library(mir_node SRCS node.cc DEPS kernel) +lite_cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node program) +lite_cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph) +lite_cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes) +lite_cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) + +add_subdirectory(fusion) +add_subdirectory(elimination) +add_subdirectory(subgraph) + +lite_cc_library(mir_passes + SRCS + fusion/fc_fuse_pass.cc + fusion/conv_elementwise_fuse_pass.cc + fusion/conv_activation_fuse_pass.cc + fusion/conv_bn_fuse_pass.cc + fusion/elementwise_add_activation_fuse_pass.cc + fusion/quant_dequant_fuse_pass.cc + elimination/identity_scale_eliminate_pass.cc + static_kernel_pick_pass.cc + variable_place_inference_pass.cc + type_target_cast_pass.cc + type_layout_cast_pass.cc + type_precision_cast_pass.cc + io_copy_kernel_pick_pass.cc + graph_visualize_pass.cc + generate_program_pass.cc + argument_type_display_pass.cc + demo_pass.cc + runtime_context_assign_pass.cc + DEPS mir_pass types context ${mir_fusers} ${subgraph_passes}) + +# lite_cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS + #mir_ssa_graph scope op + #fc_op + #${host_kernels} + #mir_passes + #mir_pass_manager + #program_fake_utils + #) +# lite_cc_test(test_variable_place_infrence_pass SRCS variable_place_inference_pass_test.cc +# DEPS +# mul_op +# feed_op +# fetch_op +# io_copy_op +# ${host_kernels} +# mir_passes +# mir_pass_manager +# optimizer +# program_fake_utils +# target_wrapper_host +# PROFILE_DEPS basic_profiler +# CUDA_DEPS target_wrapper_cuda kernels_cuda +# ARM_DEPS mul_compute_arm +# X86_DEPS mul_compute_x86 +# ) + +set(pattern_deps mir_node mir_ssa_graph op) +if (WITH_TESTING) + list(APPEND pattern_deps gtest) +endif() +lite_cc_library(pattern_matcher SRCS pattern_matcher.cc DEPS ${pattern_deps}) +lite_cc_test(test_pattern_matcher SRCS pattern_matcher_test.cc DEPS pattern_matcher) + +lite_cc_library(pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher) + + +# for mobile, unnecessary to compile the following testings. +if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + return() +endif() +lite_cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes) + + +# TODO(wz) replace framework/proto to lite proto. +if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + # it depends on the fluid/framework/proto, that is too heavy for mobile execution. + # TODO(wz) enable it latter. + # lite_cc_test(test_pattern_matcher_high_api SRCS pattern_matcher_high_api_test.cc DEPS + # pattern_matcher_high_api proto_desc mir_pass_manager fc_op mul_op elementwise_ops + # mir_passes compatible_pb program ${ops}) +endif() + +message(STATUS "----> Ops lite: ${ops}") +message(STATUS "----> Host kernels: ${host_kernels}") +message(STATUS "----> X86 kernels: ${x86_kernels}") + +# lite_cc_test(test_lite_fc_fuse SRCS fusion/fc_fuse_pass_test.cc +# DEPS cxx_api mir_passes +# ${ops} ${host_kernels} ${x86_kernels} ${arm_kernels} +# ARGS --model_dir=${LITE_MODEL_DIR}/lite_fc_model +# --optimized_model=${LITE_MODEL_DIR}/lite_fc_model_opt SERIAL) + +# lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz") +# add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz) + + +# lite_cc_test(test_lite_conv_elementwise_add_activation_fuse +# SRCS fusion/conv_elementwise_add_activation_fuse_pass_test.cc +# DEPS cxx_api mir_passes +# ${ops} ${host_kernels} ${x86_kernels}) +# lite_cc_test(test_lite_elementwise_add_activation_fuse +# SRCS fusion/elementwise_add_activation_fuse_pass_test.cc +# DEPS cxx_api mir_passes +# ${ops} ${host_kernels} ${x86_kernels}) diff --git a/lite/core/mir/argument_type_display_pass.cc b/lite/core/mir/argument_type_display_pass.cc new file mode 100644 index 00000000000..d53d705a2d7 --- /dev/null +++ b/lite/core/mir/argument_type_display_pass.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ArgumentTypeDisplayPass : public DebugPass { + public: + void Apply(const std::unique_ptr& graph) override { + VLOG(3) << "== Argument types =="; + for (auto& node : graph->mutable_nodes()) { + if (!node.IsArg()) continue; + + auto* type = node.AsArg().type; + if (type) { + VLOG(3) << "* ARG " << node.AsArg().name << " type: " << *type; + } else { + VLOG(3) << "* ARG " << node.AsArg().name << " type: UNK"; + } + } + VLOG(3) << "---------------------"; + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(argument_type_display_pass, + paddle::lite::mir::ArgumentTypeDisplayPass); diff --git a/lite/core/mir/demo_pass.cc b/lite/core/mir/demo_pass.cc new file mode 100644 index 00000000000..837a5a1cbcc --- /dev/null +++ b/lite/core/mir/demo_pass.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +class DemoPass : public mir::DebugPass { + public: + void Apply(const std::unique_ptr &graph) override {} +}; + +/* +bool RegisterDemoPass() { + return PassManager::Global().AddNewPass("demo", new DemoPass); +} + */ + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(demo, paddle::lite::mir::DemoPass); diff --git a/lite/core/mir/dot.h b/lite/core/mir/dot.h new file mode 100644 index 00000000000..df70565c077 --- /dev/null +++ b/lite/core/mir/dot.h @@ -0,0 +1,167 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * This file implements some helper classes and methods for DOT programming + * support. It will give a visualization of the graph and that helps to debug + * the logics of each Pass. + */ +#pragma once + +#include +#include +#include +#include "lite/utils/cp_logging.h" +#include "lite/utils/replace_stl/stream.h" +#include "lite/utils/string.h" + +namespace paddle { +namespace inference { +namespace analysis { + +static size_t dot_node_counter{0}; + +/* + * A Dot template that helps to build a DOT graph definition. + */ +class Dot { + public: + struct Attr { + std::string key; + std::string value; + + Attr(const std::string& key, const std::string& value) + : key(key), value(value) {} + + std::string repr() const { + STL::stringstream ss; + ss << key << "=" << '"' << value << '"'; + return ss.str(); + } + }; + + struct Node { + std::string name; + std::vector attrs; + + Node(const std::string& name, const std::vector& attrs) + : name(name), attrs(attrs) { + STL::stringstream ss; + ss << "node_" << dot_node_counter++; + id_ = ss.str(); + } + + std::string id() const { return id_; } + + std::string repr() const { + STL::stringstream ss; + CHECK(!name.empty()); + ss << id_; + if (attrs.empty()) { + ss << "[label=" << '"' << name << '"' << "]"; + return ss.str(); + } + for (size_t i = 0; i < attrs.size(); i++) { + if (i == 0) { + ss << "[label=" << '"' << name << '"' << " "; + } + ss << attrs[i].repr(); + ss << ((i < attrs.size() - 1) ? " " : "]"); + } + return ss.str(); + } + + private: + std::string id_; + }; + + struct Edge { + std::string source; + std::string target; + std::vector attrs; + + Edge(const std::string& source, + const std::string& target, + const std::vector& attrs) + : source(source), target(target), attrs(attrs) {} + + std::string repr() const { + STL::stringstream ss; + CHECK(!source.empty()); + CHECK(!target.empty()); + ss << source << "->" << target; + for (size_t i = 0; i < attrs.size(); i++) { + if (i == 0) { + ss << "["; + } + ss << attrs[i].repr(); + ss << ((i < attrs.size() - 1) ? " " : "]"); + } + return ss.str(); + } + }; + + Dot() = default; + + explicit Dot(const std::vector& attrs) : attrs_(attrs) {} + + void AddNode(const std::string& id, + const std::vector& attrs, + std::string label = "") { + CHECK(!nodes_.count(id)) << "duplicate Node '" << id << "'"; + if (label.empty()) label = id; + nodes_.emplace(id, Node{label, attrs}); + } + + void AddEdge(const std::string& source, + const std::string& target, + const std::vector& attrs) { + CHECK(!source.empty()); + CHECK(!target.empty()); + auto sid = nodes_.at(source).id(); + auto tid = nodes_.at(target).id(); + edges_.emplace_back(sid, tid, attrs); + } + + // Compile to DOT language codes. + std::string Build() const { + STL::stringstream ss; + const std::string indent = " "; + ss << "digraph G {" << '\n'; + + // Add graph attrs + for (const auto& attr : attrs_) { + ss << indent << attr.repr() << '\n'; + } + // add nodes + for (auto& item : nodes_) { + ss << indent << item.second.repr() << '\n'; + } + // add edges + for (auto& edge : edges_) { + ss << indent << edge.repr() << '\n'; + } + ss << "} // end G"; + return ss.str(); + } + + private: + std::unordered_map nodes_; + std::vector edges_; + std::vector attrs_; +}; + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/lite/core/mir/elimination/CMakeLists.txt b/lite/core/mir/elimination/CMakeLists.txt new file mode 100644 index 00000000000..9b6598630ba --- /dev/null +++ b/lite/core/mir/elimination/CMakeLists.txt @@ -0,0 +1,10 @@ +if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + # NOTE disabled for the proto_desc is not valid yet. + # TODO(Superjomn) enable them if valid latter. + # lite_cc_test(test_identity_scale_eliminate_pass + # SRCS identity_scale_eliminate_pass_test.cc + # DEPS mir_passes program proto_desc cpp_op_desc + # ${ops} + # ) +endif() + diff --git a/lite/core/mir/elimination/identity_scale_eliminate_pass.cc b/lite/core/mir/elimination/identity_scale_eliminate_pass.cc new file mode 100644 index 00000000000..07d8dfd3f5b --- /dev/null +++ b/lite/core/mir/elimination/identity_scale_eliminate_pass.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { + +namespace { + +class Eliminator : public FuseBase { + public: + void BuildPattern() override { + auto* pre_op = OpNode("preop"); // the previous op's output need update + // TODO(Superjomn) check has only one output + auto* x = VarNode("x")->assert_is_op_input("scale", "X"); + auto* scale_op = OpNode("scale", "scale") + ->assert_op_attr("scale", 1.) + ->assert_op_attr("bias", 0.); + auto* out = VarNode("out")->assert_is_op_output("scale", "Out"); + + *pre_op >> *x >> *scale_op >> *out; + + // The pre_op will be eliminated, and a new output-updated op will insert. + x->AsIntermediate(); // x is pre_op's output, need to update + } + + private: + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + auto& pre_op = matched.at("preop")->AsStmt(); + auto op_info = *pre_op.op_info(); + + op_info.UpdateAllOutputs(matched.at("x")->AsArg().name, + matched.at("out")->AsArg().name); + pre_op.ResetOp(op_info, graph->valid_places()); + + GraphSafeRemoveNodes(graph, {matched.at("scale")}); + + IR_NODE_LINK_TO(matched.at("preop"), matched.at("out")); + } +}; + +} // namespace + +class IdentityScaleEliminatePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + Eliminator eliminator; + eliminator(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(identity_scale_eliminate_pass, + paddle::lite::mir::IdentityScaleEliminatePass); diff --git a/lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc b/lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc new file mode 100644 index 00000000000..7130a13c475 --- /dev/null +++ b/lite/core/mir/elimination/identity_scale_eliminate_pass_test.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/ssa_graph.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace lite { +namespace mir { + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + // Op list: + // (x)->feed -> (feed) -> scale -> (scale_out) -> fetch->(fetch) + // After pass + // (x)->feed->(scale_out)->fetch->(fetch) + auto* main_block = program_desc->MutableBlock(0); + auto* feed_op = main_block->AppendOp(); + auto* scale_op = main_block->AppendOp(); + auto* fetch_op = main_block->AppendOp(); + main_block->Var("x"); + main_block->Var("feed"); + main_block->Var("scale_out"); + main_block->Var("fetch_out"); + + scope->Var("x")->GetMutable(); + scope->Var("feed")->GetMutable(); + scope->Var("scale_out")->GetMutable(); + scope->Var("fetch_out")->GetMutable(); + + feed_op->SetType("feed"); + feed_op->SetInput("X", {"x"}); + feed_op->SetAttr("col", 1); + feed_op->SetOutput("Out", {"feed"}); + + scale_op->SetType("scale"); + scale_op->SetInput("X", {"feed"}); + scale_op->SetOutput("Out", {"scale_out"}); + scale_op->SetAttr("scale", 1.f); + scale_op->SetAttr("bias", 0.f); + scale_op->SetAttr("bias_after_scale", true); + + fetch_op->SetType("fetch"); + fetch_op->SetInput("X", {"scale_out"}); + fetch_op->SetOutput("Out", {"fetch"}); + fetch_op->SetAttr("col", 1); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + VLOG(5) << Visualize(graph.get()); + + return graph; +} + +TEST(identity_test, test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + const int num_nodes = graph->nodes().size(); + auto pass = PassManager::Global().LookUp("identity_scale_eliminate_pass"); + ASSERT_TRUE(pass); + pass->Apply(graph); + ASSERT_EQ(graph->nodes().size(), num_nodes - 2UL); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(feed) +USE_LITE_OP(fetch) +USE_LITE_OP(scale) +USE_MIR_PASS(identity_scale_eliminate_pass) diff --git a/lite/core/mir/fusion/CMakeLists.txt b/lite/core/mir/fusion/CMakeLists.txt new file mode 100644 index 00000000000..27141a5933f --- /dev/null +++ b/lite/core/mir/fusion/CMakeLists.txt @@ -0,0 +1,36 @@ +lite_cc_library(fuse_fc + SRCS fc_fuser.cc + DEPS pattern_matcher_high_api) +lite_cc_library(fuse_conv_elementwise + SRCS conv_elementwise_fuser.cc + DEPS pattern_matcher_high_api) +lite_cc_library(fuse_conv_activation + SRCS conv_activation_fuser.cc + DEPS pattern_matcher_high_api) +lite_cc_library(fuse_conv_bn + SRCS conv_bn_fuser.cc + DEPS pattern_matcher_high_api) +lite_cc_library(fuse_elementwise_add_activation + SRCS elementwise_add_activation_fuser.cc + DEPS pattern_matcher_high_api) +lite_cc_library(fuse_quant_dequant + SRCS quant_dequant_op_fuser.cc + DEPS pattern_matcher_high_api) + +set(mir_fusers + fuse_fc + fuse_conv_elementwise + fuse_conv_activation + fuse_conv_bn + fuse_quant_dequant + fuse_elementwise_add_activation + CACHE INTERNAL "fusers") + +if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + return() +endif() + +# TODO(Superjomn) Enable it latter +# NOTE disabled for the proto_desc is not valid yet. +# lite_cc_test(test_lite_conv_bn_fuse SRCS conv_bn_fuse_pass_test.cc +# DEPS elementwise_ops batch_norm_op conv_op proto_desc compatible_pb program mir_pass mir_pass_manager pattern_matcher_high_api) diff --git a/lite/core/mir/fusion/conv_activation_fuse_pass.cc b/lite/core/mir/fusion/conv_activation_fuse_pass.cc new file mode 100644 index 00000000000..cad98cb26c2 --- /dev/null +++ b/lite/core/mir/fusion/conv_activation_fuse_pass.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/conv_activation_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/conv_activation_fuser.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ConvActivationFusePass::Apply(const std::unique_ptr& graph) { + fusion::ConvActivationFuser fuser("conv2d", "relu"); + fuser(graph.get()); + + fusion::ConvActivationFuser depthwise_fuser("depthwise_conv2d", "relu"); + depthwise_fuser(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_conv_activation_fuse_pass, + paddle::lite::mir::ConvActivationFusePass); diff --git a/lite/core/mir/fusion/conv_activation_fuse_pass.h b/lite/core/mir/fusion/conv_activation_fuse_pass.h new file mode 100644 index 00000000000..e6f0f34be0c --- /dev/null +++ b/lite/core/mir/fusion/conv_activation_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ConvActivationFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/conv_activation_fuser.cc b/lite/core/mir/fusion/conv_activation_fuser.cc new file mode 100644 index 00000000000..c49a9ad4f0b --- /dev/null +++ b/lite/core/mir/fusion/conv_activation_fuser.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/conv_activation_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void ConvActivationFuser::BuildPattern() { + // create input nodes. + auto* input = + VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput(); + auto* filter = + VarNode("filter")->assert_is_op_input(conv_type_, "Filter")->AsInput(); + auto* bias = + VarNode("bias")->assert_is_op_input(conv_type_, "Bias")->AsInput(); + + // create op nodes + auto* conv2d = + OpNode("conv2d", conv_type_)->assert_is_op(conv_type_)->AsIntermediate(); + + auto* act = + OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate(); + + // create intermediate nodes + auto* conv2d_out = VarNode("conv2d_out") + ->assert_is_op_output(conv_type_, "Output") + ->assert_is_op_input(act_type_, "X") + ->AsIntermediate(); + + // create output node + auto* out = + VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput(); + + // create topology. + std::vector conv2d_inputs{filter, input, bias}; + conv2d_inputs >> *conv2d >> *conv2d_out; + *conv2d_out >> *act >> *out; +} + +void ConvActivationFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto conv_op = LiteOpRegistry::Global().Create(conv_type_); + auto conv_old = matched.at("conv2d")->stmt()->op(); + auto* scope = conv_old->scope(); + auto& valid_places = conv_old->valid_places(); + conv_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places); + + IR_NODE_LINK_TO(matched.at("input"), new_op_node); + IR_NODE_LINK_TO(matched.at("filter"), new_op_node); + IR_NODE_LINK_TO(matched.at("bias"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("output")); +} + +cpp::OpDesc ConvActivationFuser::GenOpDesc(const key2nodes_t& matched) { + auto* desc = matched.at("conv2d")->stmt()->op_info(); + + cpp::OpDesc op_desc = *desc; + op_desc.SetType(conv_type_); + op_desc.SetInput("Input", {matched.at("input")->arg()->name}); + op_desc.SetInput("Filter", {matched.at("filter")->arg()->name}); + op_desc.SetInput("Bias", {matched.at("bias")->arg()->name}); + op_desc.SetOutput("Output", {matched.at("output")->arg()->name}); + // Other inputs. See operators/conv_op.h + std::vector input_arg_names = desc->InputArgumentNames(); + + if (std::find(input_arg_names.begin(), + input_arg_names.end(), + "ResidualData") != input_arg_names.end()) { + op_desc.SetInput("ResidualData", desc->Input("ResidualData")); + } + // Only consider strides, padding, groups, dilations, fuse_relu for now + op_desc.SetAttr("strides", desc->GetAttr>("strides")); + op_desc.SetAttr("paddings", desc->GetAttr>("paddings")); + op_desc.SetAttr("groups", desc->GetAttr("groups")); + op_desc.SetAttr("dilations", desc->GetAttr>("dilations")); + // TODO(sangoly): support other activation types + op_desc.SetAttr("fuse_relu", true); + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/conv_activation_fuser.h b/lite/core/mir/fusion/conv_activation_fuser.h new file mode 100644 index 00000000000..3377e28e29a --- /dev/null +++ b/lite/core/mir/fusion/conv_activation_fuser.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class ConvActivationFuser : public FuseBase { + public: + explicit ConvActivationFuser(const std::string& conv_type, + const std::string& act_type) { + CHECK(act_type == "relu") << "Only relu activation be supported now"; + conv_type_ = conv_type; + act_type_ = act_type; + } + + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + std::string conv_type_; + std::string act_type_; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.cc b/lite/core/mir/fusion/conv_bn_fuse_pass.cc new file mode 100644 index 00000000000..954e007a850 --- /dev/null +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/conv_bn_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/conv_bn_fuser.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ConvBNFusePass::Apply(const std::unique_ptr& graph) { + fusion::ConvBNFuser fuser("conv2d"); + fuser(graph.get()); + + fusion::ConvBNFuser fuser2("depthwise_conv2d"); + fuser2(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass); diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.h b/lite/core/mir/fusion/conv_bn_fuse_pass.h new file mode 100644 index 00000000000..b2c56d18022 --- /dev/null +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ConvBNFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc b/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc new file mode 100644 index 00000000000..7e720bcc3de --- /dev/null +++ b/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc @@ -0,0 +1,140 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/conv_bn_fuse_pass.h" +#include +#include +#include +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/program.h" +#include "lite/core/tensor.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + auto* main_block = program_desc->MutableBlock(0); + auto* conv_op = main_block->AppendOp(); + auto* bn_op = main_block->AppendOp(); + main_block->Var("conv_i"); + main_block->Var("conv_param"); + main_block->Var("conv_out"); + + main_block->Var("bn_scale"); + main_block->Var("bn_bias"); + main_block->Var("bn_mean"); + main_block->Var("bn_var"); + main_block->Var("bn_out"); + main_block->Var("bn_mean_out"); + main_block->Var("bn_var_out"); + main_block->Var("bn_saved_mean"); + main_block->Var("bn_saved_var"); + + scope->Var("conv_i")->GetMutable(); + auto conv_param_t = scope->Var("conv_param")->GetMutable(); + std::vector conv_param_shape = {3, 1, 2, 2}; + conv_param_t->Resize(lite::DDim(conv_param_shape)); + conv_param_t->mutable_data(); + scope->Var("conv_out")->GetMutable(); + auto bn_scale_t = scope->Var("bn_scale")->GetMutable(); + std::vector bn_scale_shape = {3}; + bn_scale_t->Resize(lite::DDim(bn_scale_shape)); + bn_scale_t->mutable_data(); + + auto bn_bias_t = scope->Var("bn_bias")->GetMutable(); + std::vector bn_bias_shape = {3}; + bn_bias_t->Resize(lite::DDim(bn_bias_shape)); + bn_bias_t->mutable_data(); + + auto bn_mean_t = scope->Var("bn_mean")->GetMutable(); + bn_mean_t->Resize(lite::DDim(bn_bias_shape)); + bn_mean_t->mutable_data(); + + auto bn_var_t = scope->Var("bn_var")->GetMutable(); + bn_var_t->Resize(lite::DDim(bn_bias_shape)); + bn_var_t->mutable_data(); + + scope->Var("bn_out")->GetMutable(); + scope->Var("bn_mean_out")->GetMutable(); + scope->Var("bn_var_out")->GetMutable(); + scope->Var("bn_saved_mean")->GetMutable(); + scope->Var("bn_saved_var")->GetMutable(); + + conv_op->SetType("conv2d"); + conv_op->SetInput("Input", {"conv_i"}); + conv_op->SetInput("Filter", {"conv_param"}); + conv_op->SetOutput("Output", {"conv_out"}); + const std::vector strides({1, 1}); + const std::vector paddings({1, 1}); + const std::vector dilations({1, 1}); + const int groups = 1; + conv_op->SetAttr("strides", strides); + conv_op->SetAttr("paddings", paddings); + conv_op->SetAttr("dilations", dilations); + conv_op->SetAttr("groups", groups); + conv_op->SetAttr("fuse_relu", false); + + bn_op->SetType("batch_norm"); + bn_op->SetInput("X", {"conv_out"}); + bn_op->SetInput("Bias", {"bn_bias"}); + bn_op->SetInput("Mean", {"bn_mean"}); + bn_op->SetInput("Scale", {"bn_scale"}); + bn_op->SetInput("Variance", {"bn_var"}); + + bn_op->SetOutput("Y", {"bn_out"}); + bn_op->SetOutput("MeanOut", {"bn_mean_out"}); + bn_op->SetOutput("VarianceOut", {"bn_var_out"}); + bn_op->SetOutput("SavedMean", {"bn_saved_mean"}); + bn_op->SetOutput("SavedVariance", {"bn_saved_var"}); + float eps = 1e-5; + bn_op->SetAttr("epsilon", eps); + bn_op->SetAttr("is_test", static_cast(1)); + bn_op->SetAttr("use_global_stats", false); + bn_op->SetAttr("momentum", 0.9f); + bn_op->SetAttr("data_layout", std::string("NCHW")); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + return graph; +} + +TEST(pattern_matcher2, test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + const int num_nodes = graph->nodes().size(); + auto* fuser = new ConvBNFusePass; + fuser->Apply(graph); + ASSERT_EQ(graph->nodes().size(), + num_nodes - 8UL /*nodes removed */ + 1UL /* eltwise_add node*/); +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(conv2d); +USE_LITE_OP(batch_norm); +USE_LITE_OP(elementwise_add); diff --git a/lite/core/mir/fusion/conv_bn_fuser.cc b/lite/core/mir/fusion/conv_bn_fuser.cc new file mode 100644 index 00000000000..77ad8237fe8 --- /dev/null +++ b/lite/core/mir/fusion/conv_bn_fuser.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/conv_bn_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void ConvBNFuser::BuildPattern() { + auto* conv_input = + VarNode("conv_input")->assert_is_op_input(conv_type_, "Input")->AsInput(); + auto* conv_weight = VarNode("conv_weight") + ->assert_is_op_input(conv_type_, "Filter") + ->AsInput(); + auto* conv = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_); + auto* conv_out = VarNode("conv_out") + ->assert_is_op_output(conv_type_, "Output") + ->assert_is_op_input("batch_norm", "X"); + + auto* bn_scale = VarNode("bn_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* bn_bias = + VarNode("bn_bias")->assert_is_op_input("batch_norm", "Bias")->AsInput(); + auto* bn_mean = VarNode("bn_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* bn_var = VarNode("bn_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* bn = + OpNode("bn", "batch_norm")->assert_is_op("batch_norm")->AsIntermediate(); + + auto* bn_out = + VarNode("bn_out")->assert_is_op_output("batch_norm", "Y")->AsOutput(); + auto* bn_mean_out = VarNode("bn_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* bn_var_out = VarNode("bn_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* bn_saved_mean = VarNode("bn_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* bn_saved_var = VarNode("bn_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + + conv->LinksFrom({conv_input, conv_weight}).LinksTo({conv_out}); + + bn->LinksFrom({conv_out, bn_scale, bn_bias, bn_mean, bn_var}) + .LinksTo({bn_out, bn_mean_out, bn_saved_mean, bn_saved_var, bn_var_out}); +} + +void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto eltwise_op = LiteOpRegistry::Global().Create("elementwise_add"); + + auto conv_instruct = matched.at("conv2d")->stmt(); + auto conv = conv_instruct->op(); + auto* scope = conv->scope(); + auto& valid_places = conv->valid_places(); + + auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name) + ->GetMutable(); + auto conv_weight_dims = conv_weight_t->dims(); + size_t weight_num = conv_weight_t->data_size(); + + auto bn_scale_t = scope->FindVar(matched.at("bn_scale")->arg()->name) + ->GetMutable(); + size_t bias_size = bn_scale_t->data_size(); + auto bn_scale_d = bn_scale_t->mutable_data(); + CHECK_EQ(bias_size, static_cast(conv_weight_dims[0])) + << "The BN bias's size should be equal to the size of the first " + << "dim size of the conv weights"; + + auto bn_mean_t = scope->FindVar(matched.at("bn_mean")->arg()->name) + ->GetMutable(); + auto bn_mean_d = bn_mean_t->mutable_data(); + + auto bn_var_t = scope->FindVar(matched.at("bn_variance")->arg()->name) + ->GetMutable(); + auto bn_var_d = bn_var_t->mutable_data(); + + auto bn_bias_t = scope->FindVar(matched.at("bn_bias")->arg()->name) + ->GetMutable(); + auto bn_bias_d = bn_bias_t->mutable_data(); + auto eps = matched.at("bn")->stmt()->op_info()->GetAttr("epsilon"); + + auto conv_op_desc = conv_instruct->mutable_op_info(); + + bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false; + Tensor alpha_tensor, beta_tensor; + alpha_tensor.CopyDataFrom(*bn_bias_t); + beta_tensor.CopyDataFrom(*bn_bias_t); + auto alpha_data = alpha_tensor.mutable_data(); + auto beta_data = beta_tensor.mutable_data(); + + int h = bias_size; + int w = weight_num / bias_size; + ComputeAlphaAndBeta( + bn_scale_d, bn_mean_d, bn_var_d, alpha_data, beta_data, eps, h, w); + + if (enable_int8) { + PADDLE_ENFORCE(conv_op_desc->HasAttr("weight_scale"), + "INT8 mode: Conv should has weight_scale attr"); + auto weight_scale = + conv_op_desc->GetAttr>("weight_scale"); + for (int i = 0; i < h; i++) { + weight_scale[i] *= alpha_data[i]; + } + // Interface like this should be abandoned. + conv_op_desc->SetAttr("weight_scale", weight_scale); + auto update_conv_desc = *conv_instruct->mutable_op_info(); + conv_instruct->ResetOp(update_conv_desc, graph->valid_places()); + } else { + auto conv_weight_d = conv_weight_t->mutable_data(); + for (int i = 0; i < h; i++) { + for (int j = 0; j < w; j++) { + conv_weight_d[i * w + j] *= alpha_data[i]; + } + } + } + for (int i = 0; i < bias_size; i++) { + bn_bias_d[i] += beta_data[i]; + } + eltwise_op->Attach(op_desc, scope); + auto* new_op_node = graph->GraphCreateInstructNode(eltwise_op, valid_places); + + IR_NODE_LINK_TO(matched.at("conv_out"), new_op_node); + IR_NODE_LINK_TO(matched.at("bn_bias"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("bn_out")); +} + +cpp::OpDesc ConvBNFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + op_desc.SetType("elementwise_add"); + op_desc.SetInput("X", {matched.at("conv_out")->arg()->name}); + op_desc.SetInput("Y", {matched.at("bn_bias")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("bn_out")->arg()->name}); + op_desc.SetAttr("axis", 1); + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/conv_bn_fuser.h b/lite/core/mir/fusion/conv_bn_fuser.h new file mode 100644 index 00000000000..9acf65f9e21 --- /dev/null +++ b/lite/core/mir/fusion/conv_bn_fuser.h @@ -0,0 +1,58 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" +#include "lite/utils/paddle_enforce.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class ConvBNFuser : public FuseBase { + public: + explicit ConvBNFuser(const std::string& conv_type) : conv_type_(conv_type) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + void ComputeAlphaAndBeta(float* scale_d, + float* mean_d, + float* var_d, + float* alpha, + float* beta, + float eps, + int h, + int w) { + for (int i = 0; i < h; i++) { + alpha[i] = scale_d[i] / std::sqrt(var_d[i] + eps); + } + for (int i = 0; i < h; i++) { + beta[i] = (-mean_d[i]) * alpha[i]; + } + } + + private: + std::string conv_type_{"conv2d"}; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass_test.cc b/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass_test.cc new file mode 100644 index 00000000000..59bf7035e79 --- /dev/null +++ b/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass_test.cc @@ -0,0 +1,157 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/core/mir/fusion/conv_activation_fuse_pass.h" +#include "lite/core/mir/fusion/conv_elementwise_fuse_pass.h" +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/op_registry.h" +#include "lite/core/program.h" +#include "lite/core/tensor.h" +#include "paddle/fluid/framework/program_desc.h" + +DEFINE_string(model_dir, "", ""); +DEFINE_string(optimized_model, "", ""); + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + auto* main_block = program_desc->MutableBlock(0); + + auto* conv2d_1 = main_block->AppendOp(); + auto* conv2d_2 = main_block->AppendOp(); + auto* add_1 = main_block->AppendOp(); + auto* relu_1 = main_block->AppendOp(); + auto* add_2 = main_block->AppendOp(); + auto* relu_2 = main_block->AppendOp(); + + main_block->Var("input_1"); + main_block->Var("input_2"); + main_block->Var("filter_1"); + main_block->Var("filter_2"); + main_block->Var("conv2d_1_out"); + main_block->Var("conv2d_2_out"); + main_block->Var("bias_1"); + main_block->Var("add_1_out"); + main_block->Var("add_2_out"); + main_block->Var("relu_1_out"); + main_block->Var("out"); + + scope->Var("input_1")->GetMutable(); + scope->Var("input_2")->GetMutable(); + scope->Var("filter_1")->GetMutable(); + scope->Var("filter_2")->GetMutable(); + scope->Var("conv2d_1_out")->GetMutable(); + scope->Var("conv2d_2_out")->GetMutable(); + scope->Var("bias_1")->GetMutable(); + scope->Var("add_1_out")->GetMutable(); + scope->Var("add_2_out")->GetMutable(); + scope->Var("relu_1_out")->GetMutable(); + scope->Var("out")->GetMutable(); + + conv2d_1->SetType("conv2d"); + conv2d_1->SetInput("Input", {"input_1"}); + conv2d_1->SetInput("Filter", {"filter_1"}); + conv2d_1->SetOutput("Output", {"conv2d_1_out"}); + conv2d_1->SetAttr("strides", std::vector({1, 1})); + conv2d_1->SetAttr("paddings", std::vector({0, 0})); + conv2d_1->SetAttr("groups", 1); + conv2d_1->SetAttr("dilations", std::vector({1, 1})); + conv2d_1->SetAttr("fuse_relu", false); + + add_1->SetType("elementwise_add"); + add_1->SetInput("X", {"conv2d_1_out"}); + add_1->SetInput("Y", {"bias_1"}); + add_1->SetOutput("Out", {"add_1_out"}); + add_1->SetAttr("axis", 1); + + relu_1->SetType("relu"); + relu_1->SetInput("X", {"add_1_out"}); + relu_1->SetOutput("Out", {"relu_1_out"}); + + conv2d_2->SetType("conv2d"); + conv2d_2->SetInput("Input", {"input_2"}); + conv2d_2->SetInput("Filter", {"filter_2"}); + conv2d_2->SetOutput("Output", {"conv2d_2_out"}); + conv2d_2->SetAttr("strides", std::vector({1, 1})); + conv2d_2->SetAttr("paddings", std::vector({0, 0})); + conv2d_2->SetAttr("groups", 1); + conv2d_2->SetAttr("dilations", std::vector({1, 1})); + conv2d_2->SetAttr("fuse_relu", false); + + add_2->SetType("elementwise_add"); + add_2->SetInput("X", {"conv2d_2_out"}); + add_2->SetInput("Y", {"relu_1_out"}); + add_2->SetOutput("Out", {"add_2_out"}); + add_2->SetAttr("axis", 1); + + relu_2->SetType("relu"); + relu_2->SetInput("X", {"add_2_out"}); + relu_2->SetOutput("Out", {"out"}); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + return graph; +} + +TEST(conv_elementwise_add_relu_fuse_pass, graph_test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + + Visualize(graph.get()); + ASSERT_EQ(graph->nodes().size(), 11UL /*vars*/ + 6UL /*ops*/); + Visualize(graph.get()); +} + +TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + Visualize(graph.get()); + const int num_nodes = graph->nodes().size(); + auto* fuser_eltwise = new ConvElementwiseFusePass; + auto* fuser_act = new ConvActivationFusePass; + fuser_eltwise->Apply(graph); + fuser_act->Apply(graph); + + Visualize(graph.get()); + ASSERT_EQ(graph->nodes().size(), + num_nodes - 5UL * 2 /*nodes removed */ + 1UL * 2 /* fused nodes*/); +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(elementwise_add); +USE_LITE_OP(conv2d); +USE_LITE_OP(depthwise_conv2d); +USE_LITE_OP(relu); diff --git a/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc b/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc new file mode 100644 index 00000000000..a3040d11894 --- /dev/null +++ b/lite/core/mir/fusion/conv_elementwise_fuse_pass.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/conv_elementwise_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/conv_elementwise_fuser.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ConvElementwiseFusePass::Apply(const std::unique_ptr& graph) { + fusion::ConvElementwiseFuser fuser("conv2d"); + fuser(graph.get()); + + fusion::ConvElementwiseFuser depthwise_fuser("depthwise_conv2d"); + depthwise_fuser(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_conv_elementwise_fuse_pass, + paddle::lite::mir::ConvElementwiseFusePass); diff --git a/lite/core/mir/fusion/conv_elementwise_fuse_pass.h b/lite/core/mir/fusion/conv_elementwise_fuse_pass.h new file mode 100644 index 00000000000..11953e9b10e --- /dev/null +++ b/lite/core/mir/fusion/conv_elementwise_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ConvElementwiseFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/conv_elementwise_fuser.cc b/lite/core/mir/fusion/conv_elementwise_fuser.cc new file mode 100644 index 00000000000..c3ab3e4c4ca --- /dev/null +++ b/lite/core/mir/fusion/conv_elementwise_fuser.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/conv_elementwise_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void ConvElementwiseFuser::BuildPattern() { + // create input nodes. + auto* input = + VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput(); + auto* filter = + VarNode("filter")->assert_is_op_input(conv_type_, "Filter")->AsInput(); + auto* bias = + VarNode("bias")->assert_is_op_input("elementwise_add", "Y")->AsInput(); + + // create op nodes + auto* conv2d = + OpNode("conv2d", conv_type_)->assert_is_op(conv_type_)->AsIntermediate(); + auto* add = OpNode("add", "elementwise_add") + ->assert_is_op("elementwise_add") + ->AsIntermediate(); + + // create intermediate nodes + auto* conv2d_out = VarNode("conv2d_out") + ->assert_is_op_output(conv_type_, "Output") + ->assert_is_op_input("elementwise_add", "X") + ->AsIntermediate(); + // create output node + auto* add_out = VarNode("output") + ->assert_is_op_output("elementwise_add", "Out") + ->AsOutput(); + + // create topology. + std::vector conv2d_inputs{filter, input}; + std::vector add_inputs{conv2d_out, bias}; + conv2d_inputs >> *conv2d >> *conv2d_out; + add_inputs >> *add >> *add_out; +} + +void ConvElementwiseFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto conv_op = LiteOpRegistry::Global().Create(conv_type_); + auto conv_old = matched.at("conv2d")->stmt()->op(); + auto* scope = conv_old->scope(); + auto& valid_places = conv_old->valid_places(); + conv_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places); + + IR_NODE_LINK_TO(matched.at("input"), new_op_node); + IR_NODE_LINK_TO(matched.at("filter"), new_op_node); + IR_NODE_LINK_TO(matched.at("bias"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("output")); +} + +cpp::OpDesc ConvElementwiseFuser::GenOpDesc(const key2nodes_t& matched) { + auto* desc = matched.at("conv2d")->stmt()->op_info(); + + cpp::OpDesc op_desc = *desc; + op_desc.SetType(conv_type_); + op_desc.SetInput("Input", {matched.at("input")->arg()->name}); + op_desc.SetInput("Filter", {matched.at("filter")->arg()->name}); + op_desc.SetInput("Bias", {matched.at("bias")->arg()->name}); + op_desc.SetOutput("Output", {matched.at("output")->arg()->name}); + // Other inputs. See operators/conv_op.h + std::vector input_arg_names = desc->InputArgumentNames(); + + if (std::find(input_arg_names.begin(), + input_arg_names.end(), + "ResidualData") != input_arg_names.end()) { + op_desc.SetInput("ResidualData", desc->Input("ResidualData")); + } + // Only consider strides, padding, groups, dilations for now + op_desc.SetAttr("strides", desc->GetAttr>("strides")); + op_desc.SetAttr("paddings", desc->GetAttr>("paddings")); + op_desc.SetAttr("groups", desc->GetAttr("groups")); + op_desc.SetAttr("dilations", desc->GetAttr>("dilations")); + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/conv_elementwise_fuser.h b/lite/core/mir/fusion/conv_elementwise_fuser.h new file mode 100644 index 00000000000..4514fc5010b --- /dev/null +++ b/lite/core/mir/fusion/conv_elementwise_fuser.h @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class ConvElementwiseFuser : public FuseBase { + public: + explicit ConvElementwiseFuser(const std::string& conv_type) { + conv_type_ = conv_type; + } + + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + std::string conv_type_; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc new file mode 100644 index 00000000000..33223cb140c --- /dev/null +++ b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/elementwise_add_activation_fuser.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ElementwiseAddActivationFusePass::Apply( + const std::unique_ptr& graph) { + fusion::ElementwiseAddActivationFuser fuser("relu"); + fuser(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass, + paddle::lite::mir::ElementwiseAddActivationFusePass); diff --git a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h new file mode 100644 index 00000000000..299b6b89a07 --- /dev/null +++ b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ElementwiseAddActivationFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass_test.cc b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass_test.cc new file mode 100644 index 00000000000..ca5127db168 --- /dev/null +++ b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass_test.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h" +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/op_registry.h" +#include "lite/core/program.h" +#include "lite/core/tensor.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + auto* main_block = program_desc->MutableBlock(0); + + auto* add_1 = main_block->AppendOp(); + auto* add_2 = main_block->AppendOp(); + auto* relu_1 = main_block->AppendOp(); + auto* relu_2 = main_block->AppendOp(); + + main_block->Var("x_1"); + main_block->Var("y_1"); + main_block->Var("add_out_1"); + main_block->Var("relu_out_1"); + main_block->Var("y_2"); + main_block->Var("add_out_2"); + main_block->Var("out"); + + scope->Var("x_1")->GetMutable(); + scope->Var("y_1")->GetMutable(); + scope->Var("add_out_1")->GetMutable(); + scope->Var("relu_out_1")->GetMutable(); + scope->Var("y_2")->GetMutable(); + scope->Var("add_out_2")->GetMutable(); + scope->Var("out")->GetMutable(); + + add_1->SetType("elementwise_add"); + add_1->SetInput("X", {"x_1"}); + add_1->SetInput("Y", {"y_1"}); + add_1->SetOutput("Out", {"add_out_1"}); + add_1->SetAttr("axis", 1); + + relu_1->SetType("relu"); + relu_1->SetInput("X", {"add_out_1"}); + relu_1->SetOutput("Out", {"relu_out_1"}); + + add_2->SetType("elementwise_add"); + add_2->SetInput("X", {"relu_out_1"}); + add_2->SetInput("Y", {"y_2"}); + add_2->SetOutput("Out", {"add_out_2"}); + add_2->SetAttr("axis", 1); + + relu_2->SetType("relu"); + relu_2->SetInput("X", {"add_out_2"}); + relu_2->SetOutput("Out", {"out"}); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + return graph; +} + +TEST(elementwise_add_activation_fuse_pass, graph_test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + ASSERT_EQ(graph->nodes().size(), + 7UL /*vars*/ + 4UL /*ops*/ + 1UL /* SSAGraph tmp node*/); +} + +TEST(elementwise_add_activation_fuse_pass, fuse_test_op) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + Visualize(graph.get()); + const int num_nodes = graph->nodes().size(); + auto* fuser = new ElementwiseAddActivationFusePass; + fuser->Apply(graph); + Visualize(graph.get()); + ASSERT_EQ(graph->nodes().size(), + num_nodes - 3UL * 2 /*nodes removed */ + 1UL * 2 /* fused nodes*/); +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(elementwise_add); +USE_LITE_OP(fusion_elementwise_add_activation); +USE_LITE_OP(relu); diff --git a/lite/core/mir/fusion/elementwise_add_activation_fuser.cc b/lite/core/mir/fusion/elementwise_add_activation_fuser.cc new file mode 100644 index 00000000000..3c6bf4768bf --- /dev/null +++ b/lite/core/mir/fusion/elementwise_add_activation_fuser.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/elementwise_add_activation_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void ElementwiseAddActivationFuser::BuildPattern() { + // create input nodes. + auto* x = VarNode("x")->assert_is_op_input("elementwise_add", "X")->AsInput(); + auto* y = VarNode("y")->assert_is_op_input("elementwise_add", "Y")->AsInput(); + + // create op nodes + auto* add = OpNode("add", "elementwise_add") + ->assert_is_op("elementwise_add") + ->AsIntermediate(); + auto* act = + OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate(); + + // create intermediate nodes + auto* add_out = VarNode("add_out") + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input(act_type_, "X") + ->AsIntermediate(); + + // create output node + auto* out = + VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput(); + + // create topology. + std::vector add_inputs{x, y}; + add_inputs >> *add >> *add_out; + *add_out >> *act >> *out; +} + +void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto op = + LiteOpRegistry::Global().Create("fusion_elementwise_add_activation"); + auto old_op = matched.at("add")->stmt()->op(); + auto* scope = old_op->scope(); + auto& valid_places = old_op->valid_places(); + op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(op, valid_places); + + IR_NODE_LINK_TO(matched.at("x"), new_op_node); + IR_NODE_LINK_TO(matched.at("y"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("output")); +} + +cpp::OpDesc ElementwiseAddActivationFuser::GenOpDesc( + const key2nodes_t& matched) { + auto* desc = matched.at("add")->stmt()->op_info(); + + cpp::OpDesc op_desc; + op_desc.SetType("fusion_elementwise_add_activation"); + op_desc.SetInput("X", {matched.at("x")->arg()->name}); + op_desc.SetInput("Y", {matched.at("y")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("output")->arg()->name}); + + op_desc.SetAttr("axis", desc->GetAttr("axis")); + op_desc.SetAttr("act_type", act_type_); + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/elementwise_add_activation_fuser.h b/lite/core/mir/fusion/elementwise_add_activation_fuser.h new file mode 100644 index 00000000000..47bb2fcf821 --- /dev/null +++ b/lite/core/mir/fusion/elementwise_add_activation_fuser.h @@ -0,0 +1,41 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class ElementwiseAddActivationFuser : public FuseBase { + public: + explicit ElementwiseAddActivationFuser(const std::string& act_type) + : act_type_(act_type) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + std::string act_type_; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/fc_fuse_pass.cc b/lite/core/mir/fusion/fc_fuse_pass.cc new file mode 100644 index 00000000000..0303ae06e63 --- /dev/null +++ b/lite/core/mir/fusion/fc_fuse_pass.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/fc_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/fc_fuser.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void FcFusePass::Apply(const std::unique_ptr& graph) { + fusion::FcFuser fuser; + fuser(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass); diff --git a/lite/core/mir/fusion/fc_fuse_pass.h b/lite/core/mir/fusion/fc_fuse_pass.h new file mode 100644 index 00000000000..44771345a71 --- /dev/null +++ b/lite/core/mir/fusion/fc_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class FcFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/fc_fuse_pass_test.cc b/lite/core/mir/fusion/fc_fuse_pass_test.cc new file mode 100644 index 00000000000..fb509498d19 --- /dev/null +++ b/lite/core/mir/fusion/fc_fuse_pass_test.cc @@ -0,0 +1,112 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/fc_fuse_pass.h" +#include +#include +#include +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/core/op_registry.h" + +DEFINE_string(model_dir, "", ""); +DEFINE_string(optimized_model, "", ""); + +namespace paddle { +namespace lite { +namespace mir { + +TEST(fc_fuse_pass, fuse_test) { + lite::Predictor predictor; +#ifndef LITE_WITH_CUDA + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); +#else + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kNCHW)}, + Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny)}, + Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)}, + }); +#endif + + predictor.Build(FLAGS_model_dir, + Place{TARGET(kX86), PRECISION(kFloat)}, // origin cuda + valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({100, 100}))); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < 100 * 100; i++) { + data[i] = i; + } + + predictor.Run(); + + auto* out = predictor.GetOutput(0); + LOG(INFO) << out << " memory size " << out->data_size(); + LOG(INFO) << "out " << out->data()[0]; + LOG(INFO) << "out " << out->data()[1]; + LOG(INFO) << "dims " << out->dims(); + EXPECT_NEAR(out->data()[0], 38.120617f, 1e-5); + EXPECT_NEAR(out->data()[1], 10.109812f, 1e-5); + CHECK_EQ(out->dims()[0], 100); + CHECK_EQ(out->dims()[1], 500); +} + +#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +TEST(fc_fuse_pass, save_model_test) { + lite::Predictor predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); + predictor.Build( + FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places); + + LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model; + predictor.SaveModel(FLAGS_optimized_model); +} +#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(mul); +USE_LITE_OP(elementwise_add); +USE_LITE_OP(elementwise_sub); +USE_LITE_OP(fc); +USE_LITE_OP(feed); +USE_LITE_OP(fetch); +USE_LITE_OP(io_copy); +USE_LITE_OP(softmax); +USE_LITE_OP(scale); +USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); +USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); + +// #ifdef LITE_WITH_X86 +// USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def); +// USE_LITE_KERNEL(scale, kX86, kFloat, kNCHW, def); +// #endif + +#ifdef LITE_WITH_CUDA +USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host); +#endif diff --git a/lite/core/mir/fusion/fc_fuser.cc b/lite/core/mir/fusion/fc_fuser.cc new file mode 100644 index 00000000000..72e1a4684d6 --- /dev/null +++ b/lite/core/mir/fusion/fc_fuser.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/fc_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void FcFuser::BuildPattern() { + // create nodes. + auto* x = VarNode("x")->assert_is_op_input("mul", "X"); + auto* W = VarNode("W")->assert_is_op_input("mul", "Y"); + auto* b = VarNode("b"); + auto* mul = OpNode("mul", "mul"); + auto* mul_out = VarNode("mul_out"); + auto* add = OpNode("add", "elementwise_add"); + auto* Out = VarNode("Out"); + + // create topology. + std::vector mul_inputs{W, x}; + std::vector add_inputs{mul_out, b}; + mul_inputs >> *mul >> *mul_out; + add_inputs >> *add >> *Out; + + // Some op specialities. + mul_out->AsIntermediate(); + mul->AsIntermediate(); + add->AsIntermediate(); +} + +void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto fc_op = LiteOpRegistry::Global().Create("fc"); + auto mul = matched.at("mul")->stmt()->op(); + auto* scope = mul->scope(); + auto& valid_places = mul->valid_places(); + fc_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places); + + IR_NODE_LINK_TO(matched.at("W"), new_op_node); + IR_NODE_LINK_TO(matched.at("x"), new_op_node); + IR_NODE_LINK_TO(matched.at("b"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("Out")); +} + +cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc = *matched.at("mul")->stmt()->op_info(); + op_desc.SetType("fc"); + op_desc.SetInput("Input", {matched.at("x")->arg()->name}); + op_desc.SetInput("W", {matched.at("W")->arg()->name}); + op_desc.SetInput("Bias", {matched.at("b")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("Out")->arg()->name}); + op_desc.SetAttr( + "in_num_col_dims", + matched.at("mul")->stmt()->op_info()->GetAttr("x_num_col_dims")); + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/fc_fuser.h b/lite/core/mir/fusion/fc_fuser.h new file mode 100644 index 00000000000..7ba07527898 --- /dev/null +++ b/lite/core/mir/fusion/fc_fuser.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class FcFuser : public FuseBase { + public: + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc new file mode 100644 index 00000000000..83b70c78281 --- /dev/null +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/quant_dequant_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/quant_dequant_op_fuser.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { + std::unordered_set quant_types = { + "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; + std::unordered_set quantized_op_types = { + "conv2d", "mul", "depthwise_conv2d"}; + for (auto& quant_type : quant_types) { + for (auto& op_type : quantized_op_types) { + for (int i = 6; i >= 1; i--) { + fusion::QuantDequantOpFuser fuser(op_type, quant_type, i); + fuser(graph.get()); + } + } + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass, + paddle::lite::mir::QuantDequantFusePass); diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.h b/lite/core/mir/fusion/quant_dequant_fuse_pass.h new file mode 100644 index 00000000000..243241bfb7d --- /dev/null +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.h @@ -0,0 +1,33 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class QuantDequantFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.cc b/lite/core/mir/fusion/quant_dequant_op_fuser.cc new file mode 100644 index 00000000000..5fd97db96f9 --- /dev/null +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.cc @@ -0,0 +1,198 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/quant_dequant_op_fuser.h" +#include +#include +#include "lite/utils/string.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void QuantDequantOpFuser::BuildPattern() { + const int kNumFields = 5; + const int kQuantizedWeightOffset = 0; + const int kQuantizedOpOffset = 1; + const int kQuantizedOpOutOffset = 2; + const int kDequantOpOffset = 3; + const int kDequantOpOutOffset = 4; + + std::string weight_name = ""; + if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { + weight_name = "Filter"; + } else { + weight_name = "Y"; + } + auto* quant_op_input = VarNode("quant_op_input") + ->assert_is_op_input(quant_type_, "X") + ->AsInput(); + auto* quant_op_in_scale = VarNode("quant_op_in_scale") + ->assert_is_op_input(quant_type_, "InScale") + ->AsIntermediate(); + auto* quant_op = OpNode("quant_op", quant_type_) + ->assert_is_op(quant_type_) + ->AsIntermediate(); + + auto* quant_op_out_scale = + VarNode("quant_op_out_scale") + ->assert_is_op_output(quant_type_, "OutScale") + ->assert_is_op_input("fake_dequantize_max_abs", "Scale") + ->AsIntermediate(); + + auto* quant_op_out = VarNode("quant_op_out") + ->assert_is_op_output(quant_type_, "Out") + ->assert_is_op_input(op_type_) + ->AsIntermediate(); + std::vector nodes; + for (int i = 0; i < times_; i++) { + nodes.push_back(VarNode(string_format("quantized_op_weight%d", i)) + ->assert_is_op_input(op_type_, weight_name) + ->AsInput()); + + nodes.push_back(OpNode(string_format("quantized_op%d", i), op_type_) + ->assert_is_op(op_type_) + ->AsIntermediate()); + + nodes.push_back(VarNode(string_format("quantized_op_out%d", i)) + ->assert_is_op_output(op_type_) + ->assert_is_op_input("fake_dequantize_max_abs", "X") + ->AsIntermediate()); + + nodes.push_back( + OpNode(string_format("dequant_op%d", i), "fake_dequantize_max_abs") + ->assert_is_op("fake_dequantize_max_abs") + ->AsIntermediate()); + nodes.push_back(VarNode(string_format("dequant_op_out%d", i)) + ->assert_is_op_output("fake_dequantize_max_abs", "Out") + ->AsOutput()); + } + + quant_op->LinksFrom({quant_op_input, quant_op_in_scale}); + quant_op_out->LinksFrom({quant_op}); + quant_op_out_scale->LinksFrom({quant_op}); + for (int i = 0; i < times_; i++) { + nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom( + {quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]}); + nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom( + {nodes[i * kNumFields + kQuantizedOpOffset]}); + nodes[i * kNumFields + kDequantOpOffset]->LinksFrom( + {nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale}); + nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom( + {nodes[i * kNumFields + kDequantOpOffset]}); + } +} + +void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + const int kNumFields = 5; + const int kQuantizedWeightOffset = 0; + const int kQuantizedOpOffset = 1; + const int kDequantOpOffset = 3; + const int kDequantOpOutOffset = 4; + + auto* quant_op_input = matched.at("quant_op_input"); + auto* quant_op_in_scale = matched.at("quant_op_in_scale"); + auto* quant_op = matched.at("quant_op"); + + std::vector nodes; + for (int i = 0; i < times_; i++) { + nodes.push_back(matched.at(string_format("quantized_op_weight%d", i))); + nodes.push_back(matched.at(string_format("quantized_op%d", i))); + nodes.push_back(matched.at(string_format("quantized_op_out%d", i))); + nodes.push_back(matched.at(string_format("dequant_op%d", i))); + nodes.push_back(matched.at(string_format("dequant_op_out%d", i))); + } + int bit_length = quant_op->stmt()->op_info()->GetAttr("bit_length"); + auto* scope = quant_op->stmt()->op()->scope(); + auto& valid_places = quant_op->stmt()->op()->valid_places(); + int range = ((1 << (bit_length - 1)) - 1); + auto input_scale_t = scope->FindVar(quant_op_in_scale->arg()->name) + ->GetMutable(); + float input_scale = input_scale_t->data()[0] / range; + + VLOG(4) << "range: " << range << " input_scale: " << input_scale; + for (int i = 0; i < times_; i++) { + float max_range = nodes[i * kNumFields + kDequantOpOffset] + ->stmt() + ->op_info() + ->GetAttr("max_range"); + // weight_scale = max(abs(weight)) + float whole_weight_scale = + static_cast(range * range) / max_range / range; + + cpp::OpDesc op_desc = + *nodes[i * kNumFields + kQuantizedOpOffset]->stmt()->op_info(); + + auto quantized_weight_var_name = + nodes[i * kNumFields + kQuantizedWeightOffset]->arg()->name; + auto quantized_weight_t = + scope->FindVar(quantized_weight_var_name)->GetMutable(); + std::vector weight_scale; + int weight_scale_size; + + if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { + op_desc.SetInput("Input", {matched.at("quant_op_input")->arg()->name}); + op_desc.SetOutput( + "Output", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name}); + // Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should + // be Cout. + weight_scale_size = quantized_weight_t->dims()[0]; + } else if (op_type_ == "mul") { + op_desc.SetInput("X", {matched.at("quant_op_input")->arg()->name}); + op_desc.SetOutput( + "Out", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name}); + // Fc weight: Cin * Cout, the weight_scale_size should be Cout. + weight_scale_size = quantized_weight_t->dims()[1]; + } + for (int i = 0; i < weight_scale_size; i++) { + weight_scale.push_back(whole_weight_scale); + } + op_desc.SetAttr("enable_int8", true); + op_desc.SetAttr("input_scale", input_scale); + op_desc.SetAttr("weight_scale", weight_scale); + + Tensor temp_tensor; + temp_tensor.CopyDataFrom(*quantized_weight_t); + float* temp_data = temp_tensor.mutable_data(); + + size_t weight_num = quantized_weight_t->data_size(); + int8_t* quantized_weight_data = quantized_weight_t->mutable_data(); + + // change the weight from the float type to int8 type. + for (size_t i = 0; i < weight_num; i++) { + quantized_weight_data[i] = static_cast(temp_data[i]); + } + auto quantized_op = LiteOpRegistry::Global().Create(op_type_); + + quantized_op->Attach(op_desc, scope); + auto* new_op_node = + graph->GraphCreateInstructNode(quantized_op, valid_places); + IR_NODE_LINK_TO(quant_op_input, new_op_node); + IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset], + new_op_node); + IR_NODE_LINK_TO(new_op_node, nodes[i * kNumFields + kDequantOpOutOffset]); + } +} + +cpp::OpDesc QuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/quant_dequant_op_fuser.h b/lite/core/mir/fusion/quant_dequant_op_fuser.h new file mode 100644 index 00000000000..15833ad2580 --- /dev/null +++ b/lite/core/mir/fusion/quant_dequant_op_fuser.h @@ -0,0 +1,59 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +/* The model trained by fluid quantization is a simulation of real int8. + * The quantized Ops(conv2d, mul, depthwise conv2d etc) have fake_quantop + * in front and fake_dequantop behind. + * + * When in int8 mode, the pattern like "fake_quant + quantized_op + + * fake_dequant" + * can be detected by this fuser. The fuser extract the input_scale and + * the weight_scale info from fake_quant, fake_dequant op and fuse those into + * the quantized_op. + * In addition, the fuser delete fake_quant and fake_dequant op in the graph at + * the last. + */ +class QuantDequantOpFuser : public FuseBase { + public: + explicit QuantDequantOpFuser(const std::string& op_type, + const std::string& quant_type, + int times) + : op_type_(op_type), quant_type_(quant_type), times_(times) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + + private: + std::string op_type_{"conv2d"}; + std::string quant_type_; + int times_; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/generate_program_pass.cc b/lite/core/mir/generate_program_pass.cc new file mode 100644 index 00000000000..b957e70f981 --- /dev/null +++ b/lite/core/mir/generate_program_pass.cc @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/generate_program_pass.h" +#include +#include +#include +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void GenerateProgramPass::Apply(const std::unique_ptr& graph) { + VLOG(4) << "final program \n" << Visualize(graph.get()); + for (auto& item : graph->StmtTopologicalOrder()) { + if (item->IsStmt()) { + auto& stmt = item->AsStmt(); + VLOG(4) << stmt; + insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front())); + } + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(generate_program_pass, + paddle::lite::mir::GenerateProgramPass); diff --git a/lite/core/mir/generate_program_pass.h b/lite/core/mir/generate_program_pass.h new file mode 100644 index 00000000000..b126b4aba4d --- /dev/null +++ b/lite/core/mir/generate_program_pass.h @@ -0,0 +1,50 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +/* + * GenerateProgramPass will build the execution program for executor from a mir + * graph. + */ +class GenerateProgramPass : public ProgramPass { + public: + void Apply(const std::unique_ptr &graph) override; + + std::unique_ptr GenProgram() { + LOG(INFO) << "insts.size " << insts_.size(); + std::unique_ptr program( + new RuntimeProgram(std::move(insts_))); + + return program; + } + + private: + std::vector insts_; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/graph_visualize_pass.cc b/lite/core/mir/graph_visualize_pass.cc new file mode 100644 index 00000000000..1aa9ea77d0a --- /dev/null +++ b/lite/core/mir/graph_visualize_pass.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/graph_visualize_pass.h" +#include +#include +#include +#include "lite/core/mir/pass_registry.h" +#include "lite/utils/string.h" + +namespace paddle { +namespace lite { +namespace mir { + +using inference::analysis::Dot; + +void GraphVisualizePass::Apply(const std::unique_ptr& graph) { + Visualize(graph.get()); +} + +std::string Visualize(mir::SSAGraph* graph) { + inference::analysis::Dot dot; + + int id = 0; + std::set exists_args; + + for (auto& node : graph->mutable_nodes()) { + std::string key; + if (node.IsArg()) { + key = node.AsArg().name; + } else { + key = string_format("%s%d", node.AsStmt().op_type().c_str(), id++); + } + + if (node.IsStmt()) { + dot.AddNode(key, {Dot::Attr("shape", "box")}); + for (auto& x : node.inlinks) { + auto name = x->AsArg().name; + if (!exists_args.count(name)) { + dot.AddNode(name, {}); + } + dot.AddEdge(name, key, {}); + exists_args.insert(name); + } + for (auto& x : node.outlinks) { + auto name = x->AsArg().name; + if (!exists_args.count(name)) { + dot.AddNode(name, {}); + } + dot.AddEdge(key, name, {}); + exists_args.insert(name); + } + } + } + + auto res = dot.Build(); + LOG(INFO) << "dot:\n" << res; + return res; +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(graph_visualze, paddle::lite::mir::GraphVisualizePass); diff --git a/lite/core/mir/graph_visualize_pass.h b/lite/core/mir/graph_visualize_pass.h new file mode 100644 index 00000000000..bde58a63b3b --- /dev/null +++ b/lite/core/mir/graph_visualize_pass.h @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/dot.h" +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +/* + * GraphVisualizePass helps to visualize an mir graph by exporting a DOT + * language file. + */ +class GraphVisualizePass : public DebugPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +std::string Visualize(mir::SSAGraph* graph); + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/io_copy_kernel_pick_pass.cc b/lite/core/mir/io_copy_kernel_pick_pass.cc new file mode 100644 index 00000000000..6c62ac9a1a0 --- /dev/null +++ b/lite/core/mir/io_copy_kernel_pick_pass.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +class IoCopyKernelPickPass : public StmtPass { + public: + void Apply(const std::unique_ptr& graph) override { + for (auto& node : graph->mutable_nodes()) { + if (!node.IsStmt()) continue; + auto& inst = node.AsStmt(); + if (inst.op_type() != "io_copy") continue; + + LOG(INFO) << "....> picking a IO COPY kernel"; + + auto& kernels = node.AsStmt().kernels(); + CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op"; + const auto* inty = node.inlinks.front()->AsArg().type; + const auto* outy = node.outlinks.front()->AsArg().type; + LOG(INFO) << "input type " << *inty; + LOG(INFO) << "output type " << *outy; + + bool is_found = false; + LOG(INFO) << "kernels size " << kernels.size(); + for (auto& kernel : kernels) { + CHECK_EQ(node.inlinks.size(), 1UL); + CHECK_EQ(node.outlinks.size(), 1UL); + + const Type* in_arg_ty = kernel->GetInputDeclType("Input"); + const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); + LOG(INFO) << "checking kernel candidate " << *in_arg_ty << "->" + << *out_arg_ty; + if (TargetCompatibleTo(*inty, *in_arg_ty)) { + // Both the input and output type matches, remove other kernels + // directly. + if (TargetCompatibleTo(*outy, *out_arg_ty)) { + LOG(INFO) << "get a IOCopy kernel"; + auto x = std::move(kernel); + kernels.clear(); + kernels.emplace_back(std::move(x)); + is_found = true; + break; + } + } + } + + CHECK(is_found) << "Can't find a IoCopy kernel for IO: " << *inty << "->" + << *outy; + } + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(io_copy_kernel_pick_pass, + paddle::lite::mir::IoCopyKernelPickPass); diff --git a/lite/core/mir/node.cc b/lite/core/mir/node.cc new file mode 100644 index 00000000000..61d3d317e7b --- /dev/null +++ b/lite/core/mir/node.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/node.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +const OpInfo *mir::Node::Stmt::op_info() const { + CHECK(op_); + return op_->op_info(); +} + +Place mir::Node::Stmt::place() const { + CHECK(!valid_kernels_.empty()); + return valid_kernels_.front()->place(); +} + +KernelBase &mir::Node::Stmt::picked_kernel() { + CHECK(!valid_kernels_.empty()) << "no kernel for " << op_type(); + return *valid_kernels_.front(); +} + +OpInfo *mir::Node::Stmt::mutable_op_info() { + CHECK(op_); + return op_->mutable_op_info(); +} + +void mir::Node::Stmt::ResetOp(const cpp::OpDesc &op_desc, + const std::vector &valid_places, + lite::Scope *scope) { + CHECK((op_ && op_->scope()) || scope) << "Either scope should be set"; + lite::Scope *the_scope = scope ? scope : op_->scope(); + op_->Attach(op_desc, the_scope); + // Recreate the kernels with the latest OpInfo. + valid_kernels_.clear(); + + if (!op_ || op_->op_info()->Type() != op_desc.Type()) { + op_ = LiteOpRegistry::Global().Create(op_desc.Type()); + CHECK(op_) << "No op found for " << op_desc.Type(); + } + valid_kernels_ = op_->CreateKernels(valid_places); +} + +std::ostream &mir::operator<<(std::ostream &os, const mir::Node::Stmt &other) { + os << "Statement " << other.op_type() << " " << other.place().DebugString(); + return os; +} + +mir::Node::Arg &mir::Node::AsArg(const std::string &name, int id) { + auto &x = AsArg(); + x.name = name; + x.id = id; + return x; +} +mir::Node::Arg &mir::Node::AsArg(const std::string &name) { + auto &x = AsArg(); + x.name = name; + return x; +} +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/node.h b/lite/core/mir/node.h new file mode 100644 index 00000000000..9c7d441ca38 --- /dev/null +++ b/lite/core/mir/node.h @@ -0,0 +1,173 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace mir { + +// Node in a MIR graph. +class Node { + public: + std::list inlinks; + std::list outlinks; + + Node() = default; + + enum class Role { + kArg = 0, + kStmt, + kNumRoles, /*should be last*/ + kUnk, + }; + + class Stmt { + // The kernel instances this Statement contains. + std::vector> valid_kernels_; + // TODO(Superjomn) make this a shared_ptr for resource safety. + std::shared_ptr op_; // we hold op to run InferShape + + public: + // Refresh the operator and kernels with the latest OpInfo. + void ResetOp(const cpp::OpDesc& op_desc, + const std::vector& valid_places, + lite::Scope* scope = nullptr); + + std::string op_type() const { return op_info()->Type(); } + const OpInfo* op_info() const; + OpInfo* mutable_op_info(); + + void SetKernels(std::vector>&& kernels) { + valid_kernels_ = std::move(kernels); + } + std::vector>& kernels() { + return valid_kernels_; + } + + void ClearSubgraphID() { subgraph_id_ = -1 /* note: not 0 */; } + void SetSubgraphID(int id) { subgraph_id_ = id; } + int subgraph_id() const { return subgraph_id_; } + void SetOp(const std::shared_ptr& op) { op_ = op; } + const std::shared_ptr op() const { return op_; } + + Place place() const; + + KernelBase& picked_kernel(); + + friend std::ostream& operator<<(std::ostream& os, const Stmt& other); + + // Description. + std::string desc; + + protected: + // -1 means not in subgraph, 0 means supported but not one id, id started + // from 1 + int subgraph_id_{-1}; + }; + + struct Arg { + std::string name; + int id{0}; + const Type* type{}; + // Weight is a special kind of argument, it is marked as weight explicitly + // so that some weight related optimization can take place. + bool is_weight{false}; + // is_persist indicate that whether is the argument trans from Weight + // if the need more than one tool operator(eg. io_copy layout calib), the + // argument between them should be persist to make sure it's only run once + bool is_persist{false}; + }; + + Arg& AsArg(const std::string& name, int id); + + Arg& AsArg(const std::string& name); + + Stmt& AsStmt(const std::string& op_type, + std::vector>&& kernels, + const std::shared_ptr& op) { + auto& x = AsStmt(); + x.SetOp(op); + x.SetKernels(std::move(kernels)); + return x; + } + + Stmt* stmt() const { + CHECK(IsStmt()); + return stmt_.get(); + } + + Arg* arg() const { + CHECK(IsArg()); + return arg_.get(); + } + + // Set roles. + Arg& AsArg() { + if (role_ != Role::kUnk) { + CHECK(role_ == Role::kArg); + return *arg_; + } + role_ = Role::kArg; + arg_.reset(new Arg); + return *arg_; + } + Stmt& AsStmt() { + if (role_ != Role::kUnk) { + CHECK(role_ == Role::kStmt); + return *stmt_; + } + role_ = Role::kStmt; + stmt_.reset(new Stmt); + return *stmt_; + } + + friend std::ostream& operator<<(std::ostream& os, Node& other) { + os << static_cast(other.role_) << " "; + if (!other.IsRoleSet()) { + os << "Unk role node"; + } + if (other.IsArg()) { + auto& arg = other.AsArg(); + os << "Argument " << arg.name; + } + if (other.IsStmt()) { + auto& arg = other.AsStmt(); + os << "Statement " << arg.op_type(); + } + return os; + } + + // Check roles. + bool IsRoleSet() const { return role_ != Role::kUnk; } + bool IsStmt() const { return role_ == Role::kStmt; } + bool IsArg() const { return role_ == Role::kArg; } + + private: + // Either stmt_ or argument_ is used. + std::unique_ptr stmt_; + std::unique_ptr arg_; + Role role_{Role::kUnk}; +}; +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/pass.cc b/lite/core/mir/pass.cc new file mode 100644 index 00000000000..2aaa5a4a171 --- /dev/null +++ b/lite/core/mir/pass.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass.h" diff --git a/lite/core/mir/pass.h b/lite/core/mir/pass.h new file mode 100644 index 00000000000..bd1ce1412ae --- /dev/null +++ b/lite/core/mir/pass.h @@ -0,0 +1,78 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/mir/node.h" +#include "lite/core/mir/ssa_graph.h" + +namespace paddle { +namespace lite { +namespace mir { + +class Pass { + public: + // Some appoint here, one pass should be only one of the following kinds. + enum class Kind { + // Will modify the program/graph topology. + kProgramWise = 0, + // Will modify the statement, with the graph topology fixed. + kStmtWise, + // Will not modify the IR, just collect information or visualization. + kDebug, + }; + + explicit Pass(Kind kind) : kind_(kind) {} + + virtual void Apply(const std::unique_ptr& graph) = 0; + + void set_name(const std::string& name) { name_ = name; } + const std::string& name() const { return name_; } + + void set_doc(const std::string& doc) { doc_ = doc; } + const std::string& doc() const { return doc_; } + + Kind kind() const { return kind_; } + bool is_debug_pass() const { return kind_ == Kind::kDebug; } + bool is_program_pass() const { return kind_ == Kind::kProgramWise; } + bool is_stmt_pass() const { return kind_ == Kind::kStmtWise; } + + virtual ~Pass() = default; + + private: + const Kind kind_; + std::string name_; + std::string doc_; +}; + +// Different kinds. +class ProgramPass : public Pass { + public: + ProgramPass() : Pass(Kind::kProgramWise) {} +}; + +class StmtPass : public Pass { + public: + StmtPass() : Pass(Kind::kStmtWise) {} +}; + +class DebugPass : public Pass { + public: + DebugPass() : Pass(Kind::kDebug) {} +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/pass_manager.cc b/lite/core/mir/pass_manager.cc new file mode 100644 index 00000000000..17f81b3bdd1 --- /dev/null +++ b/lite/core/mir/pass_manager.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass_manager.h" + +namespace paddle { +namespace lite { +namespace mir {} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/pass_manager.h b/lite/core/mir/pass_manager.h new file mode 100644 index 00000000000..ca40f2deca1 --- /dev/null +++ b/lite/core/mir/pass_manager.h @@ -0,0 +1,87 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class PassManager { + public: + static PassManager& Global() { + static PassManager x; + return x; + } + + PassManager() {} + + void Run(const std::unique_ptr& graph) { + for (auto& pass : passes_) { + LOG(INFO) << "Running MIR pass " << pass->name(); + pass->Apply(graph); + } + } + + bool AddNewPass(const std::string& name, Pass* pass) { + passes_.emplace_back(pass); + pass_map_.emplace(name, passes_.back().get()); + passes_.back()->set_name(name); + return true; + } + + // Clear all the passes. + void Clear() { passes_.clear(); } + + std::list>::iterator passes_begin() { + return passes_.begin(); + } + std::list>::iterator passes_end() { + return passes_.end(); + } + std::list>::const_iterator passes_const_begin() + const { + return passes_.begin(); + } + std::list>::const_iterator passes_const_end() + const { + return passes_.end(); + } + + Pass* LookUp(const std::string& key) { + auto it = pass_map_.find(key); + if (it != pass_map_.end()) return it->second; + return nullptr; + } + + template + PassTy* LookUp(const std::string& key) { + auto it = pass_map_.find(key); + if (it != pass_map_.end()) return dynamic_cast(it->second); + return nullptr; + } + + private: + std::list> passes_; + std::map pass_map_; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/pass_manager_test.cc b/lite/core/mir/pass_manager_test.cc new file mode 100644 index 00000000000..05e11ed5d16 --- /dev/null +++ b/lite/core/mir/pass_manager_test.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass_manager.h" +#include +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +TEST(PassManager, test) { + auto* pass = PassManager::Global().LookUp("demo"); + LOG(INFO) << "pass: " << pass; + ASSERT_TRUE(pass != nullptr); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_MIR_PASS(demo); diff --git a/lite/core/mir/pass_registry.cc b/lite/core/mir/pass_registry.cc new file mode 100644 index 00000000000..e80db5d4ca1 --- /dev/null +++ b/lite/core/mir/pass_registry.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir {} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/pass_registry.h b/lite/core/mir/pass_registry.h new file mode 100644 index 00000000000..6144ea2c24a --- /dev/null +++ b/lite/core/mir/pass_registry.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/api/paddle_lite_factory_helper.h" +#include "lite/core/mir/pass_manager.h" + +namespace paddle { +namespace lite { +namespace mir { + +class PassRegistry { + public: + PassRegistry(const std::string& name, mir::Pass* pass) { + VLOG(2) << "Registry add MIR pass " << name; + PassManager::Global().AddNewPass(name, pass); + } + + bool Touch() const { return true; } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +#define REGISTER_MIR_PASS(name__, class__) \ + paddle::lite::mir::PassRegistry mir_pass_registry##name__(#name__, \ + new class__); \ + bool mir_pass_registry##name__##_fake() { \ + return mir_pass_registry##name__.Touch(); \ + } diff --git a/lite/core/mir/pattern_matcher.cc b/lite/core/mir/pattern_matcher.cc new file mode 100644 index 00000000000..ade4c26008f --- /dev/null +++ b/lite/core/mir/pattern_matcher.cc @@ -0,0 +1,527 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "lite/core/mir/dot.h" +#include "lite/core/mir/pattern_matcher.h" +#include "lite/core/op_lite.h" +#include "lite/utils/string.h" + +namespace paddle { +namespace lite { +namespace mir { + +size_t PMPattern::id_ = 0UL; + +PMNode &PMNode::operator>>(PMNode &right) { + pattern_->AddEdge(this, &right); + // automatically add out op link relation. + if (right.IsOp()) { + CHECK(!right.op_type_.empty()); + this->assert_is_op_input(right.op_type_); + } + + return right; +} + +PMNode &PMNode::operator>>(std::vector &nodes) { + for (auto *node : nodes) { + *this >> *node; + } + return *this; +} + +PMNode &operator>>(std::vector &others, PMNode &me) { + for (auto *o : others) { + *o >> me; + } + return me; +} + +PMNode *PMPattern::NewNode(const std::string &name) { + if (!name.empty()) { + CHECK_EQ(node_map_.count(name), 0UL) + << "PMNode's name should be unique, get duplicate " << name; + } + + nodes_.emplace_back(new PMNode(this, name)); + auto *cur = nodes_.back().get(); + node_map_[name] = cur; + return cur; +} + +PMNode *PMPattern::NewNode(PMNode::teller_t &&teller, const std::string &name) { + if (!name.empty()) { + CHECK_EQ(node_map_.count(name), 0UL) + << "PMNode's name should be unique, get duplicate " << name; + } + + nodes_.emplace_back(new PMNode(std::move(teller), this, name)); + auto *cur = nodes_.back().get(); + node_map_[name] = cur; + return cur; +} + +PMNode *PMPattern::RetrieveNode(const std::string &id) const { + auto it = node_map_.find(id); + if (it == node_map_.end()) { + return nullptr; + } + + return it->second; +} + +void PMPattern::AddEdge(PMNode *a, PMNode *b) { + CHECK(a); + CHECK(b); + CHECK_NE(a, b) << "Can't connect to the same nodes."; + edges_.emplace_back(a, b); +} + +void PatternMatcher::operator()(SSAGraph *graph, + PatternMatcher::handle_t handler) { + if (!MarkPMNodesInGraph(graph)) { + return; + } + + auto subgraphs = DetectPatterns(); + UniquePatterns(&subgraphs); + RemoveOverlappedMatch(&subgraphs); + ValidateByNodeRole(&subgraphs); + + if (subgraphs.empty()) return; + int id = 0; + for (auto &g : subgraphs) { + VLOG(3) << "optimizing #" << id++ << " subgraph"; + handler(g, graph); + } +} + +bool PatternMatcher::MarkPMNodesInGraph(SSAGraph *graph) { + VLOG(3) << "mark pmnodes in graph"; + if (graph->nodes().empty()) return false; + for (auto &node : graph->mutable_nodes()) { + for (const auto &pmnode : pattern_.nodes()) { + if (pmnode->Tell(&node)) { + pmnodes2nodes_[pmnode.get()].insert(&node); + } + } + } + // Check to early stop if some PMNode can't find matched Node. + for (auto &pmnode : pattern_.nodes()) { + if (!pmnodes2nodes_.count(pmnode.get())) { + VLOG(4) << pmnode->name() << " can't find matched Node, early stop"; + // return false; + } + } + VLOG(3) << pmnodes2nodes_.size() << " nodes marked"; + + return !pmnodes2nodes_.empty(); +} + +// The intermediate Nodes can only link to the nodes inside the pattern, or this +// subgraph will be droped. +void PatternMatcher::ValidateByNodeRole( + std::vector *subgraphs) { + std::vector result; + + subgraphs->erase( + std::remove_if(subgraphs->begin(), + subgraphs->end(), + [](const PatternMatcher::subgraph_t &subgraph) -> bool { + // Collect the inlinks and outlinks. + std::unordered_set ios; + for (auto &item : subgraph) { + ios.insert(item.second); + } + for (auto &item : subgraph) { + if (item.first->IsIntermediate()) { + for (auto *x : item.second->inlinks) { + if (!ios.count(x)) { + return true; + } + } + for (auto *x : item.second->outlinks) { + if (!ios.count(x)) { + return true; + } + } + } + } + return false; + }), + subgraphs->end()); +} + +struct HitGroup { + std::unordered_map roles; + + bool Match(Node *node, PMNode *pat) { + if (nodes_.count(node)) { + if (roles.count(pat) && roles[pat] == node) return true; + return false; + } else { + if (roles.count(pat) && roles[pat] != node) return false; + return true; + } + } + + void Register(Node *node, PMNode *pat) { + roles[pat] = node; + nodes_.insert(node); + } + + private: + std::unordered_set nodes_; +}; + +// Tell whether Node a links to b. +bool IsNodesLink(Node *a, Node *b) { + for (auto *node : a->outlinks) { + if (b == node) { + return true; + } + } + return false; +} + +std::vector PatternMatcher::DetectPatterns() { + // Init empty subgraphs. + std::vector result; + std::vector init_groups; + std::array, 2> bi_records; + auto *first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get() + : pattern_.edges().front().first; + if (!pmnodes2nodes_.count(first_pnode)) return result; + for (auto *node : pmnodes2nodes_[first_pnode]) { + HitGroup group; + group.roles[first_pnode] = node; + init_groups.emplace_back(group); + } + + int step = 0; + bi_records[0] = std::move(init_groups); + + // Extend a PMNode to subgraphs by deducing the connection relations defined + // in edges of PMNodes. + for (const auto &edge : pattern_.edges()) { + VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name(); + // TODO(Superjomn) Fix bug here, the groups might be duplicate here. + // Each role has two PMNodes, which indicates two roles. + // Detect two Nodes that can match these two roles and they are connected. + auto &pre_groups = bi_records[step % 2]; + auto &cur_groups = bi_records[1 - (step++ % 2)]; + cur_groups.clear(); + if (pre_groups.empty()) break; + // source -> target + for (Node *source : pmnodes2nodes_[edge.first]) { + for (Node *target : pmnodes2nodes_[edge.second]) { + // TODO(Superjomn) add some prune strategies. + for (const auto &group : pre_groups) { + if (IsNodesLink(source, target)) { + HitGroup new_group = group; + bool flag = new_group.Match(source, edge.first) && + new_group.Match(target, edge.second); + if (flag) { + new_group.Register(source, edge.first); + new_group.Register(target, edge.second); + cur_groups.push_back(new_group); + // TODO(Superjomn) need to unique + } + } + } + } + } + VLOG(3) << "step " << step << " get records: " << cur_groups.size(); + } + + for (auto &group : bi_records[step % 2]) { + PatternMatcher::subgraph_t subgraph; + for (auto &role : group.roles) { + subgraph.emplace(role.first, role.second); + } + result.emplace_back(subgraph); + } + return result; +} + +struct GraphItemLessThan { + bool operator()(const std::pair &a, + const std::pair &b) { + if (a.first != b.first) { + return a.first < b.first; + } else { + return a.second < b.second; + } + } +}; + +// TODO(Superjomn) enhance the function as it marks unique unique as duplicates +// see https://github.com/PaddlePaddle/Paddle/issues/13550 +void PatternMatcher::UniquePatterns( + std::vector *subgraphs) { + if (subgraphs->empty()) return; + std::vector result; + + std::unordered_set set; + std::hash hasher; + for (auto &g : *subgraphs) { + // Sort the items in the sub-graph, and transform to a string key. + std::vector> sorted_keys(g.begin(), g.end()); + std::sort(sorted_keys.begin(), sorted_keys.end(), GraphItemLessThan()); + STL::stringstream ss; + for (auto &item : sorted_keys) { + ss << reinterpret_cast(item.first) << ":" + << reinterpret_cast(item.second); + } + auto key = hasher(ss.str()); + if (!set.count(key)) { + result.emplace_back(g); + set.insert(key); + } + } + *subgraphs = result; +} + +void PatternMatcher::RemoveOverlappedMatch(std::vector *subgraphs) { + std::vector result; + std::unordered_set node_set; + + for (const auto &subgraph : *subgraphs) { + bool valid = true; + for (auto &item : subgraph) { + if (item.first->IsIntermediate() && node_set.count(item.second)) { + valid = false; + break; + } + } + if (valid) { + for (auto &item : subgraph) { + node_set.insert(item.second); + } + result.push_back(subgraph); + } + } + *subgraphs = result; +} + +std::string PMPattern::DotString() const { + using inference::analysis::Dot; + Dot dot; + int id = 0; + // Create Nodes + std::unordered_map node2dot; + for (const auto &node : nodes()) { + std::string node_id = string_format("Node%d", id++); + dot.AddNode(node_id, {}, node->name()); + node2dot[node.get()] = node_id; + } + // Create Edges + for (const auto &edge : edges()) { + if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) { + continue; + } + auto &src = node2dot.at(edge.first); + auto &trg = node2dot.at(edge.second); + dot.AddEdge(src, trg, {}); + } + return dot.Build(); +} + +PMNode &PMNode::LinksTo(const std::vector &others) { + // extend outlinks. + for (PMNode *x : others) { + pattern_->AddEdge(this, x); + } + return *this; +} + +PMNode &PMNode::LinksFrom(const std::vector &others) { + // extend outlinks. + for (PMNode *x : others) { + pattern_->AddEdge(x, this); + } + return *this; +} + +PMNode *PMNode::assert_is_op() { + asserts_.emplace_back([](const Node *x) { return x && x->IsStmt(); }); + return this; +} + +PMNode *PMNode::assert_is_op(const std::string &op_type) { + asserts_.emplace_back([op_type](const Node *x) { + if (x && x->IsStmt()) { + auto *op_info = x->stmt()->op_info(); + return op_info->Type() == op_type; + } else { + return false; + } + }); + return this; +} + +PMNode *PMNode::assert_is_var() { + asserts_.emplace_back([](const Node *x) { return x && x->IsArg(); }); + return this; +} + +PMNode *PMNode::assert_var_not_persistable() { + assert_is_var(); + asserts_.emplace_back([](const Node *x) { return !x->arg()->is_weight; }); + return this; +} + +PMNode *PMNode::assert_is_persistable_var() { + assert_is_var(); + asserts_.emplace_back([=](const Node *x) { return x->arg()->is_weight; }); + return this; +} + +PMNode *PMNode::assert_is_op_output(const std::string &op_type) { + assert_is_var(); + asserts_.emplace_back([=](const Node *x) { + for (auto *op : x->inlinks) { + if (op && op->IsStmt()) { + auto *op_info = op->stmt()->op_info(); + if (op_info->Type() == op_type) return true; + } + } + return false; + }); + return this; +} + +bool IsNthOutput(const Node *var, + const Node *op, + const std::string &argument, + size_t nth) { + CHECK(var->IsArg()); + CHECK(op->IsStmt()); + auto op_info = op->stmt()->op_info(); + if (op_info->Output(argument).size() <= nth) return false; + return var->arg()->name == op_info->Output(argument)[nth]; +} + +bool IsNthInput(const Node *var, + const Node *op, + const std::string &argument, + size_t nth) { + CHECK(var->IsArg()); + CHECK(op->IsStmt()); + auto op_info = op->stmt()->op_info(); + if (op_info->Input(argument).size() <= nth) return false; + return var->arg()->name == op_info->Input(argument)[nth]; +} + +PMNode *PMNode::assert_is_op_input(const std::string &op_type, + const std::string &argument) { + assert_is_var(); + assert_is_op_nth_input(op_type, argument, 0); + return this; +} + +PMNode *PMNode::assert_is_op_nth_input(const std::string &op_type, + const std::string &argument, + int nth) { + assert_is_var(); + assert_is_op_input(op_type); + asserts_.emplace_back([=](const Node *x) { + for (auto *op : x->outlinks) { + if (op && op->IsStmt() && op->stmt()->op_info()->Type() == op_type && + IsNthInput(x, op, argument, nth)) + return true; + } + return false; + }); + return this; +} + +PMNode *PMNode::assert_is_op_output(const std::string &op_type, + const std::string &argument) { + assert_is_var(); + assert_is_op_nth_output(op_type, argument, 0); + return this; +} + +PMNode *PMNode::assert_is_op_nth_output(const std::string &op_type, + const std::string &argument, + int nth) { + assert_is_var(); + asserts_.emplace_back([=](const Node *x) { + for (auto *op : x->inlinks) { + if (op && op->IsStmt() && op->stmt()->op_info()->Type() == op_type && + IsNthOutput(x, op, argument, nth)) + return true; + } + return false; + }); + return this; +} + +PMNode *PMNode::assert_is_op_input(const std::string &op_type) { + assert_is_var(); + asserts_.emplace_back([=](const Node *x) { + for (auto *op : x->outlinks) { + if (op && op->IsStmt()) { + auto *op_info = op->stmt()->op_info(); + if (op_info->Type() == op_type) { + return true; + } + } + } + return false; + }); + return this; +} + +bool HasInput(const Node &op, const std::string &argument) { + CHECK(op.IsStmt()); + auto const &names = op.stmt()->op_info()->input_argnames(); + if (std::find(names.begin(), names.end(), argument) == names.end()) + return false; + return true; +} + +void GraphSafeRemoveNodes(SSAGraph *graph, + const std::unordered_set &nodes) { + for (auto *node : nodes) { + graph->RemoveNode(node); + } + + for (auto &node : graph->mutable_nodes()) { + for (auto it = node.inlinks.begin(); it != node.inlinks.end();) { + if (nodes.count(*it)) { + it = node.inlinks.erase(it); + } else { + it++; + } + } + for (auto it = node.outlinks.begin(); it != node.outlinks.end();) { + if (nodes.count(*it)) { + it = node.outlinks.erase(it); + } else { + it++; + } + } + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/pattern_matcher.h b/lite/core/mir/pattern_matcher.h new file mode 100644 index 00000000000..112ff37564d --- /dev/null +++ b/lite/core/mir/pattern_matcher.h @@ -0,0 +1,424 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_TESTING +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include "lite/core/mir/node.h" +#include "lite/core/mir/ssa_graph.h" +#include "lite/model_parser/pb/op_desc.h" +#include "lite/utils/cp_logging.h" +#include "lite/utils/replace_stl/stream.h" +#include "lite/utils/string.h" + +namespace paddle { +namespace lite { +namespace mir { +class PMPattern; + +// Some basic terminologies: +// - PMPattern: a pattern defined as a data flow graph. +// - PMNode: the node in the pattern, each PMNode represents an `mir::Node` +// that meets some conditions defined in `PMNode.teller`. +// - A pattern is defined with PMNodes with edges. + +// Pattern matcher node. This node helps to build a pattern. +struct PMNode { + // tell whether an mir::Node* is a candidation for a PMNode. + using teller_t = std::function; + enum class Type { kOp, kVar }; + enum class Role { + kUnknown, // No role, + kInput, // an input and will be retained, + kOutput, // an output and will be retained, + kIntermediate // will be removed after handler. + }; + + // this link to others + PMNode& LinksTo(const std::vector& others); + PMNode& LinksFrom(const std::vector& others); + + // Link this to another node. + PMNode& operator>>(PMNode& right); + + // Link many nodes to this node. + friend PMNode& operator>>(std::vector& others, PMNode& me); + + // Link this to many other nodes. + PMNode& operator>>(std::vector& nodes); + + bool Tell(const Node* node) const { + if (teller_) return teller_(node); + + for (auto& asrt : asserts_) { + if (!asrt(node)) return false; + } + return true; + } + + bool IsOp() const { return type_ == Type::kOp; } + bool IsVar() const { return type_ == Type::kVar; } + + const std::string& name() const { return name_; } + + PMNode& operator=(const PMNode&) = delete; + PMNode(const PMNode&) = delete; + + // Mark this node is an Input of a subgraph and will be retained. + PMNode* AsInput() { + role_ = Role::kInput; + return this; + } + // Mark this node is an Output of a subgraph and will be retained. + PMNode* AsOutput() { + role_ = Role::kOutput; + return this; + } + // Mark this node will be removed, so all the links should be inside a matched + // sub-graph. + PMNode* AsIntermediate() { + role_ = Role::kIntermediate; + return this; + } + + PMNode* AsVar() { + type_ = Type::kVar; + assert_is_var(); + return this; + } + + PMNode* AsOp(const std::string& op_type) { + type_ = Type::kOp; + assert_is_op(op_type); + return this; + } + + void set_op_type(const std::string& op_type) { op_type_ = op_type; } + + bool IsIntermediate() const { return role_ == Role::kIntermediate; } + bool IsInput() const { return role_ == Role::kInput; } + bool IsOutput() const { return role_ == Role::kOutput; } + + // Assertions, helper functions to simplify the pattern definition. + PMNode* assert_is_op(); + PMNode* assert_is_op(const std::string& op_type); + PMNode* assert_is_var(); + PMNode* assert_var_not_persistable(); + PMNode* assert_is_persistable_var(); + PMNode* assert_is_op_output(const std::string& op_type); + PMNode* assert_is_op_input(const std::string& op_type); + PMNode* assert_is_op_input(const std::string& op_type, + const std::string& argument); + PMNode* assert_is_op_output(const std::string& op_type, + const std::string& argument); + + PMNode* assert_is_op_nth_input(const std::string& op_type, + const std::string& argument, + int nth); + PMNode* assert_is_op_nth_output(const std::string& op_type, + const std::string& argument, + int nth); + + template + PMNode* assert_op_attr(const std::string& attr_name, const T& attr) { + asserts_.push_back([=](const Node* x) { + if (x && x->IsStmt()) { + auto* op_info = x->stmt()->op_info(); + return op_info->HasAttr(attr_name) && + op_info->GetAttr(attr_name) == attr; + } + return false; + }); + return this; + } + + private: + PMNode(PMPattern* pattern, + const std::string& name = "", + Type type = Type::kVar) + : pattern_(pattern), name_(name), type_(type) {} + PMNode(teller_t&& teller, + PMPattern* pattern, + const std::string& name = "", + Type type = Type::kVar) + : teller_(std::move(teller)), + pattern_(pattern), + name_(name), + type_(type) { + CHECK(teller_ != nullptr) << "invalid teller functer is set."; + } + + PMNode(PMNode&& other) = default; + + friend class PMPattern; + + // Will removed latter. + teller_t teller_; + std::vector asserts_; + PMPattern* pattern_; + std::string name_; + std::string op_type_; + Type type_; + Role role_{Role::kUnknown}; +}; + +/* + * A pattern in a graph, which defined with PMNode and edges. Most graph + * patterns can be divided into PMNodes and link relations between them. + * + * For example, the FC fusion need to filter the MUL and ELEMENTWISE_ADD + * operators from the computation graph, the MUL's output should have only one + * consumer which is the ELEMENTWISE_ADD. + * This pattern can be defined as with the following pseudo codes + * + * // Create two operator PMNodes. + * MUL = PMPattern.NewNode().assert_is_op("mul"); + * ELE = PMPattern.NewNode().assert_is_op("elementwise_add"); + * // Create the variable PMNodes. + * MUL_out = PMPattern.NewNode().assert_is_op_output("mul") \ + * .assert_is_op_input("elementwise_add") \ + * .AsIntermediate(); + * // Add relations. + * MUL->LinksTo({MUL_out}); + * MUL_out->LinksTo({ELE}); + * + * One can add more specific asserts for PMNodes or edges, both the Operator + * and Variable Nodes can be ruled in PMNode.assert_more(...). + * + * PMPattern can record the general patterns, such as the pattern represents + * - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place. + * - Ops whose inputs and outputs share the same variables + */ +class PMPattern { + public: + using edge_t = std::pair; + + void AddEdge(PMNode* a, PMNode* b); + + PMNode* NewNode(PMNode::teller_t&& teller, const std::string& name = NewID()); + PMNode* NewNode(const std::string& name = NewID()); + PMNode* NewNode(const std::string& prefix, const std::string& name) { + return NewNode(prefix + "/" + name); + } + PMNode* RetrieveNode(const std::string& id) const; + + const std::vector>& nodes() const { return nodes_; } + const std::vector& edges() const { return edges_; } + + std::string DotString() const; + + private: +#ifdef PADDLE_WITH_TESTING + FRIEND_TEST(PMPattern, AddEdge); + FRIEND_TEST(PMPattern, NewNode); +#endif + + static std::string NewID() { return string_format("pmnode-%d", id_++); } + + std::vector> nodes_; + std::vector edges_; + std::unordered_map node_map_; + static size_t id_; +}; + +/* + * PatternMatcher helps to detect the specific patterns in the graph. + * Input a pattern, output a list of the matched subgraphs/nodes. + * This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.). + * + * The algorithm has three phases: + * 1. Mark the nodes that match the defined PMNodes in a PMPattern, + * 2. Extend a PMNode to subgraphs by deducing the connection relation defined + * in PAPattern(the edges), + * 3. Get the filtered subgraphs and treat them with a pre-defined handler. + * + * Usage: + * // Create a matcher + * PatternMatcher matcher; + * // Define the matcher's pattern, by adding PMNode and define the edges. + * auto* node0 = matcher.mutable_pattern().AddNode(...) + * auto* node1 = matcher.mutable_pattern().AddNode(...) + * node0->teller = some lambda. + * node1->teller = some lambda. + * matcher.mutable_pattern().AddEdge(node0, node1); + * // Create an handler, to define the behavior of treating the filtered + * // subgraphs that comply with the patterns. + * PatternMatcher::handle_t handler = some labmda + * // Execute the matcher. + * matcher(&graph, handler); + */ +class PatternMatcher { + public: + using subgraph_t = std::unordered_map; + + // Operate on the detected pattern. + using handle_t = + std::function; + + void operator()(SSAGraph* graph, handle_t handler); + + const PMPattern& pattern() const { return pattern_; } + PMPattern* mutable_pattern() { return &pattern_; } + + private: + // Mark the nodes that fits the pattern. + bool MarkPMNodesInGraph(SSAGraph* graph); + + // Detect all the pattern and output the hit records. + std::vector DetectPatterns(); + + // Remove duplicate patterns. + void UniquePatterns(std::vector* subgraphs); + + // Remove overlapped match subgraphs, when overlapped, keep the previous one. + // The intermediate PMNodes will be removed, so can't shared by multiple + // patterns. + void RemoveOverlappedMatch(std::vector* subgraphs); + + // Validate whether the intermediate nodes are linked by external nodes. + void ValidateByNodeRole(std::vector* subgraphs); + +#ifdef PADDLE_WITH_TESTING + FRIEND_TEST(PatternMatcher, MarkPMNodesInGraph); + FRIEND_TEST(PatternMatcher, DetectPatterns); +#endif + + private: + using hit_rcd_t = + std::pair; + PMPattern pattern_; + std::unordered_map> pmnodes2nodes_; +}; + +// Check whether a var node is a op node's nth input. +bool IsNthInput(const Node& var, + const Node& op, + const std::string& argument, + int nth); + +// Check whether the op node has input of given name. +bool HasInput(const Node& op, const std::string& argument); + +// Graph safely remove some nodes, will automatically clean up the edges. +void GraphSafeRemoveNodes(SSAGraph* graph, + const std::unordered_set& nodes); + +// Some pre-defined patterns those can be reused in multiple passes. +// The related Fluid Layer or Op should be one pattern here for better re-usage +// across different fusion. +namespace patterns { + +struct KeyCounter { + static KeyCounter& Instance() { + static KeyCounter x; + return x; + } + + int IncCounter(const std::string& key) { return dic_[key]++; } + + private: + std::unordered_map dic_; +}; + +// Generate a unique PMNode's name with name_scope and id. +// The format is {name_scope}/{repr}/{id}/{name} +static std::string PMNodeName(const std::string& name_scope, + const std::string& repr, + size_t id, + const std::string& name) { + STL::stringstream ss; + ss << name_scope << "/" << repr << "/" << id << "/" << name; + return ss.str(); +} +// Generate a unique PMNode's name. +// The format is {name_scope}/{repr}/{id} +static std::string PMNodeName(const std::string& name_scope, + const std::string& repr) { + STL::stringstream ss; + ss << name_scope << "/" << repr << "/" + << KeyCounter::Instance().IncCounter(repr); + return ss.str(); +} +// Generate a unique key. It can be used for a universally unique temporary +// name. +// The format is {repr}/{id} +static std::string UniqueKey(const std::string& repr) { + STL::stringstream ss; + ss << repr << "/" << KeyCounter::Instance().IncCounter(repr); + return ss.str(); +} + +// Declare a PMNode in a pattern, will create two methods: +// std::string xxx_repr(); return this PMNode's string id. +// PMNode* xxx_n(); return the corresponding PMNode. +#define PATTERN_DECL_NODE(name__) \ + std::string name__##_repr() const { \ + return PMNodeName(name_scope_, repr_, id_, #name__); \ + } \ + PMNode* name__##_n() const { return pattern->RetrieveNode(name__##_repr()); } + +// Get an mir::Node* from the matched subgraph. +// var: variable. +// arg: the argument declared by PATTERN_DECL_NODE in a pattern definition. +// pat: the pattern object. +#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \ + CHECK(subgraph.count(pat.arg##_n())) \ + << "Node not found for PMNode " pat.arg##_repr(); \ + Node* var = subgraph.at(pat.arg##_n()); \ + CHECK(var) << "node " << #arg << "not exists in the sub-graph" + +// The base class of all the patterns. +struct PatternBase { + PatternBase(PMPattern* pattern, + const std::string& name_scope, + const std::string& repr) + : pattern(pattern), + name_scope_(name_scope), + repr_(repr), + id_(KeyCounter::Instance().IncCounter(repr)) {} + + PMPattern* pattern; + + protected: + std::string name_scope_; + std::string repr_; + size_t id_; +}; + +} // namespace patterns + +// Link two mir::Nodes from each other. +#define IR_NODE_LINK_TO(a, b) \ + a->outlinks.push_back(b); \ + b->inlinks.push_back(a); + +// Set the out_var as the output of the op +#define IR_OP_VAR_LINK(op, out_var) \ + op->outlinks.push_back(out_var); \ + out_var->inlinks.clear(); \ + out_var->inlinks.push_back(op); + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/pattern_matcher_high_api.cc b/lite/core/mir/pattern_matcher_high_api.cc new file mode 100644 index 00000000000..620f4ebbea6 --- /dev/null +++ b/lite/core/mir/pattern_matcher_high_api.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pattern_matcher_high_api.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace mir { + +void FuseBase::PerformPatternMatcher(SSAGraph *graph) { + VLOG(4) << "\n" << matcher_.pattern().DotString(); + // Get subgraphs and record the mir::Node pointers for each PMNode. + auto handler = [&](const PatternMatcher::subgraph_t &subgraph, SSAGraph *g) { + // get all the reigistered nodes. + key2nodes_.emplace_back(); + for (auto &item : nodes_) { + key2nodes_.back()[item.first] = subgraph.at(item.second); + } + }; + + matcher_(graph, handler); +} + +void FuseBase::DeleteInterNodes(SSAGraph *graph) { + std::set keys; + for (auto &node : nodes_) { + if (node.second->IsIntermediate()) { + keys.insert(node.first); + } + } + + VLOG(4) << "keys: " << key2nodes_.size(); + std::unordered_set nodes2rm; + for (auto &matched : key2nodes_) { + for (const auto &key : keys) { + nodes2rm.insert(matched.at(key)); + } + } + + VLOG(3) << "clean nodes " << nodes2rm.size(); + GraphSafeRemoveNodes(graph, nodes2rm); +} + +PMNode *FuseBase::GetOrCreateNode(const std::string &key) { + auto it = nodes_.find(key); + if (it != nodes_.end()) { + return it->second; + } + nodes_.emplace(key, + matcher_.mutable_pattern()->NewNode(patterns::UniqueKey(key))); + it = nodes_.find(key); + return it->second; +} + +PMNode *FuseBase::OpNode(const std::string &key, const std::string &op_type) { + GetOrCreateNode(key)->set_op_type(op_type); + GetOrCreateNode(key)->AsOp(op_type); + return GetOrCreateNode(key); +} + +PMNode *FuseBase::VarNode(const std::string &key) { + GetOrCreateNode(key)->AsVar(); + return GetOrCreateNode(key); +} + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/pattern_matcher_high_api.h b/lite/core/mir/pattern_matcher_high_api.h new file mode 100644 index 00000000000..e62a4fc7494 --- /dev/null +++ b/lite/core/mir/pattern_matcher_high_api.h @@ -0,0 +1,83 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "lite/core/mir/node.h" +#include "lite/core/mir/pattern_matcher.h" +#include "lite/core/mir/ssa_graph.h" + +namespace paddle { +namespace lite { +namespace mir { + +class FuseBase { + public: + using key2nodes_t = std::map; + + virtual ~FuseBase() = default; + + void operator()(SSAGraph* graph) { + BuildPattern(); + PerformPatternMatcher(graph); + + for (const auto& matched : key2nodes_) { + InsertNewNode(graph, matched); + } + + DeleteInterNodes(graph); + } + + // Build a PMPattern using PMNode. + virtual void BuildPattern() = 0; + + // Generate an operator desc with a matched subgraph. + virtual cpp::OpDesc GenOpDesc(const key2nodes_t& matched) { + return cpp::OpDesc(); + } + + PMNode* OpNode(const std::string& key) { + return GetOrCreateNode(key)->assert_is_op(); + } + + PMNode* OpNode(const std::string& key, const std::string& op_type); + + PMNode* VarNode(const std::string& key); + + protected: + virtual void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) = 0; + + private: + void PerformPatternMatcher(SSAGraph* graph); + + // Delete nodes that are marked as Intermediate + void DeleteInterNodes(SSAGraph* graph); + + PMNode* GetOrCreateNode(const std::string& key); + + protected: + PatternMatcher matcher_; + std::map nodes_; + std::vector key2nodes_; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/pattern_matcher_high_api_test.cc b/lite/core/mir/pattern_matcher_high_api_test.cc new file mode 100644 index 00000000000..61914c5a0b0 --- /dev/null +++ b/lite/core/mir/pattern_matcher_high_api_test.cc @@ -0,0 +1,150 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pattern_matcher_high_api.h" +#include +#include +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/program.h" +#include "lite/core/tensor.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace lite { +namespace mir { + +// An demo. +class FcFuser : public FuseBase { + public: + void BuildPattern() override { + // create nodes. + auto* x = VarNode("x")->assert_is_op_input("mul", "X"); + auto* W = VarNode("W")->assert_is_op_input("mul", "Y"); + auto* b = VarNode("b"); + auto* mul = OpNode("mul", "mul"); + auto* mul_out = VarNode("mul_out"); + auto* add = OpNode("add", "elementwise_add"); + auto* Out = VarNode("Out"); + + // create topology. + std::vector mul_inputs{W, x}; + std::vector add_inputs{mul_out, b}; + mul_inputs >> *mul >> *mul_out; + add_inputs >> *add >> *Out; + + // Some op specialities. + mul_out->AsIntermediate(); + mul->AsIntermediate(); + add->AsIntermediate(); + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + auto op_desc = GenOpDesc(matched); + auto fc_op = LiteOpRegistry::Global().Create("fc"); + auto mul = matched.at("mul")->stmt()->op(); + auto* scope = mul->scope(); + auto& valid_places = mul->valid_places(); + fc_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places); + + IR_NODE_LINK_TO(matched.at("W"), new_op_node); + IR_NODE_LINK_TO(matched.at("x"), new_op_node); + IR_NODE_LINK_TO(matched.at("b"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("Out")); + } + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("fc"); + op_desc.SetInput("Input", {matched.at("x")->arg()->name}); + op_desc.SetInput("W", {matched.at("W")->arg()->name}); + op_desc.SetInput("Bias", {matched.at("b")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("Out")->arg()->name}); + op_desc.SetAttr("in_num_col_dims", 1); + return op_desc; + } +}; + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + auto* main_block = program_desc->MutableBlock(0); + auto* mul = main_block->AppendOp(); + auto* add = main_block->AppendOp(); + main_block->Var("x"); + main_block->Var("b"); + main_block->Var("mul_out"); + main_block->Var("w"); + main_block->Var("out"); + + scope->Var("x")->GetMutable(); + scope->Var("b")->GetMutable(); + scope->Var("mul_out")->GetMutable(); + scope->Var("w")->GetMutable(); + scope->Var("out")->GetMutable(); + + mul->SetInput("X", {"x"}); + mul->SetInput("Y", {"w"}); + mul->SetOutput("Out", {"mul_out"}); + mul->SetType("mul"); + mul->SetAttr("x_num_col_dims", 1); + mul->SetAttr("y_num_col_dims", 1); + + add->SetInput("X", {"mul_out"}); + add->SetInput("Y", {"b"}); + add->SetOutput("Out", {"out"}); + add->SetType("elementwise_add"); + add->SetAttr("axis", 1); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + return graph; +} + +TEST(pattern_matcher_high_api, graph_test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + + ASSERT_EQ(graph->nodes().size(), 7UL /*real nodes*/); + Visualize(graph.get()); +} + +TEST(pattern_matcher_high_api, fuse_test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + const int num_nodes = graph->nodes().size(); + FcFuser fuser; + fuser(graph.get()); + ASSERT_EQ(graph->nodes().size(), + num_nodes - 3UL /*nodes removed */ + 1UL /* fused fc node*/); + Visualize(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(fc); +USE_LITE_OP(mul); +USE_LITE_OP(elementwise_add); diff --git a/lite/core/mir/pattern_matcher_test.cc b/lite/core/mir/pattern_matcher_test.cc new file mode 100644 index 00000000000..728681a4590 --- /dev/null +++ b/lite/core/mir/pattern_matcher_test.cc @@ -0,0 +1,233 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pattern_matcher.h" + +#include + +namespace paddle { +namespace lite { +namespace mir { + +void BuildGraph(SSAGraph* g) { + g->mutable_nodes().emplace_back(); + Node& o1 = g->mutable_nodes().back(); + o1.AsStmt().desc = "op1"; + g->mutable_nodes().emplace_back(); + Node& o2 = g->mutable_nodes().back(); + o2.AsStmt().desc = "op2"; + g->mutable_nodes().emplace_back(); + Node& o3 = g->mutable_nodes().back(); + o3.AsStmt().desc = "op3"; + g->mutable_nodes().emplace_back(); + Node& o4 = g->mutable_nodes().back(); + o4.AsStmt().desc = "op4"; + g->mutable_nodes().emplace_back(); + Node& o5 = g->mutable_nodes().back(); + o5.AsStmt().desc = "op5"; + g->mutable_nodes().emplace_back(); + Node& v1 = g->mutable_nodes().back(); + v1.AsArg("var1"); + g->mutable_nodes().emplace_back(); + Node& v2 = g->mutable_nodes().back(); + v2.AsArg("var2"); + g->mutable_nodes().emplace_back(); + Node& v3 = g->mutable_nodes().back(); + v3.AsArg("var3"); + g->mutable_nodes().emplace_back(); + Node& v4 = g->mutable_nodes().back(); + v4.AsArg("var4"); + + // o1->v1->o2 + o1.outlinks.push_back(&v1); + o2.inlinks.push_back(&v1); + v1.inlinks.push_back(&o1); + v1.outlinks.push_back(&o2); + // o2->v2->o3 + // o2->v2->o4 + o2.outlinks.push_back(&v2); + o3.inlinks.push_back(&v2); + o4.inlinks.push_back(&v2); + v2.inlinks.push_back(&o2); + v2.outlinks.push_back(&o3); + v2.outlinks.push_back(&o4); + // o2->v3->o5 + o2.outlinks.push_back(&v3); + o5.inlinks.push_back(&v3); + v3.inlinks.push_back(&o2); + v3.outlinks.push_back(&o5); + // o3-v4->o5 + o3.outlinks.push_back(&v4); + o5.inlinks.push_back(&v4); + v4.inlinks.push_back(&o3); + v4.outlinks.push_back(&o5); +} + +TEST(PMPattern, NewNode) { + PMPattern x; + auto* n = x.NewNode([](const Node* x) { return true; }); + ASSERT_TRUE(n); + ASSERT_EQ(x.nodes_.size(), 1UL); +} + +TEST(PMPattern, AddEdge) { + PMPattern x; + auto* a = x.NewNode([](const Node* x) { return true; }); + auto* b = x.NewNode([](const Node* x) { return true; }); + ASSERT_TRUE(a); + ASSERT_TRUE(b); + x.AddEdge(a, b); + ASSERT_EQ(x.nodes_.size(), 2UL); + ASSERT_EQ(x.edges_.size(), 1UL); + ASSERT_EQ(x.edges_.front().first, a); + ASSERT_EQ(x.edges_.front().second, b); + + ASSERT_EQ(x.nodes().size(), 2UL); + ASSERT_EQ(x.edges().size(), 1UL); + ASSERT_EQ(x.edges().front().first, a); + ASSERT_EQ(x.edges().front().second, b); +} + +TEST(PatternMatcher, MarkPMNodesInGraph) { + PatternMatcher x; + // mark o2, o3, v2 + + // The pattern is a graph: + // o2(a node named o2) -> v2(a node named v2) + // v2 -> o3(a node named o3) + auto* o2 = x.pattern_.NewNode([](const Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->IsStmt() && node->stmt()->desc == "op2"; + }); + auto* o3 = x.pattern_.NewNode([](const Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->IsStmt() && node->stmt()->desc == "op3"; + }); + auto* v2 = x.pattern_.NewNode([](const Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->IsArg() && node->arg()->name == "var2"; + }); + + ASSERT_FALSE(o2->Tell(nullptr)); + ASSERT_FALSE(o3->Tell(nullptr)); + ASSERT_FALSE(v2->Tell(nullptr)); + + x.pattern_.AddEdge(o2, v2); + x.pattern_.AddEdge(v2, o3); + + ASSERT_EQ(x.pattern_.edges().size(), 2UL); + ASSERT_EQ(x.pattern_.edges()[0].first, o2); + ASSERT_EQ(x.pattern_.edges()[0].second, v2); + ASSERT_EQ(x.pattern_.edges()[1].first, v2); + ASSERT_EQ(x.pattern_.edges()[1].second, o3); + + SSAGraph graph; + BuildGraph(&graph); + + x.MarkPMNodesInGraph(&graph); + + ASSERT_EQ(x.pmnodes2nodes_.size(), 3UL); + + auto subgraphs = x.DetectPatterns(); + ASSERT_EQ(subgraphs.size(), 1UL); +} + +TEST(PatternMatcher, MultiSubgraph) { + SSAGraph graph; + BuildGraph(&graph); + + PatternMatcher x; + + // The pattern is a graph: + // op -> var + auto* any_op = x.mutable_pattern()->NewNode( + [](const Node* node) { + return node->IsStmt() && + (node->stmt()->desc == "op2" || node->stmt()->desc == "op3"); + }, + "OP0"); + auto* any_var = + x.mutable_pattern() + ->NewNode([](const Node* node) { return node->IsArg(); }, "VAR") + ->AsIntermediate(); + auto* any_op1 = x.mutable_pattern()->NewNode( + [](const Node* node) { return node->IsStmt(); }, "OP1"); + + x.mutable_pattern()->AddEdge(any_op, any_var); + x.mutable_pattern()->AddEdge(any_var, any_op1); + + int count = 0; + PatternMatcher::handle_t handle = [&](const PatternMatcher::subgraph_t& s, + SSAGraph* g) { + LOG(INFO) << "Detect " << s.at(any_op)->stmt()->desc << " -> " + << s.at(any_var)->arg()->name << " -> " + << s.at(any_op1)->stmt()->desc; + count++; + }; + + x(&graph, handle); + + // 1. Detect op3 -> var4 -> op5 + // 2. Detect op2 -> var2 -> op3 + // 3. Detect op2 -> var2 -> op4 + // 4. Detect op2 -> var3 -> op5 + // But 2 and 3 and 4 overlapped, so keep 2, so the final choices are 1 and 2 + ASSERT_GE(count, 1); + ASSERT_LE(count, 2); +} + +TEST(PatternMatcher, IntermediateCheck) { + SSAGraph graph; + BuildGraph(&graph); + + // o2->v2->o3 + // o2->v2->o4 + // check o2+o3 fuse, should fail because v2 also link to o4. + PatternMatcher matcher; + auto* op2 = matcher.mutable_pattern()->NewNode( + [](const Node* x) { + return x && x->IsStmt() && x->stmt()->desc == "op2"; + }, + "op2"); + auto* op3 = matcher.mutable_pattern()->NewNode( + [](const Node* x) { + return x && x->IsStmt() && x->stmt()->desc == "op3"; + }, + "op3"); + auto* v2 = matcher.mutable_pattern() + ->NewNode( + [](const Node* x) { + return x && x->IsArg() && x->arg()->name == "var2"; + }, + "var2") + ->AsIntermediate(); + v2->LinksFrom({op2}).LinksTo({op3}); + + int count = 0; + matcher(&graph, [&](const PatternMatcher::subgraph_t& g, SSAGraph* graph) { + ++count; + }); + EXPECT_EQ(count, 0); + + count = 0; + v2->AsInput(); + matcher(&graph, [&](const PatternMatcher::subgraph_t& g, SSAGraph* graph) { + ++count; + }); + ASSERT_EQ(count, 1); +} + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/pattern_matcher_tester.cc b/lite/core/mir/pattern_matcher_tester.cc new file mode 100644 index 00000000000..a62c3af62f6 --- /dev/null +++ b/lite/core/mir/pattern_matcher_tester.cc @@ -0,0 +1,233 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pattern_matcher.h" + +#include + +namespace paddle { +namespace lite { +namespace mir { + +void BuildGraph(SSAGraph* g) { + g->mutable_nodes().emplace_back(); + Node& o1 = g->mutable_nodes().back(); + o1.AsStmt().op_type = "op1"; + g->mutable_nodes().emplace_back(); + Node& o2 = g->mutable_nodes().back(); + o2.AsStmt().op_type = "op2"; + g->mutable_nodes().emplace_back(); + Node& o3 = g->mutable_nodes().back(); + o3.AsStmt().op_type = "op3"; + g->mutable_nodes().emplace_back(); + Node& o4 = g->mutable_nodes().back(); + o4.AsStmt().op_type = "op4"; + g->mutable_nodes().emplace_back(); + Node& o5 = g->mutable_nodes().back(); + o5.AsStmt().op_type = "op5"; + g->mutable_nodes().emplace_back(); + Node& v1 = g->mutable_nodes().back(); + v1.AsArg("var1"); + g->mutable_nodes().emplace_back(); + Node& v2 = g->mutable_nodes().back(); + v2.AsArg("var2"); + g->mutable_nodes().emplace_back(); + Node& v3 = g->mutable_nodes().back(); + v3.AsArg("var3"); + g->mutable_nodes().emplace_back(); + Node& v4 = g->mutable_nodes().back(); + v4.AsArg("var4"); + + // o1->v1->o2 + o1.outlinks.push_back(&v1); + o2.inlinks.push_back(&v1); + v1.inlinks.push_back(&o1); + v1.outlinks.push_back(&o2); + // o2->v2->o3 + // o2->v2->o4 + o2.outlinks.push_back(&v2); + o3.inlinks.push_back(&v2); + o4.inlinks.push_back(&v2); + v2.inlinks.push_back(&o2); + v2.outlinks.push_back(&o3); + v2.outlinks.push_back(&o4); + // o2->v3->o5 + o2.outlinks.push_back(&v3); + o5.inlinks.push_back(&v3); + v3.inlinks.push_back(&o2); + v3.outlinks.push_back(&o5); + // o3-v4->o5 + o3.outlinks.push_back(&v4); + o5.inlinks.push_back(&v4); + v4.inlinks.push_back(&o3); + v4.outlinks.push_back(&o5); +} + +TEST(PMPattern, NewNode) { + PMPattern x; + auto* n = x.NewNode([](const Node* x) { return true; }); + ASSERT_TRUE(n); + ASSERT_EQ(x.nodes_.size(), 1UL); +} + +TEST(PMPattern, AddEdge) { + PMPattern x; + auto* a = x.NewNode([](const Node* x) { return true; }); + auto* b = x.NewNode([](const Node* x) { return true; }); + ASSERT_TRUE(a); + ASSERT_TRUE(b); + x.AddEdge(a, b); + ASSERT_EQ(x.nodes_.size(), 2UL); + ASSERT_EQ(x.edges_.size(), 1UL); + ASSERT_EQ(x.edges_.front().first, a); + ASSERT_EQ(x.edges_.front().second, b); + + ASSERT_EQ(x.nodes().size(), 2UL); + ASSERT_EQ(x.edges().size(), 1UL); + ASSERT_EQ(x.edges().front().first, a); + ASSERT_EQ(x.edges().front().second, b); +} + +TEST(PatternMatcher, MarkPMNodesInGraph) { + PatternMatcher x; + // mark o2, o3, v2 + + // The pattern is a graph: + // o2(a node named o2) -> v2(a node named v2) + // v2 -> o3(a node named o3) + auto* o2 = x.pattern_.NewNode([](const Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->IsStmt() && node->stmt()->op_type == "op2"; + }); + auto* o3 = x.pattern_.NewNode([](const Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->IsStmt() && node->stmt()->op_type == "op3"; + }); + auto* v2 = x.pattern_.NewNode([](const Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->IsArg() && node->arg()->name == "var2"; + }); + + ASSERT_FALSE(o2->Tell(nullptr)); + ASSERT_FALSE(o3->Tell(nullptr)); + ASSERT_FALSE(v2->Tell(nullptr)); + + x.pattern_.AddEdge(o2, v2); + x.pattern_.AddEdge(v2, o3); + + ASSERT_EQ(x.pattern_.edges().size(), 2UL); + ASSERT_EQ(x.pattern_.edges()[0].first, o2); + ASSERT_EQ(x.pattern_.edges()[0].second, v2); + ASSERT_EQ(x.pattern_.edges()[1].first, v2); + ASSERT_EQ(x.pattern_.edges()[1].second, o3); + + SSAGraph graph; + BuildGraph(&graph); + + x.MarkPMNodesInGraph(&graph); + + ASSERT_EQ(x.pmnodes2nodes_.size(), 3UL); + + auto subgraphs = x.DetectPatterns(); + ASSERT_EQ(subgraphs.size(), 1UL); +} + +TEST(PatternMatcher, MultiSubgraph) { + SSAGraph graph; + BuildGraph(&graph); + + PatternMatcher x; + + // The pattern is a graph: + // op -> var + auto* any_op = x.mutable_pattern()->NewNode( + [](const Node* node) { + return node->IsStmt() && (node->stmt()->op_type == "op2" || + node->stmt()->op_type == "op3"); + }, + "OP0"); + auto* any_var = + x.mutable_pattern() + ->NewNode([](const Node* node) { return node->IsArg(); }, "VAR") + ->AsIntermediate(); + auto* any_op1 = x.mutable_pattern()->NewNode( + [](const Node* node) { return node->IsStmt(); }, "OP1"); + + x.mutable_pattern()->AddEdge(any_op, any_var); + x.mutable_pattern()->AddEdge(any_var, any_op1); + + int count = 0; + PatternMatcher::handle_t handle = [&](const PatternMatcher::subgraph_t& s, + SSAGraph* g) { + LOG(INFO) << "Detect " << s.at(any_op)->stmt()->op_type << " -> " + << s.at(any_var)->arg()->name << " -> " + << s.at(any_op1)->stmt()->op_type; + count++; + }; + + x(&graph, handle); + + // 1. Detect op3 -> var4 -> op5 + // 2. Detect op2 -> var2 -> op3 + // 3. Detect op2 -> var2 -> op4 + // 4. Detect op2 -> var3 -> op5 + // But 2 and 3 and 4 overlapped, so keep 2, so the final choices are 1 and 2 + ASSERT_GE(count, 1); + ASSERT_LE(count, 2); +} + +TEST(PatternMatcher, IntermediateCheck) { + SSAGraph graph; + BuildGraph(&graph); + + // o2->v2->o3 + // o2->v2->o4 + // check o2+o3 fuse, should fail because v2 also link to o4. + PatternMatcher matcher; + auto* op2 = matcher.mutable_pattern()->NewNode( + [](const Node* x) { + return x && x->IsStmt() && x->stmt()->op_type == "op2"; + }, + "op2"); + auto* op3 = matcher.mutable_pattern()->NewNode( + [](const Node* x) { + return x && x->IsStmt() && x->stmt()->op_type == "op3"; + }, + "op3"); + auto* v2 = matcher.mutable_pattern() + ->NewNode( + [](const Node* x) { + return x && x->IsArg() && x->arg()->name == "var2"; + }, + "var2") + ->AsIntermediate(); + v2->LinksFrom({op2}).LinksTo({op3}); + + int count = 0; + matcher(&graph, [&](const PatternMatcher::subgraph_t& g, SSAGraph* graph) { + ++count; + }); + EXPECT_EQ(count, 0); + + count = 0; + v2->AsInput(); + matcher(&graph, [&](const PatternMatcher::subgraph_t& g, SSAGraph* graph) { + ++count; + }); + ASSERT_EQ(count, 1); +} + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/runtime_context_assign_pass.cc b/lite/core/mir/runtime_context_assign_pass.cc new file mode 100644 index 00000000000..7a063b0bfd4 --- /dev/null +++ b/lite/core/mir/runtime_context_assign_pass.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +class RuntimeContextAssignPass : public StmtPass { + public: + RuntimeContextAssignPass() {} + + void Apply(const std::unique_ptr& graph) override { + for (auto& node : graph->mutable_nodes()) { + if (!node.IsStmt()) continue; + auto& inst = node.AsStmt(); + inst.picked_kernel().SetContext( + ContextScheduler::Global().NewContext(inst.picked_kernel().target())); + } + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(runtime_context_assign_pass, + paddle::lite::mir::RuntimeContextAssignPass); diff --git a/lite/core/mir/ssa_graph.cc b/lite/core/mir/ssa_graph.cc new file mode 100644 index 00000000000..5193d9c899b --- /dev/null +++ b/lite/core/mir/ssa_graph.cc @@ -0,0 +1,240 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/ssa_graph.h" +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace mir { + +bool SSAGraph::CheckBidirectionalConnection() { + VLOG(4) << "node count " << node_storage_.size(); + for (auto &node : node_storage_) { + if (node.IsStmt()) VLOG(4) << node.AsStmt().op_info()->Type(); + if (node.IsArg()) VLOG(4) << node.AsArg().name << " " << node.AsArg().id; + for (auto *in : node.inlinks) { + CHECK(in->outlinks.end() != + std::find(in->outlinks.begin(), in->outlinks.end(), &node)); + } + for (auto *out : node.outlinks) { + CHECK(out->inlinks.end() != + std::find(out->inlinks.begin(), out->inlinks.end(), &node)); + } + } + return true; +} + +std::map> SSAGraph::BuildOperationAdjList() { + std::map> adj_list; + + for (auto &n : mutable_nodes()) { + if (!n.IsStmt()) continue; + if (adj_list.find(&n) == adj_list.end()) { + adj_list[&n] = std::set(); + } + std::vector nodes; + for (auto &var : n.inlinks) { + for (auto &adj_n : var->inlinks) { + CHECK(adj_n->IsStmt()); + nodes.push_back(adj_n); + } + } + std::sort(nodes.begin(), + nodes.end(), + [](mir::Node *node1, mir::Node *node2) { return node1 > node2; }); + adj_list[&n].insert(std::make_move_iterator(nodes.begin()), + std::make_move_iterator(nodes.end())); + } + return adj_list; +} + +void SSAGraph::SortHelper( + const std::map> &adj_list, + mir::Node *node, + std::set *visited, + std::vector *ret) { + visited->insert(node); + + for (auto adj : adj_list.at(node)) { + if (visited->find(adj) == visited->end()) { + SortHelper(adj_list, adj, visited, ret); + } + } + + ret->push_back(node); +} + +std::vector SSAGraph::StmtTopologicalOrder() { + CheckBidirectionalConnection(); + + std::stack stack; + std::set visited; + std::vector res; + + auto adj_list = BuildOperationAdjList(); + + for (auto adj : adj_list) { + if (visited.find(adj.first) == visited.end()) { + SortHelper(adj_list, adj.first, &visited, &res); + } + } + + return res; +} + +Node *SSAGraph::GraphCreateInstructNode( + const std::shared_ptr &op, const std::vector &valid_places) { + node_storage_.emplace_back(); + // TODO(Superjomn) remove one valid_places here. + op->SetValidPlaces(valid_places); + auto &new_node = node_storage_.back(); + auto kernels = op->CreateKernels(valid_places); + node_storage_.back().AsStmt(op->op_type_, std::move(kernels), op); + + CHECK(new_node.inlinks.empty()) << "duplicate Build found"; + CHECK(new_node.outlinks.empty()) << "duplicate Build found"; + return &node_storage_.back(); +} + +void SSAGraph::Build(const Program &program, + const std::vector &valid_places) { + CHECK(node_storage_.empty()); + + auto weights_name = program.weights(); + auto is_weights = [&](const std::string &name) -> bool { + auto it = std::find(weights_name.begin(), weights_name.end(), name); + if (it == weights_name.end()) return false; + return true; + }; + + std::unordered_map arg_update_node_map_; + for (auto &op : program.ops()) { + VLOG(3) << op->op_info()->Type(); + auto *op_node = GraphCreateInstructNode(op, valid_places); + for (const std::string &name : op->op_info()->input_names()) { + mir::Node *arg_node = nullptr; + if (arg_update_node_map_.count(name)) { + arg_node = arg_update_node_map_.at(name); + } else { + node_storage_.emplace_back(); + arg_node = &node_storage_.back(); + arg_node->AsArg(name, node_storage_.size() - 1); + arg_update_node_map_[name] = arg_node; + } + if (is_weights(name)) arg_node->AsArg().is_weight = true; + CHECK(arg_node->IsRoleSet()); + DirectedLink(arg_node, op_node); + } + for (const std::string &name : op->op_info()->output_names()) { + node_storage_.emplace_back(); + auto *arg_node = &node_storage_.back(); + arg_node->AsArg(name, node_storage_.size() - 1); + arg_update_node_map_[name] = arg_node; + + if (is_weights(name)) arg_node->AsArg().is_weight = true; + CHECK(arg_node->IsRoleSet()); + DirectedLink(op_node, arg_node); + } + CHECK(CheckLinksRoleSet()); + } + + CHECK(CheckNodesRoleSet()); + CheckValid(); +} + +void SSAGraph::RemoveNode(const mir::Node *node) { + auto pos = std::find_if(node_storage_.begin(), + node_storage_.end(), + [&node](mir::Node &n) { return &n == node; }); + CHECK(pos != node_storage_.end()); + node_storage_.erase(pos); +} + +mir::Node *SSAGraph::Argument(const std::string &name) { + auto it = arguments_.find(name); + CHECK(it != arguments_.end()) << "no argument called " << name; + return it->second; +} + +std::vector SSAGraph::inputs() { + std::vector res; + for (auto &node : node_storage_) { + if (node.inlinks.empty()) { + res.push_back(&node); + } + } + return res; +} + +std::vector SSAGraph::outputs() { + std::vector res; + for (auto &node : node_storage_) { + if (node.outlinks.empty()) { + res.push_back(&node); + } + } + return res; +} + +mir::Node *SSAGraph::RetrieveArgument(const std::string &arg) { + auto it = arguments_.find(arg); + if (it != arguments_.end()) { + return it->second; + } + return nullptr; +} + +bool SSAGraph::CheckNodesRoleSet() { + for (auto &node : mutable_nodes()) { + CHECK_OR_FALSE(node.IsRoleSet()); + } + return true; +} + +bool SSAGraph::CheckLinksRoleSet() { + for (auto &node : mutable_nodes()) { + CHECK_OR_FALSE(node.IsRoleSet()); + if (!node.IsStmt()) continue; + for (auto *x : node.inlinks) { + CHECK_OR_FALSE(x->IsRoleSet()); + CHECK_OR_FALSE(x->IsArg()); + } + for (auto *x : node.outlinks) { + CHECK_OR_FALSE(x->IsRoleSet()); + CHECK_OR_FALSE(x->IsArg()); + } + } + return true; +} + +Node *SSAGraph::NewArgumentNode(const std::string &name) { + node_storage_.emplace_back(); + auto &arg_node = node_storage_.back(); + arg_node.AsArg(name, node_storage_.size() - 1); + return &arg_node; +} + +Node *SSAGraph::NewInstructNode() { + node_storage_.emplace_back(); + return &node_storage_.back(); +} + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/ssa_graph.h b/lite/core/mir/ssa_graph.h new file mode 100644 index 00000000000..b5b9fb1cb28 --- /dev/null +++ b/lite/core/mir/ssa_graph.h @@ -0,0 +1,144 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/mir/node.h" +#include "lite/core/op_lite.h" +#include "lite/core/program.h" + +namespace paddle { +namespace lite { +namespace mir { + +// An Graph for MIR. It is built from a list of Op and a scope. +class GraphBase {}; + +class SSAGraph : GraphBase { + public: + // @param program: the op program + // @param valid_places: the valid places user set for the system. + void Build(const Program &program, const std::vector &valid_places); + void RemoveNode(const mir::Node *node); + + std::vector StmtTopologicalOrder(); + + // The inputs of the graph. + std::vector inputs(); + + // The outputs of the graph. + std::vector outputs(); + + const std::list &nodes() const { return node_storage_; } + std::list &mutable_nodes() { return node_storage_; } + + mir::Node *RetrieveArgument(const std::string &arg); + + Node *NewArgumentNode(const std::string &name); + Node *NewInstructNode(); + + void CheckValid() { + CHECK(CheckBidirectionalConnection()); + CHECK(CheckNodesRoleSet()); + CHECK(CheckLinksRoleSet()); + } + + Node *GraphCreateInstructNode(const std::shared_ptr &op, + const std::vector &valid_places); + + // Device related attributes + const std::vector &valid_places() const { return valid_places_; } + void SetValidPlaces(const std::vector &x) { valid_places_ = x; } + + private: + mir::Node *Argument(const std::string &name); + // Check the bidirectional connection. + bool CheckBidirectionalConnection(); + bool CheckNodesRoleSet(); + // Check all the items's role in inlinks and outlinks is set. + bool CheckLinksRoleSet(); + + void MarkArgumentWeights(const Program &program) { + for (const auto &name : program.weights()) { + arguments_[name]->AsArg().is_weight = true; + } + } + + // Build operator inlink edge table. + std::map> BuildOperationAdjList(); + + void SortHelper(const std::map> &adj_list, + mir::Node *node, + std::set *visited, + std::vector *ret); + + private: + std::list node_storage_; + std::map arguments_; + std::vector valid_places_; +}; + +// Remove the link between a -> b. +static void RemoveDirectedLink(Node *a, Node *b) { + auto it = std::find(b->inlinks.begin(), b->inlinks.end(), a); + if (it != b->inlinks.end()) { + b->inlinks.erase(it); + } + + auto it1 = std::find(a->outlinks.begin(), a->outlinks.end(), b); + if (it1 != a->outlinks.end()) { + a->outlinks.erase((it1)); + } +} + +// Link a -> b. +static void DirectedLink(Node *a, Node *b) { + // Eagerly remove first, to avoid duplicate link. + RemoveDirectedLink(a, b); + a->outlinks.push_back(b); + b->inlinks.push_back(a); +} + +static void LocalInferenceType(Node *a, Node *b, const std::string &arg_name) { + // instr -> output argument + if (a->IsStmt() && b->IsArg()) { + auto &inst = a->AsStmt(); + auto &output = b->AsArg(); + + if (!output.type) { + output.type = inst.picked_kernel().GetOutputDeclType(arg_name); + } + } + + // input argument -> instr + if (a->IsArg() && b->IsStmt()) { + auto &input = a->AsArg(); + auto &inst = b->AsStmt(); + if (!input.type) { + input.type = inst.picked_kernel().GetInputDeclType(arg_name); + } + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/ssa_graph_test.cc b/lite/core/mir/ssa_graph_test.cc new file mode 100644 index 00000000000..ef49001ba2f --- /dev/null +++ b/lite/core/mir/ssa_graph_test.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/ssa_graph.h" +#include +#include +#include "lite/api/paddle_use_passes.h" +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/op_registry.h" +#include "lite/core/program_fake_utils.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace lite { +namespace mir { + +void BuildFc(framework::ProgramDesc* desc, + const std::string& x, + const std::string& w, + const std::string& b, + const std::string& out) { + auto* fc = desc->MutableBlock(0)->AppendOp(); + fc->SetInput("Input", {x}); + fc->SetInput("W", {w}); + fc->SetInput("Bias", {b}); + fc->SetOutput("Out", {out}); +} + +TEST(SSAGraph, test) { + auto program_faker = ProgramFaker(); + SSAGraph graph; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + + lite::Program program(*program_faker.program()->Proto(), scope, places); + graph.Build(program, places); + + Visualize(&graph); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(fc); +#ifdef LITE_WITH_X86 +// USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def); +#endif diff --git a/lite/core/mir/static_kernel_pick_pass.cc b/lite/core/mir/static_kernel_pick_pass.cc new file mode 100644 index 00000000000..729ad4c9ae4 --- /dev/null +++ b/lite/core/mir/static_kernel_pick_pass.cc @@ -0,0 +1,135 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/static_kernel_pick_pass.h" +#include +#include +#include +#include +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +bool KernelScoreCmp(const std::pair>& a, + const std::pair>& b) { + return a.first > b.first; +} + +void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { + CHECK(kernel_pick_factors_.any_factor_considered()) + << "kernel_pick_factors should be specified first"; + CHECK(graph) << "graph not valid"; + // sort kernels by the factors. + + for (auto& node : graph->mutable_nodes()) { + if (!node.IsStmt()) continue; + auto& instruct = node.AsStmt(); + + // Get candidate kernels + std::vector>> scored; + CHECK(!instruct.kernels().empty()) << "No kernels found for " + << instruct.op_type(); + for (auto&& kernel : instruct.kernels()) { + size_t score = KernelGrade(*kernel); + scored.emplace_back(score, std::move(kernel)); + } + std::sort(scored.begin(), scored.end(), KernelScoreCmp); + instruct.kernels().clear(); + + if (!instruct.op_info()->HasAttr("enable_int8")) { + // Move kernel back + // Just keep a single best kernel. + // TODO(Superjomn) reconsider this. + instruct.kernels().emplace_back(std::move(scored.front().second)); + VLOG(2) << "pick " << instruct.kernels().front()->name(); + + } else { + bool out_type_int8 = true; + // Only if all ops linked to this op output has enable_int8 attr, + // then the op output type is int8, or fp32. + for (auto* out_n : node.outlinks) { + CHECK(out_n->IsArg()); + for (auto* tmp_op : out_n->outlinks) { + CHECK(tmp_op->IsStmt()); + if (!tmp_op->AsStmt().op_info()->HasAttr("enable_int8")) { + out_type_int8 = false; + break; + } + } + if (!out_type_int8) break; + } + // If the out_type_int8 is true, it turns out that the output type of this + // op can be int8. + // So we need to specify output scale for this op. + if (out_type_int8) { + auto out_node = node.outlinks.front(); + CHECK(out_node->IsArg()); + auto one_adj_op_node = out_node->outlinks.front(); + CHECK(one_adj_op_node->IsStmt()); + auto& one_adj_instruct = one_adj_op_node->AsStmt(); + CHECK(one_adj_instruct.op_info()->HasAttr("enable_int8")); + CHECK(one_adj_instruct.op_info()->HasAttr("input_scale")); + + instruct.mutable_op_info()->SetAttr( + "output_scale", + one_adj_instruct.op_info()->GetAttr("input_scale")); + + auto update_desc = *instruct.mutable_op_info(); + instruct.ResetOp(update_desc, graph->valid_places()); + scored.clear(); + for (auto&& kernel : instruct.kernels()) { + size_t score = KernelGrade(*kernel); + scored.emplace_back(score, std::move(kernel)); + } + std::sort(scored.begin(), scored.end(), KernelScoreCmp); + instruct.kernels().clear(); + } + // If the out_type_int8 is true, we should pick the kernel with the + // int8 input and int8 output. + // If the out_type_int8 is false, we should pick the kernel with the + // int8 input and fp32 output. + auto output_arguments = instruct.op_info()->OutputArgumentNames(); + for (auto& candidate : scored) { + bool all_output_type_match = true; + auto expect_output_type = + out_type_int8 ? PRECISION(kInt8) : PRECISION(kFloat); + + for (auto& arg_name : output_arguments) { + const Type* out_arg_ty = + candidate.second->GetOutputDeclType(arg_name); + if (out_arg_ty->precision() != expect_output_type) { + all_output_type_match = false; + } + } + + if (all_output_type_match) { + instruct.kernels().emplace_back(std::move(candidate.second)); + VLOG(2) << "pick " << instruct.kernels().front()->name(); + break; + } + } + CHECK(!instruct.kernels().empty()) << "No kernels found for " + << instruct.op_type(); + } + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(static_kernel_pick_pass, + paddle::lite::mir::StaticKernelPickPass); diff --git a/lite/core/mir/static_kernel_pick_pass.h b/lite/core/mir/static_kernel_pick_pass.h new file mode 100644 index 00000000000..34122782292 --- /dev/null +++ b/lite/core/mir/static_kernel_pick_pass.h @@ -0,0 +1,97 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" +#include "lite/core/types.h" + +namespace paddle { +namespace lite { +namespace mir { + +/* + * StaticKernelPickPass is a simple strategy for picking the kernel for each + * Operator using operator developer defined rule, there are many other tactics + * such as considering IO or kernel execution latency and we will implement them + * latter. + * + * There are two argument for this pass: + * - place, the target place. + * - kernel_pick_factors, the factors to consider in picking kernels. + * Set them first before execute the pass. + */ +class StaticKernelPickPass : public mir::StmtPass { + public: + void Apply(const std::unique_ptr& graph) override; + + void SetPreferPlace(const Place& place) { place_ = place; } + const Place& place() const { return place_; } + const core::KernelPickFactor& kernel_pick_factors() const { + return kernel_pick_factors_; + } + core::KernelPickFactor* mutable_kernel_pick_factors() { + return &kernel_pick_factors_; + } + + private: + // Score the kernel. + size_t KernelGrade(const lite::KernelBase& kernel) { + size_t score{}; + const int kMax = + std::numeric_limits::max(); + + // The more important factor comes first + if (kernel_pick_factors_.IsTargetConsidered() && + (place().target == kernel.target() || kernel.target() == TARGET(kAny) || + place().target == TARGET(kAny))) { + score += + kMax / static_cast(core::KernelPickFactor::Factor::TargetFirst); + } + if (kernel_pick_factors_.IsPrecisionConsidered() && + (place().precision == kernel.precision() || + kernel.precision() == PRECISION(kAny) || + place().precision == PRECISION(kAny))) { + score += kMax / + static_cast(core::KernelPickFactor::Factor::PrecisionFirst); + } + if (kernel_pick_factors_.IsDataLayoutConsidered() && + (place().layout == kernel.layout() || + kernel.layout() == DATALAYOUT(kAny) || + place().layout == DATALAYOUT(kAny))) { + score += kMax / static_cast( + core::KernelPickFactor::Factor::DataLayoutFirst); + } + VLOG(4) << "picker tactic " << kernel_pick_factors_; + VLOG(4) << "kernel place " << kernel.place().DebugString(); + VLOG(4) << "picker place " << place().DebugString(); + VLOG(4) << "score " << score; + + // The data layout is not considered, for the input and output arguments + // might have different data layout. + // TODO(Superjomn) reconsider the idea of taking the data layout as a kernel + // specification. + return score; + } + + private: + core::KernelPickFactor kernel_pick_factors_; + Place place_; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/subgraph/CMakeLists.txt b/lite/core/mir/subgraph/CMakeLists.txt new file mode 100644 index 00000000000..4b4eb562c75 --- /dev/null +++ b/lite/core/mir/subgraph/CMakeLists.txt @@ -0,0 +1,32 @@ + +lite_cc_library(subgraph_pass + SRCS subgraph_program_pass.cc + DEPS mir_pass types ${mir_fusers}) +lite_cc_test(test_subgraph_pass SRCS subgraph_program_pass_test.cc + DEPS subgraph_pass mir_passes gflags model_parser cxx_api + ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 SERIAL) +if (WITH_TESTING) + add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v1_tar_gz) + set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map") + set_target_properties(test_subgraph_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}") +endif() + +set(subgraph_passes subgraph_pass) + +if(LITE_WITH_NPU) + lite_cc_library(npu_pass SRCS generate_npu_program_pass.cc + DEPS mir_pass types context ${mir_fusers} ${npu_bridges} npu_helper ${npu_ddk_libs} graph_op subgraph_pass) + list(APPEND subgraph_passes npu_pass) + lite_cc_test(test_npu_pass SRCS generate_npu_program_pass_test.cc + DEPS npu_pass cxx_api mir_passes gflags + ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 + --optimized_model=${LITE_MODEL_DIR}/lite_npu_model_opt SERIAL) + if (WITH_TESTING) + add_dependencies(test_npu_pass extern_lite_download_mobilenet_v1_tar_gz) + set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map") + set_target_properties(test_npu_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}") + endif() +endif() + +set(subgraph_passes ${subgraph_passes} CACHE INTERNAL "subgraph_passes") +message(STATUS "+++++ subgraph_passes: ${subgraph_passes}") diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.cc b/lite/core/mir/subgraph/generate_npu_program_pass.cc new file mode 100644 index 00000000000..798faf45443 --- /dev/null +++ b/lite/core/mir/subgraph/generate_npu_program_pass.cc @@ -0,0 +1,259 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/subgraph/generate_npu_program_pass.h" +#include +#include +#include +#include +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/pattern_matcher.h" + +#include "ai_ddk_lib/include/HiAiModelManagerService.h" +#include "ai_ddk_lib/include/graph/graph.h" +#include "ai_ddk_lib/include/graph/model.h" +#include "ai_ddk_lib/include/graph/op/all_ops.h" // for ge::op::Data +#include "ai_ddk_lib/include/graph/operator_reg.h" +#include "lite/npu/bridge/paddle_use_npu_bridges.h" +#include "lite/npu/bridge/registry.h" +#include "lite/npu/bridge/utils.h" +#include "lite/npu/npu_helper.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace subgraph { + +// call convert function from start node +// return if convert success and the nodes to remove +// return the output npu op +lite::npu::bridge::node_map_type GenerateNPUProgramPass::CvtOpNodes( + const lite::npu::bridge::cvt_map_type& cvtfunc_map, + const Node* op_node, + const lite::npu::bridge::node_map_type& inputs_map, + int sub_id, + std::unordered_set* nodes2rm, + key2nodes_t* matched) { + lite::npu::bridge::node_map_type failed; + if (!op_node->IsStmt()) { + LOG(INFO) << "stop return failed"; + return failed; + } + auto* stmt = op_node->stmt(); + auto op_type = stmt->op_type(); + LOG(INFO) << "cvt op type: " << op_type; + + if (stmt->subgraph_id() != sub_id) { + LOG(INFO) << "return as subgraph_id(" << stmt->subgraph_id() + << ") != sub_id(" << sub_id << ")"; + return failed; + } else { + CHECK(cvtfunc_map.count(op_type)) << "Should be supported " << op_type + << ", with subgraph_id: " << sub_id; + } + + auto outputs_map = cvtfunc_map.at(op_type)(stmt->op(), inputs_map); + if (outputs_map.empty()) { + return outputs_map; + } + + nodes2rm->insert(op_node); + for (auto& var_node : op_node->outlinks) { + for (auto& next_op_node : var_node->outlinks) { + LOG(INFO) << "next op type: " << next_op_node->AsStmt().op_type(); + if (next_op_node->AsStmt().subgraph_id() != sub_id) { + // this is the end condition + // TODO(TJ): when enable more inputs and outputs this is bugy + LOG(INFO) << "--- should return once ---"; + // TODO(TJ): matched output could be vector + matched->insert(std::make_pair("Output", var_node)); + return outputs_map; + } else { + // LOG(INFO) << "argnames: "; + // for (auto sss : next_op_node->AsStmt().op_info()->input_argnames()) { + // LOG(INFO) << sss; + // } + // LOG(INFO) << "input argnames: "; + // for (auto sss : next_op_node->AsStmt().op_info()->input_names()) { + // LOG(INFO) << sss; + // } + for (auto& i_node : next_op_node->inlinks) { + CHECK(i_node->IsArg()); + auto& arg = i_node->AsArg(); + LOG(INFO) << arg.name; + if (outputs_map.count(arg.name)) continue; + if (!arg.is_weight) { + LOG(INFO) << "Data arg name:" << arg.name; + outputs_map.insert(std::make_pair( + arg.name, + lite::npu::bridge::CvtNode( + i_node, next_op_node->AsStmt().op()->scope()))); + } + } + nodes2rm->insert(var_node); + return CvtOpNodes( + cvtfunc_map, next_op_node, outputs_map, sub_id, nodes2rm, matched); + } + } + } +} + +void GenerateNPUProgramPass::ConvertSubgraph( + const std::unique_ptr& graph, int sub_num) { + const auto& bridges = lite::npu::bridge::Factory::Instance(); + const auto& cvtfunc_map = bridges.AllFunctions(); + std::unordered_set nodes2rm_all; + + auto items = graph->StmtTopologicalOrder(); + for (int id = 1; id <= sub_num; ++id) { + LOG(INFO) << "Converting subgraph_id:" << id; + for (auto& op_node : items) { + std::unordered_set nodes2rm; + if (!op_node->IsStmt()) continue; + auto& stmt = op_node->AsStmt(); + if (stmt.subgraph_id() != id) continue; + CHECK(bridges.HasType(stmt.op_type())); + key2nodes_t matched; + matched["target_op"] = op_node; + auto& op = stmt.op(); + auto* scope = op->scope(); + // prepare inputs data. + std::string data_name = "data_subgraph_" + std::to_string(id); + lite::npu::bridge::node_map_type npu_inputs_map; + int name_id = 0; + LOG(INFO) << "op_type: " << stmt.op_type(); + std::vector actual_input_argnames; + for (auto& arg_node : op_node->inlinks) { + CHECK(arg_node->IsArg()); + const auto& arg = arg_node->AsArg(); + if (!arg_node->AsArg().is_weight) { + LOG(INFO) << "Input arg name: " << arg.name; + npu_inputs_map.insert(std::make_pair( + arg.name, lite::npu::bridge::CvtNode(arg_node, scope))); + // TODO(TJ): Here matched inputs should also be input vector + matched["Input"] = arg_node; + name_id++; + } + } + CHECK_EQ(name_id, 1) << "mobilenetv1 only have one input data!"; + auto npu_outputs_map = CvtOpNodes( + cvtfunc_map, op_node, npu_inputs_map, id, &nodes2rm, &matched); + if (!npu_outputs_map.empty()) { + LOG(INFO) << "[NPU] subgraph " << id << ": output not empty "; + std::vector inputs; + std::vector outputs; + for (auto& i : npu_inputs_map) { + LOG(INFO) << "input data argname:" << i.first + << ", ptr: " << i.second; + inputs.emplace_back(*(i.second)); + } + for (auto& i : npu_outputs_map) { + LOG(INFO) << "output data argname:" << i.first + << ", ptr: " << i.second; + outputs.emplace_back(*(i.second)); + } + + std::string model_name("hiai_npu_client_" + std::to_string(id) + ".om"); + if (!npu::BuildNPUClient(inputs, outputs, model_name)) { + // build failed, so this subgraph is abandoned + nodes2rm.clear(); + LOG(WARNING) << "Build NPU failed subgraph " << id; + break; + } + LOG(INFO) << "[NPU] Build NPU Client success subgraph " << id; + + // Then InsertNewNode(graph, matched); make one function + cpp::OpDesc op_desc; + op_desc.SetType("graph_op"); + // change to vectors + op_desc.SetInput("Inputs", {matched.at("Input")->arg()->name}); + op_desc.SetOutput("Outputs", {matched.at("Output")->arg()->name}); + op_desc.SetAttr("model_name", model_name); + auto graph_op = LiteOpRegistry::Global().Create("graph_op"); + auto target_op = matched.at("target_op")->stmt()->op(); + auto* scope = target_op->scope(); + CHECK(scope); + CHECK(graph_op); + graph_op->Attach(op_desc, scope); + + auto valid_places = + target_op->valid_places(); // TODO(TJ): add npu place? + auto* new_op_node = + graph->GraphCreateInstructNode(graph_op, valid_places); + + IR_NODE_LINK_TO(matched.at("Input"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("Output")); + + // assign context + auto& inst = new_op_node->AsStmt(); + inst.picked_kernel().SetContext(ContextScheduler::Global().NewContext( + inst.picked_kernel().target())); + + if (!nodes2rm.empty()) { + nodes2rm_all.insert(nodes2rm.begin(), nodes2rm.end()); + } + break; + } // if npu output success + } // for op_nodes + } // for subgraph id + // remove all unused node once + GraphSafeRemoveNodes(graph.get(), nodes2rm_all); + // clear all npu ops + npu::OpList::Global().clear(); +} + +void GenerateNPUProgramPass::Apply(const std::unique_ptr& graph) { + LOG(INFO) << "Before NPU Pass \n" << Visualize(graph.get()); + const auto& bridges = lite::npu::bridge::Factory::Instance(); + const auto& op_map = bridges.AllFunctions(); + std::vector supported_op_types; + for (auto& i : op_map) { + LOG(INFO) << i.first; + supported_op_types.push_back(i.first); + } + int num_subgraph = FuseSubgraph(graph, supported_op_types); + LOG(INFO) << "detected " << num_subgraph << " NPU subgraph"; + + InferOnce(graph); + ConvertSubgraph(graph, num_subgraph); + // auto graph1 = GenerateFusedGraph(std::move(graph)); + // GraphSafeRemoveNodes(graph, nodes2rm); + LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get()); + + for (auto& item : graph->StmtTopologicalOrder()) { + if (item->IsStmt()) { + auto& stmt = item->AsStmt(); + LOG(INFO) << stmt; + insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front())); + } + } +} + +std::unique_ptr GenerateNPUProgramPass::GenProgram() { + LOG(INFO) << "insts.size " << insts_.size(); + std::unique_ptr program( + new RuntimeProgram(std::move(insts_))); + return program; +} + +} // namespace subgraph +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(generate_npu_program_pass, + paddle::lite::mir::subgraph::GenerateNPUProgramPass); + +// USE_LITE_OP(graph_op); diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.h b/lite/core/mir/subgraph/generate_npu_program_pass.h new file mode 100644 index 00000000000..908190e4e9f --- /dev/null +++ b/lite/core/mir/subgraph/generate_npu_program_pass.h @@ -0,0 +1,65 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "lite/core/mir/pass.h" +#include "lite/core/mir/subgraph/subgraph_program_pass.h" +#include "lite/npu/bridge/registry.h" +#include "lite/npu/npu_helper.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace subgraph { + +class GenerateNPUProgramPass : public SubgraphProgramPass { + public: + using key2nodes_t = std::map; + + void Apply(const std::unique_ptr& graph) override; + std::unique_ptr GenProgram(); + + protected: + // TODO(TJ): maybe change a name + // convert all fused subgraphs to npu clients + // 1. if some subgraph failed, then skip. + // 2. add new graph nodes, kernels and context + // 3. remove unused nodes + void ConvertSubgraph(const std::unique_ptr& graph, int sub_num); + + // call convert function from start node + // return if convert success and the nodes to remove + // return the output(arg.name, npu op) + lite::npu::bridge::node_map_type CvtOpNodes( + const lite::npu::bridge::cvt_map_type& cvtfunc_map, + const Node* op_node, + const lite::npu::bridge::node_map_type& inputs_map, + int sub_id, + std::unordered_set* nodes2rm, + key2nodes_t* matched); + + private: + std::vector insts_; +}; + +} // namespace subgraph +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/subgraph/generate_npu_program_pass_test.cc b/lite/core/mir/subgraph/generate_npu_program_pass_test.cc new file mode 100644 index 00000000000..d7ce9ed7d67 --- /dev/null +++ b/lite/core/mir/subgraph/generate_npu_program_pass_test.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/subgraph/subgraph_program_pass.h" +#include "lite/core/program.h" +#include "lite/core/tensor.h" + +#include "lite/api/cxx_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/core/op_registry.h" + +#include "lite/model_parser/pb/program_desc.h" + +DEFINE_string(model_dir, "", "model_dir"); +DEFINE_string(optimized_model, "", "optimized_model"); + +namespace paddle { +namespace lite { + +TEST(NPUSubgraph, mobilenetv1) { + lite::Predictor predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kNPU), PRECISION(kFloat)}}); + predictor.Build( + FLAGS_model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({1, 3, 224, 224}))); + // input_tensor->Resize(DDim(std::vector({1, 13, 1, 1}))); + auto* data = input_tensor->mutable_data(); + auto item_size = input_tensor->dims().production(); + for (int i = 0; i < item_size; i++) { + data[i] = 1; + } + + predictor.GenNPURuntimeProgram(); + + for (int i = 0; i < 10; ++i) { + predictor.Run(); + } + + LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model; + predictor.SaveModel(FLAGS_optimized_model); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/subgraph/subgraph_program_pass.cc b/lite/core/mir/subgraph/subgraph_program_pass.cc new file mode 100644 index 00000000000..5816eefe18b --- /dev/null +++ b/lite/core/mir/subgraph/subgraph_program_pass.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/subgraph/subgraph_program_pass.h" +#include +#include +#include +#include +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/pattern_matcher.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace subgraph { + +void SubgraphProgramPass::InferOnce(const std::unique_ptr& graph) { + for (auto& item : graph->StmtTopologicalOrder()) { + if (!item->IsStmt()) continue; + auto& stmt = item->AsStmt(); + auto& op = stmt.op(); + op->CheckShape(); + op->InferShape(); + // TOOD(xxx): remove Launch() at last + auto& kkks = stmt.kernels(); + if (!kkks.empty()) { + auto& kk = stmt.kernels().front(); + if (kk) { + kk->Launch(); + } + } + } +} + +void SubgraphProgramPass::InitSubgraphID( + const std::unique_ptr& graph, + const std::vector& supported_op_types) { + for (auto& item : graph->StmtTopologicalOrder()) { + if (!item->IsStmt()) continue; + auto& stmt = item->AsStmt(); + stmt.ClearSubgraphID(); + if (std::find(supported_op_types.begin(), + supported_op_types.end(), + stmt.op_type()) != supported_op_types.end()) { + stmt.SetSubgraphID(0); + LOG(INFO) << "supported " << stmt.op_type(); + } else { + LOG(INFO) << "======= not supported " << stmt.op_type(); + } + } +} + +// mark current and all output supported nodes +void SubgraphProgramPass::ChangeAllOutConnectedID(Node* node, + int to_id, + int from_id) { + if (!node) return; + if (node->IsStmt()) { + auto& stmt = node->AsStmt(); + if (stmt.subgraph_id() == from_id) { + stmt.SetSubgraphID(to_id); + nodes2rm_[to_id].insert(node); + for (auto& i : node->outlinks) { + ChangeAllOutConnectedID(i, to_id, from_id); + } + } else { + LOG(INFO) << "failed op type:" << stmt.op_type(); + return; + } + } else { + // this it arg node + bool all_out_op_supported = true; + for (auto& i : node->outlinks) { + if (!i->IsStmt()) return; + auto& stmt = i->AsStmt(); + if (stmt.subgraph_id() != from_id) { + all_out_op_supported = false; + } + } + if (!all_out_op_supported) { + return; + } + for (auto& i : node->outlinks) { + CHECK(i->IsStmt()); + auto& stmt = i->AsStmt(); + CHECK_EQ(stmt.subgraph_id(), from_id); + stmt.SetSubgraphID(to_id); + nodes2rm_[to_id].insert(i); + for (auto& o : i->outlinks) { + ChangeAllOutConnectedID(o, to_id, from_id); + } + } + } +} + +int SubgraphProgramPass::FuseSubgraphID( + const std::unique_ptr& graph) { + int sub_id = 1; // id start from 1 not 0 + for (auto& item : graph->StmtTopologicalOrder()) { + if (!item->IsStmt()) continue; + auto& stmt = item->AsStmt(); + if (stmt.subgraph_id() != 0) continue; + ChangeAllOutConnectedID(item, sub_id); + sub_id++; + } + return sub_id - 1; +} + +int SubgraphProgramPass::FuseSubgraph( + const std::unique_ptr& graph, + const std::vector& supported_op_types) { + InitSubgraphID(graph, supported_op_types); + nodes2rm_.clear(); + i_nodes_.clear(); + o_nodes_.clear(); + int num_subgraph = FuseSubgraphID(graph); + LOG(INFO) << "detected " << num_subgraph << " subgraph"; + return num_subgraph; +} + +} // namespace subgraph +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(subgraph_program_pass, + paddle::lite::mir::subgraph::SubgraphProgramPass); diff --git a/lite/core/mir/subgraph/subgraph_program_pass.h b/lite/core/mir/subgraph/subgraph_program_pass.h new file mode 100644 index 00000000000..4348c3439f2 --- /dev/null +++ b/lite/core/mir/subgraph/subgraph_program_pass.h @@ -0,0 +1,70 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace subgraph { + +class SubgraphProgramPass : public ProgramPass { + public: + using key2nodes_t = std::map; + + // make all the linked ops in subgraph with same subgraph_id + // return the fused subgraph numbers + int FuseSubgraph(const std::unique_ptr& graph, + const std::vector& supported_op_types); + + void Apply(const std::unique_ptr& graph) override{}; + + protected: + void InferOnce(const std::unique_ptr& graph); + + // clear all subgraph id and mark all ops, which could be fuse, as id zero + void InitSubgraphID(const std::unique_ptr& graph, + const std::vector& supported_op_types); + + // make all the linked ops in subgraph with same subgraph_id + // return the fused subgraph numbers + int FuseSubgraphID(const std::unique_ptr& graph); + + // // GenerateFusedGraph: + // std::unique_ptr GenerateFusedGraph(const + // std::unique_ptr& graph, int sub_num); + void ChangeAllOutConnectedID(Node* node, int to_id, int from_id = 0); + + private: + // {1: {nodes2rm_in_subgraph1, ...}, + // 2: {nodes2rm_in_subgraph2, ...}} + std::unordered_map> nodes2rm_; + // inputs nodes + std::unordered_map> i_nodes_; + // outputs nodes + std::unordered_map> o_nodes_; +}; + +} // namespace subgraph +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/subgraph/subgraph_program_pass_test.cc b/lite/core/mir/subgraph/subgraph_program_pass_test.cc new file mode 100644 index 00000000000..f9b8dd38b12 --- /dev/null +++ b/lite/core/mir/subgraph/subgraph_program_pass_test.cc @@ -0,0 +1,140 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/subgraph/subgraph_program_pass.h" +#include +#include +#include +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/ssa_graph.h" +#include "lite/core/program.h" +#include "lite/model_parser/cpp/program_desc.h" +#include "lite/model_parser/model_parser.h" + +DEFINE_string(model_dir, "", "model_dir"); + +namespace paddle { +namespace lite { + +TEST(SubgraphTest, mobilenetv1) { + cpp::ProgramDesc program_desc; + auto scope = std::make_shared(); + LoadModelPb(FLAGS_model_dir, scope.get(), &program_desc); + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat)}, +#ifdef LITE_WITH_ARM + Place{TARGET(kARM), PRECISION(kFloat)}, +#endif +#ifdef LITE_WITH_NPU + Place{TARGET(kNPU), PRECISION(kFloat)}, +#endif + }); + lite::Program program(program_desc, scope, valid_places); + auto graph = std::unique_ptr(new mir::SSAGraph()); + graph->Build(program, valid_places); + + std::vector supported_op_types{"conv2d", + "depthwise_conv2d", + "batch_norm", + "scale", + "pool2d", + "mul", + "elementwise_add", + "softmax", + "relu"}; + auto* pass = new mir::subgraph::SubgraphProgramPass; + ASSERT_EQ(pass->FuseSubgraph(graph, supported_op_types), 1); +} + +// return output_var_names +std::vector AddFCDesc( + cpp::BlockDesc* block_desc, + const std::shared_ptr& scope, + const std::vector& input_var_names, + const std::vector& wshape) { + CHECK_EQ(input_var_names.size(), 1); + CHECK_EQ(wshape.size(), 2); + static int id = 0; + std::string prefix = "fc_" + std::to_string(id); + auto* op_desc = block_desc->AddOp(); + auto* wgt = block_desc->AddVar(); + auto* bias = block_desc->AddVar(); + auto* out = block_desc->AddVar(); + + wgt->SetName(prefix + "_W"); + bias->SetName(prefix + "_Bias"); + out->SetName(prefix + "_Out"); + std::vector out_var_names{prefix + "_Out"}; + + auto* wtensor = scope->Var(prefix + "_W")->GetMutable(); + wtensor->Resize(wshape); + wtensor->mutable_data(); + + auto* btensor = scope->Var(prefix + "_Bias")->GetMutable(); + btensor->Resize({wshape[1]}); + btensor->mutable_data(); + + scope->Var(prefix + "_Out")->GetMutable(); + + op_desc->SetType("fc"); + op_desc->SetInput("Input", input_var_names); + op_desc->SetInput("W", {prefix + "_W"}); + op_desc->SetInput("Bias", {prefix + "_Bias"}); + op_desc->SetAttr("in_num_col_dims", 1); + op_desc->SetOutput("Out", out_var_names); + id++; + return out_var_names; +} + +std::unique_ptr BuildSimpleNet( + cpp::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + program_desc->ClearBlocks(); + auto* block_desc = program_desc->AddBlock(); + block_desc->ClearOps(); + block_desc->ClearVars(); + + auto* var_desc = block_desc->AddVar(); + var_desc->SetName("feed_var"); + auto* feed_var = scope->Var("feed_var")->GetMutable(); + feed_var->Resize({1, 4}); + auto fc1_out = AddFCDesc(block_desc, scope, {"feed_var"}, {4, 5}); + auto fc2_out = AddFCDesc(block_desc, scope, fc1_out, {5, 2}); + + lite::Program program(*program_desc, scope, valid_places); + auto graph = std::unique_ptr(new mir::SSAGraph()); + graph->Build(program, valid_places); + + return graph; +} + +TEST(SubGraphTest, SimpleNet) { + cpp::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildSimpleNet(&program_desc, scope, places); + + std::vector supported_op_types{"fc"}; + auto* pass = new mir::subgraph::SubgraphProgramPass; + ASSERT_EQ(pass->FuseSubgraph(graph, supported_op_types), 1); + + const int num_nodes = graph->nodes().size(); + ASSERT_EQ(graph->nodes().size(), 9); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/type_layout_cast_pass.cc b/lite/core/mir/type_layout_cast_pass.cc new file mode 100644 index 00000000000..2b216ceec59 --- /dev/null +++ b/lite/core/mir/type_layout_cast_pass.cc @@ -0,0 +1,176 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/type_layout_cast_pass.h" +#include +#include +#include +#include +#include +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/utils/string.h" + +namespace paddle { +namespace lite { +namespace mir { + +void TypeLayoutTransformPass::Apply(const std::unique_ptr& graph) { + // Start from inputs of the graph, those should have place set. + std::list nodes; + for (auto& node : graph->mutable_nodes()) { + nodes.push_back(&node); + } + + for (auto& node : nodes) { + if (!node->IsStmt()) continue; + auto inlinks = node->inlinks; + for (auto* in : inlinks) { + ComplementInputs(graph.get(), node, in); + } + } + VLOG(3) << "\n" << Visualize(graph.get()); +} + +void TypeLayoutTransformPass::ComplementInputs(SSAGraph* graph, + Node* inst_node, + Node* in) { + // If this input is out of date. + if (inst_node->inlinks.end() == + std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in)) + return; + + CHECK(inst_node->IsStmt()); + auto& inst = inst_node->AsStmt(); + CHECK(in->IsRoleSet()); + CHECK(in->IsArg()); + auto in_arg_name = in->AsArg().name; + std::string tmp; + CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp)); + auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp); + CHECK(in->AsArg().type); + if (!DataLayoutCompatible(*in->AsArg().type, *decl_arg_type)) { + VLOG(4) << "found Layout unmatched tensor: " << in->AsArg().name + << " for kernel " << inst.op()->DebugString() << " " + << *in->AsArg().type << " -> " << *decl_arg_type; + AddLayoutInst(*in->AsArg().type, + *decl_arg_type, + in, + graph, + inst_node, + graph->valid_places()); + } +} + +void TypeLayoutTransformPass::AddLayoutInst( + const Type& from, + const Type& to, + Node* in, + SSAGraph* graph, + Node* inst_node, + const std::vector& valid_places) { + CHECK(!valid_places.empty()) << "valid_place should be set"; + + CHECK(in->IsArg()); + auto node_id = [&] { return graph->nodes().size(); }; + auto layout_output_name = + string_format("%s/trans/%d", in->AsArg().name.c_str(), node_id()); + auto* layout_output_arg = graph->NewArgumentNode(layout_output_name); + auto* layout_inst = graph->NewInstructNode(); + + bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist; + std::string layout_type = in_persist ? "layout_once" : "layout"; + // create Op and kernels. + auto layout_op = LiteOpRegistry::Global().Create(layout_type); + CHECK(layout_op) << "create op [" << layout_op << "] failed"; + layout_output_arg->AsArg().is_persist = in_persist; + // Create the new var manually. + inst_node->AsStmt().op()->scope()->Var(layout_output_name); + + // Create IoCopy Instruction. + cpp::OpDesc op_desc; + op_desc.SetType(layout_type); + op_desc.SetInput("Input", {in->AsArg().name}); + op_desc.SetOutput("Out", {layout_output_name}); + + layout_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); + auto kernels = layout_op->CreateKernels(valid_places); + std::vector> selected_kernels; + bool is_found = false; + for (auto& kernel : kernels) { + const Type* in_arg_ty = kernel->GetInputDeclType("Input"); + const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); + if (TypeCompatible(*in_arg_ty, from)) { + is_found = true; + selected_kernels.emplace_back(std::move(kernel)); + // we pick the kernel + layout_inst->AsStmt(layout_type, std::move(kernels), layout_op); + break; + } + } + CHECK(is_found) << "Can't find a layout kernel for layout op: " << from + << ":" << in->AsArg().name << "->" << to << ":" + << inst_node->AsStmt().op_info()->Type(); + + // Remove the old link + RemoveDirectedLink(in, inst_node); + + // Update the original instruction OpDesc. + // Update its input to the layout_output_name + // Add new link, var -> new_inst, new_inst->newarg, newarg->inst + DirectedLink(in, layout_inst); + DirectedLink(layout_inst, layout_output_arg); + DirectedLink(layout_output_arg, inst_node); + + // reset opdesc and update kernel information + UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), + in->AsArg().name, + layout_output_name); + auto original_selected_kernel = + std::move(inst_node->AsStmt().kernels().front()); + auto update_op_info = *inst_node->AsStmt().op_info(); + // ResetOp() will change the Stmt op_info_ value, + // after that the old op_info_ value will be nullified. + // So, we can't pass `*inst_node->AsStmt().op_info()` into ResetOp. + // `update_op_info` is the copy of `*inst_node->AsStmt().op_info(). + // Whenever update the op_info of a stmt, we should call its ResetOp(). + inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places()); + inst_node->AsStmt().kernels().clear(); + inst_node->AsStmt().kernels().emplace_back( + std::move(original_selected_kernel)); + + std::string tmp; + if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) { + CHECK(false) << "get old a " << tmp; + } + + for (auto& kernel : inst_node->AsStmt().kernels()) { + inst_node->AsStmt().op()->AttachKernel(kernel.get()); + } + + graph->CheckValid(); +} + +void TypeLayoutTransformPass::SetValidPlaces( + const std::vector& valid_places) { + CHECK(!valid_places.empty()); + valid_places_ = valid_places; +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(type_layout_cast_pass, + paddle::lite::mir::TypeLayoutTransformPass); diff --git a/lite/core/mir/type_layout_cast_pass.h b/lite/core/mir/type_layout_cast_pass.h new file mode 100644 index 00000000000..bf36214e1dc --- /dev/null +++ b/lite/core/mir/type_layout_cast_pass.h @@ -0,0 +1,62 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/core/mir/pass.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +static void UpdateInputTo(cpp::OpDesc* desc, + const std::string& from, + const std::string& to) { + for (auto& item : *desc->mutable_inputs()) { + for (auto& input : item.second) { + if (input == from) { + input = to; + } + } + } +} + +class TypeLayoutTransformPass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; + + void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in); + + void AddLayoutInst(const Type& from, + const Type& to, + Node* in, + SSAGraph* graph, + Node* inst_node, + const std::vector& valid_places); + + void SetValidPlaces(const std::vector& valid_places); + + const std::vector& valid_places() const { return valid_places_; } + + private: + std::vector valid_places_; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/type_precision_cast_pass.cc b/lite/core/mir/type_precision_cast_pass.cc new file mode 100644 index 00000000000..517f9a9b70f --- /dev/null +++ b/lite/core/mir/type_precision_cast_pass.cc @@ -0,0 +1,182 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/type_precision_cast_pass.h" +#include +#include +#include +#include +#include +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void PrecisionCastPass::Apply(const std::unique_ptr& graph) { + // Start from inputs of the graph, those should have place set. + std::list nodes; + for (auto& node : graph->mutable_nodes()) { + nodes.push_back(&node); + } + + for (auto& node : nodes) { + if (!node->IsStmt()) continue; + auto inlinks = node->inlinks; + for (auto* in : inlinks) { + ComplementInputs(graph.get(), node, in); + } + } +} + +void PrecisionCastPass::ComplementInputs(SSAGraph* graph, + Node* inst_node, + Node* in) { + // If this input is out of date. + if (inst_node->inlinks.end() == + std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in)) + return; + + CHECK(inst_node->IsStmt()); + auto& inst = inst_node->AsStmt(); + CHECK(in->IsRoleSet()); + CHECK(in->IsArg()); + auto in_arg_name = in->AsArg().name; + std::string tmp; + CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp)); + auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp); + CHECK(in->AsArg().type); + VLOG(4) << inst.picked_kernel().name(); + // if (!in->AsArg().is_weight && !PrecisionCompatibleTo(*in->AsArg().type, + // *decl_arg_type)) { + if (!PrecisionCompatibleTo(*in->AsArg().type, *decl_arg_type)) { + VLOG(4) << "found Target unmatched tensor: " << in->AsArg().name + << " for kernel " << inst.op()->DebugString() << " " + << *in->AsArg().type << " -> " << *decl_arg_type; + // Add an Cast instruction to make the input compatible with other dist. + AddCastInst(*in->AsArg().type, + *decl_arg_type, + in, + graph, + inst_node, + graph->valid_places()); + } +} + +void PrecisionCastPass::AddCastInst(const Type& from, + const Type& to, + Node* in, + SSAGraph* graph, + Node* inst_node, + const std::vector& valid_places) { + CHECK(!valid_places.empty()) << "valid_place should be set"; + + // var -> new_transform_op -> new_var -> inst + // So there will be a new Argument node and a new Cast Statement Node. + CHECK(in->IsArg()); + auto node_id = [&] { return graph->nodes().size(); }; + auto cast_op_output_name = + in->AsArg().name + "/trans/" + std::to_string(node_id()); + auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name); + auto* cast_inst = graph->NewInstructNode(); + + // create Op and kernels. + bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist; + std::string cast_type = in_persist ? "calib_once" : "calib"; + cast_op_output_arg->AsArg().is_persist = in_persist; + auto cast_op = LiteOpRegistry::Global().Create(cast_type); + CHECK(cast_op) << "create op [" << cast_op << "] failed"; + + // Create the new var manually. + inst_node->AsStmt().op()->scope()->Var(cast_op_output_name); + + // Create Calib Instruction. + cpp::OpDesc op_desc; + op_desc.SetType(cast_type); + op_desc.SetInput("Input", {in->AsArg().name}); + op_desc.SetOutput("Out", {cast_op_output_name}); + if (inst_node->AsStmt().op_info()->HasAttr("input_scale")) { + op_desc.SetAttr( + "scale", inst_node->AsStmt().op_info()->GetAttr("input_scale")); + } + cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); + auto kernels = cast_op->CreateKernels(valid_places); + std::vector> selected_kernels; + bool is_found = false; + for (auto& kernel : kernels) { + const Type* in_arg_ty = kernel->GetInputDeclType("Input"); + const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); +// TODO(xg): to optimize this +#ifndef LITE_WITH_FPGA + if (in_arg_ty->precision() == from.precision() && + out_arg_ty->precision() == to.precision()) { +#else + if (TypeCompatible(*in_arg_ty, from)) { +#endif + is_found = true; + selected_kernels.emplace_back(std::move(kernel)); + // we pick the kernel + cast_inst->AsStmt(cast_type, std::move(selected_kernels), cast_op); + break; + } + } + + CHECK(is_found) << "Can't find a Cast kernel for Cast op: " << from << ":" + << in->AsArg().name << "->" << to << ":" + << inst_node->AsStmt().op_info()->Type(); + + // Remove the old link + RemoveDirectedLink(in, inst_node); + + // Update the original instruction OpDesc. + // Update its input to the io_copy_output_name + + // Add new link, var -> new_inst, new_inst->newarg, newarg->inst + DirectedLink(in, cast_inst); + DirectedLink(cast_inst, cast_op_output_arg); + DirectedLink(cast_op_output_arg, inst_node); + + // reset opdesc and update kernel information + UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), + in->AsArg().name, + cast_op_output_name); + + // recreate the op + auto original_selected_kernel = + std::move(inst_node->AsStmt().kernels().front()); + auto updated_op_info = *inst_node->AsStmt().mutable_op_info(); + + inst_node->AsStmt().ResetOp(updated_op_info, graph->valid_places()); + inst_node->AsStmt().kernels().clear(); + inst_node->AsStmt().kernels().emplace_back( + std::move(original_selected_kernel)); + for (auto& kernel : inst_node->AsStmt().kernels()) { + VLOG(4) << "kernel info: " << kernel->name(); + inst_node->AsStmt().op()->AttachKernel(kernel.get()); + } + graph->CheckValid(); +} + +void PrecisionCastPass::SetValidPlaces(const std::vector& valid_places) { + CHECK(!valid_places.empty()); + valid_places_ = valid_places; +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(type_precision_cast_pass, + paddle::lite::mir::PrecisionCastPass); diff --git a/lite/core/mir/type_precision_cast_pass.h b/lite/core/mir/type_precision_cast_pass.h new file mode 100644 index 00000000000..3f55e52ef9f --- /dev/null +++ b/lite/core/mir/type_precision_cast_pass.h @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/core/mir/pass.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +static void UpdateInputTo(cpp::OpDesc* desc, + const std::string& from, + const std::string& to) { + for (auto& item : *desc->mutable_inputs()) { + for (auto& input : item.second) { + if (input == from) { + input = to; + } + } + } +} + +/* + * The pass complement the necessary instruction to make data + * transferring or transformation between different places. + */ +class PrecisionCastPass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; + + void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in); + + void AddCastInst(const Type& from, + const Type& to, + Node* in, + SSAGraph* graph, + Node* inst_node, + const std::vector& valid_places); + + void SetValidPlaces(const std::vector& valid_places); + + const std::vector& valid_places() const { return valid_places_; } + + private: + std::vector valid_places_; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/type_target_cast_pass.cc b/lite/core/mir/type_target_cast_pass.cc new file mode 100644 index 00000000000..f653654e967 --- /dev/null +++ b/lite/core/mir/type_target_cast_pass.cc @@ -0,0 +1,182 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/type_target_cast_pass.h" +#include +#include +#include +#include +#include +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/utils/string.h" + +namespace paddle { +namespace lite { +namespace mir { + +void TypeTargetTransformPass::Apply(const std::unique_ptr& graph) { + // Start from inputs of the graph, those should have place set. + std::list nodes; + for (auto& node : graph->mutable_nodes()) { + nodes.push_back(&node); + } + + CHECK(!valid_places_.empty()); + + for (auto& node : nodes) { + if (!node->IsStmt()) continue; + auto inlinks = node->inlinks; + for (auto* in : inlinks) { + ComplementInputs(graph.get(), node, in); + } + } +} + +void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, + Node* inst_node, + Node* in) { + // If this input is out of date. + if (inst_node->inlinks.end() == + std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in)) + return; + + CHECK(inst_node->IsStmt()); + auto& inst = inst_node->AsStmt(); + LOG(INFO) << "found Target tensor: " << in->AsArg().name; + CHECK(in->IsRoleSet()); + CHECK(in->IsArg()); + auto in_arg_name = in->AsArg().name; + std::string tmp; + CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp)); + auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp); + CHECK(in->AsArg().type); + if (!TargetCompatibleTo(*in->AsArg().type, *decl_arg_type)) { + LOG(INFO) << "found Target unmatched tensor: " << in->AsArg().name + << " for kernel " << inst.op()->DebugString() << " " + << *in->AsArg().type << " -> " << *decl_arg_type; + // Add an IoCopy instruction to make the input compatible with other dist. + AddIoCopyInst( + *in->AsArg().type, *decl_arg_type, in, graph, inst_node, valid_places_); + } +} + +void TypeTargetTransformPass::AddIoCopyInst( + const Type& from, + const Type& to, + Node* in, + SSAGraph* graph, + Node* inst_node, + const std::vector& valid_places) { + CHECK(!valid_places.empty()) << "valid_place should be set"; + // var -> new_transform_op -> new_var -> inst + // So there will be a new Argument node and a new IoCopy Statement Node. + + CHECK(in->IsArg()); + auto node_id = [&] { return graph->nodes().size(); }; + auto io_copy_output_name = + string_format("%s/trans/%d", in->AsArg().name.c_str(), node_id()); + // TODO(MyPandaShaoxiang) should set same place with input? + auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); + auto* io_copy_inst = graph->NewInstructNode(); + + bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist; + std::string io_copy_type = in_persist ? "io_copy_once" : "io_copy"; + io_copy_output_arg->AsArg().is_persist = in_persist; + // create Op and kernels. + auto io_copy_op = LiteOpRegistry::Global().Create(io_copy_type); + CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed"; + // CHECK(io_copy_op); + // Create the new var manually. + inst_node->AsStmt().op()->scope()->Var(io_copy_output_name); + + // Create IoCopy Instruction. + cpp::OpDesc op_desc; + op_desc.SetType(io_copy_type); + op_desc.SetInput("Input", {in->AsArg().name}); + op_desc.SetOutput("Out", {io_copy_output_name}); + + io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); + auto kernels = io_copy_op->CreateKernels(valid_places); + // fix(MyPandaShaoxiang): select kernel that input_dcl_type same as in.type + bool is_found = false; + std::vector> selected_kernels; + for (auto& kernel : kernels) { + const Type* in_arg_ty = kernel->GetInputDeclType("Input"); + const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); + if (TypeCompatible(*in_arg_ty, from)) { + is_found = true; + selected_kernels.emplace_back(std::move(kernel)); + // we pick the kernel + io_copy_inst->AsStmt( + io_copy_type, std::move(selected_kernels), io_copy_op); + break; + } + } + CHECK(is_found) << "Can't find a io_copy kernel for io_copy op: " << from + << ":" << in->AsArg().name << "->" << to << ":" + << inst_node->AsStmt().op_info()->Type(); + + // Remove the old link + RemoveDirectedLink(in, inst_node); + + // Update the original instruction OpDesc. + // Update its input to the io_copy_output_name + // Add new link, var -> new_inst, new_inst->newarg, newarg->inst + DirectedLink(in, io_copy_inst); + DirectedLink(io_copy_inst, io_copy_output_arg); + DirectedLink(io_copy_output_arg, inst_node); + + // reset opdesc and update kernel information + UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), + in->AsArg().name, + io_copy_output_name); + auto original_selected_kernel = + std::move(inst_node->AsStmt().kernels().front()); + auto update_op_info = *inst_node->AsStmt().op_info(); + // ResetOp() will change the Stmt op_info_ value, + // after that the old op_info_ value will be nullified. + // So, we can't pass `*inst_node->AsStmt().op_info()` into ResetOp. + // `update_op_info` is the copy of `*inst_node->AsStmt().op_info(). + // Whenever update the op_info of a stmt, we should call its ResetOp(). + inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places()); + inst_node->AsStmt().kernels().clear(); + inst_node->AsStmt().kernels().emplace_back( + std::move(original_selected_kernel)); + + std::string tmp; + if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) { + CHECK(false) << "get old a " << tmp; + } + + for (auto& kernel : inst_node->AsStmt().kernels()) { + VLOG(4) << "kernel info: " << kernel->name(); + inst_node->AsStmt().op()->AttachKernel(kernel.get()); + } + + graph->CheckValid(); +} + +void TypeTargetTransformPass::SetValidPlaces( + const std::vector& valid_places) { + CHECK(!valid_places.empty()); + valid_places_ = valid_places; +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(type_target_cast_pass, + paddle::lite::mir::TypeTargetTransformPass); diff --git a/lite/core/mir/type_target_cast_pass.h b/lite/core/mir/type_target_cast_pass.h new file mode 100644 index 00000000000..8a8cfaf9f92 --- /dev/null +++ b/lite/core/mir/type_target_cast_pass.h @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "lite/core/mir/pass.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +static void UpdateInputTo(cpp::OpDesc* desc, + const std::string& from, + const std::string& to) { + for (auto& item : *desc->mutable_inputs()) { + for (auto& input : item.second) { + if (input == from) { + input = to; + } + } + } +} + +/* + * IoComplementPass complement the necessary instruction to make data + * transferring or transformation between different places. + */ +class TypeTargetTransformPass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; + + void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in); + + void AddIoCopyInst(const Type& from, + const Type& to, + Node* in, + SSAGraph* graph, + Node* inst_node, + const std::vector& valid_places); + + void SetValidPlaces(const std::vector& valid_places); + + const std::vector& valid_places() const { return valid_places_; } + + private: + std::vector valid_places_; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/variable_place_inference_pass.cc b/lite/core/mir/variable_place_inference_pass.cc new file mode 100644 index 00000000000..e3795ae6429 --- /dev/null +++ b/lite/core/mir/variable_place_inference_pass.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/variable_place_inference_pass.h" +#include +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void VariablePlaceInferencePass::Apply(const std::unique_ptr &graph) { + MarkInputPlace(graph.get()); + InferenceArgumentPlace(graph.get()); + CheckAllArgumentTypeDetermined(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(variable_place_inference_pass, + paddle::lite::mir::VariablePlaceInferencePass); diff --git a/lite/core/mir/variable_place_inference_pass.h b/lite/core/mir/variable_place_inference_pass.h new file mode 100644 index 00000000000..d5b0bb8aa67 --- /dev/null +++ b/lite/core/mir/variable_place_inference_pass.h @@ -0,0 +1,157 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "lite/core/mir/pass.h" +#include "lite/core/target_wrapper.h" + +namespace paddle { +namespace lite { +namespace mir { + +/* + * Mark the place of the variables in the SSAGrpah, it will inference the + * variables' place by the kernels outputs them. + */ +class VariablePlaceInferencePass : public DebugPass { + public: + void Apply(const std::unique_ptr& graph) override; + + private: + // Mark the place of input arguments. + void MarkInputPlace(SSAGraph* graph) { + CHECK(!graph->inputs().empty()) << "graph's inputs should be set"; + for (const auto& v : graph->inputs()) { + // the feed op might in the inputs + if (v->IsStmt()) { + VLOG(4) << "found kernel in inputs " << v->AsStmt().op_type(); + continue; + } + } + } + + void CheckAllArgumentTypeDetermined(SSAGraph* graph) { + for (auto& node : graph->mutable_nodes()) { + if (node.IsArg()) { + CHECK(node.AsArg().type) << "node " << node.AsArg().name + << " type not determined, " << &node; + } + } + } + + // Set the tye of the weight + void SetWeightType(Node* w, const LiteType& type) { +// TODO(xg) to optimize this +#ifndef LITE_WITH_FPGA + w->AsArg().type = + LiteType::GetTensorTy(TARGET(kHost), type.precision(), type.layout()); +#else + w->AsArg().type = LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); +#endif + } + + void InferenceArgumentPlace(SSAGraph* graph) { + VLOG(3) << "param-type-registry:\n" << ParamTypeRegistry::Global(); + for (auto& x : graph->StmtTopologicalOrder()) { + auto& inst = x->AsStmt(); +// The IoCopyOp is a tool operator, it won't support the type inference. +// in fpga, we has io_copy+cali+layout tool ops, so we need type inference for +// tool operator +#ifndef LITE_WITH_FPGA + if (inst.op_type() == "io_copy") continue; +#endif + // deal with inputs + VLOG(4) << "Infering op " << inst.op_info()->Repr(); + // TODO(zhaolong): Add check if the node's name in op's arguments. + + auto get_argname = [&]( + const std::string& node_name, + const std::map>& argname_map) + -> std::string { + for (auto& ele : argname_map) { + auto it = + std::find(ele.second.begin(), ele.second.end(), node_name); + if (it != ele.second.end()) return ele.first; + } + return ""; + }; + + for (auto* x_in : x->inlinks) { + std::string node_name = x_in->AsArg().name; + std::string arg_name = get_argname(node_name, inst.op_info()->inputs()); + CHECK(arg_name.size() > 0) << "can not found op arguments for node " + << node_name; + VLOG(4) << "-- input arg_name " << arg_name + << "-- node name :" << node_name; + auto type = inst.picked_kernel().GetInputDeclType(arg_name); + if (!x_in->AsArg().type) { + VLOG(4) << "set type " << *type << " " << x_in->AsArg().name; + if (x_in->AsArg().is_weight) { + SetWeightType(x_in, *type); + } else { + x_in->AsArg().type = type; + } + } + } + + VLOG(4) << "inst " << inst.op_info()->Repr(); + for (auto* x_out : x->outlinks) { + std::string node_name = x_out->AsArg().name; + std::string arg_name = + get_argname(node_name, inst.op_info()->outputs()); + CHECK(arg_name.size() > 0) << "can not found op arguments for node " + << node_name << " in Inst " + << inst.op_type(); + VLOG(4) << "-- output arg_name " << arg_name; + auto type = inst.picked_kernel().GetOutputDeclType(arg_name); + if (!x_out->AsArg().type) { + VLOG(4) << "set type " << *type << " " << x_out->AsArg().name; + if (x_out->AsArg().is_weight) { + SetWeightType(x_out, *type); + } else { + x_out->AsArg().type = type; + } + } + } + } + } + + // Update me's kUnk fields by other's fields. + void UpdatePlace(Place* me, const Place& other) { + CHECK(other.is_valid()); + if (me->target == TARGET(kUnk)) { + me->target = other.target; + } + if (me->precision == PRECISION(kUnk)) { + me->precision = other.precision; + } + if (me->layout == DATALAYOUT(kUnk)) { + me->layout = other.layout; + } + } + + private: + // The default target for arguments, e.g. load weights to CPU memory for CUDA + // computation by default. + TargetType argument_default_target_{TARGET(kHost)}; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/variable_place_inference_pass_test.cc b/lite/core/mir/variable_place_inference_pass_test.cc new file mode 100644 index 00000000000..cf86afd590d --- /dev/null +++ b/lite/core/mir/variable_place_inference_pass_test.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_passes.h" +#include "lite/core/optimizer.h" +#include "lite/core/program_fake_utils.h" +#include "lite/kernels/cuda/use_kernels.h" +#include "lite/kernels/host/use_kernels.h" + +namespace paddle { +namespace lite { +namespace mir { + +TEST(variable_place_inference_pass, test) { + std::shared_ptr scope(new lite::Scope); + ProgramFaker program_faker; + program_faker.AddFeed("a", 0); + program_faker.AddMul("a", "W", "a1"); + program_faker.AddMul("a1", "W1", "a2"); + program_faker.AddFetch("a2", 0); + program_faker.CreateVars(scope.get()); + + auto* desc = program_faker.program(); + + Optimizer optimizer; + std::vector places({ + Place{ + TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW), + }, + Place{ + TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW), + }, + Place{ + TARGET(kX86), PRECISION(kFloat), DATALAYOUT(kNCHW), + }, + Place{ + TARGET(kX86), PRECISION(kAny), DATALAYOUT(kAny), + }, + }); + + Program program(*desc->Proto(), scope, places); + + core::KernelPickFactor factor; + factor.ConsiderTarget(); + + std::vector passes({ + "static_kernel_pick_pass", // + "argument_type_display_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + "type_target_cast_pass", // + }); + + Place prefered_place{ +#ifdef PADDLE_WITH_CUDA + TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW), +#else +#ifdef PADDLE_WITH_ARM + TARGET(kARM), PRECISION(kFloat), DATALAYOUT(kNCHW), +#else // X86 + TARGET(kX86), PRECISION(kFloat), DATALAYOUT(kNCHW), +#endif // ARM +#endif + }; + optimizer.KernelPickPreferPlace(prefered_place); + optimizer.Run(std::move(program), places, factor, passes); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(mul); +USE_LITE_OP(feed); +USE_LITE_OP(fetch); +USE_LITE_OP(io_copy); + +#ifdef LITE_WITH_X86 +USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def); +#endif + +#ifdef LITE_WITH_ARM +USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); +#endif + +#ifdef LITE_WITH_CUDA +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host); +#endif diff --git a/lite/core/naive_test_model.py b/lite/core/naive_test_model.py new file mode 100644 index 00000000000..f89a5e115fa --- /dev/null +++ b/lite/core/naive_test_model.py @@ -0,0 +1,56 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy +import sys, os +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.backward import append_backward + +a = fluid.layers.data(name="a", shape=[2], dtype='float32') +label = fluid.layers.data(name="label", shape=[10], dtype='float32') + +a1 = fluid.layers.fc(input=a, size=3, act=None, bias_attr=False) + +cost = fluid.layers.square_error_cost(a1, label) +avg_cost = fluid.layers.mean(cost) + +optimizer = fluid.optimizer.SGD(learning_rate=0.001) +optimizer.minimize(cost) + +cpu = fluid.core.CPUPlace() +loss = exe = fluid.Executor(cpu) + +exe.run(fluid.default_startup_program()) +with open('startup_program.pb', 'wb') as f: + f.write(fluid.default_startup_program().desc.serialize_to_string()) + +#data_1 = np.array(numpy.random.random([100, 100]), dtype='float32') + +#fluid.default_main_program().desc. + +#prog = fluid.compiler.CompiledProgram(fluid.default_main_program()) +prog = fluid.default_main_program() + +#append_backward(loss) + +with open('main_program.pb', 'wb') as f: + f.write(prog.desc.serialize_to_string()) + +#outs = exe.run(program=prog, feed={'a':data_1, }, fetch_list=[cost]) + +#sys.exit(0) +fluid.io.save_inference_model("./model2", [a.name], [a1], exe) + +#print(numpy.array(outs)) diff --git a/lite/core/op_lite.cc b/lite/core/op_lite.cc new file mode 100644 index 00000000000..412b299339a --- /dev/null +++ b/lite/core/op_lite.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/op_lite.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +std::vector> OpLite::CreateKernels( + const std::vector &places, const std::string &kernel_type) { + std::vector> kernels; + CHECK(!op_type_.empty()) << "op_type_ should be set first"; + + auto pick_kernel = [&](const Place &place) { + auto ks = KernelRegistry::Global().Create( + op_type_, place.target, place.precision, place.layout); + VLOG(5) << "pick kernel for " << op_info()->Type() << " " + << place.DebugString() << " get " << ks.size() << " kernels"; + for (auto &&it : ks) { + AttachKernel(it.get()); + kernels.emplace_back(std::move(it)); + } + }; + + if (!kernel_type.empty()) { + Place place; + std::string op_type, alias; + KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place); + pick_kernel(place); + CHECK(!kernels.empty()) << "no kernel for kernel type " << kernel_type; + return kernels; + } + + std::set place_set; + for (auto place : places) { + place_set.insert(place); + // Pick kernels those support any Precision and any DataLayout + place.precision = PRECISION(kAny); + place_set.insert(place); + place.layout = DATALAYOUT(kAny); + place_set.insert(place); + } + + std::set targets; + for (auto place : place_set) { + pick_kernel(place); + targets.insert(place.target); + } + + VLOG(4) << "op " << op_type_ << " get " << kernels.size() << " kernels"; + return kernels; +} + +bool OpLite::Run() { + CHECK(kernel_); + SyncInputEvents(); + + kernel_->Launch(); + + RecordOutputEvents(); + return true; +} + +bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) { + // valid_places_.clear(); + CHECK(scope != nullptr); + // CHECK(!op_info_.get()); + scope_ = scope; + op_info_.reset( + new OpInfo(opdesc)); // Force clean the out-of-date infomation. + return AttachImpl(*op_info(), scope); +} + +const Tensor *OpLite::GetTensor(lite::Scope *scope, + const std::string &name) const { + auto *var = scope->FindVar(name); + CHECK(var) << "no variable called " << name << " found"; + return &var->Get(); +} + +Tensor *OpLite::GetMutableTensor(lite::Scope *scope, + const std::string &name) const { + auto *var = scope->FindVar(name); + CHECK(var) << "no variable called " << name << " found"; + return var->GetMutable(); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/op_lite.h b/lite/core/op_lite.h new file mode 100644 index 00000000000..f843ef6f2b3 --- /dev/null +++ b/lite/core/op_lite.h @@ -0,0 +1,231 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "lite/core/context.h" +#include "lite/core/kernel.h" +#include "lite/core/scope.h" +#include "lite/model_parser/cpp/op_desc.h" + +namespace paddle { +namespace lite { + +// For registry factory. +struct Registry { + void Touch() {} +}; + +namespace mir { +class Node; +class SSAGraph; +} + +class OpInfo; + +/** + * The base class of an light-weight operators, currently just used in inference + * to eliminate overhead of some operations in current framework. + * + * The Operator are designed as follows: + * - it can has some members to hold the argument and some other computation + * resources, + * - it should act like a function call, no more logic included. + */ +class OpLite : public Registry { + public: + OpLite() = default; + explicit OpLite(const std::string &type) : op_type_(type) {} + explicit OpLite(const std::vector &valid_places) + : valid_places_(valid_places) {} + + void SetValidPlaces(const std::vector &places) { + VLOG(3) << "valid places " << valid_places_.size(); + valid_places_ = places; + } + const std::vector &valid_places() const { return valid_places_; } + // Check the shape. + virtual bool CheckShape() const { return true; } + // Inference the outputs' shape. + virtual bool InferShape() const { return true; } + // Run this operator. + virtual bool Run(); + // Indicate whether the Op runs only once or not + virtual bool run_once() const { return false; } + std::string Type() { return op_type_; } + + // Link the external execution environ to internal context. + bool Attach(const cpp::OpDesc &opdesc, lite::Scope *scope); + + const OpInfo *op_info() const { return op_info_.get(); } + OpInfo *mutable_op_info() { return op_info_.get(); } + + // Human-readable information. + virtual std::string DebugString() const = 0; + + const Place &kernel_place() const { return kernel_place_; } + + // Create all the kernels for the valid targets. + std::vector> CreateKernels( + const std::vector &places, const std::string &kernel_type = ""); + + lite::Scope *scope() { return scope_; } + + // Assign op param to kernel. + virtual void AttachKernel(KernelBase *kernel) = 0; + void SetKernel(std::vector> &kernels) { // NOLINT + kernel_ = std::move(kernels.front()); + kernel_->SetContext( + ContextScheduler::Global().NewContext(kernel_->target())); + } + + KernelBase *GetKernel() { // NOLINT + return kernel_.get(); + } + + virtual ~OpLite() = default; + + protected: + // Attach it with the runtime environment. + virtual bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) = 0; + + // Specify the kernel to run by default. This will specify the value of + // `kernel_place_`. + virtual void StaticPickKernel(const std::vector &valid_targets) { + auto kernels = CreateKernels(valid_targets); + kernel_ = std::move(kernels.front()); + } + + // Wait until all the inputs' events are ready. + void SyncInputEvents() {} + + // Record the output events, and that will tell all the dependent operators + // some inputs are ready. + void RecordOutputEvents() {} + + const Tensor *GetTensor(lite::Scope *scope, const std::string &name) const; + Tensor *GetMutableTensor(lite::Scope *scope, const std::string &name) const; + + friend class mir::Node; + friend class mir::SSAGraph; + + protected: + // some helper functions. + template + const T *GetVar(Scope *scope, const std::string &name) { + auto *var = scope->FindVar(name); + CHECK(var) << "No var found for " << name; + return &var->Get(); + } + template + T *GetMutableVar(Scope *scope, const std::string &name) { + auto *var = scope->FindVar(name); + CHECK(var) << "No var found for " << name; + return var->GetMutable(); + } + + protected: + lite::Scope *scope_{nullptr}; + std::unique_ptr kernel_; + std::string op_type_; + std::vector valid_places_; + Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; + std::unique_ptr op_info_; +}; + +/* + * Operator Information, such as some description. It will be shared by all the + * kernels of the same operator. + */ +class OpInfo : public cpp::OpDesc { + public: + OpInfo(const OpInfo &) = default; + explicit OpInfo(const cpp::OpDesc &other) : cpp::OpDesc(other) {} + + // Collect all the input variable's name. + std::vector input_names() const { + std::vector res; + for (auto ¶m : InputArgumentNames()) { + for (auto &x : Input(param)) { + res.push_back(x); + } + } + return res; + } + + // Collect all the output variable's name. + std::vector output_names() const { + std::vector res; + for (auto ¶m : OutputArgumentNames()) { + for (auto &x : Output(param)) { + res.push_back(x); + } + } + return res; + } + + std::vector input_argnames() const { + return InputArgumentNames(); + } + + std::vector output_argnames() const { + return OutputArgumentNames(); + } + + bool GetInputArgname(const std::string &value_name, std::string *out) const { + for (auto &item : inputs_) { + auto it = std::find(item.second.begin(), item.second.end(), value_name); + if (it != item.second.end()) { + *out = item.first; + return true; + } + } + return false; + } + bool GetOutputArgname(const std::string &value_name, std::string *out) const { + for (auto &item : outputs_) { + auto it = std::find(item.second.begin(), item.second.end(), value_name); + if (it != item.second.end()) { + *out = item.first; + return true; + } + } + return false; + } + + void UpdateAllInputs(const std::string &from, const std::string &to) { + for (auto &item : inputs_) { + for (auto &var : item.second) { + if (var == from) var = to; + } + } + } + + void UpdateAllOutputs(const std::string &from, const std::string &to) { + for (auto &item : outputs_) { + for (auto &var : item.second) { + if (var == from) var = to; + } + } + } +}; + +} // namespace lite +} // namespace paddle diff --git a/lite/core/op_lite_test.cc b/lite/core/op_lite_test.cc new file mode 100644 index 00000000000..a18607834a6 --- /dev/null +++ b/lite/core/op_lite_test.cc @@ -0,0 +1,24 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/op_lite.h" +#include + +namespace paddle { +namespace lite { + +TEST(OpLite, test) {} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc new file mode 100644 index 00000000000..816837effc2 --- /dev/null +++ b/lite/core/op_registry.cc @@ -0,0 +1,152 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/op_registry.h" +#include +#include + +namespace paddle { +namespace lite { + +std::list> KernelRegistry::Create( + const std::string &op_type, + TargetType target, + PrecisionType precision, + DataLayoutType layout) { + Place place{target, precision, layout}; + VLOG(5) << "creating " << op_type << " kernel for " << place.DebugString(); +#define CREATE_KERNEL1(target__, precision__) \ + switch (layout) { \ + case DATALAYOUT(kNCHW): \ + return Create(op_type); \ + case DATALAYOUT(kAny): \ + return Create(op_type); \ + case DATALAYOUT(kNHWC): \ + return Create(op_type); \ + default: \ + LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \ + } + +#define CREATE_KERNEL(target__) \ + switch (precision) { \ + case PRECISION(kFloat): \ + CREATE_KERNEL1(target__, kFloat); \ + case PRECISION(kInt8): \ + CREATE_KERNEL1(target__, kInt8); \ + case PRECISION(kFP16): \ + CREATE_KERNEL1(target__, kFP16); \ + case PRECISION(kAny): \ + CREATE_KERNEL1(target__, kAny); \ + default: \ + CHECK(false) << "not supported kernel precision " \ + << PrecisionToStr(precision); \ + } + + switch (target) { + case TARGET(kHost): { + CREATE_KERNEL(kHost); + } break; + case TARGET(kX86): { + CREATE_KERNEL(kX86); + } break; + case TARGET(kCUDA): { + CREATE_KERNEL(kCUDA); + } break; + case TARGET(kARM): { + CREATE_KERNEL(kARM); + } break; + case TARGET(kOpenCL): { + CREATE_KERNEL(kOpenCL); + } break; + case TARGET(kNPU): { + CREATE_KERNEL(kNPU); + } break; + case TARGET(kFPGA): { + CREATE_KERNEL(kFPGA); + } break; + default: + CHECK(false) << "not supported kernel target " << TargetToStr(target); + } + +#undef CREATE_KERNEL + return std::list>(); +} + +KernelRegistry::KernelRegistry() + : registries_(static_cast(TARGET(NUM)) * + static_cast(PRECISION(NUM)) * + static_cast(DATALAYOUT(NUM))) { +#define INIT_FOR(target__, precision__, layout__) \ + registries_[KernelRegistry::GetKernelOffset()] \ + .set *>( \ + &KernelRegistryForTarget::Global()); + // Currently, just register 2 kernel targets. + INIT_FOR(kCUDA, kFloat, kNCHW); + INIT_FOR(kCUDA, kAny, kNCHW); + INIT_FOR(kCUDA, kAny, kAny); + + INIT_FOR(kHost, kFloat, kNCHW); + INIT_FOR(kHost, kAny, kNCHW); + INIT_FOR(kHost, kFloat, kNHWC); + INIT_FOR(kHost, kFloat, kAny); + INIT_FOR(kHost, kAny, kNHWC); + INIT_FOR(kHost, kAny, kAny); + INIT_FOR(kHost, kAny, kNHWC); + INIT_FOR(kHost, kAny, kAny); + + INIT_FOR(kX86, kFloat, kNCHW); + INIT_FOR(kX86, kAny, kNCHW); + INIT_FOR(kX86, kAny, kAny); + + INIT_FOR(kARM, kFloat, kNCHW); + INIT_FOR(kARM, kInt8, kNCHW); + INIT_FOR(kARM, kAny, kNCHW); + INIT_FOR(kARM, kAny, kAny); + + INIT_FOR(kOpenCL, kFloat, kNCHW); + INIT_FOR(kOpenCL, kAny, kNCHW); + INIT_FOR(kOpenCL, kAny, kAny); + + INIT_FOR(kNPU, kFloat, kNCHW); + INIT_FOR(kNPU, kInt8, kNCHW); + INIT_FOR(kNPU, kAny, kNCHW); + INIT_FOR(kNPU, kAny, kAny); + + INIT_FOR(kFPGA, kFP16, kNHWC); + INIT_FOR(kFPGA, kFP16, kAny); + INIT_FOR(kFPGA, kFloat, kNHWC); + INIT_FOR(kFPGA, kAny, kNHWC); + INIT_FOR(kFPGA, kAny, kAny); +#undef INIT_FOR +} + +KernelRegistry &KernelRegistry::Global() { + static auto *x = new KernelRegistry; + return *x; +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/op_registry.h b/lite/core/op_registry.h new file mode 100644 index 00000000000..3eaa0e033d4 --- /dev/null +++ b/lite/core/op_registry.h @@ -0,0 +1,282 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "lite/api/paddle_lite_factory_helper.h" +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/target_wrapper.h" +#include "lite/utils/all.h" + +using LiteType = paddle::lite::Type; + +namespace paddle { +namespace lite { + +using KernelFunc = std::function; +using KernelFuncCreator = std::function()>; +class LiteOpRegistry final : public Factory> { + public: + static LiteOpRegistry &Global() { + static auto *x = new LiteOpRegistry; + return *x; + } + + private: + LiteOpRegistry() = default; +}; + +template +class OpLiteRegistor : public Registor { + public: + explicit OpLiteRegistor(const std::string &op_type) + : Registor([&] { + LiteOpRegistry::Global().Register( + op_type, [op_type]() -> std::unique_ptr { + return std::unique_ptr(new OpClass(op_type)); + }); + }) {} +}; + +template +using KernelRegistryForTarget = + Factory, std::unique_ptr>; + +class KernelRegistry final { + public: + using any_kernel_registor_t = + variant *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget * // + >; + + KernelRegistry(); + + static KernelRegistry &Global(); + + template + void Register( + const std::string &name, + typename KernelRegistryForTarget::creator_t + &&creator) { + VLOG(3) << "register for " << TargetToStr(Target) << ":" + << PrecisionToStr(Precision) << "//" + << GetKernelOffset(); + using kernel_registor_t = + KernelRegistryForTarget; + auto &varient = registries_[GetKernelOffset()]; + auto *reg = varient.template get(); + CHECK(reg) << "Can not be empty of " << name; + reg->Register(name, std::move(creator)); + } + + template + std::list> Create(const std::string &op_type) { + using kernel_registor_t = + KernelRegistryForTarget; + return registries_[GetKernelOffset()] + .template get() + ->Creates(op_type); + } + + std::list> Create(const std::string &op_type, + TargetType target, + PrecisionType precision, + DataLayoutType layout); + + // Get a kernel registry offset in all the registries. + template + static int GetKernelOffset() { + CHECK_LT(static_cast(Target), static_cast(TARGET(NUM))); + CHECK_LT(static_cast(Precision), static_cast(PRECISION(NUM))); + CHECK_LT(static_cast(Layout), static_cast(DATALAYOUT(NUM))); + return static_cast(Target) * static_cast(PRECISION(NUM)) * + static_cast(DATALAYOUT(NUM)) + // + static_cast(Precision) * static_cast(DATALAYOUT(NUM)) + // + static_cast(Layout); + } + + std::string DebugString() const { + STL::stringstream ss; + ss << "KernelCreator:\n"; + constexpr TargetType tgt = TARGET(kHost); + constexpr PrecisionType dt = PRECISION(kFloat); + constexpr DataLayoutType lt = DATALAYOUT(kNCHW); + constexpr DataLayoutType kany = DATALAYOUT(kAny); + using kernel_registor_t = KernelRegistryForTarget; + auto *reg = registries_[GetKernelOffset()] + .template get(); + ss << reg->DebugString() << "\n"; + return ss.str(); + return ""; + } + + private: + mutable std::vector registries_; +}; + +template +class KernelRegistor : public lite::Registor { + public: + KernelRegistor(const std::string &op_type, const std::string &alias) + : Registor([=] { + VLOG(3) << "Register kernel " << op_type << " for " + << TargetToStr(target) << " " << PrecisionToStr(precision) + << " " << DataLayoutToStr(layout) << " alias " << alias; + KernelRegistry::Global().Register( + op_type, [=]() -> std::unique_ptr { + std::unique_ptr x(new KernelType); + x->set_op_type(op_type); + x->set_alias(alias); + return x; + }); + }) {} +}; + +} // namespace lite +} // namespace paddle + +// Operator registry +#define LITE_OP_REGISTER_INSTANCE(op_type__) op_type__##__registry__instance__ +#define REGISTER_LITE_OP(op_type__, OpClass) \ + static paddle::lite::OpLiteRegistor LITE_OP_REGISTER_INSTANCE( \ + op_type__)(#op_type__); \ + int touch_op_##op_type__() { \ + return LITE_OP_REGISTER_INSTANCE(op_type__).Touch(); \ + } + +// Kernel registry +#define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \ + op_type__##__##target__##__##precision__##__registor__ +#define LITE_KERNEL_REGISTER_INSTANCE( \ + op_type__, target__, precision__, layout__, alias__) \ + op_type__##__##target__##__##precision__##__registor__instance__##alias__ +#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \ + LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__) + +#define REGISTER_LITE_KERNEL( \ + op_type__, target__, precision__, layout__, KernelClass, alias__) \ + static paddle::lite::KernelRegistor \ + LITE_KERNEL_REGISTER_INSTANCE( \ + op_type__, target__, precision__, layout__, alias__)(#op_type__, \ + #alias__); \ + static KernelClass LITE_KERNEL_INSTANCE( \ + op_type__, target__, precision__, layout__, alias__); \ + int touch_##op_type__##target__##precision__##layout__##alias__() { \ + LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, alias__) \ + .Touch(); \ + return 0; \ + } \ + static bool LITE_KERNEL_PARAM_INSTANCE( \ + op_type__, target__, precision__, layout__, alias__) \ + __attribute__((unused)) = \ + paddle::lite::ParamTypeRegistry::NewInstance( \ + #op_type__ "/" #alias__) + +#define LITE_KERNEL_INSTANCE( \ + op_type__, target__, precision__, layout__, alias__) \ + op_type__##target__##precision__##layout__##alias__ +#define LITE_KERNEL_PARAM_INSTANCE( \ + op_type__, target__, precision__, layout__, alias__) \ + op_type__##target__##precision__##layout__##alias__##param_register diff --git a/lite/core/optimizer.cc b/lite/core/optimizer.cc new file mode 100644 index 00000000000..38a64a589f3 --- /dev/null +++ b/lite/core/optimizer.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/optimizer.h" +#include +#include "lite/core/mir/static_kernel_pick_pass.h" +#include "lite/core/mir/type_target_cast_pass.h" +#include "lite/model_parser/model_parser.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { + +void Optimizer::SpecifyKernelPickTactic(core::KernelPickFactor factor) { + auto* pass = mir::PassManager::Global().LookUp( + "static_kernel_pick_pass"); + CHECK(pass); + + *pass->mutable_kernel_pick_factors() = factor; +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h new file mode 100644 index 00000000000..7862c406c1c --- /dev/null +++ b/lite/core/optimizer.h @@ -0,0 +1,196 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "lite/core/mir/generate_program_pass.h" +#include "lite/core/mir/pass_manager.h" +#include "lite/core/mir/ssa_graph.h" +#include "lite/core/mir/static_kernel_pick_pass.h" +#include "lite/core/mir/type_target_cast_pass.h" +#include "lite/core/program.h" +#include "lite/core/types.h" +#include "lite/model_parser/model_parser.h" +#ifdef LITE_WITH_NPU +#include "lite/core/mir/subgraph/generate_npu_program_pass.h" +#endif + +namespace paddle { +namespace lite { + +/* + * lite::Optimizer optimize a program. It utilize the mir passes to analysis the + * program and export an optimized program. + */ +class Optimizer { + public: + void Run(Program&& program, + const std::vector& valid_places, + core::KernelPickFactor kernel_pick_factor, + const std::vector& passes = {}) { + program_ = &program; + valid_places_ = valid_places; + CHECK(!valid_places.empty()) << "At least one valid_place should be set"; + CHECK(!graph_) << "duplicate optimize found"; + graph_.reset(new mir::SSAGraph); + graph_->Build(program, valid_places); + graph_->SetValidPlaces(valid_places); + + SpecifyKernelPickTactic(kernel_pick_factor); + InitTargetTypeTransformPass(); + + if (passes.empty()) { + RunPasses(std::vector{ + {"lite_quant_dequant_fuse_pass", // + "lite_conv_bn_fuse_pass", // + // This pass is disabled to force some opencl kernels selected for + // final running, otherwise, they will be fused to ARM fusion + // kernels, and the OpenCL devices will be discarded. + // TODO(Superjomn) Refine the fusion related design to select fusion + // kernels for devices automatically. + "lite_conv_elementwise_fuse_pass", // + "lite_conv_activation_fuse_pass", // + "lite_fc_fuse_pass", // + "identity_scale_eliminate_pass", // +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + "lite_elementwise_add_activation_fuse_pass", // +#endif + "static_kernel_pick_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + + "type_target_cast_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + + "io_copy_kernel_pick_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + + "type_precision_cast_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + + "type_layout_cast_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + + "runtime_context_assign_pass", + "graph_visualze"}}); + } else { + RunPasses(passes); + } + exec_scope_ = program.exec_scope(); + } + + void KernelPickPreferPlace(const Place& place) { + auto* pass = mir::PassManager::Global().LookUp( + "static_kernel_pick_pass"); + CHECK(pass); + pass->SetPreferPlace(place); + } + + const lite::Scope* exec_scope() const { return exec_scope_; } + + // Generate a new program based on the mir graph. + std::unique_ptr GenRuntimeProgram() { + auto pass = mir::PassManager::Global().LookUp( + "generate_program_pass"); + pass->Apply(graph_); + auto program = pass->GenProgram(); + CHECK(exec_scope_); + program->set_exec_scope(exec_scope_); + return program; + } + + // check the input dims in the scope, must not be empty + void CheckInputDimsNotEmpty(const lite::Scope* scope) { + CHECK(scope); + auto* feed_var = scope->FindVar("feed"); + CHECK(feed_var) << "no feed variable in exec_scope: " << scope; + auto* feed_tensor_list = feed_var->GetMutable>(); + CHECK_GE(feed_tensor_list->size(), 1); + for (size_t i = 0; i < feed_tensor_list->size(); ++i) { + CHECK(!feed_tensor_list->at(i).dims().empty()) + << "Input " << i << " dims can not be empty."; + } + } + + std::unique_ptr GenNPURuntimeProgram() { +#ifdef LITE_WITH_NPU + CheckInputDimsNotEmpty(exec_scope_); + auto pass = mir::PassManager::Global() + .LookUp( + "generate_npu_program_pass"); + pass->Apply(graph_); + + auto program = pass->GenProgram(); + CHECK(exec_scope_); + program->set_exec_scope(exec_scope_); + return program; +#else + LOG(WARNING) << "Not compiled with NPU but use it!"; + return GenRuntimeProgram(); +#endif + } + + void InitTargetTypeTransformPass() { + auto* pass = + mir::PassManager::Global().LookUp( + "type_target_cast_pass"); + CHECK(pass); + CHECK(!valid_places_.empty()); + pass->SetValidPlaces(valid_places_); + } + + // Generate C++ code which combines the inference program, model and weights. + void GenCode(const std::string& code_dir); + + const mir::SSAGraph& ssa_graph() const { + CHECK(graph_); + return *graph_; + } + + mir::SSAGraph* mutable_ssa_graph() { + CHECK(graph_); + return graph_.get(); + } + + lite::Scope* exec_scope() { return exec_scope_; } + + protected: + void SpecifyKernelPickTactic(core::KernelPickFactor factor); + + // Specify the passes and run them. + void RunPasses(const std::vector& passes) { + for (auto& x : passes) { + LOG(INFO) << "== Running pass " << x; + auto* pass = mir::PassManager::Global().LookUp(x); + CHECK(pass) << "Can not find pass: " << x; + pass->Apply(graph_); + LOG(INFO) << "== Running pass Done." << x; + } + } + + private: + std::unique_ptr graph_; + std::vector valid_places_; + lite::Scope* exec_scope_{}; + Program* program_{}; +}; + +} // namespace lite +} // namespace paddle diff --git a/lite/core/optimizer_test.cc b/lite/core/optimizer_test.cc new file mode 100644 index 00000000000..ba5bc01b580 --- /dev/null +++ b/lite/core/optimizer_test.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/optimizer.h" +#include +#include +#include +#include "lite/api/paddle_use_passes.h" +#include "lite/core/mir/generate_program_pass.h" +#include "lite/core/mir/pass_manager.h" +#include "lite/core/mir/static_kernel_pick_pass.h" +#include "lite/core/program_fake_utils.h" + +namespace paddle { +namespace lite { + +TEST(Optimizer, test) { + Optimizer optimizer; + auto program_faker = ProgramFaker(); + program_faker.AddFeed("X", 0); + program_faker.AddFetch("X", 0); + + std::vector places({Place{TARGET(kHost), PRECISION(kFloat)}}); + + core::KernelPickFactor factor; + factor.ConsiderTarget(); + + auto scope = std::make_shared(); + auto program_proto = *program_faker.program()->Proto(); + Program program(program_proto, scope, places); + optimizer.Run(std::move(program), places, factor); + auto runtime_program = optimizer.GenRuntimeProgram(); + LOG(INFO) << "num statements " << runtime_program->num_instructions(); +} + +} // namespace lite +} // namespace paddle + +USE_LITE_OP(fc); +USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def); diff --git a/lite/core/profile/CMakeLists.txt b/lite/core/profile/CMakeLists.txt new file mode 100644 index 00000000000..de8a60bdc27 --- /dev/null +++ b/lite/core/profile/CMakeLists.txt @@ -0,0 +1,8 @@ +if (NOT LITE_WITH_PROFILE) + return() +endif() + +lite_cc_library(basic_profiler SRCS basic_profiler.cc) +lite_cc_test(test_basic_profiler SRCS basic_profiler_test.cc DEPS basic_profiler) + + diff --git a/lite/core/profile/basic_profiler.cc b/lite/core/profile/basic_profiler.cc new file mode 100644 index 00000000000..031b86beb6b --- /dev/null +++ b/lite/core/profile/basic_profiler.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/profile/basic_profiler.h" + +namespace paddle { +namespace lite { +namespace profile { + +const int BasicTimer::data_w = 10; +const int BasicTimer::name_w = 15; + +} // namespace profile +} // namespace lite +} // namespace paddle diff --git a/lite/core/profile/basic_profiler.h b/lite/core/profile/basic_profiler.h new file mode 100644 index 00000000000..4756322cb72 --- /dev/null +++ b/lite/core/profile/basic_profiler.h @@ -0,0 +1,201 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * This file implements BasicProfile, a profiler that helps to profile the basic + * CPU execution. It can display the min, max, average lantency of the execution + * of each kernel. + */ +#pragma once +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include "lite/utils/cp_logging.h" +#include "lite/utils/replace_stl/stream.h" +#include "lite/utils/string.h" + +namespace paddle { +namespace lite { +namespace profile { + +/* Base class of all the profile records */ +template +class TimerBase { + public: + void Start() { self()->Start(); } + void Stop() { self()->Stop(); } + void Log(uint32_t x) { return self()->Log(x); } + std::string basic_repr() const { return const_self()->basic_repr(); } + + void SetId(int id) { self()->SetId(id); } + void SetKey(const std::string &key) { self()->SetKey(key); } + + int id() const { return const_self()->id(); } + + protected: + ChildT *self() { return reinterpret_cast(this); } + const ChildT *const_self() const { + return reinterpret_cast(this); + } +}; + +class BasicTimer : TimerBase { + uint64_t total_{}; + uint64_t count_{}; + uint32_t max_{std::numeric_limits::min()}; + uint32_t min_{std::numeric_limits::max()}; + int id_{-1}; + std::string key_; + std::chrono::time_point timer_{}; + + // TODO(Superjomn) make static + static const int name_w; + static const int data_w; + + public: + BasicTimer() = default; + BasicTimer(int id, const std::string &key) : id_(id), key_(key) {} + + void SetId(int id) { id_ = id; } + void SetKey(const std::string &key) { key_ = key; } + void Start() { timer_ = std::chrono::high_resolution_clock::now(); } + void Stop() { + auto duration = std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - timer_); + Log(duration.count()); + } + + int count() const { return count_; } + + void Log(uint32_t timespan) { + total_ += timespan; + max_ = std::max(max_, timespan); + min_ = std::min(min_, timespan); + count_++; + } + + static std::string basic_repr_header() { + STL::stringstream ss; + ss << std::setw(name_w) << "kernel" // + << std::setw(data_w) << "average" // + << std::setw(data_w) << "min" // + << std::setw(data_w) << "max" // + << std::setw(data_w) << "count"; + return ss.str(); + } + + std::string basic_repr() const { + STL::stringstream ss; + ss << std::setw(name_w) << key() // + << std::setw(data_w) << ave() // + << std::setw(data_w) << min() // + << std::setw(data_w) << max() // + << std::setw(data_w) << count_; + return ss.str(); + } + + const std::string &key() const { return key_; } + + int id() const { + CHECK_GE(id_, 0) << "id is not inited"; + return id_; + } + + double ave() const { return total_ * 1. / count_; } + double max() const { return max_; } + double min() const { return min_; } + + // BasicRecord(const BasicRecord &) = delete; + void operator=(const BasicTimer &) = delete; +}; + +/* + * A basic profiler, with each record logs the total latency. + */ +template +class BasicProfiler { + public: + explicit BasicProfiler(const std::string &name) : name_(name) {} + using record_t = TimerT; + + static BasicProfiler &Global() { + static std::unique_ptr x(new BasicProfiler("[global]")); + return *x; + } + + record_t &NewRcd(const std::string &key) { + records_.emplace_back(); + records_.back().SetId(records_.size() - 1); + records_.back().SetKey(key); + return records_.back(); + } + + const record_t &record(int id) { + CHECK_LT(id, records_.size()); + CHECK_GE(id, 0); + return records_[id]; + } + + record_t *mutable_record(int id) { + CHECK_GE(id, 0); + CHECK_LT(static_cast(id), records_.size()); + return &records_[id]; + } + + std::string basic_repr() const { + STL::stringstream ss; + for (const auto &rcd : records_) { + ss << rcd.basic_repr() << "\n"; + } + return ss.str(); + } + + ~BasicProfiler() { + LOG(INFO) << "Profile dumps:"; + LOG(INFO) << "\n" + BasicTimer::basic_repr_header() + "\n" + basic_repr(); + } + + private: + std::string name_; + std::vector records_; +}; + +struct ProfileBlock { + explicit ProfileBlock(int id) : id_(id) { + BasicProfiler::Global().mutable_record(id_)->Start(); + } + + ~ProfileBlock() { + BasicProfiler::Global().mutable_record(id_)->Stop(); + } + + private: + int id_{}; +}; + +#define LITE_PROFILE_ONE(key__) \ + static int key__##__profiler_id = \ + ::paddle::lite::profile::BasicProfiler< \ + ::paddle::lite::profile::BasicTimer>::Global() \ + .NewRcd(#key__) \ + .id(); \ + ::paddle::lite::profile::ProfileBlock key__##profiler__(key__##__profiler_id); + +} // namespace profile +} // namespace lite +} // namespace paddle diff --git a/lite/core/profile/basic_profiler_test.cc b/lite/core/profile/basic_profiler_test.cc new file mode 100644 index 00000000000..928fdd61cb9 --- /dev/null +++ b/lite/core/profile/basic_profiler_test.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/profile/basic_profiler.h" +#include +#include // NOLINT +#include // NOLINT +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace profile { + +TEST(basic_record, init) { + BasicTimer timer; + timer.SetKey("hello"); +} + +TEST(basic_profile, init) { + auto& rcd = BasicProfiler::Global().NewRcd("fc"); + for (int i = 11; i < 100; i++) { + rcd.Log(i); + } + + LOG(INFO) << BasicProfiler::Global().basic_repr(); +} + +TEST(basic_profile, real_latency) { + LITE_PROFILE_ONE(test0); + std::this_thread::sleep_for(std::chrono::milliseconds(1200)); +} + +} // namespace profile +} // namespace lite +} // namespace paddle diff --git a/lite/core/profile/precision_profiler.h b/lite/core/profile/precision_profiler.h new file mode 100644 index 00000000000..65cc1600773 --- /dev/null +++ b/lite/core/profile/precision_profiler.h @@ -0,0 +1,102 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * This file implements BasicProfile, a profiler that helps to profile the basic + * CPU execution. It can display the min, max, average lantency of the execution + * of each kernel. + */ +#pragma once +#include +#include +#include "lite/core/program.h" + +namespace paddle { +namespace lite { +namespace profile { + +class PrecisionProfiler { + public: + explicit PrecisionProfiler(const Instruction* inst) : inst_(inst) {} + ~PrecisionProfiler() { + LOG(INFO) << ">> Running kernel: " << inst_->op()->op_info()->Repr() + << " on Target " << TargetToStr(inst_->kernel()->target()); + auto tensor_mean = [](const Tensor* in, PrecisionType ptype) -> double { + double sum = 0.; + switch (ptype) { + case PRECISION(kFloat): { + auto ptr = in->data(); + for (int i = 0; i < in->numel(); ++i) { + sum += ptr[i]; + } + return sum / in->numel(); + } + case PRECISION(kInt8): { + auto ptr = in->data(); + for (int i = 0; i < in->numel(); ++i) { + sum += ptr[i]; + } + return sum / in->numel(); + } + case PRECISION(kInt32): { + auto ptr = in->data(); + for (int i = 0; i < in->numel(); ++i) { + sum += ptr[i]; + } + return sum / in->numel(); + } + default: + LOG(INFO) << "unsupport data type: " << PrecisionToStr(ptype); + return 0.; + } + }; + if (inst_->op()->op_info()->Type() != "fetch") { + auto op = const_cast(inst_->op()); + auto kernel = inst_->kernel(); + auto op_scope = op->scope(); + auto out_names = op->op_info()->output_names(); + for (auto& out_name : out_names) { + std::string out_arg_name; + op->op_info()->GetOutputArgname(out_name, &out_arg_name); + auto type = kernel->GetOutputDeclType(out_arg_name); + if (type->IsTensor()) { + auto tout = op_scope->FindVar(out_name)->GetMutable(); + double mean = tensor_mean(tout, type->precision()); + LOG(INFO) << "output name: " << out_name << ", dims: " << tout->dims() + << ", precision: " << PrecisionToStr(type->precision()) + << ", mean value: " << mean; + } else if (type->IsTensorList()) { + auto tout = + op_scope->FindVar(out_name)->GetMutable>(); + for (auto& t : *tout) { + double mean = tensor_mean(&t, type->precision()); + LOG(INFO) << "output name: " << out_name << ", dims: " << t.dims() + << ", precision: " << PrecisionToStr(type->precision()) + << ", mean value: " << mean; + } + } + } + } + } + + private: + const Instruction* inst_{nullptr}; +}; + +} // namespace profile +} // namespace lite +} // namespace paddle + +#define LITE_PRECISION_PROFILE(inst) \ + { auto a = paddle::lite::profile::PrecisionProfiler(&inst); } diff --git a/lite/core/program.cc b/lite/core/program.cc new file mode 100644 index 00000000000..1005b4bb1f0 --- /dev/null +++ b/lite/core/program.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/program.h" +#include "lite/model_parser/cpp/block_desc.h" +#include "lite/model_parser/cpp/op_desc.h" +#include "lite/model_parser/cpp/var_desc.h" +#include "lite/operators/while_op.h" +#ifdef LITE_WITH_PROFILE +#include "lite/core/profile/precision_profiler.h" +#endif + +namespace paddle { +namespace lite { + +void RuntimeProgram::SaveOpInfosToProgram(cpp::ProgramDesc* desc) { + CHECK(desc); + // NOTE: RuntimeProgram do not has all meta info, so save model just update + // upon origin model + CHECK(desc->BlocksSize()); + auto& main_block = *desc->GetBlock(0); + main_block.ClearOps(); + for (auto& node : instructions_) { + auto* op = main_block.AddOp(); + *op = *node.op()->op_info(); + op->SetAttr(kKernelTypeAttr, node.kernel()->SerializedKernelType()); + } +} + +void RuntimeProgram::Run() { + for (auto& inst : instructions_) { + VLOG(4) << ">> Running kernel: " << inst.op()->op_info()->Repr() + << " on Target " << TargetToStr(inst.kernel()->target()); + + inst.Run(); +#ifdef LITE_WITH_PROFILE + LITE_PRECISION_PROFILE(inst) +#endif + } +} + +void Program::Build(const cpp::ProgramDesc& prog) { + CHECK(ops_.empty()) << "Executor duplicate Build found"; + + // Create operators. + auto program = prog; + CHECK(program.BlocksSize()); + auto& main_block = *program.GetBlock(0); + for (size_t i = 0; i < main_block.OpsSize(); ++i) { + auto& op_desc = *main_block.GetOp(i); + auto op_type = op_desc.Type(); + // if (op_type == "feed" || op_type == "fetch") continue; + VLOG(4) << "create Op [" << op_type << "]"; + auto op = LiteOpRegistry::Global().Create(op_type); + CHECK(op) << "no Op found for " << op_type; + if (op_type == "while") { + auto sub_block_idx = op_desc.GetAttr("sub_block"); + auto sub_block = + const_cast(prog).GetBlock( + sub_block_idx); + static_cast(op.get())->SetSubBlock(sub_block); + } + ops_.emplace_back(std::move(op)); + ops_.back()->Attach(op_desc, exec_scope_); + } +} + +void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) { + CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found"; + exec_scope_ = &scope_->NewScope(); + // Create Feed and Fetch var. + scope_->Var("feed")->GetMutable>(); + scope_->Var("fetch")->GetMutable>(); + tmp_vars_.push_back("feed"); + tmp_vars_.push_back("fetch"); + + auto program = prog; + CHECK(program.BlocksSize()); + for (size_t b = 0; b < program.BlocksSize(); ++b) { + auto& main_block = *program.GetBlock(b); + for (size_t i = 0; i < main_block.VarsSize(); ++i) { + auto& var_desc = *main_block.GetVar(i); + if (!var_desc.Persistable()) { + tmp_vars_.push_back(var_desc.Name()); + exec_scope_->Var(var_desc.Name()); + if (b > 0) { + VLOG(4) << "var: " << var_desc.Name(); + } + } else { + if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue; + weights_.push_back(var_desc.Name()); + if (var_desc.Persistable()) scope_->Var(var_desc.Name()); + } + } + } +} + +void Instruction::Run() { +#ifdef LITE_WITH_PROFILE + profile::ProfileBlock x(profile_id_); +#endif // LITE_WITH_PROFILE + CHECK(op_) << "op null"; + CHECK(kernel_) << "kernel null"; + if (first_epoch_) { + first_epoch_ = false; + CHECK(op_->CheckShape()); + } + + if (op_->run_once() && has_run_) return; + VLOG(4) << "kernel launch"; + op_->InferShape(); + kernel_->Launch(); + has_run_ = true; +} + +STL::ostream& operator<<(STL::ostream& os, const Instruction& other) { + os << other.kernel_->summary() << "\t(" << other.kernel_->doc() << ")"; + return os; +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/program.h b/lite/core/program.h new file mode 100644 index 00000000000..049dbe2aa67 --- /dev/null +++ b/lite/core/program.h @@ -0,0 +1,149 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/model_parser/cpp/program_desc.h" +#ifdef LITE_WITH_PROFILE +#include "lite/core/profile/basic_profiler.h" +#endif // LITE_WITH_PROFILE + +namespace paddle { +namespace lite { + +static const char kKernelTypeAttr[] = "__@kernel_type_attr@__"; + +// A program is used to represent a code program, in Paddle, a code program +// contains: +// - main block, which is a list of OpLite +// - scope: which contains all the weights +struct Program { + public: + explicit Program(const std::shared_ptr& root) { scope_ = root; } + Program(const cpp::ProgramDesc& desc, + const std::shared_ptr& root, + const std::vector& valid_places) + : scope_(root), valid_places_(valid_places), desc_(desc) { + CHECK(scope_) << "scope should be init first"; + VLOG(4) << "prepare work"; + PrepareWorkspace(desc); + VLOG(4) << "build desc"; + Build(desc); + VLOG(4) << "build desc finished"; + } + + std::unique_ptr Clone() const { + std::unique_ptr res(new Program(desc_, scope_, valid_places_)); + return res; + } + + const std::list& weights() const { return weights_; } + const std::list& tmp_vars() const { return tmp_vars_; } + std::list* mutable_weights() { return &weights_; } + std::list* mutable_tmp_vars() { return &tmp_vars_; } + + const std::list>& ops() const { return ops_; } + std::list>* mutable_ops() { return &ops_; } + + lite::Scope* exec_scope() { return exec_scope_; } + lite::Scope* scope() { return scope_.get(); } + + private: + // Build from a program and scope. + void Build(const cpp::ProgramDesc& program); + // Create temporary variables. + void PrepareWorkspace(const cpp::ProgramDesc& program); + + private: + std::list tmp_vars_; + std::list weights_; + std::list> ops_; + // the scope to run the kernels, NOTE this is the execution scope. + std::shared_ptr scope_; + std::vector valid_places_; + // Runtime scope. + lite::Scope* exec_scope_{}; + cpp::ProgramDesc desc_; +}; + +struct Instruction { + Instruction(const std::shared_ptr& op, + std::unique_ptr&& kernel) + : op_(op), kernel_(std::move(kernel)) { +#ifdef LITE_WITH_PROFILE + profile_id_ = profile::BasicProfiler::Global() + .NewRcd(kernel_->SerializedKernelType()) + .id(); +#endif // LITE_WITH_PROFILE + } + + // Run the instruction. + void Run(); + + friend STL::ostream& operator<<(STL::ostream& os, const Instruction& other); + + const OpLite* op() const { return op_.get(); } + const KernelBase* kernel() const { return kernel_.get(); } + KernelBase* mutable_kernel() { return kernel_.get(); } + + private: + std::shared_ptr op_; + std::unique_ptr kernel_; + bool first_epoch_{true}; + bool has_run_{false}; + +#ifdef LITE_WITH_PROFILE + // for profiler + int profile_id_{-1}; +#endif // LITE_WITH_PROFILE +}; + +/* + * A program contains kernels for runtime. + */ +class LITE_API RuntimeProgram { + public: + explicit RuntimeProgram(std::vector&& insts) + : instructions_(std::move(insts)) { + if (instructions_.empty()) { + LOG(FATAL) << "no instructions"; + } + } + + void Run(); + + void set_exec_scope(lite::Scope* x) { exec_scope_ = x; } + lite::Scope* exec_scope() { return exec_scope_; } + + size_t num_instructions() const { return instructions_.size(); } + + const std::vector& instructions() const { return instructions_; } + + void SaveOpInfosToProgram(cpp::ProgramDesc* desc); + + private: + RuntimeProgram(const RuntimeProgram&) = delete; + std::vector instructions_; + lite::Scope* exec_scope_{}; +}; + +} // namespace lite +} // namespace paddle diff --git a/lite/core/program_fake_utils.cc b/lite/core/program_fake_utils.cc new file mode 100644 index 00000000000..b4d7a00dfa3 --- /dev/null +++ b/lite/core/program_fake_utils.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/program_fake_utils.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace mir {} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/program_fake_utils.h b/lite/core/program_fake_utils.h new file mode 100644 index 00000000000..edcbb101aa5 --- /dev/null +++ b/lite/core/program_fake_utils.h @@ -0,0 +1,142 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include "lite/core/mir/ssa_graph.h" +#include "lite/core/op_registry.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace lite { + +Program FakeProgram() { + Program program(std::make_shared()); + + auto add_fc = [&](int id, std::string x) { + // create variables + std::string w1 = "w" + std::to_string(id); + std::string b1 = "b" + std::to_string(id); + std::string out1 = "out" + std::to_string(id); + auto w1v = program.scope()->Var(w1)->GetMutable(); + auto b1v = program.scope()->Var(b1)->GetMutable(); + auto out1v = program.scope()->Var(out1)->GetMutable(); + + cpp::OpDesc desc; + desc.SetInput("Input", {x}); + desc.SetInput("W", {w1}); + desc.SetInput("Bias", {b1}); + desc.SetOutput("Out", {out1}); + desc.SetType("fc"); + desc.SetAttr("in_num_col_dims", 1); + + // add to input + program.mutable_tmp_vars()->push_back(w1); + program.mutable_tmp_vars()->push_back(b1); + + auto fc_op = LiteOpRegistry::Global().Create("fc"); + fc_op->Attach(desc, program.scope()); + program.mutable_ops()->emplace_back(std::move(fc_op)); + + w1v->Resize(DDimHvy(std::vector({100, 100}))); + b1v->Resize(DDimHvy(std::vector({100, 1}))); + out1v->Resize(DDimHvy(std::vector({100, 100}))); + + return out1; + }; + + // x1, w1, b1 -fc-> out1 + // out1, w2, b2 -fc-> out2 + + std::string x = "x"; + program.mutable_tmp_vars()->push_back(x); + auto* xv = program.scope()->Var(x)->GetMutable(); + xv->Resize(DDimHvy(std::vector({100, 100}))); + + for (int i = 0; i < 3; i++) { + x = add_fc(i, x); + } + return program; +} + +class ProgramFaker { + public: + ProgramFaker() {} + + framework::ProgramDesc* program() { + desc_.Flush(); + return &desc_; + } + + void CreateVars(lite::Scope* scope) { + for (auto& var : tmp_vars_) { + auto* x = scope->Var(var); + x->GetMutable(); + } + + for (auto& x : tmp_vars_) { + desc_.MutableBlock(0)->Var(x); + } + } + + void AddMul(const std::string& X, + const std::string& Y, + const std::string& out) { + tmp_vars_.insert(X); + tmp_vars_.insert(Y); + tmp_vars_.insert(out); + + auto* block = desc_.MutableBlock(0); + auto* op = block->AppendOp(); + op->SetType("mul"); + op->SetInput("X", {X}); + op->SetInput("Y", {Y}); + op->SetOutput("Out", {Y}); + op->SetAttr("x_num_col_dims", 1); + op->SetAttr("y_num_col_dims", 1); + } + + void AddFeed(const std::string& Out, int col) { + tmp_vars_.insert(Out); + + auto* block = desc_.MutableBlock(0); + auto* op = block->AppendOp(); + op->SetType("feed"); + op->SetInput("X", {"feed"}); + op->SetOutput("Out", {Out}); + op->SetAttr("col", col); + } + + void AddFetch(const std::string& Input, int col) { + tmp_vars_.insert(Input); + auto* block = desc_.MutableBlock(0); + auto* op = block->AppendOp(); + op->SetType("fetch"); + op->SetInput("X", {Input}); + op->SetOutput("Out", {"fetch"}); + op->SetAttr("col", col); + } + + private: + std::set tmp_vars_; + std::vector weight_vars_; + framework::ProgramDesc desc_; +}; + +} // namespace lite +} // namespace paddle diff --git a/lite/core/scope.cc b/lite/core/scope.cc new file mode 100644 index 00000000000..775652e2a0d --- /dev/null +++ b/lite/core/scope.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/scope.h" + +namespace paddle { +namespace lite { + +Scope::~Scope() { + for (auto *x : kids_) { + if (x) { + delete x; + } + } +} + +Scope &Scope::NewScope() const { + kids_.push_back(new Scope); + kids_.back()->parent_ = this; + return *kids_.back(); +} + +Variable *Scope::Var(const std::string &name) { + auto *var = FindVar(name); + if (var) return var; + + // create a new variable. + vars_.emplace(name, std::unique_ptr(new Variable)); + return vars_[name].get(); +} + +Variable *Scope::FindVar(const std::string &name) const { + Variable *var{nullptr}; + var = FindLocalVar(name); + const Scope *cur_scope = this; + while (!var && cur_scope->parent()) { + cur_scope = cur_scope->parent(); + var = cur_scope->FindLocalVar(name); + } + + return var; +} + +Variable *Scope::FindLocalVar(const std::string &name) const { + auto it = vars_.find(name); + if (it != vars_.end()) { + return it->second.get(); + } + return nullptr; +} + +std::vector Scope::LocalVarNames() const { + std::vector keys; + for (const auto &item : vars_) { + keys.push_back(item.first); + } + return keys; +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/scope.h b/lite/core/scope.h new file mode 100644 index 00000000000..2593c365224 --- /dev/null +++ b/lite/core/scope.h @@ -0,0 +1,79 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include +#include "lite/core/variable.h" + +namespace paddle { +namespace lite { + +class Scope final { + public: + Scope() {} + // delete below two functions to allow pybind to recognise it cannot make a + // copy + // link: + // https://stackoverflow.com/questions/53807248/pybind11-returning-a-pointer-to-a-container-of-unique-ptr + Scope(const Scope&) = delete; + Scope& operator=(const Scope&) = delete; + ~Scope(); + + Scope& NewScope() const; + + Variable* Var(const std::string& name); + + Variable* FindVar(const std::string& name) const; + + Variable* FindLocalVar(const std::string& name) const; + + const Scope* parent() const { return parent_; } + + // Following the legacy scope interface. + std::vector LocalVarNames() const; + + /// ------------------------------------- helper functions for Tensor + /// ---------------------------------- + // Create a Tensor variable. This will create a new Variable called `name`. + Tensor* NewTensor(const std::string& name) { + auto* var = Var(name); + return var->GetMutable(); + } + + const Tensor* FindTensor(const std::string& name) { + auto* var = FindVar(name); + if (!var) return nullptr; + return &var->Get(); + } + + Tensor* FindMutableTensor(const std::string& name) { + auto* var = FindVar(name); + if (!var) return nullptr; + return var->GetMutable(); + } + + private: + // Scope in `kids_` are owned by this class. + mutable std::list kids_; + const Scope* parent_{nullptr}; + std::unordered_map> vars_; +}; + +} // namespace lite +} // namespace paddle diff --git a/lite/core/scope_test.cc b/lite/core/scope_test.cc new file mode 100644 index 00000000000..8806e6b1c06 --- /dev/null +++ b/lite/core/scope_test.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/scope.h" +#include + +namespace paddle { +namespace lite { + +TEST(Scope, Var) { + Scope scope; + auto* x = scope.Var("x"); + *x->GetMutable() = 100; + + ASSERT_EQ(x->Get(), 100); +} + +TEST(Scope, FindVar) { + Scope scope; + ASSERT_FALSE(scope.FindVar("x")); + scope.Var("x"); + ASSERT_TRUE(scope.FindVar("x")); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/target_wrapper.cc b/lite/core/target_wrapper.cc new file mode 100644 index 00000000000..046336036bb --- /dev/null +++ b/lite/core/target_wrapper.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/target_wrapper.h" +#include +#include "lite/utils/all.h" + +namespace paddle { +namespace lite {} // namespace lite +} // namespace paddle diff --git a/lite/core/target_wrapper.h b/lite/core/target_wrapper.h new file mode 100644 index 00000000000..491695ac7b1 --- /dev/null +++ b/lite/core/target_wrapper.h @@ -0,0 +1,208 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "lite/api/paddle_place.h" +#include "lite/utils/cp_logging.h" + +#ifdef LITE_WITH_CUDA +#include +#include +#endif // LITE_WITH_CUDA + +namespace paddle { +namespace lite { + +using lite_api::TargetType; +using lite_api::PrecisionType; +using lite_api::DataLayoutType; +using lite_api::PrecisionTypeLength; +using lite_api::TargetToStr; +using lite_api::Place; +using lite_api::PrecisionToStr; +using lite_api::DataLayoutToStr; +using lite_api::TargetRepr; +using lite_api::PrecisionRepr; +using lite_api::DataLayoutRepr; + +// Memory copy directions. +enum class IoDirection { + HtoH = 0, // Host to host + HtoD, // Host to device + DtoH, // Device to host + DtoD, // Device to device +}; + +// This interface should be specified by each kind of target. +template +class TargetWrapper { + public: + using stream_t = StreamTy; + using event_t = EventTy; + + static size_t num_devices() { return 0; } + static size_t maximum_stream() { return 0; } + + static void CreateStream(stream_t* stream) {} + static void DestroyStream(const stream_t& stream) {} + + static void CreateEvent(event_t* event) {} + static void DestroyEvent(const event_t& event) {} + + static void RecordEvent(const event_t& event) {} + static void SyncEvent(const event_t& event) {} + + static void StreamSync(const stream_t& stream) {} + + static void* Malloc(size_t size) { + LOG(FATAL) << "Unimplemented malloc for " << TargetToStr(Target); + return nullptr; + } + static void Free(void* ptr) { LOG(FATAL) << "Unimplemented"; } + + static void MemcpySync(void* dst, + const void* src, + size_t size, + IoDirection dir) { + LOG(FATAL) << "Unimplemented"; + } + static void MemcpyAsync(void* dst, + const void* src, + size_t size, + IoDirection dir, + const stream_t& stream) { + MemcpySync(dst, src, size, dir); + } +}; + +// This interface should be specified by each kind of target. +using TargetWrapperHost = TargetWrapper; +using TargetWrapperX86 = TargetWrapperHost; +template <> +class TargetWrapper { + public: + using stream_t = int; + using event_t = int; + + static size_t num_devices() { return 0; } + static size_t maximum_stream() { return 0; } + + static void CreateStream(stream_t* stream) {} + static void DestroyStream(const stream_t& stream) {} + + static void CreateEvent(event_t* event) {} + static void DestroyEvent(const event_t& event) {} + + static void RecordEvent(const event_t& event) {} + static void SyncEvent(const event_t& event) {} + + static void StreamSync(const stream_t& stream) {} + + static void* Malloc(size_t size); + static void Free(void* ptr); + + static void MemcpySync(void* dst, + const void* src, + size_t size, + IoDirection dir); + static void MemcpyAsync(void* dst, + const void* src, + size_t size, + IoDirection dir, + const stream_t& stream) { + MemcpySync(dst, src, size, dir); + } +}; + +#ifdef LITE_WITH_FPGA +template <> +class TargetWrapper { + public: + using stream_t = int; + using event_t = int; + + static size_t num_devices() { return 0; } + static size_t maximum_stream() { return 0; } + + static void CreateStream(stream_t* stream) {} + static void DestroyStream(const stream_t& stream) {} + + static void CreateEvent(event_t* event) {} + static void DestroyEvent(const event_t& event) {} + + static void RecordEvent(const event_t& event) {} + static void SyncEvent(const event_t& event) {} + + static void StreamSync(const stream_t& stream) {} + + static void* Malloc(size_t size); + static void Free(void* ptr); + + static void MemcpySync(void* dst, + const void* src, + size_t size, + IoDirection dir); + static void MemcpyAsync(void* dst, + const void* src, + size_t size, + IoDirection dir, + const stream_t& stream) { + MemcpySync(dst, src, size, dir); + } +}; +#endif +#ifdef LITE_WITH_CUDA +using TargetWrapperCuda = + TargetWrapper; +// This interface should be specified by each kind of target. +template <> +class TargetWrapper { + public: + using stream_t = cudaStream_t; + using event_t = cudaEvent_t; + + static size_t num_devices() { return 0; } + static size_t maximum_stream() { return 0; } + + static void CreateStream(stream_t* stream) {} + static void DestroyStream(const stream_t& stream) {} + + static void CreateEvent(event_t* event) {} + static void DestroyEvent(const event_t& event) {} + + static void RecordEvent(const event_t& event) {} + static void SyncEvent(const event_t& event) {} + + static void StreamSync(const stream_t& stream) {} + + static void* Malloc(size_t size); + static void Free(void* ptr); + + static void MemcpySync(void* dst, + const void* src, + size_t size, + IoDirection dir); + static void MemcpyAsync(void* dst, + const void* src, + size_t size, + IoDirection dir, + const stream_t& stream); +}; +#endif // LITE_WITH_CUDA + +} // namespace lite +} // namespace paddle diff --git a/lite/core/tensor.cc b/lite/core/tensor.cc new file mode 100644 index 00000000000..4dd4f5319d6 --- /dev/null +++ b/lite/core/tensor.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LITE_WITH_FPGA + +#include "lite/core/tensor.h" +#include +#include "lite/utils/string.h" + +namespace paddle { +namespace lite { + +using value_type = int64_t; + +value_type DDimLite::production() const { + value_type res = 1; + for (size_t i = 0; i < this->size(); i++) { + res *= (*this)[i]; + } + return res; +} + +value_type DDimLite::count(int start, int end) const { + if (start < 0) { + start = 0; + } + if (end > size()) { + end = size(); + } + if (end < start) { + end = start; + } + value_type sum = 1; + for (auto i = start; i < end; ++i) { + sum *= data_[i]; + } + return sum; +} + +DDimLite DDimLite::Slice(int start, int end) const { + std::vector vec; + for (int i = start; i < end; i++) { + vec.push_back((*this)[i]); + } + return DDimLite(vec); +} + +std::string DDimLite::repr() const { + STL::stringstream ss; + if (empty()) { + ss << "{}"; + return ss.str(); + } + ss << "{"; + for (size_t i = 0; i < this->size() - 1; i++) { + ss << (*this)[i] << ","; + } + if (!this->empty()) ss << (*this)[size() - 1]; + ss << "}"; + return ss.str(); +} + +void TensorLite::ShareDataWith(const TensorLite &other) { + buffer_ = other.buffer_; + dims_ = other.dims_; + target_ = other.target_; + lod_ = other.lod_; + memory_size_ = other.memory_size_; +} + +void *TensorLite::mutable_data(size_t memory_size) { + memory_size_ = memory_size; + buffer_->ResetLazy(target_, memory_size_); + return buffer_->data(); +} + +void *TensorLite::mutable_data(TargetType target, size_t memory_size) { + target_ = target; + return mutable_data(memory_size); +} + +void TensorLite::CopyDataFrom(const TensorLite &other) { + dims_ = other.dims_; + target_ = other.target_; + lod_ = other.lod_; + memory_size_ = other.memory_size_; + buffer_->CopyDataFrom(*other.buffer_, memory_size_); +} + +// static LoD TensorLite::ToAbsOffset(const LoD &lod) { +// if (lod.empty() || lod.size() == 1) return lod; +// LoD ret = lod; +// for (int level = static_cast(lod.size()) - 2; level >= 0; --level) { +// for (size_t i = 0; i < lod[level].size(); ++i) { +// size_t index = lod[level][i]; +// result[level][i] = result[level + 1][index]; +// } +// } +//} + +} // namespace lite +} // namespace paddle + +#endif diff --git a/lite/core/tensor.h b/lite/core/tensor.h new file mode 100644 index 00000000000..5ad2db15936 --- /dev/null +++ b/lite/core/tensor.h @@ -0,0 +1,227 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef LITE_WITH_FPGA +#include "lite/fpga/lite_tensor.h" +#endif + +#ifndef LITE_WITH_FPGA + +#include +#include // for multiplies +#include +#include +#include +#include +#include "lite/core/memory.h" +#include "lite/utils/replace_stl/stream.h" + +namespace paddle { +namespace lite { + +class DDimLite; +class TensorLite; + +using DDim = lite::DDimLite; +using Tensor = lite::TensorLite; + +class DDimLite { + public: + using value_type = int64_t; + + DDimLite() = default; + + explicit DDimLite(const std::vector &x) { ConstructFrom(x); } + // DDimLite(std::initializer_list init_list) : + // DDimLite(std::vector(init_list)) {} + + void ConstructFrom(const std::vector &x) { data_ = x; } + + value_type operator[](int offset) const { return data_[offset]; } + value_type &operator[](int offset) { return data_[offset]; } + std::vector Vectorize() const { return data_; } + + size_t size() const { return data_.size(); } + bool empty() const { return data_.empty(); } + + value_type production() const; + + const std::vector &data() const { return data_; } + value_type count(int start, int end) const; + + DDimLite Slice(int start, int end) const; + + DDimLite Flatten2D(int col) const { + return DDimLite(std::vector( + {Slice(0, col).production(), Slice(col, size()).production()})); + } + + std::string repr() const; + + friend STL::ostream &operator<<(STL::ostream &os, const DDimLite &dims) { + os << dims.repr(); + return os; + } + + friend bool operator==(const DDimLite &a, const DDimLite &b) { + if (a.size() != b.size()) return false; + for (size_t i = 0; i < a.size(); i++) { + if (a[i] != b[i]) return false; + } + return true; + } + + friend bool operator!=(const DDimLite &a, const DDimLite &b) { + return !(a == b); + } + + private: + std::vector data_; +}; + +using LoD = std::vector>; + +// A light-weight tensor implementation. +class TensorLite { + public: + TensorLite() : buffer_(std::make_shared()) {} + + template + void Assign(DType *data, const DimT &dim) { + Resize(dim); + auto *dst = mutable_data(Target); + CopySync( + dst, data, dim.production() * sizeof(DType), IoDirection::HtoD); + } + + // T is the data type and R is the return type + // For OpenCL, the return type can be cl::Buffer + // and the data type can be float/int8_t. + // For other devices, T and R may be the same type. + template + const R *data() const { + return static_cast(buffer_->data()); + } + + void Resize(const DDimLite &ddim) { dims_ = ddim; } + void Resize(const std::vector &x) { dims_ = DDimLite(x); } + + const DDimLite &dims() const { return dims_; } + int64_t numel() const { return dims_.production(); } + + const LoD &lod() const { return lod_; } + LoD *mutable_lod() { return &lod_; } + void set_lod(const LoD &lod) { lod_ = lod; } + + // T is the data type and R is the return type + // For OpenCL, the return type can be cl::Buffer + // and the data type can be float/int8_t. + // For other devices, T and R may be the same type. + template + R *mutable_data(); + + // T is the data type and R is the return type + // For OpenCL, the return type can be cl::Buffer + // and the data type can be float/int8_t. + // For other devices, T and R may be the same type. + template + R *mutable_data(TargetType target); + void *mutable_data(size_t memory_size); + void *mutable_data(TargetType target, size_t memory_size); + + const void *raw_data() const { + return static_cast( + (static_cast(buffer_->data()) + offset_)); + } + + size_t data_size() const { return this->dims().production(); } + + size_t memory_size() const { return memory_size_; } + + size_t offset() const { return offset_; } + + bool IsInitialized() const { return buffer_->data(); } + + // Other share data to this. + void ShareDataWith(const TensorLite &other); + + void CopyDataFrom(const TensorLite &other); + + TargetType target() const { return target_; } + + template + TensorLite Slice(int64_t begin, int64_t end) const; + + friend STL::ostream &operator<<(STL::ostream &os, const TensorLite &tensor) { + os << "Tensor:" << '\n'; + os << "dim: " << tensor.dims() << '\n'; + for (int i = 0; i < tensor.dims().production(); i++) { + os << tensor.template data()[i] << " "; + } + os << "\n"; + return os; + } + + private: + TargetType target_{TargetType::kHost}; + DDimLite dims_; + std::shared_ptr buffer_; + LoD lod_; + size_t memory_size_{}; + + /// @brief Buffer may be shared with other tensors + size_t offset_{0}; +}; + +template +R *TensorLite::mutable_data() { + memory_size_ = dims_.production() * sizeof(T); + buffer_->ResetLazy(target_, memory_size_); + return static_cast(buffer_->data()); +} + +template +R *TensorLite::mutable_data(TargetType target) { + target_ = target; + memory_size_ = dims_.production() * sizeof(T); + buffer_->ResetLazy(target, memory_size()); + return static_cast(buffer_->data()); +} + +template +TensorLite TensorLite::Slice(int64_t begin, int64_t end) const { + int64_t base = numel() / dims_[0]; + + TensorLite dst; + dst.buffer_ = buffer_; + dst.target_ = target_; + auto dst_dims = dims_; + dst_dims[0] = end - begin; + dst.Resize(dst_dims); + dst.offset_ = offset_ + static_cast(begin * base) * sizeof(T); +} + +template +bool TensorCompareWith(const TensorT &a, const TensorT &b) { + if (a.dims() != b.dims()) return false; + if (memcmp(a.raw_data(), b.raw_data(), a.data_size()) != 0) return false; + return true; +} + +} // namespace lite +} // namespace paddle + +#endif diff --git a/lite/core/type_system.cc b/lite/core/type_system.cc new file mode 100644 index 00000000000..276d0c4a349 --- /dev/null +++ b/lite/core/type_system.cc @@ -0,0 +1,157 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/type_system.h" +#include "lite/utils/string.h" + +namespace paddle { +namespace lite { + +size_t ParamTypeRegistry::KernelIdTy::hash() const { + std::hash h; + size_t hash = h(kernel_type); + hash = hash_combine(hash, place.hash()); + hash = hash_combine(hash, std::hash()(static_cast(io))); + hash = hash_combine(hash, std::hash()(arg_name)); + return hash; +} + +STL::ostream &operator<<(STL::ostream &os, const Type &other) { + os << other.name(); + return os; +} + +// An map is used to maintain a global repo for types. We don't use +// MACROs with static variables for that the TypeSystem should only used in +// compile time, that is not performance sensitive, and a map-based way is +// easier to implement and maintain. +// +// The map is declared in each Type::GetXXX method other than in the Type class +// so that it will force to construct before any usage. + +const Type *Type::GetTensorTy(TargetType target, + PrecisionType precision, + DataLayoutType layout, + int device) { + static std::map type_repo; + // NOTE quite naive implementation here, but not performance sensitive. + DataType::ID type_id = DataType::ID::Tensor; + +#define HASH_ONE(x) v = hash_combine(v, hasher(static_cast(x))) + + std::hash hasher; + size_t v = hasher(static_cast(type_id)); + HASH_ONE(target); + HASH_ONE(precision); + HASH_ONE(layout); + HASH_ONE(device); +#undef HASH_ONE + + STL::stringstream name; + name << "Tensor<"; + name << TargetToStr(target) << ","; + name << PrecisionToStr(precision) << ","; + name << DataLayoutToStr(layout) << ","; + name << device; + name << ">"; + + if (!type_repo[v]) + // The Types should alive across the process life, no need to delete. + type_repo[v] = + new Type(type_id, name.str(), target, precision, layout, device); + return type_repo[v]; +} + +const Type *Type::GetTensorListTy(TargetType target, + PrecisionType precision, + DataLayoutType layout, + int device) { + static std::map type_repo; + DataType::ID type_id = DataType::ID::TensorList; + +#define HASH_ONE(x) v = hash_combine(v, hasher(static_cast(x))) + + std::hash hasher; + size_t v = hasher(static_cast(type_id)); + HASH_ONE(target); + HASH_ONE(precision); + HASH_ONE(layout); + HASH_ONE(device); +#undef HASH_ONE + + STL::stringstream name; + name << "TensorList<"; + name << TargetToStr(target) << ","; + name << PrecisionToStr(precision) << ","; + name << DataLayoutToStr(layout) << ","; + name << device; + name << ">"; + + if (!type_repo[v]) + // The Types should alive across the process life, no need to delete. + type_repo[v] = + new Type(type_id, name.str(), target, precision, layout, device); + return type_repo[v]; +} + +const Type *Type::GetUnsupportedTy() { + static std::map type_repo; + std::hash hasher; + size_t v = hasher(static_cast(DataType::ID::Unsupported)); + if (!type_repo[v]) + type_repo[v] = new Type(DataType::ID::Unsupported, + "Unsupported", + TARGET(kUnk), + PRECISION(kUnk), + DATALAYOUT(kUnk), + -1); + return type_repo[v]; +} + +const Type *Type::GetVoidTy() { + static std::map type_repo; + std::hash hasher; + size_t v = hasher(static_cast(DataType::ID::Void)); + if (!type_repo[v]) + type_repo[v] = new Type(DataType::ID::Void, + "Void", + TARGET(kAny), + PRECISION(kAny), + DATALAYOUT(kAny), + -1); + return type_repo[v]; +} + +const Type *Type::Get(DataType::ID type_id, + TargetType target, + PrecisionType precision, + DataLayoutType layout, + int device) { + switch (type_id) { + case DataType::ID::Void: + return GetVoidTy(); + case DataType::ID::Unsupported: + return GetUnsupportedTy(); + case DataType::ID::Tensor: + return GetTensorTy(target, precision, layout, device); + case DataType::ID::TensorList: + return GetTensorListTy(target, precision, layout, device); + default: + LOG(FATAL) << "Unknown Type found"; + return nullptr; + } +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/type_system.h b/lite/core/type_system.h new file mode 100644 index 00000000000..722cdca0eb1 --- /dev/null +++ b/lite/core/type_system.h @@ -0,0 +1,390 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +// This file contains the file system of the lite system. Every data type in +// Variable should be registered here, and the analysis phase will check the +// data type correction. +// This mechanism is made for keeping our system simpler and more stable, for +// the dubious typed Variables in the Operators' inputs and outputs are disaster +// for analysis and runtime. + +#include +#include +#include +#include +#include +#include +#include "lite/core/tensor.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { + +// Type is the definition of all the types that supported by the Variable that +// represents as the input and output of an operator or kernel. +// The DNN system is simple, just a list of operators, and the architecture +// can not process that many data types as a compiler, or that will turn out to +// a chaos. +// +// We should make sure that the supported data types be registered here, and +// keep the set small and avoid using some special data types as op's +// inputs or outputs, such as some runtime cache, those types can't be processed +// by the MIR. +// +// A tensor with different places(target, precision, data layout or device) +// should be treated as different types. Different types might be compatible +// with each other, for example, the `VoidTy` means any type, so any other types +// can be treated as a `VoidTy`. +// +// The Different Types can transform to others by adding some special +// transforming operators, for example, a DataLayoutTransformOp can convert a +// `TensorFp32NCHWTy` to a `TensorFp32NHWCTy`; a IoCopyOp can convert a +// `TensorFp32NCHWTy(kHost)` to `TensorFp32NCHWTy(kCUDA)`. There are many other +// convertions between different Types, but there are some unsupported type +// convertions, for example, there is noway to convert a `UnsupportedTy` to a +// `TensorAnyTy`. +// +// We use Types to declare the definition of a kernel, each inputs' and outputs' +// arguments have a specific Types. +// +// REGISTER_LITE_KERNEL(mul, kHost, kFloat, +// paddle::lite::kernels::host::MulCompute, def) +// .BindInput("X", {paddle::lite::Type::Get( +// TARGET(kHost))}) +// .BindInput("Y", {paddle::lite::Type::Get( +// TARGET(kHost))}) +// .BindOutput("Out", +// {paddle::lite::Type::Get(TARGET(kHost))}) +// .Finalize(); +// +// The above definition will be used in MIR by Type inference and uncompatible +// types check. +// +// TODO(Superjomn) Add operator/kernel-wise static checking to avoid unsupported +// type mixed in the system. +class DataType { + public: + // The Void type can cast to any other type. + // The Unsupported is the data type that developed include in the system, for + // example, some `std::set` is used as input of some operator. It wan't be + // analyzed or optimized by the system, that way results in many bugs in + // previous system, so it should be avoided. + enum class ID : int { + Void = 0, // unknown type that can be cast to any data type. + Unsupported, // Unsupported data type that will not be analyzed. + // Tensor_Any represents a Tensor with any place, data, layout. It is used + // in some IO kernels those doesn't care the data. + Tensor, + // A tensor list, but all the elements should have the same type. + TensorList, + // --------- + NumTypes, // Must remains as last defined ID. + }; + + ID id() const { return id_; } + + // type check. + bool IsVoid() const { return id_ == ID::Void; } + bool IsUnsupported() const { return id_ == ID::Unsupported; } + bool IsTensor() const { return id_ == ID::Tensor; } + bool IsTensorList() const { return id_ == ID::TensorList; } + // Get number of types. + int num_types() const { return static_cast(ID::NumTypes); } + + protected: + // Can only extended by subclass. + explicit DataType(ID id) : id_(id) {} + + ID id_{ID::Unsupported}; +}; + +/* + * Datatype with device info considered. + * NOTE A Type with different device is treated as different DeviceDataType. + */ +class Type : public DataType { + public: + // Can cast to another type. This is heavily used in MIR, by determine whether + // is is possible to add a statement to transform a type to another. + virtual bool TypeCastable(const Type& type) const { return id_ == type.id(); } + + /// Get a Tensor type. + static const Type* GetTensorTy(TargetType target, + PrecisionType precision = PRECISION(kFloat), + DataLayoutType layout = DATALAYOUT(kNCHW), + int device = 0); + /// Get a TensorList type. + static const Type* GetTensorListTy( + TargetType target, + PrecisionType precision = PRECISION(kFloat), + DataLayoutType layout = DATALAYOUT(kNCHW), + int device = 0); + /// Get an Unsupported type. + static const Type* GetUnsupportedTy(); + /// Get an Void type. + static const Type* GetVoidTy(); + + static const Type* Get(DataType::ID type_id, + TargetType target = TARGET(kUnk), + PrecisionType precision = PRECISION(kUnk), + DataLayoutType layout = DATALAYOUT(kUnk), + int device = 0); + + TargetType target() const { return place_.target; } + PrecisionType precision() const { return place_.precision; } + DataLayoutType layout() const { return place_.layout; } + int16_t device() const { return place().device; } + const Place& place() const { return place_; } + const std::string& name() const { return name_; } + + bool operator==(const Type& other) { + return id_ == other.id() && place_ == other.place(); + } + friend STL::ostream& operator<<(STL::ostream& os, const Type& other); + + virtual ~Type() = default; + + protected: + /// One should avoid using this construct. + Type(ID id, + const std::string& name, + TargetType target = TargetType::kHost, + PrecisionType precision = PrecisionType::kFloat, + DataLayoutType layout = DataLayoutType::kNCHW, + int16_t device = 0) + : DataType(id), place_{target, precision, layout, device}, name_(name) {} + + Place place_; + const std::string name_; +}; + +// -------------------------------- compatible check --------------------------- +static bool TargetCompatibleTo(const Type& a, const Type& b) { + auto is_host = [](TargetType x) -> bool { + return x == TARGET(kHost) || x == TARGET(kX86) || x == TARGET(kARM); + }; + if (a.IsVoid() || b.IsVoid()) return true; + if (a.IsTensor() || b.IsTensor()) { + if (a.IsTensor() && b.IsTensor()) { + return is_host(a.target()) ? is_host(b.target()) + : a.target() == b.target(); + } + return false; + } + return true; +} + +static bool DataLayoutCompatibleTo(const Type& a, const Type& b) { + return a.IsVoid() || // + ((a.layout() == b.layout() || // + b.layout() == DATALAYOUT(kAny))); +} +static bool DataLayoutCompatible(const Type& a, const Type& b) { + return a.IsVoid() || b.IsVoid() || // + ((a.layout() == b.layout() || // + b.layout() == DATALAYOUT(kAny) || + a.layout() == DATALAYOUT(kAny))); +} + +static bool PrecisionCompatibleTo(const Type& a, const Type& b) { + return a.IsVoid() || // + (((a.IsTensor() && b.IsTensor()) || + (a.IsTensorList() && b.IsTensorList())) && + (a.precision() == b.precision() || // + b.precision() == PRECISION(kAny) || + a.precision() == PRECISION(kAny))); +} +static bool PrecisionCompatible(const Type& a, const Type& b) { + return a.IsVoid() || b.IsVoid() || // + (a.IsTensor() && b.IsTensor() && (a.precision() == b.precision() || // + b.precision() == PRECISION(kAny) || + a.precision() == PRECISION(kAny))); +} + +static bool DeviceCompatibleTo(const Type& a, const Type& b) { + return a.IsVoid() || // + (a.IsTensor() && b.IsTensor() && (a.device() == b.device())); +} + +// Can type 'a' be passed to 'b' directly. +static bool TypeCompatibleTo(const Type& a, const Type& b) { + return TargetCompatibleTo(a, b) && DataLayoutCompatibleTo(a, b) && + PrecisionCompatibleTo(a, b) && DeviceCompatibleTo(a, b); +} +static bool TypeCompatible(const Type& a, const Type& b) { + return TargetCompatibleTo(a, b) && DataLayoutCompatible(a, b) && + PrecisionCompatible(a, b) && DeviceCompatibleTo(a, b); +} + +/* + * ParamType is used to represent a data type of a parameter for the kernel. It + * can represent any Variable data type. + * The element_type_hash is the hash code of the element, it should be + * registered in the `TypeSystem`. + */ +struct ParamType { + const Type* type; + + ParamType() = default; + ParamType(const Type* type) : type(type) {} // NOLINT + + std::string DebugString() const { return type->name(); } +}; + +/* + * The data types of kernel parameters. It is used to track the type of kernel's + * inputs and outputs. + */ +struct ParamTypeRecorder { + std::map inputs; + std::map outputs; + + void RegisterInputType(const std::string& arg_name, const ParamType& type) { + Register(&inputs, arg_name, type); + } + + void RegisterOutputType(const std::string& arg_name, const ParamType& type) { + Register(&outputs, arg_name, type); + } + + private: + void Register(std::map* ts, + const std::string& arg_name, + ParamType type) { + (*ts)[arg_name] = type; + } +}; + +/* + * The ParamTypeRegistry help register the input and output data types for all + * the kernels. It is made singleton so that all the objects of the same kernel + * can share the same information. + * + * Usage: + * for register a kernel for FC operator. + * ParamTypeRegistry::Global().Register( + * "fc", {TARGET(kCUDA), PRECISION(kFloat)}, 0, + * {typeid(Tensor), {TARGET(kCUDA)}}); + */ +class ParamTypeRegistry { + public: + enum class IO : int { kInput = 0, kOutput }; + + template + /* + * Helper class for registering a ParamType for a Kernel. + * Usage: + * + * NewInstance("fc") + * .BindInput(0, {typeid(Tensor).hash_code(), {TARGET(kHost)}) + * .BindInput(1, {typeid(Tensor).hash_code(), {TARGET(kHost), + * PRECISION(kFloat)}); + */ + struct NewInstance { + explicit NewInstance(const std::string& kernel_type) + : kernel_type_(kernel_type) {} + + NewInstance& BindInput(const std::string& arg_name, + const ParamType& ptype) { + ParamTypeRegistry::Global().Register( + kernel_type_, Place{target, precision, layout}, arg_name, ptype); + return *this; + } + NewInstance& BindOutput(const std::string& arg_name, + const ParamType& ptype) { + ParamTypeRegistry::Global().Register( + kernel_type_, Place{target, precision, layout}, arg_name, ptype); + return *this; + } + + bool Finalize() { return true; } + + private: + std::string kernel_type_; + }; + + template + void Register(const std::string& kernel_type, + const Place& place, + const std::string& arg_name, + ParamType data_type) { + KernelIdTy key{kernel_type, place, io, arg_name}; + types_[key] = data_type; + CHECK(types_.count(key)); + } + + const ParamType* RetrieveInArgument(const Place& place, + const std::string& op_type, + const std::string& arg_name) { + return Retrieve(place, op_type, arg_name); + } + const ParamType* RetrieveOutArgument(const Place& place, + const std::string& op_type, + const std::string& arg_name) { + return Retrieve(place, op_type, arg_name); + } + + static ParamTypeRegistry& Global() { + static ParamTypeRegistry x; + return x; + } + + friend STL::ostream& operator<<(STL::ostream& os, + const ParamTypeRegistry& other) { + for (auto& item : other.types_) { + os << item.first << " " << item.second.DebugString() << "\n"; + } + return os; + } + + protected: + template + const ParamType* Retrieve(const Place& place, + const std::string& op_type, + const std::string& arg_name) { + KernelIdTy key{op_type, place, io, arg_name}; + auto it = types_.find(key); + if (it == types_.end()) return nullptr; + return &it->second; + } + + private: + ParamTypeRegistry() = default; + + public: + // Identification for a Kernel. + struct KernelIdTy { + std::string kernel_type; + Place place; + IO io; + std::string arg_name; + + size_t hash() const; + friend STL::ostream& operator<<(STL::ostream& os, const KernelIdTy& other); + }; + + using key_t = KernelIdTy; + struct KeyCmp { + bool operator()(const key_t& a, const key_t& b) const; + }; + + private: + std::map types_; +}; + +} // namespace lite +} // namespace paddle diff --git a/lite/core/type_system_test.cc b/lite/core/type_system_test.cc new file mode 100644 index 00000000000..224a779fcb9 --- /dev/null +++ b/lite/core/type_system_test.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/type_system.h" +#include + +namespace paddle { +namespace lite { + +TEST(TypeSystem, CheckDuplicateGet) { + auto* tensor_ty = + Type::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); + auto* tensor_ty1 = + Type::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); + + ASSERT_EQ(tensor_ty, tensor_ty1); + + ASSERT_EQ(tensor_ty->target(), TARGET(kHost)); + ASSERT_EQ(tensor_ty->precision(), PRECISION(kFloat)); + ASSERT_EQ(tensor_ty->layout(), DATALAYOUT(kNCHW)); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/core/types.cc b/lite/core/types.cc new file mode 100644 index 00000000000..ec89e83e580 --- /dev/null +++ b/lite/core/types.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/types.h" + +namespace paddle { +namespace lite { +namespace core { + +KernelPickFactor& KernelPickFactor::ConsiderDataLayout() { + data_ |= static_cast(Factor::DataLayoutFirst); + return *this; +} +KernelPickFactor& KernelPickFactor::ConsiderPrecision() { + data_ |= static_cast(Factor::PrecisionFirst); + return *this; +} +KernelPickFactor& KernelPickFactor::ConsiderTarget() { + data_ |= static_cast(Factor::TargetFirst); + return *this; +} +KernelPickFactor& KernelPickFactor::ConsiderDevice() { + data_ |= static_cast(Factor::DeviceFirst); + return *this; +} +bool KernelPickFactor::IsPrecisionConsidered() const { + return data_ & static_cast(Factor::PrecisionFirst); +} +bool KernelPickFactor::IsTargetConsidered() const { + return data_ & static_cast(Factor::TargetFirst); +} +bool KernelPickFactor::IsDataLayoutConsidered() const { + return data_ & static_cast(Factor::DataLayoutFirst); +} +bool KernelPickFactor::IsDeviceConsidered() const { + return data_ & static_cast(Factor::DeviceFirst); +} + +STL::ostream& operator<<(STL::ostream& os, const KernelPickFactor& k) { + std::stack bits; + auto data = k.data_; + while (data) { + bits.push(data % 2); + data /= 2; + } + int nbits = bits.size(); + for (size_t i = 0; i < sizeof(data) * 8 - nbits; i++) { + os << 0; + } + while (!bits.empty()) { + os << bits.top(); + bits.pop(); + } + return os; +} + +template <> +Type StdTypeToRepr() { + return Type::_int32; +} +template <> +Type StdTypeToRepr() { + return Type::_int64; +} +template <> +Type StdTypeToRepr() { + return Type::_float32; +} +template <> +Type StdTypeToRepr() { + return Type::_float64; +} +template <> +Type StdTypeToRepr() { + return Type::_string; +} +template <> +Type StdTypeToRepr() { + return Type::_bool; +} + +} // namespace core +} // namespace lite +} // namespace paddle diff --git a/lite/core/types.h b/lite/core/types.h new file mode 100644 index 00000000000..0664aba6b6f --- /dev/null +++ b/lite/core/types.h @@ -0,0 +1,116 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/api/paddle_place.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace core { + +/* + * Type representations used to represent standard types. + */ +// TODO(Superjomn) unify all the type representation across the lite framework. +enum class Type { + _unk = -1, + // primary types + _int32, + _int64, + _float32, + _float64, + _bool, + _string, + // primary list types + _list, + // enum type + _enum, + _float16, + // number of types + __num__, +}; + +template +Type StdTypeToRepr() { + return Type::_unk; +} +template <> +Type StdTypeToRepr(); +template <> +Type StdTypeToRepr(); +template <> +Type StdTypeToRepr(); +template <> +Type StdTypeToRepr(); +template <> +Type StdTypeToRepr(); + +// Factors that impact the kernel picking strategy. Multiple factors can be +// considered together by using statement like 'factor1 | factor2' +class KernelPickFactor { + public: + using value_type = unsigned char; + enum class Factor : int { + // The following factors are sorted by priority. + TargetFirst = 1, + PrecisionFirst = 1 << 1, + DataLayoutFirst = 1 << 2, + DeviceFirst = 1 << 3, + }; + + // Has any factors considered. + bool any_factor_considered() const { return data_; } + + KernelPickFactor& ConsiderTarget(); + // Prefer a specific target, e.g. prefer CUDA kernels. + KernelPickFactor& ConsiderPrecision(); + KernelPickFactor& ConsiderDataLayout(); + KernelPickFactor& ConsiderDevice(); + + bool IsTargetConsidered() const; + bool IsPrecisionConsidered() const; + bool IsDataLayoutConsidered() const; + bool IsDeviceConsidered() const; + + friend STL::ostream& operator<<(STL::ostream& os, const KernelPickFactor& k); + + private: + unsigned char data_{}; + lite_api::TargetType target_{TARGET(kUnk)}; +}; + +struct dim2 { + int x{}; + int y{}; + + dim2(int x, int y) : x(x), y(y) {} +}; + +struct dim3 { + int x{}; + int y{}; + int z{}; + + dim3(int x, int y, int z) : x(x), y(y), z(z) {} +}; + +using byte_t = uint8_t; + +} // namespace core +} // namespace lite +} // namespace paddle diff --git a/lite/core/types_test.cc b/lite/core/types_test.cc new file mode 100644 index 00000000000..9b7e5b6f05b --- /dev/null +++ b/lite/core/types_test.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/types.h" +#include + +namespace paddle { +namespace lite { +namespace core { + +TEST(KernelPickFactor, Default) { + KernelPickFactor factor; + ASSERT_FALSE(factor.IsTargetConsidered()); + ASSERT_FALSE(factor.IsPrecisionConsidered()); + ASSERT_FALSE(factor.IsDataLayoutConsidered()); +} + +TEST(KernelPickFactor, Set) { + KernelPickFactor factor; + factor.ConsiderTarget(); + ASSERT_TRUE(factor.IsTargetConsidered()); + factor.ConsiderPrecision(); + ASSERT_TRUE(factor.IsPrecisionConsidered()); + factor.ConsiderDataLayout(); + ASSERT_TRUE(factor.IsDataLayoutConsidered()); + + LOG(INFO) << "factor " << factor; +} + +} // namespace core +} // namespace lite +} // namespace paddle diff --git a/lite/core/variable.cc b/lite/core/variable.cc new file mode 100644 index 00000000000..a344da63f1b --- /dev/null +++ b/lite/core/variable.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/variable.h" + +namespace paddle { +namespace lite {} // namespace lite +} // namespace paddle diff --git a/lite/core/variable.h b/lite/core/variable.h new file mode 100644 index 00000000000..2c1e737a937 --- /dev/null +++ b/lite/core/variable.h @@ -0,0 +1,52 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "lite/core/tensor.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { + +using FeedFetchList = std::vector; + +class Variable { + public: + template + const T& Get() const { + return blob_.get(); + } + + template + T* GetMutable() { + if (!blob_.is()) blob_.set(); + return blob_.get_mutable(); + } + + template + bool IsType() { + return blob_.type() == typeid(T).hash_code(); + } + + private: + // variant blob_; + variant> + blob_; +}; + +} // namespace lite +} // namespace paddle diff --git a/lite/core/workspace.cc b/lite/core/workspace.cc new file mode 100644 index 00000000000..196536f9551 --- /dev/null +++ b/lite/core/workspace.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/workspace.h" diff --git a/lite/core/workspace.h b/lite/core/workspace.h new file mode 100644 index 00000000000..117b80aaa78 --- /dev/null +++ b/lite/core/workspace.h @@ -0,0 +1,83 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/memory.h" +#include "lite/core/types.h" +#include "lite/utils/macros.h" + +namespace paddle { +namespace lite { + +/* + * WorkSpace is a container that help to manage the temporary memory that are + * shared across kernels during the serial execution. + * + * Due to the mobile library size limit, a complex allocator or GC algorithm is + * not suitable here, one need to carefully manage the workspace inside a single + * kernel. + * + * NOTE + * + * For kernel developers, one need to call the workspace as follows: + * + * - call `WorkSpace::Global().Alloc()` if needed to allocate some temporary + * buffer. + */ +class WorkSpace { + public: + // Reset the workspace, and treat the workspace as empty. + void AllocReset() { cursor_ = 0; } + + // Allocate a memory buffer. + core::byte_t* Alloc(size_t size) { + buffer_.ResetLazy(target_, cursor_ + size); + auto* data = static_cast(buffer_.data()) + cursor_; + cursor_ += size; + return data; + } + + static WorkSpace& Global_Host() { + thread_local std::unique_ptr x(new WorkSpace(TARGET(kHost))); + return *x; + } + +#if defined(LITE_WITH_X86) + static WorkSpace& Global_X86() { return Global_Host(); } +#endif + +#if defined(LITE_WITH_ARM) + static WorkSpace& Global_ARM() { return Global_Host(); } +#endif + +#if defined(LITE_WITH_CUDA) + static WorkSpace& Global_CUDA() { + thread_local std::unique_ptr x(new WorkSpace(TARGET(kCUDA))); + return *x; + } +#endif + + private: + explicit WorkSpace(TargetType x) : target_(x) {} + + TargetType target_; + Buffer buffer_; + size_t cursor_; + + DISALLOW_COPY_AND_ASSIGN(WorkSpace); +}; + +} // namespace lite +} // namespace paddle diff --git a/lite/cuda/CMakeLists.txt b/lite/cuda/CMakeLists.txt new file mode 100644 index 00000000000..ded31c77e9c --- /dev/null +++ b/lite/cuda/CMakeLists.txt @@ -0,0 +1,8 @@ +if(NOT LITE_WITH_CUDA) + return() +endif() + +nv_library(target_wrapper_cuda SRCS target_wrapper.cc) +nv_library(cuda_blas SRCS blas.cc) + + diff --git a/lite/cuda/blas.cc b/lite/cuda/blas.cc new file mode 100644 index 00000000000..eacbb27bd85 --- /dev/null +++ b/lite/cuda/blas.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/cuda/blas.h" + +namespace paddle { +namespace lite { +namespace cuda { + +template <> +class Blas : public BlasBase { + using T = float; + + void sgemm(cublasOperation_t transa, + cublasOperation_t transb, // + int m, + int n, + int k, // + const T* alpha, // + const T* A, + int lda, // + const T* B, + int ldb, // + const T* beta, // + T* C, + int ldc) const { + CUBLAS_CALL(cublasSgemm(handle(), + transa, + transb, + m, + n, + k, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc)); + } +}; + +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/cuda/blas.h b/lite/cuda/blas.h new file mode 100644 index 00000000000..4e65d569a81 --- /dev/null +++ b/lite/cuda/blas.h @@ -0,0 +1,99 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "lite/cuda/cuda_utils.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace cuda { + +#define CUBLAS_CHECK(xxx) CHECK_EQ((xxx), CUBLAS_STATUS_SUCCESS); + +/* + * Some basic methods. + */ +struct BlasBase { + /* + BlasBase() { CUBLAS_CHECK(cublasCreate(&handle_)); } + ~BlasBase() { CUBLAS_CHECK(cublasDestroy(handle_)); } + */ + + void SetStream(cudaStream_t stream) { + CUBLAS_CHECK(cublasSetStream(handle_, stream)); + } + + cudaStream_t GetStream() const { + cudaStream_t stream; + CUBLAS_CHECK(cublasGetStream_v2(handle_, &stream)); + return stream; + } + + int GetVersion() const { + int version{}; + CUBLAS_CHECK(cublasGetVersion_v2(handle_, &version)); + return version; + } + + cublasHandle_t& handle() const { return handle_; } + + protected: + // Not thread-safe, should created for each thread. + // According to cublas doc. + mutable cublasHandle_t handle_; +}; + +// T: Scalar type. +template +class Blas : public lite::cuda::BlasBase { + public: + void sgemm(cublasOperation_t transa, + cublasOperation_t transb, // + int m, + int n, + int k, // + const T* alpha, // + const T* A, + int lda, // + const T* B, + int ldb, // + const T* beta, // + T* C, + int ldc) const { + CHECK_EQ(CUBLAS_STATUS_SUCCESS, + cublasSgemm(handle_, // + CUBLAS_OP_N, + CUBLAS_OP_N, // + m, + n, + k, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc)); + } +}; + +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/cuda/cuda_utils.h b/lite/cuda/cuda_utils.h new file mode 100644 index 00000000000..0db3c4b179d --- /dev/null +++ b/lite/cuda/cuda_utils.h @@ -0,0 +1,76 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "lite/utils/cp_logging.h" + +/* + * This file contains some CUDA specific utils. + */ + +// For quickly implementing the prototype, some of the following code snippets +// are borrowed from project MXNet, great thanks for the original developers. + +#define CHECK_CUDA_ERROR(msg) \ + { \ + auto e = cudaGetLastError(); \ + CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \ + } + +#define CUDA_CALL(func) \ + { \ + auto e = (func); \ + CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ + << "CUDA: " << cudaGetErrorString(e); \ + } + +#define CUBLAS_CALL(func) \ + { \ + auto e = (func); \ + CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ + << "cuBlas: " << paddle::lite::cuda::CublasErrorInfo(e); \ + } + +namespace paddle { +namespace lite { +namespace cuda { + +static const char* CublasErrorInfo(int error) { + switch (error) { +#define LITE_CUBLAS_ERROR_INFO(xx) \ + case xx: \ + return #xx; \ + break; + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_NOT_INITIALIZED); + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_ALLOC_FAILED); + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_INVALID_VALUE); + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_ARCH_MISMATCH); + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_MAPPING_ERROR); + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_EXECUTION_FAILED); + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_INTERNAL_ERROR); + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_NOT_SUPPORTED); + LITE_CUBLAS_ERROR_INFO(CUBLAS_STATUS_LICENSE_ERROR); +#undef LITE_CUBLAS_ERROR_INFO + default: + return "unknown error"; + } +} + +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/cuda/target_wrapper.cc b/lite/cuda/target_wrapper.cc new file mode 100644 index 00000000000..2a9702aebad --- /dev/null +++ b/lite/cuda/target_wrapper.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/cuda/target_wrapper.h" + +namespace paddle { +namespace lite { + +using TargetW = TargetWrapper; + +void* TargetW::Malloc(size_t size) { + void* ptr{}; + CHECK_EQ(cudaSuccess, cudaMalloc(&ptr, size)); + return ptr; +} + +void TargetW::Free(void* ptr) { CHECK_EQ(cudaSuccess, cudaFree(ptr)); } + +void TargetW::MemcpySync(void* dst, + const void* src, + size_t size, + IoDirection dir) { + switch (dir) { + case IoDirection::DtoD: + CHECK(cudaSuccess == + cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice)); + break; + case IoDirection::HtoD: + CHECK(cudaSuccess == cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice)); + break; + case IoDirection::DtoH: + CHECK(cudaSuccess == cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost)); + break; + default: + LOG(FATAL) << "Unsupported IoDirection " << static_cast(dir); + } +} + +void TargetW::MemcpyAsync(void* dst, + const void* src, + size_t size, + IoDirection dir, + const stream_t& stream) { + switch (dir) { + case IoDirection::DtoD: + CHECK(cudaSuccess == + cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream)); + break; + case IoDirection::HtoD: + CHECK(cudaSuccess == + cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, stream)); + break; + case IoDirection::DtoH: + CHECK(cudaSuccess == + cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, stream)); + break; + default: + LOG(FATAL) << "Unsupported IoDirection " << static_cast(dir); + } +} + +} // namespace lite +} // namespace paddle diff --git a/lite/cuda/target_wrapper.h b/lite/cuda/target_wrapper.h new file mode 100644 index 00000000000..d43172a99f8 --- /dev/null +++ b/lite/cuda/target_wrapper.h @@ -0,0 +1,29 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/target_wrapper.h" + +namespace paddle { +namespace lite { +namespace cuda { + +using TargetWrap = TargetWrapper; +using TargetWrapAsync = TargetWrapper; + +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/demo/cxx/Makefile.def b/lite/demo/cxx/Makefile.def new file mode 100644 index 00000000000..f0a0ec1dcb1 --- /dev/null +++ b/lite/demo/cxx/Makefile.def @@ -0,0 +1,35 @@ +CXX_DEFINES = -DARM_WITH_OMP -DHPPL_STUB_FUNC -DLITE_WITH_ARM -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK \ + -DLITE_WITH_LINUX -DPADDLE_DISABLE_PROFILER -DPADDLE_NO_PYTHON -DPADDLE_WITH_TESTING +LDFLAGS = -latomic -pthread -ldl + +SYSROOT_COMPLILE = --sysroot=/opt/android-ndk-r17c/sysroot + +THIRD_PARTY_LIBS = ../../../third_party/gflags/lib/libgflags.a + +SYSTEM_INCLUDES = -I/opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/include \ + -I/opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++abi/include \ + -I/opt/android-ndk-r17c/sources/android/support/include \ + -I/opt/android-ndk-r17c/sysroot/usr/include \ + +THIRD_PARTY_INCLUDES = -I../../../third_party/gflags/include + +ifeq ($(ARM_ABI), arm8) + CC = /opt/android-ndk-r17c/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/bin/aarch64-linux-android-g++ + CXX_FLAGS = -funwind-tables -no-canonical-prefixes -D__ANDROID_API__=22 -fexceptions -frtti -std=c++11 -fopenmp -O3 -DNDEBUG -fPIE + CXXFLAGS_LINK = $(CXX_FLAGS) -pie -Wl,--gc-sections + SYSROOT_LINK = --sysroot=/opt/android-ndk-r17c/platforms/android-24/arch-arm64 + SYSTEM_LIBS = /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/arm64-v8a/libc++_static.a \ + /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/arm64-v8a/libc++abi.a + INCLUDES = $(SYSTEM_INCLUDES) -I/opt/android-ndk-r17c/sysroot/usr/include/aarch64-linux-android $(THIRD_PARTY_INCLUDES) +else + CC = /opt/android-ndk-r17c/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin/arm-linux-androideabi-g++ + CXX_FLAGS = -march=armv7-a -mthumb -mfpu=neon -mfloat-abi=softfp -funwind-tables -no-canonical-prefixes \ + -D__ANDROID_API__=22 -fexceptions -frtti -std=c++11 -fopenmp -O3 -DNDEBUG -fPIE + CXXFLAGS_LINK = $(CXX_FLAGS) -pie -Wl,--fix-cortex-a8 -Wl,--gc-sections -Wl,-z,nocopyreloc + SYSROOT_LINK = --sysroot=/opt/android-ndk-r17c/platforms/android-22/arch-arm + SYSTEM_LIBS = /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libc++_static.a \ + /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libc++abi.a \ + /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libandroid_support.a \ + /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libunwind.a + INCLUDES = $(SYSTEM_INCLUDES) -I/opt/android-ndk-r17c/sysroot/usr/include/arm-linux-androideabi $(THIRD_PARTY_INCLUDES) +endif diff --git a/lite/demo/cxx/README.md b/lite/demo/cxx/README.md new file mode 100644 index 00000000000..ec72c044e3f --- /dev/null +++ b/lite/demo/cxx/README.md @@ -0,0 +1,42 @@ +# C++ Demo +1. 使用`lite/tools/Dockerfile.mobile`生成docker镜像 +2. 运行并进入docker镜像环境,执行`wget http://paddle-inference-dist.bj.bcebos.com/lite_release/r0.1/inference_lite_lib.android.armv8.tar.gz `下载所需demo环境。(armv7 demo可使用命令`wget http://paddle-inference-dist.bj.bcebos.com/lite_release/r0.1/inference_lite_lib.android.armv7.tar.gz` 进行下载)。 +3. 解压下载文件`tar zxvf inference_lite_lib.android.armv8.tar.gz ` +4. 执行以下命令准备模拟器环境 +```shell +# armv8 +adb kill-server +adb devices | grep emulator | cut -f1 | while read line; do adb -s $line emu kill; done +echo n | avdmanager create avd -f -n paddle-armv8 -k "system-images;android-24;google_apis;arm64-v8a" +echo -ne '\n' | ${ANDROID_HOME}/emulator/emulator -avd paddle-armv8 -noaudio -no-window -gpu off -port 5554 & +sleep 1m +``` +```shell +# armv7 +adb kill-server +adb devices | grep emulator | cut -f1 | while read line; do adb -s $line emu kill; done +echo n | avdmanager create avd -f -n paddle-armv7 -k "system-images;android-24;google_apis;armeabi-v7a" +echo -ne '\n' | ${ANDROID_HOME}/emulator/emulator -avd paddle-armv7 -noaudio -no-window -gpu off -port 5554 & +sleep 1m +``` +5. 准备模型、编译并运行完整api的demo +```shell +cd inference_lite_lib.android.armv8/demo/cxx/mobile_full +wget http://paddle-inference-dist.bj.bcebos.com/mobilenet_v1.tar.gz +tar zxvf mobilenet_v1.tar.gz +make +adb -s emulator-5554 push mobilenet_v1 /data/local/tmp/ +adb -s emulator-5554 push mobilenetv1_full_api /data/local/tmp/ +adb -s emulator-5554 shell chmod +x /data/local/tmp/mobilenetv1_full_api +adb -s emulator-5554 shell "/data/local/tmp/mobilenetv1_full_api --model_dir=/data/local/tmp/mobilenet_v1 --optimized_model_dir=/data/local/tmp/mobilenet_v1.opt" +``` +运行成功将在控制台输出预测结果的前10个类别的预测概率 + +6. 编译并运行轻量级api的demo +```shell +cd ../mobile_light +make +adb -s emulator-5554 push mobilenetv1_light_api /data/local/tmp/ +adb -s emulator-5554 shell chmod +x /data/local/tmp/mobilenetv1_light_api +adb -s emulator-5554 shell "/data/local/tmp/mobilenetv1_light_api --model_dir=/data/local/tmp/mobilenet_v1.opt" +``` diff --git a/lite/demo/cxx/makefiles/mobile_full/Makefile.android.armv7 b/lite/demo/cxx/makefiles/mobile_full/Makefile.android.armv7 new file mode 100644 index 00000000000..6c9b7413f49 --- /dev/null +++ b/lite/demo/cxx/makefiles/mobile_full/Makefile.android.armv7 @@ -0,0 +1,22 @@ +ARM_ABI = arm7 +export ARM_ABI + +include ../Makefile.def + +LITE_ROOT=../../../ + +CXX_INCLUDES = $(INCLUDES) -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = $(THIRD_PARTY_LIBS) $(LITE_ROOT)/cxx/lib/libpaddle_api_full_bundled.a $(SYSTEM_LIBS) + +mobilenetv1_full_api: mobilenetv1_full_api.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobilenetv1_full_api.o -o mobilenetv1_full_api $(CXX_LIBS) $(LDFLAGS) + +mobilenetv1_full_api.o: mobilenetv1_full_api.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o mobilenetv1_full_api.o -c mobilenetv1_full_api.cc + + +.PHONY: clean +clean: + rm mobilenetv1_full_api.o + rm mobilenetv1_full_api diff --git a/lite/demo/cxx/makefiles/mobile_full/Makefile.android.armv8 b/lite/demo/cxx/makefiles/mobile_full/Makefile.android.armv8 new file mode 100644 index 00000000000..7735f74d109 --- /dev/null +++ b/lite/demo/cxx/makefiles/mobile_full/Makefile.android.armv8 @@ -0,0 +1,22 @@ +ARM_ABI = arm8 +export ARM_ABI + +include ../Makefile.def + +LITE_ROOT=../../../ + +CXX_INCLUDES = $(INCLUDES) -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = $(THIRD_PARTY_LIBS) $(LITE_ROOT)/cxx/lib/libpaddle_api_full_bundled.a $(SYSTEM_LIBS) + +mobilenetv1_full_api: mobilenetv1_full_api.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobilenetv1_full_api.o -o mobilenetv1_full_api $(CXX_LIBS) $(LDFLAGS) + +mobilenetv1_full_api.o: mobilenetv1_full_api.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o mobilenetv1_full_api.o -c mobilenetv1_full_api.cc + + +.PHONY: clean +clean: + rm mobilenetv1_full_api.o + rm mobilenetv1_full_api diff --git a/lite/demo/cxx/makefiles/mobile_light/Makefile.android.armv7 b/lite/demo/cxx/makefiles/mobile_light/Makefile.android.armv7 new file mode 100644 index 00000000000..66a6d8f31dc --- /dev/null +++ b/lite/demo/cxx/makefiles/mobile_light/Makefile.android.armv7 @@ -0,0 +1,22 @@ +ARM_ABI = arm7 +export ARM_ABI + +include ../Makefile.def + +LITE_ROOT=../../../ + +CXX_INCLUDES = $(INCLUDES) -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = $(THIRD_PARTY_LIBS) $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) + +mobilenetv1_light_api: mobilenetv1_light_api.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobilenetv1_light_api.o -o mobilenetv1_light_api $(CXX_LIBS) $(LDFLAGS) + +mobilenetv1_light_api.o: mobilenetv1_light_api.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o mobilenetv1_light_api.o -c mobilenetv1_light_api.cc + + +.PHONY: clean +clean: + rm mobilenetv1_light_api.o + rm mobilenetv1_light_api diff --git a/lite/demo/cxx/makefiles/mobile_light/Makefile.android.armv8 b/lite/demo/cxx/makefiles/mobile_light/Makefile.android.armv8 new file mode 100644 index 00000000000..91b281c49c8 --- /dev/null +++ b/lite/demo/cxx/makefiles/mobile_light/Makefile.android.armv8 @@ -0,0 +1,22 @@ +ARM_ABI = arm8 +export ARM_ABI + +include ../Makefile.def + +LITE_ROOT=../../../ + +CXX_INCLUDES = $(INCLUDES) -I$(LITE_ROOT)/cxx/include + +CXX_LIBS = $(THIRD_PARTY_LIBS) $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) + +mobilenetv1_light_api: mobilenetv1_light_api.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobilenetv1_light_api.o -o mobilenetv1_light_api $(CXX_LIBS) $(LDFLAGS) + +mobilenetv1_light_api.o: mobilenetv1_light_api.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o mobilenetv1_light_api.o -c mobilenetv1_light_api.cc + + +.PHONY: clean +clean: + rm mobilenetv1_light_api.o + rm mobilenetv1_light_api diff --git a/lite/demo/cxx/mobile_full/mobilenetv1_full_api.cc b/lite/demo/cxx/mobile_full/mobilenetv1_full_api.cc new file mode 100644 index 00000000000..43e3662982e --- /dev/null +++ b/lite/demo/cxx/mobile_full/mobilenetv1_full_api.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle_api.h" // NOLINT +#include "paddle_use_kernels.h" // NOLINT +#include "paddle_use_ops.h" // NOLINT +#include "paddle_use_passes.h" // NOLINT + +using namespace paddle::lite_api; // NOLINT + +DEFINE_string(model_dir, "", "Model dir path."); +DEFINE_string(optimized_model_dir, "", "Optimized model dir."); + +int64_t ShapeProduction(const shape_t& shape) { + int64_t res = 1; + for (auto i : shape) res *= i; + return res; +} + +void RunModel() { + // 1. Set CxxConfig + CxxConfig config; + config.set_model_dir(FLAGS_model_dir); + config.set_preferred_place(Place{TARGET(kARM), PRECISION(kFloat)}); + config.set_valid_places({Place{TARGET(kARM), PRECISION(kFloat)}}); + + // 2. Create PaddlePredictor by CxxConfig + std::shared_ptr predictor = + CreatePaddlePredictor(config); + + // 3. Prepare input data + std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); + input_tensor->Resize(shape_t({1, 3, 224, 224})); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < ShapeProduction(input_tensor->shape()); ++i) { + data[i] = 1; + } + + // 4. Run predictor + predictor->Run(); + + // 5. Get output + std::unique_ptr output_tensor( + std::move(predictor->GetOutput(0))); + printf("Output dim: %d\n", output_tensor->shape()[1]); + for (int i = 0; i < ShapeProduction(output_tensor->shape()); i += 100) { + printf("Output[%d]: %f\n", i, output_tensor->data()[i]); + } + + // 6. Save optimition model + predictor->SaveOptimizedModel(FLAGS_optimized_model_dir, + LiteModelType::kNaiveBuffer); +} + +int main(int argc, char** argv) { + google::ParseCommandLineFlags(&argc, &argv, true); + RunModel(); + return 0; +} diff --git a/lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc b/lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc new file mode 100644 index 00000000000..e1833814cad --- /dev/null +++ b/lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle_api.h" // NOLINT +#include "paddle_use_kernels.h" // NOLINT +#include "paddle_use_ops.h" // NOLINT + +using namespace paddle::lite_api; // NOLINT + +DEFINE_string(model_dir, "", "Model dir path."); + +int64_t ShapeProduction(const shape_t& shape) { + int64_t res = 1; + for (auto i : shape) res *= i; + return res; +} + +void RunModel() { + // 1. Set MobileConfig + MobileConfig config; + config.set_model_dir(FLAGS_model_dir); + + // 2. Create PaddlePredictor by MobileConfig + std::shared_ptr predictor = + CreatePaddlePredictor(config); + + // 3. Prepare input data + std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); + input_tensor->Resize({1, 3, 224, 224}); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < ShapeProduction(input_tensor->shape()); ++i) { + data[i] = 1; + } + + // 4. Run predictor + predictor->Run(); + + // 5. Get output + std::unique_ptr output_tensor( + std::move(predictor->GetOutput(0))); + printf("Output dim: %d\n", output_tensor->shape()[1]); + for (int i = 0; i < ShapeProduction(output_tensor->shape()); i += 100) { + printf("Output[%d]: %f\n", i, output_tensor->data()[i]); + } +} + +int main(int argc, char** argv) { + google::ParseCommandLineFlags(&argc, &argv, true); + RunModel(); + return 0; +} diff --git a/lite/demo/java/README.md b/lite/demo/java/README.md new file mode 100644 index 00000000000..904726d744b --- /dev/null +++ b/lite/demo/java/README.md @@ -0,0 +1,118 @@ +# Java Android Demo + +要编译和跑起 ./android 文件夹下的 Android demo 程序 PaddlePredictor,你需要准备: + +1. 一台能运行安卓程序的安卓手机 +2. 一台带有AndroidStudio的开发机 + +## 编译 + +首先在PaddleLite的开发Docker镜像中,拉取最新PaddleLite代码,编译对应你手机架构的预测库, +下面我们以arm8 架构举例。进入paddlelite 目录,运行以下cmake 和make 命令: + +``` +mkdir -p build.lite.android.arm8.gcc +cd build.lite.android.arm8.gcc + +cmake .. \ +-DWITH_GPU=OFF \ +-DWITH_MKL=OFF \ +-DWITH_LITE=ON \ +-DLITE_WITH_JAVA=ON \ +-DLITE_WITH_CUDA=OFF \ +-DLITE_WITH_X86=OFF \ +-DLITE_WITH_ARM=ON \ +-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \ +-DWITH_TESTING=OFF \ +-DLITE_SHUTDOWN_LOG=ON \ +-DLITE_ON_TINY_PUBLISH=ON \ +-DARM_TARGET_OS=android -DARM_TARGET_ARCH_ABI=armv8 -DARM_TARGET_LANG=gcc + +make publish_inference -j4 +``` + +Make完成后查看要存在 +``` +build.lite.android.arm8.gcc/lite/api/android/jni/native/libpaddle_lite_jni.so +build.lite.android.arm8.gcc/lite/api/android/jni/PaddlePredictor.jar +``` +这两个文件。他们分别为 PaddleLite c++ 动态链接库和 Java jar 包。包含 PaddleLite Java API,接下来 Android Java 代 +码会使用这些api + +## 准备 demo 需要的其他文件 + +Demo 除了代码,还需要准备 JNI .so 库(上节提到的`libpaddle_lite_jni.so`),Java .jar 包(上文提到的 +`PaddlePredictor.jar` ),和模型文件。我们提供了自动化的脚本和手动拷贝两种方法,用户可以根据自己需要选择: + +### 脚本方法 + +进入 `build.lite.android.armv8/inference_lite_lib.android.armv8/demo/java/android/`,我们准备了 +一个脚本`prepare_demo.bash`,脚本输入一个参数,为你要拷贝的.so 对应的架构文件夹名。 + +例如运行 +``` +bash prepare_demo.bash armv8 +``` +该脚本自动下载并解压缩模型文件,拷贝了 .jar 包进demo,还有生成的.so包进 `PaddlePredictor/app/src/main/jinLibs/架构文件夹下`, +在我们这个例子里,armv8 就是架构文件夹。备注:这种方式构建的 demo 在 armv8 手机运行正常。如果要 demo 程序 +在别的手机架构(如 armv7)上也运行正常,需要添加别的架构。 + +### 手动拷贝方法 + +接下来我们介绍手动拷贝,如果使用了脚本,那么可以跳过以下手动方法的介绍。 + +### 把 .so 动态库和 .jar 拷贝进安卓demo程序: + +把本文件夹下 demo/PaddlePredictor 载入到AndroidStudio。把上一步提到的`libpaddle_lite_jni.so` +拷贝进 `PaddlePredictor/app/src/main/jinLibs/架构文件夹下` 比如文件夹arm8里要包含该 .so文件: +把上一步提到的 `PaddlePredictor.jar` 拷贝进 `PaddlePredictor/app/libs` 下 + +### 把demo使用到的模型文件拷贝进安卓程序: + +下载我们的5个模型文件,并解压缩到 `PaddlePredictor/app/src/main/assets` 这个文件夹中 +需要拷贝的模型文件和下载地址: + + inception_v4_simple_opt.nb http://paddle-inference-dist.bj.bcebos.com/inception_v4_simple_opt.nb.tar.gz + lite_naive_model_opt.nb http://paddle-inference-dist.bj.bcebos.com/lite_naive_model_opt.nb.tar.gz + mobilenet_v1_opt.nb http://paddle-inference-dist.bj.bcebos.com/mobilenet_v1_opt.nb.tar.gz + mobilenet_v2_relu_opt.nb http://paddle-inference-dist.bj.bcebos.com/mobilenet_v2_relu_opt.nb.tar.gz + resnet50_opt.nb http://paddle-inference-dist.bj.bcebos.com/resnet50_opt.nb.tar.gz + +下载完后,assets文件夹里要包含解压后的上面五个模型文件夹,但demo里不需要保存原压缩.tar.gz 文件。 + +## 运行 Android 程序结果 + +以上准备工作完成,就可以开始Build ,安装,和跑安卓demo程序。当你运行PaddlePredictor 程序时,大概会等10秒, +然后看到类似以下字样: + + lite_naive_model output: 50.213173, -28.872887 + expected: 50.2132, -28.8729 + + inception_v4_simple test:true + time: xxx ms + + resnet50 test:true + time: xxx ms + + mobilenet_v1 test:true + time: xxx ms + + mobilenet_v2 test:true + time: xxx ms + +该 demo 程序跑我们的 5 个模型,第一个模型结果将真正的头两个数字输出,并在第二行附上期望的正确值。你应该要 +看到他们的误差小于0.001。后面四个模型如果你看到 test:true 字样,说明模型输出通过了我们在 demo 程序里对其输出 +的测试。time 代表该测试花费的时间。 + +## Android demo 程序的 Instrumented Test + +本节对于想通过命令行自动化demo程序的测试人员 + +要通过命令行运行demo程序在手机上,进入 demo 的 `PaddlePredictor` 文件夹,运行 +``` +./gradlew init +``` +以上命令只要运行一次,其初始化demo能运行的任务。之后可以通过以下命令运行我们的测试 +``` +./gradlew connectedAndroidTest +``` diff --git a/lite/demo/java/android/PaddlePredictor/.gitignore b/lite/demo/java/android/PaddlePredictor/.gitignore new file mode 100644 index 00000000000..2b75303ac58 --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/.gitignore @@ -0,0 +1,13 @@ +*.iml +.gradle +/local.properties +/.idea/caches +/.idea/libraries +/.idea/modules.xml +/.idea/workspace.xml +/.idea/navEditor.xml +/.idea/assetWizardSettings.xml +.DS_Store +/build +/captures +.externalNativeBuild diff --git a/lite/demo/java/android/PaddlePredictor/app/.gitignore b/lite/demo/java/android/PaddlePredictor/app/.gitignore new file mode 100644 index 00000000000..796b96d1c40 --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/app/.gitignore @@ -0,0 +1 @@ +/build diff --git a/lite/demo/java/android/PaddlePredictor/app/build.gradle b/lite/demo/java/android/PaddlePredictor/app/build.gradle new file mode 100644 index 00000000000..b86d2f8e3dd --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/app/build.gradle @@ -0,0 +1,28 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 28 + defaultConfig { + applicationId "com.baidu.paddle.lite" + minSdkVersion 23 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar']) + implementation 'com.android.support:appcompat-v7:28.0.0' + implementation 'com.android.support.constraint:constraint-layout:1.1.3' + testImplementation 'junit:junit:4.12' + androidTestImplementation 'com.android.support.test:runner:1.0.2' + androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2' +} diff --git a/lite/demo/java/android/PaddlePredictor/app/proguard-rules.pro b/lite/demo/java/android/PaddlePredictor/app/proguard-rules.pro new file mode 100644 index 00000000000..f1b424510da --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/lite/demo/java/android/PaddlePredictor/app/src/androidTest/java/com/baidu/paddle/lite/ExampleInstrumentedTest.java b/lite/demo/java/android/PaddlePredictor/app/src/androidTest/java/com/baidu/paddle/lite/ExampleInstrumentedTest.java new file mode 100644 index 00000000000..ca40855be7a --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/app/src/androidTest/java/com/baidu/paddle/lite/ExampleInstrumentedTest.java @@ -0,0 +1,114 @@ +package com.baidu.paddle.lite; + +import android.content.Context; +import android.support.test.InstrumentationRegistry; +import android.support.test.runner.AndroidJUnit4; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.util.ArrayList; + +import static org.junit.Assert.*; + +/** + * Lite example Instrument test + */ +@RunWith(AndroidJUnit4.class) +public class ExampleInstrumentedTest { + @Test + public void naiveModel_isCorrect() { + Context appContext = InstrumentationRegistry.getTargetContext(); + ArrayList result = MainActivity.setInputAndRunNaiveModel("lite_naive_model", appContext); + Tensor output = result.get(0); + long[] shape = output.shape(); + assertEquals(2, shape.length); + assertEquals(100L, shape[0]); + assertEquals(500L, shape[1]); + + float[] outputBuffer = output.getFloatData(); + assertEquals(50000, outputBuffer.length); + assertEquals(50.2132f, outputBuffer[0], 1e-4); + assertEquals(-28.8729, outputBuffer[1], 1e-4); + } + + @Test + public void inceptionV4Simple_isCorrect() { + Context appContext = InstrumentationRegistry.getTargetContext(); + ArrayList result = MainActivity.setInputAndRunImageModel("inception_v4_simple", appContext); + float[] expected = {0.0011684548f, 0.0010390386f, 0.0011301535f, 0.0010133048f, + 0.0010259597f, 0.0010982729f, 0.00093195855f, 0.0009141837f, + 0.00096620916f, 0.00089982944f, 0.0010064574f, 0.0010474789f, + 0.0009782845f, 0.0009230255f, 0.0010548076f, 0.0010974824f, + 0.0010612885f, 0.00089107914f, 0.0010112736f, 0.00097655767f}; + assertImageResult(expected, result); + } + + @Test + public void mobilenetV1_isCorrect() { + Context appContext = InstrumentationRegistry.getTargetContext(); + ArrayList result = MainActivity.setInputAndRunImageModel("mobilenet_v1", appContext); + float[] expected = {0.00019130898f, 9.467885e-05f, 0.00015971427f, 0.0003650665f, + 0.00026431272f, 0.00060884043f, 0.0002107942f, 0.0015819625f, + 0.0010323516f, 0.00010079765f, 0.00011006987f, 0.0017364529f, + 0.0048292773f, 0.0013995157f, 0.0018453331f, 0.0002428986f, + 0.00020211363f, 0.00013668182f, 0.0005855956f, 0.00025901722f}; + assertImageResult(expected, result); + } + + @Test + public void mobilenetV2Relu_isCorrect() { + Context appContext = InstrumentationRegistry.getTargetContext(); + ArrayList result = MainActivity.setInputAndRunImageModel("mobilenet_v2_relu", appContext); + float[] expected = {0.00017082224f, 5.699624e-05f, 0.000260885f, 0.00016412718f, + 0.00034818667f, 0.00015230637f, 0.00032959113f, 0.0014772735f, + 0.0009059976f, 9.5378724e-05f, 5.386537e-05f, 0.0006427285f, + 0.0070957416f, 0.0016094646f, 0.0018807327f, 0.00010506048f, + 6.823785e-05f, 0.00012269315f, 0.0007806194f, 0.00022354358f}; + assertImageResult(expected, result); + } + + @Test + public void resnet50_isCorrect() { + Context appContext = InstrumentationRegistry.getTargetContext(); + ArrayList result = MainActivity.setInputAndRunImageModel("resnet50", appContext); + float[] expected = {0.00024139918f, 0.00020566184f, 0.00022418296f, 0.00041731037f, + 0.0005366107f, 0.00016948722f, 0.00028638865f, 0.0009257241f, + 0.00072681636f, 8.531815e-05f, 0.0002129998f, 0.0021168243f, + 0.006387163f, 0.0037145028f, 0.0012812682f, 0.00045948103f, + 0.00013535398f, 0.0002483765f, 0.00076759676f, 0.0002773295f}; + assertImageResult(expected, result); + } + + public void assertImageResult(float[] expected, ArrayList result) { + assertEquals(2, result.size()); + assertEquals(20, expected.length); + + Tensor tensor = result.get(0); + Tensor tensor1 = result.get(1); + long[] shape = tensor.shape(); + long[] shape1 = tensor1.shape(); + + assertEquals(2, shape.length); + assertEquals(2, shape1.length); + + assertEquals(1L, shape[0]); + assertEquals(1L, shape1[0]); + assertEquals(1000L, shape[1]); + assertEquals(1000L, shape1[1]); + + float[] output = tensor.getFloatData(); + float[] output1 = tensor.getFloatData(); + + assertEquals(1000, output.length); + assertEquals(1000, output1.length); + for (int i = 0; i < output.length; ++i) { + assertEquals(output[i], output1[i], 1e-6f); + } + int step = 50; + for (int i = 0; i < expected.length; ++i) { + assertEquals(output[i * step], expected[i], 1e-6f); + } + } +} + diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/AndroidManifest.xml b/lite/demo/java/android/PaddlePredictor/app/src/main/AndroidManifest.xml new file mode 100644 index 00000000000..240078a5877 --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/app/src/main/AndroidManifest.xml @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/assets/README.txt b/lite/demo/java/android/PaddlePredictor/app/src/main/assets/README.txt new file mode 100644 index 00000000000..14aace8f9b4 --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/app/src/main/assets/README.txt @@ -0,0 +1,8 @@ +After build PaddleLite in your build folder, download and decompress the +following models in this directory: + +inception_v4_simple_opt.nb http://paddle-inference-dist.bj.bcebos.com/inception_v4_simple_opt.nb.tar.gz +lite_naive_model_opt.nb http://paddle-inference-dist.bj.bcebos.com/lite_naive_model_opt.nb.tar.gz +mobilenet_v1_opt.nb http://paddle-inference-dist.bj.bcebos.com/mobilenet_v1_opt.nb.tar.gz +mobilenet_v2_relu_opt.nb http://paddle-inference-dist.bj.bcebos.com/mobilenet_v2_relu_opt.nb.tar.gz +resnet50_opt.nb http://paddle-inference-dist.bj.bcebos.com/resnet50_opt.nb.tar.gz diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/java/com/baidu/paddle/lite/MainActivity.java b/lite/demo/java/android/PaddlePredictor/app/src/main/java/com/baidu/paddle/lite/MainActivity.java new file mode 100644 index 00000000000..d2b04e6b91c --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/app/src/main/java/com/baidu/paddle/lite/MainActivity.java @@ -0,0 +1,204 @@ +package com.baidu.paddle.lite; + +import android.content.Context; +import android.support.v7.app.AppCompatActivity; +import android.os.Bundle; +import android.widget.TextView; + +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Date; + +public class MainActivity extends AppCompatActivity { + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + + String textOutput = ""; + Tensor output; + output = setInputAndRunNaiveModel("lite_naive_model_opt.nb", this); + textOutput += "lite_naive_model output: " + output.getFloatData()[0] + ", " + + output.getFloatData()[1] + "\n"; + textOutput += "expected: 50.2132, -28.8729\n"; + + Date start = new Date(); + output = setInputAndRunImageModel("inception_v4_simple_opt.nb", this); + Date end = new Date(); + textOutput += "\ninception_v4_simple test: " + testInceptionV4Simple(output) + "\n"; + textOutput += "time: " + (end.getTime() - start.getTime()) + " ms\n"; + + start = new Date(); + output = setInputAndRunImageModel("resnet50_opt.nb", this); + end = new Date(); + textOutput += "\nresnet50 test: " + testResnet50(output) + "\n"; + textOutput += "time: " + (end.getTime() - start.getTime()) + " ms\n"; + + start = new Date(); + output = setInputAndRunImageModel("mobilenet_v1_opt.nb", this); + end = new Date(); + textOutput += "\nmobilenet_v1 test: " + testMobileNetV1(output) + "\n"; + textOutput += "time: " + (end.getTime() - start.getTime()) + " ms\n"; + + start = new Date(); + output = setInputAndRunImageModel("mobilenet_v2_relu_opt.nb", this); + end = new Date(); + textOutput += "\nmobilenet_v2 test: " + testMobileNetV2Relu(output) + "\n"; + textOutput += "time: " + (end.getTime() - start.getTime()) + " ms\n"; + + TextView textView = findViewById(R.id.text_view); + textView.setText(textOutput); + } + + public static String copyFromAssetsToCache(String modelPath, Context context) { + String newPath = context.getCacheDir() + "/" + modelPath; + // String newPath = "/sdcard/" + modelPath; + File desDir = new File(newPath); + + try { + if (!desDir.exists()) { + desDir.mkdir(); + } + for (String fileName : context.getAssets().list(modelPath)) { + InputStream stream = context.getAssets().open(modelPath + "/" + fileName); + OutputStream output = new BufferedOutputStream(new FileOutputStream(newPath + "/" + fileName)); + + byte data[] = new byte[1024]; + int count; + + while ((count = stream.read(data)) != -1) { + output.write(data, 0, count); + } + + output.flush(); + output.close(); + stream.close(); + } + + } catch (Exception e) { + throw new RuntimeException(e); + } + + return desDir.getPath(); + } + + public static Tensor runModel(String modelName, long[] dims, float[] inputBuffer, Context context) { + String modelPath = copyFromAssetsToCache(modelName, context); + + MobileConfig config = new MobileConfig(); + config.setModelDir(modelPath); + PaddlePredictor predictor = PaddlePredictor.createPaddlePredictor(config); + + Tensor input = predictor.getInput(0); + input.resize(dims); + input.setData(inputBuffer); + predictor.run(); + + Tensor output = predictor.getOutput(0); + + return output; + } + + + public static Tensor setInputAndRunNaiveModel(String modelName, Context context) { + long[] dims = {100, 100}; + float[] inputBuffer = new float[10000]; + for (int i = 0; i < 10000; ++i) { + inputBuffer[i] = i; + } + return runModel(modelName, dims, inputBuffer, context); + } + + /** + * Input size is 3 * 224 * 224 + * + * @param modelName + * @return + */ + public static Tensor setInputAndRunImageModel(String modelName, Context context) { + long[] dims = {1, 3, 224, 224}; + int item_size = 3 * 224 * 224; + float[] inputBuffer = new float[item_size]; + for (int i = 0; i < item_size; ++i) { + inputBuffer[i] = 1; + } + return runModel(modelName, dims, inputBuffer, context); + } + + public boolean equalsNear(float a, float b, float delta) { + return a >= b - delta && a <= b + delta; + } + + public boolean expectedResult(float[] expected, Tensor result) { + if (expected.length != 20) { + return false; + } + + long[] shape = result.shape(); + + if (shape.length != 2) { + return false; + } + + if (shape[0] != 1 || shape[1] != 1000) { + return false; + } + + float[] output = result.getFloatData(); + + if (output.length != 1000) { + return false; + } + + int step = 50; + for (int i = 0; i < expected.length; ++i) { + if (!equalsNear(output[i * step], expected[i], 1e-6f)) { + return false; + } + } + + return true; + } + + public boolean testInceptionV4Simple(Tensor output) { + float[] expected = {0.0011684548f, 0.0010390386f, 0.0011301535f, 0.0010133048f, + 0.0010259597f, 0.0010982729f, 0.00093195855f, 0.0009141837f, + 0.00096620916f, 0.00089982944f, 0.0010064574f, 0.0010474789f, + 0.0009782845f, 0.0009230255f, 0.0010548076f, 0.0010974824f, + 0.0010612885f, 0.00089107914f, 0.0010112736f, 0.00097655767f}; + return expectedResult(expected, output); + } + + public boolean testResnet50(Tensor output) { + float[] expected = {0.00024139918f, 0.00020566184f, 0.00022418296f, 0.00041731037f, + 0.0005366107f, 0.00016948722f, 0.00028638865f, 0.0009257241f, + 0.00072681636f, 8.531815e-05f, 0.0002129998f, 0.0021168243f, + 0.006387163f, 0.0037145028f, 0.0012812682f, 0.00045948103f, + 0.00013535398f, 0.0002483765f, 0.00076759676f, 0.0002773295f}; + return expectedResult(expected, output); + } + + public boolean testMobileNetV1(Tensor output) { + float[] expected = {0.00019130898f, 9.467885e-05f, 0.00015971427f, 0.0003650665f, + 0.00026431272f, 0.00060884043f, 0.0002107942f, 0.0015819625f, + 0.0010323516f, 0.00010079765f, 0.00011006987f, 0.0017364529f, + 0.0048292773f, 0.0013995157f, 0.0018453331f, 0.0002428986f, + 0.00020211363f, 0.00013668182f, 0.0005855956f, 0.00025901722f}; + return expectedResult(expected, output); + } + + public boolean testMobileNetV2Relu(Tensor output) { + float[] expected = {0.00017082224f, 5.699624e-05f, 0.000260885f, 0.00016412718f, + 0.00034818667f, 0.00015230637f, 0.00032959113f, 0.0014772735f, + 0.0009059976f, 9.5378724e-05f, 5.386537e-05f, 0.0006427285f, + 0.0070957416f, 0.0016094646f, 0.0018807327f, 0.00010506048f, + 6.823785e-05f, 0.00012269315f, 0.0007806194f, 0.00022354358f}; + return expectedResult(expected, output); + } + +} + diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/res/drawable-v24/ic_launcher_foreground.xml b/lite/demo/java/android/PaddlePredictor/app/src/main/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 00000000000..1f6bb290603 --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/app/src/main/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/res/drawable/ic_launcher_background.xml b/lite/demo/java/android/PaddlePredictor/app/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 00000000000..0d025f9bf6b --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/app/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/res/layout/activity_main.xml b/lite/demo/java/android/PaddlePredictor/app/src/main/res/layout/activity_main.xml new file mode 100644 index 00000000000..0d1e60b97e1 --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,19 @@ + + + + + + \ No newline at end of file diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml b/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml new file mode 100644 index 00000000000..eca70cfe52e --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml b/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml new file mode 100644 index 00000000000..eca70cfe52e --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-hdpi/ic_launcher.png b/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-hdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..898f3ed59ac9f3248734a00e5902736c9367d455 GIT binary patch literal 2963 zcmV;E3vBd>P)a+K}1d8+^p? z!e{m!F(8(%L-Or7x3OYORF&;mRAm8a^;km%J=s!AdNyc=+ezQqUM;oHYO18U%`T}O zHf$ra^L^sklEoIeAKmbOvX~v2@Y|vHs<^3JwwH?D$4l*XnPNs zMOqozmbkT?^lZ?$DjQ9%E0x+GsV=1PwZ&39Y}iI-$Fb3d%nsk+qrN@cV=OmQMEdF% z)iHMl(4Yu=cIkixWXtwMIV=>BvDSrHg8?)+vLJKozy*}$iE>&gGGonlG0cJhG&DRv ztzkg-AO(q)B7~G^EwE#tK@nqmJ}!(Bqtf z=eN{I?X#P!Xx=uL)D9cAk=b!~&@H~6S)=a?R4fDdP{-5E5X_!5&FwFJ^7&W2WS z;CnxBCOsSU^v-%(vad;MPukr;&+ciI+F`>sGCPiqHe`1A1|N0p^<|#<+iECwOG@y7 zBF$;;0YAhxtqK7O0SW;M0SW;ckbsQ#9QTYyC*g`2j%bA%1Zh^g9=9l*Cy!I^{_p2$PP2>j_D2AybM$NwY}iJ(ZH9O3 zlM8g4+dw;}V{dlY2EM^Z-Q(AmcmO|Ub1&3EFTS>iuHC#rcNo$wkB3@5c#lSunxsQ) zaA7tLFV3Oxk}X2`9qVL6?4fcq?f>Yk0E0IEcm0~^P5ovLLV$&D9ibbZTOt4ivg_<= zu^#q8tYJktl(egXwj4c3u6N&}S3mj_9pv5y{gQvL;&nM}TeNE{4K3O%_QAdpCAswa z`Ev>!oQREY9uPqL)g(QPVc1U`Q3An`+x_7g8edZ^0zdcpXNv7^!ZsgV{ugB){w+5&3-Wlp}yI7?tN)6*ST)-XSL4g8_rtDVlw+a zE+K|#(tV!KfQE22d-}7B(mLkHukIp4?na@q?%@4Kb%u!@F-ww?o?tn_Ohb zPi3Do`yL?Y$rDPYtEV;|250yzpS^rZT*TflAZ&YqC;by2Ul7NTZHKmC)9NA6Vv+>C%^1XhNlp5*!7zxTTKfHTPhe?@XbH=VzWEuCcmX z@L_&qCB;=(Xi;-D&DvT)kGOiMQ0&YQTezdH&j4D;U@#9&WiZClJThS7w)OHH^fIT| z+jn{&5bhMbynmM$P<0U*%ksp0WUy)=J!n9~WJ&YNn$e3{jMFOW6n~uqMHg+M3FY|#>(q)ZF;RS(xqTh>S1Ez_jfFig z#ivbPnZ26mv{5wdB5SFYrUNM5D?g-OsiZZK?hPof9gqf&7m!5-C=d>yOsw<)(t*G@h5zIY2saaEx|99pU%^#gvdI(Qqf>)zFjf zN}5zm9~oT`PmH~EF012{9eT8?4piYolF(86uiGy`^r#V4yu7SA-c zjm})#d$(Kx2|Yn~i19Fr<)Gs+1XaUIJs~G>kg>3 zkQ$CqUj*cb1ORzHKmZ`Ab2^0!}Qkq&-DC(S~W*1GV zw9}L-zX}y4ZLblxEO1qhqE9Q-IY{NmR+w+RDpB;$@R(PRjCP|D$yJ+BvI$!mIbb<+GQ3MGKxUdIY{N`DOv%} zWA){tEw8M2f!r&ugC6C5AMVXM=w7ej#c_{G;Obab=fD={ut@71RLCd*b?Y1+R_HMR zqYNuWxFqU^Yq9YB)SmxVgNKR;UMH207l5qNItP~xUO*YTsayf1g`)yAJoRV6f2$Fh z|A1cNgyW)@1ZJ!8eBC7gN$MOgAgg|zqX4pYgkw{E4wcr09u#3tt$JW@xgr2dT0piE zfSguooznr3CR>T88cu6RII0io!Z)mN2S3C%toVr+P`0PTJ>8yo4OoHX161h;q+jRY zs$2o2lgirxY2o-j$>c;3w)BT<1fb;PVV(V`cL*zHj5+On;kX@;0)6rF-I?1)gyZtM6}?#ji{u+_Jz`IW9a=87nIA3aK2~3iFMS zzYP&fCXLEibCzR_6R~#sKN@)HB>);Za`ud*QCaKG8jEwqgoknK7rwW`Cq?RYYE5r+ zh-YUqJ082>*;EG`_lhV^vHEM7d+5Y#e$d^rC*jx{U%h3B^nU%7N|*y`o4g{@w;KP-89>&W#h zTBB2vTk*S|My+4jYTPKdk6yR3b?nAfcd`FeC@gttYuGBEl9wuf8`rOD9VP6`bhNxR znvXql-3ssVUSXfvcf^2L5R-^4E-s=g|M$Wm!?BMl!51d{AS*7Ggjwh^YsbK?6jgCA5T=(9$oK{{z$fCe9x5IJ^J=002ov JPDHLkV1g@XpTGbB literal 0 HcmV?d00001 diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-hdpi/ic_launcher_round.png b/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-hdpi/ic_launcher_round.png new file mode 100644 index 0000000000000000000000000000000000000000..dffca3601eba7bf5f409bdd520820e2eb5122c75 GIT binary patch literal 4905 zcmV+^6V~jBP)sCJ+Khgs=qzz9*aFfTF@MBLc!81jy1$_D*`qMnYCeSOOSS zh~l6kD7e75FgOnvP=_arGNJ+k0uBt2?%a3It*Y+o?&`L?*#fV=?@xECZq+^KuXD~l z_tdQ>JOSF%q}x5h@>Id>gloHZ!fr_@%N)Qad* zI}<}@Poh`#X29>b50CkB%{yWf?z(t0rQf48W{j1a($$IrZ9{N{@#9Wqx}%DM^fL-m z`X#_s9{BwX>^};}KMtudHpmMyRCq34!+|XCtnqeli6}6}7JiE;H+GAtDViHuQ~X9` zP0^{y>Ov~ufreT-w7!yx_c;QOV>|0UxJK{lqSx`7cx`b!OLV*;Ez4q9Y_XdB$PKk4 z+Aq(kmz%WbOV3IpYsa0#_Vd?)>*2Lc zn) zvVw}USbx|rlL2LMl<$^rb@TnK-;J83fd3GKh6#=C5WlXv83lKz{0$(8x1g-%;q}$b z1=&8M<_eQZO4eJk#nshu9TsZZ11Z~hVkpt8oA4831ZP3Fj3C~EG*%gSnciYD-cpkI zj{J=o1Bg-kJrjfz${Js8D?vh>vJwR{=4)c@ZtTqt#tHRR<9b9ew~kVG6oc8(lNE=Pu>)F6HIf=`kIH3oJBkSO2;+SnG--LDU5kx zC0($63w`LN)znoR#GhW@M5n&8!EGBnj_usF!G5qm>{qhQ`sdB#K+CoQF7f-se z?#7!W#vF7jw48A-)Ulxz@0b)?7iKWQI+fE6Ud#Le4H#? z*wIeM>mtaY-X;WO^yfR4Adp*W)N+A4Yv~TqOy)a5g8AjAEfJ4acRWELKhbNNKrc!( z&!ze1YQkhsw=A3()t7B^pu2=1)CJq>k}s1bv-{fV>=i+J^=8Lh=Pn_L(@77X+QqLi zSM!u0YfVL$I)-o^+D$g^8iKevTQlfM$k z8A}@MLX0cd>SIdp0%mtcJaTy&g94$WW9QB?a!}a+T)Rd$eDM!(fgHCnNCsx!svv{S z@9-MjC~sfoKOK+dN>{)_sV(mjhof{qxwvX-7Df1DQTI(g)o z>s6XRhgIhE&g6I!q!Sxz>EW}#SnudH5WeBSekYPp`9~Vp)1-G^r@B46=-SWs(Z;X8 z02evPKG%G)Nf*Dpl|HNSeWdw0`U#|(mpohWGktDRF;Bo`A2K9T}=|{(p(X*E>(aYDag2maC6ay^+ zk7K(%-yfyPJKv6-`qy{#2oNV$%o|*T^A7!TivIn?ahqEKj{ka& z1#*R?@}3aHxtTmO=~U-w(|Xu(B2EmI8B50EvnOk9*GGbcJZK_}E{D#X@`(&j@%hg` zvgc+#V--FuV!3MbUy#-AgE($~;1gULUsw`94gkTgN-nwH+_TiyxD=9t>#{5GHSR=+VC|3HUj>p$m zF=5TOh#WCVpZxG0Mfs)VLU~bclwVS}a)Tud>)$I3M@i?-ZEb;CNQ$OT?W!i>WPgI2K-%bDAV3iV{YFpxIA_D~#F;z7mA_2ToA0 zz;J#$$gz?H{f~tykIYwsN^&ofDHEcc3HtMs_ksmo_H~%=S!trXzdzzq@XJ@P(yd>A zNh?17fF3z>nk9kWDu3|gPt>$~7yTPdOfi9U)o%B9hiOkpO1&hgnGv)+?=lcH(3zlF z)1$73Anp4*+{T@4Fog)rOQR%n2^~~bNRNp!ZBKCK-@noL+ER9Y8^~8Se*UT3c%b7TLtsqf14?X2rJH|pTWGz8-n&h;14Ov z#z`fWWiO*ed){^1em`8ly%A*0PxH#fdX?ndqyYz250dgaflgvo+ zJV{-K7`Kl9diHm3hJcly zengd6QU#LyA&GQLke(wb%#d-6v?HDD3F1f!>{yWg5#|xN?9J0WD7v z;l~T-X%q||!6msgyeyyoVe>kdc~D4&(TwHYfu@{&z(qUzHQHR6u}wE)#*5x&(o-7O zw@7jXJiKu=?N?bq2i6qRnT;Fhz}ixmnKagt?l)w-)BzP^3@k~*Wp97@gTqNpbZPR zy$S@S*a*rO5riY0Ud8DORwP?Adna(v!QOi8<4{14v_(t!#gLwrT(JX4+=L_$A%|pc zXmt?{(xut$cSLlVo(30Y+4jMCjtGY2uwS_m`dG?inGHD{f(#luthNkXB!$a+a>Yn- zK~O4(yi`tCXd{2}Q7v*n=1Z+W<4npgXvmO$@_f~4uO9n2kmNBzD-1S*B*<|l$eA1@ z#7YnNRI?n@&u)dVc}PLoFRSt;=(FF*KZU}pY9KTJIT}LH;AkK9+f+gq?~2G z5#)j#B*jLMG&xp+>KqBOk%JavBS>X$J^3kS)@II(S5WsDjsv%=Is#fvo%C=}VJ79C zu4XlR`eZez2+jdtZkwl~W8jW?O+mCNa{m8IZH0?IgmNQbXlLF4NHs~k~IN5KqX9?a!NuC1W) zYsz_4m;p2B(rNZ|bq7KTK$6gs(A^{fuF@Y|C$u<+ zeYYY3Gn!;AyU4%y;QbOj@OvR}OAX~1e60jYkYi7fGch)Tw9J(lK@#LJf(#;pbZHir zB&II7NTQ;~GF=lByQEr3##lyCO%LAbWBIf<~=H3(^R#^&aTfo7d6DH>o+Z>qt5T4kD_BN0|i~wM{;) zQDk{ivKxY=^BgNdF34d7nZyJ+lfx0Dp`+JSH331CES`Ogv=4}5y2Zs^=PLgRUr*8)xq~v8}M$U zLOie%h{Y~;4ui@DJqJtzG0(xF97ij3CmS@3983s@mls%CJveFs=+cwd>4yDCfvm&e z!5#1cb>BZeo;3I6^_Foju7YH-rfKy08n55>!E;8!9e--mI{HXM9UTG5-bio}4&^qi zE~isoTuo;*ZeZWBo`Vxk8!8zvL!O6k1VIoUEds_IbStzRBxm^3Gm}w=_OY=YZzMUw zCMRKGc;U#1X^+ec$Xs%Pdmk&k3F4CX?~8#O4uI@BY`Kmq!J0Uv+5@a9tSpblLOV))hr-m%u%E*xX4>hBnb`e#B{kyo18?4;4dFUw7M^53Rybu z824~aV-c4}JY7hR>xV*sAg3fy6mLS7LnaNbD2_RfLpjc^aO!{=GM5BGo|C6yB@D9o z>0^ok{idSKZKI>_xtZixNop4pgLk193Gf?Ao}Iaq1y@!>f+5tPYW8ZSJw77VrMS#< zkU%RzE|Nf;cya`#HnR*FQxeQ`<~;c>Y2!DH$r^KWEyp=Wij2g!i9-MbcG4!}i^_bU5@kB8)I8_7rlg4C4#@0J#r1#qtCFoLQJrO9E% zt`s&x4TB&q*Dj{y&(q&hhKJ${y!SHMP)2fle^N(DLRef11H>ps$3G)mFl*0{%0f#} zK?dh~_$b?`;>l7qyL_2N&lj^qc}_^Fh@jk*X2^mq@ZAj7%2fh^%)qQAA zZ3@z-Q#;=6kf<1C_wHkrQ^se@o}KxQJaxedR`bDn4a5ufwojD_f5pWfSc3vWaa8IF z!+Z?HAa-6lxNq{aCuDPGysez_-`RL=-eMvHI(P2D`bHVO)$w1e0^WP&R`mBpOFQKR>_w07I2s zIwmM1dOoD+-D@HOzvDhQc0abkw){E0*){N5cul3$g6n-PcZs4>q4bV;KlnN~%kbn}!V8maBKN?~PDN77Zj6xT>KxccMrJYVYoo)adu8>W% zmv*U9KCo@D{=sCEstjFGl{%?R9Bd_S;`C@G{FNG~X;+5Z0h*dJ1r|5g4wB8=?S#Zy zt3sAsXM@aL)nWAyCYz08&uXYp$}38nkeVvA0^C`|ts22ve2Y2>mf~J~_Til&y|FUz z%#l)O^+i>bDr7NsoiC}@GN^5^{=sAkPSF?VF#7ysBZm@DnF?;le_~|Un-B}Itc2u|IlX``0V1M3jKlcCTY73+_+5_^1 zO|_7<%PEyPhbqxCEnFv#uom}FdO$lY%`OKi#h<5Co8ZPBFZA{I!|wAx!c?aisEfxs z?T$*AUTc9D8_Hpt%L37MoudCVml+QIa-Q{X>F$I{4t=051yd2KXJy7g2ho;dPy9%m z&|3%hK)bgG?)N=_y3^l5BAU(HpEX16sc+%jjdr-wd5e*w`^js6LDPj(u<}q7%axih zoQB@MKIp*y%l0*noe!-3>L8Nvz`X|#;P=}%;m-Yg;Pd%Hg6jXkc0~S4=WWP7_Qlvb zG1>9)E0=~O9SWcSdXd@th$;|?3QV+Z@1bR;tdb%M2ko%(GTA+u#e@F7$5Mb+;mB`4 z!xVgv{Jp95%Y!hpT7-)jrQ~&IJFY@h`L?H{0L^~?0CJaZ z{tZjr)sT1m=#VQw^-Fg;S$l@ofMbuY0uykS+-JWJI=h~`ci}FY$50ATJ+%wA zO77DqVS>075^y6_kJfo$5r(}BH#(lkaYNw(n&Hbh&XQd-lYhgIk-UdHhZ4HzOR6cX9O(7$kLq}D}u9EB; z-dhHFDZZ<8Lc2GP(}(AKLrJ-Oau&a1s?6Nk^&FO z6KSRZhEqx_SQs6S0+Eca!Fb^G1gONmI zC+HbyhfVOuc?OI&h7uoNn}=`c_>iW5NO1q-GUX8K1^!Zxzl z4XfveR)GIBSo>}=cI+IH9~|U>#(X~teA-&84{aZTo0BMk;yjBqEL^gX=_9kDnP=}a z`+sm4^17nldnZj&U`51GznG$gf}Fz|OlbvM2~cNtN6bbO;LjW>4doDpXIHr_#-WEK zTp3oTSyarnG|L?64R(Lh#u7IM@+CF;0?j-dAKR%u-gp$bMThf`Y=V%QniZFqb4;b% z+^sU^c~$y+58W}2ds$fqbXadxS)oD}YcBF8+Kmro`dqK7bh9_jZo>N(2|7ZqH?6u% zs@LZQps|*E)s_+u&N{X0R(-hsYauy#KI0bVpUP;&tcc8vw<4D;UKP1mLj0?AU!cHb ztdAKWi}A~qZL?OzGg+1b@q^keUNsrViJ`HuE@E!RO5*b9*&nDxR@U?Q6pMIaj1kMY qJl2nQa+aK&iDQb84*TpHAJ>1BQ$$nT?9A!_0000+Hy9+Dw zQlg?UKB$_cZ8RBMYcyI%jkQf{#wz1Xr!PxQ>w~B~cKP~!=iIw{_rdOp7tZhwZ1+g(AXy-HL10DFmbXNx@L~ z3H0wQYEpsnp{iIyzhEeKgc((i$;}oAoqHl}Yb`&gx~}ISy|wl# zwdwQ;nvEgzkAnwYj%g}=Nide26RJwsNTUEE)Q2P-5}7cQ3Z84R%7rdvN4sQKhOlPcRnSrOp+WGP}nNJgfkDx!pMkypKGe90p51ezT#4MxAxQ zN3CC+fuRy0nP8u@+)%h}@FHZ>vWFTTCD?*bPf|6Oz4#LAYDsH*sO<_ z+8Vve2|wE19JrkK!TNc*tzkb>2=OxIfDS8-yiLEA$m0k(kQf0ZJlj+Q&+pg*@-o6x zTdEi#&vL>m?`;jX+>v0bbWnM`S<~tiA>-z6^m&Xo6y=iH&}dMDp40vqOvn?CbR0P3 z0YX_`z8klIalWefMaf}lN@-MvK>)C@OTMQsvEFV1j6zbmglN3)tDNw{&IYft@#yp|U;GYg&z^)Rt7d@u#0Bpe zimnOEmq&Tef~aWH7SjqERa#-iBMX%jZKUfNcy71bp|`IOKD_d0nA~D<-XkQV*jewl zx|K$GjP@M*^t)>e04FWS7-Uwy|!6q{ICob5gfvYaErq&g;Btk^VqnotOu zSN-|V;a*P<^rDbv9KD!YExR|ex)jop)as*$VeKa$K-3I_~rZ#$8n0D;V;;rwan!I2{& zEnl34toAlI^wpPe zlye)Ao4ycY%W~JdLaI0e(MHvF%G1SkH=uyAXf{=!ABS!n#lZ@o8CZ4XFmw8#1n{&R zVs(YP+3GCIkwRjs%TCiYQa(?iP=b^m$jib}=-N*{ggXx&44S-zukU>W+LOO#ZOZ!~ zOnukpUM6x&FsRNVXIChVTfbhB(rD_SHz|4}839cXjAmbiVtspfigR#uEFjIMj@si>Ore+Oei$<1cCarcfF2@0*j682U1A9rp; zlE=d6(}XYz#@Cd03QHCwxdi0=G&$N_{=Yy1XfbK~!v(L-Fa7gxu<_$VaOSVq1CpmY z8$Ujb&-~r%UfZSfpfHyQ7GTlb5>~#R>JqSaSxPVhD7~ea?b-3_j}BnQxCvh0zmvuF zfymQ6C7Oj$o(rpg(e8EsF8b6fI~#$e4S@tKotNPf@Ro97lv&dmNB}MOzKDHx{Td^7 z^e>kK&H&X>w(nxk__|+v<^;uhpfq|w0oCgN2n*&Uy98ur#zdLa9sUH2!{g=78$;%} z1L1P#zaX{-%}ARM>G(3`OF*1abzPV`HC~?1g-^B_&(OXN<=~`T0!1J)ouwb`hnx4h z9=m{>-*my^gYQ9FLp5Z*znzJYxJcY)*bL{8bEG_x3mc;?*yV2q=Kg#a+Xvy`pEue zJ2#<55|A&7Ku(lOR2IUxb#E82l~|riL@t>>J=|1!XP{(Gfq7D*RSSuh3Wmux1H9O5 zbzVzIvg#nSb+dS_bpfB9xub!%!Jvc0T8>$5O?a$?#5xXzQ6&nfaS6~B@Yl=oyt`5J zUi|^Lo>^h?bXpN!k$b{#I*o}Gg+L0KqjiNap+>{bdB$Wh1B{gdNt&z zkU*wl;*p0Tp96`fH`Pew34JvBLf)EFl)AaU3W$CXzIJ5}*_hmnyplOlgkJ%5dN1-^ zfYFOQ7f|g*o(nK@@|F3Nh4!=hOBWWfJjm^}QhYrdl{|g|c5+Shdb>Od$s<#GvjwI% znqg*ZJ*3tdIBXmlNOJbhCP>{}#ZfQ82y=FCgS0Is7aB~A{A+vOWk<4kG8-CsBA>N) z2Ro)Vo9)zRim|LCBI$`F-!JxDQG~E+nVNaMkGbGoHB3M|cbfqm?Jyjr6ln%D z61dqAY5B-YX2WN|HS&_#uo&dO1ZLdVcx6-*l>@yGiUd^twKIQ z1myy3dN1;B0z4enBibGcLp_=&v^1A84wc`CetouQG9=$!N7f##SDg2(;-$ z`!;UT3E!5cpgGLm)#4Fpf{Qj}^JF&E4%N%lmmNV4&oVB`hy6ytSLkp=a!l^3{cMD2 zTZ1ifMFW4}K)*?$c>mDR24g)rEZIEGUiM-d`ALieTX6^VNp)73C?Y9z`9d?=c(?d1 zs~_K-`cOc>&%IHK9z-;#Xp`TMv(d*wB}E%mPIu_y`4;N)(a6iqDI;Sfv%{G`Tq?Y? z`XY5qua{3ZRrAk6vM-O$&0Shch^Vh+#oUI{16*NgkrFgmFX!!x!YeN2Yr^QVW|_o)XG(ZcBN)a|R?) zB#;P8w$4loZCthCwyD)Kv~>DA|AHfFa+EnB3aXYkonv5irz&0+e_1c`|f ziIC%^3DMCrgrvlo!j#n640IkHIfLEfbrQs9Mtu8!_VBgvQKZl*M~Z$T%?|zlVT_2; lV%Z2*hu);6rydA(}wUDXPCF_W1vnaRBK zeoR6LNsxyaZGA2++G?*?dRwg0Dq5+E#aFEgnub(`IsNLD^CGWJ)s74L)DOcaT_gD&woh@MDDT7paS^E*rkp>8F->o#K*x;hPkb-{g{@G1-RXg&d5PhrJUf$gT>-Kc2+T~(?$>*Yu zT4h`0W>J$pZ%Azsi;{nVW%G=At*)awy8+_t6`#e`RGh(2zZ43)n*13}cE8;I5R%*` z|5tXk`=>gMs>q*$@(4m8?`JI1Q?{ zRHAd+JgRmHP9yV))rP7q3IO??4XSoJ$5!Su*=~JDub(K$fM<8yf*a-K*Qz zPelO^(`|+V_|-0Wk_vz*qdO0>?1mS)wM$Y29FC;)bEP-uAW0uG0ct9EO#m6#%K0RZ z39?+K6Wk5gE*|+^5I8uFyX{ALNYa2Nz%T`Hn@(}pU9*C57Xtylz}>iUsV2Z#2;ejg zaNoZ2a>iW@1kiDtzFVLPa8^~&DQ^ARm5e)008Ic*fO8jsh19y~Ki*W3-Qpae2p0nv zo(NXL_4n_CukY&uHM^BPt?*wD_pyjn&Gy=Rcfp3fUR68tMLx;5n(a64-U;9T#U52V zit5Q{QE!`~T|s99zY=X$w0cfmaNYW#0DU9B1CnnlE=a4Z9-s@!Y^>p_bSr_8-_-*O#n>*O#n>*O#n>*O#n@Ra~B|fQ*l9(%QQf9xcJEvaY~>ll!7d& zeMy*!>i>NLUU=_aXnXb`eD~hF-~w+IsQDzK^0wEj+D$`WSMKSA3v0K*aIW*wzx){v z|Lq;P{lJ5=b}1e+^O;s(t?biT$yLHOtC&t(07^{x))^Qyf&6nz%;wDIf6##eu8#&sKFHx$9)9f0Z%(CUS$4kJ%h zh7xEzhK3iU_R;u@KbYx|2=~79C&+BFEBd6;PpcBt&P}D2M4-D$&W5VeCtg1)xQ^3! z9dwsT*;DBzpVRTKQar!Iz)wS)Y_}P!pfNfWp?4YK(O3Tre#~%m=I?&-Fr?${tJVhS z>=lrTBvW+|8iS#2`i=IfwE<-R;44R%@X>{!`|u$=e(U6DgfD8a!sD+U6_7w8>_2iC zX4F|kjj91=H`?IFhx(x5cTdB<7oUfx-gpfTz4Im<`TO4(Xq$f9`@-{Je(C_+`S?TZ z4vcpQ8~0gw-iMFABs?!xhr3^RjtMxadO=JCss=`ts28z5FLd@+WjRbPjd{sS);z$b0hGtE^P}he^1i z7>H-yd;^|7eoS~C1QmcUcehUNIDmRU&%AkT#6+Jh?!%J56dPSF5W|cS2~^FD7Wvd} zT-c21)vi6B=%lT`_GJe6+|LDhTUPB z>Kqr7@|jIF1GGeZq0h@xpIiwP1yjb9Y*zKO!2wZMbhJU|{xvrEbS+BPy11i`MdHh_ zU@6%x@Ok(Gv{}~ZjMb!kP=K2@70hm|8K6>-+veseAW{OYUZ4qdx&3t8|MsoFVo&7r zBR|p`^0RB9Ym&QOBA13Klxzr>w7U5`YSn4T7nW@sCeFfg|s|3n!5j{|JLH@6H|aVdjq+q(_^fRXaK3P8tZdo9e@(iRu< zt#-^$ANe`N*~%uK05m~D0gxI2h64{X!b14LJ-fp52WMNa-_Ungz>n!?42H)aRu9tf zZn@BbcY(EZVhL~!%>xXh%jx{h69NHlePI7Nbyew@+aBx-lTRSu!x_l?#;y+Fs_qPn zFzyAQVd36CK07Sp-tGSwzO%a%W;so;wyOnR9>!fGhokSm2Wxk>z$}*;zO!cs^F5s7 zdN4|kx0C?4Z8H;L+zUX*9sl^`u!*Ba_}GaL;N;-QdrRble38%L9&`MolaSM3!@FQJ z6G4Z0_?!g@Oi9v1(0V6LNg6>3G$lEgO-Tm6-~7mZF&SDOz2J<8TOPaz5~@oX5^WXm zRgCN}thFfSJHcV(r^j|mGB%U)4;_7J+>jr_V@F?x)tyaH)Y%AYx|-ou6lC4*?Vr!2 zJS|H}beRSgvSlfiJk7T%A+RjP#kOg-=>Ybx$D05Lj~|1XcHQh<^OqD2_9kucVwoaqihgiFwGD}j~1T8KAq z9 z0*J_$7eGipRXI8<3eY7Ipjr$(pS5fpOv=;6o~r=0)r#cH3Lrr~6QEWsz)#GN7h+$5Xou}0dN}v_c^boY%{;YZ{WV+0(M1QNN9kM;!AOnLO zA!aO<$`pxu4!x90Kzr3RkuIy=J+gW&=9H=qA z_U>+&-|S@9p4AWyTLkr1J{JXz;e*%scI*>vDKlk)jL}tnO0kitDO+6 z?2}J&RYIn-a{R1}qm0E@ZB`_oFkdWy1o&B&jg?@V^{!r@`-SP05aqg;X(mq$fxs-TLGNGl11do^z)ej zbyh|4sl+n@Iva%o$n^8W0w|C#6u>A?ev|-N<5GZdoFLuJoL?^%Ksv}8B7j1W6%fFy zNPbv=Zjk_D@+X75dvA_6E6 zFN6iKm8nL!k^)EsSvqW^!UD*VZ;KXSB0MP{62Yt>fJB5F5ujW(!es*ZyvoB1VF6kp z*=dv~|NIJ2T%dOv2k0&0@pc1G%QTb_ih|Yb=$T%62%3bDw82d2XhH;WDF$Wp8)|TS zO9Yk>O2SA)vS<#MrV(i-iw4q$z#0HWxD;ejKcAgz2+A3z)@+3bosdkEd0g z;D&1#CpZiz#?%|L1R`t^3D6uAKsmytNfdzqGC|f*0VK$e7Qk*e$z8qXvXKiA`1=hV zmpdyx!B&1`%>9K46G0ec(a5T#01`o#KmdgZm-_e-0c6Mz|AmPOGO9|Ba#>%@WZZ2W z>Ho;wdKvvm*|hl5+kCX*InGgW8c#HK{=|ok`9yjeW-XboyKLmQg9WCdk*LNJcD!Wm8!M{^|rzMI;*ms)i5}x+Az2Z&!25I4rWwWL}BX? zEOKufEUd2?%)sM9ARn2w5R42L+weM@-Ge!fsOt>oIm=qnPh6z`_Ydz*&dt4=I7*o{ zE1hu`!$e9>O-f74pc5eSr(Br2T9<$6_jJqiuh$jk6-OgwWnppRih^SC?_wkr78Flg zxdOMJdh#qTEon9)Lx{AD zp})x??JVrlV(c?%q&{ae4u}ilB*0A^Hwr0^^>G9BT>K=*lpq(QLcEr=q$MqBNlRMN c(!@yr22-Ey)4s~&`~Uy|07*qoM6N<$g6%nSQUCw| literal 0 HcmV?d00001 diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png b/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png new file mode 100644 index 0000000000000000000000000000000000000000..14ed0af35023e4f1901cf03487b6c524257b8483 GIT binary patch literal 6895 zcmVBruHaWfboaZ^`J@5OTb59uN+UwfO z>5DKPj6xxy*f-15A^38Hcw8gS)fY>m7X^~)>WdY`i-Y7Ev5tB;lGU`#+aci!MOUUM zD}qsF_F|N>IHn{!fdYTV_wX|;<46$x9(d2I{>ArDOEMG+AD^=P{ywF-GrY99`C;pd zTVmI*ebJ{Z?*lK5{2OnL{2bsnz#klb&V^vTF8LL3idsEt+KcA+ISDVmw89n=b3!uh}YH8Am2dcyFwO zP>3sYL|70%XiHU}0Zo+(MxFf$fG{c^GK8Lk0nm!?MOUlH=$7@wQ=P+?afrb30+O<` ziTG*r2zL#G;JREn?w(KwKTW>kAG@~nvD;BDbNA6Sw3X7nOleNtO`EFE_iw7?Nk@V% z2nn}DI|Z-=FUSS{e!iMKGH%z#^FftGb+nGAxybACovek#YjQ#vb&d*p+t1kJZ`xQz z;u|ZlH|p$>-hl#GilOt>$n{u0Xl)T;>j-tlI@@Z?Wzp-=)#G34?74swCQ~ERfdKmc zFhPnTvx5a7>%ShCv+=IbEiP%zhTLzjnoMn+{p#7s56cR+1Ip9!b!Tb z`Sm7~BP+1z^;S0iG7&)FAn@&x7D5ZD8A|Rn^8#NH904lXb|d*p^Im_M3cx}s7!4)T z9gHH`t8+}w++;htxjC@gx{~KPlVjj*{S_ks3$9(+#6u-Jl&IAP3pu!CJwK#M5t6c_ z>9wdD74a&~(E(Zk#1U@ZTtm|Z&dTxVSzAiRZr?zO5>r03qKN!s*CrAGLWn8vUzShH zLj>)tEVfOD(e%jX+M_)bim*#E5_p?Gy16VcdB?_AS3UnYnfh>x4oMP&MNjS{^B>++6>|-QpN0X@X6L&Y0v_nr&QpJ?Nedk76e$t+1QRS1iuh%{F%%f!H-mR|< zQLG8Eng=h6w*&uot15mDdp?pMw_z>mzOGmllD0RJTU#1Lm&egEdG8hyS)~+JzIUCL zOasw+)T%|5zrIFI%imD16;(cBT?v`6d!z2=P1Pi}_cC zaY){_eM2i&Osq}6Oy>Y2JfPjfx74>{k`N|n!sM^n$$Li~8z=DouS%NFPq=6oaadk$ z0*u&FPkPm9z)j6IfM-M)d8(pgV+4M-S4t-d{CpIET*U$q-ZNqpnS{w$epknMM*J)< zPm6>bel7I#uL*$fN%fSIg0yd#CHM7kuV;h_C^iY@0i^Gty9+J2aLrPcO&e_I4V!m|%QLzX;!0D_phPA9;f z54Vuq!_U%`L{EsIT^4|j0x3HRvX(Vc4%<2x@Oh2+Dn;)>o2t)Xj~&>w&Vc`00uyVP z+rjjLt~xt1(^VjmUESy@cLz5nC)L@%fx;yxhQ-ro#ptR%A^-9B0u$XgK)sha_CY+|f}c==vHJ zIsE14R^;ECC&mE-m5-zZK z+8{Cl>U!wJC$s|y>+%=$e8oRsp!aOoBrJ@MF;SPkbU$$FNuOD87#(v%q_;vE<)g{{ z)}HI>svC+uv;Os$twg|H_&AuO>#CKsTo>rM<9BT$m9M@;K7t9+k|;62$@KkG-xKZ2 zhe^_oMi>opdhOmo+KXR&YGro*f{q}Ep3j$aj{uxYnw$E)-`r`v*$LKBT)@uM9ye4J z-Q#1bNUOU9;6>Q;!8^3)TN3u@@%O2>^UtqNkTbvkW<`=Kz-yfT?N{=`iBIXo`W%cP zOF@78`!8CjaFJ~gEr7rbg{*#HA!~+a`8W%{Bz>w?4Y=;y{O2FrCCt!4 zuy^g+qyHvTAKvPoK+M_<8JLnR5|X`g3r*75jg0vjI+5}2Tc>@aBLzSo8U5@X@4sm^ z5-ujt+fn`dMM}KeB4Jx*2>uVv&wPi8j_zvT3~}C%Z`$&>zV&72aX)=W3XlNt!|X?Q zQm^Au32^rJ-)S6xb54f}0OiA!vY*2j%^E_@&@x*=87F{e-s!CjZ|nOe1f`XR>1IGiFlvUuJSK*t=o+=Yf5Tc5TadL2IQF() zEi;A4K7Fc758(rGN!uFr7=1be_I@-cIEM1amN~NnsQVQ zGnAj7{i)NE&jag-b#>GhG`pj=Hqeb+VmN|mT#uW%u2aZ9WP0=nqgD1a!xX1#>7~!l<@*A zoYvP%oqLK3P?~FShX9z1Sqj6ovlDNLrBCj+nMZO-0B}XA0IJ;6%pJ)C?Fk@Zmdxqz ztUAO8CbdHVQ=%<(ai;xq23`ZNh1c{dOsDraC(;Gp_x{_&8?%}28UgCOUzsT>BkT#_$;_WV*qs7k zaPyN$mvj4DM~Poi24V76Q+NQ14?o+kc?17edH8v_RvLR<5W!E8Nw&XzRMg*N-BY$S zuzP*nCBWq5k(6tj0?eD4;4Tw{lUUiyM?|NRtpotF6fZvOQYu;~fC>eGYcU+!A^_gI z>|g&+Jh5H^5!z*f#wXumUx4XTZuC;;xMdO!D9;DmFW!WFarO)uTvuikAf~*Cy!Q2% z?KVMgd~=fYTB|S$Fu1;)-b?J?fAZ6hBmmb%3fCA#XxAj1GG?%S0g^}b05|kYcetUL z-fe4Y`Q-Vtqy|P!>5)U^_~}z_aa-{kcrCnU&C4&rJ`sE|B!wvbkd_OtElu>j6jNVj3Vxd?2fw$+FBYCS|S$=CYSc<5Xi_2*; z&gOy)`=+1ggA3j5q=$gF`8aHR>b`OQ}eQ6h8^930& zTfz6uT#6in{r9oABIe_L$ArY#I_=r^EJ;?q_OB~WfagCwZZ1HRKmdgU5x6DEkfO}< zfwzyo4LP-t+{?-ekO2Z@S_?o$$g;aAA0l1(9&md- z<=AWj7QQA=_Jw~#d#mJ4?b#K9JJqf<0gnCn1538001ANs_@tzj2-yZ49YM<%;c8eY z$FZH)D*9o-^{baHqyo6OF>A<%3Ni|8q&>{r+d^jT-r}%~5L31_lEnvhk3OrL;pn_Wlg^IkA4rJe+-a^UwY7R5qH&49$;zI8q6 zuFa?QWFa#_X%0VCHo0|kEkwel#20?HhOE_Boonzd$ROVHrqv>s49lswR{|TU1x4L9 zYWUdAHK)eyY$D^fHyXs|f^6qRnrJT@3q;P}(?aHg7lc1M1q}7Ow>ObxkL;#qWh{6p zNoJ@q2lV_2;LW5yv5(xor2$M!4PBBnq0SsoCnSIMQwPW-xK9!YXN?9Ewl1gu%s7*t+Bg35~wxOdVL z_!J6maK$|`wmvrlW(J|R4Qp6SZiZ11h`rAlpa;f+xk}ztOG1=6^mika+17v_cwJcm znb@*{glqHQ_Z$<{mdK^Ro{!{5S13qeX|4t2CTLg$Yx3A^XhS&(#Cr%31fKxLk>AE+jwroWIAJqGD8O53ik6ycRr{+uucnefYQ1B=j?lwCZCL0Z!rfHSi)rM z13-u*5X=u3)NR;&OIH(34)$~;+?LI^bTx53U>L*(G1V#y+YdHhk;R@Ll=i?+OkCd- z%3*SEKUbcW_h90>pZQtm|g{tib$ zTp&#%&A4L)t+45A(Dt7dVJl9s;bIyEC|u)|eC+Xd1+WujnF-*8d}{%+%uSDM1z{$R z&7_>g#s<0G`%Nz|CMXD((fWe2kIJa1h~| z1dux=-=+ZA>r1lqv|jhme3Ej-a^{v(vpkqY`fO7a6BRX#kuLv&l7`Q~y7ROYB*UHn z+5!+@oj?G`=>;nRoTL}fw?`M#BtWKv2$vOLIJmo103=_5DFBm)B`<7DKe~FO@{*5NG})#;LV$p z^ny_Ujoc~u*wc9ddR8e}^0QYE$@Iz9$PLF)hny$v0ZvsH#-G7`E%D3)bN6Cny)?Oo z+qSv+;8rB2z(RmV8v@wL?N9-lEd{Wj+o1w%wGhA#`MdzbHr2Go)TqJbTt%3<(;lIm zAUDzU378K1rVR-b78b-Utqt;cXu%;L^r5#m;S(UOxMfca@Vp&7^2Kf$-2R72FCZ2X z4Uz3AJnS1&!MHIBQ6xl$8R)*9=6bq&fnGYy#$XFui~gt_LO97NkaamPlJi zG}q~I`=rPHvkwCoH&ISlZaVxMHavs*`M}$I$W4lzSC%}s2RCQw@i<@HvgZtV*b$z$ z1usHku}*8?kXySDgM-1OS3 zUTf%8r$G=$z>}u%up?*XVrolC&vhjv5k$Ci$41h-vY7O&P;e-=MkR~*S`E2p?^e2R z2iI-Qp)^O8l4dnAv4*)FoLKDvZ9bYE?D@AANMDDx52qZkTzGY)>9HjOKPle;xH&j= z@eBOKOmjv`Hyzps*NFnc=^TJ|TSRUrK%GPVdOzN?a*|%a6f$NpF_~t|=CiIQ=k0*a z_gF9s&CV^f?WRfhqJP7Z2i@Zm5rN+@gx^9pm|1YoJ~}B;5wdmmL}=@&iPu5z8@0Jc zAb{iaf=vM&M7XvE5Rxy|@!k$I=PsOZhtM{&ZTGnpnJdqF)xt#!N9$N6F zgblJ1XdAJum&oim79o@gW2kW(w3Y;Pl=9zrpi`& z!mJaI$>Fh;R0Qh?H=tA~fP;NIicACUUhq}tw&EHtE`c(si%&^rOkR(5#=6rsU|XEx(9YvlOxt7`7r?j;Y@Ha zPS9~Uq=Rp`VM6r6xi!r4g~#X|fyA-jV9L%Fxb&&yzc@|W8V$kHtq`T!J->k$fwT9f zIY8D*dwEf&fqFE>)T?2)4Pu@N7f&9Xf6RBr>&*6g&&!c~>&O}H zr#}qk$lyMl5QDrSl9VKmNn_^Ee2iK3e)M7{i32${3oSk1TC7gGkDd~w?cAO{}c+|2tHX7 zU#BJGcQlcR%3^u|EI#sS6Kjh|H*En;OH2Zj6;&!Hp+#ASkepSggI6tnD`?^Do&Mky z_(gS3!Fy7-66*lojXxVy`EzxYFjw%47oscmr^CW}fN#x@ih)QBU|84q*gJzJCZ~13 zcV=bGip38P%u7EKDP8$aq&)5O$o!1&t}Dv=F{)U027y0E7G!>hpM_^Fehd{2TmRyarwi zugRJiU+!L#tDSf;g80yf8j!fq&|tdLATY2y^~;e|A@Du?49j3d&XV1QyT&!b+bIYy pii9&6o*bz{@b60mWOsVP{|BB8eXZ|AYE1wD002ovPDHLkV1li`I!yoo literal 0 HcmV?d00001 diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-xxhdpi/ic_launcher.png b/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-xxhdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..b0907cac3bfd8fbfdc46e1108247f0a1055387ec GIT binary patch literal 6387 zcma($WmFVQySpr~^b#u_OG=0|(kva)DP1B+cP_AmARxJ*NC=Wrg0zUl5(`L)gp{N- z(%_OG?|Z*r_s2c=$2@ap&UtF)$(eXP9W_!SdLjS-K&qjxY;ZTH{xb;h@8E{&N(%r$ z+p3|gU=%dFmq%!1q&9_NsUvvk-GvvZjaIJ%uU(o!Ypc=Wv%E8e<<)SFdRM{tz(T@!nKT{;0jT2A&dgKu3 zk|GDUX<&73+f+CnZza0G4g29@hmNkl+2wP#$0yi6=u-4CD#*a8LxJLG9KlkveQ7v} z>E#)-tL=xh89y&5li1I!>Zzc!_i6V~nKP^5-+!69FtnX*f=*tr+cf&UpZtLBY|wv< zJ6r*Z5374 zi$7+B3A@szy#|*$Tb~kkzc_N~h3;oe8q95K$w@e#5FRGcF}wXTR}t#^!OnNc>Z52w zu23YrlIQY7UrLLcFSW5ctMBzwrTz=X-m{1Y!*LWUbO~;u&&q8Lu;wlGFqO2h4olL; z{rpPfr}7f=Z)eZhFw1_ITpft-VzPF1CHv-W>u;OCBJBEOEn$HmTpFjX=xN6-H5#V{ zn6Si;q3V*@lFMd>H8;M}vOp8McQcJ}^bBfV`1xb0g0`9ZZa9(wb+L_RGO6wD&I8ouM<}YVDFU ztMSz*yMDz3AkS0YO)3_lYDarEUyj?A#9s@-ln${-1Op^nD7zREi=%4Hy%V?=YS7G`L@>`3kHM4eAD%)t@F};|C zfj?B^Kox-WuPMuDp2=LPZU3Obgnl7{dD>|>*A`fn-0|^8uAHJz;<)tkTXA8lI&dHt&xG(4Il=e~QNN6o9YD7H{TR?17eM>#Z8#Y@_=7fZ?HkZX8i|mEGs5mR`uBi^ zzFh5AG^3EMyvpx(a*)!eOI1?nPTn?v0Ly$)KlQ16Xfrzh+}+Ua_I!5XU@ciwrAZ>O z<7!MU$n6`x${EB6YH$hWOMuSEw+72Lb~rgO*Yp26LGdNp*;^;HAD@(SAr(Dk;j7w! zQ>!M4rxUFYn7E?v7)2q)2rJ2%PY>A>-1O7bY~nt&n)jYnG$(iR#hvlih1p}c)I+|I zy^C;=uIJImfY zL~pm6t6Zw8FiOIY<1>EBS(<5`Cv8DBcZEpTCQ{@@-|2$Bhi;6H?Pofq1Z%b2@)&at zUA{9iaqi62D1|=T{xTe3Czr|z52P;M7EB|V-ss{qspYc0Cj~hUUURef8?i5H?e;kA z<~qW5`JIc(rCLz_oJ~>x8O2IVR%>+7%}`TBSQt%i+m+4tV?z0(?5cf&1v8cNlz7Lg z%ZS>-e!({r)+sH_1+QJvE5BqOgmfK_$X*P0*x6beoRN|0FV zBu+T9^1E5}1I>g&wC|Bn^{(R$!_A@+E4<}3n|QMU=H|GuQZRAZ+zSZ}SS{MNj&mi0 zRY+fp&8IQn-}zGeIVj+qntrIP-IpXF?2xAoyT|i)X+@HL$+|t{#ZAvBrd?L!=9aLy z%@CY;X7U41O6VpHq<1UBk2vi~afo_h1Xrb{vQ%cE|Fvi8EjFCP^~ zabJnB#=NPyBD*BaNSQW*VI+TbEmlu2&HD<4U_UQNUR_`K~u~XWideSoLc(k)vEtG^CT* zG`Zdarw^M&6C=~oi^6W#WL!BMe{E&Gg9Arbg2gg;cO^sJ#+L$ zWBP!R+lcV(p-B#aK<&Ly>?*3fngF)TwSRSmGJ!zET{Brabip#AUPyChm}S9IFG!l{ z%+I_?Cl?zVm9nbGSU`Ksi%z1{vEPpxnv}!StZLIR4yl9y>GM~KIIbNdVs|xsuCpX=J#rE`8<@v*FO%Lb)=#c`~s7W#9EDhRI!G*VBK(y z5D`)jJo4o1={q}Kg%YGhdH~@PGate(xi{(OiQn~MMSZM;!kHNh*1-e<+YS5-j3b?2 zq7SYPWMn1a!^Gqxr4d1gZ5G`QQ(&4Ag*OcnWO}~9rz5xeE3Ycol5cj$@jggn@8x2* z)UpG-U2|Av7a)Hi=b^@SNp#`PEDfswF$nyx&rD*+4SF}`_U48`=1VnBn}aEm{Funk zSWQuC>r8yUkd_D(dKEqo`7i}}{#+a?O4 zDIg~&^q#d5-Ji>``G%gDDzV<~+=*qePTy_lbVjK?!d`>ygnhxwtyL65_G4A=A}{Dh zq;iS@h|Y-wJdeGj1b{KBTkst|klERM7*Hwy#ZO<~Q$5~GzC~WjZHz>=z3~>oAVbbv zzmgOw2JQ#Kv)GT9dwrXGJKz5(Jw%&rYPjfi;TI|dyVJrvaZ*ivGRT;i>R6}8B>7*j zbJi0%9UfLcYKp+TU9qXLSp`rm`)3(g6YOdHa4cv2Y)-JCPZ&g1Z*%F~T@dw@_HA~- zxeq6NeOi{(yh(ziMZ)4yIfDP6nhTg;)$=9N_-{KO!ZB@c@e$(SVH`%0b3YF`lgX)? zmPOF$H%(2yD*LrQ;d*vDgW=s=2h+1RYg?DCXa2gXNT~W+Hu+pBZ$bO8IlS+nqXw^| zBM2iS@v_S^5P@J5V0gw2hamKs7Wro(xWlv)U$%_D)AA{;Mb;l$7?FOK*2{U?f_M(W z4#aOFFlOC*Grkxzi#w)?qgNP48e=dJ*`EYNKfLm6BlZ-j@VMi+{0T>$Y6e%gC|6;v z4=~J;U-H`Rv(<}l7sEXpm?7;(jXl{O>aLca zP;<5GjkKb?74YTOqJAtFKzq|v(-+j{(@?GPIKVS95tsog!>*S60XwAsnYHqG)dW<#@2UIte}({hi5+*r;^rQeDpKps%Ql|LRink z=CR6^g!&1h1Ks5JplDey{0{E~MNPgvQNeH21%lrCFFh~_7#;b73>@zaFo0B}hXo(J z#OVP*a2!ZeK|x0LfazsE0=vAP5xpQ58{e}Xtzn5B`l%b)PM2PI{UmZ`}XbW%4eE=4-VAbQ|zojxNh6BnLDzTlx-stKQP0|=pi5R7qw0g}ivih_z$ zN`Pc6h9K3P5vFz^s^};EaGwq5yEdpH4Um!3Lju85e*w5hg)|yEkihSklp#pqhWjij zaK_T%_)PG>g`7N9$25qwhR3WB{&pp8G2;J-#qe6%xdFHO2AeceqW`Q#`J1X4*a>V4 z;Y4EVTMA!^vxOA;$ZDCt!CPots~0yn*Erio(G!n)@W*|^D_=Wy;f*k=tF~9Zmr)dn zCzfODoJ@UXXs>1NP-A4#YmmhGXavn<+z_gJ`>cZaGo@Iz2J)=M7{{ zJ;n45y6T86%gls;?`*1bFl=sXf1H<+2AiBU`}H6YM=+eFPoz%Sg=s>Dva{ls1mJO? zTWP*i(U7Ec^3%Z$g`f%l##*mSt_wOa-d&(0A0@(ms#pY$P8SX-ZAVg)> zpsk00`SNH__*AQ#=>~|-wScS`e>RBCs6NsQ18sz`Q({qI(fOQUY10Mt%YO^v{>w>TEBSR zi>oS_n(}3A8W+^iWG~}cr3Bv#s3W>CFUJm0ejS>=V^X>!UmDV@|xH@hWB5yhc zuXagN9&cY%tMFc@?PqIxYmy+OSGU`O5gvK2Yaic7tFAiaz`*T*dLafG4tz~<{L=*n z1iRA9k6#TYhCWcSFW6P4&4yOea4q&Fy6Mbkfl&!{&@KmDXMWs7;2Q2bRU~gBtDs>o zNeUgzt#lWV4oq=C=5{Id0)=a+u5HaCtDZwXnX5u!bO%{LbXF-L40}KeG4lG*uU{E_AOMMd4ch=Q9&rc=;3fB`I@EFBuF!XcuT783*FH`4zO zxZ=AOG#fzwnh^u6!|A7Fqf5u{$IesB&EF?V9g5dyhcmbVh)|M3^!U*}qJEYbGFaK2 z#0I`dWniJzl~+;sJs^jty%7`^Yv#{r+=Q<#CleH22pEWpQ)lwX9b5uv064&fPlS+b zqZM<&o~(2`QgUJ$O29zuo%|4(uP+zAeibd;jfc(zz|+6+9EUrZ?#^|ymX-knV0Dsz zFn=Bg(*p-JjWR}+{_C#CZ~dR&on|-C9&{&ij%~0x9gtgIMPCkr_rc{WE_}pL*bCnZ z3d?M3AYq3)iUS7jPOFD3m9DVG)E&SJ1*`YXzZQib9R(``({n~0aGXEhgZnJU3vy*N zlEAeqef_?@nqICTH{?wuZFw#7F{`&i?NLpf<7G2noyziDxMHBmK=Z&P8jf>~^fSVF zFmD1h)DVg7D8erkb}OkfElv2i`s#7j5-;7~&l>SlgLRqNM90B`oFJ!3Z!I+~g7^$B zkD<7Y^U2QID5DVT!a*uS%0aL5KAD#Lk5^|WCC!!OQcFyxCl$386q*ohKGP#?pNL0_ zG0d|NfxU%N?);5-{u0rA@S7+4>7&sDwppXmJaj`?8D#?9@k90l(a-Vg>E`q1zXh9B zEsyo)21!OKE@yf_^P?a!d>O%I$~z&Bg| z{KuO5lVh07O|keMJh@ks$3EfHm`nFk6qNS&_PxPbKN1c~Ds8?;y>OzV;B0$XVQ=LQx12PJ2~x!&?qm%Tl)eivoas}<)&`&84*`tT{?ou45c+RPjX;imIsuwmXJs;5Klbii3#Q0kSLKcW+Y@xKcRce+GJ-RTlpMp(c)D`xrv zd|#_rj!Bm<&cad=Pq($+uKOY#CGCK-8EXOLAo{LJ2l({+_%87YR(e2EErULI*gm@X z*m6LuczdHTQHH`3=)x;unt9KH-4duW3nu}xk&Cu4-DS4wjNG}S$tO5H_$l1*S3Go6 z0HH1rN4WcDUK${}+a@ICZ(ZC#*`6h6EK7)q2OePook_w)c5%-9AxwoT6E*>!XDxpM zy_C$yP!`aN2TiCVLn_z`_E((J%LUYuw%2%(GBL3Cve+5zmepidD|^#$=@2Wfp!?NR zUpV2SwaMg68}9+`X#n-Ust|TK-Qk@HXu7dM*@>KO~@YA_S!geT; zxLp>TbIo9^WI=ZuT?ErRN;LqRSZX$7)+{MdSSiDnSdSwQ+6Yqb#nF393O_Ow-rRZD z1MtC55vP=~4kwe+$#2C8b3Q6*<^!T_D^X($HS$*Ns2(pd5~m<_QgfsetRt77rwh}yjg#yx`@p|%;RnzvAN8~6i5D;EQg*azSU-+F9W;M>-%sM=r4J zY%}@{t+!2883WSGMgw_85U#I}O75Rr0Q_D5;Du8|l@ zHWBq-r2&(pezi>6+daPx-qwVIQ3A6$h}GxIH72G*;HeRgyXKy?Uf!HvVg$M3Vs?lo j7HB*8-{6~e<}KKy%g|C8?m&3=nE}vH(NX@WXdCq(XawjJ literal 0 HcmV?d00001 diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png b/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png new file mode 100644 index 0000000000000000000000000000000000000000..d8ae03154975f397f8ed1b84f2d4bf9783ecfa26 GIT binary patch literal 10413 zcmV;eC{ovnP){+^kJY@_qlWNt)byXXcl4&di)UgOL4U zf7l=Phy7uH*dML-fsqKMr;DlfM>yz|;&bpF`{OQzgo8jbktkySeg~64fbWuHz_H+% zO2F)JwJEE@HLSkR79_Z#oHbogc3dx%o7^AeCk{b5(&1F_9NvTf!DryJ`XFJT+JS0q z&?sCD-y=8K2W2PRhjJ3<`jzFS2UeBViE9@x1RKUQCZdv7kl1SX?3WZMS(_}*GPxT+MhW0P|fyhZ+Qq30&o zK&_A(Oze8$+U<`PdXPq;v4_f|Urm8qVAY042UnGp45})9cTiQyEh4N`WieG?WwHFJ zL%SQEJASBPNL8tfyeEVAm>Ttneh$6^dT@7TL)6K`4dZuI$Q8$@YC7*NxE8o3xHh;( z)oY%paC7#DbzBq#z7eX{hBSaAFX=&XZgM%%7vkI`tW*yCO_Yg=`yqnAa-v2eeE;?> zc{iKw z56$?22D^!CP)@={l~{!+p^?NV4J00s5s~K!m``K3Z^mK!w_^!uRBfLTqF!aWIQ-yF z+-+mFw$C)OYiVHDrh2UxX&Im_YA#t%&~JYj4^H@@?c?sN*|d{1z)fXCWK#h&a-j`x zMSwIVr!Zx+>*mUE)45>nPAFTm4uSn)0ywG_n3eP}spMCtk;WQXTc!Xa#?G<8~9?@D4_J^SH8;MHSdkm@M;{c4Zl4~|K=yFf32q2}KbIxDWFpb1y zO+OA&=Iq3=s^1(B1GFU0ED0TN)1GUEzJjf&cITr}~_843H9IFf?D zpy-;D=W+{Ha$5$7>!~TGM>3^{(aM!hTwS-Zu6}T3B@Ohtm!x|WXwD0DS$2Sg4MHki zT4wy)C@!)S)O94Q^ENX$IJLgcuiK`aOAMYnR<7i>43I*17(|~2Z^{a28-tFl06j}G z1E(L_b%g+AG(2{IghMo@X493&wrmJ$)etG%R?khj1IO;za&76!!+2C}`5mZmW7T)d zdc5TLAso7|4x4fu(6j?P@#13#aX@*#Nyh;YpF8maDO(w~k+R(hKe!7&`(pji{+WqG zRNJD}1i%xZuq*IN{U@la2#gbNVFCfAchs zIJDcO;{ZH`Z=Jz5RkkxH?-ZOri>KGuU75U|b7#sb@!GV{ltwd6tl0 z`-tj|)YKcR-o#ogdg%auyuQ|?Hi%I3R1^-|ZB z3w@dmquBHyVR{7VswXIVTX$?MPH4+9kb2qjlDK$t-RcV{VoZD69&BtHN{89>gQ~qP zJ3uX1wj2^zXGt+iUU`JHjaZ|tY;IN^;K@-L=fQS>Y@uwVEi&RUN?2Y*+sNids}(cC z+40kwrYD*P3GD#2c-goFwX_(F;ug=ctyz2p&FRs8BZP#KW)rz1wGkz3b++zpGX3NIKL+e&!v|_Kf@T~~axF4tuT$cD=XZI()UWvicEV_jFqjbw^Y;_9AkJsqs?mSQ_V zHd!_~?Uk)r`5Rg=yAOj%Y^~TwjIt7{g{Gt00kYMyk+w^ZgMfMuZBvVP>lJ}>TFiaQ z6}$vw71{x^*|Ko~^_rD(w0N!+0&330f%Q3TNHV+~AX_dQo92j#JW0ofEat`()+cpU zNK-<*Wh>c%oF}ld7(cPM7T>>P3+`N++2#S7TwjYH+FeDL-}5iew@%rhE!V8XXvx!0 zTFweF>(f3j`6XB-!?_??289+P$hL!oDad&d`knUqYw_}zU&NQL{fPhk`)_>p#vk~F zOaH-9ClAxr#e^P5nv&DV0je~`L#5{FGh$URTHx9AYn@Acj8H9 z-fn2Xa=Bbhm#_bhv)?!+_&C~>bovC&J9ipS=gMNVj42zRq^}*vKi$01ti15vyd!%p zUA9JO)5+CkcwA~i2(aSSaRpH~0l2>#}`U$mAt<;*`UUpCUF!4<_g zFf*C<$Rf;^y{H)XiCNlB=(vxmae|1Pqx`~~S}Rm0li_pUevNx<%Eh8q90Q566YDZZYFMh0VeMrAMOVe1 z|Lz;ye`{f@1!x?J0yCotz`^}fMr`Fm4fEt{bxGcZ@CDfQlmg-(RljEY}^PEkElrDm9b@vQz3{qdC=2bx32OI6ixaob7Peg<(shE$A37*Y0*ydf7hWB3l zfOPA%yE6dnF4t(NpuypoFMj$Fe(uB} zYGE`j2L$`WNWctZJGzc_^Y7cZ=&iGKe5Qp4N#!&iijDjXjTz(3xiMo>J=mmazv7G# zF};w)79FkiA@1zpCm-spe1PcGSD#bY2j6kZTSF>x2d*b>5aJ1Q0i#dXZr;STA6&qX z?AfNYN-*H~;g8?zcE?0p{`DpSKBZ+x+2NX#R$#Yh=T4y^j8P-g+?ON+%kpw5Ksi!b zOAq(oLt>AA{_iWD?hG2?wJ$%XV>2K8a2fw~=WnZlqj?=Lg8tUGU(+#}_pV&l`FXI2 z2R{CgjGSMfif5%=Dvs=1Gg5Q<1A2u%ogU0AeaR=a7WglGq9Gm z05rN_()Itp2xw&&&f%Gd_t?ff9{`jo#qQFme-Q@S8}7!~yjOSWsy>00CD&oc8BE zFMG|E_M?KjbKQ9%c|x42azM)$4)-h1zrz4(v;}}*K(PA#cWCU;R^U~Jl3;7>rw{Cu!{8QN zl(B*ZEn!VUSbEKv??13(3(hAM`|DqSwpn--f-*wJC6w9N`i?w)2q&I8VbU?i)Rp5$ zpRbmO?ySVUW0vO8F+m{!u@5;7*qFB&61$hYbWjGt9T07-U^P?#05ata{Vwd{2a}a; z(QWDK-j|R#Z<>+y4)Emu^ECb8n$m7_4%f@(9^8ck*T(DwCIkV5Cej$Fy(m5INbk)B z81_|%Sz$1T#tN3wg#Zy2eKhpDFrV~OEAFZrs~>OtfgjpaWmJ8GEc7e5$ z<-7`0<%3Bl$~A83zX=m=j13)K`E?&RU1#)%u;U-p*j;=g6-ytEUsw>Kreg^;rRu)?wAO})#2n1X6G=;eY zbpY#7JLDu;AE2T%dC;~}?3TFl3JMDHXKYCH0n`pX@o;Z)fS+3mpgvpH+sc<*x z1F}9*_-oA}DzIg@@Ei1s?3sQ04(rg@i;xN56+FJ0yx!{~|Zn%b_xqcb^P%5t(dMXW@Ug}*T&pN4~-o|+0Y3PH&pF}W=|bT0Q%e706_}svCls?Dd?;u zzf`BxSd7-LQcApTHC}%70KMPb((ph|^QvQq=sA_wK%P6L#o@{e=S=Dp9Q*VlcFK&` z3z4}2a!ZM6K#x2yjjU$pQYbW-n|+%|^QNhAEZ%^{+o;|Dp_Dctk{ReEnaG1N7!M zUvln?NB+f`^cqb${^jex;SpPlIV(gVl3I2ghz8NCZ=kUwM+yh%k@0;{mh_r60fM<7 zQyUMG(-U4kq8@)Rcpf7Gs5P<|e4I7+Y4)N_=QfSdz}A0i8M z<9|WJh7HjV5X(eFBM0>$=J8u=0pwnoia*!0$bca|pm_&(<4!rrxI=n8_RLDeAtY}2 z=*KHo>(0ZuLTbvfXLb_qK-^8I+%| zUdG%Cl=sFd>;Oyj@<24U&RhVc(aBVo=p`QzCVUthI@4N3$j=WxTE)7Iqpe%ok|sRnzE-FFFLy4v@Ojy zAh^N;M6&#AA&{i2o>0u#PM074u4E9~0hJ6dw^~A0!+7s~xzzXy*t&$}*`nH~ad24Swg^YQW%SiNd)(;TZ&v!xo_w?$uA?IrfP_|`m zEQFQk^)0w$mv+7L-8Z=N`c!^^cB=rCZUjVG+>M2OQ>B-YZ>N5giD0_7nBKcn9Z(nY zVT8K$EKGZqvp|-)wRvDgk=|8G?b5E#u3g0gVLJp(fT}bAG6o{JwYgv&4v1g=CLIIv zMIDs;tm=7)QDC4e`P->SW@4!&?~R8=%fD+wwQ%fNlz;`*m_7f4lZg zPs+CxK;6mf8GGySjQUzZnze5S&OQAymYz5)_&eH^bn*y2)>B%~UnfXQkL<$*XJ5rj zUfj!-MX2_vYu16CIG-E`Qa)zv+b&q$i!-$Vw2cR#ICW+4KtvPw2|#OCVb?j+tDrN5 z?)7#T8bCM2K|x)hC)UY#!K_emE(FoWtx~UdHXaJ8k-wu&kn8+J-4;A-Q@)_j>(YJY zg?Mu97A%3iAvFK5B_WJYJ=Uk;DLX5%Z$S!1DXUc!tzD^_ios5qQXIOg3I}f~YCb`# zRk6GpUA2J+pg4XtgGkD)Rv#BBbDlJQ4i`ZC2o9iC;vkyV;Ys8tPL2MM0+eN;g~p)} z0w6LgK%2DyWB@z>N{>Q5fDD62D?moT1F($VrU{S^crr8~0`~=JA&cjHO4_~;Wq@Nr zWEemQNj!S?^ny4@yn0cIMFA2Bk;MTr5FUPj42OpoAS2;v4v+wNsNimoCijJ&noYkkmt8oOdws$f#{!w*f?U)Jch8E3A=KN%$ z+~TWqXo1Kw0L2&$j}jo#@V*79M#G~7Xtyqagu%lBw2>bmUGSvS8y4j#ei=rgkL1%f z@7Ap&y`32$qxTGRKt41A?~MHXhN9HfKQK2YxA^)%Jnqcg06k8QB}t7j8Xmm>352H! zplw$Td3)1=B;S71raVS|C4XCE+i!)Y)YsxC zwr{1D2jEFPc?7RGyqCV#udVzd$BRCC0H?lu6o-;y!s{o=UxTz0REZZH+>J9|JAt3s zzmvYE+Eq#889~}zMJ*4&lX>bSjy`sXzE)_;9zIn!*Yltns(4batkeI%Q%T*?_v-l- zwzrm3eQo2^eRVjbFzZgQkn!Qr)?Qv-9>(^*n!7QC+Pie_+=cw@9hkfB2xJx-vh}yA zTVn@TmEvJ#1=R8YJWubbp>9m4%JS)VG&LMlUV!KB-HunhxDSsc$As6z%h&U3vo;k{ zO$HcWI*2C`VCj2X3Q12&RYlshwMk%k0G`!-Fx?$J^uSaSsW%wXr8mn$ z;~AVgF)0R8iD^b{(GvruXp?%J)1xrGDF!ki=FyCE)MFsSVjfM6Au&)Wu}Bi=^k|QH z6l$achszhr(CFcFXd8EPGdXzH1jvCdyxFM(++21qTCwm28srMxgw9+m)jJWN4erJ$ zfHVLZMJ&MMe#UxB{gzxExlj?R><7D^?>gd zIsvP#Th0rRf$)HO7NyhMYMKBt93Bp!1R5YW1IR#lv;!2+Z+#M@Fq;1OKH8?<-rZ>% zn<;qKH8R~3_2@bhB`p7*PXFr}owme&VS;Ayb&TsY1IP$?02pEJib{@y9PbYJ9-F0^9DWM#x0cd9E8d{Nhwu7<=K>8+N^$ZNE0c0dR zf&mgRx77?FBjITdP&~i&$sz#7EWzl}kQ~~U7Pda>u@Fr0w?{q5-~J?^euK+yOKh+@ zK-wS@FtV&4AYl`uO#r1C4No(GOn|2epc(>Df)>{$ZJ_HW%?-am+He4COHWJ0KH7U^ zJ}zBh%m57^@+5I(e{q>?{I1NR0BKHp2%Oha0+beGG(36%GGJC+2~b6`N$@BEs@DQg zX1pBgOSE*}Efmy$I&DJ>^}KXhp?36ES5Hqr^0%LO&a^z*cv>b}Ee=pNt0)6z*0lp< zSV{&gYQPJSfhidrK-D||#TlBCfycn$tyX}D>xy2C#ZNx60osnWp*w3+F|xu#VTHJL zgq)pW3H*WRxp}YA%HipiSp^_NAR?fQ+R6uz;rTqg02z_b!w-<*@IW1C1t<%~d{$u5 ztf~K`ZN{~oH)~6)SfAzrbq8wx0#N79V@ObTnO>*{L{8A*)}e#1H3DaS0kwz1l{q{-VIh)6$u;94s{*9U z5~XMZ$oNb`HGoXWBy0kx#3Xo{0hGz&9?~NdEngrPj~y9BU6+T4KW#fJ1kU3zQ!wON-a=10NQ87wwb%6LRQHnNzVok~O}hUVsF`(;T3r*TuC}N0kXv5o)1FlPiM+Bqt}hut8}4Q~S}Hl}cCEA^@pEl%fTo9TnOE z5;!qR0U`~r9Ux&7qZFX$wE$!QJWT-AasYwrihB-=rayj^whh-tom(<6q$B9d zZUq^P7R@|EduBNavK9kK0a0o+4?xA*0Wx4#9hQ{S4v_F!bx8Vx+?{3s83>O8AUKu; z7R5-2!lIdB=SZ6jp>5M1b)#+7g073t3W?bexF?D1dr=>Y&`=aP=RG=KRF>NSOQy95 zK)et|<53k_05UKoLpwl*rDX5|WCT1=*3s1jpuM#X5*RF;GwnaH88>Ycu5CP3rYl6q zMjop1khimkM{gLVb|XErK`9BJ!`9JjPoHdbLU(bm z;eEj(uqd?P&>oz1`XpVG5SEpLMGg41O+(c*@m(RvVTLqR$Rvb$EPmC{;Fw=5eU(@q zfM-E*{{K4m?)@;dfs>DWA9{;2*ESMcghxGlkqgj#6g@N7fPjz(bJITSk)MJkc}X&3 zx1n||Scj*RSZZ`#x$)as6IUTgi=&nY;DLm932`IpiqozPb@`WM;c2AddJtCz%c<}x zlTT7LK>|GFFhd$DOoH+&LAOZEBO#raL9xrfVDKn#VxV-BG6@wi5acWy8uM^nb<*3C zF2kbP(>^3_>j4H&AJ*e?wdPcXIU#bR%Y(SN^(B7;+qG*q9Lts!hUfDDKvSRB0+0c->J*@QZ2-mV0!U8Bd1526=;cl}bkQ8tzni+Ng#wO^Uu3(L_tPcUJ2^F{|sY8r}6)1CKU{y0Ag40i>Wq#8V$DMynRd zXk`mr#M7(*DR#7h*J;LQ680?4Yz~kS`8@mp>4Aq_pJ?eknRs%@Ca6=I+r!mym(~ss zA4IM+m~%${$kj2BJP&es;J(Eua`v~}s5PX5=yquq0SGoEfnRZ&amirK05UQetT{mO z+VYs?G@CFn3XA4Hby++zco~HU>eLzaW&yLSEe#Z!GbVCj-N~NF)fFHbEb;NWAI%Ow z1wNeH15|rvqs0JH3^oD)2Bu^v0V+y2DU+}Xpi&+1NE_($Rg19bsnD~MPM#C!sK1x% zAX=wf-MX~Km`A83YRASRU?Q&vfoLGi&p=!xesa=!(en8>x#^F@M!Hf~mK6a~LS$G< zhHij_&#Ef{sw!;`4kW-spbWV@OXl1ZKNeC#V@a6X;(mxdSet;y4)0u*1N9VQ6mnIhyQEZyBO%Gb%x{I6!oXH>p9h>Ks5dJOCM%k^un0ed6UHP%Pb8m@^LR*1I5nOkq_hdUc^+S%FHIjIFJs_SQx=R!_ z{|}V3f?1%o4b%2-m&4)?76nK(Cekx8+8iL`lEGk!m8tc$a$f-|$Uu0~PAo}G2sF?{mwdqxbK&cGQ$%gni}UaT%W z>{iFH*vN(TF1pf6baWg*dmhXpN!;AVi65PqEqZ491+;wOpOAS+8#RZ)#91aeU3opr zM1U0TES(RaEFAz5U^3zeEO9c{qvEDbq@;7OZ2q63IpG(?4?U1W%5uNL;yAjv45nq} z!0F2Bz~yd^b&Rz}5@xDhSt1nNKIG>}ewB_*u5Bn$utQM)S>h>^Dn$#P{*b_Qi}v2A zWlB&7DvMeu3e}jpavVlt4oQvyTVrcNloqGbjn8N#ujME$ULBYWcGoQFO`)jyw?y-1 zd?*fmxYA*8|JiWuY&?g$Do4)Z__4Bjv$8v>bkFVZm;oftBGK_9@@pl%lXjej!A!LC zh#}9ohCi{{ZQ-mp-B&KY>P}({57N+{xyjh8FctPfr+T!$Mn30oz09XHQwIB^dljb1 z$^SVOsXW(wZ+)uVGjE;TvtW(PvtX@k@RmZ^+(Uch12(V6o&_nG{11DO9u@4h`w=yp@yLR7+-F_P_1>{dzv%Vc z{4?EWO|R#D_cC>41Q@6rEpfZPY}Qsw(iu+VtM zk?VfLxt-`8D*o)6RH0G0sdlU^c5qq%Bu%TN3R6ec{q<$PcmS#o?ctDy1vk>p({m{8 zE>kOk6c$U>a;ZxBKlm)ODnpQ`%TPxJEO2ZmdS9GBJEt$ZhK?H0Xj&UPI5rAX2R88L z$%0cK7N~Y(7NHkw?B3M1K;whO01!A0WE#NW=*IvFVBhg)$LPV1*_EBco1N2*U4tE( zRtl2?YqWMOIBn0yR9sp7qyVcUb1gnBpzXq7P*oT9KOgqljw+zIvtzojb2zbcN;KS) z9hz1SlqysTupC)~JF~`b&#VTY6#sW--*Hp{MHLo1Fn0-5nsA9VKvNapXEcv<*FF9Z XdJ+W}DiIkV00000NkvXXu0mjfKBlg6 literal 0 HcmV?d00001 diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png b/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..2c18de9e66108411737e910f5c1972476f03ddbf GIT binary patch literal 9128 zcmb`NcT^K!5btji2)!5SAPPuNq)Ls56s4*38hVo^(nUfO6%ZAH(6N9hNR=iCp@USV zNUs_|I-wKc#ou}5-}laWIcKxU$(_yIot@8o_s%{sGSH@@=As4w(CO-E-X`sF|29fE z>HYT9T?zm$_~>e0H4dIw&!!4C9vSZxNlr9*d^_s#H!1R~WS_6MVYz@X@%G!e zXHz-tb|VivQj`iFZDUWNj>i`*9rwT8VC9f`)ww2)D0tG&WBFX^J|oMigqUy#_eV)Q z<3?;pz6pkr(;Z)thNWZ3Tu^XIU(m2~K2{iFEAS`~Gy5VW_tC>i*Cl0kv`b9xtW+!e zPD_a1*)E4YGCWy+8(ZVrP7}Y9URLg*>8E8fyY^0u;VQCkoBQJ<_5zdXl(d!zb~b;b z)6|dkG)>oK`*erN6Q98nTc z*T4b)onLqyA@?UYxy_MYQjd+D&|e(Pm(0oT&BjWQ4@?kFIoB**?M#(;rSUW9SnG<- zSt-|WaL6iG_P3uZd9eIpr{TtNWC*$Hh2Qz?uBS}bIbRfO#e{zRE!IEy&YexD%F}@N zL-y@k#YdI*GK@^S9Mw$gu9^2z1mSnEkrdxz+MPN|ZNhhS)_oYvhM)cLTYGn3J-&{3 z*gO%dE$+F=!pgEJp;TQOxUvmXY0MZXd)l&aIQ@q%&TOO4FwrA~ak$>;=zXV4zzr%` z=0~OcyNxrVAu`L~2ctf1)jOUXrl5QhI{u_3cR4;2>t?n_c`o(TMz?xA14+Wh$Va%BY0&2$WKO9mM2sYf3h-OCY*=ZOJ$Ngw)1D_iorRZXHQZi4&2K7qT927nQC0Lrg3 z(#lL522bDvLQQ|!4#s}u&v;Yf6v=QytSm1*VR`JzNHPFHGlJ!`WMgHC3lNnE^`=*0 zy?^9tJWsJlLSn+d=%5(DNQYCcv%)omexK}hyZmUHWQF=7JRFKXB_b-*?UD4{x!=dVwazRjll3YN!e1GQ6{ViI{ zhkd)N+MWKT`q_V0)j;tA_oAca{;nI(Y$Pb7t7Zgb7)DUREOEf@igE4Q;TqcgkX-wd zJ;8G+7!?>DALr#bk)GNchOvQs{BBN~iU1F0&RMR&ou$CHl>C|ZrZ@PkAenI@K>Al% zQ7|N8uxRTq4vM*lnm?oa%}HLn-3G$yJC_b75?=65k%LM)%(H@{N`65=i4pdO>Mz+= zLeav25B?f086=X6O6;%!2@%ZP1|;Nvbnj_2aSc+8ZOx$k{x3Drh^ zc*UWh!@lFm$>1}Uo>u2rUqXSar;=W-2Mqo41Pl(rQD;>HWC;@e#W@Z29HUt(caNqC zC&6BqG(7E8;B^rX*m6|Ejm>-6L>RWQs{?%J*!{N&Cn3FMX$DmBS8~(Emio*Dj(^J_ zk~mE@d*561epZk|Er>78iC#q_4Sp0Y3GD6B@JKKrmyoJG4WGBh)HqTZZw>kH>(OJH zlp#iE)N?g*Z@4^*MV+s+H!!1LJlIN*`JxC#o-v0{2|BS}}kDUMqX8%d%;Zo1pF*{G_rVrzNd`M2ya!T0DJTesuRVwL9u7n&PS ze_~l@1G?`(riUCq#<3T)^gi`sw~pk^JSP})C#_iBKTD*{^N7d0$A0wJ3#IRYe;0q4 zA*$YJb_LE1lo-`!M^fB~U00SLiLywh>%-_CXgSb{ju=7v+FzB+78O;y>TeZvRv&RoWxTLP?d+9Zi&Ypua2+{3 z?&P=TOQKt{%~L~p0$j8^;iia9j_>fKovkcwq%sUQ@nh>Z!)%cfJ0$;z4CPrz6I0OU z@+^ZT$qbq`@V*LyaM7l>CZ1ZQo!IplAN5a81(Tt~ztAbYc(d{@u2@?f2YdnGcoX!#60Ixw-Nvix#$k1X*NJg)beTLqL8^6*<{2f@@ns|Q}RjZ!$JIHK8NbS8xrmu#@ z6ulfiVr7xxNb~dV#acSrSX_pQm;bUeyjdV!{OZy#M4(A` zwu81?V`O!?oZ`D{REMi+x!1hB*6Cy(I?k8T%kET=uKQWo39E}=ca$my=uHTEyP8y z54Nz1YH*)(w%#ztIo^C*PQOjte`Hel~gpFN_jZaXoFZnUzuu<)94E6T<5ZU?s4>c zpU3Uo@d?+!hgYmVil!6X(ly;KNm*OwbI8{z3v|%I_4HT>Nt&7^q0@@SPXaA`iAvAR zSr*v1muELwpeL3wqu$P7L5q4m)-N%|J6fE`4!V+xyrOkr+X2!LT$k#tFYksHJH=n z3F!I2Qe4B5pnFmAer;+($yQcgD*uHlDurPx@2dd)1-RjhQe(5`*~SLS`q|S9v+`3~ zQ>IMi+hcTX^%}_YWT=}koWlGSwSH~mOvRNJ&Sfrc>H__ux(6*kTUubhdoQN>V2}J< zR)ymBx4g=I%zlp1J+QjI7joltSLskIt}qG%d@lfB@0(d>+A&l+Glwv&La86NxDmfT zNv>`p7eT?@iBSF8R6M^wCx1D;HRt!F#6s8>2mF;&B-MF;2m~@G4CaiZ!p=4aG-$V0 zYR+PtSNvY$YwW0OPYxL-i+8&!G0&s(?(IcQ&Iv2 z0Nx*-7_~pZT6#2L-so8nF7QMgH5}#22w+dCGMyllm->HAO8q%eYuJ_BHB7343cyG+ zgo9$W05T7{CPl`Zw^P=q+#rx_`T2%M zMCeCJLfZT%fI{csusPnQ7Xv@XSzVNmPU{iX2w134>~=VfgQ82*rq^p^97wA647vgT`a# z85e!NpbSl#8uA*dnopv4RMby4F4MY{UFn^r{Li3l%Ume;QtBh5?8wCixw0*zSQ${* z6)@M`djm|Nz;H2K_j1ACvx90`pqKN#`9b8Cd=@J|$6R{ZYc5yw){(D1GtABWH=Zy` z-HxQuV(8LOB`UjI4iAOJ34LY@KVEmPb@XIC)FfA6m5B&*8T*hQyR{mweAL1#*kA9n z;O}eZUE%DcD;yjrQM!F!8~hPzPrCH2Fvr-ItjJE$$pV*gv9>ye(q2lsB=uQP$h%X% zlekK6q~fP4niGy&O9mR~_I;)G@;?e;L8#rja{}{3_rR(d$+fAsX?PiFx`2ashkOGP zw9A><#);kE3G}H}!W&WxH1$sg*P@*n!{=#L{PK)y~GHI;RsgpA$#8cpY~ zct*9kjG$l!k{*0T43n={dVV!idt6Zw;lPW%!2K;#E>?J>D|V%r^A`&*)MdYZJT>jL z*;x5TTDFevc8OARtqyN`Wyt;0MTTO-DDG|wtNxUqM1$~ye0&&wUtZ&eqI0=0|Y{WT*|Ia1An)J!bjzf9y3P874R^|FamuD zD47YqkS6Zsd3^fEq_zq1i3zN7fM#ldxb7Z@0Y;<&n|qFI`e8q;TO3t$s`geh?U*oK zp&F$0CKJFD-a%BYO^4KA!5J4T1f9rK@Izkpt4qui#^S_s8AE_pvL7$dKQ z*TXfMJYx+MCq$g?pCj@15ZQdjbAm~v`@A?MCg`$$;e!iKvcv423 z^QOF{_mgOGh3-cDZ={Gyr z_&&UYqVw>f(5K`SHp~Mm5XB0N9$~=XOXd$uQNj=bO95ChnZX9K@n&#T?vXPDfqt07xJZVvBuujM>H*4hP6HvbJ~#$K=z-vNQnRCryVz5?3YqR02@1#K{#%aX?h4VQ45b zcmM<+1V?|eCnx}P7(IWh<1mpP1d4*Z4r1WAfB;C4dhrfKPC^**Pz;nD$YOJ0I9i3T zdQ`v*UjtnCM$WL`J8L<$;~1_X+Oyzj(IKG(tLOn!YS8Vny{ z@>lc1XCA-~hhrD7h1@0O)T))gw+GcvsVwxcnaCv{EQzu|qcwKGyiwb`TTP(}njGXHh$KxOryTWq$B1F6I8!hh2O<$rL^FOXZoKME=~3M&0eN93bd- zfpL<(mU)+asMc@#Mvb?Ws^Rw;E;iny$Mb$bu)1ovt0lOm4f(~cAmY<65o0ePN*$EX zrmHUhGI1J_t=@d`{#mmFd?eV^Q&jw>g^;Pf)7JHdLzQB*87{77?Kto0xMvGjC=&M5EOW+c zXpXOY6|Uf)0am19ZLde+hX5J6c11*#mSinvk^A4NWc#m5P)?v~|Bppv*0~T;-^rI9{w3{`~5)bC}`nF?zGx z#@S`#(Q@kl-1Fmze)A@u^#@9=c>MA>$*eslP^G`Zvb5N|sKK{mQ*V?4eX_x+nT?*N zalRRl;P=w1HG57g+d^AJQCZh4&g{?mbJZuj*>jJpGL#!`*C>{MRd4-HML#+BNUG#EHx5`rs8QUMda13u9eMG(lKCYTHCS2gO0L&PIU zkkI-^jv5$aR|blKRsJ6xJ^?au7%A7>eD6+l!ALkEL&*RPl442Nll#UeUv)cn5=YV~ zP)$eQ=SZYMG+hSAy@o*c95}KXP7(~*M%`ovFuZos#RM5t0XkRn?DdjD!7zh+HMGoz6C^Gk*}xdzg{VaE0-2L4An_I# z_)DVjA|u=a+{fkuUkWg+!HA~@f87&ENbQ{u_}}LPin9T}}BZ5K1W#~XT5z0gcc+cy7@$?+tH6Ta*1qVBL@ zBwd%m=LAwRv8~~Cx3MfLmwax@N%=M`ciGYizcDPi#Qug{`#^)V(iZGpR*3ayNFiWv zCT;%Yg?Tn;SO3Pvyu6Dolgt$Pq@8;O(nD{uHM<__6!t9UUP@K#N73GQB){T~9Hpci z<4P6T>Kb;ktBMTne4`e~@)E&sIdENQj5G9OYu`7~bvsRTeRl1z?i^aI{)?VNlekCC zXJKVy+B;Z0|Abe1cpfcW)93y`*4%NW#+1!-OVtut{#3Q5fvBQ-b<*gu4x4f6pmz-x)Q8wc+4G^!kGq??b_{28Zdu9+dS0=wgR`1Va^@f*j96v zE?=;Q{AtjKXi>F3-EkrPfL<`s@S z(Cl$t|NBt^_k;7j{U(%~9iLt{7g5yFfhq?^mE$`_Z>W$9l{seeXUdzmz8$X$3_fz0 zNc_d*naeGkU7&S83}C%)Owd-QTjWCq)4F3puS?Y*tOH3*JX`9t7=HyB%;}BFw)~fX zP3M8Ef?E#|5Tf;EuVktd)#&vh7trJcyxkI{{O|eok{tE^hzi3_4LW$*rN)J?Qmy@$ z@GmJ)5nOLC0(h_C(Ayd(aO3hP5pxuMsRZfvoFgBCNNrsu!(1gLl_W1XDWi)1KiM4& z4TFIN4Z44?71-@F^TGn<^DjNF#jfDTD;qdJ36mB3{oK$>kk1T9x32)H^4{v<&J$?GFZQeeKn zog^e?9JHCkaVAg{99*Xytpn)yWZ-y+!;hT(I=Fwaat_Fckc87LJ*r7!)y;@7k^fUK zxl{eySNWG_U%a8X+L`q+Pwk<%iyJN!iw;Q%=1>$p(4~A8CwtPS13^pt$BA_79TEm3 z!hx@gB4KmstaCTszUdc8*ch3y0f@{;*awP0cxYg(J0u?XLQsFzBA;#(`vHd`I*lBM z;(99!j{626=)R8+$DgEz-MfuzaGI&_b*%9#-BUQaw^>IHgp<=gob@UA0r`@#>-qw0 zpfFP4HZ?#}t^J2jFG?J|6<^ALo3?t>Oz5`IuInteCESw+$NTFo3L77A?}>NbqA$vz z-v81kRTwtLT8^1Hkf#X&iRsn`fKmr-Mu&N{*qwp;$qBXyT}BAQ@L;wB^UWEXX)3_b zh&*ke8czIhFd!IxCi_N!jnrKGIQpfPR2xJo1%*JNF^PvDwB;>G~7@ zQVZ23Q}9_P0C|)?QPY(DS0!&Y!!b^`S|XCy zKNy*Kil!;HIXgI}+mn{ko*V0S7_|JPJm`{p{nOe9Vi^>B;a*toh zNY>_;v-=$AgIA44ebwp@a!75wJN7K9j;+SW z8uoQjVUb03=55d=@#Y_9`Fs=Ut|9xs?0ce>@0mn&q+oSJdb^!tTO8;mb$%l));(4- zKPebA@3lPn z@G1otTd9DCo-AAllf-ruy4anJn=H{RXLG>6j;g|@m(&__Lzek=U-sRZzRO1lOrtOJ zm+5k9slTfFKsku7%a$T6ENphjA3uy9eG=kh6ii90n}D&mc!E$-XY)ycsx6qljq9PY zpDzzbG!`4}xmvrE+7f*Jx351b!!}L5XmvDjt;&0$*g9U$nbVZwscA2!5>S?vG~K*d zPzXIIrnkt|yfEO5^dk>cVc0*&Hh$%zYA8nPL(Hwwk?vVuZpJ+&#LxCsujZ^dalGUq zk8X*2y(traI^+1KZEu-(_j%t<)w?tI>hVd#CUfisw!-|mSM{#>X=67C83>oRW^)Nc z_@hYvV5!q}p#c+`qTV9*kqk5GkA6Z;&)MXHw7m;gzS)ito45k#Ejt_oX>5cfTLfXUX@_N^+#UicK@ zbUwcCAj!Nyi??H{sraN8NiTB?aleSuG-iy_c^*{zg2xn*m1e+7rBnP~o!PuP9z$Gcf(C!4f_G&|`v9JI zHr460gE4qwW4yYiYMyx4c#(d_<1JDCcBZLe=D9DE4fC#q8)2D2Dpnaszf0h1)i*7) zxyKd8y*&dyiKySsH2Uj5(~gfdkoWmaI$)6ycN3CquawfZ+R8$$x+k;L>%Fd*;XYy0 zkq~3{maC~f(~h3ZUsXWo-EodvK!+KO{DW8g|IOnpPq%l@9Ky`Dd0%sz0@6$Ox`Aei I20H400LcNok^lez literal 0 HcmV?d00001 diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png b/lite/demo/java/android/PaddlePredictor/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png new file mode 100644 index 0000000000000000000000000000000000000000..beed3cdd2c32af5114a7dc70b9ef5b698eb8797e GIT binary patch literal 15132 zcmZvDWmr_-8||54h>`B@4yC)hOQZ#cM!EzfhmdZRPLWXQlpaz*O1gvrk&^D_^84TW z@jlOq4`=WFp4extwb#3MjEilFPELs0YL1Js)Fn* zzr}qsbfZ_wbNOa4S@vf>;bE~>+%RD!>v%IFV#WTd^7(B=#T|Xno7mV6xS4f=u6692 zQq~7{i;;}Y46D{(Y+R?~SpnS3W=+e#JKDJX-SSUi>9(#}mwE5Tv-r0dn5ZY||9_k1 zWM~Q&Gt=O&6oAqZ3T;9&9$g)JWBOFs0NWF6vYJZJ24_?zn}`jXIHjr$^?F69z!2p< zy%t?XyTRP;!zMXPY^&6kR$$J?UW%?3bCC4XDqr@?ukqAzCEf6lUi%~QE1bZLYf8h# zNIFjy{z&gk+iBasaZQZklPN%Bhl~H-pewWJX`t_4w;I)?=gcrEWq1%u$-pwhg=Fn& zj3nJfbY`j%G4F^8@$CZRg?Lweh*w;b>{2YdOIAi*x9?W^yUNovn|q?NJ#6TPeU_fVowC-#v9#b~gYH6zAw5m28>MUeJ4Tj* znIVgljj#XhW$ zhiz?z_2X4xbgPrk6@%1I-IDPigjXj6D_rk=N!MHKhrgxgN|sX9wAG{r8mKBc5uYx! zD6;oWKPFPVaeKY+;_tfGk8dnA3*mxhD6c6ylsqfXvWFU-T3PF_*(Y_!aR4ycp@UiK zL{0B(1-*H{F=ezF{RJj(g)4PzJx50@A1Bg2>XU|TM&*KjHze0G!vbN}?9#L0`)Mh& zSDg1vm!sTu701b=n&--{Q{n2DpuDb{%No!D^gwg^bAW&J!~L20v4&-T0QrdY*80B?ozklkW% z0rk7=VB9&#oB_RdT&RhUD^ z<%mehua9i+?=)hn7$VmdJdx(xObB8b; zd)9+r z`yz+r{dSM5hDz=4ys1#(+WoWqC+KtBRNG8x2R zkNK+s#C-E*)s>kZCpyIRfB`}hQ6FwUXyKlgYs)!v{kjY>{yEe5^Qr5JEe^d*zcU@; zK#oE%1w&_PZ%A@P#G}S>`1qbU0tkHPO<2-5_Uhe0Y6$FovD9c;Ov~qVD?l$$zpcmn z8BGk}4~3UeEkzOUc<9FqtY1TqoY%qGS&?kSM=O3g}NY85}H(VQS~6J6eJsX=%$ zf%etV-q-i9X(#Qm$6xDNs6>@0-*1b4*6TC?1v|R@FkpbQLy%N<#0-I&1swvEMn?Y( zQKWmqz2#a=uq>R|^cdhnkaB3z*DB@@Q=Jpj%9EBXLuo{WDl~W0E}qH^aARnpD#`Dn zAO=+iepMRRSE1j%9nTDc{=3ACQK(De^37Zvsl54F9`aO8G+M-hmV$3r9l|3HavVov z=cO%-IOVsvo}L%}Jm> zX9gR60KV3P&h$KA;XH%c12K@uFzJy5i9S6?U7BKXLk4&WhD>E$HbfP_Ojp5OF9rfm zT$`)n#dWaGB<22Cl)AZ@Gv7i0;!*>IUJv7##H1X4+Wx!Jki<;jka&jGH6W2$nzJ4> z6yD|%yOMzcBZj~}DSWA5Qj5Q$P>edSrrCzs=X;k&irN=Q9KBAfO4RZ>klxjm*H%`2m5c(y7Pw zcP@DyYA!WftG!MB6T>V!I>_ym+&LEFyikRHI`-j@U5hGl(;JWZbO|orN^1|6{D4+0 z>5k@1pQ`!&UM0WB;(#4ds`}Zu6)B_YebI)X)jZRhJn}_frc0jF4SFi~JHS=t;knPP z&yEu(+8%qK>YIlcGahTfF6Ze^7edgT$J`6#2qm|n26OTFDY|d8s~3hl zpLtuXp@mq2GW8<6|E)D{#yU2)#iuPY!=|5Hmo-<*yo(QYr$3HQqx#%vtHjS|I7NiRxC6lDQq< zTXIalFx_Ncd(TZ(!iRaFymyh~tc4h-VJo_vaMKP(y_b-@V9j{@6aA&=*?g2r3#HBa z-Q(IP$--;P*a%%PO{^%D$`G{5nl&>sUgEN|s^PG}Jh>ISvD%;O|psp}p`-pKAK?pbIHTV?a9?u}(q*GCDRrVm> z0lC9`wd;C96R!Yg%?DnK2`W*_@jf%9IPnwdr@BgGxWS)z)J>cDasy)mt3Y7)p=txP zM)#~H^+!85n&7b%$l{U`iUrdD?1+BT#+yClM)OQek##8!6GFE0paMGl~ znJT5wR_VzqeBv^?U47rJ0!hXwG=8QSN^}EyUNDp2J?(D#FGFgCo^@;lRCMe2zczB^ zM%9XHn3ccHp;wqZ^Uy8mD<>D6R1W$5gqQ>%@AfWuiX0~?SIt2=9&6BS)f-v(V+-C6 zBfbm+ypV$sk2v=A1#JUeO~Sbved*o%-1Huvn%MCF?%m%fP5;xCPP|-(b1@laO;e4- zd6?k_0KN;j`6NXEVgi#X0MXBw38O@O`lZ=y4(f@Vx@QT9*Vpgk{{$@lzYwyh%?NrN zGtU^kn)F6?fKBPA{djTaw^L#(7F&HK0b>+C#os)3 zXBq#MC^QE6lzK^4733pD>UE36G;-{`GpU&0a|`(V-vTwp@G~>2EL6F$*&3YMPp-<3 z$pGu8`_-xR9b-}m{9;+irLXejrTbK_!ep%zGnh;U{^iGo^_=F2)RW>Gnr99OXB*dm zfO+ugGg0L-0>cKR_lG&~a#|_x2{kD1`&ncdCyi6M^Lm931EU`O+-XCCFYRAnjs5f6 zUa^V+z|fk5UB$rN`lRE$u7^I~$Cjw-;Cp6f)HA(2LU;};f)pd4T8-D?I2up+3G(m$&;vg0~+JOD};L`gqqk*eJg+xpbq{T}SE4${0xj>in~=ldQi1rE&?>CiYw2 z#vg0Xtv2hPZfP@t{cR}nkn`imMzN%Ni-Y?Fuhn*~A(k1`mx6vQI)vLRy&;WKU0n}B z@ZJ|)Fn=>TPu!<>B>2~#eYSLuW5D_)A)V?!{Y4XguE!i#eiyl1d{uE|RTBFea zM(g%RB^85qT#!n$qYwxcyR1CEXmt{nlJiLD0Zs8{OI%+d`MxVXSwT?e&2t6`t3 za4o!LrCv}!1now|E(qC6Hf>E@-0qF^3NbW7_qjxU<9CDT$8j)VXDt{8H;2Pzmw@Nb zJ}1NB7;d^GlLw5^EU`sTe0n9Pg~GmQIXwnxEAeh@zS%X#f?&FG!fvUXW1I^%m4Huq zFb9-|D>sEz%pg}Dy}4S#5$%jBg@1FfhQKlNSk?MlP{oDv8s=i*#C%7KTfKRpT((!vAA*0?h5%4doY~|3yq_DA32&6T2RHbNq-AItD)b&W z5)Ng>T|a!hlRxqb6(lwy3n#TR>Q{5$zoTQ(7Yp23btrx0L6lb;lMIld_ZsBm;X65W zhL~-DK~O*?iR1lG`e>ZDti=^0@Hu{22rk-ri$|Mhlfjx zz}x1wtNp{S65T4sftJev1F_{RMAe{B#a1+VB3lE#HN&bH7Rc8 z9d*c27p;2oA4ZYZSk)abazBuwEu8=L?5J?TG~{R3V8o868I?F z#Lt>o_|ohZd7psYl9Vtz6-np(@R&^Q6yKF@# zKK_Phwv=G^eE6%t(B0N4(**az{Z$|8Nab8SLz)m@0bPk@Wo;!3I&BJu}Fl z{}e^!Iy||DQ~DlD9=@%{OB>I8fpV4ZTC})4v8^-k&+wR4`hMI|wtCe3@xtk*M_gV& zT7}a{1ERd3c8RiWPPBvInQ4k+GPxSExF}CJt9v>(EoD>AsA|3ioYaprn4PVQ}7|zFbK2=iyU{SL8K#I2+N-*;IUC zGNwTD;XDPHkYcjzxc(jT?|J#?A9c3l*&Jc_`dkI4Rs7QC{PM6ty6TzkxCMvgm=@WZ zf59SoAflkydVV7?TYoT5`U(N`-HxGa2z_V)YRIz`HRRE3`12J1-lEtmojvMCPtH+1 z)V=IiqG9TR@`K%FOk2#6!1{1OD;*%xRAYo%)EDc|<)I;%EXi}?^()_B6K`pYE*`4Sg)tmZ&*^v8jAGJgK-rh(nO znii&AGyPojK+Ee9+EI?hH-rm&m>=`lAO7{E>D1JKm7n{&r&z%Cwi})WQZ*k0bJ6u=B0Pn1}ek~+ch_lXwn zuc_uu@YRZb$iGWq5BG|g|^Wd_oh(t2hEHAQ>~0CE_L3eNN1(NZ={TZ z*Q&K4gY{whUfZO+x8Pi73^^HTU(N+4u|z~}-7IGjQufEje1K4zazaTk96zyU#Oomt z{bZ_BZ#I(ren>G~3QNkj-ElHS()&+TCR+bjq4vO-*_o`jyU7mwVd?J!edfIxKubK~ znqmum7Gd^m1|fh?4|kW$?Yo6*!cTvq_fNlm%+Olmz3Wf^I(4mQ zO~z#3)9fPojD(VbPK-c6xq)}DM$borMa#X!P?x0&SBqzQG-BST1On6bd~bfeDWpmL zg;dMkgsT6muQ^9L>bR6T?+9!G07EA3XvMR&Q}8^MSfgNeA zEzFXFyts}my(yK#E3|dx>wH+PW-82HFn_p_ z{;sH%Izw2f?je+3ZGMKbJJ%-MUk6I$Q3lW`X#vZ{OC+X9zuDb|vQX4W2a2z2W*Oj)w$<7+lPbGYqEE4!Y z5j4*J(;o`UAc^wryi7M1qZAX{UySopT5y$cT@|8wdo0j-F+*z55(QN4-0X9E2(%0w z->Pj3_BQrPW?JjaUyorsqkqgQ;wow+pkug_qLB3byas`FE+^x`c+_Iv!A2o)GczmY zAV6d5;m~?7FDJ}pHp;5ORZwuDRq(s2BNghbg+aq0nsM$z_3LiUp~h}O&p9WQTkF%8 zM=j%0_<0RSBT*koU?wS=bWkoexJwQclztyKASoPa^=_gN4ebgz`-%PQ4pC%-=4Vq0 zfe#O}LUsDlrtPI4qXRa|3{g~nzfS$+u@EI(83`y$`zM*F4ZrP)V>J3FyYXx}ZGKDg zcnAHvt{Rs*n3G9nWAYgvN_?47{`Qg%8)$u7L&yUCg=`X~0xo?Nm zOT?BaawiXVZT^N9@PB8m9mlRme!pMhW#CUp&O)q1Ff49V5&%z22#hJ2F`M#8APaP0 z$_Rp4aJOUiQWa7(@mp|%WL)nG$d&Zv_rF<$bdOHX?n0#JYw}R-L?73ZR{Dh~d)_hC zut16KfP{BGRQ-I6p%4Q2bsb~&j&!tu<3}y`>iw3ht$>i661@OYn_Xr&XV#5d@S|oP zA@W{))lxW_UJQXd+s5{jYwPj)u*;o$QivH&LtwNF#bMPtindqcy_Sg_0jNOW`lS26z`VMFkJaH+Sv!=ug__rdCdmKpW)`?T6Ob{o>w!vsy+D z-B>}mgAw_|pUbN&6M&;nPF~<=LStpG+Z5n5r71uf?m?gQ-F4dx9x_V$5%CbECK$Gw zzJ2<^i95T446#0C`xOGneN913e!;7o!R%C)^uMCe0=Tn<*P?H{k7Z&~3QPz=NJW=T zj3CEU61-h1U6W|>zbw|;d_CCnt>k5|J0cEO>N_La+8&pSKU3E{M-On-Vw%ehQ{LlX zxIB8%LF!fTxKT!H6<|d62Qh9ehYjV*#xl%&Z~JpAI7ZChyU6I`b9k!^*geM*&r!)0 z`P_*C_$(P{7dfN3zXX2lZVtYo4StL|JW2|=e>3xO1G$K#=;n=dYTEcI0n01mkFdT* zZlxjCcP7Y5aQ>oPVpawo8YKRl#hc>oIaxO{*fKmVk?3H*sQ8bIy$$PNS zm^QUJj;!T<|8X&Tmhjigq?%e(ppMY%uLMndna;mU(!hA{kXVc%0H6AUgIMB;Y2q3as&sY398#kE0 zW83CIlm!|%OO&SzQ41d zS$iN9BrRi!79O=xyI?ngbQV~+RpO` zgt2WYwEdm=V<3qZ)gKkzTAP9Zf$LsE<)l0?cLpV{+UkiYYIQGnS~Bad;H{xUx0IA93P!Z$Ub zRs}&&XlPF1+UESgi+B-d`JNY2Bfq~xE9@Kpnx?;#;mg;m75vQ*?*d4Tztw|nTLS^Y zH-`iqEf>b-r);F3Q~_D`cZH$BGWu)siXg~pRDs3)1|az7kgqJm2#$NR_{p2Y23-4BY)ULyBEa^$KdzDc9uq0^ACB~H-gaD=Y4z@9VVD}V$kHmZY*Zd--RR|Y0w6WlPWsSq`9?!a)pOu312EGz zk4m+W%p>D^0mr(5WfHSjGm4$@-XbLhSU&;M=<@H`iuaG1?)qq49eVAA5|f{k5V){} z8uBYG8s*=a?&=i4q?=aPx<^%phdi8kO`X$JJFg~83BLUMcYF-+MJbGo^^{rW9Z@->vG69q4q3;`%j1PYG2lz1;eHLUAMDldZP&8yIZ=zAT!_W^5Gh_b#n%EiU zZ%Fin+oCFPL;K`A8?8xGtUp%fnKU^o)jCC>R2*P%Cfi#_LmHjMEJxhmc}|a?*)R;# zbyHfgLFFpb00`ZaHUnRQmT#aiiK}x0gu+pd23%n_RUjE4QhiC3{(j_k)DA`~jo|p# z#u5J(u73}=8;tpFvdM1RcA}^T|4=?G_T`x+6LdEhUm=K9erRBQI z%4?gf+wXzRB%6mX!*t}t3Kv1nsQ~!hZbTr0bFyUkaDfV!snDh2##9g(Hhul2EW747 zgi;TxQ%{3b>Mc4N=|y#vIG(4HW=>NnpTpmFun$Rj02m`#o`ex0ONfET z4F{r7@emkC;R~!#dbkG?-M#lhIS+y-buu?tP{T}iowTIQI|Q3D*0|PFM=K&Z8(ngl zIFhy237n_38l?NRLR4+dQiB2V$&rEkfgtk?a6l=H7ExIM41_<)P%KaggZNGFqMZAL zMY&tS8=|yPYSZZFA&!dSI@Tu^@(_*Fml5a%4cZC)7jK+63+eEuZ3PCX_~(AjQOo`= zNPnlQ)GVKn42^BzfT?X|&6O%hoWj^?UbjQVlhMl_0`x{xa=q49T>Mx-$^2R5#O^pn z>2!Sz?&CdJ65j%GFWASd4pIV3tzxpdURHySx^q=6dVRBZ3a7`JP?PSBjkcQPh@?pe)x&( zA66UTKY_1wx3-Ur8yZU zi(!nn?u&oDM9#cLFP7RGZ@liCG@JKro%!fz2GqHc@fk04klM@5*ths6nRZJ%lI|p) ztyuO1VIcggf?H~xX6i7k&p4~V9`G>zjntUEflyoQ^SD~$lBIr*#v)di`!hHHzZ~Wd zJ-QNEBRBq)fz4l2#_xXm8YV8KB%v!-2Is(P`1=|D+zIhS-F?ZUgd{4ZvFP};cKr74 zvi0T|HHv$hL!f3guj8b`g!f?>1v>B0gS~UEbJ?|HOB?fc^jFhtGDY1pfHBHP3X70`g0Pl;1%{(WPrw) zLA={hi)#y_&B|CHDe{&@tUa4*`Gx7EV=fZARJ1+2VgS0L3UZC@{Wc`R>bF^Y|J_=) z6@zu_xnjZE0yN`sSuL5S5%*$tR?_Sn;IN zk+q_-5?}{FkQtG0br0boxa+}qf_r@ocNJU^!H6bY#l--XDfxMU;d>>l#G-kxw=U|n z4oX{wIsAKre7G+PF-;OsE5di0T5MG_-(T zhUl%sTLJ_I(vT32H{#nS1y2{d~Bk*>z;1fMDT#15#7$-u6_Yo!o9QuS!|5#-{ zC0)T!;?6@2clqJa$)sMARqIYV;r+ zk0)L=B>56L%h)=EE^|VE0=oK*K#|t8- zuPFs$^fLQzLGuZ2ZmXe@id)*N@}ZDUnL1)Z8A52hime?+&Bx7u|5)K3ImXEMUQge< zM`(Zo{DDFnt^k6F1jF&@18xC^>12aHE)&2k zs@Nwb?4XI^>w*cbU-d#dTM%R#VlaWL2MW8>deH&l@xZNi1uJB>M`h5y{I|JcKhaAgcz;0;FDw2<~EhliI5igwCTS&^FLFZSoB$eD>H zD10LcRu|WoR}}rm2%pHJGsgh+eOu9q0~qG^b(v)v%8_%bfYg<>q0IYcTAhF-kNC49 zGRJPK;g!YDNi0#B-0xu-ox&gG{wQ(DTXtXWgzKH6KjnvR?85x$A$ZN+G0#8>XkFb9 z9zWb_5-`)TxAZ%jIz@ik!2)usZWY?tyjjOd<;04s^5^fjU8zy`7I$70NYN82zW6h| z$X=NbEUMsfM*!<{`)e40n^{H-)`KJX!(mZdv-cC!9L+JvSVnSO(VKcNP;t?UGtk!b zSPgVYsnD9ejE;FGyPg{6YW6R5Q$rGiy%J(H)2LXP4eT;Slga?wulT3;iy&;Ia=@Rj z!U(jtPyK}8ZWprMhYw6rMgQS66{Y=o_anEEOn1Vj*{8icX-1vaY{+vNoJDFj0{pO( zMG_NH%h3QMU|oF!Z9ocohL5ayn*Z36RiYk>2PU&{vAU1j? zkRdJ8tizF;3llfJ+zh|bK4_O(7pI-9w^Y4gTB0F9sU?J)5ad=AE{p>o;579Jw#@~5OWbag~+3Mnyph?f@wbwu8 z=fB{(_w#nycZtQsdzOuJ=!+1W3GvhPtLJ9m8OpCA&1MCEcLm9=MUSexJUgvMnqDuz zd3!`HT>912mxR#8IDT6FH+LT`QmrCDq@~pdJ?clm$SLSgUD~0uNXRqN&U+KZqw7Df zzDBzgap!mUAGRk7ciu7Jh?&{>=jdQn1ag0rfaz2*?e8k)dfhWih%4+tNn18&)E9RC<4z zeXoG((fW36d;|?kq_y=zW+bjMr=HBC9G6~Oz67sXY9iWf{^(T=lY^M^#K>_LyRTd# zP2auGUqc^`u^ubR5w4Vs@kxf)dChil)2=KRi>a|4o@pNTPdUTmaKG~`#_vwS6!#k6 z{+4VvCc;c#xdy8hCDR;Cl~`TpA&O_}1i*3^LT54QK|MZcr> z_WFbw0$>}L+Ody2Uo6A7WL7!Jjsi|{&4b%5B5BgX4~e|uY}|YIqYsLi98Q<{`IYRM zg6GJnsy+;=)vhXW#}ZcT6Xz)uFQxpe`U{DB-KsDH#Ubr*#odC)p9`{S*v9t${JC%W zNwRP4qvDI=x+u!)g-*90R-vYQbpgwWYEHiCSSi3znGDt6hfK_&?&t8e#l%}MMpBFl zxE>$Q97^qR@(KeM*(xar8JyGv7=1lKpu)}4U@!(Ggn@EP+h#cPr~OUH-`QqXhlhNd zjl-d^u9-i0$Gp!aVs!#8LeIRnr-PZYrSHxBwm7LpU-rGj%`%3{jJ$YGlC;!ih7QtL z?Zt!uX4Po`%PTiH$H>#58o08=3zvG`f%ntyD#+pAjuhI>e65GIil-1!j zY|&2)#*BgVwZTom3H=~rSH4u71~5Evh9-a_APuJ-&g8=GsZ%XZ`qc>;Jya=i6~{(4 zze`0_$3fz?k)M$&6Q&2k9O@)|ms0J}WX+PQI!AD_7a~rK?MmT=*{6>HgTC8@7F?wW zQvP*i_&d*0XyEkG>uvdgHGS``HxH~dcZ(_r(SdxGqHQ%PTNR$W9pbwF`p%+Ykchrg zd;ZKP$e_{BKpcRu)<0Yc9BtI9zz>QDE10>pjI*RY^gW>ul4rjnPF^nE9*z_fjWPsx z;rz(NO!21+*w8E;HQ$iEs5?KQdY&WrS6@)|)f2@QGGUNb`pZ9QAe|~5VNk^MzNK=| z;9mAK2uc9Z4dpSjUqcHr9b7A0l!Z0R|#ihlchp@I~KLoS?6Doh)_ zu=K%3UGOn9lpxZdn;Jp5l_rCG^PfI$I}&ztJSpaMC0Dy0lkx;${plYda`3~ne*P2} z9ns|~NVrt6b{V?dJkGZr?$|N@3Us`o=$|_;^#S3=1iixlG*FRl!;~WTtHWQYrv4vi zfe1%Iyo&Usa1;vcWijV9f7lG3%s-7n>1JhqP#>q+%Q)cm8&5xe%t7J#7D4;Pq!ZrW z*g^ioamw?yQzmW9rs}H{8t5HMq^f8a;yr5&UFlvWAEjU8sr=MHK{6`(@8X=pB5QW2 z)rThuRkfKID&7*$00)V;uz|kjA&u<%qJ(-ftQI~Y0{FUqmAQ!dX>BIlbU4uR1a+&@ zkmj#sFi6@RVdl;od8!Nb$k?GwV+%UZN9AD$I^SFxGhyZiYBo6^FlHMmi!Ic%74vOR zTbAhK$tdDL$9G>b!@nzjgEd46*Yv8FuSvFht22=+*rv|+4$3b zZ!3S9Pw}ln%eG1#?EZ^BG{yxDUxw|9&~c^5s(?Zdx-((jv z13BIiNg7v<)1Ffv6D%?fSr_TBhX^49!*M=iw(6`RQc?jsR0}$}pNjkz<6%^oMiYn`-l$ug_5e zS1DRhObQInw-Hk}ce)nOJZ9INf!2B`WzZ4KR@X3E!~FpiZ)K(=-8Jv@E0_O7vHoC^ z*mjWnD^9@x&n<51a}BtoDA5<;<}xSCC+OaWNZ$ME3m&cIdTfwC4Zm$M?e4xF(O$|$ zrSzuPFiN2WDjj&+{!K)`jnAnWe@$`zFB!7C_VUHc>G-^C$sIK&2Yo??dG8%0cY(-P z1rmXM{)O0gYP&rAn2vYb`0|l9nE3ECc_<5>4C^-IkP5A?DipVEh9TOz&DpiYx%6@C z#Dno^dc`iX8XU-yP(<05{clKW%B~$F$=^>896~*gwp&*&IxfA9fhpjF$7_{qs|GRM zLX+R8N{JxU6-9q%_r?JeOsI^WN_t7?pj&xEkHMow{;zu80jt}tvI zFD>(I?F<}NeZm5#`PrYw0M)P3Kz3*VPJFh2r$Th$n@AOsr`1dhA9WkD|k=MnY0PQDYtoFoJo3AVzoQ(6}uJ5 zwBXm2)hE`7bwu6b&XTa}cPj9p2ZnQpcF_$!1-P{a=mYqW?0lIKJ;w@^$6in|X0*YF`$DQZHSS134zF#>yPW_`4AM znjWs@7CMvwH&w=voOp3Nmp*fLCy%HIhrP5`8tIG_zpnAcnl=|XlAwc5huL$3P(55h z>c_yBe?U^0$VIy65!`OulJGuDnbnWNi(Y(X%(q+=wc|?Q2Wu_JnDJ&$*`0Aw!ZUIi zLNC5ADY4@dQNnc>jc?!5JbOc?nNQyEX>`M5$mfqT$&v=S?+6QQU0tZYtev?)e4p?- zY{z1l6g8L;7w5*j(|auG#MUb~C2FLD6F18@z+LutDU_~ID;*L^^u`B!#;k#f{-zo9?Ko4_oPY}^K;S}Z+?xf&NYM^|v z*pkvo9N^|^q7*<0z0x+Hj+W+}ccPQ$H(-$H-?fpVpC<>uExt9k+(1qEU9M}vo%HvX0RkxaW5 z=KK>pm4^BzfJRm1U%B1g>RZ@jDfLn$`jQ>x1y$v|mymsRDCL?c!YkXHKGa-HgE^c< z&YfRD-oQYl9&jEJOV>1l30cc7hM{sP6OEbF4?M=-nqywL<U9Y?sIr@s$(G5wcSm@dzPD$+RR=zaQD*X%5`4WL^3uN+b)z#*3hP*#P%bC@!UE zZ>`)nYW}1sbTh`W{0WJAY;H1vzX&xGt4PFK9HgIS)leN-3# literal 0 HcmV?d00001 diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/res/values/colors.xml b/lite/demo/java/android/PaddlePredictor/app/src/main/res/values/colors.xml new file mode 100644 index 00000000000..69b22338c65 --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/app/src/main/res/values/colors.xml @@ -0,0 +1,6 @@ + + + #008577 + #00574B + #D81B60 + diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/res/values/strings.xml b/lite/demo/java/android/PaddlePredictor/app/src/main/res/values/strings.xml new file mode 100644 index 00000000000..168adfb0a0c --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/app/src/main/res/values/strings.xml @@ -0,0 +1,3 @@ + + PaddlePredictor + diff --git a/lite/demo/java/android/PaddlePredictor/app/src/main/res/values/styles.xml b/lite/demo/java/android/PaddlePredictor/app/src/main/res/values/styles.xml new file mode 100644 index 00000000000..5885930df6d --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/app/src/main/res/values/styles.xml @@ -0,0 +1,11 @@ + + + + + + diff --git a/lite/demo/java/android/PaddlePredictor/app/src/test/java/com/baidu/paddle/lite/ExampleUnitTest.java b/lite/demo/java/android/PaddlePredictor/app/src/test/java/com/baidu/paddle/lite/ExampleUnitTest.java new file mode 100644 index 00000000000..99dc6d27b35 --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/app/src/test/java/com/baidu/paddle/lite/ExampleUnitTest.java @@ -0,0 +1,17 @@ +package com.baidu.paddle.lite; + +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Example local unit test, which will execute on the development machine (host). + * + * @see Testing documentation + */ +public class ExampleUnitTest { + @Test + public void addition_isCorrect() { + assertEquals(4, 2 + 2); + } +} \ No newline at end of file diff --git a/lite/demo/java/android/PaddlePredictor/build.gradle b/lite/demo/java/android/PaddlePredictor/build.gradle new file mode 100644 index 00000000000..02199bb823f --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/build.gradle @@ -0,0 +1,27 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. + +buildscript { + repositories { + google() + jcenter() + + } + dependencies { + classpath 'com.android.tools.build:gradle:3.4.1' + + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files + } +} + +allprojects { + repositories { + google() + jcenter() + + } +} + +task clean(type: Delete) { + delete rootProject.buildDir +} diff --git a/lite/demo/java/android/PaddlePredictor/gradle.properties b/lite/demo/java/android/PaddlePredictor/gradle.properties new file mode 100644 index 00000000000..743d692ce15 --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/gradle.properties @@ -0,0 +1,13 @@ +# Project-wide Gradle settings. +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx1536m +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true diff --git a/lite/demo/java/android/PaddlePredictor/gradle/wrapper/gradle-wrapper.jar b/lite/demo/java/android/PaddlePredictor/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000000000000000000000000000000000..f6b961fd5a86aa5fbfe90f707c3138408be7c718 GIT binary patch literal 54329 zcmagFV|ZrKvM!pAZQHhO+qP}9lTNj?q^^Y^VFp)SH8qbSJ)2BQ2giqr}t zFG7D6)c?v~^Z#E_K}1nTQbJ9gQ9<%vVRAxVj)8FwL5_iTdUB>&m3fhE=kRWl;g`&m z!W5kh{WsV%fO*%je&j+Lv4xxK~zsEYQls$Q-p&dwID|A)!7uWtJF-=Tm1{V@#x*+kUI$=%KUuf2ka zjiZ{oiL1MXE2EjciJM!jrjFNwCh`~hL>iemrqwqnX?T*MX;U>>8yRcZb{Oy+VKZos zLiFKYPw=LcaaQt8tj=eoo3-@bG_342HQ%?jpgAE?KCLEHC+DmjxAfJ%Og^$dpC8Xw zAcp-)tfJm}BPNq_+6m4gBgBm3+CvmL>4|$2N$^Bz7W(}fz1?U-u;nE`+9`KCLuqg} zwNstNM!J4Uw|78&Y9~9>MLf56to!@qGkJw5Thx%zkzj%Ek9Nn1QA@8NBXbwyWC>9H z#EPwjMNYPigE>*Ofz)HfTF&%PFj$U6mCe-AFw$U%-L?~-+nSXHHKkdgC5KJRTF}`G zE_HNdrE}S0zf4j{r_f-V2imSqW?}3w-4=f@o@-q+cZgaAbZ((hn))@|eWWhcT2pLpTpL!;_5*vM=sRL8 zqU##{U#lJKuyqW^X$ETU5ETeEVzhU|1m1750#f}38_5N9)B_2|v@1hUu=Kt7-@dhA zq_`OMgW01n`%1dB*}C)qxC8q;?zPeF_r;>}%JYmlER_1CUbKa07+=TV45~symC*g8 zW-8(gag#cAOuM0B1xG8eTp5HGVLE}+gYTmK=`XVVV*U!>H`~j4+ROIQ+NkN$LY>h4 zqpwdeE_@AX@PL};e5vTn`Ro(EjHVf$;^oiA%@IBQq>R7_D>m2D4OwwEepkg}R_k*M zM-o;+P27087eb+%*+6vWFCo9UEGw>t&WI17Pe7QVuoAoGHdJ(TEQNlJOqnjZ8adCb zI`}op16D@v7UOEo%8E-~m?c8FL1utPYlg@m$q@q7%mQ4?OK1h%ODjTjFvqd!C z-PI?8qX8{a@6d&Lb_X+hKxCImb*3GFemm?W_du5_&EqRq!+H?5#xiX#w$eLti-?E$;Dhu`{R(o>LzM4CjO>ICf z&DMfES#FW7npnbcuqREgjPQM#gs6h>`av_oEWwOJZ2i2|D|0~pYd#WazE2Bbsa}X@ zu;(9fi~%!VcjK6)?_wMAW-YXJAR{QHxrD5g(ou9mR6LPSA4BRG1QSZT6A?kelP_g- zH(JQjLc!`H4N=oLw=f3{+WmPA*s8QEeEUf6Vg}@!xwnsnR0bl~^2GSa5vb!Yl&4!> zWb|KQUsC$lT=3A|7vM9+d;mq=@L%uWKwXiO9}a~gP4s_4Yohc!fKEgV7WbVo>2ITbE*i`a|V!^p@~^<={#?Gz57 zyPWeM2@p>D*FW#W5Q`1`#5NW62XduP1XNO(bhg&cX`-LYZa|m-**bu|>}S;3)eP8_ zpNTnTfm8 ze+7wDH3KJ95p)5tlwk`S7mbD`SqHnYD*6`;gpp8VdHDz%RR_~I_Ar>5)vE-Pgu7^Y z|9Px+>pi3!DV%E%4N;ii0U3VBd2ZJNUY1YC^-e+{DYq+l@cGtmu(H#Oh%ibUBOd?C z{y5jW3v=0eV0r@qMLgv1JjZC|cZ9l9Q)k1lLgm))UR@#FrJd>w^`+iy$c9F@ic-|q zVHe@S2UAnc5VY_U4253QJxm&Ip!XKP8WNcnx9^cQ;KH6PlW8%pSihSH2(@{2m_o+m zr((MvBja2ctg0d0&U5XTD;5?d?h%JcRJp{_1BQW1xu&BrA3(a4Fh9hon-ly$pyeHq zG&;6q?m%NJ36K1Sq_=fdP(4f{Hop;_G_(i?sPzvB zDM}>*(uOsY0I1j^{$yn3#U(;B*g4cy$-1DTOkh3P!LQ;lJlP%jY8}Nya=h8$XD~%Y zbV&HJ%eCD9nui-0cw!+n`V~p6VCRqh5fRX z8`GbdZ@73r7~myQLBW%db;+BI?c-a>Y)m-FW~M=1^|<21_Sh9RT3iGbO{o-hpN%d6 z7%++#WekoBOP^d0$$|5npPe>u3PLvX_gjH2x(?{&z{jJ2tAOWTznPxv-pAv<*V7r$ z6&glt>7CAClWz6FEi3bToz-soY^{ScrjwVPV51=>n->c(NJngMj6TyHty`bfkF1hc zkJS%A@cL~QV0-aK4>Id!9dh7>0IV;1J9(myDO+gv76L3NLMUm9XyPauvNu$S<)-|F zZS}(kK_WnB)Cl`U?jsdYfAV4nrgzIF@+%1U8$poW&h^c6>kCx3;||fS1_7JvQT~CV zQ8Js+!p)3oW>Df(-}uqC`Tcd%E7GdJ0p}kYj5j8NKMp(KUs9u7?jQ94C)}0rba($~ zqyBx$(1ae^HEDG`Zc@-rXk1cqc7v0wibOR4qpgRDt#>-*8N3P;uKV0CgJE2SP>#8h z=+;i_CGlv+B^+$5a}SicVaSeaNn29K`C&=}`=#Nj&WJP9Xhz4mVa<+yP6hkrq1vo= z1rX4qg8dc4pmEvq%NAkpMK>mf2g?tg_1k2%v}<3`$6~Wlq@ItJ*PhHPoEh1Yi>v57 z4k0JMO)*=S`tKvR5gb-(VTEo>5Y>DZJZzgR+j6{Y`kd|jCVrg!>2hVjz({kZR z`dLlKhoqT!aI8=S+fVp(5*Dn6RrbpyO~0+?fy;bm$0jmTN|t5i6rxqr4=O}dY+ROd zo9Et|x}!u*xi~>-y>!M^+f&jc;IAsGiM_^}+4|pHRn{LThFFpD{bZ|TA*wcGm}XV^ zr*C6~@^5X-*R%FrHIgo-hJTBcyQ|3QEj+cSqp#>&t`ZzB?cXM6S(lRQw$I2?m5=wd z78ki`R?%;o%VUhXH?Z#(uwAn9$m`npJ=cA+lHGk@T7qq_M6Zoy1Lm9E0UUysN)I_x zW__OAqvku^>`J&CB=ie@yNWsaFmem}#L3T(x?a`oZ+$;3O-icj2(5z72Hnj=9Z0w% z<2#q-R=>hig*(t0^v)eGq2DHC%GymE-_j1WwBVGoU=GORGjtaqr0BNigOCqyt;O(S zKG+DoBsZU~okF<7ahjS}bzwXxbAxFfQAk&O@>LsZMsZ`?N?|CDWM(vOm%B3CBPC3o z%2t@%H$fwur}SSnckUm0-k)mOtht`?nwsDz=2#v=RBPGg39i#%odKq{K^;bTD!6A9 zskz$}t)sU^=a#jLZP@I=bPo?f-L}wpMs{Tc!m7-bi!Ldqj3EA~V;4(dltJmTXqH0r z%HAWKGutEc9vOo3P6Q;JdC^YTnby->VZ6&X8f{obffZ??1(cm&L2h7q)*w**+sE6dG*;(H|_Q!WxU{g)CeoT z(KY&bv!Usc|m+Fqfmk;h&RNF|LWuNZ!+DdX*L=s-=_iH=@i` z?Z+Okq^cFO4}_n|G*!)Wl_i%qiMBaH8(WuXtgI7EO=M>=i_+;MDjf3aY~6S9w0K zUuDO7O5Ta6+k40~xh~)D{=L&?Y0?c$s9cw*Ufe18)zzk%#ZY>Tr^|e%8KPb0ht`b( zuP@8#Ox@nQIqz9}AbW0RzE`Cf>39bOWz5N3qzS}ocxI=o$W|(nD~@EhW13Rj5nAp; zu2obEJa=kGC*#3=MkdkWy_%RKcN=?g$7!AZ8vBYKr$ePY(8aIQ&yRPlQ=mudv#q$q z4%WzAx=B{i)UdLFx4os?rZp6poShD7Vc&mSD@RdBJ=_m^&OlkEE1DFU@csgKcBifJ zz4N7+XEJhYzzO=86 z#%eBQZ$Nsf2+X0XPHUNmg#(sNt^NW1Y0|M(${e<0kW6f2q5M!2YE|hSEQ*X-%qo(V zHaFwyGZ0on=I{=fhe<=zo{=Og-_(to3?cvL4m6PymtNsdDINsBh8m>a%!5o3s(en) z=1I z6O+YNertC|OFNqd6P=$gMyvmfa`w~p9*gKDESFqNBy(~Zw3TFDYh}$iudn)9HxPBi zdokK@o~nu?%imcURr5Y~?6oo_JBe}t|pU5qjai|#JDyG=i^V~7+a{dEnO<(y>ahND#_X_fcEBNiZ)uc&%1HVtx8Ts z*H_Btvx^IhkfOB#{szN*n6;y05A>3eARDXslaE>tnLa>+`V&cgho?ED+&vv5KJszf zG4@G;7i;4_bVvZ>!mli3j7~tPgybF5|J6=Lt`u$D%X0l}#iY9nOXH@(%FFJLtzb%p zzHfABnSs;v-9(&nzbZytLiqqDIWzn>JQDk#JULcE5CyPq_m#4QV!}3421haQ+LcfO*>r;rg6K|r#5Sh|y@h1ao%Cl)t*u`4 zMTP!deC?aL7uTxm5^nUv#q2vS-5QbBKP|drbDXS%erB>fYM84Kpk^au99-BQBZR z7CDynflrIAi&ahza+kUryju5LR_}-Z27g)jqOc(!Lx9y)e z{cYc&_r947s9pteaa4}dc|!$$N9+M38sUr7h(%@Ehq`4HJtTpA>B8CLNO__@%(F5d z`SmX5jbux6i#qc}xOhumzbAELh*Mfr2SW99=WNOZRZgoCU4A2|4i|ZVFQt6qEhH#B zK_9G;&h*LO6tB`5dXRSBF0hq0tk{2q__aCKXYkP#9n^)@cq}`&Lo)1KM{W+>5mSed zKp~=}$p7>~nK@va`vN{mYzWN1(tE=u2BZhga5(VtPKk(*TvE&zmn5vSbjo zZLVobTl%;t@6;4SsZ>5+U-XEGUZGG;+~|V(pE&qqrp_f~{_1h@5ZrNETqe{bt9ioZ z#Qn~gWCH!t#Ha^n&fT2?{`}D@s4?9kXj;E;lWV9Zw8_4yM0Qg-6YSsKgvQ*fF{#Pq z{=(nyV>#*`RloBVCs;Lp*R1PBIQOY=EK4CQa*BD0MsYcg=opP?8;xYQDSAJBeJpw5 zPBc_Ft9?;<0?pBhCmOtWU*pN*;CkjJ_}qVic`}V@$TwFi15!mF1*m2wVX+>5p%(+R zQ~JUW*zWkalde{90@2v+oVlkxOZFihE&ZJ){c?hX3L2@R7jk*xjYtHi=}qb+4B(XJ z$gYcNudR~4Kz_WRq8eS((>ALWCO)&R-MXE+YxDn9V#X{_H@j616<|P(8h(7z?q*r+ zmpqR#7+g$cT@e&(%_|ipI&A%9+47%30TLY(yuf&*knx1wNx|%*H^;YB%ftt%5>QM= z^i;*6_KTSRzQm%qz*>cK&EISvF^ovbS4|R%)zKhTH_2K>jP3mBGn5{95&G9^a#4|K zv+!>fIsR8z{^x4)FIr*cYT@Q4Z{y}};rLHL+atCgHbfX*;+k&37DIgENn&=k(*lKD zG;uL-KAdLn*JQ?@r6Q!0V$xXP=J2i~;_+i3|F;_En;oAMG|I-RX#FwnmU&G}w`7R{ z788CrR-g1DW4h_`&$Z`ctN~{A)Hv_-Bl!%+pfif8wN32rMD zJDs$eVWBYQx1&2sCdB0!vU5~uf)=vy*{}t{2VBpcz<+~h0wb7F3?V^44*&83Z2#F` z32!rd4>uc63rQP$3lTH3zb-47IGR}f)8kZ4JvX#toIpXH`L%NnPDE~$QI1)0)|HS4 zVcITo$$oWWwCN@E-5h>N?Hua!N9CYb6f8vTFd>h3q5Jg-lCI6y%vu{Z_Uf z$MU{{^o~;nD_@m2|E{J)q;|BK7rx%`m``+OqZAqAVj-Dy+pD4-S3xK?($>wn5bi90CFAQ+ACd;&m6DQB8_o zjAq^=eUYc1o{#+p+ zn;K<)Pn*4u742P!;H^E3^Qu%2dM{2slouc$AN_3V^M7H_KY3H)#n7qd5_p~Za7zAj|s9{l)RdbV9e||_67`#Tu*c<8!I=zb@ z(MSvQ9;Wrkq6d)!9afh+G`!f$Ip!F<4ADdc*OY-y7BZMsau%y?EN6*hW4mOF%Q~bw z2==Z3^~?q<1GTeS>xGN-?CHZ7a#M4kDL zQxQr~1ZMzCSKFK5+32C%+C1kE#(2L=15AR!er7GKbp?Xd1qkkGipx5Q~FI-6zt< z*PTpeVI)Ngnnyaz5noIIgNZtb4bQdKG{Bs~&tf)?nM$a;7>r36djllw%hQxeCXeW^ z(i6@TEIuxD<2ulwLTt|&gZP%Ei+l!(%p5Yij6U(H#HMkqM8U$@OKB|5@vUiuY^d6X zW}fP3;Kps6051OEO(|JzmVU6SX(8q>*yf*x5QoxDK={PH^F?!VCzES_Qs>()_y|jg6LJlJWp;L zKM*g5DK7>W_*uv}{0WUB0>MHZ#oJZmO!b3MjEc}VhsLD~;E-qNNd?x7Q6~v zR=0$u>Zc2Xr}>x_5$-s#l!oz6I>W?lw;m9Ae{Tf9eMX;TI-Wf_mZ6sVrMnY#F}cDd z%CV*}fDsXUF7Vbw>PuDaGhu631+3|{xp<@Kl|%WxU+vuLlcrklMC!Aq+7n~I3cmQ! z`e3cA!XUEGdEPSu``&lZEKD1IKO(-VGvcnSc153m(i!8ohi`)N2n>U_BemYJ`uY>8B*Epj!oXRLV}XK}>D*^DHQ7?NY*&LJ9VSo`Ogi9J zGa;clWI8vIQqkngv2>xKd91K>?0`Sw;E&TMg&6dcd20|FcTsnUT7Yn{oI5V4@Ow~m zz#k~8TM!A9L7T!|colrC0P2WKZW7PNj_X4MfESbt<-soq*0LzShZ}fyUx!(xIIDwx zRHt^_GAWe0-Vm~bDZ(}XG%E+`XhKpPlMBo*5q_z$BGxYef8O!ToS8aT8pmjbPq)nV z%x*PF5ZuSHRJqJ!`5<4xC*xb2vC?7u1iljB_*iUGl6+yPyjn?F?GOF2_KW&gOkJ?w z3e^qc-te;zez`H$rsUCE0<@7PKGW?7sT1SPYWId|FJ8H`uEdNu4YJjre`8F*D}6Wh z|FQ`xf7yiphHIAkU&OYCn}w^ilY@o4larl?^M7&8YI;hzBIsX|i3UrLsx{QDKwCX< zy;a>yjfJ6!sz`NcVi+a!Fqk^VE^{6G53L?@Tif|j!3QZ0fk9QeUq8CWI;OmO-Hs+F zuZ4sHLA3{}LR2Qlyo+{d@?;`tpp6YB^BMoJt?&MHFY!JQwoa0nTSD+#Ku^4b{5SZVFwU9<~APYbaLO zu~Z)nS#dxI-5lmS-Bnw!(u15by(80LlC@|ynj{TzW)XcspC*}z0~8VRZq>#Z49G`I zgl|C#H&=}n-ajxfo{=pxPV(L*7g}gHET9b*s=cGV7VFa<;Htgjk>KyW@S!|z`lR1( zGSYkEl&@-bZ*d2WQ~hw3NpP=YNHF^XC{TMG$Gn+{b6pZn+5=<()>C!N^jncl0w6BJ zdHdnmSEGK5BlMeZD!v4t5m7ct7{k~$1Ie3GLFoHjAH*b?++s<|=yTF+^I&jT#zuMx z)MLhU+;LFk8bse|_{j+d*a=&cm2}M?*arjBPnfPgLwv)86D$6L zLJ0wPul7IenMvVAK$z^q5<^!)7aI|<&GGEbOr=E;UmGOIa}yO~EIr5xWU_(ol$&fa zR5E(2vB?S3EvJglTXdU#@qfDbCYs#82Yo^aZN6`{Ex#M)easBTe_J8utXu(fY1j|R z9o(sQbj$bKU{IjyhosYahY{63>}$9_+hWxB3j}VQkJ@2$D@vpeRSldU?&7I;qd2MF zSYmJ>zA(@N_iK}m*AMPIJG#Y&1KR)6`LJ83qg~`Do3v^B0>fU&wUx(qefuTgzFED{sJ65!iw{F2}1fQ3= ziFIP{kezQxmlx-!yo+sC4PEtG#K=5VM9YIN0z9~c4XTX?*4e@m;hFM!zVo>A`#566 z>f&3g94lJ{r)QJ5m7Xe3SLau_lOpL;A($wsjHR`;xTXgIiZ#o&vt~ zGR6KdU$FFbLfZCC3AEu$b`tj!9XgOGLSV=QPIYW zjI!hSP#?8pn0@ezuenOzoka8!8~jXTbiJ6+ZuItsWW03uzASFyn*zV2kIgPFR$Yzm zE<$cZlF>R8?Nr2_i?KiripBc+TGgJvG@vRTY2o?(_Di}D30!k&CT`>+7ry2!!iC*X z<@=U0_C#16=PN7bB39w+zPwDOHX}h20Ap);dx}kjXX0-QkRk=cr};GYsjSvyLZa-t zzHONWddi*)RDUH@RTAsGB_#&O+QJaaL+H<<9LLSE+nB@eGF1fALwjVOl8X_sdOYme z0lk!X=S(@25=TZHR7LlPp}fY~yNeThMIjD}pd9+q=j<_inh0$>mIzWVY+Z9p<{D^#0Xk+b_@eNSiR8;KzSZ#7lUsk~NGMcB8C2c=m2l5paHPq`q{S(kdA7Z1a zyfk2Y;w?^t`?@yC5Pz9&pzo}Hc#}mLgDmhKV|PJ3lKOY(Km@Fi2AV~CuET*YfUi}u zfInZnqDX(<#vaS<^fszuR=l)AbqG{}9{rnyx?PbZz3Pyu!eSJK`uwkJU!ORQXy4x83r!PNgOyD33}}L=>xX_93l6njNTuqL8J{l%*3FVn3MG4&Fv*`lBXZ z?=;kn6HTT^#SrPX-N)4EZiIZI!0ByXTWy;;J-Tht{jq1mjh`DSy7yGjHxIaY%*sTx zuy9#9CqE#qi>1misx=KRWm=qx4rk|}vd+LMY3M`ow8)}m$3Ggv&)Ri*ON+}<^P%T5 z_7JPVPfdM=Pv-oH<tecoE}(0O7|YZc*d8`Uv_M*3Rzv7$yZnJE6N_W=AQ3_BgU_TjA_T?a)U1csCmJ&YqMp-lJe`y6>N zt++Bi;ZMOD%%1c&-Q;bKsYg!SmS^#J@8UFY|G3!rtyaTFb!5@e(@l?1t(87ln8rG? z--$1)YC~vWnXiW3GXm`FNSyzu!m$qT=Eldf$sMl#PEfGmzQs^oUd=GIQfj(X=}dw+ zT*oa0*oS%@cLgvB&PKIQ=Ok?>x#c#dC#sQifgMwtAG^l3D9nIg(Zqi;D%807TtUUCL3_;kjyte#cAg?S%e4S2W>9^A(uy8Ss0Tc++ZTjJw1 z&Em2g!3lo@LlDyri(P^I8BPpn$RE7n*q9Q-c^>rfOMM6Pd5671I=ZBjAvpj8oIi$! zl0exNl(>NIiQpX~FRS9UgK|0l#s@#)p4?^?XAz}Gjb1?4Qe4?j&cL$C8u}n)?A@YC zfmbSM`Hl5pQFwv$CQBF=_$Sq zxsV?BHI5bGZTk?B6B&KLdIN-40S426X3j_|ceLla*M3}3gx3(_7MVY1++4mzhH#7# zD>2gTHy*%i$~}mqc#gK83288SKp@y3wz1L_e8fF$Rb}ex+`(h)j}%~Ld^3DUZkgez zOUNy^%>>HHE|-y$V@B}-M|_{h!vXpk01xaD%{l{oQ|~+^>rR*rv9iQen5t?{BHg|% zR`;S|KtUb!X<22RTBA4AAUM6#M?=w5VY-hEV)b`!y1^mPNEoy2K)a>OyA?Q~Q*&(O zRzQI~y_W=IPi?-OJX*&&8dvY0zWM2%yXdFI!D-n@6FsG)pEYdJbuA`g4yy;qrgR?G z8Mj7gv1oiWq)+_$GqqQ$(ZM@#|0j7})=#$S&hZwdoijFI4aCFLVI3tMH5fLreZ;KD zqA`)0l~D2tuIBYOy+LGw&hJ5OyE+@cnZ0L5+;yo2pIMdt@4$r^5Y!x7nHs{@>|W(MzJjATyWGNwZ^4j+EPU0RpAl-oTM@u{lx*i0^yyWPfHt6QwPvYpk9xFMWfBFt!+Gu6TlAmr zeQ#PX71vzN*_-xh&__N`IXv6`>CgV#eA_%e@7wjgkj8jlKzO~Ic6g$cT`^W{R{606 zCDP~+NVZ6DMO$jhL~#+!g*$T!XW63#(ngDn#Qwy71yj^gazS{e;3jGRM0HedGD@pt z?(ln3pCUA(ekqAvvnKy0G@?-|-dh=eS%4Civ&c}s%wF@0K5Bltaq^2Os1n6Z3%?-Q zAlC4goQ&vK6TpgtzkHVt*1!tBYt-`|5HLV1V7*#45Vb+GACuU+QB&hZ=N_flPy0TY zR^HIrdskB#<$aU;HY(K{a3(OQa$0<9qH(oa)lg@Uf>M5g2W0U5 zk!JSlhrw8quBx9A>RJ6}=;W&wt@2E$7J=9SVHsdC?K(L(KACb#z)@C$xXD8^!7|uv zZh$6fkq)aoD}^79VqdJ!Nz-8$IrU(_-&^cHBI;4 z^$B+1aPe|LG)C55LjP;jab{dTf$0~xbXS9!!QdcmDYLbL^jvxu2y*qnx2%jbL%rB z{aP85qBJe#(&O~Prk%IJARcdEypZ)vah%ZZ%;Zk{eW(U)Bx7VlzgOi8)x z`rh4l`@l_Ada7z&yUK>ZF;i6YLGwI*Sg#Fk#Qr0Jg&VLax(nNN$u-XJ5=MsP3|(lEdIOJ7|(x3iY;ea)5#BW*mDV%^=8qOeYO&gIdJVuLLN3cFaN=xZtFB=b zH{l)PZl_j^u+qx@89}gAQW7ofb+k)QwX=aegihossZq*+@PlCpb$rpp>Cbk9UJO<~ zDjlXQ_Ig#W0zdD3&*ei(FwlN#3b%FSR%&M^ywF@Fr>d~do@-kIS$e%wkIVfJ|Ohh=zc zF&Rnic^|>@R%v?@jO}a9;nY3Qrg_!xC=ZWUcYiA5R+|2nsM*$+c$TOs6pm!}Z}dfM zGeBhMGWw3$6KZXav^>YNA=r6Es>p<6HRYcZY)z{>yasbC81A*G-le8~QoV;rtKnkx z;+os8BvEe?0A6W*a#dOudsv3aWs?d% z0oNngyVMjavLjtjiG`!007#?62ClTqqU$@kIY`=x^$2e>iqIy1>o|@Tw@)P)B8_1$r#6>DB_5 zmaOaoE~^9TolgDgooKFuEFB#klSF%9-~d2~_|kQ0Y{Ek=HH5yq9s zDq#1S551c`kSiWPZbweN^A4kWiP#Qg6er1}HcKv{fxb1*BULboD0fwfaNM_<55>qM zETZ8TJDO4V)=aPp_eQjX%||Ud<>wkIzvDlpNjqW>I}W!-j7M^TNe5JIFh#-}zAV!$ICOju8Kx)N z0vLtzDdy*rQN!7r>Xz7rLw8J-(GzQlYYVH$WK#F`i_i^qVlzTNAh>gBWKV@XC$T-` z3|kj#iCquDhiO7NKum07i|<-NuVsX}Q}mIP$jBJDMfUiaWR3c|F_kWBMw0_Sr|6h4 zk`_r5=0&rCR^*tOy$A8K;@|NqwncjZ>Y-75vlpxq%Cl3EgH`}^^~=u zoll6xxY@a>0f%Ddpi;=cY}fyG!K2N-dEyXXmUP5u){4VnyS^T4?pjN@Ot4zjL(Puw z_U#wMH2Z#8Pts{olG5Dy0tZj;N@;fHheu>YKYQU=4Bk|wcD9MbA`3O4bj$hNRHwzb zSLcG0SLV%zywdbuwl(^E_!@&)TdXge4O{MRWk2RKOt@!8E{$BU-AH(@4{gxs=YAz9LIob|Hzto0}9cWoz6Tp2x0&xi#$ zHh$dwO&UCR1Ob2w00-2eG7d4=cN(Y>0R#$q8?||q@iTi+7-w-xR%uMr&StFIthC<# zvK(aPduwuNB}oJUV8+Zl)%cnfsHI%4`;x6XW^UF^e4s3Z@S<&EV8?56Wya;HNs0E> z`$0dgRdiUz9RO9Au3RmYq>K#G=X%*_dUbSJHP`lSfBaN8t-~@F>)BL1RT*9I851A3 z<-+Gb#_QRX>~av#Ni<#zLswtu-c6{jGHR>wflhKLzC4P@b%8&~u)fosoNjk4r#GvC zlU#UU9&0Hv;d%g72Wq?Ym<&&vtA3AB##L}=ZjiTR4hh7J)e>ei} zt*u+>h%MwN`%3}b4wYpV=QwbY!jwfIj#{me)TDOG`?tI!%l=AwL2G@9I~}?_dA5g6 zCKgK(;6Q0&P&K21Tx~k=o6jwV{dI_G+Ba*Zts|Tl6q1zeC?iYJTb{hel*x>^wb|2RkHkU$!+S4OU4ZOKPZjV>9OVsqNnv5jK8TRAE$A&^yRwK zj-MJ3Pl?)KA~fq#*K~W0l4$0=8GRx^9+?w z!QT8*-)w|S^B0)ZeY5gZPI2G(QtQf?DjuK(s^$rMA!C%P22vynZY4SuOE=wX2f8$R z)A}mzJi4WJnZ`!bHG1=$lwaxm!GOnRbR15F$nRC-M*H<*VfF|pQw(;tbSfp({>9^5 zw_M1-SJ9eGF~m(0dvp*P8uaA0Yw+EkP-SWqu zqal$hK8SmM7#Mrs0@OD+%_J%H*bMyZiWAZdsIBj#lkZ!l2c&IpLu(5^T0Ge5PHzR} zn;TXs$+IQ_&;O~u=Jz+XE0wbOy`=6>m9JVG} zJ~Kp1e5m?K3x@@>!D)piw^eMIHjD4RebtR`|IlckplP1;r21wTi8v((KqNqn%2CB< zifaQc&T}*M&0i|LW^LgdjIaX|o~I$`owHolRqeH_CFrqCUCleN130&vH}dK|^kC>) z-r2P~mApHotL4dRX$25lIcRh_*kJaxi^%ZN5-GAAMOxfB!6flLPY-p&QzL9TE%ho( zRwftE3sy5<*^)qYzKkL|rE>n@hyr;xPqncY6QJ8125!MWr`UCWuC~A#G1AqF1@V$kv>@NBvN&2ygy*{QvxolkRRb%Ui zsmKROR%{*g*WjUUod@@cS^4eF^}yQ1>;WlGwOli z+Y$(8I`0(^d|w>{eaf!_BBM;NpCoeem2>J}82*!em=}}ymoXk>QEfJ>G(3LNA2-46 z5PGvjr)Xh9>aSe>vEzM*>xp{tJyZox1ZRl}QjcvX2TEgNc^(_-hir@Es>NySoa1g^ zFow_twnHdx(j?Q_3q51t3XI7YlJ4_q&(0#)&a+RUy{IcBq?)eaWo*=H2UUVIqtp&lW9JTJiP&u zw8+4vo~_IJXZIJb_U^&=GI1nSD%e;P!c{kZALNCm5c%%oF+I3DrA63_@4)(v4(t~JiddILp7jmoy+>cD~ivwoctFfEL zP*#2Rx?_&bCpX26MBgp^4G>@h`Hxc(lnqyj!*t>9sOBcXN(hTwEDpn^X{x!!gPX?1 z*uM$}cYRwHXuf+gYTB}gDTcw{TXSOUU$S?8BeP&sc!Lc{{pEv}x#ELX>6*ipI1#>8 zKes$bHjiJ1OygZge_ak^Hz#k;=od1wZ=o71ba7oClBMq>Uk6hVq|ePPt)@FM5bW$I z;d2Or@wBjbTyZj|;+iHp%Bo!Vy(X3YM-}lasMItEV_QrP-Kk_J4C>)L&I3Xxj=E?| zsAF(IfVQ4w+dRRnJ>)}o^3_012YYgFWE)5TT=l2657*L8_u1KC>Y-R{7w^S&A^X^U}h20jpS zQsdeaA#WIE*<8KG*oXc~$izYilTc#z{5xhpXmdT-YUnGh9v4c#lrHG6X82F2-t35} zB`jo$HjKe~E*W$=g|j&P>70_cI`GnOQ;Jp*JK#CT zuEGCn{8A@bC)~0%wsEv?O^hSZF*iqjO~_h|>xv>PO+?525Nw2472(yqS>(#R)D7O( zg)Zrj9n9$}=~b00=Wjf?E418qP-@8%MQ%PBiCTX=$B)e5cHFDu$LnOeJ~NC;xmOk# z>z&TbsK>Qzk)!88lNI8fOE2$Uxso^j*1fz>6Ot49y@=po)j4hbTIcVR`ePHpuJSfp zxaD^Dn3X}Na3@<_Pc>a;-|^Pon(>|ytG_+U^8j_JxP=_d>L$Hj?|0lz>_qQ#a|$+( z(x=Lipuc8p4^}1EQhI|TubffZvB~lu$zz9ao%T?%ZLyV5S9}cLeT?c} z>yCN9<04NRi~1oR)CiBakoNhY9BPnv)kw%*iv8vdr&&VgLGIs(-FbJ?d_gfbL2={- zBk4lkdPk~7+jIxd4{M(-W1AC_WcN&Oza@jZoj zaE*9Y;g83#m(OhA!w~LNfUJNUuRz*H-=$s*z+q+;snKPRm9EptejugC-@7-a-}Tz0 z@KHra#Y@OXK+KsaSN9WiGf?&jlZ!V7L||%KHP;SLksMFfjkeIMf<1e~t?!G3{n)H8 zQAlFY#QwfKuj;l@<$YDATAk;%PtD%B(0<|8>rXU< zJ66rkAVW_~Dj!7JGdGGi4NFuE?7ZafdMxIh65Sz7yQoA7fBZCE@WwysB=+`kT^LFX zz8#FlSA5)6FG9(qL3~A24mpzL@@2D#>0J7mMS1T*9UJ zvOq!!a(%IYY69+h45CE?(&v9H4FCr>gK0>mK~F}5RdOuH2{4|}k@5XpsX7+LZo^Qa4sH5`eUj>iffoBVm+ zz4Mtf`h?NW$*q1yr|}E&eNl)J``SZvTf6Qr*&S%tVv_OBpbjnA0&Vz#(;QmGiq-k! zgS0br4I&+^2mgA15*~Cd00cXLYOLA#Ep}_)eED>m+K@JTPr_|lSN}(OzFXQSBc6fM z@f-%2;1@BzhZa*LFV z-LrLmkmB%<<&jEURBEW>soaZ*rSIJNwaV%-RSaCZi4X)qYy^PxZ=oL?6N-5OGOMD2 z;q_JK?zkwQ@b3~ln&sDtT5SpW9a0q+5Gm|fpVY2|zqlNYBR}E5+ahgdj!CvK$Tlk0 z9g$5N;aar=CqMsudQV>yb4l@hN(9Jcc=1(|OHsqH6|g=K-WBd8GxZ`AkT?OO z-z_Ued-??Z*R4~L7jwJ%-`s~FK|qNAJ;EmIVDVpk{Lr7T4l{}vL)|GuUuswe9c5F| zv*5%u01hlv08?00Vpwyk*Q&&fY8k6MjOfpZfKa@F-^6d=Zv|0@&4_544RP5(s|4VPVP-f>%u(J@23BHqo2=zJ#v9g=F!cP((h zpt0|(s++ej?|$;2PE%+kc6JMmJjDW)3BXvBK!h!E`8Y&*7hS{c_Z?4SFP&Y<3evqf z9-ke+bSj$%Pk{CJlJbWwlBg^mEC^@%Ou?o>*|O)rl&`KIbHrjcpqsc$Zqt0^^F-gU2O=BusO+(Op}!jNzLMc zT;0YT%$@ClS%V+6lMTfhuzzxomoat=1H?1$5Ei7&M|gxo`~{UiV5w64Np6xV zVK^nL$)#^tjhCpTQMspXI({TW^U5h&Wi1Jl8g?P1YCV4=%ZYyjSo#5$SX&`r&1PyC zzc;uzCd)VTIih|8eNqFNeBMe#j_FS6rq81b>5?aXg+E#&$m++Gz9<+2)h=K(xtn}F ziV{rmu+Y>A)qvF}ms}4X^Isy!M&1%$E!rTO~5(p+8{U6#hWu>(Ll1}eD64Xa>~73A*538wry?v$vW z>^O#FRdbj(k0Nr&)U`Tl(4PI*%IV~;ZcI2z&rmq=(k^}zGOYZF3b2~Klpzd2eZJl> zB=MOLwI1{$RxQ7Y4e30&yOx?BvAvDkTBvWPpl4V8B7o>4SJn*+h1Ms&fHso%XLN5j z-zEwT%dTefp~)J_C8;Q6i$t!dnlh-!%haR1X_NuYUuP-)`IGWjwzAvp!9@h`kPZhf zwLwFk{m3arCdx8rD~K2`42mIN4}m%OQ|f)4kf%pL?Af5Ul<3M2fv>;nlhEPR8b)u} zIV*2-wyyD%%) zl$G@KrC#cUwoL?YdQyf9WH)@gWB{jd5w4evI& zOFF)p_D8>;3-N1z6mES!OPe>B^<;9xsh)){Cw$Vs-ez5nXS95NOr3s$IU;>VZSzKn zBvub8_J~I%(DozZW@{)Vp37-zevxMRZ8$8iRfwHmYvyjOxIOAF2FUngKj289!(uxY zaClWm!%x&teKmr^ABrvZ(ikx{{I-lEzw5&4t3P0eX%M~>$wG0ZjA4Mb&op+0$#SO_ z--R`>X!aqFu^F|a!{Up-iF(K+alKB{MNMs>e(i@Tpy+7Z-dK%IEjQFO(G+2mOb@BO zP>WHlS#fSQm0et)bG8^ZDScGnh-qRKIFz zfUdnk=m){ej0i(VBd@RLtRq3Ep=>&2zZ2%&vvf?Iex01hx1X!8U+?>ER;yJlR-2q4 z;Y@hzhEC=d+Le%=esE>OQ!Q|E%6yG3V_2*uh&_nguPcZ{q?DNq8h_2ahaP6=pP-+x zK!(ve(yfoYC+n(_+chiJ6N(ZaN+XSZ{|H{TR1J_s8x4jpis-Z-rlRvRK#U%SMJ(`C z?T2 zF(NNfO_&W%2roEC2j#v*(nRgl1X)V-USp-H|CwFNs?n@&vpRcj@W@xCJwR6@T!jt377?XjZ06=`d*MFyTdyvW!`mQm~t3luzYzvh^F zM|V}rO>IlBjZc}9Z zd$&!tthvr>5)m;5;96LWiAV0?t)7suqdh0cZis`^Pyg@?t>Ms~7{nCU;z`Xl+raSr zXpp=W1oHB*98s!Tpw=R5C)O{{Inl>9l7M*kq%#w9a$6N~v?BY2GKOVRkXYCgg*d

<5G2M1WZP5 zzqSuO91lJod(SBDDw<*sX(+F6Uq~YAeYV#2A;XQu_p=N5X+#cmu19Qk>QAnV=k!?wbk5I;tDWgFc}0NkvC*G=V+Yh1cyeJVq~9czZiDXe+S=VfL2g`LWo8om z$Y~FQc6MFjV-t1Y`^D9XMwY*U_re2R?&(O~68T&D4S{X`6JYU-pz=}ew-)V0AOUT1 zVOkHAB-8uBcRjLvz<9HS#a@X*Kc@|W)nyiSgi|u5$Md|P()%2(?olGg@ypoJwp6>m z*dnfjjWC>?_1p;%1brqZyDRR;8EntVA92EJ3ByOxj6a+bhPl z;a?m4rQAV1@QU^#M1HX)0+}A<7TCO`ZR_RzF}X9-M>cRLyN4C+lCk2)kT^3gN^`IT zNP~fAm(wyIoR+l^lQDA(e1Yv}&$I!n?&*p6?lZcQ+vGLLd~fM)qt}wsbf3r=tmVYe zl)ntf#E!P7wlakP9MXS7m0nsAmqxZ*)#j;M&0De`oNmFgi$ov#!`6^4)iQyxg5Iuj zjLAhzQ)r`^hf7`*1`Rh`X;LVBtDSz@0T?kkT1o!ijeyTGt5vc^Cd*tmNgiNo^EaWvaC8$e+nb_{W01j3%=1Y&92YacjCi>eNbwk%-gPQ@H-+4xskQ}f_c=jg^S-# zYFBDf)2?@5cy@^@FHK5$YdAK9cI;!?Jgd}25lOW%xbCJ>By3=HiK@1EM+I46A)Lsd zeT|ZH;KlCml=@;5+hfYf>QNOr^XNH%J-lvev)$Omy8MZ`!{`j>(J5cG&ZXXgv)TaF zg;cz99i$4CX_@3MIb?GL0s*8J=3`#P(jXF(_(6DXZjc@(@h&=M&JG)9&Te1?(^XMW zjjC_70|b=9hB6pKQi`S^Ls7JyJw^@P>Ko^&q8F&?>6i;#CbxUiLz1ZH4lNyd@QACd zu>{!sqjB!2Dg}pbAXD>d!3jW}=5aN0b;rw*W>*PAxm7D)aw(c*RX2@bTGEI|RRp}vw7;NR2wa;rXN{L{Q#=Fa z$x@ms6pqb>!8AuV(prv>|aU8oWV={C&$c zMa=p=CDNOC2tISZcd8~18GN5oTbKY+Vrq;3_obJlfSKRMk;Hdp1`y`&LNSOqeauR_ z^j*Ojl3Ohzb5-a49A8s|UnM*NM8tg}BJXdci5%h&;$afbmRpN0&~9rCnBA`#lG!p zc{(9Y?A0Y9yo?wSYn>iigf~KP$0*@bGZ>*YM4&D;@{<%Gg5^uUJGRrV4 z(aZOGB&{_0f*O=Oi0k{@8vN^BU>s3jJRS&CJOl3o|BE{FAA&a#2YYiX3pZz@|Go-F z|Fly;7eX2OTs>R}<`4RwpHFs9nwh)B28*o5qK1Ge=_^w0m`uJOv!=&!tzt#Save(C zgKU=Bsgql|`ui(e1KVxR`?>Dx>(rD1$iWp&m`v)3A!j5(6vBm*z|aKm*T*)mo(W;R zNGo2`KM!^SS7+*9YxTm6YMm_oSrLceqN*nDOAtagULuZl5Q<7mOnB@Hq&P|#9y{5B z!2x+2s<%Cv2Aa0+u{bjZXS);#IFPk(Ph-K7K?3i|4ro> zRbqJoiOEYo(Im^((r}U4b8nvo_>4<`)ut`24?ILnglT;Pd&U}$lV3U$F9#PD(O=yV zgNNA=GW|(E=&m_1;uaNmipQe?pon4{T=zK!N!2_CJL0E*R^XXIKf*wi!>@l}3_P9Z zF~JyMbW!+n-+>!u=A1ESxzkJy$DRuG+$oioG7(@Et|xVbJ#BCt;J43Nvj@MKvTxzy zMmjNuc#LXBxFAwIGZJk~^!q$*`FME}yKE8d1f5Mp}KHNq(@=Z8YxV}0@;YS~|SpGg$_jG7>_8WWYcVx#4SxpzlV9N4aO>K{c z$P?a_fyDzGX$Of3@ykvedGd<@-R;M^Shlj*SswJLD+j@hi_&_>6WZ}#AYLR0iWMK|A zH_NBeu(tMyG=6VO-=Pb>-Q#$F*or}KmEGg*-n?vWQREURdB#+6AvOj*I%!R-4E_2$ zU5n9m>RWs|Wr;h2DaO&mFBdDb-Z{APGQx$(L`if?C|njd*fC=rTS%{o69U|meRvu?N;Z|Y zbT|ojL>j;q*?xXmnHH#3R4O-59NV1j=uapkK7}6@Wo*^Nd#(;$iuGsb;H315xh3pl zHaJ>h-_$hdNl{+|Zb%DZH%ES;*P*v0#}g|vrKm9;j-9e1M4qX@zkl&5OiwnCz=tb6 zz<6HXD+rGIVpGtkb{Q^LIgExOm zz?I|oO9)!BOLW#krLmWvX5(k!h{i>ots*EhpvAE;06K|u_c~y{#b|UxQ*O@Ks=bca z^_F0a@61j3I(Ziv{xLb8AXQj3;R{f_l6a#H5ukg5rxwF9A$?Qp-Mo54`N-SKc}fWp z0T)-L@V$$&my;l#Ha{O@!fK4-FSA)L&3<${Hcwa7ue`=f&YsXY(NgeDU#sRlT3+9J z6;(^(sjSK@3?oMo$%L-nqy*E;3pb0nZLx6 z;h5)T$y8GXK1DS-F@bGun8|J(v-9o=42&nLJy#}M5D0T^5VWBNn$RpC zZzG6Bt66VY4_?W=PX$DMpKAI!d`INr) zkMB{XPQ<52rvWVQqgI0OL_NWxoe`xxw&X8yVftdODPj5|t}S6*VMqN$-h9)1MBe0N zYq?g0+e8fJCoAksr0af1)FYtz?Me!Cxn`gUx&|T;)695GG6HF7!Kg1zzRf_{VWv^bo81v4$?F6u2g|wxHc6eJQAg&V z#%0DnWm2Rmu71rPJ8#xFUNFC*V{+N_qqFH@gYRLZ6C?GAcVRi>^n3zQxORPG)$-B~ z%_oB?-%Zf7d*Fe;cf%tQwcGv2S?rD$Z&>QC2X^vwYjnr5pa5u#38cHCt4G3|efuci z@3z=#A13`+ztmp;%zjXwPY_aq-;isu*hecWWX_=Z8paSqq7;XYnUjK*T>c4~PR4W7 z#C*%_H&tfGx`Y$w7`dXvVhmovDnT>btmy~SLf>>~84jkoQ%cv=MMb+a{JV&t0+1`I z32g_Y@yDhKe|K^PevP~MiiVl{Ou7^Mt9{lOnXEQ`xY^6L8D$705GON{!1?1&YJEl#fTf5Z)da=yiEQ zGgtC-soFGOEBEB~ZF_{7b(76En>d}mI~XIwNw{e>=Fv)sgcw@qOsykWr?+qAOZSVrQfg}TNI ztKNG)1SRrAt6#Q?(me%)>&A_^DM`pL>J{2xu>xa$3d@90xR61TQDl@fu%_85DuUUA za9tn64?At;{`BAW6oykwntxHeDpXsV#{tmt5RqdN7LtcF4vR~_kZNT|wqyR#z^Xcd zFdymVRZvyLfTpBT>w9<)Ozv@;Yk@dOSVWbbtm^y@@C>?flP^EgQPAwsy75bveo=}T zFxl(f)s)j(0#N_>Or(xEuV(n$M+`#;Pc$1@OjXEJZumkaekVqgP_i}p`oTx;terTx zZpT+0dpUya2hqlf`SpXN{}>PfhajNk_J0`H|2<5E;U5Vh4F8er z;RxLSFgpGhkU>W?IwdW~NZTyOBrQ84H7_?gviIf71l`EETodG9a1!8e{jW?DpwjL? zGEM&eCzwoZt^P*8KHZ$B<%{I}>46IT%jJ3AnnB5P%D2E2Z_ z1M!vr#8r}1|KTqWA4%67ZdbMW2YJ81b(KF&SQ2L1Qn(y-=J${p?xLMx3W7*MK;LFQ z6Z`aU;;mTL4XrrE;HY*Rkh6N%?qviUGNAKiCB~!P}Z->IpO6E(gGd7I#eDuT7j|?nZ zK}I(EJ>$Kb&@338M~O+em9(L!+=0zBR;JAQesx|3?Ok90)D1aS9P?yTh6Poh8Cr4X zk3zc=f2rE7jj+aP7nUsr@~?^EGP>Q>h#NHS?F{Cn`g-gD<8F&dqOh-0sa%pfL`b+1 zUsF*4a~)KGb4te&K0}bE>z3yb8% zibb5Q%Sfiv7feb1r0tfmiMv z@^4XYwg@KZI=;`wC)`1jUA9Kv{HKe2t$WmRcR4y8)VAFjRi zaz&O7Y2tDmc5+SX(bj6yGHYk$dBkWc96u3u&F)2yEE~*i0F%t9Kg^L6MJSb&?wrXi zGSc;_rln$!^ybwYBeacEFRsVGq-&4uC{F)*Y;<0y7~USXswMo>j4?~5%Zm!m@i@-> zXzi82sa-vpU{6MFRktJy+E0j#w`f`>Lbog{zP|9~hg(r{RCa!uGe>Yl536cn$;ouH za#@8XMvS-kddc1`!1LVq;h57~zV`7IYR}pp3u!JtE6Q67 zq3H9ZUcWPm2V4IukS}MCHSdF0qg2@~ufNx9+VMjQP&exiG_u9TZAeAEj*jw($G)zL zq9%#v{wVyOAC4A~AF=dPX|M}MZV)s(qI9@aIK?Pe+~ch|>QYb+78lDF*Nxz2-vpRbtQ*F4$0fDbvNM#CCatgQ@z1+EZWrt z2dZfywXkiW=no5jus-92>gXn5rFQ-COvKyegmL=4+NPzw6o@a?wGE-1Bt;pCHe;34K%Z z-FnOb%!nH;)gX+!a3nCk?5(f1HaWZBMmmC@lc({dUah+E;NOros{?ui1zPC-Q0);w zEbJmdE$oU$AVGQPdm{?xxI_0CKNG$LbY*i?YRQ$(&;NiA#h@DCxC(U@AJ$Yt}}^xt-EC_ z4!;QlLkjvSOhdx!bR~W|Ezmuf6A#@T`2tsjkr>TvW*lFCMY>Na_v8+{Y|=MCu1P8y z89vPiH5+CKcG-5lzk0oY>~aJC_0+4rS@c@ZVKLAp`G-sJB$$)^4*A!B zmcf}lIw|VxV9NSoJ8Ag3CwN&d7`|@>&B|l9G8tXT^BDHOUPrtC70NgwN4${$k~d_4 zJ@eo6%YQnOgq$th?0{h`KnqYa$Nz@vlHw<%!C5du6<*j1nwquk=uY}B8r7f|lY+v7 zm|JU$US08ugor8E$h3wH$c&i~;guC|3-tqJy#T;v(g( zBZtPMSyv%jzf->435yM(-UfyHq_D=6;ouL4!ZoD+xI5uCM5ay2m)RPmm$I}h>()hS zO!0gzMxc`BPkUZ)WXaXam%1;)gedA7SM8~8yIy@6TPg!hR0=T>4$Zxd)j&P-pXeSF z9W`lg6@~YDhd19B9ETv(%er^Xp8Yj@AuFVR_8t*KS;6VHkEDKI#!@l!l3v6`W1`1~ zP{C@keuV4Q`Rjc08lx?zmT$e$!3esc9&$XZf4nRL(Z*@keUbk!GZi(2Bmyq*saOD? z3Q$V<*P-X1p2}aQmuMw9nSMbOzuASsxten7DKd6A@ftZ=NhJ(0IM|Jr<91uAul4JR zADqY^AOVT3a(NIxg|U;fyc#ZnSzw2cr}#a5lZ38>nP{05D)7~ad7JPhw!LqOwATXtRhK!w0X4HgS1i<%AxbFmGJx9?sEURV+S{k~g zGYF$IWSlQonq6}e;B(X(sIH|;52+(LYW}v_gBcp|x%rEAVB`5LXg_d5{Q5tMDu0_2 z|LOm$@K2?lrLNF=mr%YP|U-t)~9bqd+wHb4KuPmNK<}PK6e@aosGZK57=Zt+kcszVOSbe;`E^dN! ze7`ha3WUUU7(nS0{?@!}{0+-VO4A{7+nL~UOPW9_P(6^GL0h${SLtqG!} zKl~Ng5#@Sy?65wk9z*3SA`Dpd4b4T^@C8Fhd8O)k_4%0RZL5?#b~jmgU+0|DB%0Z) zql-cPC>A9HPjdOTpPC` zQwvF}uB5kG$Xr4XnaH#ruSjM*xG?_hT7y3G+8Ox`flzU^QIgb_>2&-f+XB6MDr-na zSi#S+c!ToK84<&m6sCiGTd^8pNdXo+$3^l3FL_E`0 z>8it5YIDxtTp2Tm(?}FX^w{fbfgh7>^8mtvN>9fWgFN_*a1P`Gz*dyOZF{OV7BC#j zQV=FQM5m>47xXgapI$WbPM5V`V<7J9tD)oz@d~MDoM`R^Y6-Na(lO~uvZlpu?;zw6 zVO1faor3dg#JEb5Q*gz4<W8tgC3nE2BG2jeIQs1)<{In&7hJ39x=;ih;CJDy)>0S1at*7n?Wr0ahYCpFjZ|@u91Zl7( zv;CSBRC65-6f+*JPf4p1UZ)k=XivKTX6_bWT~7V#rq0Xjas6hMO!HJN8GdpBKg_$B zwDHJF6;z?h<;GXFZan8W{XFNPpOj!(&I1`&kWO86p?Xz`a$`7qV7Xqev|7nn_lQuX ziGpU1MMYt&5dE2A62iX3;*0WzNB9*nSTzI%62A+N?f?;S>N@8M=|ef3gtQTIA*=yq zQAAjOqa!CkHOQo4?TsqrrsJLclXcP?dlAVv?v`}YUjo1Htt;6djP@NPFH+&p1I+f_ z)Y279{7OWomY8baT(4TAOlz1OyD{4P?(DGv3XyJTA2IXe=kqD)^h(@*E3{I~w;ws8 z)ZWv7E)pbEM zd3MOXRH3mQhks9 zv6{s;k0y5vrcjXaVfw8^>YyPo=oIqd5IGI{)+TZq5Z5O&hXAw%ZlL}^6FugH;-%vP zAaKFtt3i^ag226=f0YjzdPn6|4(C2sC5wHFX{7QF!tG1E-JFA`>eZ`}$ymcRJK?0c zN363o{&ir)QySOFY0vcu6)kX#;l??|7o{HBDVJN+17rt|w3;(C_1b>d;g9Gp=8YVl zYTtA52@!7AUEkTm@P&h#eg+F*lR zQ7iotZTcMR1frJ0*V@Hw__~CL>_~2H2cCtuzYIUD24=Cv!1j6s{QS!v=PzwQ(a0HS zBKx04KA}-Ue+%9d`?PG*hIij@54RDSQpA7|>qYVIrK_G6%6;#ZkR}NjUgmGju)2F`>|WJoljo)DJgZr4eo1k1i1+o z1D{>^RlpIY8OUaOEf5EBu%a&~c5aWnqM zxBpJq98f=%M^{4mm~5`CWl%)nFR64U{(chmST&2jp+-r z3675V<;Qi-kJud%oWnCLdaU-)xTnMM%rx%Jw6v@=J|Ir=4n-1Z23r-EVf91CGMGNz zb~wyv4V{H-hkr3j3WbGnComiqmS0vn?n?5v2`Vi>{Ip3OZUEPN7N8XeUtF)Ry6>y> zvn0BTLCiqGroFu|m2zG-;Xb6;W`UyLw)@v}H&(M}XCEVXZQoWF=Ykr5lX3XWwyNyF z#jHv)A*L~2BZ4lX?AlN3X#axMwOC)PoVy^6lCGse9bkGjb=qz%kDa6}MOmSwK`cVO zt(e*MW-x}XtU?GY5}9{MKhRhYOlLhJE5=ca+-RmO04^ z66z{40J=s=ey9OCdc(RCzy zd7Zr1%!y3}MG(D=wM_ebhXnJ@MLi7cImDkhm0y{d-Vm81j`0mbi4lF=eirlr)oW~a zCd?26&j^m4AeXEsIUXiTal)+SPM4)HX%%YWF1?(FV47BaA`h9m67S9x>hWMVHx~Hg z1meUYoLL(p@b3?x|9DgWeI|AJ`Ia84*P{Mb%H$ZRROouR4wZhOPX15=KiBMHl!^JnCt$Az`KiH^_d>cev&f zaG2>cWf$=A@&GP~DubsgYb|L~o)cn5h%2`i^!2)bzOTw2UR!>q5^r&2Vy}JaWFUQE04v>2;Z@ZPwXr?y&G(B^@&y zsd6kC=hHdKV>!NDLIj+3rgZJ|dF`%N$DNd;B)9BbiT9Ju^Wt%%u}SvfM^=|q-nxDG zuWCQG9e#~Q5cyf8@y76#kkR^}{c<_KnZ0QsZcAT|YLRo~&tU|N@BjxOuy`#>`X~Q< z?R?-Gsk$$!oo(BveQLlUrcL#eirhgBLh`qHEMg`+sR1`A=1QX7)ZLMRT+GBy?&mM8 zQG^z-!Oa&J-k7I(3_2#Q6Bg=NX<|@X&+YMIOzfEO2$6Mnh}YV!m!e^__{W@-CTprr zbdh3f=BeCD$gHwCrmwgM3LAv3!Mh$wM)~KWzp^w)Cu6roO7uUG5z*}i0_0j47}pK; ztN530`ScGatLOL06~zO)Qmuv`h!gq5l#wx(EliKe&rz-5qH(hb1*fB#B+q`9=jLp@ zOa2)>JTl7ovxMbrif`Xe9;+fqB1K#l=Dv!iT;xF zdkCvS>C5q|O;}ns3AgoE({Ua-zNT-9_5|P0iANmC6O76Sq_(AN?UeEQJ>#b54fi3k zFmh+P%b1x3^)0M;QxXLP!BZ^h|AhOde*{9A=f3|Xq*JAs^Y{eViF|=EBfS6L%k4ip zk+7M$gEKI3?bQg?H3zaE@;cyv9kv;cqK$VxQbFEsy^iM{XXW0@2|DOu$!-k zSFl}Y=jt-VaT>Cx*KQnHTyXt}f9XswFB9ibYh+k2J!ofO+nD?1iw@mwtrqI4_i?nE zhLkPp41ED62me}J<`3RN80#vjW;wt`pP?%oQ!oqy7`miL>d-35a=qotK$p{IzeSk# ze_$CFYp_zIkrPFVaW^s#U4xT1lI^A0IBe~Y<4uS%zSV=wcuLr%gQT=&5$&K*bwqx| zWzCMiz>7t^Et@9CRUm9E+@hy~sBpm9fri$sE1zgLU((1?Yg{N1Sars=DiW&~Zw=3I zi7y)&oTC?UWD2w97xQ&5vx zRXEBGeJ(I?Y}eR0_O{$~)bMJRTsNUPIfR!xU9PE7A>AMNr_wbrFK>&vVw=Y;RH zO$mlpmMsQ}-FQ2cSj7s7GpC+~^Q~dC?y>M}%!-3kq(F3hGWo9B-Gn02AwUgJ>Z-pKOaj zysJBQx{1>Va=*e@sLb2z&RmQ7ira;aBijM-xQ&cpR>X3wP^foXM~u1>sv9xOjzZpX z0K;EGouSYD~oQ&lAafj3~EaXfFShC+>VsRlEMa9cg9i zFxhCKO}K0ax6g4@DEA?dg{mo>s+~RPI^ybb^u--^nTF>**0l5R9pocwB?_K)BG_)S zyLb&k%XZhBVr7U$wlhMqwL)_r&&n%*N$}~qijbkfM|dIWP{MyLx}X&}ES?}7i;9bW zmTVK@zR)7kE2+L42Q`n4m0VVg5l5(W`SC9HsfrLZ=v%lpef=Gj)W59VTLe+Z$8T8i z4V%5+T0t8LnM&H>Rsm5C%qpWBFqgTwL{=_4mE{S3EnBXknM&u8n}A^IIM4$s3m(Rd z>zq=CP-!9p9es2C*)_hoL@tDYABn+o#*l;6@7;knWIyDrt5EuakO99S$}n((Fj4y} zD!VvuRzghcE{!s;jC*<_H$y6!6QpePo2A3ZbX*ZzRnQq*b%KK^NF^z96CHaWmzU@f z#j;y?X=UP&+YS3kZx7;{ zDA{9(wfz7GF`1A6iB6fnXu0?&d|^p|6)%3$aG0Uor~8o? z*e}u#qz7Ri?8Uxp4m_u{a@%bztvz-BzewR6bh*1Xp+G=tQGpcy|4V_&*aOqu|32CM zz3r*E8o8SNea2hYJpLQ-_}R&M9^%@AMx&`1H8aDx4j%-gE+baf2+9zI*+Pmt+v{39 zDZ3Ix_vPYSc;Y;yn68kW4CG>PE5RoaV0n@#eVmk?p$u&Fy&KDTy!f^Hy6&^-H*)#u zdrSCTJPJw?(hLf56%2;_3n|ujUSJOU8VPOTlDULwt0jS@j^t1WS z!n7dZIoT+|O9hFUUMbID4Ec$!cc($DuQWkocVRcYSikFeM&RZ=?BW)mG4?fh#)KVG zcJ!<=-8{&MdE)+}?C8s{k@l49I|Zwswy^ZN3;E!FKyglY~Aq?4m74P-0)sMTGXqd5(S<-(DjjM z&7dL-Mr8jhUCAG$5^mI<|%`;JI5FVUnNj!VO2?Jiqa|c2;4^n!R z`5KK0hyB*F4w%cJ@Un6GC{mY&r%g`OX|1w2$B7wxu97%<@~9>NlXYd9RMF2UM>(z0 zouu4*+u+1*k;+nFPk%ly!nuMBgH4sL5Z`@Rok&?Ef=JrTmvBAS1h?C0)ty5+yEFRz zY$G=coQtNmT@1O5uk#_MQM1&bPPnspy5#>=_7%WcEL*n$;sSAZcXxMpcXxLe;_mLA z5F_paad+bGZV*oh@8h0(|D2P!q# zTHjmiphJ=AazSeKQPkGOR-D8``LjzToyx{lfK-1CDD6M7?pMZOdLKFtjZaZMPk4}k zW)97Fh(Z+_Fqv(Q_CMH-YYi?fR5fBnz7KOt0*t^cxmDoIokc=+`o# zrud|^h_?KW=Gv%byo~(Ln@({?3gnd?DUf-j2J}|$Mk>mOB+1{ZQ8HgY#SA8END(Zw z3T+W)a&;OO54~m}ffemh^oZ!Vv;!O&yhL0~hs(p^(Yv=(3c+PzPXlS5W79Er8B1o* z`c`NyS{Zj_mKChj+q=w)B}K za*zzPhs?c^`EQ;keH{-OXdXJet1EsQ)7;{3eF!-t^4_Srg4(Ot7M*E~91gwnfhqaM zNR7dFaWm7MlDYWS*m}CH${o?+YgHiPC|4?X?`vV+ws&Hf1ZO-w@OGG^o4|`b{bLZj z&9l=aA-Y(L11!EvRjc3Zpxk7lc@yH1e$a}8$_-r$)5++`_eUr1+dTb@ zU~2P1HM#W8qiNN3b*=f+FfG1!rFxnNlGx{15}BTIHgxO>Cq4 z;#9H9YjH%>Z2frJDJ8=xq>Z@H%GxXosS@Z>cY9ppF+)e~t_hWXYlrO6)0p7NBMa`+ z^L>-#GTh;k_XnE)Cgy|0Dw;(c0* zSzW14ZXozu)|I@5mRFF1eO%JM=f~R1dkNpZM+Jh(?&Zje3NgM{2ezg1N`AQg5%+3Y z64PZ0rPq6;_)Pj-hyIOgH_Gh`1$j1!jhml7ksHA1`CH3FDKiHLz+~=^u@kUM{ilI5 z^FPiJ7mSrzBs9{HXi2{sFhl5AyqwUnU{sPcUD{3+l-ZHAQ)C;c$=g1bdoxeG(5N01 zZy=t8i{*w9m?Y>V;uE&Uy~iY{pY4AV3_N;RL_jT_QtLFx^KjcUy~q9KcLE3$QJ{!)@$@En{UGG7&}lc*5Kuc^780;7Bj;)X?1CSy*^^ zPP^M)Pr5R>mvp3_hmCtS?5;W^e@5BjE>Cs<`lHDxj<|gtOK4De?Sf0YuK5GX9G93i zMYB{8X|hw|T6HqCf7Cv&r8A$S@AcgG1cF&iJ5=%+x;3yB`!lQ}2Hr(DE8=LuNb~Vs z=FO&2pdc16nD$1QL7j+!U^XWTI?2qQKt3H8=beVTdHHa9=MiJ&tM1RRQ-=+vy!~iz zj3O{pyRhCQ+b(>jC*H)J)%Wq}p>;?@W*Eut@P&?VU+Sdw^4kE8lvX|6czf{l*~L;J zFm*V~UC;3oQY(ytD|D*%*uVrBB}BbAfjK&%S;z;7$w68(8PV_whC~yvkZmX)xD^s6 z{$1Q}q;99W?*YkD2*;)tRCS{q2s@JzlO~<8x9}X<0?hCD5vpydvOw#Z$2;$@cZkYrp83J0PsS~!CFtY%BP=yxG?<@#{7%2sy zOc&^FJxsUYN36kSY)d7W=*1-{7ghPAQAXwT7z+NlESlkUH&8ODlpc8iC*iQ^MAe(B z?*xO4i{zFz^G=^G#9MsLKIN64rRJykiuIVX5~0#vAyDWc9-=6BDNT_aggS2G{B>dD ze-B%d3b6iCfc5{@yz$>=@1kdK^tX9qh0=ocv@9$ai``a_ofxT=>X7_Y0`X}a^M?d# z%EG)4@`^Ej_=%0_J-{ga!gFtji_byY&Vk@T1c|ucNAr(JNr@)nCWj?QnCyvXg&?FW;S-VOmNL6^km_dqiVjJuIASVGSFEos@EVF7St$WE&Z%)`Q##+0 zjaZ=JI1G@0!?l|^+-ZrNd$WrHBi)DA0-Eke>dp=_XpV<%CO_Wf5kQx}5e<90dt>8k zAi00d0rQ821nA>B4JHN7U8Zz=0;9&U6LOTKOaC1FC8GgO&kc=_wHIOGycL@c*$`ce703t%>S}mvxEnD-V!;6c`2(p74V7D0No1Xxt`urE66$0(ThaAZ1YVG#QP$ zy~NN%kB*zhZ2Y!kjn826pw4bh)75*e!dse+2Db(;bN34Uq7bLpr47XTX{8UEeC?2i z*{$`3dP}32${8pF$!$2Vq^gY|#w+VA_|o(oWmQX8^iw#n_crb(K3{69*iU?<%C-%H zuKi)3M1BhJ@3VW>JA`M>L~5*_bxH@Euy@niFrI$82C1}fwR$p2E&ZYnu?jlS}u7W9AyfdXh2pM>78bIt3 z)JBh&XE@zA!kyCDfvZ1qN^np20c1u#%P6;6tU&dx0phT1l=(mw7`u!-0e=PxEjDds z9E}{E!7f9>jaCQhw)&2TtG-qiD)lD(4jQ!q{`x|8l&nmtHkdul# zy+CIF8lKbp9_w{;oR+jSLtTfE+B@tOd6h=QePP>rh4@~!8c;Hlg9m%%&?e`*Z?qz5-zLEWfi>`ord5uHF-s{^bexKAoMEV@9nU z^5nA{f{dW&g$)BAGfkq@r5D)jr%!Ven~Q58c!Kr;*Li#`4Bu_?BU0`Y`nVQGhNZk@ z!>Yr$+nB=`z#o2nR0)V3M7-eVLuY`z@6CT#OTUXKnxZn$fNLPv7w1y7eGE=Qv@Hey`n;`U=xEl|q@CCV^#l)s0ZfT+mUf z^(j5r4)L5i2jnHW4+!6Si3q_LdOLQi<^fu?6WdohIkn79=jf%Fs3JkeXwF(?_tcF? z?z#j6iXEd(wJy4|p6v?xNk-)iIf2oX5^^Y3q3ziw16p9C6B;{COXul%)`>nuUoM*q zzmr|NJ5n)+sF$!yH5zwp=iM1#ZR`O%L83tyog-qh1I z0%dcj{NUs?{myT~33H^(%0QOM>-$hGFeP;U$puxoJ>>o-%Lk*8X^rx1>j|LtH$*)>1C!Pv&gd16%`qw5LdOIUbkNhaBBTo}5iuE%K&ZV^ zAr_)kkeNKNYJRgjsR%vexa~&8qMrQYY}+RbZ)egRg9_$vkoyV|Nc&MH@8L)`&rpqd zXnVaI@~A;Z^c3+{x=xgdhnocA&OP6^rr@rTvCnhG6^tMox$ulw2U7NgUtW%|-5VeH z_qyd47}1?IbuKtqNbNx$HR`*+9o=8`%vM8&SIKbkX9&%TS++x z5|&6P<%=F$C?owUI`%uvUq^yW0>`>yz!|WjzsoB9dT;2Dx8iSuK%%_XPgy0dTD4kd zDXF@&O_vBVVKQq(9YTClUPM30Sk7B!v7nOyV`XC!BA;BIVwphh+c)?5VJ^(C;GoQ$ zvBxr7_p*k$T%I1ke}`U&)$uf}I_T~#3XTi53OX)PoXVgxEcLJgZG^i47U&>LY(l%_ z;9vVDEtuMCyu2fqZeez|RbbIE7@)UtJvgAcVwVZNLccswxm+*L&w`&t=ttT=sv6Aq z!HouSc-24Y9;0q$>jX<1DnnGmAsP))- z^F~o99gHZw`S&Aw7e4id6Lg7kMk-e)B~=tZ!kE7sGTOJ)8@q}np@j7&7Sy{2`D^FH zI7aX%06vKsfJ168QnCM2=l|i>{I{%@gcr>ExM0Dw{PX6ozEuqFYEt z087%MKC;wVsMV}kIiuu9Zz9~H!21d!;Cu#b;hMDIP7nw3xSX~#?5#SSjyyg+Y@xh| z%(~fv3`0j#5CA2D8!M2TrG=8{%>YFr(j)I0DYlcz(2~92?G*?DeuoadkcjmZszH5& zKI@Lis%;RPJ8mNsbrxH@?J8Y2LaVjUIhRUiO-oqjy<&{2X~*f|)YxnUc6OU&5iac= z*^0qwD~L%FKiPmlzi&~a*9sk2$u<7Al=_`Ox^o2*kEv?p`#G(p(&i|ot8}T;8KLk- zPVf_4A9R`5^e`Om2LV*cK59EshYXse&IoByj}4WZaBomoHAPKqxRKbPcD`lMBI)g- zeMRY{gFaUuecSD6q!+b5(?vAnf>c`Z(8@RJy%Ulf?W~xB1dFAjw?CjSn$ph>st5bc zUac1aD_m6{l|$#g_v6;=32(mwpveQDWhmjR7{|B=$oBhz`7_g7qNp)n20|^^op3 zSfTdWV#Q>cb{CMKlWk91^;mHap{mk)o?udk$^Q^^u@&jd zfZ;)saW6{e*yoL6#0}oVPb2!}r{pAUYtn4{P~ES9tTfC5hXZnM{HrC8^=Pof{G4%Bh#8 ze~?C9m*|fd8MK;{L^!+wMy>=f^8b&y?yr6KnTq28$pFMBW9Oy7!oV5z|VM$s-cZ{I|Xf@}-)1=$V&x7e;9v81eiTi4O5-vs?^5pCKy2l>q);!MA zS!}M48l$scB~+Umz}7NbwyTn=rqt@`YtuwiQSMvCMFk2$83k50Q>OK5&fe*xCddIm)3D0I6vBU<+!3=6?(OhkO|b4fE_-j zimOzyfBB_*7*p8AmZi~X2bgVhyPy>KyGLAnOpou~sx9)S9%r)5dE%ADs4v%fFybDa_w*0?+>PsEHTbhKK^G=pFz z@IxLTCROWiKy*)cV3y%0FwrDvf53Ob_XuA1#tHbyn%Ko!1D#sdhBo`;VC*e1YlhrC z?*y3rp86m#qI|qeo8)_xH*G4q@70aXN|SP+6MQ!fJQqo1kwO_v7zqvUfU=Gwx`CR@ zRFb*O8+54%_8tS(ADh}-hUJzE`s*8wLI>1c4b@$al)l}^%GuIXjzBK!EWFO8W`>F^ ze7y#qPS0NI7*aU)g$_ziF(1ft;2<}6Hfz10cR8P}67FD=+}MfhrpOkF3hFhQu;Q1y zu%=jJHTr;0;oC94Hi@LAF5quAQ(rJG(uo%BiRQ@8U;nhX)j0i?0SL2g-A*YeAqF>RVCBOTrn{0R27vu}_S zS>tX4!#&U4W;ikTE!eFH+PKw%p+B(MR2I%n#+m0{#?qRP_tR@zpgCb=4rcrL!F=;A zh%EIF8m6%JG+qb&mEfuFTLHSxUAZEvC-+kvZKyX~SA3Umt`k}}c!5dy?-sLIM{h@> z!2=C)@nx>`;c9DdwZ&zeUc(7t<21D7qBj!|1^Mp1eZ6)PuvHx+poKSDCSBMFF{bKy z;9*&EyKitD99N}%mK8431rvbT+^%|O|HV23{;RhmS{$5tf!bIPoH9RKps`-EtoW5h zo6H_!s)Dl}2gCeGF6>aZtah9iLuGd19^z0*OryPNt{70RvJSM<#Ox9?HxGg04}b^f zrVEPceD%)#0)v5$YDE?f`73bQ6TA6wV;b^x*u2Ofe|S}+q{s5gr&m~4qGd!wOu|cZ||#h_u=k*fB;R6&k?FoM+c&J;ISg70h!J7*xGus)ta4veTdW)S^@sU@ z4$OBS=a~@F*V0ECic;ht4@?Jw<9kpjBgHfr2FDPykCCz|v2)`JxTH55?b3IM={@DU z!^|9nVO-R#s{`VHypWyH0%cs;0GO3E;It6W@0gX6wZ%W|Dzz&O%m17pa19db(er}C zUId1a4#I+Ou8E1MU$g=zo%g7K(=0Pn$)Rk z<4T2u<0rD)*j+tcy2XvY+0 z0d2pqm4)4lDewsAGThQi{2Kc3&C=|OQF!vOd#WB_`4gG3@inh-4>BoL!&#ij8bw7? zqjFRDaQz!J-YGitV4}$*$hg`vv%N)@#UdzHFI2E<&_@0Uw@h_ZHf}7)G;_NUD3@18 zH5;EtugNT0*RXVK*by>WS>jaDDfe!A61Da=VpIK?mcp^W?!1S2oah^wowRnrYjl~`lgP-mv$?yb6{{S55CCu{R z$9;`dyf0Y>uM1=XSl_$01Lc1Iy68IosWN8Q9Op=~I(F<0+_kKfgC*JggjxNgK6 z-3gQm6;sm?J&;bYe&(dx4BEjvq}b`OT^RqF$J4enP1YkeBK#>l1@-K`ajbn05`0J?0daOtnzh@l3^=BkedW1EahZlRp;`j*CaT;-21&f2wU z+Nh-gc4I36Cw+;3UAc<%ySb`#+c@5y ze~en&bYV|kn?Cn|@fqmGxgfz}U!98$=drjAkMi`43I4R%&H0GKEgx-=7PF}y`+j>r zg&JF`jomnu2G{%QV~Gf_-1gx<3Ky=Md9Q3VnK=;;u0lyTBCuf^aUi?+1+`4lLE6ZK zT#(Bf`5rmr(tgTbIt?yA@y`(Ar=f>-aZ}T~>G32EM%XyFvhn&@PWCm#-<&ApLDCXT zD#(9m|V(OOo7PmE@`vD4$S5;+9IQm19dd zvMEU`)E1_F+0o0-z>YCWqg0u8ciIknU#{q02{~YX)gc_u;8;i233D66pf(IkTDxeN zL=4z2)?S$TV9=ORVr&AkZMl<4tTh(v;Ix1{`pPVqI3n2ci&4Dg+W|N8TBUfZ*WeLF zqCH_1Q0W&f9T$lx3CFJ$o@Lz$99 zW!G&@zFHxTaP!o#z^~xgF|(vrHz8R_r9eo;TX9}2ZyjslrtH=%6O)?1?cL&BT(Amp zTGFU1%%#xl&6sH-UIJk_PGk_McFn7=%yd6tAjm|lnmr8bE2le3I~L{0(ffo}TQjyo zHZZI{-}{E4ohYTlZaS$blB!h$Jq^Rf#(ch}@S+Ww&$b);8+>g84IJcLU%B-W?+IY& zslcZIR>+U4v3O9RFEW;8NpCM0w1ROG84=WpKxQ^R`{=0MZCubg3st z48AyJNEvyxn-jCPTlTwp4EKvyEwD3e%kpdY?^BH0!3n6Eb57_L%J1=a*3>|k68A}v zaW`*4YitylfD}ua8V)vb79)N_Ixw_mpp}yJGbNu+5YYOP9K-7nf*jA1#<^rb4#AcS zKg%zCI)7cotx}L&J8Bqo8O1b0q;B1J#B5N5Z$Zq=wX~nQFgUfAE{@u0+EnmK{1hg> zC{vMfFLD;L8b4L+B51&LCm|scVLPe6h02rws@kGv@R+#IqE8>Xn8i|vRq_Z`V;x6F zNeot$1Zsu`lLS92QlLWF54za6vOEKGYQMdX($0JN*cjG7HP&qZ#3+bEN$8O_PfeAb z0R5;=zXac2IZ?fxu59?Nka;1lKm|;0)6|#RxkD05P5qz;*AL@ig!+f=lW5^Jbag%2 z%9@iM0ph$WFlxS!`p31t92z~TB}P-*CS+1Oo_g;7`6k(Jyj8m8U|Q3Sh7o-Icp4kV zK}%qri5>?%IPfamXIZ8pXbm-#{ytiam<{a5A+3dVP^xz!Pvirsq7Btv?*d7eYgx7q zWFxrzb3-%^lDgMc=Vl7^={=VDEKabTG?VWqOngE`Kt7hs236QKidsoeeUQ_^FzsXjprCDd@pW25rNx#6x&L6ZEpoX9Ffzv@olnH3rGOSW( zG-D|cV0Q~qJ>-L}NIyT?T-+x+wU%;+_GY{>t(l9dI%Ximm+Kmwhee;FK$%{dnF;C% zFjM2&$W68Sz#d*wtfX?*WIOXwT;P6NUw}IHdk|)fw*YnGa0rHx#paG!m=Y6GkS4VX zX`T$4eW9k1W!=q8!(#8A9h67fw))k_G)Q9~Q1e3f`aV@kbcSv7!priDUN}gX(iXTy zr$|kU0Vn%*ylmyDCO&G0Z3g>%JeEPFAW!5*H2Ydl>39w3W+gEUjL&vrRs(xGP{(ze zy7EMWF14@Qh>X>st8_029||TP0>7SG9on_xxeR2Iam3G~Em$}aGsNt$iES9zFa<3W zxtOF*!G@=PhfHO!=9pVPXMUVi30WmkPoy$02w}&6A7mF)G6-`~EVq5CwD2`9Zu`kd)52``#V zNSb`9dG~8(dooi1*-aSMf!fun7Sc`-C$-E(3BoSC$2kKrVcI!&yC*+ff2+C-@!AT_ zsvlAIV+%bRDfd{R*TMF><1&_a%@yZ0G0lg2K;F>7b+7A6pv3-S7qWIgx+Z?dt8}|S z>Qbb6x(+^aoV7FQ!Ph8|RUA6vXWQH*1$GJC+wXLXizNIc9p2yLzw9 z0=MdQ!{NnOwIICJc8!+Jp!zG}**r#E!<}&Te&}|B4q;U57$+pQI^}{qj669zMMe_I z&z0uUCqG%YwtUc8HVN7?0GHpu=bL7&{C>hcd5d(iFV{I5c~jpX&!(a{yS*4MEoYXh z*X4|Y@RVfn;piRm-C%b@{0R;aXrjBtvx^HO;6(>i*RnoG0Rtcd25BT6edxTNOgUAOjn zJ2)l{ipj8IP$KID2}*#F=M%^n&=bA0tY98@+2I+7~A&T-tw%W#3GV>GTmkHaqftl)#+E zMU*P(Rjo>8%P@_@#UNq(_L{}j(&-@1iY0TRizhiATJrnvwSH0v>lYfCI2ex^><3$q znzZgpW0JlQx?JB#0^^s-Js1}}wKh6f>(e%NrMwS`Q(FhazkZb|uyB@d%_9)_xb$6T zS*#-Bn)9gmobhAtvBmL+9H-+0_0US?g6^TOvE8f3v=z3o%NcPjOaf{5EMRnn(_z8- z$|m0D$FTU zDy;21v-#0i)9%_bZ7eo6B9@Q@&XprR&oKl4m>zIj-fiRy4Dqy@VVVs?rscG| zmzaDQ%>AQTi<^vYCmv#KOTd@l7#2VIpsj?nm_WfRZzJako`^uU%Nt3e;cU*y*|$7W zLm%fX#i_*HoUXu!NI$ey>BA<5HQB=|nRAwK!$L#n-Qz;~`zACig0PhAq#^5QS<8L2 zS3A+8%vbVMa7LOtTEM?55apt(DcWh#L}R^P2AY*c8B}Cx=6OFAdMPj1f>k3#^#+Hk z6uW1WJW&RlBRh*1DLb7mJ+KO>!t^t8hX1#_Wk`gjDio9)9IGbyCAGI4DJ~orK+YRv znjxRMtshZQHc$#Y-<-JOV6g^Cr@odj&Xw5B(FmI)*qJ9NHmIz_r{t)TxyB`L-%q5l ztzHgD;S6cw?7Atg*6E1!c6*gPRCb%t7D%z<(xm+K{%EJNiI2N0l8ud0Ch@_av_RW? zIr!nO4dL5466WslE6MsfMss7<)-S!e)2@r2o=7_W)OO`~CwklRWzHTfpB)_HYwgz=BzLhgZ9S<{nLBOwOIgJU=94uj6r!m>Xyn9>&xP+=5!zG_*yEoRgM0`aYts z^)&8(>z5C-QQ*o_s(8E4*?AX#S^0)aqB)OTyX>4BMy8h(cHjA8ji1PRlox@jB*1n? zDIfyDjzeg91Ao(;Q;KE@zei$}>EnrF6I}q&Xd=~&$WdDsyH0H7fJX|E+O~%LS*7^Q zYzZ4`pBdY{b7u72gZm6^5~O-57HwzwAz{)NvVaowo`X02tL3PpgLjwA`^i9F^vSpN zAqH3mRjG8VeJNHZ(1{%!XqC+)Z%D}58Qel{_weSEHoygT9pN@i zi=G;!Vj6XQk2tuJC>lza%ywz|`f7TIz*EN2Gdt!s199Dr4Tfd_%~fu8gXo~|ogt5Q zlEy_CXEe^BgsYM^o@L?s33WM14}7^T(kqohOX_iN@U?u;$l|rAvn{rwy>!yfZw13U zB@X9)qt&4;(C6dP?yRsoTMI!j-f1KC!<%~i1}u7yLXYn)(#a;Z6~r>hp~kfP));mi zcG%kdaB9H)z9M=H!f>kM->fTjRVOELNwh1amgKQT=I8J66kI)u_?0@$$~5f`u%;zl zC?pkr^p2Fe=J~WK%4ItSzKA+QHqJ@~m|Cduv=Q&-P8I5rQ-#G@bYH}YJr zUS(~(w|vKyU(T(*py}jTUp%I%{2!W!K(i$uvotcPjVddW z8_5HKY!oBCwGZcs-q`4Yt`Zk~>K?mcxg51wkZlX5e#B08I75F7#dgn5yf&Hrp`*%$ zQ;_Qg>TYRzBe$x=T(@WI9SC!ReSas9vDm(yslQjBJZde5z8GDU``r|N(MHcxNopGr z_}u39W_zwWDL*XYYt>#Xo!9kL#97|EAGyGBcRXtLTd59x%m=3i zL^9joWYA)HfL15l9%H?q`$mY27!<9$7GH(kxb%MV>`}hR4a?+*LH6aR{dzrX@?6X4 z3e`9L;cjqYb`cJmophbm(OX0b)!AFG?5`c#zLagzMW~o)?-!@e80lvk!p#&CD8u5_r&wp4O0zQ>y!k5U$h_K;rWGk=U)zX!#@Q%|9g*A zWx)qS1?fq6X<$mQTB$#3g;;5tHOYuAh;YKSBz%il3Ui6fPRv#v62SsrCdMRTav)Sg zTq1WOu&@v$Ey;@^+_!)cf|w_X<@RC>!=~+A1-65O0bOFYiH-)abINwZvFB;hJjL_$ z(9iScmUdMp2O$WW!520Hd0Q^Yj?DK%YgJD^ez$Z^?@9@Ab-=KgW@n8nC&88)TDC+E zlJM)L3r+ZJfZW_T$;Imq*#2<(j+FIk8ls7)WJ6CjUu#r5PoXxQs4b)mZza<8=v{o)VlLRM<9yw^0En#tXAj`Sylxvki{<1DPe^ zhjHwx^;c8tb?Vr$6ZB;$Ff$+3(*oinbwpN-#F)bTsXq@Sm?43MC#jQ~`F|twI=7oC zH4TJtu#;ngRA|Y~w5N=UfMZi?s0%ZmKUFTAye&6Y*y-%c1oD3yQ%IF2q2385Zl+=> zfz=o`Bedy|U;oxbyb^rB9ixG{Gb-{h$U0hVe`J;{ql!s_OJ_>>eoQn(G6h7+b^P48 zG<=Wg2;xGD-+d@UMZ!c;0>#3nws$9kIDkK13IfloGT@s14AY>&>>^#>`PT7GV$2Hp zN<{bN*ztlZu_%W=&3+=#3bE(mka6VoHEs~0BjZ$+=0`a@R$iaW)6>wp2w)=v2@|2d z%?34!+iOc5S@;AAC4hELWLH56RGxo4jw8MDMU0Wk2k_G}=Vo(>eRFo(g3@HjG|`H3 zm8b*dK=moM*oB<)*A$M9!!5o~4U``e)wxavm@O_R(`P|u%9^LGi(_%IF<6o;NLp*0 zKsfZ0#24GT8(G`i4UvoMh$^;kOhl?`0yNiyrC#HJH=tqOH^T_d<2Z+ zeN>Y9Zn!X4*DMCK^o75Zk2621bdmV7Rx@AX^alBG4%~;G_vUoxhfhFRlR&+3WwF^T zaL)8xPq|wCZoNT^>3J0K?e{J-kl+hu2rZI>CUv#-z&u@`hjeb+bBZ>bcciQVZ{SbW zez04s9oFEgc8Z+Kp{XFX`MVf-s&w9*dx7wLen(_@y34}Qz@&`$2+osqfxz4&d}{Ql z*g1ag00Gu+$C`0avds{Q65BfGsu9`_`dML*rX~hyWIe$T>CsPRoLIr%MTk3pJ^2zH1qub1MBzPG}PO;Wmav9w%F7?%l=xIf#LlP`! z_Nw;xBQY9anH5-c8A4mME}?{iewjz(Sq-29r{fV;Fc>fv%0!W@(+{={Xl-sJ6aMoc z)9Q+$bchoTGTyWU_oI19!)bD=IG&OImfy;VxNXoIO2hYEfO~MkE#IXTK(~?Z&!ae! zl8z{D&2PC$Q*OBC(rS~-*-GHNJ6AC$@eve>LB@Iq;jbBZj`wk4|LGogE||Ie=M5g= z9d`uYQ1^Sr_q2wmZE>w2WG)!F%^KiqyaDtIAct?}D~JP4shTJy5Bg+-(EA8aXaxbd~BKMtTf2iQ69jD1o* zZF9*S3!v-TdqwK$%&?91Sh2=e63;X0Lci@n7y3XOu2ofyL9^-I767eHESAq{m+@*r zbVDx!FQ|AjT;!bYsXv8ilQjy~Chiu&HNhFXt3R_6kMC8~ChEFqG@MWu#1Q1#=~#ix zrkHpJre_?#r=N0wv`-7cHHqU`phJX2M_^{H0~{VP79Dv{6YP)oA1&TSfKPEPZn2)G z9o{U1huZBLL;Tp_0OYw@+9z(jkrwIGdUrOhKJUbwy?WBt zlIK)*K0lQCY0qZ!$%1?3A#-S70F#YyUnmJF*`xx?aH5;gE5pe-15w)EB#nuf6B*c~ z8Z25NtY%6Wlb)bUA$w%HKs5$!Z*W?YKV-lE0@w^{4vw;J>=rn?u!rv$&eM+rpU6rc=j9>N2Op+C{D^mospMCjF2ZGhe4eADA#skp2EA26%p3Ex9wHW8l&Y@HX z$Qv)mHM}4*@M*#*ll5^hE9M^=q~eyWEai*P;4z<9ZYy!SlNE5nlc7gm;M&Q zKhKE4d*%A>^m0R?{N}y|i6i^k>^n4(wzKvlQeHq{l&JuFD~sTsdhs`(?lFK@Q{pU~ zb!M3c@*3IwN1RUOVjY5>uT+s-2QLWY z4T2>fiSn>>Fob+%B868-v9D@AfWr#M8eM6w#eAlhc#zk6jkLxGBGk`E3$!A@*am!R zy>29&ptYK6>cvP`b!syNp)Q$0UOW|-O@)8!?94GOYF_}+zlW%fCEl|Tep_zx05g6q z>tp47e-&R*hSNe{6{H!mL?+j$c^TXT{C&@T-xIaesNCl05 z9SLb@q&mSb)I{VXMaiWa3PWj=Ed!>*GwUe;^|uk=Pz$njNnfFY^MM>E?zqhf6^{}0 zx&~~dA5#}1ig~7HvOQ#;d9JZBeEQ+}-~v$at`m!(ai z$w(H&mWCC~;PQ1$%iuz3`>dWeb3_p}X>L2LK%2l59Tyc}4m0>9A!8rhoU3m>i2+hl zx?*qs*c^j}+WPs>&v1%1Ko8_ivAGIn@QK7A`hDz-Emkcgv2@wTbYhkiwX2l=xz*XG zaiNg+j4F-I>9v+LjosI-QECrtKjp&0T@xIMKVr+&)gyb4@b3y?2CA?=ooN zT#;rU86WLh(e@#mF*rk(NV-qSIZyr z$6!ZUmzD)%yO-ot`rw3rp6?*_l*@Z*IB0xn4|BGPWHNc-1ZUnNSMWmDh=EzWJRP`) zl%d%J613oXzh5;VY^XWJi{lB`f#u+ThvtP7 zq(HK<4>tw(=yzSBWtYO}XI`S1pMBe3!jFxBHIuwJ(@%zdQFi1Q_hU2eDuHqXte7Ki zOV55H2D6u#4oTfr7|u*3p75KF&jaLEDpxk!4*bhPc%mpfj)Us3XIG3 zIKMX^s^1wt8YK7Ky^UOG=w!o5e7W-<&c|fw2{;Q11vm@J{)@N3-p1U>!0~sKWHaL= zWV(0}1IIyt1p%=_-Fe5Kfzc71wg}`RDDntVZv;4!=&XXF-$48jS0Sc;eDy@Sg;+{A zFStc{dXT}kcIjMXb4F7MbX~2%i;UrBxm%qmLKb|2=?uPr00-$MEUIGR5+JG2l2Nq` zkM{{1RO_R)+8oQ6x&-^kCj)W8Z}TJjS*Wm4>hf+4#VJP)OBaDF%3pms7DclusBUw} z{ND#!*I6h85g6DzNvdAmnwWY{&+!KZM4DGzeHI?MR@+~|su0{y-5-nICz_MIT_#FE zm<5f3zlaKq!XyvY3H`9s&T};z!cK}G%;~!rpzk9-6L}4Rg7vXtKFsl}@sT#U#7)x- z7UWue5sa$R>N&b{J61&gvKcKlozH*;OjoDR+elkh|4bJ!_3AZNMOu?n9&|L>OTD78 z^i->ah_Mqc|Ev)KNDzfu1P3grBIM#%`QZqj5W{qu(HocQhjyS;UINoP`{J+DvV?|1 z_sw6Yr3z6%e7JKVDY<$P=M)dbk@~Yw9|2!Cw!io3%j92wTD!c^e9Vj+7VqXo3>u#= zv#M{HHJ=e$X5vQ>>ML?E8#UlmvJgTnb73{PSPTf*0)mcj6C z{KsfUbDK|F$E(k;ER%8HMdDi`=BfpZzP3cl5yJHu;v^o2FkHNk;cXc17tL8T!CsYI zfeZ6sw@;8ia|mY_AXjCS?kUfxdjDB28)~Tz1dGE|{VfBS9`0m2!m1yG?hR})er^pl4c@9Aq+|}ZlDaHL)K$O| z%9Jp-imI-Id0|(d5{v~w6mx)tUKfbuVD`xNt04Mry%M+jXzE>4(TBsx#&=@wT2Vh) z1yeEY&~17>0%P(eHP0HB^|7C+WJxQBTG$uyOWY@iDloRIb-Cf!p<{WQHR!422#F34 zG`v|#CJ^G}y9U*7jgTlD{D&y$Iv{6&PYG>{Ixg$pGk?lWrE#PJ8KunQC@}^6OP!|< zS;}p3to{S|uZz%kKe|;A0bL0XxPB&Q{J(9PyX`+Kr`k~r2}yP^ND{8!v7Q1&vtk& z2Y}l@J@{|2`oA%sxvM9i0V+8IXrZ4;tey)d;LZI70Kbim<4=WoTPZy=Yd|34v#$Kh zx|#YJ8s`J>W&jt#GcMpx84w2Z3ur-rK7gf-p5cE)=w1R2*|0mj12hvapuUWM0b~dG zMg9p8FmAZI@i{q~0@QuY44&mMUNXd7z>U58shA3o`p5eVLpq>+{(<3->DWuSFVZwC zxd50Uz(w~LxC4}bgag#q#NNokK@yNc+Q|Ap!u>Ddy+df>v;j@I12CDNN9do+0^n8p zMQs7X#+FVF0C5muGfN{r0|Nkql%BQT|K(DDNdR2pzM=_ea5+GO|J67`05AV92t@4l z0Qno0078PIHdaQGHZ~Scw!dzgqjK~3B7kf>BcP__&lLyU(cu3B^uLo%{j|Mb0NR)tkeT7Hcwp4O# z)yzu>cvG(d9~0a^)eZ;;%3ksk@F&1eEBje~ zW+-_s)&RgiweQc!otF>4%vbXKaOU41{!hw?|2`Ld3I8$&#WOsq>EG)1ANb!{N4z9@ zsU!bPG-~-bqCeIDzo^Q;gnucB{tRzm{ZH^Orphm2U+REA!*<*J6YQV83@&xoDl%#wnl5qcBqCcAF-vX5{30}(oJrnSH z{RY85hylK2dMOh2%oO1J8%)0?8TOL%rS8)+CsDv}aQ>4D)Jv+DLK)9gI^n-T^$)Tc zFPUD75qJm!Y-KBqj;JP4dV4 z`X{lGmn<)1IGz330}s}Jrjtf{(lnuuNHe5(ezA(pYa=1|Ff-LhPFK8 zyJh_b{yzu0yll6ZkpRzRjezyYivjyjW7QwO;@6X`m;2Apn2EK2!~7S}-*=;5*7K$B z`x(=!^?zgj(-`&ApZJXI09aDLXaT@<;CH=?fBOY5d|b~wBA@@p^K#nxr`)?i?SqTupI_PJ(A3cx`z~9mX_*)>L F{|7XC?P&l2 literal 0 HcmV?d00001 diff --git a/lite/demo/java/android/PaddlePredictor/gradle/wrapper/gradle-wrapper.properties b/lite/demo/java/android/PaddlePredictor/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000000..2d135d7b25c --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,6 @@ +#Wed Jun 26 10:57:21 CST 2019 +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-5.1.1-all.zip diff --git a/lite/demo/java/android/PaddlePredictor/gradlew b/lite/demo/java/android/PaddlePredictor/gradlew new file mode 100755 index 00000000000..cccdd3d517f --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/gradlew @@ -0,0 +1,172 @@ +#!/usr/bin/env sh + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS="" + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin, switch paths to Windows format before running java +if $cygwin ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=$((i+1)) + done + case $i in + (0) set -- ;; + (1) set -- "$args0" ;; + (2) set -- "$args0" "$args1" ;; + (3) set -- "$args0" "$args1" "$args2" ;; + (4) set -- "$args0" "$args1" "$args2" "$args3" ;; + (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=$(save "$@") + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong +if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then + cd "$(dirname "$0")" +fi + +exec "$JAVACMD" "$@" diff --git a/lite/demo/java/android/PaddlePredictor/gradlew.bat b/lite/demo/java/android/PaddlePredictor/gradlew.bat new file mode 100644 index 00000000000..e95643d6a2c --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/gradlew.bat @@ -0,0 +1,84 @@ +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS= + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windows variants + +if not "%OS%" == "Windows_NT" goto win9xME_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/lite/demo/java/android/PaddlePredictor/settings.gradle b/lite/demo/java/android/PaddlePredictor/settings.gradle new file mode 100644 index 00000000000..e7b4def49cb --- /dev/null +++ b/lite/demo/java/android/PaddlePredictor/settings.gradle @@ -0,0 +1 @@ +include ':app' diff --git a/lite/demo/java/android/prepare_demo.bash b/lite/demo/java/android/prepare_demo.bash new file mode 100644 index 00000000000..e0dbdaf75f4 --- /dev/null +++ b/lite/demo/java/android/prepare_demo.bash @@ -0,0 +1,23 @@ +#!/bin/bash + +# Script to download model files and copy .Jar and JNI lib for Android demo +# $1 will be the arch name + +if [ x$1 != x ]; then + cp ../../../java/so/libpaddle_lite_jni.so PaddlePredictor/app/src/main/jniLibs/$1/ +else + echo "Warning: didn't copy JNI .so lib because arch name is empty" +fi + +MODELS=(inception_v4_simple_opt.nb lite_naive_model_opt.nb mobilenet_v1_opt.nb mobilenet_v2_relu_opt.nb resnet50_opt.nb) +MODELS_DIR=PaddlePredictor/app/src/main/assets/ + +for m in "${MODELS[@]}" +do + wget --no-check-certificate -q http://paddle-inference-dist.bj.bcebos.com/${m}.tar.gz \ + -O ${MODELS_DIR}${m}.tar.gz + tar xzf ${MODELS_DIR}${m}.tar.gz -C ${MODELS_DIR} + rm -rf ${MODELS_DIR}${m}.tar.gz +done + +cp ../../../java/jar/PaddlePredictor.jar PaddlePredictor/app/libs/ diff --git a/lite/fluid/CMakeLists.txt b/lite/fluid/CMakeLists.txt new file mode 100644 index 00000000000..2258ae61cda --- /dev/null +++ b/lite/fluid/CMakeLists.txt @@ -0,0 +1,4 @@ +if (LITE_WITH_X86) +lite_cc_library(fluid_data_type SRCS data_type.cc DEPS framework_proto) +# lite_cc_library(selected_rows SRCS selected_rows.cc) +endif() diff --git a/lite/fluid/data_type.cc b/lite/fluid/data_type.cc new file mode 100644 index 00000000000..aa8971499fb --- /dev/null +++ b/lite/fluid/data_type.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/fluid/data_type.h" +#include +#include +#include + +using float16 = paddle::lite::fluid::float16; + +namespace paddle { +namespace lite { +namespace fluid { + +struct DataTypeMap { + std::unordered_map + cpp_to_proto_; + std::unordered_map proto_to_cpp_; + std::unordered_map proto_to_str_; + std::unordered_map proto_to_size_; +}; + +static DataTypeMap* InitDataTypeMap(); +// C++11 removes the need for manual locking. Concurrent execution shall wait if +// a static local variable is already being initialized. +// https://stackoverflow.com/questions/11711920/how-to-implement-multithread-safe-singleton-in-c11-without-using-mutex +static DataTypeMap& gDataTypeMap() { + static DataTypeMap* g_data_type_map_ = InitDataTypeMap(); + return *g_data_type_map_; +} + +template +static inline void RegisterType(DataTypeMap* map, + framework::proto::VarType::Type proto_type, + const std::string& name) { + map->proto_to_cpp_.emplace(static_cast(proto_type), typeid(T)); + map->cpp_to_proto_.emplace(typeid(T), proto_type); + map->proto_to_str_.emplace(static_cast(proto_type), name); + map->proto_to_size_.emplace(static_cast(proto_type), sizeof(T)); +} + +static DataTypeMap* InitDataTypeMap() { + auto retv = new DataTypeMap(); + +#define RegType(cc_type, proto_type) \ + RegisterType(retv, proto_type, #cc_type) + + _ForEachDataType_(RegType); + +#undef RegType + return retv; +} + +framework::proto::VarType::Type ToDataType(std::type_index type) { + auto it = gDataTypeMap().cpp_to_proto_.find(type); + if (it != gDataTypeMap().cpp_to_proto_.end()) { + return it->second; + } + PADDLE_THROW("Not support %s as tensor type", type.name()); +} + +std::type_index ToTypeIndex(framework::proto::VarType::Type type) { + auto it = gDataTypeMap().proto_to_cpp_.find(static_cast(type)); + if (it != gDataTypeMap().proto_to_cpp_.end()) { + return it->second; + } + PADDLE_THROW("Not support framework::proto::VarType::Type(%d) as tensor type", + static_cast(type)); +} + +std::string DataTypeToString(const framework::proto::VarType::Type type) { + auto it = gDataTypeMap().proto_to_str_.find(static_cast(type)); + if (it != gDataTypeMap().proto_to_str_.end()) { + return it->second; + } + PADDLE_THROW("Not support framework::proto::VarType::Type(%d) as tensor type", + static_cast(type)); +} + +size_t SizeOfType(framework::proto::VarType::Type type) { + auto it = gDataTypeMap().proto_to_size_.find(static_cast(type)); + if (it != gDataTypeMap().proto_to_size_.end()) { + return it->second; + } + PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type)); +} + +} // namespace fluid +} // namespace lite +} // namespace paddle diff --git a/lite/fluid/data_type.h b/lite/fluid/data_type.h new file mode 100644 index 00000000000..a8b11ec465e --- /dev/null +++ b/lite/fluid/data_type.h @@ -0,0 +1,88 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "lite/core/framework.pb.h" +#include "lite/fluid/float16.h" +#include "lite/utils/paddle_enforce.h" + +namespace paddle { +namespace lite { +namespace fluid { + +template +struct DataTypeTrait {}; + +// Stub handle for void +template <> +struct DataTypeTrait { + constexpr static auto DataType = framework::proto::VarType::RAW; +}; + +#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \ + callback(cpp_type, ::paddle::framework::proto::VarType::proto_type); + +#define _ForEachDataType_(callback) \ + _ForEachDataTypeHelper_(callback, float, FP32); \ + _ForEachDataTypeHelper_(callback, ::paddle::lite::fluid::float16, FP16); \ + _ForEachDataTypeHelper_(callback, double, FP64); \ + _ForEachDataTypeHelper_(callback, int, INT32); \ + _ForEachDataTypeHelper_(callback, int64_t, INT64); \ + _ForEachDataTypeHelper_(callback, bool, BOOL); \ + _ForEachDataTypeHelper_(callback, uint8_t, UINT8); \ + _ForEachDataTypeHelper_(callback, int16_t, INT16); \ + _ForEachDataTypeHelper_(callback, int8_t, INT8) + +#define DefineDataTypeTrait(cpp_type, proto_type) \ + template <> \ + struct DataTypeTrait { \ + constexpr static auto DataType = proto_type; \ + } + +_ForEachDataType_(DefineDataTypeTrait); + +#undef DefineDataTypeTrait + +extern framework::proto::VarType::Type ToDataType(std::type_index type); +extern std::type_index ToTypeIndex(framework::proto::VarType::Type type); + +template +inline void VisitDataType(framework::proto::VarType::Type type, + Visitor visitor) { +#define VisitDataTypeCallback(cpp_type, proto_type) \ + do { \ + if (type == proto_type) { \ + visitor.template apply(); \ + return; \ + } \ + } while (0) + + _ForEachDataType_(VisitDataTypeCallback); +#undef VisitDataTypeCallback + PADDLE_THROW("Not supported %d", type); +} + +extern std::string DataTypeToString(const framework::proto::VarType::Type type); +extern size_t SizeOfType(framework::proto::VarType::Type type); +inline std::ostream& operator<<(std::ostream& out, + const framework::proto::VarType::Type& type) { + out << DataTypeToString(type); + return out; +} + +} // namespace fluid +} // namespace lite +} // namespace paddle diff --git a/lite/fluid/data_type_test.cc b/lite/fluid/data_type_test.cc new file mode 100644 index 00000000000..2a380201f29 --- /dev/null +++ b/lite/fluid/data_type_test.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/framework/data_type.h" + +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/tensor.h" + +TEST(DataType, float16) { + using paddle::framework::Tensor; + using paddle::platform::CPUPlace; + using paddle::platform::float16; + namespace f = paddle::framework; + f::proto::VarType::Type dtype = f::proto::VarType::FP16; + + Tensor tensor; + CPUPlace cpu; + tensor.mutable_data(cpu, dtype); + + // test fp16 tensor + EXPECT_EQ(tensor.type(), f::ToDataType(typeid(float16))); + + // test fp16 size + EXPECT_EQ(f::SizeOfType(dtype), 2u); + + // test debug info + std::string type = "::paddle::platform::float16"; + EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str()); +} diff --git a/lite/fluid/eigen.h b/lite/fluid/eigen.h new file mode 100644 index 00000000000..f5d5e4b5e51 --- /dev/null +++ b/lite/fluid/eigen.h @@ -0,0 +1,141 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "lite/core/tensor.h" +#include "lite/fluid/float16.h" +#include "lite/utils/paddle_enforce.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace lite { +namespace fluid { + +// EigenDim converts paddle::platform::DDim into Eigen::DSizes. +template +struct EigenDim { + using Type = Eigen::DSizes; + + static Type From(const lite::DDim& dims) { + PADDLE_ENFORCE(dims.size() == D, "D must match DDim::size"); + Type ret; + for (int64_t d = 0; d < dims.size(); d++) { + ret[d] = dims[d]; + } + return ret; + } +}; + +// Interpret paddle::platform::Tensor as EigenTensor and EigenConstTensor. +template +struct EigenTensor { + // TODO(qijun) Now, default type in unaligned, and we will make a benchmark on + // the speed of aligned and unaligned version in future. + using Type = Eigen::TensorMap>; + + using ConstType = + Eigen::TensorMap>; + + static Type From(Tensor& tensor, lite::DDim dims) { // NOLINT + return Type(const_cast(tensor.data()), + EigenDim::From(dims)); // NOLINT + } + + static Type From(Tensor& tensor) { // NOLINT + return From(tensor, tensor.dims()); + } // NOLINT + + static ConstType From(const Tensor& tensor, lite::DDim dims) { + return ConstType(tensor.data(), EigenDim::From(dims)); + } + + static ConstType From(const Tensor& tensor) { + return From(tensor, tensor.dims()); + } +}; + +template +struct EigenMatrix : public EigenTensor { + static typename EigenMatrix::Type Reshape(Tensor& tensor, // NOLINT + int num_col_dims) { + int rank = tensor.dims().size(); + PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank, + "`num_col_dims` must be between (0, rank_of_tensor)."); + return EigenMatrix::From(tensor, tensor.dims().Flatten2D(num_col_dims)); + } + + static typename EigenMatrix::ConstType Reshape(const Tensor& tensor, + int num_col_dims) { + int rank = tensor.dims().size(); + PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank, + "`num_col_dims` must be between (0, rank_of_tensor)."); + return EigenMatrix::From(tensor, tensor.dims().Flatten2D(num_col_dims)); + } +}; + +template +struct EigenVector : public EigenTensor { + // Flatten reshapes a Tensor into an EigenVector. + static typename EigenVector::Type Flatten(Tensor& tensor) { // NOLINT + return EigenVector::From( + tensor, lite::DDim(std::vector({tensor.dims().production()}))); + } + + static typename EigenVector::ConstType Flatten( + const Tensor& tensor) { // NOLINT + return EigenVector::From( + tensor, lite::DDim(std::vector({tensor.dims().production()}))); + } +}; + +template +struct EigenScalar { + // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. + using Type = Eigen::TensorMap< + Eigen::TensorFixedSize, MajorType, IndexType>>; + using ConstType = Eigen::TensorMap< + Eigen::TensorFixedSize, MajorType, IndexType>>; + + static Type From(Tensor& tensor) { return Type(tensor.data()); } // NOLINT + + static ConstType From(const Tensor& tensor) { + return ConstType(tensor.data()); + } +}; + +template +struct EigenDevice; + +template <> +struct EigenDevice { + using Type = ::Eigen::DefaultDevice; +}; + +template +using EigenDeviceType = typename EigenDevice::Type; + +} // namespace fluid +} // namespace lite +} // namespace paddle diff --git a/lite/fluid/float16.h b/lite/fluid/float16.h new file mode 100644 index 00000000000..d1ef6f7dc59 --- /dev/null +++ b/lite/fluid/float16.h @@ -0,0 +1,1100 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#ifdef PADDLE_WITH_CUDA +#include +#endif // PADDLE_WITH_CUDA + +#ifdef __GNUC__ +#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__) +#else +#define PADDLE_GNUC_VER 0 +#endif // __GNUC__ + +#ifdef __clang__ +#define PADDLE_CLANG_VER (__clang_major__ * 10 + __clang_minor__) +#else +#define PADDLE_CLANG_VER 0 +#endif // __clang__ + +#if defined(__CUDACC__) && CUDA_VERSION >= 7050 +#define PADDLE_CUDA_FP16 +#include +#endif + +#if defined(__arm__) || defined(__aarch64__) +#define PADDLE_ARM +#endif + +#if defined(__ARM_NEON) || defined(__ARM_NEON__) +#define PADDLE_NEON +#include +#endif + +#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \ + (PADDLE_GNUC_VER >= 62 || PADDLE_CLANG_VER >= 37) +#define PADDLE_WITH_NATIVE_FP16 +#endif + +#ifndef PADDLE_ARM +#include +#endif // PADDLE_ARM + +#if !defined(_WIN32) +#define PADDLE_ALIGN(x) __attribute__((aligned(x))) +#else +#define PADDLE_ALIGN(x) __declspec(align(x)) +#endif + +namespace paddle { +namespace lite { +namespace fluid { + +// Forward declare float16 for eigen.h +struct float16; + +} // namespace fluid +} // namespace lite +} // namespace paddle + +#include "lite/utils/macros.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace lite { +namespace fluid { + +// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated +// and aligned at least on a 2-byte boundary, which leads to efficient +// memory access of float16 struct and also makes float16 compatible +// with CUDA half, ARM float16_t, and Eigen::half data types. +struct PADDLE_ALIGN(2) float16 { + public: + uint16_t x; + + // The following defaulted special class member functions + // are added to make float16 pass the std::is_trivial test + float16() = default; + float16(const float16& o) = default; + float16& operator=(const float16& o) = default; + float16(float16&& o) = default; + float16& operator=(float16&& o) = default; + ~float16() = default; + +// Constructors +#ifdef PADDLE_CUDA_FP16 + HOSTDEVICE inline explicit float16(const half& h) { +#if CUDA_VERSION >= 9000 + x = reinterpret_cast<__half_raw*>(const_cast(&h))->x; +#else + x = h.x; +#endif // CUDA_VERSION >= 9000 + } +#endif // PADDLE_CUDA_FP16 + + HOSTDEVICE inline explicit float16(const Eigen::half& h) : x(h.x) {} + +#ifdef PADDLE_WITH_NATIVE_FP16 + // __fp16 is a native half precision data type for arm cpu, + // float16_t is an alias for __fp16 + HOSTDEVICE inline explicit float16(const float16_t& h) { + x = *reinterpret_cast(&h); + } +#endif + + HOSTDEVICE inline explicit float16(float val) { +#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 + half tmp = __float2half(val); + x = *reinterpret_cast(&tmp); + +#elif defined(PADDLE_WITH_NATIVE_FP16) + float32x4_t tmp = vld1q_dup_f32(&val); + float16_t res = vget_lane_f16(vcvt_f16_f32(tmp), 0); + x = *reinterpret_cast(&res); + +#elif defined(__F16C__) + x = _cvtss_sh(val, 0); + +#else + // Conversion routine adapted from + // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion + Bits v, s; + v.f = val; + uint32_t sign = v.si & sigN; + v.si ^= sign; + sign >>= shiftSign; // logical shift + s.si = mulN; + s.si = s.f * v.f; // correct subnormals + v.si ^= (s.si ^ v.si) & -(minN > v.si); + v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN)); + v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN)); + v.ui >>= shift; // logical shift + v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC); + v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC); + x = v.ui | sign; + +#endif + } + + HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {} + + template + HOSTDEVICE inline explicit float16(const T& val) + : x(float16(static_cast(val)).x) {} + +// Assignment operators +#ifdef PADDLE_CUDA_FP16 + HOSTDEVICE inline float16& operator=(const half& rhs) { +#if CUDA_VERSION >= 9000 + x = reinterpret_cast<__half_raw*>(const_cast(&rhs))->x; +#else + x = rhs.x; +#endif + return *this; + } +#endif + + HOSTDEVICE inline float16& operator=(const Eigen::half& rhs) { + x = rhs.x; + return *this; + } + +#ifdef PADDLE_WITH_NATIVE_FP16 + HOSTDEVICE inline float16& operator=(const float16_t& rhs) { + x = *reinterpret_cast(&rhs); + return *this; + } +#endif + + HOSTDEVICE inline float16& operator=(bool b) { + x = b ? 0x3c00 : 0; + return *this; + } + + HOSTDEVICE inline float16& operator=(int8_t val) { + x = float16(val).x; + return *this; + } + + HOSTDEVICE inline float16& operator=(uint8_t val) { + x = float16(val).x; + return *this; + } + + HOSTDEVICE inline float16& operator=(int16_t val) { + x = float16(val).x; + return *this; + } + + HOSTDEVICE inline float16& operator=(uint16_t val) { + x = float16(val).x; + return *this; + } + + HOSTDEVICE inline float16& operator=(int32_t val) { + x = float16(val).x; + return *this; + } + + HOSTDEVICE inline float16& operator=(uint32_t val) { + x = float16(val).x; + return *this; + } + + HOSTDEVICE inline float16& operator=(int64_t val) { + x = float16(val).x; + return *this; + } + + HOSTDEVICE inline float16& operator=(uint64_t val) { + x = float16(val).x; + return *this; + } + + HOSTDEVICE inline float16& operator=(float val) { + x = float16(val).x; + return *this; + } + + HOSTDEVICE inline float16& operator=(double val) { + x = float16(val).x; + return *this; + } + +// Conversion opertors +#ifdef PADDLE_CUDA_FP16 + HOSTDEVICE inline explicit operator half() const { +#if CUDA_VERSION >= 9000 + __half_raw h; + h.x = x; + return half(h); +#else + half h; + h.x = x; + return h; +#endif // CUDA_VERSION >= 9000 + } +#endif // PADDLE_CUDA_FP16 + + HOSTDEVICE inline explicit operator Eigen::half() const { + Eigen::half h; + h.x = x; + return h; + } + +#ifdef PADDLE_WITH_NATIVE_FP16 + HOSTDEVICE inline explicit operator float16_t() const { + return *reinterpret_cast(this); + } +#endif + + HOSTDEVICE inline explicit operator float() const { +#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 + half tmp = *reinterpret_cast(this); + return __half2float(tmp); + +#elif defined(PADDLE_WITH_NATIVE_FP16) + float16x4_t res = vld1_dup_f16(reinterpret_cast(this)); + return vgetq_lane_f32(vcvt_f32_f16(res), 0); + +#elif defined(__F16C__) + return _cvtsh_ss(this->x); + +#else + // Conversion routine adapted from + // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion + Bits v; + v.ui = this->x; + int32_t sign = v.si & sigC; + v.si ^= sign; + sign <<= shiftSign; + v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); + v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); + Bits s; + s.si = mulC; + s.f *= v.si; + int32_t mask = -(norC > v.si); + v.si <<= shift; + v.si ^= (s.si ^ v.si) & mask; + v.si |= sign; + return v.f; + +#endif + } + + HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; } + + HOSTDEVICE inline explicit operator int8_t() const { + return static_cast(static_cast(*this)); + } + + HOSTDEVICE inline explicit operator uint8_t() const { + return static_cast(static_cast(*this)); + } + + HOSTDEVICE inline explicit operator int16_t() const { + return static_cast(static_cast(*this)); + } + + HOSTDEVICE inline explicit operator uint16_t() const { + return static_cast(static_cast(*this)); + } + + HOSTDEVICE inline explicit operator int32_t() const { + return static_cast(static_cast(*this)); + } + + HOSTDEVICE inline explicit operator uint32_t() const { + return static_cast(static_cast(*this)); + } + + HOSTDEVICE inline explicit operator int64_t() const { + return static_cast(static_cast(*this)); + } + + HOSTDEVICE inline explicit operator uint64_t() const { + return static_cast(static_cast(*this)); + } + + HOSTDEVICE inline explicit operator double() const { + return static_cast(static_cast(*this)); + } + + private: + union Bits { + float f; + int32_t si; + uint32_t ui; + }; + + static const int shift = 13; + static const int shiftSign = 16; + + static const int32_t infN = 0x7F800000; + static const int32_t maxN = 0x477FE000; // max flt16 as flt32 + static const int32_t minN = 0x38800000; // min flt16 normal as flt32 + static const int32_t sigN = 0x80000000; // sign bit + + static constexpr int32_t infC = infN >> shift; + static constexpr int32_t nanN = (infC + 1) + << shift; // minimum flt16 nan as float32 + static constexpr int32_t maxC = maxN >> shift; + static constexpr int32_t minC = minN >> shift; + static constexpr int32_t sigC = sigN >> shiftSign; + + static const int32_t mulN = 0x52000000; // (1 << 23) / minN + static const int32_t mulC = 0x33800000; // minN / (1 << (23 - shift)) + static const int32_t subC = 0x003FF; // max flt32 subnormal downshifted + static const int32_t norC = 0x00400; // min flt32 normal downshifted + + static constexpr int32_t maxD = infC - maxC - 1; + static constexpr int32_t minD = minC - subC - 1; +}; + +// Arithmetic operators on GPU +// CUDA 9.0 provides built-in arithmetic operators for half while +// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are +// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in +// CUDA 9.0 regarding the half data type. +#if defined(PADDLE_CUDA_FP16) && CUDA_VERSION < 9000 + +DEVICE inline half operator+(const half& a, const half& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hadd(a, b); +#else + float res = static_cast(float16(a)) + static_cast(float16(b)); + return half(float16(res)); +#endif +} + +DEVICE inline half operator-(const half& a, const half& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hsub(a, b); +#else + float res = static_cast(float16(a)) - static_cast(float16(b)); + return half(float16(res)); +#endif +} + +DEVICE inline half operator*(const half& a, const half& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hmul(a, b); +#else + float res = static_cast(float16(a)) * static_cast(float16(b)); + return half(float16(res)); +#endif +} + +DEVICE inline half operator/(const half& a, const half& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 + float num = __half2float(a); + float denom = __half2float(b); + return __float2half(num / denom); +#else + float res = static_cast(float16(a)) / static_cast(float16(b)); + return half(float16(res)); +#endif +} + +DEVICE inline half operator-(const half& a) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hneg(a); +#else + float res = -static_cast(float16(a)); + return half(float16(res)); +#endif +} + +DEVICE inline half& operator+=(half& a, const half& b) { // NOLINT + a = a + b; + return a; +} + +DEVICE inline half& operator-=(half& a, const half& b) { // NOLINT + a = a - b; + return a; +} + +DEVICE inline half& operator*=(half& a, const half& b) { // NOLINT + a = a * b; + return a; +} + +DEVICE inline half& operator/=(half& a, const half& b) { // NOLINT + a = a / b; + return a; +} + +DEVICE inline bool operator==(const half& a, const half& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __heq(a, b); +#else + return static_cast(float16(a)) == static_cast(float16(b)); +#endif +} + +DEVICE inline bool operator!=(const half& a, const half& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hne(a, b); +#else + return static_cast(float16(a)) != static_cast(float16(b)); +#endif +} + +DEVICE inline bool operator<(const half& a, const half& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hlt(a, b); +#else + return static_cast(float16(a)) < static_cast(float16(b)); +#endif +} + +DEVICE inline bool operator<=(const half& a, const half& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hle(a, b); +#else + return static_cast(float16(a)) <= static_cast(float16(b)); +#endif +} + +DEVICE inline bool operator>(const half& a, const half& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hgt(a, b); +#else + return static_cast(float16(a)) > static_cast(float16(b)); +#endif +} + +DEVICE inline bool operator>=(const half& a, const half& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hge(a, b); +#else + return static_cast(float16(a)) >= static_cast(float16(b)); +#endif +} + +#endif // PADDLE_CUDA_FP16 + +// Arithmetic operators for float16 on GPU +#if defined(PADDLE_CUDA_FP16) +HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return float16(__hadd(half(a), half(b))); +#else + return float16(static_cast(a) + static_cast(b)); +#endif +} + +HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return float16(__hsub(half(a), half(b))); +#else + return float16(static_cast(a) - static_cast(b)); +#endif +} + +HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return float16(__hmul(half(a), half(b))); +#else + return float16(static_cast(a) * static_cast(b)); +#endif +} + +HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300 + // TODO(kexinzhao): check which cuda version starts to support __hdiv + float num = __half2float(half(a)); + float denom = __half2float(half(b)); + return float16(num / denom); +#else + return float16(static_cast(a) / static_cast(b)); +#endif +} + +HOSTDEVICE inline float16 operator-(const float16& a) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return float16(__hneg(half(a))); +#else + float16 res; + res.x = a.x ^ 0x8000; + return res; +#endif +} + +HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) { // NOLINT + a = a + b; + return a; +} + +HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) { // NOLINT + a = a - b; + return a; +} + +HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) { // NOLINT + a = a * b; + return a; +} + +HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) { // NOLINT + a = a / b; + return a; +} + +HOSTDEVICE inline bool operator==(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __heq(half(a), half(b)); +#else + return static_cast(a) == static_cast(b); +#endif +} + +HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hne(half(a), half(b)); +#else + return static_cast(a) != static_cast(b); +#endif +} + +HOSTDEVICE inline bool operator<(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hlt(half(a), half(b)); +#else + return static_cast(a) < static_cast(b); +#endif +} + +HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hle(half(a), half(b)); +#else + return static_cast(a) <= static_cast(b); +#endif +} + +HOSTDEVICE inline bool operator>(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hgt(half(a), half(b)); +#else + return static_cast(a) > static_cast(b); +#endif +} + +HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hge(half(a), half(b)); +#else + return static_cast(a) >= static_cast(b); +#endif +} + +// Arithmetic operators for float16 on ARMv8.2-A CPU +#elif defined(PADDLE_WITH_NATIVE_FP16) +inline float16 operator+(const float16& a, const float16& b) { + float16 res; + asm volatile( + "ld1 {v0.h}[0], [%[a_ptr]]\n" + "ld1 {v1.h}[0], [%[b_ptr]]\n" + "fadd h0, h0, h1\n" + "st1 {v0.h}[0], [%[res_ptr]]\n" + : // outputs + : // inputs + [a_ptr] "r"(&(a.x)), + [b_ptr] "r"(&(b.x)), + [res_ptr] "r"(&(res.x)) + : // clobbers + "memory", "v0", "v1"); + return res; +} + +inline float16 operator-(const float16& a, const float16& b) { + float16 res; + asm volatile( + "ld1 {v0.h}[0], [%[a_ptr]]\n" + "ld1 {v1.h}[0], [%[b_ptr]]\n" + "fsub h0, h0, h1\n" + "st1 {v0.h}[0], [%[res_ptr]]\n" + : // outputs + : // inputs + [a_ptr] "r"(&(a.x)), + [b_ptr] "r"(&(b.x)), + [res_ptr] "r"(&(res.x)) + : // clobbers + "memory", "v0", "v1"); + return res; +} + +inline float16 operator*(const float16& a, const float16& b) { + float16 res; + asm volatile( + "ld1 {v0.h}[0], [%[a_ptr]]\n" + "ld1 {v1.h}[0], [%[b_ptr]]\n" + "fmul h0, h0, h1\n" + "st1 {v0.h}[0], [%[res_ptr]]\n" + : // outputs + : // inputs + [a_ptr] "r"(&(a.x)), + [b_ptr] "r"(&(b.x)), + [res_ptr] "r"(&(res.x)) + : // clobbers + "memory", "v0", "v1"); + return res; +} + +inline float16 operator/(const float16& a, const float16& b) { + float16 res; + asm volatile( + "ld1 {v0.h}[0], [%[a_ptr]]\n" + "ld1 {v1.h}[0], [%[b_ptr]]\n" + "fdiv h0, h0, h1\n" + "st1 {v0.h}[0], [%[res_ptr]]\n" + : // outputs + : // inputs + [a_ptr] "r"(&(a.x)), + [b_ptr] "r"(&(b.x)), + [res_ptr] "r"(&(res.x)) + : // clobbers + "memory", "v0", "v1"); + return res; +} + +inline float16 operator-(const float16& a) { + float16 res; + asm volatile( + "ld1 {v0.h}[0], [%[a_ptr]]\n" + "fneg h0, h0\n" + "st1 {v0.h}[0], [%[res_ptr]]\n" + : // outputs + : // inputs + [a_ptr] "r"(&(a.x)), + [res_ptr] "r"(&(res.x)) + : // clobbers + "memory", "v0"); + return res; +} + +inline float16& operator+=(float16& a, const float16& b) { // NOLINT + a = a + b; + return a; +} + +inline float16& operator-=(float16& a, const float16& b) { // NOLINT + a = a - b; + return a; +} + +inline float16& operator*=(float16& a, const float16& b) { // NOLINT + a = a * b; + return a; +} + +inline float16& operator/=(float16& a, const float16& b) { // NOLINT + a = a / b; + return a; +} + +inline bool operator==(const float16& a, const float16& b) { + uint16_t res; + asm volatile( + "ld1 {v0.h}[0], [%[a_ptr]]\n" + "ld1 {v1.h}[0], [%[b_ptr]]\n" + "fcmeq h0, h0, h1\n" + "st1 {v0.h}[0], [%[res_ptr]]\n" + : // outputs + : // inputs + [a_ptr] "r"(&(a.x)), + [b_ptr] "r"(&(b.x)), + [res_ptr] "r"(&res) + : // clobbers + "memory", "v0", "v1"); + return (res & 0xffff) != 0; +} + +inline bool operator!=(const float16& a, const float16& b) { return !(a == b); } + +inline bool operator<(const float16& a, const float16& b) { + uint16_t res; + asm volatile( + "ld1 {v1.h}[0], [%[a_ptr]]\n" + "ld1 {v0.h}[0], [%[b_ptr]]\n" + "fcmgt h0, h0, h1\n" + "st1 {v0.h}[0], [%[res_ptr]]\n" + : // outputs + : // inputs + [a_ptr] "r"(&(a.x)), + [b_ptr] "r"(&(b.x)), + [res_ptr] "r"(&res) + : // clobbers + "memory", "v0", "v1"); + return (res & 0xffff) != 0; +} + +inline bool operator<=(const float16& a, const float16& b) { + uint16_t res; + asm volatile( + "ld1 {v1.h}[0], [%[a_ptr]]\n" + "ld1 {v0.h}[0], [%[b_ptr]]\n" + "fcmge h0, h0, h1\n" + "st1 {v0.h}[0], [%[res_ptr]]\n" + : // outputs + : // inputs + [a_ptr] "r"(&(a.x)), + [b_ptr] "r"(&(b.x)), + [res_ptr] "r"(&res) + : // clobbers + "memory", "v0", "v1"); + return (res & 0xffff) != 0; +} + +inline bool operator>(const float16& a, const float16& b) { + uint16_t res; + asm volatile( + "ld1 {v0.h}[0], [%[a_ptr]]\n" + "ld1 {v1.h}[0], [%[b_ptr]]\n" + "fcmgt h0, h0, h1\n" + "st1 {v0.h}[0], [%[res_ptr]]\n" + : // outputs + : // inputs + [a_ptr] "r"(&(a.x)), + [b_ptr] "r"(&(b.x)), + [res_ptr] "r"(&res) + : // clobbers + "memory", "v0", "v1"); + return (res & 0xffff) != 0; +} + +inline bool operator>=(const float16& a, const float16& b) { + uint16_t res; + asm volatile( + "ld1 {v0.h}[0], [%[a_ptr]]\n" + "ld1 {v1.h}[0], [%[b_ptr]]\n" + "fcmge h0, h0, h1\n" + "st1 {v0.h}[0], [%[res_ptr]]\n" + : // outputs + : // inputs + [a_ptr] "r"(&(a.x)), + [b_ptr] "r"(&(b.x)), + [res_ptr] "r"(&res) + : // clobbers + "memory", "v0", "v1"); + return (res & 0xffff) != 0; +} + +// Arithmetic operators for float16, software emulated on other CPU +#else +inline float16 operator+(const float16& a, const float16& b) { + return float16(static_cast(a) + static_cast(b)); +} + +inline float16 operator-(const float16& a, const float16& b) { + return float16(static_cast(a) - static_cast(b)); +} + +inline float16 operator*(const float16& a, const float16& b) { + return float16(static_cast(a) * static_cast(b)); +} + +inline float16 operator/(const float16& a, const float16& b) { + return float16(static_cast(a) / static_cast(b)); +} + +inline float16 operator-(const float16& a) { + float16 res; + res.x = a.x ^ 0x8000; + return res; +} + +inline float16& operator+=(float16& a, const float16& b) { // NOLINT + a = float16(static_cast(a) + static_cast(b)); + return a; +} + +inline float16& operator-=(float16& a, const float16& b) { // NOLINT + a = float16(static_cast(a) - static_cast(b)); + return a; +} + +inline float16& operator*=(float16& a, const float16& b) { // NOLINT + a = float16(static_cast(a) * static_cast(b)); + return a; +} + +inline float16& operator/=(float16& a, const float16& b) { // NOLINT + a = float16(static_cast(a) / static_cast(b)); + return a; +} + +inline bool operator==(const float16& a, const float16& b) { + return static_cast(a) == static_cast(b); +} + +inline bool operator!=(const float16& a, const float16& b) { + return static_cast(a) != static_cast(b); +} + +inline bool operator<(const float16& a, const float16& b) { + return static_cast(a) < static_cast(b); +} + +inline bool operator<=(const float16& a, const float16& b) { + return static_cast(a) <= static_cast(b); +} + +inline bool operator>(const float16& a, const float16& b) { + return static_cast(a) > static_cast(b); +} + +inline bool operator>=(const float16& a, const float16& b) { + return static_cast(a) >= static_cast(b); +} +#endif + +HOSTDEVICE inline float16 raw_uint16_to_float16(uint16_t a) { + float16 res; + res.x = a; + return res; +} + +HOSTDEVICE inline bool(isnan)(const float16& a) { +#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hisnan(half(a)); +#else + return (a.x & 0x7fff) > 0x7c00; +#endif +} + +HOSTDEVICE inline bool(isinf)(const float16& a) { + return (a.x & 0x7fff) == 0x7c00; +} + +HOSTDEVICE inline bool(isfinite)(const float16& a) { + return !((isnan)(a)) && !((isinf)(a)); +} + +inline std::ostream& operator<<(std::ostream& os, const float16& a) { + os << static_cast(a); + return os; +} + +} // namespace fluid +} // namespace lite +} // namespace paddle + +namespace std { + +// Override the std::is_pod::value for float16 +// The reason is that different compilers implemented std::is_pod based on +// different C++ standards. float16 class is a plain old data in C++11 given +// that it is both trivial and standard_layout. +// However, std::is_pod in nvcc 8.0 host c++ compiler follows C++0x and is +// more restricted in that you cannot provide any customized +// constructor in float16. Hence, we override is_pod here following C++11 +// so that .cu files can be successfully compiled by nvcc. +template <> +struct is_pod { + static const bool value = + is_trivial::value && + is_standard_layout::value; +}; + +template <> +struct is_floating_point + : std::integral_constant< + bool, + std::is_same::type>::value> {}; +template <> +struct is_signed { + static const bool value = true; +}; + +template <> +struct is_unsigned { + static const bool value = false; +}; + +inline bool isnan(const paddle::lite::fluid::float16& a) { + return paddle::lite::fluid::isnan(a); +} + +inline bool isinf(const paddle::lite::fluid::float16& a) { + return paddle::lite::fluid::isinf(a); +} + +template <> +struct numeric_limits { + static const bool is_specialized = true; + static const bool is_signed = true; + static const bool is_integer = false; + static const bool is_exact = false; + static const bool has_infinity = true; + static const bool has_quiet_NaN = true; + static const bool has_signaling_NaN = true; + static const float_denorm_style has_denorm = denorm_present; + static const bool has_denorm_loss = false; + static const std::float_round_style round_style = std::round_to_nearest; + static const bool is_iec559 = false; + static const bool is_bounded = false; + static const bool is_modulo = false; + static const int digits = 11; + static const int digits10 = 3; + static const int max_digits10 = 5; + static const int radix = 2; + static const int min_exponent = -13; + static const int min_exponent10 = -4; + static const int max_exponent = 16; + static const int max_exponent10 = 4; + static const bool traps = true; + static const bool tinyness_before = false; + + static paddle::lite::fluid::float16(min)() { + return paddle::lite::fluid::raw_uint16_to_float16(0x400); + } + static paddle::lite::fluid::float16 lowest() { + return paddle::lite::fluid::raw_uint16_to_float16(0xfbff); + } + static paddle::lite::fluid::float16(max)() { + return paddle::lite::fluid::raw_uint16_to_float16(0x7bff); + } + static paddle::lite::fluid::float16 epsilon() { + return paddle::lite::fluid::raw_uint16_to_float16(0x0800); + } + static paddle::lite::fluid::float16 round_error() { + return paddle::lite::fluid::float16(0.5); + } + static paddle::lite::fluid::float16 infinity() { + return paddle::lite::fluid::raw_uint16_to_float16(0x7c00); + } + static paddle::lite::fluid::float16 quiet_NaN() { + return paddle::lite::fluid::raw_uint16_to_float16(0x7e00); + } + static paddle::lite::fluid::float16 signaling_NaN() { + return paddle::lite::fluid::raw_uint16_to_float16(0x7e00); + } + static paddle::lite::fluid::float16 denorm_min() { + return paddle::lite::fluid::raw_uint16_to_float16(0x1); + } +}; + +} // namespace std + +namespace Eigen { + +using float16 = paddle::lite::fluid::float16; + +template <> +struct NumTraits : GenericNumTraits { + enum { + IsSigned = true, + IsInteger = false, + IsComplex = false, + RequireInitialization = false + }; + + HOSTDEVICE static inline float16 epsilon() { + return paddle::lite::fluid::raw_uint16_to_float16(0x0800); + } + HOSTDEVICE static inline float16 dummy_precision() { return float16(1e-2f); } + HOSTDEVICE static inline float16 highest() { + return paddle::lite::fluid::raw_uint16_to_float16(0x7bff); + } + HOSTDEVICE static inline float16 lowest() { + return paddle::lite::fluid::raw_uint16_to_float16(0xfbff); + } + HOSTDEVICE static inline float16 infinity() { + return paddle::lite::fluid::raw_uint16_to_float16(0x7c00); + } + HOSTDEVICE static inline float16 quiet_NaN() { + return paddle::lite::fluid::raw_uint16_to_float16(0x7c01); + } +}; + +namespace numext { + +template <> +HOSTDEVICE inline bool(isnan)(const float16& a) { + return (paddle::lite::fluid::isnan)(a); +} + +template <> +HOSTDEVICE inline bool(isinf)(const float16& a) { + return (paddle::lite::fluid::isinf)(a); +} + +template <> +HOSTDEVICE inline bool(isfinite)(const float16& a) { + return (paddle::lite::fluid::isfinite)(a); +} + +template <> +HOSTDEVICE inline float16 exp(const float16& a) { + return float16(::expf(static_cast(a))); +} + +template <> +HOSTDEVICE inline float16 erf(const float16& a) { + return float16(::erff(static_cast(a))); +} + +template <> +HOSTDEVICE inline float16 log(const float16& a) { + return float16(::logf(static_cast(a))); +} + +template <> +HOSTDEVICE inline float16 tanh(const float16& a) { + return float16(::tanhf(static_cast(a))); +} + +template <> +HOSTDEVICE inline float16 sqrt(const float16& a) { + return float16(::sqrtf(static_cast(a))); +} + +template <> +HOSTDEVICE inline float16 ceil(const float16& a) { + return float16(::ceilf(static_cast(a))); +} + +template <> +HOSTDEVICE inline float16 floor(const float16& a) { + return float16(::floorf(static_cast(a))); +} + +template <> +HOSTDEVICE inline float16 round(const float16& a) { + return float16(::roundf(static_cast(a))); +} + +template <> +HOSTDEVICE inline float16 pow(const float16& a, const float16& b) { + return float16(::powf(static_cast(a), static_cast(b))); +} + +template <> +HOSTDEVICE inline float16 abs(const float16& a) { + return float16(::fabs(static_cast(a))); +} + +} // namespace numext + +} // namespace Eigen diff --git a/lite/fluid/lod.h b/lite/fluid/lod.h new file mode 100644 index 00000000000..68068ba1d01 --- /dev/null +++ b/lite/fluid/lod.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace paddle { +namespace lite { +namespace fluid { +using LoD = std::vector>; + +LoD ToAbsOffset(const LoD &in) { + // the lowest level stores relative offsets + if (in.empty() || in.size() == 1) return in; + LoD result = in; + for (auto level = static_cast(in.size() - 2); level >= 0; level--) { + for (size_t i = 0; i < in[level].size(); ++i) { + size_t index = in[level][i]; + result[level][i] = result[level + 1][index]; + } + } + return result; +} +} // namespace fluid +} // namespace lite +} // namespace paddle diff --git a/lite/fluid/math.h b/lite/fluid/math.h new file mode 100644 index 00000000000..8cc24200d37 --- /dev/null +++ b/lite/fluid/math.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/hostdevice.h" + +#include "math.h" // NOLINT + +namespace paddle { +namespace operators { + +inline HOSTDEVICE platform::float16 real_exp(platform::float16 x) { + return static_cast(::expf(static_cast(x))); +} + +inline HOSTDEVICE float real_exp(float x) { return ::expf(x); } + +inline HOSTDEVICE double real_exp(double x) { return ::exp(x); } + +inline HOSTDEVICE platform::float16 real_log(platform::float16 x) { + return static_cast(::logf(static_cast(x))); +} + +inline HOSTDEVICE float real_log(float x) { return ::logf(x); } + +inline HOSTDEVICE double real_log(double x) { return ::log(x); } + +} // namespace operators +} // namespace paddle diff --git a/lite/fpga/CMakeLists.txt b/lite/fpga/CMakeLists.txt new file mode 100644 index 00000000000..2956d947466 --- /dev/null +++ b/lite/fpga/CMakeLists.txt @@ -0,0 +1,15 @@ +if (NOT LITE_WITH_FPGA) + return() +endif() + +set(LITE_FPGA_KD_PATH "${PADDLE_SOURCE_DIR}/lite/fpga/KD") +set(LITE_FPGA_PATH "${PADDLE_SOURCE_DIR}/lite/fpga") + +message("fpga_kd_path ${LITE_FPGA_KD_PATH}") +message("fpga_path ${LITE_FPGA_PATH}") +file(GLOB_RECURSE KD_CPP *.cpp *.cc) +file(GLOB FPGA_CPP "${LITE_FPGA_PATH}/*.cc") + +cc_library(kernel_fpga SRCS ${KD_CPP} ${FPGA_CPP}) +cc_library(lite_tensor_fpga SRCS lite_tensor.cc DEPS memory) +cc_library(fpga_target_wrapper SRCS ${LITE_FPGA_PATH}/target_wrapper.cc DEPS kernel_fpga) diff --git a/lite/fpga/KD/alignment.h b/lite/fpga/KD/alignment.h new file mode 100644 index 00000000000..74b80de042f --- /dev/null +++ b/lite/fpga/KD/alignment.h @@ -0,0 +1,26 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +#include "lite/fpga/KD/llapi/zynqmp_api.h" + +namespace paddle { +namespace zynqmp { + +inline int align_image(int wc) { return align_to_x(wc, IMAGE_ALIGNMENT); } + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/context.hpp b/lite/fpga/KD/context.hpp new file mode 100644 index 00000000000..adb9f8a13d8 --- /dev/null +++ b/lite/fpga/KD/context.hpp @@ -0,0 +1,50 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pes/conv_pe.hpp" +#include "lite/fpga/KD/pes/depthwise_conv_pe.hpp" +#include "lite/fpga/KD/pes/fully_connected_pe.hpp" +#include "lite/fpga/KD/pes/input_pe.hpp" +#include "lite/fpga/KD/pes/output_pe.hpp" +#include "lite/fpga/KD/pes/pooling_pe.hpp" +#include "lite/fpga/KD/pes/softmax_pe.hpp" + +namespace paddle { +namespace zynqmp { + +class Context { + public: + template + Ptype& pe() { + if (pe_ == nullptr) { + pe_ = new Ptype(); + } + return static_cast(*pe_); + } + + ~Context() { + if (pe_ != nullptr) { + delete pe_; + } + } + + private: + PE* pe_ = nullptr; +}; +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/dl_engine.cpp b/lite/fpga/KD/dl_engine.cpp new file mode 100644 index 00000000000..90c447c8897 --- /dev/null +++ b/lite/fpga/KD/dl_engine.cpp @@ -0,0 +1,27 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/fpga/KD/dl_engine.hpp" +namespace paddle { +namespace zynqmp { + +DLEngine::DLEngine() { + open_device(); + struct DeviceInfo info; + int ret = get_device_info(info); + filter::set_filter_capacity(info.filter_cap); +} + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/dl_engine.hpp b/lite/fpga/KD/dl_engine.hpp new file mode 100644 index 00000000000..1ac4f4f0fe7 --- /dev/null +++ b/lite/fpga/KD/dl_engine.hpp @@ -0,0 +1,36 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "lite/fpga/KD/llapi/filter.h" +#include "lite/fpga/KD/llapi/zynqmp_api.h" + +namespace paddle { +namespace zynqmp { + +class DLEngine { + public: + static DLEngine& get_instance() { + static DLEngine s_instance; + return s_instance; + } + + private: + DLEngine(); +}; +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/float16.hpp b/lite/fpga/KD/float16.hpp new file mode 100755 index 00000000000..9f123171969 --- /dev/null +++ b/lite/fpga/KD/float16.hpp @@ -0,0 +1,508 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +namespace paddle { +namespace zynqmp { + +typedef uint16_t float16; + +static const uint32_t mantissatable[2048] = { + 0x00000000, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34a00000, + 0x34c00000, 0x34e00000, 0x35000000, 0x35100000, 0x35200000, 0x35300000, + 0x35400000, 0x35500000, 0x35600000, 0x35700000, 0x35800000, 0x35880000, + 0x35900000, 0x35980000, 0x35a00000, 0x35a80000, 0x35b00000, 0x35b80000, + 0x35c00000, 0x35c80000, 0x35d00000, 0x35d80000, 0x35e00000, 0x35e80000, + 0x35f00000, 0x35f80000, 0x36000000, 0x36040000, 0x36080000, 0x360c0000, + 0x36100000, 0x36140000, 0x36180000, 0x361c0000, 0x36200000, 0x36240000, + 0x36280000, 0x362c0000, 0x36300000, 0x36340000, 0x36380000, 0x363c0000, + 0x36400000, 0x36440000, 0x36480000, 0x364c0000, 0x36500000, 0x36540000, + 0x36580000, 0x365c0000, 0x36600000, 0x36640000, 0x36680000, 0x366c0000, + 0x36700000, 0x36740000, 0x36780000, 0x367c0000, 0x36800000, 0x36820000, + 0x36840000, 0x36860000, 0x36880000, 0x368a0000, 0x368c0000, 0x368e0000, + 0x36900000, 0x36920000, 0x36940000, 0x36960000, 0x36980000, 0x369a0000, + 0x369c0000, 0x369e0000, 0x36a00000, 0x36a20000, 0x36a40000, 0x36a60000, + 0x36a80000, 0x36aa0000, 0x36ac0000, 0x36ae0000, 0x36b00000, 0x36b20000, + 0x36b40000, 0x36b60000, 0x36b80000, 0x36ba0000, 0x36bc0000, 0x36be0000, + 0x36c00000, 0x36c20000, 0x36c40000, 0x36c60000, 0x36c80000, 0x36ca0000, + 0x36cc0000, 0x36ce0000, 0x36d00000, 0x36d20000, 0x36d40000, 0x36d60000, + 0x36d80000, 0x36da0000, 0x36dc0000, 0x36de0000, 0x36e00000, 0x36e20000, + 0x36e40000, 0x36e60000, 0x36e80000, 0x36ea0000, 0x36ec0000, 0x36ee0000, + 0x36f00000, 0x36f20000, 0x36f40000, 0x36f60000, 0x36f80000, 0x36fa0000, + 0x36fc0000, 0x36fe0000, 0x37000000, 0x37010000, 0x37020000, 0x37030000, + 0x37040000, 0x37050000, 0x37060000, 0x37070000, 0x37080000, 0x37090000, + 0x370a0000, 0x370b0000, 0x370c0000, 0x370d0000, 0x370e0000, 0x370f0000, + 0x37100000, 0x37110000, 0x37120000, 0x37130000, 0x37140000, 0x37150000, + 0x37160000, 0x37170000, 0x37180000, 0x37190000, 0x371a0000, 0x371b0000, + 0x371c0000, 0x371d0000, 0x371e0000, 0x371f0000, 0x37200000, 0x37210000, + 0x37220000, 0x37230000, 0x37240000, 0x37250000, 0x37260000, 0x37270000, + 0x37280000, 0x37290000, 0x372a0000, 0x372b0000, 0x372c0000, 0x372d0000, + 0x372e0000, 0x372f0000, 0x37300000, 0x37310000, 0x37320000, 0x37330000, + 0x37340000, 0x37350000, 0x37360000, 0x37370000, 0x37380000, 0x37390000, + 0x373a0000, 0x373b0000, 0x373c0000, 0x373d0000, 0x373e0000, 0x373f0000, + 0x37400000, 0x37410000, 0x37420000, 0x37430000, 0x37440000, 0x37450000, + 0x37460000, 0x37470000, 0x37480000, 0x37490000, 0x374a0000, 0x374b0000, + 0x374c0000, 0x374d0000, 0x374e0000, 0x374f0000, 0x37500000, 0x37510000, + 0x37520000, 0x37530000, 0x37540000, 0x37550000, 0x37560000, 0x37570000, + 0x37580000, 0x37590000, 0x375a0000, 0x375b0000, 0x375c0000, 0x375d0000, + 0x375e0000, 0x375f0000, 0x37600000, 0x37610000, 0x37620000, 0x37630000, + 0x37640000, 0x37650000, 0x37660000, 0x37670000, 0x37680000, 0x37690000, + 0x376a0000, 0x376b0000, 0x376c0000, 0x376d0000, 0x376e0000, 0x376f0000, + 0x37700000, 0x37710000, 0x37720000, 0x37730000, 0x37740000, 0x37750000, + 0x37760000, 0x37770000, 0x37780000, 0x37790000, 0x377a0000, 0x377b0000, + 0x377c0000, 0x377d0000, 0x377e0000, 0x377f0000, 0x37800000, 0x37808000, + 0x37810000, 0x37818000, 0x37820000, 0x37828000, 0x37830000, 0x37838000, + 0x37840000, 0x37848000, 0x37850000, 0x37858000, 0x37860000, 0x37868000, + 0x37870000, 0x37878000, 0x37880000, 0x37888000, 0x37890000, 0x37898000, + 0x378a0000, 0x378a8000, 0x378b0000, 0x378b8000, 0x378c0000, 0x378c8000, + 0x378d0000, 0x378d8000, 0x378e0000, 0x378e8000, 0x378f0000, 0x378f8000, + 0x37900000, 0x37908000, 0x37910000, 0x37918000, 0x37920000, 0x37928000, + 0x37930000, 0x37938000, 0x37940000, 0x37948000, 0x37950000, 0x37958000, + 0x37960000, 0x37968000, 0x37970000, 0x37978000, 0x37980000, 0x37988000, + 0x37990000, 0x37998000, 0x379a0000, 0x379a8000, 0x379b0000, 0x379b8000, + 0x379c0000, 0x379c8000, 0x379d0000, 0x379d8000, 0x379e0000, 0x379e8000, + 0x379f0000, 0x379f8000, 0x37a00000, 0x37a08000, 0x37a10000, 0x37a18000, + 0x37a20000, 0x37a28000, 0x37a30000, 0x37a38000, 0x37a40000, 0x37a48000, + 0x37a50000, 0x37a58000, 0x37a60000, 0x37a68000, 0x37a70000, 0x37a78000, + 0x37a80000, 0x37a88000, 0x37a90000, 0x37a98000, 0x37aa0000, 0x37aa8000, + 0x37ab0000, 0x37ab8000, 0x37ac0000, 0x37ac8000, 0x37ad0000, 0x37ad8000, + 0x37ae0000, 0x37ae8000, 0x37af0000, 0x37af8000, 0x37b00000, 0x37b08000, + 0x37b10000, 0x37b18000, 0x37b20000, 0x37b28000, 0x37b30000, 0x37b38000, + 0x37b40000, 0x37b48000, 0x37b50000, 0x37b58000, 0x37b60000, 0x37b68000, + 0x37b70000, 0x37b78000, 0x37b80000, 0x37b88000, 0x37b90000, 0x37b98000, + 0x37ba0000, 0x37ba8000, 0x37bb0000, 0x37bb8000, 0x37bc0000, 0x37bc8000, + 0x37bd0000, 0x37bd8000, 0x37be0000, 0x37be8000, 0x37bf0000, 0x37bf8000, + 0x37c00000, 0x37c08000, 0x37c10000, 0x37c18000, 0x37c20000, 0x37c28000, + 0x37c30000, 0x37c38000, 0x37c40000, 0x37c48000, 0x37c50000, 0x37c58000, + 0x37c60000, 0x37c68000, 0x37c70000, 0x37c78000, 0x37c80000, 0x37c88000, + 0x37c90000, 0x37c98000, 0x37ca0000, 0x37ca8000, 0x37cb0000, 0x37cb8000, + 0x37cc0000, 0x37cc8000, 0x37cd0000, 0x37cd8000, 0x37ce0000, 0x37ce8000, + 0x37cf0000, 0x37cf8000, 0x37d00000, 0x37d08000, 0x37d10000, 0x37d18000, + 0x37d20000, 0x37d28000, 0x37d30000, 0x37d38000, 0x37d40000, 0x37d48000, + 0x37d50000, 0x37d58000, 0x37d60000, 0x37d68000, 0x37d70000, 0x37d78000, + 0x37d80000, 0x37d88000, 0x37d90000, 0x37d98000, 0x37da0000, 0x37da8000, + 0x37db0000, 0x37db8000, 0x37dc0000, 0x37dc8000, 0x37dd0000, 0x37dd8000, + 0x37de0000, 0x37de8000, 0x37df0000, 0x37df8000, 0x37e00000, 0x37e08000, + 0x37e10000, 0x37e18000, 0x37e20000, 0x37e28000, 0x37e30000, 0x37e38000, + 0x37e40000, 0x37e48000, 0x37e50000, 0x37e58000, 0x37e60000, 0x37e68000, + 0x37e70000, 0x37e78000, 0x37e80000, 0x37e88000, 0x37e90000, 0x37e98000, + 0x37ea0000, 0x37ea8000, 0x37eb0000, 0x37eb8000, 0x37ec0000, 0x37ec8000, + 0x37ed0000, 0x37ed8000, 0x37ee0000, 0x37ee8000, 0x37ef0000, 0x37ef8000, + 0x37f00000, 0x37f08000, 0x37f10000, 0x37f18000, 0x37f20000, 0x37f28000, + 0x37f30000, 0x37f38000, 0x37f40000, 0x37f48000, 0x37f50000, 0x37f58000, + 0x37f60000, 0x37f68000, 0x37f70000, 0x37f78000, 0x37f80000, 0x37f88000, + 0x37f90000, 0x37f98000, 0x37fa0000, 0x37fa8000, 0x37fb0000, 0x37fb8000, + 0x37fc0000, 0x37fc8000, 0x37fd0000, 0x37fd8000, 0x37fe0000, 0x37fe8000, + 0x37ff0000, 0x37ff8000, 0x38000000, 0x38004000, 0x38008000, 0x3800c000, + 0x38010000, 0x38014000, 0x38018000, 0x3801c000, 0x38020000, 0x38024000, + 0x38028000, 0x3802c000, 0x38030000, 0x38034000, 0x38038000, 0x3803c000, + 0x38040000, 0x38044000, 0x38048000, 0x3804c000, 0x38050000, 0x38054000, + 0x38058000, 0x3805c000, 0x38060000, 0x38064000, 0x38068000, 0x3806c000, + 0x38070000, 0x38074000, 0x38078000, 0x3807c000, 0x38080000, 0x38084000, + 0x38088000, 0x3808c000, 0x38090000, 0x38094000, 0x38098000, 0x3809c000, + 0x380a0000, 0x380a4000, 0x380a8000, 0x380ac000, 0x380b0000, 0x380b4000, + 0x380b8000, 0x380bc000, 0x380c0000, 0x380c4000, 0x380c8000, 0x380cc000, + 0x380d0000, 0x380d4000, 0x380d8000, 0x380dc000, 0x380e0000, 0x380e4000, + 0x380e8000, 0x380ec000, 0x380f0000, 0x380f4000, 0x380f8000, 0x380fc000, + 0x38100000, 0x38104000, 0x38108000, 0x3810c000, 0x38110000, 0x38114000, + 0x38118000, 0x3811c000, 0x38120000, 0x38124000, 0x38128000, 0x3812c000, + 0x38130000, 0x38134000, 0x38138000, 0x3813c000, 0x38140000, 0x38144000, + 0x38148000, 0x3814c000, 0x38150000, 0x38154000, 0x38158000, 0x3815c000, + 0x38160000, 0x38164000, 0x38168000, 0x3816c000, 0x38170000, 0x38174000, + 0x38178000, 0x3817c000, 0x38180000, 0x38184000, 0x38188000, 0x3818c000, + 0x38190000, 0x38194000, 0x38198000, 0x3819c000, 0x381a0000, 0x381a4000, + 0x381a8000, 0x381ac000, 0x381b0000, 0x381b4000, 0x381b8000, 0x381bc000, + 0x381c0000, 0x381c4000, 0x381c8000, 0x381cc000, 0x381d0000, 0x381d4000, + 0x381d8000, 0x381dc000, 0x381e0000, 0x381e4000, 0x381e8000, 0x381ec000, + 0x381f0000, 0x381f4000, 0x381f8000, 0x381fc000, 0x38200000, 0x38204000, + 0x38208000, 0x3820c000, 0x38210000, 0x38214000, 0x38218000, 0x3821c000, + 0x38220000, 0x38224000, 0x38228000, 0x3822c000, 0x38230000, 0x38234000, + 0x38238000, 0x3823c000, 0x38240000, 0x38244000, 0x38248000, 0x3824c000, + 0x38250000, 0x38254000, 0x38258000, 0x3825c000, 0x38260000, 0x38264000, + 0x38268000, 0x3826c000, 0x38270000, 0x38274000, 0x38278000, 0x3827c000, + 0x38280000, 0x38284000, 0x38288000, 0x3828c000, 0x38290000, 0x38294000, + 0x38298000, 0x3829c000, 0x382a0000, 0x382a4000, 0x382a8000, 0x382ac000, + 0x382b0000, 0x382b4000, 0x382b8000, 0x382bc000, 0x382c0000, 0x382c4000, + 0x382c8000, 0x382cc000, 0x382d0000, 0x382d4000, 0x382d8000, 0x382dc000, + 0x382e0000, 0x382e4000, 0x382e8000, 0x382ec000, 0x382f0000, 0x382f4000, + 0x382f8000, 0x382fc000, 0x38300000, 0x38304000, 0x38308000, 0x3830c000, + 0x38310000, 0x38314000, 0x38318000, 0x3831c000, 0x38320000, 0x38324000, + 0x38328000, 0x3832c000, 0x38330000, 0x38334000, 0x38338000, 0x3833c000, + 0x38340000, 0x38344000, 0x38348000, 0x3834c000, 0x38350000, 0x38354000, + 0x38358000, 0x3835c000, 0x38360000, 0x38364000, 0x38368000, 0x3836c000, + 0x38370000, 0x38374000, 0x38378000, 0x3837c000, 0x38380000, 0x38384000, + 0x38388000, 0x3838c000, 0x38390000, 0x38394000, 0x38398000, 0x3839c000, + 0x383a0000, 0x383a4000, 0x383a8000, 0x383ac000, 0x383b0000, 0x383b4000, + 0x383b8000, 0x383bc000, 0x383c0000, 0x383c4000, 0x383c8000, 0x383cc000, + 0x383d0000, 0x383d4000, 0x383d8000, 0x383dc000, 0x383e0000, 0x383e4000, + 0x383e8000, 0x383ec000, 0x383f0000, 0x383f4000, 0x383f8000, 0x383fc000, + 0x38400000, 0x38404000, 0x38408000, 0x3840c000, 0x38410000, 0x38414000, + 0x38418000, 0x3841c000, 0x38420000, 0x38424000, 0x38428000, 0x3842c000, + 0x38430000, 0x38434000, 0x38438000, 0x3843c000, 0x38440000, 0x38444000, + 0x38448000, 0x3844c000, 0x38450000, 0x38454000, 0x38458000, 0x3845c000, + 0x38460000, 0x38464000, 0x38468000, 0x3846c000, 0x38470000, 0x38474000, + 0x38478000, 0x3847c000, 0x38480000, 0x38484000, 0x38488000, 0x3848c000, + 0x38490000, 0x38494000, 0x38498000, 0x3849c000, 0x384a0000, 0x384a4000, + 0x384a8000, 0x384ac000, 0x384b0000, 0x384b4000, 0x384b8000, 0x384bc000, + 0x384c0000, 0x384c4000, 0x384c8000, 0x384cc000, 0x384d0000, 0x384d4000, + 0x384d8000, 0x384dc000, 0x384e0000, 0x384e4000, 0x384e8000, 0x384ec000, + 0x384f0000, 0x384f4000, 0x384f8000, 0x384fc000, 0x38500000, 0x38504000, + 0x38508000, 0x3850c000, 0x38510000, 0x38514000, 0x38518000, 0x3851c000, + 0x38520000, 0x38524000, 0x38528000, 0x3852c000, 0x38530000, 0x38534000, + 0x38538000, 0x3853c000, 0x38540000, 0x38544000, 0x38548000, 0x3854c000, + 0x38550000, 0x38554000, 0x38558000, 0x3855c000, 0x38560000, 0x38564000, + 0x38568000, 0x3856c000, 0x38570000, 0x38574000, 0x38578000, 0x3857c000, + 0x38580000, 0x38584000, 0x38588000, 0x3858c000, 0x38590000, 0x38594000, + 0x38598000, 0x3859c000, 0x385a0000, 0x385a4000, 0x385a8000, 0x385ac000, + 0x385b0000, 0x385b4000, 0x385b8000, 0x385bc000, 0x385c0000, 0x385c4000, + 0x385c8000, 0x385cc000, 0x385d0000, 0x385d4000, 0x385d8000, 0x385dc000, + 0x385e0000, 0x385e4000, 0x385e8000, 0x385ec000, 0x385f0000, 0x385f4000, + 0x385f8000, 0x385fc000, 0x38600000, 0x38604000, 0x38608000, 0x3860c000, + 0x38610000, 0x38614000, 0x38618000, 0x3861c000, 0x38620000, 0x38624000, + 0x38628000, 0x3862c000, 0x38630000, 0x38634000, 0x38638000, 0x3863c000, + 0x38640000, 0x38644000, 0x38648000, 0x3864c000, 0x38650000, 0x38654000, + 0x38658000, 0x3865c000, 0x38660000, 0x38664000, 0x38668000, 0x3866c000, + 0x38670000, 0x38674000, 0x38678000, 0x3867c000, 0x38680000, 0x38684000, + 0x38688000, 0x3868c000, 0x38690000, 0x38694000, 0x38698000, 0x3869c000, + 0x386a0000, 0x386a4000, 0x386a8000, 0x386ac000, 0x386b0000, 0x386b4000, + 0x386b8000, 0x386bc000, 0x386c0000, 0x386c4000, 0x386c8000, 0x386cc000, + 0x386d0000, 0x386d4000, 0x386d8000, 0x386dc000, 0x386e0000, 0x386e4000, + 0x386e8000, 0x386ec000, 0x386f0000, 0x386f4000, 0x386f8000, 0x386fc000, + 0x38700000, 0x38704000, 0x38708000, 0x3870c000, 0x38710000, 0x38714000, + 0x38718000, 0x3871c000, 0x38720000, 0x38724000, 0x38728000, 0x3872c000, + 0x38730000, 0x38734000, 0x38738000, 0x3873c000, 0x38740000, 0x38744000, + 0x38748000, 0x3874c000, 0x38750000, 0x38754000, 0x38758000, 0x3875c000, + 0x38760000, 0x38764000, 0x38768000, 0x3876c000, 0x38770000, 0x38774000, + 0x38778000, 0x3877c000, 0x38780000, 0x38784000, 0x38788000, 0x3878c000, + 0x38790000, 0x38794000, 0x38798000, 0x3879c000, 0x387a0000, 0x387a4000, + 0x387a8000, 0x387ac000, 0x387b0000, 0x387b4000, 0x387b8000, 0x387bc000, + 0x387c0000, 0x387c4000, 0x387c8000, 0x387cc000, 0x387d0000, 0x387d4000, + 0x387d8000, 0x387dc000, 0x387e0000, 0x387e4000, 0x387e8000, 0x387ec000, + 0x387f0000, 0x387f4000, 0x387f8000, 0x387fc000, 0x38000000, 0x38002000, + 0x38004000, 0x38006000, 0x38008000, 0x3800a000, 0x3800c000, 0x3800e000, + 0x38010000, 0x38012000, 0x38014000, 0x38016000, 0x38018000, 0x3801a000, + 0x3801c000, 0x3801e000, 0x38020000, 0x38022000, 0x38024000, 0x38026000, + 0x38028000, 0x3802a000, 0x3802c000, 0x3802e000, 0x38030000, 0x38032000, + 0x38034000, 0x38036000, 0x38038000, 0x3803a000, 0x3803c000, 0x3803e000, + 0x38040000, 0x38042000, 0x38044000, 0x38046000, 0x38048000, 0x3804a000, + 0x3804c000, 0x3804e000, 0x38050000, 0x38052000, 0x38054000, 0x38056000, + 0x38058000, 0x3805a000, 0x3805c000, 0x3805e000, 0x38060000, 0x38062000, + 0x38064000, 0x38066000, 0x38068000, 0x3806a000, 0x3806c000, 0x3806e000, + 0x38070000, 0x38072000, 0x38074000, 0x38076000, 0x38078000, 0x3807a000, + 0x3807c000, 0x3807e000, 0x38080000, 0x38082000, 0x38084000, 0x38086000, + 0x38088000, 0x3808a000, 0x3808c000, 0x3808e000, 0x38090000, 0x38092000, + 0x38094000, 0x38096000, 0x38098000, 0x3809a000, 0x3809c000, 0x3809e000, + 0x380a0000, 0x380a2000, 0x380a4000, 0x380a6000, 0x380a8000, 0x380aa000, + 0x380ac000, 0x380ae000, 0x380b0000, 0x380b2000, 0x380b4000, 0x380b6000, + 0x380b8000, 0x380ba000, 0x380bc000, 0x380be000, 0x380c0000, 0x380c2000, + 0x380c4000, 0x380c6000, 0x380c8000, 0x380ca000, 0x380cc000, 0x380ce000, + 0x380d0000, 0x380d2000, 0x380d4000, 0x380d6000, 0x380d8000, 0x380da000, + 0x380dc000, 0x380de000, 0x380e0000, 0x380e2000, 0x380e4000, 0x380e6000, + 0x380e8000, 0x380ea000, 0x380ec000, 0x380ee000, 0x380f0000, 0x380f2000, + 0x380f4000, 0x380f6000, 0x380f8000, 0x380fa000, 0x380fc000, 0x380fe000, + 0x38100000, 0x38102000, 0x38104000, 0x38106000, 0x38108000, 0x3810a000, + 0x3810c000, 0x3810e000, 0x38110000, 0x38112000, 0x38114000, 0x38116000, + 0x38118000, 0x3811a000, 0x3811c000, 0x3811e000, 0x38120000, 0x38122000, + 0x38124000, 0x38126000, 0x38128000, 0x3812a000, 0x3812c000, 0x3812e000, + 0x38130000, 0x38132000, 0x38134000, 0x38136000, 0x38138000, 0x3813a000, + 0x3813c000, 0x3813e000, 0x38140000, 0x38142000, 0x38144000, 0x38146000, + 0x38148000, 0x3814a000, 0x3814c000, 0x3814e000, 0x38150000, 0x38152000, + 0x38154000, 0x38156000, 0x38158000, 0x3815a000, 0x3815c000, 0x3815e000, + 0x38160000, 0x38162000, 0x38164000, 0x38166000, 0x38168000, 0x3816a000, + 0x3816c000, 0x3816e000, 0x38170000, 0x38172000, 0x38174000, 0x38176000, + 0x38178000, 0x3817a000, 0x3817c000, 0x3817e000, 0x38180000, 0x38182000, + 0x38184000, 0x38186000, 0x38188000, 0x3818a000, 0x3818c000, 0x3818e000, + 0x38190000, 0x38192000, 0x38194000, 0x38196000, 0x38198000, 0x3819a000, + 0x3819c000, 0x3819e000, 0x381a0000, 0x381a2000, 0x381a4000, 0x381a6000, + 0x381a8000, 0x381aa000, 0x381ac000, 0x381ae000, 0x381b0000, 0x381b2000, + 0x381b4000, 0x381b6000, 0x381b8000, 0x381ba000, 0x381bc000, 0x381be000, + 0x381c0000, 0x381c2000, 0x381c4000, 0x381c6000, 0x381c8000, 0x381ca000, + 0x381cc000, 0x381ce000, 0x381d0000, 0x381d2000, 0x381d4000, 0x381d6000, + 0x381d8000, 0x381da000, 0x381dc000, 0x381de000, 0x381e0000, 0x381e2000, + 0x381e4000, 0x381e6000, 0x381e8000, 0x381ea000, 0x381ec000, 0x381ee000, + 0x381f0000, 0x381f2000, 0x381f4000, 0x381f6000, 0x381f8000, 0x381fa000, + 0x381fc000, 0x381fe000, 0x38200000, 0x38202000, 0x38204000, 0x38206000, + 0x38208000, 0x3820a000, 0x3820c000, 0x3820e000, 0x38210000, 0x38212000, + 0x38214000, 0x38216000, 0x38218000, 0x3821a000, 0x3821c000, 0x3821e000, + 0x38220000, 0x38222000, 0x38224000, 0x38226000, 0x38228000, 0x3822a000, + 0x3822c000, 0x3822e000, 0x38230000, 0x38232000, 0x38234000, 0x38236000, + 0x38238000, 0x3823a000, 0x3823c000, 0x3823e000, 0x38240000, 0x38242000, + 0x38244000, 0x38246000, 0x38248000, 0x3824a000, 0x3824c000, 0x3824e000, + 0x38250000, 0x38252000, 0x38254000, 0x38256000, 0x38258000, 0x3825a000, + 0x3825c000, 0x3825e000, 0x38260000, 0x38262000, 0x38264000, 0x38266000, + 0x38268000, 0x3826a000, 0x3826c000, 0x3826e000, 0x38270000, 0x38272000, + 0x38274000, 0x38276000, 0x38278000, 0x3827a000, 0x3827c000, 0x3827e000, + 0x38280000, 0x38282000, 0x38284000, 0x38286000, 0x38288000, 0x3828a000, + 0x3828c000, 0x3828e000, 0x38290000, 0x38292000, 0x38294000, 0x38296000, + 0x38298000, 0x3829a000, 0x3829c000, 0x3829e000, 0x382a0000, 0x382a2000, + 0x382a4000, 0x382a6000, 0x382a8000, 0x382aa000, 0x382ac000, 0x382ae000, + 0x382b0000, 0x382b2000, 0x382b4000, 0x382b6000, 0x382b8000, 0x382ba000, + 0x382bc000, 0x382be000, 0x382c0000, 0x382c2000, 0x382c4000, 0x382c6000, + 0x382c8000, 0x382ca000, 0x382cc000, 0x382ce000, 0x382d0000, 0x382d2000, + 0x382d4000, 0x382d6000, 0x382d8000, 0x382da000, 0x382dc000, 0x382de000, + 0x382e0000, 0x382e2000, 0x382e4000, 0x382e6000, 0x382e8000, 0x382ea000, + 0x382ec000, 0x382ee000, 0x382f0000, 0x382f2000, 0x382f4000, 0x382f6000, + 0x382f8000, 0x382fa000, 0x382fc000, 0x382fe000, 0x38300000, 0x38302000, + 0x38304000, 0x38306000, 0x38308000, 0x3830a000, 0x3830c000, 0x3830e000, + 0x38310000, 0x38312000, 0x38314000, 0x38316000, 0x38318000, 0x3831a000, + 0x3831c000, 0x3831e000, 0x38320000, 0x38322000, 0x38324000, 0x38326000, + 0x38328000, 0x3832a000, 0x3832c000, 0x3832e000, 0x38330000, 0x38332000, + 0x38334000, 0x38336000, 0x38338000, 0x3833a000, 0x3833c000, 0x3833e000, + 0x38340000, 0x38342000, 0x38344000, 0x38346000, 0x38348000, 0x3834a000, + 0x3834c000, 0x3834e000, 0x38350000, 0x38352000, 0x38354000, 0x38356000, + 0x38358000, 0x3835a000, 0x3835c000, 0x3835e000, 0x38360000, 0x38362000, + 0x38364000, 0x38366000, 0x38368000, 0x3836a000, 0x3836c000, 0x3836e000, + 0x38370000, 0x38372000, 0x38374000, 0x38376000, 0x38378000, 0x3837a000, + 0x3837c000, 0x3837e000, 0x38380000, 0x38382000, 0x38384000, 0x38386000, + 0x38388000, 0x3838a000, 0x3838c000, 0x3838e000, 0x38390000, 0x38392000, + 0x38394000, 0x38396000, 0x38398000, 0x3839a000, 0x3839c000, 0x3839e000, + 0x383a0000, 0x383a2000, 0x383a4000, 0x383a6000, 0x383a8000, 0x383aa000, + 0x383ac000, 0x383ae000, 0x383b0000, 0x383b2000, 0x383b4000, 0x383b6000, + 0x383b8000, 0x383ba000, 0x383bc000, 0x383be000, 0x383c0000, 0x383c2000, + 0x383c4000, 0x383c6000, 0x383c8000, 0x383ca000, 0x383cc000, 0x383ce000, + 0x383d0000, 0x383d2000, 0x383d4000, 0x383d6000, 0x383d8000, 0x383da000, + 0x383dc000, 0x383de000, 0x383e0000, 0x383e2000, 0x383e4000, 0x383e6000, + 0x383e8000, 0x383ea000, 0x383ec000, 0x383ee000, 0x383f0000, 0x383f2000, + 0x383f4000, 0x383f6000, 0x383f8000, 0x383fa000, 0x383fc000, 0x383fe000, + 0x38400000, 0x38402000, 0x38404000, 0x38406000, 0x38408000, 0x3840a000, + 0x3840c000, 0x3840e000, 0x38410000, 0x38412000, 0x38414000, 0x38416000, + 0x38418000, 0x3841a000, 0x3841c000, 0x3841e000, 0x38420000, 0x38422000, + 0x38424000, 0x38426000, 0x38428000, 0x3842a000, 0x3842c000, 0x3842e000, + 0x38430000, 0x38432000, 0x38434000, 0x38436000, 0x38438000, 0x3843a000, + 0x3843c000, 0x3843e000, 0x38440000, 0x38442000, 0x38444000, 0x38446000, + 0x38448000, 0x3844a000, 0x3844c000, 0x3844e000, 0x38450000, 0x38452000, + 0x38454000, 0x38456000, 0x38458000, 0x3845a000, 0x3845c000, 0x3845e000, + 0x38460000, 0x38462000, 0x38464000, 0x38466000, 0x38468000, 0x3846a000, + 0x3846c000, 0x3846e000, 0x38470000, 0x38472000, 0x38474000, 0x38476000, + 0x38478000, 0x3847a000, 0x3847c000, 0x3847e000, 0x38480000, 0x38482000, + 0x38484000, 0x38486000, 0x38488000, 0x3848a000, 0x3848c000, 0x3848e000, + 0x38490000, 0x38492000, 0x38494000, 0x38496000, 0x38498000, 0x3849a000, + 0x3849c000, 0x3849e000, 0x384a0000, 0x384a2000, 0x384a4000, 0x384a6000, + 0x384a8000, 0x384aa000, 0x384ac000, 0x384ae000, 0x384b0000, 0x384b2000, + 0x384b4000, 0x384b6000, 0x384b8000, 0x384ba000, 0x384bc000, 0x384be000, + 0x384c0000, 0x384c2000, 0x384c4000, 0x384c6000, 0x384c8000, 0x384ca000, + 0x384cc000, 0x384ce000, 0x384d0000, 0x384d2000, 0x384d4000, 0x384d6000, + 0x384d8000, 0x384da000, 0x384dc000, 0x384de000, 0x384e0000, 0x384e2000, + 0x384e4000, 0x384e6000, 0x384e8000, 0x384ea000, 0x384ec000, 0x384ee000, + 0x384f0000, 0x384f2000, 0x384f4000, 0x384f6000, 0x384f8000, 0x384fa000, + 0x384fc000, 0x384fe000, 0x38500000, 0x38502000, 0x38504000, 0x38506000, + 0x38508000, 0x3850a000, 0x3850c000, 0x3850e000, 0x38510000, 0x38512000, + 0x38514000, 0x38516000, 0x38518000, 0x3851a000, 0x3851c000, 0x3851e000, + 0x38520000, 0x38522000, 0x38524000, 0x38526000, 0x38528000, 0x3852a000, + 0x3852c000, 0x3852e000, 0x38530000, 0x38532000, 0x38534000, 0x38536000, + 0x38538000, 0x3853a000, 0x3853c000, 0x3853e000, 0x38540000, 0x38542000, + 0x38544000, 0x38546000, 0x38548000, 0x3854a000, 0x3854c000, 0x3854e000, + 0x38550000, 0x38552000, 0x38554000, 0x38556000, 0x38558000, 0x3855a000, + 0x3855c000, 0x3855e000, 0x38560000, 0x38562000, 0x38564000, 0x38566000, + 0x38568000, 0x3856a000, 0x3856c000, 0x3856e000, 0x38570000, 0x38572000, + 0x38574000, 0x38576000, 0x38578000, 0x3857a000, 0x3857c000, 0x3857e000, + 0x38580000, 0x38582000, 0x38584000, 0x38586000, 0x38588000, 0x3858a000, + 0x3858c000, 0x3858e000, 0x38590000, 0x38592000, 0x38594000, 0x38596000, + 0x38598000, 0x3859a000, 0x3859c000, 0x3859e000, 0x385a0000, 0x385a2000, + 0x385a4000, 0x385a6000, 0x385a8000, 0x385aa000, 0x385ac000, 0x385ae000, + 0x385b0000, 0x385b2000, 0x385b4000, 0x385b6000, 0x385b8000, 0x385ba000, + 0x385bc000, 0x385be000, 0x385c0000, 0x385c2000, 0x385c4000, 0x385c6000, + 0x385c8000, 0x385ca000, 0x385cc000, 0x385ce000, 0x385d0000, 0x385d2000, + 0x385d4000, 0x385d6000, 0x385d8000, 0x385da000, 0x385dc000, 0x385de000, + 0x385e0000, 0x385e2000, 0x385e4000, 0x385e6000, 0x385e8000, 0x385ea000, + 0x385ec000, 0x385ee000, 0x385f0000, 0x385f2000, 0x385f4000, 0x385f6000, + 0x385f8000, 0x385fa000, 0x385fc000, 0x385fe000, 0x38600000, 0x38602000, + 0x38604000, 0x38606000, 0x38608000, 0x3860a000, 0x3860c000, 0x3860e000, + 0x38610000, 0x38612000, 0x38614000, 0x38616000, 0x38618000, 0x3861a000, + 0x3861c000, 0x3861e000, 0x38620000, 0x38622000, 0x38624000, 0x38626000, + 0x38628000, 0x3862a000, 0x3862c000, 0x3862e000, 0x38630000, 0x38632000, + 0x38634000, 0x38636000, 0x38638000, 0x3863a000, 0x3863c000, 0x3863e000, + 0x38640000, 0x38642000, 0x38644000, 0x38646000, 0x38648000, 0x3864a000, + 0x3864c000, 0x3864e000, 0x38650000, 0x38652000, 0x38654000, 0x38656000, + 0x38658000, 0x3865a000, 0x3865c000, 0x3865e000, 0x38660000, 0x38662000, + 0x38664000, 0x38666000, 0x38668000, 0x3866a000, 0x3866c000, 0x3866e000, + 0x38670000, 0x38672000, 0x38674000, 0x38676000, 0x38678000, 0x3867a000, + 0x3867c000, 0x3867e000, 0x38680000, 0x38682000, 0x38684000, 0x38686000, + 0x38688000, 0x3868a000, 0x3868c000, 0x3868e000, 0x38690000, 0x38692000, + 0x38694000, 0x38696000, 0x38698000, 0x3869a000, 0x3869c000, 0x3869e000, + 0x386a0000, 0x386a2000, 0x386a4000, 0x386a6000, 0x386a8000, 0x386aa000, + 0x386ac000, 0x386ae000, 0x386b0000, 0x386b2000, 0x386b4000, 0x386b6000, + 0x386b8000, 0x386ba000, 0x386bc000, 0x386be000, 0x386c0000, 0x386c2000, + 0x386c4000, 0x386c6000, 0x386c8000, 0x386ca000, 0x386cc000, 0x386ce000, + 0x386d0000, 0x386d2000, 0x386d4000, 0x386d6000, 0x386d8000, 0x386da000, + 0x386dc000, 0x386de000, 0x386e0000, 0x386e2000, 0x386e4000, 0x386e6000, + 0x386e8000, 0x386ea000, 0x386ec000, 0x386ee000, 0x386f0000, 0x386f2000, + 0x386f4000, 0x386f6000, 0x386f8000, 0x386fa000, 0x386fc000, 0x386fe000, + 0x38700000, 0x38702000, 0x38704000, 0x38706000, 0x38708000, 0x3870a000, + 0x3870c000, 0x3870e000, 0x38710000, 0x38712000, 0x38714000, 0x38716000, + 0x38718000, 0x3871a000, 0x3871c000, 0x3871e000, 0x38720000, 0x38722000, + 0x38724000, 0x38726000, 0x38728000, 0x3872a000, 0x3872c000, 0x3872e000, + 0x38730000, 0x38732000, 0x38734000, 0x38736000, 0x38738000, 0x3873a000, + 0x3873c000, 0x3873e000, 0x38740000, 0x38742000, 0x38744000, 0x38746000, + 0x38748000, 0x3874a000, 0x3874c000, 0x3874e000, 0x38750000, 0x38752000, + 0x38754000, 0x38756000, 0x38758000, 0x3875a000, 0x3875c000, 0x3875e000, + 0x38760000, 0x38762000, 0x38764000, 0x38766000, 0x38768000, 0x3876a000, + 0x3876c000, 0x3876e000, 0x38770000, 0x38772000, 0x38774000, 0x38776000, + 0x38778000, 0x3877a000, 0x3877c000, 0x3877e000, 0x38780000, 0x38782000, + 0x38784000, 0x38786000, 0x38788000, 0x3878a000, 0x3878c000, 0x3878e000, + 0x38790000, 0x38792000, 0x38794000, 0x38796000, 0x38798000, 0x3879a000, + 0x3879c000, 0x3879e000, 0x387a0000, 0x387a2000, 0x387a4000, 0x387a6000, + 0x387a8000, 0x387aa000, 0x387ac000, 0x387ae000, 0x387b0000, 0x387b2000, + 0x387b4000, 0x387b6000, 0x387b8000, 0x387ba000, 0x387bc000, 0x387be000, + 0x387c0000, 0x387c2000, 0x387c4000, 0x387c6000, 0x387c8000, 0x387ca000, + 0x387cc000, 0x387ce000, 0x387d0000, 0x387d2000, 0x387d4000, 0x387d6000, + 0x387d8000, 0x387da000, 0x387dc000, 0x387de000, 0x387e0000, 0x387e2000, + 0x387e4000, 0x387e6000, 0x387e8000, 0x387ea000, 0x387ec000, 0x387ee000, + 0x387f0000, 0x387f2000, 0x387f4000, 0x387f6000, 0x387f8000, 0x387fa000, + 0x387fc000, 0x387fe000}; + +static const uint16_t offsettable[64] = { + 0x0000, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, + 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, + 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, + 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, + 0x0000, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, + 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, + 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, + 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400, 0x0400}; + +static const uint32_t exponenttable[64] = { + 0x00000000, 0x00800000, 0x01000000, 0x01800000, 0x02000000, 0x02800000, + 0x03000000, 0x03800000, 0x04000000, 0x04800000, 0x05000000, 0x05800000, + 0x06000000, 0x06800000, 0x07000000, 0x07800000, 0x08000000, 0x08800000, + 0x09000000, 0x09800000, 0x0a000000, 0x0a800000, 0x0b000000, 0x0b800000, + 0x0c000000, 0x0c800000, 0x0d000000, 0x0d800000, 0x0e000000, 0x0e800000, + 0x0f000000, 0x47800000, 0x80000000, 0x80800000, 0x81000000, 0x81800000, + 0x82000000, 0x82800000, 0x83000000, 0x83800000, 0x84000000, 0x84800000, + 0x85000000, 0x85800000, 0x86000000, 0x86800000, 0x87000000, 0x87800000, + 0x88000000, 0x88800000, 0x89000000, 0x89800000, 0x8a000000, 0x8a800000, + 0x8b000000, 0x8b800000, 0x8c000000, 0x8c800000, 0x8d000000, 0x8d800000, + 0x8e000000, 0x8e800000, 0x8f000000, 0xc7800000}; + +static const uint16_t basetable[512] = { + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, + 0x0020, 0x0040, 0x0080, 0x0100, 0x0200, 0x0400, 0x0800, 0x0c00, 0x1000, + 0x1400, 0x1800, 0x1c00, 0x2000, 0x2400, 0x2800, 0x2c00, 0x3000, 0x3400, + 0x3800, 0x3c00, 0x4000, 0x4400, 0x4800, 0x4c00, 0x5000, 0x5400, 0x5800, + 0x5c00, 0x6000, 0x6400, 0x6800, 0x6c00, 0x7000, 0x7400, 0x7800, 0x7c00, + 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, + 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, + 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, + 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, + 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, + 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, + 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, + 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, + 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, + 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, + 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, + 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x7c00, + 0x7c00, 0x7c00, 0x7c00, 0x7c00, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8001, + 0x8002, 0x8004, 0x8008, 0x8010, 0x8020, 0x8040, 0x8080, 0x8100, 0x8200, + 0x8400, 0x8800, 0x8c00, 0x9000, 0x9400, 0x9800, 0x9c00, 0xa000, 0xa400, + 0xa800, 0xac00, 0xb000, 0xb400, 0xb800, 0xbc00, 0xc000, 0xc400, 0xc800, + 0xcc00, 0xd000, 0xd400, 0xd800, 0xdc00, 0xe000, 0xe400, 0xe800, 0xec00, + 0xf000, 0xf400, 0xf800, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, + 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, + 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, + 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, + 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, + 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, + 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, + 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, + 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, + 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, + 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, + 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, + 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00, 0xfc00}; + +static const uint8_t shifttable[512] = { + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x17, 0x16, 0x15, 0x14, 0x13, + 0x12, 0x11, 0x10, 0x0f, 0x0e, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, + 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, + 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x0d, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x17, + 0x16, 0x15, 0x14, 0x13, 0x12, 0x11, 0x10, 0x0f, 0x0e, 0x0d, 0x0d, 0x0d, + 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, + 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, 0x0d, + 0x0d, 0x0d, 0x0d, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, + 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x18, 0x0d}; + +inline float16 float_to_half(float f) { + uint32_t v = *reinterpret_cast(&f); + return basetable[(v >> 23) & 0x1ff] + + ((v & 0x007fffff) >> shifttable[(v >> 23) & 0x1ff]); +} + +inline float half_to_float(float16 h) { + uint32_t v = mantissatable[offsettable[h >> 10] + (h & 0x3ff)] + + exponenttable[h >> 10]; + return *reinterpret_cast(&v); +} + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/fpga_cv.cpp b/lite/fpga/KD/fpga_cv.cpp new file mode 100644 index 00000000000..92ddccb0ada --- /dev/null +++ b/lite/fpga/KD/fpga_cv.cpp @@ -0,0 +1,80 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/fpga/KD/fpga_cv.hpp" + +using paddle::zynqmp::float16; + +void fpga_resize(float* input, + int input_width, + int input_height, + int input_channel, + uint8_t* output, + int output_width, + int output_height) { + paddle::zynqmp::InplaceArgs inplace_args = { + .relu_enable = 0, .power_enable = 0, + }; + paddle::zynqmp::config_inplace(inplace_args); + + paddle::zynqmp::ImageInputArgs input_args = {nullptr}; + input_args.address = nullptr; + input_args.scale_address = nullptr; + + float16* input_image_address = + reinterpret_cast(paddle::zynqmp::fpga_malloc( + input_width * input_height * input_channel * sizeof(float16))); + int index = 0; + + for (int i = 0; i < input_width * input_height * input_channel; i++) { + input_image_address[i] = float16(1.0 * input[i]); + } + + paddle::zynqmp::ResizeArgs resize_args = {0}; + + resize_args.input_width = input_width; + resize_args.input_height = input_height; + resize_args.image_channel = input_channel; + resize_args.output_width = output_width; + resize_args.output_height = output_height; + float height_ratio = static_cast(input_height) / + static_cast(resize_args.output_height); + float width_ratio = static_cast(input_width) / + static_cast(resize_args.output_width); + resize_args.height_ratio = *reinterpret_cast(&height_ratio); + resize_args.width_ratio = *reinterpret_cast(&width_ratio); + + int output_size = + resize_args.output_width * resize_args.output_height * input_channel; + float16* fpga_output = reinterpret_cast( + paddle::zynqmp::fpga_malloc(output_size * sizeof(float16))); + resize_args.input_image_address = input_image_address; + resize_args.output_image_address = fpga_output; + + memset(fpga_output, 0, output_size * sizeof(float16)); + paddle::zynqmp::fpga_flush( + input_image_address, + input_width * input_height * input_channel * sizeof(float16)); + paddle::zynqmp::fpga_flush(resize_args.output_image_address, + output_size * sizeof(float16)); + int ret = paddle::zynqmp::compute_fpga_resize(resize_args); + if (ret == 0) { + paddle::zynqmp::fpga_invalidate(resize_args.output_image_address, + output_size * sizeof(float16)); + } + + for (int i = 0; i < output_size; i++) { + output[i] = fpga_output[i]; + } +} diff --git a/lite/fpga/KD/fpga_cv.hpp b/lite/fpga/KD/fpga_cv.hpp new file mode 100644 index 00000000000..0f68ab239d0 --- /dev/null +++ b/lite/fpga/KD/fpga_cv.hpp @@ -0,0 +1,28 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "lite/fpga/KD/float16.hpp" +#include "lite/fpga/KD/llapi/zynqmp_api.h" +#include "lite/fpga/KD/pe.hpp" + +void fpga_resize(float* input, + int input_width, + int input_height, + int input_channel, + uint8_t* output, + int output_width, + int output_height); diff --git a/lite/fpga/KD/layout.hpp b/lite/fpga/KD/layout.hpp new file mode 100644 index 00000000000..81ec746e39c --- /dev/null +++ b/lite/fpga/KD/layout.hpp @@ -0,0 +1,99 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "lite/fpga/KD/alignment.h" + +namespace paddle { +namespace zynqmp { + +enum LayoutType { + N, + NC, + NCHW, + NHWC, + NHW, +}; + +class Layout { + public: + virtual int numIndex() = 0; + virtual int channelIndex() { return -1; } + virtual int heightIndex() { return -1; } + virtual int widthIndex() { return -1; } + virtual int alignedElementCount(const std::vector& dims) = 0; + virtual int elementCount(const std::vector& dims) = 0; +}; + +struct NCHW : Layout { + int numIndex() { return 0; } + int channelIndex() { return 1; } + int heightIndex() { return 2; } + int widthIndex() { return 3; } + int alignedElementCount(const std::vector& dims) { + return dims[0] * dims[2] * align_image(dims[1] * dims[3]); + } + virtual int elementCount(const std::vector& dims) { + return dims[0] * dims[1] * dims[2] * dims[3]; + } +}; + +struct NHWC : Layout { + int numIndex() { return 0; } + int heightIndex() { return 1; } + int widthIndex() { return 2; } + int channelIndex() { return 3; } + int alignedElementCount(const std::vector& dims) { + return dims[0] * dims[1] * align_image(dims[2] * dims[3]); + } + virtual int elementCount(const std::vector& dims) { + return dims[0] * dims[1] * dims[2] * dims[3]; + } +}; + +struct NC : Layout { + int numIndex() { return 0; } + int channelIndex() { return 1; } + int alignedElementCount(const std::vector& dims) { + return dims[0] * dims[1]; + } + virtual int elementCount(const std::vector& dims) { + return dims[0] * dims[1]; + } +}; + +struct N : Layout { + int numIndex() { return 0; } + int alignedElementCount(const std::vector& dims) { return dims[0]; } + virtual int elementCount(const std::vector& dims) { return dims[0]; } +}; + +struct NHW : Layout { + int numIndex() { return 0; } + int heightIndex() { return 1; } + int widthIndex() { return 2; } + int alignedElementCount(const std::vector& dims) { + // TODO(chonwhite) align it; + return dims[0] * dims[1] * dims[2]; + } + virtual int elementCount(const std::vector& dims) { + return dims[0] * dims[1] * dims[2]; + } +}; + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/llapi/bias_scale.cpp b/lite/fpga/KD/llapi/bias_scale.cpp new file mode 100644 index 00000000000..4ea6897ced4 --- /dev/null +++ b/lite/fpga/KD/llapi/bias_scale.cpp @@ -0,0 +1,102 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "lite/fpga/KD/llapi/bias_scale.h" +#include "lite/fpga/KD/llapi/zynqmp_api.h" + +namespace paddle { +namespace zynqmp { +namespace bias_scale { + +void align_element(float **data_in, int num_per_div_before_alignment, int num) { + int copynum = 0; + float *ptr_unaligned = *data_in; + int div_num = + (num + num_per_div_before_alignment - 1) / num_per_div_before_alignment; + int num_per_div_after_alignment = + align_to_x(num_per_div_before_alignment, BS_NUM_ALIGNMENT); + int num_element = + 2 * div_num * num_per_div_after_alignment; // including bias & scale + float *ptr_aligned = + (float *)fpga_malloc(num_element * sizeof(float)); // NOLINT + + memset(ptr_aligned, 0, num_element * sizeof(float)); + for (int i = 0; i < div_num; i++) { + if (i == div_num - 1) { + copynum = (num_per_div_after_alignment * div_num > num) + ? (num % num_per_div_after_alignment) + : (num_per_div_before_alignment); + } else { + copynum = num_per_div_before_alignment; + } + + memcpy(ptr_aligned + i * num_per_div_after_alignment, + ptr_unaligned + num_per_div_before_alignment * i, + copynum * sizeof(float)); + memcpy(ptr_aligned + (div_num + i) * num_per_div_after_alignment, + ptr_unaligned + num_per_div_before_alignment * i + num, + copynum * sizeof(float)); + } + fpga_free(ptr_unaligned); + *data_in = ptr_aligned; +} + +void interleave(float **data_in, int num_after_alignment) { + float *ptr_uninterleaved = *data_in; + float *ptr_interleaved = + (float *)fpga_malloc(2 * num_after_alignment * sizeof(float)); // NOLINT + int num = num_after_alignment / 4; + for (int i = 0; i < num; i++) { + memcpy( + ptr_interleaved + 8 * i, ptr_uninterleaved + 4 * i, 4 * sizeof(float)); + memcpy(ptr_interleaved + 8 * i + 4, + ptr_uninterleaved + num_after_alignment + 4 * i, + 4 * sizeof(float)); + } + + fpga_free(ptr_uninterleaved); + *data_in = ptr_interleaved; +} + +void format_bias_scale_array(float **bias_scale_array, + int element_num_per_division, + int num) { + align_element(bias_scale_array, element_num_per_division, num); + int div_num = (num + element_num_per_division - 1) / element_num_per_division; + int element_num_after_division = + align_to_x(element_num_per_division, BS_NUM_ALIGNMENT); + interleave(bias_scale_array, div_num * element_num_after_division); + fpga_flush(*bias_scale_array, 2 * element_num_after_division * sizeof(float)); +} +void format_bias_array(float **bias_array, int num) { + float *ptr_unaligned = *bias_array; + int num_before_align = num; + int num_after_align = align_to_x(num_before_align, BIAS_NUM_ALIGNMENT); + int16_t *ptr_aligned = + (int16_t *)fpga_malloc(num_after_align * sizeof(int16_t)); // NOLINT + + memset(ptr_aligned, 0, num_after_align * sizeof(int16_t)); + for (int i = 0; i < num_before_align; i++) { + float value = ptr_aligned[i]; + ptr_aligned[i] = fp32_2_fp16(ptr_unaligned[i]); + } + *bias_array = (float *)ptr_aligned; // NOLINT + fpga_free(ptr_unaligned); +} + +} // namespace bias_scale +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/llapi/bias_scale.h b/lite/fpga/KD/llapi/bias_scale.h new file mode 100644 index 00000000000..83f30df18fc --- /dev/null +++ b/lite/fpga/KD/llapi/bias_scale.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace zynqmp { +namespace bias_scale { + +void align_element(float** data_in, int num_per_div_before_alignment, int num); +void interleave(float** data_in, int num_after_alignment); +void format_bias_scale_array(float** bias_scale_array, + int element_num_per_division, + int num); +void format_bias_array(float** bias_array, int num); + +} // namespace bias_scale +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/llapi/config.h b/lite/fpga/KD/llapi/config.h new file mode 100755 index 00000000000..acf8c8adf4f --- /dev/null +++ b/lite/fpga/KD/llapi/config.h @@ -0,0 +1,19 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#define PADDLE_LITE_ZU5 +#define FPGA_PRINT_MODE +#define PADDLE_LITE_PROFILE diff --git a/lite/fpga/KD/llapi/filter.cpp b/lite/fpga/KD/llapi/filter.cpp new file mode 100644 index 00000000000..2bc1e28a4ba --- /dev/null +++ b/lite/fpga/KD/llapi/filter.cpp @@ -0,0 +1,317 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/fpga/KD/llapi/filter.h" +#include +#include +#include "lite/fpga/KD/float16.hpp" +#include "lite/fpga/KD/llapi/zynqmp_api.h" + +namespace paddle { +namespace zynqmp { +namespace filter { + +static int FILTER_SIZE = 2048; + +void set_filter_capacity(uint32_t cap) { FILTER_SIZE = cap; } + +int calc_division_capacity(int chw) { + int n = FILTER_SIZE / ((chw + 15) / 16) * 32; + return n < FILTER_SIZE ? n : FILTER_SIZE; +} + +int calc_split_num(int num, int division_capacity) { + return (num + division_capacity - 1) / division_capacity; +} + +int calc_division_number(int num, int group_num, int division_capacity) { + int split_num = calc_split_num(num, division_capacity); + return group_num * split_num; +} + +int calc_num_per_div(int num, int group_num, int division_capacity) { + if (group_num == 1) { + if (num > division_capacity) { + return division_capacity; + } else { + return num; + } + } else { + return (num + group_num - 1) / group_num; + } +} + +void convert_to_hwc( + char **data_in, int num, int channel, int height, int width) { + char *tmp = *data_in; + int chw = channel * height * width; + char *data_tmp = (char *)fpga_malloc(chw * num * sizeof(char)); // NOLINT + for (int n = 0; n < num; n++) { + int64_t amount_per_row = width * channel; + for (int c = 0; c < channel; c++) { + for (int h = 0; h < height; h++) { + int64_t offset_height = h * amount_per_row; + for (int w = 0; w < width; w++) { + *(data_tmp + n * chw + offset_height + w * channel + c) = + *((*data_in)++); + } + } + } + } + *data_in = data_tmp; + fpga_free(tmp); +} + +float find_max(float *data_in, int data_size) { + float max = 0.0; + for (int i = 0; i < data_size; ++i) { + float value = data_in[i]; + float abs = value > 0 ? value : -value; + max = std::max(max, abs); + } + return max; +} + +signed char float_to_int8(float fdata) { + if (fdata < 0.0) { + fdata -= 0.5; + } else { + fdata += 0.5; + } + return (signed char)fdata; +} + +void quantize(float **data_in, int data_size, float max) { + float *tmp = *data_in; + float fix_range = 127; + float scale = fix_range / max; + + signed char *tmp_data = (signed char *)fpga_malloc(data_size * sizeof(char)); + for (int i = 0; i < data_size; i++) { + tmp_data[i] = float_to_int8( + (*data_in)[i] * scale); // (signed char)((*data_in)[i] * scale); + } + *data_in = (float *)tmp_data; // NOLINT + fpga_free(tmp); +} + +void align_element(char **data_in, int num, int chw) { + int j = 0; + int align_chw = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); + if (align_chw != chw) { + char *tmp = *data_in; + char *data_tmp = + (char *)fpga_malloc(num * align_chw * sizeof(char)); // NOLINT + + memset(data_tmp, 0, num * align_chw); + for (j = 0; j < num; j++) { + memcpy(data_tmp + j * align_chw, (*data_in) + j * chw, chw); + } + *data_in = data_tmp; + fpga_free(tmp); + } +} + +void align_num(char **data_in, + int num_per_div_before_alignment, + int num, + int chw) { + int i = 0; + int align_chw = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); + int num_per_div_after_alignment = + align_to_x(num_per_div_before_alignment, FILTER_NUM_ALIGNMENT); + + char *tmp = *data_in; + int div_num = + (num + num_per_div_before_alignment - 1) / num_per_div_before_alignment; + int num_element = div_num * num_per_div_after_alignment * align_chw; + char *data_tmp = (char *)fpga_malloc(num_element * sizeof(char)); // NOLINT + + memset(data_tmp, 0, num_element * sizeof(char)); + + for (i = 0; i < div_num - 1; i++) { + memcpy(data_tmp + num_per_div_after_alignment * align_chw * i, + *data_in + num_per_div_before_alignment * align_chw * i, + num_per_div_before_alignment * align_chw); + } + + memcpy(data_tmp + num_per_div_after_alignment * align_chw * i, + *data_in + num_per_div_before_alignment * align_chw * i, + (num - (div_num - 1) * num_per_div_before_alignment) * align_chw); + + *data_in = data_tmp; + fpga_free(tmp); +} + +void reorder(char **data_in, int num_after_alignment, int chw) { + int index = 0; + int new_index = 0; + + int chw_align = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); + + char *data_tmp = + (char *)fpga_malloc(chw_align * num_after_alignment * // NOLINT + sizeof(char)); + char *tmp = *data_in; + for (index = 0; index < num_after_alignment; index++) { + new_index = index / 32 * 32 + (index % 16 / 4 * 8) + (index % 16 % 4) + + (index / 16 % 2 * 4); + memcpy(data_tmp + index * chw_align, + *data_in + new_index * chw_align, + chw_align); + } + *data_in = data_tmp; + fpga_free(tmp); +} + +size_t interleave(char **data_in, int num_after_alignment, int chw) { + int i = 0; + int j = 0; + int k = 0; + int interleave_per_num = 16; + + int chw_align = align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); + char *data_tmp = + (char *)fpga_malloc(chw_align * num_after_alignment * // NOLINT + sizeof(char)); + char *tmp = *data_in; + int interleave_num = chw_align * 2 / interleave_per_num; + for (i = 0; i < num_after_alignment; i += 2) { + for (j = 0, k = 0; j < interleave_num; j += 2, k++) { + memcpy(data_tmp + i * chw_align + interleave_per_num * j, + *data_in + i * chw_align + interleave_per_num * k, + interleave_per_num); + memcpy(data_tmp + i * chw_align + interleave_per_num * (j + 1), + *data_in + (i + 1) * chw_align + interleave_per_num * k, + interleave_per_num); + } + } + *data_in = data_tmp; + fpga_free(tmp); + return chw_align * num_after_alignment; +} + +size_t format_filter(float **data_in, + int num, + int channel, + int height, + int width, + int group_num, + float max) { + int data_size = channel * height * width * num; + int chw = channel * height * width; + + int division_capacity = calc_division_capacity(chw); + int num_per_div_before_alignment = + calc_num_per_div(num, group_num, division_capacity); + int num_per_div_after_alignment = + align_to_x(num_per_div_before_alignment, FILTER_NUM_ALIGNMENT); + int div_num = + (num + num_per_div_before_alignment - 1) / num_per_div_before_alignment; + int residual = num % num_per_div_before_alignment; + int num_after_alignment = num_per_div_after_alignment * + ((residual == 0) ? div_num : (div_num - 1)) + + align_to_x(residual, FILTER_NUM_ALIGNMENT); + quantize(data_in, data_size, max); + char **quantize_data = (char **)data_in; // NOLINT + convert_to_hwc(quantize_data, num, channel, height, width); + align_element(quantize_data, num, chw); + if (num_after_alignment != num) { + align_num(quantize_data, num_per_div_before_alignment, num, chw); + } + + reorder(quantize_data, num_after_alignment, chw); + size_t mem_size = interleave(quantize_data, num_after_alignment, chw); + fpga_flush(*quantize_data, + align_to_x(chw, FILTER_ELEMENT_ALIGNMENT) * num_after_alignment * + sizeof(char)); + return mem_size; +} + +void convert_to_hwn(int16_t **data_in, int num, int height, int width) { + int16_t *tmp = *data_in; + int16_t *data_tmp = + (int16_t *)fpga_malloc(height * width * num * sizeof(int16_t)); // NOLINT + for (int n = 0; n < num; n++) { + for (int h = 0; h < height; h++) { + for (int w = 0; w < width; w++) { + *(data_tmp + h * width * num + w * num + n) = *((*data_in)++); + } + } + } + *data_in = data_tmp; + fpga_free(tmp); +} + +size_t align_element_n(int16_t **data_in, int num, int height, int width) { + int unalign_n = num; + int align_n = align_to_x(num, FILTER_ELEMENT_ALIGNMENT); + int num_element = height * width * align_n; + if (unalign_n != align_n) { + int16_t *tmp = *data_in; + + int num_element = height * width * align_n; + int16_t *data_tmp = + (int16_t *)fpga_malloc(num_element * sizeof(int16_t)); // NOLINT + + memset(data_tmp, 0, num_element * sizeof(int16_t)); + for (int h = 0; h < height; h++) { + for (int w = 0; w < width; w++) { + int offset_unalign = h * width * unalign_n + w * unalign_n; + int offset_align = h * width * align_n + w * align_n; + for (int n = 0; n < unalign_n; n++) { + data_tmp[offset_align + n] = *((*data_in) + offset_unalign + n); + } + } + } + *data_in = data_tmp; + free(tmp); + } + return num_element * sizeof(int16_t); +} + +void quantize_to_fp16( + float **data_in, int num, int height, int width, float *scale_ptr) { + float *tmp = *data_in; + int size = num * height * width; + + float16 *tmp_data = (float16 *)fpga_malloc(size * sizeof(float16)); // NOLINT + for (int n = 0; n < num; n++) { + float scale_val = scale_ptr[n]; + for (int h = 0; h < height; h++) { + for (int w = 0; w < width; w++) { + int index = n * height * width + h * width + w; + float value = tmp[index] * scale_val; + tmp_data[index] = float_to_half(value); + } + } + } + fpga_flush(tmp_data, size * sizeof(int16_t)); + *data_in = (float *)tmp_data; // NOLINT + fpga_free(tmp); +} +size_t format_dwconv_filter( + float **data_in, int num, int height, int width, float *scale_ptr) { + quantize_to_fp16(data_in, num, height, width, scale_ptr); + int16_t **quantize_data = (int16_t **)data_in; // NOLINT + convert_to_hwn(quantize_data, num, height, width); + size_t size = align_element_n(quantize_data, num, height, width); + fpga_flush(*quantize_data, + align_to_x(num, FILTER_ELEMENT_ALIGNMENT) * height * width * + sizeof(int16_t)); + return size; +} +} // namespace filter +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/llapi/filter.h b/lite/fpga/KD/llapi/filter.h new file mode 100644 index 00000000000..7d9c6c2e015 --- /dev/null +++ b/lite/fpga/KD/llapi/filter.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +namespace paddle { +namespace zynqmp { +namespace filter { + +void set_filter_capacity(uint32_t cap); +int calc_division_capacity(int chw); +int calc_split_num(int num, int division_capacity); +int calc_division_number(int num, int group_num, int division_capacity); +int calc_num_per_div(int num, int group_num, int division_capacity); +void convert_to_hwc( + char** data_in, int num, int channel, int height, int width); +float find_max(float* data_in, int data_size); +void quantize(float** data_in, int data_size, float max); +void align_element(char** data_in, int num, int chw); +void align_num(char** data_in, + int num_per_div_before_alignment, + int num, + int chw); +void reorder(char** data_in, int num_after_alignment, int chw); +size_t interleave(char** data_in, int num_after_alignment, int chw); +size_t format_filter(float** data_in, + int num, + int channel, + int height, + int width, + int group_num, + float max); + +void convert_to_hwn(int16_t** data_in, int num, int height, int width); +size_t align_element_n(int16_t** data_in, int num, int height, int width); +void quantize_to_fp16( + float** data_in, int num, int height, int width, float* scale_ptr); +size_t format_dwconv_filter( + float** data_in, int num, int height, int width, float* scale_ptr); + +} // namespace filter +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/llapi/zynqmp_api.cpp b/lite/fpga/KD/llapi/zynqmp_api.cpp new file mode 100644 index 00000000000..b61eda8d9d7 --- /dev/null +++ b/lite/fpga/KD/llapi/zynqmp_api.cpp @@ -0,0 +1,323 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "lite/fpga/KD/llapi/config.h" +#include "lite/fpga/KD/llapi/zynqmp_api.h" + +namespace paddle { +namespace zynqmp { + +#define PADDLE_LITE_OS_LINUX + +static int fd = -1; +static const char *device_path = "/dev/fpgadrv0"; +static std::map memory_map; + +static size_t memory_size_max = 0; +static size_t memory_size = 0; + +static inline int do_ioctl(uint64_t req, const void *arg) { +#ifdef PADDLE_LITE_OS_LINUX + return ioctl(fd, req, arg); +#else + return -1; +#endif +} + +int open_device() { + if (fd == -1) { + fd = open(device_path, O_RDWR); + } + return fd; +} + +void close_device() { close(fd); } + +void reset_device() { + FpgaResetArgs args; + do_ioctl(IOCTL_FPGA_RESET, &args); +} + +// memory management; +void *fpga_malloc(size_t size) { +#ifdef PADDLE_LITE_OS_LINUX + void *ptr = reinterpret_cast( + mmap64(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0)); + if (ptr == NULL) { + std::cout << "not enough memory !"; + exit(-1); + } + memory_map.insert(std::make_pair(ptr, size)); + memory_size += size; + if (memory_size > memory_size_max) { + memory_size_max = memory_size; + } + return ptr; +#else + return malloc(size); +#endif +} + +size_t fpga_get_memory_size(void *ptr) { return memory_map[ptr]; } + +size_t fpga_get_memory_size_max() { return memory_size_max; } + +size_t fpga_diagnose_memory(int detailed) { + size_t total = 0; + auto iter = memory_map.begin(); // std::map::iterator + while (iter != memory_map.end()) { + total += iter->second; + iter++; + } + return total; +} + +void fpga_free(void *ptr) { + size_t size = 0; + auto iter = memory_map.find(ptr); // std::map::iterator + if (iter != memory_map.end()) { + size = iter->second; + memory_map.erase(iter); + } + + memory_size -= size; + +#ifdef PADDLE_LITE_OS_LINUX + + munmap(ptr, size); +#else + free(ptr); +#endif +} + +void fpga_copy(void *dst, const void *src, int size) { memcpy(dst, src, size); } + +int fpga_flush(void *address, size_t size) { + struct MemoryCacheArgs args; + args.address = address; + args.size = size; + return do_ioctl(IOCTL_MEMCACHE_FLUSH, &args); +} + +int fpga_invalidate(void *address, size_t size) { + struct MemoryCacheArgs args; + args.address = address; + args.size = size; + return do_ioctl(IOCTL_MEMCACHE_INVAL, &args); +} + +int invalidate_cache(void *addr, int size) { + struct MemoryCacheArgs args; + args.address = addr; + args.size = size; + return do_ioctl(IOCTL_MEMCACHE_INVAL, &args); +} + +int flush_cache(void *addr, int size) { + struct MemoryCacheArgs args; + args.address = addr; + args.size = size; + return do_ioctl(IOCTL_MEMCACHE_FLUSH, &args); +} + +void fpga_copy(void *dest, const void *src, size_t num) { + memcpy(dest, src, num); +} + +int ioctl_conv(const struct ConvArgs &args) { + return do_ioctl(IOCTL_CONFIG_CONV, &args); +} + +int compute_fpga_conv_basic(const struct ConvArgs &args) { + return do_ioctl(IOCTL_CONFIG_CONV, &args); +} + +int compute_fpga_conv(const struct SplitConvArgs &args) { + int split_num = args.split_num; + int ret = -1; + for (int i = 0; i < split_num; i++) { + ret = compute_fpga_conv_basic(args.conv_arg[i]); + } + + if (split_num > 1) { + std::cout << "Split num > 1 !!!!!!!!!!!!!!!!!!" << std::endl; + exit(-1); + } + return ret; +} + +int compute_fpga_pool(const struct PoolingArgs &args) { + return do_ioctl(IOCTL_CONFIG_POOLING, &args); +} + +int compute_fpga_ewadd(const struct EWAddArgs &args) { + return do_ioctl(IOCTL_CONFIG_EW, &args); +} + +int get_device_info(const struct DeviceInfo &args) { + int ret = do_ioctl(IOCTL_DEVICE_INFO, &args); + return ret; +} + +int perform_bypass(const struct BypassArgs &args) { + int size = args.image.channels * args.image.width * args.image.height; + int max_size = 1 << 21; + + float times = 1.0 * size / max_size; + int count = static_cast(times); + + void *input_address = args.image.address; + int type_size = + args.input_data_type == DATA_TYPE_FP32 ? sizeof(float) : sizeof(int16_t); + + void *output_address = args.output.address; + int out_type_size = + args.output_data_type == DATA_TYPE_FP32 ? sizeof(float) : sizeof(int16_t); + + float scales[2]; + struct BypassArgs bypassArgs = args; + bypassArgs.image.width = 1; + bypassArgs.image.height = 1; + bypassArgs.output.scale_address = scales; + + float scale = 0; + for (int i = 0; i < count; ++i) { + bypassArgs.image.channels = max_size; + bypassArgs.image.address = + reinterpret_cast(input_address + i * max_size * type_size); + bypassArgs.output.address = + reinterpret_cast(output_address + i * max_size * out_type_size); + int ret = do_ioctl(IOCTL_CONFIG_BYPASS, &bypassArgs); + scale = std::max(scale, scales[0]); + + if (ret != 0) { + return ret; + } + } + + int remainder = size - max_size * count; + bypassArgs.image.channels = remainder; + bypassArgs.image.address = + reinterpret_cast(input_address + count * max_size * type_size); + bypassArgs.output.address = reinterpret_cast( + output_address + count * max_size * out_type_size); + int ret = do_ioctl(IOCTL_CONFIG_BYPASS, &bypassArgs); + scale = std::max(scale, scales[0]); + args.output.scale_address[0] = scale; + args.output.scale_address[1] = 1.0f / scale; + return ret; +} + +int compute_fpga_concat(const struct ConcatArgs &args) { return -1; } + +int compute_fpga_scale(const struct ScaleArgs &args) { +#ifdef ENABLE_DEBUG + std::cout << "======Compute Scale======"; + std::cout << "scale_address:" << args.scale_address << std::endl; + std::cout << "bias_address:" << args.bias_address << std::endl; + + std::cout << "wc_alignment:" << args.wc_alignment << std::endl; + std::cout << "channel_alignment:" << args.channel_alignment << std::endl; + + std::cout << " image_address:" << args.image.address + << " image_scale_address:" << args.image.scale_address + << " image_channels:" << args.image.channels + << " image_height:" << args.image.height + << " image_width:" << args.image.width + << " pad_height:" << args.image.pad_height + << " pad_width:" << args.image.pad_width; + + std::cout << " out_address:" << args.output.address + << " out_scale_address:" << args.output.scale_address; + +#endif + return do_ioctl(IOCTL_CONFIG_SCALE, &args); +} + +int compute_fpga_dwconv(const struct DWconvArgs &args) { +#ifdef ENABLE_DEBUG + std::cout << "======Compute Basic Conv======"; + std::cout << " relu_enabled:" << args.relu_enabled + << " filter_address:" << args.filter_address; + std::cout << " image_address:" << args.image.address + << " image_scale_address:" << args.image.scale_address + << " image_channels:" << args.image.channels + << " image_height:" << args.image.height + << " image_width:" << args.image.width + << " pad_height:" << args.image.pad_height + << " pad_width:" << args.image.pad_width; + std::cout << " kernel_height:" << args.kernel.height + << " kernel_width:" << args.kernel.width + << " stride_h:" << args.kernel.stride_h + << " stride_w:" << args.kernel.stride_w; + std::cout << " out_address:" << args.output.address + << " out_scale_address:" << args.output.scale_address; + +#endif + return do_ioctl(IOCTL_CONFIG_DWCONV, &args); +} + +int config_inplace(const struct InplaceArgs &args) { + return do_ioctl(IOCTL_CONFIG_INPLACE, &args); +} + +int config_norm_param(const struct NormalizeParameterArgs &args) { + return do_ioctl(IOCTL_CONFIG_NORMALIZE_PARAMETER, &args); +} + +int compute_norm(const struct NormalizeArgs &args) { + return do_ioctl(IOCTL_CONFIG_NORMALIZE, &args); +} + +int compute_fpga_resize(const struct ResizeArgs &args) { + return do_ioctl(IOCTL_CONFIG_RESIZE, &args); +} + +int16_t fp32_2_fp16(float fp32_num) { + unsigned long tmp = *(unsigned long *)(&fp32_num); // NOLINT + auto t = (int16_t)(((tmp & 0x007fffff) >> 13) | ((tmp & 0x80000000) >> 16) | + (((tmp & 0x7f800000) >> 13) - (112 << 10))); + if (tmp & 0x1000) { + t++; // roundoff + } + return t; +} + +float fp16_2_fp32(int16_t fp16_num) { + if (0 == fp16_num) { + return 0; + } + int frac = (fp16_num & 0x3ff); + int exp = ((fp16_num & 0x7c00) >> 10) + 112; + int s = fp16_num & 0x8000; + int tmp = 0; + float fp32_num = 0; + tmp = s << 16 | exp << 23 | frac << 13; + fp32_num = *(float *)&tmp; // NOLINT + return fp32_num; +} + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/llapi/zynqmp_api.h b/lite/fpga/KD/llapi/zynqmp_api.h new file mode 100644 index 00000000000..3dd7f1e981a --- /dev/null +++ b/lite/fpga/KD/llapi/zynqmp_api.h @@ -0,0 +1,337 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +namespace paddle { +namespace zynqmp { + +typedef int16_t half; + +#define IMAGE_ALIGNMENT 16 // Aligned to 16 +#define FILTER_NUM_ALIGNMENT 32 // Filter number aligned to 32 +#define FILTER_ELEMENT_ALIGNMENT 16 // Filter element number aligned to 16 +#define BS_NUM_ALIGNMENT 8 +#define BIAS_NUM_ALIGNMENT 16 + +enum DDataType { + DATA_TYPE_FP32 = 1, + DATA_TYPE_FP16 = 0, +}; + +enum DLayoutType { + LAYOUT_CHW = 1, + LAYOUT_HWC = 0, +}; + +struct VersionArgs { + void* buffer; +}; + +struct DeviceInfo { + uint32_t filter_cap; +}; + +struct MemoryCopyArgs { + void* src; + void* dest; + size_t size; +}; + +struct MemoryCacheArgs { + void* address; + size_t size; +}; + +struct MemoryBarrierArgs {}; + +struct BNArgs { + bool enabled; + void* bias_address; + void* scale_address; +}; + +/** +Conv and Pooling kernel +*/ +struct KernelArgs { + uint32_t width; + uint32_t height; + uint32_t stride_w; + uint32_t stride_h; +}; + +struct ImageInputArgs { + void* address; // input featuremap virtual address + void* scale_address; // input scale address; + uint32_t channels; + uint32_t width; // featuremap width + uint32_t height; + uint32_t pad_width; // padding width; + uint32_t pad_height; +}; + +struct ImageOutputArgs { + void* address; // output result address; + float* scale_address; // output scale address; +}; + +struct ConvArgs { + bool relu_enabled; + void* sb_address; // scale and bias are interlaced; + void* filter_address; + void* filter_scale_address; + uint32_t filter_num; + uint32_t group_num; + + struct KernelArgs kernel; + struct ImageInputArgs image; // input image; + struct ImageOutputArgs output; +}; + +struct DWconvArgs { + bool relu_enabled; + void* bias_address; + void* filter_address; + struct KernelArgs kernel; + struct ImageInputArgs image; + struct ImageOutputArgs output; + uint16_t out_width; + uint16_t out_height; + uint16_t sub_conv_num; +}; + +struct PoolingArgs { + uint16_t mode; + uint16_t kernel_reciprocal; + struct KernelArgs kernel; + struct ImageInputArgs image; // input image; + struct ImageOutputArgs output; + uint16_t out_width; + uint16_t out_height; +}; + +// elementwise add arguments +struct EWAddArgs { + bool relu_enabled; + + uint32_t const0; // output0 = const0 x input0 + const1 x input1; + uint32_t const1; + struct ImageInputArgs image0; + struct ImageInputArgs image1; + struct ImageOutputArgs output; +}; + +struct BypassArgs { + enum DDataType input_data_type; + enum DDataType output_data_type; + enum DLayoutType input_layout_type; + enum DLayoutType output_layout_type; + struct ImageInputArgs image; + struct ImageOutputArgs output; +}; + +struct ScaleArgs { + void* scale_address; + void* bias_address; + uint32_t wc_alignment; + uint32_t channel_alignment; + + struct ImageInputArgs image; + struct ImageOutputArgs output; +}; + +struct NormalizeArgs { + void* input_image_address; + void* output_image_address; + uint32_t image_width; + uint32_t image_height; + uint32_t image_channel; + uint32_t* output_scale_address; +}; + +struct ResizeArgs { + void* input_image_address; + void* output_image_address; + uint32_t input_width; + uint32_t input_height; + uint32_t image_channel; + uint32_t output_width; + uint32_t output_height; + uint32_t height_ratio; + uint32_t width_ratio; + uint32_t* output_scale_address; +}; + +struct PowerParameterArgs { + uint16_t shift; + uint16_t scale; + uint16_t power; +}; + +struct NormalizeParameterArgs { + uint32_t channel; + uint32_t hight_width; +}; + +struct InplaceArgs { + bool relu_enable; + bool power_enable; + bool normalize_enable; +}; + +struct FpgaRegWriteArgs { + uint64_t address; // + uint64_t value; +}; + +struct FpgaRegReadArgs { + uint64_t address; + uint64_t value; +}; + +struct FpgaResetArgs {}; + +#define IOCTL_FPGA_MAGIC (('F' + 'P' + 'G' + 'A') / 4) + +#define IOCTL_VERSION _IOW(IOCTL_FPGA_MAGIC, 01, struct VersionArgs) +#define IOCTL_DEVICE_INFO _IOW(IOCTL_FPGA_MAGIC, 100, struct DeviceInfo) + +#define IOCTL_SEPARATOR_0 10 + +#define IOCTL_MEM_COPY _IOW(IOCTL_FPGA_MAGIC, 11, struct MemoryCopyArgs) +#define IOCTL_MEMCACHE_INVAL _IOW(IOCTL_FPGA_MAGIC, 12, struct MemoryCacheArgs) +#define IOCTL_MEMCACHE_FLUSH _IOW(IOCTL_FPGA_MAGIC, 13, struct MemoryCacheArgs) +#define IOCTL_MEMORY_BARRIER \ + _IOW(IOCTL_FPGA_MAGIC, 14, struct MemoryBarrierArgs) + +#define IOCTL_SEPARATOR_1 20 + +#define IOCTL_CONFIG_CONV _IOW(IOCTL_FPGA_MAGIC, 21, struct ConvArgs) +#define IOCTL_CONFIG_POOLING _IOW(IOCTL_FPGA_MAGIC, 22, struct PoolingArgs) +#define IOCTL_CONFIG_EW _IOW(IOCTL_FPGA_MAGIC, 23, struct EWAddArgs) +#define IOCTL_CONFIG_BYPASS _IOW(IOCTL_FPGA_MAGIC, 24, struct BypassArgs) +#define IOCTL_CONFIG_SCALE _IOW(IOCTL_FPGA_MAGIC, 25, struct ScaleArgs) +#define IOCTL_CONFIG_NORMALIZE _IOW(IOCTL_FPGA_MAGIC, 26, struct NormalizeArgs) +#define IOCTL_CONFIG_RESIZE _IOW(IOCTL_FPGA_MAGIC, 30, struct ResizeArgs) + +#define IOCTL_CONFIG_DWCONV _IOW(IOCTL_FPGA_MAGIC, 31, struct DWconvArgs) + +#define IOCTL_CONFIG_INPLACE _IOW(IOCTL_FPGA_MAGIC, 40, struct InplaceArgs) +#define IOCTL_CONFIG_POWER_PARAMETER \ + _IOW(IOCTL_FPGA_MAGIC, 41, struct PowerParameterArgs) +#define IOCTL_CONFIG_NORMALIZE_PARAMETER \ + _IOW(IOCTL_FPGA_MAGIC, 42, struct NormalizeParameterArgs) +#define IOCTL_FPGA_REG_READ _IOW(IOCTL_FPGA_MAGIC, 50, struct FpgaRegReadArgs) +#define IOCTL_FPGA_REG_WRITE _IOW(IOCTL_FPGA_MAGIC, 51, struct FpgaRegWriteArgs) +#define IOCTL_FPGA_RESET _IOW(IOCTL_FPGA_MAGIC, 52, struct FpgaResetArgs) + +//============================== API ============================= + +struct DeconvArgs { + uint32_t sub_conv_num; + uint32_t group_num; + uint32_t filter_num; + uint32_t omit_size; + uint32_t sub_output_width; + uint32_t sub_output_height; + struct ImageOutputArgs output; + struct SplitConvArgs* split_conv_args; +}; + +struct SplitArgs { + uint32_t image_num; + int16_t* image_in; + float* scale_in; + void** images_out; + float** scales_out; + uint32_t* out_channel_nums; + uint32_t height; + uint32_t width; +}; + +struct ConcatArgs { + uint32_t image_num; + half** images_in; + float** scales_in; + void* image_out; + float* scale_out; + uint32_t* channel_num; + uint32_t height; + uint32_t width; +}; + +struct SplitConvArgs { + uint32_t split_num; + uint32_t group_num; + uint32_t filter_num; + struct ImageOutputArgs output; + struct ConvArgs* conv_arg; + struct ConcatArgs concat_arg; +}; + +struct GroupConvArgs { + uint32_t group_num; + uint32_t filter_num; + struct ImageOutputArgs output; + struct SplitConvArgs* conv_args; + struct ConcatArgs concat_arg; +}; + +inline int align_to_x(int num, int x) { return (num + x - 1) / x * x; } +int open_device(); +void close_device(); +void reset_device(); + +void* fpga_malloc(size_t size); +void fpga_free(void* ptr); +size_t fpga_get_memory_size(void* ptr); +size_t fpga_get_memory_size_max(); +size_t fpga_diagnose_memory(int detailed); + +void fpga_copy(void* dst, const void* src, int size); + +int fpga_flush(void* address, size_t size); +int fpga_invalidate(void* address, size_t size); + +int get_device_info(const struct DeviceInfo& args); + +int perform_bypass(const struct BypassArgs& args); +int compute_fpga_conv_basic(const struct ConvArgs& args); +int compute_fpga_conv(const struct SplitConvArgs& args); +int compute_fpga_pool(const struct PoolingArgs& args); +int compute_fpga_ewadd(const struct EWAddArgs& args); +int compute_fpga_scale(const struct ScaleArgs& args); +int compute_fpga_concat(const struct ConcatArgs& args); +int compute_fpga_resize(const struct ResizeArgs& args); + +int config_power(const struct PowerArgs& args); +int compute_fpga_dwconv(const struct DWconvArgs& args); +int config_norm_param(const struct NormalizeParameterArgs& args); +int compute_norm(const struct NormalizeArgs& args); + +int config_inplace(const struct InplaceArgs& args); + +int flush_cache(void* addr, int size); +int invalidate_cache(void* addr, int size); + +int16_t fp32_2_fp16(float fp32_num); +float fp16_2_fp32(int16_t fp16_num); +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pe.hpp b/lite/fpga/KD/pe.hpp new file mode 100644 index 00000000000..0695be603ae --- /dev/null +++ b/lite/fpga/KD/pe.hpp @@ -0,0 +1,37 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "lite/fpga/KD/pe_params.hpp" +#include "lite/fpga/KD/tensor_util.hpp" + +namespace paddle { +namespace zynqmp { + +class PE { + public: + virtual bool init() { return false; } + + virtual void apply() {} + + virtual bool dispatch() { return false; } + + virtual ~PE() {} +}; + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pe_params.hpp b/lite/fpga/KD/pe_params.hpp new file mode 100644 index 00000000000..b6f4500a1a9 --- /dev/null +++ b/lite/fpga/KD/pe_params.hpp @@ -0,0 +1,233 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "lite/fpga/KD/llapi/zynqmp_api.h" +#include "lite/fpga/KD/tensor.hpp" + +namespace paddle { +namespace zynqmp { + +struct ReLUParam { + public: + bool enabled = false; +}; + +struct PEParam { + ReLUParam relu; +}; + +struct InputParam : PEParam { + public: + Tensor* input = nullptr; + Tensor* output = nullptr; +}; + +struct OutputParam : PEParam { + public: + Tensor* input = nullptr; + Tensor* output = nullptr; +}; + +struct BatchnormParam : PEParam { + public: + Tensor* input = nullptr; + Tensor* output = nullptr; + + Tensor* bias = nullptr; + Tensor* scale = nullptr; + Tensor* mean = nullptr; + Tensor* variance = nullptr; + float epsilon = 0; +}; + +struct BasicConvParam { + Tensor input; + Tensor output; + Tensor filter; + Tensor scaleBias; + ConvArgs args; +}; + +struct ConvParam : PEParam { + public: + Tensor* input = nullptr; + Tensor* output = nullptr; + Tensor* filter = nullptr; + + int groups = 1; + std::vector strides; + std::vector paddings; + std::vector kernelSize; + std::vector dilations; + + Tensor* scale() { return scale_; } + + Tensor* bias() { return bias_; } + + std::vector& splitParams() { return splitParams_; } + + protected: + std::vector splitParams_; + Tensor* scale_ = new Tensor(); + Tensor* bias_ = new Tensor(); +}; + +struct DepthwiseConvParam : ConvParam { + public: + Tensor* quantizedFilter() { return quantizedFilter_; } + + DWconvArgs args; + + protected: + Tensor* quantizedFilter_ = new Tensor(); +}; + +enum PoolingType : int { + MAX = 0, + AVERAGE = 1, +}; + +struct PoolingParam : PEParam { + public: + Tensor* input = nullptr; + Tensor* output = nullptr; + + PoolingType type = PoolingType::MAX; + bool globalPooling = false; + std::vector kernelSize; + std::vector strides; + std::vector paddings; + + PoolingArgs poolingArgs = {0}; +}; + +struct ConcatParam : PEParam { + public: + std::vector inputs; + Tensor* output; + int axis = 0; +}; + +struct ElementwiseAddParam : PEParam { + public: + std::vector inputs; + Tensor* output = nullptr; + int axis = 0; + + EWAddArgs ewargs; +}; + +struct FullyConnectedParam : PEParam { + public: + Tensor* input = nullptr; + Tensor* filter = nullptr; + Tensor* bias = nullptr; + Tensor* output = nullptr; + + Tensor* quantizedFilter() { return quantizedFilter_; } + + Tensor* biasScale() { return biasScale_; } + + protected: + Tensor* quantizedFilter_ = new Tensor(); + Tensor* biasScale_ = new Tensor(); +}; + +struct SoftmaxParam : PEParam { + public: + Tensor* input = nullptr; + + Tensor* output = nullptr; + + private: + Tensor* floatInput = nullptr; +}; + +struct SplitParam : PEParam { + public: + Tensor* input = nullptr; + std::vector outputs; + int axis = 1; + int num = 1; +}; + +struct NormParam : PEParam { + public: + Tensor* input = nullptr; + + Tensor* output = nullptr; + float epsilon = 0; + + private: + Tensor* floatInput = nullptr; +}; + +struct PriorBoxParam : PEParam { + Tensor* input; + Tensor* image; + Tensor* outputBoxes; + Tensor* outputVariances; + + std::vector minSizes; + std::vector maxSizes; + std::vector aspectRatios; + std::vector variances; + + bool minMaxAspectRatiosOrder; + bool flip; + bool clip; + float stepW; + float stepH; + float offset; +}; + +struct ScaleParam : PEParam { + public: + Tensor* input = nullptr; + Tensor* output = nullptr; + Tensor* scale = nullptr; + Tensor* bias = nullptr; + + Tensor* alignedScale() { return alignedScale_; } + + Tensor* alignedBias() { return alignedBias_; } + + ScaleArgs args = {0}; + + protected: + Tensor* alignedScale_ = new Tensor(); + Tensor* alignedBias_ = new Tensor(); +}; + +struct ResizeParam : PEParam { + public: + Tensor* input = nullptr; + Tensor* output = nullptr; +}; + +struct CropParam : PEParam { + public: + Tensor* input = nullptr; + Tensor* output = nullptr; + int axis = 2; + std::vector offsets; + std::vector shape; +}; +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/batchnorm_pe.hpp b/lite/fpga/KD/pes/batchnorm_pe.hpp new file mode 100644 index 00000000000..73f9849680b --- /dev/null +++ b/lite/fpga/KD/pes/batchnorm_pe.hpp @@ -0,0 +1,105 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" +#include "lite/fpga/KD/pes/scale_pe.hpp" + +namespace paddle { +namespace zynqmp { +class BatchnormPE : public PE { + public: + bool init() { + Tensor* output = param_.output; + output->setAligned(true); + output->setDataLocation(Device); + + ScaleParam& scale_param = scalePE_.param(); + scale_param.input = param_.input; + scale_param.output = param_.output; + Tensor* scale = new Tensor(); + Tensor* bias = new Tensor(); + Shape shape(N, {output->shape().channel()}); + + auto mean_data = param_.mean->data(); + auto variance_data = param_.variance->data(); + auto scale_data = param_.scale->data(); + auto bias_data = param_.bias->data(); + auto new_scale_ptr = scale->mutableData(FP32, shape); + auto new_bias_ptr = bias->mutableData(FP32, shape); + + float epsilon = param_.epsilon; + + Shape& in_shape = param_.input->shape(); + bool match = in_shape.channel() == 128 && in_shape.height() == 128 && + in_shape.width() == 128; + + for (int c = 0; c < output->shape().channel(); c++) { + float var = variance_data[c]; + float inv_scale = 1.0 / (std::sqrt(var + epsilon)); + float scale_value = inv_scale * scale_data[c]; + float bias_value = bias_data[c] - scale_value * mean_data[c]; + new_scale_ptr[c] = scale_value; + new_bias_ptr[c] = bias_value; + } + + scale->flush(); + bias->flush(); + + scale_param.scale = scale; + scale_param.bias = bias; + scale_param.relu = param_.relu; + + scalePE_.init(); + + inplace_.relu_enable = param_.relu.enabled; + inplace_.relu_enable = true; + inplace_.power_enable = false; + inplace_.normalize_enable = false; + + return true; + } + + void apply() { scalePE_.apply(); } + + bool dispatch() { + if (inplace_.relu_enable) { + config_inplace(inplace_); + } + bool ret = scalePE_.dispatch(); + + inplace_.relu_enable = false; + config_inplace(inplace_); + return ret; + } + + BatchnormParam& param() { return param_; } + + ~BatchnormPE() { + scalePE_.param().input = nullptr; + scalePE_.param().output = nullptr; + } + + private: + BatchnormParam param_; + ScalePE scalePE_; + InplaceArgs inplace_; +}; +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/concat_pe.hpp b/lite/fpga/KD/pes/concat_pe.hpp new file mode 100644 index 00000000000..fb3f87ecafb --- /dev/null +++ b/lite/fpga/KD/pes/concat_pe.hpp @@ -0,0 +1,135 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" + +namespace paddle { +namespace zynqmp { + +class ConcatPE : public PE { + public: + bool init() { + Tensor* output = param_.output; + output->setAligned(false); + output->setDataLocation(CPU); + return true; + } + + void apply() {} + + void concat2D() { + int offset = 0; + float16* out_data = param_.output->data(); + for (unsigned int n = 0; n < param_.inputs.size(); n++) { + Tensor* input = param_.inputs[n]; + Shape& input_shape = input->shape(); + + float16* src = input->data(); + memcpy(out_data + offset, src, input_shape.numel() * sizeof(float16)); + offset += input_shape.numel(); + } + Tensor* output = param_.output; + output->flush(); + } + + void concat3D() { + auto input = param_.inputs; + Tensor* output = param_.output; + int axis = param_.axis; + int num = input.size(); + int rows = 1; + auto dim_0 = input[0]->shape().dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int out_rows = rows, out_cols = 0; + + std::vector input_cols(input.size()); + for (int i = 0; i < num; ++i) { + int t_cols = input[i]->shape().numel() / rows; + out_cols += t_cols; + input_cols[i] = t_cols; + } + + // computation + for (int k = 0; k < out_rows; ++k) { + float16* dst_ptr = output->data() + k * out_cols; + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + const float16* src_prt = input[j]->data() + k * col_len; + memcpy(dst_ptr + col_idx, src_prt, sizeof(float16) * col_len); + col_idx += col_len; + } + } + output->flush(); + } + + bool dispatch() { + Tensor* output = param_.output; + Shape& output_shape = output->shape(); + + float scale = 0; + for (unsigned int n = 0; n < param_.inputs.size(); n++) { + Tensor* input = param_.inputs[n]; + input->syncToCPU(); + input->unalignImage(); + scale = std::max(scale, input->scale()[0]); + } + output->scale()[0] = scale; + output->scale()[1] = 1.0f / scale; + + if (output_shape.dimSize() == 3) { + concat3D(); + return true; + } + + if (output_shape.dimSize() == 2) { + concat2D(); + return true; + } + + float16* out_data = param_.output->data(); + int channel_sum = 0; + int out_channel = output_shape.channel(); + for (unsigned int n = 0; n < param_.inputs.size(); n++) { + Tensor* input = param_.inputs[n]; + Shape& input_shape = input->shape(); + int wh = output_shape.width() * output_shape.height(); + for (int j = 0; j < wh; j++) { + float16* src = input->data() + j * input_shape.channel(); + memcpy(out_data + j * out_channel + channel_sum, + src, + input_shape.channel() * sizeof(float16)); + } + channel_sum += input_shape.channel(); + } + output->flush(); + return true; + } + + ConcatParam& param() { return param_; } + + private: + ConcatParam param_; +}; + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/conv_pe.hpp b/lite/fpga/KD/pes/conv_pe.hpp new file mode 100644 index 00000000000..06071001358 --- /dev/null +++ b/lite/fpga/KD/pes/conv_pe.hpp @@ -0,0 +1,138 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" +#include "lite/fpga/KD/pes/concat_pe.hpp" +#include "lite/fpga/KD/pes/conv_pe.hpp" +#include "lite/fpga/KD/pes/conv_process.hpp" +#include "lite/fpga/KD/pes/elementwise_add_pe.hpp" +#include "lite/fpga/KD/pes/scale_pe.hpp" + +namespace paddle { +namespace zynqmp { + +class ConvPE : public PE { + public: + bool init() { + Tensor* output = param_.output; + output->setAligned(true); + output->setDataLocation(Device); + return true; + } + + void apply() { + split_axis = fill_split_arg(param_); + + if (split_axis == 0 && param_.splitParams().size() > 1) { + ConcatParam& concat_param = concatPE_.param(); + for (auto conv_param : param_.splitParams()) { + concat_param.inputs.push_back(&conv_param->output); + } + concat_param.output = param_.output; + concatPE_.init(); + concatPE_.apply(); + } + } + void cpu_compute() { + Tensor* input = param_.input; + Tensor* output = param_.output; + input->syncToCPU(); + + Tensor float_input; + Tensor float_output; + float* image_addr = float_input.mutableData(FP32, input->shape()); + float_input.copyFrom(input); + float* out = float_output.mutableData(FP32, output->shape()); + + int out_channel = output->shape().channel(); + int in_channel = input->shape().channel(); + + float* filter_data = param_.filter->data(); + float* mi = new float[in_channel]; + + for (int i = 0; i < out_channel; i++) { + float* image = image_addr; + float* filter_ptr = filter_data + i * in_channel; + float* out_ptr = mi; +#pragma omp parallel for + for (int j = 0; j < in_channel; j++) { + float value = image_addr[j] * filter_ptr[j]; + mi[j] = value; + } + + float sum = 0; + for (int j = 0; j < in_channel; j++) { + sum += mi[j]; + } + out[i] = sum; + } + delete[] mi; + float_output.flush(); + output->copyFrom(&float_output); + } + + bool dispatch() { + inplace_.relu_enable = param_.relu.enabled; + inplace_.power_enable = false; + inplace_.normalize_enable = false; + + if (param_.relu.enabled) { + inplace_.relu_enable = param_.relu.enabled; + config_inplace(inplace_); + } + + std::vector& params = param_.splitParams(); + int ret = 0; + for (auto conv_param : params) { + ret |= compute_fpga_conv_basic(conv_param->args); + } + + if (param_.relu.enabled) { + inplace_.relu_enable = false; + config_inplace(inplace_); + } + + size_t size = params.size(); + if (split_axis == 0 && ret == 0 && size > 1) { + concatPE_.dispatch(); + } + if (split_axis == 1 && ret == 0 && size > 1) { + ElementwiseAddParam& add_param = addPE_.param(); + add_param.inputs = {¶ms[0]->output, ¶ms[1]->output}; + add_param.output = param_.output; + addPE_.init(); + addPE_.apply(); + addPE_.dispatch(); + } + return ret == 0; + } + + ConvParam& param() { return param_; } + + private: + ConvParam param_; + ConcatPE concatPE_; + ElementwiseAddPE addPE_; + int split_axis = 0; + InplaceArgs inplace_ = {0}; +}; + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/conv_process.hpp b/lite/fpga/KD/pes/conv_process.hpp new file mode 100644 index 00000000000..e6cf2fef3bc --- /dev/null +++ b/lite/fpga/KD/pes/conv_process.hpp @@ -0,0 +1,418 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "lite/fpga/KD/float16.hpp" +#include "lite/fpga/KD/llapi/bias_scale.h" +#include "lite/fpga/KD/llapi/filter.h" +#include "lite/fpga/KD/pe_params.hpp" +#include "lite/fpga/KD/tensor.hpp" +#include "lite/fpga/KD/tensor_util.hpp" + +namespace paddle { +namespace zynqmp { + +inline int get_aligned_filter_element_num(int chw) { + return align_to_x(chw, FILTER_ELEMENT_ALIGNMENT); +} + +inline int get_filter_num_per_div(Tensor* filter, int group_num) { + auto chw = filter->shape().channel() * filter->shape().height() * + filter->shape().width(); + auto num = filter->shape().num(); + int div_capacity = filter::calc_division_capacity(chw); + return filter::calc_num_per_div(num, group_num, div_capacity); +} + +inline int get_split_num(Tensor* filter) { + auto chw = filter->shape().channel() * filter->shape().height() * + filter->shape().width(); + auto num = filter->shape().num(); + int div_capacity = filter::calc_division_capacity(chw); + return filter::calc_split_num(num, div_capacity); +} + +inline void fill_scale_bias_const(ConvParam* param_) { + int channel = param_->output->shape().channel(); + Shape sb_shape(N, {channel}); + float* new_scale_ptr = param_->scale()->mutableData(FP32, sb_shape); + float* new_bias_ptr = param_->bias()->mutableData(FP32, sb_shape); + for (int i = 0; i < channel; i++) { + new_scale_ptr[i] = 1.0f; + new_bias_ptr[i] = 0.0f; + } + param_->scale()->flush(); + param_->bias()->flush(); +} + +inline void combine_bn_params(BatchnormParam* bn, ConvParam* param_) { + int channel = param_->output->shape().channel(); + Shape sb_shape(N, {channel}); + float* new_scale_ptr = param_->scale()->mutableData(FP32, sb_shape); + float* new_bias_ptr = param_->bias()->mutableData(FP32, sb_shape); + float* bn_scale_ptr = bn->scale->data(); + float* bn_bias_ptr = bn->bias->data(); + float* bn_var_ptr = bn->variance->data(); + float* bn_mean_ptr = bn->mean->data(); + float epsilon = bn->epsilon; + for (int i = 0; i < channel; i++) { + float new_scale = bn_scale_ptr[i] / + static_cast(pow((bn_var_ptr[i] + epsilon), 0.5)); + new_scale_ptr[i] = new_scale; + new_bias_ptr[i] = bn_bias_ptr[i] + (0 - bn_mean_ptr[i]) * new_scale_ptr[i]; + } +} + +inline void combine_add_bn_params(BatchnormParam* bn, + Tensor* bias, + ConvParam* param_) { + int channel = param_->output->shape().channel(); + Shape sb_shape(N, {channel}); + float* new_scale_ptr = param_->scale()->mutableData(FP32, sb_shape); + float* new_bias_ptr = param_->bias()->mutableData(FP32, sb_shape); + if (bn != nullptr) { + float* bn_scale_ptr = bn->scale->data(); + float* bn_bias_ptr = bn->bias->data(); + float* bn_var_ptr = bn->variance->data(); + float* bn_mean_ptr = bn->mean->data(); + float epsilon = bn->epsilon; + float* bias_data = bias->data(); + for (int i = 0; i < channel; i++) { + float new_scale = bn_scale_ptr[i] / + static_cast(pow((bn_var_ptr[i] + epsilon), 0.5)); + new_scale_ptr[i] = new_scale; + new_bias_ptr[i] = + bn_bias_ptr[i] + (bias_data[i] - bn_mean_ptr[i]) * new_scale_ptr[i]; + } + } else { + for (int i = 0; i < channel; i++) { + new_scale_ptr[i] = 1.0f; + new_bias_ptr[i] = 0.0f; + } + } + param_->scale()->flush(); + param_->bias()->flush(); + param_->scale()->setDataLocation(CPU); + param_->bias()->setDataLocation(CPU); +} + +inline void format_scale_bias(Tensor* scale, + Tensor* bias, + Tensor* filter, + Tensor* scale_bias, + int group) { + float* scale_data = nullptr; + float* bias_data = nullptr; + if (scale != nullptr) { + scale_data = scale->data(); + } + if (bias != nullptr) { + bias_data = bias->data(); + } + int channel = filter->shape().num(); + Shape bias_scale_shape(N, {2 * channel}); + float* bs_data = scale_bias->mutableData(FP32, bias_scale_shape); + for (int i = 0; i < channel; i++) { + float scale_value = scale_data == nullptr ? 1 : scale_data[i]; + float bias_value = bias_data == nullptr ? 0 : bias_data[i]; + bs_data[i + channel] = scale_value; + bs_data[i] = bias_value; + } + + int element_num_per_div = get_filter_num_per_div(filter, group); + bias_scale::format_bias_scale_array(&bs_data, element_num_per_div, channel); +} + +inline void format_filter(Tensor* filter, Tensor* quantized_filter, int group) { + float max_value = find_max(*filter); + Shape& filter_shape = filter->shape(); + quantized_filter->setAligned(true); + quantized_filter->mutableData(INT8, filter->shape()); + quantized_filter->scale()[0] = max_value / 127.0f; + quantized_filter->scale()[1] = 127.0f / max_value; + + auto memory_size = filter->shape().memorySize(sizeof(float)); + auto new_data = reinterpret_cast(fpga_malloc(memory_size)); + memcpy(new_data, filter->data(), memory_size); + size_t mem_size = filter::format_filter(&new_data, + filter_shape.num(), + filter_shape.channel(), + filter_shape.height(), + filter_shape.width(), + group, + max_value); + int8_t* src = quantized_filter->mutableData(INT8, filter->shape()); + memcpy(src, new_data, mem_size); + fpga_free(new_data); + quantized_filter->flush(); +} + +inline void format_dw_filter(Tensor* filter, + Tensor* quantized_filter, + float* scale) { + int num = filter->shape().num(); + int height = filter->shape().height(); + int width = filter->shape().width(); + auto memory_size = filter->shape().memorySize(sizeof(float)); + auto new_data = (float*)fpga_malloc(memory_size); // NOLINT + memcpy(new_data, filter->data(), memory_size); + + size_t size = + filter::format_dwconv_filter(&new_data, num, height, width, scale); + float16* src = quantized_filter->mutableData(FP16, filter->shape()); + + memcpy(src, new_data, size); + quantized_filter->flush(); + + fpga_free(new_data); +} + +inline void format_fc_filter(Tensor* filter, Tensor* quantized_filter) { + float max_value = find_max(*filter); + Shape& filter_shape = filter->shape(); + quantized_filter->setAligned(true); + quantized_filter->mutableData(INT8, filter->shape()); + quantized_filter->scale()[0] = max_value / 127.0f; + quantized_filter->scale()[1] = 127.0f / max_value; + + size_t memory_size = filter->shape().memorySize(sizeof(float)); + auto new_data = (float*)fpga_malloc(memory_size); // NOLINT + memcpy(new_data, filter->data(), memory_size); + + int8_t* src = quantized_filter->mutableData(INT8, filter->shape()); + memcpy(src, new_data, quantized_filter->shape().memorySize(sizeof(int8_t))); + quantized_filter->flush(); + fpga_free(new_data); +} + +inline void split_filter_num(const ConvParam& c_param) { + ConvParam& param = const_cast(c_param); + Tensor* input = param.input; + Tensor* out = param.output; + Tensor* filter = param.filter; + auto channel = out->shape().channel(); + + int split_num = param.groups == 1 ? get_split_num(param.filter) : 1; + int filter_num_per_div = get_filter_num_per_div(filter, param.groups); + + Shape& out_shape = out->shape(); + for (int i = 0; i < split_num; i++) { + BasicConvParam* conv_param = new BasicConvParam(); + conv_param->output.setDataLocation(Device); + conv_param->output.setAligned(true); + + int filter_num = filter->shape().num(); + float16* out_address = nullptr; + float* out_scale_address = nullptr; + + ConvArgs& args = conv_param->args; + + if (split_num == 1) { + out_address = out->data(); + out_scale_address = out->scale(); + } + filter_num = i == split_num - 1 + ? channel - (split_num - 1) * filter_num_per_div // NOLINT + : filter_num_per_div; + + if (split_num != 1) { + Shape shape(NHWC, {1, out_shape.height(), out_shape.width(), filter_num}); + out_address = conv_param->output.mutableData(FP16, shape); + out_scale_address = conv_param->output.scale(); + } + Shape f_shape(NCHW, + {filter_num, + filter->shape().channel(), + filter->shape().height(), + filter->shape().width()}); + + Tensor new_filter; + float* new_filter_data = new_filter.mutableData(FP32, f_shape); + int filter_hwc = filter->shape().height() * filter->shape().width() * + filter->shape().channel(); + + memcpy(new_filter_data, + filter->data() + i * filter_num_per_div * filter_hwc, + filter_num * filter_hwc * sizeof(float)); + new_filter.flush(); + + conv_param->filter.mutableData(FP32, f_shape); + format_filter(&new_filter, &(conv_param->filter), param.groups); + + int sb_num = 2 * align_to_x(filter_num, BS_NUM_ALIGNMENT); + Tensor scale; + Tensor bias; + + int chnnnel_start = i * filter_num_per_div; + + Shape s_shape(N, {filter_num}); + float* scale_data = scale.mutableData(FP32, s_shape); + float* bias_data = bias.mutableData(FP32, s_shape); + for (int n = 0; n < filter_num; n++) { + scale_data[n] = param.scale()->data()[n + chnnnel_start]; + } + for (int n = 0; n < filter_num; n++) { + bias_data[n] = param.bias()->data()[n + chnnnel_start]; + } + Shape sb_shape(N, {sb_num}); + format_scale_bias(&scale, + &bias, + &conv_param->filter, + &conv_param->scaleBias, + param.groups); + conv_param->scaleBias.flush(); + + args.group_num = param.groups; + args.relu_enabled = param.relu.enabled; + args.sb_address = conv_param->scaleBias.data(); + args.kernel.stride_h = param.strides[1]; + args.kernel.stride_w = param.strides[0]; + args.kernel.height = new_filter.shape().height(); + args.kernel.width = new_filter.shape().width(); + + args.filter_address = conv_param->filter.data(); + args.filter_num = filter_num; + args.filter_scale_address = conv_param->filter.scale(); + args.image.address = input->data(); + args.image.scale_address = input->scale(); + args.image.channels = input->shape().channel(); + args.image.width = input->shape().width(); + args.image.height = input->shape().height(); + args.image.pad_width = param.paddings[1]; + args.image.pad_height = param.paddings[0]; + args.output.address = out_address; + args.output.scale_address = out_scale_address; + param.splitParams().push_back(conv_param); + } +} + +inline void split_channel(const ConvParam& c_param) { + ConvParam& param = const_cast(c_param); + Tensor* input = param.input; + Tensor* output = param.output; + input->syncToCPU(); + + int num = ceil(input->shape().channel() * 1.0f / 2047); + int channel = input->shape().channel() / num; + std::cout << "channel::" << channel << "num::" << num << std::endl; + Shape bs_shape(N, {channel}); + + for (int i = 0; i < num; i++) { + BasicConvParam* conv_param = new BasicConvParam(); + + // input && output; + Shape in_shape( + NCHW, {1, channel, input->shape().height(), input->shape().width()}); + conv_param->input.shareDataWith(input, in_shape, channel * i); + conv_param->output.mutableData(FP16, output->shape()); + + // filter transformation; + Shape f_shape(NCHW, {param.filter->shape().num(), channel, 1, 1}); + Tensor new_filter; + + float* dst = new_filter.mutableData(FP32, f_shape); + float* src = param.filter->data() + i * channel; + for (int n = 0; n < f_shape.num(); n++) { + memcpy(dst, src, channel * sizeof(float)); + dst += channel; + src += param.filter->shape().channel(); + } + new_filter.flush(); + format_filter(&new_filter, &(conv_param->filter), param.groups); + + Tensor bias; + Tensor scale; + + float* bias_data = bias.mutableData(FP32, bs_shape); + float* scale_data = scale.mutableData(FP32, bs_shape); + for (int c = 0; c < channel; c++) { + scale_data[c] = 1; + bias_data[c] = param.bias()->data()[c] / num; + } + scale.flush(); + bias.flush(); + format_scale_bias(&scale, + &bias, + &conv_param->filter, + &conv_param->scaleBias, + param.groups); + conv_param->scaleBias.flush(); + + ConvArgs& args = conv_param->args; + args.group_num = param.groups; + args.relu_enabled = param.relu.enabled; + args.sb_address = conv_param->scaleBias.data(); + args.kernel.stride_h = param.strides[1]; + args.kernel.stride_w = param.strides[0]; + args.kernel.height = new_filter.shape().height(); + args.kernel.width = new_filter.shape().width(); + + args.filter_address = conv_param->filter.data(); + args.filter_num = f_shape.num(); + args.filter_scale_address = conv_param->filter.scale(); + args.image.address = conv_param->input.mutableData(); + args.image.scale_address = conv_param->input.scale(); + + args.image.channels = conv_param->input.shape().channel(); + args.image.width = conv_param->input.shape().width(); + args.image.height = conv_param->input.shape().height(); + args.image.pad_width = param.paddings[1]; + args.image.pad_height = param.paddings[0]; + args.output.address = conv_param->output.mutableData(); + args.output.scale_address = conv_param->output.scale(); + param.splitParams().push_back(conv_param); + } +} + +inline int fill_split_arg(const ConvParam& c_param) { + ConvParam& param = const_cast(c_param); + Tensor* input = param.input; + Tensor* output = param.output; + if (output->shape().dimSize() == 4 && input->shape().channel() > 2047 && + input->shape().width() == 1) { + split_channel(c_param); + return 1; + } else { + split_filter_num(c_param); + return 0; + } +} + +inline bool compute_conv(const ConvParam& c_conv_params) { + ConvParam& conv_params = const_cast(c_conv_params); + std::vector& params = conv_params.splitParams(); + int ret = 0; + for (auto conv_param : params) { + ret |= compute_fpga_conv_basic(conv_param->args); + } + size_t size = params.size(); + if (ret == 0 && size > 1) { + Tensor& img = params[0]->output; + for (int i = 0; i < 1; i++) { + for (int i = 0; i < img.shape().numel(); i++) { + float value = half_to_float(img.data()[i]); + std::cout << "value:" << value << std::endl; + } + } + } + return ret == 0; +} + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/crop_pe.cpp b/lite/fpga/KD/pes/crop_pe.cpp new file mode 100644 index 00000000000..48347c39152 --- /dev/null +++ b/lite/fpga/KD/pes/crop_pe.cpp @@ -0,0 +1,88 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/fpga/KD/pes/crop_pe.hpp" + +#include + +namespace paddle { +namespace zynqmp { + +bool CropPE::dispatch() { + Tensor* input = param_.input; + input->syncToCPU(); + const auto axis = param_.axis; + std::vector shape = param_.shape; + auto* out = param_.output; + + Shape out_shape = out->shape(); + float16* src_ptr = reinterpret_cast(input->data()); + float16* dst_ptr = reinterpret_cast( + out->mutableData(DataType::FP16, out_shape)); + + std::vector offsets = param_.offsets; + + int input_c = input->shape().channel(); + int input_h = input->shape().height(); + int input_w = input->shape().width(); + + int out_c = out->shape().channel(); + int out_h = out->shape().height(); + int out_w = out->shape().width(); + if (axis == 1) { + int index = 0; + + int offset_h = offsets[0]; + int offset_w = offsets[0]; + int offset_c = offsets[0]; + + if (offsets.size() == 3) { + offset_h = offsets[1]; + offset_w = offsets[2]; + offset_c = offsets[0]; + } + + for (int h = 0; h < out_h; h++) { + for (int w = 0; w < out_w; w++) { + float16* crop_start = src_ptr + (h + offset_h) * input_w * input_c + + (offset_w * input_c) + offset_c; + std::memcpy(dst_ptr + h * (out_w * out_c) + w * out_c, + crop_start, + out_c * sizeof(float16)); + } + } + } else if (axis == 2) { + int offset_h = offsets[0]; + int offset_w = offsets[0]; + + if (offsets.size() == 2) { + offset_h = offsets[0]; + offset_w = offsets[1]; + } + + for (int h = 0; h < out_h; h++) { + float16* crop_start = + src_ptr + (h + offset_h) * input_w * input_c + (offset_w * input_c); + std::memcpy(dst_ptr + h * out_w * input_c, + crop_start, + out_w * input_c * sizeof(float16)); + } + } + out->flush(); + out->copyScaleFrom(input); + return true; +} + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/crop_pe.hpp b/lite/fpga/KD/pes/crop_pe.hpp new file mode 100755 index 00000000000..01ae7b50003 --- /dev/null +++ b/lite/fpga/KD/pes/crop_pe.hpp @@ -0,0 +1,45 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "lite/fpga/KD/float16.hpp" +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" + +namespace paddle { +namespace zynqmp { +class CropPE : public PE { + public: + bool init() { + Tensor* output = param_.output; + output->setAligned(true); + output->setDataLocation(CPU); + return true; + } + + void apply() {} + + bool dispatch(); + + CropParam& param() { return param_; } + + private: + CropParam param_; +}; +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/depthwise_conv_pe.hpp b/lite/fpga/KD/pes/depthwise_conv_pe.hpp new file mode 100755 index 00000000000..759ce8ecee2 --- /dev/null +++ b/lite/fpga/KD/pes/depthwise_conv_pe.hpp @@ -0,0 +1,102 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "lite/fpga/KD/float16.hpp" +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" +#include "lite/fpga/KD/pes/conv_process.hpp" + +namespace paddle { +namespace zynqmp { + +class DepthwiseConvPE : public PE { + public: + bool init() { + Tensor* output = param_.output; + output->setAligned(true); + output->setDataLocation(Device); + return true; + } + + void apply() { + DepthwiseConvParam& param = param_; + Tensor* input = param.input; + Tensor* output = param.output; + int channel = output->shape().channel(); + + float* new_scale_data = param_.scale()->data(); + float* new_bias_data = param_.bias()->data(); + + float16* b_data = bias_.mutableData(FP16, param_.bias()->shape()); + for (int i = 0; i < channel; i++) { + b_data[i] = float_to_half(new_bias_data[i]); + } + bias_.flush(); + + Tensor* quantized_filter = param.quantizedFilter(); + quantized_filter->mutableData(FP16, param.filter->shape()); + format_dw_filter(param.filter, param.quantizedFilter(), new_scale_data); + + DWconvArgs args = {0}; + args.bias_address = b_data; + args.filter_address = param.quantizedFilter()->data(); + args.kernel.width = param.filter->shape().height(); + args.kernel.height = param.filter->shape().width(); + args.kernel.stride_w = param.strides[0]; + args.kernel.stride_h = param.strides[1]; + args.image.address = input->data(); + args.image.channels = input->shape().channel(); + args.image.height = input->shape().height(); + args.image.width = input->shape().width(); + args.image.pad_width = param.paddings[0]; + args.image.pad_height = param.paddings[1]; + args.image.scale_address = input->scale(); + args.output.address = output->data(); + args.output.scale_address = output->scale(); + args.out_width = param.output->shape().width(); + args.out_height = param.output->shape().height(); + args.sub_conv_num = 1; + param.args = args; + + inplace_.relu_enable = param_.relu.enabled; + inplace_.power_enable = false; + inplace_.normalize_enable = false; + } + + bool dispatch() { + param_.input->syncToDevice(); + if (param_.relu.enabled) { + inplace_.relu_enable = param_.relu.enabled; + config_inplace(inplace_); + } + bool ret = compute_fpga_dwconv(param_.args) == 0; + if (param_.relu.enabled) { + inplace_.relu_enable = false; + config_inplace(inplace_); + } + return ret; + } + + DepthwiseConvParam& param() { return param_; } + + private: + DepthwiseConvParam param_; + Tensor bias_; + InplaceArgs inplace_ = {0}; +}; + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/elementwise_add_pe.hpp b/lite/fpga/KD/pes/elementwise_add_pe.hpp new file mode 100755 index 00000000000..015d861c29d --- /dev/null +++ b/lite/fpga/KD/pes/elementwise_add_pe.hpp @@ -0,0 +1,81 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" + +namespace paddle { +namespace zynqmp { + +class ElementwiseAddPE : public PE { + public: + bool init() { + Tensor* output = param_.output; + output->setAligned(true); + output->setDataLocation(Device); + return true; + } + + void apply() { + Tensor* input0 = param_.inputs[0]; + Tensor* input1 = param_.inputs[1]; + Tensor* output = param_.output; + EWAddArgs args = {0}; + args.const0 = 0x3c00; + args.const1 = 0x3c00; // =1 + args.image0.address = input0->data(); + args.image0.channels = input0->shape().channel(); + args.image0.scale_address = input0->scale(); + args.image0.height = input0->shape().height(); + args.image0.width = input0->shape().width(); + args.image0.pad_height = 0; + args.image0.pad_width = 0; + args.image1.address = input1->data(); + args.image1.channels = input1->shape().channel(); + args.image1.scale_address = input1->scale(); + args.image1.height = input1->shape().height(); + args.image1.width = input1->shape().width(); + args.image1.pad_height = 0; + args.image1.pad_width = 0; + args.output.scale_address = output->scale(); + args.output.address = output->data(); + param_.ewargs = args; + } + + bool dispatch() { + param_.inputs[0]->syncToDevice(); + param_.inputs[1]->syncToDevice(); + InplaceArgs inplace_args = {0}; + if (param_.relu.enabled) { + inplace_args.relu_enable = true; + config_inplace(inplace_args); + } + compute_fpga_ewadd(param_.ewargs); + if (param_.relu.enabled) { + inplace_args.relu_enable = false; + config_inplace(inplace_args); + } + return true; + } + + ElementwiseAddParam& param() { return param_; } + + private: + ElementwiseAddParam param_; +}; + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/fully_connected_pe.hpp b/lite/fpga/KD/pes/fully_connected_pe.hpp new file mode 100644 index 00000000000..ab02991c93c --- /dev/null +++ b/lite/fpga/KD/pes/fully_connected_pe.hpp @@ -0,0 +1,94 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" +#include "lite/fpga/KD/pes/conv_pe.hpp" +#include "lite/fpga/KD/pes/conv_process.hpp" + +namespace paddle { +namespace zynqmp { + +class FullyConnectedPE : public PE { + public: + bool init() { + Tensor* output = param_.output; + output->setAligned(true); + output->setDataLocation(Device); + return true; + } + + void apply() { + ConvParam& convParam_ = convPE_.param(); + Tensor* input = param_.input; + convParam_.input = param_.input; + convParam_.output = param_.output; + convParam_.groups = 1; + convParam_.strides = {1, 1}; + convParam_.paddings = {0, 0}; + convParam_.kernelSize = {input->shape().width(), input->shape().height()}; + convParam_.dilations = {1, 1}; + + int num = param_.filter->shape().channel(); + int chw = param_.filter->shape().num(); + + int height = param_.input->shape().height(); + int width = param_.input->shape().width(); + int filter_channel = chw / height / width; + + int channel = param_.output->shape().channel(); + Shape shape(NCHW, {num, filter_channel, height, width}); + Tensor* conv_filter = new Tensor(); + float* new_filter_data = conv_filter->mutableData(FP32, shape); + float* filter_data = param_.filter->data(); + + for (int i = 0; i < num; i++) { + for (int j = 0; j < chw; j++) { + float scale = filter_data[j * num + i]; + new_filter_data[i * chw + j] = scale; + } + } + + conv_filter->flush(); + convParam_.filter = conv_filter; + + Shape sb_shape(N, {channel}); + float* scale_data = convParam_.scale()->mutableData(FP32, sb_shape); + float* bias_data = convParam_.bias()->mutableData(FP32, sb_shape); + + for (int i = 0; i < channel; i++) { + scale_data[i] = 1.0f; + bias_data[i] = param_.bias->data()[i]; + } + convParam_.scale()->flush(); + convParam_.bias()->flush(); + + convPE_.init(); + convPE_.apply(); + } + + bool dispatch() { return convPE_.dispatch(); } + + FullyConnectedParam& param() { return param_; } + + private: + FullyConnectedParam param_; + ConvPE convPE_; +}; +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/input_pe.hpp b/lite/fpga/KD/pes/input_pe.hpp new file mode 100755 index 00000000000..0c2629fc1ee --- /dev/null +++ b/lite/fpga/KD/pes/input_pe.hpp @@ -0,0 +1,54 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" +namespace paddle { +namespace zynqmp { + +class InputPE : public PE { + public: + bool init() { + Tensor* output = param_.output; + output->setAligned(true); + output->setDataLocation(Device); + return true; + } + + bool dispatch() { + Tensor* input = param_.input; + Tensor* output = param_.output; + + Tensor* src = input; + input->flush(); + Tensor half_tensor; + if (input->dataType() == DataType::FP32) { + half_tensor.mutableData(DataType::FP16, input->shape()); + half_tensor.copyFrom(input); + src = &half_tensor; + } + output->mutableData(); + src->alignImage(output, true); + return true; + } + + InputParam& param() { return param_; } + + private: + InputParam param_; +}; +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/norm_pe.hpp b/lite/fpga/KD/pes/norm_pe.hpp new file mode 100644 index 00000000000..ad009e0e964 --- /dev/null +++ b/lite/fpga/KD/pes/norm_pe.hpp @@ -0,0 +1,121 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "lite/fpga/KD/float16.hpp" +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" + +namespace paddle { +namespace zynqmp { +class NormPE : public PE { + public: + bool init() { + Tensor* output = param_.output; + output->setAligned(true); + output->setDataLocation(Device); + return true; + } + + void apply() { + inplace_args_.relu_enable = false; + inplace_args_.power_enable = false; + inplace_args_.normalize_enable = true; + + Shape& input_shape = param_.input->shape(); + + norm_param_args_.channel = input_shape.channel(); + norm_param_args_.hight_width = input_shape.height() * input_shape.width(); + + float16* mid_data = + mid_out_.mutableData(FP16, param_.output->shape()); + + bypass_args_.input_data_type = DATA_TYPE_FP16; + bypass_args_.output_data_type = DATA_TYPE_FP16; + bypass_args_.input_layout_type = LAYOUT_HWC; + bypass_args_.output_layout_type = LAYOUT_HWC; + bypass_args_.image.address = param_.input->data(); + bypass_args_.image.scale_address = param_.input->scale(); + bypass_args_.image.channels = input_shape.channel(); + bypass_args_.image.height = input_shape.height(); + bypass_args_.image.width = input_shape.width(); + bypass_args_.output.address = mid_out_.data(); + bypass_args_.output.scale_address = mid_out_.scale(); + + norm_args_.input_image_address = mid_data; + norm_args_.image_width = input_shape.width(); + norm_args_.image_height = input_shape.height(); + norm_args_.image_channel = input_shape.channel(); + norm_args_.output_image_address = param_.output->data(); + norm_args_.output_scale_address = + reinterpret_cast(param_.output->scale()); + } + + void cpuCompute() { + Tensor input_float; + Tensor float_out; + input_float.mutableData(FP32, param_.input->shape()); + float_out.mutableData(FP32, param_.output->shape()); + + input_float.copyFrom(param_.input); + input_float.syncToCPU(); + + int channel = input_float.shape().channel(); + int height = input_float.shape().height(); + int width = input_float.shape().width(); + int cw = channel * width; + + Tensor* input = &input_float; + float* input_ptr = input->data(); + float* out_ptr = float_out.data(); + + int loop = height * width; + for (int i = 0; i < loop; i++) { + float sum = param_.epsilon; + for (int c = 0; c < channel; c++) { + float value = input_ptr[i * channel + c]; + sum += value * value; + } + float norm = sqrtf(sum); +#pragma omp parallel for + for (int c = 0; c < channel; c++) { + out_ptr[i * channel + c] = input_ptr[i * channel + c] / norm; + } + } + float_out.flush(); + param_.output->copyFrom(&float_out); + } + + bool dispatch() { + cpuCompute(); + return true; + } + + NormParam& param() { return param_; } + + private: + NormParam param_; + Tensor mid_out_; + InplaceArgs inplace_args_ = {0}; + NormalizeParameterArgs norm_param_args_ = {0}; + BypassArgs bypass_args_; + + NormalizeArgs norm_args_ = {0}; +}; +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/output_pe.hpp b/lite/fpga/KD/pes/output_pe.hpp new file mode 100644 index 00000000000..f4ce136852a --- /dev/null +++ b/lite/fpga/KD/pes/output_pe.hpp @@ -0,0 +1,53 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" + +namespace paddle { +namespace zynqmp { + +class OutputPE : public PE { + public: + bool init() { + Tensor* output = param_.output; + output->setAligned(false); + return true; + } + + bool dispatch() { + Tensor* input = param_.input; + Tensor* output = param_.output; + if (input->aligned()) { + Tensor tmp; + tmp.setAligned(true); + tmp.mutableData(FP16, input->shape()); + tmp.copyFrom(input); + tmp.unalignImage(); + output->copyFrom(&tmp); + } else { + output->copyFrom(input); + } + return true; + } + + OutputParam& param() { return param_; } + + private: + OutputParam param_; +}; +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/pooling_pe.hpp b/lite/fpga/KD/pes/pooling_pe.hpp new file mode 100644 index 00000000000..fb4c2924f0f --- /dev/null +++ b/lite/fpga/KD/pes/pooling_pe.hpp @@ -0,0 +1,176 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" + +namespace paddle { +namespace zynqmp { + +class PoolingPE : public PE { + public: + bool init() { + Tensor* output = param_.output; + output->setAligned(true); + output->setDataLocation(Device); + return true; + } + + void apply() { + Tensor* input = param_.input; + Tensor* output = param_.output; + + uint32_t k_width = param_.kernelSize[0]; + uint32_t k_height = param_.kernelSize[1]; + + if (param_.globalPooling) { + k_width = input->shape().width(); + k_height = input->shape().height(); + } + + PoolingArgs args = {0}; + args.mode = param_.type; + args.kernel_reciprocal = fp32_2_fp16(1.0f / (k_width * k_height)); + args.image.address = input->data(); + args.image.channels = input->shape().channel(); + args.image.height = input->shape().height(); + args.image.width = input->shape().width(); + args.image.pad_height = param_.paddings[0]; + args.image.pad_width = param_.paddings[1]; + args.image.scale_address = input->scale(); + args.output.address = output->mutableData(); + args.output.scale_address = output->scale(); + args.kernel.height = k_height; + args.kernel.width = k_width; + args.kernel.stride_h = param_.strides[0]; + args.kernel.stride_w = param_.strides[1]; + args.out_height = output->shape().height(); + args.out_width = output->shape().width(); + param_.poolingArgs = args; + + use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 && + (k_width > 7 || k_height > 7); + } + + void compute() { + Tensor* input = param_.input; + Tensor* output = param_.output; + input->syncToCPU(); + + Tensor float_input; + float* image_addr = float_input.mutableData(FP32, input->shape()); + float_input.copyFrom(input); + float16* data_out = output->data(); + + int image_height = input->shape().height(); + int image_width = input->shape().width(); + int image_channels = input->shape().channel(); + int image_pad_h = param_.paddings[0]; + int image_pad_w = param_.paddings[1]; + int kernel_height = param_.kernelSize[1]; + int kernel_width = param_.kernelSize[0]; + int kernel_step_h = param_.strides[0]; + int kernel_step_w = param_.strides[1]; + + int pooled_height_ = output->shape().height(); + int pooled_width_ = output->shape().width(); + + int kernel = kernel_height * kernel_width; + + float max = 0; + + for (int ph = 0; ph < pooled_height_; ++ph) { + for (int pw = 0; pw < pooled_width_; ++pw) { + int hstart = ph * kernel_step_h - image_pad_h; + int wstart = pw * kernel_step_w - image_pad_w; + int hend = std::min(hstart + kernel_height, image_height); + int wend = std::min(wstart + kernel_width, image_width); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + + kernel = (hend - hstart) * (wend - wstart); + for (int c = 0; c < image_channels; ++c) { + const int pool_index = (ph * pooled_width_ + pw) * image_channels + c; + float sum = 0; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int index = (h * image_width + w) * image_channels + c; + float value = image_addr[index]; + sum += value; + } + } + float value = sum / kernel; + if (value > max) { + max = value; + } + data_out[pool_index] = float_to_half(value); + } + } + } + output->scale()[0] = max / 127.0f; + output->scale()[1] = 127.0f / max; + output->flush(); + } + + void cpu_compute() { + Tensor* input = param_.input; + Tensor* output = param_.output; + input->syncToCPU(); + + Tensor float_input; + float_input.mutableData(FP32, input->shape()); + float_input.copyFrom(input); + float16* data_out = output->data(); + + int kernel_hw = param_.kernelSize[0] * param_.kernelSize[1]; + + float scale_max = 0; + for (int i = 0; i < output->shape().channel(); i++) { + float sum = 0; + for (int j = 0; j < kernel_hw; j++) { + float value = half_to_float(input->data()[i * kernel_hw + j]); + sum += value; + } + float value = sum / kernel_hw; + data_out[i] = float_to_half(value); + scale_max = std::max(scale_max, std::abs(value)); + } + output->scale()[0] = scale_max / 127.0f; + output->scale()[1] = 127.0f / scale_max; + std::cout << "pool scale:" << scale_max / 127.0f << std::endl; + output->flush(); + } + + bool dispatch() { + if (use_cpu_) { + compute(); + return true; + } + param_.input->syncToDevice(); + return compute_fpga_pool(param_.poolingArgs) == 0; + } + + PoolingParam& param() { return param_; } + + private: + PoolingParam param_; + bool use_cpu_; +}; + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/prior_box_pe.cpp b/lite/fpga/KD/pes/prior_box_pe.cpp new file mode 100644 index 00000000000..d4b90b5240c --- /dev/null +++ b/lite/fpga/KD/pes/prior_box_pe.cpp @@ -0,0 +1,273 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "lite/fpga/KD/pes/prior_box_pe.hpp" + +namespace paddle { +namespace zynqmp { + +struct Transform { + template + void operator()(InputIter first, + InputIter last, + OutputIter result, + UnaryOperation op) { + std::transform(first, last, result, op); + } + + template + void operator()(InputIter1 first1, + InputIter1 last1, + InputIter2 first2, + OutputIter result, + BinaryOperation op) { + std::transform(first1, last1, first2, result, op); + } +}; + +inline void ExpandAspectRatios(const std::vector &input_aspect_ratior, + bool flip, + std::vector *output_aspect_ratior) { + constexpr float epsilon = 1e-6; + output_aspect_ratior->clear(); + output_aspect_ratior->push_back(1.0f); + for (size_t i = 0; i < input_aspect_ratior.size(); ++i) { + float ar = input_aspect_ratior[i]; + bool already_exist = false; + for (size_t j = 0; j < output_aspect_ratior->size(); ++j) { + if (fabs(ar - output_aspect_ratior->at(j)) < epsilon) { + already_exist = true; + break; + } + } + if (!already_exist) { + output_aspect_ratior->push_back(ar); + if (flip) { + output_aspect_ratior->push_back(1.0f / ar); + } + } + } +} + +template +struct ClipFunctor { + inline T operator()(T in) const { + return std::min(std::max(in, 0.), 1.); + } +}; + +void PriorBoxPE::compute_prior_box() { + PriorBoxParam ¶m = param_; + Tensor *input = param.input; + Shape &input_shape = input->shape(); + + Tensor *input_image = param.image; + Shape &image_shape = input_image->shape(); + + const auto &min_sizes = param.minSizes; + const auto &max_sizes = param.maxSizes; + const auto &input_aspect_ratio = param.aspectRatios; + const bool &flip = param.flip; + const bool &clip = param.clip; + const float &step_w = param.stepW; + const float &step_h = param.stepH; + const float &offset = param.offset; + + Tensor *output_boxes = this->cachedBoxes_; + Tensor *output_variances = this->cachedVariances_; + + Tensor boxes; + Tensor variances; + + float *output_boxes_dataptr = + boxes.mutableData(FP32, output_boxes->shape()); + memset(output_boxes_dataptr, 0, boxes.memorySize()); + float *output_variances_dataptr = + variances.mutableData(FP32, output_boxes->shape()); + + std::vector aspect_ratios; + ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios); + + auto img_width = image_shape.width(); + auto img_height = image_shape.height(); + auto feature_width = input_shape.width(); + auto feature_height = input_shape.height(); + + auto stride0 = output_boxes->shape().channel() * + output_boxes->shape().height() * output_boxes->shape().width(); + auto stride1 = output_boxes->shape().height() * output_boxes->shape().width(); + auto stride2 = output_boxes->shape().width(); + + float step_width = step_w; + float step_height = step_h; + if (step_w == 0 || step_h == 0) { + step_width = static_cast(img_width) / feature_width; + step_height = static_cast(img_height) / feature_height; + } + + int num_priors = aspect_ratios.size() * min_sizes.size(); + if (!max_sizes.empty()) { + num_priors += max_sizes.size(); + } + + for (int h = 0; h < feature_height; ++h) { + for (int w = 0; w < feature_width; ++w) { + /// map origin image + float center_x = (w + offset) * step_width; + float center_y = (h + offset) * step_height; + float box_width, box_height; + int idx = 0; + for (size_t s = 0; s < min_sizes.size(); ++s) { + auto min_size = min_sizes[s]; + if (param.minMaxAspectRatiosOrder) { + box_width = box_height = min_size / 2.; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 0] = + (center_x - box_width) / img_width; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 1] = + (center_y - box_height) / img_height; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 2] = + (center_x + box_width) / img_width; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 3] = + (center_y + box_height) / img_height; + idx++; + + if (max_sizes.size() > 0) { + auto max_size = max_sizes[s]; + // square prior with size sqrt(minSize * maxSize) + box_width = box_height = sqrt(min_size * max_size) / 2.; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + + 0] = (center_x - box_width) / img_width; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + + 1] = (center_y - box_height) / img_height; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + + 2] = (center_x + box_width) / img_width; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + + 3] = (center_y + box_height) / img_height; + + idx++; + } + + // priors with different aspect ratios + for (float ar : aspect_ratios) { + if (fabs(ar - 1.) < 1e-6) { + continue; + } + box_width = min_size * sqrt(ar) / 2.; + box_height = min_size / sqrt(ar) / 2.; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + + 0] = (center_x - box_width) / img_width; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + + 1] = (center_y - box_height) / img_height; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + + 2] = (center_x + box_width) / img_width; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + + 3] = (center_y + box_height) / img_height; + + idx++; + } + + } else { + // priors with different aspect ratios + for (float ar : aspect_ratios) { + box_width = min_size * sqrt(ar) / 2.; + box_height = min_size / sqrt(ar) / 2.; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + + 0] = (center_x - box_width) / img_width; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + + 1] = (center_y - box_height) / img_height; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + + 2] = (center_x + box_width) / img_width; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + + 3] = (center_y + box_height) / img_height; + idx++; + } + if (!max_sizes.empty()) { + auto max_size = max_sizes[s]; + // square prior with size sqrt(minSize * maxSize) + box_width = box_height = sqrt(min_size * max_size) / 2.; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + + 0] = (center_x - box_width) / img_width; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + + 1] = (center_y - box_height) / img_height; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + + 2] = (center_x + box_width) / img_width; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + + 3] = (center_y + box_height) / img_height; + idx++; + } + } + } + } + } + if (clip) { + for (int i = 0; i < output_boxes->shape().numel(); i++) { + float value = output_boxes_dataptr[i]; + value = std::min(std::max(0.0f, value), 1.0f); + output_boxes_dataptr[i] = value; + } + } + + if ((param.variances.size() != 4)) { + // TODO(chonwhite) throw error; + } + + int64_t box_num = feature_height * feature_width * num_priors; + + for (int i = 0; i < box_num; i++) { + output_variances_dataptr[4 * i] = param.variances[0]; + output_variances_dataptr[4 * i + 1] = param.variances[1]; + output_variances_dataptr[4 * i + 2] = param.variances[2]; + output_variances_dataptr[4 * i + 3] = param.variances[3]; + } + + boxes.flush(); + boxes.syncToCPU(); + variances.flush(); + output_boxes->copyFrom(&boxes); + output_variances->copyFrom(&variances); +} + +void PriorBoxPE::apply() {} + +bool PriorBoxPE::dispatch() { + if (cachedBoxes_ == nullptr) { + cachedBoxes_ = new Tensor(); + cachedVariances_ = new Tensor(); + cachedBoxes_->mutableData(FP16, param_.outputBoxes->shape()); + cachedVariances_->mutableData(FP16, + param_.outputVariances->shape()); + cachedBoxes_->setDataLocation(CPU); + cachedVariances_->setDataLocation(CPU); + compute_prior_box(); + } + + param_.outputBoxes->copyFrom(this->cachedBoxes_); + + param_.outputVariances->copyFrom(this->cachedVariances_); + param_.outputBoxes->flush(); + param_.outputBoxes->syncToCPU(); + param_.outputVariances->flush(); +} + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/prior_box_pe.hpp b/lite/fpga/KD/pes/prior_box_pe.hpp new file mode 100755 index 00000000000..ca5382687c9 --- /dev/null +++ b/lite/fpga/KD/pes/prior_box_pe.hpp @@ -0,0 +1,46 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" +namespace paddle { +namespace zynqmp { + +class PriorBoxPE : public PE { + public: + bool init() { + param_.outputBoxes->setAligned(false); + param_.outputVariances->setAligned(false); + param_.outputBoxes->setDataLocation(CPU); + param_.outputVariances->setDataLocation(CPU); + return true; + } + + bool dispatch(); + + void apply(); + + PriorBoxParam& param() { return param_; } + + private: + PriorBoxParam param_; + Tensor* cachedBoxes_ = nullptr; + Tensor* cachedVariances_ = nullptr; + + void compute_prior_box(); +}; +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/relu_pe.hpp b/lite/fpga/KD/pes/relu_pe.hpp new file mode 100755 index 00000000000..0b3a0868dc4 --- /dev/null +++ b/lite/fpga/KD/pes/relu_pe.hpp @@ -0,0 +1,75 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" +namespace paddle { +namespace zynqmp { + +class ReluPE : public PE { + public: + bool init() { + Tensor* output = param_.output; + output->setAligned(true); + output->setDataLocation(Device); + return true; + } + + void apply() { + Tensor* src = param_.input; + + args_.input_data_type = DATA_TYPE_FP16; + args_.output_data_type = DATA_TYPE_FP16; + args_.input_layout_type = LAYOUT_HWC; + args_.output_layout_type = LAYOUT_HWC; + args_.image = {.address = src->data(), + .scale_address = src->scale(), + .channels = (uint32_t)src->shape().channel(), + .width = (uint32_t)src->shape().width(), + .height = (uint32_t)src->shape().height(), + .pad_width = 0u, + .pad_height = 0u}; + args_.output = { + .address = param_.output->data(), + .scale_address = param_.output->scale(), + }; + + inplace_.relu_enable = false; + inplace_.power_enable = false; + inplace_.normalize_enable = false; + } + + bool dispatch() { + inplace_.relu_enable = true; + config_inplace(inplace_); + param_.input->syncToDevice(); + param_.output->copyFrom(param_.input); + param_.output->invalidate(); + inplace_.relu_enable = false; + config_inplace(inplace_); + return true; + } + + InputParam& param() { return param_; } + + private: + InputParam param_; + BypassArgs args_; + InplaceArgs inplace_; +}; + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/resize.hpp b/lite/fpga/KD/pes/resize.hpp new file mode 100644 index 00000000000..370932dba9e --- /dev/null +++ b/lite/fpga/KD/pes/resize.hpp @@ -0,0 +1,89 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" + +namespace paddle { +namespace zynqmp { +class ResizePE : public PE { + public: + bool init() { + Tensor* output = param_.output; + output->setAligned(true); + output->setDataLocation(Device); + return true; + } + + void apply() { + Tensor* input = param_.input; + Tensor* output = param_.output; + ResizeArgs& args = args_; + + int input_width = input->shape().width(); + int input_height = input->shape().height(); + int input_channel = input->shape().channel(); + + int output_width = output->shape().width(); + int output_height = output->shape().height(); + + args.input_width = input_width; + args.input_height = input_height; + args.image_channel = input_channel; + args.output_width = output_width; + args.output_height = output_height; + float height_ratio = static_cast(input_height) / + static_cast(args.output_height); + float width_ratio = + static_cast(input_width) / static_cast(args.output_width); + args.height_ratio = *reinterpret_cast(&height_ratio); + args.width_ratio = *reinterpret_cast(&width_ratio); + + args.input_image_address = input->mutableData(); + args.output_image_address = output->mutableData(); + args.output_scale_address = reinterpret_cast(output->scale()); + } + + void compute_scale(Tensor* src, float* scale) { + float16* data = src->data(); + src->invalidate(); + float max = 0; + for (int i = 0; i < src->shape().numel(); i++) { + float value = half_to_float(data[i]); + if (value < 0) { + value = -value; + } + if (value > max) { + max = value; + } + } + scale[0] = max / 127.0; + scale[1] = 127.0 / max; + } + + bool dispatch() { + bool ret = compute_fpga_resize(args_) == 0; + return true; + } + + ResizeParam& param() { return param_; } + + private: + ResizeParam param_; + ResizeArgs args_; +}; +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/scale_pe.hpp b/lite/fpga/KD/pes/scale_pe.hpp new file mode 100755 index 00000000000..664a4db2e99 --- /dev/null +++ b/lite/fpga/KD/pes/scale_pe.hpp @@ -0,0 +1,120 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" + +namespace paddle { +namespace zynqmp { +class ScalePE : public PE { + public: + inline int gcd(int a, int b) { + while (b) { + int temp = a; + a = b; + b = temp % b; + } + return a; + } + + inline int lcm(int a, int b) { return a * b / gcd(a, b); } + bool init() { + Tensor* output = param_.output; + output->setAligned(true); + output->setDataLocation(Device); + return true; + } + + void apply() { + Tensor* input = param_.input; + Tensor* output = param_.output; + Shape& input_shape = input->shape(); + int channel = input_shape.channel(); + int repeat = 1; + int alignment = 16; + int length = channel; + + if (channel % alignment != 0 || channel < alignment) { + int c_lcm = lcm(channel, alignment); + repeat = c_lcm / (channel); + } + Shape shape(N, {channel * repeat}); + param_.alignedBias()->mutableData(FP16, shape); + param_.alignedScale()->mutableData(FP16, shape); + + float16* bias_data = param_.alignedBias()->data(); + float16* scale_data = param_.alignedScale()->data(); + + if (param_.bias != nullptr) { + float* bias_data_float = param_.bias->data(); + for (int i = 0; i < repeat; i++) { + for (int j = 0; j < length; j++) { + float16 value = float_to_half(bias_data_float[j]); + bias_data[i * length + j] = value; + } + } + } else { + float16 zero = float_to_half(0.0f); + for (int i = 0; i < repeat; i++) { + for (int j = 0; j < length; j++) { + bias_data[i * length + j] = zero; + } + } + } + + float* scale_data_float = param_.scale->data(); + for (int i = 0; i < repeat; i++) { + for (int j = 0; j < length; j++) { + float16 value = float_to_half(scale_data_float[j]); + scale_data[i * length + j] = value; + } + } + + param_.alignedScale()->flush(); + param_.alignedBias()->flush(); + + int wc = input_shape.width() * input_shape.channel(); + int wc_aligned = align_image(wc); + + ScaleArgs& args = param_.args; + args.scale_address = param_.alignedScale()->data(); + args.bias_address = param_.alignedBias()->data(); + args.wc_alignment = wc_aligned; + args.channel_alignment = channel * repeat; + + args.image.address = input->data(); + args.image.scale_address = input->scale(); + args.image.channels = channel; + args.image.height = input_shape.height(); + args.image.width = input_shape.width(); + args.image.pad_width = 0; + args.image.pad_height = 0; + args.output.address = output->data(); + args.output.scale_address = output->scale(); + } + + bool dispatch() { + param_.input->syncToDevice(); + return compute_fpga_scale(param_.args) == 0; + } + + ScaleParam& param() { return param_; } + + private: + ScaleParam param_; +}; +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/softmax_pe.cpp b/lite/fpga/KD/pes/softmax_pe.cpp new file mode 100755 index 00000000000..8e3296ffa0c --- /dev/null +++ b/lite/fpga/KD/pes/softmax_pe.cpp @@ -0,0 +1,162 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/fpga/KD/pes/softmax_pe.hpp" + +#include + +namespace paddle { +namespace zynqmp { + +#if defined(__ARM_NEON) || defined(__ARM_NEON__) +#ifndef __aarch64__ +static inline float32_t vmaxvq_f32(const float32x4_t &r) { + float32x2_t v = vmax_f32(vget_high_f32(r), vget_low_f32(r)); + return vget_lane_f32(vpmax_f32(v, v), 0); +} + +static inline float32_t vaddvq_f32(const float32x4_t &r) { + float32x2_t v = vadd_f32(vget_high_f32(r), vget_low_f32(r)); + return vget_lane_f32(vpadd_f32(v, v), 0); +} +#endif // __aarch64__ +#endif // __ARM_NEON__ + +static float find_max(const float *input, const int num_classes) { + int remain = num_classes; + float max = -std::numeric_limits::max(); +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + int loop = num_classes >> 3; + remain = num_classes & 0x7; + float32x4_t __max = vdupq_n_f32(max); + for (int i = 0; i < loop; ++i, input += 8) { + float32x4_t x0 = vld1q_f32(input); + float32x4_t x1 = vld1q_f32(input + 4); + __max = vmaxq_f32(x0, __max); + __max = vmaxq_f32(x1, __max); + } + max = vmaxvq_f32(__max); +#endif + for (int i = 0; i < remain; ++i) { + max = std::max(max, input[i]); + } + return max; +} + +static void softmax(Tensor *X, Tensor *Y) { + std::vector dims = X->shape().dims(); + int batch_size = X->shape().num(); + int num_classes = dims[X->shape().dimSize() - 1]; + int channels = X->shape().numel() / batch_size / num_classes; + float *x = X->data(); + float *y = Y->mutableData(); + +#pragma omp parallel for collapse(2) + for (int batch = 0; batch < batch_size; ++batch) { + for (int channel = 0; channel < channels; ++channel) { + size_t offset = (batch * channels + channel) * num_classes; + const float *input = x + offset; + float *output = y + offset; + // find max + float max = find_max(input, num_classes); + + // exp(x - max) + int remain = num_classes; +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + int loop = num_classes >> 3; + remain = num_classes & 0x7; + float32x4_t __max = vdupq_n_f32(max); + for (int i = 0; i < loop; ++i, input += 8, output += 8) { + float32x4_t x0 = vld1q_f32(input); + float32x4_t x1 = vld1q_f32(input + 4); + x0 = vsubq_f32(x0, __max); + x1 = vsubq_f32(x1, __max); + x0 = lite::arm::math::exp_ps(x0); + x1 = lite::arm::math::exp_ps(x1); + vst1q_f32(output, x0); + vst1q_f32(output + 4, x1); + } +#endif // __ARM_NEON__ + for (int i = 0; i < remain; ++i) { + output[i] = expf(input[i] - max); + } + + // sum(exp(x - max)) + float sum = 0.f; + output = y + offset; +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + float32x4_t __sum = vdupq_n_f32(0.f); + for (int i = 0; i < loop; ++i, output += 8) { + float32x4_t x0 = vld1q_f32(output); + float32x4_t x1 = vld1q_f32(output + 4); + __sum = vaddq_f32(x0, __sum); + __sum = vaddq_f32(x1, __sum); + } + sum += vaddvq_f32(__sum); +#endif // __ARM_NEON__ + for (int i = 0; i < remain; ++i) { + sum += output[i]; + } + + // exp(x - max) / sum + float inv_sum = 1.f / sum; + output = y + offset; +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + float32x4_t __inv_sum = vdupq_n_f32(inv_sum); + for (int i = 0; i < loop; ++i, output += 8) { + float32x4_t x0 = vld1q_f32(output); + float32x4_t x1 = vld1q_f32(output + 4); + x0 = vmulq_f32(x0, __inv_sum); + x1 = vmulq_f32(x1, __inv_sum); + vst1q_f32(output, x0); + vst1q_f32(output + 4, x1); + } +#endif + for (int i = 0; i < remain; ++i) { + output[i] *= inv_sum; + } + } + } +} + +bool SoftmaxPE::init() { + Tensor *output = param_.output; + output->setAligned(false); + output->setDataLocation(CPU); + return true; +} + +bool SoftmaxPE::dispatch() { + Tensor *input = param_.input; + Tensor *output = param_.output; + input->syncToCPU(); + + Tensor float_input; + Tensor float_output; + float_input.mutableData(DataType::FP32, input->shape()); + float_input.copyFrom(input); + + float *out_data = + float_output.mutableData(DataType::FP32, input->shape()); + + softmax(&float_input, &float_output); + float_output.flush(); + + output->copyFrom(&float_output); + return true; +} + +SoftmaxParam &SoftmaxPE::param() { return param_; } +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/softmax_pe.hpp b/lite/fpga/KD/pes/softmax_pe.hpp new file mode 100644 index 00000000000..6ac8e6bb975 --- /dev/null +++ b/lite/fpga/KD/pes/softmax_pe.hpp @@ -0,0 +1,44 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#if defined(__ARM_NEON) || defined(__ARM_NEON__) +#include +#include "lite/arm/math/funcs.h" +#endif + +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" + +namespace paddle { +namespace zynqmp { + +class SoftmaxPE : public PE { + public: + bool init(); + bool dispatch(); + + SoftmaxParam& param(); + + private: + SoftmaxParam param_; +}; + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/pes/split_pe.hpp b/lite/fpga/KD/pes/split_pe.hpp new file mode 100644 index 00000000000..074cc534e7e --- /dev/null +++ b/lite/fpga/KD/pes/split_pe.hpp @@ -0,0 +1,124 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "lite/fpga/KD/pe.hpp" +#include "lite/fpga/KD/pe_params.hpp" +namespace paddle { +namespace zynqmp { + +class SplitPE : public PE { + public: + bool init() { + std::vector outputs = param_.outputs; + for (size_t i = 0; i < outputs.size(); i++) { + Tensor* out = outputs[i]; + out->setAligned(false); + out->setDataLocation(CPU); + } + return true; + } + + std::vector stride_numel(std::vector ddim) { + std::vector strides(ddim.size()); + strides[ddim.size() - 1] = ddim[ddim.size() - 1]; + for (int i = ddim.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * ddim[i]; + } + return strides; + } + + template + inline void StridedNumelCopyWithAxis(int64_t axis, + T* dst, + const std::vector& dst_stride_numel, + T* src, + const std::vector& src_stride_numel, + int64_t size) { + int64_t before = dst_stride_numel[0] / dst_stride_numel[axis]; + int64_t src_after = src_stride_numel[axis]; + int64_t dst_after = dst_stride_numel[axis]; + + for (int64_t i = 0; i < axis; ++i) { + if (i < axis) { + } else if (i == axis) { + continue; + } else { + } + } + + for (int64_t i = 0; i < before; ++i) { + memory::Copy(dst + i * dst_after, src + i * src_after, sizeof(T) * size); + } + } + + void split3D() { int axis = param_.axis; } + + bool dispatch() { + Tensor* input = param_.input; + input->syncToCPU(); + if (input->shape().dimSize() <= 3) { + auto in_stride = stride_numel(input->shape().dims()); + int64_t axis = param_.axis; + size_t input_offset = 0; + float16* in_data = input->data(); + + for (auto& out : param_.outputs) { + float16* out_data = out->mutableData(); + auto out_stride = stride_numel(out->shape().dims()); + + StridedNumelCopyWithAxis(axis, + out_data, + out_stride, + in_data + input_offset, + in_stride, + out_stride[axis]); + input_offset += out_stride[axis]; + } + return true; + } + + std::vector outputs = param_.outputs; + + int in_channel = input->shape().channel(); + int split_channel = input->shape().channel() / param_.num; + int hw = input->shape().height() * input->shape().width(); + + float16* in_data = input->data(); + for (int i = 0; i < hw; i++) { + for (int n = 0; n < outputs.size(); n++) { + Tensor* out = outputs[n]; + float16* out_data = out->data(); + memcpy(out_data + i * split_channel, + in_data + i * in_channel + n * split_channel, + split_channel * sizeof(float16)); + } + } + for (int n = 0; n < outputs.size(); n++) { + Tensor* out = outputs[n]; + out->copyScaleFrom(input); + } + return true; + } + + SplitParam& param() { return param_; } + + private: + SplitParam param_; +}; +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/shape.hpp b/lite/fpga/KD/shape.hpp new file mode 100755 index 00000000000..12be3ac4853 --- /dev/null +++ b/lite/fpga/KD/shape.hpp @@ -0,0 +1,116 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "lite/fpga/KD/alignment.h" +#include "lite/fpga/KD/layout.hpp" + +namespace paddle { +namespace zynqmp { + +static struct NCHW nchw_; +static struct NHWC nhwc_; +static struct NC nc_; +static struct NHW nhw_; +static struct N n_; + +class Shape { + public: + explicit Shape(std::vector dims) { dims_ = dims; } + + Shape(LayoutType type, std::vector dims) { + dims_ = dims; + setLayoutType(type); + } + + Shape(const Shape& src) { + dims_ = src.dims_; + setLayoutType(src.layoutType_); + } + + bool shouldAlign() { + return layout_->alignedElementCount(dims_) != layout_->elementCount(dims_); + } + + int num() { + int index = layout_->numIndex(); + return index == -1 ? 1 : dims_[index]; + } + + int channel() { + int index = layout_->channelIndex(); + return index == -1 ? 1 : dims_[index]; + } + + int height() { + int index = layout_->heightIndex(); + return index == -1 ? 1 : dims_[index]; + } + + int width() { + int index = layout_->widthIndex(); + return index == -1 ? 1 : dims_[index]; + } + + int dimSize() { return dims_.size(); } + + std::vector dims() { return dims_; } + + size_t memorySize(int cellSize) { + return layout_->alignedElementCount(dims_) * cellSize; + } + + int numel() { return layout_->elementCount(dims_); } + + int alignedElementCount() { return layout_->alignedElementCount(dims_); } + + void setLayoutType(LayoutType layout) { + this->layoutType_ = layout; + switch (layout) { + case NCHW: + layout_ = &nchw_; + break; + case NHWC: + layout_ = &nhwc_; + break; + case NC: + layout_ = &nc_; + break; + case NHW: + layout_ = &nhw_; + break; + case N: + layout_ = &n_; + break; + default: + break; + } + } + + void print() {} + + int operator[](int index) { return dims_[index]; } + + private: + LayoutType layoutType_; + Layout* layout_ = &nhwc_; + std::vector dims_; +}; + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/tensor.hpp b/lite/fpga/KD/tensor.hpp new file mode 100644 index 00000000000..4b93c4671d1 --- /dev/null +++ b/lite/fpga/KD/tensor.hpp @@ -0,0 +1,456 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// #include "lite/core/tensor.h" + +#include "lite/fpga/KD/dl_engine.hpp" +#include "lite/fpga/KD/float16.hpp" +#include "lite/fpga/KD/llapi/zynqmp_api.h" +#include "lite/fpga/KD/shape.hpp" +// #include "lite/fpga/KD/types.hpp" + +namespace paddle { +namespace zynqmp { + +enum DataType : int { + FP32 = 0, + FP16 = 1, + INT8 = 2, + INT32 = 3, +}; + +enum DataSyncStatus : int { + Synched = 0, + Device = 1, + CPU = 2, +}; + +typedef uint16_t float16; + +inline int CellSize(DataType type) { + switch (type) { + case FP32: + return sizeof(float); + case FP16: + return sizeof(float16); + case INT32: + return sizeof(int32_t); + case INT8: + return sizeof(int8_t); + default: + return 0; + } + return 0; +} + +class PlaceHolder { + public: + PlaceHolder() {} + explicit PlaceHolder(size_t size) { + size_ = size; + data_ = fpga_malloc(size_); + } + + void* data() { return data_; } + void set_data(const void* ptr) { data_ = const_cast(ptr); } + + size_t memorySize() { return size_; } + void set_size(size_t new_size) { size_ = new_size; } + + ~PlaceHolder() { fpga_free(data_); } + + float scale_[2]; + + private: + void* data_ = nullptr; + size_t size_ = 0; +}; + +class Tensor { + public: + Tensor() { DLEngine::get_instance(); } + + int id() { return id_; } + + template + Dtype* data() { + if (placeHolder_ == nullptr) { + return nullptr; + } + void* ptr = reinterpret_cast(this->placeHolder_->data()) + + offset * CellSize(dataType_); + return reinterpret_cast(ptr); + } + + template + Dtype* mutableData(DataType dataType, const Shape& shape) { + if (this->shape_ != nullptr) { + delete shape_; + } + this->shape_ = new Shape(shape); + this->dataType_ = dataType; + return mutableData(); + } + + template + Dtype* mutableData() { + size_t memorySize = shape_->memorySize(CellSize(dataType_)); + if (placeHolder_ != nullptr) { + if (memorySize > placeHolder_->memorySize()) { + placeHolder_.reset(new PlaceHolder(memorySize)); + } + } else { + placeHolder_.reset(new PlaceHolder(memorySize)); + } + return data(); + } + + size_t memorySize() { + if (placeHolder_ == nullptr) { + return 0; + } + return placeHolder_->memorySize(); + } + + void setDataType(DataType dataType) { this->dataType_ = dataType; } + + DataType dataType() { return this->dataType_; } + + Shape& shape() { return *shape_; } + + bool aligned() { return this->aligned_; } + + void setAligned(bool aligned) { this->aligned_ = aligned; } + + float* scale() { return placeHolder_->scale_; } + + void alignImage(Tensor* dst = nullptr, bool copy = false) { + if (shape_->shouldAlign()) { + int cell_size = CellSize(this->dataType_); + char* dst_data = nullptr; + size_t mem_size = shape_->memorySize(cell_size); + if (dst == nullptr) { + dst_data = reinterpret_cast(fpga_malloc(mem_size)); + } else { + dst_data = dst->data(); + } + int wc = shape_->width() * shape_->channel(); + int wc_aligned = align_image(wc); + int remainder = wc_aligned - wc; + + char* src_start = data(); + char* dst_start = dst_data; + for (int n = 0; n < shape_->num(); n++) { + for (int h = 0; h < shape_->height(); h++) { + memcpy(dst_start, src_start, wc * cell_size); + memset(dst_start + wc * cell_size, 0, remainder * cell_size); + src_start += wc * cell_size; + dst_start += wc_aligned * cell_size; + } + } + if (dst == nullptr) { + memcpy(data(), dst_data, mem_size); + flush(); + fpga_free(dst_data); + } else { + dst->flush(); + } + } else { + if (copy) { + dst->copyFrom(this); + } else { + // TODO(chonwhite) share data. + } + } + if (dst != nullptr) { + dst->copyScaleFrom(this); + } + } + + inline void copyScaleFrom(Tensor* src) { + placeHolder_->scale_[0] = src->placeHolder_->scale_[0]; + placeHolder_->scale_[1] = src->placeHolder_->scale_[1]; + } + + void unalignImage(Tensor* dst = nullptr, bool copy = false) { + Tensor* target = dst == nullptr ? this : dst; + if (!target->aligned_) { + if (copy && dst != nullptr) { + dst->copyFrom(this); + } + return; + } + target->syncToCPU(); + if (shape_->shouldAlign()) { + int cell_size = CellSize(this->dataType_); + char* dst_data = nullptr; + size_t mem_size = shape_->memorySize(cell_size); + if (dst == nullptr) { + dst_data = reinterpret_cast(fpga_malloc(mem_size)); + } else { + dst_data = dst->data(); + } + int wc = shape_->width() * shape_->channel(); + int wc_aligned = align_image(wc); + + char* src_start = data(); + char* dst_start = dst_data; + for (int n = 0; n < shape_->num(); n++) { + for (int h = 0; h < shape_->height(); h++) { + memcpy(dst_start, src_start, wc * cell_size); + src_start += wc_aligned * cell_size; + dst_start += wc * cell_size; + } + } + if (dst == nullptr) { + memcpy(data(), dst_data, mem_size); + flush(); + fpga_free(dst_data); + } else { + dst->flush(); + } + } else { + if (copy) { + dst->copyFrom(this); + } else { + // TODO(chonwhite) share data. + } + } + } + + void shareDataWith(Tensor* src) { shareDataWith(src, src->shape()); } + + void shareDataWith(Tensor* src, const Shape& shape, int offset = 0) { + if (shape_ != nullptr) { + delete shape_; + } + this->placeHolder_ = src->placeHolder_; + this->dataType_ = src->dataType_; + this->aligned_ = src->aligned_; + this->dateLocation_ = src->dateLocation_; + this->offset = offset; + shape_ = new Shape(const_cast(shape)); + } + + void copyFrom(Tensor* src) { + if (src->dataType_ == dataType_) { + src->syncToCPU(); + memcpy(data(), src->data(), memorySize()); + copyScaleFrom(src); + flush(); + return; + } + BypassArgs args; + args.input_data_type = + src->dataType_ == FP32 ? DATA_TYPE_FP32 : DATA_TYPE_FP16; + args.output_data_type = dataType_ == FP32 ? DATA_TYPE_FP32 : DATA_TYPE_FP16; + args.input_layout_type = LAYOUT_HWC; + args.output_layout_type = LAYOUT_HWC; + args.image = {.address = src->data(), + .scale_address = src->scale(), + .channels = (uint32_t)src->shape().numel(), + .width = 1, + .height = 1, + .pad_width = 0u, + .pad_height = 0u}; + args.output = { + .address = data(), .scale_address = scale(), + }; + src->syncToDevice(); + size_t aligned_remainder = src->shape().numel() % 16; + if (aligned_remainder > 0) { + size_t dtype_size = + src->dataType_ == FP32 ? sizeof(float) : sizeof(float16); + void* dst = src->data() + src->shape().numel() * dtype_size; + memset(dst, 0, aligned_remainder * dtype_size); + fpga_flush(dst, aligned_remainder * dtype_size); + } + src->syncToDevice(); + this->invalidate(); + perform_bypass(args); + this->invalidate(); + } + + void flush() { fpga_flush(placeHolder_->data(), placeHolder_->memorySize()); } + + void invalidate() { + fpga_invalidate(placeHolder_->data(), placeHolder_->memorySize()); + } + + void sync() { + switch (synchedStatus_) { + case CPU: + flush(); + break; + case Device: + invalidate(); + break; + default: + break; + } + } + + void syncToCPU() { + if (dateLocation_ == Device) { + invalidate(); + } + } + + void syncToDevice() { + if (dateLocation_ == CPU) { + flush(); + } + } + + DataSyncStatus synchedStatus() { return synchedStatus_; } + + void setSynchedStatus(DataSyncStatus status) { synchedStatus_ = status; } + + void setDataLocation(DataSyncStatus location) { dateLocation_ = location; } + + void print() {} + + void printScale() { + if (placeHolder_ == nullptr) { + return; + } + } + + std::string dimsFileName() { + return std::to_string(shape_->num()) + "_" + + std::to_string(shape_->channel()) + "_" + + std::to_string(shape_->height()) + "_" + + std::to_string(shape_->width()) + ".txt"; + } + + void saveToFile() { std::string path = dimsFileName(); } + + void saveToFile(std::string prefix, bool with_shape) { + std::string path = prefix; + if (with_shape) { + path = path + "_" + dimsFileName(); + } else { + path = path + ".txt"; + } + saveToFile(path); + } + + friend std::ostream& operator<<(std::ostream& os, Tensor& tensor) { + os << "tensor:" + << "\n"; + os << "dims: {"; + for (int i = 0; i < tensor.shape().dimSize(); ++i) { + os << tensor.shape()[i] << " "; + } + os << "}\n"; + for (int i = 0; i < tensor.shape().numel(); i++) { + float value = 0; + if (tensor.dataType() == FP32) { + value = tensor.data()[i]; + } else { + value = half_to_float(tensor.data()[i]); + } + os << value << " "; + } + os << "\n"; + return os; + } + + void saveToFile(std::string path) { + syncToCPU(); + std::ofstream ofs; + static int counter = 0; + std::string npath = std::to_string(counter) + "_" + path; + counter++; + save_file_with_name(npath); + } + + void save_file_with_name(std::string path) { + // return; + invalidate(); + std::ofstream ofs; + + ofs.open(path); + for (int i = 0; i < shape_->numel(); i++) { + float value = 0; + if (dataType_ == FP32) { + value = data()[i]; + } else { + value = half_to_float(data()[i]); + } + ofs << value << std::endl; + } + ofs.close(); + } + + void readFromFile(std::string path) { + std::ifstream file_stream; + file_stream.open(path); + if (!file_stream) { + return; + } + int num = shape_->numel(); + invalidate(); + float max = 0.0f; + float16* data = mutableData(); + for (int i = 0; i < num; ++i) { + float value = 0; + file_stream >> value; + max = std::max(std::abs(value), max); + data[i] = float_to_half(value); + } + flush(); + placeHolder_->scale_[0] = max / 127.0f; + placeHolder_->scale_[1] = 127.0f / max; + } + + ~Tensor() { + if (shape_ != nullptr) { + delete shape_; + shape_ = nullptr; + } + } + + private: + int offset = 0; + std::shared_ptr placeHolder_; + Shape* shape_ = nullptr; + DataType dataType_ = FP32; + bool aligned_ = false; + DataSyncStatus synchedStatus_ = Synched; + DataSyncStatus dateLocation_ = Device; + + static int generateID() { + static int sID = 0; + int id = sID++; + return id; + } + + int id_ = generateID(); +}; + +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/tensor_util.cpp b/lite/fpga/KD/tensor_util.cpp new file mode 100644 index 00000000000..49c58ec91a4 --- /dev/null +++ b/lite/fpga/KD/tensor_util.cpp @@ -0,0 +1,32 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "lite/fpga/KD/tensor_util.hpp" + +namespace paddle { +namespace zynqmp { +float find_max(const Tensor& tensor) { + float max = 0; + Tensor& t = const_cast(tensor); + float* data = t.data(); + for (int i = 0; i < t.shape().numel(); i++) { + float value = data[i] > 0 ? data[i] : -data[i]; + max = std::max(value, max); + } + return max; +} +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/KD/tensor_util.hpp b/lite/fpga/KD/tensor_util.hpp new file mode 100644 index 00000000000..6022089264a --- /dev/null +++ b/lite/fpga/KD/tensor_util.hpp @@ -0,0 +1,25 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "lite/fpga/KD/tensor.hpp" + +namespace paddle { +namespace zynqmp { +float find_max(const Tensor& tensor); +} // namespace zynqmp +} // namespace paddle diff --git a/lite/fpga/lite_tensor.cc b/lite/fpga/lite_tensor.cc new file mode 100644 index 00000000000..2653dd6b217 --- /dev/null +++ b/lite/fpga/lite_tensor.cc @@ -0,0 +1,110 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/fpga/lite_tensor.h" +#include + +namespace paddle { +namespace lite { + +using value_type = int64_t; + +value_type DDimLite::production() const { + value_type res = 1; + for (size_t i = 0; i < this->size(); i++) { + res *= (*this)[i]; + } + return res; +} + +value_type DDimLite::count(int start, int end) const { + if (start < 0) { + start = 0; + } + if (end > size()) { + end = size(); + } + if (end < start) { + end = start; + } + value_type sum = 1; + for (auto i = start; i < end; ++i) { + sum *= data_[i]; + } + return sum; +} + +DDimLite DDimLite::Slice(int start, int end) const { + std::vector vec; + for (int i = start; i < end; i++) { + vec.push_back((*this)[i]); + } + return DDimLite(vec); +} + +std::string DDimLite::repr() const { + std::stringstream ss; + if (empty()) { + ss << "{}"; + return ss.str(); + } + ss << "{"; + for (size_t i = 0; i < this->size() - 1; i++) { + ss << (*this)[i] << ","; + } + if (!this->empty()) ss << (*this)[size() - 1]; + ss << "}"; + return ss.str(); +} + +void TensorLite::ShareDataWith(const TensorLite &other) { + buffer_ = other.buffer_; + dims_ = other.dims_; + zynq_tensor_ = other.zynq_tensor_; + target_ = other.target_; + lod_ = other.lod_; + memory_size_ = other.memory_size_; + throw - 1; +} + +void *TensorLite::mutable_data(size_t memory_size) { + memory_size_ = memory_size; + buffer_->ResetLazy(target_, memory_size_); + // throw -1; + std::cout << memory_size << std::endl; + return buffer_->data(); +} + +void *TensorLite::mutable_data(TargetType target, size_t memory_size) { + target_ = target; + return mutable_data(memory_size); +} + +void TensorLite::CopyDataFrom(const TensorLite &other) { + dims_ = other.dims_; + target_ = other.target_; + lod_ = other.lod_; + // memory_size_ = other.memory_size_; + // buffer_->CopyDataFrom(*other.buffer_, memory_size_); + zynq_tensor_->mutableData(other.zynq_tensor_->dataType(), + other.zynq_tensor_->shape()); +} + +// template +// void TensorLite::mutable_data_internal() { + +// } + +} // namespace lite +} // namespace paddle diff --git a/lite/fpga/lite_tensor.h b/lite/fpga/lite_tensor.h new file mode 100644 index 00000000000..296f462dc98 --- /dev/null +++ b/lite/fpga/lite_tensor.h @@ -0,0 +1,224 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include // for multiplies +#include +#include +#include +#include + +#include "lite/core/memory.h" +#include "lite/fpga/KD/tensor.hpp" + +namespace paddle { +namespace lite { + +class DDimLite; +class TensorLite; + +using DDim = lite::DDimLite; +using Tensor = lite::TensorLite; + +class DDimLite { + public: + using value_type = int64_t; + + DDimLite() = default; + + explicit DDimLite(const std::vector &x) { ConstructFrom(x); } + + void ConstructFrom(const std::vector &x) { data_ = x; } + + value_type operator[](int offset) const { return data_[offset]; } + value_type &operator[](int offset) { return data_[offset]; } + std::vector Vectorize() const { return data_; } + + size_t size() const { return data_.size(); } + bool empty() const { return data_.empty(); } + + value_type production() const; + + const std::vector &data() const { return data_; } + value_type count(int start, int end) const; + + DDimLite Slice(int start, int end) const; + + DDimLite Flattern2D(int col) const { + return DDimLite(std::vector( + {Slice(0, col).production(), Slice(col, size()).production()})); + } + + std::string repr() const; + + friend std::ostream &operator<<(std::ostream &os, const DDimLite &dims) { + os << dims.repr(); + return os; + } + + friend bool operator==(const DDimLite &a, const DDimLite &b) { + if (a.size() != b.size()) return false; + for (size_t i = 0; i < a.size(); i++) { + if (a[i] != b[i]) return false; + } + return true; + } + + friend bool operator!=(const DDimLite &a, const DDimLite &b) { + return !(a == b); + } + + private: + std::vector data_; +}; + +using LoD = std::vector>; + +// A light-weight tensor implementation. +class TensorLite { + public: + TensorLite() : buffer_(std::make_shared()) {} + + template + void Assign(DType *data, const DimT &dim) { + Resize(dim); + auto *dst = mutable_data(Target); + CopySync( + dst, data, dim.production() * sizeof(DType), IoDirection::HtoD); + } + + // T is the data type and R is the return type + // For OpenCL, the return type can be cl::Buffer + // and the data type can be float/int8_t. + // For other devices, T and R may be the same type. + template + const R *data() const { + return zynq_tensor_->data(); + } + + void Resize(const DDimLite &ddim) { dims_ = ddim; } + void Resize(const std::vector &x) { dims_ = DDimLite(x); } + + const DDimLite &dims() const { return dims_; } + int64_t numel() const { return dims_.production(); } + + const LoD &lod() const { return lod_; } + LoD *mutable_lod() { return &lod_; } + + // T is the data type and R is the return type + // For OpenCL, the return type can be cl::Buffer + // and the data type can be float/int8_t. + // For other devices, T and R may be the same type. + template + R *mutable_data(); + + // T is the data type and R is the return type + // For OpenCL, the return type can be cl::Buffer + // and the data type can be float/int8_t. + // For other devices, T and R may be the same type. + template + R *mutable_data(TargetType target); + void *mutable_data(size_t memory_size); + void *mutable_data(TargetType target, size_t memory_size); + + const void *raw_data() const { return buffer_->data(); } + + size_t data_size() const { return this->dims().production(); } + + size_t memory_size() const { return zynq_tensor_->memorySize(); } + + bool IsInitialized() const { return buffer_->data(); } + + // Other share data to this. + void ShareDataWith(const TensorLite &other); + + void CopyDataFrom(const TensorLite &other); + + TargetType target() const { return target_; } + + zynqmp::Tensor *ZynqTensor() const { return zynq_tensor_; } + + friend std::ostream &operator<<(std::ostream &os, const TensorLite &tensor) { + os << "Tensor:" << '\n'; + os << "dim: " << tensor.dims() << '\n'; + for (int i = 0; i < tensor.dims().production(); i++) { + os << tensor.template data()[i] << " "; + } + os << "\n"; + return os; + } + + private: + TargetType target_{TargetType::kHost}; + DDimLite dims_; + std::shared_ptr buffer_; + LoD lod_; + size_t memory_size_{}; + + zynqmp::Tensor *zynq_tensor_ = new zynqmp::Tensor(); + + template + void mutable_data_internal(); +}; + +template +R *TensorLite::mutable_data() { + std::vector v; + for (int i = 0; i < dims_.size(); i++) { + v.push_back(dims_[i]); + } + zynqmp::LayoutType layout_type = zynqmp::NCHW; + switch (v.size()) { + case 1: + layout_type = zynqmp::N; + break; + case 2: + layout_type = zynqmp::NC; + break; + case 3: + layout_type = zynqmp::NHW; + break; + case 4: + layout_type = zynqmp::NCHW; + break; + } + zynqmp::Shape input_shape(layout_type, v); + + zynqmp::DataType data_type = zynqmp::FP32; + if (typeid(T) == typeid(float)) { + data_type = zynqmp::FP32; + } + if (typeid(T) == typeid(zynqmp::float16)) { + data_type = zynqmp::FP16; + } + return zynq_tensor_->mutableData(data_type, input_shape); +} + +template +R *TensorLite::mutable_data(TargetType target) { + target_ = target; + return mutable_data(); +} + +template +bool TensorCompareWith(const TensorT &a, const TensorT &b) { + if (a.dims() != b.dims()) return false; + if (memcmp(a.raw_data(), b.raw_data(), a.data_size()) != 0) return false; + return true; +} + +} // namespace lite +} // namespace paddle diff --git a/lite/fpga/target_wrapper.cc b/lite/fpga/target_wrapper.cc new file mode 100644 index 00000000000..4d5350b8fdd --- /dev/null +++ b/lite/fpga/target_wrapper.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/target_wrapper.h" +#include "lite/fpga/KD/llapi/zynqmp_api.h" +#include "lite/utils/all.h" +#ifdef LITE_WITH_FPGA +namespace paddle { +namespace lite { + +void* TargetWrapper::Malloc(size_t size) { + return zynqmp::fpga_malloc(size); +} + +void TargetWrapper::Free(void* ptr) { zynqmp::fpga_free(ptr); } + +void TargetWrapper::MemcpySync(void* dst, + const void* src, + size_t size, + IoDirection dir) { + memcpy(dst, src, size); +} + +} // namespace lite +} // namespace paddle +#endif diff --git a/lite/gen_code/CMakeLists.txt b/lite/gen_code/CMakeLists.txt new file mode 100644 index 00000000000..d83657ad3e2 --- /dev/null +++ b/lite/gen_code/CMakeLists.txt @@ -0,0 +1,49 @@ +if (LITE_ON_TYNY_PUBLISH) + return() +endif() + +lite_cc_library(gen_code SRCS gen_code.cc + DEPS program op scope + cpp_op_desc + HVY_DEPS operator) +lite_cc_library(paddle_infer_gencode SRCS paddle_infer.cc DEPS program utils) + +lite_cc_test(test_gen_code SRCS gen_code_test.cc + DEPS gen_code tensor ${host_kernels} ${ops} + compatible_pb + model_parser + X86_DEPS ${x86_kernels} + ARM_DEPS ${arm_kernels} + NPU_DEPS ${npu_kernels} + CL_DEPS ${opencl_kernels} + FPGA_DEPS ${fpga_kernels} + EXCLUDE_COMPILE_DEPS "ON" + ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) + +lite_cc_library(__generated_code__ + SRCS ${CMAKE_BINARY_DIR}/lite/gen_code/__generated_code__.cc + DEPS scope op kernel paddle_infer_gencode + EXCLUDE_COMPILE_DEPS "ON" +) +if(WITH_TESTING) + add_dependencies(__generated_code__ test_gen_code) + add_dependencies(__generated_code__ extern_lite_download_lite_naive_model_tar_gz) +endif(WITH_TESTING) + +lite_cc_binary(paddle_code_generator SRCS paddle_code_generator.cc DEPS model_parser gen_code gflags) + +# TODO(xxx): fix the gen code bug on ios +if(IOS) + return() +endif() + +lite_cc_test(test_generated_code SRCS generated_code_test.cc DEPS __generated_code__ + ${ops} ${host_kernels} + X86_DEPS ${x86_kernels} + ARM_DEPS ${arm_kernels} + NPU_DEPS ${npu_kernels} + CL_DEPS ${opencl_kernels} + FPGA_DEPS ${fpga_kernels} + EXCLUDE_COMPILE_DEPS "ON" +) + diff --git a/lite/gen_code/gen_code.cc b/lite/gen_code/gen_code.cc new file mode 100644 index 00000000000..0d8f4d0d192 --- /dev/null +++ b/lite/gen_code/gen_code.cc @@ -0,0 +1,223 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/gen_code/gen_code.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace gencode { + +void Module::AddWeight(const std::string &name, const TensorRepr &tensor) { + auto w_name = WeightUniqueName(); + Line(string_format("// Create weight: %s", name.c_str())); + // auto* w0 = scope.Var("w0")->GetMutable(); + Line(string_format("auto* %s = scope->Var(%s)->GetMutable();", + w_name.c_str(), + Repr(name).c_str())); + // lite::DDim w_ddim({1, 2}) + Line(string_format("lite::DDim %s_ddim(std::vector(%s));", + w_name.c_str(), + tensor.ddim.repr().c_str())); + // std::vector w_data({}); + auto w_data_repr = DataRepr( + std::string(static_cast(tensor.raw_data), tensor.num_bytes), + tensor.dtype); + Line(string_format("std::vector<%s> %s_data({%s});", + PrecisionToStr(tensor.dtype).c_str(), + w_name.c_str(), + w_data_repr.c_str())); + // w0->Assign(w0_data.data(), w0_ddim); + Line(string_format( + "%s->Assign<%s, lite::DDim, TARGET(kX86)>(%s_data.data(), %s_ddim);", + w_name.c_str(), + PrecisionToStr(tensor.dtype).c_str(), + w_name.c_str(), + w_name.c_str())); + Line(""); +} + +void Module::AddHeaderIncludeGenCode() { + Line(""); + Line("#include "); + Line("#include "); + Line("#include \"lite/core/tensor.h\""); + Line("#include \"lite/core/context.h\""); + Line("#include \"lite/gen_code/paddle_infer.h\""); + Line("#include \"lite/core/op_registry.h\""); + Line("#include \"lite/core/scope.h\""); + Line("#include \"lite/model_parser/cpp/op_desc.h\""); + Line(""); + Line(""); +} + +std::string Module::DataRepr(const std::string &raw_data, PrecisionType dtype) { + STL::stringstream ss; + switch (dtype) { + case PRECISION(kFloat): { + const float *raw = reinterpret_cast(raw_data.c_str()); + int num_elems = raw_data.size() / sizeof(float); + if (num_elems) { + for (int i = 0; i < num_elems - 1; i++) { + ss << raw[i] << ","; + } + ss << raw[num_elems - 1]; + } + } break; + + default: + LOG(FATAL) << "Unsupported type " << PrecisionToStr(dtype); + } + return ss.str(); +} + +void Module::AddOpDescHelper(const std::string &op_id, + const cpp::OpDesc &desc) { + std::string desc_var = op_id + "_desc"; + Line(string_format("lite::cpp::OpDesc %s;", desc_var.c_str())); + auto vec_str_repr = [](const std::vector &vec) { + return Repr(vec); + }; + for (auto &item : desc.inputs()) { + Line(string_format("%s.SetInput(%s, %s);", + desc_var.c_str(), + Repr(item.first).c_str(), + vec_str_repr(item.second).c_str())); + } + + for (auto &item : desc.outputs()) { + Line(string_format("%s.SetOutput(%s, %s);", + desc_var.c_str(), + Repr(item.first).c_str(), + vec_str_repr(item.second).c_str())); + } + + auto attr_repr = [&](const std::string &name) -> std::string { + using AttrType = OpDescAPI::AttrType; + auto type = desc.GetAttrType(name); + + switch (type) { + case AttrType::INT: + return std::to_string(desc.GetAttr(name)); + case AttrType::FLOAT: + return std::to_string(desc.GetAttr(name)); + case AttrType::BOOLEAN: + return std::to_string(desc.GetAttr(name)); + case AttrType::STRING: + return "\"" + desc.GetAttr(name) + "\""; + case AttrType::FLOATS: { + auto vals = desc.GetAttr>(name); + return "{" + Join(vals, ",") + "}"; + } + case AttrType::INTS: { + auto vals = desc.GetAttr>(name); + return "{" + Join(vals, ",") + "}"; + } + + case AttrType::STRINGS: { + std::vector tmp; + auto vals = desc.GetAttr>(name); + std::transform(vals.begin(), + vals.end(), + std::back_inserter(tmp), + [](const std::string &x) { return Repr(x); }); + return "{" + Join(tmp, ",") + "}"; + } + default: + LOG(FATAL) << "Unsupported attribute type: " << static_cast(type); + } + return ""; + }; + + auto attr_type_repr = [&](const std::string &name) -> std::string { + using AttrType = OpDescAPI::AttrType; + auto type = desc.GetAttrType(name); + + switch (type) { + case AttrType::INT: + return "int"; + case AttrType::FLOAT: + return "float"; + case AttrType::BOOLEAN: + return "bool"; + case AttrType::STRING: + return "std::string"; + case AttrType::FLOATS: + return "std::vector"; + case AttrType::STRINGS: + return "std::vector"; + case AttrType::INTS: + return "std::vector"; + default: + LOG(FATAL) << "Unsupported attribute type: " << static_cast(type); + } + + return "unk_t"; + }; + for (auto &item : desc.AttrNames()) { + // Drop the python information. + if (item == "op_callstack") continue; + auto attr_type = attr_type_repr(item); + auto attr_val = attr_repr(item); + Line(string_format("%s.SetAttr<%s>(%s, %s);", // + desc_var.c_str(), + attr_type.c_str(), + Repr(item).c_str(), + attr_val.c_str())); + } +} + +void Module::AddOp(const cpp::OpDesc &op) { + auto op_name = OpUniqueName(); + AddOpDescHelper(op_name, op); + + LOG(INFO) << "add op " << op_name; + + Line(string_format("// Create Op: %s", op.Type().c_str())); + + Line(string_format("auto %s = lite::LiteOpRegistry::Global().Create(\"%s\");", + op_name.c_str(), + op.Type().c_str())); + + CHECK(op.HasAttr(kKernelTypeAttr)) + << "the kernel type should be specified before generate code."; + auto kernel_type = op.GetAttr(kKernelTypeAttr); + Line(string_format("%s->Attach(%s, exec_scope);", + op_name.c_str(), + (op_name + "_desc").c_str())); + + // Create kernel + auto kernel_name = KernelUniqueName(); + Line(string_format( + "auto %s = std::move(%s->CreateKernels(valid_places, \"%s\").front());", + kernel_name.c_str(), + op_name.c_str(), + kernel_type.c_str())); + + // Set Context for kernel + // clang-format off + Line(string_format("%s->SetContext(lite::ContextScheduler::Global().NewContext(%s->target()));", kernel_name.c_str(), kernel_name.c_str())); // NOLINT + // clang-format on + + Line(string_format("ops.push_back(%s);", op_name.c_str())); + Line(string_format("kernels.push_back(std::move(%s));", kernel_name.c_str())); + + op_kinds_.insert(op.Type()); + kernel_kinds_.insert(kernel_type); +} +} // namespace gencode +} // namespace lite +} // namespace paddle diff --git a/lite/gen_code/gen_code.h b/lite/gen_code/gen_code.h new file mode 100644 index 00000000000..7dea36636af --- /dev/null +++ b/lite/gen_code/gen_code.h @@ -0,0 +1,258 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "lite/core/framework.pb.h" +#include "lite/core/program.h" +#include "lite/core/target_wrapper.h" +#include "lite/core/tensor.h" +#include "lite/model_parser/compatible_pb.h" +#include "lite/model_parser/cpp/op_desc.h" +#include "lite/model_parser/desc_apis.h" +#include "lite/model_parser/pb/op_desc.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace gencode { + +struct TensorRepr { + TensorRepr() = default; + TensorRepr(PrecisionType dtype, + const std::vector &ddim, + void *raw_data, + size_t num_bytes) + : dtype(dtype), ddim(ddim), raw_data(raw_data), num_bytes(num_bytes) {} + + PrecisionType dtype; + lite::DDim ddim; + const void *raw_data; + size_t num_bytes{}; +}; + +class Module { + std::vector ops; + std::vector weights; + std::vector tmp_vars_; + STL::stringstream stream_; + std::set kernel_kinds_; + std::set op_kinds_; + + int line_indent_{}; + const int indent_unit_{2}; + + public: + void NewOp(const cpp::OpDesc &desc) { ops.push_back(desc); } + void NewWeight(const TensorRepr &x) { weights.push_back(x); } + void NewTmpVar(const std::string &x) { tmp_vars_.push_back(x); } + + STL::stringstream &stream() { return stream_; } + + void AddHeaderIncludeGenCode(); + + void AddNamespaceBegin() { + Line("namespace paddle {"); + Line("namespace gencode{"); + Line(""); + } + + void AddNamespaceEnd() { + Line(""); + Line("} // namespace gencode"); + Line("} // namespace paddle"); + } + + void AddInitFuncBegin() { + Line("void PaddlePredictor::Init() {"); + Line(""); + IncIndent(); + } + + void AddInitFuncEnd() { + DecIndent(); + Line(""); + Line("}"); + } + + void AddScopeDecl() { + Line("lite::Scope* scope = static_cast(raw_scope_);"); + + // clang-format off + Line("lite::Scope* exec_scope = static_cast(raw_exe_scope_);"); // NOLINT + // clang-format on + + // Create feed and fetch in exec_scope. + Line(string_format("exec_scope->Var(%s);", Repr("feed").c_str())); + Line(string_format("exec_scope->Var(%s);", Repr("fetch").c_str())); + } + + void AddValidPlaceDecl() { + // clang-format off + Line("std::vector valid_places({lite::Place({TARGET(kX86), PRECISION(kFloat), DATALAYOUT(kNCHW)}), lite::Place({TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)})});"); // NOLINT + // clang-format on + } + + void AddMemberCast() { + Line("// Cast the raw members"); + // clang-format off + Line(string_format("auto& ops = *static_cast>*>(raw_ops_);")); // NOLINT + Line(string_format("auto& kernels = *static_cast>*>(raw_kernels_);")); // NOLINT + // clang-format on + Line(""); + } + + void AddWeight(const std::string &name, const TensorRepr &tensor); + + void AddTmpVar(const std::string &x) { + Line(string_format("// Create temporary variable: %s", x.c_str())); + Line(string_format("exec_scope->Var(%s);", Repr(x).c_str())); + Line(""); + } + + void AddOp(const cpp::OpDesc &op); + + void AddOpDescHelper(const std::string &op_id, const cpp::OpDesc &desc); + + void AddOpCompileDeps() { + Line(""); + Line("// Add Operator compile deps"); + for (auto &op_type : op_kinds_) { + Line(string_format("USE_LITE_OP(%s)", op_type.c_str())); + } + Line(""); + } + void AddKernelCompileDeps() { + Line("// Add Kernel compile deps"); + + std::string op_type, alias; + Place place; + for (auto &kernel_type : kernel_kinds_) { + KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place); + Line(string_format("USE_LITE_KERNEL(%s, %s, %s, %s, %s)", // + op_type.c_str(), // + TargetRepr(place.target).c_str(), + PrecisionRepr(place.precision).c_str(), + DataLayoutRepr(place.layout).c_str(), + alias.c_str())); + } + } + + private: + std::string WeightUniqueName() const { + return "w_" + std::to_string(weight_counter_++); + } + std::string TmpVarUniqueName() const { + return "tmp_" + std::to_string(tmp_var_counter_++); + } + std::string OpUniqueName() const { + return "op_" + std::to_string(op_counter_++); + } + std::string KernelUniqueName() const { + return "kernel_" + std::to_string(kernel_counter_++); + } + + std::string DataRepr(const std::string &raw_data, PrecisionType dtype); + + void IncIndent() { line_indent_++; } + void DecIndent() { line_indent_--; } + + void Line(const std::string &x) { + std::string indent_str(line_indent_ * indent_unit_, ' '); + stream() << indent_str << x << "\n"; + } + + private: + mutable int weight_counter_{}; + mutable int tmp_var_counter_{}; + mutable int op_counter_{}; + mutable int kernel_counter_{}; +}; + +class ProgramCodeGenerator { + public: + ProgramCodeGenerator(const framework::proto::ProgramDesc &program, + const lite::Scope &exec_scope) + : program_(program), exec_scope_(exec_scope) {} + + std::string GenCode() { + Module m; + m.AddHeaderIncludeGenCode(); + m.AddNamespaceBegin(); + m.AddInitFuncBegin(); + m.AddMemberCast(); + m.AddScopeDecl(); + m.AddValidPlaceDecl(); + + AddWeights(&m); + AddTmpVars(&m); + AddOps(&m); + + m.AddInitFuncEnd(); + m.AddNamespaceEnd(); + + m.AddOpCompileDeps(); + m.AddKernelCompileDeps(); + + return m.stream().str(); + } + + void AddWeights(Module *m) { + for (auto &var : program_.blocks(0).vars()) { + if (var.persistable()) { + auto name = var.name(); + if (name == "feed" || name == "fetch") continue; + const auto &tensor = exec_scope_.FindVar(name)->Get(); + TensorRepr repr; + TensorToRepr(tensor, &repr); + m->AddWeight(name, repr); + } + } + } + void AddTmpVars(Module *m) { + for (auto &var : program_.blocks(0).vars()) { + if (!var.persistable()) { + m->AddTmpVar(var.name()); + } + } + } + void AddOps(Module *m) { + for (auto &pb_op : program_.blocks(0).ops()) { + auto op = pb_op; + lite::pb::OpDesc pb_desc(&op); + lite::cpp::OpDesc cpp_desc; + TransformOpDescAnyToCpp(pb_desc, &cpp_desc); + m->AddOp(cpp_desc); + } + } + + private: + void TensorToRepr(const lite::Tensor &tensor, TensorRepr *repr) { + repr->ddim = tensor.dims(); + // TODO(Superjomn) support other types. + repr->dtype = PRECISION(kFloat); + repr->raw_data = tensor.data(); + repr->num_bytes = repr->ddim.production() * sizeof(float); + } + + private: + const framework::proto::ProgramDesc &program_; + const lite::Scope &exec_scope_; +}; + +} // namespace gencode +} // namespace lite +} // namespace paddle diff --git a/lite/gen_code/gen_code_test.cc b/lite/gen_code/gen_code_test.cc new file mode 100644 index 00000000000..caf0921cc17 --- /dev/null +++ b/lite/gen_code/gen_code_test.cc @@ -0,0 +1,165 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/gen_code/gen_code.h" +#include +#include +#include +#include +#include +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/context.h" +#include "lite/core/scope.h" +#include "lite/core/tensor.h" +#include "lite/model_parser/compatible_pb.h" +#include "lite/model_parser/cpp/op_desc.h" +#include "lite/model_parser/model_parser.h" +#include "lite/model_parser/pb/program_desc.h" + +DEFINE_string(optimized_model, "", ""); +DEFINE_string(generated_code_file, "__generated_code__.cc", ""); + +namespace paddle { +namespace lite { +namespace gencode { + +// Manually construct a program. +TEST(gen_code, manual) { + // For holding the weights. + lite::Scope scope; + // For holding the temporary variables. + auto &tmp_scope = scope.NewScope(); + + // Create weight variables. + auto *w0 = scope.Var("w0")->GetMutable(); + // Create temporary variables. + auto *a = tmp_scope.Var("x")->GetMutable(); + tmp_scope.Var("out")->GetMutable(); + + // Set weights. + std::vector w0_data({0, 1, 2, 3}); + std::vector a_data({0, 1, 2, 3}); +#ifdef LITE_WITH_ARM + w0->Assign( + w0_data.data(), lite::DDim{std::vector({2, 2})}); + a->Assign( + a_data.data(), lite::DDim{std::vector({2, 2})}); +#else + w0->Assign( + w0_data.data(), lite::DDim{std::vector({2, 2})}); + a->Assign( + a_data.data(), lite::DDim{std::vector({2, 2})}); +#endif + + std::vector valid_places({ +#ifdef LITE_WITH_ARM + Place{TARGET(kARM), PRECISION(kFloat)}, +#else + Place{TARGET(kX86), PRECISION(kFloat)}, +#endif + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kHost), PRECISION(kAny)}, + }); + auto mul_op = LiteOpRegistry::Global().Create("mul"); + cpp::OpDesc mul_op_desc; + mul_op_desc.SetType("mul"); + mul_op_desc.SetInput("X", {"x"}); + mul_op_desc.SetInput("Y", {"w0"}); + mul_op_desc.SetAttr("x_num_col_dims", 1); + mul_op_desc.SetAttr("y_num_col_dims", 1); + mul_op_desc.SetOutput("Out", {"out"}); + + mul_op->Attach(mul_op_desc, &tmp_scope); + auto mul_kernel = std::move(mul_op->CreateKernels(valid_places).front()); +#ifdef LITE_WITH_ARM + auto fc_ctx = ContextScheduler::Global().NewContext(TARGET(kARM)); +#else + auto fc_ctx = ContextScheduler::Global().NewContext(TARGET(kX86)); +#endif + mul_op->CheckShape(); + mul_op->InferShape(); + mul_kernel->SetContext(std::move(fc_ctx)); + mul_kernel->Launch(); +} + +TEST(gen_code, auto_gen) { + std::vector w0_data({0, 1, 2, 3}); + TensorRepr w0(PRECISION(kFloat), + std::vector({2, 2}), + w0_data.data(), + w0_data.size() * sizeof(float)); + + std::vector w1_data({0.01, 1.2, 2.3, 3.4, 1.1, 2.2}); + TensorRepr w1(PRECISION(kFloat), + std::vector({3, 2}), + w1_data.data(), + w1_data.size() * sizeof(float)); + + cpp::OpDesc op0; + op0.SetType("mul"); + op0.SetInput("X", {"a", "b"}); + op0.SetOutput("Out", {"out0"}); + op0.SetAttr("desc", "this is a desc"); + op0.SetAttr("x_col", 1); + op0.SetAttr("y_col", 2); +#ifdef LITE_WITH_ARM + op0.SetAttr(kKernelTypeAttr, "arm"); +#else + op0.SetAttr(kKernelTypeAttr, "x86"); +#endif + + gencode::Module module; + module.AddHeaderIncludeGenCode(); + + module.AddNamespaceBegin(); + module.AddInitFuncBegin(); + + module.AddMemberCast(); + + module.AddWeight("w0", w0); + module.AddWeight("w1", w1); + module.AddTmpVar("a"); + module.AddTmpVar("b"); + + module.AddOp(op0); + + module.AddInitFuncEnd(); + module.AddNamespaceEnd(); + + LOG(INFO) << module.stream().str(); +} + +TEST(gen_code, optimized_program) { + lite::Scope scope; + cpp::ProgramDesc cpp_desc; + LoadModelPb(FLAGS_optimized_model, &scope, &cpp_desc); + + framework::proto::ProgramDesc pb_proto_desc; + lite::pb::ProgramDesc pb_desc(&pb_proto_desc); + TransformProgramDescCppToAny(cpp_desc, &pb_desc); + + ProgramCodeGenerator codegen(pb_proto_desc, scope); + + std::ofstream file(FLAGS_generated_code_file); + + file << codegen.GenCode(); + + file.close(); +} + +} // namespace gencode +} // namespace lite +} // namespace paddle diff --git a/lite/gen_code/generated_code_test.cc b/lite/gen_code/generated_code_test.cc new file mode 100644 index 00000000000..199ba579d47 --- /dev/null +++ b/lite/gen_code/generated_code_test.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/gen_code/paddle_infer.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { + +TEST(PaddlePredictor, Init) { + gencode::PaddlePredictor predictor; + predictor.Init(); +} + +#ifdef LITE_WITH_X86 +TEST(PaddlePredictor, RunX86) { + gencode::PaddlePredictor predictor; + predictor.Init(); + + LOG(INFO) << "run the generated code"; + auto input_tensor = predictor.GetInput(0); + input_tensor->Resize(std::vector({100, 100})); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < 100 * 100; i++) { + data[i] = i; + } + + predictor.Run(); + + auto output_tensor = predictor.GetOutput(0); + LOG(INFO) << "output: " << output_tensor->data()[0]; +} +#endif + +#ifdef LITE_WITH_ARM +TEST(PaddlePredictor, RunARM) { + gencode::PaddlePredictor predictor; + predictor.Init(); + + LOG(INFO) << "run the generated code"; + auto input_tensor = predictor.GetInput(0); + input_tensor->Resize(std::vector({1, 100})); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < 100; i++) { + data[i] = 1; + } + + predictor.Run(); + + std::vector result({0.4350058, + -0.6048313, + -0.29346266, + 0.40377066, + -0.13400325, + 0.37114543, + -0.3407839, + 0.14574292, + 0.4104212, + 0.8938774}); + + auto output_tensor = predictor.GetOutput(0); + auto output_shape = output_tensor->shape(); + ASSERT_EQ(output_shape.size(), 2); + ASSERT_EQ(output_shape[0], 1); + ASSERT_EQ(output_shape[1], 500); + + int step = 50; + for (int i = 0; i < result.size(); i += step) { + EXPECT_NEAR(output_tensor->data()[i], result[i], 1e-6); + } +} +#endif + +} // namespace lite +} // namespace paddle diff --git a/lite/gen_code/paddle_code_generator.cc b/lite/gen_code/paddle_code_generator.cc new file mode 100644 index 00000000000..344a63297a4 --- /dev/null +++ b/lite/gen_code/paddle_code_generator.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/gen_code/gen_code.h" +#include "lite/model_parser/model_parser.h" +#include "lite/model_parser/pb/program_desc.h" + +DEFINE_string(optimized_model, "", ""); +DEFINE_string(generated_code_file, "__generated_code__.cc", ""); + +namespace paddle { +namespace lite { +namespace gencode { + +void GenCode(const std::string& model_dir, const std::string& out_file) { + lite::Scope scope; + cpp::ProgramDesc cpp_desc; + LoadModelPb(model_dir, &scope, &cpp_desc); + + framework::proto::ProgramDesc pb_proto_desc; + lite::pb::ProgramDesc pb_desc(&pb_proto_desc); + TransformProgramDescCppToAny(cpp_desc, &pb_desc); + + ProgramCodeGenerator codegen(pb_proto_desc, scope); + + std::ofstream file(out_file); + + file << codegen.GenCode(); + + file.close(); +} + +} // namespace gencode +} // namespace lite +} // namespace paddle + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + paddle::lite::gencode::GenCode(FLAGS_optimized_model, + FLAGS_generated_code_file); + return 0; +} diff --git a/lite/gen_code/paddle_infer.cc b/lite/gen_code/paddle_infer.cc new file mode 100644 index 00000000000..180e75e1a6c --- /dev/null +++ b/lite/gen_code/paddle_infer.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/gen_code/paddle_infer.h" +#include "lite/core/op_lite.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace gencode { + +void Tensor::Resize(const Tensor::ddim_t &shape) { + CHECK(raw_mutable_tensor_); + auto *tensor = static_cast(raw_mutable_tensor_); + tensor->Resize(shape); +} + +std::vector Tensor::shape() const { + CHECK(raw_tensor_); + auto *tensor = static_cast(raw_tensor_); + return tensor->dims().Vectorize(); +} + +#define FOR_EACH_TYPE(HANDLE) \ + HANDLE(int); \ + HANDLE(float); \ + HANDLE(int8_t); \ + HANDLE(int64_t); + +#define IMPL_DATA(T) \ + template <> \ + const T *Tensor::data() const { \ + CHECK(raw_tensor_); \ + const auto *tensor = static_cast(raw_tensor_); \ + return tensor->data(); \ + } +FOR_EACH_TYPE(IMPL_DATA); +#undef IMPL_DATA + +#define IMPL_MUTABLE_DATA(T) \ + template <> \ + T *Tensor::mutable_data() { \ + CHECK(raw_mutable_tensor_); \ + auto *tensor = static_cast(raw_mutable_tensor_); \ + return tensor->mutable_data(); \ + } +FOR_EACH_TYPE(IMPL_MUTABLE_DATA); +#undef IMPL_MUTABLE_DATA + +PaddlePredictor::PaddlePredictor() { + raw_ops_ = new std::vector>; + raw_kernels_ = new std::vector>; + raw_scope_ = new lite::Scope; + raw_exe_scope_ = &(static_cast(raw_scope_)->NewScope()); +} + +std::unique_ptr PaddlePredictor::GetTensor( + const std::string &id) const { + auto *exe_scope = static_cast(raw_exe_scope_); + const auto *var = exe_scope->FindVar(id); + const auto &tensor = var->Get(); + return std::unique_ptr(new Tensor(&tensor, nullptr)); +} + +std::unique_ptr PaddlePredictor::GetMutableTensor( + const std::string &id) { + auto *exe_scope = static_cast(raw_exe_scope_); + auto *var = exe_scope->FindVar(id); + auto *tensor = var->GetMutable(); + return std::unique_ptr(new Tensor(nullptr, tensor)); +} + +#define CAST_OPS \ + auto *ops = \ + static_cast> *>(raw_ops_); +#define CAST_KERNELS \ + auto *kernels = \ + static_cast> *>( \ + raw_kernels_); +#define CAST_SCOPE auto *scope = static_cast(raw_scope_); + +PaddlePredictor::~PaddlePredictor() { + CAST_OPS + CAST_KERNELS + CAST_SCOPE + + if (ops) { + delete ops; + } + if (kernels) { + delete kernels; + } + if (scope) { + delete scope; + } +} + +void PaddlePredictor::Run() { + CAST_OPS + CAST_KERNELS + + CHECK(ops); + CHECK(kernels); + CHECK_EQ(ops->size(), kernels->size()); + + for (size_t i = 0; i < ops->size(); i++) { + LOG(INFO) << "Running the " << i << "-th operator"; + ops->at(i)->InferShape(); + kernels->at(i)->Launch(); + } +} + +std::unique_ptr PaddlePredictor::GetInput(size_t offset) { + auto *exec_scope = static_cast(raw_exe_scope_); + auto *_feed_list = exec_scope->FindVar("feed"); + CHECK(_feed_list) << "no feed variable in exec_scope"; + auto *feed_list = _feed_list->GetMutable>(); + if (offset >= feed_list->size()) { + feed_list->resize(offset + 1); + } + + return std::unique_ptr(new Tensor(nullptr, &feed_list->at(offset))); +} + +std::unique_ptr PaddlePredictor::GetOutput(size_t offset) { + auto *exec_scope = static_cast(raw_exe_scope_); + auto *_fetch_list = exec_scope->FindVar("fetch"); + CHECK(_fetch_list) << "no fatch variable in exec_scope"; + auto &fetch_list = *_fetch_list->GetMutable>(); + CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; + return std::unique_ptr(new Tensor(&fetch_list.at(offset), nullptr)); +} + +} // namespace gencode +} // namespace paddle diff --git a/lite/gen_code/paddle_infer.h b/lite/gen_code/paddle_infer.h new file mode 100644 index 00000000000..e01ffc25e29 --- /dev/null +++ b/lite/gen_code/paddle_infer.h @@ -0,0 +1,72 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +namespace paddle { +namespace gencode { + +/// Zero Copy Tensor. +class Tensor { + public: + using ddim_t = std::vector; + + Tensor(const void *raw_tensor, void *raw_mutable_tensor) + : raw_tensor_(raw_tensor), raw_mutable_tensor_(raw_mutable_tensor) {} + + void Resize(const ddim_t &shape); + template + const T *data() const; + template + T *mutable_data(); + + ddim_t shape() const; + + private: + const void *raw_tensor_; + void *raw_mutable_tensor_{}; +}; + +/* + * Predictor for the generated code. + */ +class PaddlePredictor { + public: + void Init(); + + std::unique_ptr GetTensor(const std::string &id) const; + std::unique_ptr GetMutableTensor(const std::string &id); + + // Get offset-th col of feed. + std::unique_ptr GetInput(size_t offset); + + std::unique_ptr GetOutput(size_t offset); + + void Run(); + + PaddlePredictor(); + ~PaddlePredictor(); + + private: + void *raw_ops_; + void *raw_kernels_; + void *raw_scope_{}; + void *raw_exe_scope_{}; // raw_exe_scope is not owned. +}; + +} // namespace gencode +} // namespace paddle diff --git a/lite/host/CMakeLists.txt b/lite/host/CMakeLists.txt new file mode 100644 index 00000000000..8c22d8da751 --- /dev/null +++ b/lite/host/CMakeLists.txt @@ -0,0 +1,3 @@ +lite_cc_library(target_wrapper_host SRCS target_wrapper.cc) + + diff --git a/lite/host/target_wrapper.cc b/lite/host/target_wrapper.cc new file mode 100644 index 00000000000..5f020662a9d --- /dev/null +++ b/lite/host/target_wrapper.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/target_wrapper.h" +#include +#include + +namespace paddle { +namespace lite { + +const int MALLOC_ALIGN = 64; + +void* TargetWrapper::Malloc(size_t size) { + size_t offset = sizeof(void*) + MALLOC_ALIGN - 1; + char* p = static_cast(malloc(offset + size)); + if (!p) { + return nullptr; + } + void* r = reinterpret_cast(reinterpret_cast(p + offset) & + (~(MALLOC_ALIGN - 1))); + static_cast(r)[-1] = p; + memset(r, 0, size); + return r; +} +void TargetWrapper::Free(void* ptr) { + if (ptr) { + free(static_cast(ptr)[-1]); + } +} +void TargetWrapper::MemcpySync(void* dst, + const void* src, + size_t size, + IoDirection dir) { + memcpy(dst, src, size); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/CMakeLists.txt b/lite/kernels/CMakeLists.txt new file mode 100644 index 00000000000..d5a3f6d9f02 --- /dev/null +++ b/lite/kernels/CMakeLists.txt @@ -0,0 +1,11 @@ +message(STATUS "add lite kernels") + +set(lite_kernel_deps type_system kernel op op_registry context tensor CACHE INTERNAL "" FORCE) + +add_subdirectory(host) +add_subdirectory(arm) +add_subdirectory(cuda) +add_subdirectory(x86) +add_subdirectory(opencl) +add_subdirectory(fpga) +add_subdirectory(npu) diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt new file mode 100644 index 00000000000..a59737809ad --- /dev/null +++ b/lite/kernels/arm/CMakeLists.txt @@ -0,0 +1,141 @@ +if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) + return() +endif() + +message(STATUS "compile with lite ARM kernels") + +lite_cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(activation_compute_arm SRCS activation_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(batch_norm_compute_arm SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(elementwise_compute_arm SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(lrn_compute_arm SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(decode_bboxes_compute_arm SRCS decode_bboxes_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(multiclass_nms_compute_arm SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(split_compute_arm SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(concat_compute_arm SRCS concat_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(pad2d_compute_arm SRCS pad2d_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(prior_box_compute_arm SRCS prior_box_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(density_prior_box_compute_arm SRCS density_prior_box_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(negative_compute_arm SRCS negative_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(crop_compute_arm SRCS crop_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(dropout_compute_arm SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(calib_compute_arm SRCS calib_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(transpose_compute_arm SRCS transpose_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(power_compute_arm SRCS power_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(yolo_box_compute_arm SRCS yolo_box_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(shuffle_channel_compute_arm SRCS shuffle_channel_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(argmax_compute_arm SRCS argmax_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(axpy_compute_arm SRCS axpy_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(conv_transpose_compute_arm SRCS conv_transpose_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(gru_unit_compute_arm SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(gru_compute_arm SRCS gru_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(beam_search_decode_compute_arm SRCS beam_search_decode_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(lookup_table_compute_arm SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(im2sequence_compute_arm SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(sequence_softmax_compute_arm SRCS sequence_softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(norm_compute_arm SRCS norm_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(interpolate_compute_arm SRCS interpolate_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(logical_compute_arm SRCS logical_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(less_than_arm SRCS compare_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(while_compute_arm SRCS while_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(compare_compute_arm SRCS compare_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(topk_compute_arm SRCS topk_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(increment_compute_arm SRCS increment_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(write_to_array_compute_arm SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(read_from_array_compute_arm SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(beam_search_compute_arm SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(fill_constant_compute_arm SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(lod_reset_compute_arm SRCS lod_reset_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(box_coder_compute_arm SRCS box_coder_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(sequence_pool_compute_arm SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(sequence_expand_compute_arm SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(reduce_max_compute_arm SRCS reduce_max_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(is_empty_compute_arm SRCS is_empty_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(shape_compute_arm SRCS shape_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(slice_compute_arm SRCS slice_compute.cc DEPS ${lite_kernel_deps} math_arm) +lite_cc_library(cast_compute_arm SRCS cast_compute.cc DEPS ${lite_kernel_deps} math_arm) + +lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) +lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) +lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) +lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm) +lite_cc_test(test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_arm) +lite_cc_test(test_elementwise_compute_arm SRCS elementwise_compute_test.cc DEPS elementwise_compute_arm) +lite_cc_test(test_lrn_compute_arm SRCS lrn_compute_test.cc DEPS lrn_compute_arm) +lite_cc_test(test_decode_bboxes_compute_arm SRCS decode_bboxes_compute_test.cc DEPS decode_bboxes_compute_arm) +lite_cc_test(test_multiclass_nms_compute_arm SRCS multiclass_nms_compute_test.cc DEPS multiclass_nms_compute_arm) +lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm) +lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm) +lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm) +lite_cc_test(test_concat_compute_arm SRCS concat_compute_test.cc DEPS concat_compute_arm) +lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_compute_arm) +lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm) +lite_cc_test(test_argmax_compute_arm SRCS argmax_compute_test.cc DEPS argmax_compute_arm) +lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm) +lite_cc_test(test_conv_transpose_compute_arm SRCS conv_transpose_compute_test.cc DEPS conv_transpose_compute_arm) + + +set(arm_kernels + fc_compute_arm + activation_compute_arm + mul_compute_arm + scale_compute_arm + softmax_compute_arm + conv_compute_arm + batch_norm_compute_arm + elementwise_compute_arm + lrn_compute_arm + decode_bboxes_compute_arm + multiclass_nms_compute_arm + pool_compute_arm + split_compute_arm + concat_compute_arm + pad2d_compute_arm + prior_box_compute_arm + density_prior_box_compute_arm + negative_compute_arm + crop_compute_arm + dropout_compute_arm + transpose_compute_arm + calib_compute_arm + argmax_compute_arm + axpy_compute_arm + conv_transpose_compute_arm + gru_unit_compute_arm + gru_compute_arm + beam_search_decode_compute_arm + lookup_table_compute_arm + im2sequence_compute_arm + sequence_softmax_compute_arm + norm_compute_arm + power_compute_arm + shuffle_channel_compute_arm + yolo_box_compute_arm + interpolate_compute_arm + logical_compute_arm + less_than_arm + while_compute_arm + compare_compute_arm + topk_compute_arm + increment_compute_arm + write_to_array_compute_arm + read_from_array_compute_arm + beam_search_compute_arm + fill_constant_compute_arm + lod_reset_compute_arm + box_coder_compute_arm + reduce_max_compute_arm + sequence_expand_compute_arm + sequence_pool_compute_arm + is_empty_compute_arm + shape_compute_arm + slice_compute_arm + cast_compute_arm + ) + +set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels") diff --git a/lite/kernels/arm/activation_compute.cc b/lite/kernels/arm/activation_compute.cc new file mode 100644 index 00000000000..6a56633965b --- /dev/null +++ b/lite/kernels/arm/activation_compute.cc @@ -0,0 +1,196 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/activation_compute.h" +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void ReluCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto x_dims = param.X->dims(); + auto x_data = param.X->data(); + auto output_data = param.Out->mutable_data(); + lite::arm::math::act_relu( + x_data, output_data, x_dims.production(), ctx.threads()); +} + +void LeakyReluCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto x_dims = param.X->dims(); + auto x_data = param.X->data(); + auto alpha = param.Leaky_relu_alpha; + auto output_data = param.Out->mutable_data(); + lite::arm::math::act_relu_neg( + x_data, output_data, x_dims.production(), alpha, ctx.threads()); +} + +void ReluClippedCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto x_dims = param.X->dims(); + auto x_data = param.X->data(); + auto coef = param.Relu_clipped_coef; + auto output_data = param.Out->mutable_data(); + lite::arm::math::act_clipped_relu( + x_data, output_data, x_dims.production(), coef, ctx.threads()); +} + +void PReluCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto x_dims = param.X->dims(); + auto x_data = param.X->data(); + auto mode = param.Prelu_mode; + auto alpha_data = param.Prelu_alpha->data(); + auto output_data = param.Out->mutable_data(); + + int outer_size = x_dims[0]; + int channel_size = x_dims[1]; + int inner_size = x_dims.count(2, x_dims.size()); + + lite::arm::math::act_prelu(x_data, + output_data, + outer_size, + channel_size, + inner_size, + mode, + alpha_data, + ctx.threads()); +} + +void SigmoidCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto x_dims = param.X->dims(); + auto x_data = param.X->data(); + auto output_data = param.Out->mutable_data(); + lite::arm::math::act_sigmoid( + x_data, output_data, x_dims.production(), ctx.threads()); +} + +void TanhCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto x_dims = param.X->dims(); + auto x_data = param.X->data(); + auto output_data = param.Out->mutable_data(); + lite::arm::math::act_tanh( + x_data, output_data, x_dims.production(), ctx.threads()); +} + +void SwishCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto x_dims = param.X->dims(); + auto x_data = param.X->data(); + auto beta = param.Swish_beta; + auto output_data = param.Out->mutable_data(); + lite::arm::math::act_swish( + x_data, output_data, x_dims.production(), beta, ctx.threads()); +} + +void Relu6Compute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto x_dims = param.X->dims(); + auto x_data = param.X->data(); + float coef = 6.; + auto output_data = param.Out->mutable_data(); + lite::arm::math::act_clipped_relu( + x_data, output_data, x_dims.production(), coef, ctx.threads()); +} + +void LogCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto x_dims = param.X->dims(); + auto x_data = param.X->data(); + auto output_data = param.Out->mutable_data(); + lite::arm::math::act_log( + x_data, output_data, x_dims.production(), ctx.threads()); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + relu, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ReluCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); +REGISTER_LITE_KERNEL(leaky_relu, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::LeakyReluCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("alpha", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); +REGISTER_LITE_KERNEL(relu_clipped, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ReluClippedCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Relu_clipped_coef", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); +REGISTER_LITE_KERNEL( + prelu, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::PReluCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("mode", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Alpha", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); +REGISTER_LITE_KERNEL(sigmoid, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::SigmoidCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); +REGISTER_LITE_KERNEL( + tanh, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::TanhCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); +REGISTER_LITE_KERNEL( + swish, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::SwishCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("beta", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); +REGISTER_LITE_KERNEL( + relu6, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ReluCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); +REGISTER_LITE_KERNEL( + log, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::LogCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/activation_compute.h b/lite/kernels/arm/activation_compute.h new file mode 100644 index 00000000000..9360528b812 --- /dev/null +++ b/lite/kernels/arm/activation_compute.h @@ -0,0 +1,108 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class ReluCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + + virtual ~ReluCompute() = default; +}; + +class LeakyReluCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + + virtual ~LeakyReluCompute() = default; +}; + +class ReluClippedCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + + virtual ~ReluClippedCompute() = default; +}; + +class PReluCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + + virtual ~PReluCompute() = default; +}; + +class SigmoidCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + + virtual ~SigmoidCompute() = default; +}; + +class TanhCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + + virtual ~TanhCompute() = default; +}; + +class SwishCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + + virtual ~SwishCompute() = default; +}; + +class Relu6Compute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + + virtual ~Relu6Compute() = default; +}; + +class LogCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + + virtual ~LogCompute() = default; +}; +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/argmax_compute.cc b/lite/kernels/arm/argmax_compute.cc new file mode 100644 index 00000000000..5087038ff45 --- /dev/null +++ b/lite/kernels/arm/argmax_compute.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/argmax_compute.h" +#include +#include +#include "lite/arm/math/funcs.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void ArgmaxCompute::Run() { + auto& param = Param(); + lite::Tensor* input = param.X; + lite::Tensor* output = param.Out; + int axis = param.Axis; + + lite::arm::math::argmax_func(input, axis, output); + return; +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + argmax, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ArgmaxCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/argmax_compute.h b/lite/kernels/arm/argmax_compute.h new file mode 100644 index 00000000000..c87f5a451bc --- /dev/null +++ b/lite/kernels/arm/argmax_compute.h @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/operators/argmax_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class ArgmaxCompute : public KernelLite { + public: + using param_t = operators::ArgmaxParam; + + void Run() override; + + virtual ~ArgmaxCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/argmax_compute_test.cc b/lite/kernels/arm/argmax_compute_test.cc new file mode 100644 index 00000000000..ee603efa86a --- /dev/null +++ b/lite/kernels/arm/argmax_compute_test.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/argmax_compute.h" +#include +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template +void argmax_compute_ref(const operators::ArgmaxParam& param) { + lite::Tensor* x = param.X; + lite::Tensor* output = param.Out; + int axis = param.Axis; + + auto x_data = x->data(); + auto output_data = output->mutable_data(); + DDim x_dims = x->dims(); + DDim output_dims = output->dims(); + + // int in_channel = x_dims + const int size = x_dims[axis]; + const int in_channel = x_dims.count(axis, x_dims.size()); + const int out_channel = output_dims.count(axis, output_dims.size()); + const int in_stride = x_dims.count(axis + 1, x_dims.size()); + const int out_stride = x_dims.count(0, axis); + + for (int n = 0; n < out_stride; n++) { + for (int k = 0; k < in_stride; k++) { + const dtype* in_ptr = x_data + n * in_channel + k; + std::vector> vec; + vec.resize(size); + for (int i = 0; i < size; i++) { + vec[i] = std::make_pair(in_ptr[i * in_stride], i); + } + // sort + std::partial_sort(vec.begin(), + vec.begin() + 1, + vec.end(), + std::greater>()); + + // out + dtype* out_ptr = output_data + n * out_channel + k; + *out_ptr = vec[0].second; + } + } +} + +TEST(argmax_arm, retrive_op) { + auto argmax = + KernelRegistry::Global().Create( + "argmax"); + ASSERT_FALSE(argmax.empty()); + ASSERT_TRUE(argmax.front()); +} + +TEST(argmax_arm, init) { + ArgmaxCompute argmax; + ASSERT_EQ(argmax.precision(), PRECISION(kFloat)); + ASSERT_EQ(argmax.target(), TARGET(kARM)); +} +TEST(argmax_arm, compute) { + DeviceInfo::Init(); + for (auto n : {2, 3}) { + for (auto c : {3, 4 /*, 128*/}) { + for (auto h : {4, 5 /*, 56 , 112, 224, 512*/}) { + for (auto w : {5, 6 /*, 56, 112, 224, 512*/}) { + Tensor x; + Tensor output; + Tensor output_ref; + int axis = (n + c + h + w) % 4; + + // get tensor x data + x.Resize({n, c, h, w}); + auto* x_data = x.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + float sign = i % 3 == 0 ? -1.0f : 1.0f; + x_data[i] = sign * static_cast(i % 128) * 0.013f; + } + + // resize output and output_ref + int nchw[] = {n, c, h, w}; + std::vector output_size(nchw, nchw + 4); + output_size.erase(output_size.begin() + axis); + output.Resize(output_size); + output_ref.Resize(output_size); + + // obtain output_data + ArgmaxCompute argmaxOp; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + argmaxOp.SetContext(std::move(ctx)); + operators::ArgmaxParam param; + param.X = &x; + param.Out = &output; + param.Axis = axis; + argmaxOp.SetParam(param); + argmaxOp.Launch(); + auto* output_data = output.mutable_data(); + + // obtain output_ref_data + param.Out = &output_ref; + argmax_compute_ref(param); + auto* output_ref_data = output_ref.mutable_data(); + + // compare + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } + } + } + } + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle +USE_LITE_KERNEL(argmax, kARM, kFloat, kNCHW, def); diff --git a/lite/kernels/arm/axpy_compute.cc b/lite/kernels/arm/axpy_compute.cc new file mode 100644 index 00000000000..2cd4435d6af --- /dev/null +++ b/lite/kernels/arm/axpy_compute.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/axpy_compute.h" +#include +#include +#include "lite/arm/math/funcs.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void AxpyCompute::Run() { + auto& param = Param(); + lite::Tensor* scale = param.Scale; + lite::Tensor* x = param.X; + lite::Tensor* bias = param.Bias; + lite::Tensor* out = param.Out; + + const float* scale_ptr = scale->data(); + const float* x_ptr = x->data(); + const float* bias_ptr = bias->data(); + float* out_ptr = out->mutable_data(); + + auto bias_dims = bias->dims(); + int num = bias_dims[0]; + int channel = bias_dims[1]; + int size = bias_dims[2] * bias_dims[3]; + int in_channel = channel * size; + + lite::arm::math::axpy_kernel_fp32( + scale_ptr, x_ptr, bias_ptr, out_ptr, num, channel, size, in_channel); + return; +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + axpy, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::AxpyCompute, def) + .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/axpy_compute.h b/lite/kernels/arm/axpy_compute.h new file mode 100644 index 00000000000..29983bdb993 --- /dev/null +++ b/lite/kernels/arm/axpy_compute.h @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/operators/axpy_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class AxpyCompute : public KernelLite { + public: + using param_t = operators::AxpyParam; + + void Run() override; + + virtual ~AxpyCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/axpy_compute_test.cc b/lite/kernels/arm/axpy_compute_test.cc new file mode 100644 index 00000000000..af145435ebe --- /dev/null +++ b/lite/kernels/arm/axpy_compute_test.cc @@ -0,0 +1,142 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/axpy_compute.h" +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template +void axpy_compute_ref(const operators::AxpyParam& param) { + lite::Tensor* scale = param.Scale; + lite::Tensor* x = param.X; + lite::Tensor* bias = param.Bias; + lite::Tensor* output = param.Out; + + auto scale_data = scale->data(); + auto x_data = x->data(); + auto bias_data = bias->data(); + auto output_data = output->mutable_data(); + + DDim x_dims = x->dims(); + int num = x_dims[0]; + int channel = x_dims[1]; + int size = x_dims[2] * x_dims[3]; + int in_channel = channel * size; + + for (int i = 0; i < num; i++) { + auto scale_data_i = scale_data + i * channel; + auto x_data_i = x_data + i * in_channel; + auto bias_data_i = bias_data + i * in_channel; + auto output_data_i = output_data + i * in_channel; + for (int j = 0; j < channel; j++) { + auto scale_data_j = scale_data_i + j; + auto x_data_j = x_data_i + j * size; + auto bias_data_j = bias_data_i + j * size; + auto output_data_j = output_data_i + j * size; + for (int k = 0; k < size; k++) { + output_data_j[k] = scale_data_j[0] * x_data_j[k] + bias_data_j[k]; + } + } + } +} + +TEST(axpy_arm, retrive_op) { + auto axpy = + KernelRegistry::Global().Create("axpy"); + ASSERT_FALSE(axpy.empty()); + ASSERT_TRUE(axpy.front()); +} + +TEST(axpy_arm, init) { + AxpyCompute axpy; + ASSERT_EQ(axpy.precision(), PRECISION(kFloat)); + ASSERT_EQ(axpy.target(), TARGET(kARM)); +} +TEST(axpy_arm, compute) { + DeviceInfo::Init(); + int iter = 10; + for (int i = 0; i < iter; i++) { + Tensor scale; + Tensor x; + Tensor bias; + Tensor output; + Tensor output_ref; + + // set the dims of scale, x, bias and output_ref + int n = 2, c = 3, h = 4, w = 5; + scale.Resize({n, c}); + x.Resize({n, c, h, w}); + bias.Resize({n, c, h, w}); + output.Resize({n, c, h, w}); + output_ref.Resize({n, c, h, w}); + + // initialize the data of scale, x, bias + // initialize_random_data(scale); + // initialize_random_data(x); + // initialize_random_data(bias); + auto* scale_data = scale.mutable_data(); + for (int i = 0; i < scale.dims().production(); i++) { + float sign = i % 3 == 0 ? -1.0f : 1.0f; + scale_data[i] = sign * static_cast(i % 128) * 0.010f; + } + auto* x_data = x.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + float sign = i % 4 == 0 ? -1.0f : 1.0f; + x_data[i] = sign * static_cast(i % 128) * 0.007f; + } + auto* bias_data = bias.mutable_data(); + for (int i = 0; i < bias.dims().production(); i++) { + float sign = i % 5 == 0 ? -1.0f : 1.0f; + bias_data[i] = sign * static_cast(i % 128) * 0.005f; + } + + // prepare kernel params and run to obtain output_data + AxpyCompute axpy_op; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + axpy_op.SetContext(std::move(ctx)); + operators::AxpyParam param; + param.Scale = &scale; + param.X = &x; + param.Bias = &bias; + param.Out = &output; + axpy_op.SetParam(param); + axpy_op.Launch(); + auto* output_data = output.mutable_data(); + + // invoking ref implementation and compare results + param.Out = &output_ref; + axpy_compute_ref(param); + auto* output_ref_data = output_ref.mutable_data(); + + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle +USE_LITE_KERNEL(axpy, kARM, kFloat, kNCHW, def); diff --git a/lite/kernels/arm/batch_norm_compute.cc b/lite/kernels/arm/batch_norm_compute.cc new file mode 100644 index 00000000000..855f873eee5 --- /dev/null +++ b/lite/kernels/arm/batch_norm_compute.cc @@ -0,0 +1,123 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/batch_norm_compute.h" +#include "lite/arm/math/funcs.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void BatchNormCompute::PrepareForRun() { + auto& param = this->Param(); + auto x_dims = param.x->dims(); + bool global_stats = param.is_test || param.use_global_stats; + if (global_stats) { + int64_t channel_size = 0; + switch (param.data_layout) { + case DATALAYOUT(kNCHW): + channel_size = x_dims[1]; + break; + // case DATALAYOUT(kNHWC): + // channel_size = x_dims[x_dims.size() - 1]; + // break; + default: + LOG(FATAL) << "Unknown storage order: " + << DataLayoutToStr(param.data_layout); + break; + } + new_scale.Resize({channel_size}); + new_bias.Resize({channel_size}); + auto* scale_data = param.scale->mutable_data(); + auto* bias_data = param.bias->mutable_data(); + auto* mean_data = param.mean->mutable_data(); + auto* variance_data = param.variance->mutable_data(); + auto* new_scale_data = new_scale.mutable_data(); + auto* new_bias_data = new_bias.mutable_data(); + for (int c = 0; c < channel_size; c++) { + float inv_scale = 1.f / (std::sqrt(variance_data[c] + param.epsilon)); + new_bias_data[c] = + bias_data[c] - inv_scale * scale_data[c] * mean_data[c]; + new_scale_data[c] = inv_scale * scale_data[c]; + } + } +} + +void BatchNormCompute::Run() { + auto& param = this->Param(); + auto x_dims = param.x->dims(); + auto x_data = param.x->mutable_data(); + auto y_data = param.y->mutable_data(); + bool global_stats = param.is_test || param.use_global_stats; + if (global_stats) { + auto* new_scale_data = new_scale.mutable_data(); + auto* new_bias_data = new_bias.mutable_data(); + int64_t outer_size = 0; + int64_t channel_size = 0; + int64_t inner_size = 0; + switch (param.data_layout) { + case DATALAYOUT(kNCHW): + outer_size = x_dims[0]; + channel_size = x_dims[1]; + inner_size = x_dims.Slice(2, x_dims.size()).production(); + lite::arm::math::scale(x_data, + y_data, + outer_size, + channel_size, + inner_size, + new_scale_data, + new_bias_data); + break; + // case DATALAYOUT(kNHWC): + // outer_size = x_dims.Slice(0, x_dims.size() - 1).production(); + // channel_size = x_dims[x_dims.size() - 1]; + // lite::arm::math::scale(x_data, y_data, outer_size, channel_size, + // new_scale_data, new_bias_data); + // break; + default: + LOG(FATAL) << "Unknown storage order: " + << DataLayoutToStr(param.data_layout); + break; + } + } else { + // TODO(hong19860320) calculate mean_out, variance_out, saved_mean and + // saved_variance + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(batch_norm, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::BatchNormCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Mean", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Variance", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("MeanOut", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("VarianceOut", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("SavedMean", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("SavedVariance", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/batch_norm_compute.h b/lite/kernels/arm/batch_norm_compute.h new file mode 100644 index 00000000000..22553f55d5d --- /dev/null +++ b/lite/kernels/arm/batch_norm_compute.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class BatchNormCompute : public KernelLite { + public: + using param_t = operators::BatchNormParam; + + void PrepareForRun() override; + + void Run() override; + + virtual ~BatchNormCompute() = default; + + private: + Tensor new_scale; + Tensor new_bias; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/batch_norm_compute_test.cc b/lite/kernels/arm/batch_norm_compute_test.cc new file mode 100644 index 00000000000..c603a04d470 --- /dev/null +++ b/lite/kernels/arm/batch_norm_compute_test.cc @@ -0,0 +1,221 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/batch_norm_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template +void batch_norm_compute_ref(const operators::BatchNormParam& param) { + DDim x_dims = param.x->dims(); + auto x_data = param.x->mutable_data(); + auto scale_data = param.scale->mutable_data(); + auto bias_data = param.bias->mutable_data(); + auto mean_data = param.mean->mutable_data(); + auto variance_data = param.variance->mutable_data(); + auto y_data = param.y->mutable_data(); + float epsilon = param.epsilon; + float momentum = param.momentum; + DataLayoutType data_layout = param.data_layout; + + bool global_stats = param.is_test || param.use_global_stats; + if (global_stats) { + int64_t outer_size = 0; + int64_t channel_size = 0; + int64_t inner_size = 0; + switch (data_layout) { + case DATALAYOUT(kNCHW): + outer_size = x_dims[0]; + channel_size = x_dims[1]; + inner_size = x_dims.Slice(2, x_dims.size()).production(); + break; + // case DATALAYOUT(kNHWC): + // outer_size = x_dims.Slice(0, x_dims.size() - 1).production(); + // channel_size = x_dims[x_dims.size() - 1]; + // inner_size = 1; + // break; + default: + LOG(FATAL) << "Unknown storage order: " << DataLayoutToStr(data_layout); + break; + } + auto x_ptr = x_data; + auto y_ptr = y_data; + for (int o = 0; o < outer_size; o++) { + for (int c = 0; c < channel_size; c++) { + for (int i = 0; i < inner_size; i++) { + dtype norm_x = + (*x_ptr - mean_data[c]) / std::sqrt(variance_data[c] + epsilon); + *y_ptr = norm_x * scale_data[c] + bias_data[c]; + x_ptr++; + y_ptr++; + } + } + } + } else { + // TODO(hong19860320) calculate mean_out, variance_out, saved_mean and + // saved_variance + } +} + +TEST(batch_norm_arm, retrive_op) { + auto batch_norm = + KernelRegistry::Global().Create( + "batch_norm"); + ASSERT_FALSE(batch_norm.empty()); + ASSERT_TRUE(batch_norm.front()); +} + +TEST(batch_norm_arm, init) { + BatchNormCompute batch_norm; + ASSERT_EQ(batch_norm.precision(), PRECISION(kFloat)); + ASSERT_EQ(batch_norm.target(), TARGET(kARM)); +} + +TEST(batch_norm_arm, compute) { + DeviceInfo::Init(); + for (auto n : {1, 2}) { + for (auto c : {6, 32 /*, 128*/}) { + for (auto h : {9, 18 /*, 56 , 112, 224, 512*/}) { + for (auto w : {9, 18 /*, 56, 112, 224, 512*/}) { + for (auto is_test : {/*false, */ true}) { + for (auto use_global_stats : {false, true}) { + for (auto epsilon : {1e-4f, 1e-5f}) { + for (auto momentum : {0.9f, 0.99f}) { + for (auto data_layout : + {DATALAYOUT(kNCHW) /*, DATALAYOUT(kNHWC)*/}) { + Tensor x; + Tensor scale; + Tensor bias; + Tensor mean; + Tensor variance; + Tensor y; + Tensor mean_out; + Tensor variance_out; + Tensor saved_mean; + Tensor saved_variance; + Tensor y_ref; + Tensor mean_out_ref; + Tensor variance_out_ref; + Tensor saved_mean_ref; + Tensor saved_variance_ref; + // set the dims of input, output, ref output tensors + std::vector in_out_shape; + switch (data_layout) { + case DATALAYOUT(kNCHW): + in_out_shape = {n, c, h, w}; + break; + // case DATALAYOUT(kNHWC): + // in_out_shape = {n, h, w, c}; + // break; + default: + LOG(FATAL) << "Unknown storage order: " + << DataLayoutToStr(data_layout); + break; + } + x.Resize(in_out_shape); + scale.Resize({c}); + bias.Resize({c}); + mean.Resize({c}); + variance.Resize({c}); + y.Resize(in_out_shape); + mean_out.Resize({c}); + variance_out.Resize({c}); + saved_mean.Resize({c}); + saved_variance.Resize({c}); + y_ref.Resize(in_out_shape); + mean_out_ref.Resize({c}); + variance_out_ref.Resize({c}); + saved_mean_ref.Resize({c}); + saved_variance_ref.Resize({c}); + // initialize the data of input tensors + auto* x_data = x.mutable_data(); + auto* scale_data = scale.mutable_data(); + auto* bias_data = bias.mutable_data(); + auto* mean_data = mean.mutable_data(); + auto* variance_data = variance.mutable_data(); + auto* y_data = y.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + x_data[i] = static_cast(i % 64); + } + for (int i = 0; i < scale.dims().production(); i++) { + scale_data[i] = static_cast(i) * 0.01f + 0.03f; + } + for (int i = 0; i < bias.dims().production(); i++) { + bias_data[i] = static_cast(i) * 0.065f + 0.1f; + } + for (int i = 0; i < mean.dims().production(); i++) { + mean_data[i] = static_cast(i) * 0.0565f; + } + for (int i = 0; i < variance.dims().production(); i++) { + variance_data[i] = static_cast(i) * 2.08f + 1.5f; + } + // prepare kernel params and run + BatchNormCompute batch_norm; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + batch_norm.SetContext(std::move(ctx)); + operators::BatchNormParam param; + param.x = &x; + param.scale = &scale; + param.bias = &bias; + param.mean = &mean; + param.variance = &variance; + param.is_test = is_test; + param.use_global_stats = use_global_stats; + param.epsilon = epsilon; + param.momentum = momentum; + param.data_layout = data_layout; + param.y = &y; + param.mean_out = &mean_out; + param.variance_out = &variance_out; + param.saved_mean = &saved_mean; + param.saved_variance = &saved_variance; + batch_norm.SetParam(param); + batch_norm.Launch(); + // invoking ref implementation and compare results + param.y = &y_ref; + param.mean_out = &mean_out_ref; + param.variance_out = &variance_out_ref; + param.saved_mean = &saved_mean_ref; + param.saved_variance = &saved_variance_ref; + batch_norm_compute_ref(param); + auto* y_ref_data = y_ref.mutable_data(); + for (int i = 0; i < y.dims().production(); i++) { + EXPECT_NEAR(y_data[i], y_ref_data[i], 1e-5); + } + } + } + } + } + } + } + } + } + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def); diff --git a/lite/kernels/arm/beam_search_compute.cc b/lite/kernels/arm/beam_search_compute.cc new file mode 100644 index 00000000000..6be5d680b4e --- /dev/null +++ b/lite/kernels/arm/beam_search_compute.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/beam_search_compute.h" +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void BeamSearchCompute::PrepareForRun() {} + +void BeamSearchCompute::Run() { + auto& ctx = this->ctx_->template As(); + auto& param = this->Param(); + lite::arm::math::beam_search(param.pre_ids, + param.pre_scores, + param.ids, + param.scores, + param.selected_ids, + param.selected_scores, + param.parent_idx, + param.level, + param.beam_size, + param.end_id, + param.is_accumulated, + &ctx); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(beam_search, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::BeamSearchCompute, + def) + .BindInput("pre_ids", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("pre_scores", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("ids", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("scores", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("selected_ids", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("selected_scores", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("parent_idx", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/beam_search_compute.h b/lite/kernels/arm/beam_search_compute.h new file mode 100644 index 00000000000..ef150ba74ac --- /dev/null +++ b/lite/kernels/arm/beam_search_compute.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/arm/math/type_trans.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class BeamSearchCompute : public KernelLite { + public: + using param_t = operators::BeamSearchParam; + + void PrepareForRun() override; + + void Run() override; + + ~BeamSearchCompute() {} + + private: +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/beam_search_decode_compute.cc b/lite/kernels/arm/beam_search_decode_compute.cc new file mode 100644 index 00000000000..d0640e2c3be --- /dev/null +++ b/lite/kernels/arm/beam_search_decode_compute.cc @@ -0,0 +1,296 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/beam_search_decode_compute.h" +#include +#include +#include "lite/api/paddle_place.h" +#include "lite/arm/math/funcs.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +using LoDTensor = lite::Tensor; +using LoDTensorArray = std::vector; + +// all the lod have 2 levels. +// The first is source level, the second is sentence level. +// source level describe how many prefixes (branchs) for each source sentece +// (beam). sentence level describe how these candidates belong to the prefixes. +const size_t kSourceLevel = 0; +const size_t kSentenceLevel = 1; + +template +struct Sentence { + std::vector word_ids; + std::vector scores; +}; + +template +using SentenceVector = std::vector>; + +template +struct BeamSearchDecoder { + BeamSearchDecoder(size_t beam_size, int end_id) + : beam_size_(beam_size), end_id_(end_id) {} + + /** + * convert the result sentence_vector for each source sentence into two + * LodTensor. + * One is all candidate sentences with word id, one is all candidate sentences + * with word score. + * Param: + * sentence_vector_list: sentence_vector for each source sentence. + * id_tensor: result LoDTensor for sentences of id. + * score_tensor: result LoDTensor for sentences of score. + * reverse: whether ids of sentence in sentence_vector_list is reversed + * sort_by_score: whether to sort hypotheses of each sentence by scores. + */ + void ConvertSentenceVectorToLodTensor( + std::vector> sentence_vector_list, + LoDTensor* id_tensor, + LoDTensor* score_tensor, + bool reverse = true, + bool sort_by_score = true) const { + size_t src_num = sentence_vector_list.size(); + CHECK_GT(src_num, 0) << "src_num should not be 0"; + + std::vector source_level_lod = {0}; + std::vector sentence_level_lod = {0}; + std::vector id_data; + std::vector score_data; + + for (size_t src_idx = 0; src_idx < src_num; ++src_idx) { + if (sort_by_score) { + sort(sentence_vector_list[src_idx].begin(), + sentence_vector_list[src_idx].end(), + [reverse](const Sentence& a, const Sentence& b) { + if (reverse) + return a.scores.front() > b.scores.front(); + else + return a.scores.back() > b.scores.back(); + }); + } + for (Sentence& sentence : sentence_vector_list[src_idx]) { + if (reverse) { + id_data.insert(id_data.end(), + sentence.word_ids.rbegin(), + sentence.word_ids.rend()); + score_data.insert(score_data.end(), + sentence.scores.rbegin(), + sentence.scores.rend()); + } else { + id_data.insert(id_data.end(), + sentence.word_ids.begin(), + sentence.word_ids.end()); + score_data.insert( + score_data.end(), sentence.scores.begin(), sentence.scores.end()); + } + + sentence_level_lod.push_back(sentence_level_lod.back() + + sentence.word_ids.size()); + } + source_level_lod.push_back(source_level_lod.back() + + sentence_vector_list[src_idx].size()); + } + + LoD lod; + lod.push_back(source_level_lod); + lod.push_back(sentence_level_lod); + + *(id_tensor->mutable_lod()) = lod; + + id_tensor->Resize({static_cast(id_data.size())}); + auto id_ptr = id_tensor->mutable_data(); + TargetCopy( + TARGET(kARM), id_ptr, id_data.data(), id_data.size() * sizeof(float)); + + *(score_tensor->mutable_lod()) = lod; + score_tensor->Resize({static_cast(score_data.size())}); + auto score_ptr = score_tensor->mutable_data(); + TargetCopy(TARGET(kARM), + score_ptr, + score_data.data(), + score_data.size() * sizeof(T)); + } + + /** + * Gather the hypotheses for each source sentence by backtrace though the + * LoDTensorArray step_ids whose lods reserve the path in the tree. + */ + void Backtrace(const LoDTensorArray& step_ids, + const LoDTensorArray& step_scores, + LoDTensor* id_tensor, + LoDTensor* score_tensor) const { + CHECK(!step_ids.empty()) << "step num should be larger than 0"; + CHECK_EQ(step_ids.size(), step_scores.size()) + << "step_ids and step_scores should be the same"; + const size_t step_num = step_ids.size(); + const size_t src_num = step_ids.at(0).lod().at(kSourceLevel).size() - 1; + std::vector> sentence_vector_list( + src_num, SentenceVector(beam_size_)); + std::vector> prefix_idx_vector_list(src_num); + for (int step_id = step_num - 1; step_id >= 0; --step_id) { + auto& cur_ids = step_ids.at(step_id); + auto& cur_scores = step_scores.at(step_id); + for (size_t src_idx = 0; src_idx < src_num; ++src_idx) { + // for each source sentence + auto& sentence_vector = sentence_vector_list.at(src_idx); + auto& prefix_idx_vector = prefix_idx_vector_list.at(src_idx); + size_t src_prefix_start = cur_ids.lod().at(kSourceLevel)[src_idx]; + size_t src_prefix_end = cur_ids.lod().at(kSourceLevel)[src_idx + 1]; + if (prefix_idx_vector.empty()) { // be finished and pruned at this step + // or the last time step + for (size_t prefix_idx = src_prefix_start; + prefix_idx < src_prefix_end; + ++prefix_idx) { + size_t candidate_start = + cur_ids.lod().at(kSentenceLevel)[prefix_idx]; + size_t candidate_end = + cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1]; + for (size_t candidate_idx = candidate_start; + candidate_idx < candidate_end; + ++candidate_idx) { + prefix_idx_vector.push_back(prefix_idx); + size_t idx = prefix_idx_vector.size() - 1; + auto cur_id = cur_ids.data()[candidate_idx]; + auto cur_score = cur_scores.data()[candidate_idx]; + sentence_vector.at(idx).word_ids.push_back(cur_id); + sentence_vector.at(idx).scores.push_back(cur_score); + } + } + } else { // use prefix_idx_vector to backtrace + size_t src_candidate_start = + cur_ids.lod().at(kSentenceLevel)[src_prefix_start]; + size_t prefix_idx = src_prefix_start; + size_t candidate_num = + cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1] - + cur_ids.lod().at(kSentenceLevel)[prefix_idx]; + for (size_t idx = 0; idx < prefix_idx_vector.size(); ++idx) { + auto candidate_idx = prefix_idx_vector.at(idx); + auto cur_id = cur_ids.data()[candidate_idx]; + auto cur_score = cur_scores.data()[candidate_idx]; + if (cur_id != end_id_ || sentence_vector.at(idx).word_ids.empty()) { + // to skip redundant end tokens + sentence_vector.at(idx).word_ids.push_back(cur_id); + sentence_vector.at(idx).scores.push_back(cur_score); + } + + while (src_candidate_start + candidate_num <= + candidate_idx) { // search the corresponding prefix + prefix_idx++; + candidate_num += + cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1] - + cur_ids.lod().at(kSentenceLevel)[prefix_idx]; + } + prefix_idx_vector.at(idx) = prefix_idx; + } + } + } + } + + ConvertSentenceVectorToLodTensor( + sentence_vector_list, id_tensor, score_tensor, true, true); + } + + size_t beam_size_; + int end_id_; +}; + +struct BeamSearchDecodeFunctor { + BeamSearchDecodeFunctor(const LoDTensorArray& step_ids, + const LoDTensorArray& step_scores, + LoDTensor* id_tensor, + LoDTensor* score_tensor, + size_t beam_size, + int end_id) + : beam_size_(beam_size), + end_id_(end_id), + step_ids_(step_ids), + step_scores_(step_scores), + id_tensor_(id_tensor), + score_tensor_(score_tensor) {} + + template + void apply() const { + BeamSearchDecoder beam_search_decoder(beam_size_, end_id_); + beam_search_decoder.Backtrace( + step_ids_, step_scores_, id_tensor_, score_tensor_); + } + + size_t beam_size_; + int end_id_; + const LoDTensorArray& step_ids_; + const LoDTensorArray& step_scores_; + LoDTensor* id_tensor_; + LoDTensor* score_tensor_; +}; + +template <> +void BeamSearchDecodeFunctor::apply() const { + LOG(FATAL) << "beam search decode op does not support bool!"; +} + +void BeamSearchDecodeCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + // inputs + auto ids = param.ids; + auto scores = param.scores; + // outputs + auto sentence_ids = param.sentence_ids; + auto sentence_scores = param.sentence_scores; + + const size_t step_num = ids->size(); + CHECK_GT(step_num, 0UL) << "beam search steps should be larger than 0"; + const size_t source_num = ids->at(0).lod().at(0).size() - 1; + CHECK_GT(source_num, 0UL) << "source num should be larger than 0"; + + for (size_t i = 0; i < step_num; ++i) { + CHECK_EQ(ids->at(i).lod().size(), 2UL) << "Level of LodTensor should be 2"; + } + + //! fixme + // only support float score now + BeamSearchDecodeFunctor func(*ids, + *scores, + sentence_ids, + sentence_scores, + param.beam_size, + param.end_id); + + func.apply(); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(beam_search_decode, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::BeamSearchDecodeCompute, + def) + .BindInput("Ids", {LiteType::GetTensorListTy(TARGET(kARM))}) + .BindInput("Scores", {LiteType::GetTensorListTy(TARGET(kARM))}) + .BindOutput("SentenceIds", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("SentenceScores", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/beam_search_decode_compute.h b/lite/kernels/arm/beam_search_decode_compute.h new file mode 100644 index 00000000000..db1961ad937 --- /dev/null +++ b/lite/kernels/arm/beam_search_decode_compute.h @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class BeamSearchDecodeCompute + : public KernelLite { + public: + using param_t = operators::BeamSearchDecodeParam; + + BeamSearchDecodeCompute() = default; + + void Run() override; + + virtual ~BeamSearchDecodeCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/box_coder_compute.cc b/lite/kernels/arm/box_coder_compute.cc new file mode 100644 index 00000000000..75dca9496dc --- /dev/null +++ b/lite/kernels/arm/box_coder_compute.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/box_coder_compute.h" +#include +#include +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void BoxCoderCompute::Run() { + auto& param = Param(); + int axis = param.axis; + bool box_normalized = param.box_normalized; + std::string code_type = param.code_type; + + lite::arm::math::box_coder(param.proposals, + param.prior_box, + param.prior_box_var, + param.target_box, + code_type, + box_normalized, + axis); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(box_coder, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::BoxCoderCompute, + def) + .BindInput("PriorBox", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("PriorBoxVar", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("TargetBox", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("OutputBox", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/box_coder_compute.h b/lite/kernels/arm/box_coder_compute.h new file mode 100644 index 00000000000..0279af4ea58 --- /dev/null +++ b/lite/kernels/arm/box_coder_compute.h @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class BoxCoderCompute : public KernelLite { + public: + using param_t = operators::BoxCoderParam; + + void Run() override; + + virtual ~BoxCoderCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/calib_compute.cc b/lite/kernels/arm/calib_compute.cc new file mode 100644 index 00000000000..3bc434329a9 --- /dev/null +++ b/lite/kernels/arm/calib_compute.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/calib_compute.h" +#include +#include "lite/arm/math/type_trans.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void CalibComputeFp32ToInt8::Run() { + auto& param = this->Param(); + std::vector scale = {param.scale}; + const auto* din = param.input->data(); + auto* dout = param.output->mutable_data(); + lite::arm::math::fp32_to_int8( + din, dout, scale.data(), 1, 1, param.input->numel()); + return; +} + +void CalibComputeInt8ToFp32::Run() { + auto& param = this->Param(); + const auto* din = param.input->data(); + std::vector scale = {param.scale}; + auto* dout = param.output->mutable_data(); + lite::arm::math::int8_to_fp32( + din, dout, scale.data(), 1, 1, param.input->numel()); + return; +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(calib, + kARM, + kInt8, + kNCHW, + paddle::lite::kernels::arm::CalibComputeFp32ToInt8, + fp32_to_int8) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .Finalize(); + +REGISTER_LITE_KERNEL(calib, + kARM, + kInt8, + kNCHW, + paddle::lite::kernels::arm::CalibComputeInt8ToFp32, + int8_to_fp32) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .Finalize(); +REGISTER_LITE_KERNEL(calib_once, + kARM, + kInt8, + kNCHW, + paddle::lite::kernels::arm::CalibComputeFp32ToInt8, + fp32_to_int8) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .Finalize(); + +REGISTER_LITE_KERNEL(calib_once, + kARM, + kInt8, + kNCHW, + paddle::lite::kernels::arm::CalibComputeInt8ToFp32, + int8_to_fp32) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .Finalize(); diff --git a/lite/kernels/arm/calib_compute.h b/lite/kernels/arm/calib_compute.h new file mode 100644 index 00000000000..8d9a32bc245 --- /dev/null +++ b/lite/kernels/arm/calib_compute.h @@ -0,0 +1,51 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/operators/calib_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class CalibComputeFp32ToInt8 + : public KernelLite { + public: + using param_t = operators::CalibParam; + + void Run() override; + + ~CalibComputeFp32ToInt8() override{}; + + private: +}; + +class CalibComputeInt8ToFp32 + : public KernelLite { + public: + using param_t = operators::CalibParam; + + void Run() override; + + ~CalibComputeInt8ToFp32() override{}; + + private: +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/calib_compute_test.cc b/lite/kernels/arm/calib_compute_test.cc new file mode 100644 index 00000000000..ee29424293e --- /dev/null +++ b/lite/kernels/arm/calib_compute_test.cc @@ -0,0 +1,156 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/calib_compute.h" +#include +#include +#include +#include +#include +#include +#include +#include "lite/arm/math/funcs.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +static int get_rand(int start, int end) { + int i = rand(); // NOLINT + i = (i % (end - start)) + start; + return i; +} + +static void int8_to_fp32_basic(const int8_t* din, + float* dout, + const float* scale, + int axis_size, + int64_t outer_size, + int64_t inner_size) { + int loop_size = axis_size * outer_size; + for (int i = 0; i < loop_size; ++i) { + float scale_in = scale[i % axis_size]; + for (int j = 0; j < inner_size; ++j) { + dout[j] = din[j] * scale_in; + } + dout += inner_size; + din += inner_size; + } +} + +static void fp32_to_int8_basic(const float* din, + int8_t* dout, + const float* scale, + int axis_size, + int64_t outer_size, + int64_t inner_size) { + int loop_size = axis_size * outer_size; + for (int i = 0; i < loop_size; ++i) { + float inv_scale = 1.f / scale[i % axis_size]; + for (int j = 0; j < inner_size; ++j) { + dout[j] = static_cast(roundf(din[j] * inv_scale)); + } + dout += inner_size; + din += inner_size; + } +} + +void calib_ref(const operators::CalibParam& param) { + std::vector scale = {param.in_scale}; + if (param.in_dtype == PRECISION(kFloat) && + param.out_dtype == PRECISION(kInt8)) { + const auto* din = param.input->data(); + auto* dout = param.output->mutable_data(); + fp32_to_int8_basic(din, dout, scale.data(), 1, 1, param.input->numel()); + return; + } + if (param.in_dtype == PRECISION(kInt8) && + param.out_dtype == PRECISION(kFloat)) { + const auto* din = param.input->data(); + auto* dout = param.output->mutable_data(); + int8_to_fp32_basic(din, dout, scale.data(), 1, 1, param.input->numel()); + return; + } + LOG(FATAL) << "Unsupport Dtype."; +} + +TEST(calib_arm, retrive_op) { + auto calib = + KernelRegistry::Global() + .Create("calib"); + ASSERT_FALSE(calib.empty()); + ASSERT_TRUE(calib.front()); +} + +TEST(calib_arm, init) { + CalibCompute calib; + ASSERT_EQ(calib.precision(), PRECISION(kInt8)); + ASSERT_EQ(calib.target(), TARGET(kARM)); +} + +TEST(calib_arm, int8_to_fp32) { + DeviceInfo::Init(); + for (auto n : {1, 2}) { + for (auto c : {6, 32 /*, 128*/}) { + for (auto h : {9, 18 /*, 56 , 112, 224, 512*/}) { + for (auto w : {9, 18 /*, 56, 112, 224, 512*/}) { + Tensor x; + Tensor output; + Tensor output_ref; + // set the dims of input, output, ref output tensors + x.Resize({n, c, h, w}); + output.Resize({n, c, h, w}); + output_ref.Resize({n, c, h, w}); + // initialize the data of input tensors + auto* x_data = x.mutable_data(); + auto* output_data = output.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + float sign = i % 3 == 0 ? -1.0f : 1.0f; + x_data[i] = sign * static_cast(i % 128) * 0.013f; + } + // prepare kernel params and run + CalibCompute calib; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + calib.SetContext(std::move(ctx)); + operators::CalibParam param; + param.in_scale = get_rand(0, 100) * 0.1f; + param.in_dtype = PRECISION(kInt8); + param.out_dtype = PRECISION(kFloat); + param.input = &x; + param.output = &output; + calib.SetParam(param); + calib.Launch(); + // invoking ref implementation and compare results + param.output = &output_ref; + calib_ref(param); + auto* output_ref_data = output_ref.mutable_data(); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } + } + } + } + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32); +USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8); diff --git a/lite/kernels/arm/cast_compute.cc b/lite/kernels/arm/cast_compute.cc new file mode 100644 index 00000000000..ad4cc82d3a1 --- /dev/null +++ b/lite/kernels/arm/cast_compute.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/cast_compute.h" +#include "lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void CastCompute::PrepareForRun() {} + +void CastCompute::Run() { + auto& ctx = this->ctx_->template As(); + auto& param = this->Param(); + + auto input_dims = param.X->dims(); + + if (param.in_dtype == param.out_dtype && param.in_dtype == 2 || + param.in_dtype == 0) { + const auto* x_data = param.X->data(); + auto* o_data = param.Out->mutable_data(); + memcpy(o_data, x_data, sizeof(float) * param.X->numel()); + } else { + LOG(FATAL) << "other has not been implemented"; + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + cast, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::CastCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/cast_compute.h b/lite/kernels/arm/cast_compute.h new file mode 100644 index 00000000000..fc5b82f4c9f --- /dev/null +++ b/lite/kernels/arm/cast_compute.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/arm/math/type_trans.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class CastCompute : public KernelLite { + public: + using param_t = operators::CastParam; + + void PrepareForRun() override; + + void Run() override; + + ~CastCompute() {} + + private: +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/compare_compute.cc b/lite/kernels/arm/compare_compute.cc new file mode 100644 index 00000000000..72465b1bc45 --- /dev/null +++ b/lite/kernels/arm/compare_compute.cc @@ -0,0 +1,186 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/compare_compute.h" +#include +#include "lite/api/paddle_place.h" +#include "lite/arm/math/funcs.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +#define COMPARE_FUNCTOR(name, op) \ + template \ + struct _##name##Functor { \ + inline bool operator()(const T &a, const T &b) const { return a op b; } \ + }; + +COMPARE_FUNCTOR(Equal, ==); +COMPARE_FUNCTOR(NotEqual, !=); +COMPARE_FUNCTOR(LessThan, <); +COMPARE_FUNCTOR(LessEqual, <=); +COMPARE_FUNCTOR(GreaterThan, >); +COMPARE_FUNCTOR(GreaterEqual, >=); + +template <> +struct _EqualFunctor { + inline bool operator()(const float &a, const float &b) const { + // It is safe to cast a and b to double. + return fabs(static_cast(a - b)) < 1e-8; + } +}; + +template <> +struct _NotEqualFunctor { + inline bool operator()(const float &a, const float &b) const { + return !_EqualFunctor()(a, b); + } +}; + +inline void get_mid_dims(const lite::DDim &x_dims, + const lite::DDim &y_dims, + const int axis, + int *pre, + int *n, + int *post) { + *pre = 1; + *n = 1; + *post = 1; + for (int i = 0; i < axis; ++i) { + (*pre) *= x_dims[i]; + } + + for (int i = 0; i < y_dims.size(); ++i) { + (*n) *= y_dims[i]; + } + + for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { + (*post) *= x_dims[i]; + } +} +template