Skip to content

[pose-detection]Change enableSmoothing to modelConfig for MoveNet. #727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pose-detection/src/blazepose_mediapipe/detector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class BlazePoseMediaPipeDetector implements PoseDetector {
/**
* Loads the MediaPipe solution.
*
* @param modelConfig ModelConfig dictionary that contains parameters for
* @param modelConfig ModelConfig object that contains parameters for
* the BlazePose loading process. Please find more details of each parameters
* in the documentation of the `BlazePoseMediaPipeModelConfig` interface.
*/
Expand Down
28 changes: 27 additions & 1 deletion pose-detection/src/blazepose_mediapipe/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,38 @@ export interface BlazePoseModelConfig extends ModelConfig {
export interface BlazePoseEstimationConfig extends EstimationConfig {}

/**
* Mediapipe model loading config.
* Model parameters for BlazePose MediaPipe runtime
*
* `runtime`: Must set to be 'mediapipe'.
*
* `enableSmoothing`: Optional. A boolean indicating whether to use temporal
* filter to smooth the predicted keypoints. Defaults to True. The temporal
* filter relies on `performance.now()`. You can override this timestamp by
* passing in your own timestamp (in milliseconds) as the third parameter in
* `estimatePoses`.
*
* `modelType`: Optional. Possible values: 'lite'|'full'|'heavy'. Defaults to
* 'full'. The model accuracy increases from lite to heavy, while the inference
* speed decreases and memory footprint increases. The heavy variant is intended
* for applications that require high accuracy, while the lite variant is
* intended for latency-critical applications. The full variant is a balanced
* option.
*
* `solutionPath`: Optional. The path to where the wasm binary and model files
* are located.
*/
export interface BlazePoseMediaPipeModelConfig extends BlazePoseModelConfig {
runtime: 'mediapipe';
solutionPath?: string;
}

/**
* Pose estimation parameters for BlazePose MediaPipe runtime.
*
* `maxPoses`: Optional. Defaults to 1. BlazePose only supports 1 pose for now.
*
* `flipHorizontal`: Optional. Default to false. When image data comes from
* camera, the result has to flip horizontally.
*/
export interface BlazePoseMediaPipeEstimationConfig extends
BlazePoseEstimationConfig {}
16 changes: 4 additions & 12 deletions pose-detection/src/blazepose_tfjs/detector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,8 @@ class BlazePoseTfjsDetector implements PoseDetector {
* ImageData|HTMLImageElement|HTMLCanvasElement|HTMLVideoElement The input
* image to feed through the network.
*
* @param config Optional.
* maxPoses: Optional. Max number of poses to estimate.
* When maxPoses = 1, a single pose is detected, it is usually much more
* efficient than maxPoses > 1. When maxPoses > 1, multiple poses are
* detected.
*
* flipHorizontal: Optional. Default to false. When image data comes
* from camera, the result has to flip horizontally.
* @param estimationConfig Optional. See `BlazePoseTfjsEstimationConfig`
* documentation for detail.
*
* @param timestamp Optional. In milliseconds. This is useful when image is
* a tensor, which doesn't have timestamp info. Or to override timestamp
Expand Down Expand Up @@ -473,11 +467,9 @@ class BlazePoseTfjsDetector implements PoseDetector {
}

/**
* Loads the BlazePose model. The model to be loaded is configurable using the
* config dictionary `BlazePoseTfjsModelConfig`. Please find more details in
* the documentation of the `BlazePoseTfjsModelConfig`.
* Loads the BlazePose model.
*
* @param modelConfig ModelConfig dictionary that contains parameters for
* @param modelConfig ModelConfig object that contains parameters for
* the BlazePose loading process. Please find more details of each parameters
* in the documentation of the `BlazePoseTfjsModelConfig` interface.
*/
Expand Down
27 changes: 26 additions & 1 deletion pose-detection/src/blazepose_tfjs/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,24 @@
import {BlazePoseEstimationConfig, BlazePoseModelConfig} from '../blazepose_mediapipe/types';

/**
* Additional model parameters for BlazePose Tfjs backend
* Model parameters for BlazePose TFJS runtime.
*
* `runtime`: Must set to be 'tfjs'.
*
* `enableSmoothing`: Optional. A boolean indicating whether to use temporal
* filter to smooth the predicted keypoints. Defaults to True. The temporal
* filter relies on the currentTime field of the HTMLVideoElement. You can
* override this timestamp by passing in your own timestamp (in milliseconds)
* as the third parameter in `estimatePoses`. This is useful when the input is
* a tensor, which doesn't have the currentTime field. Or in testing, to
* simulate different FPS.
*
* `modelType`: Optional. Possible values: 'lite'|'full'|'heavy'. Defaults to
* 'full'. The model accuracy increases from lite to heavy, while the inference
* speed decreases and memory footprint increases. The heavy variant is intended
* for applications that require high accuracy, while the lite variant is
* intended for latency-critical applications. The full variant is a balanced
* option.
*
* `detectorModelUrl`: Optional. An optional string that specifies custom url of
* the detector model. This is useful for area/countries that don't have access
Expand All @@ -34,5 +51,13 @@ export interface BlazePoseTfjsModelConfig extends BlazePoseModelConfig {
landmarkModelUrl?: string;
}

/**
* Pose estimation parameters for BlazePose TFJS runtime.
*
* `maxPoses`: Optional. Defaults to 1. BlazePose only supports 1 pose for now.
*
* `flipHorizontal`: Optional. Default to false. When image data comes from
* camera, the result has to flip horizontally.
*/
export interface BlazePoseTfjsEstimationConfig extends
BlazePoseEstimationConfig {}
6 changes: 3 additions & 3 deletions pose-detection/src/movenet/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ export const MOVENET_SINGLEPOSE_THUNDER_RESOLUTION = 256;

// The default configuration for loading MoveNet.
export const MOVENET_CONFIG: MoveNetModelConfig = {
modelType: SINGLEPOSE_LIGHTNING
modelType: SINGLEPOSE_LIGHTNING,
enableSmoothing: true
};

export const MOVENET_SINGLE_POSE_ESTIMATION_CONFIG: MoveNetEstimationConfig = {
maxPoses: 1,
enableSmoothing: true
maxPoses: 1
};

export const KEYPOINT_FILTER_CONFIG = {
Expand Down
6 changes: 3 additions & 3 deletions pose-detection/src/movenet/detector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class MoveNetDetector implements PoseDetector {
InputResolution = {height: 0, width: 0};
private readonly keypointIndexByName =
getKeypointIndexByName(SupportedModels.MoveNet);
private readonly enableSmoothing: boolean;

// Global states.
private keypointsFilter = new KeypointsOneEuroFilter(KEYPOINT_FILTER_CONFIG);
Expand All @@ -61,6 +62,7 @@ class MoveNetDetector implements PoseDetector {
this.modelInputResolution.width = MOVENET_SINGLEPOSE_THUNDER_RESOLUTION;
this.modelInputResolution.height = MOVENET_SINGLEPOSE_THUNDER_RESOLUTION;
}
this.enableSmoothing = config.enableSmoothing;
}

/**
Expand Down Expand Up @@ -135,8 +137,6 @@ class MoveNetDetector implements PoseDetector {
* @param config Optional. A configuration object with the following
* properties:
* `maxPoses`: Optional. Has to be set to 1.
* `enableSmoothing`: Optional. Defaults to `true`. When enabled, a temporal
* smoothing filter will be used on the keypoint locations to reduce jitter.
*
* @param timestamp Optional. In milliseconds. This is useful when image is
* a tensor, which doesn't have timestamp info. Or to override timestamp
Expand Down Expand Up @@ -214,7 +214,7 @@ class MoveNetDetector implements PoseDetector {

// Apply the sequential filter before estimating the cropping area to make
// it more stable.
if (timestamp != null && estimationConfig.enableSmoothing) {
if (timestamp != null && this.enableSmoothing) {
keypoints = this.keypointsFilter.apply(keypoints, timestamp);
}

Expand Down
8 changes: 4 additions & 4 deletions pose-detection/src/movenet/detector_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ export function validateModelConfig(modelConfig: MoveNetModelConfig):
`Should be one of ${VALID_MODELS}`);
}

if (config.enableSmoothing == null) {
config.enableSmoothing = true;
}

return config;
}

Expand All @@ -47,9 +51,5 @@ export function validateEstimationConfig(
throw new Error(`Invalid maxPoses ${config.maxPoses}. Should be 1.`);
}

if (config.enableSmoothing == null) {
config.enableSmoothing = true;
}

return config;
}
25 changes: 16 additions & 9 deletions pose-detection/src/movenet/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,37 @@
import {EstimationConfig, ModelConfig} from '../types';

/**
* Additional MoveNet model loading config.
* MoveNet model loading config.
*
* 'modelType': Optional. The type of MoveNet model to load, Lighting or
* `enableSmoothing`: Optional. A boolean indicating whether to use temporal
* filter to smooth the predicted keypoints. Defaults to True. The temporal
* filter relies on the currentTime field of the HTMLVideoElement. You can
* override this timestamp by passing in your own timestamp (in milliseconds)
* as the third parameter in `estimatePoses`. This is useful when the input is
* a tensor, which doesn't have the currentTime field. Or in testing, to
* simulate different FPS.
*
* `modelType`: Optional. The type of MoveNet model to load, Lighting or
* Thunder. Defaults to Lightning. Lightning is a lower capacity model that can
* run >50FPS on most modern laptops while achieving good performance. Thunder
* is A higher capacity model that performs better prediction quality while
* is a higher capacity model that performs better prediction quality while
* still achieving real-time (>30FPS) speed. Thunder will lag behind the
* lightning, but it will pack a punch.
*
* `modelUrl`: Optional. An optional string that specifies custom url of the
* model. This is useful for area/countries that don't have access to the model
* hosted on TF Hub.
* hosted on TF Hub. If not provided, it will load the model specified by
* `modelType` from tf.hub.
*/
export interface MoveNetModelConfig extends ModelConfig {
enableSmoothing?: boolean;
modelType?: string;
modelUrl?: string;
}

/**
* MoveNet Specific Inference Config.
*
* `enableSmoothing`: Optional. Defaults to 'true'. When enabled, a temporal
* smoothing filter will be used on the keypoint locations to reduce jitter.
* `maxPoses`: Optional. Defaults to 1. MoveNet only supports 1 pose for now.
*/
export interface MoveNetEstimationConfig extends EstimationConfig {
enableSmoothing?: boolean;
}
export interface MoveNetEstimationConfig extends EstimationConfig {}
2 changes: 1 addition & 1 deletion pose-detection/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@
resolved "https://registry.npmjs.org/@istanbuljs/schema/-/schema-0.1.3.tgz#e45e384e4b8ec16bce2fd903af78450f6bf7ec98"
integrity sha512-ZXRY4jNvVgSVQ8DL3LTcakaAtXwTVUxE81hslsyD2AtoXW/wVob10HkOJ1X/pAlcI7D+2YoZKg5do8G/w6RYgA==

"@mediapipe/pose@0.3.x":
"@mediapipe/pose@~0.3.0":
version "0.3.1620676110"
resolved "https://registry.npmjs.org/@mediapipe/pose/-/pose-0.3.1620676110.tgz#35d522196db581e213259fbc42290a577bd239ed"
integrity sha512-macv7l2fuaqLArYvUSuoUrxoOVut4KEFHEqPU0i6SLpfohJFhBRgpJXiy8w+KTWkdh0e7cTLkPTESiV2+u9tNA==
Expand Down