Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Clean data #58

Merged
merged 149 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
149 commits
Select commit Hold shift + click to select a range
ed02951
Added feature and example of NAS (#516)
XinyuYe-Intel Jan 17, 2023
9812468
fix input_shape_recorder issue in SUT multi-ins (#527)
zhentaoyu Jan 30, 2023
e04b965
revision (#545)
XuhuiRen Jan 30, 2023
2ce0b0d
fix auto-distillation example failed (#550)
n1ck-guo Jan 30, 2023
e45ec64
Update pt version (#548)
VincyZhang Jan 31, 2023
cec31c0
Update docs (#544)
VincyZhang Jan 31, 2023
55b7cd5
Update pt version (#554)
VincyZhang Feb 1, 2023
12f77b6
add src1_perm in matmul python op (#537)
zhentaoyu Feb 2, 2023
224bcbf
Fix requirement (#556)
VincyZhang Feb 2, 2023
f039a9c
remove version in yaml (#557)
VincyZhang Feb 3, 2023
3b116dd
[Kernels] fix windows fail (#542)
airMeng Feb 3, 2023
6d11d72
turn on dynamic link by default (#539)
zhenwei-intel Feb 7, 2023
2a97052
[Kernels] Attention ref klocwork issues (#553)
airMeng Feb 7, 2023
9dafe11
add examples for gptj (#541)
XuhuiRen Feb 8, 2023
eef6487
improve to support C++ api document (#567)
NeoZhangJianyu Feb 8, 2023
28515ad
Added setfit notebook example. (#552)
XinyuYe-Intel Feb 9, 2023
6bbef46
Simplize readme (#564)
VincyZhang Feb 9, 2023
6cae222
Added a nas example. (#563)
XinyuYe-Intel Feb 10, 2023
ea86040
Udpate onnx version (#572)
VincyZhang Feb 10, 2023
de08b4b
update main page and example (#577)
VincyZhang Feb 12, 2023
cc5f325
Update jit_seq_cpy_2x8x8.hpp (#576)
NeoZhangJianyu Feb 12, 2023
95c6e4c
Enable lat_int8 (#565)
intellinjun Feb 13, 2023
70875d6
add docstring in tf_extractor and tf_utils (#569)
zhentaoyu Feb 13, 2023
a780c6f
docstring (#566)
Zhenzhong1 Feb 13, 2023
accfd87
update readme (#568)
zhenwei-intel Feb 13, 2023
401bd77
Refined example documents (#562)
XinyuYe-Intel Feb 13, 2023
5b8e386
add docstring to optimization (#573)
violetch24 Feb 13, 2023
e60b00e
Refactor TF quantization/pruning/distillation examples document (#571)
Spycsh Feb 14, 2023
3417a6a
Update readme (#581)
VincyZhang Feb 14, 2023
2440ca8
[Kernels] Trans MHA merge lnormalized spmm (#558)
zhewang1-intc Feb 15, 2023
6281d8d
sync external repo (#590)
VincyZhang Feb 15, 2023
084da57
Document fix (#591)
VincyZhang Feb 15, 2023
bfbb3c9
Add showcase bloom (#592)
VincyZhang Feb 15, 2023
c2f9ec6
[Kernels] visualize sparsity script (#454)
yuchengliu1 Feb 15, 2023
366e1ee
Enhance compile op registering (#584)
zhentaoyu Feb 15, 2023
181fbb3
Update distillation examples (#595)
VincyZhang Feb 16, 2023
4352992
add base and large bert example to pruner (#560)
n1ck-guo Feb 16, 2023
720d36d
[Engine]: add squeeze op and binary ops (#456)
zhenwei-intel Feb 17, 2023
478abd2
add docstring and update README (#579)
changwangss Feb 17, 2023
097f1c7
docstring (#599)
XuhuiRen Feb 17, 2023
a7671a0
[Kernels] fix improper-null-terminator and MHA cpplint (#594)
sunjiweiswift Feb 17, 2023
66cb3c9
[Neural Engine] Add the code to support tiny vit HF model (#561)
a32543254 Feb 17, 2023
047dd24
Fixed type error for PyTorch Pruning examples (#603)
PenghuiCheng Feb 17, 2023
f79b14d
revise md file for examples (#601)
XuhuiRen Feb 20, 2023
328455a
Changed to quantize SetFit model with INC (#606)
XinyuYe-Intel Feb 21, 2023
e144884
Wangwenqi add op (#596)
CeciliaWwq Feb 21, 2023
b8fbf9e
Zhenzhong/op attr (#604)
Zhenzhong1 Feb 21, 2023
3d39678
BinaryOP->BinaryOp frontend (#613)
Zhenzhong1 Feb 21, 2023
59491be
fix link color (#536)
NeoZhangJianyu Feb 21, 2023
aa5c714
fix of JIRA-391: windows build issue (#588)
luoyu-intel Feb 21, 2023
de05ee8
example README docs refine (#574)
violetch24 Feb 22, 2023
07b95bb
add gpt-neox example (#540)
violetch24 Feb 22, 2023
752942a
Update requiremend docs (#618)
VincyZhang Feb 23, 2023
82b3343
added multi-nodes QAT support for Question Answering and Text Classif…
XinyuYe-Intel Feb 23, 2023
50709bd
Pick back public repo (#622)
VincyZhang Feb 23, 2023
9eaf30a
Guoheng/fix bug 432 (#587)
n1ck-guo Feb 24, 2023
71441fb
[Kernels] bugfix benchmark spmm (#611)
sunjiweiswift Feb 25, 2023
8499771
Remove redundant code (#616)
VincyZhang Feb 27, 2023
bc81ecb
add image classification example (#225)
lkk12014402 Feb 27, 2023
882ee35
[Kernels] Refine headers for library compatibility and documents (#605)
airMeng Feb 28, 2023
41c4281
[Kernels] Reference impl and UT for Dense MHA with dynamic quantizati…
yi1ding Feb 28, 2023
a01bb90
update main page (#651)
VincyZhang Mar 1, 2023
1bf6741
fix klocwork issues (#649)
zhenwei-intel Mar 1, 2023
084ae49
Fix sparse bert mini example (#647)
a32543254 Mar 2, 2023
df6e369
fix for pruning import (#653)
violetch24 Mar 2, 2023
14e6a34
update README (#655)
violetch24 Mar 2, 2023
24f247c
Support gather with pytorch interface (#607)
yuchengliu1 Mar 2, 2023
d16e2ab
remove onnxruntime-extension (#660)
VincyZhang Mar 3, 2023
5c5a5ac
add longformer pruning codes (#585)
lkk12014402 Mar 4, 2023
827dcaa
[Kernels] fix translnorm benchmark fail (#643)
zhewang1-intc Mar 6, 2023
34d8b90
Opennmt fp32 (#598)
zhentaoyu Mar 6, 2023
0da5eb0
update inc build from source (#671)
VincyZhang Mar 6, 2023
7c3da8b
[Kernels] Static Q10N MHA support for GPT-J (#657)
yi1ding Mar 7, 2023
154aabb
Add the DLSA E2E solution to the ITREX (#632)
LifengWang Mar 7, 2023
7c43cd4
[Kernels] fix kernels format (#673)
airMeng Mar 8, 2023
2325063
fix empty_ops (#676)
zhentaoyu Mar 9, 2023
a4aa7f0
remove invalid code (#677)
PenghuiCheng Mar 9, 2023
b5c54de
fix for int8 flag (#684)
violetch24 Mar 10, 2023
5a60fc9
[Kernels] kernel code generator for gpu (#610)
VincyZhang Mar 10, 2023
f2390e3
stable diffusion enabling, including text encoder / vae decoder / une…
Zhenzhong1 Mar 10, 2023
da4d9cd
update pytorch pruner to v2.0 (#624)
n1ck-guo Mar 10, 2023
edc9090
Support smooth quantization and enable bloom model example (#675)
PenghuiCheng Mar 13, 2023
46fa399
Dynamic quantization in executor (#593)
yuchengliu1 Mar 14, 2023
a089147
fix pytest (#699)
yuchengliu1 Mar 14, 2023
bc38e86
parse torchscript model and build new graph (#687)
zhenwei-intel Mar 14, 2023
e135aa0
design a new benchmark API (#656)
xin3he Mar 15, 2023
35aba4e
logsoftmax modified solved conflicts (#682)
CeciliaWwq Mar 15, 2023
6e38b5d
avoid aggr init list & some windows warnings (#697)
yi1ding Mar 15, 2023
a10232d
fix bf16 (#701)
a32543254 Mar 15, 2023
edc855c
removed unspport recipe (#692)
PenghuiCheng Mar 15, 2023
8805289
add longformer (#669)
violetch24 Mar 15, 2023
684b6c6
fix for new benchmark (#706)
violetch24 Mar 15, 2023
2d0fec0
return torch model instead of inc model (#695)
xin3he Mar 15, 2023
56cf71a
stable diffusion bf16 enabling and example initialize (#691)
Zhenzhong1 Mar 16, 2023
6a2f259
[Kernels] Dynamic quant matmul for stable diffusion (#686)
zhewang1-intc Mar 16, 2023
acb693c
add devcatalog (#666)
VincyZhang Mar 16, 2023
eb21d2e
ut optimize (#731)
Zhenzhong1 Mar 20, 2023
60b0a00
[Engine]: Support int8 torch model per-tensor and per-channel (#703)
zhenwei-intel Mar 21, 2023
a8b9e8b
fix for new benchmark API (#729)
violetch24 Mar 21, 2023
525490f
recover tf examples (#723)
Spycsh Mar 21, 2023
5d61e9a
close yaml and bin file (#716)
zhentaoyu Mar 21, 2023
b49c161
[Kernels] dynamic quant mop up (#715)
zhewang1-intc Mar 21, 2023
5387bc0
readd engine related ut (#483)
zhentaoyu Mar 22, 2023
547323d
[Engine]fix lat (#704)
a32543254 Mar 22, 2023
1067b41
fix quant node and pattern order (#734)
zhenwei-intel Mar 22, 2023
c593438
add example for text generation (#664)
XuhuiRen Mar 22, 2023
44570e6
Added a textual inversion distillation for quantization example. (#586)
XinyuYe-Intel Mar 23, 2023
eaa14ce
Stable diffusion example optimize (#741)
Zhenzhong1 Mar 24, 2023
828e5d2
Stable Diffusion README and UT optimize. (#747)
Zhenzhong1 Mar 24, 2023
1ce5cf7
add flan-t5 for summarization (#733)
changwangss Mar 24, 2023
9e818cc
Bert dq examples (#742)
yuchengliu1 Mar 26, 2023
fbb4a6c
skip weight sharing ut (#751)
VincyZhang Mar 27, 2023
4640a05
fix compile fail (#752)
zhewang1-intc Mar 27, 2023
6b9c40d
Refine examples (#690)
VincyZhang Mar 27, 2023
178e85e
Fix example readme (#757)
VincyZhang Mar 28, 2023
9dbf282
add save_model API (#735)
xin3he Mar 28, 2023
93128ce
fix doc typo (#759)
zhentaoyu Mar 28, 2023
f1b41de
Patterns for GPT-J (#743)
zhenwei-intel Mar 28, 2023
e107817
Improve online document with source link (#753)
NeoZhangJianyu Mar 28, 2023
8eca747
Enhancement document of data augmentation (#661)
PenghuiCheng Mar 28, 2023
da67747
fix windows UT (#761)
yuchengliu1 Mar 28, 2023
02ed5b7
klockworks issues (#756)
airMeng Mar 29, 2023
a527766
Support smooth quant args with 'auto' and impove the docstring for co…
PenghuiCheng Mar 29, 2023
3645965
Fixed benchmark error since neural_compressor changed API name (#770)
PenghuiCheng Mar 30, 2023
ac2cbdd
[GPT-J] cherry-pick patterns and ops (#760)
zhenwei-intel Mar 30, 2023
ead0a15
build32 klocwork
airMeng Mar 31, 2023
222bfb8
leave TODO for dynamic_quant_matmul_ref
airMeng Mar 31, 2023
10e5a39
Revert "build32 klocwork"
airMeng Mar 31, 2023
ae80fb6
Revert "leave TODO for dynamic_quant_matmul_ref"
airMeng Mar 31, 2023
1cfdebf
fix bug with Escape characters issues by shlex quote (#766)
CeciliaWwq Mar 31, 2023
968064f
add gpt int8 test (#782)
zhenwei-intel Mar 31, 2023
522c0ec
Use DLOG to improve release efficiency (#781)
sunjiweiswift Mar 31, 2023
8439e5a
update example (#778)
yuchengliu1 Apr 1, 2023
de30df4
fix release (#787)
a32543254 Apr 1, 2023
6c7c6f6
fix windows pytest (#790)
a32543254 Apr 3, 2023
4b2e666
fix engine integration doc (#795)
zhentaoyu Apr 3, 2023
1fe646b
fix the shlex.quote issue (#794)
Zhenzhong1 Apr 3, 2023
20c75db
Fixed typo for smooth_quant example (#792)
PenghuiCheng Apr 3, 2023
96c1b95
add the vit example (#797)
a32543254 Apr 3, 2023
dabf27f
Added example for finetuning chatbot. (#763)
XinyuYe-Intel Apr 3, 2023
cb355f0
add workaround (#800)
Spycsh Apr 4, 2023
5ad2ca7
update example requiremnet (#789)
VincyZhang Apr 4, 2023
47e33b6
Build 31 klockwork (#777)
airMeng Apr 4, 2023
1c4eee4
fix pytorch examples (#796)
violetch24 Apr 4, 2023
9676fb9
Remove LLaMA for legal issue. (#803)
XinyuYe-Intel Apr 4, 2023
64ab07b
update version to 1.0 (#805)
VincyZhang Apr 4, 2023
b0e6088
Data cleaning for Intel domain dataset (#807)
XuhuiRen Apr 4, 2023
035e4a9
update readme
VincyZhang Apr 4, 2023
32c3d2e
Merge branch 'main' into clean_data
VincyZhang Apr 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[Kernels] Reference impl and UT for Dense MHA with dynamic quantizati…
…on (#612)



---------

Co-authored-by: Wang,Zhe <zhe1.wang@intel.com>
  • Loading branch information
yi1ding and zhewang1-intc authored Feb 28, 2023
commit 41c4281af18e8e8b2ac4f4a7fa1c173255658e1d
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class SPARSE_API_ kernel_proxy : public proxy_base<kernel_t, std::shared_ptr<con
public:
inline const jd::kernel_kind& kernel_kind() const { return get_sp()->kd()->kernel_kind(); }
void execute(const std::vector<const void*>& rt_data) const;
size_t get_workspace_size() const;
};

//// The following paragraphs are the various derived kernels and its descriptors.
Expand Down Expand Up @@ -164,6 +165,13 @@ class SPARSE_API_ transpose_mha_desc : public kernel_desc_proxy {
virtual ~transpose_mha_desc() {}
};

class SPARSE_API_ dyn_quantize_mha_desc : public kernel_desc_proxy {
public:
dyn_quantize_mha_desc() {}
explicit dyn_quantize_mha_desc(const operator_desc& op_desc) : kernel_desc_proxy(op_desc) {}
virtual ~dyn_quantize_mha_desc() {}
};

/**
* @brief Derived proxy class, interfacing to the real/cached sparse_matmul_t.
*/
Expand Down Expand Up @@ -237,5 +245,12 @@ class SPARSE_API_ transpose_mha : public kernel_proxy {
virtual ~transpose_mha() {}
};

class SPARSE_API_ dyn_quantize_mha : public kernel_proxy {
public:
dyn_quantize_mha() {}
explicit dyn_quantize_mha(const kernel_desc_proxy& kdp) : kernel_proxy(kdp) {}
virtual ~dyn_quantize_mha() {}
};

} // namespace jd
#endif // ENGINE_SPARSELIB_INCLUDE_INTERFACE_HPP_
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class kernel_t {
// init kernel_t
virtual bool init() = 0;
virtual bool execute(const std::vector<const void*>& rt_data) const = 0;
virtual size_t get_workspace_size() const { return 0; }

public:
const std::shared_ptr<const kernel_desc_t>& kd() const { return kd_; }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) 2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef ENGINE_SPARSELIB_INCLUDE_KERNELS_DYN_QUANTIZE_MHA_REF_HPP_
#define ENGINE_SPARSELIB_INCLUDE_KERNELS_DYN_QUANTIZE_MHA_REF_HPP_

#include <memory>
#include <vector>

#include "amx_utils.hpp"
#include "cpu_isa.hpp"
#include "dyn_quantize_mha_types.hpp"
#include "kernel.hpp"
#include "kernel_desc.hpp"
#include "operator_desc.hpp"
#include "utils.hpp"

namespace jd {

/**
* @brief
* Q K V
* | | |
* | | |
* | Reorder |
* \ / |
* \ / Reorder
* Matmul /
* | /
* | /
* Softmax /
* \ /
* \ /
* Matmul
* |
* |
* Output
*/
class dyn_quantize_mha_ref_k_t;

class SPARSE_API_ dyn_quantize_mha_ref_kd_t : public kernel_desc_t {
public:
using io = ssd::dyn_quantize_mha_io::io;
explicit dyn_quantize_mha_ref_kd_t(const jd::operator_desc& op_desc)
: kernel_desc_t(kernel_kind::dyn_quantize_mha), op_desc_(op_desc) {}
virtual ~dyn_quantize_mha_ref_kd_t() {}

bool init() override;
DECLARE_COMMON_PD_T(dyn_quantize_mha_ref_k_t, dyn_quantize_mha_ref_kd_t);

const jd::operator_desc& get_operator_desc() const override { return op_desc_; }
inline std::vector<dim_t> shape() const override {
return {
op_desc_.tensor_descs()[io::Q].shape()[0], // batch_size
op_desc_.tensor_descs()[io::Q].shape()[2], // head_num
op_desc_.tensor_descs()[io::Q].shape()[1], // M
op_desc_.tensor_descs()[io::Q].shape()[3], // head_size
op_desc_.tensor_descs()[io::K].shape()[1], // N
};
}

private:
jd::operator_desc op_desc_;
};

class SPARSE_API_ dyn_quantize_mha_ref_k_t : public kernel_t {
public:
using io = ssd::dyn_quantize_mha_io::io;
using kd_t = dyn_quantize_mha_ref_kd_t;
explicit dyn_quantize_mha_ref_k_t(const std::shared_ptr<const kernel_desc_t>& kd);
virtual ~dyn_quantize_mha_ref_k_t() {}
// Delete move constructor and move operator
dyn_quantize_mha_ref_k_t(dyn_quantize_mha_ref_k_t&&) = delete;
dyn_quantize_mha_ref_k_t& operator=(dyn_quantize_mha_ref_k_t&&) = delete;
// Delete copy constructor and copy operator
dyn_quantize_mha_ref_k_t(const dyn_quantize_mha_ref_k_t&) = delete;
dyn_quantize_mha_ref_k_t& operator=(const dyn_quantize_mha_ref_k_t&) = delete;

bool init() override;
bool execute(const std::vector<const void*>& rt_data) const override;
const std::shared_ptr<const kd_t> derived_kd() const { return std::static_pointer_cast<const kd_t>(kd_); }

private:
std::vector<std::vector<dim_t>> t_shapes_;
int32_t batch_size_, head_num_, M_, head_size_, N_;
};

} // namespace jd
#endif // ENGINE_SPARSELIB_INCLUDE_KERNELS_DYN_QUANTIZE_MHA_REF_HPP_
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) 2022 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef ENGINE_SPARSELIB_INCLUDE_KERNELS_DYN_QUANTIZE_MHA_TYPES_HPP_
#define ENGINE_SPARSELIB_INCLUDE_KERNELS_DYN_QUANTIZE_MHA_TYPES_HPP_

#include <vector>

#include "amx_utils.hpp"
#include "param_types.hpp"

namespace jd {
namespace ssd {
namespace dyn_quantize_mha_io {
enum io {
Q,
K,
MASK,
V,
DST,
TMP, // size of K + size of V + ~1M

Q_SCALE,
Q_ZP,
K_SCALE,
K_ZP,
V_SCALE,
V_ZP,
DST_SCALE,
DST_ZP,

BATCH_SIZE,
HEAD_NUM,
HEAD_SIZE,
M, // "seq_len" for Q & DST
N, // "seq_len" for K & V
dyn_quantize_mha_io_MAX = N,
};
} // namespace dyn_quantize_mha_io

} // namespace ssd
} // namespace jd
#endif // ENGINE_SPARSELIB_INCLUDE_KERNELS_DYN_QUANTIZE_MHA_TYPES_HPP_
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@

namespace jd {
namespace ssd {
/**
* @brief tensors index configuration of this kernel.
* TODO(Yi): potential confliction with indices of other op types
*/
static constexpr int SRC0 = 0;
static constexpr int SRC1 = 1;
static constexpr int DST0 = 2;
static constexpr int SRC2 = 3; // for binary add
static constexpr int SCALE0 = 4;
static constexpr int ZP0 = 5;

namespace matmul_io {
enum io {
SRC0,
SRC1,
DST0,
SRC2,
SCALE0,
ZP0,
matmul_io_MAX = ZP0,
};
} // namespace matmul_io

struct matmul_param_t {
dim_t M;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ enum class kernel_kind : uint8_t {
logsoftmax,
gather,
attention,
transpose_mha
transpose_mha,
dyn_quantize_mha,
};

enum class postop_alg : uint8_t { undef, exp, tanh, gelu, relu, quantize, dequantize, linear, eltop_int_lut };
Expand Down Expand Up @@ -82,6 +83,7 @@ enum class format_type : uint8_t {
ab, // shape permutation = {0, 1}
ba, // shape permutation = {1, 0}
abc,
abcd,

// encoding format of sparse matrix
uncoded,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ DECLARE_IMPL_LIST(softmax);
DECLARE_IMPL_LIST(gather);
DECLARE_IMPL_LIST(attention);
DECLARE_IMPL_LIST(transpose_mha);
DECLARE_IMPL_LIST(dyn_quantize_mha);

#undef DECLARE_IMPL_LIST

Expand All @@ -48,6 +49,7 @@ const std::vector<impl_list_item_t>* cpu_engine::get_implementation_list(const o
CASE(softmax);
CASE(attention);
CASE(transpose_mha);
CASE(dyn_quantize_mha);
default:
return &cpu_engine::empty_list;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ bool kernel_proxy::create_proxy_object(std::shared_ptr<const kernel_t>& result_r
return true;
}

size_t kernel_proxy::get_workspace_size() const { return get_sp()->get_workspace_size(); }

void kernel_proxy::execute(const std::vector<const void*>& rt_data) const {
bool status = false;
#ifdef SPARSE_LIB_USE_VTUNE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ enum SubKernel {
} // namespace

namespace jd {
using matmul_io = ssd::matmul_io::io;

template <typename T_kd>
inline bool attention_kd_t::add_kernel_desc(const operator_desc& op_desc, const char* name) {
std::shared_ptr<const kernel_desc_t> kd;
Expand Down Expand Up @@ -283,7 +285,7 @@ void attention_k_t::setup_memory() {
const auto tensor_bytes = [](const jd::tensor_desc& d) { return d.size() * type2bytes[d.dtype()]; };

offset.push_back(tensor_bytes(ker_opdesc(SubKernel::QK_SPMM).tensor_descs()[ssd::DST]));
offset.push_back(tensor_bytes(ker_opdesc(SubKernel::Q_K_GEMM).tensor_descs()[ssd::DST0]));
offset.push_back(tensor_bytes(ker_opdesc(SubKernel::Q_K_GEMM).tensor_descs()[matmul_io::DST0]));
offset.push_back(tensor_bytes(ker_opdesc(SubKernel::SOFTMAX).tensor_descs()[1]));
offset.push_back(tensor_bytes(ker_opdesc(SubKernel::V_SPMM).tensor_descs()[ssd::DST]));
// the last kernel(QK(softmax) * V) don't need alloc memory
Expand All @@ -301,15 +303,16 @@ void attention_k_t::setup_memory() {

// part1 Q X K
mem_[SubKernel::Q_K_GEMM].resize(4);
mem_[SubKernel::Q_K_GEMM][ssd::SRC0] = mem_[SubKernel::QK_SPMM][ssd::DST];
mem_[SubKernel::Q_K_GEMM][ssd::SRC1] = mem_[SubKernel::QK_SPMM][ssd::DST] + offset[0] / 2; // split qk out to q and k
mem_[SubKernel::Q_K_GEMM][ssd::DST0] = mem_[SubKernel::QK_SPMM][ssd::DST] + offset[0]; // dst
mem_[SubKernel::Q_K_GEMM][ssd::SRC2] = nullptr;
mem_[SubKernel::Q_K_GEMM][matmul_io::SRC0] = mem_[SubKernel::QK_SPMM][ssd::DST];
mem_[SubKernel::Q_K_GEMM][matmul_io::SRC1] =
mem_[SubKernel::QK_SPMM][ssd::DST] + offset[0] / 2; // split qk out to q and k
mem_[SubKernel::Q_K_GEMM][matmul_io::DST0] = mem_[SubKernel::QK_SPMM][ssd::DST] + offset[0]; // dst
mem_[SubKernel::Q_K_GEMM][matmul_io::SRC2] = nullptr;

// part2 Softmax
mem_[SubKernel::SOFTMAX].resize(2);
mem_[SubKernel::SOFTMAX][0] = mem_[SubKernel::Q_K_GEMM][ssd::DST0];
mem_[SubKernel::SOFTMAX][1] = mem_[SubKernel::Q_K_GEMM][ssd::DST0] + offset[1];
mem_[SubKernel::SOFTMAX][0] = mem_[SubKernel::Q_K_GEMM][matmul_io::DST0];
mem_[SubKernel::SOFTMAX][1] = mem_[SubKernel::Q_K_GEMM][matmul_io::DST0] + offset[1];

// part5 spmm for V
mem_[SubKernel::V_SPMM].resize(ssd::SCALES + 1);
Expand All @@ -321,13 +324,13 @@ void attention_k_t::setup_memory() {
mem_[SubKernel::V_SPMM][ssd::DST] = mem_[SubKernel::SOFTMAX][1] + offset[2];

// part6 V X QK(softmax out)
mem_[SubKernel::QK_V_MATMUL].resize(ssd::ZP0 + 1);
mem_[SubKernel::QK_V_MATMUL][ssd::SRC0] = mem_[SubKernel::SOFTMAX][1];
mem_[SubKernel::QK_V_MATMUL][ssd::SRC1] = mem_[SubKernel::V_SPMM][ssd::DST];
mem_[SubKernel::QK_V_MATMUL][ssd::DST0] = nullptr;
mem_[SubKernel::QK_V_MATMUL][ssd::SRC2] = nullptr;
mem_[SubKernel::QK_V_MATMUL][ssd::SCALE0] = nullptr;
mem_[SubKernel::QK_V_MATMUL][ssd::ZP0] = nullptr;
mem_[SubKernel::QK_V_MATMUL].resize(matmul_io::ZP0 + 1);
mem_[SubKernel::QK_V_MATMUL][matmul_io::SRC0] = mem_[SubKernel::SOFTMAX][1];
mem_[SubKernel::QK_V_MATMUL][matmul_io::SRC1] = mem_[SubKernel::V_SPMM][ssd::DST];
mem_[SubKernel::QK_V_MATMUL][matmul_io::DST0] = nullptr;
mem_[SubKernel::QK_V_MATMUL][matmul_io::SRC2] = nullptr;
mem_[SubKernel::QK_V_MATMUL][matmul_io::SCALE0] = nullptr;
mem_[SubKernel::QK_V_MATMUL][matmul_io::ZP0] = nullptr;
}
bool attention_k_t::init() {
// Create kernel
Expand Down Expand Up @@ -357,12 +360,12 @@ std::vector<const void*> attention_k_t::set_input_output(int index, const std::v
// part0 QK spmm_vnni and part5 V spmm_vnni
data[ssd::SRC] = rt_data[attention_io::MERGE_SRC];
} else if (index == SubKernel::Q_K_GEMM) {
data[ssd::SRC2] = rt_data[attention_io::Q_K_SRC2];
data[matmul_io::SRC2] = rt_data[attention_io::Q_K_SRC2];
} else if (index == SubKernel::QK_V_MATMUL) {
// part4 transpose matmul for QK x V
data[ssd::DST0] = rt_data[attention_io::MERGE_DST];
data[ssd::SCALE0] = rt_data[attention_io::QK_V_OUTPUT_SCALES];
data[ssd::ZP0] = rt_data[attention_io::QK_V_OUTPUT_ZERO_POINT];
data[matmul_io::DST0] = rt_data[attention_io::MERGE_DST];
data[matmul_io::SCALE0] = rt_data[attention_io::QK_V_OUTPUT_SCALES];
data[matmul_io::ZP0] = rt_data[attention_io::QK_V_OUTPUT_ZERO_POINT];
}
return data;
}
Expand Down
Loading