-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add index_select op #5661
add index_select op #5661
Conversation
Ikkyu321
commented
Jul 29, 2021
•
edited
Loading
edited
…flow into dev_index_select git pull project needed
添加文档截图 参考 #5636 oneflow/docs 执行 |
|
||
|
||
@register_tensor_op("index_select") | ||
def index_select_op(input, dim, index, sparse_grad=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flow.xxx
和 flow.Tensor.xxx
共用一个函数的话,就共用了 docstring,那么总有一方的文档是不准确的。
需要为 flow.Tensor.xxx
单独准备一个函数,书写 docstring。
可以参考 acos
等算子的写法
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flow.xxx
和flow.Tensor.xxx
共用一个函数的话,就共用了 docstring,那么总有一方的文档是不准确的。
需要为flow.Tensor.xxx
单独准备一个函数,书写 docstring。
可以参考acos
等算子的写法
已修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
按照 torch 的 API 文档 https://pytorch.org/docs/stable/generated/torch.index_select.html?highlight=index_select#torch.index_select
这个算子的原型是:
torch.index_select(input, dim, index, *, out=None) → Tensor
考虑到暂时不对齐 out 参数,那么原型也应该是
.index_select(input, dim, index)
这里的 sparse_grad
是那个版本的呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input.index_select(dim, index, sparse_grad) -> Tensor | ||
See :func:`oneflow.index_select` | ||
""" | ||
assert sparse_grad is False, "Only support bool = False for now!" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不必要把一样的代码写两遍,直接调用 index_select_op
就可以了
index_select_op(input, dim, index, sparse_grad)
|
||
|
||
@register_tensor_op("index_select") | ||
def index_select_op(input, dim, index, sparse_grad=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
按照 torch 的 API 文档 https://pytorch.org/docs/stable/generated/torch.index_select.html?highlight=index_select#torch.index_select
这个算子的原型是:
torch.index_select(input, dim, index, *, out=None) → Tensor
考虑到暂时不对齐 out 参数,那么原型也应该是
.index_select(input, dim, index)
这里的 sparse_grad
是那个版本的呢?
|
||
|
||
@register_tensor_op("index_select") | ||
def index_select_op(input, dim, index, sparse_grad=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Args: | ||
input (Tensor): the source tensor | ||
dim (int): the axis along which to index | ||
index (LongTensor): the indices of elements to select |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flow 里面还没有 LongTensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the 1-D tensor containing the indices to index
我给的建议是直接参考 torch 的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the 1-D tensor containing the indices to index
我给的建议是直接参考 torch 的
已修改
|
||
|
||
def index_select_op(input, dim, index, sparse_grad=False): | ||
r"""Select values along an axis specified by `dim`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加一个对 PyTorch 参考的说明。可以参考 Conv1d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加一个对 PyTorch 参考的说明。可以参考
Conv1d
已修改
""" | ||
assert sparse_grad is False, "Only support bool = False for now!" | ||
assert len(index.shape) == 1, "Dimensions of index should be an LongTensor" | ||
assert dim < len(input.shape) and dim > -1, "Value of dim is out of range" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert dim < len(input.shape) and dim > -1, "Value of dim is out of range" | |
assert dim < len(input.shape) and dim >= 0, "Value of dim is out of range" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
assert sparse_grad is False, "Only support bool = False for now!" | ||
assert len(index.shape) == 1, "Dimensions of index should be an LongTensor" | ||
assert dim < len(input.shape) and dim > -1, "Value of dim is out of range" | ||
assert _input_args_is_int(index.tolist()), "input index parameter is not illegal!" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得这里的提示可以更明确点,告诉用户怎样做可以修正好。而不是只告诉用户他做的不合法
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得这里的提示可以更明确点,告诉用户怎样做可以修正好。而不是只告诉用户他做的不合法
已修改
@register_tensor_op("index_select") | ||
def index_select_op_tensor(input, dim, index, sparse_grad=False): | ||
""" | ||
input.index_select(dim, index, sparse_grad) -> Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数原型是错的(不应该有 sparase_grad)
可以直接参考 https://pytorch.org/docs/1.9.0/generated/torch.Tensor.index_select.html?highlight=index_select#torch.Tensor.index_select 或者 oneflow 已有的其它例子
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数原型是错的(不应该有 sparase_grad)
可以直接参考 https://pytorch.org/docs/1.9.0/generated/torch.Tensor.index_select.html?highlight=index_select#torch.Tensor.index_select 或者 oneflow 已有的其它例子
已修改
device = random_device() | ||
|
||
# test 4 dimensions tensor | ||
axis = random(0, 4).to(int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我不是很赞同 axis 和 dim 混用。 虽然无伤大雅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我不是很赞同 axis 和 dim 混用。 虽然无伤大雅
修改为dim
x = index_i.expand(index_rshp) | ||
index_gather = flow.cat((index_gather, x), dim) | ||
|
||
return flow.gather(input, index_gather, dim, sparse_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return flow.gather(input, index_gather, dim, sparse_grad) | |
return flow.gather(input, index_gather, dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
tensor_dim.append(random(2, 6).to(int).value()) | ||
|
||
index = random_pytorch_tensor( | ||
ndim=1, low=0, high=tensor_dim[dim.value()], dtype=int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ndim=1, low=0, high=tensor_dim[dim.value()], dtype=int | |
ndim=1, dim0=dim.value()+1, low=0, high=tensor_dim[dim.value()], dtype=int |
现在这样的测试案例 ,index 是shape 总是 [1],不能充分覆盖各种情况。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在这样的测试案例 ,index 是shape 总是 [1],不能充分覆盖各种情况。
已修改index的长度为随机(1,10)
Speed stats:
|