Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@
backward: fused_rotary_position_embedding_grad
support_dygraph_mode : true

- op : fused_scale_bias_relu_conv_bnstats
args : (Tensor x, Tensor w, Tensor scale, Tensor bias, Tensor bn_scale, Tensor bn_bias, Tensor input_running_mean, Tensor input_running_var, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, str data_format, float momentum, float epsilon, bool fuse_prologue, bool exhaustive_search, int64_t accumulation_count = 0)
optional : scale, bias
output : Tensor(out), Tensor(out_running_mean), Tensor(out_running_var), Tensor(saved_mean), Tensor(saved_var), Tensor(eq_scale), Tensor(eq_bias)
infer_meta :
func : FusedScaleBiasReluConvBnstatsInferMeta
kernel :
func : fused_scale_bias_relu_conv_bnstats
data_type : x

- op : generate_sequence_xpu
args : (Tensor x, DataType dtype)
output : Tensor
Expand Down
134 changes: 134 additions & 0 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -821,4 +821,138 @@ void FastLayernormXPUInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}

void FusedScaleBiasReluConvBnstatsInferMeta(
const MetaTensor& x,
const MetaTensor& w,
const MetaTensor& scale,
const MetaTensor& bias,
const MetaTensor& bn_scale,
const MetaTensor& bn_bias,
const MetaTensor& input_running_mean,
const MetaTensor& input_running_var,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
const std::string& data_format,
float momentum,
float epsilon,
bool fuse_prologue,
bool exhaustive_search,
int64_t accumulation_count,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个参数是什么功能?

Copy link
Contributor Author

@Tom-Zheng Tom-Zheng Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BatchNorm normalize的元素个数, 单GPU (非SyncBatchNorm)时为N*H*W. 详见Table 42.

MetaTensor* out,
MetaTensor* out_running_mean,
MetaTensor* out_running_var,
MetaTensor* saved_mean,
MetaTensor* saved_var,
MetaTensor* eq_scale,
MetaTensor* eq_bias) {
auto in_dims = x.dims();
auto filter_dims = w.dims();
// do some checks
PADDLE_ENFORCE_EQ(
in_dims.size(),
4,
phi::errors::InvalidArgument(
"The input of Op(FusedScaleBiasReluConvBnstats) should be a 4-D "
"Tensor. But "
"received: input's dimension is %u, input's shape is [%s].",
in_dims.size(),
in_dims));

PADDLE_ENFORCE_EQ(
in_dims.size(),
filter_dims.size(),
phi::errors::InvalidArgument(
"The input's dimension and filter's dimension of "
"Op(FusedScaleBiasReluConvBnstats) should be equal. But received: "
"the input's"
" shape is [%s], "
"the input's dimension is %d; the filter's shape is [%s], "
"the filter's dimension is %d.",
in_dims,
in_dims.size(),
filter_dims,
filter_dims.size()));

// Check if data format is NHWC
PADDLE_ENFORCE_EQ(
data_format,
"NHWC",
phi::errors::InvalidArgument(
"Operator(FusedScaleBiasReluConvBnstats) only supports data format "
"of "
"channel last (NHWC) now. But recieved: data_format = '%s'.",
data_format));

PADDLE_ENFORCE_EQ(
groups,
1,
phi::errors::InvalidArgument("Expect group to be 1, got %d.", groups));

const auto input_channels = in_dims[in_dims.size() - 1];
int dilation_size = dilations.size();
for (int i = 0; i < dilation_size; ++i) {
PADDLE_ENFORCE_GT(
dilations[i],
0,
phi::errors::InvalidArgument(
"The dilation of Op(Conv) should be larget than 0, but received "
"dilation is %d.",
dilations[i]));
}

PADDLE_ENFORCE_EQ(
input_channels,
filter_dims[1] * groups,
phi::errors::InvalidArgument(
"The number of input's channels should be equal to filter's channels "
"* groups for Op(FusedScaleBiasReluConvBnstats). But received: the "
"input's"
" channels is %d, "
"the input's shape is [%s]; the filter's channels is %d, the "
"filter's shape is [%s]; the groups is %d. ",
input_channels,
in_dims,
filter_dims[1],
filter_dims,
groups));

// update paddings and dilations accoring to padding_algorithm
std::vector<int> paddings_vec = paddings;
std::vector<int> dilations_vec = dilations;
// get "HW" from "NHWC"
DDim in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
DDim filter_data_dims = phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
phi::UpdatePaddingAndDilation(&paddings_vec,
&dilations_vec,
padding_algorithm,
in_data_dims,
strides,
ksize);

std::vector<int64_t> out_shape({in_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
out_shape.push_back(ConvOutSize(in_dims[i + 1],
filter_dims[i + 2],
dilations[i],
paddings_vec[i * 2],
paddings_vec[i * 2 + 1],
strides[i]));
}
out_shape.push_back(filter_dims[0]);
// make shape for other outputs
auto c_dims = phi::make_ddim({filter_dims[0]});
// set output and output max dims
out->set_dims(DDim(out_shape.data(), out_shape.size()));
out_running_mean->set_dims(c_dims);
out_running_var->set_dims(c_dims);
saved_mean->set_dims(c_dims);
saved_var->set_dims(c_dims);
eq_scale->set_dims(c_dims);
eq_bias->set_dims(c_dims);
}

} // namespace phi
28 changes: 28 additions & 0 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,32 @@ void FastLayernormXPUInferMeta(const MetaTensor& x,
float epsilon,
MetaTensor* out);

