Skip to content

Commit 37a0835

Browse files
Math: Optimise 16-bit matrix multiplication function
- Switch to int32_t in accumulators to reduce memory and improve speed in mat_multiply and mat_multiply_elementwise functions. - Refactor pointer incrementation for code clarity and optimizations, enhancing maintainability and execution speed. - Handle Q0 data directly in core logic, removing redundant conditionals for faster execution in common scenarios. - Revise fractional bit shift logic for better accuracy in fixed-point calculations. This consolidation achieves notable performance boosts and memory savings. Signed-off-by: Shriram Shastry <malladi.sastry@intel.com>
1 parent 07b762e commit 37a0835

File tree

1 file changed

+73
-61
lines changed

1 file changed

+73
-61
lines changed

src/math/matrix.c

Lines changed: 73 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,93 +3,105 @@
33
// Copyright(c) 2022 Intel Corporation. All rights reserved.
44
//
55
// Author: Seppo Ingalsuo <seppo.ingalsuo@linux.intel.com>
6+
// Shriram Shastry <malladi.sastry@linux.intel.com>
7+
//
68

79
#include <sof/math/matrix.h>
810
#include <errno.h>
911
#include <stdint.h>
1012

11-
int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c)
13+
/* Performs matrix multiplication of two fixed-point 16-bit integer matrices,
14+
* storing the result in a third matrix. It accounts for fractional bits for
15+
* fixed-point arithmetic, adjusting the result accordingly.
16+
*
17+
* Arguments:
18+
* a: pointer to the first input matrix
19+
* b: pointer to the second input matrix
20+
* c: pointer to the output matrix to store result
21+
*
22+
* Return:
23+
* 0 on successful multiplication.
24+
* -EINVAL if input dimensions do not allow for multiplication.
25+
* -ERANGE if the shift operation might cause integer overflow.
26+
*/
27+
int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b,
28+
struct mat_matrix_16b *c)
1229
{
13-
int64_t s;
14-
int16_t *x;
15-
int16_t *y;
16-
int16_t *z = c->data;
17-
int i, j, k;
18-
int y_inc = b->columns;
19-
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
30+
int32_t acc; /* Accumulator for dot product calculation */
31+
int16_t *x, *y, *z = c->data; /* Pointers for matrices a, b, and c */
32+
int i, j, k; /* Loop counters */
33+
int y_inc = b->columns; /* Column increment for matrix b elements */
34+
/* Calculate shift amount for adjusting fractional bits in the result */
35+
const int shift = a->fractions + b->fractions - c->fractions - 1;
2036

37+
/* Validate matrix dimensions are compatible for multiplication */
2138
if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns)
2239
return -EINVAL;
2340

24-
/* If all data is Q0 */
25-
if (shift_minus_one == -1) {
26-
for (i = 0; i < a->rows; i++) {
27-
for (j = 0; j < b->columns; j++) {
28-
s = 0;
29-
x = a->data + a->columns * i;
30-
y = b->data + j;
31-
for (k = 0; k < b->rows; k++) {
32-
s += (int32_t)(*x) * (*y);
33-
x++;
34-
y += y_inc;
35-
}
36-
*z = (int16_t)s; /* For Q16.0 */
37-
z++;
38-
}
39-
}
40-
41-
return 0;
42-
}
41+
/* Check shift to ensure no integer overflow occurs during shifting */
42+
if (shift < -1 || shift > 31)
43+
return -ERANGE;
4344

45+
/* Matrix multiplication loop */
4446
for (i = 0; i < a->rows; i++) {
4547
for (j = 0; j < b->columns; j++) {
46-
s = 0;
47-
x = a->data + a->columns * i;
48-
y = b->data + j;
48+
acc = 0; /* Initialize accumulator for each element */
49+
x = a->data + a->columns * i; /* Set x at the start of ith row of a */
50+
y = b->data + j; /* Set y at the top of jth column of b */
51+
/* Dot product loop */
4952
for (k = 0; k < b->rows; k++) {
50-
s += (int32_t)(*x) * (*y);
51-
x++;
52-
y += y_inc;
53+
acc += (int32_t)(*x++) * (*y); /* Multiply & accumulate */
54+
y += y_inc; /* Move to next row in the current column of b */
5355
}
54-
*z = (int16_t)(((s >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
55-
z++;
56+
/* Assign computed value to c matrix, adjusting for fractional bits */
57+
if (shift == -1)
58+
*z = (int16_t)acc;
59+
else
60+
*z = (int16_t)(((acc >> shift) + 1) >> 1);
61+
z++; /* Move to the next element in the output matrix */
5662
}
5763
}
5864
return 0;
5965
}
6066

67+
/* Description: Performs element-wise multiplication of 16-bit int matrices
68+
* (a and b), storing the result in a third matrix (c). Assumes
69+
* all matrices are of the same dimensions and adjusts fractional
70+
* bits appropriately.
71+
* Arguments:
72+
* a - pointer to the first input matrix
73+
* b - pointer to the second input matrix
74+
* c - pointer to the output matrix to store result
75+
* Returns:
76+
* 0 on successful element-wise multiplication
77+
* -EINVAL if pointers are null or dimensions do not match
78+
*/
6179
int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b,
6280
struct mat_matrix_16b *c)
63-
{ int64_t p;
64-
int16_t *x = a->data;
65-
int16_t *y = b->data;
66-
int16_t *z = c->data;
67-
int i;
68-
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
81+
{
82+
int16_t *x = a->data; /* Pointer to data of matrix a */
83+
int16_t *y = b->data; /* Pointer to data of matrix b */
84+
int16_t *z = c->data; /* Pointer to data of matrix c */
85+
int32_t prod; /* Product of element-wise multiplication */
6986

70-
if (a->columns != b->columns || b->columns != c->columns ||
71-
a->rows != b->rows || b->rows != c->rows) {
87+
/* Validate pointers and matrices dimensions */
88+
if (!a || !b || !c || a->columns != b->columns || a->rows != b->rows)
7289
return -EINVAL;
73-
}
7490

75-
/* If all data is Q0 */
76-
if (shift_minus_one == -1) {
77-
for (i = 0; i < a->rows * a->columns; i++) {
78-
*z = *x * *y;
79-
x++;
80-
y++;
81-
z++;
82-
}
91+
const int total_elements = a->rows * a->columns; /* Total elements count */
92+
const int shift = a->fractions + b->fractions - c->fractions - 1;
8393

84-
return 0;
85-
}
86-
87-
for (i = 0; i < a->rows * a->columns; i++) {
88-
p = (int32_t)(*x) * *y;
89-
*z = (int16_t)(((p >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
90-
x++;
91-
y++;
92-
z++;
94+
/* Perform element-wise multiplication */
95+
if (shift == -1) {
96+
/* No adjustment needed for fractional bits */
97+
for (int i = 0; i < total_elements; i++)
98+
z[i] = x[i] * y[i]; /* Direct multiplication */
99+
} else {
100+
/* Adjustment needed for fractional bits */
101+
for (int i = 0; i < total_elements; i++) {
102+
prod = (int32_t)x[i] * y[i]; /* Multiply with extended precision */
103+
z[i] = (int16_t)(((prod >> shift) + 1) >> 1);
104+
}
93105
}
94106

95107
return 0;

0 commit comments

Comments
 (0)