Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make ggml_conv_2d faster #483

Merged
merged 7 commits into from
Oct 9, 2023
Merged

Conversation

leejet
Copy link
Contributor

@leejet leejet commented Aug 27, 2023

I've split the CONV_2D operation into two parts: im2col(CONV_2D_STAGE_0) and GEMM(CONV_2D_STAGE_1). I've enabled multi-threading for im2col and restructured the loops to make the cache more friendly. For clarity, I've used more intuitive and general variable names.

Taking the decode_first_stage phase in stable-diffusion as an example, when generating a 512x512 image.
For a detailed comparison, you can refer to the information provided here leejet/stable-diffusion.cpp#30

before:

perf_total_per_op_us[             ADD] = 1590.776 ms
perf_total_per_op_us[             MUL] = 708.338 ms
perf_total_per_op_us[          REPEAT] = 3385.188 ms
perf_total_per_op_us[      GROUP_NORM] = 754.513 ms
perf_total_per_op_us[         MUL_MAT] = 244.121 ms
perf_total_per_op_us[           SCALE] =  11.161 ms
perf_total_per_op_us[            CONT] =  16.623 ms
perf_total_per_op_us[         RESHAPE] =   0.111 ms
perf_total_per_op_us[         PERMUTE] =   0.003 ms
perf_total_per_op_us[        SOFT_MAX] =  12.856 ms
perf_total_per_op_us[         CONV_2D] = 43098.730 ms
perf_total_per_op_us[         UPSCALE] = 313.308 ms
perf_total_per_op_us[           UNARY] = 254.900 ms

after:

perf_total_per_op_us[             ADD] = 1595.238 ms
perf_total_per_op_us[             MUL] = 698.069 ms
perf_total_per_op_us[          REPEAT] = 3551.963 ms
perf_total_per_op_us[      GROUP_NORM] = 760.794 ms
perf_total_per_op_us[         MUL_MAT] = 219.111 ms
perf_total_per_op_us[           SCALE] =   9.338 ms
perf_total_per_op_us[            CONT] =  15.030 ms
perf_total_per_op_us[         RESHAPE] =   0.106 ms
perf_total_per_op_us[         PERMUTE] =   0.003 ms
perf_total_per_op_us[        SOFT_MAX] =  10.828 ms
perf_total_per_op_us[ CONV_2D_STAGE_0] = 5607.402 ms
perf_total_per_op_us[ CONV_2D_STAGE_1] = 12835.200 ms
perf_total_per_op_us[         UPSCALE] = 317.210 ms
perf_total_per_op_us[           UNARY] = 289.982 ms

The time for CONV_2D has been reduced from the original 43098.730 ms to 18442.602 ms (CONV_2D_STAGE_0 + CONV_2D_STAGE_1).

@leejet
Copy link
Contributor Author

leejet commented Aug 27, 2023

By the way, I believe that the matrix multiplication used in both MUL_MAT and CONV_2D operations should be unified for easier optimization.

@ggerganov
Copy link
Owner

ggerganov commented Aug 28, 2023

Nice speedup!

I am thinking how to avoid the 2-step process. Do you need it just because the INIT step is not multi-threaded?
If it becomes multi-threaded, would you still need 2 stages?

I think it should be possible to reuse ggml_compute_forward_mul_mat() in stage 2, however, it does not support src1 being F16. It has to be F32

Btw, I'm not 100% that the strategy of making the Conv2D into a matrix multiplication is the best thing to do. I initially did it like this since mul mat is familiar, but we waste a lot of extra memory this way. Maybe there is something better to do here - not sure. On the other hand, if we reuse the mul mat implementation it would be great because we won't have to write GPU code for Conv2D.

Will need some time to figure out how to refactor and merge this.
@slaren Let me know if you get any ideas how to improve this

@leejet
Copy link
Contributor Author

leejet commented Aug 28, 2023

I am thinking how to avoid the 2-step process. Do you need it just because the INIT step is not multi-threaded?
If it becomes multi-threaded, would you still need 2 stages?

If the INIT step is multithreaded, then indeed there might be no need to split it into two stages.

Btw, I'm not 100% that the strategy of making the Conv2D into a matrix multiplication is the best thing to do.

Im2col is a common practice in many frameworks like Caffe, PyTorch, and others. However, these frameworks often implement other optimization techniques such as Winograd, where they choose different calculation methods based on the input format.

@ggerganov
Copy link
Owner

I think for now we'll merge this as proposed.

At some point, we have to make Stage 1 reuse the matrix multiplication operator because this will bring massive speed-up on Apple Silicon via the Accelerate framework.

