@@ -2,6 +2,7 @@ import type { BoxProps, ButtonProps, SystemStyleObject } from '@invoke-ai/ui-lib
2
2
import {
3
3
Button ,
4
4
Flex ,
5
+ Icon ,
5
6
Popover ,
6
7
PopoverArrow ,
7
8
PopoverBody ,
@@ -12,12 +13,17 @@ import {
12
13
Text ,
13
14
} from '@invoke-ai/ui-library' ;
14
15
import { useStore } from '@nanostores/react' ;
16
+ import { EMPTY_ARRAY } from 'app/store/constants' ;
17
+ import { createMemoizedSelector } from 'app/store/createMemoizedSelector' ;
15
18
import { $onClickGoToModelManager } from 'app/store/nanostores/onClickGoToModelManager' ;
16
19
import { useAppDispatch , useAppSelector } from 'app/store/storeHooks' ;
17
20
import type { Group , PickerContextState } from 'common/components/Picker/Picker' ;
18
- import { buildGroup , getRegex , Picker , usePickerContext } from 'common/components/Picker/Picker' ;
21
+ import { buildGroup , getRegex , isOption , Picker , usePickerContext } from 'common/components/Picker/Picker' ;
19
22
import { useDisclosure } from 'common/hooks/useBoolean' ;
20
23
import { typedMemo } from 'common/util/typedMemo' ;
24
+ import { uniq } from 'es-toolkit/compat' ;
25
+ import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice' ;
26
+ import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice' ;
21
27
import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore' ;
22
28
import { BASE_COLOR_MAP } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge' ;
23
29
import ModelImage from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelImage' ;
@@ -29,10 +35,39 @@ import { filesize } from 'filesize';
29
35
import { memo , useCallback , useMemo , useRef } from 'react' ;
30
36
import { Trans , useTranslation } from 'react-i18next' ;
31
37
import { PiCaretDownBold , PiLinkSimple } from 'react-icons/pi' ;
38
+ import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships' ;
32
39
import type { AnyModelConfig , BaseModelType } from 'services/api/types' ;
33
40
41
+ const selectSelectedModelKeys = createMemoizedSelector ( selectParamsSlice , selectLoRAsSlice , ( params , loras ) => {
42
+ const keys : string [ ] = [ ] ;
43
+ const main = params . model ;
44
+ const vae = params . vae ;
45
+ const refiner = params . refinerModel ;
46
+ const controlnet = params . controlLora ;
47
+
48
+ if ( main ) {
49
+ keys . push ( main . key ) ;
50
+ }
51
+ if ( vae ) {
52
+ keys . push ( vae . key ) ;
53
+ }
54
+ if ( refiner ) {
55
+ keys . push ( refiner . key ) ;
56
+ }
57
+ if ( controlnet ) {
58
+ keys . push ( controlnet . key ) ;
59
+ }
60
+ for ( const { model } of loras . loras ) {
61
+ keys . push ( model . key ) ;
62
+ }
63
+
64
+ return uniq ( keys ) ;
65
+ } ) ;
66
+
67
+ type WithStarred < T > = T & { starred ?: boolean } ;
68
+
34
69
// Type for models with starred field
35
- const getOptionId = < T extends AnyModelConfig > ( modelConfig : T & { starred ?: boolean } ) => modelConfig . key ;
70
+ const getOptionId = < T extends AnyModelConfig > ( modelConfig : WithStarred < T > ) => modelConfig . key ;
36
71
37
72
const ModelManagerLink = memo ( ( props : ButtonProps ) => {
38
73
const onClickGoToModelManager = useStore ( $onClickGoToModelManager ) ;
@@ -105,6 +140,15 @@ const getGroupColorSchemeFromModelConfig = (modelConfig: AnyModelConfig): string
105
140
return BASE_COLOR_MAP [ modelConfig . base ] ;
106
141
} ;
107
142
143
+ const relatedModelKeysQueryOptions = {
144
+ selectFromResult : ( { data } ) => {
145
+ if ( ! data ) {
146
+ return { relatedModelKeys : EMPTY_ARRAY } ;
147
+ }
148
+ return { relatedModelKeys : data } ;
149
+ } ,
150
+ } satisfies Parameters < typeof useGetRelatedModelIdsBatchQuery > [ 1 ] ;
151
+
108
152
const popperModifiers = [
109
153
{
110
154
// Prevents the popover from "touching" the edges of the screen
@@ -113,13 +157,17 @@ const popperModifiers = [
113
157
} ,
114
158
] ;
115
159
160
+ const removeStarred = < T , > ( obj : WithStarred < T > ) : T => {
161
+ const { starred : _ , ...rest } = obj ;
162
+ return rest as T ;
163
+ } ;
164
+
116
165
export const ModelPicker = typedMemo (
117
166
< T extends AnyModelConfig = AnyModelConfig > ( {
118
167
modelConfigs,
119
168
selectedModelConfig,
120
169
onChange,
121
170
grouped,
122
- relatedModelKeys = [ ] ,
123
171
getIsOptionDisabled,
124
172
placeholder,
125
173
allowEmpty,
@@ -133,7 +181,6 @@ export const ModelPicker = typedMemo(
133
181
selectedModelConfig : T | undefined ;
134
182
onChange : ( modelConfig : T ) => void ;
135
183
grouped ?: boolean ;
136
- relatedModelKeys ?: string [ ] ;
137
184
getIsOptionDisabled ?: ( model : T ) => boolean ;
138
185
placeholder ?: string ;
139
186
allowEmpty ?: boolean ;
@@ -144,7 +191,11 @@ export const ModelPicker = typedMemo(
144
191
initialGroupStates ?: Record < string , boolean > ;
145
192
} ) => {
146
193
const { t } = useTranslation ( ) ;
147
- const options = useMemo < ( T & { starred ?: boolean } ) [ ] | Group < T & { starred ?: boolean } > [ ] > ( ( ) => {
194
+ const selectedKeys = useAppSelector ( selectSelectedModelKeys ) ;
195
+
196
+ const { relatedModelKeys } = useGetRelatedModelIdsBatchQuery ( selectedKeys , relatedModelKeysQueryOptions ) ;
197
+
198
+ const options = useMemo < WithStarred < T > [ ] | Group < WithStarred < T > > [ ] > ( ( ) => {
148
199
if ( ! grouped ) {
149
200
// Add starred field to model options and sort them
150
201
const modelsWithStarred = modelConfigs . map ( ( model ) => ( {
@@ -165,13 +216,13 @@ export const ModelPicker = typedMemo(
165
216
}
166
217
167
218
// When all groups are disabled, we show all models
168
- const groups : Record < string , Group < T & { starred ?: boolean } > > = { } ;
219
+ const groups : Record < string , Group < WithStarred < T > > > = { } ;
169
220
170
221
for ( const modelConfig of modelConfigs ) {
171
222
const groupId = getGroupIDFromModelConfig ( modelConfig ) ;
172
223
let group = groups [ groupId ] ;
173
224
if ( ! group ) {
174
- group = buildGroup < T & { starred ?: boolean } > ( {
225
+ group = buildGroup < WithStarred < T > > ( {
175
226
id : modelConfig . base ,
176
227
color : `${ getGroupColorSchemeFromModelConfig ( modelConfig ) } .300` ,
177
228
shortName : getGroupShortNameFromModelConfig ( modelConfig ) ,
@@ -191,7 +242,7 @@ export const ModelPicker = typedMemo(
191
242
}
192
243
}
193
244
194
- const _options : Group < T & { starred ?: boolean } > [ ] = [ ] ;
245
+ const _options : Group < WithStarred < T > > [ ] = [ ] ;
195
246
196
247
// Add groups in the original order
197
248
for ( const groupId of [ 'api' , 'flux' , 'cogview4' , 'sdxl' , 'sd-3' , 'sd-2' , 'sd-1' ] ) {
@@ -216,19 +267,26 @@ export const ModelPicker = typedMemo(
216
267
return _options ;
217
268
} , [ grouped , modelConfigs , relatedModelKeys , t ] ) ;
218
269
const popover = useDisclosure ( false ) ;
219
- const pickerRef = useRef < PickerContextState < T & { starred ?: boolean } > > ( null ) ;
270
+ const pickerRef = useRef < PickerContextState < WithStarred < T > > > ( null ) ;
271
+
272
+ const selectedOption = useMemo < WithStarred < T > | undefined > ( ( ) => {
273
+ if ( ! selectedModelConfig ) {
274
+ return undefined ;
275
+ }
276
+
277
+ return options . filter ( isOption ) . find ( ( o ) => o . key === selectedModelConfig . key ) ;
278
+ } , [ options , selectedModelConfig ] ) ;
220
279
221
280
const onClose = useCallback ( ( ) => {
222
281
popover . close ( ) ;
223
282
pickerRef . current ?. $searchTerm . set ( '' ) ;
224
283
} , [ popover ] ) ;
225
284
226
285
const onSelect = useCallback (
227
- ( model : T & { starred ?: boolean } ) => {
286
+ ( model : WithStarred < T > ) => {
228
287
onClose ( ) ;
229
288
// Remove the starred field before passing to onChange
230
- const { starred : _ , ...modelWithoutStarred } = model ;
231
- onChange ( modelWithoutStarred as T ) ;
289
+ onChange ( removeStarred ( model ) ) ;
232
290
} ,
233
291
[ onChange , onClose ]
234
292
) ;
@@ -268,17 +326,13 @@ export const ModelPicker = typedMemo(
268
326
< Portal appendToParentPortal = { false } >
269
327
< PopoverContent p = { 0 } w = { 400 } h = { 400 } >
270
328
< PopoverArrow />
271
- < PopoverBody p = { 0 } w = "full" h = "full" >
272
- < Picker < T & { starred ?: boolean } >
329
+ < PopoverBody p = { 0 } w = "full" h = "full" borderWidth = { 1 } borderColor = "base.700" borderRadius = "base" >
330
+ < Picker < WithStarred < T > >
273
331
handleRef = { pickerRef }
274
332
optionsOrGroups = { options }
275
333
getOptionId = { getOptionId < T > }
276
334
onSelect = { onSelect }
277
- selectedOption = {
278
- selectedModelConfig
279
- ? { ...selectedModelConfig , starred : relatedModelKeys . includes ( selectedModelConfig . key ) }
280
- : undefined
281
- }
335
+ selectedOption = { selectedOption }
282
336
isMatch = { isMatch < T > }
283
337
OptionComponent = { PickerOptionComponent < T > }
284
338
noOptionsFallback = { < NoOptionsFallback noOptionsText = { noOptionsText } /> }
@@ -332,16 +386,16 @@ const optionNameSx: SystemStyleObject = {
332
386
} ;
333
387
334
388
const PickerOptionComponent = typedMemo (
335
- < T extends AnyModelConfig > ( { option, ...rest } : { option : T & { starred ?: boolean } } & BoxProps ) => {
336
- const { $compactView } = usePickerContext < T & { starred ?: boolean } > ( ) ;
389
+ < T extends AnyModelConfig > ( { option, ...rest } : { option : WithStarred < T > } & BoxProps ) => {
390
+ const { $compactView } = usePickerContext < WithStarred < T > > ( ) ;
337
391
const compactView = useStore ( $compactView ) ;
338
392
339
393
return (
340
394
< Flex { ...rest } sx = { optionSx } data-is-compact = { compactView } >
341
395
{ ! compactView && option . cover_image && < ModelImage image_url = { option . cover_image } /> }
342
396
< Flex flexDir = "column" gap = { 1 } flex = { 1 } >
343
397
< Flex gap = { 2 } alignItems = "center" >
344
- { option . starred && < PiLinkSimple color = "yellow" size = { 16 } /> }
398
+ { option . starred && < Icon as = { PiLinkSimple } color = "invokeYellow.500" boxSize = { 4 } /> }
345
399
< Text sx = { optionNameSx } data-is-compact = { compactView } >
346
400
{ option . name }
347
401
</ Text >
@@ -371,7 +425,7 @@ const BASE_KEYWORDS: { [key in BaseModelType]?: string[] } = {
371
425
'sd-3' : [ 'sd3' , 'sd3.0' , 'sd3.5' , 'sd-3' ] ,
372
426
} ;
373
427
374
- const isMatch = < T extends AnyModelConfig > ( model : T & { starred ?: boolean } , searchTerm : string ) => {
428
+ const isMatch = < T extends AnyModelConfig > ( model : WithStarred < T > , searchTerm : string ) => {
375
429
const regex = getRegex ( searchTerm ) ;
376
430
const bases = BASE_KEYWORDS [ model . base ] ?? [ model . base ] ;
377
431
const testString =
0 commit comments