Skip to content

Add CUDA non-contiguous Unary Ops support #14639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

YavorGIvanov
Copy link
Contributor

No description provided.

@github-actions github-actions bot added documentation Improvements or additions to documentation build Compilation issues Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jul 11, 2025
@YavorGIvanov YavorGIvanov force-pushed the feature/cuda-non-cont-unary-support branch from c44bfde to 919ce38 Compare July 11, 2025 23:34
{ "name": "x64-linux-gcc-debug", "inherits": [ "base", "x64-linux-gcc", "debug" ] },
{ "name": "x64-linux-gcc-release", "inherits": [ "base", "x64-linux-gcc", "release" ] },
{ "name": "x64-linux-gcc-reldbg", "inherits": [ "base", "x64-linux-gcc", "reldbg" ] },
{ "name": "x64-linux-gcc+static-release", "inherits": [ "base", "x64-linux-gcc", "release", "static" ] },
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this accidental?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Should I separate it another PR c4ecdef

I am fine with removing it, but I did not see a preset that fit my use case and decided to add.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe easier to merge if you separate into another PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put it into a separate PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@am17an am17an requested a review from JohannesGaessler July 12, 2025 10:08
{ "name": "x64-linux-gcc-debug", "inherits": [ "base", "x64-linux-gcc", "debug" ] },
{ "name": "x64-linux-gcc-release", "inherits": [ "base", "x64-linux-gcc", "release" ] },
{ "name": "x64-linux-gcc-reldbg", "inherits": [ "base", "x64-linux-gcc", "reldbg" ] },
{ "name": "x64-linux-gcc+static-release", "inherits": [ "base", "x64-linux-gcc", "release", "static" ] },
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put it into a separate PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this file? Did you add it by accident?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JohannesGaessler recent merge #14598, in subsequent PRs we'll work out how to have such a huge diff when merging. Currently it records the timestamp, device etc so it becomes an entirely new file

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YavorGIvanov For now don't commit the docs/ops/CUDA.csv and docs/ops.md. I'll make a follow-up PR after this gets merged to update the ops table.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with improving and simplifying process of generating the docs/ops.md to not produce huge diffs myself.

Comment on lines 103 to 105
const int k) {

const int i = blockDim.x*blockIdx.x + threadIdx.x;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
const int64_t k) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Applied as part of other PR review changes.

Comment on lines +131 to +133
if (ggml_is_contiguous(src) && ggml_is_contiguous(dst_tensor)) {
unary_op_kernel<op><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the contiguous path, it's no longer needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept it as the performance of the simple cont kernel is obviously better. I thought you may prefer to still use the most optimal path in this case. I know in the big scheme of things these unary operations are a very small part of the inference time, but think it is good idea to not degrade cont perf in this case.

  ABS(type=f32,ne_a=[256,256,3,1],v=0):               532415 runs -     1.88 us/run -     1536 kB/run -  778.95 GB/s
  ABS(type=f32,ne_a=[256,256,3,1],v=1):               311220 runs -     3.24 us/run -     3070 kB/run -  903.14 GB/s

Here is example perf test using test-backend-ops on a H100 SXM5.
v=0 meaning contiguous and v=1 meaning non-contiguous.

Let me know whether you still want the cont path removed or you agree I should keep it for now.

@github-actions github-actions bot added the testing Everything test related label Jul 12, 2025
@YavorGIvanov YavorGIvanov force-pushed the feature/cuda-non-cont-unary-support branch from 1174a95 to 1752873 Compare July 12, 2025 23:43
@YavorGIvanov YavorGIvanov force-pushed the feature/cuda-non-cont-unary-support branch from 1752873 to 64be8c5 Compare July 12, 2025 23:44
@YavorGIvanov
Copy link
Contributor Author

@JohannesGaessler @am17an Tried to address all comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
build Compilation issues documentation Improvements or additions to documentation ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants