Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API]Support static branch in paddle.to_tensor #45164

Merged
merged 48 commits into from
Aug 18, 2022

Conversation

feifei-111
Copy link
Contributor

PR types

Others

PR changes

Others

Describe

  1. to_tensor will call assign in static graph, so transformer is not needed
  2. to_tensor will set output a correct dtype when it calls assign in static mode
  3. to_tensor will skip assign op when it is not necessary
  4. to_variable will not be changed to to_tensor api , cause it was used to copy a variable in some cases (but to_tensor will not)

@paddle-bot
Copy link

paddle-bot bot commented Aug 15, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@feifei-111 feifei-111 changed the title [API]Support static branch in paddle.to_tensorPr 45088 [API]Support static branch in paddle.to_tensor Aug 16, 2022
Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有几个遗留的TODO,可以再优化下。

place = _get_paddle_place(place)
if place is None:
place = _current_expected_place()
elif not isinstance(
Copy link
Contributor

@Aurelius84 Aurelius84 Aug 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个判断其实可以去掉,_get_paddle_place里已经有合法性检查了。可以复用_to_tensor_non_static

"'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.MLUPlace, paddle.CustomPlace"
)

import re
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要动态import, import语句放到文件最前面

def call_assign(data, dtype=None, stop_grandient=None):

if isinstance(data,
(Variable, core.VarBase)) and (dtype is None or dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只需要判断Variable就可以了吧,这个是静态图分支,如果出现core.VarBase应该是非法的吧?

if dtype:
target_dtype = convert_dtype(dtype)
elif hasattr(data, 'dtype'):
target_dtype = convert_dtype(data.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if hasattr(data, 'dtype'):
    dtype = data.dtype
elif dtype is None:
    dtype = paddle.get_default_dtype()

target_dtype = convert_dtype(dtype)

else:
place = paddle.CPUPlace()

x = paddle.to_tensor(paddle.randn([5, 2]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测需要再丰富一下,你这里并没有测试[var, 1, 1]的情况

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# to_tensor api will create 1 less op now, this test was changed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用加这一行

# call assign for static graph
else:

def call_assign(data, dtype=None, stop_grandient=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_to_tensor_static() 放到外面

)

import re
re_exp = re.compile(r'[(](.*?)[)]', re.S)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
re_exp = re.compile(r'[(](.*?)[)]', re.S)
re_exp = re.compile(r'[(](.*+)[)]', re.S)

Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Aurelius84 Aurelius84 merged commit 3012221 into PaddlePaddle:develop Aug 18, 2022
@feifei-111 feifei-111 deleted the pr_45088 branch December 15, 2023 04:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants