Skip to content

Commit

Permalink
use tensor.shape bug not paddle.shape(tensor) (#8260)
Browse files Browse the repository at this point in the history
* use tensor.shape bug not paddle.shape(tensor)

* refine

* refine
  • Loading branch information
wanghuancoder authored Apr 16, 2024
1 parent f658fa7 commit ee88c12
Show file tree
Hide file tree
Showing 68 changed files with 292 additions and 326 deletions.
4 changes: 2 additions & 2 deletions examples/language_model/moe/dygraph/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,8 +748,8 @@ def forward(self, input_ids, position_ids=None, attention_mask=None, use_cache=F
if position_ids is None:
past_length = 0
if cache is not None:
past_length = paddle.shape(cache[0].k)[-2]
position_ids = paddle.arange(past_length, paddle.shape(input_ids)[-1] + past_length, dtype="int64")
past_length = cache[0].k.shape[-2]
position_ids = paddle.arange(past_length, input_ids.shape[-1] + past_length, dtype="int64")
position_ids = position_ids.unsqueeze(0)
# .expand_as(input_ids)
position_ids = paddle.expand_as(position_ids, input_ids)
Expand Down
12 changes: 5 additions & 7 deletions examples/model_interpretation/task/senti/rnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def forward(self, input, mask=None):
# Shape: (batch_size, max_seq_len, hidden_size)
h = paddle.add_n([forward_input, backward_input])
# Shape: (batch_size, hidden_size, 1)
att_weight = self.att_weight.tile(repeat_times=(paddle.shape(h)[0], 1, 1))
att_weight = self.att_weight.tile(repeat_times=(h.shape[0], 1, 1))
# Shape: (batch_size, max_seq_len, 1)
att_score = paddle.bmm(paddle.tanh(h), att_weight)
if mask is not None:
Expand Down Expand Up @@ -246,20 +246,18 @@ def forward(self, input, mask=None):
Tensor is a bool tensor, whose each element identifies whether the input word id is pad token or not.
Defaults to `None
"""
weight = self.input_weight.tile(
repeat_times=(paddle.shape(input)[0], 1, 1)
) # tensor[batch, hidden_size, hidden_size]
bias = self.bias.tile(repeat_times=(paddle.shape(input)[0], 1, 1)) # tensor[batch, 1, hidden_size]
weight = self.input_weight.tile(repeat_times=(input.shape[0], 1, 1)) # tensor[batch, hidden_size, hidden_size]
bias = self.bias.tile(repeat_times=(input.shape[0], 1, 1)) # tensor[batch, 1, hidden_size]
word_squish = paddle.bmm(input, weight) + bias # Shape: (batch_size, seq_len, hidden_size)
att_context_vector = self.att_context_vector.tile(
repeat_times=(paddle.shape(input)[0], 1, 1)
repeat_times=(input.shape[0], 1, 1)
) # Shape: (batch_size, hidden_size, 1)
att_score = paddle.bmm(word_squish, att_context_vector) # tensor[batch_size, seq_len, 1]
if mask is not None:
# mask, remove the effect of 'PAD'
mask = paddle.cast(mask, dtype="float32")
mask = mask.unsqueeze(axis=-1)
inf_tensor = paddle.full(shape=paddle.shape(mask), dtype="float32", fill_value=-INF)
inf_tensor = paddle.full(shape=mask.shape, dtype="float32", fill_value=-INF)
att_score = paddle.multiply(att_score, mask) + paddle.multiply(inf_tensor, (1 - mask))
att_weight = F.softmax(att_score, axis=1) # tensor[batch_size, seq_len, 1]

Expand Down
2 changes: 1 addition & 1 deletion examples/simultaneous_translation/stacl/demo/model_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def greedy_search(self, src_word, max_len=256, waitk=-1, caches=None, bos_id=Non
So, it needsprevious state(caches) and last one of generated
tokens id last time.
"""
src_max_len = paddle.shape(src_word)[-1]
src_max_len = src_word.shape[-1]
base_attn_bias = (
paddle.cast(src_word == self.bos_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e9
)
Expand Down
10 changes: 5 additions & 5 deletions examples/simultaneous_translation/stacl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from __future__ import print_function

import numpy as np

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlenlp.transformers import WordEmbedding, PositionalEmbedding

from paddlenlp.transformers import PositionalEmbedding, WordEmbedding


class CrossEntropyCriterion(nn.Layer):
Expand Down Expand Up @@ -190,8 +190,8 @@ def __init__(
self.linear = nn.Linear(in_features=d_model, out_features=trg_vocab_size, bias_attr=False)

def forward(self, src_word, trg_word):
src_max_len = paddle.shape(src_word)[-1]
trg_max_len = paddle.shape(trg_word)[-1]
src_max_len = src_word.shape[-1]
trg_max_len = trg_word.shape[-1]
base_attn_bias = (
paddle.cast(src_word == self.bos_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e9
)
Expand Down Expand Up @@ -236,7 +236,7 @@ def beam_search(self, src_word, beam_size=4, max_len=256, waitk=-1):
raise NotImplementedError

def greedy_search(self, src_word, max_len=256, waitk=-1):
src_max_len = paddle.shape(src_word)[-1]
src_max_len = src_word.shape[-1]
base_attn_bias = (
paddle.cast(src_word == self.bos_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e9
)
Expand Down
10 changes: 5 additions & 5 deletions examples/text_classification/rnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def forward(self, input, mask=None):
# Shape: (batch_size, max_seq_len, hidden_size)
h = paddle.add_n([forward_input, backward_input])
# Shape: (batch_size, hidden_size, 1)
att_weight = self.att_weight.tile(repeat_times=(paddle.shape(h)[0], 1, 1))
att_weight = self.att_weight.tile(repeat_times=(h.shape[0], 1, 1))
# Shape: (batch_size, max_seq_len, 1)
att_score = paddle.bmm(paddle.tanh(h), att_weight)
if mask is not None:
Expand Down Expand Up @@ -292,19 +292,19 @@ def forward(self, input, mask=None):
Tensor is a bool tensor, whose each element identifies whether the input word id is pad token or not.
Defaults to `None
"""
weight = self.input_weight.tile(repeat_times=(paddle.shape(input)[0], 1, 1))
bias = self.bias.tile(repeat_times=(paddle.shape(input)[0], 1, 1))
weight = self.input_weight.tile(repeat_times=(input.shape[0], 1, 1))
bias = self.bias.tile(repeat_times=(input.shape[0], 1, 1))
# Shape: (batch_size, max_seq_len, hidden_size)
word_squish = paddle.bmm(input, weight) + bias

att_context_vector = self.att_context_vector.tile(repeat_times=(paddle.shape(input)[0], 1, 1))
att_context_vector = self.att_context_vector.tile(repeat_times=(input.shape[0], 1, 1))
# Shape: (batch_size, max_seq_len, 1)
att_score = paddle.bmm(word_squish, att_context_vector)
if mask is not None:
# mask, remove the effect of 'PAD'
mask = paddle.cast(mask, dtype="float32")
mask = mask.unsqueeze(axis=-1)
inf_tensor = paddle.full(shape=paddle.shape(mask), dtype="float32", fill_value=-INF)
inf_tensor = paddle.full(shape=mask.shape, dtype="float32", fill_value=-INF)
att_score = paddle.multiply(att_score, mask) + paddle.multiply(inf_tensor, (1 - mask))
att_weight = F.softmax(att_score, axis=1)

Expand Down
2 changes: 1 addition & 1 deletion examples/text_to_sql/RAT-SQL/text2sql/utils/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def batch_gather_2d(var, indices):
"shape of indices error. it should be a 2-D layers. " "but got shape = %s" % (str(indices.shape),)
)

batch_size = paddle.shape(indices)[0]
batch_size = indices.shape[0]

zero = paddle.to_tensor([0], dtype="int64")
one = paddle.to_tensor([1], dtype="int64")
Expand Down
4 changes: 2 additions & 2 deletions llm/ernie-3.5-se/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, output_attentions, config, is_causal=True
):

bsz, q_len, num_heads, _ = paddle.shape(query_states)
bsz, q_len, num_heads, _ = query_states.shape
head_dim = config.hidden_size // config.num_attention_heads
_, kv_seq_len, _, _ = value_states.shape

Expand Down Expand Up @@ -1054,7 +1054,7 @@ def forward(
seq_length_with_past = seq_length
cache_length = 0
if past_key_values[0] is not None:
cache_length = paddle.shape(past_key_values[0][0])[1]
cache_length = past_key_values[0][0].shape[1]
seq_length_with_past += cache_length
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids).astype(self.embed_tokens.weight.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -735,8 +735,8 @@ def forward(self, input_ids, position_ids=None, attention_mask=None, use_cache=F
if position_ids is None:
past_length = 0
if cache is not None:
past_length = paddle.shape(attention_mask)[-1] - 1
position_ids = paddle.arange(past_length, paddle.shape(input_ids)[-1] + past_length, dtype=input_ids.dtype)
past_length = attention_mask.shape[-1] - 1
position_ids = paddle.arange(past_length, input_ids.shape[-1] + past_length, dtype=input_ids.dtype)
position_ids = position_ids.unsqueeze(0)
position_ids = paddle.expand_as(position_ids, input_ids)

Expand All @@ -753,7 +753,7 @@ def forward(self, input_ids, position_ids=None, attention_mask=None, use_cache=F
if not self.fused_softmax_with_triangular or not paddle.is_compiled_with_cuda():
# TODO, use registered buffer
causal_mask = paddle.tensor.triu(
paddle.ones((paddle.shape(input_ids)[-1], paddle.shape(input_ids)[-1])) * -1e4, diagonal=1
paddle.ones((input_ids.shape[-1], input_ids.shape[-1])) * -1e4, diagonal=1
)
if attention_mask is not None:
if len(attention_mask.shape) == 2:
Expand Down Expand Up @@ -972,7 +972,7 @@ def get_logits_processor(

def expand_inputs_for_generation(self, input_ids, expand_size, attention_mask=None, **model_kwargs):

index = paddle.tile(paddle.arange(paddle.shape(input_ids)[0]).unsqueeze(-1), [1, expand_size]).reshape([-1])
index = paddle.tile(paddle.arange(input_ids.shape[0]).unsqueeze(-1), [1, expand_size]).reshape([-1])

input_ids = paddle.gather(input_ids, index)

Expand Down Expand Up @@ -1109,11 +1109,11 @@ def TopPProcess(probs, top_p, min_tokens_to_keep):
probs = paddle.where(condition, paddle.full_like(probs, 0.0), probs)
return probs

batch_size, cur_len = paddle.shape(input_ids)
batch_size, cur_len = input_ids.shape
# used for compute on gpu, avoid memcpy D2H
cur_len_gpu = paddle.full([1], cur_len, dtype="int64")

origin_len = paddle.shape(input_ids)[1]
origin_len = input_ids.shape[1]
# used for compute on gpu, avoid memcpy D2H
origin_len_gpu = paddle.full([1], origin_len, dtype="int64")

Expand Down Expand Up @@ -1167,7 +1167,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f
raise ImportError(
"please install ppfleetx_ops by 'cd ppfleetx/ops && python setup_cuda.py install'!"
)
top_ps_tensor = paddle.full(shape=[paddle.shape(probs)[0]], fill_value=top_p, dtype=probs.dtype)
top_ps_tensor = paddle.full(shape=[probs.shape[0]], fill_value=top_p, dtype=probs.dtype)
# TODO fake random seed here
# Users should set the random seed dynamically when inference
_, next_tokens = topp_sampling(probs, top_ps_tensor, random_seed=100)
Expand Down Expand Up @@ -1299,7 +1299,7 @@ def forward(self, input_ids=None, **model_kwargs):

if model_kwargs.get("position_ids", None) is None:
model_kwargs["position_ids"] = paddle.arange(
0, paddle.shape(model_kwargs["attention_mask"])[-1], dtype=input_ids.dtype
0, model_kwargs["attention_mask"].shape[-1], dtype=input_ids.dtype
).unsqueeze(0)

self.is_encoder_decoder = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -837,8 +837,8 @@ def forward(self, input_ids, position_ids=None, attention_mask=None, use_cache=F
if position_ids is None:
past_length = 0
if cache is not None:
past_length = paddle.shape(attention_mask)[-1] - 1
position_ids = paddle.arange(past_length, paddle.shape(input_ids)[-1] + past_length, dtype=input_ids.dtype)
past_length = attention_mask.shape[-1] - 1
position_ids = paddle.arange(past_length, input_ids.shape[-1] + past_length, dtype=input_ids.dtype)
position_ids = position_ids.unsqueeze(0)
# .expand_as(input_ids)
position_ids = paddle.expand_as(position_ids, input_ids)
Expand All @@ -851,7 +851,7 @@ def forward(self, input_ids, position_ids=None, attention_mask=None, use_cache=F
if not self.fused_softmax_with_triangular or not paddle.is_compiled_with_cuda():
# TODO, use registered buffer
causal_mask = paddle.tensor.triu(
paddle.ones((paddle.shape(input_ids)[-1], paddle.shape(input_ids)[-1])) * -1e4, diagonal=1
paddle.ones((input_ids.shape[-1], input_ids.shape[-1])) * -1e4, diagonal=1
)
if attention_mask is not None:
if len(attention_mask.shape) == 2:
Expand Down Expand Up @@ -1304,7 +1304,7 @@ def get_logits_processor(

def expand_inputs_for_generation(self, input_ids, expand_size, attention_mask=None, **model_kwargs):

index = paddle.tile(paddle.arange(paddle.shape(input_ids)[0]).unsqueeze(-1), [1, expand_size]).reshape([-1])
index = paddle.tile(paddle.arange(input_ids.shape[0]).unsqueeze(-1), [1, expand_size]).reshape([-1])

input_ids = paddle.gather(input_ids, index)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,8 +602,8 @@ def forward(self, input_ids, position_ids=None, attention_mask=None, use_cache=F
if position_ids is None:
past_length = 0
if cache is not None:
past_length = paddle.shape(attention_mask)[-1] - 1
position_ids = paddle.arange(past_length, paddle.shape(input_ids)[-1] + past_length, dtype=input_ids.dtype)
past_length = attention_mask.shape[-1] - 1
position_ids = paddle.arange(past_length, input_ids.shape[-1] + past_length, dtype=input_ids.dtype)
position_ids = position_ids.unsqueeze(0)
# .expand_as(input_ids)
position_ids = paddle.expand_as(position_ids, input_ids)
Expand All @@ -615,7 +615,7 @@ def forward(self, input_ids, position_ids=None, attention_mask=None, use_cache=F
if not self.fused_softmax_with_triangular or not paddle.is_compiled_with_cuda():
# TODO, use registered buffer
causal_mask = paddle.tensor.triu(
paddle.ones((paddle.shape(input_ids)[-1], paddle.shape(input_ids)[-1])) * -1e4, diagonal=1
paddle.ones((input_ids.shape[-1], input_ids.shape[-1])) * -1e4, diagonal=1
)
if attention_mask is not None:
if len(attention_mask.shape) == 2:
Expand Down Expand Up @@ -848,7 +848,7 @@ def get_logits_processor(

def expand_inputs_for_generation(self, input_ids, expand_size, attention_mask=None, **model_kwargs):

index = paddle.tile(paddle.arange(paddle.shape(input_ids)[0]).unsqueeze(-1), [1, expand_size]).reshape([-1])
index = paddle.tile(paddle.arange(input_ids.shape[0]).unsqueeze(-1), [1, expand_size]).reshape([-1])

input_ids = paddle.gather(input_ids, index)

Expand Down Expand Up @@ -1039,7 +1039,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f
raise ImportError(
"please install ppfleetx_ops by 'cd ppfleetx/ops && python setup_cuda.py install'!"
)
top_ps_tensor = paddle.full(shape=[paddle.shape(probs)[0]], fill_value=top_p, dtype=probs.dtype)
top_ps_tensor = paddle.full(shape=[probs.shape[0]], fill_value=top_p, dtype=probs.dtype)
_, next_tokens = topp_sampling(probs, top_ps_tensor, random_seed=100)
else:
probs = TopPProcess(probs, top_p, min_tokens_to_keep)
Expand Down Expand Up @@ -1194,7 +1194,7 @@ def forward(self, input_ids=None, **model_kwargs):

if model_kwargs.get("position_ids", None) is None:
model_kwargs["position_ids"] = paddle.arange(
0, paddle.shape(model_kwargs["attention_mask"])[-1], dtype=input_ids.dtype
0, model_kwargs["attention_mask"].shape[-1], dtype=input_ids.dtype
).unsqueeze(0)

self.is_encoder_decoder = False
Expand Down
14 changes: 7 additions & 7 deletions paddlenlp/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,9 @@ def get_logits_processor(
@staticmethod
def expand_inputs_for_generation(input_ids, expand_size, attention_mask=None, **model_kwargs):

index = paddle.tile(
paddle.arange(paddle.shape(input_ids)[0], dtype="int64").unsqueeze(-1), [1, expand_size]
).reshape([-1])
index = paddle.tile(paddle.arange(input_ids.shape[0], dtype="int64").unsqueeze(-1), [1, expand_size]).reshape(
[-1]
)

input_ids = paddle.gather(input_ids, index)

Expand Down Expand Up @@ -1340,11 +1340,11 @@ def sample_d2s(
"you should not specify InputSpec for top_k and top_p parameters, one of InputSpec is expected"
)

batch_size, cur_len = paddle.shape(input_ids)
batch_size, cur_len = input_ids.shape
# used for compute on gpu, avoid memcpy D2H
cur_len_gpu = paddle.full([1], cur_len, dtype="int64")

origin_len = paddle.shape(input_ids)[1]
origin_len = input_ids.shape[1]
# used for compute on gpu, avoid memcpy D2H
origin_len_gpu = paddle.full([1], origin_len, dtype="int64")

Expand Down Expand Up @@ -1384,7 +1384,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f
# compute next_tokens
if use_top_p:
logits = logits / temperature
top_ps_tensor = paddle.full(shape=[paddle.shape(probs)[0], 1], fill_value=top_p, dtype=probs.dtype)
top_ps_tensor = paddle.full(shape=[probs.shape[0], 1], fill_value=top_p, dtype=probs.dtype)
_, next_tokens = paddle.tensor.top_p_sampling(probs, top_ps_tensor)
else:
probs = TopKProcess(probs, top_k, min_tokens_to_keep)
Expand Down Expand Up @@ -1428,7 +1428,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f

attn_mask = model_kwargs["attention_mask"]
# make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
model_kwargs["attention_mask"] = paddle.reshape(attn_mask, paddle.shape(attn_mask))
model_kwargs["attention_mask"] = paddle.reshape(attn_mask, attn_mask.shape)
model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
max_new_tokens = paddle.full([1], max_new_tokens + cur_len - 1, dtype="int64")

Expand Down
8 changes: 4 additions & 4 deletions paddlenlp/layers/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def __init__(self, transitions, with_start_stop_tag=True):
if with_start_stop_tag:
self.start_idx = -1
self.stop_idx = -2
self.num_tags = paddle.shape(transitions)[0]
self.num_tags = transitions.shape[0]

self._initial_alpha = None
self._index = None
Expand All @@ -312,7 +312,7 @@ def __init__(self, transitions, with_start_stop_tag=True):

def _initialize_alpha(self, batch_size):
# alpha accumulate the path value to get the different next tag
if self._initial_alpha is None or batch_size > paddle.shape(self._initial_alpha)[0]:
if self._initial_alpha is None or batch_size > self._initial_alpha.shape[0]:
# Initialized by a small value.
initial_alpha = paddle.full([batch_size, self.num_tags - 1], dtype="float32", fill_value=-10000.0)
# alpha_start fill_value = 0. > -10000., means the first one step START gets the most score.
Expand All @@ -336,7 +336,7 @@ def forward(self, inputs, lengths):
The `paths` tensor containing the highest scoring tag indices.
Its dtype is int64 and has a shape of `[batch_size, sequence_length]`.
"""
input_shape = paddle.shape(inputs)
input_shape = inputs.shape
batch_size = input_shape[0]
n_label = input_shape[2]

Expand Down Expand Up @@ -412,6 +412,6 @@ def forward(self, inputs, lengths):
return scores, batch_path

def _get_batch_index(self, batch_size):
if self._batch_index is None or batch_size != paddle.shape(self._batch_index)[0]:
if self._batch_index is None or batch_size != self._batch_index.shape[0]:
self._batch_index = paddle.arange(end=batch_size, dtype="int64")
return self._batch_index
Loading

0 comments on commit ee88c12

Please sign in to comment.