Skip to content

Commit

Permalink
test: add tests for validateBounds
Browse files Browse the repository at this point in the history
  • Loading branch information
Juanadelacuesta committed Oct 31, 2024
1 parent d0b015e commit 80e398b
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 35 deletions.
15 changes: 6 additions & 9 deletions drivers/shared/validators/validators.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (

var (
ErrInvalidBound = errors.New("range bound not valid")
ErrEmptyRange = errors.New("range value cannot be empty")
//ErrEmptyRange = errors.New("range value cannot be empty")
ErrInvalidRange = errors.New("lower bound cannot be greater than upper bound")
)

Expand All @@ -29,7 +29,7 @@ type (
UserID uint64
)

type validator struct {
type Validator struct {
// DeniedHostUids configures which host uids are disallowed
deniedUIDs *idset.Set[UserID]

Expand All @@ -40,7 +40,7 @@ type validator struct {
logger hclog.Logger
}

func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*validator, error) {
func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*Validator, error) {
valLogger := logger.Named("id_validator")

err := validateIDRange("deniedHostUIDs", deniedHostUIDs)
Expand All @@ -55,7 +55,7 @@ func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*
}
valLogger.Debug("group range configured", "denied range", deniedHostGIDs)

v := &validator{
v := &Validator{
deniedUIDs: idset.Parse[UserID](deniedHostUIDs),
deniedGIDs: idset.Parse[GroupID](deniedHostGIDs),
logger: valLogger,
Expand All @@ -66,7 +66,7 @@ func NewValidator(logger hclog.Logger, deniedHostUIDs, deniedHostGIDs string) (*

// HasValidIDs is used when running a task to ensure the
// given user is in the ID range defined in the task config
func (v *validator) HasValidIDs(userName string) error {
func (v *Validator) HasValidIDs(userName string) error {
user, err := users.Lookup(userName)
if err != nil {
return fmt.Errorf("failed to identify user %q: %w", userName, err)
Expand All @@ -82,7 +82,7 @@ func (v *validator) HasValidIDs(userName string) error {
return fmt.Errorf("running as uid %d is disallowed", uid)
}

gids, err := getGroupID(user)
gids, err := getGroupsID(user)
if err != nil {
return fmt.Errorf("validator: %w", err)
}
Expand Down Expand Up @@ -122,9 +122,6 @@ func validateBounds(boundsString string) error {
uidDenyRangeParts := strings.Split(boundsString, "-")

switch len(uidDenyRangeParts) {
case 0:
return ErrEmptyRange

case 1:
disallowedIdStr := uidDenyRangeParts[0]
if _, err := strconv.ParseUint(disallowedIdStr, 10, 32); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion drivers/shared/validators/validators_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ func getUserID(*user.User) (UserID, error) {
}

// noop
func getGroupID(*user.User) ([]GroupID, error) {
func getGroupsID(*user.User) ([]GroupID, error) {
return []GroupID{}, nil
}
67 changes: 43 additions & 24 deletions drivers/shared/validators/validators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package validators
import (
"fmt"
"os/user"
"strconv"
"testing"

"github.com/hashicorp/go-hclog"
Expand Down Expand Up @@ -50,47 +51,45 @@ func Test_IDRangeValid(t *testing.T) {
}

func Test_HasValidIds(t *testing.T) {
var validRange = "1-100"

var validRangeSingle = "1"
user, err := user.Current()
must.NoError(t, err)

userID, err := strconv.ParseUint(user.Uid, 10, 32)
groupID, err := strconv.ParseUint(user.Gid, 10, 32)
must.NoError(t, err)

userNotIncluded := fmt.Sprintf("%d-%d", userID+1, userID+11)
userIncluded := fmt.Sprintf("%d-%d", userID, userID+11)
userNotIncludedSingle := fmt.Sprintf("%d", userID+1)

groupNotIncluded := fmt.Sprintf("%d-%d", groupID+1, groupID+11)
groupIncluded := fmt.Sprintf("%d-%d", groupID, groupID+11)
groupNotIncludedSingle := fmt.Sprintf("%d", groupID+1)

emptyRanges := ""
validRangesList := fmt.Sprintf("%s,%s", validRange, validRangeSingle)

userDeniedRangesList := fmt.Sprintf("%s,%s", userNotIncluded, userNotIncludedSingle)
groupDeniedRangesList := fmt.Sprintf("%s,%s", groupNotIncluded, groupNotIncludedSingle)

testCases := []struct {
name string
uidRanges string
gidRanges string
uid string
gid string
expectedErr string
}{
{name: "no-ranges-are-valid", uidRanges: validRangesList, gidRanges: emptyRanges},
{name: "uid-and-gid-outside-of-ranges-valid", uidRanges: validRangesList, gidRanges: validRangesList},
{name: "uid-in-one-of-ranges-is-invalid", uidRanges: validRangesList, gidRanges: validRangesList, uid: "50", expectedErr: "running as uid 50 is disallowed"},
{name: "gid-in-one-of-ranges-is-invalid", uidRanges: validRangesList, gidRanges: validRangesList, gid: "50", expectedErr: "running as gid 50 is disallowed"},
{name: "string-uid-throws-error", uid: "banana", expectedErr: "unable to convert userid banana to integer"},
{name: "user_not_in_denied_ranges", uidRanges: userDeniedRangesList, gidRanges: emptyRanges},
{name: "user_and group_not_in_denied_ranges", uidRanges: userDeniedRangesList, gidRanges: groupDeniedRangesList},
{name: "uid_in_one_of_ranges_is_invalid", uidRanges: userIncluded, gidRanges: groupDeniedRangesList, expectedErr: fmt.Sprintf("running as uid %s is disallowed", user.Uid)},
{name: "gid-in-one-of-ranges-is-invalid", uidRanges: userDeniedRangesList, gidRanges: groupIncluded, expectedErr: fmt.Sprintf("running as gid %s is disallowed", user.Gid)},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
user := &user.User{
Uid: "200",
Gid: "200",
}

if tc.uid != "" {
user.Uid = tc.uid
}

if tc.gid != "" {
user.Gid = tc.gid
}

v, err := NewValidator(hclog.NewNullLogger(), tc.uidRanges, tc.gidRanges)
must.NoError(t, err)

err = v.HasValidIDs(user)
err = v.HasValidIDs(user.Username)

if tc.expectedErr == "" {
must.NoError(t, err)
Expand All @@ -101,3 +100,23 @@ func Test_HasValidIds(t *testing.T) {
})
}
}

func Test_ValidateBounds(t *testing.T) {
testCases := []struct {
name string
bounds string
expectedErr error
}{
{name: "invalid_bound", bounds: "banana", expectedErr: ErrInvalidBound},
{name: "invalid_lower_bound", bounds: "banana-10", expectedErr: ErrInvalidBound},
{name: "invalid_upper_bound", bounds: "10-banana", expectedErr: ErrInvalidBound},
{name: "lower_bigger_than_upper", bounds: "10-1", expectedErr: ErrInvalidRange},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := validateBounds(tc.bounds)
must.ErrorIs(t, err, tc.expectedErr)
})
}
}
2 changes: 1 addition & 1 deletion drivers/shared/validators/validators_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func getUserID(user *user.User) (UserID, error) {
return UserID(id), nil
}

func getGroupID(user *user.User) ([]GroupID, error) {
func getGroupsID(user *user.User) ([]GroupID, error) {
gidStrings, err := user.GroupIds()
if err != nil {
return []GroupID{}, fmt.Errorf("unable to lookup user's group membership: %w", err)
Expand Down

0 comments on commit 80e398b

Please sign in to comment.