Skip to content

Commit cd9ca63

Browse files
v0.17~preview.128.40+46
1 parent 2d1de9e commit cd9ca63

File tree

6 files changed

+197
-15
lines changed

6 files changed

+197
-15
lines changed

async_rpc/src/rpc.ml

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,16 @@ module Connection = struct
143143
Result.ok_exn res
144144
;;
145145

146+
let connection_description ~server_addr ~client_addr =
147+
let server_addr = (server_addr :> Socket.Address.t) in
148+
let client_addr = (client_addr :> Socket.Address.t) in
149+
Info.create_s
150+
[%message
151+
"TCP server" (server_addr : Socket.Address.t) (client_addr : Socket.Address.t)]
152+
;;
153+
154+
let default_on_handshake_error = `Ignore
155+
146156
let make_serve_func
147157
serve_with_transport_handler
148158
~implementations
@@ -157,7 +167,7 @@ module Connection = struct
157167
?handshake_timeout
158168
?heartbeat_config
159169
?auth
160-
?(on_handshake_error = `Ignore)
170+
?(on_handshake_error = default_on_handshake_error)
161171
?on_handler_error
162172
()
163173
=
@@ -172,18 +182,11 @@ module Connection = struct
172182
?auth
173183
?on_handler_error
174184
(fun ~client_addr ~server_addr transport ->
175-
let description =
176-
let server_addr = (server_addr :> Socket.Address.t) in
177-
let client_addr = (client_addr :> Socket.Address.t) in
178-
Info.create_s
179-
[%message
180-
"TCP server" (server_addr : Socket.Address.t) (client_addr : Socket.Address.t)]
181-
in
182185
serve_with_transport
183186
~handshake_timeout
184187
~heartbeat_config
185188
~implementations
186-
~description
189+
~description:(connection_description ~server_addr ~client_addr)
187190
~connection_state:(fun conn -> initial_connection_state client_addr conn)
188191
~on_handshake_error
189192
~client_addr
@@ -198,6 +201,46 @@ module Connection = struct
198201
make_serve_func Rpc_transport.Tcp.serve_inet ~implementations
199202
;;
200203

204+
let serve_unix
205+
~implementations
206+
~initial_connection_state
207+
~where_to_listen
208+
?max_connections
209+
?backlog
210+
?drop_incoming_connections
211+
?time_source
212+
?max_message_size
213+
?make_transport
214+
?handshake_timeout
215+
?heartbeat_config
216+
?auth
217+
?(on_handshake_error = default_on_handshake_error)
218+
?on_handler_error
219+
()
220+
=
221+
Rpc_transport.Tcp.serve_unix
222+
~where_to_listen
223+
?max_connections
224+
?backlog
225+
?drop_incoming_connections
226+
?time_source
227+
?max_message_size
228+
?make_transport
229+
?auth
230+
?on_handler_error
231+
(fun ~client_addr ~server_addr peer_creds transport ->
232+
serve_with_transport
233+
~handshake_timeout
234+
~heartbeat_config
235+
~implementations
236+
~description:(connection_description ~server_addr ~client_addr)
237+
~connection_state:(fun conn ->
238+
initial_connection_state client_addr peer_creds conn)
239+
~on_handshake_error
240+
~client_addr
241+
transport)
242+
;;
243+
201244
let default_handshake_timeout_float =
202245
Time_ns.Span.to_span_float_round_nearest
203246
Async_rpc_kernel.Async_rpc_kernel_private.default_handshake_timeout

async_rpc/src/rpc.mli

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,30 @@ module Connection : sig
140140
-> unit
141141
-> (Socket.Address.Inet.t, int) Tcp.Server.t
142142

143+
(** As [serve], but only accepts Unix sockets; provides peer credentials of the socket
144+
to [initial_connection_state]. *)
145+
val serve_unix
146+
: implementations:'s Implementations.t
147+
-> initial_connection_state:
148+
(Socket.Address.Unix.t -> Linux_ext.Peer_credentials.t -> t -> 's)
149+
-> where_to_listen:Tcp.Where_to_listen.unix
150+
-> ?max_connections:int
151+
-> ?backlog:int
152+
-> ?drop_incoming_connections:bool
153+
-> ?time_source:[> read ] Time_source.T1.t
154+
-> ?max_message_size:int
155+
-> ?make_transport:transport_maker
156+
-> ?handshake_timeout:Time_float.Span.t
157+
-> ?heartbeat_config:Heartbeat_config.t
158+
-> ?auth:(Socket.Address.Unix.t -> bool) (** default is [`Ignore] *)
159+
-> ?on_handshake_error:
160+
[ `Raise | `Ignore | `Call of Socket.Address.Unix.t -> exn -> unit ]
161+
(** default is [`Ignore] *)
162+
-> ?on_handler_error:
163+
[ `Raise | `Ignore | `Call of Socket.Address.Unix.t -> exn -> unit ]
164+
-> unit
165+
-> Tcp.Server.unix Deferred.t
166+
143167
(** [client where_to_connect ()] connects to the server at [where_to_connect] and
144168
returns the connection or an Error if a connection could not be made. It is the
145169
responsibility of the caller to eventually call [close].

async_rpc/src/rpc_transport.ml

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ module Header = Kernel_transport.Header
55
module Handler_result = Kernel_transport.Handler_result
66
module Send_result = Kernel_transport.Send_result
77

8-
let environment_variable = "ASYNC_RPC_MAX_MESSAGE_SIZE"
8+
let max_message_size_env_var = "ASYNC_RPC_MAX_MESSAGE_SIZE"
99

1010
let max_message_size_from_environment =
1111
lazy
1212
(Option.try_with_join (fun () ->
13-
Sys.getenv environment_variable |> Option.map ~f:Int.of_string))
13+
Sys.getenv max_message_size_env_var |> Option.map ~f:Int.of_string))
1414
;;
1515

1616
let aux_effective_max_message_size ~max_message_size_from_environment ~proposed_max =
@@ -99,7 +99,9 @@ end = struct
9999
then
100100
failwiths
101101
~here:[%here]
102-
"Rpc_transport: message too small or too big"
102+
[%string
103+
"Rpc_transport: message is too large or has negative size. Try increasing the \
104+
size limit by setting the %{max_message_size_env_var} env var"]
103105
(`Message_size payload_len, `Max_message_size t.max_message_size)
104106
[%sexp_of: [ `Message_size of int ] * [ `Max_message_size of int ]]
105107
;;
@@ -297,7 +299,7 @@ let of_fd ?buffer_age_limit ?reader_buffer_size ?writer_buffer_size ~max_message
297299
module Tcp = struct
298300
let default_transport_maker fd ~max_message_size = of_fd fd ~max_message_size
299301

300-
let make_serve_func
302+
let make_serve_func_with_fd
301303
tcp_creator
302304
~where_to_listen
303305
?max_connections
@@ -324,10 +326,12 @@ module Tcp = struct
324326
| false -> return ()
325327
| true ->
326328
let max_message_size = effective_max_message_size ~proposed_max in
327-
let transport = make_transport ~max_message_size (Socket.fd socket) in
329+
let fd = Socket.fd socket in
330+
let transport = make_transport ~max_message_size fd in
328331
let%bind result =
329332
Monitor.try_with ~run:`Schedule ~rest:`Raise (fun () ->
330333
handle_transport
334+
fd
331335
~client_addr
332336
~server_addr:(Socket.getsockname socket)
333337
transport)
@@ -338,6 +342,34 @@ module Tcp = struct
338342
| Error exn -> raise exn))
339343
;;
340344
345+
let make_serve_func
346+
tcp_creator
347+
~where_to_listen
348+
?max_connections
349+
?backlog
350+
?drop_incoming_connections
351+
?time_source
352+
?max_message_size
353+
?make_transport
354+
?auth
355+
?on_handler_error
356+
handle_transport
357+
=
358+
make_serve_func_with_fd
359+
tcp_creator
360+
~where_to_listen
361+
?max_connections
362+
?backlog
363+
?drop_incoming_connections
364+
?time_source
365+
?max_message_size
366+
?make_transport
367+
?auth
368+
?on_handler_error
369+
(fun (_ : Fd.t) ~client_addr ~server_addr transport ->
370+
handle_transport ~client_addr ~server_addr transport)
371+
;;
372+
341373
(* eta-expand [where_to_listen] to avoid value restriction. *)
342374
let serve ~where_to_listen = make_serve_func Tcp.Server.create_sock ~where_to_listen
343375
@@ -346,6 +378,39 @@ module Tcp = struct
346378
make_serve_func Tcp.Server.create_sock_inet ~where_to_listen
347379
;;
348380
381+
let serve_unix
382+
~(where_to_listen : Tcp.Where_to_listen.unix)
383+
?max_connections
384+
?backlog
385+
?drop_incoming_connections
386+
?time_source
387+
?max_message_size
388+
?make_transport
389+
?auth
390+
?on_handler_error
391+
handle_transport
392+
=
393+
make_serve_func_with_fd
394+
Tcp.Server.create_sock
395+
~where_to_listen
396+
?max_connections
397+
?backlog
398+
?drop_incoming_connections
399+
?time_source
400+
?max_message_size
401+
?make_transport
402+
?auth
403+
?on_handler_error
404+
(fun fd ~client_addr ~server_addr transport ->
405+
let peer_credentials =
406+
Or_error.try_with (fun () ->
407+
(ok_exn Linux_ext.peer_credentials) (Fd.file_descr_exn fd))
408+
|> Or_error.tag ~tag:"Error getting peer credentials of unix socket"
409+
|> ok_exn
410+
in
411+
handle_transport ~client_addr ~server_addr peer_credentials transport)
412+
;;
413+
349414
let connect
350415
?max_message_size:proposed_max
351416
?(make_transport = default_transport_maker)

async_rpc/src/rpc_transport.mli

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,26 @@ module Tcp : sig
8585
-> unit Deferred.t)
8686
-> (Socket.Address.Inet.t, int) Tcp.Server.t
8787

88+
(** [serve_inet] is like [serve] but only for unix sockets (not inet sockets), and
89+
returns the identity of the peer on the socket. *)
90+
val serve_unix
91+
: where_to_listen:Tcp.Where_to_listen.unix
92+
-> ?max_connections:int
93+
-> ?backlog:int
94+
-> ?drop_incoming_connections:bool
95+
-> ?time_source:[> read ] Time_source.T1.t
96+
-> ?max_message_size:int
97+
-> ?make_transport:transport_maker
98+
-> ?auth:(Socket.Address.Unix.t -> bool)
99+
-> ?on_handler_error:
100+
[ `Raise | `Ignore | `Call of Socket.Address.Unix.t -> exn -> unit ]
101+
-> (client_addr:Socket.Address.Unix.t
102+
-> server_addr:Socket.Address.Unix.t
103+
-> Linux_ext.Peer_credentials.t
104+
-> Rpc_kernel.Transport.t
105+
-> unit Deferred.t)
106+
-> (Socket.Address.Unix.t, string) Tcp.Server.t Deferred.t
107+
88108
(** [connect ?make_transport where_to_connect ()] connects to the server at
89109
[where_to_connect]. On success, it returns the transport created using
90110
[make_transport] and the [Socket.Address.t] that it connected to, otherwise it

async_rpc/test/dune

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
(library (name async_rpc_test) (libraries async re2 sexp_select)
1+
(library (name async_rpc_test)
2+
(libraries async expect_test_helpers_async re2 sexp_select)
23
(preprocess (pps ppx_jane)))

async_rpc/test/test_rpc.ml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,35 @@ let%test_unit "Open dispatches see connection closed error" =
203203
Tcp.Server.close server)
204204
;;
205205

206+
let%expect_test "serve_unix returns the identity of the caller" =
207+
Expect_test_helpers_async.with_temp_dir
208+
~in_dir:
209+
(let force_tmp_to_avoid_unix_socket_path_length_problems = "/tmp" in
210+
force_tmp_to_avoid_unix_socket_path_length_problems)
211+
(fun tempdir ->
212+
let socket_path = tempdir ^/ "socket" in
213+
let%bind server =
214+
Rpc.Connection.serve_unix
215+
~initial_connection_state:
216+
(fun
217+
(_ : Socket.Address.Unix.t)
218+
(peer_creds : Linux_ext.Peer_credentials.t)
219+
(_ : Rpc.Connection.t)
220+
->
221+
[%test_result: int] peer_creds.uid ~expect:(Unix.getuid ());
222+
[%test_result: int] peer_creds.gid ~expect:(Unix.getgid ());
223+
[%test_result: Pid.t] peer_creds.pid ~expect:(Unix.getpid ()))
224+
~implementations:(Rpc.Implementations.null ())
225+
~where_to_listen:(Tcp.Where_to_listen.of_file socket_path)
226+
()
227+
in
228+
let%bind connection =
229+
Rpc.Connection.client (Tcp.Where_to_connect.of_file socket_path) >>| Result.ok_exn
230+
in
231+
let%bind () = Rpc.Connection.close connection in
232+
Tcp.Server.close server)
233+
;;
234+
206235
let%test_module "Exception handling" =
207236
(module struct
208237
let on_exception ~callback_triggered ~expect_close_connection =

0 commit comments

Comments
 (0)