Skip to content

Commit

Permalink
feat: add fast multiplication algorithm (strassen)
Browse files Browse the repository at this point in the history
* add an implementation of matrix product in the benchmark (strassen's algorithm)
* add the function mmul_strassen to the class abstractMatrix.
* Modification of the benchmark of mmul : use integer instead of float between 0 and 1.
* Add a test of mmul_strassen
  • Loading branch information
jajoe authored and targos committed Aug 23, 2016
1 parent 3de8a15 commit fdc1c07
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 26 deletions.
129 changes: 104 additions & 25 deletions benchmark/mmul.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,126 @@

var x = parseInt(process.argv[2]) || 5;
var y = parseInt(process.argv[3]) || x;
console.log(`mmul operations benchmark for ${x}x${y} matrix`);
// console.log(`mmul operations benchmark for ${x}x${y} matrix`);

var Benchmark = require('benchmark');
var suite = new Benchmark.Suite;

var Matrix = require('../src/index');

Matrix.prototype.mmul2 = function (other) {
other = Matrix.checkMatrix(other);
if (this.columns !== other.rows)
console.warn('Number of columns of left matrix are not equal to number of rows of right matrix.');

var m = this.rows;
var n = this.columns;
var p = other.columns;

var result = Matrix.zeros(m, p);
for (var i = 0; i < m; i++) {
for (var k = 0; k < n; k++) {
for (var j = 0; j < p; j++) {
result[i][j] += this[i][k] * other[k][j];
}
}
function strassen_2x2(a,b){
var a11 = a.get(0,0);
var b11 = b.get(0,0);
var a12 = a.get(0,1);
var b12 = b.get(0,1);
var a21 = a.get(1,0);
var b21 = b.get(1,0);
var a22 = a.get(1,1);
var b22 = b.get(1,1);

// Compute intermediate values.
var m1 = (a11+a22)*(b11+b22);
var m2 = (a21+a22)*b11;
var m3 = a11*(b12-b22);
var m4 = a22*(b21-b11);
var m5 = (a11+a12)*b22;
var m6 = (a21-a11)*(b11+b12);
var m7 = (a12-a22)*(b21+b22);

// Combine intermediate values into the output.
var c11 =m1+m4-m5+m7;
var c12 = m3+m5;
var c21 = m2+m4;
var c22 = m1-m2+m3+m6;

var c = new Matrix(2,2);
c.set(0,0,c11);
c.set(0,1,c12);
c.set(1,0,c21);
c.set(1,1,c22);
return c;
}

// bad, very bad...
function strassen_nxn(a,b){
if(a.rows == 2){
return strassen_2x2(a, b);
}
return result;
};
else{
var size = a.rows;
var size1 = size - 1;
var demi_size0 = parseInt(size/2);
var demi_size1 = parseInt(demi_size0 - 1);
// a et b must be the same size and rows = columns

var a11 = a.subMatrix(0, demi_size1, 0, demi_size1);
var b11 = b.subMatrix(0, demi_size1, 0, demi_size1);
var a12 = a.subMatrix(0, demi_size1, demi_size0, size1);
var b12 = b.subMatrix(0, demi_size1, demi_size0, size1);
var a21 = a.subMatrix(demi_size0, size1, 0, demi_size1);
var b21 = b.subMatrix(demi_size0, size1, 0, demi_size1);
var a22 = a.subMatrix(demi_size0, size1, demi_size0, size1);
var b22 = b.subMatrix(demi_size0, size1, demi_size0, size1);

// Compute intermediate values.
var m1 = strassen_nxn(Matrix.add(a11,a22),Matrix.add(b11,b22));
var m2 = strassen_nxn(Matrix.add(a21,a22),b11);
var m3 = strassen_nxn(a11,Matrix.sub(b12,b22));
var m4 = strassen_nxn(a22,Matrix.sub(b21,b11));
var m5 = strassen_nxn(Matrix.add(a11,a12),b22);
var m6 = strassen_nxn(Matrix.sub(a21,a11),Matrix.add(b11,b12));
var m7 = strassen_nxn(Matrix.sub(a12,a22),Matrix.add(b21,b22));

// Combine intermediate values into the output.
var c11 = Matrix.add(m1,m4).sub(m5).add(m7);
var c12 = Matrix.add(m3,m5);
var c21 = Matrix.add(m2,m4);
var c22 = Matrix.sub(m1,m2).add(m3).add(m6);

var c = new Matrix(size,size);
c.setSubMatrix(c11,0,0);
c.setSubMatrix(c12,0,demi_size0);
c.setSubMatrix(c21,demi_size0,0);
c.setSubMatrix(c22,demi_size0,demi_size0);
return c;
}
}


var m = Matrix.randInt(x, y);
var m2 = Matrix.randInt(y, x);

var m = Matrix.rand(x, y);
var m2 = Matrix.rand(y, x);
/*console.log("test avec strassen n by n")
console.time("r0");
var r0 = m.mmul_strassen_2(m, m2);
console.timeEnd("r0")*/
console.log("test avec une implementation standard")
console.time("r1");
var r1 = m.mmul(m2);
console.timeEnd("r1")
console.log("test avec une implementation de Strassen basee sur du Dynamic Padding")
console.time("r2")
var r2 = m.mmul_strassen(m, m2);
console.timeEnd("r2")
if(x == 2 && y == 2){
console.log("Test avec Strassen 2*2")
console.time("r3")
var r3 =strassen_2x2(m, m2);
console.timeEnd("r3")
}

suite
/*suite
.add('mmul1', function() {
m.mmul(m2);
})
.add('mmul2', function() {
m.mmul2(m2);
m.mmul_strassen(m, m2);
})
.on('cycle', function(event) {
console.log(String(event.target));
console.log(String(event.target));
})
.on('complete', function() {
console.log('Fastest is ' + this.filter('fastest').map('name'));
console.log('Fastest is ' + this.filter('fastest').map('name'));
})
.run();
*/
129 changes: 128 additions & 1 deletion src/abstractMatrix.js
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,25 @@ function abstractMatrix(superCtor) {
return matrix;
}

/**
* Creates a matrix with the given dimensions. Values will be randomly set.
* @param {number} rows - Number of rows
* @param {number} columns - Number of columns
* @param {function} [rng] - Random number generator (default: Math.random)
* @returns {Matrix} The new matrix
*/
static randInt(rows, columns, rng) {
if (rng === undefined) rng = Math.random;
var matrix = this.empty(rows, columns);
for (var i = 0; i < rows; i++) {
for (var j = 0; j < columns; j++) {
var value = parseInt(rng()*1000);
matrix.set(i, j, value);
}
}
return matrix;
}

/**
* Creates an identity matrix with the given dimension. Values of the diagonal will be 1 and others will be 0.
* @param {number} rows - Number of rows
Expand Down Expand Up @@ -958,6 +977,114 @@ function abstractMatrix(superCtor) {
return result;
}

/**
* Returns the matrix product between x and y. More efficient than mmul(other) only when we multiply squared matrix and when the size of the matrix is > 1000.
* @param {Matrix} x
* @param {Matrix} y
* @returns {Matrix}
*/
mmul_strassen(y){
var x = this.clone();
var r1 = x.rows;
var c1 = x.columns;
var r2 = y.rows;
var c2 = y.columns;
if(c1 != r2){
console.log(`Multiplying ${r1} x ${c1} and ${r2} x ${c2} matrix: dimensions do not match.`)
}

// Put a matrix into the top left of a matrix of zeros.
// `rows` and `cols` are the dimensions of the output matrix.
function embed(mat, rows, cols){
var r = mat.rows;
var c = mat.columns;
if((r==rows) && (c==cols)){
return mat;
}
else{
var resultat = Matrix.zeros(rows, cols);
resultat = resultat.setSubMatrix(mat, 0, 0);
return resultat;
}
}


// Make sure both matrices are the same size.
// This is exclusively for simplicity:
// this algorithm can be implemented with matrices of different sizes.

var r = Math.max(r1, r2);
var c = Math.max(c1, c2);
var x = embed(x, r, c);
var y = embed(y, r, c);

// Our recursive multiplication function.
function block_mult(a, b, rows, cols){
// For small matrices, resort to naive multiplication.
if (rows <= 512 || cols <= 512){
return a.mmul(b); // a is equivalent to this
}

// Apply dynamic padding.
if ((rows % 2 == 1) && (cols % 2 == 1)) {
a = embed(a, rows + 1, cols + 1);
b = embed(b, rows + 1, cols + 1);
}
else if (rows % 2 == 1){
a = embed(a, rows + 1, cols);
b = embed(b, rows + 1, cols);
}
else if (cols % 2 == 1){
a = embed(a, rows, cols + 1);
b = embed(b, rows, cols + 1);
}

var half_rows = parseInt(a.rows / 2);
var half_cols = parseInt(a.columns / 2);
// Subdivide input matrices.
var a11 = a.subMatrix(0, half_rows -1, 0, half_cols - 1);
var b11 = b.subMatrix(0, half_rows -1, 0, half_cols - 1);

var a12 = a.subMatrix(0, half_rows -1, half_cols, a.columns - 1);
var b12 = b.subMatrix(0, half_rows -1, half_cols, b.columns - 1);

var a21 = a.subMatrix(half_rows, a.rows - 1, 0, half_cols - 1);
var b21 = b.subMatrix(half_rows, b.rows - 1, 0, half_cols - 1);

var a22 = a.subMatrix(half_rows, a.rows - 1, half_cols, a.columns - 1);
var b22 = b.subMatrix(half_rows, b.rows - 1, half_cols, b.columns - 1);

// Compute intermediate values.
var m1 = block_mult(Matrix.add(a11,a22), Matrix.add(b11,b22), half_rows, half_cols);
var m2 = block_mult(Matrix.add(a21,a22), b11, half_rows, half_cols);
var m3 = block_mult(a11, Matrix.sub(b12, b22), half_rows, half_cols);
var m4 = block_mult(a22, Matrix.sub(b21,b11), half_rows, half_cols);
var m5 = block_mult(Matrix.add(a11,a12), b22, half_rows, half_cols);
var m6 = block_mult(Matrix.sub(a21, a11), Matrix.add(b11, b12), half_rows, half_cols);
var m7 = block_mult(Matrix.sub(a12,a22), Matrix.add(b21,b22), half_rows, half_cols);

// Combine intermediate values into the output.
var c11 = Matrix.add(m1, m4);
c11.sub(m5);
c11.add(m7);
var c12 = Matrix.add(m3,m5);
var c21 = Matrix.add(m2,m4);
var c22 = Matrix.sub(m1,m2);
c22.add(m3);
c22.add(m6);

//Crop output to the desired size (undo dynamic padding).
var resultat = Matrix.zeros(2*c11.rows, 2*c11.columns);
resultat = resultat.setSubMatrix(c11, 0, 0);
resultat = resultat.setSubMatrix(c12, c11.rows, 0)
resultat = resultat.setSubMatrix(c21, 0, c11.columns);
resultat = resultat.setSubMatrix(c22, c11.rows, c11.columns);
return resultat.subMatrix(0, rows - 1, 0, cols - 1);
}
var resultat_final = block_mult(x, y, r, c);
return resultat_final;
};

/**
* Returns a row-by-row scaled matrix
* @param {Number} [min=0] - Minimum scaled value
Expand Down Expand Up @@ -1198,7 +1325,7 @@ function abstractMatrix(superCtor) {
}

/*
Matrix views
Matrix views
*/

/**
Expand Down
6 changes: 6 additions & 0 deletions test/matrix/utility.js
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,10 @@ describe('utility methods', function () {
matrix.repeat(2, 2).to2DArray().should.eql([[1, 2, 1, 2], [3, 4, 3, 4], [1, 2, 1, 2], [3, 4, 3, 4]]);
matrix.repeat(1, 2).to2DArray().should.eql([[1, 2, 1, 2], [3, 4, 3, 4]]);
});

it('mmul strassen', function (){
var matrix = new Matrix([[2,4],[7,1]]);
var matrix2 = new Matrix([[2,1],[1,1]]);
matrix.mmul_strassen(matrix2).to2DArray().should.eql([[8,6], [15,8]]);
});
});

0 comments on commit fdc1c07

Please sign in to comment.