Skip to content

Commit

Permalink
feat: 搭建 Attention 在各层的基本结构
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Jan 29, 2024
1 parent add61cb commit daeac28
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 1 deletion.
19 changes: 19 additions & 0 deletions src/04kernel/include/kernel/collectors/attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef KERNEL_ATTENTION_H
#define KERNEL_ATTENTION_H

#include "../collector.h"

namespace refactor::kernel {

struct AttentionCollector final : public InfoCollector {
dim_t maxSeqLen;

AttentionCollector(decltype(_target), decltype(maxSeqLen)) noexcept;

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
};

}// namespace refactor::kernel

#endif// KERNEL_ATTENTION_H
31 changes: 31 additions & 0 deletions src/04kernel/src/collectors/attention.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "kernel/collectors/attention.h"
#include "kernel/kernel.h"
#include "kernel/tensor.h"
// #include "../kernels/attention/cpu_kernel.hh"
// #include "../kernels/attention/cuda_kernel.hh"

namespace refactor::kernel {

AttentionCollector::AttentionCollector(
decltype(_target) target,
decltype(maxSeqLen) maxSeqLen_) noexcept
: InfoCollector(target),
maxSeqLen(maxSeqLen_) {}

std::vector<KernelBox>
AttentionCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
break;
case decltype(_target)::Nvidia:
break;
case decltype(_target)::Mlu:
break;
default:
UNREACHABLEX(void, "Unknown target");
}
return ans;
}

}// namespace refactor::kernel
21 changes: 21 additions & 0 deletions src/05computation/include/computation/operators/attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef COMPUTATION_ATTENTION_H
#define COMPUTATION_ATTENTION_H

#include "../operator.h"

namespace refactor::computation {

struct Attention final : public Operator {
dim_t maxSeqLen;

constexpr Attention(decltype(maxSeqLen) maxSeqLen_) noexcept
: Operator(), maxSeqLen(maxSeqLen_) {}

static size_t typeId() noexcept;
size_t opTypeId() const noexcept final;
std::string_view name() const noexcept final;
};

}// namespace refactor::computation

#endif// COMPUTATION_ATTENTION_H
13 changes: 13 additions & 0 deletions src/05computation/src/operators/attention.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include "computation/operators/attention.h"

namespace refactor::computation {
using Op = Attention;

auto Op::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}
auto Op::opTypeId() const noexcept -> size_t { return typeId(); }
auto Op::name() const noexcept -> std::string_view { return "Attention"; }

}// namespace refactor::computation
4 changes: 3 additions & 1 deletion src/08-01llm/src/operators.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "llm/operators.h"
#include "operators/attention.hh"
#include "operators/mat_mul.hh"
#include "operators/rms_normalization.hh"

Expand All @@ -8,8 +9,9 @@ namespace refactor::llm {
void register_() {
#define REGISTER(NAME, CLASS) Operator::register_<CLASS>("llm::" #NAME)
// clang-format off
REGISTER(MatMul , MatMul );
REGISTER(Attention , Attention );
REGISTER(RmsNormalization, RmsNormalization);
REGISTER(MatMul , MatMul );
// clang-format on
#undef REGISTER
}
Expand Down
31 changes: 31 additions & 0 deletions src/08-01llm/src/operators/attention.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "computation/operators/attention.h"
#include "attention.hh"
#include "common.h"

namespace refactor::llm {
using Op = Attention;

Op::Attention(decltype(maxSeqLen) maxSeqLen_)
: Operator(), maxSeqLen(maxSeqLen_) {}

auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox {
auto maxSeqLen = attributes.getOrInsert("max_seq_len", {0}).float_();
return OpBox(std::make_unique<Op>(maxSeqLen));
}
auto Op::typeId() -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto Op::opTypeId() const -> size_t { return typeId(); }
auto Op::opTypeName() const -> std::string_view { return "llm::Attention"; }

auto Op::infer(TensorRefs inputs, InferOptions const &) const -> InferResult {
TODO("");
}

auto Op::lower(TensorRefs) const -> computation::OpBox {
TODO("");
}

}// namespace refactor::llm
25 changes: 25 additions & 0 deletions src/08-01llm/src/operators/attention.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef LLM_RMS_ATTENTION_HH
#define LLM_RMS_ATTENTION_HH

#include "frontend/operator.h"

namespace refactor::llm {
using namespace frontend;

struct Attention final : public Operator {
dim_t maxSeqLen;

explicit Attention(decltype(maxSeqLen));

static OpBox build(ModelContext const &, std::string_view, Attributes);
static size_t typeId();

size_t opTypeId() const final;
std::string_view opTypeName() const final;
InferResult infer(TensorRefs, InferOptions const &) const final;
computation::OpBox lower(TensorRefs) const final;
};

}// namespace refactor::llm

#endif// LLM_RMS_ATTENTION_HH

0 comments on commit daeac28

Please sign in to comment.