Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/linux/tnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ struct tnum tnum_intersect(struct tnum a, struct tnum b);
/* Return @a with all but the lowest @size bytes cleared */
struct tnum tnum_cast(struct tnum a, u8 size);

/* Return @a sign-extended from @size bytes */
struct tnum tnum_scast(struct tnum a, u8 size);

/* Returns true if @a is a known constant */
static inline bool tnum_is_const(struct tnum a)
{
Expand Down
32 changes: 32 additions & 0 deletions kernel/bpf/tnum.c
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,37 @@ struct tnum tnum_cast(struct tnum a, u8 size)
return a;
}

struct tnum tnum_scast(struct tnum a, u8 size)
{
u64 s = size * 8 - 1;
u64 sign_mask;
u64 value_mask;
u64 new_value, new_mask;
u64 sign_bit_unknown, sign_bit_value;
u64 mask;

if (size >= 8) {
return a;
}

sign_mask = 1ULL << s;
value_mask = (1ULL << (s + 1)) - 1;

new_value = a.value & value_mask;
new_mask = a.mask & value_mask;

sign_bit_unknown = (a.mask >> s) & 1;
sign_bit_value = (a.value >> s) & 1;

mask = ~value_mask;

new_mask |= mask & (0 - sign_bit_unknown);

new_value |= mask & (0 - ((sign_bit_unknown ^ 1) & sign_bit_value));

return TNUM(new_value, new_mask);
}

bool tnum_is_aligned(struct tnum a, u64 size)
{
if (!size)
Expand Down Expand Up @@ -211,3 +242,4 @@ struct tnum tnum_const_subreg(struct tnum a, u32 value)
{
return tnum_with_subreg(a, tnum_const(value));
}

64 changes: 15 additions & 49 deletions kernel/bpf/verifier.c
Original file line number Diff line number Diff line change
Expand Up @@ -6288,61 +6288,27 @@ static void set_sext64_default_val(struct bpf_reg_state *reg, int size)

static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size)
{
s64 init_s64_max, init_s64_min, s64_max, s64_min, u64_cval;
u64 top_smax_value, top_smin_value;
u64 num_bits = size * 8;
reg->var_off = tnum_scast(reg->var_off, size);

if (tnum_is_const(reg->var_off)) {
u64_cval = reg->var_off.value;
if (size == 1)
reg->var_off = tnum_const((s8)u64_cval);
else if (size == 2)
reg->var_off = tnum_const((s16)u64_cval);
else
/* size == 4 */
reg->var_off = tnum_const((s32)u64_cval);

u64_cval = reg->var_off.value;
reg->smax_value = reg->smin_value = u64_cval;
reg->umax_value = reg->umin_value = u64_cval;
reg->s32_max_value = reg->s32_min_value = u64_cval;
reg->u32_max_value = reg->u32_min_value = u64_cval;
return;
}

top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits;
top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits;
reg->smin_value = (s64)(reg->var_off.value & ~reg->var_off.mask);
reg->smax_value = (s64)(reg->var_off.value | reg->var_off.mask);

if (top_smax_value != top_smin_value)
goto out;
reg->umin_value = (u64)reg->smin_value;
reg->umax_value = (u64)reg->smax_value;

/* find the s64_min and s64_min after sign extension */
if (size == 1) {
init_s64_max = (s8)reg->smax_value;
init_s64_min = (s8)reg->smin_value;
} else if (size == 2) {
init_s64_max = (s16)reg->smax_value;
init_s64_min = (s16)reg->smin_value;
if (size <= 4) {
reg->s32_min_value = (s32)reg->smin_value;
reg->s32_max_value = (s32)reg->smax_value;
reg->u32_min_value = (u32)reg->umin_value;
reg->u32_max_value = (u32)reg->umax_value;
} else {
init_s64_max = (s32)reg->smax_value;
init_s64_min = (s32)reg->smin_value;
}

s64_max = max(init_s64_max, init_s64_min);
s64_min = min(init_s64_max, init_s64_min);

/* both of s64_max/s64_min positive or negative */
if ((s64_max >= 0) == (s64_min >= 0)) {
reg->smin_value = reg->s32_min_value = s64_min;
reg->smax_value = reg->s32_max_value = s64_max;
reg->umin_value = reg->u32_min_value = s64_min;
reg->umax_value = reg->u32_max_value = s64_max;
reg->var_off = tnum_range(s64_min, s64_max);
return;
reg->s32_min_value = S32_MIN;
reg->s32_max_value = S32_MAX;
reg->u32_min_value = 0;
reg->u32_max_value = U32_MAX;
}

out:
set_sext64_default_val(reg, size);
reg_bounds_sync(reg);
}

static void set_sext32_default_val(struct bpf_reg_state *reg, int size)
Expand Down
199 changes: 199 additions & 0 deletions tools/testing/selftests/bpf/test_tnum.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
// SPDX-License-Identifier: GPL-2.0-only
/* test_tnum.c: Selftests for tnum_scast function
*
* This program tests the tnum_scast function
*/

#include <stdio.h>
#include <stdint.h>
#include <stdbool.h>
#include <string.h>
#include <errno.h>
#include <stdlib.h>

#include "tnum.h"

#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0]))

struct tnum tnum_scast(struct tnum a, u8 size)
{
u64 s = size * 8 - 1;
u64 sign_mask;
u64 value_mask;
u64 new_value, new_mask;

if (size >= 8) {
return a;
}

sign_mask = 1ULL << s;
value_mask = (1ULL << (s + 1)) - 1;

new_value = a.value & value_mask;
new_mask = a.mask & value_mask;

if (a.mask & sign_mask) {
new_mask |= ~value_mask;
} else if (a.value & sign_mask) {
new_value |= ~value_mask;
}

return TNUM(new_value, new_mask);
}

struct tnum_test_case {
const char *description;
struct tnum input;
u8 size;
struct tnum expected;
};

static int test_tnum_scast(void)
{
int i, err = 0;
struct tnum result;

/* Define test cases */
struct tnum_test_case tests[] = {
/* 8-bit tests */
{
.description = "Known positive value (8-bit)",
.input = TNUM(0x7F, 0x00), // 127 in decimal
.size = 1,
.expected = TNUM(0x000000000000007F, 0x0000000000000000),
},
{
.description = "Known negative value (8-bit)",
.input = TNUM(0xFF, 0x00), // -1 in 8-bit signed
.size = 1,
.expected = TNUM(0xFFFFFFFFFFFFFFFF, 0x0000000000000000),
},
{
.description = "Unknown sign bit (8-bit)",
.input = TNUM(0x7F, 0x80), // Value 127, sign bit unknown
.size = 1,
.expected = TNUM(0x000000000000007F, 0xFFFFFFFFFFFFFF80),
},
{
.description = "Completely unknown value (8-bit)",
.input = TNUM(0x00, 0xFF), // All bits unknown
.size = 1,
.expected = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF),
},
/* 16-bit tests */
{
.description = "Known positive value (16-bit)",
.input = TNUM(0x7FFF, 0x0000),
.size = 2,
.expected = TNUM(0x0000000000007FFF, 0x0000000000000000),
},
{
.description = "Known negative value (16-bit)",
.input = TNUM(0xFFFF, 0x0000), // -1 in 16-bit signed
.size = 2,
.expected = TNUM(0xFFFFFFFFFFFFFFFF, 0x0000000000000000),
},
{
.description = "Unknown sign bit (16-bit)",
.input = TNUM(0x7FFF, 0x8000),
.size = 2,
.expected = TNUM(0x0000000000007FFF, 0xFFFFFFFFFFFF8000),
},
{
.description = "Completely unknown value (16-bit)",
.input = TNUM(0x0000, 0xFFFF),
.size = 2,
.expected = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF),
},
/* 32-bit tests */
{
.description = "Known positive value (32-bit)",
.input = TNUM(0x7FFFFFFF, 0x00000000),
.size = 4,
.expected = TNUM(0x000000007FFFFFFF, 0x0000000000000000),
},
{
.description = "Known negative value (32-bit)",
.input = TNUM(0xFFFFFFFF, 0x00000000), // -1 in 32-bit signed
.size = 4,
.expected = TNUM(0xFFFFFFFFFFFFFFFF, 0x0000000000000000),
},
{
.description = "Unknown sign bit (32-bit)",
.input = TNUM(0x7FFFFFFF, 0x80000000),
.size = 4,
.expected = TNUM(0x000000007FFFFFFF, 0xFFFFFFFF80000000),
},
{
.description = "Completely unknown value (32-bit)",
.input = TNUM(0x00000000, 0xFFFFFFFF),
.size = 4,
.expected = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF),
},
/* 64-bit tests */
{
.description = "Known positive value (64-bit)",
.input = TNUM(0x7FFFFFFFFFFFFFFF, 0x0000000000000000),
.size = 8,
.expected = TNUM(0x7FFFFFFFFFFFFFFF, 0x0000000000000000),
},
{
.description = "Known negative value (64-bit)",
.input = TNUM(0xFFFFFFFFFFFFFFFF, 0x0000000000000000),
.size = 8,
.expected = TNUM(0xFFFFFFFFFFFFFFFF, 0x0000000000000000),
},
{
.description = "Unknown sign bit (64-bit)",
.input = TNUM(0x7FFFFFFFFFFFFFFF, 0x8000000000000000ULL),
.size = 8,
.expected = TNUM(0x7FFFFFFFFFFFFFFF, 0x8000000000000000ULL),
},
{
.description = "Completely unknown value (64-bit)",
.input = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF),
.size = 8,
.expected = TNUM(0x0000000000000000, 0xFFFFFFFFFFFFFFFF),
},
};

