Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
3146d44
feat(nodes): rename "FLUX Fill" -> "FLUX Fill Conditioning"
psychedelicious Mar 20, 2025
715c90b
fix(nodes): ensure alpha mask is opened as RGBA
psychedelicious Mar 20, 2025
8711475
feat(nodes): add `apply_mask_to_image` node
psychedelicious Mar 20, 2025
a0d2cd7
feat(nodes): add expand_mask_with_fade to better handle canvas compos…
psychedelicious Mar 20, 2025
d3d8aaf
feat(nodes): deprecate canvas_v2_mask_and_crop
psychedelicious Mar 20, 2025
f66bf61
chore(ui): typegen
psychedelicious Mar 20, 2025
701cc28
refactor(ui): use more succient syntax to opt-out of RTKQ caching for…
psychedelicious Mar 20, 2025
8785ad9
feat(ui): add FLUX Fill graph builder util
psychedelicious Mar 20, 2025
c178e5e
refactor(ui): use new compositing nodes for inpaint/outpaint graphs
psychedelicious Mar 20, 2025
5a41ce6
feat(ui): add selector to select the main model full config object
psychedelicious Mar 20, 2025
83df5e3
feat(ui): pass the full model config throughout validation logic
psychedelicious Mar 20, 2025
b8df47e
feat(ui): add warning for FLUX Fill + Control LoRA
psychedelicious Mar 20, 2025
8963bff
feat(ui): bump FLUX guidance up to 30 if it's too low during graph bu…
psychedelicious Mar 20, 2025
bb89c4c
feat(ui): better error message/warning for FLUX Fill w/ Control LoRA
psychedelicious Mar 20, 2025
c0014bb
feat(ui): better error for FLUX Fill + t2i/i2i incompatibility
psychedelicious Mar 20, 2025
e074d52
refactor(ui): just always set guidance to 30 when using FLUX Fill
psychedelicious Mar 20, 2025
0f1519e
feat(ui): hide guidance when FLUX Fill model selected
psychedelicious Mar 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion invokeai/app/invocations/flux_fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FluxFillOutput(BaseInvocationOutput):

@invocation(
"flux_fill",
title="FLUX Fill",
title="FLUX Fill Conditioning",
tags=["inpaint"],
category="inpaint",
version="1.0.0",
Expand Down
108 changes: 107 additions & 1 deletion invokeai/app/invocations/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
tags=["image", "mask", "id"],
category="image",
version="1.0.0",
classification=Classification.Internal,
classification=Classification.Deprecated,
)
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Handles Canvas V2 image output masking and cropping"""
Expand Down Expand Up @@ -1089,6 +1089,112 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
return ImageOutput.build(image_dto)


@invocation(
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.0"
)
class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
The mask is thresholded to create a binary mask, and then a distance transform is applied to create a fade effect.
The fade size is specified in pixels, and the mask is expanded by that amount. The result is a mask with a smooth transition from black to white.
"""

mask: ImageField = InputField(description="The mask to expand")
threshold: int = InputField(default=0, ge=0, le=255, description="The threshold for the binary mask (0-255)")
fade_size_px: int = InputField(default=32, ge=0, description="The size of the fade in pixels")

def invoke(self, context: InvocationContext) -> ImageOutput:
pil_mask = context.images.get_pil(self.mask.image_name, mode="L")

np_mask = numpy.array(pil_mask)

# Threshold the mask to create a binary mask - 0 for black, 255 for white
# If we don't threshold we can get some weird artifacts
np_mask = numpy.where(np_mask > self.threshold, 255, 0).astype(numpy.uint8)

# Create a mask for the black region (1 where black, 0 otherwise)
black_mask = (np_mask == 0).astype(numpy.uint8)

# Invert the black region
bg_mask = 1 - black_mask

# Create a distance transform of the inverted mask
dist = cv2.distanceTransform(bg_mask, cv2.DIST_L2, 5)

# Normalize distances so that pixels <fade_size_px become a linear gradient (0 to 1)
d_norm = numpy.clip(dist / self.fade_size_px, 0, 1)

# Control points: x values (normalized distance) and corresponding fade pct y values.

# There are some magic numbers here that are used to create a smooth transition:
# - The first point is at 0% of fade size from edge of mask (meaning the edge of the mask), and is 0% fade (black)
# - The second point is 1px from the edge of the mask and also has 0% fade, effectively expanding the mask
# by 1px. This fixes an issue where artifacts can occur at the edge of the mask
# - The third point is at 20% of the fade size from the edge of the mask and has 20% fade
# - The fourth point is at 80% of the fade size from the edge of the mask and has 90% fade
# - The last point is at 100% of the fade size from the edge of the mask and has 100% fade (white)

# x values: 0 = mask edge, 1 = fade_size_px from edge
x_control = numpy.array([0.0, 1.0 / self.fade_size_px, 0.2, 0.8, 1.0])
# y values: 0 = black, 1 = white
y_control = numpy.array([0.0, 0.0, 0.2, 0.9, 1.0])