@leejet Let me know if this is still the version you want to have upstream.

@leejet
Copy link
Contributor Author

leejet commented Sep 9, 2023

@leejet Let me know if this is still the version you want to have upstream.

@ggerganov The code in this PR is up to date now.

@AngryLoki
Copy link

@leejet , could you check, please, if there is a race condition? I tried to compare results of old conv2d with new gemm-based and got same values with single-threaded, but different with multithreaded (and I know for sure that exactly old code was correct).

    int64_t ne00 = 3, ne01 = 3, ne02 = 640, ne03 = 640;
    int64_t ne10 = 80, ne11 = 80, ne12 = 640, ne13 = 1;

    int s0 = 1;
    int s1 = 1;
    int p0 = 1;
    int p1 = 1;
    int d0 = 1;
    int d1 = 1;

    std::vector<float> adata(ne00 * ne01 * ne02 * ne03);
    for (size_t i = 0; i < adata.size(); i++) adata[i] = 1;

    std::vector<uint16_t> hadata(ne00 * ne01 * ne02 * ne03);
    ggml_fp32_to_fp16_row(adata.data(), hadata.data(), adata.size());

    std::vector<float> bdata(ne10 * ne11 * ne12 * ne13);
    for (size_t i = 0; i < bdata.size(); i++) bdata[i] = 1;

    struct ggml_init_params params_ctx = {
        .mem_size = 200 * 1024 * 1024,
        .mem_buffer = NULL,
        .no_alloc = false,
    };

    struct ggml_context* ctx = ggml_init(params_ctx);

    struct ggml_tensor* a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne00, ne01, ne02, ne03);
    memcpy(a->data, adata.data(), adata.size() * sizeof(adata[0]));
    struct ggml_tensor* ha = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, ne00, ne01, ne02, ne03);
    memcpy(ha->data, hadata.data(), hadata.size() * sizeof(hadata[0]));

    struct ggml_tensor* b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne10, ne11, ne12, ne13);
    memcpy(b->data, bdata.data(), bdata.size() * sizeof(bdata[0]));

    struct ggml_tensor* r1 = ggml_conv_2d_stage_0(ctx, ha, b, s0, s1, p0, p1, d0, d1); // [N, OH, OW, IC * KH * KW]
    struct ggml_tensor* ref1 = ggml_conv_2d_stage_1(ctx, ha, r1);

        struct ggml_compute_params params = {
            .type = GGML_TASK_INIT,
            .ith = 0,
            .nth = 1,
            .wsize = 0,
            .wdata = NULL,
        };

        ggml_compute_forward_conv_2d_stage_0(&params, ha, b, r1);

        params.type = GGML_TASK_COMPUTE;
        #pragma omp parallel
        {
            params.nth = omp_get_num_threads();
            params.ith = omp_get_thread_num();
            ggml_compute_forward_conv_2d_stage_0(&params, ha, b, r1);
        }

        #pragma omp parallel
        {
            params.nth = omp_get_num_threads();
            params.ith = omp_get_thread_num();
            ggml_compute_forward_conv_2d_stage_1(&params, ha, r1, ref1);
        }

    struct ggml_tensor* ref2 = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ref1->ne[0], ref1->ne[1], ref1->ne[2], ref1->ne[3]);
    ref2->op_params[0] = ref2->op_params[1] = ref2->op_params[2] = ref2->op_params[3] = ref2->op_params[4] = ref2->op_params[5] = 1;
    void *buf = malloc(200*1024*1024);

        struct ggml_compute_params params = {
            .type = GGML_TASK_INIT,
            .ith = 0,
            .nth = 1,
            .wsize = 200*1024*1024,
            .wdata = buf,
        };

        ggml_compute_forward_conv_2d_f16_f32(&params, ha, b, ref2);

        params.type = GGML_TASK_COMPUTE;
        #pragma omp parallel
        {
            params.nth = omp_get_num_threads();
            params.ith = omp_get_thread_num();
            ggml_compute_forward_conv_2d_f16_f32(&params, ha, b, ref2);
        }

    const float *dref1 = (float *)(ref1->data);
    printf("new: %f %f %f %f %f %f %f\n", dref1[0], dref1[1], dref1[2], dref1[3], dref1[4], dref1[5], dref1[6]);

    const float *dref2 = (float *)(ref2->data);
    printf("old: %f %f %f %f %f %f %f\n", dref2[0], dref2[1], dref2[2], dref2[3], dref2[4], dref2[5], dref2[6]);
