Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exponential added. #138

Merged
merged 5 commits into from
Aug 1, 2024
Merged

Exponential added. #138

merged 5 commits into from
Aug 1, 2024

Conversation

tongxin
Copy link
Contributor

@tongxin tongxin commented Jul 26, 2024

Added Triton function for inplace operator tensor.exponential_.
The output value range is (0, \infin)

x = torch.randn(size=shape, dtype=dtype, device="cuda")
with flag_gems.use_gems():
res_out = x.exponential_(lambd=0.5)
assert res_out.min() > 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a K-S test for the distribution?
Testing for positiveness is too minimal
https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks and more substantial tests will be on the way.

Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

benchmark to do

src/flag_gems/ops/exponential_.py Outdated Show resolved Hide resolved
x = torch.randn(size=shape, dtype=dtype, device="cuda")
with flag_gems.use_gems():
res_out = x.exponential_(lambd=0.5)
assert res_out.min() > 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it enough to ensure the accuracy? I think we need more checks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More tests will be done.

src/flag_gems/ops/exponential_.py Show resolved Hide resolved
def transform_exponential(u, lambd, eps):
eps1 = -0.5 * eps
is_min = u >= 1.0 + eps1
log = tl.where(is_min, eps1, tl.math.log(u))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really needed? What about just using log(u) or log(1-u)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for enforcing compatibility with Pytorch..

Copy link
Collaborator

@Bowen12992 Bowen12992 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, added scipy to UT Env

logging.debug("GEMS EXPONENTIAL_")
dtype = x.dtype
device = x.device
inplace = x.is_contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performing inplace operation on a tensor with internal overlapping should raise a Runtime Exception.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RuntimeError: unsupported operation: more than one element of the written-to tensor refers to a single memory location. Please clone() the tensor before performing the operation.

Copy link
Collaborator

@iclementine iclementine Aug 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it would raise a runtime error when copying data back.

import torch
import flag_gems
flag_gems.enable()
x = torch.ones(2, device="cuda")
x = torch.broadcast_to(x, (3, 2))
x.exponential_()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pytorch throws with exactly the same error. We'll just keep the current way.

@iclementine iclementine merged commit 4a5acdd into master Aug 1, 2024
3 checks passed
Bowen12992 pushed a commit to Bowen12992/FlagGems that referenced this pull request Aug 6, 2024
* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.
Bowen12992 pushed a commit to Bowen12992/FlagGems that referenced this pull request Aug 6, 2024
* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.
Bowen12992 pushed a commit to Bowen12992/FlagGems that referenced this pull request Aug 6, 2024
* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.
@StrongSpoon StrongSpoon deleted the exponential branch August 13, 2024 07:36
tongxin added a commit that referenced this pull request Sep 2, 2024
* WIP: multinomial

* add Ops & UT & Bench

* add full zero ones Ops & UT & Bench

* split normal op

* Adding multinomial.

* fixed one off error in binary search

* Added multinomial tests without replacement.

* PR comment

* split test_special_ops

* updated with_replacement  tests

* add K-S test

* split special perf

* Update to a more reliable without-replacement test

* Exponential added. (#138)

* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.

* resolve conflict

* Use int64 indexing when needed & fix argmax (#146)

 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size

* test for op

* test for op

* Added multinomial perf tests.

* Making libentry thread safe (#136)

* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.

* add argparse

* fix desc

* fix num

* Update test_specific_ops.py

* split UT files

* fix

* fix

* resolved conflicts with master.

* fixing multinomial, working in progress.

* Multinomial passes tests.

* Enhance multinomial tests and benchmarks.

* [bugfix] keepdim when samples one

* [bugfix] fix accu test

* fix anomaly behavior in fused_renorm_cumsum

* Polish multinomial tests.

* remove garbage files.

* bfloat16 added for multinomial, polish without replacement test.

* Enable two-pass normed cumsum.

* cumsum updated

* normed cumsum complete.

* Fixed multinomial binary search boundary bug

* fix normed_cumsum bugs.

* quick fix dim check.

---------

Co-authored-by: Bowen12992 <zhangbluestars@gmail.com>
Co-authored-by: Clement Chan <iclementine@outlook.com>
Co-authored-by: Bowen <81504862+Bowen12992@users.noreply.github.com>
Co-authored-by: StrongSpoon <35829812+StrongSpoon@users.noreply.github.com>
Co-authored-by: StrongSpoon <strongspoon@outlook.com>
GwokHiujin added a commit that referenced this pull request Oct 28, 2024
* add Ops & UT & Bench

* add full zero ones Ops & UT & Bench

* split normal op

* [Operator] init slice&select scatter

* code format

* PR comment

* split test_special_ops

* add K-S test

* split special perf

* Exponential added. (#138)

* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.

* resolve conflict

* Use int64 indexing when needed & fix argmax (#146)

 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size

* test for op

* test for op

* Making libentry thread safe (#136)

* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.

* add argparse

* fix desc

* fix num

* Update test_specific_ops.py

* split UT files

* fix

* fix

* [Operator] Optimize CrossEntropyLoss (#131)

reimplement cross_entropy_loss forward and backward
support; indices/probabilities/weight/reduction/ignore_index/label_smoothing; perform better than torch eager on large scale tensors

* Exponential added. (#138)

* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.

* Use int64 indexing when needed & fix argmax (#146)

 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size

* Making libentry thread safe (#136)

* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.

* [Test] Test for op (#151)

* [chore] solve slice&select scatter's test cases

* [fix] fix slice&select scatter's test cases

* [chore] remove out-of-range indices in select_scatter's test cases

* [chore] simplify slice_scatter's test cases

* [fix] Added range that is deleted by mistake

* Merge branch 'master' into slice&select_scatter

* [chore] reformat

* [fix] typo

* [chore] Considering perf, pause the replacement of some aTen operators
* slice_scatter
* select_scatter
* index_select

* [fix] Add libentry in op.cumsum

* [fix] Del slice&select scatter's perf tests

* [Chore] Add pytest mark for slice&select scatter's test

* [Fix] Correct slice_scatter test

* [Fix] Replace CPU Tensor

---------

Co-authored-by: Bowen12992 <zhangbluestars@gmail.com>
Co-authored-by: Tongxin Bai <waffle.bai@gmail.com>
Co-authored-by: Clement Chan <iclementine@outlook.com>
Co-authored-by: Bowen <81504862+Bowen12992@users.noreply.github.com>
Co-authored-by: StrongSpoon <35829812+StrongSpoon@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants