Skip to content

Commit

Permalink
feat: add mean by dimension and product methods
Browse files Browse the repository at this point in the history
  • Loading branch information
targos committed Apr 18, 2019
1 parent 220f2df commit 6b57aae
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 11 deletions.
27 changes: 25 additions & 2 deletions matrix.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ declare module 'ml-matrix' {
type MaybeMatrix = Matrix | number[][];
type Rng = () => number;
type ScalarOrMatrix = number | Matrix;
type MatrixDimension = 'row' | 'column';

class BaseView extends Matrix {}
class MatrixColumnView extends BaseView {
Expand Down Expand Up @@ -154,13 +155,35 @@ declare module 'ml-matrix' {
* Returns the sum of all elements of the matrix.
*/
sum(): number;

/**
* Returns the sum by the dimension given.
* Returns the sum by the given dimension.
* @param by - sum by 'row' or 'column'.
*/
sum(by: 'row' | 'column'): number[];
sum(by: MatrixDimension): number[];

/**
* Returns the product of all elements of the matrix.
*/
product(): number;

/**
* Returns the product by the given dimension.
* @param by - product by 'row' or 'column'.
*/
product(by: MatrixDimension): number[];

/**
* Returns the mean of all elements of the matrix.
*/
mean(): number;

/**
* Returns the mean by the given dimension.
* @param by - mean by 'row' or 'column'.
*/
mean(by: MatrixDimension): number[];

prod(): number;
norm(type: 'frobenius' | 'max'): number;
cumulativeSum(): Matrix;
Expand Down
16 changes: 16 additions & 0 deletions src/__tests__/matrix/mean.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { Matrix } from '../..';

describe('mean by row and columns', () => {
const matrix = new Matrix([[1, 2, 3], [4, 5, 6]]);
it('mean by row', () => {
expect(matrix.mean('row')).toStrictEqual([2, 5]);
});

it('mean by column', () => {
expect(matrix.mean('column')).toStrictEqual([2.5, 3.5, 4.5]);
});

it('mean all', () => {
expect(matrix.mean()).toBe(3.5);
});
});
16 changes: 16 additions & 0 deletions src/__tests__/matrix/product.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { Matrix } from '../..';

describe('product by row and columns', () => {
const matrix = new Matrix([[1, 2, 3], [4, 5, 6]]);
it('product by row', () => {
expect(matrix.product('row')).toStrictEqual([6, 120]);
});

it('product by column', () => {
expect(matrix.product('column')).toStrictEqual([4, 10, 18]);
});

it('product all', () => {
expect(matrix.product()).toBe(720);
});
});
41 changes: 34 additions & 7 deletions src/abstractMatrix.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
checkRange,
checkIndices
} from './util';
import { sumByRow, sumByColumn, sumAll } from './stat';
import { sumByRow, sumByColumn, sumAll, productByRow, productByColumn, productAll } from './stat';
import MatrixTransposeView from './views/transpose';
import MatrixRowView from './views/row';
import MatrixSubView from './views/sub';
Expand Down Expand Up @@ -960,12 +960,39 @@ export default function AbstractMatrix(superCtor) {
}
}

/**
* Returns the mean of all elements of the matrix
* @return {number}
*/
mean() {
return this.sum() / this.size;
product(by) {
switch (by) {
case 'row':
return productByRow(this);
case 'column':
return productByColumn(this);
case undefined:
return productAll(this);
default:
throw new Error(`invalid option: ${by}`);
}
}

mean(by) {
const sum = this.sum(by);
switch (by) {
case 'row': {
for (let i = 0; i < this.rows; i++) {
sum[i] /= this.columns;
}
return sum;
}
case 'column': {
for (let i = 0; i < this.columns; i++) {
sum[i] /= this.rows;
}
return sum;
}
case undefined:
return sum / this.size;
default:
throw new Error(`invalid option: ${by}`);
}
}

/**
Expand Down
30 changes: 30 additions & 0 deletions src/stat.js
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,33 @@ export function sumAll(matrix) {
}
return v;
}

export function productByRow(matrix) {
var sum = newArray(matrix.rows, 1);
for (var i = 0; i < matrix.rows; ++i) {
for (var j = 0; j < matrix.columns; ++j) {
sum[i] *= matrix.get(i, j);
}
}
return sum;
}

export function productByColumn(matrix) {
var sum = newArray(matrix.columns, 1);
for (var i = 0; i < matrix.rows; ++i) {
for (var j = 0; j < matrix.columns; ++j) {
sum[j] *= matrix.get(i, j);
}
}
return sum;
}

export function productAll(matrix) {
var v = 1;
for (var i = 0; i < matrix.rows; i++) {
for (var j = 0; j < matrix.columns; j++) {
v *= matrix.get(i, j);
}
}
return v;
}
4 changes: 2 additions & 2 deletions src/util.js
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ export function getRange(from, to) {
return arr;
}

export function newArray(length) {
export function newArray(length, value = 0) {
var array = [];
for (var i = 0; i < length; i++) {
array.push(0);
array.push(value);
}
return array;
}
Expand Down

0 comments on commit 6b57aae

Please sign in to comment.