void FusedScaleBiasReluConvBnstatsInferMeta(
const MetaTensor& x,
const MetaTensor& w,
const MetaTensor& scale,
const MetaTensor& bias,
const MetaTensor& bn_scale,
const MetaTensor& bn_bias,
const MetaTensor& input_running_mean,
const MetaTensor& input_running_var,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
const std::string& data_format,
float momentum,
float epsilon,
bool fuse_prologue,
bool exhaustive_search,
int64_t accumulation_count,
MetaTensor* out,
MetaTensor* out_running_mean,
MetaTensor* out_running_var,
MetaTensor* saved_mean,
MetaTensor* saved_var,
MetaTensor* eq_scale,
MetaTensor* eq_bias);

} // namespace phi
5 changes: 5 additions & 0 deletions paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ if(WITH_CUTLASS)
list(APPEND kernel_cu ${cutlass_cu})
endif()

if(NOT WITH_CUDNN_FRONTEND)
list(REMOVE_ITEM kernel_cu
"fusion/gpu/fused_scale_bias_relu_conv_bnstats_kernel.cu")
endif()

set(cc_search_pattern
"*.cc"
"cpu/*.cc"
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/autotune/cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ std::string AlgorithmTypeString(int64_t algo_type) {
} else if (algo_type ==
static_cast<int64_t>(AlgorithmType::kConvBackwardFilterV8)) {
return "conv_backward_filter_v8";
} else if (algo_type ==
static_cast<int64_t>(AlgorithmType::kScaleBiasReluConvBNstats)) {
return "scale_bias_relu_conv_bnstats";
} else if (algo_type == static_cast<int64_t>(AlgorithmType::kBNFinalize)) {
return "bn_finalize";
}
#endif
return std::to_string(algo_type);
Expand Down
9 changes: 5 additions & 4 deletions paddle/phi/kernels/autotune/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ enum class AlgorithmType {
kConvForwardV8 = 10,
kConvBackwardDataV8 = 11,
kConvBackwardFilterV8 = 12,
kAlgorithmCount = 13
kScaleBiasReluConvBNstats = 13,
kBNFinalize = 14,
kAlgorithmCount = 15
#endif
};

Expand Down Expand Up @@ -178,9 +180,8 @@ class AutoTuneCache {
conv_auto_tune_map_[key] = cache;
}
#ifdef PADDLE_WITH_CUDNN_FRONTEND
} else if (algo_type == AlgorithmType::kConvForwardV8 ||
algo_type == AlgorithmType::kConvBackwardDataV8 ||
algo_type == AlgorithmType::kConvBackwardFilterV8) {
} else if (algo_type >= AlgorithmType::kConvForwardV8 &&
algo_type <= AlgorithmType::kBNFinalize) {
int64_t key = static_cast<int64_t>(algo_type);
if (cudnn_v8_auto_tune_map_.find(key) == cudnn_v8_auto_tune_map_.end()) {
CudnnFrontendPlanCache cache;
Expand Down
73 changes: 59 additions & 14 deletions paddle/phi/kernels/autotune/cache_cudnn_frontend.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,22 @@ class CudnnFrontendPlanCache {
return ret;
}

void GetPlan(const cudnn_frontend::feature_vector_t &feature,
const cudnn_frontend::ExecutionPlan **plan,
int64_t *workspace_size,
cudnnHandle_t handle) {
void GetPlanAndWorkspaceSize(const cudnn_frontend::feature_vector_t &feature,
const cudnn_frontend::ExecutionPlan **plan,
int64_t *workspace_size,
cudnnHandle_t handle) {
// Note(tizheng): CUDNNv8 execution plan is not thread-safe.
// A shared plan being executed by different threads is
// generally not safe (for now).
std::lock_guard<std::mutex> lock(*cache_mutex_);
auto &local_map = map_[hasher(std::this_thread::get_id())];

auto it = local_map.find(GetExtendedFeature(feature, handle));
if (it == local_map.end()) {
PADDLE_THROW(phi::errors::InvalidArgument(
"[cudnn_frontend] Cached Plan Not Found."));
return;
}
PADDLE_ENFORCE_NE(it,
local_map.end(),
phi::errors::InvalidArgument(
"[cudnn_frontend] Cached Plan Not Found."));

*plan = &(it->second);
*workspace_size = (*plan)->getWorkspaceSize();
VLOG(4) << "Cached execution plan found." << (*plan)->getTag()
Expand Down Expand Up @@ -133,11 +133,12 @@ class CudnnFrontendPlanCache {
return FindPlan(op_graph.getFeatureVector(), handle);
}

void GetPlan(const cudnn_frontend::OperationGraph &op_graph,
const cudnn_frontend::ExecutionPlan **plan,
int64_t *workspace_size,
cudnnHandle_t handle) {
GetPlan(op_graph.getFeatureVector(), plan, workspace_size, handle);
void GetPlanAndWorkspaceSize(const cudnn_frontend::OperationGraph &op_graph,
const cudnn_frontend::ExecutionPlan **plan,
int64_t *workspace_size,
cudnnHandle_t handle) {
GetPlanAndWorkspaceSize(
op_graph.getFeatureVector(), plan, workspace_size, handle);
}

void InsertPlan(const cudnn_frontend::OperationGraph &op_graph,
Expand Down Expand Up @@ -176,5 +177,49 @@ class CudnnFrontendPlanCache {
int64_t cache_misses_{0};
}; // class CudnnFrontendPlanCache

template <typename T>
inline void BuildFeatureVectorSingle(cudnn_frontend::feature_vector_t *v,
const T &value) {
v->push_back(static_cast<int64_t>(value));
}

template <>
inline void BuildFeatureVectorSingle(cudnn_frontend::feature_vector_t *v,
const float &value) {
int64_t val = 0;
memcpy(&val, &value, sizeof(float));
v->push_back(val);
}

template <>
inline void BuildFeatureVectorSingle<std::vector<int64_t>>(
cudnn_frontend::feature_vector_t *v, const std::vector<int64_t> &value) {
v->insert(v->end(), value.begin(), value.end());
}

template <>
inline void BuildFeatureVectorSingle<std::vector<int>>(
cudnn_frontend::feature_vector_t *v, const std::vector<int> &value) {
for (auto &val : value) {
v->push_back(static_cast<int64_t>(val));
}
}

template <>
inline void BuildFeatureVectorSingle<std::string>(
cudnn_frontend::feature_vector_t *v, const std::string &value) {
v->push_back(std::hash<std::string>()(value));
}

inline void BuildFeatureVector(cudnn_frontend::feature_vector_t *v) { return; }

template <typename T, typename... Args>
inline void BuildFeatureVector(cudnn_frontend::feature_vector_t *v,
const T &value,
Args... args) {
BuildFeatureVectorSingle(v, value);
BuildFeatureVector(v, args...);
}

} // namespace autotune
} // namespace phi
Loading