Skip to content

samuelklam/strassens-algorithm

Repository files navigation

Strassen's Algorithm

We implement Strassen’s algorithm to improve the standard matrix multiplication algorithm, Θ(n3), for reasonably sized matrices to Θ(nlog7). For sufficiently large values of n, Strassen’s algorithm will run faster than the conventional algorithm. For small values of n, however, the conventional algorithm may be faster. We can define the cross-over point between the two algorithms to be the value of n for which we want to stop using Strassens algorithm and switch to conventional matrix multiplication. We seek to analytically and experimentally determine the cross-over point.

We define Strassen's algorithm as follows. Given matrices A and B of size n, we split each into 4 sub-matrices of size n/2. We denote A1, 1, A1, 2, A2, 1, A2, 2 as the sub-matrices of matrix A and B1, 1, B1, 2, B2, 1, B2, 2 as the sub-matrices of matrix B, where A1, 2 denotes the sub-matrix on the top right of A. We compute the following seven products:

  • M1 = A1, 1(B1, 2 - B2, 2)
  • M2 = (A1, 1 + A1, 2)B2, 2
  • M3 = (A2, 1 + A2, 2)B1, 1
  • M4 = A2, 2(B2, 1 - B1, 1)
  • M5 = (A1, 1 + A2, 2)(B1, 1 + B2, 2)
  • M6 = (A1, 2 - A2, 2)(B2, 1 + B2, 2)
  • M7 = (A1, 1 - A2, 1)(B1, 1 + B1, 2)

Then we can find the appropriate terms of the product by addition, such that C = AB and C1, 2 denotes the top right section of C:

  • C1, 1 = M5 + M4 - M2 + M6
  • C1, 2 = M1 + M2
  • C2, 1 = M3 + M4
  • C2, 2 = M5 + M1 - M3 - M7

For a full write-up refer to: strassens-writeup.pdf.

Usage

In the command line:

$ make
$ ./strassen 0 dimension inputfile

The flag 0 can be used for testing, debugging, or extensions. The dimension d, is the dimension of the matrix being multiplied; e.g. 32 denotes multiplying two 32 x 32 matrices together. The inputfile is an ASCII file with 2d2 integer numbers, one per line, representing two matrices A and B. The first integer number is matrix entry a0,0, followed by a0,1, a0,2, ..., a0,d-1; next comes a1,0, a1,1, and so on, for the first d2 numbers. The next d2 numbers are similar for matrix B.

The output is a list of values of the diagonal entries of C = AB; i.e. c0,0, c1,1, ..., cd-1,d-1, one per line, followed by a trailing newline.

There are four Strassen-Variant algorithms. To run each one, set strassen_algo_flag in main.cpp to [1, 4] accordingly.