Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGro
return uniqueGroupKey in optionOrGroup && optionOrGroup[uniqueGroupKey] === true;
};

export const isOption = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is T => {
return !(uniqueGroupKey in optionOrGroup);
};

const DefaultOptionComponent = typedMemo(<T extends object>({ option }: { option: T }) => {
const { getOptionId } = usePickerContext();
return <Text fontWeight="bold">{getOptionId(option)}</Text>;
Expand Down
35 changes: 18 additions & 17 deletions invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import type { GroupStatusMap } from 'common/components/Picker/Picker';
import { useRelatedGroupedModelCombobox } from 'common/hooks/useRelatedGroupedModelCombobox';
import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { ModelPicker } from 'features/parameters/components/ModelPicker';
Expand All @@ -20,16 +20,23 @@ const LoRASelect = () => {
const [modelConfigs, { isLoading }] = useLoRAModels();
const { t } = useTranslation();
const addedLoRAs = useAppSelector(selectLoRAs);

const currentBaseModel = useAppSelector(selectBase);

// Filter to only show compatible LoRAs
const compatibleLoRAs = useMemo(() => {
if (!currentBaseModel) {
return EMPTY_ARRAY;
}
return modelConfigs.filter((model) => model.base === currentBaseModel);
}, [modelConfigs, currentBaseModel]);

const getIsDisabled = useCallback(
(model: LoRAModelConfig): boolean => {
const isCompatible = currentBaseModel === model.base;
const isAdded = Boolean(addedLoRAs.find((lora) => lora.model.key === model.key));
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible || isAdded;
return isAdded;
},
[addedLoRAs, currentBaseModel]
[addedLoRAs]
);

const onChange = useCallback(
Expand All @@ -42,23 +49,17 @@ const LoRASelect = () => {
[dispatch]
);

const { options } = useRelatedGroupedModelCombobox({
modelConfigs,
getIsDisabled,
onChange,
});

const placeholder = useMemo(() => {
if (isLoading) {
return t('common.loading');
}

if (options.length === 0) {
return t('models.noLoRAsInstalled');
if (compatibleLoRAs.length === 0) {
return currentBaseModel ? t('models.noCompatibleLoRAs') : t('models.selectModelFirst');
}

return t('models.addLora');
}, [isLoading, options.length, t]);
}, [isLoading, compatibleLoRAs.length, currentBaseModel, t]);

// Calculate initial group states to default to the current base model architecture
const initialGroupStates = useMemo(() => {
Expand All @@ -79,15 +80,15 @@ const LoRASelect = () => {
<FormLabel>{t('models.concepts')} </FormLabel>
</InformationalPopover>
<ModelPicker
modelConfigs={modelConfigs}
modelConfigs={compatibleLoRAs}
onChange={onChange}
grouped
grouped={false}
selectedModelConfig={undefined}
allowEmpty
placeholder={placeholder}
getIsOptionDisabled={getIsDisabled}
noOptionsText={t('models.noLoRAsInstalled')}
initialGroupStates={initialGroupStates}
noOptionsText={currentBaseModel ? t('models.noCompatibleLoRAs') : t('models.selectModelFirst')}
/>
</FormControl>
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { BoxProps, ButtonProps, SystemStyleObject } from '@invoke-ai/ui-lib
import {
Button,
Flex,
Icon,
Popover,
PopoverArrow,
PopoverBody,
Expand All @@ -12,12 +13,17 @@ import {
Text,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { $onClickGoToModelManager } from 'app/store/nanostores/onClickGoToModelManager';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { Group, PickerContextState } from 'common/components/Picker/Picker';
import { buildGroup, getRegex, Picker, usePickerContext } from 'common/components/Picker/Picker';
import { buildGroup, getRegex, isOption, Picker, usePickerContext } from 'common/components/Picker/Picker';
import { useDisclosure } from 'common/hooks/useBoolean';
import { typedMemo } from 'common/util/typedMemo';
import { uniq } from 'es-toolkit/compat';
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore';
import { BASE_COLOR_MAP } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
import ModelImage from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelImage';
Expand All @@ -28,10 +34,40 @@ import { setActiveTab } from 'features/ui/store/uiSlice';
import { filesize } from 'filesize';
import { memo, useCallback, useMemo, useRef } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { PiCaretDownBold } from 'react-icons/pi';
import { PiCaretDownBold, PiLinkSimple } from 'react-icons/pi';
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
import type { AnyModelConfig, BaseModelType } from 'services/api/types';

const getOptionId = (modelConfig: AnyModelConfig) => modelConfig.key;
const selectSelectedModelKeys = createMemoizedSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => {
const keys: string[] = [];
const main = params.model;
const vae = params.vae;
const refiner = params.refinerModel;
const controlnet = params.controlLora;

if (main) {
keys.push(main.key);
}
if (vae) {
keys.push(vae.key);
}
if (refiner) {
keys.push(refiner.key);
}
if (controlnet) {
keys.push(controlnet.key);
}
for (const { model } of loras.loras) {
keys.push(model.key);
}

return uniq(keys);
});

type WithStarred<T> = T & { starred?: boolean };

// Type for models with starred field
const getOptionId = <T extends AnyModelConfig>(modelConfig: WithStarred<T>) => modelConfig.key;

const ModelManagerLink = memo((props: ButtonProps) => {
const onClickGoToModelManager = useStore($onClickGoToModelManager);
Expand Down Expand Up @@ -104,6 +140,15 @@ const getGroupColorSchemeFromModelConfig = (modelConfig: AnyModelConfig): string
return BASE_COLOR_MAP[modelConfig.base];
};

const relatedModelKeysQueryOptions = {
selectFromResult: ({ data }) => {
if (!data) {
return { relatedModelKeys: EMPTY_ARRAY };
}
return { relatedModelKeys: data };
},
} satisfies Parameters<typeof useGetRelatedModelIdsBatchQuery>[1];

const popperModifiers = [
{
// Prevents the popover from "touching" the edges of the screen
Expand All @@ -112,6 +157,11 @@ const popperModifiers = [
},
];

const removeStarred = <T,>(obj: WithStarred<T>): T => {
const { starred: _, ...rest } = obj;
return rest as T;
};

export const ModelPicker = typedMemo(
<T extends AnyModelConfig = AnyModelConfig>({
modelConfigs,
Expand Down Expand Up @@ -141,19 +191,38 @@ export const ModelPicker = typedMemo(
initialGroupStates?: Record<string, boolean>;
}) => {
const { t } = useTranslation();
const options = useMemo<T[] | Group<T>[]>(() => {
const selectedKeys = useAppSelector(selectSelectedModelKeys);

const { relatedModelKeys } = useGetRelatedModelIdsBatchQuery(selectedKeys, relatedModelKeysQueryOptions);

const options = useMemo<WithStarred<T>[] | Group<WithStarred<T>>[]>(() => {
if (!grouped) {
return modelConfigs;
// Add starred field to model options and sort them
const modelsWithStarred = modelConfigs.map((model) => ({
...model,
starred: relatedModelKeys.includes(model.key),
}));

// Sort so starred models come first
return modelsWithStarred.sort((a, b) => {
if (a.starred && !b.starred) {
return -1;
}
if (!a.starred && b.starred) {
return 1;
}
return 0;
});
}

// When all groups are disabled, we show all models
const groups: Record<string, Group<T>> = {};
const groups: Record<string, Group<WithStarred<T>>> = {};

for (const modelConfig of modelConfigs) {
const groupId = getGroupIDFromModelConfig(modelConfig);
let group = groups[groupId];
if (!group) {
group = buildGroup<T>({
group = buildGroup<WithStarred<T>>({
id: modelConfig.base,
color: `${getGroupColorSchemeFromModelConfig(modelConfig)}.300`,
shortName: getGroupShortNameFromModelConfig(modelConfig),
Expand All @@ -164,35 +233,60 @@ export const ModelPicker = typedMemo(
groups[groupId] = group;
}
if (group) {
group.options.push(modelConfig);
// Add starred field to the model
const modelWithStarred = {
...modelConfig,
starred: relatedModelKeys.includes(modelConfig.key),
};
group.options.push(modelWithStarred);
}
}

const _options: Group<T>[] = [];
const _options: Group<WithStarred<T>>[] = [];

// Add groups in the original order
for (const groupId of ['api', 'flux', 'cogview4', 'sdxl', 'sd-3', 'sd-2', 'sd-1']) {
const group = groups[groupId];
if (group) {
// Sort options within each group so starred ones come first
group.options.sort((a, b) => {
if (a.starred && !b.starred) {
return -1;
}
if (!a.starred && b.starred) {
return 1;
}
return 0;
});
_options.push(group);
delete groups[groupId];
}
}
_options.push(...Object.values(groups));

return _options;
}, [grouped, modelConfigs, t]);
}, [grouped, modelConfigs, relatedModelKeys, t]);
const popover = useDisclosure(false);
const pickerRef = useRef<PickerContextState<T>>(null);
const pickerRef = useRef<PickerContextState<WithStarred<T>>>(null);

const selectedOption = useMemo<WithStarred<T> | undefined>(() => {
if (!selectedModelConfig) {
return undefined;
}

return options.filter(isOption).find((o) => o.key === selectedModelConfig.key);
}, [options, selectedModelConfig]);

const onClose = useCallback(() => {
popover.close();
pickerRef.current?.$searchTerm.set('');
}, [popover]);

const onSelect = useCallback(
(model: T) => {
(model: WithStarred<T>) => {
onClose();
onChange(model);
// Remove the starred field before passing to onChange
onChange(removeStarred(model));
},
[onChange, onClose]
);
Expand Down Expand Up @@ -232,15 +326,15 @@ export const ModelPicker = typedMemo(
<Portal appendToParentPortal={false}>
<PopoverContent p={0} w={400} h={400}>
<PopoverArrow />
<PopoverBody p={0} w="full" h="full">
<Picker<T>
<PopoverBody p={0} w="full" h="full" borderWidth={1} borderColor="base.700" borderRadius="base">
<Picker<WithStarred<T>>
handleRef={pickerRef}
optionsOrGroups={options}
getOptionId={getOptionId}
getOptionId={getOptionId<T>}
onSelect={onSelect}
selectedOption={selectedModelConfig}
isMatch={isMatch}
OptionComponent={PickerOptionComponent}
selectedOption={selectedOption}
isMatch={isMatch<T>}
OptionComponent={PickerOptionComponent<T>}
noOptionsFallback={<NoOptionsFallback noOptionsText={noOptionsText} />}
noMatchesFallback={t('modelManager.noMatchingModels')}
NextToSearchBar={<NavigateToModelManagerButton />}
Expand Down Expand Up @@ -291,35 +385,38 @@ const optionNameSx: SystemStyleObject = {
},
};

const PickerOptionComponent = typedMemo(({ option, ...rest }: { option: AnyModelConfig } & BoxProps) => {
const { $compactView } = usePickerContext<AnyModelConfig>();
const compactView = useStore($compactView);
const PickerOptionComponent = typedMemo(
<T extends AnyModelConfig>({ option, ...rest }: { option: WithStarred<T> } & BoxProps) => {
const { $compactView } = usePickerContext<WithStarred<T>>();
const compactView = useStore($compactView);

return (
<Flex {...rest} sx={optionSx} data-is-compact={compactView}>
{!compactView && option.cover_image && <ModelImage image_url={option.cover_image} />}
<Flex flexDir="column" gap={1} flex={1}>
<Flex gap={2} alignItems="center">
<Text sx={optionNameSx} data-is-compact={compactView}>
{option.name}
</Text>
<Spacer />
{option.file_size > 0 && (
<Text variant="subtext" fontStyle="italic" noOfLines={1} flexShrink={0} overflow="visible">
{filesize(option.file_size)}
</Text>
)}
{option.usage_info && (
<Text variant="subtext" fontStyle="italic" noOfLines={1} flexShrink={0} overflow="visible">
{option.usage_info}
return (
<Flex {...rest} sx={optionSx} data-is-compact={compactView}>
{!compactView && option.cover_image && <ModelImage image_url={option.cover_image} />}
<Flex flexDir="column" gap={1} flex={1}>
<Flex gap={2} alignItems="center">
{option.starred && <Icon as={PiLinkSimple} color="invokeYellow.500" boxSize={4} />}
<Text sx={optionNameSx} data-is-compact={compactView}>
{option.name}
</Text>
)}
<Spacer />
{option.file_size > 0 && (
<Text variant="subtext" fontStyle="italic" noOfLines={1} flexShrink={0} overflow="visible">
{filesize(option.file_size)}
</Text>
)}
{option.usage_info && (
<Text variant="subtext" fontStyle="italic" noOfLines={1} flexShrink={0} overflow="visible">
{option.usage_info}
</Text>
)}
</Flex>
{option.description && !compactView && <Text color="base.200">{option.description}</Text>}
</Flex>
{option.description && !compactView && <Text color="base.200">{option.description}</Text>}
</Flex>
</Flex>
);
});
);
}
);
PickerOptionComponent.displayName = 'PickerItemComponent';

const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = {
Expand All @@ -328,7 +425,7 @@ const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = {
'sd-3': ['sd3', 'sd3.0', 'sd3.5', 'sd-3'],
};

const isMatch = (model: AnyModelConfig, searchTerm: string) => {
const isMatch = <T extends AnyModelConfig>(model: WithStarred<T>, searchTerm: string) => {
const regex = getRegex(searchTerm);
const bases = BASE_KEYWORDS[model.base] ?? [model.base];
const testString =
Expand Down