Skip to content

Commit 12b70bc

Browse files
feat(ui): consolidated gallery (wip)
1 parent 990fec4 commit 12b70bc

File tree

4 files changed

+134
-68
lines changed

4 files changed

+134
-68
lines changed

invokeai/frontend/web/src/features/gallery/components/Gallery.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ export const GalleryPanel = memo(() => {
122122
</Collapse>
123123
<Divider pt={2} />
124124
<Flex w="full" h="full" pt={2}>
125-
{galleryView === 'images' ? <NewGallery /> : galleryView === 'videos' ? <VideoGallery /> : <NewGallery />}
125+
<NewGallery />
126126
</Flex>
127127
</Flex>
128128
);

invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx

Lines changed: 94 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { logger } from 'app/logging/logger';
44
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
55
import { getFocusedRegion, useIsRegionFocused } from 'common/hooks/focus';
66
import { useRangeBasedImageFetching } from 'features/gallery/hooks/useRangeBasedImageFetching';
7-
import type { selectGetImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors';
7+
import type { selectGetImageNamesQueryArgs, selectGetVideoIdsQueryArgs } from 'features/gallery/store/gallerySelectors';
88
import {
99
selectGalleryImageMinimumWidth,
1010
selectGalleryView,
@@ -28,23 +28,30 @@ import type {
2828
} from 'react-virtuoso';
2929
import { VirtuosoGrid } from 'react-virtuoso';
3030
import { imagesApi, useImageDTO, useStarImagesMutation, useUnstarImagesMutation } from 'services/api/endpoints/images';
31-
import { videosApi } from 'services/api/endpoints/videos';
31+
import { useStarVideosMutation, useUnstarVideosMutation, useVideoDTO, videosApi } from 'services/api/endpoints/videos';
3232
import { useDebounce } from 'use-debounce';
3333

3434
import { GalleryImage, GalleryImagePlaceholder } from './ImageGrid/GalleryImage';
3535
import { GallerySelectionCountTag } from './ImageGrid/GallerySelectionCountTag';
36-
import { useGalleryImageNames } from './use-gallery-image-names';
37-
import { useGalleryVideoIds } from './use-gallery-video-ids';
3836
import { GalleryVideo } from './ImageGrid/GalleryVideo';
37+
import { useGalleryImageNames, useGalleryVideoIds } from './use-gallery-image-names';
3938

4039
const log = logger('gallery');
4140

4241
type ListImageNamesQueryArgs = ReturnType<typeof selectGetImageNamesQueryArgs>;
42+
type ListVideoIdsQueryArgs = ReturnType<typeof selectGetVideoIdsQueryArgs>;
4343

44-
type GridContext = {
45-
queryArgs: ListImageNamesQueryArgs;
46-
imageNames: string[];
47-
};
44+
type GridContext =
45+
| {
46+
queryArgs: ListImageNamesQueryArgs;
47+
galleryView: 'images' | 'assets';
48+
itemIds: string[];
49+
}
50+
| {
51+
queryArgs: ListVideoIdsQueryArgs;
52+
galleryView: 'videos';
53+
itemIds: string[];
54+
};
4855

4956
const ImageAtPosition = memo(({ imageName }: { index: number; imageName: string }) => {
5057
/*
@@ -96,8 +103,8 @@ const VideoAtPosition = memo(({ itemId }: { index: number; itemId: string }) =>
96103
});
97104
VideoAtPosition.displayName = 'VideoAtPosition';
98105

99-
const computeItemKey: GridComputeItemKey<string, GridContext> = (index, imageName, { queryArgs }) => {
100-
return `${JSON.stringify(queryArgs)}-${imageName ?? index}`;
106+
const computeItemKey: GridComputeItemKey<string, GridContext> = (index, id, { queryArgs }) => {
107+
return `${JSON.stringify(queryArgs)}-${id ?? index}`;
101108
};
102109

103110
/**
@@ -106,7 +113,7 @@ const computeItemKey: GridComputeItemKey<string, GridContext> = (index, imageNam
106113
* TODO(psyche): We only need to do this when the gallery width changes, or when the galleryImageMinimumWidth value
107114
* changes. Cache this calculation.
108115
*/
109-
const getImagesPerRow = (rootEl: HTMLDivElement): number => {
116+
const getItemsPerRow = (rootEl: HTMLDivElement): number => {
110117
// Start from root and find virtuoso grid elements
111118
const gridElement = rootEl.querySelector('.virtuoso-grid-list');
112119

@@ -140,20 +147,20 @@ const getImagesPerRow = (rootEl: HTMLDivElement): number => {
140147
*
141148
* Instead, we use a more robust approach that iteratively calculates how many images fit in the row.
142149
*/
143-
let imagesPerRow = 0;
150+
let itemsPerRow = 0;
144151
let spaceUsed = 0;
145152

146153
// Floating point precision can cause imagesPerRow to be 1 too small. Adding 1px to the container size fixes
147154
// this, without the possibility of accidentally adding an extra column.
148155
while (spaceUsed + itemRect.width <= containerRect.width + 1) {
149-
imagesPerRow++; // Increment the number of images
156+
itemsPerRow++; // Increment the number of images
150157
spaceUsed += itemRect.width; // Add image size to the used space
151158
if (spaceUsed + gap <= containerRect.width) {
152159
spaceUsed += gap; // Add gap size to the used space after each image except after the last image
153160
}
154161
}
155162

156-
return Math.max(1, imagesPerRow);
163+
return Math.max(1, itemsPerRow);
157164
};
158165

159166
/**
@@ -180,9 +187,7 @@ const scrollIntoView = (
180187
return;
181188
}
182189

183-
const targetItem = rootEl.querySelector(
184-
`.virtuoso-grid-item:has([data-item-id="${targetItemId}"])`
185-
) as HTMLElement;
190+
const targetItem = rootEl.querySelector(`.virtuoso-grid-item:has([data-item-id="${targetItemId}"])`) as HTMLElement;
186191

187192
if (!targetItem) {
188193
if (targetIndex > range.endIndex) {
@@ -268,19 +273,19 @@ const scrollIntoView = (
268273
* If the image name is not found, return 0.
269274
* If no image name is provided, return 0.
270275
*/
271-
const getImageIndex = (imageName: string | undefined | null, imageNames: string[]) => {
272-
if (!imageName || imageNames.length === 0) {
276+
const getItemIndex = (targetItemId: string | undefined | null, itemIds: string[]) => {
277+
if (!targetItemId || itemIds.length === 0) {
273278
return 0;
274279
}
275-
const index = imageNames.findIndex((n) => n === imageName);
280+
const index = itemIds.findIndex((n) => n === targetItemId);
276281
return index >= 0 ? index : 0;
277282
};
278283

279284
/**
280285
* Handles keyboard navigation for the gallery.
281286
*/
282287
const useKeyboardNavigation = (
283-
imageNames: string[],
288+
itemIds: string[],
284289
virtuosoRef: React.RefObject<VirtuosoGridHandle>,
285290
rootRef: React.RefObject<HTMLDivElement>
286291
) => {
@@ -308,27 +313,28 @@ const useKeyboardNavigation = (
308313
return;
309314
}
310315

311-
if (imageNames.length === 0) {
316+
if (itemIds.length === 0) {
312317
return;
313318
}
314319

315-
const imagesPerRow = getImagesPerRow(rootEl);
320+
const itemsPerRow = getItemsPerRow(rootEl);
316321

317-
if (imagesPerRow === 0) {
322+
if (itemsPerRow === 0) {
318323
// This can happen if the grid is not yet rendered or has no items
319324
return;
320325
}
321326

322327
event.preventDefault();
323328

324329
const state = getState();
325-
const imageName = event.altKey
326-
? // When the user holds alt, we are changing the image to compare - if no image to compare is currently selected,
327-
// we start from the last selected image
328-
(selectImageToCompare(state) ?? selectLastSelectedImage(state))
329-
: selectLastSelectedImage(state);
330+
const imageName =
331+
event.altKey && selectGalleryView(state) !== 'videos'
332+
? // When the user holds alt, we are changing the image to compare - if no image to compare is currently selected,
333+
// we start from the last selected image
334+
(selectImageToCompare(state) ?? selectLastSelectedImage(state))
335+
: selectLastSelectedImage(state);
330336

331-
const currentIndex = getImageIndex(imageName, imageNames);
337+
const currentIndex = getItemIndex(imageName, itemIds);
332338

333339
let newIndex = currentIndex;
334340

@@ -342,7 +348,7 @@ const useKeyboardNavigation = (
342348
}
343349
break;
344350
case 'ArrowRight':
345-
if (currentIndex < imageNames.length - 1) {
351+
if (currentIndex < itemIds.length - 1) {
346352
newIndex = currentIndex + 1;
347353
// } else {
348354
// // Wrap to first image
@@ -351,34 +357,34 @@ const useKeyboardNavigation = (
351357
break;
352358
case 'ArrowUp':
353359
// If on first row, stay on current image
354-
if (currentIndex < imagesPerRow) {
360+
if (currentIndex < itemsPerRow) {
355361
newIndex = currentIndex;
356362
} else {
357-
newIndex = Math.max(0, currentIndex - imagesPerRow);
363+
newIndex = Math.max(0, currentIndex - itemsPerRow);
358364
}
359365
break;
360366
case 'ArrowDown':
361367
// If no images below, stay on current image
362-
if (currentIndex >= imageNames.length - imagesPerRow) {
368+
if (currentIndex >= itemIds.length - itemsPerRow) {
363369
newIndex = currentIndex;
364370
} else {
365-
newIndex = Math.min(imageNames.length - 1, currentIndex + imagesPerRow);
371+
newIndex = Math.min(itemIds.length - 1, currentIndex + itemsPerRow);
366372
}
367373
break;
368374
}
369375

370-
if (newIndex !== currentIndex && newIndex >= 0 && newIndex < imageNames.length) {
371-
const newImageName = imageNames[newIndex];
376+
if (newIndex !== currentIndex && newIndex >= 0 && newIndex < itemIds.length) {
377+
const newImageName = itemIds[newIndex];
372378
if (newImageName) {
373-
if (event.altKey) {
379+
if (selectGalleryView(state) !== 'videos' && event.altKey) {
374380
dispatch(imageToCompareChanged(newImageName));
375381
} else {
376382
dispatch(selectionChanged([newImageName]));
377383
}
378384
}
379385
}
380386
},
381-
[rootRef, virtuosoRef, imageNames, getState, dispatch]
387+
[rootRef, virtuosoRef, itemIds, getState, dispatch]
382388
);
383389

384390
useRegisteredHotkeys({
@@ -451,28 +457,28 @@ const useKeyboardNavigation = (
451457
* This is useful for keyboard navigation and ensuring the user can see their selection.
452458
* It only tracks the last selected image, not the image to compare.
453459
*/
454-
const useKeepSelectedImageInView = (
455-
imageNames: string[],
460+
const useKeepSelectedItemInView = (
461+
itemIds: string[],
456462
virtuosoRef: React.RefObject<VirtuosoGridHandle>,
457463
rootRef: React.RefObject<HTMLDivElement>,
458464
rangeRef: MutableRefObject<ListRange>
459465
) => {
460466
const selection = useAppSelector(selectSelection);
461467

462468
useEffect(() => {
463-
const targetImageName = selection.at(-1);
469+
const targetItemId = selection.at(-1);
464470
const virtuosoGridHandle = virtuosoRef.current;
465471
const rootEl = rootRef.current;
466472
const range = rangeRef.current;
467473

468-
if (!virtuosoGridHandle || !rootEl || !targetImageName || !imageNames || imageNames.length === 0) {
474+
if (!virtuosoGridHandle || !rootEl || !targetItemId || !itemIds || itemIds.length === 0) {
469475
return;
470476
}
471477

472478
setTimeout(() => {
473-
scrollIntoView(targetImageName, imageNames, rootEl, virtuosoGridHandle, range);
479+
scrollIntoView(targetItemId, itemIds, rootEl, virtuosoGridHandle, range);
474480
}, 0);
475-
}, [imageNames, rangeRef, rootRef, virtuosoRef, selection]);
481+
}, [itemIds, rangeRef, rootRef, virtuosoRef, selection]);
476482
};
477483

478484
/**
@@ -523,30 +529,45 @@ const useScrollableGallery = (rootRef: RefObject<HTMLDivElement>) => {
523529
const useStarImageHotkey = () => {
524530
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
525531
const selectionCount = useAppSelector(selectSelectionCount);
532+
const galleryView = useAppSelector(selectGalleryView);
526533
const isGalleryFocused = useIsRegionFocused('gallery');
527-
const imageDTO = useImageDTO(lastSelectedImage);
534+
const imageDTO = useImageDTO(galleryView !== 'videos' ? lastSelectedImage : null);
535+
const videoDTO = useVideoDTO(galleryView === 'videos' ? lastSelectedImage : null);
528536
const [starImages] = useStarImagesMutation();
529537
const [unstarImages] = useUnstarImagesMutation();
530538

539+
const [starVideos] = useStarVideosMutation();
540+
const [unstarVideos] = useUnstarVideosMutation();
541+
531542
const handleStarHotkey = useCallback(() => {
532-
if (!imageDTO) {
533-
return;
534-
}
535543
if (!isGalleryFocused) {
536544
return;
537545
}
538-
if (imageDTO.starred) {
539-
unstarImages({ image_names: [imageDTO.image_name] });
540-
} else {
541-
starImages({ image_names: [imageDTO.image_name] });
546+
if (galleryView === 'videos' && videoDTO) {
547+
if (videoDTO.starred) {
548+
unstarVideos({ video_ids: [videoDTO.video_id] });
549+
} else {
550+
starVideos({ video_ids: [videoDTO.video_id] });
551+
}
552+
} else if (galleryView !== 'videos' && imageDTO) {
553+
if (imageDTO.starred) {
554+
unstarImages({ image_names: [imageDTO.image_name] });
555+
} else {
556+
starImages({ image_names: [imageDTO.image_name] });
557+
}
542558
}
543559
}, [imageDTO, isGalleryFocused, starImages, unstarImages]);
544560

545561
useRegisteredHotkeys({
546562
id: 'starImage',
547563
category: 'gallery',
548564
callback: handleStarHotkey,
549-
options: { enabled: !!imageDTO && selectionCount === 1 && isGalleryFocused },
565+
options: {
566+
enabled:
567+
((galleryView === 'videos' && !!videoDTO) || (galleryView !== 'videos' && !!imageDTO)) &&
568+
selectionCount === 1 &&
569+
isGalleryFocused,
570+
},
550571
dependencies: [imageDTO, selectionCount, isGalleryFocused, handleStarHotkey],
551572
});
552573
};
@@ -558,18 +579,22 @@ export const NewGallery = memo(() => {
558579
const galleryView = useAppSelector(selectGalleryView);
559580

560581
// Get the ordered list of image names - this is our primary data source for virtualization
561-
const { queryArgs, imageNames, isLoading } = useGalleryImageNames();
562-
const { queryArgs: videoQueryArgs, videoIds, isLoading: isLoadingVideos } = useGalleryVideoIds();
582+
const galleryImageNamesQuery = useGalleryImageNames();
583+
const galleryVideoIdsQuery = useGalleryVideoIds();
563584

564585
// Use range-based fetching for bulk loading image DTOs into cache based on the visible range
565586
const { onRangeChanged } = useRangeBasedImageFetching({
566-
imageNames,
567-
enabled: !isLoading,
587+
imageNames: galleryImageNamesQuery.imageNames,
588+
enabled: !galleryImageNamesQuery.isLoading,
568589
});
569590

591+
const itemIds = galleryView === 'videos' ? galleryVideoIdsQuery.video_ids : galleryImageNamesQuery.imageNames;
592+
const queryArgs = galleryView === 'videos' ? galleryVideoIdsQuery.queryArgs : galleryImageNamesQuery.queryArgs;
593+
const isLoading = galleryView === 'videos' ? galleryVideoIdsQuery.isLoading : galleryImageNamesQuery.isLoading;
570594
useStarImageHotkey();
571-
useKeepSelectedImageInView(imageNames, virtuosoRef, rootRef, rangeRef);
572-
useKeyboardNavigation(imageNames, virtuosoRef, rootRef);
595+
596+
useKeepSelectedItemInView(itemIds, virtuosoRef, rootRef, rangeRef);
597+
useKeyboardNavigation(itemIds, virtuosoRef, rootRef);
573598
const scrollerRef = useScrollableGallery(rootRef);
574599

575600
/*
@@ -584,7 +609,7 @@ export const NewGallery = memo(() => {
584609
[onRangeChanged]
585610
);
586611

587-
const context = useMemo<GridContext>(() => ({ imageNames, queryArgs, videoIds, videoQueryArgs }), [imageNames, queryArgs, videoIds, videoQueryArgs]);
612+
const context = useMemo<GridContext>(() => ({ itemIds, galleryView, queryArgs }), [itemIds, queryArgs, galleryView]);
588613

589614
if (isLoading) {
590615
return (
@@ -595,7 +620,7 @@ export const NewGallery = memo(() => {
595620
);
596621
}
597622

598-
if (imageNames.length === 0) {
623+
if (itemIds.length === 0) {
599624
return (
600625
<Flex w="full" h="full" alignItems="center" justifyContent="center">
601626
<Text color="base.300">No images found</Text>
@@ -609,7 +634,7 @@ export const NewGallery = memo(() => {
609634
<VirtuosoGrid<string, GridContext>
610635
ref={virtuosoRef}
611636
context={context}
612-
data={galleryView === 'images' ? imageNames : videoIds}
637+
data={itemIds}
613638
increaseViewportBy={4096}
614639
itemContent={itemContent}
615640
computeItemKey={computeItemKey}
@@ -652,8 +677,12 @@ const ListComponent: GridComponents<GridContext>['List'] = forwardRef(({ context
652677
});
653678
ListComponent.displayName = 'ListComponent';
654679

655-
const itemContent: GridItemContent<string, GridContext> = (index, imageName) => {
656-
return <ImageAtPosition index={index} imageName={imageName} />;
680+
const itemContent: GridItemContent<string, GridContext> = (index, itemId, { galleryView }) => {
681+
if (galleryView === 'videos') {
682+
return <VideoAtPosition index={index} itemId={itemId} />;
683+
} else {
684+
return <ImageAtPosition index={index} imageName={itemId} />;
685+
}
657686
};
658687

659688
const ItemComponent: GridComponents<GridContext>['Item'] = forwardRef(({ context: _, ...rest }, ref) => (

0 commit comments

Comments
 (0)