Skip to content

Commit 1ded459

Browse files
refactor(ui): clean up related models impl for picker
1 parent d9024dc commit 1ded459

File tree

3 files changed

+84
-65
lines changed

3 files changed

+84
-65
lines changed

invokeai/frontend/web/src/common/components/Picker/Picker.tsx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGro
9191
return uniqueGroupKey in optionOrGroup && optionOrGroup[uniqueGroupKey] === true;
9292
};
9393

94+
export const isOption = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is T => {
95+
return !(uniqueGroupKey in optionOrGroup);
96+
};
97+
9498
const DefaultOptionComponent = typedMemo(<T extends object>({ option }: { option: T }) => {
9599
const { getOptionId } = usePickerContext();
96100
return <Text fontWeight="bold">{getOptionId(option)}</Text>;

invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,67 +4,29 @@ import { EMPTY_ARRAY } from 'app/store/constants';
44
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
55
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
66
import type { GroupStatusMap } from 'common/components/Picker/Picker';
7-
import { uniq } from 'es-toolkit/compat';
87
import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
9-
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
8+
import { selectBase } from 'features/controlLayers/store/paramsSlice';
109
import { ModelPicker } from 'features/parameters/components/ModelPicker';
1110
import { API_BASE_MODELS } from 'features/parameters/types/constants';
1211
import { memo, useCallback, useMemo } from 'react';
1312
import { useTranslation } from 'react-i18next';
14-
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
1513
import { useLoRAModels } from 'services/api/hooks/modelsByType';
1614
import type { LoRAModelConfig } from 'services/api/types';
1715

1816
const selectLoRAs = createSelector(selectLoRAsSlice, (loras) => loras.loras);
1917

20-
const selectSelectedModelKeys = createSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => {
21-
const keys: string[] = [];
22-
const main = params.model;
23-
const vae = params.vae;
24-
const refiner = params.refinerModel;
25-
const controlnet = params.controlLora;
26-
27-
if (main) {
28-
keys.push(main.key);
29-
}
30-
if (vae) {
31-
keys.push(vae.key);
32-
}
33-
if (refiner) {
34-
keys.push(refiner.key);
35-
}
36-
if (controlnet) {
37-
keys.push(controlnet.key);
38-
}
39-
for (const { model } of loras.loras) {
40-
keys.push(model.key);
41-
}
42-
43-
return uniq(keys);
44-
});
45-
4618
const LoRASelect = () => {
4719
const dispatch = useAppDispatch();
4820
const [modelConfigs, { isLoading }] = useLoRAModels();
4921
const { t } = useTranslation();
5022
const addedLoRAs = useAppSelector(selectLoRAs);
51-
const selectedKeys = useAppSelector(selectSelectedModelKeys);
52-
53-
const { relatedKeys } = useGetRelatedModelIdsBatchQuery(selectedKeys, {
54-
selectFromResult: ({ data }) => {
55-
if (!data) {
56-
return { relatedKeys: EMPTY_ARRAY };
57-
}
58-
return { relatedKeys: data };
59-
},
60-
});
6123

62-
const currentBaseModel = useAppSelector((state) => state.params.model?.base);
24+
const currentBaseModel = useAppSelector(selectBase);
6325

6426
// Filter to only show compatible LoRAs
6527
const compatibleLoRAs = useMemo(() => {
6628
if (!currentBaseModel) {
67-
return [];
29+
return EMPTY_ARRAY;
6830
}
6931
return modelConfigs.filter((model) => model.base === currentBaseModel);
7032
}, [modelConfigs, currentBaseModel]);
@@ -121,7 +83,6 @@ const LoRASelect = () => {
12183
modelConfigs={compatibleLoRAs}
12284
onChange={onChange}
12385
grouped={false}
124-
relatedModelKeys={relatedKeys}
12586
selectedModelConfig={undefined}
12687
allowEmpty
12788
placeholder={placeholder}

invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx

Lines changed: 77 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import type { BoxProps, ButtonProps, SystemStyleObject } from '@invoke-ai/ui-lib
22
import {
33
Button,
44
Flex,
5+
Icon,
56
Popover,
67
PopoverArrow,
78
PopoverBody,
@@ -12,12 +13,17 @@ import {
1213
Text,
1314
} from '@invoke-ai/ui-library';
1415
import { useStore } from '@nanostores/react';
16+
import { EMPTY_ARRAY } from 'app/store/constants';
17+
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
1518
import { $onClickGoToModelManager } from 'app/store/nanostores/onClickGoToModelManager';
1619
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
1720
import type { Group, PickerContextState } from 'common/components/Picker/Picker';
18-
import { buildGroup, getRegex, Picker, usePickerContext } from 'common/components/Picker/Picker';
21+
import { buildGroup, getRegex, isOption, Picker, usePickerContext } from 'common/components/Picker/Picker';
1922
import { useDisclosure } from 'common/hooks/useBoolean';
2023
import { typedMemo } from 'common/util/typedMemo';
24+
import { uniq } from 'es-toolkit/compat';
25+
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
26+
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
2127
import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore';
2228
import { BASE_COLOR_MAP } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
2329
import ModelImage from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelImage';
@@ -29,10 +35,39 @@ import { filesize } from 'filesize';
2935
import { memo, useCallback, useMemo, useRef } from 'react';
3036
import { Trans, useTranslation } from 'react-i18next';
3137
import { PiCaretDownBold, PiLinkSimple } from 'react-icons/pi';
38+
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
3239
import type { AnyModelConfig, BaseModelType } from 'services/api/types';
3340

41+
const selectSelectedModelKeys = createMemoizedSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => {
42+
const keys: string[] = [];
43+
const main = params.model;
44+
const vae = params.vae;
45+
const refiner = params.refinerModel;
46+
const controlnet = params.controlLora;
47+
48+
if (main) {
49+
keys.push(main.key);
50+
}
51+
if (vae) {
52+
keys.push(vae.key);
53+
}
54+
if (refiner) {
55+
keys.push(refiner.key);
56+
}
57+
if (controlnet) {
58+
keys.push(controlnet.key);
59+
}
60+
for (const { model } of loras.loras) {
61+
keys.push(model.key);
62+
}
63+
64+
return uniq(keys);
65+
});
66+
67+
type WithStarred<T> = T & { starred?: boolean };
68+
3469
// Type for models with starred field
35-
const getOptionId = <T extends AnyModelConfig>(modelConfig: T & { starred?: boolean }) => modelConfig.key;
70+
const getOptionId = <T extends AnyModelConfig>(modelConfig: WithStarred<T>) => modelConfig.key;
3671

3772
const ModelManagerLink = memo((props: ButtonProps) => {
3873
const onClickGoToModelManager = useStore($onClickGoToModelManager);
@@ -105,6 +140,15 @@ const getGroupColorSchemeFromModelConfig = (modelConfig: AnyModelConfig): string
105140
return BASE_COLOR_MAP[modelConfig.base];
106141
};
107142

143+
const relatedModelKeysQueryOptions = {
144+
selectFromResult: ({ data }) => {
145+
if (!data) {
146+
return { relatedModelKeys: EMPTY_ARRAY };
147+
}
148+
return { relatedModelKeys: data };
149+
},
150+
} satisfies Parameters<typeof useGetRelatedModelIdsBatchQuery>[1];
151+
108152
const popperModifiers = [
109153
{
110154
// Prevents the popover from "touching" the edges of the screen
@@ -113,13 +157,17 @@ const popperModifiers = [
113157
},
114158
];
115159

160+
const removeStarred = <T,>(obj: WithStarred<T>): T => {
161+
const { starred: _, ...rest } = obj;
162+
return rest as T;
163+
};
164+
116165
export const ModelPicker = typedMemo(
117166
<T extends AnyModelConfig = AnyModelConfig>({
118167
modelConfigs,
119168
selectedModelConfig,
120169
onChange,
121170
grouped,
122-
relatedModelKeys = [],
123171
getIsOptionDisabled,
124172
placeholder,
125173
allowEmpty,
@@ -133,7 +181,6 @@ export const ModelPicker = typedMemo(
133181
selectedModelConfig: T | undefined;
134182
onChange: (modelConfig: T) => void;
135183
grouped?: boolean;
136-
relatedModelKeys?: string[];
137184
getIsOptionDisabled?: (model: T) => boolean;
138185
placeholder?: string;
139186
allowEmpty?: boolean;
@@ -144,7 +191,11 @@ export const ModelPicker = typedMemo(
144191
initialGroupStates?: Record<string, boolean>;
145192
}) => {
146193
const { t } = useTranslation();
147-
const options = useMemo<(T & { starred?: boolean })[] | Group<T & { starred?: boolean }>[]>(() => {
194+
const selectedKeys = useAppSelector(selectSelectedModelKeys);
195+
196+
const { relatedModelKeys } = useGetRelatedModelIdsBatchQuery(selectedKeys, relatedModelKeysQueryOptions);
197+
198+
const options = useMemo<WithStarred<T>[] | Group<WithStarred<T>>[]>(() => {
148199
if (!grouped) {
149200
// Add starred field to model options and sort them
150201
const modelsWithStarred = modelConfigs.map((model) => ({
@@ -165,13 +216,13 @@ export const ModelPicker = typedMemo(
165216
}
166217

167218
// When all groups are disabled, we show all models
168-
const groups: Record<string, Group<T & { starred?: boolean }>> = {};
219+
const groups: Record<string, Group<WithStarred<T>>> = {};
169220

170221
for (const modelConfig of modelConfigs) {
171222
const groupId = getGroupIDFromModelConfig(modelConfig);
172223
let group = groups[groupId];
173224
if (!group) {
174-
group = buildGroup<T & { starred?: boolean }>({
225+
group = buildGroup<WithStarred<T>>({
175226
id: modelConfig.base,
176227
color: `${getGroupColorSchemeFromModelConfig(modelConfig)}.300`,
177228
shortName: getGroupShortNameFromModelConfig(modelConfig),
@@ -191,7 +242,7 @@ export const ModelPicker = typedMemo(
191242
}
192243
}
193244

194-
const _options: Group<T & { starred?: boolean }>[] = [];
245+
const _options: Group<WithStarred<T>>[] = [];
195246

196247
// Add groups in the original order
197248
for (const groupId of ['api', 'flux', 'cogview4', 'sdxl', 'sd-3', 'sd-2', 'sd-1']) {
@@ -216,19 +267,26 @@ export const ModelPicker = typedMemo(
216267
return _options;
217268
}, [grouped, modelConfigs, relatedModelKeys, t]);
218269
const popover = useDisclosure(false);
219-
const pickerRef = useRef<PickerContextState<T & { starred?: boolean }>>(null);
270+
const pickerRef = useRef<PickerContextState<WithStarred<T>>>(null);
271+
272+
const selectedOption = useMemo<WithStarred<T> | undefined>(() => {
273+
if (!selectedModelConfig) {
274+
return undefined;
275+
}
276+
277+
return options.filter(isOption).find((o) => o.key === selectedModelConfig.key);
278+
}, [options, selectedModelConfig]);
220279

221280
const onClose = useCallback(() => {
222281
popover.close();
223282
pickerRef.current?.$searchTerm.set('');
224283
}, [popover]);
225284

226285
const onSelect = useCallback(
227-
(model: T & { starred?: boolean }) => {
286+
(model: WithStarred<T>) => {
228287
onClose();
229288
// Remove the starred field before passing to onChange
230-
const { starred: _, ...modelWithoutStarred } = model;
231-
onChange(modelWithoutStarred as T);
289+
onChange(removeStarred(model));
232290
},
233291
[onChange, onClose]
234292
);
@@ -268,17 +326,13 @@ export const ModelPicker = typedMemo(
268326
<Portal appendToParentPortal={false}>
269327
<PopoverContent p={0} w={400} h={400}>
270328
<PopoverArrow />
271-
<PopoverBody p={0} w="full" h="full">
272-
<Picker<T & { starred?: boolean }>
329+
<PopoverBody p={0} w="full" h="full" borderWidth={1} borderColor="base.700" borderRadius="base">
330+
<Picker<WithStarred<T>>
273331
handleRef={pickerRef}
274332
optionsOrGroups={options}
275333
getOptionId={getOptionId<T>}
276334
onSelect={onSelect}
277-
selectedOption={
278-
selectedModelConfig
279-
? { ...selectedModelConfig, starred: relatedModelKeys.includes(selectedModelConfig.key) }
280-
: undefined
281-
}
335+
selectedOption={selectedOption}
282336
isMatch={isMatch<T>}
283337
OptionComponent={PickerOptionComponent<T>}
284338
noOptionsFallback={<NoOptionsFallback noOptionsText={noOptionsText} />}
@@ -332,16 +386,16 @@ const optionNameSx: SystemStyleObject = {
332386
};
333387

334388
const PickerOptionComponent = typedMemo(
335-
<T extends AnyModelConfig>({ option, ...rest }: { option: T & { starred?: boolean } } & BoxProps) => {
336-
const { $compactView } = usePickerContext<T & { starred?: boolean }>();
389+
<T extends AnyModelConfig>({ option, ...rest }: { option: WithStarred<T> } & BoxProps) => {
390+
const { $compactView } = usePickerContext<WithStarred<T>>();
337391
const compactView = useStore($compactView);
338392

339393
return (
340394
<Flex {...rest} sx={optionSx} data-is-compact={compactView}>
341395
{!compactView && option.cover_image && <ModelImage image_url={option.cover_image} />}
342396
<Flex flexDir="column" gap={1} flex={1}>
343397
<Flex gap={2} alignItems="center">
344-
{option.starred && <PiLinkSimple color="yellow" size={16} />}
398+
{option.starred && <Icon as={PiLinkSimple} color="invokeYellow.500" boxSize={4} />}
345399
<Text sx={optionNameSx} data-is-compact={compactView}>
346400
{option.name}
347401
</Text>
@@ -371,7 +425,7 @@ const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = {
371425
'sd-3': ['sd3', 'sd3.0', 'sd3.5', 'sd-3'],
372426
};
373427

374-
const isMatch = <T extends AnyModelConfig>(model: T & { starred?: boolean }, searchTerm: string) => {
428+
const isMatch = <T extends AnyModelConfig>(model: WithStarred<T>, searchTerm: string) => {
375429
const regex = getRegex(searchTerm);
376430
const bases = BASE_KEYWORDS[model.base] ?? [model.base];
377431
const testString =

0 commit comments

Comments
 (0)