Skip to content

Commit 6ba7b78

Browse files
Update rocket to 0.5.0-rc1
1 parent b5c722f commit 6ba7b78

File tree

19 files changed

+951
-834
lines changed

19 files changed

+951
-834
lines changed

Cargo.lock

Lines changed: 692 additions & 565 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

aw-client-rust/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@ aw-models = { path = "../aw-models" }
1515
[dev-dependencies]
1616
aw-datastore = { path = "../aw-datastore" }
1717
aw-server = { path = "../aw-server", default-features = false, features=[] }
18+
rocket = "0.5.0-rc.1"
19+
tokio-test = "*"

aw-client-rust/tests/test.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ extern crate aw_client_rust;
22
extern crate aw_datastore;
33
extern crate aw_server;
44
extern crate chrono;
5+
extern crate rocket;
56
extern crate serde_json;
7+
extern crate tokio_test;
68

79
#[cfg(test)]
810
mod test {
@@ -13,6 +15,7 @@ mod test {
1315
use std::path::PathBuf;
1416
use std::sync::Mutex;
1517
use std::thread;
18+
use tokio_test::block_on;
1619

1720
// A random port, but still not guaranteed to not be bound
1821
// FIXME: Bind to a port that is free for certain and use that for the client instead
@@ -37,9 +40,7 @@ mod test {
3740
}
3841
}
3942

40-
fn setup_testserver() -> () {
41-
// Start testserver and wait 10s for it to start up
42-
// TODO: Properly shutdown
43+
fn setup_testserver() -> rocket::Shutdown {
4344
use aw_server::endpoints::ServerState;
4445
let state = ServerState {
4546
datastore: Mutex::new(aw_datastore::Datastore::new_in_memory(false)),
@@ -49,10 +50,14 @@ mod test {
4950
let mut aw_config = aw_server::config::AWConfig::default();
5051
aw_config.port = PORT;
5152
let server = aw_server::endpoints::build_rocket(state, aw_config);
53+
let server = block_on(server.ignite()).unwrap();
54+
let shutdown_handler = server.shutdown();
5255

5356
thread::spawn(move || {
54-
server.launch();
57+
block_on(server.launch()).unwrap();
5558
});
59+
60+
shutdown_handler
5661
}
5762

5863
#[test]
@@ -62,7 +67,7 @@ mod test {
6267
let clientname = "aw-client-rust-test";
6368
let client: AwClient = AwClient::new(ip, &port, clientname);
6469

65-
setup_testserver();
70+
let shutdown_handler = setup_testserver();
6671

6772
wait_for_server(20, &client);
6873

@@ -113,5 +118,7 @@ mod test {
113118
assert_eq!(count, 0);
114119

115120
client.delete_bucket(&bucketname).unwrap();
121+
122+
shutdown_handler.notify();
116123
}
117124
}

aw-server/Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ name = "aw-server"
1414
path = "src/main.rs"
1515

1616
[dependencies]
17-
rocket = "0.4"
18-
rocket_contrib = { version = "*", default-features = false, features = ["json"] }
19-
rocket_cors = "0.5"
17+
rocket = { version = "0.5.0-rc.1", features = ["json"] }
18+
# TODO: Once rocket_cors has a version for rocket 0.5, use that instead
19+
rocket_cors = { git = "https://github.com/lawliet89/rocket_cors", rev = "a062933" }
2020
multipart = { version = "0.18", default-features = false, features = ["server"] }
2121
serde = { version = "1.0", features = ["derive"] }
2222
serde_json = "1.0"

aw-server/src/android/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::dirs;
1212

1313
use android_logger::Config;
1414
use log::Level;
15+
use rocket::serde::json::json;
1516

1617
#[no_mangle]
1718
pub extern "C" fn rust_greeting(to: *const c_char) -> *mut c_char {

aw-server/src/config.rs

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use std::fs::File;
22
use std::io::{Read, Write};
33

4-
use rocket::config::{Config, Environment, Limits};
4+
use rocket::config::Config;
5+
use rocket::data::{Limits, ToByteUnit};
56
use serde::{Deserialize, Serialize};
67

78
use crate::dirs;
@@ -42,21 +43,23 @@ impl Default for AWConfig {
4243

4344
impl AWConfig {
4445
pub fn to_rocket_config(&self) -> rocket::Config {
45-
let env = if self.testing {
46-
Environment::Production
46+
let mut config = if self.testing {
47+
Config::release_default()
4748
} else {
48-
Environment::Development
49+
Config::debug_default()
4950
};
51+
5052
// Needed for bucket imports
51-
let limits = Limits::new().limit("json", 1_000_000_000);
52-
53-
Config::build(env)
54-
.address(self.address.clone())
55-
.port(self.port)
56-
.keep_alive(0)
57-
.limits(limits)
58-
.finalize()
59-
.unwrap()
53+
let limits = Limits::default()
54+
.limit("json", 1000u64.megabytes())
55+
.limit("data-form", 1000u64.megabytes());
56+
57+
config.address = self.address.parse().unwrap();
58+
config.port = self.port;
59+
config.keep_alive = 0;
60+
config.limits = limits;
61+
62+
config
6063
}
6164
}
6265

aw-server/src/endpoints/bucket.rs

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use std::collections::HashMap;
2-
use std::io::Cursor;
32

4-
use rocket_contrib::json::Json;
3+
use rocket::serde::json::Json;
54

65
use chrono::DateTime;
76
use chrono::Utc;
@@ -11,16 +10,15 @@ use aw_models::BucketsExport;
1110
use aw_models::Event;
1211
use aw_models::TryVec;
1312

14-
use rocket::http::Header;
1513
use rocket::http::Status;
16-
use rocket::response::Response;
1714
use rocket::State;
1815

16+
use crate::endpoints::util::BucketsExportRocket;
1917
use crate::endpoints::{HttpErrorJson, ServerState};
2018

2119
#[get("/")]
2220
pub fn buckets_get(
23-
state: State<ServerState>,
21+
state: &State<ServerState>,
2422
) -> Result<Json<HashMap<String, Bucket>>, HttpErrorJson> {
2523
let datastore = endpoints_get_lock!(state.datastore);
2624
match datastore.get_buckets() {
@@ -32,7 +30,7 @@ pub fn buckets_get(
3230
#[get("/<bucket_id>")]
3331
pub fn bucket_get(
3432
bucket_id: String,
35-
state: State<ServerState>,
33+
state: &State<ServerState>,
3634
) -> Result<Json<Bucket>, HttpErrorJson> {
3735
let datastore = endpoints_get_lock!(state.datastore);
3836
match datastore.get_bucket(&bucket_id) {
@@ -45,7 +43,7 @@ pub fn bucket_get(
4543
pub fn bucket_new(
4644
bucket_id: String,
4745
message: Json<Bucket>,
48-
state: State<ServerState>,
46+
state: &State<ServerState>,
4947
) -> Result<(), HttpErrorJson> {
5048
let mut bucket = message.into_inner();
5149
if bucket.id != bucket_id {
@@ -65,7 +63,7 @@ pub fn bucket_events_get(
6563
start: Option<String>,
6664
end: Option<String>,
6765
limit: Option<u64>,
68-
state: State<ServerState>,
66+
state: &State<ServerState>,
6967
) -> Result<Json<Vec<Event>>, HttpErrorJson> {
7068
let starttime: Option<DateTime<Utc>> = match start {
7169
Some(dt_str) => match DateTime::parse_from_rfc3339(&dt_str) {
@@ -107,7 +105,7 @@ pub fn bucket_events_get(
107105
pub fn bucket_events_create(
108106
bucket_id: String,
109107
events: Json<Vec<Event>>,
110-
state: State<ServerState>,
108+
state: &State<ServerState>,
111109
) -> Result<Json<Vec<Event>>, HttpErrorJson> {
112110
let datastore = endpoints_get_lock!(state.datastore);
113111
let res = datastore.insert_events(&bucket_id, &events);
@@ -126,7 +124,7 @@ pub fn bucket_events_heartbeat(
126124
bucket_id: String,
127125
heartbeat_json: Json<Event>,
128126
pulsetime: f64,
129-
state: State<ServerState>,
127+
state: &State<ServerState>,
130128
) -> Result<Json<Event>, HttpErrorJson> {
131129
let heartbeat = heartbeat_json.into_inner();
132130
let datastore = endpoints_get_lock!(state.datastore);
@@ -139,7 +137,7 @@ pub fn bucket_events_heartbeat(
139137
#[get("/<bucket_id>/events/count")]
140138
pub fn bucket_event_count(
141139
bucket_id: String,
142-
state: State<ServerState>,
140+
state: &State<ServerState>,
143141
) -> Result<Json<u64>, HttpErrorJson> {
144142
let datastore = endpoints_get_lock!(state.datastore);
145143
let res = datastore.get_event_count(&bucket_id, None, None);
@@ -153,7 +151,7 @@ pub fn bucket_event_count(
153151
pub fn bucket_events_delete_by_id(
154152
bucket_id: String,
155153
event_id: i64,
156-
state: State<ServerState>,
154+
state: &State<ServerState>,
157155
) -> Result<(), HttpErrorJson> {
158156
let datastore = endpoints_get_lock!(state.datastore);
159157
match datastore.delete_events_by_id(&bucket_id, vec![event_id]) {
@@ -165,8 +163,8 @@ pub fn bucket_events_delete_by_id(
165163
#[get("/<bucket_id>/export")]
166164
pub fn bucket_export(
167165
bucket_id: String,
168-
state: State<ServerState>,
169-
) -> Result<Response, HttpErrorJson> {
166+
state: &State<ServerState>,
167+
) -> Result<BucketsExportRocket, HttpErrorJson> {
170168
let datastore = endpoints_get_lock!(state.datastore);
171169
let mut export = BucketsExport {
172170
buckets: HashMap::new(),
@@ -181,20 +179,12 @@ pub fn bucket_export(
181179
.expect("Failed to get events for bucket");
182180
bucket.events = Some(TryVec::new(events));
183181
export.buckets.insert(bucket_id.clone(), bucket);
184-
let filename = format!("aw-bucket-export_{}.json", bucket_id);
185-
186-
let header_content = format!("attachment; filename={}", filename);
187-
Ok(Response::build()
188-
.status(Status::Ok)
189-
.header(Header::new("Content-Disposition", header_content))
190-
.sized_body(Cursor::new(
191-
serde_json::to_string(&export).expect("Failed to serialize"),
192-
))
193-
.finalize())
182+
183+
Ok(export.into())
194184
}
195185

196186
#[delete("/<bucket_id>")]
197-
pub fn bucket_delete(bucket_id: String, state: State<ServerState>) -> Result<(), HttpErrorJson> {
187+
pub fn bucket_delete(bucket_id: String, state: &State<ServerState>) -> Result<(), HttpErrorJson> {
198188
let datastore = endpoints_get_lock!(state.datastore);
199189
match datastore.delete_bucket(&bucket_id) {
200190
Ok(_) => Ok(()),

aw-server/src/endpoints/export.rs

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
use std::collections::HashMap;
2-
use std::io::Cursor;
32

4-
use rocket::http::Header;
5-
use rocket::http::Status;
6-
use rocket::response::Response;
73
use rocket::State;
84

95
use aw_models::BucketsExport;
106
use aw_models::TryVec;
117

8+
use crate::endpoints::util::BucketsExportRocket;
129
use crate::endpoints::{HttpErrorJson, ServerState};
1310

1411
#[get("/")]
15-
pub fn buckets_export(state: State<ServerState>) -> Result<Response, HttpErrorJson> {
12+
pub fn buckets_export(state: &State<ServerState>) -> Result<BucketsExportRocket, HttpErrorJson> {
1613
let datastore = endpoints_get_lock!(state.datastore);
1714
let mut export = BucketsExport {
1815
buckets: HashMap::new(),
@@ -30,14 +27,5 @@ pub fn buckets_export(state: State<ServerState>) -> Result<Response, HttpErrorJs
3027
export.buckets.insert(bid, bucket);
3128
}
3229

33-
Ok(Response::build()
34-
.status(Status::Ok)
35-
.header(Header::new(
36-
"Content-Disposition",
37-
"attachment; filename=aw-buckets-export.json",
38-
))
39-
.sized_body(Cursor::new(
40-
serde_json::to_string(&export).expect("Failed to serialize"),
41-
))
42-
.finalize())
30+
Ok(export.into())
4331
}

aw-server/src/endpoints/hostcheck.rs

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
//!
88
//! [1]: https://github.com/ActivityWatch/activitywatch/security/advisories/GHSA-v9fg-6g9j-h4x4
99
use rocket::fairing::Fairing;
10-
use rocket::handler::Outcome;
1110
use rocket::http::uri::Origin;
1211
use rocket::http::{Method, Status};
12+
use rocket::route::Outcome;
1313
use rocket::{Data, Request, Rocket, Route};
1414

1515
use crate::config::AWConfig;
@@ -29,15 +29,25 @@ impl HostCheck {
2929
}
3030
}
3131

32-
/// Route for HostCheck Fairing error
33-
fn fairing_error_route<'r>(req: &'r Request<'_>, _: Data) -> Outcome<'r> {
34-
let err = HttpErrorJson::new(Status::BadRequest, "Host header is invalid".to_string());
35-
Outcome::from(req, err)
32+
/// Create a `Handler` for Fairing error handling
33+
#[derive(Clone)]
34+
struct FairingErrorRoute {}
35+
36+
#[rocket::async_trait]
37+
impl rocket::route::Handler for FairingErrorRoute {
38+
async fn handle<'r>(
39+
&self,
40+
request: &'r Request<'_>,
41+
_: rocket::Data<'r>,
42+
) -> rocket::route::Outcome<'r> {
43+
let err = HttpErrorJson::new(Status::BadRequest, "Host header is invalid".to_string());
44+
Outcome::from(request, err)
45+
}
3646
}
3747

3848
/// Create a new `Route` for Fairing handling
3949
fn fairing_route() -> Route {
40-
Route::ranked(1, Method::Get, "/", fairing_error_route)
50+
Route::ranked(1, Method::Get, "/", FairingErrorRoute {})
4151
}
4252

4353
fn redirect_bad_request(request: &mut Request) {
@@ -47,15 +57,16 @@ fn redirect_bad_request(request: &mut Request) {
4757
request.set_uri(origin);
4858
}
4959

60+
#[rocket::async_trait]
5061
impl Fairing for HostCheck {
5162
fn info(&self) -> rocket::fairing::Info {
5263
rocket::fairing::Info {
5364
name: "HostCheck",
54-
kind: rocket::fairing::Kind::Attach | rocket::fairing::Kind::Request,
65+
kind: rocket::fairing::Kind::Ignite | rocket::fairing::Kind::Request,
5566
}
5667
}
5768

58-
fn on_attach(&self, rocket: Rocket) -> Result<Rocket, Rocket> {
69+
async fn on_ignite(&self, rocket: Rocket<rocket::Build>) -> rocket::fairing::Result {
5970
match self.validate {
6071
true => Ok(rocket.mount(FAIRING_ROUTE_BASE, vec![fairing_route()])),
6172
false => {
@@ -65,7 +76,7 @@ impl Fairing for HostCheck {
6576
}
6677
}
6778

68-
fn on_request(&self, request: &mut Request, _: &Data) {
79+
async fn on_request(&self, request: &mut Request<'_>, _: &mut Data<'_>) {
6980
if !self.validate {
7081
// host header check is disabled
7182
return;
@@ -112,7 +123,7 @@ mod tests {
112123
use crate::config::AWConfig;
113124
use crate::endpoints;
114125

115-
fn setup_testserver(address: String) -> Rocket {
126+
fn setup_testserver(address: String) -> Rocket<rocket::Build> {
116127
let state = endpoints::ServerState {
117128
datastore: Mutex::new(aw_datastore::Datastore::new_in_memory(false)),
118129
asset_path: PathBuf::from("aw-webui/dist"),
@@ -126,7 +137,7 @@ mod tests {
126137
#[test]
127138
fn test_public_address() {
128139
let server = setup_testserver("0.0.0.0".to_string());
129-
let client = rocket::local::Client::new(server).expect("valid instance");
140+
let client = rocket::local::blocking::Client::tracked(server).expect("valid instance");
130141

131142
// When a public address is used, request should always pass, regardless
132143
// if the Host header is missing
@@ -140,7 +151,7 @@ mod tests {
140151
#[test]
141152
fn test_localhost_address() {
142153
let server = setup_testserver("127.0.0.1".to_string());
143-
let client = rocket::local::Client::new(server).expect("valid instance");
154+
let client = rocket::local::blocking::Client::tracked(server).expect("valid instance");
144155

145156
// If Host header is missing we should get a BadRequest
146157
let res = client

0 commit comments

Comments
 (0)