printf("Running tnum_scast tests...\n\n");

for (i = 0; i < ARRAY_SIZE(tests); i++) {
struct tnum_test_case *t = &tests[i];

result = tnum_scast(t->input, t->size);

printf("Test %d (%s, size=%d bytes):\n", i + 1, t->description, t->size);
printf(" Input: value=0x%016llx, mask=0x%016llx\n",
t->input.value, t->input.mask);
printf(" Expected: value=0x%016llx, mask=0x%016llx\n",
t->expected.value, t->expected.mask);
printf(" Result: value=0x%016llx, mask=0x%016llx\n",
result.value, result.mask);

if (memcmp(&result, &t->expected, sizeof(struct tnum)) != 0) {
printf(" Fail.\n\n");
err = 1;
} else {
printf(" Pass.\n\n");
}
}

if (err)
printf("Some tnum_scast tests failed.\n");
else
printf("All tnum_scast tests passed successfully.\n");

return err;
}

int main(int argc, char **argv)
{
int err = 0;

err |= test_tnum_scast();

return err;
}

26 changes: 26 additions & 0 deletions tools/testing/selftests/bpf/tnum.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-License-Identifier: GPL-2.0-only
/* tnum.h: Header file for tnum utility functions */

#ifndef __TNUM_H__
#define __TNUM_H__

#include <stdint.h>

typedef uint64_t u64;
typedef int64_t s64;
typedef uint32_t u32;
typedef int32_t s32;
typedef uint8_t u8;

struct tnum {
u64 value;
u64 mask;
};

#define TNUM(_v, _m) (struct tnum){.value = (_v), .mask = (_m)}

/* Function prototypes */
struct tnum tnum_scast(struct tnum a, u8 size);

#endif /* __TNUM_H__ */