diff --git a/io/directory.go b/io/directory.go index f35b52ea2..1f67dc204 100644 --- a/io/directory.go +++ b/io/directory.go @@ -430,29 +430,42 @@ func (d *HAMTDirectory) removeFromSizeChange(name string, linkCid cid.Cid) { d.sizeChange -= estimatedLinkSize(name, linkCid) } -// FIXME: Will be extended later to the `AddEntry` case. -func (d *HAMTDirectory) needsToSwitchToBasicDir(ctx context.Context, nameToRemove string) (switchToBasic bool, err error) { +// Evaluate a switch from HAMTDirectory to BasicDirectory in case the size will +// go above the threshold when we are adding or removing an entry. +// In both the add/remove operations any old name will be removed, and for the +// add operation in particular a new entry will be added under that name (otherwise +// nodeToAdd is nil). We compute both (potential) future subtraction and +// addition to the size change. +func (d *HAMTDirectory) needsToSwitchToBasicDir(ctx context.Context, name string, nodeToAdd ipld.Node) (switchToBasic bool, err error) { if HAMTShardingSize == 0 { // Option disabled. return false, nil } - entryToRemove, err := d.shard.Find(ctx, nameToRemove) - if err == os.ErrNotExist { - // Nothing to remove, no point in evaluating a switch. - return false, nil - } else if err != nil { - return false, err + operationSizeChange := 0 + + // Find if there is an old entry under that name that will be overwritten + // (AddEntry) or flat out removed (RemoveEntry). + entryToRemove, err := d.shard.Find(ctx, name) + if err != os.ErrNotExist { + if err != nil { + return false, err + } + operationSizeChange -= estimatedLinkSize(name, entryToRemove.Cid) + } + + // For the AddEntry case compute the size addition of the new entry. + if nodeToAdd != nil { + operationSizeChange += estimatedLinkSize(name, nodeToAdd.Cid()) } - sizeToRemove := estimatedLinkSize(nameToRemove, entryToRemove.Cid) - if d.sizeChange-sizeToRemove >= 0 { + if d.sizeChange+operationSizeChange >= 0 { // We won't have reduced the HAMT net size. return false, nil } // We have reduced the directory size, check if went below the // HAMTShardingSize threshold to trigger a switch. - belowThreshold, err := d.sizeBelowThreshold(ctx, -sizeToRemove) + belowThreshold, err := d.sizeBelowThreshold(ctx, operationSizeChange) if err != nil { return false, err } @@ -511,7 +524,29 @@ var _ Directory = (*UpgradeableDirectory)(nil) // AddChild implements the `Directory` interface. We check when adding new entries // if we should switch to HAMTDirectory according to global option(s). func (d *UpgradeableDirectory) AddChild(ctx context.Context, name string, nd ipld.Node) error { - err := d.Directory.AddChild(ctx, name, nd) + hamtDir, ok := d.Directory.(*HAMTDirectory) + if ok { + // We evaluate a switch in the HAMTDirectory case even for an AddChild + // as it may overwrite an existing entry and end up actually reducing + // the directory size. + switchToBasic, err := hamtDir.needsToSwitchToBasicDir(ctx, name, nd) + if err != nil { + return err + } + + if switchToBasic { + basicDir, err := hamtDir.switchToBasic(ctx) + if err != nil { + return err + } + d.Directory = basicDir + } + return d.Directory.AddChild(ctx, name, nd) + } + + // BasicDirectory + basicDir := d.Directory.(*BasicDirectory) + err := basicDir.AddChild(ctx, name, nd) if err != nil { return err } @@ -520,10 +555,6 @@ func (d *UpgradeableDirectory) AddChild(ctx context.Context, name string, nd ipl if HAMTShardingSize == 0 { return nil } - basicDir, ok := d.Directory.(*BasicDirectory) - if !ok { - return nil - } if basicDir.estimatedSize >= HAMTShardingSize { // Ideally to minimize performance we should check if this last // `AddChild` call would bring the directory size over the threshold @@ -562,7 +593,7 @@ func (d *UpgradeableDirectory) getDagService() ipld.DAGService { func (d *UpgradeableDirectory) RemoveChild(ctx context.Context, name string) error { hamtDir, ok := d.Directory.(*HAMTDirectory) if ok { - switchToBasic, err := hamtDir.needsToSwitchToBasicDir(ctx, name) + switchToBasic, err := hamtDir.needsToSwitchToBasicDir(ctx, name, nil) if err != nil { return err }