Skip to content

Commit

Permalink
[XPU] Add unbind kernel (#10395)
Browse files Browse the repository at this point in the history
  • Loading branch information
leolishaohao authored Sep 25, 2023
1 parent edea477 commit 68ced6b
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 0 deletions.
1 change: 1 addition & 0 deletions lite/kernels/xpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ add_kernel(where_compute_xpu XPU extra SRCS where_compute.cc)
add_kernel(gather_nd_compute_xpu XPU extra SRCS gather_nd_compute.cc)
add_kernel(meshgrid_compute_xpu XPU basic SRCS meshgrid_compute.cc)
add_kernel(fetch_compute_xpu XPU basic SRCS fetch_compute.cc)
add_kernel(unbind_compute_xpu XPU basic SRCS unbind_compute.cc)

# extra
add_kernel(lookup_table_compute_xpu XPU extra SRCS lookup_table_compute.cc)
Expand Down
61 changes: 61 additions & 0 deletions lite/kernels/xpu/unbind_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/xpu/unbind_compute.h"
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {

template <typename T, PrecisionType PType>
void UnbindCompute<T, PType>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();
auto x = param.x;
auto& axis = param.axis;

auto output = param.output;

std::vector<T*> y_ptrs;
for (size_t j = 0; j < output.size(); ++j) {
y_ptrs.emplace_back(output[j]->template mutable_data<T>(TARGET(kXPU)));
}
auto x_shape = x->dims().Vectorize();
int r = xdnn::unbind(
ctx.GetRawContext(), x->template data<T>(), y_ptrs, x_shape, axis);
CHECK_EQ(r, 0);
}

} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle

using unbind_fp32 =
paddle::lite::kernels::xpu::UnbindCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(unbind, kXPU, kFloat, kNCHW, unbind_fp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFloat))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFloat))})
.Finalize();

using unbind_int64 =
paddle::lite::kernels::xpu::UnbindCompute<int64_t, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(unbind, kXPU, kFloat, kNCHW, unbind_int64, int64)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.Finalize();
37 changes: 37 additions & 0 deletions lite/kernels/xpu/unbind_compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "lite/core/kernel.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {

template <typename T, PrecisionType PType>
class UnbindCompute : public KernelLite<TARGET(kXPU), PType> {
public:
using param_t = operators::UnbindParam;

virtual void Run();

virtual ~UnbindCompute() = default;
};

} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle

0 comments on commit 68ced6b

Please sign in to comment.