Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/04kernel/cuda/include/kernel/cuda/pad.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef KERNEL_CUDA_PAD_CUH
#define KERNEL_CUDA_PAD_CUH

#include "threads_distributer.cuh"
#include <cstdint>

namespace refactor::kernel::cuda {

struct PadDimInfo {
unsigned int strideI, strideO, padS, dimI;
};

void launchPad(
KernelLaunchParameters const &,
uint8_t const *src, uint8_t const *src_const,
PadDimInfo const *dims, void *output,
unsigned int rank,
unsigned int blockSize);

}// namespace refactor::kernel::cuda

#endif// KERNEL_CUDA_PAD_CUH
4 changes: 2 additions & 2 deletions src/04kernel/cuda/include/kernel/cuda/slice.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

namespace refactor::kernel::cuda {

struct DimInfo {
struct SliceDimInfo {
unsigned int strideO, skip;
int strideI;
};

void launchSlice(
KernelLaunchParameters const &,
void const *src, DimInfo const *dims, void *output,
void const *src, SliceDimInfo const *dims, void *output,
unsigned int rank,
unsigned int blockSize);

Expand Down
64 changes: 64 additions & 0 deletions src/04kernel/cuda/src/pad.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "kernel/cuda/pad.cuh"
#include "macro.cuh"
#include <cstdint>

namespace refactor::kernel::cuda {

__global__ static void padKernel(
unsigned long long n,
uint8_t const *__restrict__ src,
uint8_t const *__restrict__ src_const,
PadDimInfo const *__restrict__ dims,
uint8_t *__restrict__ dst,
unsigned int rank,
unsigned int blockSize) {
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
step = blockDim.x * gridDim.x;
tid < n;
tid += step) {
long rem = tid, j = 0;
bool flag = false;
for (auto i = 0; i < rank; ++i) {
auto strideO = __ldg(&(dims[i].strideO));
auto strideI = __ldg(&(dims[i].strideI));
auto padS = __ldg(&(dims[i].padS));
auto dimI = __ldg(&(dims[i].dimI));
auto pos = rem / strideO - padS;
if (pos < 0 || pos >= dimI) {
flag = true;
break;
}
j += pos * strideI;
rem %= strideO;
}
if (flag) {
optimizedMemcpy(dst + tid * blockSize, src_const, blockSize);
} else {
optimizedMemcpy(dst + tid * blockSize, src + j * blockSize, blockSize);
}
}
}

void launchPad(
KernelLaunchParameters const &params,
uint8_t const *src, uint8_t const *src_const,
PadDimInfo const *dims, void *output,
unsigned int rank,
unsigned int blockSize) {


padKernel<<<
params.gridSize,
params.blockSize,
0,
reinterpret_cast<cudaStream_t>(params.stream)>>>(
params.n,
src,
src_const,
dims,
reinterpret_cast<uint8_t *>(output),
rank,
blockSize);
}

}// namespace refactor::kernel::cuda
4 changes: 2 additions & 2 deletions src/04kernel/cuda/src/slice.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace refactor::kernel::cuda {
__global__ static void sliceKernel(
unsigned long long n,
uint8_t const *__restrict__ src,
DimInfo const *__restrict__ dims,
SliceDimInfo const *__restrict__ dims,
uint8_t *__restrict__ dst,
unsigned int rank,
unsigned int blockSize) {
Expand All @@ -29,7 +29,7 @@ namespace refactor::kernel::cuda {

void launchSlice(
KernelLaunchParameters const &params,
void const *src, DimInfo const *dims, void *output,
void const *src, SliceDimInfo const *dims, void *output,
unsigned int rank,
unsigned int blockSize) {
sliceKernel<<<
Expand Down
62 changes: 62 additions & 0 deletions src/04kernel/include/kernel/attributes/pad_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#ifndef KERNEL_PAD_ATTRIBUTES_H
#define KERNEL_PAD_ATTRIBUTES_H

#include "../tensor.h"
#include "common.h"

namespace refactor::kernel {

struct PadType {
enum : uint8_t {
Constant,
Reflect,
Edge,
Wrap,
} type;

constexpr PadType() noexcept
: type(Constant) {}
constexpr PadType(decltype(type) type_) noexcept
: type(type_) {}
constexpr operator decltype(type)() const noexcept {
return type;
}
constexpr std::string_view toString() const noexcept {
switch (type) {
case Constant:
return "Constant";
case Reflect:
return "Reflect";
case Edge:
return "Edge";
case Wrap:
return "Wrap";
default:
UNREACHABLE();
}
}
};

namespace pad {
struct Dim {
int64_t dimI, dimO, pads;
};
}// namespace pad

using PadDimension = std::vector<pad::Dim>;

struct PadInfo {
struct Dim {
dim_t strideI, strideO, padS, dimI;
};
std::vector<Dim> dims;
dim_t blockCount, blockSize;

PadInfo(decltype(dims), dim_t, dim_t) noexcept;
PadInfo(PadDimension, Tensor const &);
void reform(dim_t) noexcept;
};

}// namespace refactor::kernel

#endif// KERNEL_PAD_ATTRIBUTES_H
21 changes: 21 additions & 0 deletions src/04kernel/include/kernel/collectors/pad.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef KERNEL_PAD_H
#define KERNEL_PAD_H

#include "../attributes/pad_info.h"
#include "../collector.h"

namespace refactor::kernel {

struct PadCollector final : public InfoCollector {
PadDimension dims;
PadType mode;

explicit PadCollector(decltype(_target) target, PadDimension const &dims_, PadType mode_) noexcept
: InfoCollector(target), dims(std::move(dims_)), mode(mode_) {}

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
};
}// namespace refactor::kernel

#endif// KERNEL_PAD_H
74 changes: 74 additions & 0 deletions src/04kernel/src/attributes/pad_info.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include "kernel/attributes/pad_info.h"
#include <numeric>

namespace refactor::kernel {
using PI = PadInfo;

PI::PadInfo(decltype(dims) dims_, dim_t blockCount_, dim_t blockSize_) noexcept
: dims(std::move(dims_)), blockCount(blockCount_), blockSize(blockSize_) {}

PI::PadInfo(PadDimension dims_, Tensor const &input) : dims{}, blockCount(1),
blockSize(input.dataType.size()) {
size_t rank = input.rank();
ASSERT(dims_.size() == rank, "Invalid to get PadInfo.");

size_t j = 0;
for (auto i : range0_(rank)) {
if (dims_[i].dimI != dims_[i].dimO || dims_[i].dimI != 1) {
if (j < i) { dims_[j] = dims_[i]; }
j++;
}
}
dims_.resize(rank = j);

// 合并末尾连续维度
for (auto i : range0_(rank).rev()) {
if (auto d = dims_[i].dimI; d == dims_[i].dimO) {
blockSize *= d;
dims_.pop_back();
} else {
auto &dim = dims_[i];
if (auto times = std::gcd(std::gcd(dims_[i].dimI, dims_[i].pads), dims_[i].dimO); times > 1) {
blockSize *= times;
dim.dimI /= times;
dim.dimO /= times;
dim.pads /= times;
}
break;
}
}
dims.reserve(rank = dims_.size());

dim_t strideI = 1, strideO = 1;
for (auto i : range0_(rank).rev()) {
auto const &dim = dims_[i];
dims.push_back({
strideI,
strideO,
static_cast<dim_t>(dim.pads),
static_cast<dim_t>(dim.dimI),
});
strideI *= dim.dimI;
strideO *= dim.dimO;
}
std::reverse(dims.begin(), dims.end());
blockCount = strideO;
}

void PI::reform(dim_t maxblockSize) noexcept {
auto blockSize_ = std::gcd(blockSize, maxblockSize);
if (blockSize_ == blockSize) { return; }
auto t = blockSize / blockSize_;
blockCount *= t;
blockSize = blockSize_;
for (auto &d : dims) {
d.strideI *= t;
d.strideO *= t;
d.padS *= t;
d.dimI *= t;
}
dims.resize(dims.size() + 1);
dims.back() = {1, 1, 0, t};
}

}// namespace refactor::kernel
2 changes: 1 addition & 1 deletion src/04kernel/src/attributes/slice_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ namespace refactor::kernel {
shape.pop_back();
dims_.pop_back();
} else {
dims.resize(rank = shape.size());
if (auto &dim = dims_[i]; dim.step == 1) {
if (auto times = std::gcd(std::gcd(dim.start, dim.length), shape[i]); times > 1) {
blockSize *= times;
Expand All @@ -58,6 +57,7 @@ namespace refactor::kernel {
break;
}
}
dims.resize(rank = shape.size());
dim_t strideI = 1;
for (auto i : range0_(rank).rev()) {
auto const &dim = dims_[i];
Expand Down
32 changes: 32 additions & 0 deletions src/04kernel/src/collectors/pad.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "kernel/collectors/pad.h"
#include "../kernels/pad/cpu_kernel.hh"
#include "../kernels/pad/cuda_kernel.hh"

namespace refactor::kernel {

std::vector<KernelBox>
PadCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
auto const &input = inputs[0];
PadInfo info(dims, input);
auto const_value = inputs.size() >= 3 ? std::make_optional(inputs[2]) : std::nullopt;

std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
if (auto ptr = PadCpu::build(std::move(info), mode, const_value); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
case decltype(_target)::Nvidia:
if (auto ptr = PadCuda::build(std::move(info), mode, const_value); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
default:
UNREACHABLEX(void, "Unknown target");
}
return ans;
}

}// namespace refactor::kernel

66 changes: 66 additions & 0 deletions src/04kernel/src/kernels/pad/cpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include "cpu_kernel.hh"
#include <execution>

namespace refactor::kernel {
using K = PadCpu;

K::PadCpu(PadInfo info_, PadType mode_, size_t value_) noexcept
: Kernel(), info(std::move(info_)), mode(mode_), valueLength(value_) {}

auto K::build(PadInfo info, PadType mode, std::optional<std::reference_wrapper<Tensor const>> value_) noexcept -> KernelBox {
if (mode != PadType::Constant) {
return nullptr;
}
size_t value = value_ ? value_->get().dataType.size() : 0;
return std::make_unique<K>(std::move(info), mode, value);
}
auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t {
return typeId();
}
auto K::description() const noexcept -> std::string_view {
return "Performing pad operation on generic cpu";
}


auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
using namespace runtime;

return [info = this->info, value = this->valueLength](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
auto src = reinterpret_cast<uint8_t const *>(inputs[0]);
auto dst = reinterpret_cast<uint8_t *>(outputs[0]);
std::vector<uint8_t> defaultValue(info.blockSize, 0);
if (value != 0) {
auto constValue = reinterpret_cast<uint8_t const *>(inputs[2]);
for (auto i : range0_(info.blockSize / value)) {
std::memcpy(defaultValue.data() + i * value, constValue, value);
}
}
std::for_each_n(std::execution::par_unseq,
natural_t(0), info.blockCount,
[=, &info](auto i) {
long rem = i, j = 0;
bool flag = false;
for (auto const &dim : info.dims) {
auto pos = rem / dim.strideO - dim.padS;
if (pos < 0 || pos >= dim.dimI) {
flag = true;
break;
}
j += pos * dim.strideI;
rem %= dim.strideO;
}
if (flag) {
std::memcpy(dst + i * info.blockSize, defaultValue.data(), info.blockSize);
} else {
std::memcpy(dst + i * info.blockSize, src + j * info.blockSize, info.blockSize);
}
});
};
}

}// namespace refactor::kernel
Loading