Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not implement inference model in high-level model classes #509

Merged
merged 2 commits into from
Jan 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public abstract class ModelHub {
* @param [loadingMode] Strategy of existing model use-case handling.
* @return Raw model without weights. Needs in compilation and weights loading before usage.
*/
public abstract fun <T : InferenceModel, U : InferenceModel> loadModel(
public abstract fun <T : InferenceModel, U> loadModel(
modelType: ModelType<T, U>,
loadingMode: LoadingMode = LoadingMode.SKIP_LOADING_IF_EXISTS
): T
Expand All @@ -40,7 +40,7 @@ public abstract class ModelHub {
* @param [loadingMode] Strategy of existing model use-case handling.
* @return Pretrained model.
*/
public fun <T : InferenceModel, U : InferenceModel> loadPretrainedModel(
public fun <T : InferenceModel, U> loadPretrainedModel(
modelType: ModelType<T, U>,
loadingMode: LoadingMode = LoadingMode.SKIP_LOADING_IF_EXISTS
): U {
Expand All @@ -50,7 +50,7 @@ public abstract class ModelHub {
/**
* This operator equivalent to [loadPretrainedModel].
*/
public operator fun <T : InferenceModel, U : InferenceModel> get(modelType: ModelType<T, U>): U {
public operator fun <T : InferenceModel, U> get(modelType: ModelType<T, U>): U {
return loadPretrainedModel(modelType = modelType)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import org.jetbrains.kotlinx.dl.api.preprocessing.Operation
* @param T the type of the basic model for common functionality.
* @param U the type of the pre-trained model for usage in Easy API.
*/
public interface ModelType<T : InferenceModel, U : InferenceModel> {
public interface ModelType<T : InferenceModel, U> {
/** Relative path to model for local and S3 buckets storages. */
public val modelRelativePath: String

Expand Down
133 changes: 0 additions & 133 deletions examples/src/test/kotlin/examples/onnx/ModelCopyingTestSuite.kt

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.jetbrains.kotlinx.dl.api.summary.ModelWithSummary
public abstract class ImageRecognitionModelBase<I>(
protected val internalModel: InferenceModel,
protected val modelKindDescription: String? = null
) : InferenceModel by internalModel, ModelWithSummary {
) : ModelWithSummary, AutoCloseable {
/**
* Preprocessing operation specific to this model.
*/
Expand Down Expand Up @@ -70,4 +70,6 @@ public abstract class ImageRecognitionModelBase<I>(
ModelHubModelSummary(EmptySummary(), modelKindDescription)
}
}

override fun close(): Unit = internalModel.close()
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,6 @@ public class ImageRecognitionModel(
return predictObject(ImageConverter.toBufferedImage(imageFile))
}

override fun copy(
copiedModelName: String?,
saveOptimizerState: Boolean,
copyWeights: Boolean
): ImageRecognitionModel {
return ImageRecognitionModel(
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
inputColorMode,
channelsFirst,
preprocessor,
modelKindDescription
)
}

public companion object {
/**
* Creates a preprocessing [Operation] which converts given [BufferedImage] to [FloatData].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public object ONNXModels {
}

/** Pose detection models. */
public sealed class PoseDetection<U : InferenceModel>(override val modelRelativePath: String) :
public sealed class PoseDetection<U>(override val modelRelativePath: String) :
OnnxModelType<U> {
/**
* This model is a convolutional neural network model that runs on RGB images and predicts human joint locations of a single person.
Expand Down Expand Up @@ -164,7 +164,7 @@ public object ONNXModels {
}

/** Object detection models and preprocessing. */
public sealed class ObjectDetection<U : InferenceModel>(override val modelRelativePath: String) :
public sealed class ObjectDetection<U>(override val modelRelativePath: String) :
OnnxModelType<U> {
/**
* This model is a real-time neural network for object detection that detects 90 different classes
Expand Down Expand Up @@ -291,7 +291,7 @@ public object ONNXModels {
}

/** Face alignment models */
public sealed class FaceAlignment<U : InferenceModel> : OnnxModelType<U> {
public sealed class FaceAlignment<U> : OnnxModelType<U> {
/**
* This model is a neural network for face alignment that take RGB images of faces as input and produces coordinates of 106 faces landmarks.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ public class ONNXModelHub(private val context: Context) : ModelHub() {
* @param [loadingMode] it's ignored
*/
@Suppress("UNCHECKED_CAST")
override fun <T : InferenceModel, U : InferenceModel> loadModel(
override fun <T : InferenceModel, U> loadModel(
modelType: ModelType<T, U>,
loadingMode: LoadingMode, /* unused */
): T {
return loadModel(modelType as OnnxModelType<U>) as T
return loadModel(modelType as OnnxModelType<*>) as T
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.jetbrains.kotlinx.dl.onnx.inference.doWithRotation
public class FaceDetectionModel(
override val internalModel: OnnxInferenceModel,
modelKindDescription: String? = null
) : FaceDetectionModelBase<Bitmap>(modelKindDescription), CameraXCompatibleModel, InferenceModel by internalModel {
) : FaceDetectionModelBase<Bitmap>(modelKindDescription), CameraXCompatibleModel {
override var targetRotation: Int = 0
override val preprocessing: Operation<Bitmap, FloatData>
get() = pipeline<Bitmap>()
Expand All @@ -40,12 +40,7 @@ public class FaceDetectionModel(
.toFloatArray { layout = TensorLayout.NCHW }
.call(ONNXModels.FaceDetection.defaultPreprocessor)

override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): InferenceModel {
return FaceDetectionModel(
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
modelKindDescription
)
}
override fun close(): Unit = internalModel.close()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.jetbrains.kotlinx.dl.onnx.inference.doWithRotation
public class Fan2D106FaceAlignmentModel(
override val internalModel: OnnxInferenceModel,
modelKindDescription: String? = null
) : FaceAlignmentModelBase<Bitmap>(modelKindDescription), CameraXCompatibleModel, InferenceModel by internalModel {
) : FaceAlignmentModelBase<Bitmap>(modelKindDescription), CameraXCompatibleModel {
override val outputName: String = "fc1"
override var targetRotation: Int = 0

Expand All @@ -41,12 +41,7 @@ public class Fan2D106FaceAlignmentModel(
.rotate { degrees = targetRotation.toFloat() }
.toFloatArray { layout = TensorLayout.NCHW }

override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): InferenceModel {
return Fan2D106FaceAlignmentModel(
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
modelKindDescription
)
}
override fun close(): Unit = internalModel.close()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.jetbrains.kotlinx.dl.onnx.inference.doWithRotation
public class SSDLikeModel(
override val internalModel: OnnxInferenceModel, metadata: SSDLikeModelMetadata,
modelKindDescription: String? = null
) : SSDLikeModelBase<Bitmap>(metadata, modelKindDescription), CameraXCompatibleModel, InferenceModel by internalModel {
) : SSDLikeModelBase<Bitmap>(metadata, modelKindDescription), CameraXCompatibleModel {

override val classLabels: Map<Int, String> = Coco.V2017.labels(zeroIndexed = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ import org.jetbrains.kotlinx.dl.onnx.inference.executionproviders.ExecutionProvi
public class SinglePoseDetectionModel(
override val internalModel: OnnxInferenceModel,
modelKindDescription: String? = null
) : SinglePoseDetectionModelBase<Bitmap>(modelKindDescription), InferenceModel by internalModel,
CameraXCompatibleModel {
) : SinglePoseDetectionModelBase<Bitmap>(modelKindDescription), CameraXCompatibleModel {
override val preprocessing: Operation<Bitmap, FloatData>
get() = pipeline<Bitmap>()
.resize {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

package org.jetbrains.kotlinx.dl.onnx.inference

import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.api.inference.loaders.ModelType

/**
* Base type for [OnnxInferenceModel].
*/
public interface OnnxModelType<U : InferenceModel> : ModelType<OnnxInferenceModel, U> {
public interface OnnxModelType<U> : ModelType<OnnxInferenceModel, U> {
/**
* Shape of the input accepted by this model, without batch size.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ public class ONNXModelHub(public val cacheDirectory: File) : ModelHub() {
* @return An instance of [OnnxInferenceModel].
*/
@Suppress("UNCHECKED_CAST")
public override fun <T : InferenceModel, U : InferenceModel> loadModel(
public override fun <T : InferenceModel, U> loadModel(
modelType: ModelType<T, U>,
loadingMode: LoadingMode
): T {
return loadModel(modelType as OnnxModelType<U>, ExecutionProvider.CPU(), loadingMode = loadingMode) as T
return loadModel(modelType as OnnxModelType<*>, ExecutionProvider.CPU(), loadingMode = loadingMode) as T
}

private fun getONNXModelFile(modelFile: String, loadingMode: LoadingMode): File {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ public object ONNXModels {
}

/** Object detection models and preprocessing. */
public sealed class ObjectDetection<U : InferenceModel>(override val modelRelativePath: String) :
public sealed class ObjectDetection<U>(override val modelRelativePath: String) :
OnnxModelType<U> {
/**
* This model is a real-time neural network for object detection that detects 80 different classes
Expand Down Expand Up @@ -895,7 +895,7 @@ public object ONNXModels {
}

/** Face alignment models and preprocessing. */
public sealed class FaceAlignment<U : InferenceModel>(override val modelRelativePath: String) :
public sealed class FaceAlignment<U>(override val modelRelativePath: String) :
OnnxModelType<U> {
/**
* This model is a neural network for face alignment that take RGB images of faces as input and produces coordinates of 106 faces landmarks.
Expand All @@ -916,7 +916,7 @@ public object ONNXModels {
}

/** Pose detection models. */
public sealed class PoseDetection<U : InferenceModel>(override val modelRelativePath: String) :
public sealed class PoseDetection<U>(override val modelRelativePath: String) :
OnnxModelType<U> {
/**
* This model is a convolutional neural network model that runs on RGB images and predicts human joint locations of a single person.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import java.awt.image.BufferedImage
public class FaceDetectionModel(
override val internalModel: OnnxInferenceModel,
modelKindDescription: String? = null
) : FaceDetectionModelBase<BufferedImage>(modelKindDescription), InferenceModel by internalModel {
) : FaceDetectionModelBase<BufferedImage>(modelKindDescription) {
override val preprocessing: Operation<BufferedImage, FloatData>
get() = pipeline<BufferedImage>()
.resize {
Expand All @@ -38,10 +38,5 @@ public class FaceDetectionModel(
.toFloatArray { }
.call(ONNXModels.FaceDetection.defaultPreprocessor)

override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): InferenceModel {
return FaceDetectionModel(
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
modelKindDescription
)
}
override fun close(): Unit = internalModel.close()
}
Loading