Skip to content

Commit

Permalink
feat: Allow mutable reference to Request queries #391 (#393)
Browse files Browse the repository at this point in the history
* chore: cargo clippy

* chore: cargo clippy

* wip

* wip

* Format Rust code using rustfmt

* wip

* Format Rust code using rustfmt

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
chrislearn and github-actions[bot] authored Aug 30, 2023
1 parent 1838650 commit d978647
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 52 deletions.
8 changes: 4 additions & 4 deletions crates/core/src/conn/quinn/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,14 @@ async fn process_web_transport(
conn: salvo_http3::server::Connection<salvo_http3::http3_quinn::Connection, Bytes>,
request: hyper::Request<()>,
stream: RequestStream<salvo_http3::http3_quinn::BidiStream<Bytes>, Bytes>,
mut hyper_handler: crate::service::HyperHandler,
hyper_handler: crate::service::HyperHandler,
) -> IoResult<Option<salvo_http3::server::Connection<salvo_http3::http3_quinn::Connection, Bytes>>> {
let (parts, _body) = request.into_parts();
let mut request = hyper::Request::from_parts(parts, ReqBody::None);
request.extensions_mut().insert(Mutex::new(conn));
request.extensions_mut().insert(stream);

let mut response = hyper::service::Service::call(&mut hyper_handler, request)
let mut response = hyper::service::Service::call(&hyper_handler, request)
.await
.map_err(|e| IoError::new(ErrorKind::Other, format!("failed to call hyper service : {}", e)))?;

Expand Down Expand Up @@ -192,7 +192,7 @@ async fn process_web_transport(
async fn process_request<S>(
request: hyper::Request<()>,
stream: RequestStream<S, Bytes>,
mut hyper_handler: crate::service::HyperHandler,
hyper_handler: crate::service::HyperHandler,
) -> IoResult<()>
where
S: salvo_http3::quic::BidiStream<Bytes> + Send + Unpin + 'static,
Expand All @@ -202,7 +202,7 @@ where
let (parts, _body) = request.into_parts();
let request = hyper::Request::from_parts(parts, ReqBody::from(H3ReqBody::new(rx)));

let response = hyper::service::Service::call(&mut hyper_handler, request)
let response = hyper::service::Service::call(&hyper_handler, request)
.await
.map_err(|e| IoError::new(ErrorKind::Other, format!("failed to call hyper service : {}", e)))?;

Expand Down
27 changes: 27 additions & 0 deletions crates/core/src/http/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ impl Request {

/// Returns a mutable reference to the associated URI.
///
/// *Notice: If you using this mutable reference to change the uri, you should change the `params` and `queries` manually.*
///
/// # Examples
///
/// ```
Expand All @@ -200,6 +202,15 @@ impl Request {
&mut self.uri
}

/// Set the associated URI. `querie` will be reset.
///
/// *Notice: `params` will not reset.*
#[inline]
pub fn set_uri(&mut self, uri: Uri) {
self.uri = uri;
self.queries = OnceCell::new();
}

/// Returns a reference to the associated HTTP method.
///
/// # Examples
Expand Down Expand Up @@ -258,9 +269,20 @@ impl Request {
}
/// Get request remote address.
#[inline]
pub fn remote_addr_mut(&mut self) -> &mut SocketAddr {
&mut self.remote_addr
}

/// Get request remote address reference.
#[inline]
pub fn local_addr(&self) -> &SocketAddr {
&self.local_addr
}
/// Get mutable request remote address reference.
#[inline]
pub fn local_addr_mut(&mut self) -> &mut SocketAddr {
&mut self.local_addr
}

/// Returns a reference to the associated header field map.
///
Expand Down Expand Up @@ -501,6 +523,11 @@ impl Request {
.collect()
})
}
/// Get mutable queries reference.
pub fn queries_mut(&mut self) -> &MultiMap<String, String> {
let _ = self.queries();
self.queries.get_mut().unwrap()
}

/// Get query value from queries.
#[inline]
Expand Down
5 changes: 1 addition & 4 deletions crates/core/src/serde/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ pub(crate) struct RequestDeserializer<'de> {
impl<'de> RequestDeserializer<'de> {
/// Construct a new `RequestDeserializer<I, E>`.
#[inline]
pub(crate) fn new(
request: &'de mut Request,
metadata: &'de Metadata,
) -> Result<RequestDeserializer<'de>, ParseError> {
pub(crate) fn new(request: &'de Request, metadata: &'de Metadata) -> Result<RequestDeserializer<'de>, ParseError> {
let mut payload = None;
if let Some(ctype) = request.content_type() {
match ctype.subtype() {
Expand Down
2 changes: 1 addition & 1 deletion crates/oapi-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ impl ToTokens for ExternalDocs {

/// Represents OpenAPI Any value used in example and default fields.
#[derive(Clone, Debug)]
pub(self) enum AnyValue {
enum AnyValue {
String(TokenStream2),
Json(TokenStream2),
DefaultTrait { struct_ident: Ident, field_ident: Member },
Expand Down
76 changes: 36 additions & 40 deletions crates/oapi-macros/src/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,29 +177,27 @@ pub(crate) fn parse_value(attributes: &[Attribute]) -> Option<SerdeValue> {
.iter()
.filter(|attribute| attribute.path().is_ident("serde"))
.map(|serde_attribute| serde_attribute.parse_args_with(SerdeValue::parse).unwrap_or_abort())
.fold(Some(SerdeValue::default()), |acc, value| {
acc.map(|mut acc| {
if value.skip {
acc.skip = value.skip;
}
if value.skip_serializing_if {
acc.skip_serializing_if = value.skip_serializing_if;
}
if value.rename.is_some() {
acc.rename = value.rename;
}
if value.flatten {
acc.flatten = value.flatten;
}
if value.is_default {
acc.is_default = value.is_default;
}
if value.double_option {
acc.double_option = value.double_option;
}
.try_fold(SerdeValue::default(), |mut acc, value| {
if value.skip {
acc.skip = value.skip;
}
if value.skip_serializing_if {
acc.skip_serializing_if = value.skip_serializing_if;
}
if value.rename.is_some() {
acc.rename = value.rename;
}
if value.flatten {
acc.flatten = value.flatten;
}
if value.is_default {
acc.is_default = value.is_default;
}
if value.double_option {
acc.double_option = value.double_option;
}

acc
})
Some(acc)
})
}

Expand All @@ -208,26 +206,24 @@ pub(crate) fn parse_container(attributes: &[Attribute]) -> Option<SerdeContainer
.iter()
.filter(|attribute| attribute.path().is_ident("serde"))
.map(|serde_attribute| serde_attribute.parse_args_with(SerdeContainer::parse).unwrap_or_abort())
.fold(Some(SerdeContainer::default()), |acc, value| {
acc.map(|mut acc| {
if value.is_default {
acc.is_default = value.is_default;
}
match value.enum_repr {
SerdeEnumRepr::ExternallyTagged => {}
SerdeEnumRepr::Untagged
| SerdeEnumRepr::InternallyTagged { .. }
| SerdeEnumRepr::AdjacentlyTagged { .. }
| SerdeEnumRepr::UnfinishedAdjacentlyTagged { .. } => {
acc.enum_repr = value.enum_repr;
}
}
if value.rename_all.is_some() {
acc.rename_all = value.rename_all;
.try_fold(SerdeContainer::default(), |mut acc, value| {
if value.is_default {
acc.is_default = value.is_default;
}
match value.enum_repr {
SerdeEnumRepr::ExternallyTagged => {}
SerdeEnumRepr::Untagged
| SerdeEnumRepr::InternallyTagged { .. }
| SerdeEnumRepr::AdjacentlyTagged { .. }
| SerdeEnumRepr::UnfinishedAdjacentlyTagged { .. } => {
acc.enum_repr = value.enum_repr;
}
}
if value.rename_all.is_some() {
acc.rename_all = value.rename_all;
}

acc
})
Some(acc)
})
}

Expand Down
2 changes: 1 addition & 1 deletion examples/db-graphql/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ impl QueryRoot {
fn get_all_users(context: &DatabaseContext) -> FieldResult<Vec<User>> {
let read = context.0.read();
let users = read.get_all_users();
let mut result = Vec::<User>::new();
let mut result = Vec::with_capacity(users.len());
result.reserve(users.len());
for user in users {
result.push(User {
Expand Down
2 changes: 1 addition & 1 deletion examples/otel-jaeger/src/exporter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ impl Exporter {
let registry = Registry::new_custom(None, None).expect("create prometheus registry");
Self { registry }
}
fn handle(&self, req: &mut Request, res: &mut Response) {
fn handle(&self, req: &Request, res: &mut Response) {
if req.method() != Method::GET {
res.status_code(StatusCode::METHOD_NOT_ALLOWED);
return;
Expand Down
2 changes: 1 addition & 1 deletion examples/with-listenfd/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async fn main() -> Result<(), salvo::Error> {
let (addr, listener) = if let Some(listener) = listenfd.take_tcp_listener(0)? {
(
listener.local_addr()?,
tokio::net::TcpListener::from_std(listener.into()).unwrap(),
tokio::net::TcpListener::from_std(listener).unwrap(),
)
} else {
let addr: SocketAddr = format!(
Expand Down

0 comments on commit d978647

Please sign in to comment.