Skip to content

Commit

Permalink
Merge pull request canonical#12203 from MiguelPires/refresh-enforce-o…
Browse files Browse the repository at this point in the history
…verlord

overlord: auto-resolve validation set enforcement constraints
  • Loading branch information
miguelpires authored Oct 3, 2022
2 parents 8d9724d + 23e4fc6 commit f854897
Show file tree
Hide file tree
Showing 26 changed files with 849 additions and 108 deletions.
5 changes: 5 additions & 0 deletions asserts/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ func (e *NotFoundError) Error() string {
return fmt.Sprintf("%v not found", &Ref{Type: e.Type, PrimaryKey: pk})
}

func (e *NotFoundError) Is(err error) bool {
// TODO: replace IsNotFound usages for errors.Is(err, &NotFoundError{})
return IsNotFound(err)
}

// IsNotFound returns whether err is an assertion not found error.
func IsNotFound(err error) bool {
_, ok := err.(*NotFoundError)
Expand Down
13 changes: 13 additions & 0 deletions asserts/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (

"github.com/snapcore/snapd/asserts"
"github.com/snapcore/snapd/asserts/assertstest"
"github.com/snapcore/snapd/testutil"
)

func Test(t *testing.T) { TestingT(t) }
Expand Down Expand Up @@ -203,6 +204,18 @@ func (dbs *databaseSuite) TestPublicKeyNotFound(c *C) {
c.Check(err, ErrorMatches, "cannot find key pair")
}

func (dbs *databaseSuite) TestNotFoundErrorIs(c *C) {
this := &asserts.NotFoundError{
Headers: map[string]string{"a": "a"},
Type: asserts.ValidationSetType,
}
that := &asserts.NotFoundError{
Headers: map[string]string{"b": "b"},
Type: asserts.RepairType,
}
c.Check(this, testutil.ErrorIs, that)
}

type checkSuite struct {
bs asserts.Backstore
a asserts.Assertion
Expand Down
31 changes: 13 additions & 18 deletions daemon/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,13 @@
package daemon

import (
"context"
"errors"
"fmt"
"net/http"
"strconv"
"strings"

"github.com/gorilla/mux"

"github.com/snapcore/snapd/asserts/snapasserts"
"github.com/snapcore/snapd/overlord/assertstate"
"github.com/snapcore/snapd/overlord/auth"
"github.com/snapcore/snapd/overlord/snapstate"
Expand Down Expand Up @@ -134,21 +131,19 @@ func storeFrom(d *Daemon) snapstate.StoreService {
}

var (
snapstateInstall = snapstate.Install
snapstateInstallPath = snapstate.InstallPath
snapstateInstallPathMany = snapstate.InstallPathMany
snapstateRefreshCandidates = snapstate.RefreshCandidates
snapstateTryPath = snapstate.TryPath
snapstateUpdate = snapstate.Update
snapstateUpdateMany = snapstate.UpdateMany
snapstateInstallMany = snapstate.InstallMany
snapstateRemoveMany = snapstate.RemoveMany
snapstateResolveValSetEnforcementError = func(context.Context, *state.State, *snapasserts.ValidationSetsValidationError, map[string]int, int) ([]*state.TaskSet, []string, error) {
return nil, nil, errors.New("not implemented")
}
snapstateRevert = snapstate.Revert
snapstateRevertToRevision = snapstate.RevertToRevision
snapstateSwitch = snapstate.Switch
snapstateInstall = snapstate.Install
snapstateInstallPath = snapstate.InstallPath
snapstateInstallPathMany = snapstate.InstallPathMany
snapstateRefreshCandidates = snapstate.RefreshCandidates
snapstateTryPath = snapstate.TryPath
snapstateUpdate = snapstate.Update
snapstateUpdateMany = snapstate.UpdateMany
snapstateInstallMany = snapstate.InstallMany
snapstateRemoveMany = snapstate.RemoveMany
snapstateResolveValSetsEnforcementError = snapstate.ResolveValidationSetsEnforcementError
snapstateRevert = snapstate.Revert
snapstateRevertToRevision = snapstate.RevertToRevision
snapstateSwitch = snapstate.Switch

assertstateRefreshSnapAssertions = assertstate.RefreshSnapAssertions
assertstateRestoreValidationSetsTracking = assertstate.RestoreValidationSetsTracking
Expand Down
4 changes: 2 additions & 2 deletions daemon/api_snaps.go
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ func snapEnforceValidationSets(inst *snapInstruction, st *state.State) (*snapIns

var tss []*state.TaskSet
var affected []string
err = assertstateTryEnforceValidationSets(st, inst.ValidationSets, inst.userID, snaps, ignoreValidationSnaps)
err = assertstateTryEnforcedValidationSets(st, inst.ValidationSets, inst.userID, snaps, ignoreValidationSnaps)
if err != nil {
vErr, ok := err.(*snapasserts.ValidationSetsValidationError)
if !ok {
Expand Down Expand Up @@ -727,7 +727,7 @@ func meetSnapConstraintsForEnforce(inst *snapInstruction, st *state.State, vErr
pinnedSeqs[fmt.Sprintf("%s/%s", account, name)] = sequence
}

return snapstateResolveValSetEnforcementError(context.TODO(), st, vErr, pinnedSeqs, inst.userID)
return snapstateResolveValSetsEnforcementError(context.TODO(), st, vErr, pinnedSeqs, inst.userID)
}

func snapRemoveMany(inst *snapInstruction, st *state.State) (*snapInstructionResult, error) {
Expand Down
31 changes: 30 additions & 1 deletion daemon/api_snaps_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ func (s *snapsSuite) TestPostSnapsOpSystemRestartImmediate(c *check.C) {

func (s *snapsSuite) testPostSnapsOp(c *check.C, extraJSON, contentType string) (systemRestartImmediate bool) {
defer daemon.MockAssertstateRefreshSnapAssertions(func(*state.State, int, *assertstate.RefreshAssertionsOptions) error { return nil })()
defer daemon.MockSnapstateUpdateMany(func(_ context.Context, s *state.State, names []string, _ []*snapstate.RevisionOptions, userID int, flags *snapstate.Flags) ([]string, []*state.TaskSet, error) {
defer daemon.MockSnapstateUpdateMany(func(_ context.Context, s *state.State, names []string, _ []*snapstate.RevisionOptions, _ int, _ *snapstate.Flags) ([]string, []*state.TaskSet, error) {
c.Check(names, check.HasLen, 0)
t := s.NewTask("fake-refresh-all", "Refreshing everything")
return []string{"fake1", "fake2"}, []*state.TaskSet{state.NewTaskSet(t)}, nil
Expand Down Expand Up @@ -2544,3 +2544,32 @@ func (s *snapsSuite) TestRefreshEnforceSetsNoUnmetConstraints(c *check.C) {
c.Check(resp.Tasksets, check.IsNil)
c.Check(resp.Summary, check.Equals, fmt.Sprintf("Enforce validation sets %s", strutil.Quoted(valsets)))
}

func (s *snapsSuite) TestRefreshEnforceResolveErrorChangeConflictError(c *check.C) {
restore := daemon.MockAssertstateTryEnforceValidationSets(func(st *state.State, validationSets []string, userID int, snaps []*snapasserts.InstalledSnap, ignoreValidation map[string]bool) error {
return &snapasserts.ValidationSetsValidationError{}
})
defer restore()

restore = daemon.MockSnapstateResolveValSetEnforcementError(func(_ context.Context, st *state.State, validErr *snapasserts.ValidationSetsValidationError, pinnedSeqs map[string]int, _ int) ([]*state.TaskSet, []string, error) {
return nil, nil, fmt.Errorf("wrapped error: %w", &snapstate.ChangeConflictError{
Snap: "some-snap",
ChangeID: "12",
ChangeKind: "a-thing",
Message: "conflict with a thing",
})
})
defer restore()

s.daemon(c)

buf := strings.NewReader(`{"action": "refresh", "validation-sets": ["foo/bar"]}`)
req, err := http.NewRequest("POST", "/v2/snaps", buf)
c.Assert(err, check.IsNil)
req.Header.Set("Content-Type", "application/json")

rspe := s.errorReq(c, req, nil)
c.Check(rspe.Status, check.Equals, 409)
c.Check(rspe.Kind, check.Equals, client.ErrorKindSnapChangeConflict)
c.Check(rspe.Message, check.Equals, "conflict with a thing")
}
6 changes: 3 additions & 3 deletions daemon/api_validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ func applyValidationSet(c *Command, r *http.Request, user *auth.UserState) Respo
}

var assertstateMonitorValidationSet = assertstate.MonitorValidationSet
var assertstateEnforceValidationSet = assertstate.EnforceValidationSet
var assertstateTryEnforceValidationSets = assertstate.TryEnforceValidationSets
var assertstateFetchAndApplyEnforcedValidationSet = assertstate.FetchAndApplyEnforcedValidationSet
var assertstateTryEnforcedValidationSets = assertstate.TryEnforcedValidationSets

// updateValidationSet handles snap validate --monitor and --enforce accountId/name[=sequence].
func updateValidationSet(st *state.State, accountID, name string, reqMode string, sequence int, user *auth.UserState) Response {
Expand Down Expand Up @@ -412,7 +412,7 @@ func enforceValidationSet(st *state.State, accountID, name string, sequence, use
if err != nil {
return InternalError(err.Error())
}
tr, err := assertstateEnforceValidationSet(st, accountID, name, sequence, userID, snaps, ignoreValidation)
tr, err := assertstateFetchAndApplyEnforcedValidationSet(st, accountID, name, sequence, userID, snaps, ignoreValidation)
if err != nil {
// XXX: provide more specific error kinds? This would probably require
// assertstate.ValidationSetAssertionForEnforce tuning too.
Expand Down
8 changes: 4 additions & 4 deletions daemon/api_validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ func (s *apiValidationSetsSuite) TestApplyValidationSetEnforceMode(c *check.C) {
c.Assert(err, check.IsNil)

var called int
restore := daemon.MockAssertstateEnforceValidationSet(func(st *state.State, accountID, name string, sequence int, userID int, snaps []*snapasserts.InstalledSnap, ignoreValidation map[string]bool) (*assertstate.ValidationSetTracking, error) {
restore := daemon.MockAssertstateFetchEnforceValidationSet(func(st *state.State, accountID, name string, sequence int, userID int, snaps []*snapasserts.InstalledSnap, ignoreValidation map[string]bool) (*assertstate.ValidationSetTracking, error) {
c.Check(ignoreValidation, check.HasLen, 0)
c.Assert(accountID, check.Equals, s.dev1acct.AccountID())
c.Assert(name, check.Equals, "bar")
Expand Down Expand Up @@ -840,7 +840,7 @@ func (s *apiValidationSetsSuite) TestApplyValidationSetEnforceModeIgnoreValidati
c.Assert(err, check.IsNil)

var called int
restore := daemon.MockAssertstateEnforceValidationSet(func(st *state.State, accountID, name string, sequence int, userID int, snaps []*snapasserts.InstalledSnap, ignoreValidation map[string]bool) (*assertstate.ValidationSetTracking, error) {
restore := daemon.MockAssertstateFetchEnforceValidationSet(func(st *state.State, accountID, name string, sequence int, userID int, snaps []*snapasserts.InstalledSnap, ignoreValidation map[string]bool) (*assertstate.ValidationSetTracking, error) {
c.Check(ignoreValidation, check.DeepEquals, map[string]bool{"snap-b": true})
c.Check(snaps, testutil.DeepUnsortedMatches, []*snapasserts.InstalledSnap{
snapasserts.NewInstalledSnap("snap-b", "yOqKhntON3vR7kwEbVPsILm7bUViPDzz", snap.R("1"))})
Expand Down Expand Up @@ -893,7 +893,7 @@ func (s *apiValidationSetsSuite) TestApplyValidationSetEnforceModeSpecificSequen
c.Assert(err, check.IsNil)

var called int
restore := daemon.MockAssertstateEnforceValidationSet(func(st *state.State, accountID, name string, sequence int, userID int, snaps []*snapasserts.InstalledSnap, ignoreValidation map[string]bool) (*assertstate.ValidationSetTracking, error) {
restore := daemon.MockAssertstateFetchEnforceValidationSet(func(st *state.State, accountID, name string, sequence int, userID int, snaps []*snapasserts.InstalledSnap, ignoreValidation map[string]bool) (*assertstate.ValidationSetTracking, error) {
c.Assert(accountID, check.Equals, s.dev1acct.AccountID())
c.Assert(name, check.Equals, "bar")
c.Assert(sequence, check.Equals, 5)
Expand Down Expand Up @@ -932,7 +932,7 @@ func (s *apiValidationSetsSuite) TestApplyValidationSetEnforceModeSpecificSequen
}

func (s *apiValidationSetsSuite) TestApplyValidationSetEnforceModeError(c *check.C) {
restore := daemon.MockAssertstateEnforceValidationSet(func(st *state.State, accountID, name string, sequence int, userID int, snaps []*snapasserts.InstalledSnap, ignoreValidation map[string]bool) (*assertstate.ValidationSetTracking, error) {
restore := daemon.MockAssertstateFetchEnforceValidationSet(func(st *state.State, accountID, name string, sequence int, userID int, snaps []*snapasserts.InstalledSnap, ignoreValidation map[string]bool) (*assertstate.ValidationSetTracking, error) {
return nil, fmt.Errorf("boom")
})
defer restore()
Expand Down
12 changes: 10 additions & 2 deletions daemon/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package daemon

import (
"errors"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -314,8 +315,6 @@ func errToResponse(err error, snaps []string, fallback errorResponder, format st
case *snap.NotInstalledError:
kind = client.ErrorKindSnapNotInstalled
snapName = err.Snap
case *snapstate.ChangeConflictError:
return SnapChangeConflict(err)
case *servicestate.QuotaChangeConflictError:
return QuotaChangeConflict(err)
case *snapstate.SnapNeedsDevModeError:
Expand Down Expand Up @@ -347,6 +346,15 @@ func errToResponse(err error, snaps []string, fallback errorResponder, format st
}
handled = false
default:
// support wrapped errors
switch {
case errors.Is(err, &snapstate.ChangeConflictError{}):
var conflErr *snapstate.ChangeConflictError
if errors.As(err, &conflErr) {
return SnapChangeConflict(conflErr)
}
}

handled = false
}

Expand Down
8 changes: 4 additions & 4 deletions daemon/export_api_validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ func MockAssertstateMonitorValidationSet(f func(st *state.State, accountID, name
}
}

func MockAssertstateEnforceValidationSet(f func(st *state.State, accountID, name string, sequence int, userID int, snaps []*snapasserts.InstalledSnap, ignoreValidation map[string]bool) (*assertstate.ValidationSetTracking, error)) func() {
old := assertstateEnforceValidationSet
assertstateEnforceValidationSet = f
func MockAssertstateFetchEnforceValidationSet(f func(st *state.State, accountID, name string, sequence int, userID int, snaps []*snapasserts.InstalledSnap, ignoreValidation map[string]bool) (*assertstate.ValidationSetTracking, error)) func() {
old := assertstateFetchAndApplyEnforcedValidationSet
assertstateFetchAndApplyEnforcedValidationSet = f
return func() {
assertstateEnforceValidationSet = old
assertstateFetchAndApplyEnforcedValidationSet = old
}
}
10 changes: 5 additions & 5 deletions daemon/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ func MockAssertstateRefreshSnapAssertions(mock func(*state.State, int, *assertst
}

func MockAssertstateTryEnforceValidationSets(f func(st *state.State, validationSets []string, userID int, snaps []*snapasserts.InstalledSnap, ignoreValidation map[string]bool) error) (restore func()) {
r := testutil.Backup(&assertstateTryEnforceValidationSets)
assertstateTryEnforceValidationSets = f
r := testutil.Backup(&assertstateTryEnforcedValidationSets)
assertstateTryEnforcedValidationSets = f
return r
}

Expand Down Expand Up @@ -215,10 +215,10 @@ func MockSnapstateInstallPathMany(f func(context.Context, *state.State, []*snap.
}

func MockSnapstateResolveValSetEnforcementError(f func(context.Context, *state.State, *snapasserts.ValidationSetsValidationError, map[string]int, int) ([]*state.TaskSet, []string, error)) func() {
old := snapstateResolveValSetEnforcementError
snapstateResolveValSetEnforcementError = f
old := snapstateResolveValSetsEnforcementError
snapstateResolveValSetsEnforcementError = f
return func() {
snapstateResolveValSetEnforcementError = old
snapstateResolveValSetsEnforcementError = old
}
}

Expand Down
2 changes: 1 addition & 1 deletion overlord/assertstate/assertmgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func doValidateSnap(t *state.Task, _ *tomb.Tomb) error {
modelAs := deviceCtx.Model()
expectedProv := snapsup.ExpectedProvenance

err = doFetch(st, snapsup.UserID, deviceCtx, func(f asserts.Fetcher) error {
err = doFetch(st, snapsup.UserID, deviceCtx, nil, func(f asserts.Fetcher) error {
if err := snapasserts.FetchSnapAssertions(f, sha3_384, expectedProv); err != nil {
return err
}
Expand Down
Loading

0 comments on commit f854897

Please sign in to comment.