Skip to content

Helper functions to use tensorflow in nodejs for transfer learning, image classification, and more

License

Notifications You must be signed in to change notification settings

beenotung/tensorflow-helpers

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

174 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

tensorflow-helpers

Helper functions to use tensorflow in nodejs/browser for transfer learning, image classification, and more.

npm Package Version

Features

  • Support transfer learning and continuous learning
  • Custom image classifier using embedding features from pre-trained image model
  • Extract both spatial and pooled features from image models
  • Correctly save/load model on filesystem[1]
  • Load image file into tensor with resize and crop
  • List varies pre-trained models (url, image dimension, embedding size)
  • Support both nodejs and browser environment
  • Support caching model and image embedding
  • CLI tool for downloading TensorFlow.js models from various sources
  • Interactive model topology visualization
  • Standard weightsManifest format support for better compatibility
  • Typescript support
  • Works with plain Javascript, Typescript is not mandatory

[1]: The built-in tf.loadGraphModel() cannot load the model saved by model.save()

Model Artifacts Management

Models now provide better access to their internal artifacts while maintaining backward compatibility:

// The familiar classNames API still works (now uses getter/setter proxy)
model.classNames = ['cat', 'dog', 'bird']

// New: Direct access to model artifacts
const artifacts = model.getArtifacts()

// Access classNames through the artifacts (standard TensorFlow.js format)
const classNames = artifacts.userDefinedMetadata?.classNames

Installation

npm install tensorflow-helpers

You can also install tensorflow-helpers with pnpm, yarn, or slnpm

Development Scripts

When working with the source code, the following scripts are available:

# Development
npm run dev          # Watch mode for browser testing
npm run dev:chart    # Watch mode for model visualization

# Building
npm run bundle       # Build browser test bundle
npm run bundle:chart # Build chart visualization bundle
npm run build        # Build the library for distribution

# Testing
npm run test         # TypeScript type checking
npm run clean        # Clean build artifacts

Usage Example

See model.test.ts and classifier.test.ts for complete examples.

Quick Start: Use the CLI tool to download models: npx download-tfjs-model <source> <output-dir> (see CLI Tool section for details).

Usage from browser:

import {
  loadImageModel,
  getImageFeatures,
  loadImageClassifierModel,
  toOneTensor,
} from 'tensorflow-helpers/browser'

declare var fileInput: HTMLInputElement

async function main() {
  let baseModel = await loadImageModel({
    url: 'saved_model/mobilenet-v3-large-100',
    cacheUrl: 'indexeddb://mobilenet-v3-large-100',
    checkForUpdates: false,
  })

  let classifier = await loadImageClassifierModel({
    baseModel,
    classNames: ['anime', 'real', 'others'],
    modelUrl: 'saved_model/emotion-classifier',
    cacheUrl: 'indexeddb://emotion-classifier',
  })

  fileInput.onchange = async () => {
    let file = fileInput.files?.[0]
    if (!file) return

    // Extract both spatial and pooled features
    let features = await getImageFeatures({
      tf,
      imageModel: baseModel,
      image: file,
    })
    console.log('spatial features shape:', features.spatialFeatures.shape) // [1, 7, 7, 160]
    console.log('pooled features shape:', features.pooledFeatures.shape) // [1, 1280]

    // Classify the image
    let result = await classifier.classifyImageFile(file)
    // classifyImageFile handles end-to-end classification: auto-resize image → extract features → classify
    console.log('classification result:', result)
    // result is Array<{ label: string, confidence: number }> - e.g. [{ label: 'anime', confidence: 0.8 }, ...]
  }
}
main().catch(e => console.error(e))

Usage from nodejs:

import {
  loadImageModel,
  PreTrainedImageModels,
  getImageFeatures,
  loadImageClassifierModel,
  topClassifyResult,
} from 'tensorflow-helpers'

// Load pre-trained base model
let baseModel = await loadImageModel({
  spec: PreTrainedImageModels.mobilenet['mobilenet-v3-large-100'],
  dir: 'saved_model/base_model',
})
console.log('embedding features:', baseModel.spec.features)
// [print] embedding features: 1280

// Extract both spatial and pooled features
let features = await getImageFeatures({
  tf,
  imageModel: baseModel,
  image: 'image.jpg',
})
console.log('spatial features shape:', features.spatialFeatures.shape) // [1, 7, 7, 160]
console.log('pooled features shape:', features.pooledFeatures.shape) // [1, 1280]

// Create classifier for image classification
let classifier = await loadImageClassifierModel({
  baseModel,
  modelDir: 'saved_model/classifier_model',
  hiddenLayers: [128],
  datasetDir: 'dataset',
  // classNames: ['anime', 'real', 'others'], // auto scan from datasetDir
})

// auto load training dataset
let history = await classifier.train({
  epochs: 5,
  batchSize: 32,
})

// persist the parameters across restart
await classifier.save()

// auto load image from filesystem, resize and crop
let classes = await classifier.classifyImageFile('image.jpg')
let topClass = topClassifyResult(classes)

console.log('result:', topClass)
// [print] result: { label: 'anime', confidence: 0.7991582155227661 }

Model Visualization (Local Development)

Interactive model topology visualization for analyzing model structure and selecting feature extraction points:

# Start the visualization development server
npm run dev:chart

# Open chart.html in your browser to visualize model topology

Note: This feature is currently only available for local development due to CORS restrictions. A hosted version is not available at this time.

Main Purpose:

  • Model Topology Analysis: Visualize the complete graph of model nodes and their connections
  • Feature Selection: Identify optimal nodes to tap into for extracting intermediate features (spatial features, embeddings, etc.)
  • Shape Inspection: See the tensor shapes at each node to understand data flow through the model

Features:

  • Interactive node exploration with hover and click-to-lock functionality
  • Visual connections showing data flow between layers
  • Node details including operation type and tensor shapes
  • Support for various model formats (GraphModel, LayersModel)

Typescript Signature

Details see the type hints from IDE.

Shortcut to tensorflow

exported as 'tensorflow-helpers':

import * as tfjs from '@tensorflow/tfjs-node'

export let tensorflow: typeof tfjs
export let tf: typeof tfjs

exported as 'tensorflow-helpers/browser':

import * as tfjs from '@tensorflow/tfjs'

export let tensorflow: typeof tfjs
export let tf: typeof tfjs
Pre-trained model constants
export const PreTrainedImageModels: {
  mobilenet: {
    'mobilenet-v3-large-100': {
      url: 'https://www.kaggle.com/models/google/mobilenet-v3/TfJs/large-100-224-feature-vector/1'
      width: 224
      height: 224
      channels: 3
      features: 1280
    }
    // more models omitted ...
  }
}
Model helper functions
export type Model = tf.GraphModel | tf.LayersModel

export function saveModel(options: {
  model: Model
  dir: string
}): Promise<SaveResult>

export function loadGraphModel(options: { dir: string }): Promise<tf.GraphModel>

export function loadLayersModel(options: {
  dir: string
}): Promise<tf.LayersModel>

export function cachedLoadGraphModel(options: {
  url: string
  dir: string
}): Promise<Model>

export function cachedLoadLayersModel(options: {
  url: string
  dir: string
}): Promise<Model>

// Model artifacts management
export function getModelArtifacts<Model extends object>(
  model: Model,
): PatchedModelArtifacts

export function exposeModelArtifacts<Model extends object>(
  model: Model,
): Model & {
  getArtifacts: () => PatchedModelArtifacts
  classNames?: string[]
}

export type PatchedModelArtifacts = ModelJSON &
  Pick<ModelArtifacts, 'weightData' | 'weightSpecs'> & {
    userDefinedMetadata?: {
      classNames?: string[]
    }
  }

export function loadImageModel(options: {
  spec: ImageModelSpec
  dir: string
  aspectRatio?: CropAndResizeAspectRatio
  cache?: EmbeddingCache | boolean
}): Promise<ImageModel>

export type EmbeddingCache = {
  get(filename: string): number[] | null | undefined
  set(filename: string, values: number[]): void
}

export type ImageModelSpec = {
  url: string
  width: number
  height: number
  channels: number
  features: number
}

export type ImageModel = {
  spec: ImageModelSpec
  model: Model

  fileEmbeddingCache: Map<string, tf.Tensor> | null
  checkCache(file_or_filename: string): tf.Tensor | void

  loadImageCropped(
    file: string,
    options?: {
      expandAnimations?: boolean
    },
  ): Promise<tf.Tensor3D | tf.Tensor4D>

  imageFileToEmbedding(
    file: string,
    options?: {
      expandAnimations?: boolean
    },
  ): Promise<tf.Tensor>

  imageTensorToEmbedding(imageTensor: tf.Tensor3D | tf.Tensor4D): tf.Tensor
}
Image helper functions and types
export function loadImageFile(
  file: string,
  options?: {
    channels?: number
    dtype?: string
    expandAnimations?: boolean
    crop?: {
      width: number
      height: number
      aspectRatio?: CropAndResizeAspectRatio
    }
  },
): Promise<tf.Tensor3D | tf.Tensor4D>

export type ImageTensor = tf.Tensor3D | tf.Tensor4D

export function getImageTensorShape(imageTensor: tf.Tensor3D | tf.Tensor4D): {
  width: number
  height: number
}

export type Box = [top: number, left: number, bottom: number, right: number]

/**
 * @description calculate center-crop box
 * @returns [top,left,bottom,right], values range: 0..1
 */
export function calcCropBox(options: {
  sourceShape: { width: number; height: number }
  targetShape: { width: number; height: number }
}): Box

/**
 * @description default is 'rescale'
 *
 * 'rescale' -> scratch/transform to target shape;
 *
 * 'center-crop' -> crop the edges, maintain aspect ratio at center
 */
export type CropAndResizeAspectRatio = 'rescale' | 'center-crop'

export function cropAndResizeImageTensor(options: {
  imageTensor: tf.Tensor3D | tf.Tensor4D
  width: number
  height: number
  aspectRatio?: CropAndResizeAspectRatio
}): tf.Tensor4D

export function cropAndResizeImageFile(options: {
  srcFile: string
  destFile: string
  width: number
  height: number
  aspectRatio?: CropAndResizeAspectRatio
}): Promise<void>
Tensor helper functions
export function disposeTensor(tensor: tf.Tensor | tf.Tensor[]): void

export function toOneTensor(
  tensor: tf.Tensor | tf.Tensor[] | tf.NamedTensorMap,
): tf.Tensor

export function toTensor4D(tensor: tf.Tensor3D | tf.Tensor4D): tf.Tensor4D

export function toTensor3D(tensor: tf.Tensor3D | tf.Tensor4D): tf.Tensor3D
Classifier helper functions
export type ClassifierModelSpec = {
  embeddingFeatures: number
  hiddenLayers?: number[]
  classes: number
}

export function createImageClassifier(spec: ClassifierModelSpec): tf.Sequential

export type ClassificationResult = {
  label: string
  /** @description between 0 to 1 */
  confidence: number
}

export type ClassifierModel = {
  baseModel: {
    spec: ImageModelSpec
    model: Model
    loadImageAsync: (file: string) => Promise<tf.Tensor4D>
    loadImageSync: (file: string) => tf.Tensor4D
    loadAnimatedImageAsync: (file: string) => Promise<tf.Tensor4D>
    loadAnimatedImageSync: (file: string) => tf.Tensor4D
    inferEmbeddingAsync: (
      file_or_image_tensor: string | tf.Tensor,
    ) => Promise<tf.Tensor>
    inferEmbeddingSync: (file_or_image_tensor: string | tf.Tensor) => tf.Tensor
  }
  classifierModel: tf.LayersModel | tf.Sequential
  classNames: string[]
  classifyAsync: (
    file_or_image_tensor: string | tf.Tensor,
  ) => Promise<ClassificationResult[]>
  classifySync: (
    file_or_image_tensor: string | tf.Tensor,
  ) => ClassificationResult[]
  loadDatasetFromDirectoryAsync: () => Promise<{
    x: tf.Tensor<tf.Rank>
    y: tf.Tensor<tf.Rank>
  }>
  compile: () => void
  train: (options?: tf.ModelFitArgs) => Promise<tf.History>
  save: (dir?: string) => Promise<SaveResult>
}

