Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Clean data #58

Merged
merged 149 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
149 commits
Select commit Hold shift + click to select a range
ed02951
Added feature and example of NAS (#516)
XinyuYe-Intel Jan 17, 2023
9812468
fix input_shape_recorder issue in SUT multi-ins (#527)
zhentaoyu Jan 30, 2023
e04b965
revision (#545)
XuhuiRen Jan 30, 2023
2ce0b0d
fix auto-distillation example failed (#550)
n1ck-guo Jan 30, 2023
e45ec64
Update pt version (#548)
VincyZhang Jan 31, 2023
cec31c0
Update docs (#544)
VincyZhang Jan 31, 2023
55b7cd5
Update pt version (#554)
VincyZhang Feb 1, 2023
12f77b6
add src1_perm in matmul python op (#537)
zhentaoyu Feb 2, 2023
224bcbf
Fix requirement (#556)
VincyZhang Feb 2, 2023
f039a9c
remove version in yaml (#557)
VincyZhang Feb 3, 2023
3b116dd
[Kernels] fix windows fail (#542)
airMeng Feb 3, 2023
6d11d72
turn on dynamic link by default (#539)
zhenwei-intel Feb 7, 2023
2a97052
[Kernels] Attention ref klocwork issues (#553)
airMeng Feb 7, 2023
9dafe11
add examples for gptj (#541)
XuhuiRen Feb 8, 2023
eef6487
improve to support C++ api document (#567)
NeoZhangJianyu Feb 8, 2023
28515ad
Added setfit notebook example. (#552)
XinyuYe-Intel Feb 9, 2023
6bbef46
Simplize readme (#564)
VincyZhang Feb 9, 2023
6cae222
Added a nas example. (#563)
XinyuYe-Intel Feb 10, 2023
ea86040
Udpate onnx version (#572)
VincyZhang Feb 10, 2023
de08b4b
update main page and example (#577)
VincyZhang Feb 12, 2023
cc5f325
Update jit_seq_cpy_2x8x8.hpp (#576)
NeoZhangJianyu Feb 12, 2023
95c6e4c
Enable lat_int8 (#565)
intellinjun Feb 13, 2023
70875d6
add docstring in tf_extractor and tf_utils (#569)
zhentaoyu Feb 13, 2023
a780c6f
docstring (#566)
Zhenzhong1 Feb 13, 2023
accfd87
update readme (#568)
zhenwei-intel Feb 13, 2023
401bd77
Refined example documents (#562)
XinyuYe-Intel Feb 13, 2023
5b8e386
add docstring to optimization (#573)
violetch24 Feb 13, 2023
e60b00e
Refactor TF quantization/pruning/distillation examples document (#571)
Spycsh Feb 14, 2023
3417a6a
Update readme (#581)
VincyZhang Feb 14, 2023
2440ca8
[Kernels] Trans MHA merge lnormalized spmm (#558)
zhewang1-intc Feb 15, 2023
6281d8d
sync external repo (#590)
VincyZhang Feb 15, 2023
084da57
Document fix (#591)
VincyZhang Feb 15, 2023
bfbb3c9
Add showcase bloom (#592)
VincyZhang Feb 15, 2023
c2f9ec6
[Kernels] visualize sparsity script (#454)
yuchengliu1 Feb 15, 2023
366e1ee
Enhance compile op registering (#584)
zhentaoyu Feb 15, 2023
181fbb3
Update distillation examples (#595)
VincyZhang Feb 16, 2023
4352992
add base and large bert example to pruner (#560)
n1ck-guo Feb 16, 2023
720d36d
[Engine]: add squeeze op and binary ops (#456)
zhenwei-intel Feb 17, 2023
478abd2
add docstring and update README (#579)
changwangss Feb 17, 2023
097f1c7
docstring (#599)
XuhuiRen Feb 17, 2023
a7671a0
[Kernels] fix improper-null-terminator and MHA cpplint (#594)
sunjiweiswift Feb 17, 2023
66cb3c9
[Neural Engine] Add the code to support tiny vit HF model (#561)
a32543254 Feb 17, 2023
047dd24
Fixed type error for PyTorch Pruning examples (#603)
PenghuiCheng Feb 17, 2023
f79b14d
revise md file for examples (#601)
XuhuiRen Feb 20, 2023
328455a
Changed to quantize SetFit model with INC (#606)
XinyuYe-Intel Feb 21, 2023
e144884
Wangwenqi add op (#596)
CeciliaWwq Feb 21, 2023
b8fbf9e
Zhenzhong/op attr (#604)
Zhenzhong1 Feb 21, 2023
3d39678
BinaryOP->BinaryOp frontend (#613)
Zhenzhong1 Feb 21, 2023
59491be
fix link color (#536)
NeoZhangJianyu Feb 21, 2023
aa5c714
fix of JIRA-391: windows build issue (#588)
luoyu-intel Feb 21, 2023
de05ee8
example README docs refine (#574)
violetch24 Feb 22, 2023
07b95bb
add gpt-neox example (#540)
violetch24 Feb 22, 2023
752942a
Update requiremend docs (#618)
VincyZhang Feb 23, 2023
82b3343
added multi-nodes QAT support for Question Answering and Text Classif…
XinyuYe-Intel Feb 23, 2023
50709bd
Pick back public repo (#622)
VincyZhang Feb 23, 2023
9eaf30a
Guoheng/fix bug 432 (#587)
n1ck-guo Feb 24, 2023
71441fb
[Kernels] bugfix benchmark spmm (#611)
sunjiweiswift Feb 25, 2023
8499771
Remove redundant code (#616)
VincyZhang Feb 27, 2023
bc81ecb
add image classification example (#225)
lkk12014402 Feb 27, 2023
882ee35
[Kernels] Refine headers for library compatibility and documents (#605)
airMeng Feb 28, 2023
41c4281
[Kernels] Reference impl and UT for Dense MHA with dynamic quantizati…
yi1ding Feb 28, 2023
a01bb90
update main page (#651)
VincyZhang Mar 1, 2023
1bf6741
fix klocwork issues (#649)
zhenwei-intel Mar 1, 2023
084ae49
Fix sparse bert mini example (#647)
a32543254 Mar 2, 2023
df6e369
fix for pruning import (#653)
violetch24 Mar 2, 2023
14e6a34
update README (#655)
violetch24 Mar 2, 2023
24f247c
Support gather with pytorch interface (#607)
yuchengliu1 Mar 2, 2023
d16e2ab
remove onnxruntime-extension (#660)
VincyZhang Mar 3, 2023
5c5a5ac
add longformer pruning codes (#585)
lkk12014402 Mar 4, 2023
827dcaa
[Kernels] fix translnorm benchmark fail (#643)
zhewang1-intc Mar 6, 2023
34d8b90
Opennmt fp32 (#598)
zhentaoyu Mar 6, 2023
0da5eb0
update inc build from source (#671)
VincyZhang Mar 6, 2023
7c3da8b
[Kernels] Static Q10N MHA support for GPT-J (#657)
yi1ding Mar 7, 2023
154aabb
Add the DLSA E2E solution to the ITREX (#632)
LifengWang Mar 7, 2023
7c43cd4
[Kernels] fix kernels format (#673)
airMeng Mar 8, 2023
2325063
fix empty_ops (#676)
zhentaoyu Mar 9, 2023
a4aa7f0
remove invalid code (#677)
PenghuiCheng Mar 9, 2023
b5c54de
fix for int8 flag (#684)
violetch24 Mar 10, 2023
5a60fc9
[Kernels] kernel code generator for gpu (#610)
VincyZhang Mar 10, 2023
f2390e3
stable diffusion enabling, including text encoder / vae decoder / une…
Zhenzhong1 Mar 10, 2023
da4d9cd
update pytorch pruner to v2.0 (#624)
n1ck-guo Mar 10, 2023
edc9090
Support smooth quantization and enable bloom model example (#675)
PenghuiCheng Mar 13, 2023
46fa399
Dynamic quantization in executor (#593)
yuchengliu1 Mar 14, 2023
a089147
fix pytest (#699)
yuchengliu1 Mar 14, 2023
bc38e86
parse torchscript model and build new graph (#687)
zhenwei-intel Mar 14, 2023
e135aa0
design a new benchmark API (#656)
xin3he Mar 15, 2023
35aba4e
logsoftmax modified solved conflicts (#682)
CeciliaWwq Mar 15, 2023
6e38b5d
avoid aggr init list & some windows warnings (#697)
yi1ding Mar 15, 2023
a10232d
fix bf16 (#701)
a32543254 Mar 15, 2023
edc855c
removed unspport recipe (#692)
PenghuiCheng Mar 15, 2023
8805289
add longformer (#669)
violetch24 Mar 15, 2023
684b6c6
fix for new benchmark (#706)
violetch24 Mar 15, 2023
2d0fec0
return torch model instead of inc model (#695)
xin3he Mar 15, 2023
56cf71a
stable diffusion bf16 enabling and example initialize (#691)
Zhenzhong1 Mar 16, 2023
6a2f259
[Kernels] Dynamic quant matmul for stable diffusion (#686)
zhewang1-intc Mar 16, 2023
acb693c
add devcatalog (#666)
VincyZhang Mar 16, 2023
eb21d2e
ut optimize (#731)
Zhenzhong1 Mar 20, 2023
60b0a00
[Engine]: Support int8 torch model per-tensor and per-channel (#703)
zhenwei-intel Mar 21, 2023
a8b9e8b
fix for new benchmark API (#729)
violetch24 Mar 21, 2023
525490f
recover tf examples (#723)
Spycsh Mar 21, 2023
5d61e9a
close yaml and bin file (#716)
zhentaoyu Mar 21, 2023
b49c161
[Kernels] dynamic quant mop up (#715)
zhewang1-intc Mar 21, 2023
5387bc0
readd engine related ut (#483)
zhentaoyu Mar 22, 2023
547323d
[Engine]fix lat (#704)
a32543254 Mar 22, 2023
1067b41
fix quant node and pattern order (#734)
zhenwei-intel Mar 22, 2023
c593438
add example for text generation (#664)
XuhuiRen Mar 22, 2023
44570e6
Added a textual inversion distillation for quantization example. (#586)
XinyuYe-Intel Mar 23, 2023
eaa14ce
Stable diffusion example optimize (#741)
Zhenzhong1 Mar 24, 2023
828e5d2
Stable Diffusion README and UT optimize. (#747)
Zhenzhong1 Mar 24, 2023
1ce5cf7
add flan-t5 for summarization (#733)
changwangss Mar 24, 2023
9e818cc
Bert dq examples (#742)
yuchengliu1 Mar 26, 2023
fbb4a6c
skip weight sharing ut (#751)
VincyZhang Mar 27, 2023
4640a05
fix compile fail (#752)
zhewang1-intc Mar 27, 2023
6b9c40d
Refine examples (#690)
VincyZhang Mar 27, 2023
178e85e
Fix example readme (#757)
VincyZhang Mar 28, 2023
9dbf282
add save_model API (#735)
xin3he Mar 28, 2023
93128ce
fix doc typo (#759)
zhentaoyu Mar 28, 2023
f1b41de
Patterns for GPT-J (#743)
zhenwei-intel Mar 28, 2023
e107817
Improve online document with source link (#753)
NeoZhangJianyu Mar 28, 2023
8eca747
Enhancement document of data augmentation (#661)
PenghuiCheng Mar 28, 2023
da67747
fix windows UT (#761)
yuchengliu1 Mar 28, 2023
02ed5b7
klockworks issues (#756)
airMeng Mar 29, 2023
a527766
Support smooth quant args with 'auto' and impove the docstring for co…
PenghuiCheng Mar 29, 2023
3645965
Fixed benchmark error since neural_compressor changed API name (#770)
PenghuiCheng Mar 30, 2023
ac2cbdd
[GPT-J] cherry-pick patterns and ops (#760)
zhenwei-intel Mar 30, 2023
ead0a15
build32 klocwork
airMeng Mar 31, 2023
222bfb8
leave TODO for dynamic_quant_matmul_ref
airMeng Mar 31, 2023
10e5a39
Revert "build32 klocwork"
airMeng Mar 31, 2023
ae80fb6
Revert "leave TODO for dynamic_quant_matmul_ref"
airMeng Mar 31, 2023
1cfdebf
fix bug with Escape characters issues by shlex quote (#766)
CeciliaWwq Mar 31, 2023
968064f
add gpt int8 test (#782)
zhenwei-intel Mar 31, 2023
522c0ec
Use DLOG to improve release efficiency (#781)
sunjiweiswift Mar 31, 2023
8439e5a
update example (#778)
yuchengliu1 Apr 1, 2023
de30df4
fix release (#787)
a32543254 Apr 1, 2023
6c7c6f6
fix windows pytest (#790)
a32543254 Apr 3, 2023
4b2e666
fix engine integration doc (#795)
zhentaoyu Apr 3, 2023
1fe646b
fix the shlex.quote issue (#794)
Zhenzhong1 Apr 3, 2023
20c75db
Fixed typo for smooth_quant example (#792)
PenghuiCheng Apr 3, 2023
96c1b95
add the vit example (#797)
a32543254 Apr 3, 2023
dabf27f
Added example for finetuning chatbot. (#763)
XinyuYe-Intel Apr 3, 2023
cb355f0
add workaround (#800)
Spycsh Apr 4, 2023
5ad2ca7
update example requiremnet (#789)
VincyZhang Apr 4, 2023
47e33b6
Build 31 klockwork (#777)
airMeng Apr 4, 2023
1c4eee4
fix pytorch examples (#796)
violetch24 Apr 4, 2023
9676fb9
Remove LLaMA for legal issue. (#803)
XinyuYe-Intel Apr 4, 2023
64ab07b
update version to 1.0 (#805)
VincyZhang Apr 4, 2023
b0e6088
Data cleaning for Intel domain dataset (#807)
XuhuiRen Apr 4, 2023
035e4a9
update readme
VincyZhang Apr 4, 2023
32c3d2e
Merge branch 'main' into clean_data
VincyZhang Apr 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[Engine]fix lat (#704)
  • Loading branch information
a32543254 authored Mar 22, 2023
commit 547323d07b0516bcdb17631313c29f811ac0b55a
Original file line number Diff line number Diff line change
Expand Up @@ -92,21 +92,48 @@ def __call__(self, model):
},
'returns': [13, 14]
},
# minilmv2-lat-roberta int8
{
'patterns': {
'in': [[(0, 'Shape'), (1, 'Gather'), (2, 'Unsqueeze'), (5, 'Concat'),
(6, 'Reshape'), (7, 'Equal'), (8, 'Where'), (11, 'Expand')],
[(0, 'Shape'), (3, 'Gather'), (4, 'Unsqueeze'), (5, 'Concat')],
[(), (9, 'Unsqueeze'), (10, 'Unsqueeze'), (11, 'Expand')]],
'out': [[(0, 'ExpandIndices')]]
},
'search_mode': 'op_type',
'node_names': {
0: 11,
},
'input_tensors': {
0: [[{
9: [0]
}, {
0: [0]
}], [[0, 1], 2]],
},
'output_tensors': {
0: [[{
11: [0]
}], [[0], 1]],
},
'returns': [9, 10]
},
]
}

# minilmv2-lat-roberta
for idx, pattern_dict in enumerate(pattern_mapping_config['AttentionMaskLengthAdaptiveExpandIndices']):
model, new_node_names, ret_old_nodes = \
model, new_node_names, ret_old_nodes = \
util.pattern_mapping('AttentionMaskLengthAdaptiveExpandIndices', pattern_dict, model)
if len(new_node_names) != 0:
for i in range(len(new_node_names)):
attr = OrderedDict()
input_indices = []
for unsqueeze_node in ret_old_nodes[i]:
input_indices.append(int(unsqueeze_node.attr['axis']))
attr['position'] = util.list2str(input_indices)
keep_indices_node_idx = model.get_node_id(new_node_names[i][0])
model.nodes[keep_indices_node_idx].attr = attr
if len(new_node_names) != 0:
for i in range(len(new_node_names)):
attr = OrderedDict()
input_indices = []
for unsqueeze_node in ret_old_nodes[i]:
input_indices.append(int(unsqueeze_node.attr['axes']))
attr['position'] = util.list2str(input_indices)
keep_indices_node_idx = model.get_node_id(new_node_names[i][0])
model.nodes[keep_indices_node_idx].attr = attr

return model
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,33 @@ def __call__(self, model):
},
'returns': [5,6]
},

# int8 lat
{
'patterns': {
'in': [[(0, 'Shape'), (1, 'Gather'), (2, 'Unsqueeze'), (3, 'Concat'),
(4, 'Reshape'), (5, 'Equal'), (6, 'Where'), (8, 'Expand')],
[(), (7, 'Unsqueeze'), (8, 'Expand')]],
'out': [[(0, 'ExpandIndices')]]
},
'search_mode': 'op_type',
'node_names': {
0: 7,
},
'input_tensors': {
0: [[{
7: [0]
}, {
0: [0]
}], [[0, 1], 2]],
},
'output_tensors': {
0: [[{
8: [0]
}], [[0], 1]],
},
'returns': [7]
},
]
}

Expand All @@ -179,7 +206,7 @@ def __call__(self, model):
axis_gather = []
for ret_old_node in ret_old_nodes[i]:
if ret_old_node.op_type == 'Unsqueeze':
input_indices.append(int(ret_old_node.attr['axis']))
input_indices.append(int(ret_old_node.attr['axes']))
elif ret_old_node.op_type == 'GatherElements':
axis_gather.append(int(ret_old_node.attr['axis']))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,34 @@ def __call__(self, model):
}], [[0], 1]],
},
'returns': [11]
}]
},
{
'patterns': {
'in': [[(0, "Shape"), (1, 'Gather'), (2, "Unsqueeze"), (3, "Concat"),
(7, "Tile")],
[(0, "Shape"), (4, 'Gather'), (5, 'Range'),
(6, "Unsqueeze"), (7, "Tile")]],
'out': [[(0, 'Range')]]
},
'search_mode': 'op_type',
'node_names': {
0: 5
},
'input_tensors': {
0: [[{
'input_data': [0]
}], [[0], 1]],
},
'output_tensors': {
0: [[{
7: [0]
}], [[0], 1]],
},
'returns': [5, 0]
}
]
}
collect_node = []

for i in range(len(pattern_mapping_config['GenerateSequence'])):
pattern_dict = pattern_mapping_config['GenerateSequence'][i]
Expand All @@ -72,6 +98,10 @@ def __call__(self, model):
attr["step"] = int(old_node.input_tensors[2].data)
new_node_idx = model.get_node_id(new_node_names[j][0])
model.nodes[new_node_idx].attr = attr

if i == 1:
collect_node.append(ret_old_nodes[j][1])
model.insert_nodes(10, collect_node)
return model

return model
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,22 @@ def search_quant_fusion(node):
quant_node, can_fuse = search_quant_fusion(node)
if can_fuse:
if dtype == 'u8' or dtype == 's8':

if quant_node.op_type == "Softmax":
model.change_node_input_tensors(quant_node.name, 1, node.input_tensors[1],
'insert')
model.change_node_input_tensors(quant_node.name, 2, node.input_tensors[2],
'insert')
quant_node.attr['output_dtype'] = "u8"
def is_lat_model(model, p=None):
if p == None:
p = [[(0, 'TopK'),(1, 'GatherElements')]]
match_result = util.search_pattern(p, model)
return len(match_result) != 0
if is_lat_model(model):
node.attr = OrderedDict({'output_dtype': "u8"})
continue
else:
model.change_node_input_tensors(quant_node.name, 1, node.input_tensors[1],
'insert')
model.change_node_input_tensors(quant_node.name, 2, node.input_tensors[2],
'insert')
quant_node.attr['output_dtype'] = "u8"
else:
model.change_node_input_tensors(quant_node.name, -2, node.input_tensors[1],
'modify')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def _set_attr(se_attr, mat_attr, hidden_size, node_names, model):
hidden_size = int(ret_old_nodes[i][1].input_tensors[1].shape[0])
se_attr = ret_old_nodes[i][0].attr
mat_attr = ret_old_nodes[i][1].attr
mat_node = model.get_node_by_name(new_node_names[i][2])
reshape_node = model.get_node_by_name(new_node_names[i][1])
mat_node.input_tensors[0].name = ret_old_nodes[i][1].input_tensors[0].name
reshape_node.output_tensors[0].name = mat_node.input_tensors[0].name
_set_attr(se_attr, mat_attr, hidden_size, new_node_names[i], model)

return model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def _set_attr(ln_attr, se_attr, hidden_size, node_names, model):
model.nodes[scatter_elements_node_idx].attr = se_attr

# minilmv2-lat-roberta
layer_norm_idx = []
remove_list = []
pattern = pattern_mapping_config['ReshapeBeforeRestoreHiddenStates'][0]['patterns']['in']
patterns_nodes_name = util.search_pattern(pattern, model)
for pattern_nodes_name in patterns_nodes_name:
layer_norm_idx.append(model.get_node_id(pattern_nodes_name[0]))
pattern_dict = pattern_mapping_config['ReshapeBeforeRestoreHiddenStates'][0]
model, new_node_names, ret_old_nodes = util.pattern_mapping(
'ReshapeBeforeRestoreHiddenStates', pattern_dict, model)
Expand All @@ -85,7 +91,15 @@ def _set_attr(ln_attr, se_attr, hidden_size, node_names, model):
ln_attr = ret_old_nodes[i][0].attr
se_attr = ret_old_nodes[i][1].attr
_set_attr(ln_attr, se_attr, hidden_size, new_node_names[i], model)
import copy
ln_node = copy.deepcopy(model.get_node_by_name(new_node_names[i][0]))
model.remove_nodes([new_node_names[i][0]])
model.insert_nodes(layer_norm_idx[i] + i, [ln_node])

remove_list.append(new_node_names[i][0])


# model.remove_nodes(remove_list)
return model

return model
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,44 @@ def __call__(self, model):
},
'returns': [10, 12]
},
{
'patterns': {
'in': [[(0, 'Shape'), (1, 'Gather'), (2, 'Unsqueeze'), (3, 'Concat'),
(4, 'Reshape'), (5, 'Equal'), (6, 'Where'), (8, 'Expand'),
(9, 'ScatterElements')],
[(), (7, 'Unsqueeze'), (8, 'Expand')]],
'out': [[(0, 'Reshape'), (1, 'ExpandIndices'), (2, 'ScatterElements')]]
},
'search_mode': 'op_type',
'node_names': {
0: 'reshape_to_3d_before_restoration',
1: 8,
2: 9,
},
'input_tensors': {
0: [[{
0: [0]
}, {
'input_data': [0]
}], [[0, 1], 2]],
1: [[{
7: [0]
}], [[0], 2]],
2: [[{
9: [0],
}, {
9: [2]
}], [[0, 2], 3]],
},
'output_tensors': {
0: [[], [[], 1]],
1: [[], [[], 1]],
2: [[{
9: [0]
}], [[0], 1]],
},
'returns': [7, 9]
},
]
}

Expand All @@ -94,29 +132,29 @@ def _set_attr(input_indices, se_attr, node_names, model):
model.nodes[se_node_idx].attr = se_attr

# minilmv2-lat-roberta
pattern_dict = pattern_mapping_config['RestoreHiddenStatesInLengthAdaptiveUpdateIndices'][
0]
model, new_node_names, ret_old_nodes = util.pattern_mapping(
'RestoreHiddenStatesInLengthAdaptiveUpdateIndices', pattern_dict, model)
if len(new_node_names) != 0:
for i in range(len(new_node_names)):
attr = OrderedDict()
input_indices = []
unsqueeze_node = ret_old_nodes[i][0]
input_indices.append(int(unsqueeze_node.attr['axis']))
se_attr = ret_old_nodes[i][1].attr
_set_attr(input_indices, se_attr, new_node_names[i], model)
# the first scatter elements operation need the output of embedding layer norm
# but its output shape is [bsxseq_len, hidden_size]
# so the first scatter node need modify this tensor to 3d tensor
# whose shape is [bs, seq_len, hidden_size]
reshape_3d_node = model.get_node_by_name(new_node_names[i][0])
embedding_ln_out_tensor = copy.deepcopy(reshape_3d_node.output_tensors[0])
scatter_node = model.get_node_by_name(new_node_names[i][2])
# check if one input tensor is from embedding_layer_norm node
if scatter_node.input_tensors[0].name == reshape_3d_node.input_tensors[0].name:
model.change_node_input_tensors(new_node_names[i][2], 0,
tensor=embedding_ln_out_tensor, mode='modify')
return model
for i in range(len(pattern_mapping_config['RestoreHiddenStatesInLengthAdaptiveUpdateIndices'])):
pattern_dict = pattern_mapping_config['RestoreHiddenStatesInLengthAdaptiveUpdateIndices'][i]
model, new_node_names, ret_old_nodes = util.pattern_mapping(
'RestoreHiddenStatesInLengthAdaptiveUpdateIndices', pattern_dict, model)
if len(new_node_names) != 0:
for i in range(len(new_node_names)):
attr = OrderedDict()
input_indices = []
unsqueeze_node = ret_old_nodes[i][0]
input_indices.append(int(unsqueeze_node.attr['axes']))
se_attr = ret_old_nodes[i][1].attr
_set_attr(input_indices, se_attr, new_node_names[i], model)
# the first scatter elements operation need the output of embedding layer norm
# but its output shape is [bsxseq_len, hidden_size]
# so the first scatter node need modify this tensor to 3d tensor
# whose shape is [bs, seq_len, hidden_size]
reshape_3d_node = model.get_node_by_name(new_node_names[i][0])
embedding_ln_out_tensor = copy.deepcopy(reshape_3d_node.output_tensors[0])
scatter_node = model.get_node_by_name(new_node_names[i][2])
# check if one input tensor is from embedding_layer_norm node
if scatter_node.input_tensors[0].name == reshape_3d_node.input_tensors[0].name:
model.change_node_input_tensors(new_node_names[i][2], 0,
tensor=embedding_ln_out_tensor, mode='modify')
return model

return model
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ void GatherElementsOperator::Reshape(const vector<Tensor*>& input, const vector<
auto& input_data_dtype = data->dtype();
dst_tensor_ptr->set_shape(dst_shape_);
dst_tensor_ptr->set_dtype(input_data_dtype);
outer_ = 1;
inner_ = 1;
for (int i = 0; i < input[0]->shape().size(); i++) {
if (i < axis_) outer_ *= input[0]->shape()[i];
if (i > axis_) inner_ *= input[0]->shape()[i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_attention_reshape_0(self):
Tensor(name='unsqueeze_input1', data=np.array(1), shape=[1])]
output_tensors = [Tensor(name='unsqueeze_output', source_op=['unsqueeze'], dest_op=['expand'])]
unsqueeze_node.construct('unsqueeze', 'Unsqueeze', input_tensors=input_tensors,
output_tensors=output_tensors,attr=OrderedDict({'axis': '1'}))
output_tensors=output_tensors,attr=OrderedDict({'axes': '1'}))

expand_node = OPERATORS['Expand']()
input_tensors = [Tensor(name='unsqueeze_output', source_op=['unsqueeze'], dest_op=['expand']),
Expand Down