From daeac287927f538b017cb83962650b52fb827812 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 29 Jan 2024 09:38:16 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=90=AD=E5=BB=BA=20Attention=20?= =?UTF-8?q?=E5=9C=A8=E5=90=84=E5=B1=82=E7=9A=84=E5=9F=BA=E6=9C=AC=E7=BB=93?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../include/kernel/collectors/attention.h | 19 ++++++++++++ src/04kernel/src/collectors/attention.cc | 31 +++++++++++++++++++ .../include/computation/operators/attention.h | 21 +++++++++++++ src/05computation/src/operators/attention.cc | 13 ++++++++ src/08-01llm/src/operators.cpp | 4 ++- src/08-01llm/src/operators/attention.cc | 31 +++++++++++++++++++ src/08-01llm/src/operators/attention.hh | 25 +++++++++++++++ 7 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 src/04kernel/include/kernel/collectors/attention.h create mode 100644 src/04kernel/src/collectors/attention.cc create mode 100644 src/05computation/include/computation/operators/attention.h create mode 100644 src/05computation/src/operators/attention.cc create mode 100644 src/08-01llm/src/operators/attention.cc create mode 100644 src/08-01llm/src/operators/attention.hh diff --git a/src/04kernel/include/kernel/collectors/attention.h b/src/04kernel/include/kernel/collectors/attention.h new file mode 100644 index 00000000..527bc63f --- /dev/null +++ b/src/04kernel/include/kernel/collectors/attention.h @@ -0,0 +1,19 @@ +#ifndef KERNEL_ATTENTION_H +#define KERNEL_ATTENTION_H + +#include "../collector.h" + +namespace refactor::kernel { + + struct AttentionCollector final : public InfoCollector { + dim_t maxSeqLen; + + AttentionCollector(decltype(_target), decltype(maxSeqLen)) noexcept; + + std::vector + filter(TensorRefs inputs, TensorRefs outputs) const final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_ATTENTION_H diff --git a/src/04kernel/src/collectors/attention.cc b/src/04kernel/src/collectors/attention.cc new file mode 100644 index 00000000..736db6cd --- /dev/null +++ b/src/04kernel/src/collectors/attention.cc @@ -0,0 +1,31 @@ +#include "kernel/collectors/attention.h" +#include "kernel/kernel.h" +#include "kernel/tensor.h" +// #include "../kernels/attention/cpu_kernel.hh" +// #include "../kernels/attention/cuda_kernel.hh" + +namespace refactor::kernel { + + AttentionCollector::AttentionCollector( + decltype(_target) target, + decltype(maxSeqLen) maxSeqLen_) noexcept + : InfoCollector(target), + maxSeqLen(maxSeqLen_) {} + + std::vector + AttentionCollector::filter(TensorRefs inputs, TensorRefs outputs) const { + std::vector ans; + switch (_target) { + case decltype(_target)::Cpu: + break; + case decltype(_target)::Nvidia: + break; + case decltype(_target)::Mlu: + break; + default: + UNREACHABLEX(void, "Unknown target"); + } + return ans; + } + +}// namespace refactor::kernel diff --git a/src/05computation/include/computation/operators/attention.h b/src/05computation/include/computation/operators/attention.h new file mode 100644 index 00000000..d5f37997 --- /dev/null +++ b/src/05computation/include/computation/operators/attention.h @@ -0,0 +1,21 @@ +#ifndef COMPUTATION_ATTENTION_H +#define COMPUTATION_ATTENTION_H + +#include "../operator.h" + +namespace refactor::computation { + + struct Attention final : public Operator { + dim_t maxSeqLen; + + constexpr Attention(decltype(maxSeqLen) maxSeqLen_) noexcept + : Operator(), maxSeqLen(maxSeqLen_) {} + + static size_t typeId() noexcept; + size_t opTypeId() const noexcept final; + std::string_view name() const noexcept final; + }; + +}// namespace refactor::computation + +#endif// COMPUTATION_ATTENTION_H diff --git a/src/05computation/src/operators/attention.cc b/src/05computation/src/operators/attention.cc new file mode 100644 index 00000000..4624482a --- /dev/null +++ b/src/05computation/src/operators/attention.cc @@ -0,0 +1,13 @@ +#include "computation/operators/attention.h" + +namespace refactor::computation { + using Op = Attention; + + auto Op::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto Op::opTypeId() const noexcept -> size_t { return typeId(); } + auto Op::name() const noexcept -> std::string_view { return "Attention"; } + +}// namespace refactor::computation diff --git a/src/08-01llm/src/operators.cpp b/src/08-01llm/src/operators.cpp index b48e56fe..a99adb08 100644 --- a/src/08-01llm/src/operators.cpp +++ b/src/08-01llm/src/operators.cpp @@ -1,4 +1,5 @@ #include "llm/operators.h" +#include "operators/attention.hh" #include "operators/mat_mul.hh" #include "operators/rms_normalization.hh" @@ -8,8 +9,9 @@ namespace refactor::llm { void register_() { #define REGISTER(NAME, CLASS) Operator::register_("llm::" #NAME) // clang-format off - REGISTER(MatMul , MatMul ); + REGISTER(Attention , Attention ); REGISTER(RmsNormalization, RmsNormalization); + REGISTER(MatMul , MatMul ); // clang-format on #undef REGISTER } diff --git a/src/08-01llm/src/operators/attention.cc b/src/08-01llm/src/operators/attention.cc new file mode 100644 index 00000000..d8704dd6 --- /dev/null +++ b/src/08-01llm/src/operators/attention.cc @@ -0,0 +1,31 @@ +#include "computation/operators/attention.h" +#include "attention.hh" +#include "common.h" + +namespace refactor::llm { + using Op = Attention; + + Op::Attention(decltype(maxSeqLen) maxSeqLen_) + : Operator(), maxSeqLen(maxSeqLen_) {} + + auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { + auto maxSeqLen = attributes.getOrInsert("max_seq_len", {0}).float_(); + return OpBox(std::make_unique(maxSeqLen)); + } + auto Op::typeId() -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto Op::opTypeId() const -> size_t { return typeId(); } + auto Op::opTypeName() const -> std::string_view { return "llm::Attention"; } + + auto Op::infer(TensorRefs inputs, InferOptions const &) const -> InferResult { + TODO(""); + } + + auto Op::lower(TensorRefs) const -> computation::OpBox { + TODO(""); + } + +}// namespace refactor::llm diff --git a/src/08-01llm/src/operators/attention.hh b/src/08-01llm/src/operators/attention.hh new file mode 100644 index 00000000..1ec0d3e8 --- /dev/null +++ b/src/08-01llm/src/operators/attention.hh @@ -0,0 +1,25 @@ +#ifndef LLM_RMS_ATTENTION_HH +#define LLM_RMS_ATTENTION_HH + +#include "frontend/operator.h" + +namespace refactor::llm { + using namespace frontend; + + struct Attention final : public Operator { + dim_t maxSeqLen; + + explicit Attention(decltype(maxSeqLen)); + + static OpBox build(ModelContext const &, std::string_view, Attributes); + static size_t typeId(); + + size_t opTypeId() const final; + std::string_view opTypeName() const final; + InferResult infer(TensorRefs, InferOptions const &) const final; + computation::OpBox lower(TensorRefs) const final; + }; + +}// namespace refactor::llm + +#endif// LLM_RMS_ATTENTION_HH