export function loadImageClassifierModel(options: {
  baseModel: ImageModel
  hiddenLayers?: number[]
  modelDir: string
  datasetDir: string
  classNames?: string[]
}): Promise<ClassifierModel>

export function topClassifyResult(
  items: ClassificationResult[],
): ClassificationResult

/**
 * @description the values is returned as is.
 * It should has be applied softmax already
 * */
export function mapWithClassName(
  classNames: string[],
  values: ArrayLike<number>,
  options?: {
    sort?: boolean
  },
): ClassificationResult[]
Feature extraction functions
export async function getImageFeatures(options: {
  tf: typeof import('@tensorflow/tfjs-node')
  imageModel: ImageModel
  image: string | Tensor
  /** default: 'Identity:0' */
  outputNode?: string
}): Promise<{
  spatialFeatures: Tensor // e.g. [1, 7, 7, 160] - spatial feature map
  pooledFeatures: Tensor // e.g. [1, 1280] - global average pooled features
}>

/**
 * @description Get the name of the last spatial node in the model
 * Used internally by getImageFeatures to extract spatial features
 */
export function getLastSpatialNodeName(model: GraphModel): string
Model helper functions
/**
 * A factor to give larger hidden layer size for complex tasks:
 * - 1 for easy tasks
 * - 2-3 for medium difficulty tasks
 * - 4-5 for complex tasks
 *
 * Remark: giving too high difficulty may result in over-fitting.
 */
export type Difficulty = number

/** Formula `hiddenSize = difficulty * sqrt(inputSize * outputSize)` */
export function calcHiddenLayerSize(options: {
  inputSize: number
  outputSize: number
  difficulty?: Difficulty
})

/** Inject one or more hidden layers that's having large gap between input size and output size. */
export function injectHiddenLayers(options: {
  layers: number[]
  difficulty?: Difficulty
  numberOfHiddenLayers?: number
})
File helper functions
/**
 * @description
 * - rename filename to content hash + extname;
 * - return list of (renamed) filenames
 */
export async function scanDir(dir: string): Promise<string[]>

export function isContentHash(file_or_filename: string): boolean

export async function saveFile(args: {
  dir: string
  content: Buffer
  mimeType: string
}): Promise<void>

export function hashContent(
  content: Buffer,
  encoding: BufferEncoding = 'hex',
): string

/** @returns new filename with content hash and extname */
export async function renameFileByContentHash(file: string): Promise<string>
CLI Tool

The package includes a command-line tool for downloading and converting TensorFlow.js models:

# Usage
npx download-tfjs-model <source> <output-dir>

# Examples
npx download-tfjs-model https://www.kaggle.com/models/google/mobilenet-v3/TfJs/large-100-224-feature-vector/1 ./browser-models/mobilenet-v3-large-100
npx download-tfjs-model ./hub-models/mobilenet-v2-035-128-feature-vector ./browser-models/mobilenet-v2-035

Supported sources:

  • TensorFlow Hub URLs
  • Kaggle model URLs
  • Local model directories
  • Local model.json files

Features:

  • Automatic model format detection (GraphModel vs LayersModel)
  • Standard weightsManifest format conversion
  • Source metadata preservation
  • Recursive directory creation
(Browser version) model functions and types
/**
 * @example `loadGraphModel({ url: 'saved_model/mobilenet-v3-large-100' })`
 */
export function loadGraphModel(options: { url: string }): Promise<tf.GraphModel>

/**
 * @example `loadGraphModel({ url: 'saved_model/emotion-classifier' })`
 */
export function loadLayersModel(options: {
  url: string
}): Promise<tf.LayersModel>

/**
 * @example ```
 * cachedLoadGraphModel({
 *   url: 'saved_model/mobilenet-v3-large-100',
 *   cacheUrl: 'indexeddb://mobilenet-v3-large-100',
 * })
 * ```
 */
export function cachedLoadGraphModel(options: {
  url: string
  cacheUrl: string
  checkForUpdates?: boolean
}): Promise<tf.GraphModel<string | tf.io.IOHandler>>

/**
 * @example ```
 * cachedLoadLayersModel({
 *   url: 'saved_model/emotion-classifier',
 *   cacheUrl: 'indexeddb://emotion-classifier',
 * })
 * ```
 */
