diff --git a/server/svix-server/src/v1/endpoints/application.rs b/server/svix-server/src/v1/endpoints/application.rs index 9ccad731e..1cad96ad5 100644 --- a/server/svix-server/src/v1/endpoints/application.rs +++ b/server/svix-server/src/v1/endpoints/application.rs @@ -1,6 +1,8 @@ // SPDX-FileCopyrightText: © 2022 Svix Authors // SPDX-License-Identifier: MIT +use std::borrow::Cow; + use crate::{ core::{ security::{ @@ -12,8 +14,13 @@ use crate::{ db::models::application, error::{HttpError, Result}, v1::utils::{ - validate_no_control_characters, EmptyResponse, ListResponse, ModelIn, ModelOut, Pagination, - PaginationLimit, ValidatedJson, ValidatedQuery, + patch::{ + patch_field_non_nullable, patch_field_nullable, UnrequiredField, + UnrequiredNullableField, + }, + validate_no_control_characters, validate_no_control_characters_unrequired, EmptyResponse, + ListResponse, ModelIn, ModelOut, Pagination, PaginationLimit, ValidatedJson, + ValidatedQuery, }, }; use axum::{ @@ -27,7 +34,7 @@ use sea_orm::{entity::prelude::*, ActiveValue::Set, QueryOrder}; use sea_orm::{ActiveModelTrait, DatabaseConnection, QuerySelect}; use serde::{Deserialize, Serialize}; use svix_server_derive::{ModelIn, ModelOut}; -use validator::Validate; +use validator::{Validate, ValidationError}; #[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, Validate, ModelIn)] #[serde(rename_all = "camelCase")] @@ -58,6 +65,82 @@ impl ModelIn for ApplicationIn { } } +#[derive(Deserialize, ModelIn, Serialize, Validate)] +#[serde(rename_all = "camelCase")] +pub struct ApplicationPatch { + #[serde(default, skip_serializing_if = "UnrequiredField::is_absent")] + #[validate( + custom = "validate_name_length_patch", + custom = "validate_no_control_characters_unrequired" + )] + pub name: UnrequiredField, + + #[serde(default, skip_serializing_if = "UnrequiredNullableField::is_absent")] + #[validate(custom = "validate_rate_limit_patch")] + pub rate_limit: UnrequiredNullableField, + + #[serde(default, skip_serializing_if = "UnrequiredNullableField::is_absent")] + #[validate] + pub uid: UnrequiredNullableField, +} + +impl ModelIn for ApplicationPatch { + type ActiveModel = application::ActiveModel; + + fn update_model(self, model: &mut Self::ActiveModel) { + let ApplicationPatch { + name, + rate_limit, + uid, + } = self; + + // `model`'s version of `rate_limit` is an i32, while `self`'s is a u16. + let rate_limit_map = |x: u16| -> i32 { x.into() }; + + patch_field_non_nullable!(model, name); + patch_field_nullable!(model, rate_limit, rate_limit_map); + patch_field_nullable!(model, uid); + } +} + +fn validate_name_length_patch( + name: &UnrequiredField, +) -> std::result::Result<(), ValidationError> { + match name { + UnrequiredField::Absent => Ok(()), + UnrequiredField::Some(s) => { + if s.is_empty() { + let mut error = ValidationError::new("length"); + error.message = Some(Cow::from( + "Application names must be at least one character", + )); + Err(error) + } else { + Ok(()) + } + } + } +} + +fn validate_rate_limit_patch( + rate_limit: &UnrequiredNullableField, +) -> std::result::Result<(), ValidationError> { + match rate_limit { + UnrequiredNullableField::Absent | UnrequiredNullableField::None => Ok(()), + UnrequiredNullableField::Some(rate_limit) => { + if *rate_limit > 0 { + Ok(()) + } else { + let mut error = ValidationError::new("range"); + error.message = Some(Cow::from( + "Application rate limits must be at least 1 if set", + )); + Err(error) + } + } + } +} + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, ModelOut)] #[serde(rename_all = "camelCase")] pub struct ApplicationOut { @@ -172,6 +255,21 @@ async fn update_application( Ok(Json(ret.into())) } +async fn patch_application( + Extension(ref db): Extension, + ValidatedJson(data): ValidatedJson, + AuthenticatedOrganizationWithApplication { + permissions: _, + app, + }: AuthenticatedOrganizationWithApplication, +) -> Result> { + let mut app: application::ActiveModel = app.into(); + data.update_model(&mut app); + + let ret = app.update(db).await?; + Ok(Json(ret.into())) +} + async fn delete_application( Extension(ref db): Extension, AuthenticatedOrganizationWithApplication { @@ -193,13 +291,14 @@ pub fn router() -> Router { "/app/:app_id/", get(get_application) .put(update_application) + .patch(patch_application) .delete(delete_application), ) } #[cfg(test)] mod tests { - use super::ApplicationIn; + use super::{ApplicationIn, ApplicationPatch}; use serde_json::json; use validator::Validate; @@ -235,4 +334,31 @@ mod tests { .unwrap(); valid.validate().unwrap(); } + + // FIXME: How to eliminate the repetition here? + #[test] + fn test_application_patch_validation() { + let invalid_1: ApplicationPatch = + serde_json::from_value(json!({ "name": APP_NAME_INVALID })).unwrap(); + let invalid_2: ApplicationPatch = serde_json::from_value(json!({ + "name": APP_NAME_VALID, + "rateLimit": RATE_LIMIT_INVALID })) + .unwrap(); + let invalid_3: ApplicationPatch = serde_json::from_value(json!({ + "name": APP_NAME_VALID, + "uid": UID_INVALID })) + .unwrap(); + + for a in [invalid_1, invalid_2, invalid_3] { + assert!(a.validate().is_err()); + } + + let valid: ApplicationPatch = serde_json::from_value(json!({ + "name": APP_NAME_VALID, + "rateLimit": RATE_LIMIT_VALID, + "uid": UID_VALID, + })) + .unwrap(); + valid.validate().unwrap(); + } } diff --git a/server/svix-server/src/v1/endpoints/endpoint/crud.rs b/server/svix-server/src/v1/endpoints/endpoint/crud.rs index 1ea7dc498..37c4b81be 100644 --- a/server/svix-server/src/v1/endpoints/endpoint/crud.rs +++ b/server/svix-server/src/v1/endpoints/endpoint/crud.rs @@ -9,7 +9,7 @@ use sea_orm::{entity::prelude::*, ActiveValue::Set, QueryOrder}; use sea_orm::{ActiveModelTrait, DatabaseConnection, QuerySelect}; use url::Url; -use super::{secrets::generate_secret, EndpointIn, EndpointOut}; +use super::{secrets::generate_secret, EndpointIn, EndpointOut, EndpointPatch}; use crate::{ cfg::Configuration, core::{ @@ -23,6 +23,7 @@ use crate::{ db::models::{endpoint, eventtype}, error::{HttpError, Result, ValidationErrorItem}, v1::utils::{ + patch::{UnrequiredField, UnrequiredNullableField}, EmptyResponse, ListResponse, ModelIn, ModelOut, Pagination, PaginationLimit, ValidatedJson, ValidatedQuery, }, @@ -156,6 +157,47 @@ pub(super) async fn update_endpoint( Ok(Json(ret.into())) } +pub(super) async fn patch_endpoint( + Extension(ref db): Extension, + Extension(cfg): Extension, + Extension(op_webhooks): Extension, + Path((_app_id, endp_id)): Path<(ApplicationIdOrUid, EndpointIdOrUid)>, + ValidatedJson(data): ValidatedJson, + AuthenticatedApplication { permissions, app }: AuthenticatedApplication, +) -> Result> { + let endp = endpoint::Entity::secure_find_by_id_or_uid(app.id.clone(), endp_id) + .one(db) + .await? + .ok_or_else(|| HttpError::not_found(None, None))?; + + if let UnrequiredNullableField::Some(ref event_types_ids) = data.event_types_ids { + validate_event_types(db, event_types_ids, &permissions.org_id).await?; + } + if let UnrequiredField::Some(url) = &data.url { + validate_endpoint_url(url, cfg.endpoint_https_only)?; + } + + let mut endp: endpoint::ActiveModel = endp.into(); + data.update_model(&mut endp); + + let ret = endp.update(db).await?; + + let app_uid = app.uid; + op_webhooks + .send_operational_webhook( + &permissions.org_id, + OperationalWebhook::EndpointUpdated(EndpointEvent { + app_id: &ret.app_id, + app_uid: app_uid.as_ref(), + endpoint_id: &ret.id, + endpoint_uid: ret.uid.as_ref(), + }), + ) + .await?; + + Ok(Json(ret.into())) +} + pub(super) async fn delete_endpoint( Extension(ref db): Extension, Extension(op_webhooks): Extension, diff --git a/server/svix-server/src/v1/endpoints/endpoint/mod.rs b/server/svix-server/src/v1/endpoints/endpoint/mod.rs index d65701111..a2302e8ba 100644 --- a/server/svix-server/src/v1/endpoints/endpoint/mod.rs +++ b/server/svix-server/src/v1/endpoints/endpoint/mod.rs @@ -15,7 +15,14 @@ use crate::{ }, db::models::messagedestination, error::HttpError, - v1::utils::{api_not_implemented, validate_no_control_characters, ModelIn}, + v1::utils::{ + api_not_implemented, + patch::{ + patch_field_non_nullable, patch_field_nullable, UnrequiredField, + UnrequiredNullableField, + }, + validate_no_control_characters, validate_no_control_characters_unrequired, ModelIn, + }, }; use axum::{ @@ -28,7 +35,7 @@ use sea_orm::{ ActiveValue::Set, ColumnTrait, DatabaseConnection, FromQueryResult, QueryFilter, QuerySelect, }; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, collections::HashSet}; +use std::{borrow::Cow, collections::HashMap, collections::HashSet}; use url::Url; use svix_server_derive::{ModelIn, ModelOut}; @@ -49,6 +56,15 @@ pub fn validate_event_types_ids( } } +fn validate_event_types_ids_unrequired_nullable( + event_types_ids: &UnrequiredNullableField, +) -> std::result::Result<(), ValidationError> { + match event_types_ids { + UnrequiredNullableField::Absent | UnrequiredNullableField::None => Ok(()), + UnrequiredNullableField::Some(event_type_ids) => validate_event_types_ids(event_type_ids), + } +} + pub fn validate_channels_endpoint( channels: &EventChannelSet, ) -> std::result::Result<(), ValidationError> { @@ -62,6 +78,15 @@ pub fn validate_channels_endpoint( } } +fn validate_channels_endpoint_unrequired_nullable( + channels: &UnrequiredNullableField, +) -> std::result::Result<(), ValidationError> { + match channels { + UnrequiredNullableField::Absent | UnrequiredNullableField::None => Ok(()), + UnrequiredNullableField::Some(channels) => validate_channels_endpoint(channels), + } +} + pub fn validate_url(val: &str) -> std::result::Result<(), ValidationError> { match Url::parse(val) { Ok(url) => { @@ -79,6 +104,15 @@ pub fn validate_url(val: &str) -> std::result::Result<(), ValidationError> { } } +fn validate_url_unrequired( + val: &UnrequiredField, +) -> std::result::Result<(), ValidationError> { + match val { + UnrequiredField::Absent => Ok(()), + UnrequiredField::Some(val) => validate_url(val), + } +} + #[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, Validate, ModelIn)] #[serde(rename_all = "camelCase")] pub struct EndpointIn { @@ -134,6 +168,115 @@ impl ModelIn for EndpointIn { } } +#[derive(Clone, Debug, Default, Serialize, Deserialize, Validate, ModelIn)] +#[serde(rename_all = "camelCase")] +pub struct EndpointPatch { + #[serde(default)] + #[serde(skip_serializing_if = "UnrequiredField::is_absent")] + #[validate(custom = "validate_no_control_characters_unrequired")] + pub description: UnrequiredField, + + #[validate(custom = "validate_rate_limit_patch")] + #[serde(default, skip_serializing_if = "UnrequiredNullableField::is_absent")] + pub rate_limit: UnrequiredNullableField, + + #[validate] + #[serde(default, skip_serializing_if = "UnrequiredNullableField::is_absent")] + pub uid: UnrequiredNullableField, + + #[validate(custom = "validate_url_unrequired")] + #[serde(default)] + pub url: UnrequiredField, + + #[validate(custom = "validate_minimum_version_patch")] + #[serde(default)] + pub version: UnrequiredField, + + #[serde(default)] + #[serde(skip_serializing_if = "UnrequiredField::is_absent")] + pub disabled: UnrequiredField, + + #[serde(default, rename = "filterTypes")] + #[validate(custom = "validate_event_types_ids_unrequired_nullable")] + #[validate] + #[serde(skip_serializing_if = "UnrequiredNullableField::is_absent")] + pub event_types_ids: UnrequiredNullableField, + + #[validate(custom = "validate_channels_endpoint_unrequired_nullable")] + #[validate] + #[serde(default, skip_serializing_if = "UnrequiredNullableField::is_absent")] + pub channels: UnrequiredNullableField, + + #[validate] + #[serde(default)] + #[serde(rename = "secret")] + #[serde(skip_serializing_if = "UnrequiredNullableField::is_absent")] + pub key: UnrequiredNullableField, +} + +impl ModelIn for EndpointPatch { + type ActiveModel = endpoint::ActiveModel; + + fn update_model(self, model: &mut Self::ActiveModel) { + let EndpointPatch { + description, + rate_limit, + uid, + url, + version, + disabled, + event_types_ids, + channels, + key: _, + } = self; + + let map = |x: u16| -> i32 { x.into() }; + + patch_field_non_nullable!(model, description); + patch_field_nullable!(model, rate_limit, map); + patch_field_nullable!(model, uid); + patch_field_non_nullable!(model, url); + patch_field_non_nullable!(model, version, map); + patch_field_non_nullable!(model, disabled); + patch_field_nullable!(model, event_types_ids); + patch_field_nullable!(model, channels); + } +} + +fn validate_rate_limit_patch( + rate_limit: &UnrequiredNullableField, +) -> std::result::Result<(), ValidationError> { + match rate_limit { + UnrequiredNullableField::Absent | UnrequiredNullableField::None => Ok(()), + UnrequiredNullableField::Some(rate_limit) => { + if *rate_limit > 0 { + Ok(()) + } else { + let mut error = ValidationError::new("range"); + error.message = Some(Cow::from("Endpoint rate limits must be at least 1 if set")); + Err(error) + } + } + } +} + +fn validate_minimum_version_patch( + version: &UnrequiredField, +) -> std::result::Result<(), ValidationError> { + match version { + UnrequiredField::Absent => Ok(()), + UnrequiredField::Some(version) => { + if *version == 0 { + let mut error = ValidationError::new("range"); + error.message = Some(Cow::from("Endpoint versions must be at least one")); + Err(error) + } else { + Ok(()) + } + } + } +} + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, ModelOut)] #[serde(rename_all = "camelCase")] pub struct EndpointOut { @@ -338,6 +481,7 @@ pub fn router() -> Router { "/endpoint/:endp_id/", get(crud::get_endpoint) .put(crud::update_endpoint) + .patch(crud::patch_endpoint) .delete(crud::delete_endpoint), ) .route( diff --git a/server/svix-server/src/v1/endpoints/event_type.rs b/server/svix-server/src/v1/endpoints/event_type.rs index ed1e73636..b640ff5d2 100644 --- a/server/svix-server/src/v1/endpoints/event_type.rs +++ b/server/svix-server/src/v1/endpoints/event_type.rs @@ -9,8 +9,14 @@ use crate::{ db::models::eventtype, error::{HttpError, Result}, v1::utils::{ - api_not_implemented, validate_no_control_characters, EmptyResponse, ListResponse, ModelIn, - ModelOut, Pagination, PaginationLimit, ValidatedJson, ValidatedQuery, + api_not_implemented, + patch::{ + patch_field_non_nullable, patch_field_nullable, UnrequiredField, + UnrequiredNullableField, + }, + validate_no_control_characters, validate_no_control_characters_unrequired, EmptyResponse, + ListResponse, ModelIn, ModelOut, Pagination, PaginationLimit, ValidatedJson, + ValidatedQuery, }, }; use axum::{ @@ -70,6 +76,40 @@ impl ModelIn for EventTypeUpdate { } } +#[derive(Deserialize, ModelIn, Serialize, Validate)] +#[serde(rename_all = "camelCase")] +struct EventTypePatch { + #[serde(default, skip_serializing_if = "UnrequiredField::is_absent")] + #[validate(custom = "validate_no_control_characters_unrequired")] + description: UnrequiredField, + + #[serde( + default, + rename = "archived", + skip_serializing_if = "UnrequiredField::is_absent" + )] + deleted: UnrequiredField, + + #[serde(default, skip_serializing_if = "UnrequiredNullableField::is_absent")] + schemas: UnrequiredNullableField, +} + +impl ModelIn for EventTypePatch { + type ActiveModel = eventtype::ActiveModel; + + fn update_model(self, model: &mut Self::ActiveModel) { + let EventTypePatch { + description, + deleted, + schemas, + } = self; + + patch_field_non_nullable!(model, description); + patch_field_non_nullable!(model, deleted); + patch_field_nullable!(model, schemas); + } +} + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, ModelOut)] #[serde(rename_all = "camelCase")] pub struct EventTypeOut { @@ -218,6 +258,24 @@ async fn update_event_type( Ok(Json(ret.into())) } +async fn patch_event_type( + Extension(ref db): Extension, + Path(evtype_name): Path, + ValidatedJson(data): ValidatedJson, + AuthenticatedOrganization { permissions }: AuthenticatedOrganization, +) -> Result> { + let evtype = eventtype::Entity::secure_find_by_name(permissions.org_id.clone(), evtype_name) + .one(db) + .await? + .ok_or_else(|| HttpError::not_found(None, None))?; + + let mut evtype: eventtype::ActiveModel = evtype.into(); + data.update_model(&mut evtype); + + let ret = evtype.update(db).await?; + Ok(Json(ret.into())) +} + async fn delete_event_type( Extension(ref db): Extension, Path(evtype_name): Path, @@ -244,6 +302,7 @@ pub fn router() -> Router { "/event-type/:event_type_name/", get(get_event_type) .put(update_event_type) + .patch(patch_event_type) .delete(delete_event_type), ) .route( diff --git a/server/svix-server/src/v1/utils.rs b/server/svix-server/src/v1/utils/mod.rs similarity index 98% rename from server/svix-server/src/v1/utils.rs rename to server/svix-server/src/v1/utils/mod.rs index a5664a1fc..b055501c8 100644 --- a/server/svix-server/src/v1/utils.rs +++ b/server/svix-server/src/v1/utils/mod.rs @@ -20,6 +20,9 @@ use crate::{ error::{Error, HttpError, Result, ValidationErrorItem}, }; +pub mod patch; +use patch::UnrequiredField; + const fn default_limit() -> PaginationLimit { PaginationLimit(50) } @@ -396,6 +399,15 @@ pub fn validate_no_control_characters(str: &str) -> std::result::Result<(), Vali Ok(()) } +pub fn validate_no_control_characters_unrequired( + str: &UnrequiredField, +) -> std::result::Result<(), ValidationError> { + match str { + UnrequiredField::Absent => Ok(()), + UnrequiredField::Some(str) => validate_no_control_characters(str), + } +} + #[cfg(test)] mod tests { use validator::Validate; diff --git a/server/svix-server/src/v1/utils/patch.rs b/server/svix-server/src/v1/utils/patch.rs new file mode 100644 index 000000000..ba421d5ad --- /dev/null +++ b/server/svix-server/src/v1/utils/patch.rs @@ -0,0 +1,246 @@ +// SPDX-FileCopyrightText: © 2022 Svix Authors +// SPDX-License-Identifier: MIT + +//! Module defining utilites for PATCH requests focused mostly around non-required field types. + +use serde::{Deserialize, Serialize}; +use validator::Validate; + +/// This is an enum that will wrap every nullable field for a PATCH request. Nonnullable fields can +/// be represented via an [`UnrequiredField`]. This differs from an [`Option`] in that it +/// distinguishes null values and absent values such that an optional value in a model may be made +/// None via PATCHing while allowing omitted fields to be skipped when updating. +/// +/// NOTE: You must tag these fields with `#[serde(default)]` in order for the serialization to work +/// correctly. +#[derive(Debug)] +pub enum UnrequiredNullableField { + Absent, + None, + Some(T), +} + +/// This enum is a non-nullable equivalent to [`UnrequiredNullableField`]. This is effectively an +/// [`Option`] with the additional context that any field which uses this type is a member of a +/// PATCH request model and that the field may be absent, meaning it is not to be updated. In +/// comparison, [`Option`]s are used in other [`ModelIn`]s to define a field, that when absent, +/// is `null`. +/// +/// NOTE: You must tag these fields with `#[serde(default)]` in order for the serialization to work +/// correctly. +#[derive(Debug)] +pub enum UnrequiredField { + Absent, + Some(T), +} + +impl UnrequiredNullableField { + pub fn is_absent(&self) -> bool { + matches!(self, UnrequiredNullableField::Absent) + } + + pub fn map(self, f: impl Fn(T) -> U) -> UnrequiredNullableField { + match self { + UnrequiredNullableField::Absent => UnrequiredNullableField::Absent, + UnrequiredNullableField::None => UnrequiredNullableField::None, + UnrequiredNullableField::Some(v) => UnrequiredNullableField::Some(f(v)), + } + } +} + +impl UnrequiredField { + pub fn is_absent(&self) -> bool { + matches!(self, UnrequiredField::Absent) + } + + pub fn map(self, f: impl Fn(T) -> U) -> UnrequiredField { + match self { + UnrequiredField::Absent => UnrequiredField::Absent, + UnrequiredField::Some(v) => UnrequiredField::Some(f(v)), + } + } +} + +impl Default for UnrequiredNullableField { + fn default() -> Self { + Self::Absent + } +} + +impl Default for UnrequiredField { + fn default() -> Self { + Self::Absent + } +} + +impl From> for UnrequiredNullableField { + fn from(opt: Option) -> Self { + match opt { + Some(v) => UnrequiredNullableField::Some(v), + None => UnrequiredNullableField::None, + } + } +} + +impl Validate for UnrequiredNullableField { + fn validate(&self) -> std::result::Result<(), validator::ValidationErrors> { + match self { + UnrequiredNullableField::Absent | UnrequiredNullableField::None => Ok(()), + UnrequiredNullableField::Some(v) => v.validate(), + } + } +} + +impl Validate for UnrequiredField { + fn validate(&self) -> std::result::Result<(), validator::ValidationErrors> { + match self { + UnrequiredField::Absent => Ok(()), + UnrequiredField::Some(v) => v.validate(), + } + } +} + +impl Clone for UnrequiredNullableField { + fn clone(&self) -> Self { + match self { + UnrequiredNullableField::Absent => UnrequiredNullableField::Absent, + UnrequiredNullableField::None => UnrequiredNullableField::None, + UnrequiredNullableField::Some(v) => UnrequiredNullableField::Some(v.clone()), + } + } +} + +impl Clone for UnrequiredField { + fn clone(&self) -> Self { + match self { + UnrequiredField::Absent => UnrequiredField::Absent, + UnrequiredField::Some(v) => UnrequiredField::Some(v.clone()), + } + } +} + +impl Copy for UnrequiredNullableField {} +impl Copy for UnrequiredField {} + +impl<'de, T> Deserialize<'de> for UnrequiredNullableField +where + T: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + Option::deserialize(deserializer).map(Into::into) + } +} + +impl<'de, T> Deserialize<'de> for UnrequiredField +where + T: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + T::deserialize(deserializer).map(UnrequiredField::Some) + } +} + +impl Serialize for UnrequiredNullableField +where + T: Serialize, +{ + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + match self { + UnrequiredNullableField::Absent => Err(serde::ser::Error::custom( + "UnrequiredNullableField must skip serializing if field is absent", + )), + UnrequiredNullableField::None => serializer.serialize_none(), + UnrequiredNullableField::Some(v) => v.serialize(serializer), + } + } +} +impl Serialize for UnrequiredField +where + T: Serialize, +{ + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + match self { + UnrequiredField::Absent => Err(serde::ser::Error::custom( + "UnrequiredField must skip serializing if field is absent", + )), + UnrequiredField::Some(v) => v.serialize(serializer), + } + } +} + +/// Macro that simplifies updating a field on an [`ActiveModel`] for use in a [`ModelIn`] +/// implementation. This macro expands to setting the field when the [`Option`] is `Some`, but +/// performs no operation in the case it is `None`. +/// +/// The input for this macro is three identifiers meant to be `self`, the `model` in a [`ModelIn`] +/// implementation, and the member that `self`, and `model` share that is being modified. +/// +/// Optionally, a fourth identifier may be given which is meant to be a closure that takes the type +/// of self's version of the member beng modified and returns model's version of the member being +/// modified. This is applied via [`UnrequiredNullableField::map`] such that basic type conversions may +/// be made. +/// +/// The nullable equivalent which is used for [`UnrequiredNullableField`] is [`patch_field_nullable`]. +macro_rules! patch_field_non_nullable { + ($model:ident, $member:ident) => { + match $member { + UnrequiredField::Some(v) => $model.$member = Set(v), + UnrequiredField::Absent => {} + } + }; + + ($model:ident, $member:ident, $f:ident) => { + let mapped = $member.map($f); + match mapped { + UnrequiredField::Some(v) => $model.$member = Set(v), + UnrequiredField::Absent => {} + } + }; +} +pub(crate) use patch_field_non_nullable; + +/// Macro that simplifies updating a field on an [`ActiveModel`] for use in a [`ModelIn`] +/// implementation. This macro expands to setting the field when the [`UnrequiredNullableField`] is +/// `Some` and unsetting the field when it is `None`, but performs no operation in the case it is +/// `Absent`. +/// +/// The input for this macro is three identifiers meant to be `self`, the `model` in a [`ModelIn`] +/// implementation, and the member that `self`, and `model` share that is being modified. +/// +/// Optionally, a fourth identifier may be given which is meant to be a closure that takes the type +/// of self's version of the member beng modified and returns model's version of the member being +/// modified. This is applied via [`UnrequiredNullableField::map`] such that basic type conversions may +/// be made. +/// +/// The non-nullable equivalent which is used for [`Option`] is [`patch_field_non_nullable`]. +macro_rules! patch_field_nullable { + ($model:ident, $member:ident) => { + match $member { + UnrequiredNullableField::Some(v) => $model.$member = Set(Some(v)), + UnrequiredNullableField::None => $model.$member = Set(None), + UnrequiredNullableField::Absent => {} + } + }; + + ($model:ident, $member:ident, $f:ident) => { + let mapped = $member.map($f); + match mapped { + UnrequiredNullableField::Some(v) => $model.$member = Set(Some(v)), + UnrequiredNullableField::None => $model.$member = Set(None), + UnrequiredNullableField::Absent => {} + } + }; +} +pub(crate) use patch_field_nullable; diff --git a/server/svix-server/tests/e2e_application.rs b/server/svix-server/tests/e2e_application.rs index c8c90401b..fab7bcb0a 100644 --- a/server/svix-server/tests/e2e_application.rs +++ b/server/svix-server/tests/e2e_application.rs @@ -15,6 +15,128 @@ use utils::{ start_svix_server, IgnoredResponse, }; +// NOTE: PATCHing must be tested exhaustively as if any of the boilerplate is missed then the +// operation could fail. This should probably be made into a macro if at all possible. +#[tokio::test] +async fn test_patch() { + let (client, _jh) = start_svix_server(); + + let app: ApplicationOut = client + .post( + "api/v1/app/", + application_in("first_name"), + StatusCode::CREATED, + ) + .await + .unwrap(); + + // Test that name may be set while the rest are omitted + let _: ApplicationOut = client + .patch( + &format!("api/v1/app/{}/", app.id), + serde_json::json! ({ + "name": "second_name" + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert change was made when later fetched + let out = client + .get::(&format!("api/v1/app/{}/", app.id), StatusCode::OK) + .await + .unwrap(); + assert_eq!(out.name, "second_name".to_owned()); + // Assert that no other field was changed + assert_eq!(out.rate_limit, None); + assert_eq!(out.uid, None); + + // Test that rate_limit may be set while the rest are omitted + let _: ApplicationOut = client + .patch( + &format!("api/v1/app/{}/", app.id), + serde_json::json! ({ + "rateLimit": 1 + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert the change was made + let out = client + .get::(&format!("api/v1/app/{}/", app.id), StatusCode::OK) + .await + .unwrap(); + assert_eq!(out.rate_limit, Some(1)); + // Assert that no other field was changed + assert_eq!(out.name, "second_name".to_owned()); + assert_eq!(out.uid, None); + + // Test that rate_limit may be unset while the rest are omitted + let _: ApplicationOut = client + .patch( + &format!("api/v1/app/{}/", app.id), + serde_json::json!({ "rateLimit": null }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert the change was made + let out = client + .get::(&format!("api/v1/app/{}/", app.id), StatusCode::OK) + .await + .unwrap(); + assert_eq!(out.rate_limit, None); + // Assert that no other field was changed + assert_eq!(out.name, "second_name".to_owned()); + assert_eq!(out.uid, None); + + // Test that uid may be set while the rest are omitted + let _: ApplicationOut = client + .patch( + &format!("api/v1/app/{}/", app.id), + serde_json::json! ({ + "uid": "test_uid" + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert the change was made + let out = client + .get::(&format!("api/v1/app/{}/", app.id), StatusCode::OK) + .await + .unwrap(); + assert_eq!(out.uid, Some(ApplicationUid("test_uid".to_owned()))); + // Assert that no other field was changed + assert_eq!(out.name, "second_name".to_owned()); + assert_eq!(out.rate_limit, None); + + // Test that uid may be unset while the rest are omitted + let _: ApplicationOut = client + .patch( + &format!("api/v1/app/{}/", app.id), + serde_json::json!({ "uid": null }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert the change was made + let out = client + .get::(&format!("api/v1/app/{}/", app.id), StatusCode::OK) + .await + .unwrap(); + assert_eq!(out.uid, None); + // Assert that no other field was changed + assert_eq!(out.name, "second_name".to_owned()); + assert_eq!(out.rate_limit, None); +} + #[tokio::test] async fn test_crud() { let (client, _jh) = start_svix_server(); diff --git a/server/svix-server/tests/e2e_endpoint.rs b/server/svix-server/tests/e2e_endpoint.rs index 9e1c324f0..cae160706 100644 --- a/server/svix-server/tests/e2e_endpoint.rs +++ b/server/svix-server/tests/e2e_endpoint.rs @@ -87,6 +87,370 @@ async fn delete_endpoint(client: &TestClient, app_id: &ApplicationId, ep_id: &st Ok(()) } +#[tokio::test] +async fn test_patch() { + let (client, _jh) = start_svix_server(); + + let app = create_test_app(&client, "v1EndpointPatchTestApp") + .await + .unwrap() + .id; + let ep = create_test_endpoint(&client, &app, "http://bad.url") + .await + .unwrap() + .id; + + let url = format!("api/v1/app/{}/endpoint/{}/", app, ep); + + // Test that the description may be set + let _: EndpointOut = client + .patch( + &url, + serde_json::json!({ + "description": "test" + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert the change was made + let out = client + .get::(&url, StatusCode::OK) + .await + .unwrap(); + assert_eq!(out.description, "test".to_owned()); + // Assert that no other changes were made + assert_eq!(out.rate_limit, None); + assert_eq!(out.uid, None); + assert_eq!(out.url, "http://bad.url".to_owned()); + assert_eq!(out.version, 1); + assert!(!out.disabled); + assert_eq!(out.event_types_ids, None); + assert_eq!(out.channels, None); + + // Test that the rate limit may be set + let _: EndpointOut = client + .patch( + &url, + serde_json::json!({ + "rateLimit": 1, + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert the change was made + let out = client + .get::(&url, StatusCode::OK) + .await + .unwrap(); + assert_eq!(out.rate_limit, Some(1)); + // Assert that no other changes were made + assert_eq!(out.description, "test".to_owned()); + assert_eq!(out.uid, None); + assert_eq!(out.url, "http://bad.url".to_owned()); + assert_eq!(out.version, 1); + assert!(!out.disabled); + assert_eq!(out.event_types_ids, None); + assert_eq!(out.channels, None); + + // Test that the rate limit may be unset + let _: EndpointOut = client + .patch( + &url, + serde_json::json!({ + "rateLimit": null, + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert the change was made + let out = client + .get::(&url, StatusCode::OK) + .await + .unwrap(); + assert_eq!(out.rate_limit, None); + // Assert that no other changes were made + assert_eq!(out.description, "test".to_owned()); + assert_eq!(out.uid, None); + assert_eq!(out.url, "http://bad.url".to_owned()); + assert_eq!(out.version, 1); + assert!(!out.disabled); + assert_eq!(out.event_types_ids, None); + assert_eq!(out.channels, None); + + // Test that the UID may be set + let _: EndpointOut = client + .patch( + &url, + serde_json::json!({ + "uid": "some", + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert the change was made + let out = client + .get::(&url, StatusCode::OK) + .await + .unwrap(); + assert_eq!(out.uid, Some(EndpointUid("some".to_owned()))); + // Assert that no other changes were made + assert_eq!(out.description, "test".to_owned()); + assert_eq!(out.rate_limit, None); + assert_eq!(out.url, "http://bad.url".to_owned()); + assert_eq!(out.version, 1); + assert!(!out.disabled); + assert_eq!(out.event_types_ids, None); + assert_eq!(out.channels, None); + + // Test the UID may be unset + let _: EndpointOut = client + .patch( + &url, + serde_json::json!({ + "uid": null, + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert the change was made + let out = client + .get::(&url, StatusCode::OK) + .await + .unwrap(); + assert_eq!(out.uid, None); + // Assert that no other changes were made + assert_eq!(out.description, "test".to_owned()); + assert_eq!(out.rate_limit, None); + assert_eq!(out.url, "http://bad.url".to_owned()); + assert_eq!(out.version, 1); + assert!(!out.disabled); + assert_eq!(out.event_types_ids, None); + assert_eq!(out.channels, None); + + // Test that the URL may be set + let _: EndpointOut = client + .patch( + &url, + serde_json::json!({ + "url": "http://bad.url2", + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert the change was made + let out = client + .get::(&url, StatusCode::OK) + .await + .unwrap(); + assert_eq!(out.url, "http://bad.url2".to_owned()); + // Assert that no other changes were made + assert_eq!(out.description, "test".to_owned()); + assert_eq!(out.rate_limit, None); + assert_eq!(out.uid, None); + assert_eq!(out.version, 1); + assert!(!out.disabled); + assert_eq!(out.event_types_ids, None); + assert_eq!(out.channels, None); + + // Test that the version may be set + let _: EndpointOut = client + .patch( + &url, + serde_json::json!({ + "version": 2, + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert the change was made + let out = client + .get::(&url, StatusCode::OK) + .await + .unwrap(); + assert_eq!(out.version, 2); + // Assert that no other changes were made + assert_eq!(out.description, "test".to_owned()); + assert_eq!(out.rate_limit, None); + assert_eq!(out.uid, None); + assert_eq!(out.url, "http://bad.url2".to_owned()); + assert!(!out.disabled); + assert_eq!(out.event_types_ids, None); + assert_eq!(out.channels, None); + + // Test that disabled may be set + let _: EndpointOut = client + .patch( + &url, + serde_json::json!({ + "disabled": true, + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert the change was made + let out = client + .get::(&url, StatusCode::OK) + .await + .unwrap(); + assert!(out.disabled); + // Assert that no other changes were made + assert_eq!(out.description, "test".to_owned()); + assert_eq!(out.rate_limit, None); + assert_eq!(out.uid, None); + assert_eq!(out.url, "http://bad.url2".to_owned()); + assert_eq!(out.version, 2); + assert_eq!(out.event_types_ids, None); + assert_eq!(out.channels, None); + + // Test that event type IDs may be set + + // But first make an event type to set it to + let _: EventTypeOut = client + .post( + "api/v1/event-type", + serde_json::json!({ + "description": "a test event type", + "name": "test", + }), + StatusCode::CREATED, + ) + .await + .unwrap(); + + let _: EndpointOut = client + .patch( + &url, + serde_json::json!({ + "filterTypes": [ "test" ], + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert the change was made + let out = client + .get::(&url, StatusCode::OK) + .await + .unwrap(); + assert_eq!( + out.event_types_ids, + Some(EventTypeNameSet(HashSet::from([EventTypeName( + "test".to_owned() + )]))) + ); + // Assert that no other changes were made + assert_eq!(out.description, "test".to_owned()); + assert_eq!(out.rate_limit, None); + assert_eq!(out.uid, None); + assert_eq!(out.url, "http://bad.url2".to_owned()); + assert_eq!(out.version, 2); + assert!(out.disabled); + assert_eq!(out.channels, None); + + // Test that event type IDs may be unset + let _: EndpointOut = client + .patch( + &url, + serde_json::json!({ + "filterTypes": null, + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert the change was made + let out = client + .get::(&url, StatusCode::OK) + .await + .unwrap(); + assert_eq!(out.event_types_ids, None); + // Assert that no other changes were made + assert_eq!(out.description, "test".to_owned()); + assert_eq!(out.rate_limit, None); + assert_eq!(out.uid, None); + assert_eq!(out.url, "http://bad.url2".to_owned()); + assert_eq!(out.version, 2); + assert!(out.disabled); + assert_eq!(out.channels, None); + + // Test that channels may be set + let _: EndpointOut = client + .patch( + &url, + serde_json::json!({ + "channels": [ "test" ], + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert the change was made + let out = client + .get::(&url, StatusCode::OK) + .await + .unwrap(); + assert_eq!( + out.channels, + Some(EventChannelSet(HashSet::from([EventChannel( + "test".to_owned() + )]))) + ); + // Assert that no other changes were made + assert_eq!(out.description, "test".to_owned()); + assert_eq!(out.rate_limit, None); + assert_eq!(out.uid, None); + assert_eq!(out.url, "http://bad.url2".to_owned()); + assert_eq!(out.version, 2); + assert!(out.disabled); + assert_eq!(out.event_types_ids, None); + + // Test that channels may be unset + let _: EndpointOut = client + .patch( + &url, + serde_json::json!({ + "channels": null, + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert the change was made + let out = client + .get::(&url, StatusCode::OK) + .await + .unwrap(); + assert_eq!(out.channels, None); + // Assert that no other changes were made + assert_eq!(out.description, "test".to_owned()); + assert_eq!(out.rate_limit, None); + assert_eq!(out.uid, None); + assert_eq!(out.url, "http://bad.url2".to_owned()); + assert_eq!(out.version, 2); + assert!(out.disabled); + assert_eq!(out.event_types_ids, None); +} + #[tokio::test] async fn test_crud() { let (client, _jh) = start_svix_server(); diff --git a/server/svix-server/tests/e2e_event_type.rs b/server/svix-server/tests/e2e_event_type.rs index 969a14e3c..58fa08bad 100644 --- a/server/svix-server/tests/e2e_event_type.rs +++ b/server/svix-server/tests/e2e_event_type.rs @@ -15,6 +15,115 @@ use utils::{ start_svix_server, }; +#[tokio::test] +async fn test_patch() { + let (client, _jh) = start_svix_server(); + + let et: EventTypeOut = client + .post( + "api/v1/event-type", + event_type_in("test-event-type", serde_json::json!({"test": "value"})).unwrap(), + StatusCode::CREATED, + ) + .await + .unwrap(); + + // Test that description may be set while the rest are omitted + let _: EventTypeOut = client + .patch( + &format!("api/v1/event-type/{}/", et.name), + serde_json::json!({ + "description": "updated_description", + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert that the change was made + let out = client + .get::(&format!("api/v1/event-type/{}/", &et.name), StatusCode::OK) + .await + .unwrap(); + assert_eq!(out.description, "updated_description".to_owned()); + + // Assert the other fields remain unchanged + assert_eq!(out.deleted, et.deleted); + assert_eq!(out.schemas, et.schemas); + + // Test that schemas may be set while the rest are omitted + let _: EventTypeOut = client + .patch( + &format!("api/v1/event-type/{}/", et.name), + serde_json::json!({ + "schemas": {}, + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert that the change was made + let out = client + .get::(&format!("api/v1/event-type/{}/", &et.name), StatusCode::OK) + .await + .unwrap(); + + assert_eq!(out.schemas, Some(serde_json::json!({}))); + + // Assert the other fields remain unchanged + assert_eq!(out.deleted, et.deleted); + assert_eq!(out.description, "updated_description".to_owned()); + + // Test that schemas may be unset while the rest are omitted + let _: EventTypeOut = client + .patch( + &format!("api/v1/event-type/{}/", et.name), + serde_json::json!({ + "schemas": null, + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert that the change was made + let out = client + .get::(&format!("api/v1/event-type/{}/", &et.name), StatusCode::OK) + .await + .unwrap(); + + assert_eq!(out.schemas, None); + + // Assert the other fields remain unchanged + assert_eq!(out.deleted, et.deleted); + assert_eq!(out.description, "updated_description".to_owned()); + + // Test that deleted may be set while the rest are omitted + let _: EventTypeOut = client + .patch( + &format!("api/v1/event-type/{}/", et.name), + serde_json::json!({ + "archived": true, + }), + StatusCode::OK, + ) + .await + .unwrap(); + + // Assert that the change was made + let out = client + .get::(&format!("api/v1/event-type/{}/", &et.name), StatusCode::OK) + .await + .unwrap(); + + assert!(out.deleted); + + // Assert the other fields remain unchanged + assert_eq!(out.schemas, None); + assert_eq!(out.description, "updated_description".to_owned()); +} + #[tokio::test] async fn test_event_type_create_read_list() { let (client, _jh) = start_svix_server();