Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/infrt/naive/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ cc_library(infrt_naive SRCS meta_tensor.cc
infershaped/infershaped_kernel_launchers.cc
)

cc_test_tiny(test_infrt_infershape_launchers SRCS infershaped/infershape_launchers_test.cc DEPS infrt)
cc_test_tiny(test_infrt_infershape_launchers SRCS
infershaped/infershape_launchers_test.cc DEPS infrt)
30 changes: 14 additions & 16 deletions paddle/infrt/naive/infershaped/elementwise_add.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/naive/infershaped/infershaped_utils.h"

// This file contains a example of the infershape ElementwiseAdd kernel.
// Some of the following code should be generated from PTEN by script.
Expand All @@ -32,39 +33,36 @@ static void ElementwiseAddInferShape(const MetaTensor& a,
*c->mutable_shape() = a.shape();
}

static void ElementwiseAdd(const tensor::DenseHostTensor& a,
static void ElementwiseAdd(tensor::DenseHostTensor* /*Context*/,
const tensor::DenseHostTensor& a,
const tensor::DenseHostTensor& b,
tensor::DenseHostTensor* c) {}

// TODO(zhiqiang) This class should be generated by a script offline.
class ElementwiseAddLauncher : public InferShapedKernelLauncher {
template <typename KernelFunc,
KernelFunc kernel,
typename InferShapedFunc,
InferShapedFunc infershape>
class KernelLauncher : public InferShapedKernelLauncher {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[TODO] This class needs to move the position

public:
static const uint16_t input_tensor_indices[2];
static const uint16_t num_input_tensors{2};
static const uint16_t num_input_tensors{InferShapeHelper<KernelFunc>::count};
static const bool turn_on_infer_shape_cache{true};

void Invoke(host_context::KernelFrame* frame) override {
// Build the infershape KernelFrame if needed.
// TODO(Superjomn) add unlikely here.
if (infershape_kernel_frame_builder.IsEmpty()) {
CreateKernelFrameForInferShape(frame);
}
if (turn_on_infer_shape_cache) {
if (IsShapeChanged(input_tensor_indices, num_input_tensors)) {
INFRT_KERNEL(ElementwiseAddInferShape)
(&infershape_kernel_frame_builder);
BuildInferShapeCache(input_tensor_indices, num_input_tensors);
if (!turn_on_infer_shape_cache || IsShapeChanged(num_input_tensors)) {
::infrt::host_context::KernelImpl<InferShapedFunc, infershape>::Invoke(
&infershape_kernel_frame_builder);
BuildInferShapeCache(num_input_tensors);
}
} else {
INFRT_KERNEL(ElementwiseAddInferShape)(&infershape_kernel_frame_builder);
BuildInferShapeCache(input_tensor_indices, num_input_tensors);
}

INFRT_KERNEL(ElementwiseAdd)(frame);
::infrt::host_context::KernelImpl<KernelFunc, kernel>::Invoke(frame);
}
};

const uint16_t ElementwiseAddLauncher::input_tensor_indices[2] = {0, 1};

} // namespace naive
} // namespace infrt
14 changes: 14 additions & 0 deletions paddle/infrt/naive/infershaped/infershape_launchers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,24 @@
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"
#include "paddle/infrt/naive/infershaped/infershaped_utils.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"

namespace infrt {
namespace naive {

namespace {
static void ElementwiseAddTest(const tensor::DenseHostTensor& a,
const tensor::DenseHostTensor& b,
tensor::DenseHostTensor* c);
}

TEST(utils, registry) {
constexpr uint8_t count =
InferShapeHelper<decltype(&ElementwiseAddTest)>::count;
CHECK_EQ(count, 2U);
}

TEST(ElementwiseAdd, registry) {
InferShapedKernelRegistry registry;
RegisterInferShapeLaunchers(&registry);
Expand All @@ -35,6 +48,7 @@ TEST(ElementwiseAdd, registry) {
tensor::DenseHostTensor c({2, 8}, GetDType<float>());

host_context::KernelFrameBuilder kernel_frame_builder;
kernel_frame_builder.AddArgument(new host_context::Value(0));
kernel_frame_builder.AddArgument(new host_context::Value(std::move(a)));
kernel_frame_builder.AddArgument(new host_context::Value(std::move(b)));
kernel_frame_builder.SetResults({new host_context::Value(std::move(c))});
Expand Down
17 changes: 7 additions & 10 deletions paddle/infrt/naive/infershaped/infershaped_kernel_launcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace naive {
void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
host_context::KernelFrame* frame) {
for (host_context::Value* value :
frame->GetValues(0, frame->GetNumElements())) {
frame->GetValues(1, frame->GetNumElements() - 1)) {
// TODO(Superjomn) To extend this.
if (value->is_type<tensor::DenseHostTensor>()) {
values.emplace_back(MetaTensor{&value->get<tensor::DenseHostTensor>()});
Expand All @@ -32,27 +32,24 @@ void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
}

void InferShapedKernelLauncher::BuildInferShapeCache(
const uint16_t* input_indices, const uint16_t num_inputs) {
const uint16_t num_inputs) {
tensor_shape_cache.resize(num_inputs);
for (uint16_t i = 0; i < num_inputs; i++) {
tensor_shape_cache[i] =
infershape_kernel_frame_builder.GetArgAt(input_indices[i])
->get<MetaTensor>()
.shape();
infershape_kernel_frame_builder.GetArgAt(i)->get<MetaTensor>().shape();
}
}

bool InferShapedKernelLauncher::IsShapeChanged(
const uint16_t* input_indices, const uint16_t num_inputs) const {
const uint16_t num_inputs) const {
if (tensor_shape_cache.empty() && !infershape_kernel_frame_builder.IsEmpty())
return true;

bool changed = false;
for (uint16_t i = 0; i < num_inputs && !changed; i++) {
changed = changed || (tensor_shape_cache[i] !=
infershape_kernel_frame_builder
.GetArgAt<MetaTensor>(input_indices[i])
.shape());
changed = changed ||
(tensor_shape_cache[i] !=
infershape_kernel_frame_builder.GetArgAt<MetaTensor>(i).shape());
}
return changed;
}
Expand Down
6 changes: 2 additions & 4 deletions paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,10 @@ struct InferShapedKernelLauncher {

//! Build or update the infer-shape cache using the latest shape from
//! InferShapeFrame.
void BuildInferShapeCache(const uint16_t* input_indices,
const uint16_t num_inputs);
void BuildInferShapeCache(const uint16_t num_inputs);

//! Compare the latest shape with the shape cache.
bool IsShapeChanged(const uint16_t* input_indices,
const uint16_t num_inputs) const;
bool IsShapeChanged(const uint16_t num_inputs) const;

// values to hold the TensorMeta.
llvm::SmallVector<host_context::ValueRef, 3> values;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@
// limitations under the License.

#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"

#include "paddle/infrt/naive/infershaped/elementwise_add.h"
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"

namespace infrt {
namespace naive {

using ElementwiseAddLauncher =
KernelLauncher<decltype(&ElementwiseAdd),
&ElementwiseAdd,
decltype(&ElementwiseAddInferShape),
&ElementwiseAddInferShape>;

void RegisterInferShapeLaunchers(InferShapedKernelRegistry* registry) {
registry->AddKernel("elementwise_add",
INFERSHAPED_KERNEL_CREATOR(ElementwiseAddLauncher));
Expand Down
77 changes: 77 additions & 0 deletions paddle/infrt/naive/infershaped/infershaped_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <type_traits>
#include "paddle/infrt/tensor/dense_host_tensor.h"

namespace infrt {
namespace naive {
namespace infershaped {

using KeyType = const tensor::DenseHostTensor&;
using CountType = uint8_t;

constexpr CountType value(std::true_type) { return 1; }

constexpr CountType value(std::false_type) { return 0; }

template <typename T>
constexpr CountType value() {
return value(std::integral_constant<bool, std::is_same<T, KeyType>::value>{});
}

template <typename FirstArg>
constexpr CountType count(CountType num) {
return num;
}

template <typename FirstArg>
constexpr CountType count() {
return 0;
}

template <>
constexpr CountType count<KeyType>(CountType num) {
return num + 1;
}

template <>
constexpr CountType count<KeyType>() {
return 1;
}

template <typename FirstArg, typename SecondArg, typename... RestOfArgs>
constexpr CountType count(CountType num) {
return count<SecondArg, RestOfArgs...>(num + value<FirstArg>());
}

template <typename FirstArg, typename SecondArg, typename... RestOfArgs>
constexpr CountType count() {
return count<SecondArg, RestOfArgs...>(value<FirstArg>());
}

} // namespace infershaped

template <typename F>
struct InferShapeHelper;

template <typename Return, typename... Args>
struct InferShapeHelper<Return (*)(Args...)> {
static constexpr int count = infershaped::count<Args...>();
};

} // namespace naive
} // namespace infrt