Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 49 additions & 38 deletions packages/orchestrator/internal/sandbox/block/tracker.go
Original file line number Diff line number Diff line change
@@ -1,66 +1,77 @@
package block

import (
"context"
"fmt"
"sync"
"sync/atomic"

"github.com/bits-and-blooms/bitset"

"github.com/e2b-dev/infra/packages/shared/pkg/storage/header"
)

type TrackedSliceDevice struct {
data ReadonlyDevice
blockSize int64
type Tracker struct {
b *bitset.BitSet
mu sync.RWMutex

nilTracking atomic.Bool
dirty *bitset.BitSet
dirtyMu sync.Mutex
empty []byte
blockSize int64
}

func NewTrackedSliceDevice(blockSize int64, device ReadonlyDevice) (*TrackedSliceDevice, error) {
return &TrackedSliceDevice{
data: device,
empty: make([]byte, blockSize),
func NewTracker(blockSize int64) *Tracker {
return &Tracker{
// The bitset resizes automatically based on the maximum set bit.
b: bitset.New(0),
blockSize: blockSize,
}, nil
}
}

func (t *TrackedSliceDevice) Disable() error {
size, err := t.data.Size()
if err != nil {
return fmt.Errorf("failed to get device size: %w", err)
func NewTrackerFromBitset(b *bitset.BitSet, blockSize int64) *Tracker {
return &Tracker{
b: b,
blockSize: blockSize,
}
}

t.dirty = bitset.New(uint(header.TotalBlocks(size, t.blockSize)))
// We are starting with all being dirty.
t.dirty.FlipRange(0, t.dirty.Len())

t.nilTracking.Store(true)
func (t *Tracker) Has(off int64) bool {
t.mu.RLock()
defer t.mu.RUnlock()

return nil
return t.b.Test(uint(header.BlockIdx(off, t.blockSize)))
}

func (t *TrackedSliceDevice) Slice(ctx context.Context, off int64, length int64) ([]byte, error) {
if t.nilTracking.Load() {
t.dirtyMu.Lock()
t.dirty.Clear(uint(header.BlockIdx(off, t.blockSize)))
t.dirtyMu.Unlock()
func (t *Tracker) Add(off int64) bool {
t.mu.Lock()
defer t.mu.Unlock()

return t.empty, nil
if t.b.Test(uint(header.BlockIdx(off, t.blockSize))) {
return false
}

return t.data.Slice(ctx, off, length)
t.b.Set(uint(header.BlockIdx(off, t.blockSize)))

return true
}

// Return which bytes were not read since Disable.
// This effectively returns the bytes that have been requested after paused vm and are not dirty.
func (t *TrackedSliceDevice) Dirty() *bitset.BitSet {
t.dirtyMu.Lock()
defer t.dirtyMu.Unlock()
func (t *Tracker) Reset() {
t.mu.Lock()
defer t.mu.Unlock()

return t.dirty.Clone()
t.b.ClearAll()
}

// BitSet returns a clone of the bitset and the block size.
func (t *Tracker) BitSet() *bitset.BitSet {
t.mu.RLock()
defer t.mu.RUnlock()

return t.b.Clone()
}

func (t *Tracker) BlockSize() int64 {
return t.blockSize
}

func (t *Tracker) Clone() *Tracker {
return &Tracker{
b: t.BitSet(),
blockSize: t.BlockSize(),
}
}
109 changes: 109 additions & 0 deletions packages/orchestrator/internal/sandbox/block/tracker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package block

import (
"testing"
)

func TestTracker_AddAndHas(t *testing.T) {
const pageSize = 4096
tr := NewTracker(pageSize)

offset := int64(pageSize * 4)

// Initially should not be marked
if tr.Has(offset) {
t.Errorf("Expected offset %d not to be marked initially", offset)
}

// After adding, should be marked
tr.Add(offset)
if !tr.Has(offset) {
t.Errorf("Expected offset %d to be marked after Add", offset)
}
}

func TestTracker_Reset(t *testing.T) {
const pageSize = 4096
tr := NewTracker(pageSize)

offset := int64(pageSize * 4)

// Add offset and verify it's marked
tr.Add(offset)
if !tr.Has(offset) {
t.Errorf("Expected offset %d to be marked after Add", offset)
}

// After reset, should not be marked
tr.Reset()
if tr.Has(offset) {
t.Errorf("Expected offset %d to be cleared after Reset", offset)
}
}

func TestTracker_MultipleOffsets(t *testing.T) {
const pageSize = 4096
tr := NewTracker(pageSize)

offsets := []int64{0, pageSize, 2 * pageSize, 10 * pageSize}

// Add multiple offsets
for _, o := range offsets {
tr.Add(o)
}

// Verify all offsets are marked
for _, o := range offsets {
if !tr.Has(o) {
t.Errorf("Expected offset %d to be marked", o)
}
}
}

func TestTracker_ResetClearsAll(t *testing.T) {
const pageSize = 4096
tr := NewTracker(pageSize)

offsets := []int64{0, pageSize, 2 * pageSize, 10 * pageSize}

// Add multiple offsets
for _, o := range offsets {
tr.Add(o)
}

// Reset should clear all
tr.Reset()

// Verify all offsets are cleared
for _, o := range offsets {
if tr.Has(o) {
t.Errorf("Expected offset %d to be cleared after Reset", o)
}
}
}

func TestTracker_MisalignedOffset(t *testing.T) {
const pageSize = 4096
tr := NewTracker(pageSize)

// Test with misaligned offset
misalignedOffset := int64(123)
tr.Add(misalignedOffset)

// Should be set for the block containing the offset—that is, block 0 (0..4095)
if !tr.Has(misalignedOffset) {
t.Errorf("Expected misaligned offset %d to be marked (should mark its containing block)", misalignedOffset)
}

// Now check that any offset in the same block is also considered marked
anotherOffsetInSameBlock := int64(1000)
if !tr.Has(anotherOffsetInSameBlock) {
t.Errorf("Expected offset %d to be marked as in same block as %d", anotherOffsetInSameBlock, misalignedOffset)
}

// But not for a different block
offsetInNextBlock := int64(pageSize) // convert to int64 to match Has signature
if tr.Has(offsetInNextBlock) {
t.Errorf("Did not expect offset %d to be marked", offsetInNextBlock)
}
}
9 changes: 7 additions & 2 deletions packages/orchestrator/internal/sandbox/sandbox.go
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ func (s *Sandbox) Pause(
return nil, fmt.Errorf("failed to pause VM: %w", err)
}

if err := s.memory.Disable(); err != nil {
if err := s.memory.Disable(ctx); err != nil {
return nil, fmt.Errorf("failed to disable uffd: %w", err)
}

Expand Down Expand Up @@ -724,14 +724,19 @@ func (s *Sandbox) Pause(
return nil, fmt.Errorf("failed to get original rootfs: %w", err)
}

dirty, err := s.memory.Dirty(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get dirty pages: %w", err)
}

// Start POSTPROCESSING
memfileDiff, memfileDiffHeader, err := pauseProcessMemory(
ctx,
buildID,
originalMemfile.Header(),
&MemoryDiffCreator{
memfile: memfile,
dirtyPages: s.memory.Dirty(),
dirtyPages: dirty.BitSet(),
blockSize: originalMemfile.BlockSize(),
doneHook: func(context.Context) error {
return memfile.Close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ package uffd
import (
"context"

"github.com/bits-and-blooms/bitset"

"github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block"
"github.com/e2b-dev/infra/packages/shared/pkg/utils"
)

type MemoryBackend interface {
Disable() error
Dirty() *bitset.BitSet
Dirty(ctx context.Context) (*block.Tracker, error)
// Disable switch the uffd to start serving empty pages.
Disable(ctx context.Context) error

Start(ctx context.Context, sandboxId string) error
Stop() error
Expand Down
15 changes: 8 additions & 7 deletions packages/orchestrator/internal/sandbox/uffd/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/bits-and-blooms/bitset"

"github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block"
"github.com/e2b-dev/infra/packages/shared/pkg/storage/header"
"github.com/e2b-dev/infra/packages/shared/pkg/utils"
)
Expand All @@ -13,7 +14,7 @@ type NoopMemory struct {
size int64
blockSize int64

dirty *bitset.BitSet
dirty *block.Tracker

exit *utils.ErrorOnce
}
Expand All @@ -23,23 +24,23 @@ var _ MemoryBackend = (*NoopMemory)(nil)
func NewNoopMemory(size, blockSize int64) *NoopMemory {
blocks := header.TotalBlocks(size, blockSize)

dirty := bitset.New(uint(blocks))
dirty.FlipRange(0, dirty.Len())
b := bitset.New(uint(blocks))
b.FlipRange(0, b.Len())

return &NoopMemory{
size: size,
blockSize: blockSize,
dirty: dirty,
dirty: block.NewTrackerFromBitset(b, blockSize),
exit: utils.NewErrorOnce(),
}
}

func (m *NoopMemory) Disable() error {
func (m *NoopMemory) Disable(context.Context) error {
return nil
}

func (m *NoopMemory) Dirty() *bitset.BitSet {
return m.dirty
func (m *NoopMemory) Dirty(context.Context) (*block.Tracker, error) {
return m.dirty.Clone(), nil
}

func (m *NoopMemory) Start(context.Context, string) error {
Expand Down
31 changes: 19 additions & 12 deletions packages/orchestrator/internal/sandbox/uffd/uffd.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"syscall"
"time"

"github.com/bits-and-blooms/bitset"
"go.opentelemetry.io/otel"
"go.uber.org/zap"

Expand All @@ -36,7 +35,8 @@ type Uffd struct {
fdExit *fdexit.FdExit
lis *net.UnixListener
socketPath string
memfile *block.TrackedSliceDevice
memfile block.ReadonlyDevice
dirty *block.Tracker
handler utils.SetOnce[*userfaultfd.Userfaultfd]
}

Expand All @@ -48,17 +48,13 @@ func New(memfile block.ReadonlyDevice, socketPath string, blockSize int64) (*Uff
return nil, fmt.Errorf("failed to create fd exit: %w", err)
}

trackedMemfile, err := block.NewTrackedSliceDevice(blockSize, memfile)
if err != nil {
return nil, fmt.Errorf("failed to create tracked slice device: %w", err)
}

return &Uffd{
exit: utils.NewErrorOnce(),
readyCh: make(chan struct{}, 1),
fdExit: fdExit,
socketPath: socketPath,
memfile: trackedMemfile,
memfile: memfile,
dirty: block.NewTracker(blockSize),
handler: *utils.NewSetOnce[*userfaultfd.Userfaultfd](),
}, nil
}
Expand Down Expand Up @@ -147,6 +143,7 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error {
uintptr(fds[0]),
u.memfile,
m,
u.memfile.BlockSize(),
zap.L().With(logger.WithSandboxID(sandboxId)),
)
if err != nil {
Expand Down Expand Up @@ -187,10 +184,20 @@ func (u *Uffd) Exit() *utils.ErrorOnce {
return u.exit
}

func (u *Uffd) Disable() error {
return u.memfile.Disable()
func (u *Uffd) Disable(ctx context.Context) error {
uffd, err := u.handler.WaitWithContext(ctx)
if err != nil {
return fmt.Errorf("failed to get uffd: %w", err)
}

return uffd.Unregister()
}

func (u *Uffd) Dirty() *bitset.BitSet {
return u.memfile.Dirty()
func (u *Uffd) Dirty(ctx context.Context) (*block.Tracker, error) {
uffd, err := u.handler.WaitWithContext(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get uffd: %w", err)
}

return uffd.Dirty()
}
Loading