Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to calculate and view score & toggle for model switch #2

Closed
wants to merge 12 commits into from
Prev Previous commit
Next Next commit
Add option to toggle model
  • Loading branch information
ua741 committed Oct 5, 2023
commit 1a4b6fceec4dff7b7a2fba7ab8fa8c4bf888f1c1
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ class ImageEncoderONNX(private val context: MainActivity, useQuantizedModel: Boo
ortSession = createOrtSession()
}

fun close() {
ortSession?.close()
ortEnv?.close()
}

private fun createOrtSession(): OrtSession? {
val p = assetFilePath(context, modelPath) ?: return null
return ortEnv?.createSession(p)
Expand Down
26 changes: 23 additions & 3 deletions app/src/main/java/com/example/mycomposeapplication/MainActivity.kt
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package com.example.mycomposeapplication

import android.graphics.BitmapFactory
import android.media.Image
import android.os.Bundle
import android.util.Log
import android.widget.Toast
import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent
import androidx.compose.foundation.Image
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
Expand Down Expand Up @@ -112,6 +113,23 @@ class MainActivity : ComponentActivity() {

Text(text = encodeImageState1.value)
Text(text = encodeImageState2.value)
Row(
horizontalArrangement = Arrangement.SpaceBetween,
modifier = Modifier.fillMaxWidth()
) {
Text("Use Quantized Model")
Switch(
checked = useQuantizedModel.value,
onCheckedChange = { newState ->
useQuantizedModel.value = newState
textEncoderONNX?.close()
imageEncoderONNX?.close()
textEncoderONNX = null
imageEncoderONNX = null
// You can also call any other logic that needs to run when the state changes
}
)
}
Button(onClick = { testScoring() }) {
Text(text = "testScore")
}
Expand All @@ -134,6 +152,7 @@ class MainActivity : ComponentActivity() {
private var encodeImageState2: MutableState<String> = mutableStateOf("None")
private var scoreState: MutableState<String> = mutableStateOf("")
private var imagePathState: MutableState<String> = mutableStateOf("")
private var useQuantizedModel: MutableState<Boolean> = mutableStateOf(true)

private var textEncoderONNX: TextEncoderONNX? = null
private var imageEncoderONNX: ImageEncoderONNX? = null
Expand All @@ -154,6 +173,7 @@ class MainActivity : ComponentActivity() {
}
}


private fun testImageEncoder() {
lifecycleScope.launch {
if (imageEncoderONNX == null) {
Expand Down Expand Up @@ -216,7 +236,7 @@ class MainActivity : ComponentActivity() {
if (imageEncoderONNX == null) {
encodeImageState1.value = "Loading ImageEncoder ONNX ..."
encodeImageState2.value = "Loading ImageEncoder ONNX ..."
imageEncoderONNX = ImageEncoderONNX(context = this@MainActivity)
imageEncoderONNX = ImageEncoderONNX(context = this@MainActivity, useQuantizedModel.value)
encodeImageState1.value = "Loading ImageEncoder ONNX done"
encodeImageState2.value = "Loading ImageEncoder ONNX done"
}
Expand All @@ -227,7 +247,7 @@ class MainActivity : ComponentActivity() {
withContext(Dispatchers.Default) {
if (textEncoderONNX == null) {
Log.i(this.javaClass.canonicalName, "Starting loading textEncoder")
textEncoderONNX = TextEncoderONNX(context = this@MainActivity)
textEncoderONNX = TextEncoderONNX(context = this@MainActivity, useQuantizedModel.value)
Log.i(this.javaClass.canonicalName, "Done loading textEncoder")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ class TextEncoderONNX(private val context: Context, useQuantizedModel: Boolean =
tokenizer = BPETokenizer(context)
}

fun close() {
ortSession?.close()
ortEnv?.close()
}

private fun createOrtSession(): OrtSession? {
val p = assetFilePath(context, modelPath) ?: return null
return ortEnv?.createSession(p)
Expand Down