Skip to content

Commit

Permalink
kernel: do not rely on struct layout to perform downcast
Browse files Browse the repository at this point in the history
  • Loading branch information
mosmeh committed Oct 4, 2024
1 parent 20c196b commit 8cf6e95
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 43 deletions.
5 changes: 5 additions & 0 deletions common/extra.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@

#define SIZEOF_FIELD(t, f) sizeof(((t*)0)->f)
#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0]))
#define CONTAINER_OF(ptr, type, member) \
({ \
const __typeof__(((type*)0)->member)* __mptr = (ptr); \
(type*)((char*)__mptr - offsetof(type, member)); \
})

#define ROUND_UP(x, align) (((x) + ((align) - 1)) & ~((align) - 1))
#define ROUND_DOWN(x, align) ((x) & ~((align) - 1))
Expand Down
16 changes: 10 additions & 6 deletions kernel/console/tty.c
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "private.h"
#include <common/extra.h>
#include <common/string.h>
#include <kernel/api/signal.h>
#include <kernel/api/sys/ioctl.h>
Expand All @@ -9,20 +10,23 @@
#include <kernel/safe_string.h>
#include <kernel/task.h>

static struct tty* tty_from_file(struct file* file) {
return CONTAINER_OF(file->inode, struct tty, inode);
}

static bool can_read(struct tty* tty) {
return !ring_buf_is_empty(&tty->input_buf);
}

static bool unblock_read(struct file* file) {
struct tty* tty = (struct tty*)file->inode;
return can_read(tty);
return can_read(tty_from_file(file));
}

