-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add randint #5718
add randint #5718
Conversation
Kevin-XiongC
commented
Aug 4, 2021
•
edited
Loading
edited
python/oneflow/nn/modules/randint.py
Outdated
>>> import numpy as np | ||
>>> generator = flow.Generator() | ||
>>> generator.manual_seed(0) | ||
>>> flow.randint(10,(1,10),generator=generator) |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
(user_op::HobDeviceTag() == "gpu")); | ||
|
||
REGISTER_GPU_RANDINT_KERNEL | ||
} // namespace oneflow |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
python/oneflow/nn/modules/randint.py
Outdated
high: flow.int32, | ||
size: tuple, | ||
generator: flow.Generator = None, | ||
dtype: flow.dtype = flow.int32, |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
oneflow/user/ops/randint_op.cpp
Outdated
.Output("out") | ||
.Attr<int64_t>("low") | ||
.Attr<int64_t>("high") | ||
.Attr<int64_t>("seed", -1) |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
oneflow/user/ops/randint_op.cpp
Outdated
return Maybe<void>::Ok(); | ||
}); | ||
|
||
} // namespace oneflow |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
python/oneflow/nn/modules/randint.py
Outdated
def randint( | ||
low: flow.int64 = 0, | ||
high: Union[int, tuple] = None, | ||
size: tuple = None, | ||
generator: flow.Generator = None, | ||
dtype: flow.dtype = flow.int64, | ||
layout=None, | ||
device: flow.device = flow.device("cpu"), | ||
requires_grad: bool = False, | ||
) -> flow.Tensor: |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
…into dev-randint
CI failed, removing label automerge |
Speed stats:
|