77// the Business Source License, use of this software will be governed
88// by the Apache License, Version 2.0.
99
10+ use std:: sync:: { Arc , Mutex } ;
1011use std:: time:: Duration ;
1112
13+ use anyhow:: Context ;
14+ use mz_ore:: collections:: HashMap ;
15+ use tokio:: sync:: oneshot;
16+
1217use reqwest_retry:: policies:: ExponentialBackoff ;
1318use reqwest_retry:: RetryTransientMiddleware ;
1419
20+ use crate :: client:: tokens:: RefreshTokenResponse ;
21+ use crate :: { ApiTokenArgs , ApiTokenResponse , Error , RefreshToken } ;
22+
1523pub mod tokens;
1624
25+ /// Client for Frontegg auth requests.
26+ ///
27+ /// Internally the client will attempt to de-dupe requests, e.g. if a single user tries to connect
28+ /// many clients at once, we'll de-dupe the authentication requests.
1729#[ derive( Clone , Debug ) ]
1830pub struct Client {
1931 pub client : reqwest_middleware:: ClientWithMiddleware ,
32+ inflight_requests : Arc < Mutex < HashMap < Request , ResponseHandle > > > ,
2033}
2134
35+ type ResponseHandle = Vec < oneshot:: Sender < Result < Response , Error > > > ;
36+
2237impl Default for Client {
2338 fn default ( ) -> Self {
2439 // Re-use the envd defaults until there's a reason to use something else. This is a separate
@@ -44,6 +59,167 @@ impl Client {
4459 . with ( RetryTransientMiddleware :: new_with_policy ( retry_policy) )
4560 . build ( ) ;
4661
47- Self { client }
62+ let inflight_requests = Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
63+
64+ Self {
65+ client,
66+ inflight_requests,
67+ }
68+ }
69+
70+ /// Makes a request to the provided URL, possibly de-duping by attaching a listener to an
71+ /// already in-flight request.
72+ async fn make_request < Req , Resp > ( & self , url : String , req : Req ) -> Result < Resp , Error >
73+ where
74+ Req : AuthRequest ,
75+ Resp : AuthResponse ,
76+ {
77+ let req = req. into_request ( ) ;
78+
79+ // Note: we get the reciever in a block to scope the access to the mutex.
80+ let rx = {
81+ let mut inflight_requests = self
82+ . inflight_requests
83+ . lock ( )
84+ . expect ( "Frontegg Auth Client panicked" ) ;
85+ let ( tx, rx) = tokio:: sync:: oneshot:: channel ( ) ;
86+
87+ match inflight_requests. get_mut ( & req) {
88+ // Already have an inflight request, add to our list of waiters.
89+ Some ( senders) => {
90+ tracing:: debug!( "reusing request, {req:?}" ) ;
91+ senders. push ( tx) ;
92+ rx
93+ }
94+ // New request! Need to queue one up.
95+ None => {
96+ tracing:: debug!( "spawning new request, {req:?}" ) ;
97+
98+ inflight_requests. insert ( req. clone ( ) , vec ! [ tx] ) ;
99+
100+ let client = self . client . clone ( ) ;
101+ let inflight = Arc :: clone ( & self . inflight_requests ) ;
102+ let req_ = req. clone ( ) ;
103+
104+ mz_ore:: task:: spawn ( move || "frontegg-auth-request" , async move {
105+ // Make the actual request.
106+ let result = async {
107+ let resp = client
108+ . post ( & url)
109+ . json ( & req_. into_json ( ) )
110+ . send ( )
111+ . await ?
112+ . error_for_status ( ) ?
113+ . json :: < Resp > ( )
114+ . await ?;
115+ Ok :: < _ , Error > ( resp)
116+ }
117+ . await ;
118+
119+ // Get all of our waiters.
120+ let mut inflight = inflight. lock ( ) . expect ( "Frontegg Auth Client panicked" ) ;
121+ let Some ( waiters) = inflight. remove ( & req) else {
122+ tracing:: error!( "Inflight entry already removed? {req:?}" ) ;
123+ return ;
124+ } ;
125+
126+ // Tell all of our waiters about the result.
127+ let response = result. map ( |r| r. into_response ( ) ) ;
128+ for tx in waiters {
129+ let _ = tx. send ( response. clone ( ) ) ;
130+ }
131+ } ) ;
132+
133+ rx
134+ }
135+ }
136+ } ;
137+
138+ let resp = rx. await . context ( "waiting for inflight response" ) ?;
139+ resp. map ( |r| Resp :: from_response ( r) )
140+ }
141+ }
142+
143+ /// Boilerplate for de-duping requests.
144+ ///
145+ /// We maintain an in-memory map of inflight requests, and that map needs to have keys of a single
146+ /// type, so we wrap all of our request types an an enum to create that single type.
147+ #[ derive( Clone , Debug , Hash , PartialEq , Eq ) ]
148+ enum Request {
149+ ExchangeSecretForToken ( ApiTokenArgs ) ,
150+ RefreshToken ( RefreshToken ) ,
151+ }
152+
153+ impl Request {
154+ fn into_json ( self ) -> serde_json:: Value {
155+ match self {
156+ Request :: ExchangeSecretForToken ( arg) => serde_json:: to_value ( arg) ,
157+ Request :: RefreshToken ( arg) => serde_json:: to_value ( arg) ,
158+ }
159+ . expect ( "converting to JSON cannot fail" )
160+ }
161+ }
162+
163+ /// Boilerplate for de-duping requests.
164+ ///
165+ /// Deduplicates the wrapping of request types into a [`Request`].
166+ trait AuthRequest : serde:: Serialize + Clone {
167+ fn into_request ( self ) -> Request ;
168+ }
169+
170+ impl AuthRequest for ApiTokenArgs {
171+ fn into_request ( self ) -> Request {
172+ Request :: ExchangeSecretForToken ( self )
173+ }
174+ }
175+
176+ impl AuthRequest for RefreshToken {
177+ fn into_request ( self ) -> Request {
178+ Request :: RefreshToken ( self )
179+ }
180+ }
181+
182+ /// Boilerplate for de-duping requests.
183+ ///
184+ /// We maintain an in-memory map of inflight requests, the values of the map are a Vec of waiters
185+ /// that listen for a response. These listeners all need to have the same type, so we wrap all of
186+ /// our response types in an enum.
187+ #[ derive( Clone , Debug ) ]
188+ enum Response {
189+ ExchangeSecretForToken ( ApiTokenResponse ) ,
190+ RefreshToken ( RefreshTokenResponse ) ,
191+ }
192+
193+ /// Boilerplate for de-duping requests.
194+ ///
195+ /// Deduplicates the wrapping and unwrapping between response types and [`Response`].
196+ trait AuthResponse : serde:: de:: DeserializeOwned {
197+ fn into_response ( self ) -> Response ;
198+ fn from_response ( resp : Response ) -> Self ;
199+ }
200+
201+ impl AuthResponse for ApiTokenResponse {
202+ fn into_response ( self ) -> Response {
203+ Response :: ExchangeSecretForToken ( self )
204+ }
205+
206+ fn from_response ( resp : Response ) -> Self {
207+ let Response :: ExchangeSecretForToken ( result) = resp else {
208+ unreachable ! ( "programming error!, didn't roundtrip {resp:?}" )
209+ } ;
210+ result
211+ }
212+ }
213+
214+ impl AuthResponse for RefreshTokenResponse {
215+ fn into_response ( self ) -> Response {
216+ Response :: RefreshToken ( self )
217+ }
218+
219+ fn from_response ( resp : Response ) -> Self {
220+ let Response :: RefreshToken ( result) = resp else {
221+ unreachable ! ( "programming error!, didn't roundtrip" )
222+ } ;
223+ result
48224 }
49225}
0 commit comments