# Fit a cubic polynomial that smoothly passes through the control points
coeffs = numpy.polyfit(x_control, y_control, 3)
poly = numpy.poly1d(coeffs)

# Evaluate and clip the smooth mapping
feather = numpy.clip(poly(d_norm), 0, 1)

# Build final image.
np_result = numpy.where(black_mask == 1, 0, (feather * 255).astype(numpy.uint8))

# Convert back to PIL, grayscale
pil_result = Image.fromarray(np_result.astype(numpy.uint8), mode="L")

image_dto = context.images.save(image=pil_result, image_category=ImageCategory.MASK)

return ImageOutput.build(image_dto)


@invocation(
"apply_mask_to_image",
title="Apply Mask to Image",
tags=["image", "mask", "blend"],
category="image",
version="1.0.0",
)
class ApplyMaskToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""
Extracts a region from a generated image using a mask and blends it seamlessly onto a source image.
The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
"""

image: ImageField = InputField(description="The image from which to extract the masked region")
mask: ImageField = InputField(description="The mask defining the region (black=keep, white=discard)")
invert_mask: bool = InputField(
default=False,
description="Whether to invert the mask before applying it",
)

def invoke(self, context: InvocationContext) -> ImageOutput:
# Load images
image = context.images.get_pil(self.image.image_name, mode="RGBA")
mask = context.images.get_pil(self.mask.image_name, mode="L")

if self.invert_mask:
# Invert the mask if requested
mask = ImageOps.invert(mask.copy())

# Combine the mask as the alpha channel of the image
r, g, b, _ = image.split() # Split the image into RGB and alpha channels
result_image = Image.merge("RGBA", (r, g, b, mask)) # Use the mask as the new alpha channel

# Save the resulting image
image_dto = context.images.save(image=result_image)

return ImageOutput.build(image_dto)


@invocation(
"img_noise",
title="Add Image Noise",
Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/invocations/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class AlphaMaskToTensorInvocation(BaseInvocation):
invert: bool = InputField(default=False, description="Whether to invert the mask.")

def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.images.get_pil(self.image.image_name)
image = context.images.get_pil(self.image.image_name, mode="RGBA")
mask = torch.zeros((1, image.height, image.width), dtype=torch.bool)
if self.invert:
mask[0] = torch.tensor(np.array(image)[:, :, 3] == 0, dtype=torch.bool)
Expand Down
6 changes: 4 additions & 2 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,8 @@
"problemDeletingWorkflow": "Problem Deleting Workflow",
"unableToCopy": "Unable to Copy",
"unableToCopyDesc": "Your browser does not support clipboard access. Firefox users may be able to fix this by following ",
"unableToCopyDesc_theseSteps": "these steps"
"unableToCopyDesc_theseSteps": "these steps",
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks."
},
"popovers": {
"clipSkip": {
Expand Down Expand Up @@ -1950,7 +1951,8 @@
"rgNegativePromptNotSupported": "Negative Prompt not supported for selected base model",
"rgReferenceImagesNotSupported": "regional Reference Images not supported for selected base model",
"rgAutoNegativeNotSupported": "Auto-Negative not supported for selected base model",
"rgNoRegion": "no region drawn"
"rgNoRegion": "no region drawn",
"fluxFillIncompatibleWithControlLoRA": "Control LoRA is not compatible with FLUX Fill"
},
"errors": {
"unableToFindImage": "Unable to find image",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppSelector } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useEntityIsEnabled } from 'features/controlLayers/hooks/useEntityIsEnabled';
import { selectModel } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import {
Expand All @@ -19,11 +18,12 @@ import { upperFirst } from 'lodash-es';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiWarningBold } from 'react-icons/pi';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';

const buildSelectWarnings = (entityIdentifier: CanvasEntityIdentifier, t: TFunction) => {
return createSelector(selectCanvasSlice, selectModel, (canvas, model) => {
return createSelector(selectCanvasSlice, selectMainModelConfig, (canvas, model) => {
// This component is used within a <CanvasEntityStateGate /> so we can safely assume that the entity exists.
// Should never throw.
const entity = selectEntityOrThrow(canvas, entityIdentifier, 'CanvasEntityHeaderWarnings');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import type {
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
} from 'features/controlLayers/store/types';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import type { MainModelConfig } from 'services/api/types';

const WARNINGS = {
UNSUPPORTED_MODEL: 'controlLayers.warnings.unsupportedModel',
Expand All @@ -20,13 +20,14 @@ const WARNINGS = {
CONTROL_ADAPTER_NO_MODEL_SELECTED: 'controlLayers.warnings.controlAdapterNoModelSelected',
CONTROL_ADAPTER_INCOMPATIBLE_BASE_MODEL: 'controlLayers.warnings.controlAdapterIncompatibleBaseModel',
CONTROL_ADAPTER_NO_CONTROL: 'controlLayers.warnings.controlAdapterNoControl',
FLUX_FILL_NO_WORKY_WITH_CONTROL_LORA: 'controlLayers.warnings.fluxFillIncompatibleWithControlLoRA',
} as const;

type WarningTKey = (typeof WARNINGS)[keyof typeof WARNINGS];

export const getRegionalGuidanceWarnings = (
entity: CanvasRegionalGuidanceState,
model: ParameterModel | null
model: MainModelConfig | null | undefined
): WarningTKey[] => {
const warnings: WarningTKey[] = [];

Expand Down Expand Up @@ -78,7 +79,7 @@ export const getRegionalGuidanceWarnings = (

export const getGlobalReferenceImageWarnings = (
entity: CanvasReferenceImageState,
model: ParameterModel | null
model: MainModelConfig | null | undefined
): WarningTKey[] => {
const warnings: WarningTKey[] = [];

Expand Down Expand Up @@ -110,7 +111,7 @@ export const getGlobalReferenceImageWarnings = (

export const getControlLayerWarnings = (
entity: CanvasControlLayerState,
model: ParameterModel | null
model: MainModelConfig | null | undefined
): WarningTKey[] => {
const warnings: WarningTKey[] = [];

Expand All @@ -129,6 +130,13 @@ export const getControlLayerWarnings = (
} else if (entity.controlAdapter.model.base !== model.base) {
// Supported model architecture but doesn't match
warnings.push(WARNINGS.CONTROL_ADAPTER_INCOMPATIBLE_BASE_MODEL);
} else if (
model.base === 'flux' &&
model.variant === 'inpaint' &&
entity.controlAdapter.model.type === 'control_lora'
) {
// FLUX inpaint variants are FLUX Fill models - not compatible w/ Control LoRA
warnings.push(WARNINGS.FLUX_FILL_NO_WORKY_WITH_CONTROL_LORA);
}
}

Expand All @@ -137,7 +145,7 @@ export const getControlLayerWarnings = (

export const getRasterLayerWarnings = (
_entity: CanvasRasterLayerState,
_model: ParameterModel | null
_model: MainModelConfig | null | undefined
): WarningTKey[] => {
const warnings: WarningTKey[] = [];

Expand All @@ -148,7 +156,7 @@ export const getRasterLayerWarnings = (

export const getInpaintMaskWarnings = (
_entity: CanvasInpaintMaskState,
_model: ParameterModel | null
_model: MainModelConfig | null | undefined
): WarningTKey[] => {
const warnings: WarningTKey[] = [];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ export class InvalidModelConfigError extends Error {
export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> => {
const { dispatch } = getStore();
try {
const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key));
req.unsubscribe();
const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key, { subscribe: false }));
return await req.unwrap();
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for key ${key}`);
Expand All @@ -62,8 +61,9 @@ export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> =>
const fetchModelConfigByAttrs = async (name: string, base: BaseModelType, type: ModelType): Promise<AnyModelConfig> => {
const { dispatch } = getStore();
try {
const req = dispatch(modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type }));
req.unsubscribe();
const req = dispatch(
modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type }, { subscribe: false })
);
return await req.unwrap();
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for name/base/type ${name}/${base}/${type}`);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasControlLayerState, Rect } from 'features/controlLayers/store/types';
import { getControlLayerWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import { serializeError } from 'serialize-error';
import type { ImageDTO, Invocation } from 'services/api/types';
import type { ImageDTO, Invocation, MainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';

const log = logger('system');
Expand All @@ -17,7 +16,7 @@ type AddControlNetsArg = {
g: Graph;
rect: Rect;
collector: Invocation<'collect'>;
model: ParameterModel;
model: MainModelConfig;
};

type AddControlNetsResult = {
Expand Down Expand Up @@ -66,7 +65,7 @@ type AddT2IAdaptersArg = {
g: Graph;
rect: Rect;
collector: Invocation<'collect'>;
model: ParameterModel;
model: MainModelConfig;
};

type AddT2IAdaptersResult = {
Expand Down Expand Up @@ -114,7 +113,7 @@ type AddControlLoRAArg = {
entities: CanvasControlLayerState[];
g: Graph;
rect: Rect;
model: ParameterModel;
model: MainModelConfig;
denoise: Invocation<'flux_denoise'>;
};

Expand All @@ -129,9 +128,9 @@ export const addControlLoRA = async ({ manager, entities, g, rect, model, denois
// No valid control LoRA found
return;
}
if (validControlLayers.length > 1) {
throw new Error('Cannot add more than one FLUX control LoRA.');
}

assert(model.variant !== 'inpaint', 'FLUX Control LoRA is not compatible with FLUX Fill.');
assert(validControlLayers.length <= 1, 'Cannot add more than one FLUX control LoRA.');

const getImageDTOResult = await withResultAsync(() => {
const adapter = manager.adapters.controlLayers.get(validControlLayer.id);
Expand Down
Loading
Loading