Skip to content

Commit

Permalink
refactor: unwraps and mor (#587)
Browse files Browse the repository at this point in the history
* refactor: better error handling

* refactor: trim mutexes

* refactor: remove abstract factory

* refactor: remove extension todo
  • Loading branch information
chesedo authored Jan 18, 2023
1 parent 35c0660 commit a8b6166
Show file tree
Hide file tree
Showing 15 changed files with 103 additions and 106 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ async-trait = "0.1.58"
axum = { version = "0.6.0", default-features = false }
chrono = { version = "0.4.23", default-features = false, features = ["clock"] }
once_cell = "1.16.0"
prost-types = "0.11.0"
uuid = "1.2.2"
thiserror = "1.0.37"
serde = { version = "1.0.148", default-features = false }
Expand Down
3 changes: 2 additions & 1 deletion cargo-shuttle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod provisioner_server;
use shuttle_common::project::ProjectName;
use shuttle_proto::runtime::{self, LoadRequest, StartRequest, SubscribeLogsRequest};
use std::collections::HashMap;
use std::convert::TryInto;
use std::ffi::OsString;
use std::fs::{read_to_string, File};
use std::io::stdout;
Expand Down Expand Up @@ -454,7 +455,7 @@ impl Shuttle {

tokio::spawn(async move {
while let Some(log) = stream.message().await.expect("to get log from stream") {
let log: shuttle_common::LogItem = log.into();
let log: shuttle_common::LogItem = log.try_into().expect("to convert log");
println!("{log}");
}
});
Expand Down
4 changes: 3 additions & 1 deletion common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ crossterm = { version = "0.25.0", optional = true }
http = { version = "0.2.8", optional = true }
http-serde = { version = "1.1.2", optional = true }
once_cell = { workspace = true, optional = true }
prost-types = { workspace = true, optional = true }
reqwest = { version = "0.11.13", optional = true }
rmp-serde = { version = "1.1.1", optional = true }
rustrict = { version = "0.5.5", optional = true }
serde = { workspace = true }
serde_json = { workspace = true, optional = true }
strum = { version = "0.24.1", features = ["derive"], optional = true }
thiserror = { workspace = true, optional = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, optional = true }
uuid = { workspace = true, features = ["v4", "serde"], optional = true }
Expand All @@ -35,5 +37,5 @@ backend = ["async-trait", "axum"]
display = ["comfy-table", "crossterm"]
tracing = ["serde_json"]
wasm = ["http-serde", "http", "rmp-serde", "tracing", "tracing-subscriber"]
models = ["anyhow", "async-trait", "display", "http", "reqwest", "serde_json", "service"]
models = ["anyhow", "async-trait", "display", "http", "prost-types", "reqwest", "serde_json", "service", "thiserror"]
service = ["chrono/serde", "once_cell", "rustrict", "serde/derive", "strum", "uuid"]
13 changes: 13 additions & 0 deletions common/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ use anyhow::{Context, Result};
use async_trait::async_trait;
use http::StatusCode;
use serde::de::DeserializeOwned;
use thiserror::Error;
use tracing::trace;

/// A to_json wrapper for handling our error states
#[async_trait]
pub trait ToJson {
async fn to_json<T: DeserializeOwned>(self) -> Result<T>;
Expand Down Expand Up @@ -48,3 +50,14 @@ impl ToJson for reqwest::Response {
}
}
}

/// Errors that can occur when changing types. Especially from prost
#[derive(Error, Debug)]
pub enum ParseError {
#[error("failed to parse UUID: {0}")]
Uuid(#[from] uuid::Error),
#[error("failed to parse timestamp: {0}")]
Timestamp(#[from] prost_types::TimestampError),
#[error("failed to parse serde: {0}")]
Serde(#[from] serde_json::Error),
}
21 changes: 10 additions & 11 deletions common/src/wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use crate::tracing::JsonVisitor;

extern crate rmp_serde as rmps;

// todo: add http extensions field
#[derive(Serialize, Deserialize, Debug)]
pub struct RequestWrapper {
#[serde(with = "http_serde::method")]
Expand Down Expand Up @@ -44,11 +43,11 @@ impl From<http::request::Parts> for RequestWrapper {

impl RequestWrapper {
/// Serialize a RequestWrapper to the Rust MessagePack data format
pub fn into_rmp(self) -> Vec<u8> {
pub fn into_rmp(self) -> Result<Vec<u8>, rmps::encode::Error> {
let mut buf = Vec::new();
self.serialize(&mut Serializer::new(&mut buf)).unwrap();
self.serialize(&mut Serializer::new(&mut buf))?;

buf
Ok(buf)
}

/// Consume the wrapper and return a request builder with `Parts` set
Expand All @@ -60,7 +59,7 @@ impl RequestWrapper {

request
.headers_mut()
.unwrap()
.unwrap() // Safe to unwrap as we just made the builder
.extend(self.headers.into_iter());

request
Expand Down Expand Up @@ -92,11 +91,11 @@ impl From<http::response::Parts> for ResponseWrapper {

impl ResponseWrapper {
/// Serialize a ResponseWrapper into the Rust MessagePack data format
pub fn into_rmp(self) -> Vec<u8> {
pub fn into_rmp(self) -> Result<Vec<u8>, rmps::encode::Error> {
let mut buf = Vec::new();
self.serialize(&mut Serializer::new(&mut buf)).unwrap();
self.serialize(&mut Serializer::new(&mut buf))?;

buf
Ok(buf)
}

/// Consume the wrapper and return a response builder with `Parts` set
Expand All @@ -107,7 +106,7 @@ impl ResponseWrapper {

response
.headers_mut()
.unwrap()
.unwrap() // Safe to unwrap since we just made the builder
.extend(self.headers.into_iter());

response
Expand Down Expand Up @@ -389,7 +388,7 @@ mod test {
.unwrap();

let (parts, _) = request.into_parts();
let rmp = RequestWrapper::from(parts).into_rmp();
let rmp = RequestWrapper::from(parts).into_rmp().unwrap();

let back: RequestWrapper = rmps::from_slice(&rmp).unwrap();

Expand All @@ -415,7 +414,7 @@ mod test {
.unwrap();

let (parts, _) = response.into_parts();
let rmp = ResponseWrapper::from(parts).into_rmp();
let rmp = ResponseWrapper::from(parts).into_rmp().unwrap();

let back: ResponseWrapper = rmps::from_slice(&rmp).unwrap();

Expand Down
28 changes: 17 additions & 11 deletions deployer/src/deployment/deploy_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

use chrono::{DateTime, Utc};
use serde_json::json;
use shuttle_common::{tracing::JsonVisitor, STATE_MESSAGE};
use shuttle_common::{models::ParseError, tracing::JsonVisitor, STATE_MESSAGE};
use shuttle_proto::runtime;
use std::{str::FromStr, time::SystemTime};
use std::{convert::TryFrom, str::FromStr, time::SystemTime};
use tracing::{field::Visit, span, warn, Metadata, Subscriber};
use tracing_subscriber::Layer;
use uuid::Uuid;
Expand Down Expand Up @@ -112,19 +112,25 @@ impl From<Log> for DeploymentState {
}
}

impl From<runtime::LogItem> for Log {
fn from(log: runtime::LogItem) -> Self {
Self {
id: Uuid::from_slice(&log.id).unwrap(),
state: runtime::LogState::from_i32(log.state).unwrap().into(),
level: runtime::LogLevel::from_i32(log.level).unwrap().into(),
timestamp: DateTime::from(SystemTime::try_from(log.timestamp.unwrap()).unwrap()),
impl TryFrom<runtime::LogItem> for Log {
type Error = ParseError;

fn try_from(log: runtime::LogItem) -> Result<Self, Self::Error> {
Ok(Self {
id: Uuid::from_slice(&log.id)?,
state: runtime::LogState::from_i32(log.state)
.unwrap_or_default()
.into(),
level: runtime::LogLevel::from_i32(log.level)
.unwrap_or_default()
.into(),
timestamp: DateTime::from(SystemTime::try_from(log.timestamp.unwrap_or_default())?),
file: log.file,
line: log.line,
target: log.target,
fields: serde_json::from_slice(&log.fields).unwrap(),
fields: serde_json::from_slice(&log.fields)?,
r#type: LogType::Event,
}
})
}
}

Expand Down
20 changes: 15 additions & 5 deletions deployer/src/deployment/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,23 @@ async fn load(
) -> Result<()> {
info!(
"loading project from: {}",
so_path.clone().into_os_string().into_string().unwrap()
so_path
.clone()
.into_os_string()
.into_string()
.unwrap_or_default()
);

let secrets = secret_getter
.get_secrets(&service_id)
.await
.unwrap()
.map_err(|e| Error::SecretsGet(Box::new(e)))?
.into_iter()
.map(|secret| (secret.key, secret.value));
let secrets = HashMap::from_iter(secrets);

let load_request = tonic::Request::new(LoadRequest {
path: so_path.into_os_string().into_string().unwrap(),
path: so_path.into_os_string().into_string().unwrap_or_default(),
service_name: service_name.clone(),
secrets,
});
Expand Down Expand Up @@ -283,7 +287,10 @@ async fn run(
mut kill_recv: KillReceiver,
cleanup: impl FnOnce(std::result::Result<Response<StopResponse>, Status>) + Send + 'static,
) {
deployment_updater.set_address(&id, &address).await.unwrap();
deployment_updater
.set_address(&id, &address)
.await
.expect("to set deployment address");

let start_request = tonic::Request::new(StartRequest {
deployment_id: id.as_bytes().to_vec(),
Expand All @@ -292,7 +299,10 @@ async fn run(
});

info!("starting service");
let response = runtime_client.start(start_request).await.unwrap();
let response = runtime_client
.start(start_request)
.await
.expect("to start deployment");

info!(response = ?response.into_inner(), "start client response: ");

Expand Down
2 changes: 2 additions & 0 deletions deployer/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub enum Error {
SecretsParse(#[from] toml::de::Error),
#[error("Failed to set secrets: {0}")]
SecretsSet(#[source] Box<dyn StdError + Send>),
#[error("Failed to get secrets: {0}")]
SecretsGet(#[source] Box<dyn StdError + Send>),
#[error("Failed to cleanup old deployments: {0}")]
OldCleanup(#[source] Box<dyn StdError + Send>),
#[error("Gateway client error: {0}")]
Expand Down
6 changes: 4 additions & 2 deletions deployer/src/runtime_manager.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{path::PathBuf, sync::Arc};
use std::{convert::TryInto, path::PathBuf, sync::Arc};

use anyhow::Context;
use shuttle_proto::runtime::{self, runtime_client::RuntimeClient, SubscribeLogsRequest};
Expand Down Expand Up @@ -99,7 +99,9 @@ impl RuntimeManager {

tokio::spawn(async move {
while let Ok(Some(log)) = stream.message().await {
sender.send(log.into()).expect("to send log to persistence");
if let Ok(log) = log.try_into() {
sender.send(log).expect("to send log to persistence");
}
}
});

Expand Down
2 changes: 1 addition & 1 deletion proto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ anyhow = { workspace = true }
chrono = { workspace = true }
home = "0.5.4"
prost = "0.11.2"
prost-types = "0.11.0"
prost-types = { workspace = true }
tokio = { version = "1.22.0", features = ["process"] }
tonic = { workspace = true }
tracing = { workspace = true }
Expand Down
20 changes: 12 additions & 8 deletions proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ pub mod provisioner {

pub mod runtime {
use std::{
convert::TryFrom,
path::PathBuf,
process::Command,
time::{Duration, SystemTime},
Expand All @@ -104,6 +105,7 @@ pub mod runtime {
use anyhow::Context;
use chrono::DateTime;
use prost_types::Timestamp;
use shuttle_common::models::ParseError;
use tokio::process;
use tonic::transport::{Channel, Endpoint};
use tracing::info;
Expand Down Expand Up @@ -159,18 +161,20 @@ pub mod runtime {
}
}

impl From<LogItem> for shuttle_common::LogItem {
fn from(log: LogItem) -> Self {
Self {
id: Uuid::from_slice(&log.id).unwrap(),
timestamp: DateTime::from(SystemTime::try_from(log.timestamp.unwrap()).unwrap()),
state: LogState::from_i32(log.state).unwrap().into(),
level: LogLevel::from_i32(log.level).unwrap().into(),
impl TryFrom<LogItem> for shuttle_common::LogItem {
type Error = ParseError;

fn try_from(log: LogItem) -> Result<Self, Self::Error> {
Ok(Self {
id: Uuid::from_slice(&log.id)?,
timestamp: DateTime::from(SystemTime::try_from(log.timestamp.unwrap_or_default())?),
state: LogState::from_i32(log.state).unwrap_or_default().into(),
level: LogLevel::from_i32(log.level).unwrap_or_default().into(),
file: log.file,
line: log.line,
target: log.target,
fields: log.fields,
}
})
}
}

Expand Down
Loading

0 comments on commit a8b6166

Please sign in to comment.