Skip to content

Commit f513e50

Browse files
authored
fix(rivetkit): correctly handle hibernatable websocket reconnection with persisted request ids (#3398)
1 parent 1364d96 commit f513e50

File tree

11 files changed

+144
-53
lines changed

11 files changed

+144
-53
lines changed

rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v2.bare

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ type PersistedConnection struct {
1212
state: data
1313
subscriptions: list<PersistedSubscription>
1414
lastSeen: i64
15+
hibernatableRequestId: optional<data>
1516
}
1617

1718
# MARK: Schedule Event

rivetkit-typescript/packages/rivetkit/src/actor/conn-drivers.ts

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,6 @@ export interface ConnDriver<State> {
6868
conn: AnyConn,
6969
state: State,
7070
): ConnReadyState | undefined;
71-
72-
/**
73-
* If the underlying connection can hibernate.
74-
*/
75-
isHibernatable(
76-
actor: AnyActorInstance,
77-
conn: AnyConn,
78-
state: State,
79-
): boolean;
8071
}
8172

8273
// MARK: WebSocket
@@ -159,22 +150,6 @@ const WEBSOCKET_DRIVER: ConnDriver<ConnDriverWebSocketState> = {
159150
): ConnReadyState | undefined => {
160151
return state.websocket.readyState;
161152
},
162-
163-
isHibernatable(
164-
_actor: AnyActorInstance,
165-
_conn: AnyConn,
166-
state: ConnDriverWebSocketState,
167-
): boolean {
168-
// Extract isHibernatable from the HonoWebSocketAdapter
169-
if (state.websocket.raw) {
170-
const raw = state.websocket.raw as HonoWebSocketAdapter;
171-
if (typeof raw.isHibernatable === "boolean") {
172-
return raw.isHibernatable;
173-
}
174-
}
175-
176-
return false;
177-
},
178153
};
179154

180155
// MARK: SSE
@@ -210,10 +185,6 @@ const SSE_DRIVER: ConnDriver<ConnDriverSseState> = {
210185

211186
return ConnReadyState.OPEN;
212187
},
213-
214-
isHibernatable(): boolean {
215-
return false;
216-
},
217188
};
218189

219190
// MARK: HTTP
@@ -226,9 +197,6 @@ const HTTP_DRIVER: ConnDriver<ConnDriverHttpState> = {
226197
// Noop
227198
// TODO: Abort the request
228199
},
229-
isHibernatable(): boolean {
230-
return false;
231-
},
232200
};
233201

234202
/** List of all connection drivers. */

rivetkit-typescript/packages/rivetkit/src/actor/conn-socket.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,7 @@ import type { ConnDriverState } from "./conn-drivers";
22

33
export interface ConnSocket {
44
requestId: string;
5+
requestIdBuf?: ArrayBuffer;
6+
hibernatable: boolean;
57
driverState: ConnDriverState;
68
}

rivetkit-typescript/packages/rivetkit/src/actor/conn.ts

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import invariant from "invariant";
33
import { PersistedHibernatableWebSocket } from "@/schemas/actor-persist/mod";
44
import type * as protocol from "@/schemas/client-protocol/mod";
55
import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned";
6-
import { bufferToArrayBuffer } from "@/utils";
6+
import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils";
77
import {
88
CONN_DRIVERS,
99
ConnDriverKind,
@@ -14,7 +14,7 @@ import {
1414
import type { ConnSocket } from "./conn-socket";
1515
import type { AnyDatabaseProvider } from "./database";
1616
import * as errors from "./errors";
17-
import type { ActorInstance } from "./instance";
17+
import { type ActorInstance, PERSIST_SYMBOL } from "./instance";
1818
import type { PersistedConn } from "./persisted";
1919
import { CachedSerializer } from "./protocol/serde";
2020
import { generateSecureToken } from "./utils";
@@ -69,7 +69,8 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
6969
__socket?: ConnSocket;
7070

7171
get __status(): ConnectionStatus {
72-
if (this.__socket) {
72+
// TODO: isHibernatible might be true while the actual hibernatable websocket has disconnected
73+
if (this.__socket || this.isHibernatable) {
7374
return "connected";
7475
} else {
7576
return "reconnecting";
@@ -132,17 +133,17 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
132133
* If the underlying connection can hibernate.
133134
*/
134135
public get isHibernatable(): boolean {
135-
if (this.__driverState) {
136-
const driverKind = getConnDriverKindFromState(this.__driverState);
137-
const driver = CONN_DRIVERS[driverKind];
138-
return driver.isHibernatable(
139-
this.#actor,
140-
this,
141-
(this.__driverState as any)[driverKind],
142-
);
143-
} else {
136+
if (!this.__persist.hibernatableRequestId) {
144137
return false;
145138
}
139+
return (
140+
this.#actor[PERSIST_SYMBOL].hibernatableWebSocket.findIndex((x) =>
141+
arrayBuffersEqual(
142+
x.requestId,
143+
this.__persist.hibernatableRequestId!,
144+
),
145+
) > -1
146+
);
146147
}
147148

148149
/**

rivetkit-typescript/packages/rivetkit/src/actor/instance.ts

Lines changed: 98 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import {
1919
bufferToArrayBuffer,
2020
EXTRA_ERROR_LOG,
2121
getEnvUniversal,
22+
idToStr,
2223
promiseWithResolvers,
2324
SinglePromiseQueue,
2425
} from "@/utils";
@@ -244,7 +245,10 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
244245
lastSeen: conn.lastSeen,
245246
stateEnabled: conn.__stateEnabled,
246247
isHibernatable: conn.isHibernatable,
247-
requestId: conn.__socket?.requestId,
248+
hibernatableRequestId: conn.__persist
249+
.hibernatableRequestId
250+
? idToStr(conn.__persist.hibernatableRequestId)
251+
: undefined,
248252
driver: conn.__driverState
249253
? getConnDriverKindFromState(conn.__driverState)
250254
: undefined,
@@ -267,6 +271,7 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
267271
const conn = await this.createConn(
268272
{
269273
requestId: requestId,
274+
hibernatable: false,
270275
driverState: { [ConnDriverKind.HTTP]: {} },
271276
},
272277
undefined,
@@ -1016,6 +1021,74 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
10161021
): Promise<Conn<S, CP, CS, V, I, DB>> {
10171022
this.#assertReady();
10181023

1024+
// Check for hibernatable websocket reconnection
1025+
if (socket.requestIdBuf && socket.hibernatable) {
1026+
this.rLog.debug({
1027+
msg: "checking for hibernatable websocket connection",
1028+
requestId: socket.requestId,
1029+
existingConnectionsCount: this.#connections.size,
1030+
});
1031+
1032+
// Find existing connection with matching hibernatableRequestId
1033+
const existingConn = Array.from(this.#connections.values()).find(
1034+
(conn) =>
1035+
conn.__persist.hibernatableRequestId &&
1036+
arrayBuffersEqual(
1037+
conn.__persist.hibernatableRequestId,
1038+
socket.requestIdBuf!,
1039+
),
1040+
);
1041+
1042+
if (existingConn) {
1043+
this.rLog.debug({
1044+
msg: "reconnecting hibernatable websocket connection",
1045+
connectionId: existingConn.id,
1046+
requestId: socket.requestId,
1047+
});
1048+
1049+
// If there's an existing driver state, clean it up without marking as clean disconnect
1050+
if (existingConn.__driverState) {
1051+
this.#rLog.warn({
1052+
msg: "found existing driver state on hibernatable websocket",
1053+
connectionId: existingConn.id,
1054+
requestId: socket.requestId,
1055+
});
1056+
const driverKind = getConnDriverKindFromState(
1057+
existingConn.__driverState,
1058+
);
1059+
const driver = CONN_DRIVERS[driverKind];
1060+
if (driver.disconnect) {
1061+
// Call driver disconnect to clean up directly. Don't use Conn.disconnect since that will remove the connection entirely.
1062+
driver.disconnect(
1063+
this,
1064+
existingConn,
1065+
(existingConn.__driverState as any)[driverKind],
1066+
"Reconnecting hibernatable websocket with new driver state",
1067+
);
1068+
}
1069+
}
1070+
1071+
// Update with new driver state
1072+
existingConn.__socket = socket;
1073+
existingConn.__persist.lastSeen = Date.now();
1074+
1075+
// Update sleep timer since connection is now active
1076+
this.#resetSleepTimer();
1077+
1078+
this.inspector.emitter.emit("connectionUpdated");
1079+
1080+
// We don't need to send a new init message since this is a
1081+
// hibernated request that has already been initialized
1082+
1083+
return existingConn;
1084+
} else {
1085+
this.rLog.debug({
1086+
msg: "no existing hibernatable connection found, creating new connection",
1087+
requestId: socket.requestId,
1088+
});
1089+
}
1090+
}
1091+
10191092
// If connection ID and token are provided, try to reconnect
10201093
if (connectionId && connectionToken) {
10211094
this.rLog.debug({
@@ -1074,14 +1147,12 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
10741147
);
10751148

10761149
return existingConn;
1150+
} else {
1151+
this.rLog.debug({
1152+
msg: "connection not found or token mismatch, creating new connection",
1153+
connectionId,
1154+
});
10771155
}
1078-
1079-
// If we get here, either connection doesn't exist or token doesn't match
1080-
// Fall through to create new connection with new IDs
1081-
this.rLog.debug({
1082-
msg: "connection not found or token mismatch, creating new connection",
1083-
connectionId,
1084-
});
10851156
}
10861157

10871158
// Generate new connection ID and token if not provided or if reconnection failed
@@ -1147,6 +1218,19 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
11471218
lastSeen: Date.now(),
11481219
subscriptions: [],
11491220
};
1221+
1222+
// Check if this connection is for a hibernatable websocket
1223+
if (socket.requestIdBuf) {
1224+
const isHibernatable =
1225+
this.#persist.hibernatableWebSocket.findIndex((ws) =>
1226+
arrayBuffersEqual(ws.requestId, socket.requestIdBuf!),
1227+
) !== -1;
1228+
1229+
if (isHibernatable) {
1230+
persist.hibernatableRequestId = socket.requestIdBuf;
1231+
}
1232+
}
1233+
11501234
const conn = new Conn<S, CP, CS, V, I, DB>(this, persist);
11511235
conn.__socket = socket;
11521236
this.#connections.set(conn.id, conn);
@@ -2094,6 +2178,10 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
20942178
// Disconnect existing non-hibernatable connections
20952179
for (const connection of this.#connections.values()) {
20962180
if (!connection.isHibernatable) {
2181+
this.#rLog.debug({
2182+
msg: "disconnecting non-hibernatable connection on actor stop",
2183+
connId: connection.id,
2184+
});
20972185
promises.push(connection.disconnect());
20982186
}
20992187

@@ -2187,6 +2275,7 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
21872275
eventName: sub.eventName,
21882276
})),
21892277
lastSeen: BigInt(conn.lastSeen),
2278+
hibernatableRequestId: conn.hibernatableRequestId ?? null,
21902279
})),
21912280
scheduledEvents: persist.scheduledEvents.map((event) => ({
21922281
eventId: event.eventId,
@@ -2225,6 +2314,7 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
22252314
eventName: sub.eventName,
22262315
})),
22272316
lastSeen: Number(conn.lastSeen),
2317+
hibernatableRequestId: conn.hibernatableRequestId ?? undefined,
22282318
})),
22292319
scheduledEvents: bareData.scheduledEvents.map((event) => ({
22302320
eventId: event.eventId,

rivetkit-typescript/packages/rivetkit/src/actor/persisted.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@ export interface PersistedConn<CP, CS> {
1616
state: CS;
1717
subscriptions: PersistedSubscription[];
1818

19-
/** Last time the socket was seen. This is set when disconencted so we can determine when we need to clean this up. */
19+
/** Last time the socket was seen. This is set when disconnected so we can determine when we need to clean this up. */
2020
lastSeen: number;
21+
22+
/** Request ID of the hibernatable WebSocket. See PersistedActor.hibernatableWebSocket */
23+
hibernatableRequestId?: ArrayBuffer;
2124
}
2225

2326
export interface PersistedSubscription {

rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ export async function handleWebSocketConnect(
116116
encoding: Encoding,
117117
parameters: unknown,
118118
requestId: string,
119+
requestIdBuf: ArrayBuffer | undefined,
119120
connId: string | undefined,
120121
connToken: string | undefined,
121122
): Promise<UpgradeWebSocketArgs> {
@@ -184,9 +185,19 @@ export async function handleWebSocketConnect(
184185
actorId,
185186
});
186187

188+
// Check if this is a hibernatable websocket
189+
const isHibernatable =
190+
!!requestIdBuf &&
191+
actor[PERSIST_SYMBOL].hibernatableWebSocket.findIndex(
192+
(ws) =>
193+
arrayBuffersEqual(ws.requestId, requestIdBuf),
194+
) !== -1;
195+
187196
conn = await actor.createConn(
188197
{
189198
requestId: requestId,
199+
requestIdBuf: requestIdBuf,
200+
hibernatable: isHibernatable,
190201
driverState: {
191202
[ConnDriverKind.WEBSOCKET]: {
192203
encoding,
@@ -365,6 +376,7 @@ export async function handleSseConnect(
365376
conn = await actor.createConn(
366377
{
367378
requestId: requestId,
379+
hibernatable: false,
368380
driverState: {
369381
[ConnDriverKind.SSE]: {
370382
encoding,
@@ -479,6 +491,7 @@ export async function handleAction(
479491
conn = await actor.createConn(
480492
{
481493
requestId: requestId,
494+
hibernatable: false,
482495
driverState: { [ConnDriverKind.HTTP]: {} },
483496
},
484497
parameters,
@@ -593,6 +606,7 @@ export async function handleRawWebSocketHandler(
593606
path: string,
594607
actorDriver: ActorDriver,
595608
actorId: string,
609+
requestIdBuf: ArrayBuffer | undefined,
596610
): Promise<UpgradeWebSocketArgs> {
597611
const actor = await actorDriver.loadActor(actorId);
598612

rivetkit-typescript/packages/rivetkit/src/actor/router.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ export function createActorRouter(
187187
encoding,
188188
connParams,
189189
generateConnRequestId(),
190+
undefined,
190191
connIdRaw,
191192
connTokenRaw,
192193
);
@@ -303,6 +304,7 @@ export function createActorRouter(
303304
pathWithQuery,
304305
actorDriver,
305306
c.env.actorId,
307+
undefined,
306308
);
307309
})(c, noopNext());
308310
} else {

rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ export class EngineActorDriver implements ActorDriver {
562562
encoding,
563563
connParams,
564564
requestId,
565+
requestIdBuf,
565566
// Extract connId and connToken from protocols if needed
566567
undefined,
567568
undefined,
@@ -572,6 +573,7 @@ export class EngineActorDriver implements ActorDriver {
572573
url.pathname + url.search,
573574
this,
574575
actorId,
576+
requestIdBuf,
575577
);
576578
} else {
577579
throw new Error(`Unreachable path: ${url.pathname}`);

0 commit comments

Comments
 (0)