export function cachedLoadLayersModel(options: {
  url: string
  cacheUrl: string
  checkForUpdates?: boolean
}): Promise<tf.LayersModel>
(Browser version) image model functions and types
export type ImageModel = {
  spec: ImageModelSpec
  model: tf.GraphModel<string | tf.io.IOHandler>
  fileEmbeddingCache: Map<string, tf.Tensor<tf.Rank>> | null
  checkCache: (url: string) => tf.Tensor | void
  loadImageCropped: (url: string) => Promise<tf.Tensor4D & tf.Tensor<tf.Rank>>
  imageUrlToEmbedding: (url: string) => Promise<tf.Tensor>
  imageFileToEmbedding: (file: File) => Promise<tf.Tensor>
  imageTensorToEmbedding: (imageTensor: ImageTensor) => tf.Tensor
}

/**
 * @description cache image embedding keyed by filename.
 * The dirname is ignored.
 * The filename is expected to be content hash (w/wo extname)
 */
export type EmbeddingCache = {
  get(url: string): number[] | null | undefined
  set(url: string, values: number[]): void
}

export function loadImageModel<Cache extends EmbeddingCache>(options: {
  url: string
  cacheUrl?: string
  checkForUpdates?: boolean
  aspectRatio?: CropAndResizeAspectRatio
  cache?: Cache | boolean
}): Promise<ImageModel>
(Browser version) classifier functions and types
export type ClassifierModel = {
  baseModel: ImageModel
  classifierModel: tf.LayersModel | tf.Sequential
  classNames: string[]
  classifyImageUrl(url: string): Promise<ClassificationResult[]>
  classifyImageFile(file: File): Promise<ClassificationResult[]>
  classifyImageTensor(
    imageTensor: tf.Tensor3D | tf.Tensor4D,
  ): Promise<ClassificationResult[]>
  classifyImage(
    image: Parameters<typeof tf.browser.fromPixels>[0],
  ): Promise<ClassificationResult[]>
  classifyImageEmbedding(embedding: tf.Tensor): Promise<ClassificationResult[]>
  compile(): void
  train(
    options: tf.ModelFitArgs & {
      x: tf.Tensor<tf.Rank>
      y: tf.Tensor<tf.Rank>
      /** @description to calculate classWeight */
      classCounts?: number[]
    },
  ): Promise<tf.History>
}

export function loadImageClassifierModel(options: {
  baseModel: ImageModel
  hiddenLayers?: number[]
  modelUrl?: string
  cacheUrl?: string
  checkForUpdates?: boolean
  classNames: string[]
}): Promise<ClassifierModel>
(Browser version) feature extraction functions
export async function getImageFeatures(options: {
  tf: typeof import('@tensorflow/tfjs-core')
  imageModel: ImageModel
  image: string | Tensor
  /** default: 'Identity:0' */
  outputNode?: string
  /** default: getLastSpatialNodeName(model) */
  spatialNode?: node
}): Promise<{
  /** e.g. `[1 x 7 x 7 x 160]` spatial feature map */
  spatialFeatures: Tensor
  /** e.g. `[1 x 1280]` global average pooled features */
  pooledFeatures: Tensor
}>
export async function getImageFeatures(options: {
  tf: typeof import('@tensorflow/tfjs-core')
  imageModel: ImageModel
  image: string | Tensor
  /** default: 'Identity:0' */
  outputNode?: string
  /** e.g. `imageModel.spatialNodesWithUniqueShapes` */
  spatialNodes: node[]
}): Promise<{
  /** list of spatial feature maps
   * e.g.
   * ```
   * [
   *   [1 x 56 x 56 x 24],
   *   [1 x 28 x 28 x 40],
   *   [1 x 14 x 14 x 80],
   *   [1 x 14 x 14 x 112],
   *   [1 x 7 x 7 x 160],
   * ]
   * ```
   *  */
  spatialFeatures: Tensor[]
  /** e.g. `[1 x 1280]` global average pooled features */
  pooledFeatures: Tensor
}>

Packages

No packages published

Contributors 2

  •  
  •