new: 2240.000000 3360.000000 3360.000000 3360.000000 3360.000000 3360.000000 3360.000000
old: 2560.000000 3840.000000 3840.000000 3840.000000 3840.000000 3840.000000 3840.000000

where new is randomly 2560,... or 2240,..., or even 0.000000,... only when omp is enabled. In single-thread results are the same.

@leejet
Copy link
Contributor Author

leejet commented Sep 14, 2023

@AngryLoki Because you changed the shared params variable in different OpenMP threads at the same time. You can try copying params, modifying them, and passing it in as arguments.

int64_t ne00 = 3, ne01 = 3, ne02 = 640, ne03 = 640;
    int64_t ne10 = 80, ne11 = 80, ne12 = 640, ne13 = 1;

    int s0 = 1;
    int s1 = 1;
    int p0 = 1;
    int p1 = 1;
    int d0 = 1;
    int d1 = 1;

    std::vector<float> adata(ne00 * ne01 * ne02 * ne03);
    for (size_t i = 0; i < adata.size(); i++) adata[i] = 1;

    std::vector<uint16_t> hadata(ne00 * ne01 * ne02 * ne03);
    ggml_fp32_to_fp16_row(adata.data(), hadata.data(), adata.size());

    std::vector<float> bdata(ne10 * ne11 * ne12 * ne13);
    for (size_t i = 0; i < bdata.size(); i++) bdata[i] = 1;

    struct ggml_init_params params_ctx = {
        .mem_size = 200 * 1024 * 1024,
        .mem_buffer = NULL,
        .no_alloc = false,
    };

    struct ggml_context* ctx = ggml_init(params_ctx);

    struct ggml_tensor* a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne00, ne01, ne02, ne03);
    memcpy(a->data, adata.data(), adata.size() * sizeof(adata[0]));
    struct ggml_tensor* ha = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, ne00, ne01, ne02, ne03);
    memcpy(ha->data, hadata.data(), hadata.size() * sizeof(hadata[0]));

    struct ggml_tensor* b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne10, ne11, ne12, ne13);
    memcpy(b->data, bdata.data(), bdata.size() * sizeof(bdata[0]));

    struct ggml_tensor* r1 = ggml_conv_2d_stage_0(ctx, ha, b, s0, s1, p0, p1, d0, d1); // [N, OH, OW, IC * KH * KW]
    struct ggml_tensor* ref1 = ggml_conv_2d_stage_1(ctx, ha, r1);

    {
        struct ggml_compute_params params = {
            .type = GGML_TASK_INIT,
            .ith = 0,
            .nth = 1,
            .wsize = 0,
            .wdata = NULL,
        };

        ggml_compute_forward_conv_2d_stage_0(&params, ha, b, r1);

        params.type = GGML_TASK_COMPUTE;
        #pragma omp parallel
        {
            struct ggml_compute_params  params2 = params;
            params2.nth = omp_get_num_threads();
            params2.ith = omp_get_thread_num();
            ggml_compute_forward_conv_2d_stage_0(&params2, ha, b, r1);
        }

        #pragma omp parallel
        {
            struct ggml_compute_params  params2 = params;
            params2.nth = omp_get_num_threads();
            params2.ith = omp_get_thread_num();
            ggml_compute_forward_conv_2d_stage_1(&params2, ha, r1, ref1);
        }
    }

    struct ggml_tensor* ref2 = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ref1->ne[0], ref1->ne[1], ref1->ne[2], ref1->ne[3]);
    ref2->op_params[0] = ref2->op_params[1] = ref2->op_params[2] = ref2->op_params[3] = ref2->op_params[4] = ref2->op_params[5] = 1;
    void *buf = malloc(200*1024*1024);

    {
        struct ggml_compute_params params = {
            .type = GGML_TASK_INIT,
            .ith = 0,
            .nth = 1,
            .wsize = 200*1024*1024,
            .wdata = buf,
        };

        ggml_compute_forward_conv_2d_f16_f32(&params, ha, b, ref2);

        params.type = GGML_TASK_COMPUTE;
        #pragma omp parallel
        {
            struct ggml_compute_params params2 = params;
            params2.nth = omp_get_num_threads();
            params2.ith = omp_get_thread_num();
            
            ggml_compute_forward_conv_2d_f16_f32(&params2, ha, b, ref2);
        }
    }
    const float *dref1 = (float *)(ref1->data);
    printf("new: %f %f %f %f %f %f %f\n", dref1[0], dref1[1], dref1[2], dref1[3], dref1[4], dref1[5], dref1[6]);

    const float *dref2 = (float *)(ref2->data);
    printf("old: %f %f %f %f %f %f %f\n", dref2[0], dref2[1], dref2[2], dref2[3], dref2[4], dref2[5], dref2[6]);
