-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[API]Support static branch in paddle.to_tensor #45164
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有几个遗留的TODO,可以再优化下。
place = _get_paddle_place(place) | ||
if place is None: | ||
place = _current_expected_place() | ||
elif not isinstance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个判断其实可以去掉,_get_paddle_place里已经有合法性检查了。可以复用_to_tensor_non_static
"'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.MLUPlace, paddle.CustomPlace" | ||
) | ||
|
||
import re |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要动态import, import语句放到文件最前面
def call_assign(data, dtype=None, stop_grandient=None): | ||
|
||
if isinstance(data, | ||
(Variable, core.VarBase)) and (dtype is None or dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里只需要判断Variable就可以了吧,这个是静态图分支,如果出现core.VarBase应该是非法的吧?
if dtype: | ||
target_dtype = convert_dtype(dtype) | ||
elif hasattr(data, 'dtype'): | ||
target_dtype = convert_dtype(data.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if hasattr(data, 'dtype'):
dtype = data.dtype
elif dtype is None:
dtype = paddle.get_default_dtype()
target_dtype = convert_dtype(dtype)
else: | ||
place = paddle.CPUPlace() | ||
|
||
x = paddle.to_tensor(paddle.randn([5, 2]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测需要再丰富一下,你这里并没有测试[var, 1, 1]的情况
@@ -12,6 +12,8 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
|
|||
# to_tensor api will create 1 less op now, this test was changed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用加这一行
# call assign for static graph | ||
else: | ||
|
||
def call_assign(data, dtype=None, stop_grandient=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_to_tensor_static() 放到外面
) | ||
|
||
import re | ||
re_exp = re.compile(r'[(](.*?)[)]', re.S) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
re_exp = re.compile(r'[(](.*?)[)]', re.S) | |
re_exp = re.compile(r'[(](.*+)[)]', re.S) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Describe