Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tflite example #457

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions lite/examples/generative_ai/android/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Generative AI

## Introduction
Large language models (LLMs) are types of machine learning models that are created based on large bodies of text data to generate various outputs for natural language processing (NLP) tasks, including text generation, question answering, and machine translation. They are based on Transformer architecture and are trained on massive amounts of text data, often involving billions of words. Even LLMs of a smaller scale, such as GPT-2, can perform impressively. Converting TensorFlow models to a lighter, faster, and low-power model allows for us to run generative AI models on-device, with benefits of better user security because data will never leave your device.

This example shows you how to build an Android app with TensorFlow Lite to run a Keras LLM and provides suggestions for model optimization using quantizing techniques, which otherwise would require a much larger amount of memory and greater computational power to run.

This example open sourced an Android app framework that any compatible TFLite LLMs can plug into. Here are two demos:
* In Figure 1, we used a Keras GPT-2 model to perform text completion tasks on device.
* In Figure 2, we converted a version of instruction-tuned [PaLM model](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html) (1.5 billion parameters) to TFLite and executed through TFLite runtime.

<p align="center">
<img src="figures/autocomplete_fig1.gif" width="300">
</p>
Figure 1: Example of running the Keras GPT-2 model (converted from this Codelab) on device to perform text completion on Pixel 7. Demo shows the real latency with no speedup.
<p align="center">
<img src="figures/auto_complete_2.gif" width="300">
</p>
Figure 2: Example of running a version of [PaLM model](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html) with 1.5 billion parameters. Demo is recorded on Pixel 7 Pro without playback speedup.




## Guides
### Step 1. Train a language model using Keras

For this demonstration, we will use KerasNLP to get the GPT-2 model. KerasNLP is a library that contains state-of-the-art pretrained models for natural language processing tasks, and can support users through their entire development cycle. You can see the list of models available in the [KerasNLP repository](https://github.com/keras-team/keras-nlp/tree/master/keras_nlp/models). The workflows are built from modular components that have state-of-the-art preset weights and architectures when used out-of-the-box and are easily customizable when more control is needed. Creating the GPT-2 model can be done with the following steps:

```python
gpt2_tokenizer = keras_nlp.models.GPT2Tokenizer.from_preset("gpt2_base_en")

gpt2_preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
"gpt2_base_en",
sequence_length=256,
add_end_token=True,
)

gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
"gpt2_base_en",
preprocessor=gpt2_preprocessor,
)
```

You can check out the full GPT-2 model implementation [on GitHub](https://github.com/keras-team/keras-nlp/tree/master/keras_nlp/models/gpt2).


### Step 2. Convert a Keras model to a TFLite model

Start with the `generate()` function from GPT2CausalLM that performs the conversion. Wrap the `generate()` function to create a concrete TensorFlow function:

```python
@tf.function
def generate(prompt, max_length):
# prompt: input prompt to the LLM in string format
# max_length: the max length of the generated tokens
return gpt2_lm.generate(prompt, max_length)
concrete_func = generate.get_concrete_function(tf.TensorSpec([], tf.string), 100)
```

Now define a helper function that will run inference with an input and a TFLite model. TensorFlow text ops are not built-in ops in the TFLite runtime, so you will need to add these custom ops in order for the interpreter to make inference on this model. This helper function accepts an input and a function that performs the conversion, namely the `generator()` function defined above.

```python
def run_inference(input, generate_tflite):
interp = interpreter.InterpreterWithCustomOps(
model_content=generate_tflite,
custom_op_registerers=tf_text.tflite_registrar.SELECT_TFTEXT_OPS)
interp.get_signature_list()

generator = interp.get_signature_runner('serving_default')
output = generator(prompt=np.array([input]))
```

You can convert the model now:

```python
gpt2_lm.jit_compile = False
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[concrete_func],
gpt2_lm)

converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TFLite ops
tf.lite.OpsSet.SELECT_TF_OPS, # enable TF ops
]
converter.allow_custom_ops = True
converter.target_spec.experimental_select_user_tf_ops = [
"UnsortedSegmentJoin",
"UpperBound"
]
converter._experimental_guarantee_all_funcs_one_use = True
generate_tflite = converter.convert()
run_inference("I'm enjoying a", generate_tflite)
```

### Step 3. Quantization
TensorFlow Lite has implemented an optimization technique called quantization which can reduce model size and accelerate inference. Through the quantization process, 32-bit floats are mapped to smaller 8-bit integers, therefore reducing the model size by a factor of 4 for more efficient execution on modern hardwares. There are several ways to do quantization in TensorFlow. You can visit the [TFLite Model optimization](https://www.tensorflow.org/lite/performance/model_optimization) and [TensorFlow Model Optimization Toolkit](https://www.tensorflow.org/model_optimization) pages for more information. The types of quantizations are explained briefly below.

Here, you will use the post-training dynamic range quantization on the GPT-2 model by setting the converter optimization flag to tf.lite.Optimize.DEFAULT, and the rest of the conversion process is the same as detailed before. We tested that with this quantization technique the latency is around 6.7 seconds on Pixel 7 with max output length set to 100.

```python
gpt2_lm.jit_compile = False
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[concrete_func],
gpt2_lm)

converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TFLite ops
tf.lite.OpsSet.SELECT_TF_OPS, # enable TF ops
]
converter.allow_custom_ops = True
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.experimental_select_user_tf_ops = [
"UnsortedSegmentJoin",
"UpperBound"
]
converter._experimental_guarantee_all_funcs_one_use = True
quant_generate_tflite = converter.convert()
run_inference("I'm enjoying a", quant_generate_tflite)

with open('quantized_gpt2.tflite', 'wb') as f:
f.write(quant_generate_tflite)
```



### Step 4. Android App integration

You can clone this repo and substitute `android/app/src/main/assets/autocomplete.tflite` with your converted `quant_generate_tflite` file. Please refer to [how-to-build.md](https://github.com/tensorflow/examples/blob/master/lite/examples/generative_ai/android/how-to-build.md) to build this Android App.

## Safety and Responsible AI
As noted in the original [OpenAI GPT-2 announcement](https://openai.com/research/better-language-models), there are [notable caveats and limitations](https://github.com/openai/gpt-2#some-caveats) with the GPT-2 model. In fact, LLMs today generally have some well-known challenges such as hallucinations, fairness, and bias; this is because these models are trained on real-world data, which make them reflect real world issues.
This codelab is created only to demonstrate how to create an app powered by LLMs with TensorFlow tooling. The model produced in this codelab is for educational purposes only and not intended for production usage.
LLM production usage requires thoughtful selection of training datasets and comprehensive safety mitigations. One such functionality offered in this Android app is the profanity filter, which rejects bad user inputs or model outputs. If any inappropriate language is detected, the app will in return reject that action. To learn more about Responsible AI in the context of LLMs, make sure to watch the Safe and Responsible Development with Generative Language Models technical session at Google I/O 2023 and check out the [Responsible AI Toolkit](https://www.tensorflow.org/responsible_ai).
102 changes: 102 additions & 0 deletions lite/examples/generative_ai/android/app/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
@file:Suppress("UnstableApiUsage")

plugins {
kotlin("android")
id("com.android.application")
id("de.undercouch.download")
}

ext {
set("AAR_URL", "https://storage.googleapis.com/download.tensorflow.org/models/tflite/generativeai/tensorflow-lite-select-tf-ops.aar")
set("AAR_PATH", "$projectDir/libs/tensorflow-lite-select-tf-ops.aar")
}

apply {
from("download.gradle")
}

android {
namespace = "com.google.tensorflowdemo"
compileSdk = 33

defaultConfig {
applicationId = "com.google.tensorflowdemo"
minSdk = 24
targetSdk = 33
versionCode = 1
versionName = "1.0"
}
buildFeatures {
compose = true
buildConfig = true
viewBinding = true
}
composeOptions {
kotlinCompilerExtensionVersion = "1.3.2"
}
packagingOptions {
resources {
excludes += "/META-INF/{AL2.0,LGPL2.1}"
}
}
buildTypes {
getByName("release") {
isMinifyEnabled = true
proguardFiles(getDefaultProguardFile("proguard-android-optimize.txt"), "proguard-rules.pro")
isDebuggable = false
}
getByName("debug") {
applicationIdSuffix = ".debug"
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_1_8
targetCompatibility = JavaVersion.VERSION_1_8
}
kotlinOptions {
jvmTarget = "1.8"
freeCompilerArgs = listOf(
"-P",
"plugin:androidx.compose.compiler.plugins.kotlin:suppressKotlinVersionCompatibilityCheck=1.8.10"
)
}
}

dependencies {
implementation(fileTree(mapOf("dir" to "libs", "include" to listOf("*.aar"))))

// Compose
implementation(libraries.compose.ui)
implementation(libraries.compose.ui.tooling)
implementation(libraries.compose.ui.tooling.preview)
implementation(libraries.compose.foundation)
implementation(libraries.compose.material)
implementation(libraries.compose.material.icons)
implementation(libraries.compose.activity)

// Accompanist for Compose
implementation(libraries.accompanist.systemuicontroller)

// Koin
implementation(libraries.koin.core)
implementation(libraries.koin.android)
implementation(libraries.koin.compose)

// Lifecycle
implementation(libraries.lifecycle.viewmodel)
implementation(libraries.lifecycle.viewmodel.compose)
implementation(libraries.lifecycle.viewmodel.ktx)
implementation(libraries.lifecycle.runtime.compose)

// Logging
implementation(libraries.napier)

// Profanity filter
implementation(libraries.wordfilter)

// TensorFlow Lite
implementation(libraries.tflite)

// Unit tests
testImplementation(libraries.junit)
}
7 changes: 7 additions & 0 deletions lite/examples/generative_ai/android/app/download.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
task downloadAAR {
download {
src project.ext.AAR_URL
dest project.ext.AAR_PATH
overwrite false
}
}
1 change: 1 addition & 0 deletions lite/examples/generative_ai/android/app/libs/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorflow-lite-select-tf-ops.aar
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Build your own aar

By default the app automatically downloads the needed aar files. But if you want
to build your own, just go ahead and run `./build_aar.sh`. This script will pull
in the necessary ops from [TensorFlow Text](https://www.tensorflow.org/text) and
build the aar for [Select TF operators](https://www.tensorflow.org/lite/guide/ops_select).

After compilation, a new file `tftext_tflite_flex.aar` is generated. Replace the
one in app/libs/ folder and re-build the app.

By default, the script builds only for `android_x86_64`. You can change it to
`android_x86`, `android_arm` or `android_arm64`.

Note that you still need to include the standard `tensorflow-lite` aar in your
gradle file.

Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#! /bin/bash

set -e

# Clone TensorFlow Text repo
git clone https://github.com/tensorflow/text.git tensorflow_text

cd tensorflow_text/
echo 'exports_files(["LICENSE"])' > BUILD

# Checkout 2.12 branch
git checkout 2.12

# Apply tftext-2.12.patch
git apply ../tftext-2.12.patch

# Run config
./oss_scripts/configure.sh

# Run bazel build
bazel build -c opt --cxxopt='--std=c++14' --config=monolithic --config=android_x86_64 --experimental_repo_remote_exec //tensorflow_text:tftext_tflite_flex

if [ $? -eq 0 ]; then
# Print a message
echo "Please find the aar file: tensorflow_text/bazel-bin/tensorflow_text/tftext_tflite_flex.aar"
else
echo "build_aar.sh has failed. Please find the error message above and address it before proceeding."
fi
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
diff --git a/WORKSPACE b/WORKSPACE
index 28b7ee5..5ad0b55 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -116,3 +116,10 @@ load("@org_tensorflow//third_party/android:android_configure.bzl", "android_conf
android_configure(name="local_config_android")
load("@local_config_android//:android.bzl", "android_workspace")
android_workspace()
+
+android_sdk_repository(name = "androidsdk")
+
+android_ndk_repository(
+ name = "androidndk",
+ api_level = 21,
+)
diff --git a/tensorflow_text/BUILD b/tensorflow_text/BUILD
index 9b5ee5b..880c7c5 100644
--- a/tensorflow_text/BUILD
+++ b/tensorflow_text/BUILD
@@ -2,6 +2,8 @@ load("//tensorflow_text:tftext.bzl", "py_tf_text_library")

# [internal] load build_test.bzl
load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_cc_shared_object")
+load("@org_tensorflow//tensorflow/lite/delegates/flex:build_def.bzl", "tflite_flex_android_library")
+load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")

# Visibility rules
package(
@@ -61,6 +63,20 @@ tflite_cc_shared_object(
deps = [":ops_lib"],
)

+tflite_flex_android_library(
+ name = "tftext_ops",
+ additional_deps = [
+ "@org_tensorflow//tensorflow/lite/delegates/flex:delegate",
+ ":ops_lib",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+aar_with_jni(
+ name = "tftext_tflite_flex",
+ android_library = ":tftext_ops",
+)
+
py_library(
name = "ops",
srcs = [
diff --git a/tensorflow_text/tftext.bzl b/tensorflow_text/tftext.bzl
index 65430ca..04f68d8 100644
--- a/tensorflow_text/tftext.bzl
+++ b/tensorflow_text/tftext.bzl
@@ -140,6 +140,7 @@ def tf_cc_library(
deps += select({
"@org_tensorflow//tensorflow:mobile": [
"@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite",
+ "@org_tensorflow//tensorflow/lite/kernels/shim:tf_op_shim",
],
"//conditions:default": [
"@local_config_tf//:libtensorflow_framework",
21 changes: 21 additions & 0 deletions lite/examples/generative_ai/android/app/proguard-rules.pro
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html

# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}

# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable

# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
Loading