Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion api/services/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"time"

"github.com/cnf/structhash"
"github.com/shellhub-io/shellhub/api/store"
"github.com/shellhub-io/shellhub/pkg/api/authorizer"
"github.com/shellhub-io/shellhub/pkg/api/jwttoken"
"github.com/shellhub-io/shellhub/pkg/api/requests"
Expand Down Expand Up @@ -166,7 +167,7 @@ func (s *service) AuthDevice(ctx context.Context, req requests.DeviceAuth, remot
}
}

dev, err := s.store.DeviceGetByUID(ctx, models.UID(device.UID), device.TenantID)
dev, err := s.store.DeviceResolve(ctx, store.DeviceUIDResolver, device.UID, s.store.Options().InNamespace(device.TenantID))
if err != nil {
return nil, NewErrDeviceNotFound(models.UID(device.UID), err)
}
Expand Down
20 changes: 17 additions & 3 deletions api/services/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import (

func TestAuthDevice(t *testing.T) {
storeMock := new(mocks.Store)
queryOptionsMock := new(mocks.QueryOptions)
storeMock.On("Options").Return(queryOptionsMock)
cacheMock := new(mockcache.Cache)

clockMock := new(clockmock.Clock)
Expand Down Expand Up @@ -203,8 +205,12 @@ func TestAuthDevice(t *testing.T) {
On("SessionSetLastSeen", ctx, models.UID("session")).
Return(nil).
Once()
queryOptionsMock.
On("InNamespace", "tenant").
Return(nil).
Once()
storeMock.
On("DeviceGetByUID", ctx, testifymock.Anything, "tenant").
On("DeviceResolve", ctx, store.DeviceUIDResolver, testifymock.Anything, testifymock.AnythingOfType("store.QueryOption")).
Return(nil, goerrors.New("device not found")).
Once()
},
Expand Down Expand Up @@ -234,8 +240,12 @@ func TestAuthDevice(t *testing.T) {
On("DeviceCreate", ctx, testifymock.Anything, req.Hostname).
Return(nil).
Once()
queryOptionsMock.
On("InNamespace", "tenant").
Return(nil).
Once()
storeMock.
On("DeviceGetByUID", ctx, testifymock.Anything, "tenant").
On("DeviceResolve", ctx, store.DeviceUIDResolver, testifymock.Anything, testifymock.AnythingOfType("store.QueryOption")).
Return(&models.Device{
UID: key,
Name: "device-name",
Expand Down Expand Up @@ -263,8 +273,12 @@ func TestAuthDevice(t *testing.T) {
On("DeviceCreate", ctx, testifymock.Anything, req.Hostname).
Return(nil).
Once()
queryOptionsMock.
On("InNamespace", "tenant").
Return(nil).
Once()
storeMock.
On("DeviceGetByUID", ctx, testifymock.Anything, "tenant").
On("DeviceResolve", ctx, store.DeviceUIDResolver, testifymock.Anything, testifymock.AnythingOfType("store.QueryOption")).
Return(&models.Device{
UID: key,
Name: "device-name",
Expand Down
36 changes: 21 additions & 15 deletions api/services/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (s *service) ListDevices(ctx context.Context, req *requests.DeviceList) ([]
}

func (s *service) GetDevice(ctx context.Context, uid models.UID) (*models.Device, error) {
device, err := s.store.DeviceGet(ctx, uid)
device, err := s.store.DeviceResolve(ctx, store.DeviceUIDResolver, string(uid))
if err != nil {
return nil, NewErrDeviceNotFound(uid, err)
}
Expand All @@ -98,9 +98,9 @@ func (s *service) ResolveDevice(ctx context.Context, req *requests.ResolveDevice
var device *models.Device
switch {
case req.UID != "":
device, err = s.store.DeviceResolve(ctx, n.TenantID, store.DeviceUIDResolver, req.UID)
device, err = s.store.DeviceResolve(ctx, store.DeviceUIDResolver, req.UID, s.store.Options().InNamespace(n.TenantID))
case req.Hostname != "":
device, err = s.store.DeviceResolve(ctx, n.TenantID, store.DeviceHostnameResolver, req.Hostname)
device, err = s.store.DeviceResolve(ctx, store.DeviceHostnameResolver, req.Hostname, s.store.Options().InNamespace(n.TenantID))
}

if err != nil {
Expand All @@ -120,7 +120,7 @@ func (s *service) ResolveDevice(ctx context.Context, req *requests.ResolveDevice
// NewErrNamespaceNotFound(tenant, err), if the usage cannot be reported, ErrReport or if the store function that
// delete the device fails.
func (s *service) DeleteDevice(ctx context.Context, uid models.UID, tenant string) error {
device, err := s.store.DeviceGetByUID(ctx, uid, tenant)
device, err := s.store.DeviceResolve(ctx, store.DeviceUIDResolver, string(uid), s.store.Options().InNamespace(tenant))
if err != nil {
return NewErrDeviceNotFound(uid, err)
}
Expand All @@ -143,7 +143,7 @@ func (s *service) DeleteDevice(ctx context.Context, uid models.UID, tenant strin
}

func (s *service) RenameDevice(ctx context.Context, uid models.UID, name, tenant string) error {
device, err := s.store.DeviceGetByUID(ctx, uid, tenant)
device, err := s.store.DeviceResolve(ctx, store.DeviceUIDResolver, string(uid), s.store.Options().InNamespace(tenant))
if err != nil {
return NewErrDeviceNotFound(uid, err)
}
Expand Down Expand Up @@ -173,7 +173,7 @@ func (s *service) RenameDevice(ctx context.Context, uid models.UID, name, tenant
return nil
}

otherDevice, err := s.store.DeviceGetByName(ctx, updatedDevice.Name, tenant, models.DeviceStatusAccepted)
otherDevice, err := s.store.DeviceResolve(ctx, store.DeviceHostnameResolver, updatedDevice.Name, s.store.Options().WithDeviceStatus(models.DeviceStatusAccepted), s.store.Options().InNamespace(tenant))
if err != nil && err != store.ErrNoDocuments {
return NewErrDeviceNotFound(models.UID(updatedDevice.UID), err)
}
Expand All @@ -190,9 +190,14 @@ func (s *service) RenameDevice(ctx context.Context, uid models.UID, name, tenant
// It receives a context, used to "control" the request flow and, the namespace name from a models.Namespace and a
// device name from models.Device.
func (s *service) LookupDevice(ctx context.Context, namespace, name string) (*models.Device, error) {
device, err := s.store.DeviceLookup(ctx, namespace, name)
n, err := s.store.NamespaceGetByName(ctx, namespace)
if err != nil {
return nil, NewErrNamespaceNotFound(namespace, err)
}

device, err := s.store.DeviceResolve(ctx, store.DeviceHostnameResolver, name, s.store.Options().InNamespace(n.TenantID))
if err != nil || device == nil {
return nil, NewErrDeviceLookupNotFound(namespace, name, err)
return nil, NewErrDeviceNotFound(models.UID(name), err)
}

return device, nil
Expand All @@ -218,7 +223,7 @@ func (s *service) UpdateDeviceStatus(ctx context.Context, tenant string, uid mod
return NewErrNamespaceNotFound(tenant, err)
}

device, err := s.store.DeviceGetByUID(ctx, uid, tenant)
device, err := s.store.DeviceResolve(ctx, store.DeviceUIDResolver, string(uid), s.store.Options().InNamespace(tenant))
if err != nil {
return NewErrDeviceNotFound(uid, err)
}
Expand All @@ -241,15 +246,16 @@ func (s *service) UpdateDeviceStatus(ctx context.Context, tenant string, uid mod

// NOTICE: when there is an already accepted device with the same MAC address, we need to update the device UID
// transfer the sessions and delete the old device.
sameMacDev, err := s.store.DeviceGetByMac(ctx, device.Identity.MAC, device.TenantID, models.DeviceStatusAccepted)
sameMacDev, err := s.store.DeviceResolve(ctx, store.DeviceMACResolver, device.Identity.MAC, s.store.Options().WithDeviceStatus(models.DeviceStatusAccepted), s.store.Options().InNamespace(device.TenantID))
if err != nil && err != store.ErrNoDocuments {
return NewErrDeviceNotFound(models.UID(device.UID), err)
}

// TODO: move this logic to store's transactions.
if sameMacDev != nil && sameMacDev.UID != device.UID {
if sameName, err := s.store.DeviceGetByName(ctx, device.Name, device.TenantID, models.DeviceStatusAccepted); sameName != nil && sameName.Identity.MAC != device.Identity.MAC {
return NewErrDeviceDuplicated(device.Name, err)
sameDevice, _ := s.store.DeviceResolve(ctx, store.DeviceHostnameResolver, device.Name, s.store.Options().WithDeviceStatus(models.DeviceStatusAccepted), s.store.Options().InNamespace(device.TenantID))
if sameDevice != nil && sameDevice.Identity.MAC != device.Identity.MAC {
return NewErrDeviceDuplicated(device.Name, nil)
}

if err := s.store.SessionUpdateDeviceUID(ctx, models.UID(sameMacDev.UID), models.UID(device.UID)); err != nil && err != store.ErrNoDocuments {
Expand All @@ -267,8 +273,8 @@ func (s *service) UpdateDeviceStatus(ctx context.Context, tenant string, uid mod
return s.store.DeviceUpdateStatus(ctx, uid, status)
}

if sameName, err := s.store.DeviceGetByName(ctx, device.Name, device.TenantID, models.DeviceStatusAccepted); sameName != nil {
return NewErrDeviceDuplicated(device.Name, err)
if sameDevice, _ := s.store.DeviceResolve(ctx, store.DeviceHostnameResolver, device.Name, s.store.Options().WithDeviceStatus(models.DeviceStatusAccepted), s.store.Options().InNamespace(device.TenantID)); sameDevice != nil {
return NewErrDeviceDuplicated(device.Name, nil)
}

if status != models.DeviceStatusAccepted {
Expand Down Expand Up @@ -322,7 +328,7 @@ func (s *service) UpdateDeviceStatus(ctx context.Context, tenant string, uid mod
}

func (s *service) UpdateDevice(ctx context.Context, req *requests.DeviceUpdate) error {
device, err := s.store.DeviceGetByUID(ctx, models.UID(req.UID), req.TenantID)
device, err := s.store.DeviceResolve(ctx, store.DeviceUIDResolver, req.UID, s.store.Options().InNamespace(req.TenantID))
if err != nil {
return NewErrDeviceNotFound(models.UID(req.UID), err)
}
Expand Down
7 changes: 4 additions & 3 deletions api/services/device_tags.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package services
import (
"context"

"github.com/shellhub-io/shellhub/api/store"
"github.com/shellhub-io/shellhub/pkg/models"
)

Expand All @@ -23,7 +24,7 @@ const DeviceMaxTags = 3
// If the device already has the maximum number of tags, a NewErrTagLimit error will be returned.
// A unknown error will be returned if the tag is not created.
func (s *service) CreateDeviceTag(ctx context.Context, uid models.UID, tag string) error {
device, err := s.store.DeviceGet(ctx, uid)
device, err := s.store.DeviceResolve(ctx, store.DeviceUIDResolver, string(uid))
if err != nil || device == nil {
return NewErrDeviceNotFound(uid, err)
}
Expand All @@ -45,7 +46,7 @@ func (s *service) CreateDeviceTag(ctx context.Context, uid models.UID, tag strin
// If the tag does not exist, a NewErrTagNotFound error will be returned.
// A unknown error will be returned if the tag is not removed.
func (s *service) RemoveDeviceTag(ctx context.Context, uid models.UID, tag string) error {
device, err := s.store.DeviceGet(ctx, uid)
device, err := s.store.DeviceResolve(ctx, store.DeviceUIDResolver, string(uid))
if err != nil || device == nil {
return NewErrDeviceNotFound(uid, err)
}
Expand All @@ -67,7 +68,7 @@ func (s *service) UpdateDeviceTag(ctx context.Context, uid models.UID, tags []st
return NewErrTagLimit(DeviceMaxTags, nil)
}

if _, err := s.store.DeviceGet(ctx, uid); err != nil {
if _, err := s.store.DeviceResolve(ctx, store.DeviceUIDResolver, string(uid)); err != nil {
return NewErrDeviceNotFound(uid, err)
}

Expand Down
20 changes: 10 additions & 10 deletions api/services/device_tags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestCreateTag(t *testing.T) {
uid: "invalid_uid",
deviceName: "device1",
requiredMocks: func() {
mock.On("DeviceGet", ctx, models.UID("invalid_uid")).Return(nil, errors.New("error", "", 0)).Once()
mock.On("DeviceResolve", ctx, store.DeviceUIDResolver, "invalid_uid").Return(nil, errors.New("error", "", 0)).Once()
},
expected: NewErrDeviceNotFound(models.UID("invalid_uid"), errors.New("error", "", 0)),
},
Expand All @@ -49,7 +49,7 @@ func TestCreateTag(t *testing.T) {
Tags: []string{"device1"},
}

mock.On("DeviceGet", ctx, models.UID("uid")).Return(device, nil).Once()
mock.On("DeviceResolve", ctx, store.DeviceUIDResolver, "uid").Return(device, nil).Once()
},
expected: NewErrTagDuplicated("device1", nil),
},
Expand All @@ -64,7 +64,7 @@ func TestCreateTag(t *testing.T) {
Tags: []string{"device1"},
}

mock.On("DeviceGet", ctx, models.UID(device.UID)).Return(device, nil).Once()
mock.On("DeviceResolve", ctx, store.DeviceUIDResolver, "uid").Return(device, nil).Once()
mock.On("DevicePushTag", ctx, models.UID(device.UID), "device6").Return(nil).Once()
},
expected: nil,
Expand Down Expand Up @@ -103,7 +103,7 @@ func TestRemoveTag(t *testing.T) {
uid: "invalid_uid",
deviceName: "device1",
requiredMocks: func() {
mock.On("DeviceGet", ctx, models.UID("invalid_uid")).Return(nil, errors.New("error", "", 0)).Once()
mock.On("DeviceResolve", ctx, store.DeviceUIDResolver, "invalid_uid").Return(nil, errors.New("error", "", 0)).Once()
},
expected: NewErrDeviceNotFound(models.UID("invalid_uid"), errors.New("error", "", 0)),
},
Expand All @@ -118,7 +118,7 @@ func TestRemoveTag(t *testing.T) {
Tags: []string{"device1"},
}

mock.On("DeviceGet", ctx, models.UID("uid")).Return(device, nil).Once()
mock.On("DeviceResolve", ctx, store.DeviceUIDResolver, "uid").Return(device, nil).Once()
},
expected: NewErrTagNotFound("device2", nil),
},
Expand All @@ -133,7 +133,7 @@ func TestRemoveTag(t *testing.T) {
Tags: []string{"device1"},
}

mock.On("DeviceGet", ctx, models.UID("uid")).Return(device, nil).Once()
mock.On("DeviceResolve", ctx, store.DeviceUIDResolver, "uid").Return(device, nil).Once()
mock.On("DevicePullTag", ctx, models.UID("uid"), "device1").Return(errors.New("error", "", 0)).Once()
},
expected: errors.New("error", "", 0),
Expand All @@ -149,7 +149,7 @@ func TestRemoveTag(t *testing.T) {
Tags: []string{"device1"},
}

mock.On("DeviceGet", ctx, models.UID("uid")).Return(device, nil).Once()
mock.On("DeviceResolve", ctx, store.DeviceUIDResolver, "uid").Return(device, nil).Once()
mock.On("DevicePullTag", ctx, models.UID("uid"), "device1").Return(nil).Once()
},
expected: nil,
Expand Down Expand Up @@ -194,7 +194,7 @@ func TestDeviceUpdateTag(t *testing.T) {
uid: models.UID("2300230e3ca2f637636b4d025d2235269014865db5204b6d115386cbee89809c"),
tags: []string{"device1", "device2", "device3"},
requiredMocks: func() {
storemock.On("DeviceGet", context.TODO(), models.UID("2300230e3ca2f637636b4d025d2235269014865db5204b6d115386cbee89809c")).Return(nil, errors.New("error", "", 0)).Once()
storemock.On("DeviceResolve", context.TODO(), store.DeviceUIDResolver, "2300230e3ca2f637636b4d025d2235269014865db5204b6d115386cbee89809c").Return(nil, errors.New("error", "", 0)).Once()
},
expected: NewErrDeviceNotFound("2300230e3ca2f637636b4d025d2235269014865db5204b6d115386cbee89809c", errors.New("error", "", 0)),
},
Expand All @@ -207,7 +207,7 @@ func TestDeviceUpdateTag(t *testing.T) {
UID: "2300230e3ca2f637636b4d025d2235269014865db5204b6d115386cbee89809c",
TenantID: "tenant",
}
storemock.On("DeviceGet", context.TODO(), models.UID("2300230e3ca2f637636b4d025d2235269014865db5204b6d115386cbee89809c")).Return(device, nil).Once()
storemock.On("DeviceResolve", context.TODO(), store.DeviceUIDResolver, "2300230e3ca2f637636b4d025d2235269014865db5204b6d115386cbee89809c").Return(device, nil).Once()

tags := []string{"device1", "device2", "device3"}
storemock.On("DeviceSetTags", context.TODO(), models.UID("2300230e3ca2f637636b4d025d2235269014865db5204b6d115386cbee89809c"), tags).Return(int64(0), int64(0), errors.New("error", "layer", 1)).Once()
Expand All @@ -223,7 +223,7 @@ func TestDeviceUpdateTag(t *testing.T) {
UID: "2300230e3ca2f637636b4d025d2235269014865db5204b6d115386cbee89809c",
TenantID: "tenant",
}
storemock.On("DeviceGet", context.TODO(), models.UID("2300230e3ca2f637636b4d025d2235269014865db5204b6d115386cbee89809c")).Return(device, nil).Once()
storemock.On("DeviceResolve", context.TODO(), store.DeviceUIDResolver, "2300230e3ca2f637636b4d025d2235269014865db5204b6d115386cbee89809c").Return(device, nil).Once()

tags := []string{"device1", "device2", "device3"}
storemock.On("DeviceSetTags", context.TODO(), models.UID("2300230e3ca2f637636b4d025d2235269014865db5204b6d115386cbee89809c"), tags).Return(int64(1), int64(3), nil).Once()
Expand Down
Loading
Loading