new: 2560.000000 3840.000000 3840.000000 3840.000000 3840.000000 3840.000000 3840.000000
old: 2560.000000 3840.000000 3840.000000 3840.000000 3840.000000 3840.000000 3840.000000

@AngryLoki
Copy link

Ouch, my mistake... Yes, new version is correct.

@LeonNerd
Copy link

I have a question.I see in stable diffusion.cpp that it doesn't apply to Gpus,Why does ggml_conv_2d not support Gpu?I wish someone could help me.What's holding it back?Because I think convolution is a pretty general operator.Thanks.

@ggerganov
Copy link
Owner

We will eventually support GPU implementations of the convolution operators, but it's not the focus yet. Contributions are welcome.

Sorry for delaying this merge for so long. After merging #547, will focus on bringing this up-to-date and merge

@ggerganov ggerganov self-assigned this Oct 6, 2023
@FSSRepo
Copy link
Collaborator

FSSRepo commented Oct 7, 2023

We will eventually support GPU implementations of the convolution operators, but it's not the focus yet. Contributions are welcome.

I need some feedback to finish the implementation in #556

@ggerganov
Copy link
Owner

I would like to merge this

@PABannier Would be nice if you can test that this branch works for your cases, as it touches conv_2d and the change is not trivial. My SAM example still works, but want to make sure something else didn't break

@PABannier
Copy link
Contributor

Hey @ggerganov ! I used the following two-step implementation for the ggml_conv_1d operation. As discussed in #523 , this works well and all the examples pass with this implementation.

Yet, since neither encodec nor bark use 2d convolution, I can't check for this specific implementation of 2d convolution.

@ggerganov ggerganov merged commit 6549d12 into ggerganov:master Oct 9, 2023
4 checks passed
@rayrayraykk
Copy link

I am thinking how to avoid the 2-step process. Do you need it just because the INIT step is not multi-threaded?
If it becomes multi-threaded, would you still need 2 stages?

If the INIT step is multithreaded, then indeed there might be no need to split it into two stages.

Btw, I'm not 100% that the strategy of making the Conv2D into a matrix multiplication is the best thing to do.

Im2col is a common practice in many frameworks like Caffe, PyTorch, and others. However, these frameworks often implement other optimization techniques such as Winograd, where they choose different calculation methods based on the input format.

I would like to ask a question: is it possible to be faster with winograd convolution?

@leejet
Copy link
Contributor Author

leejet commented Nov 10, 2023

I would like to ask a question: is it possible to be faster with winograd convolution?

It might be faster, but I'm not sure, I haven't tried

@FSSRepo
Copy link
Collaborator

FSSRepo commented Nov 18, 2023

Im2col is quite fast on CUDA, but it consumes a lot of memory. I'm thinking of adding an option to split the input into tiles and process them sequentially to save memory when the input is very large. I would then concatenate each result of the matrix multiplication. However, it's somewhat challenging to perform data shifting in ggml_mul_mat.

Some useful info:
Winograd: precompute kernel redundant calculations of Matrix Multiplication more speed without matrix multiplication.

Winograd CUDA

Screenshot 2023-11-18 094120

I'm trying this:

def winograd_1d(data, filter):
    N = len(data)
    b1 = filter[0] + filter[2]
    b2 = 0.5 * (b1 + filter[1])
    b3 = 0.5 * (b1 - filter[1])

    output = [0] * N
    print(f"N {N}")
    for i in range(0, N - 1, 2):
        a1 = (data[i] + data[i + 1]) * b2
        a2 = (data[i + 1] - data[i]) * b3
        output[i] = (data[i] - data[i + 1]) * filter[0] + a1 + 2
        output[i + 1] = a1 - a2 - (data[i] - data[i + 2]) * filter[2] # Error: data + 2 index out of range

    if N % 2 != 0:
        a1 = (data[N - 1] + data[N]) * b2
        a2 = (data[N] - data[N - 1]) * b3
        output[N - 1] = (data[N - 2] - data[N]) * filter[0] + a1 + a2

    return output

# Example usage:
data = [1, 2, 3, 4, 5, 6, 7, 8]
filter = [0.1, 0.2, 0.3]

result = winograd_1d(data, filter)
print(result)

CCLDArjun pushed a commit to CCLDArjun/ggml that referenced this pull request Dec 18, 2023
`llama_sample_top_p_top_k` was missing the struct annotation on line 126.

This causes a compiler issue when being parsed by the Kotlin C interop generator.

This commit fixes the above issue by adding the struct annotation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants