Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to calculate and view score & toggle for model switch #2

Closed
wants to merge 12 commits into from
Prev Previous commit
Next Next commit
Show image + button to compute and display scores
  • Loading branch information
ua741 committed Oct 5, 2023
commit dd623d185eaf35b8a761c02271c5d93a9ff79f33
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
package com.example.mycomposeapplication

import android.graphics.BitmapFactory
import android.media.Image
import android.os.Bundle
import android.util.Log
import android.widget.Toast
import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent
import androidx.compose.foundation.Image
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.material.*
import androidx.compose.runtime.*
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.lifecycle.lifecycleScope
Expand All @@ -22,6 +28,9 @@ import kotlinx.coroutines.withContext
import java.io.File
import java.util.*
import kotlin.concurrent.thread
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.verticalScroll


private val imageList =
listOf(
Expand All @@ -31,6 +40,19 @@ private val imageList =
"image@4000px-large.jpg",
"image-large-17.2MB.jpg"
)
private val textList =
listOf(
"A bird flying in the sky, cloudy",
"A helicopter in water",
"cat",
"cat on the pavement",
"pink rose in the pond",
"red cloth inside blue bag",
"white brown cat on pavement with shadow",
"keyboard",
"white computer keyboard keys",
"dog face in cold weather"
)

class MainActivity : ComponentActivity() {

Expand All @@ -45,7 +67,13 @@ class MainActivity : ComponentActivity() {
.fillMaxSize(),
color = MaterialTheme.colors.background
) {
Column {
// Create a scroll state
val scrollState = rememberScrollState()

// Wrap the Column with verticalScroll
Column(
modifier = Modifier.verticalScroll(scrollState)
) {
Greeting(selectedImage.value)
OutlinedButton(onClick = { imageListExpanded.value = true }) {
Text(text = "selectImage")
Expand Down Expand Up @@ -84,11 +112,17 @@ class MainActivity : ComponentActivity() {

Text(text = encodeImageState1.value)
Text(text = encodeImageState2.value)
Button(onClick = { testScoring() }) {
Text(text = "testScore")
}
DisplayImage(imagePath = imagePathState.value)
Text(text = scoreState.value)
}
}
}
}
imagePath = assetFilePath(this, selectedImage.value).toString()
imagePathState.value = imagePath
}

var imageListExpanded = mutableStateOf(false)
Expand All @@ -98,6 +132,8 @@ class MainActivity : ComponentActivity() {
private var encodeImageCost: MutableState<Long> = mutableStateOf(0L)
private var encodeImageState1: MutableState<String> = mutableStateOf("None")
private var encodeImageState2: MutableState<String> = mutableStateOf("None")
private var scoreState: MutableState<String> = mutableStateOf("")
private var imagePathState: MutableState<String> = mutableStateOf("")

private var textEncoderONNX: TextEncoderONNX? = null
private var imageEncoderONNX: ImageEncoderONNX? = null
Expand Down Expand Up @@ -138,6 +174,44 @@ class MainActivity : ComponentActivity() {
}
}

private fun testScoring() {
lifecycleScope.launch {
scoreState.value = ""
if (imageEncoderONNX == null) {
loadImageEncoderONNX()
}
if (textEncoderONNX == null) {
loadTextEncoderONNX()
}
val time = System.currentTimeMillis()
val bitmap = loadThumbnail(this@MainActivity, imagePath)
Log.d("loadImage", "${System.currentTimeMillis() - time} ms")
saveBitMap(this@MainActivity, bitmap, "decodeSampledBitmapFromFile")
val imageEmbedding = imageEncoderONNX?.encode(bitmap)
encodeImageCost.value = System.currentTimeMillis() - time

// Create a mutable list to hold text and its corresponding score
val scoreList = mutableListOf<Pair<String, Double>>()
for (text in textList) {
val textEmbedding = textEncoderONNX?.encode(text)
if (imageEmbedding != null && textEmbedding != null) {
val score = computeScore(imageEmbedding[0], textEmbedding[0])
scoreList.add(Pair(text, score))
}
}
scoreList.sortBy { -1.0 * it.second }
val displayString = buildString {
for ((text, score) in scoreList) {
append("\n%.4f".format(score))
append(" : $text")
// Limit to 4 digits
}
}
scoreState.value = displayString

}
}

private fun loadImageEncoderONNX() {
if (imageEncoderONNX == null) {
encodeImageState1.value = "Loading ImageEncoder ONNX ..."
Expand Down Expand Up @@ -269,8 +343,11 @@ class MainActivity : ComponentActivity() {
imageListExpanded.value = false
selectedImage.value = im
imagePath = assetFilePath(this@MainActivity, selectedImage.value).toString()
imagePathState.value = imagePath
if (imagePath.isEmpty()) {
Toast.makeText(this, "图片加载失败!", Toast.LENGTH_SHORT).show()
} else {
scoreState.value = ""
}
}
}
Expand All @@ -279,7 +356,24 @@ class MainActivity : ComponentActivity() {
fun Greeting(name: String) {
Text(text = "Selected image: $name")
}
@Composable
fun DisplayImage(imagePath: String) {
val bitmap = BitmapFactory.decodeFile(imagePath)
val imageBitmap = bitmap?.asImageBitmap()

if (imageBitmap != null) {
Image(
bitmap = imageBitmap,
contentDescription = null, // decorative
modifier = Modifier
.fillMaxWidth()
.height(200.dp),
contentScale = ContentScale.Crop
)
} else {
// Handle the case where the image could not be loaded
}
}
@Preview(showBackground = true)
@Composable
fun DefaultPreview() {
Expand Down