Skip to content

feat(ui): chatgpt ref images & img2img #7974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,8 @@
"unableToCopyDesc": "Your browser does not support clipboard access. Firefox users may be able to fix this by following ",
"unableToCopyDesc_theseSteps": "these steps",
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
"imagen3IncompatibleGenerationMode": "Imagen3 only supports Text to Image. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"imagen3IncompatibleGenerationMode": "Google Imagen3 supports Text to Image only. Ensure the bounding box is empty, or use other models for Image to Image, Inpainting and Outpainting tasks.",
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image. Use other models Inpainting and Outpainting tasks.",
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
"workflowUnpublished": "Workflow Unpublished"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGlobalReferenceImageModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, ApiModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';

type Props = {
modelKey: string | null;
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => void;
};

export const GlobalReferenceImageModel = memo(({ modelKey, onChangeModel }: Props) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector(selectBase);
const [modelConfigs, { isLoading }] = useGlobalReferenceImageModels();
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);

const _onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null) => {
if (!modelConfig) {
return;
}
onChangeModel(modelConfig);
},
[onChangeModel]
);

const getIsDisabled = useCallback(
(model: AnyModelConfig): boolean => {
const hasMainModel = Boolean(currentBaseModel);
const hasSameBase = currentBaseModel === model.base;
return !hasMainModel || !hasSameBase;
},
[currentBaseModel]
);

const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChangeModel,
selectedModel,
getIsDisabled,
isLoading,
});

return (
<Tooltip label={selectedModel?.description}>
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
<Combobox
options={options}
placeholder={t('common.placeholderSelectAModel')}
value={value}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
<NavigateToModelManagerButton />
</FormControl>
</Tooltip>
);
});

GlobalReferenceImageModel.displayName = 'GlobalReferenceImageModel';
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ export const IPAdapterImagePreview = memo(
)}
{imageDTO && (
<>
<DndImage imageDTO={imageDTO} borderWidth={1} borderStyle="solid" />
<DndImage imageDTO={imageDTO} borderWidth={1} borderStyle="solid" w="full" />
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
<DndImageIcon
onClick={handleResetControlImage}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/c
import { Weight } from 'features/controlLayers/components/common/Weight';
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
import { GlobalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/GlobalReferenceImageModel';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterSettingsEmptyState } from 'features/controlLayers/components/IPAdapter/IPAdapterSettingsEmptyState';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
Expand Down Expand Up @@ -33,10 +34,9 @@ import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold } from 'react-icons/pi';
import type { FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import type { ApiModelConfig, FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';

import { IPAdapterImagePreview } from './IPAdapterImagePreview';
import { IPAdapterModel } from './IPAdapterModel';

const buildSelectIPAdapter = (entityIdentifier: CanvasEntityIdentifier<'reference_image'>) =>
createSelector(
Expand Down Expand Up @@ -80,7 +80,7 @@ const IPAdapterSettingsContent = memo(() => {
);

const onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => {
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => {
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier, modelConfig }));
},
[dispatch, entityIdentifier]
Expand Down Expand Up @@ -113,11 +113,7 @@ const IPAdapterSettingsContent = memo(() => {
<CanvasEntitySettingsWrapper>
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<IPAdapterModel
isRegionalGuidance={false}
modelKey={ipAdapter.model?.key ?? null}
onChangeModel={onChangeModel}
/>
<GlobalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
{ipAdapter.type === 'ip_adapter' && (
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,26 @@ import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useIPAdapterOrFLUXReduxModels } from 'services/api/hooks/modelsByType';
import { useRegionalReferenceImageModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';

type Props = {
isRegionalGuidance: boolean;
modelKey: string | null;
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => void;
};

export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeModel }: Props) => {
const filter = (config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
if (config.base === 'flux' && config.type === 'ip_adapter') {
return false;
}
return true;
};

export const RegionalReferenceImageModel = memo(({ modelKey, onChangeModel }: Props) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector(selectBase);
const filter = useCallback(
(config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
if (isRegionalGuidance && config.base === 'flux' && config.type === 'ip_adapter') {
return false;
}
return true;
},
[isRegionalGuidance]
);
const [modelConfigs, { isLoading }] = useIPAdapterOrFLUXReduxModels(filter);
const [modelConfigs, { isLoading }] = useRegionalReferenceImageModels(filter);
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);

const _onChangeModel = useCallback(
Expand Down Expand Up @@ -73,4 +70,4 @@ export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeMode
);
});

IPAdapterModel.displayName = 'IPAdapterModel';
RegionalReferenceImageModel.displayName = 'RegionalReferenceImageModel';
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLI
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
import { IPAdapterImagePreview } from 'features/controlLayers/components/IPAdapter/IPAdapterImagePreview';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterModel } from 'features/controlLayers/components/IPAdapter/IPAdapterModel';
import { RegionalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/RegionalReferenceImageModel';
import { RegionalGuidanceIPAdapterSettingsEmptyState } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoRegionalGuidanceReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
Expand Down Expand Up @@ -140,11 +140,7 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
</Flex>
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<IPAdapterModel
isRegionalGuidance={true}
modelKey={ipAdapter.model?.key ?? null}
onChangeModel={onChangeModel}
/>
<RegionalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
{ipAdapter.type === 'ip_adapter' && (
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,26 @@ import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
import type {
CanvasEntityIdentifier,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
ControlLoRAConfig,
ControlNetConfig,
IPAdapterConfig,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features/controlLayers/store/util';
import {
initialChatGPT4oReferenceImage,
initialControlNet,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlLayers/store/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { useCallback } from 'react';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import {
modelConfigsAdapterSelectors,
selectMainModelConfig,
selectModelConfigsQuery,
} from 'services/api/endpoints/models';
import type {
ControlLoRAModelConfig,
ControlNetModelConfig,
Expand Down Expand Up @@ -64,6 +74,35 @@ export const selectDefaultControlAdapter = createSelector(
}
);

export const selectDefaultRefImageConfig = createSelector(
selectMainModelConfig,
selectModelConfigsQuery,
selectBase,
(selectedMainModel, query, base): CanvasReferenceImageState['ipAdapter'] => {
if (selectedMainModel?.base === 'chatgpt-4o') {
const referenceImage = deepClone(initialChatGPT4oReferenceImage);
referenceImage.model = zModelIdentifierField.parse(selectedMainModel);
return referenceImage;
}

const { data } = query;
let model: IPAdapterModelConfig | null = null;
if (data) {
const modelConfigs = modelConfigsAdapterSelectors.selectAll(data).filter(isIPAdapterModelConfig);
const compatibleModels = modelConfigs.filter((m) => (base ? m.base === base : true));
model = compatibleModels[0] ?? modelConfigs[0] ?? null;
}
const ipAdapter = deepClone(initialIPAdapter);
if (model) {
ipAdapter.model = zModelIdentifierField.parse(model);
if (model.base === 'flux') {
ipAdapter.clipVisionModel = 'ViT-L';
}
}
return ipAdapter;
}
);

/**
* Selects the default IP adapter configuration based on the model configurations and the base.
*
Expand Down Expand Up @@ -146,11 +185,11 @@ export const useAddRegionalReferenceImage = () => {

export const useAddGlobalReferenceImage = () => {
const dispatch = useAppDispatch();
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
const defaultRefImage = useAppSelector(selectDefaultRefImageConfig);
const func = useCallback(() => {
const overrides = { ipAdapter: deepClone(defaultIPAdapter) };
const overrides = { ipAdapter: deepClone(defaultRefImage) };
dispatch(referenceImageAdded({ isSelected: true, overrides }));
}, [defaultIPAdapter, dispatch]);
}, [defaultRefImage, dispatch]);

return func;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHook
import { deepClone } from 'common/util/deepClone';
import { withResultAsync } from 'common/util/result';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { selectDefaultIPAdapter, selectDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import {
controlLayerAdded,
Expand Down Expand Up @@ -198,7 +198,7 @@ export const useNewRegionalReferenceImageFromBbox = () => {
export const useNewGlobalReferenceImageFromBbox = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
const defaultIPAdapter = useAppSelector(selectDefaultRefImageConfig);

const arg = useMemo<UseSaveCanvasArg>(() => {
const onSave = (imageDTO: ImageDTO) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
const isEntityTypeEnabled = useMemo<boolean>(() => {
switch (entityType) {
case 'reference_image':
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
return !isSD3 && !isCogView4 && !isImagen3;
case 'regional_guidance':
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
case 'control_layer':
Expand Down
Loading