Skip to content

Commit 990fec4

Browse files
feat(ui): gallery optimistic updates for video
1 parent 4598489 commit 990fec4

File tree

3 files changed

+247
-22
lines changed

3 files changed

+247
-22
lines changed

invokeai/frontend/web/src/services/api/endpoints/videos.ts

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
1+
import { getStore } from 'app/store/nanostores/store';
12
import type { paths } from 'services/api/schema';
2-
import type {
3-
GetVideoIdsArgs,
4-
GetVideoIdsResult,
5-
VideoDTO,
6-
} from 'services/api/types';
3+
import type { GetVideoIdsArgs, GetVideoIdsResult, VideoDTO } from 'services/api/types';
4+
import {
5+
getTagsToInvalidateForBoardAffectingMutation,
6+
getTagsToInvalidateForVideoMutation,
7+
} from 'services/api/util/tagInvalidation';
78
import stableHash from 'stable-hash';
89
import type { Param0 } from 'tsafe';
910

1011
import { api, buildV1Url, LIST_TAG } from '..';
11-
import { getTagsToInvalidateForBoardAffectingMutation, getTagsToInvalidateForImageMutation, getTagsToInvalidateForVideoMutation } from '../util/tagInvalidation';
1212

1313
/**
1414
* Builds an endpoint URL for the videos router
1515
* @example
1616
* buildVideosUrl('some-path')
1717
* // '/api/v1/videos/some-path'
1818
*/
19-
const buildVideosUrl = (path: string = '', query?: Parameters<typeof buildV1Url>[1]) =>
19+
const buildVideosUrl = (path: string = '', query?: Parameters<typeof buildV1Url>[1]) =>
2020
buildV1Url(`videos/${path}`, query);
2121

22-
const buildBoardVideosUrl = (path: string = '') => buildV1Url(`board_videos/${path}`);
22+
const buildBoardVideosUrl = (path: string = '') => buildV1Url(`board_videos/${path}`);
2323

2424
export const videosApi = api.injectEndpoints({
2525
endpoints: (build) => ({
@@ -31,7 +31,6 @@ export const videosApi = api.injectEndpoints({
3131
query: (video_id) => ({ url: buildVideosUrl(`i/${video_id}`) }),
3232
providesTags: (result, error, video_id) => [{ type: 'Video', id: video_id }],
3333
}),
34-
3534

3635
/**
3736
* Get ordered list of image names for selection operations
@@ -204,4 +203,24 @@ export const {
204203
useRemoveVideosFromBoardMutation,
205204
} = videosApi;
206205

207-
206+
/**
207+
* Imperative RTKQ helper to fetch an VideoDTO.
208+
* @param id The id of the video to fetch
209+
* @param options The options for the query. By default, the query will not subscribe to the store.
210+
* @returns The ImageDTO if found, otherwise null
211+
*/
212+
export const getVideoDTOSafe = async (
213+
id: string,
214+
options?: Parameters<typeof videosApi.endpoints.getVideoDTOsByNames.initiate>[1]
215+
): Promise<VideoDTO | null> => {
216+
const _options = {
217+
subscribe: false,
218+
...options,
219+
};
220+
const req = getStore().dispatch(videosApi.endpoints.getVideoDTOsByNames.initiate({ video_ids: [id] }, _options));
221+
try {
222+
return (await req.unwrap())[0] ?? null;
223+
} catch {
224+
return null;
225+
}
226+
};

invokeai/frontend/web/src/services/api/util/optimisticUpdates.ts

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import type { OrderDir } from 'features/gallery/store/types';
2-
import type { GetImageNamesResult, ImageDTO } from 'services/api/types';
2+
import type { GetImageNamesResult, GetVideoIdsResult, ImageDTO, VideoDTO } from 'services/api/types';
33

44
/**
55
* Calculates the optimal insertion position for a new image in the names list.
@@ -57,3 +57,60 @@ export function insertImageIntoNamesResult(
5757
total_count: currentResult.total_count + 1,
5858
};
5959
}
60+
61+
/**
62+
* Calculates the optimal insertion position for a new image in the names list.
63+
* For starred_first=true: starred images go to position 0, unstarred go after all starred images
64+
* For starred_first=false: all new images go to position 0 (newest first)
65+
*/
66+
function calculateVideoInsertionPosition(
67+
videoDTO: VideoDTO,
68+
starredFirst: boolean,
69+
starredCount: number,
70+
orderDir: OrderDir = 'DESC'
71+
): number {
72+
if (!starredFirst) {
73+
// When starred_first is false, insertion depends on order direction
74+
return orderDir === 'DESC' ? 0 : Number.MAX_SAFE_INTEGER;
75+
}
76+
77+
// When starred_first is true
78+
if (videoDTO.starred) {
79+
// Starred images: beginning for desc, after existing starred for asc
80+
return orderDir === 'DESC' ? 0 : starredCount;
81+
}
82+
83+
// Unstarred images go after all starred images
84+
return orderDir === 'DESC' ? starredCount : Number.MAX_SAFE_INTEGER;
85+
}
86+
87+
/**
88+
* Optimistically inserts a new image into the ImageNamesResult at the correct position
89+
*/
90+
export function insertVideoIntoGetVideoIdsResult(
91+
currentResult: GetVideoIdsResult,
92+
videoDTO: VideoDTO,
93+
starredFirst: boolean,
94+
orderDir: OrderDir = 'DESC'
95+
): GetVideoIdsResult {
96+
// Don't insert if the image is already in the list
97+
if (currentResult.video_ids.includes(videoDTO.video_id)) {
98+
return currentResult;
99+
}
100+
101+
const insertPosition = calculateVideoInsertionPosition(videoDTO, starredFirst, currentResult.starred_count, orderDir);
102+
103+
const newVideoIds = [...currentResult.video_ids];
104+
// Handle MAX_SAFE_INTEGER by pushing to end
105+
if (insertPosition >= newVideoIds.length) {
106+
newVideoIds.push(videoDTO.video_id);
107+
} else {
108+
newVideoIds.splice(insertPosition, 0, videoDTO.video_id);
109+
}
110+
111+
return {
112+
video_ids: newVideoIds,
113+
starred_count: starredFirst && videoDTO.starred ? currentResult.starred_count + 1 : currentResult.starred_count,
114+
total_count: currentResult.total_count + 1,
115+
};
116+
}

invokeai/frontend/web/src/services/events/onInvocationComplete.tsx

Lines changed: 160 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import {
55
selectAutoSwitch,
66
selectGalleryView,
77
selectGetImageNamesQueryArgs,
8+
selectGetVideoIdsQueryArgs,
89
selectListBoardsQueryArgs,
910
selectSelectedBoardId,
1011
} from 'features/gallery/store/gallerySelectors';
@@ -17,9 +18,10 @@ import { generatedVideoChanged } from 'features/parameters/store/videoSlice';
1718
import type { LRUCache } from 'lru-cache';
1819
import { boardsApi } from 'services/api/endpoints/boards';
1920
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
20-
import type { ImageDTO, S } from 'services/api/types';
21+
import { getVideoDTOSafe, videosApi } from 'services/api/endpoints/videos';
22+
import type { ImageDTO, S, VideoDTO } from 'services/api/types';
2123
import { getCategories } from 'services/api/util';
22-
import { insertImageIntoNamesResult } from 'services/api/util/optimisticUpdates';
24+
import { insertImageIntoNamesResult, insertVideoIntoGetVideoIdsResult } from 'services/api/util/optimisticUpdates';
2325
import { $lastProgressEvent } from 'services/events/stores';
2426
import stableHash from 'stable-hash';
2527
import type { Param0 } from 'tsafe';
@@ -185,6 +187,154 @@ export const buildOnInvocationComplete = (
185187
}
186188
};
187189

190+
const addVideosToGallery = async (data: S['InvocationCompleteEvent']) => {
191+
if (nodeTypeDenylist.includes(data.invocation.type)) {
192+
log.trace(`Skipping denylisted node type (${data.invocation.type})`);
193+
return;
194+
}
195+
196+
const videoDTOs = await getResultVideoDTOs(data);
197+
if (videoDTOs.length === 0) {
198+
return;
199+
}
200+
201+
// For efficiency's sake, we want to minimize the number of dispatches and invalidations we do.
202+
// We'll keep track of each change we need to make and do them all at once.
203+
const boardTotalAdditions: Record<string, number> = {};
204+
const getVideoIdsArg = selectGetVideoIdsQueryArgs(getState());
205+
206+
for (const videoDTO of videoDTOs) {
207+
if (videoDTO.is_intermediate) {
208+
return;
209+
}
210+
211+
const board_id = videoDTO.board_id ?? 'none';
212+
213+
boardTotalAdditions[board_id] = (boardTotalAdditions[board_id] || 0) + 1;
214+
}
215+
216+
// Update all the board image totals at once
217+
const entries: Param0<typeof boardsApi.util.upsertQueryEntries> = [];
218+
for (const [boardId, amountToAdd] of objectEntries(boardTotalAdditions)) {
219+
// upsertQueryEntries doesn't provide a "recipe" function for the update - we must provide the new value
220+
// directly. So we need to select the board totals first.
221+
const total = boardsApi.endpoints.getBoardImagesTotal.select(boardId)(getState()).data?.total;
222+
if (total === undefined) {
223+
// No cache exists for this board, so we can't update it.
224+
continue;
225+
}
226+
entries.push({
227+
endpointName: 'getBoardImagesTotal',
228+
arg: boardId,
229+
value: { total: total + amountToAdd },
230+
});
231+
}
232+
dispatch(boardsApi.util.upsertQueryEntries(entries));
233+
234+
dispatch(
235+
boardsApi.util.updateQueryData('listAllBoards', selectListBoardsQueryArgs(getState()), (draft) => {
236+
for (const board of draft) {
237+
board.image_count = board.image_count + (boardTotalAdditions[board.board_id] ?? 0);
238+
}
239+
})
240+
);
241+
242+
/**
243+
* Optimistic update and cache invalidation for image names queries that match this image's board and categories.
244+
* - Optimistic update for the cache that does not have a search term (we cannot derive the correct insertion
245+
* position when a search term is present).
246+
* - Cache invalidation for the query that has a search term, so it will be refetched.
247+
*
248+
* Note: The image DTO objects are already implicitly cached by the getResultImageDTOs function. We do not need
249+
* to explicitly cache them again here.
250+
*/
251+
for (const videoDTO of videoDTOs) {
252+
// Override board_id and categories for this specific image to build the "expected" args for the query.
253+
const videoSpecificArgs = {
254+
board_id: videoDTO.board_id ?? 'none',
255+
};
256+
257+
const expectedQueryArgs = {
258+
...getVideoIdsArg,
259+
...videoSpecificArgs,
260+
search_term: '',
261+
};
262+
263+
// If the cache for the query args provided here does not exist, RTK Query will ignore the update.
264+
dispatch(
265+
videosApi.util.updateQueryData(
266+
'getVideoIds',
267+
{
268+
...getVideoIdsArg,
269+
...videoSpecificArgs,
270+
search_term: '',
271+
},
272+
(draft) => {
273+
const updatedResult = insertVideoIntoGetVideoIdsResult(
274+
draft,
275+
videoDTO,
276+
expectedQueryArgs.starred_first ?? true,
277+
expectedQueryArgs.order_dir
278+
);
279+
280+
draft.video_ids = updatedResult.video_ids;
281+
draft.starred_count = updatedResult.starred_count;
282+
draft.total_count = updatedResult.total_count;
283+
}
284+
)
285+
);
286+
287+
// If there is a search term present, we need to invalidate that query to ensure the search results are updated.
288+
if (getVideoIdsArg.search_term) {
289+
const expectedQueryArgs = {
290+
...getVideoIdsArg,
291+
...videoSpecificArgs,
292+
};
293+
dispatch(imagesApi.util.invalidateTags([{ type: 'ImageNameList', id: stableHash(expectedQueryArgs) }]));
294+
}
295+
}
296+
297+
// No need to invalidate tags since we're doing optimistic updates
298+
// Board totals are already updated above via upsertQueryEntries
299+
300+
const autoSwitch = selectAutoSwitch(getState());
301+
302+
if (!autoSwitch) {
303+
return;
304+
}
305+
306+
// Finally, we may need to autoswitch to the new image. We'll only do it for the last image in the list.
307+
const lastVideoDTO = videoDTOs.at(-1);
308+
309+
if (!lastVideoDTO) {
310+
return;
311+
}
312+
313+
const { video_id } = lastVideoDTO;
314+
const board_id = lastVideoDTO.board_id ?? 'none';
315+
316+
// With optimistic updates, we can immediately switch to the new image
317+
const selectedBoardId = selectSelectedBoardId(getState());
318+
319+
// If the image is from a different board, switch to that board & select the image - otherwise just select the
320+
// image. This implicitly changes the view to 'images' if it was not already.
321+
if (board_id !== selectedBoardId) {
322+
dispatch(
323+
boardIdSelected({
324+
boardId: board_id,
325+
selectedImageName: video_id,
326+
})
327+
);
328+
} else {
329+
// Ensure we are on the 'images' gallery view - that's where this image will be displayed
330+
const galleryView = selectGalleryView(getState());
331+
if (galleryView !== 'videos') {
332+
dispatch(galleryViewChanged('videos'));
333+
}
334+
// Select the image immediately since we've optimistically updated the cache
335+
dispatch(imageSelected(lastVideoDTO.video_id));
336+
}
337+
};
188338
const getResultImageDTOs = async (data: S['InvocationCompleteEvent']): Promise<ImageDTO[]> => {
189339
const { result } = data;
190340
const imageDTOs: ImageDTO[] = [];
@@ -206,17 +356,20 @@ export const buildOnInvocationComplete = (
206356
return imageDTOs;
207357
};
208358

209-
const getResultVideoFields = (data: S['InvocationCompleteEvent']): VideoField[] => {
359+
const getResultVideoDTOs = async (data: S['InvocationCompleteEvent']): Promise<VideoDTO[]> => {
210360
const { result } = data;
211-
const videoFields: VideoField[] = [];
361+
const videoDTOs: VideoDTO[] = [];
212362

213363
for (const [_name, value] of objectEntries(result)) {
214364
if (isVideoField(value)) {
215-
videoFields.push(value);
365+
const videoDTO = await getVideoDTOSafe(value.video_id);
366+
if (videoDTO) {
367+
videoDTOs.push(videoDTO);
368+
}
216369
}
217370
}
218371

219-
return videoFields;
372+
return videoDTOs;
220373
};
221374

222375
return async (data: S['InvocationCompleteEvent']) => {
@@ -239,11 +392,7 @@ export const buildOnInvocationComplete = (
239392
}
240393

241394
await addImagesToGallery(data);
242-
243-
const videoField = getResultVideoFields(data)[0];
244-
if (videoField) {
245-
dispatch(generatedVideoChanged({ videoField }));
246-
}
395+
await addVideosToGallery(data);
247396

248397
$lastProgressEvent.set(null);
249398
};

0 commit comments

Comments
 (0)