From b0985629881299ba6ea6338cb853aa4c8ae81226 Mon Sep 17 00:00:00 2001 From: Johan Siebens Date: Sat, 10 Feb 2024 10:04:44 +0100 Subject: [PATCH] fix: log in with different use should create new machine entry --- internal/domain/machine.go | 4 +-- internal/domain/repository.go | 2 +- internal/handlers/authentication.go | 2 +- internal/handlers/registration.go | 2 +- tests/switch_test.go | 43 +++++++++++++++++++++++++++++ tests/tsn/conditions.go | 7 +++++ tests/tsn/node.go | 2 +- 7 files changed, 56 insertions(+), 6 deletions(-) create mode 100644 tests/switch_test.go diff --git a/internal/domain/machine.go b/internal/domain/machine.go index 479cac00..32720cbd 100644 --- a/internal/domain/machine.go +++ b/internal/domain/machine.go @@ -389,9 +389,9 @@ func (r *repository) GetNextMachineNameIndex(ctx context.Context, tailnetID uint return m.NameIdx + 1, nil } -func (r *repository) GetMachineByKey(ctx context.Context, tailnetID uint64, machineKey string) (*Machine, error) { +func (r *repository) GetMachineByKeyAndUser(ctx context.Context, machineKey string, userID uint64) (*Machine, error) { var m Machine - tx := r.withContext(ctx).Preload("Tailnet").Preload("User").Take(&m, "tailnet_id = ? AND machine_key = ?", tailnetID, machineKey) + tx := r.withContext(ctx).Preload("Tailnet").Preload("User").Take(&m, "machine_key = ? AND user_id = ?", machineKey, userID) if errors.Is(tx.Error, gorm.ErrRecordNotFound) { return nil, nil diff --git a/internal/domain/repository.go b/internal/domain/repository.go index 5e618e90..9779cbc5 100644 --- a/internal/domain/repository.go +++ b/internal/domain/repository.go @@ -58,7 +58,7 @@ type Repository interface { SaveMachine(ctx context.Context, m *Machine) error DeleteMachine(ctx context.Context, id uint64) (bool, error) GetMachine(ctx context.Context, id uint64) (*Machine, error) - GetMachineByKey(ctx context.Context, tailnetID uint64, key string) (*Machine, error) + GetMachineByKeyAndUser(ctx context.Context, key string, userID uint64) (*Machine, error) GetMachineByKeys(ctx context.Context, machineKey string, nodeKey string) (*Machine, error) CountMachinesWithIPv4(ctx context.Context, ip string) (int64, error) GetNextMachineNameIndex(ctx context.Context, tailnetID uint64, name string) (uint64, error) diff --git a/internal/handlers/authentication.go b/internal/handlers/authentication.go index 3f8bf79f..52a63ccc 100644 --- a/internal/handlers/authentication.go +++ b/internal/handlers/authentication.go @@ -446,7 +446,7 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, form var m *domain.Machine - m, err := h.repository.GetMachineByKey(ctx, tailnet.ID, machineKey) + m, err := h.repository.GetMachineByKeyAndUser(ctx, machineKey, user.ID) if err != nil { return logError(err) } diff --git a/internal/handlers/registration.go b/internal/handlers/registration.go index 7393335a..ac610a1d 100644 --- a/internal/handlers/registration.go +++ b/internal/handlers/registration.go @@ -173,7 +173,7 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, ma var m *domain.Machine - m, err = h.repository.GetMachineByKey(ctx, tailnet.ID, machineKey) + m, err = h.repository.GetMachineByKeyAndUser(ctx, machineKey, user.ID) if err != nil { return logError(err) } diff --git a/tests/switch_test.go b/tests/switch_test.go new file mode 100644 index 00000000..9b2042bd --- /dev/null +++ b/tests/switch_test.go @@ -0,0 +1,43 @@ +package tests + +import ( + api "github.com/jsiebens/ionscale/pkg/gen/ionscale/v1" + "github.com/jsiebens/ionscale/tests/sc" + "github.com/jsiebens/ionscale/tests/tsn" + "github.com/stretchr/testify/require" + "net/http" + "testing" +) + +func TestSwitchAccounts(t *testing.T) { + sc.Run(t, func(s *sc.Scenario) { + s.PushOIDCUser("123", "john@localtest.me", "john") + s.PushOIDCUser("124", "jane@localtest.me", "jane") + + tailnet := s.CreateTailnet() + s.SetIAMPolicy(tailnet.Id, &api.IAMPolicy{Filters: []string{"domain == localtest.me"}}) + + node := s.NewTailscaleNode(sc.WithName("switch")) + + code, err := node.LoginWithOidc() + require.NoError(t, err) + require.Equal(t, http.StatusOK, code) + + require.NoError(t, node.WaitFor(tsn.Connected())) + require.NoError(t, node.Check(tsn.HasUser("john@localtest.me"))) + require.NoError(t, node.Check(tsn.HasName("switch"))) + + code, err = node.LoginWithOidc() + require.NoError(t, err) + require.Equal(t, http.StatusOK, code) + + require.NoError(t, node.WaitFor(tsn.Connected())) + require.NoError(t, node.Check(tsn.HasUser("jane@localtest.me"))) + require.NoError(t, node.Check(tsn.HasName("switch-1"))) + + machines := s.ListMachines(tailnet.Id) + require.Equal(t, 2, len(machines)) + require.Equal(t, "switch", machines[0].Name) + require.Equal(t, "switch-1", machines[1].Name) + }) +} diff --git a/tests/tsn/conditions.go b/tests/tsn/conditions.go index 5352666c..b23d01d8 100644 --- a/tests/tsn/conditions.go +++ b/tests/tsn/conditions.go @@ -2,6 +2,7 @@ package tsn import ( "slices" + "strings" "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" "tailscale.com/types/views" @@ -27,6 +28,12 @@ func HasTag(tag string) Condition { } } +func HasName(name string) Condition { + return func(status *ipnstate.Status) bool { + return status.Self != nil && strings.HasPrefix(status.Self.DNSName, name) + } +} + func NeedsMachineAuth() Condition { return func(status *ipnstate.Status) bool { return status.BackendState == "NeedsMachineAuth" diff --git a/tests/tsn/node.go b/tests/tsn/node.go index 59dbfd68..2798668c 100644 --- a/tests/tsn/node.go +++ b/tests/tsn/node.go @@ -47,7 +47,7 @@ func (t *TailscaleNode) LoginWithOidc(flags ...UpFlag) (int, error) { return strings.Contains(stderr, "To authenticate, visit:") } - cmd := []string{"up", "--login-server", t.loginServer} + cmd := []string{"login", "--login-server", t.loginServer} for _, f := range flags { cmd = append(cmd, f...) }