Skip to content

Add imagen editing options like inpainting and outpainting #7075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions firebase-ai/api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ package com.google.firebase.ai {
}

@com.google.firebase.ai.type.PublicPreviewAPI public final class ImagenModel {
method public suspend Object? editImage(String prompt, com.google.firebase.ai.type.ImagenEditingConfig config, kotlin.coroutines.Continuation<? super com.google.firebase.ai.type.ImagenGenerationResponse<com.google.firebase.ai.type.ImagenInlineImage>>);
method public suspend Object? generateImages(String prompt, kotlin.coroutines.Continuation<? super com.google.firebase.ai.type.ImagenGenerationResponse<com.google.firebase.ai.type.ImagenInlineImage>>);
}

Expand Down Expand Up @@ -104,6 +105,7 @@ package com.google.firebase.ai.java {
}

@com.google.firebase.ai.type.PublicPreviewAPI public abstract class ImagenModelFutures {
method public abstract com.google.common.util.concurrent.ListenableFuture<com.google.firebase.ai.type.ImagenGenerationResponse<com.google.firebase.ai.type.ImagenInlineImage>> editImage(String prompt, com.google.firebase.ai.type.ImagenEditingConfig config);
method public static final com.google.firebase.ai.java.ImagenModelFutures from(com.google.firebase.ai.ImagenModel model);
method public abstract com.google.common.util.concurrent.ListenableFuture<com.google.firebase.ai.type.ImagenGenerationResponse<com.google.firebase.ai.type.ImagenInlineImage>> generateImages(String prompt);
method public abstract com.google.firebase.ai.ImagenModel getImageModel();
Expand Down Expand Up @@ -484,6 +486,47 @@ package com.google.firebase.ai.type {
public static final class ImagenAspectRatio.Companion {
}

public final class ImagenEditMode {
field public static final com.google.firebase.ai.type.ImagenEditMode.Companion Companion;
}

public static final class ImagenEditMode.Companion {
method public com.google.firebase.ai.type.ImagenEditMode getINPAINT_INSERTION();
method public com.google.firebase.ai.type.ImagenEditMode getINPAINT_REMOVAL();
method public com.google.firebase.ai.type.ImagenEditMode getOUTPAINT();
property public final com.google.firebase.ai.type.ImagenEditMode INPAINT_INSERTION;
property public final com.google.firebase.ai.type.ImagenEditMode INPAINT_REMOVAL;
property public final com.google.firebase.ai.type.ImagenEditMode OUTPAINT;
}

@com.google.firebase.ai.type.PublicPreviewAPI public final class ImagenEditingConfig {
ctor public ImagenEditingConfig(com.google.firebase.ai.type.ImagenInlineImage image, com.google.firebase.ai.type.ImagenEditMode editMode, com.google.firebase.ai.type.ImagenInlineImage? mask = null, Double? maskDilation = null, Integer? editSteps = null);
field public static final com.google.firebase.ai.type.ImagenEditingConfig.Companion Companion;
}

public static final class ImagenEditingConfig.Builder {
ctor public ImagenEditingConfig.Builder();
method public com.google.firebase.ai.type.ImagenEditingConfig build();
method public com.google.firebase.ai.type.ImagenEditingConfig.Builder setEditMode(com.google.firebase.ai.type.ImagenEditMode editMode);
method public com.google.firebase.ai.type.ImagenEditingConfig.Builder setEditSteps(int editSteps);
method public com.google.firebase.ai.type.ImagenEditingConfig.Builder setImage(com.google.firebase.ai.type.ImagenInlineImage image);
method public com.google.firebase.ai.type.ImagenEditingConfig.Builder setMask(com.google.firebase.ai.type.ImagenInlineImage mask);
method public com.google.firebase.ai.type.ImagenEditingConfig.Builder setMaskDilation(double maskDilation);
field public com.google.firebase.ai.type.ImagenEditMode? editMode;
field public Integer? editSteps;
field public com.google.firebase.ai.type.ImagenInlineImage? image;
field public com.google.firebase.ai.type.ImagenInlineImage? mask;
field public Double? maskDilation;
}

public static final class ImagenEditingConfig.Companion {
method public com.google.firebase.ai.type.ImagenEditingConfig.Builder builder();
}

public final class ImagenEditingConfigKt {
method @com.google.firebase.ai.type.PublicPreviewAPI public static com.google.firebase.ai.type.ImagenEditingConfig imagenEditingConfig(kotlin.jvm.functions.Function1<? super com.google.firebase.ai.type.ImagenEditingConfig.Builder,kotlin.Unit> init);
}

@com.google.firebase.ai.type.PublicPreviewAPI public final class ImagenGenerationConfig {
ctor public ImagenGenerationConfig(String? negativePrompt = null, Integer? numberOfImages = 1, com.google.firebase.ai.type.ImagenAspectRatio? aspectRatio = null, com.google.firebase.ai.type.ImagenImageFormat? imageFormat = null, Boolean? addWatermark = null);
method public Boolean? getAddWatermark();
Expand Down Expand Up @@ -552,6 +595,10 @@ package com.google.firebase.ai.type {
property public final String mimeType;
}

public final class ImagenInlineImageKt {
method @com.google.firebase.ai.type.PublicPreviewAPI public static com.google.firebase.ai.type.ImagenInlineImage toImagenInlineImage(android.graphics.Bitmap);
}

@com.google.firebase.ai.type.PublicPreviewAPI public final class ImagenPersonFilterLevel {
field public static final com.google.firebase.ai.type.ImagenPersonFilterLevel ALLOW_ADULT;
field public static final com.google.firebase.ai.type.ImagenPersonFilterLevel ALLOW_ALL;
Expand Down
79 changes: 71 additions & 8 deletions firebase-ai/src/main/kotlin/com/google/firebase/ai/ImagenModel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.google.firebase.ai.common.AppCheckHeaderProvider
import com.google.firebase.ai.common.ContentBlockedException
import com.google.firebase.ai.common.GenerateImageRequest
import com.google.firebase.ai.type.FirebaseAIException
import com.google.firebase.ai.type.ImagenEditingConfig
import com.google.firebase.ai.type.ImagenGenerationConfig
import com.google.firebase.ai.type.ImagenGenerationResponse
import com.google.firebase.ai.type.ImagenInlineImage
Expand Down Expand Up @@ -75,30 +76,92 @@ internal constructor(
public suspend fun generateImages(prompt: String): ImagenGenerationResponse<ImagenInlineImage> =
try {
controller
.generateImage(constructRequest(prompt, null, generationConfig))
.generateImage(constructGenerateImageRequest(prompt, generationConfig))
.validate()
.toPublicInline()
} catch (e: Throwable) {
throw FirebaseAIException.from(e)
}

private fun constructRequest(
public suspend fun editImage(
prompt: String,
gcsUri: String?,
config: ImagenGenerationConfig?,
config: ImagenEditingConfig
): ImagenGenerationResponse<ImagenInlineImage> =
try {
controller.generateImage(constructEditRequest(prompt, config)).validate().toPublicInline()
} catch (e: Throwable) {
throw FirebaseAIException.from(e)
}

private fun constructGenerateImageRequest(
prompt: String,
generationConfig: ImagenGenerationConfig? = null,
): GenerateImageRequest {
return GenerateImageRequest(
listOf(GenerateImageRequest.ImagenPrompt(prompt)),
GenerateImageRequest.ImagenParameters(
sampleCount = config?.numberOfImages ?: 1,
sampleCount = generationConfig?.numberOfImages ?: 1,
includeRaiReason = true,
addWatermark = generationConfig?.addWatermark,
personGeneration = safetySettings?.personFilterLevel?.internalVal,
negativePrompt = generationConfig?.negativePrompt,
safetySetting = safetySettings?.safetyFilterLevel?.internalVal,
storageUri = null,
aspectRatio = generationConfig?.aspectRatio?.internalVal,
imageOutputOptions = generationConfig?.imageFormat?.toInternal(),
editMode = null,
editConfig = null
),
)
}

private fun constructEditRequest(
prompt: String,
editConfig: ImagenEditingConfig,
): GenerateImageRequest {
return GenerateImageRequest(
listOf(
GenerateImageRequest.ImagenPrompt(
prompt = prompt,
referenceImages =
buildList {
add(
GenerateImageRequest.ReferenceImage(
referenceType = GenerateImageRequest.ReferenceType.RAW,
referenceId = 1,
referenceImage = editConfig.image.toInternal(),
maskImageConfig = null
)
)
if (editConfig.mask != null) {
add(
GenerateImageRequest.ReferenceImage(
referenceType = GenerateImageRequest.ReferenceType.MASK,
referenceId = 2,
referenceImage = editConfig.mask.toInternal(),
maskImageConfig =
GenerateImageRequest.MaskImageConfig(
maskMode = GenerateImageRequest.MaskMode.USER_PROVIDED,
dilation = editConfig.maskDilation
)
)
)
}
}
)
),
GenerateImageRequest.ImagenParameters(
sampleCount = generationConfig?.numberOfImages ?: 1,
includeRaiReason = true,
addWatermark = generationConfig?.addWatermark,
personGeneration = safetySettings?.personFilterLevel?.internalVal,
negativePrompt = config?.negativePrompt,
negativePrompt = generationConfig?.negativePrompt,
safetySetting = safetySettings?.safetyFilterLevel?.internalVal,
storageUri = gcsUri,
aspectRatio = config?.aspectRatio?.internalVal,
storageUri = null,
aspectRatio = generationConfig?.aspectRatio?.internalVal,
imageOutputOptions = generationConfig?.imageFormat?.toInternal(),
editMode = editConfig.editMode.value,
editConfig = editConfig.toInternal()
),
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import com.google.firebase.ai.common.util.fullModelName
import com.google.firebase.ai.common.util.trimmedModelName
import com.google.firebase.ai.type.Content
import com.google.firebase.ai.type.GenerationConfig
import com.google.firebase.ai.type.ImagenEditingConfig
import com.google.firebase.ai.type.ImagenImageFormat
import com.google.firebase.ai.type.ImagenInlineImage
import com.google.firebase.ai.type.PublicPreviewAPI
import com.google.firebase.ai.type.SafetySetting
import com.google.firebase.ai.type.Tool
Expand Down Expand Up @@ -75,11 +77,17 @@ internal data class CountTokensRequest(
}

@Serializable
@PublicPreviewAPI
internal data class GenerateImageRequest(
val instances: List<ImagenPrompt>,
val parameters: ImagenParameters,
) : Request {
@Serializable internal data class ImagenPrompt(val prompt: String)
@Serializable
internal data class ImagenPrompt(
val prompt: String? = null,
val image: ImagenInlineImage.Internal? = null,
val referenceImages: List<ReferenceImage>? = null
)

@OptIn(PublicPreviewAPI::class)
@Serializable
Expand All @@ -93,5 +101,38 @@ internal data class GenerateImageRequest(
val personGeneration: String?,
val addWatermark: Boolean?,
val imageOutputOptions: ImagenImageFormat.Internal?,
val editMode: String?,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be ImagenEditMode instead of String?

val editConfig: ImagenEditingConfig.Internal?,
)

@Serializable
internal enum class ReferenceType {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't some/all of these values available to the dev for some features to work? Like https://cloud.google.com/vertex-ai/generative-ai/docs/image/style-customization

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but right now the rest of the scaffolding isn't there to enable these features. As I add them, I'll also add access to these values.

@SerialName("REFERENCE_TYPE_UNSPECIFIED") UNSPECIFIED,
@SerialName("REFERENCE_TYPE_RAW") RAW,
@SerialName("REFERENCE_TYPE_MASK") MASK,
@SerialName("REFERENCE_TYPE_CONTROL") CONTROL,
@SerialName("REFERENCE_TYPE_STYLE") STYLE,
@SerialName("REFERENCE_TYPE_SUBJECT") SUBJECT,
@SerialName("REFERENCE_TYPE_MASKED_SUBJECT") MASKED_SUBJECT,
@SerialName("REFERENCE_TYPE_PRODUCT") PRODUCT
}

@Serializable
internal enum class MaskMode {
@SerialName("MASK_MODE_DEFAULT") DEFAULT,
@SerialName("MASK_MODE_USER_PROVIDED") USER_PROVIDED,
@SerialName("MASK_MODE_BACKGROUND") BACKGROUND,
@SerialName("MASK_MODE_FOREGROUND") FOREGROUND,
@SerialName("MASK_MODE_SEMANTIC") SEMANTIC
}

@Serializable internal data class MaskImageConfig(val maskMode: MaskMode, val dilation: Double?)

@Serializable
internal data class ReferenceImage(
val referenceType: ReferenceType,
val referenceId: Int,
val referenceImage: ImagenInlineImage.Internal,
val maskImageConfig: MaskImageConfig?
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.google.firebase.ai.java
import androidx.concurrent.futures.SuspendToFutureAdapter
import com.google.common.util.concurrent.ListenableFuture
import com.google.firebase.ai.ImagenModel
import com.google.firebase.ai.type.ImagenEditingConfig
import com.google.firebase.ai.type.ImagenGenerationResponse
import com.google.firebase.ai.type.ImagenInlineImage
import com.google.firebase.ai.type.PublicPreviewAPI
Expand All @@ -39,6 +40,11 @@ public abstract class ImagenModelFutures internal constructor() {
prompt: String,
): ListenableFuture<ImagenGenerationResponse<ImagenInlineImage>>

public abstract fun editImage(
prompt: String,
config: ImagenEditingConfig
): ListenableFuture<ImagenGenerationResponse<ImagenInlineImage>>

/** Returns the [ImagenModel] object wrapped by this object. */
public abstract fun getImageModel(): ImagenModel

Expand All @@ -48,6 +54,12 @@ public abstract class ImagenModelFutures internal constructor() {
): ListenableFuture<ImagenGenerationResponse<ImagenInlineImage>> =
SuspendToFutureAdapter.launchFuture { model.generateImages(prompt) }

override fun editImage(
prompt: String,
config: ImagenEditingConfig
): ListenableFuture<ImagenGenerationResponse<ImagenInlineImage>> =
SuspendToFutureAdapter.launchFuture { model.editImage(prompt, config) }

override fun getImageModel(): ImagenModel = model
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.google.firebase.ai.type
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please always include the copyright header to make sure the copyright check passes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I have it in the documentation PR instead


public class ImagenEditMode private constructor(internal val value: String) {

public companion object {
public val INPAINT_INSERTION: ImagenEditMode = ImagenEditMode("EDIT_MODE_INPAINT_INSERTION")
public val INPAINT_REMOVAL: ImagenEditMode = ImagenEditMode("EDIT_MODE_INPAINT_REMOVAL")
public val OUTPAINT: ImagenEditMode = ImagenEditMode("EDIT_MODE_OUTPAINT")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package com.google.firebase.ai.type
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'm not sure what the "same" here is referring to)


import kotlinx.serialization.Serializable

@PublicPreviewAPI
public class ImagenEditingConfig(
internal val image: ImagenInlineImage,
internal val editMode: ImagenEditMode,
internal val mask: ImagenInlineImage? = null,
internal val maskDilation: Double? = null,
internal val editSteps: Int? = null,
) {
public companion object {
public fun builder(): Builder = Builder()
}

public class Builder {
@JvmField public var image: ImagenInlineImage? = null
@JvmField public var editMode: ImagenEditMode? = null
@JvmField public var mask: ImagenInlineImage? = null
@JvmField public var maskDilation: Double? = null
@JvmField public var editSteps: Int? = null

public fun setImage(image: ImagenInlineImage): Builder = apply { this.image = image }

public fun setEditMode(editMode: ImagenEditMode): Builder = apply { this.editMode = editMode }

public fun setMask(mask: ImagenInlineImage): Builder = apply { this.mask = mask }

public fun setMaskDilation(maskDilation: Double): Builder = apply {
this.maskDilation = maskDilation
}

public fun setEditSteps(editSteps: Int): Builder = apply { this.editSteps = editSteps }

public fun build(): ImagenEditingConfig {
if (image == null) {
throw IllegalStateException("ImagenEditingConfig must contain an image")
}
if (editMode == null) {
throw IllegalStateException("ImagenEditingConfig must contain an editMode")
}
return ImagenEditingConfig(
image = image!!,
editMode = editMode!!,
mask = mask,
maskDilation = maskDilation,
editSteps = editSteps,
)
}
}

internal fun toInternal(): Internal {
return Internal(baseSteps = editSteps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the difference in name cause issues with error messages returned by the server? See shortn/_sPJ9lr7YkR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, it could, but maybe users could figure it out? or we could include it in the documentation?

I can rename it if you think thats the best solution, but I'd prefer a more descriptive name.

}

@Serializable
internal data class Internal(
val baseSteps: Int?,
)
}

@PublicPreviewAPI
public fun imagenEditingConfig(init: ImagenEditingConfig.Builder.() -> Unit): ImagenEditingConfig {
val builder = ImagenEditingConfig.builder()
builder.init()
return builder.build()
Comment on lines +64 to +67

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a default value for maskDilation and editSteps in the ImagenEditingConfig constructor to avoid having to use nullable types. This would simplify the usage of the class and reduce the risk of null pointer exceptions.

Suggested change
public fun imagenEditingConfig(init: ImagenEditingConfig.Builder.() -> Unit): ImagenEditingConfig {
val builder = ImagenEditingConfig.builder()
builder.init()
return builder.build()
public class ImagenEditingConfig(
public val image: ImagenInlineImage,
public val editMode: ImagenEditMode,
public val mask: ImagenInlineImage? = null,
public val maskDilation: Double = 0.0,
public val editSteps: Int = 0,
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bad idea, the server has default values that should be prefered over these, especially 0 edit steps would likely cause issues.

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ package com.google.firebase.ai.type

import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.util.Base64
import java.io.ByteArrayOutputStream
import kotlinx.serialization.Serializable

/**
* Represents an Imagen-generated image that is returned as inline data.
Expand All @@ -36,4 +39,19 @@ internal constructor(public val data: ByteArray, public val mimeType: String) {
public fun asBitmap(): Bitmap {
return BitmapFactory.decodeByteArray(data, 0, data.size)
}

@Serializable internal data class Internal(val bytesBase64Encoded: String)

internal fun toInternal(): Internal {
val base64 = Base64.encodeToString(data, Base64.NO_WRAP)
return Internal(base64)
}
}

@PublicPreviewAPI
public fun Bitmap.toImagenInlineImage(): ImagenInlineImage {
val byteArrayOutputStream = ByteArrayOutputStream()
this.compress(Bitmap.CompressFormat.PNG, 100, byteArrayOutputStream)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously we decided to go with JPEG for these conversions, see

Any reason to go with PNG in this scenario?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For image editing like this, I figured the lossless format would make a better default.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can replace it with the 80% jpeg, or we could change it to PNG across the board.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go with JPEG for consistency. That being said, it's absolutely true that we need to provide better interfaces to allow devs to pass the raw bytes (and/or file descriptors) for better compatibility

val byteArray = byteArrayOutputStream.toByteArray()
return ImagenInlineImage(data = byteArray, mimeType = "image/png")
}
Loading