Skip to content

Commit df07e42

Browse files
graebmbgklika
andauthored
new function: aws_socket_get_bound_address() (#491)
Enable users to bind on port 0, which has the OS assign a port, and then query which port it ended up with. The socket stores this address during `aws_socket_bind()` call, which allows `aws_socket_get_bound_address()` to be const and avoid any tricky threading issues where the socked closes on another thread. Also fix a few subtle bugs in Windows socket code. Co-authored-by: Vitaly Khalmansky <vkhalmansky@klika-tech.com>
1 parent 66a38bc commit df07e42

File tree

5 files changed

+413
-300
lines changed

5 files changed

+413
-300
lines changed

include/aws/io/socket.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ AWS_IO_API int aws_socket_connect(
176176
*/
177177
AWS_IO_API int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint *local_endpoint);
178178

179+
/**
180+
* Get the local address which the socket is bound to.
181+
* Raises an error if no address is bound.
182+
*/
183+
AWS_IO_API int aws_socket_get_bound_address(const struct aws_socket *socket, struct aws_socket_endpoint *out_address);
184+
179185
/**
180186
* TCP, LOCAL and VSOCK only. Sets up the socket to listen on the address bound to in `aws_socket_bind()`.
181187
*/

source/posix/socket.c

Lines changed: 135 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,98 @@ void aws_socket_clean_up(struct aws_socket *socket) {
263263
socket->io_handle.data.fd = -1;
264264
}
265265

266+
/* Update socket->local_endpoint based on the results of getsockname() */
267+
static int s_update_local_endpoint(struct aws_socket *socket) {
268+
struct aws_socket_endpoint tmp_endpoint;
269+
AWS_ZERO_STRUCT(tmp_endpoint);
270+
271+
struct sockaddr_storage address;
272+
AWS_ZERO_STRUCT(address);
273+
socklen_t address_size = sizeof(address);
274+
275+
if (getsockname(socket->io_handle.data.fd, (struct sockaddr *)&address, &address_size) != 0) {
276+
AWS_LOGF_ERROR(
277+
AWS_LS_IO_SOCKET,
278+
"id=%p fd=%d: getsockname() failed with error %d",
279+
(void *)socket,
280+
socket->io_handle.data.fd,
281+
errno);
282+
int aws_error = s_determine_socket_error(errno);
283+
return aws_raise_error(aws_error);
284+
}
285+
286+
if (address.ss_family == AF_INET) {
287+
struct sockaddr_in *s = (struct sockaddr_in *)&address;
288+
tmp_endpoint.port = ntohs(s->sin_port);
289+
if (inet_ntop(AF_INET, &s->sin_addr, tmp_endpoint.address, sizeof(tmp_endpoint.address)) == NULL) {
290+
AWS_LOGF_ERROR(
291+
AWS_LS_IO_SOCKET,
292+
"id=%p fd=%d: inet_ntop() failed with error %d",
293+
(void *)socket,
294+
socket->io_handle.data.fd,
295+
errno);
296+
int aws_error = s_determine_socket_error(errno);
297+
return aws_raise_error(aws_error);
298+
}
299+
} else if (address.ss_family == AF_INET6) {
300+
struct sockaddr_in6 *s = (struct sockaddr_in6 *)&address;
301+
tmp_endpoint.port = ntohs(s->sin6_port);
302+
if (inet_ntop(AF_INET6, &s->sin6_addr, tmp_endpoint.address, sizeof(tmp_endpoint.address)) == NULL) {
303+
AWS_LOGF_ERROR(
304+
AWS_LS_IO_SOCKET,
305+
"id=%p fd=%d: inet_ntop() failed with error %d",
306+
(void *)socket,
307+
socket->io_handle.data.fd,
308+
errno);
309+
int aws_error = s_determine_socket_error(errno);
310+
return aws_raise_error(aws_error);
311+
}
312+
} else if (address.ss_family == AF_UNIX) {
313+
struct sockaddr_un *s = (struct sockaddr_un *)&address;
314+
315+
/* Ensure there's a null-terminator.
316+
* On some platforms it may be missing when the path gets very long. See:
317+
* https://man7.org/linux/man-pages/man7/unix.7.html#BUGS
318+
* But let's keep it simple, and not deal with that madness until someone demands it. */
319+
size_t sun_len;
320+
if (aws_secure_strlen(s->sun_path, sizeof(tmp_endpoint.address), &sun_len)) {
321+
AWS_LOGF_ERROR(
322+
AWS_LS_IO_SOCKET,
323+
"id=%p fd=%d: UNIX domain socket name is too long",
324+
(void *)socket,
325+
socket->io_handle.data.fd);
326+
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
327+
}
328+
memcpy(tmp_endpoint.address, s->sun_path, sun_len);
329+
#if USE_VSOCK
330+
} else if (address.ss_family == AF_VSOCK) {
331+
struct sockaddr_vm *s = (struct sockaddr_vm *)&address;
332+
333+
/* VSOCK port is 32bit, but aws_socket_endpoint.port is only 16bit.
334+
* Hopefully this isn't an issue, since users can only pass in 16bit values.
335+
* But if it becomes an issue, we'll need to make aws_socket_endpoint more flexible */
336+
if (s->svm_port > UINT16_MAX) {
337+
AWS_LOGF_ERROR(
338+
AWS_LS_IO_SOCKET,
339+
"id=%p fd=%d: aws_socket_endpoint can't deal with VSOCK port > UINT16_MAX",
340+
(void *)socket,
341+
socket->io_handle.data.fd);
342+
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
343+
}
344+
tmp_endpoint.port = (uint16_t)s->svm_port;
345+
346+
snprintf(tmp_endpoint.address, sizeof(tmp_endpoint.address), "%" PRIu32, s->svm_cid);
347+
return AWS_OP_SUCCESS;
348+
#endif /* USE_VSOCK */
349+
} else {
350+
AWS_ASSERT(0);
351+
return aws_raise_error(AWS_IO_SOCKET_UNSUPPORTED_ADDRESS_FAMILY);
352+
}
353+
354+
socket->local_endpoint = tmp_endpoint;
355+
return AWS_OP_SUCCESS;
356+
}
357+
266358
static void s_on_connection_error(struct aws_socket *socket, int error);
267359

268360
static int s_on_connection_success(struct aws_socket *socket) {
@@ -308,67 +400,8 @@ static int s_on_connection_success(struct aws_socket *socket) {
308400

309401
AWS_LOGF_INFO(AWS_LS_IO_SOCKET, "id=%p fd=%d: connection success", (void *)socket, socket->io_handle.data.fd);
310402

311-
struct sockaddr_storage address;
312-
AWS_ZERO_STRUCT(address);
313-
socklen_t address_size = sizeof(address);
314-
if (!getsockname(socket->io_handle.data.fd, (struct sockaddr *)&address, &address_size)) {
315-
uint16_t port = 0;
316-
317-
if (address.ss_family == AF_INET) {
318-
struct sockaddr_in *s = (struct sockaddr_in *)&address;
319-
port = ntohs(s->sin_port);
320-
/* this comes straight from the kernal. a.) they won't fail. b.) even if they do, it's not fatal
321-
* once we add logging, we can log this if it fails. */
322-
if (inet_ntop(
323-
AF_INET, &s->sin_addr, socket->local_endpoint.address, sizeof(socket->local_endpoint.address))) {
324-
AWS_LOGF_DEBUG(
325-
AWS_LS_IO_SOCKET,
326-
"id=%p fd=%d: local endpoint %s:%d",
327-
(void *)socket,
328-
socket->io_handle.data.fd,
329-
socket->local_endpoint.address,
330-
port);
331-
} else {
332-
AWS_LOGF_WARN(
333-
AWS_LS_IO_SOCKET,
334-
"id=%p fd=%d: determining local endpoint failed",
335-
(void *)socket,
336-
socket->io_handle.data.fd);
337-
}
338-
} else if (address.ss_family == AF_INET6) {
339-
struct sockaddr_in6 *s = (struct sockaddr_in6 *)&address;
340-
port = ntohs(s->sin6_port);
341-
/* this comes straight from the kernal. a.) they won't fail. b.) even if they do, it's not fatal
342-
* once we add logging, we can log this if it fails. */
343-
if (inet_ntop(
344-
AF_INET6, &s->sin6_addr, socket->local_endpoint.address, sizeof(socket->local_endpoint.address))) {
345-
AWS_LOGF_DEBUG(
346-
AWS_LS_IO_SOCKET,
347-
"id=%p fd %d: local endpoint %s:%d",
348-
(void *)socket,
349-
socket->io_handle.data.fd,
350-
socket->local_endpoint.address,
351-
port);
352-
} else {
353-
AWS_LOGF_WARN(
354-
AWS_LS_IO_SOCKET,
355-
"id=%p fd=%d: determining local endpoint failed",
356-
(void *)socket,
357-
socket->io_handle.data.fd);
358-
}
359-
}
360-
361-
socket->local_endpoint.port = port;
362-
} else {
363-
AWS_LOGF_ERROR(
364-
AWS_LS_IO_SOCKET,
365-
"id=%p fd=%d: getsockname() failed with error %d",
366-
(void *)socket,
367-
socket->io_handle.data.fd,
368-
errno);
369-
int aws_error = s_determine_socket_error(errno);
370-
aws_raise_error(aws_error);
371-
s_on_connection_error(socket, aws_error);
403+
if (s_update_local_endpoint(socket)) {
404+
s_on_connection_error(socket, aws_last_error());
372405
return AWS_OP_ERR;
373406
}
374407

@@ -761,9 +794,6 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint
761794
return AWS_OP_ERR;
762795
}
763796

764-
int error_code = -1;
765-
766-
socket->local_endpoint = *local_endpoint;
767797
AWS_LOGF_INFO(
768798
AWS_LS_IO_SOCKET,
769799
"id=%p fd=%d: binding to %s:%d.",
@@ -813,31 +843,55 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint
813843
return aws_raise_error(s_convert_pton_error(pton_err));
814844
}
815845

816-
error_code = bind(socket->io_handle.data.fd, (struct sockaddr *)&address.sock_addr_types, sock_size);
846+
if (bind(socket->io_handle.data.fd, (struct sockaddr *)&address.sock_addr_types, sock_size) != 0) {
847+
AWS_LOGF_ERROR(
848+
AWS_LS_IO_SOCKET,
849+
"id=%p fd=%d: bind failed with error code %d",
850+
(void *)socket,
851+
socket->io_handle.data.fd,
852+
errno);
817853

818-
if (!error_code) {
819-
if (socket->options.type == AWS_SOCKET_STREAM) {
820-
socket->state = BOUND;
821-
} else {
822-
/* e.g. UDP is now readable */
823-
socket->state = CONNECTED_READ;
824-
}
825-
AWS_LOGF_DEBUG(AWS_LS_IO_SOCKET, "id=%p fd=%d: successfully bound", (void *)socket, socket->io_handle.data.fd);
854+
aws_raise_error(s_determine_socket_error(errno));
855+
goto error;
856+
}
826857

827-
return AWS_OP_SUCCESS;
858+
if (s_update_local_endpoint(socket)) {
859+
goto error;
828860
}
829861

830-
socket->state = ERROR;
831-
error_code = errno;
832-
AWS_LOGF_ERROR(
862+
if (socket->options.type == AWS_SOCKET_STREAM) {
863+
socket->state = BOUND;
864+
} else {
865+
/* e.g. UDP is now readable */
866+
socket->state = CONNECTED_READ;
867+
}
868+
869+
AWS_LOGF_DEBUG(
833870
AWS_LS_IO_SOCKET,
834-
"id=%p fd=%d: bind failed with error code %d",
871+
"id=%p fd=%d: successfully bound to %s:%d",
835872
(void *)socket,
836873
socket->io_handle.data.fd,
837-
error_code);
874+
socket->local_endpoint.address,
875+
socket->local_endpoint.port);
838876

839-
int aws_error = s_determine_socket_error(error_code);
840-
return aws_raise_error(aws_error);
877+
return AWS_OP_SUCCESS;
878+
879+
error:
880+
socket->state = ERROR;
881+
return AWS_OP_ERR;
882+
}
883+
884+
int aws_socket_get_bound_address(const struct aws_socket *socket, struct aws_socket_endpoint *out_address) {
885+
if (socket->local_endpoint.address[0] == 0) {
886+
AWS_LOGF_ERROR(
887+
AWS_LS_IO_SOCKET,
888+
"id=%p fd=%d: Socket has no local address. Socket must be bound first.",
889+
(void *)socket,
890+
socket->io_handle.data.fd);
891+
return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE);
892+
}
893+
*out_address = socket->local_endpoint;
894+
return AWS_OP_SUCCESS;
841895
}
842896

843897
int aws_socket_listen(struct aws_socket *socket, int backlog_size) {

0 commit comments

Comments
 (0)