static ssize_t tty_pread(struct file* file, void* buf, size_t count,
uint64_t offset) {
(void)offset;

struct tty* tty = (struct tty*)file->inode;
struct tty* tty = tty_from_file(file);

for (;;) {
int rc = file_block(file, unblock_read, 0);
Expand Down Expand Up @@ -81,15 +85,15 @@ static void processed_echo(struct tty* tty, const char* buf, size_t count) {
static ssize_t tty_pwrite(struct file* file, const void* buf, size_t count,
uint64_t offset) {
(void)offset;
struct tty* tty = (struct tty*)file->inode;
struct tty* tty = tty_from_file(file);
spinlock_lock(&tty->lock);
processed_echo(tty, buf, count);
spinlock_unlock(&tty->lock);
return count;
}

static int tty_ioctl(struct file* file, int request, void* user_argp) {
struct tty* tty = (struct tty*)file->inode;
struct tty* tty = tty_from_file(file);
struct termios* termios = &tty->termios;
int ret = 0;
spinlock_lock(&tty->lock);
Expand Down Expand Up @@ -157,9 +161,9 @@ static int tty_ioctl(struct file* file, int request, void* user_argp) {
}

static short tty_poll(struct file* file, short events) {
struct tty* tty = (struct tty*)file->inode;
short revents = 0;
if (events & POLLIN) {
struct tty* tty = tty_from_file(file);
spinlock_lock(&tty->lock);
if (can_read(tty))
revents |= POLLIN;
Expand Down
22 changes: 15 additions & 7 deletions kernel/fs/fifo.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@ struct fifo {
atomic_size_t num_writers;
};

static struct fifo* fifo_from_inode(struct inode* inode) {
return CONTAINER_OF(inode, struct fifo, inode);
}

static struct fifo* fifo_from_file(struct file* file) {
return fifo_from_inode(file->inode);
}

static void fifo_destroy_inode(struct inode* inode) {
struct fifo* fifo = (struct fifo*)inode;
struct fifo* fifo = fifo_from_inode(inode);
ring_buf_destroy(&fifo->buf);
kfree(fifo);
}
Expand All @@ -36,7 +44,7 @@ static bool unblock_open(struct file* file) {
static int fifo_open(struct file* file, mode_t mode) {
(void)mode;

struct fifo* fifo = (struct fifo*)file->inode;
struct fifo* fifo = fifo_from_file(file);
switch (file->flags & O_ACCMODE) {
case O_RDONLY:
++fifo->num_readers;
Expand All @@ -61,7 +69,7 @@ static int fifo_open(struct file* file, mode_t mode) {
}

static int fifo_close(struct file* file) {
struct fifo* fifo = (struct fifo*)file->inode;
struct fifo* fifo = fifo_from_file(file);
switch (file->flags & O_ACCMODE) {
case O_RDONLY:
--fifo->num_readers;
Expand All @@ -76,15 +84,15 @@ static int fifo_close(struct file* file) {
}

static bool unblock_read(struct file* file) {
const struct fifo* fifo = (const struct fifo*)file->inode;
const struct fifo* fifo = fifo_from_file(file);
return fifo->num_writers == 0 || !ring_buf_is_empty(&fifo->buf);
}

static ssize_t fifo_pread(struct file* file, void* buffer, size_t count,
uint64_t offset) {
(void)offset;

struct fifo* fifo = (struct fifo*)file->inode;
struct fifo* fifo = fifo_from_file(file);
struct ring_buf* buf = &fifo->buf;

for (;;) {
Expand Down Expand Up @@ -115,7 +123,7 @@ static ssize_t fifo_pwrite(struct file* file, const void* buffer, size_t count,
uint64_t offset) {
(void)offset;

struct fifo* fifo = (struct fifo*)file->inode;
struct fifo* fifo = fifo_from_file(file);
struct ring_buf* buf = &fifo->buf;

for (;;) {
Expand Down Expand Up @@ -190,5 +198,5 @@ struct inode* fifo_create(void) {
inode->mode = S_IFIFO;
inode->ref_count = 1;

return (struct inode*)fifo;
return inode;
}
3 changes: 2 additions & 1 deletion kernel/fs/fs.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <kernel/panic.h>
#include <kernel/safe_string.h>
#include <kernel/sched.h>
#include <kernel/socket.h>

void inode_ref(struct inode* inode) {
ASSERT(inode);
Expand All @@ -27,7 +28,7 @@ void inode_destroy(struct inode* inode) {
ASSERT(inode->ref_count == 0 && inode->num_links == 0);
ASSERT(inode->fops->destroy_inode);
inode_unref(inode->fifo);
inode_unref((struct inode*)inode->bound_socket);
inode_unref(&inode->bound_socket->inode);
inode->fops->destroy_inode(inode);
}

Expand Down
6 changes: 3 additions & 3 deletions kernel/fs/proc/pid.c
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ static int add_item(proc_dir_inode* parent, const proc_item_def* item_def,
pid_t pid) {
proc_pid_item_inode* node = kmalloc(sizeof(proc_pid_item_inode));
if (!node) {
inode_unref((struct inode*)parent);
inode_unref(&parent->inode);
return -ENOMEM;
}
*node = (proc_pid_item_inode){0};
Expand All @@ -174,7 +174,7 @@ static int add_item(proc_dir_inode* parent, const proc_item_def* item_def,
inode->ref_count = 1;

int rc = dentry_append(&parent->children, item_def->name, inode);
inode_unref((struct inode*)parent);
inode_unref(&parent->inode);
return rc;
}

Expand Down Expand Up @@ -215,6 +215,6 @@ struct inode* proc_pid_dir_inode_create(proc_dir_inode* parent, pid_t pid) {
return ERR_PTR(rc);
}

inode_unref((struct inode*)parent);
inode_unref(&parent->inode);
return inode;
}
4 changes: 2 additions & 2 deletions kernel/fs/proc/root.c
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ static int proc_root_getdents(struct file* file, getdents_callback_fn callback,
static int add_item(proc_dir_inode* parent, const proc_item_def* item_def) {
proc_item_inode* node = kmalloc(sizeof(proc_item_inode));
if (!node) {
inode_unref((struct inode*)parent);
inode_unref(&parent->inode);
return -ENOMEM;
}
*node = (proc_item_inode){0};
Expand All @@ -232,7 +232,7 @@ static int add_item(proc_dir_inode* parent, const proc_item_def* item_def) {
inode->ref_count = 1;

int rc = dentry_append(&parent->children, item_def->name, inode);
inode_unref((struct inode*)parent);
inode_unref(&parent->inode);
return rc;
}

Expand Down
8 changes: 8 additions & 0 deletions kernel/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,11 @@ NODISCARD int unix_socket_listen(struct unix_socket*, int backlog);
NODISCARD struct unix_socket* unix_socket_accept(struct file*);
NODISCARD int unix_socket_connect(struct file*, struct inode* addr_inode);
NODISCARD int unix_socket_shutdown(struct file*, int how);

static inline struct unix_socket* unix_socket_from_inode(struct inode* inode) {
return CONTAINER_OF(inode, struct unix_socket, inode);
}

static inline struct unix_socket* unix_socket_from_file(struct file* file) {
return unix_socket_from_inode(file->inode);
}
9 changes: 4 additions & 5 deletions kernel/syscall/socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ int sys_socket(int domain, int type, int protocol) {
struct unix_socket* socket = unix_socket_create();
if (IS_ERR(socket))
return PTR_ERR(socket);
struct file* file = inode_open((struct inode*)socket, O_RDWR, 0);
struct file* file = inode_open(&socket->inode, O_RDWR, 0);
if (IS_ERR(file))
return PTR_ERR(file);
int fd = task_alloc_file_descriptor(-1, file);
Expand All @@ -31,7 +31,7 @@ int sys_bind(int sockfd, const struct sockaddr* user_addr, socklen_t addrlen) {
return PTR_ERR(file);
if (!S_ISSOCK(file->inode->mode))
return -ENOTSOCK;
struct unix_socket* socket = (struct unix_socket*)file->inode;
struct unix_socket* socket = unix_socket_from_file(file);

if (addrlen <= sizeof(sa_family_t) || sizeof(struct sockaddr_un) < addrlen)
return -EINVAL;
Expand Down Expand Up @@ -69,7 +69,7 @@ int sys_listen(int sockfd, int backlog) {
if (!S_ISSOCK(file->inode->mode))
return -ENOTSOCK;

struct unix_socket* socket = (struct unix_socket*)file->inode;
struct unix_socket* socket = unix_socket_from_file(file);
return unix_socket_listen(socket, backlog);
}

Expand Down Expand Up @@ -101,8 +101,7 @@ int sys_accept4(int sockfd, struct sockaddr* user_addr, socklen_t* user_addrlen,
struct unix_socket* connector = unix_socket_accept(file);
if (IS_ERR(connector))
return PTR_ERR(connector);
struct file* connector_file =
inode_open((struct inode*)connector, O_RDWR, 0);
struct file* connector_file = inode_open(&connector->inode, O_RDWR, 0);
if (IS_ERR(connector_file))
return PTR_ERR(connector_file);

Expand Down
35 changes: 16 additions & 19 deletions kernel/unix_socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,37 @@
#include "task.h"

static void unix_socket_destroy_inode(struct inode* inode) {
struct unix_socket* socket = (struct unix_socket*)inode;
struct unix_socket* socket = unix_socket_from_inode(inode);
ring_buf_destroy(&socket->to_connector_buf);
ring_buf_destroy(&socket->to_acceptor_buf);
kfree(socket);
}

static int unix_socket_close(struct file* file) {
struct unix_socket* socket = (struct unix_socket*)file->inode;
struct unix_socket* socket = unix_socket_from_file(file);
socket->is_open_for_writing_to_connector = false;
socket->is_open_for_writing_to_acceptor = false;
return 0;
}

static bool is_connector(struct file* file) {
struct unix_socket* socket = (struct unix_socket*)file->inode;
return socket->connector_file == file;
return unix_socket_from_file(file)->connector_file == file;
}

static bool is_open_for_reading(struct file* file) {
struct unix_socket* socket = (struct unix_socket*)file->inode;
struct unix_socket* socket = unix_socket_from_file(file);
return is_connector(file) ? socket->is_open_for_writing_to_connector
: socket->is_open_for_writing_to_acceptor;
}

static struct ring_buf* buf_to_read(struct file* file) {
struct unix_socket* socket = (struct unix_socket*)file->inode;
struct unix_socket* socket = unix_socket_from_file(file);
return is_connector(file) ? &socket->to_connector_buf
: &socket->to_acceptor_buf;
}

static struct ring_buf* buf_to_write(struct file* file) {
struct unix_socket* socket = (struct unix_socket*)file->inode;
struct unix_socket* socket = unix_socket_from_file(file);
return is_connector(file) ? &socket->to_acceptor_buf
: &socket->to_connector_buf;
}
Expand All @@ -54,7 +53,7 @@ static ssize_t unix_socket_pread(struct file* file, void* buffer, size_t count,
uint64_t offset) {
(void)offset;

struct unix_socket* socket = (struct unix_socket*)file->inode;
struct unix_socket* socket = unix_socket_from_file(file);
if (!socket->is_connected)
return -EINVAL;

Expand All @@ -78,7 +77,7 @@ static ssize_t unix_socket_pread(struct file* file, void* buffer, size_t count,
}

static bool is_writable(struct file* file) {
struct unix_socket* socket = (struct unix_socket*)file->inode;
struct unix_socket* socket = unix_socket_from_file(file);
if (is_connector(file)) {
if (!socket->is_open_for_writing_to_acceptor)
return false;
Expand All @@ -93,7 +92,7 @@ static ssize_t unix_socket_pwrite(struct file* file, const void* buffer,
size_t count, uint64_t offset) {
(void)offset;

struct unix_socket* socket = (struct unix_socket*)file->inode;
struct unix_socket* socket = unix_socket_from_file(file);
if (!socket->is_connected)
return -ENOTCONN;

Expand Down Expand Up @@ -121,7 +120,7 @@ static ssize_t unix_socket_pwrite(struct file* file, const void* buffer,
}

static short unix_socket_poll(struct file* file, short events) {
struct unix_socket* socket = (struct unix_socket*)file->inode;
struct unix_socket* socket = unix_socket_from_file(file);
short revents = 0;
if (events & POLLIN) {
bool can_read =
Expand Down Expand Up @@ -211,15 +210,14 @@ int unix_socket_listen(struct unix_socket* socket, int backlog) {
}

static bool is_acceptable(struct file* file) {
struct unix_socket* socket = (struct unix_socket*)file->inode;
return socket->num_pending > 0;
return unix_socket_from_file(file)->num_pending > 0;
}

struct unix_socket* unix_socket_accept(struct file* file) {
if (!S_ISSOCK(file->inode->mode))
return ERR_PTR(-ENOTSOCK);

struct unix_socket* listener = (struct unix_socket*)file->inode;
struct unix_socket* listener = unix_socket_from_file(file);

mutex_lock(&listener->lock);
bool is_listening = listener->state == SOCKET_STATE_LISTENING;
Expand Down Expand Up @@ -255,8 +253,7 @@ struct unix_socket* unix_socket_accept(struct file* file) {
}

static bool is_connectable(struct file* file) {
struct unix_socket* connector = (struct unix_socket*)file->inode;
return connector->is_connected;
return unix_socket_from_file(file)->is_connected;
}

int unix_socket_connect(struct file* file, struct inode* addr_inode) {
Expand All @@ -267,7 +264,7 @@ int unix_socket_connect(struct file* file, struct inode* addr_inode) {
if (!listener)
return -ECONNREFUSED;

struct unix_socket* connector = (struct unix_socket*)file->inode;
struct unix_socket* connector = unix_socket_from_file(file);
mutex_lock(&connector->lock);

switch (connector->state) {
Expand Down Expand Up @@ -297,7 +294,7 @@ int unix_socket_connect(struct file* file, struct inode* addr_inode) {
connector->state = SOCKET_STATE_PENDING;
connector->next = NULL;

inode_ref((struct inode*)connector);
inode_ref(&connector->inode);

if (listener->next) {
struct unix_socket* it = listener->next;
Expand Down Expand Up @@ -330,7 +327,7 @@ int unix_socket_shutdown(struct file* file, int how) {
bool shut_read = how == SHUT_RD || how == SHUT_RDWR;
bool shut_write = how == SHUT_WR || how == SHUT_RDWR;
bool conn = is_connector(file);
struct unix_socket* socket = (struct unix_socket*)file->inode;
struct unix_socket* socket = unix_socket_from_file(file);
if ((conn && shut_read) || (!conn && shut_write))
socket->is_open_for_writing_to_connector = false;
if ((conn && shut_write) || (!conn && shut_read))
Expand Down

0 comments on commit 8cf6e95

Please sign in to comment.