Skip to content

Commit 96a6661

Browse files
feat: add SessionMiddleware configuration (#251)
* Add SessionMiddleWare Configuration * revert config path monkeypatch * chore(pre-commit.ci): auto fixes from pre-commit hooks * revert debug code * make secure config true by default * chore(pre-commit.ci): auto fixes from pre-commit hooks --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5dc772c commit 96a6661

File tree

4 files changed

+133
-10
lines changed

4 files changed

+133
-10
lines changed

cot-cli/src/project_template/config/dev.toml

+3
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@ url = "sqlite://db.sqlite3?mode=rwc"
55

66
[middlewares]
77
live_reload.enabled = true
8+
9+
[middlewares.session]
10+
secure = false

cot/src/config.rs

+73-1
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,8 @@ impl DatabaseConfig {
413413
pub struct MiddlewareConfig {
414414
/// The configuration for the live reload middleware.
415415
pub live_reload: LiveReloadMiddlewareConfig,
416+
/// The configuration for the session middleware.
417+
pub session: SessionMiddlewareConfig,
416418
}
417419

418420
impl MiddlewareConfig {
@@ -438,16 +440,18 @@ impl MiddlewareConfigBuilder {
438440
/// # Examples
439441
///
440442
/// ```
441-
/// use cot::config::{LiveReloadMiddlewareConfig, MiddlewareConfig};
443+
/// use cot::config::{LiveReloadMiddlewareConfig, MiddlewareConfig, SessionMiddlewareConfig};
442444
///
443445
/// let config = MiddlewareConfig::builder()
444446
/// .live_reload(LiveReloadMiddlewareConfig::builder().enabled(true).build())
447+
/// .session(SessionMiddlewareConfig::builder().secure(false).build())
445448
/// .build();
446449
/// ```
447450
#[must_use]
448451
pub fn build(&self) -> MiddlewareConfig {
449452
MiddlewareConfig {
450453
live_reload: self.live_reload.clone().unwrap_or_default(),
454+
session: self.session.clone().unwrap_or_default(),
451455
}
452456
}
453457
}
@@ -514,6 +518,68 @@ impl LiveReloadMiddlewareConfigBuilder {
514518
}
515519
}
516520

521+
/// The configuration for the session middleware.
522+
///
523+
/// This is used as part of the [`MiddlewareConfig`] struct.
524+
///
525+
/// # Examples
526+
///
527+
/// ```
528+
/// use cot::config::SessionMiddlewareConfig;
529+
///
530+
/// let config = SessionMiddlewareConfig::builder().secure(false).build();
531+
/// ```
532+
#[derive(Debug, Default, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
533+
#[builder(build_fn(skip, error = std::convert::Infallible))]
534+
#[serde(default)]
535+
pub struct SessionMiddlewareConfig {
536+
/// Whether the session middleware is secure.
537+
///
538+
/// # Examples
539+
///
540+
/// ```
541+
/// use cot::config::SessionMiddlewareConfig;
542+
///
543+
/// let config = SessionMiddlewareConfig::builder().secure(false).build();
544+
/// ```
545+
pub secure: bool,
546+
}
547+
548+
impl SessionMiddlewareConfig {
549+
/// Create a new [`SessionMiddlewareConfigBuilder`] to build a
550+
/// [`SessionMiddlewareConfig`].
551+
///
552+
/// # Examples
553+
///
554+
/// ```
555+
/// use cot::config::SessionMiddlewareConfig;
556+
///
557+
/// let config = SessionMiddlewareConfig::builder().build();
558+
/// ```
559+
#[must_use]
560+
pub fn builder() -> SessionMiddlewareConfigBuilder {
561+
SessionMiddlewareConfigBuilder::default()
562+
}
563+
}
564+
565+
impl SessionMiddlewareConfigBuilder {
566+
/// Builds the session middleware configuration.
567+
///
568+
/// # Examples
569+
///
570+
/// ```
571+
/// use cot::config::SessionMiddlewareConfig;
572+
///
573+
/// let config = SessionMiddlewareConfig::builder().secure(false).build();
574+
/// ```
575+
#[must_use]
576+
pub fn build(&self) -> SessionMiddlewareConfig {
577+
SessionMiddlewareConfig {
578+
secure: self.secure.unwrap_or(true),
579+
}
580+
}
581+
}
582+
517583
/// A secret key.
518584
///
519585
/// This is a wrapper over a byte array, which is used to store a cryptographic
@@ -718,6 +784,10 @@ mod tests {
718784
secret_key = "123abc"
719785
fallback_secret_keys = ["456def", "789ghi"]
720786
auth_backend = { type = "none" }
787+
[middlewares]
788+
live_reload.enabled = true
789+
[middlewares.session]
790+
secure = false
721791
"#;
722792

723793
let config = ProjectConfig::from_toml(toml_content).unwrap();
@@ -729,6 +799,8 @@ mod tests {
729799
assert_eq!(config.fallback_secret_keys[0].as_bytes(), b"456def");
730800
assert_eq!(config.fallback_secret_keys[1].as_bytes(), b"789ghi");
731801
assert_eq!(config.auth_backend, AuthBackendConfig::None);
802+
assert!(config.middlewares.live_reload.enabled);
803+
assert!(!config.middlewares.session.secure);
732804
}
733805

734806
#[test]

cot/src/middleware.rs

+50-7
Original file line numberDiff line numberDiff line change
@@ -248,14 +248,60 @@ where
248248
/// A middleware that provides session management.
249249
///
250250
/// By default, it uses an in-memory store for session data.
251-
#[derive(Debug, Copy, Clone)]
252-
pub struct SessionMiddleware;
251+
#[derive(Debug, Clone)]
252+
pub struct SessionMiddleware {
253+
inner: SessionManagerLayer<MemoryStore>,
254+
}
253255

254256
impl SessionMiddleware {
255257
/// Crates a new instance of [`SessionMiddleware`].
256258
#[must_use]
257259
pub fn new() -> Self {
258-
Self {}
260+
let store = MemoryStore::default();
261+
let layer = SessionManagerLayer::new(store);
262+
Self { inner: layer }
263+
}
264+
/// Creates a new instance of [`SessionMiddleware`] from the application
265+
/// context.
266+
///
267+
/// # Examples
268+
///
269+
/// ```
270+
/// use cot::middleware::SessionMiddleware;
271+
/// use cot::project::{RootHandlerBuilder, WithApps};
272+
/// use cot::{BoxedHandler, Project, ProjectContext};
273+
///
274+
/// struct MyProject;
275+
/// impl Project for MyProject {
276+
/// fn middlewares(
277+
/// &self,
278+
/// handler: RootHandlerBuilder,
279+
/// context: &ProjectContext<WithApps>,
280+
/// ) -> BoxedHandler {
281+
/// handler
282+
/// .middleware(SessionMiddleware::from_context(context))
283+
/// .build()
284+
/// }
285+
/// }
286+
/// ```
287+
#[must_use]
288+
pub fn from_context(context: &crate::ProjectContext<crate::project::WithApps>) -> Self {
289+
Self::new().secure(context.config().middlewares.session.secure)
290+
}
291+
/// Sets the secure flag for the session middleware.
292+
///
293+
/// # Examples
294+
///
295+
/// ```
296+
/// use cot::middleware::SessionMiddleware;
297+
///
298+
/// let middleware = SessionMiddleware::new().secure(false);
299+
/// ```
300+
#[must_use]
301+
pub fn secure(self, secure: bool) -> Self {
302+
Self {
303+
inner: self.inner.with_secure(secure),
304+
}
259305
}
260306
}
261307

@@ -269,12 +315,9 @@ impl<S> tower::Layer<S> for SessionMiddleware {
269315
type Service = <SessionManagerLayer<MemoryStore> as tower::Layer<S>>::Service;
270316

271317
fn layer(&self, inner: S) -> Self::Service {
272-
let session_store = MemoryStore::default();
273-
let session_layer = SessionManagerLayer::new(session_store);
274-
session_layer.layer(inner)
318+
self.inner.layer(inner)
275319
}
276320
}
277-
278321
#[cfg(feature = "live-reload")]
279322
type LiveReloadLayerType = tower::util::Either<
280323
(

examples/admin/src/main.rs

+7-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use cot::__private::async_trait;
22
use cot::admin::AdminApp;
33
use cot::auth::db::{DatabaseUser, DatabaseUserApp};
44
use cot::cli::CliMetadata;
5-
use cot::config::{DatabaseConfig, ProjectConfig};
5+
use cot::config::{DatabaseConfig, MiddlewareConfig, ProjectConfig, SessionMiddlewareConfig};
66
use cot::middleware::{LiveReloadMiddleware, SessionMiddleware};
77
use cot::project::{WithApps, WithConfig};
88
use cot::request::Request;
@@ -63,6 +63,11 @@ impl Project for AdminProject {
6363
.url("sqlite://db.sqlite3?mode=rwc")
6464
.build(),
6565
)
66+
.middlewares(
67+
MiddlewareConfig::builder()
68+
.session(SessionMiddlewareConfig::builder().secure(false).build())
69+
.build(),
70+
)
6671
.build())
6772
}
6873

@@ -79,7 +84,7 @@ impl Project for AdminProject {
7984
) -> BoxedHandler {
8085
handler
8186
.middleware(StaticFilesMiddleware::from_context(context))
82-
.middleware(SessionMiddleware::new())
87+
.middleware(SessionMiddleware::from_context(context))
8388
.middleware(LiveReloadMiddleware::new())
8489
.build()
8590
}

0 commit comments

Comments
 (0)