-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev where scalar #5797
Dev where scalar #5797
Conversation
out->mut_dptr<T>()); | ||
if (!(x->shape() == y->shape() && y->shape() == cond->shape())) { | ||
size_t num_axes = out->shape().NumAxes(); | ||
int64_t elem_cnt = out->shape().elem_cnt(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch那边的where支持0-d tensor吗,比如(torch.randn(2, 0, 2)) 这样他其实是空的(因为2x0x2 = 0)
如果支持的话,需要对elem_cnt == 0情况特判下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好!我看看
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉你这里是不是还没对 0-d tensor的情况做考虑
…dev_where_scalar
|
||
namespace oneflow { | ||
namespace one { | ||
|
||
namespace { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
能不能简单解释下为什么这里这么处理
我感觉你是想在reducesumlike的基础前面,加入一个broadcast操作
如果是这样,那我感觉还不如放到functional那里,并且名字改为BroadcastReduceSumLike,放在gradient_funcs是不太合适的个人认为,并且有可能以后别的地方会用到这个函数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以,好像有个其他的functor也用到了
const std::shared_ptr<oneflow::one::Tensor>& y = ctx->SavedTensors().at(1); | ||
|
||
std::shared_ptr<oneflow::one::Tensor> zero_out = JUST(functional::ZerosLike(y)); | ||
in_grads->resize(2); |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
} else { | ||
UNIMPLEMENTED_THEN_RETURN() << "The scalar in Where shoule be float or int."; | ||
} | ||
return OpInterpUtil::Dispatch<Tensor>(*op_, {condition, y}, attrs); |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
out->mut_dptr<T>()); | ||
if (!(x->shape() == y->shape() && y->shape() == cond->shape())) { | ||
size_t num_axes = out->shape().NumAxes(); | ||
int64_t elem_cnt = out->shape().elem_cnt(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉你这里是不是还没对 0-d tensor的情况做考虑
}; | ||
|
||
template<DeviceType device_type, typename T, typename CondT> | ||
class WhereScalarXKernel final : public user_op::OpKernel { |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
x_scalar_operand = static_cast<T>(ctx->Attr<double>("x_float_operand")); | ||
y_scalar_operand = static_cast<T>(ctx->Attr<double>("y_float_operand")); | ||
} else { | ||
UNIMPLEMENTED(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加点输出提示
const T x_scalar, const T* rhs, T* out) { | ||
FOR_RANGE(int64_t, i, 0, elem_cnt) { out[i] = static_cast<bool>(cond[i]) ? x_scalar : rhs[i]; } | ||
} | ||
static void WhereYScalar(DeviceCtx* ctx, const int64_t elem_cnt, const CondT* cond, const T* lhs, |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
@@ -27,6 +27,35 @@ struct WhereFunctor { | |||
} | |||
}; | |||
|
|||
template<typename T, typename CondT> | |||
struct WhereScalarXFunctor { |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
这个PR的修改,可以让masked_fill变得很简单,这部分也请修改掉构造tensor的部分 |
…dev_where_scalar
Speed stats:
|
No description provided.