Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add repeat_interleave_self_int op #214

Merged
merged 4 commits into from
Sep 18, 2024

Conversation

zfu82
Copy link
Collaborator

@zfu82 zfu82 commented Sep 14, 2024

Performance

Tested on NV-A100

benchmark/test_pointwise_perf.py::test_perf_repeat_interleave_self_int Operator repeat_interleave_self_int Performance Test (torch.float16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.019456            0.012288
6144                   0.05632            0.036864
11264                 0.096256             0.06144
16384                 0.136192            0.084992
21504                 0.177152            0.169984
26624                 0.216064            0.171008
31744                    0.256            0.172032
36864                 0.295936            0.246784
41984                 0.334848            0.247808
47104                 0.374784            0.249856
52224                  0.41472            0.323584
57344                 0.454656            0.325632
62464                 0.494592            0.326656
67584                 0.533504            0.400384
72704                  0.57344            0.402432
77824                 0.612352            0.402432
Operator repeat_interleave_self_int Performance Test (torch.float32)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.017408            0.014336
6144                  0.065536            0.055296
11264                 0.111616            0.094208
16384                  0.15872            0.134144
21504                   0.2048            0.198656
26624                 0.251904            0.226304
31744                 0.297984            0.254976
36864                 0.345088            0.295936
41984                 0.392192            0.336896
47104                 0.438272            0.375808
52224                   0.4864            0.417792
57344                  0.53248            0.457728
62464                  0.57856            0.497664
67584                 0.625664            0.539648
72704                 0.672768            0.581632
77824                 0.718848             0.61952
Operator repeat_interleave_self_int Performance Test (torch.bfloat16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.017408            0.012288
6144                   0.05632            0.036864
11264                  0.09728            0.060416
16384                 0.137216            0.086016
21504                 0.176128            0.169984
26624                 0.216064            0.171008
31744                 0.254976            0.172032
36864                 0.294912            0.246784
41984                 0.334848            0.248832
47104                 0.374784            0.249856
52224                  0.41472            0.323584
57344                 0.453632            0.324608
62464                 0.494592            0.326656
67584                 0.533504            0.400384
72704                  0.57344            0.402432
77824                 0.612352            0.402432
PASSED

@zfu82 zfu82 force-pushed the dev_repeat_interleave_self_int branch from dedd293 to 1e8f1aa Compare September 14, 2024 07:37
@iclementine iclementine self-assigned this Sep 18, 2024
Copy link
Collaborator

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation now does not handle non-contiguous input correctly.

src/flag_gems/ops/repeat_interleave.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@iclementine iclementine merged commit f784450 into master Sep 18, 2024
4 checks passed
@iclementine iclementine deleted the dev_repeat_interleave_self_int branch September 18, 2024 06:44
DuanYaQi pushed a commit that referenced this pull request Sep 20, 2024
* [Operator] Add repeat_interleave_self_int op
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants