Skip to content

Commit 3b33984

Browse files
feat(ui): genericizing picker
1 parent 8d83af7 commit 3b33984

File tree

3 files changed

+76
-37
lines changed

3 files changed

+76
-37
lines changed

invokeai/frontend/web/src/common/components/Picker/Picker.tsx

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
import type { SystemStyleObject } from '@invoke-ai/ui-library';
1+
import type { InputProps, SystemStyleObject } from '@invoke-ai/ui-library';
22
import { Box, Divider, Flex, Input, Text } from '@invoke-ai/ui-library';
33
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
44
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
55
import { useStateImperative } from 'common/hooks/useStateImperative';
6+
import { fixedForwardRef } from 'common/util/fixedForwardRef';
67
import { typedMemo } from 'common/util/typedMemo';
7-
import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton';
88
import type { ChangeEvent } from 'react';
99
import {
1010
createContext,
@@ -37,35 +37,36 @@ export type ImperativeModelPickerHandle = {
3737
setSearchTerm: (searchTerm: string) => void;
3838
};
3939

40-
const DefaultOptionComponent = typedMemo(({ id }: { id: string }) => {
41-
return <Text fontWeight="bold">{id}</Text>;
40+
const DefaultOptionComponent = typedMemo(<T extends object>({ option }: { option: T }) => {
41+
const { getOptionId } = usePickerContext();
42+
return <Text fontWeight="bold">{getOptionId(option)}</Text>;
4243
});
4344
DefaultOptionComponent.displayName = 'DefaultOptionComponent';
4445

45-
const DefaultGroupHeaderComponent = typedMemo(({ id }: { id: string }) => {
46-
return <Text fontWeight="bold">{id}</Text>;
46+
const DefaultGroupHeaderComponent = typedMemo(<T extends object>({ group }: { group: Group<T> }) => {
47+
return <Text fontWeight="bold">{group.id}</Text>;
4748
});
4849
DefaultGroupHeaderComponent.displayName = 'DefaultGroupHeaderComponent';
4950

50-
const DefaultNoOptionsFallback = typedMemo(() => {
51+
const DefaultNoOptionsFallbackComponent = typedMemo(() => {
5152
const { t } = useTranslation();
5253
return (
5354
<Flex w="full" h="full" alignItems="center" justifyContent="center">
5455
<Text variant="subtext">{t('common.noOptions')}</Text>
5556
</Flex>
5657
);
5758
});
58-
DefaultNoOptionsFallback.displayName = 'DefaultNoOptionsFallback';
59+
DefaultNoOptionsFallbackComponent.displayName = 'DefaultNoOptionsFallbackComponent';
5960

60-
const DefaultNoMatchesFallback = typedMemo(() => {
61+
const DefaultNoMatchesFallbackComponent = typedMemo(() => {
6162
const { t } = useTranslation();
6263
return (
6364
<Flex w="full" h="full" alignItems="center" justifyContent="center">
6465
<Text variant="subtext">{t('common.noMatches')}</Text>
6566
</Flex>
6667
);
6768
});
68-
DefaultNoMatchesFallback.displayName = 'DefaultNoMatchesFallback';
69+
DefaultNoMatchesFallbackComponent.displayName = 'DefaultNoMatchesFallbackComponent';
6970

7071
export type PickerProps<T extends object> = {
7172
options: (T | Group<T>)[];
@@ -75,9 +76,10 @@ export type PickerProps<T extends object> = {
7576
selectedItem?: T;
7677
onSelect?: (option: T) => void;
7778
onClose?: () => void;
78-
noOptionsFallback?: React.ReactNode;
79-
noMatchesFallback?: React.ReactNode;
8079
handleRef?: React.Ref<ImperativeModelPickerHandle>;
80+
SearchBarComponent?: ReturnType<typeof fixedForwardRef<HTMLInputElement, InputProps>>;
81+
NoOptionsFallbackComponent?: React.ComponentType;
82+
NoMatchesFallbackComponent?: React.ComponentType;
8183
OptionComponent?: React.ComponentType<{ option: T }>;
8284
GroupHeaderComponent?: React.ComponentType<{ group: Group<T> }>;
8385
};
@@ -88,10 +90,11 @@ type PickerContextState<T extends object> = {
8890
getIsDisabled?: (option: T) => boolean;
8991
setActiveOptionId: (id: string) => void;
9092
onSelectById: (id: string) => void;
91-
noOptionsFallback?: React.ReactNode;
92-
noMatchesFallback?: React.ReactNode;
93-
OptionComponent?: React.ComponentType<{ option: T }>;
94-
GroupHeaderComponent?: React.ComponentType<{ group: Group<T> }>;
93+
SearchBarComponent: ReturnType<typeof fixedForwardRef<HTMLInputElement, InputProps>>;
94+
NoOptionsFallbackComponent: React.ComponentType;
95+
NoMatchesFallbackComponent: React.ComponentType;
96+
OptionComponent: React.ComponentType<{ option: T }>;
97+
GroupHeaderComponent: React.ComponentType<{ group: Group<T> }>;
9598
};
9699

97100
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
@@ -180,13 +183,14 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
180183
handleRef,
181184
isMatch,
182185
getIsDisabled,
183-
noMatchesFallback = <DefaultNoMatchesFallback />,
184-
noOptionsFallback = <DefaultNoOptionsFallback />,
185186
onClose,
186187
onSelect,
187188
selectedItem,
188-
OptionComponent,
189-
GroupHeaderComponent,
189+
SearchBarComponent = DefaultPickerSearchBarComponent,
190+
NoMatchesFallbackComponent = DefaultNoMatchesFallbackComponent,
191+
NoOptionsFallbackComponent = DefaultNoOptionsFallbackComponent,
192+
OptionComponent = DefaultOptionComponent,
193+
GroupHeaderComponent = DefaultGroupHeaderComponent,
190194
} = props;
191195
const [activeOptionId, setActiveOptionId, getActiveOptionId] = useStateImperative(() =>
192196
getFirstOptionId(options, getOptionId)
@@ -346,22 +350,24 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
346350
isMatch,
347351
getIsDisabled,
348352
onSelectById,
349-
noOptionsFallback,
350-
noMatchesFallback,
353+
setActiveOptionId,
354+
SearchBarComponent,
355+
NoOptionsFallbackComponent,
356+
NoMatchesFallbackComponent,
351357
OptionComponent,
352358
GroupHeaderComponent,
353-
setActiveOptionId,
354359
}) satisfies PickerContextState<T>,
355360
[
356-
GroupHeaderComponent,
357-
OptionComponent,
358-
getIsDisabled,
359361
getOptionId,
360362
isMatch,
361-
noMatchesFallback,
362-
noOptionsFallback,
363+
getIsDisabled,
363364
onSelectById,
364365
setActiveOptionId,
366+
SearchBarComponent,
367+
NoOptionsFallbackComponent,
368+
NoMatchesFallbackComponent,
369+
OptionComponent,
370+
GroupHeaderComponent,
365371
]
366372
);
367373

@@ -378,14 +384,11 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
378384
gap={2}
379385
onKeyDown={onKeyDown}
380386
>
381-
<Flex gap={2} alignItems="center">
382-
<Input ref={inputRef} value={searchTerm} onChange={onChangeSearchTerm} placeholder="Filter" />
383-
<NavigateToModelManagerButton />
384-
</Flex>
387+
<SearchBarComponent ref={inputRef} value={searchTerm} onChange={onChangeSearchTerm} />
385388
<Divider />
386389
<Flex tabIndex={-1} w="full" flexGrow={1}>
387-
{flattenedOptions.length === 0 && noOptionsFallback}
388-
{flattenedOptions.length > 0 && flattenedFilteredOptions.length === 0 && noMatchesFallback}
390+
{flattenedOptions.length === 0 && <NoOptionsFallbackComponent />}
391+
{flattenedOptions.length > 0 && flattenedFilteredOptions.length === 0 && <NoMatchesFallbackComponent />}
389392
{flattenedOptions.length > 0 && flattenedFilteredOptions.length > 0 && (
390393
<ScrollableContent>
391394
<PickerList
@@ -402,6 +405,13 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
402405
});
403406
Picker.displayName = 'Picker';
404407

408+
const DefaultPickerSearchBarComponent = typedMemo(
409+
fixedForwardRef<HTMLInputElement, InputProps>((props, ref) => {
410+
return <Input placeholder="Search" ref={ref} {...props} />;
411+
})
412+
);
413+
DefaultPickerSearchBarComponent.displayName = 'DefaultPickerSearchBarComponent';
414+
405415
const PickerList = typedMemo(
406416
<T extends object>({
407417
items,
@@ -473,7 +483,7 @@ const PickerOptionGroup = typedMemo(
473483

474484
return (
475485
<Flex flexDir="column" gap={2} w="full">
476-
{GroupHeaderComponent ? <GroupHeaderComponent group={group} /> : <DefaultGroupHeaderComponent id={group.id} />}
486+
<GroupHeaderComponent group={group} />
477487
<Flex flexDir="column" gap={1} w="full">
478488
{group.options.map((item) => {
479489
const id = getOptionId(item);
@@ -535,7 +545,7 @@ const PickerOption = typedMemo(
535545
onPointerMove={isDisabled ? undefined : onPointerMove}
536546
onClick={isDisabled ? undefined : onClick}
537547
>
538-
{OptionComponent ? <OptionComponent option={option} /> : <DefaultOptionComponent id={id} />}
548+
<OptionComponent option={option} />
539549
</Box>
540550
);
541551
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import { forwardRef } from 'react';
2+
3+
/**
4+
* A forwardRef that works with generics and doesn't require the use of `as` to cast the type.
5+
* See: https://www.totaltypescript.com/forwardref-with-generic-components
6+
*/
7+
export function fixedForwardRef<T, P = object>(
8+
render: (props: P, ref: React.Ref<T>) => React.ReactNode
9+
): (props: P & React.RefAttributes<T>) => React.ReactNode {
10+
// @ts-expect-error: This is a workaround for forwardRef's crappy typing
11+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
12+
return forwardRef(render) as any;
13+
}

invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
import type { FormLabelProps } from '@invoke-ai/ui-library';
1+
import type { FormLabelProps, InputProps } from '@invoke-ai/ui-library';
22
import {
33
Box,
44
Button,
55
Expander,
66
Flex,
77
FormControlGroup,
88
FormLabel,
9+
Input,
910
Popover,
1011
PopoverArrow,
1112
PopoverBody,
@@ -23,6 +24,7 @@ import { InformationalPopover } from 'common/components/InformationalPopover/Inf
2324
import type { Group, ImperativeModelPickerHandle } from 'common/components/Picker/Picker';
2425
import { getRegex, Picker } from 'common/components/Picker/Picker';
2526
import { useDisclosure } from 'common/hooks/useBoolean';
27+
import { fixedForwardRef } from 'common/util/fixedForwardRef';
2628
import { typedMemo } from 'common/util/typedMemo';
2729
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
2830
import { selectIsCogView4, selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
@@ -222,6 +224,7 @@ const MainModelPicker = memo(() => {
222224
isMatch={isMatch}
223225
OptionComponent={PickerItemComponent}
224226
GroupHeaderComponent={PickerGroupHeaderComponent}
227+
SearchBarComponent={SearchBarComponent}
225228
/>
226229
</PopoverBody>
227230
</PopoverContent>
@@ -231,6 +234,19 @@ const MainModelPicker = memo(() => {
231234
});
232235
MainModelPicker.displayName = 'MainModelPicker';
233236

237+
const SearchBarComponent = typedMemo(
238+
fixedForwardRef<HTMLInputElement, InputProps>((props, ref) => {
239+
const { t } = useTranslation();
240+
return (
241+
<Flex gap={2} alignItems="center">
242+
<Input ref={ref} {...props} placeholder={t('common.search')} />
243+
<NavigateToModelManagerButton />
244+
</Flex>
245+
);
246+
})
247+
);
248+
SearchBarComponent.displayName = 'SearchBarComponent';
249+
234250
const PickerGroupHeaderComponent = memo(
235251
({ group }: { group: Group<AnyModelConfig, { name: string; description: string }> }) => {
236252
return (

0 commit comments

Comments
 (0)