Skip to content

feat: select model in android super resolution app #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,4 @@ dmypy.json
# Hub exports
**/*.mlmodel
**/*.tflite
apps/android/SuperResolution/src/main/res/values/models.xml
35 changes: 32 additions & 3 deletions apps/android/SuperResolution/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ android {
}

preBuild.doFirst {
if (!file("./src/main/assets/" + project.properties['superresolution_tfLiteModelAsset']).exists()) {
throw new RuntimeException(missingModelErrorMsg)
}
generateModelList()

for (int i = 1; i <= 2; ++i) {
String filename = "./src/main/assets/images/Sample${i}.jpg"
Expand Down Expand Up @@ -61,3 +59,34 @@ dependencies {
if (System.getProperty("user.dir") != project.rootDir.path) {
throw new RuntimeException("This project should be opened from the `android` directory (parent of SuperResolution directory), NOT the SuperResolution directory.")
}


def generateModelList() {
def assetsDir = file("${projectDir}/src/main/assets")
def outputDir = file("${projectDir}/src/main/res/values")
def outputFile = file("${outputDir}/models.xml")
if (!outputDir.exists()) {
throw new GradleException("res directory not exist: ${outputDir}")
}
if (!assetsDir.exists()) {
throw new GradleException("assets directory not exist: ${assetsDir}")
}

def files = []
if (assetsDir.exists()) {
files = assetsDir.listFiles().findAll { it.name.endsWith('.tflite') || it.name.endsWith('.bin') }.collect { it.name }
}

def xmlContent = """<?xml version="1.0" encoding="utf-8"?>
<resources>
<string-array name="model_files">
"""
files.each { fileName ->
xmlContent += " <item>${fileName}</item>\n"
}
xmlContent += """ </string-array>
</resources>
"""
outputFile.text = xmlContent

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import android.os.Handler;
import android.os.Looper;
import android.provider.MediaStore;
import android.text.TextUtils;
import android.util.Log;
import android.view.View;
import android.widget.AdapterView;
import android.widget.ArrayAdapter;
Expand Down Expand Up @@ -51,7 +53,7 @@ public class MainActivity extends AppCompatActivity {
ImageView selectedImageView;
TextView inferenceTimeView;
TextView predictionTimeView;
Spinner imageSelector;
Spinner imageSelector, modelSelector;
Button predictionButton;
ActivityResultLauncher<Intent> selectImageResultLauncher;
private final String fromGalleryImageSelectorOption = "From Gallery";
Expand All @@ -62,6 +64,8 @@ public class MainActivity extends AppCompatActivity {
"Sample2.jpg",
fromGalleryImageSelectorOption};

private String[] modelSelectorOptions;

// Inference Elements
Bitmap selectedImage = null; // Raw image, not resized
private SuperResolution defaultDelegateUpscaler;
Expand Down Expand Up @@ -91,6 +95,7 @@ protected void onCreate(Bundle savedInstanceState) {
allDelegatesButton = (RadioButton)findViewById(R.id.defaultDelegateRadio);

imageSelector = (Spinner) findViewById((R.id.imageSelector));
modelSelector = (Spinner) findViewById((R.id.modelSelector));
inferenceTimeView = (TextView)findViewById(R.id.inferenceTimeResultText);
predictionTimeView = (TextView)findViewById(R.id.predictionTimeResultText);
predictionButton = (Button)findViewById(R.id.runModelButton);
Expand Down Expand Up @@ -122,6 +127,26 @@ public void onItemSelected(AdapterView<?> parent, View view, int position, long
public void onNothingSelected(AdapterView<?> parent) { }
});

// Setup Model Selector Dropdown
modelSelectorOptions = getResources().getStringArray(R.array.model_files);
ArrayAdapter modelAdapter = new ArrayAdapter(this, android.R.layout.simple_spinner_item, modelSelectorOptions);
modelAdapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item);
modelSelector.setAdapter(modelAdapter);
modelSelector.setOnItemSelectedListener(new AdapterView.OnItemSelectedListener() {
@Override
public void onItemSelected(AdapterView<?> parent, View view, int position, long id) {
// Load selected models from assets
((TextView) view).setTextColor(getResources().getColor(R.color.white));
((TextView) view).setEllipsize(TextUtils.TruncateAt.END);

// Exit the UI thread and instantiate the model in the background.
String modelName = parent.getItemAtPosition(position).toString();
createTFLiteUpscalerAsync(modelName);
}

@Override
public void onNothingSelected(AdapterView<?> parent) { }
});
// Setup Image Selection from Phone Gallery
selectImageResultLauncher = registerForActivityResult(
new ActivityResultContracts.StartActivityForResult(),
Expand Down Expand Up @@ -155,9 +180,6 @@ public void onNothingSelected(AdapterView<?> parent) { }
// Setup button callback
predictionButton.setOnClickListener((view) -> updatePredictionDataAsync());

// Exit the UI thread and instantiate the model in the background.
createTFLiteUpscalerAsync();

// Enable image selection
enableImageSelector();
enableDelegateSelectionButtons();
Expand All @@ -176,12 +198,15 @@ void setInferenceUIEnabled(boolean enabled) {
predictionButton.setAlpha(0.5f);
imageSelector.setEnabled(false);
imageSelector.setAlpha(0.5f);
modelSelector.setEnabled(false);
modelSelector.setAlpha(0.5f);
cpuOnlyButton.setEnabled(false);
allDelegatesButton.setEnabled(false);
} else if (cpuOnlyUpscaler != null && defaultDelegateUpscaler != null && selectedImage != null) {
predictionButton.setEnabled(true);
predictionButton.setAlpha(1.0f);
enableImageSelector();
enableModelSelector();
enableDelegateSelectionButtons();
}
}
Expand All @@ -193,6 +218,13 @@ void enableImageSelector() {
imageSelector.setEnabled(true);
imageSelector.setAlpha(1.0f);
}
/**
* Enable the model selector UI spinner.
*/
void enableModelSelector() {
modelSelector.setEnabled(true);
modelSelector.setAlpha(1.0f);
}

/**
* Enable the image selector UI radio buttons.
Expand Down Expand Up @@ -327,17 +359,18 @@ void updatePredictionDataAsync() {
* Loading the TF Lite model takes time, so this is done asynchronously to the main UI thread.
* Disables the inference UI during load and reenables it afterwards.
*/
void createTFLiteUpscalerAsync() {
void createTFLiteUpscalerAsync(final String tfLiteModelAsset) {
if (defaultDelegateUpscaler != null || cpuOnlyUpscaler != null) {
throw new RuntimeException("Classifiers were already created");
defaultDelegateUpscaler.close();
cpuOnlyUpscaler.close();
// throw new RuntimeException("Classifiers were already created");
}
setInferenceUIEnabled(false);

// Exit the UI thread and instantiate the model in the background.
backgroundTaskExecutor.execute(() -> {
// Create two upscalers.
// One uses the default set of delegates (can access NPU, GPU, CPU), and the other uses only XNNPack (CPU).
String tfLiteModelAsset = this.getResources().getString(R.string.tfLiteModelAsset);
try {
defaultDelegateUpscaler = new SuperResolution(
this,
Expand All @@ -352,6 +385,7 @@ void createTFLiteUpscalerAsync() {
} catch (IOException | NoSuchAlgorithmException e) {
throw new RuntimeException(e.getMessage());
}
Log.i("createTFLiteUpscalerAsync","model load finish: "+tfLiteModelAsset);

mainLooperHandler.post(() -> setInferenceUIEnabled(true));
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,39 @@

</androidx.cardview.widget.CardView>

<LinearLayout
android:orientation="horizontal"
android:id="@+id/modelSelectorCard"
android:layout_width="409sp"
android:layout_height="50sp"
android:background="@color/purple_qcom"
app:layout_constraintBottom_toTopOf="@id/imageSelectorCard"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent">

<TextView
android:id="@+id/modelSelectorText"
android:layout_width="wrap_content"
android:layout_height="match_parent"
android:layout_marginStart="60sp"
android:gravity="center"
android:text="Model"
android:textColor="@color/white"
android:textSize="17sp" />


<Spinner
android:id="@+id/modelSelector"
android:layout_width="0dp"
android:layout_height="match_parent"
android:layout_gravity="end"
android:layout_marginEnd="30sp"
android:layout_weight="1"
android:backgroundTint="@color/white"
android:textAlignment="textEnd"
android:theme="@style/spinnerTheme" />

</LinearLayout>
<androidx.cardview.widget.CardView
android:id="@+id/imageSelectorCard"
android:layout_width="409sp"
Expand Down Expand Up @@ -67,7 +100,7 @@
android:layout_marginBottom="8sp"
android:backgroundTint="@color/purple_qcom"
android:orientation="horizontal"
app:layout_constraintBottom_toTopOf="@+id/imageSelectorCard"
app:layout_constraintBottom_toTopOf="@+id/modelSelectorCard"
tools:layout_editor_absoluteX="2sp">

<RadioButton
Expand Down