Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev where scalar #5797

Merged
merged 18 commits into from
Aug 17, 2021
Merged

Dev where scalar #5797

merged 18 commits into from
Aug 17, 2021

Conversation

simonJJJ
Copy link
Contributor

@simonJJJ simonJJJ commented Aug 8, 2021

No description provided.

out->mut_dptr<T>());
if (!(x->shape() == y->shape() && y->shape() == cond->shape())) {
size_t num_axes = out->shape().NumAxes();
int64_t elem_cnt = out->shape().elem_cnt();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch那边的where支持0-d tensor吗,比如(torch.randn(2, 0, 2)) 这样他其实是空的(因为2x0x2 = 0)

如果支持的话,需要对elem_cnt == 0情况特判下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好!我看看

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉你这里是不是还没对 0-d tensor的情况做考虑


namespace oneflow {
namespace one {

namespace {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

能不能简单解释下为什么这里这么处理

我感觉你是想在reducesumlike的基础前面,加入一个broadcast操作

如果是这样,那我感觉还不如放到functional那里,并且名字改为BroadcastReduceSumLike,放在gradient_funcs是不太合适的个人认为,并且有可能以后别的地方会用到这个函数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以,好像有个其他的functor也用到了

const std::shared_ptr<oneflow::one::Tensor>& y = ctx->SavedTensors().at(1);

std::shared_ptr<oneflow::one::Tensor> zero_out = JUST(functional::ZerosLike(y));
in_grads->resize(2);

This comment was marked as resolved.

} else {
UNIMPLEMENTED_THEN_RETURN() << "The scalar in Where shoule be float or int.";
}
return OpInterpUtil::Dispatch<Tensor>(*op_, {condition, y}, attrs);

This comment was marked as resolved.

This comment was marked as resolved.

out->mut_dptr<T>());
if (!(x->shape() == y->shape() && y->shape() == cond->shape())) {
size_t num_axes = out->shape().NumAxes();
int64_t elem_cnt = out->shape().elem_cnt();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉你这里是不是还没对 0-d tensor的情况做考虑

};

template<DeviceType device_type, typename T, typename CondT>
class WhereScalarXKernel final : public user_op::OpKernel {

This comment was marked as resolved.

x_scalar_operand = static_cast<T>(ctx->Attr<double>("x_float_operand"));
y_scalar_operand = static_cast<T>(ctx->Attr<double>("y_float_operand"));
} else {
UNIMPLEMENTED();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加点输出提示

const T x_scalar, const T* rhs, T* out) {
FOR_RANGE(int64_t, i, 0, elem_cnt) { out[i] = static_cast<bool>(cond[i]) ? x_scalar : rhs[i]; }
}
static void WhereYScalar(DeviceCtx* ctx, const int64_t elem_cnt, const CondT* cond, const T* lhs,

This comment was marked as resolved.

@@ -27,6 +27,35 @@ struct WhereFunctor {
}
};

template<typename T, typename CondT>
struct WhereScalarXFunctor {

This comment was marked as resolved.

@MARD1NO
Copy link
Contributor

MARD1NO commented Aug 17, 2021

这个PR的修改,可以让masked_fill变得很简单,这部分也请修改掉构造tensor的部分
https://github.com/Oneflow-Inc/oneflow/blob/master/python/oneflow/nn/modules/masked_fill.py#L59-L62

@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot August 17, 2021 08:38
@oneflow-ci-bot oneflow-ci-bot self-requested a review August 17, 2021 10:12
@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

PyTorch resnet50 time: 139.4ms (= 6969.1ms / 50, input_shape=[16, 3, 224, 224], backward is enabled)
OneFlow resnet50 time: 127.8ms (= 6389.2ms / 50, input_shape=[16, 3, 224, 224], backward is enabled)
Relative speed: 1.09 (= 139.4ms / 127.8ms)

PyTorch resnet50 time: 84.9ms (= 4244.6ms / 50, input_shape=[8, 3, 224, 224], backward is enabled)
OneFlow resnet50 time: 74.2ms (= 3712.0ms / 50, input_shape=[8, 3, 224, 224], backward is enabled)
Relative speed: 1.14 (= 84.9ms / 74.2ms)

PyTorch resnet50 time: 60.1ms (= 3003.9ms / 50, input_shape=[4, 3, 224, 224], backward is enabled)
OneFlow resnet50 time: 47.6ms (= 2380.8ms / 50, input_shape=[4, 3, 224, 224], backward is enabled)
Relative speed: 1.26 (= 60.1ms / 47.6ms)

PyTorch resnet50 time: 49.8ms (= 2488.6ms / 50, input_shape=[2, 3, 224, 224], backward is enabled)
OneFlow resnet50 time: 42.9ms (= 2142.8ms / 50, input_shape=[2, 3, 224, 224], backward is enabled)
Relative speed: 1.16 (= 49.8ms / 42.9ms)

PyTorch resnet50 time: 40.0ms (= 1999.7ms / 50, input_shape=[1, 3, 224, 224], backward is enabled)
OneFlow resnet50 time: 37.7ms (= 1886.0ms / 50, input_shape=[1, 3, 224, 224], backward is enabled)
Relative speed: 1.06 (= 40.0ms / 37.7ms)

@oneflow-ci-bot oneflow-ci-bot merged commit 572951d into master Aug 17, 2021
@oneflow-ci-bot oneflow-ci-bot deleted the dev_where_scalar branch August 17, 2021 11:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants