Skip to content

feat(ui): Adds Imagen4 scaffold support #8032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
FluxReduxModel = "FluxReduxModelField"
LlavaOnevisionModel = "LLaVAModelField"
Imagen3Model = "Imagen3ModelField"
Imagen4Model = "Imagen4ModelField"
ChatGPT4oModel = "ChatGPT4oModelField"
# endregion

Expand Down
1 change: 1 addition & 0 deletions invokeai/backend/model_manager/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class BaseModelType(str, Enum):
Flux = "flux"
CogView4 = "cogview4"
Imagen3 = "imagen3"
Imagen4 = "imagen4"
ChatGPT4o = "chatgpt-4o"


Expand Down
2 changes: 1 addition & 1 deletion invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@
"unableToCopyDesc": "Your browser does not support clipboard access. Firefox users may be able to fix this by following ",
"unableToCopyDesc_theseSteps": "these steps",
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
"imagen3IncompatibleGenerationMode": "Google Imagen3 supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"imagenIncompatibleGenerationMode": "Google {{model}} supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image only. Use other models Inpainting and Outpainting tasks.",
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildC
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
Expand Down Expand Up @@ -54,6 +55,8 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
return await buildCogView4Graph(state, manager);
case 'imagen3':
return await buildImagen3Graph(state, manager);
case 'imagen4':
return await buildImagen4Graph(state, manager);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, manager);
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
selectIsChatGTP4o,
selectIsCogView4,
selectIsImagen3,
selectIsImagen4,
selectIsSD3,
} from 'features/controlLayers/store/paramsSlice';
import type { CanvasEntityType } from 'features/controlLayers/store/types';
Expand All @@ -14,24 +15,25 @@ export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
const isSD3 = useAppSelector(selectIsSD3);
const isCogView4 = useAppSelector(selectIsCogView4);
const isImagen3 = useAppSelector(selectIsImagen3);
const isImagen4 = useAppSelector(selectIsImagen4);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);

const isEntityTypeEnabled = useMemo<boolean>(() => {
switch (entityType) {
case 'reference_image':
return !isSD3 && !isCogView4 && !isImagen3;
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4;
case 'regional_guidance':
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isChatGPT4o;
case 'control_layer':
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isChatGPT4o;
case 'inpaint_mask':
return !isImagen3 && !isChatGPT4o;
return !isImagen3 && !isImagen4 && !isChatGPT4o;
case 'raster_layer':
return !isImagen3 && !isChatGPT4o;
return !isImagen3 && !isImagen4 && !isChatGPT4o;
default:
assert<Equals<typeof entityType, never>>(false);
}
}, [entityType, isSD3, isCogView4, isImagen3, isChatGPT4o]);
}, [entityType, isSD3, isCogView4, isImagen3, isImagen4, isChatGPT4o]);

return isEntityTypeEnabled;
};
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { selectModel } from 'features/controlLayers/store/paramsSlice';
import { selectBbox } from 'features/controlLayers/store/selectors';
import type { Coordinate, Rect, Tool } from 'features/controlLayers/store/types';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { API_BASE_MODELS } from 'features/parameters/types/constants';
import Konva from 'konva';
import { noop } from 'lodash-es';
import { atom } from 'nanostores';
Expand Down Expand Up @@ -235,7 +236,7 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
if (tool !== 'bbox') {
return NO_ANCHORS;
}
if (model?.base === 'imagen3' || model?.base === 'chatgpt-4o') {
if (model?.base && API_BASE_MODELS.includes(model.base)) {
// The bbox is not resizable in these modes
return NO_ANCHORS;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import {
import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify';
import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
import { API_BASE_MODELS } from 'features/parameters/types/constants';
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import type { IRect } from 'konva/lib/types';
import { isEqual, merge } from 'lodash-es';
Expand Down Expand Up @@ -68,7 +69,7 @@ import type {
IPMethodV2,
T2IAdapterConfig,
} from './types';
import { getEntityIdentifier, isChatGPT4oAspectRatioID, isImagen3AspectRatioID, isRenderableEntity } from './types';
import { getEntityIdentifier, isChatGPT4oAspectRatioID, isImagenAspectRatioID, isRenderableEntity } from './types';
import {
converters,
getControlLayerState,
Expand Down Expand Up @@ -1236,7 +1237,10 @@ export const canvasSlice = createSlice({
state.bbox.aspectRatio.id = id;
if (id === 'Free') {
state.bbox.aspectRatio.isLocked = false;
} else if (state.bbox.modelBase === 'imagen3' && isImagen3AspectRatioID(id)) {
} else if (
(state.bbox.modelBase === 'imagen3' || state.bbox.modelBase === 'imagen4') &&
isImagenAspectRatioID(id)
) {
// Imagen3 has specific output sizes that are not exactly the same as the aspect ratio. Need special handling.
if (id === '16:9') {
state.bbox.rect.width = 1408;
Expand Down Expand Up @@ -1742,7 +1746,7 @@ export const canvasSlice = createSlice({
const base = model?.base;
if (isMainModelBase(base) && state.bbox.modelBase !== base) {
state.bbox.modelBase = base;
if (base === 'imagen3' || base === 'chatgpt-4o') {
if (API_BASE_MODELS.includes(base)) {
state.bbox.aspectRatio.isLocked = true;
state.bbox.aspectRatio.value = 1;
state.bbox.aspectRatio.id = '1:1';
Expand Down Expand Up @@ -1881,7 +1885,7 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
};

const syncScaledSize = (state: CanvasState) => {
if (state.bbox.modelBase === 'imagen3' || state.bbox.modelBase === 'chatgpt-4o') {
if (API_BASE_MODELS.includes(state.bbox.modelBase)) {
// Imagen3 has fixed sizes. Scaled bbox is not supported.
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ export const selectIsFLUX = createParamsSelector((params) => params.model?.base
export const selectIsSD3 = createParamsSelector((params) => params.model?.base === 'sd-3');
export const selectIsCogView4 = createParamsSelector((params) => params.model?.base === 'cogview4');
export const selectIsImagen3 = createParamsSelector((params) => params.model?.base === 'imagen3');
export const selectIsImagen4 = createParamsSelector((params) => params.model?.base === 'imagen4');
export const selectIsChatGTP4o = createParamsSelector((params) => params.model?.base === 'chatgpt-4o');

export const selectModel = createParamsSelector((params) => params.model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ export type StagingAreaImage = {
export const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);

export const zImagen3AspectRatioID = z.enum(['16:9', '4:3', '1:1', '3:4', '9:16']);
export const isImagen3AspectRatioID = (v: unknown): v is z.infer<typeof zImagen3AspectRatioID> =>
export const isImagenAspectRatioID = (v: unknown): v is z.infer<typeof zImagen3AspectRatioID> =>
zImagen3AspectRatioID.safeParse(v).success;

export const zChatGPT4oAspectRatioID = z.enum(['3:2', '1:1', '2:3']);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export const BASE_COLOR_MAP: Record<BaseModelType, string> = {
flux: 'gold',
cogview4: 'red',
imagen3: 'pink',
imagen4: 'pink',
'chatgpt-4o': 'pink',
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flo
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import { ImageGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageGeneratorFieldComponent';
import Imagen3ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen3ModelFieldInputComponent';
import Imagen4ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen4ModelFieldInputComponent';
import { IntegerFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerFieldCollectionInputComponent';
import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent';
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
Expand Down Expand Up @@ -63,6 +64,8 @@ import {
isImageGeneratorFieldInputTemplate,
isImagen3ModelFieldInputInstance,
isImagen3ModelFieldInputTemplate,
isImagen4ModelFieldInputInstance,
isImagen4ModelFieldInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isIntegerFieldInputInstance,
Expand Down Expand Up @@ -407,6 +410,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <Imagen3ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}

if (isImagen4ModelFieldInputTemplate(template)) {
if (!isImagen4ModelFieldInputInstance(field)) {
return null;
}
return <Imagen4ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}

if (isChatGPT4oModelFieldInputTemplate(template)) {
if (!isChatGPT4oModelFieldInputInstance(field)) {
return null;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldImagen4ModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { Imagen4ModelFieldInputInstance, Imagen4ModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useImagen4Models } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';

import type { FieldComponentProps } from './types';

const Imagen4ModelFieldInputComponent = (
props: FieldComponentProps<Imagen4ModelFieldInputInstance, Imagen4ModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();

const [modelConfigs, { isLoading }] = useImagen4Models();

const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldImagen4ModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);

return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

export default memo(Imagen4ModelFieldInputComponent);
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ const NODE_TYPE_PUBLISH_DENYLIST = [
'metadata_to_t2i_adapters',
'google_imagen3_generate',
'google_imagen3_edit',
'google_imagen4_generate',
'chatgpt_create_image',
'chatgpt_edit_image',
];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import type {
ImageFieldValue,
ImageGeneratorFieldValue,
Imagen3ModelFieldValue,
Imagen4ModelFieldValue,
IntegerFieldCollectionValue,
IntegerFieldValue,
IntegerGeneratorFieldValue,
Expand Down Expand Up @@ -80,6 +81,7 @@ import {
zImageFieldValue,
zImageGeneratorFieldValue,
zImagen3ModelFieldValue,
zImagen4ModelFieldValue,
zIntegerFieldCollectionValue,
zIntegerFieldValue,
zIntegerGeneratorFieldValue,
Expand Down Expand Up @@ -519,6 +521,9 @@ export const nodesSlice = createSlice({
fieldImagen3ModelValueChanged: (state, action: FieldValueAction<Imagen3ModelFieldValue>) => {
fieldValueReducer(state, action, zImagen3ModelFieldValue);
},
fieldImagen4ModelValueChanged: (state, action: FieldValueAction<Imagen4ModelFieldValue>) => {
fieldValueReducer(state, action, zImagen4ModelFieldValue);
},
fieldChatGPT4oModelValueChanged: (state, action: FieldValueAction<ChatGPT4oModelFieldValue>) => {
fieldValueReducer(state, action, zChatGPT4oModelFieldValue);
},
Expand Down Expand Up @@ -690,6 +695,7 @@ export const {
fieldSigLipModelValueChanged,
fieldFluxReduxModelValueChanged,
fieldImagen3ModelValueChanged,
fieldImagen4ModelValueChanged,
fieldChatGPT4oModelValueChanged,
fieldFloatGeneratorValueChanged,
fieldIntegerGeneratorValueChanged,
Expand Down
13 changes: 12 additions & 1 deletion invokeai/frontend/web/src/features/nodes/types/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,21 @@ const zBaseModel = z.enum([
'flux',
'cogview4',
'imagen3',
'imagen4',
'chatgpt-4o',
]);
export type BaseModelType = z.infer<typeof zBaseModel>;
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o']);
export const zMainModelBase = z.enum([
'sd-1',
'sd-2',
'sd-3',
'sdxl',
'flux',
'cogview4',
'imagen3',
'imagen4',
'chatgpt-4o',
]);
export type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
const zModelType = z.enum([
Expand Down
26 changes: 26 additions & 0 deletions invokeai/frontend/web/src/features/nodes/types/field.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ const zImagen3ModelFieldType = zFieldTypeBase.extend({
name: z.literal('Imagen3ModelField'),
originalType: zStatelessFieldType.optional(),
});
const zImagen4ModelFieldType = zFieldTypeBase.extend({
name: z.literal('Imagen4ModelField'),
originalType: zStatelessFieldType.optional(),
});
const zChatGPT4oModelFieldType = zFieldTypeBase.extend({
name: z.literal('ChatGPT4oModelField'),
originalType: zStatelessFieldType.optional(),
Expand Down Expand Up @@ -307,6 +311,7 @@ const zStatefulFieldType = z.union([
zSigLipModelFieldType,
zFluxReduxModelFieldType,
zImagen3ModelFieldType,
zImagen4ModelFieldType,
zChatGPT4oModelFieldType,
zColorFieldType,
zSchedulerFieldType,
Expand Down Expand Up @@ -347,6 +352,7 @@ const modelFieldTypeNames = [
zSigLipModelFieldType.shape.name.value,
zFluxReduxModelFieldType.shape.name.value,
zImagen3ModelFieldType.shape.name.value,
zImagen4ModelFieldType.shape.name.value,
zChatGPT4oModelFieldType.shape.name.value,
// Stateless model fields
'UNetField',
Expand Down Expand Up @@ -1207,6 +1213,24 @@ export const isImagen3ModelFieldInputTemplate =
buildTemplateTypeGuard<Imagen3ModelFieldInputTemplate>('Imagen3ModelField');
// #endregion

// #region Imagen4ModelField
export const zImagen4ModelFieldValue = zModelIdentifierField.optional();
const zImagen4ModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zImagen4ModelFieldValue,
});
const zImagen4ModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zImagen4ModelFieldType,
originalType: zFieldType.optional(),
default: zImagen4ModelFieldValue,
});
export type Imagen4ModelFieldValue = z.infer<typeof zImagen4ModelFieldValue>;
export type Imagen4ModelFieldInputInstance = z.infer<typeof zImagen4ModelFieldInputInstance>;
export type Imagen4ModelFieldInputTemplate = z.infer<typeof zImagen4ModelFieldInputTemplate>;
export const isImagen4ModelFieldInputInstance = buildInstanceTypeGuard(zImagen4ModelFieldInputInstance);
export const isImagen4ModelFieldInputTemplate =
buildTemplateTypeGuard<Imagen4ModelFieldInputTemplate>('Imagen4ModelField');
// #endregion

// #region ChatGPT4oModelField
export const zChatGPT4oModelFieldValue = zModelIdentifierField.optional();
const zChatGPT4oModelFieldInputInstance = zFieldInputInstanceBase.extend({
Expand Down Expand Up @@ -1857,6 +1881,7 @@ export const zStatefulFieldValue = z.union([
zSigLipModelFieldValue,
zFluxReduxModelFieldValue,
zImagen3ModelFieldValue,
zImagen4ModelFieldValue,
zChatGPT4oModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
Expand Down Expand Up @@ -1949,6 +1974,7 @@ const zStatefulFieldInputTemplate = z.union([
zSigLipModelFieldInputTemplate,
zFluxReduxModelFieldInputTemplate,
zImagen3ModelFieldInputTemplate,
zImagen4ModelFieldInputTemplate,
zChatGPT4oModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,
Expand Down
Loading