Skip to content

Commit

Permalink
[model_zoo/gpt-3] Fix bugs from PR-61236 which cleared `paddle.jit.dy…
Browse files Browse the repository at this point in the history
…2static.utils_helper` (#7989)

* fix bugs

* add try import to support develop and release
  • Loading branch information
haohongxiang authored Feb 20, 2024
1 parent 4eb6f0a commit 5c9c8d3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@
except:
flash_attention = None

try:
from paddle.jit.api import set_dynamic_shape
except:
from paddle.jit.dy2static.utils_helper import set_dynamic_shape

def shard_op_for_sequence_parallel_linear(tgt, mesh):
# FIXME Hack to shard op for module (linear)
# we only shard the second to the last op (matmul) leave the last op (elementwise_add) un-touched
Expand Down Expand Up @@ -1206,7 +1211,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f

attn_mask = model_kwargs["attention_mask"]
# make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
paddle.jit.dy2static.utils_helper.set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
max_length = paddle.to_tensor(max_length)
while cur_len < max_length:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,17 @@
from paddle.nn.functional.flash_attention import flash_attention
except:
flash_attention = None

try:
from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd
except:
FusedDropoutAdd = None

try:
from paddle.jit.api import set_dynamic_shape
except:
from paddle.jit.dy2static.utils_helper import set_dynamic_shape

def get_attr(layer, name):
if getattr(layer, name, None) is not None:
return getattr(layer, name, None)
Expand Down Expand Up @@ -1501,7 +1507,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f

attn_mask = model_kwargs["attention_mask"]
# make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
paddle.jit.dy2static.utils_helper.set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
while cur_len < max_length:
# Note(GuoxiaWang): Remove outputs = _forward_(**model_kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
except:
flash_attention = None

try:
from paddle.jit.api import set_dynamic_shape
except:
from paddle.jit.dy2static.utils_helper import set_dynamic_shape

def get_attr(layer, name):
if getattr(layer, name, None) is not None:
Expand Down Expand Up @@ -1077,7 +1081,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f

attn_mask = model_kwargs["attention_mask"]
# make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
paddle.jit.dy2static.utils_helper.set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
set_dynamic_shape(model_kwargs["attention_mask"], [-1, -1, -1, -1])
model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
if hasattr(paddle.framework, "_no_check_dy2st_diff"):
# TODO(wanghuancoder): _no_check_dy2st_diff is used to turn off the checking of behavior
Expand Down

0 comments on commit 5c9c8d3

Please sign in to comment.