-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.17】 为 Paddle 新增 pdist API -part #57869
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
Sorry to inform you that 9c651d9's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
Sorry to inform you that 13d04ad's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
@zxcd 可以辛苦帮忙review一下嘛?看看有没有啥问题~ |
test/legacy_test/test_pdist.py
Outdated
paddle.enable_static() | ||
|
||
|
||
class TestpdistAPICase1(TestpdistAPI): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
规范一下测试命名
exe = paddle.static.Executor(self.place) | ||
res = exe.run(feed={'x': self.x}, fetch_list=[out]) | ||
out_ref = ref_pdist(self.x, self.p) | ||
np.testing.assert_allclose(out_ref, res[0], rtol=1e-5, atol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该不需要使用rtol, atol?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>>> x = np.random.rand(3, 4).astype('float64')
>>> t_x = paddle.to_tensor(x)
>>> x == t_x.numpy()
array([[ True, True, True, True],
[ True, True, True, True],
[ True, True, True, True]])
>>> np_norm = np.linalg.norm(x, axis=-1)
>>> pd_norm = paddle.linalg.norm(t_x, axis=-1)
>>> np_norm == pd_norm.numpy()
array([ True, True, True])
>>> x = np.random.rand(5, 6).astype('float64')
>>> t_x = paddle.to_tensor(x)
>>> np_norm = np.linalg.norm(x, axis=-1)
>>> pd_norm = paddle.linalg.norm(t_x, axis=-1)
>>> np_norm == pd_norm.numpy()
array([ True, True, True, True, True])
>>> x = np.random.rand(5, 6).astype('float64')
>>> t_x = paddle.to_tensor(x)
>>> np_norm = np.linalg.norm(x, axis=-1)
>>> pd_norm = paddle.linalg.norm(t_x, axis=-1)
>>> np_norm == pd_norm.numpy()
array([ True, True, True, True, False])
>>>
>>> x = np.random.rand(5, 6).astype('float64')
>>> t_x = paddle.to_tensor(x)
>>> np_norm = np.linalg.norm(x, axis=-1)
>>> pd_norm = paddle.linalg.norm(t_x, axis=-1)
>>> np_norm == pd_norm.numpy()
array([ True, True, True, True, False])
>>>
>>> x = np.random.rand(5, 6).astype('float64')
>>> t_x = paddle.to_tensor(x)
>>> np_norm = np.linalg.norm(x, axis=-1)
>>> pd_norm = paddle.linalg.norm(t_x, axis=-1)
>>> np_norm == pd_norm.numpy()
array([ True, True, True, True, False])
>>>
>>> x = np.random.rand(5, 6).astype('float64')
>>> t_x = paddle.to_tensor(x)
>>> np_norm = np.linalg.norm(x, axis=-1)
>>> pd_norm = paddle.linalg.norm(t_x, axis=-1)
>>> np_norm == pd_norm.numpy()
array([False, False, False, True, True])
>>>
>>> x = np.random.rand(5, 6).astype('float64')
>>> t_x = paddle.to_tensor(x)
>>> np_norm = np.linalg.norm(x, axis=-1)
>>> pd_norm = paddle.linalg.norm(t_x, axis=-1)
>>> np_norm == pd_norm.numpy()
array([ True, True, False, False, True])
>>>
如果不用rtol atol的话精度过不了。我进行了一些尝试,发现norm看样子像是没和numpy对齐,在cdist
的单测中也是放宽了精度
Paddle/test/legacy_test/test_cdist.py
Line 56 in 907e425
np.testing.assert_allclose(out_ref, res[0], rtol=1e-5, atol=1e-5) |
compute_mode (str, optional): The mode for compute distance. | ||
|
||
- ``use_mm_for_euclid_dist_if_necessary`` , for p = 2.0 and (P > 25 or R > 25), it will use matrix multiplication to calculate euclid distance if possible. | ||
- ``use_mm_for_euclid_dist`` , for p = 2.0, it will use matrix multiplication to calculate euclid distance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
仅针对pdist这个API来说,cdist中使用到的这几个mode是否会用到?如果完全用不到我理解这个地方可能不需要添加mode参数。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
仅针对pdist这个API来说,cdist中使用到的这几个mode是否会用到?如果完全用不到我理解这个地方可能不需要添加mode参数。
我这里是把pdist当作一种特殊的cdist来实现:cdist输入x:[P,M], y:[R,M], 输出shape: [P,R],所以我在这里就是通过计算两份相同输入的cdist来实现pdist(令y=x),也就是cdist(x, x), 输入shape:[N, M],输出shape:[N, N],然后取上三角(去除对角线)得到结果。
对照起来看的话,cdist中的P和R对应现在输入[N,M]中的N,那当N较大的时候(大于25)我感觉是不是也可以采取cdist中的矩阵策略,当然不采用矩阵方法的话,应该就是直接norm然后取三角了,不知道这么想对不对
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从竞品对照来看,pdist的结果都是比较统一的,不存在N大于25时需要进行额外操作。另外从公式来说cdist和pdist的意义并不完全一致,我建议与竞品保持一致,不需要额外添加该参数。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从竞品对照来看,pdist的结果都是比较统一的,不存在N大于25时需要进行额外操作。我建议与竞品保持一致,不需要额外添加该参数。
OKK,辛苦review,我之后稍作修改~
Sorry to inform you that 8bcbe47's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
test/legacy_test/test_pdist.py
Outdated
self.x = np.random.rand(50, 20).astype('float64') | ||
|
||
|
||
class TestPdistAPICase8(TestPdistAPI): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议给出有意义的测试命名
.. code-block:: python | ||
|
||
>>> import paddle | ||
>>> a = paddle.randn([4, 5]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doc中给出seed,不然PR-CI-Static-Check过不了
参考:
Paddle/python/paddle/tensor/linalg.py
Line 764 in 855e51e
>>> paddle.seed(2023) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doc中给出seed,不然PR-CI-Static-Check过不了 参考:
Paddle/python/paddle/tensor/linalg.py
Line 764 in 855e51e
>>> paddle.seed(2023)
现在添加了,但是好像没作用嘛?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
添加了seed之后,你的print的结果也会有变化,这块的输出你可以参考报错的内容
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
添加了seed之后,你的print的结果也会有变化,这块的输出你可以参考报错的内容
明白了,稍后修改~
… pdist_coco_dev
…to pdist_coco_dev
…to pdist_coco_dev
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -162,6 +163,7 @@ | |||
'conv3d', | |||
'conv3d_transpose', | |||
'pairwise_distance', | |||
'pdist', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which path do we recommend users to use? If paddle.pdist
is recommended, it cannot be added to this __all__ list. if paddle.nn.functional.pdist
, it cannot be added to the __all__ list in python/paddle/__init__.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tks for you review, I think I prefer paddle.pdist
path, I removed this line then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
对应中文文档可以提上来 |
Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>
Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Description
RFC:
中文文档: