Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

146 changes: 83 additions & 63 deletions worker-macros/src/durable_object.rs

Large diffs are not rendered by default.

9 changes: 8 additions & 1 deletion worker-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@ use proc_macro::TokenStream;

#[proc_macro_attribute]
pub fn durable_object(_attr: TokenStream, item: TokenStream) -> TokenStream {
durable_object::expand_macro(item.into())
durable_object::expand_macro(item.into(), false)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}

#[proc_macro_attribute]
pub fn shared_durable_object(_attr: TokenStream, item: TokenStream) -> TokenStream {
durable_object::expand_macro(item.into(), true)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}
Expand Down
1 change: 1 addition & 0 deletions worker-sandbox/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ chrono = { version = "0.4.35", default-features = false, features = [
cfg-if = "1.0.0"
console_error_panic_hook = { version = "0.1.7", optional = true }
getrandom = { version = "0.2.10", features = ["js"] }
gloo-timers = { version = "0.3.0", features = ["futures"] }
hex = "0.4.3"
http.workspace=true
regex = "1.8.4"
Expand Down
18 changes: 14 additions & 4 deletions worker-sandbox/src/alarm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,13 @@ pub async fn handle_alarm(_req: Request, env: Env, _data: SomeSharedData) -> Res
}

#[worker::send]
pub async fn handle_id(_req: Request, env: Env, _data: SomeSharedData) -> Result<Response> {
let namespace = env.durable_object("COUNTER").expect("DAWJKHDAD");
pub async fn handle_id(req: Request, env: Env, _data: SomeSharedData) -> Result<Response> {
let durable_object_name = if req.path().contains("shared") {
"SHARED_COUNTER"
} else {
"COUNTER"
};
let namespace = env.durable_object(durable_object_name).expect("DAWJKHDAD");
let stub = namespace.id_from_name("A")?.get_stub()?;
// when calling fetch to a Durable Object, a full URL must be used. Alternatively, a
// compatibility flag can be provided in wrangler.toml to opt-in to older behavior:
Expand All @@ -72,15 +77,20 @@ pub async fn handle_put_raw(req: Request, env: Env, _data: SomeSharedData) -> Re
}

#[worker::send]
pub async fn handle_websocket(_req: Request, env: Env, _data: SomeSharedData) -> Result<Response> {
pub async fn handle_websocket(req: Request, env: Env, _data: SomeSharedData) -> Result<Response> {
let durable_object_name = if req.path().contains("shared") {
"SHARED_COUNTER"
} else {
"COUNTER"
};
// Accept / handle a websocket connection
let pair = WebSocketPair::new()?;
let server = pair.server;
server.accept()?;

// Connect to Durable Object via WS
let namespace = env
.durable_object("COUNTER")
.durable_object(durable_object_name)
.expect("failed to get namespace");
let stub = namespace.id_from_name("A")?.get_stub()?;
let mut req = Request::new("https://fake-host/ws", Method::Get)?;
Expand Down
1 change: 1 addition & 0 deletions worker-sandbox/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mod r2;
mod request;
mod router;
mod service;
mod shared_counter;
mod socket;
mod test;
mod user;
Expand Down
10 changes: 10 additions & 0 deletions worker-sandbox/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ pub fn make_router(data: SomeSharedData, env: Env) -> axum::Router {
.route("/durable/:id", get(handler!(alarm::handle_id)))
.route("/durable/put-raw", get(handler!(alarm::handle_put_raw)))
.route("/durable/websocket", get(handler!(alarm::handle_websocket)))
.route("/durable-shared/:id", get(handler!(alarm::handle_id)))
.route(
"/durable-shared/websocket",
get(handler!(alarm::handle_websocket)),
)
.route("/var", get(handler!(request::handle_var)))
.route("/object-var", get(handler!(request::handle_object_var)))
.route("/secret", get(handler!(request::handle_secret)))
Expand Down Expand Up @@ -277,6 +282,11 @@ pub fn make_router<'a>(data: SomeSharedData) -> Router<'a, SomeSharedData> {
.get_async("/durable/:id", handler!(alarm::handle_id))
.get_async("/durable/put-raw", handler!(alarm::handle_put_raw))
.get_async("/durable/websocket", handler!(alarm::handle_websocket))
.get_async("/durable-shared/:id", handler!(alarm::handle_id))
.get_async(
"/durable-shared/websocket",
handler!(alarm::handle_websocket),
)
.get_async("/secret", handler!(request::handle_secret))
.get_async("/var", handler!(request::handle_var))
.get_async("/object-var", handler!(request::handle_object_var))
Expand Down
94 changes: 94 additions & 0 deletions worker-sandbox/src/shared_counter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use gloo_timers::future::TimeoutFuture;
use std::cell::RefCell;
use worker::*;

#[shared_durable_object]
pub struct SharedCounter {
count: RefCell<usize>,
state: State,
initialized: RefCell<bool>,
env: Env,
}

#[shared_durable_object]
impl SharedDurableObject for SharedCounter {
fn new(state: State, env: Env) -> Self {
Self {
count: RefCell::new(0),
initialized: RefCell::new(false),
state,
env,
}
}

async fn fetch(&self, req: Request) -> Result<Response> {
if !*self.initialized.borrow() {
*self.initialized.borrow_mut() = true;
*self.count.borrow_mut() = self.state.storage().get("count").await.unwrap_or(0);
}

if req.path().eq("/ws") {
let pair = WebSocketPair::new()?;
let server = pair.server;
// accept websocket with hibernation api
self.state.accept_web_socket(&server);
server
.serialize_attachment("hello")
.expect("failed to serialize attachment");

return Ok(ResponseBuilder::new()
.with_status(101)
.with_websocket(pair.client)
.empty());
}

// simulated delay, to allow testing concurrency
TimeoutFuture::new(1_000).await;

*self.count.borrow_mut() += 15;
let count = *self.count.borrow();
self.state.storage().put("count", count).await?;

Response::ok(format!(
"[durable_object]: self.count: {}, secret value: {}",
self.count.borrow(),
self.env.secret("SOME_SECRET")?
))
}

async fn websocket_message(
&self,
ws: WebSocket,
_message: WebSocketIncomingMessage,
) -> Result<()> {
let _attach: String = ws
.deserialize_attachment()?
.expect("websockets should have an attachment");

// simulated delay, to allow testing concurrency
TimeoutFuture::new(1_000).await;

// get and increment storage by 15
let mut count = self.state.storage().get("count").await.unwrap_or(0);
count += 15;
self.state.storage().put("count", count).await?;
// send value to client
ws.send_with_str(format!("{}", count))
.expect("failed to send value to client");
Ok(())
}

async fn websocket_close(
&self,
_ws: WebSocket,
_code: usize,
_reason: String,
_was_clean: bool,
) -> Result<()> {
Ok(())
}

async fn websocket_error(&self, _ws: WebSocket, _error: Error) -> Result<()> {
Ok(())
}
}
38 changes: 38 additions & 0 deletions worker-sandbox/tests/durable.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,41 @@ describe("durable", () => {
});
});

describe("durable-shared", () => {
test("websocket-to-durable", async () => {
const resp = await mf.dispatchFetch("http://fake.host/durable-shared/websocket", {
headers: {
upgrade: "websocket",
},
});
expect(resp.webSocket).not.toBeNull();

const socket = resp.webSocket!;
socket.accept();

const handlers = {
messageHandler: (event: MessageEvent) => {
expect(Number(event.data) % 15).toBe(0);
},
close(event: CloseEvent) {},
};

const messageHandlerWrapper = vi.spyOn(handlers, "messageHandler");
const closeHandlerWrapper = vi.spyOn(handlers, "messageHandler");
socket.addEventListener("message", handlers.messageHandler);
socket.addEventListener("close", handlers.close);

for (let i = 0; i < 10; i++) {
socket.send("hi, can you ++?");
}
await new Promise((resolve) => setTimeout(resolve, 1500));
expect(messageHandlerWrapper).toHaveBeenCalledTimes(10);

socket.send("hi again, more ++?");
await new Promise((resolve) => setTimeout(resolve, 1500));
expect(messageHandlerWrapper).toHaveBeenCalledTimes(11);

socket.close();
expect(closeHandlerWrapper).toBeCalled();
});
});
1 change: 1 addition & 0 deletions worker-sandbox/tests/mf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ export const mf = new Miniflare({
},
durableObjects: {
COUNTER: "Counter",
SHARED_COUNTER: "SharedCounter",
PUT_RAW_TEST_OBJECT: "PutRawTestObject",
},
kvNamespaces: ["SOME_NAMESPACE", "FILE_SIZES", "TEST"],
Expand Down
2 changes: 1 addition & 1 deletion worker-sandbox/tests/request.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ test("fetch json", async () => {

test("proxy request", async () => {
const resp = await mf.dispatchFetch(
"https://fake.host/proxy_request/https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding/contributors.txt"
"https://fake.host/proxy_request/https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Content-Encoding/contributors.txt"
);
expect(resp.status).toBe(200);
});
Expand Down
1 change: 1 addition & 0 deletions worker-sandbox/wrangler.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ remote-service = "./remote-service"
[durable_objects]
bindings = [
{ name = "COUNTER", class_name = "Counter" },
{ name = "SHARED_COUNTER", class_name = "SharedCounter" },
{ name = "ALARM", class_name = "AlarmObject" },
{ name = "PUT_RAW_TEST_OBJECT", class_name = "PutRawTestObject" },
]
Expand Down
37 changes: 37 additions & 0 deletions worker/src/durable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -846,3 +846,40 @@ pub trait DurableObject {
unimplemented!("websocket_error() handler not implemented")
}
}

#[async_trait(?Send)]
pub trait SharedDurableObject {
fn new(state: State, env: Env) -> Self;

async fn fetch(&self, req: Request) -> Result<Response>;

#[allow(clippy::diverging_sub_expression)]
async fn alarm(&self) -> Result<Response> {
unimplemented!("alarm() handler not implemented")
}

#[allow(unused_variables, clippy::diverging_sub_expression)]
async fn websocket_message(
&self,
ws: WebSocket,
message: WebSocketIncomingMessage,
) -> Result<()> {
unimplemented!("websocket_message() handler not implemented")
}

#[allow(unused_variables, clippy::diverging_sub_expression)]
async fn websocket_close(
&self,
ws: WebSocket,
code: usize,
reason: String,
was_clean: bool,
) -> Result<()> {
unimplemented!("websocket_close() handler not implemented")
}

#[allow(unused_variables, clippy::diverging_sub_expression)]
async fn websocket_error(&self, ws: WebSocket, error: Error) -> Result<()> {
unimplemented!("websocket_error() handler not implemented")
}
}
2 changes: 1 addition & 1 deletion worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ pub use wasm_bindgen_futures;
pub use worker_kv as kv;

pub use cf::{Cf, CfResponseProperties, TlsClientAuth};
pub use worker_macros::{durable_object, event, send};
pub use worker_macros::{durable_object, event, send, shared_durable_object};
#[doc(hidden)]
pub use worker_sys;
pub use worker_sys::{console_debug, console_error, console_log, console_warn};
Expand Down
17 changes: 5 additions & 12 deletions worker/src/socket.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::{
convert::TryFrom,
io::ErrorKind,
pin::Pin,
task::{Context, Poll},
};
Expand Down Expand Up @@ -150,7 +149,7 @@ fn js_value_to_std_io_error(value: JsValue) -> IoError {
} else {
format!("Error interpreting JsError: {:?}", value)
};
IoError::new(ErrorKind::Other, s)
IoError::other(s)
}
impl AsyncRead for Socket {
fn poll_read(
Expand All @@ -173,10 +172,7 @@ impl AsyncRead for Socket {
Ok(value) => value.into(),
Err(error) => {
let msg = format!("Unable to interpret field 'done' in ReadableStreamDefaultReader.read(): {:?}", error);
return (
Reading::None,
Poll::Ready(Err(IoError::new(ErrorKind::Other, msg))),
);
return (Reading::None, Poll::Ready(Err(IoError::other(msg))));
}
};
if done.is_truthy() {
Expand All @@ -189,10 +185,7 @@ impl AsyncRead for Socket {
Ok(value) => value.into(),
Err(error) => {
let msg = format!("Unable to interpret field 'value' in ReadableStreamDefaultReader.read(): {:?}", error);
return (
Reading::None,
Poll::Ready(Err(IoError::new(ErrorKind::Other, msg))),
);
return (Reading::None, Poll::Ready(Err(IoError::other(msg))));
}
};
let data = arr.to_vec();
Expand All @@ -214,7 +207,7 @@ impl AsyncRead for Socket {
"Unable to cast JsObject to ReadableStreamDefaultReader: {:?}",
error
);
return Poll::Ready(Err(IoError::new(ErrorKind::Other, msg)));
return Poll::Ready(Err(IoError::other(msg)));
}
};

Expand All @@ -241,7 +234,7 @@ impl AsyncWrite for Socket {
Ok(writer) => writer,
Err(error) => {
let msg = format!("Could not retrieve Writer: {:?}", error);
return Poll::Ready(Err(IoError::new(ErrorKind::Other, msg)));
return Poll::Ready(Err(IoError::other(msg)));
}
};
Self::handle_write_future(
Expand Down