diff --git a/tensorboard/plugins/projector/vz_projector/knn.ts b/tensorboard/plugins/projector/vz_projector/knn.ts index 819eb8b14ef..8ab7a1a48e3 100644 --- a/tensorboard/plugins/projector/vz_projector/knn.ts +++ b/tensorboard/plugins/projector/vz_projector/knn.ts @@ -22,17 +22,9 @@ export type NearestEntry = { index: number; dist: number; }; -/** - * Optimal size for the height of the matrix when doing computation on the GPU - * using WebGL. This was found experimentally. - * - * This also guarantees that for computing pair-wise distance for up to 10K - * vectors, no more than 40MB will be allocated in the GPU. Without the - * allocation limit, we can freeze the graphics of the whole OS. - */ -const OPTIMAL_GPU_BLOCK_SIZE = 256; -/** Id of message box used for knn gpu progress bar. */ -const KNN_GPU_MSG_ID = 'knn-gpu'; + +/** Id of message box used for knn. */ +const KNN_MSG_ID = 'knn'; /** * Returns the K nearest neighbors for each vector where the distance @@ -52,105 +44,63 @@ export function findKNNGPUCosDistNorm( const N = dataPoints.length; const dim = accessor(dataPoints[0]).length; // The goal is to compute a large matrix multiplication A*A.T where A is of - // size NxD and A.T is its transpose. This results in a NxN matrix which - // could be too big to store on the GPU memory. To avoid memory overflow, we - // compute multiple A*partial_A.T where partial_A is of size BxD (B is much - // smaller than N). This results in storing only NxB size matrices on the GPU - // at a given time. + // size NxD and A.T is its transpose. This results in a NxN matrix. // A*A.T will give us NxN matrix holding the cosine distance between every // pair of points, which we sort using KMin data structure to obtain the // K nearest neighbors for each point. const nearest: NearestEntry[][] = new Array(N); - let numPieces = Math.ceil(N / OPTIMAL_GPU_BLOCK_SIZE); - const actualPieceSize = Math.floor(N / numPieces); - const modulo = N % actualPieceSize; - numPieces += modulo ? 1 : 0; - let offset = 0; - let progress = 0; - let progressDiff = 1 / (2 * numPieces); - let piece = 0; - - const typedArray = vector.toTypedArray(dataPoints, accessor); - const bigMatrix = tf.tensor(typedArray, [N, dim]); - const bigMatrixTransposed = tf.transpose(bigMatrix); - // 1 - A * A^T. - const bigMatrixSquared = tf.matMul(bigMatrix, bigMatrixTransposed); - const cosDistMatrix = tf.sub(1, bigMatrixSquared); - - let maybePaddedCosDistMatrix = cosDistMatrix; - if (actualPieceSize * numPieces > N) { - // Expect the input to be rank 2 (though it is not typed that way) so we - // want to pad the first dimension so we split very evenly (all splitted - // tensor have exactly the same dimesion). - const padding: Array<[number, number]> = [ - [0, actualPieceSize * numPieces - N], - [0, 0], - ]; - maybePaddedCosDistMatrix = tf.pad(cosDistMatrix, padding); - } - const splits = tf.split( - maybePaddedCosDistMatrix, - new Array(numPieces).fill(actualPieceSize), - 0 - ); - function step(resolve: (result: NearestEntry[][]) => void) { - let progressMsg = - 'Finding nearest neighbors: ' + (progress * 100).toFixed() + '%'; util .runAsyncTask( - progressMsg, + 'Finding nearest neighbors...', async () => { + const cosSimilarityMatrix = tf.tidy(() => { + const typedArray = vector.toTypedArray(dataPoints, accessor); + const bigMatrix = tf.tensor(typedArray, [N, dim]); + const bigMatrixTransposed = tf.transpose(bigMatrix); + // A * A^T. + return tf.matMul(bigMatrix, bigMatrixTransposed); + }); // `.data()` returns flattened Float32Array of B * N dimension. // For matrix of // [ 1 2 ] // [ 3 4 ], // `.data()` returns [1, 2, 3, 4]. - const partial = await splits[piece].data(); - progress += progressDiff; - for (let i = 0; i < actualPieceSize; i++) { + let partial; + try { + partial = await cosSimilarityMatrix.data(); + } finally { + // Discard all tensors and free up the memory. + cosSimilarityMatrix.dispose(); + } + for (let i = 0; i < N; i++) { let kMin = new KMin(k); - let iReal = offset + i; - if (iReal >= N) break; for (let j = 0; j < N; j++) { // Skip diagonal entries. - if (j === iReal) { + if (j === i) { continue; } // Access i * N's row at `j` column. // Reach row has N entries and j-th index has cosine distance - // between iReal vs. j-th vectors. - const cosDist = partial[i * N + j]; + // between i-th vs. j-th vectors. + const cosDist = 1 - partial[i * N + j]; if (cosDist >= 0) { kMin.add(cosDist, {index: j, dist: cosDist}); } } - nearest[iReal] = kMin.getMinKItems(); + nearest[i] = kMin.getMinKItems(); } - progress += progressDiff; - offset += actualPieceSize; - piece++; }, - KNN_GPU_MSG_ID + KNN_MSG_ID ) .then( () => { - if (piece < numPieces) { - step(resolve); - } else { - logging.setModalMessage(null!, KNN_GPU_MSG_ID); - // Discard all tensors and free up the memory. - bigMatrix.dispose(); - bigMatrixTransposed.dispose(); - bigMatrixSquared.dispose(); - cosDistMatrix.dispose(); - splits.forEach((split) => split.dispose()); - resolve(nearest); - } + logging.setModalMessage(null!, KNN_MSG_ID); + resolve(nearest); }, (error) => { // GPU failed. Reverting back to CPU. - logging.setModalMessage(null!, KNN_GPU_MSG_ID); + logging.setModalMessage(null!, KNN_MSG_ID); let distFunc = (a, b, limit) => vector.cosDistNorm(a, b); findKNN(dataPoints, k, accessor, distFunc).then((nearest) => { resolve(nearest); @@ -212,47 +162,12 @@ export function findKNN( for (let i = 0; i < N; i++) { nearest[i] = kMin[i].getMinKItems(); } + logging.setModalMessage(null!, KNN_MSG_ID); return nearest; - } + }, + KNN_MSG_ID ); } -/** Calculates the minimum distance between a search point and a rectangle. */ -function minDist( - point: [number, number], - x1: number, - y1: number, - x2: number, - y2: number -) { - let x = point[0]; - let y = point[1]; - let dx1 = x - x1; - let dx2 = x - x2; - let dy1 = y - y1; - let dy2 = y - y2; - if (dx1 * dx2 <= 0) { - // x is between x1 and x2 - if (dy1 * dy2 <= 0) { - // (x,y) is inside the rectangle - return 0; // return 0 as point is in rect - } - return Math.min(Math.abs(dy1), Math.abs(dy2)); - } - if (dy1 * dy2 <= 0) { - // y is between y1 and y2 - // We know it is already inside the rectangle - return Math.min(Math.abs(dx1), Math.abs(dx2)); - } - let corner: [number, number]; - if (x > x2) { - // Upper-right vs lower-right. - corner = y > y2 ? [x2, y2] : [x2, y1]; - } else { - // Upper-left vs lower-left. - corner = y > y2 ? [x1, y2] : [x1, y1]; - } - return Math.sqrt(vector.dist22D([x, y], corner)); -} /** * Returns the nearest neighbors of a particular point. * @@ -281,5 +196,3 @@ export function findKNNofPoint( } return kMin.getMinKItems(); } - -export const TEST_ONLY = {OPTIMAL_GPU_BLOCK_SIZE}; diff --git a/tensorboard/plugins/projector/vz_projector/knn_test.ts b/tensorboard/plugins/projector/vz_projector/knn_test.ts index ce7c771b659..a8e4009d23a 100644 --- a/tensorboard/plugins/projector/vz_projector/knn_test.ts +++ b/tensorboard/plugins/projector/vz_projector/knn_test.ts @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {findKNN, findKNNGPUCosDistNorm, NearestEntry, TEST_ONLY} from './knn'; +import {findKNN, findKNNGPUCosDistNorm, NearestEntry} from './knn'; import {cosDistNorm, unit} from './vector'; describe('projector knn test', () => { @@ -65,22 +65,6 @@ describe('projector knn test', () => { expect(getIndices(values)).toEqual([[1], [0]]); }); - - it('splits a large data into one that would fit into GPU memory', async () => { - const size = TEST_ONLY.OPTIMAL_GPU_BLOCK_SIZE + 5; - const data = new Array(size).fill( - unitVector(new Float32Array([1, 1, 1])) - ); - const values = await findKNNGPUCosDistNorm(data, 1, (a) => a); - - expect(getIndices(values)).toEqual([ - // Since distance to the diagonal entries (distance to self is 0) is - // non-sensical, the diagonal entires are ignored. So for the first - // item, the nearest neighbor should be 2nd item (index 1). - [1], - ...new Array(size - 1).fill([0]), - ]); - }); }); describe('#findKNN', () => { @@ -108,22 +92,5 @@ describe('projector knn test', () => { // Floating point precision makes it hard to test. Just assert indices. expect(getIndices(findKnnGpuCosVal)).toEqual(getIndices(findKnnVal)); }); - - it('splits a large data without the result being wrong', async () => { - const size = TEST_ONLY.OPTIMAL_GPU_BLOCK_SIZE + 5; - const data = Array.from(new Array(size)).map((_, index) => { - return unitVector(new Float32Array([index + 1, index + 1])); - }); - - const findKnnGpuCosVal = await findKNNGPUCosDistNorm(data, 2, (a) => a); - const findKnnVal = await findKNN( - data, - 2, - (a) => a, - (a, b, limit) => cosDistNorm(a, b) - ); - - expect(getIndices(findKnnGpuCosVal)).toEqual(getIndices(findKnnVal)); - }); }); });