-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Closed
Labels
Description
文档链接&描述 Document Links & Description
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/distribution/Categorical_cn.html
对于Categorical.probs()方法:
logtis拼错value[:-1] = logits[:-1]存在问题,例如当logits是2D tensor时:
logits = paddle.to_tensor([
[0.8, 0.2],
[0.4, 0.6]]
)
value = paddle.to_tensor([
[1],
[0]]
)logits[:-1]为[0.8, 0.2],value[:-1]为[1],并不相等
请提出你的建议 Please give your suggestion
应将value[:-1] = logits[:-1]修改为value.shape[:-1] = logits.shape[:-1]