-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PIR API adaptor No.31】Migrate paddle.distribution.Normal
into pir
#59910
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice work ,希望下面的 comment 能帮助你解决现在碰到的问题
- 关于 paddle.distribution.Normal 的迁移,需要同时适配 class Distribution 和 class Normal。可以参考 [WIP] Migrate paddle.distribution.Normal into pir #59922
- test/distribution/test_distribution_normal.py 文件里
NormalTest
单测及其派生单测,也可以参考 [WIP] Migrate paddle.distribution.Normal into pir #59922 来进行适配。我的方法是新建了一个上下文管理器InitDataContextManager
,可以根据当前所处模式,指定在哪个 program 里构图。在NormalTest
派生类的 init_static_data 方法里调用InitDataContextManager
. 这个方法仅供参考,如果大佬有更好的方法,欢迎提出~ 😄 - 关于
TestNormalSampleStaic
和TestNormalRSampleStaic
单测,旧 ir 和 pir 不好作统一。建议重新写一份 pir 模式下的单测
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great work~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
raise TypeError( | ||
"Type of input args must be float, list, tuple, numpy.ndarray or Tensor, but received type {}".format( | ||
type(arg) | ||
) | ||
) | ||
if isinstance(arg, paddle.base.libpaddle.pir.Value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么要对Value特殊处理呢?老静态图的Variable是怎么处理的呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不处理会这样子~
这我看的陆师傅的,我不清楚😂 @MarioLulab
for arg in numpy_args: | ||
if isinstance(arg, paddle.base.libpaddle.pir.Value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
Sorry to inform you that ddaadc9's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
PR types
Others
PR changes
APIs
Description
link #58067
PIR API 推全升级
将 paddle.distribution.Normal 迁移升级至 pir,并更新单测
单测覆盖率: 13/13