Skip to content

Commit

Permalink
test++
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Oct 10, 2024
1 parent f5a828b commit 5e4fcfa
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 32 deletions.
12 changes: 6 additions & 6 deletions src/layer/vulkan/gemm_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,33 +223,33 @@ int Gemm_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkM
{
vkdev->convert_packing(C0, C, 1, cmd, opt);

if (C0.dims == 1 && C0.w == 1)
if (C.dims == 1 && C.w == 1)
{
// scalar
broadcast_type_C = 0;
}
if (C0.dims == 1 && C0.w == M)
if (C.dims == 1 && C.w == M)
{
// M
// auto broadcast from h to w is the ncnn-style convention
broadcast_type_C = 1;
}
if (C0.dims == 1 && C0.w == N)
if (C.dims == 1 && C.w == N)
{
// N
broadcast_type_C = 4;
}
if (C0.dims == 2 && C0.w == 1 && C0.h == M)
if (C.dims == 2 && C.w == 1 && C.h == M)
{
// Mx1
broadcast_type_C = 2;
}
if (C0.dims == 2 && C0.w == N && C0.h == M)
if (C.dims == 2 && C.w == N && C.h == M)
{
// MxN
broadcast_type_C = 3;
}
if (C0.dims == 2 && C0.w == N && C0.h == 1)
if (C.dims == 2 && C.w == N && C.h == 1)
{
// 1xN
broadcast_type_C = 4;
Expand Down
60 changes: 34 additions & 26 deletions tests/test_gemm_3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,12 @@ static int test_gemm_1(int M, int N, int K)
|| test_gemm_int8_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 1, 1, 0, 0, 0)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, -0.8f, 1, 1, 2, 0, 0, 0, 0)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, -0.8f, 1, 1, 3, 1, 0, 0, 0)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 0, 0, 0, 0, 0)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 1, 1, 0, 0, 0)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 0, 0, 0, 0, 0)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 1, 1, 0, 0, 0)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 2, 0, 0, 0, 0)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 3, 1, 0, 0, 0)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 0.8f, 1.f, 0, 0, 0, 0, 0, 0, 0)
|| test_gemm_int8_bias(M, N, K, RandomMat(N), 0.8f, 1.f, 0, 0, 1, 1, 0, 0, 0)
|| test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, -0.6f, 0, 1, 2, 0, 0, 0, 0)
|| test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, -0.6f, 0, 1, 3, 1, 0, 0, 0)

Expand All @@ -194,8 +198,12 @@ static int test_gemm_1(int M, int N, int K)
|| test_gemm_int8_bias(M, N, K, RandomMat(1, M), -4.1f, 0.7f, 1, 0, 1, 1, 1, 1, 1)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, M), -5.1f, -0.8f, 1, 1, 2, 0, 1, 1, 1)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, M), -5.1f, -0.8f, 1, 1, 3, 1, 1, 1, 1)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 0, 0, 1, 1, 1)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 1, 1, 1, 1, 1)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 0, 0, 1, 1, 1)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 1, 1, 1, 1, 1)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 2, 0, 1, 1, 1)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 3, 1, 1, 1, 1)
|| test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 0.8f, 1.f, 0, 0, 0, 0, 1, 1, 1)
|| test_gemm_int8_bias(M, N, K, RandomMat(N), 0.8f, 1.f, 0, 0, 1, 1, 1, 1, 1)
|| test_gemm_int8_bias(M, N, K, RandomMat(N), -3.1f, -0.6f, 0, 1, 2, 0, 1, 1, 1)
|| test_gemm_int8_bias(M, N, K, RandomMat(N), -3.1f, -0.6f, 0, 1, 3, 1, 1, 1, 1);
}
Expand All @@ -208,40 +216,40 @@ int main()
#if NCNN_INT8
int mnk[][3] = {
{1, 1, 1},
{1, 1, 23},
{1, 1, 47},
{1, 23, 1},
{1, 23, 23},
{1, 31, 1},
{1, 35, 1},
{1, 35, 47},
{1, 47, 1},
{2, 2, 2},
{3, 3, 3},
{4, 4, 4},
{5, 5, 5},
{6, 6, 6},
{7, 7, 7},
{7, 31, 3},
{8, 8, 8},
{15, 15, 15},
{16, 16, 16},
{31, 31, 31},
{40, 40, 40},
{1, 1, 23},
{1, 31, 1},
{23, 1, 1},
{12, 12, 23},
{12, 23, 12},
{12, 31, 12},
{23, 12, 12},
{1, 1, 47},
{1, 35, 1},
{47, 1, 1},
{24, 24, 47},
{24, 35, 24},
{47, 24, 24},
{1, 35, 47},
{15, 15, 15},
{16, 16, 16},
{19, 44, 7},
{20, 28, 7},
{23, 31, 1},
{23, 1, 23},
{23, 31, 23},
{31, 7, 3},
{28, 20, 7},
{24, 24, 47},
{24, 35, 24},
{24, 47, 24},
{31, 31, 31},
{32, 32, 9},
{44, 19, 7},
{47, 35, 48},
{47, 48, 47},
{48, 35, 47}
{35, 47, 48},
{35, 48, 47},
{40, 40, 40},
{47, 48, 47}
};

int mnk_count = sizeof(mnk) / sizeof(int) / 3;
Expand Down

0 comments on commit 5e4fcfa

Please sign in to comment.