Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend integrity protection of LCOW layers to SCSI devices #1170

Merged
merged 4 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
pr feedback #1: update function name mocks and Mount calls in tests
Signed-off-by: Maksim An <maksiman@microsoft.com>
  • Loading branch information
anmaxvl committed Oct 7, 2021
commit 49e4ed6153b391a0bff837142585f96ff2f922bf
18 changes: 9 additions & 9 deletions internal/guest/storage/pmem/pmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import (

// Test dependencies
var (
osMkdirAll = os.MkdirAll
osRemoveAll = os.RemoveAll
unixMount = unix.Mount
mountInternal = mount
createLinearTarget = dm.CreateZeroSectorLinearTarget
veritySetup = dm.CreateVerityTarget
removeDevice = dm.RemoveDevice
osMkdirAll = os.MkdirAll
osRemoveAll = os.RemoveAll
unixMount = unix.Mount
mountInternal = mount
createZeroSectorLinearTarget = dm.CreateZeroSectorLinearTarget
createVerityTargetCalled = dm.CreateVerityTarget
anmaxvl marked this conversation as resolved.
Show resolved Hide resolved
removeDevice = dm.RemoveDevice
)

const (
Expand Down Expand Up @@ -93,7 +93,7 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot.
// device instead of the original VPMem.
if mappingInfo != nil {
dmLinearName := fmt.Sprintf(linearDeviceFmt, device, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes)
if devicePath, err = createLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil {
if devicePath, err = createZeroSectorLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil {
return err
}
defer func() {
Expand All @@ -107,7 +107,7 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot.

if verityInfo != nil {
dmVerityName := fmt.Sprintf(verityDeviceFmt, device, verityInfo.RootDigest)
if devicePath, err = veritySetup(mCtx, devicePath, dmVerityName, verityInfo); err != nil {
if devicePath, err = createVerityTargetCalled(mCtx, devicePath, dmVerityName, verityInfo); err != nil {
return err
}
defer func() {
Expand Down
108 changes: 69 additions & 39 deletions internal/guest/storage/pmem/pmem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ func clearTestDependencies() {
osMkdirAll = nil
osRemoveAll = nil
unixMount = nil
createLinearTarget = nil
veritySetup = nil
createZeroSectorLinearTarget = nil
createVerityTargetCalled = nil
removeDevice = nil
mountInternal = mount
}
Expand Down Expand Up @@ -323,7 +323,7 @@ func Test_CreateLinearTarget_And_Mount_Called_With_Correct_Parameters(t *testing
expectedSource := "/dev/pmem0"
expectedTarget := "/foo"
mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearName)
createLTCalled := false
createZSLTCalled := false

osMkdirAll = func(_ string, _ os.FileMode) error {
return nil
Expand All @@ -339,28 +339,33 @@ func Test_CreateLinearTarget_And_Mount_Called_With_Correct_Parameters(t *testing
return nil
}

createLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
createLTCalled = true
createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
createZSLTCalled = true
if source != expectedSource {
t.Errorf("expected createLinearTarget source %s, got %s", expectedSource, source)
t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedSource, source)
}
if name != expectedLinearName {
t.Errorf("expected createLinearTarget name %s, got %s", expectedLinearName, name)
t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearName, name)
}
return mapperPath, nil
}

if err := Mount(
context.Background(), 0, expectedTarget, mappingInfo, nil, openDoorSecurityPolicyEnforcer(),
context.Background(),
0,
expectedTarget,
mappingInfo,
nil,
openDoorSecurityPolicyEnforcer(),
); err != nil {
t.Fatalf("unexpected error during Mount: %s", err)
}
if !createLTCalled {
t.Fatalf("createLinearTarget not called")
if !createZSLTCalled {
t.Fatalf("createZeroSectorLinearTarget not called")
}
}

func Test_VeritySetup_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
func Test_CreateVerityTargetCalled_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
clearTestDependencies()

verityInfo := &prot.DeviceVerityInfo{
Expand All @@ -370,7 +375,7 @@ func Test_VeritySetup_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
expectedSource := "/dev/pmem0"
expectedTarget := "/foo"
mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName)
veritySetupCalled := false
createVerityTargetCalledCalled := false
anmaxvl marked this conversation as resolved.
Show resolved Hide resolved

mountInternal = func(_ context.Context, source, target string) error {
if source != mapperPath {
Expand All @@ -381,28 +386,33 @@ func Test_VeritySetup_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
}
return nil
}
veritySetup = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
veritySetupCalled = true
createVerityTargetCalled = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
createVerityTargetCalledCalled = true
if source != expectedSource {
t.Errorf("expected veritySetup source %s, got %s", expectedSource, source)
t.Errorf("expected createVerityTargetCalled source %s, got %s", expectedSource, source)
}
if name != expectedVerityName {
t.Errorf("expected veritySetup name %s, got %s", expectedVerityName, name)
t.Errorf("expected createVerityTargetCalled name %s, got %s", expectedVerityName, name)
}
return mapperPath, nil
}

if err := Mount(
context.Background(), 0, expectedTarget, nil, verityInfo, openDoorSecurityPolicyEnforcer(),
context.Background(),
0,
expectedTarget,
nil,
verityInfo,
openDoorSecurityPolicyEnforcer(),
); err != nil {
t.Fatalf("unexpected Mount failure: %s", err)
}
if !veritySetupCalled {
t.Fatal("veritySetup not called")
if !createVerityTargetCalledCalled {
t.Fatal("createVerityTargetCalled not called")
}
}

func Test_CreateLinearTarget_And_VeritySetup_Called_Correctly(t *testing.T) {
func Test_CreateLinearTarget_And_CreateVerityTargetCalled_Called_Correctly(t *testing.T) {
clearTestDependencies()

verityInfo := &prot.DeviceVerityInfo{
Expand All @@ -421,23 +431,23 @@ func Test_CreateLinearTarget_And_VeritySetup_Called_Correctly(t *testing.T) {
dmVerityCalled := false
mountCalled := false

createLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
dmLinearCalled = true
if source != expectedPMemDevice {
t.Errorf("expected createLinearTarget source %s, got %s", expectedPMemDevice, source)
t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source)
}
if name != expectedLinearTarget {
t.Errorf("expected createLineartarget name %s, got %s", expectedLinearTarget, name)
t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearTarget, name)
}
return mapperLinearPath, nil
}
veritySetup = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
createVerityTargetCalled = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
dmVerityCalled = true
if source != mapperLinearPath {
t.Errorf("expected veritySetup source %s, got %s", mapperLinearPath, source)
t.Errorf("expected createVerityTargetCalled source %s, got %s", mapperLinearPath, source)
}
if name != expectedVerityTarget {
t.Errorf("expected veritySetup target name %s, got %s", expectedVerityTarget, name)
t.Errorf("expected createVerityTargetCalled target name %s, got %s", expectedVerityTarget, name)
}
return mapperVerityPath, nil
}
Expand All @@ -450,15 +460,20 @@ func Test_CreateLinearTarget_And_VeritySetup_Called_Correctly(t *testing.T) {
}

if err := Mount(
context.Background(), 0, "/foo", mapping, verityInfo, openDoorSecurityPolicyEnforcer(),
context.Background(),
0,
"/foo",
mapping,
verityInfo,
openDoorSecurityPolicyEnforcer(),
); err != nil {
t.Fatalf("unexpected error during Mount call: %s", err)
}
if !dmLinearCalled {
t.Fatal("expected createLinearTarget call")
t.Fatal("expected createZeroSectorLinearTarget call")
}
if !dmVerityCalled {
t.Fatal("expected veritySetup call")
t.Fatal("expected createVerityTargetCalled call")
}
if !mountCalled {
t.Fatal("expected mountInternal call")
Expand All @@ -477,7 +492,7 @@ func Test_RemoveDevice_Called_For_LinearTarget_On_MountInternalFailure(t *testin
mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedTarget)
removeDeviceCalled := false

createLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
return mapperPath, nil
}
mountInternal = func(_ context.Context, source, target string) error {
Expand All @@ -492,7 +507,12 @@ func Test_RemoveDevice_Called_For_LinearTarget_On_MountInternalFailure(t *testin
}

if err := Mount(
context.Background(), 0, "/foo", mappingInfo, nil, openDoorSecurityPolicyEnforcer(),
context.Background(),
0,
"/foo",
mappingInfo,
nil,
openDoorSecurityPolicyEnforcer(),
); err != expectedError {
t.Fatalf("expected Mount error %s, got %s", expectedError, err)
}
Expand All @@ -512,7 +532,7 @@ func Test_RemoveDevice_Called_For_VerityTarget_On_MountInternalFailure(t *testin
mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
removeDeviceCalled := false

veritySetup = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
createVerityTargetCalled = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
return mapperPath, nil
}
mountInternal = func(_ context.Context, _, _ string) error {
Expand All @@ -527,7 +547,12 @@ func Test_RemoveDevice_Called_For_VerityTarget_On_MountInternalFailure(t *testin
}

if err := Mount(
context.Background(), 0, "/foo", nil, verity, openDoorSecurityPolicyEnforcer(),
context.Background(),
0,
"/foo",
nil,
verity,
openDoorSecurityPolicyEnforcer(),
); err != expectedError {
t.Fatalf("expected Mount error %s, got %s", expectedError, err)
}
Expand Down Expand Up @@ -555,18 +580,18 @@ func Test_RemoveDevice_Called_For_Both_Targets_On_MountInternalFailure(t *testin
rmLinearCalled := false
rmVerityCalled := false

createLinearTarget = func(_ context.Context, source, name string, m *prot.DeviceMappingInfo) (string, error) {
createZeroSectorLinearTarget = func(_ context.Context, source, name string, m *prot.DeviceMappingInfo) (string, error) {
if source != expectedPMemDevice {
t.Errorf("expected createLinearTarget source %s, got %s", expectedPMemDevice, source)
t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source)
}
return mapperLinearPath, nil
}
veritySetup = func(_ context.Context, source, name string, v *prot.DeviceVerityInfo) (string, error) {
createVerityTargetCalled = func(_ context.Context, source, name string, v *prot.DeviceVerityInfo) (string, error) {
if source != mapperLinearPath {
t.Errorf("expected veritySetup to be called with %s, got %s", mapperLinearPath, source)
t.Errorf("expected createVerityTargetCalled to be called with %s, got %s", mapperLinearPath, source)
}
if name != expectedVerityTarget {
t.Errorf("expected veritySetup target %s, got %s", expectedVerityTarget, name)
t.Errorf("expected createVerityTargetCalled target %s, got %s", expectedVerityTarget, name)
}
return mapperVerityPath, nil
}
Expand All @@ -587,7 +612,12 @@ func Test_RemoveDevice_Called_For_Both_Targets_On_MountInternalFailure(t *testin
}

if err := Mount(
context.Background(), 0, "/foo", mapping, verity, openDoorSecurityPolicyEnforcer(),
context.Background(),
0,
"/foo",
mapping,
verity,
openDoorSecurityPolicyEnforcer(),
); err != expectedError {
t.Fatalf("expected Mount error %s, got %s", expectedError, err)
}
Expand Down
6 changes: 3 additions & 3 deletions internal/guest/storage/scsi/scsi.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ var (

// controllerLunToName is stubbed to make testing `Mount` easier.
controllerLunToName = ControllerLunToName
// veritySetup is stubbed for unit testing `Mount`
veritySetup = dm.CreateVerityTarget
// createVerityTarget is stubbed for unit testing `Mount`
createVerityTarget = dm.CreateVerityTarget
// removeDevice is stubbed for unit testing `Mount`
removeDevice = dm.RemoveDevice
)
Expand Down Expand Up @@ -77,7 +77,7 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b

if verityInfo != nil {
dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, deviceHash)
anmaxvl marked this conversation as resolved.
Show resolved Hide resolved
if source, err = veritySetup(ctx, source, dmVerityName, verityInfo); err != nil {
if source, err = createVerityTarget(spnCtx, source, dmVerityName, verityInfo); err != nil {
return err
}
defer func() {
Expand Down
Loading