Skip to content

Commit

Permalink
Currently, our BPF modules send all events from the kernel to the use…
Browse files Browse the repository at this point in the history
…rspace, where we filter them and log only the ones related to our process. Sending the events between the user and kernel space is expensive, and most of the events are discarded after. (gravitational#19354)

This PR moves the filtering from the userspace to the kernel, where we can filter them earlier and not pay for sending all events to our userspace process. Because the filtering happens in the kernel, the BPF test had to be rewritten to execute events in a sub-cgroup instead of the global one.
  • Loading branch information
jakule authored Feb 7, 2023
1 parent bb5f828 commit 7fab8fa
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 25 deletions.
20 changes: 18 additions & 2 deletions bpf/enhancedrecording/disk.bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
// the userspace can adjust this value based on config.
#define EVENTS_BUF_SIZE (4096*128)

// Maximum monitored sessions.
#define MAX_MONITORED_SESSIONS 1024

char LICENSE[] SEC("license") = "Dual BSD/GPL";

struct val_t {
Expand All @@ -33,6 +36,9 @@ struct data_t {

BPF_HASH(infotmp, u64, struct val_t, INFLIGHT_MAX);

// hashmap keeps all cgroups id that should be monitored by Teleport.
BPF_HASH(monitored_cgroups, u64, int64_t, MAX_MONITORED_SESSIONS);

// open_events ring buffer
BPF_RING_BUF(open_events, EVENTS_BUF_SIZE);

Expand All @@ -52,11 +58,21 @@ static int enter_open(const char *filename, int flags) {

static int exit_open(int ret) {
u64 id = bpf_get_current_pid_tgid();
u64 cgroup = bpf_get_current_cgroup_id();

struct val_t *valp;
struct data_t data = {};
u64 *is_monitored;

valp = bpf_map_lookup_elem(&infotmp, &id);
if (valp == 0) {
if (valp == NULL) {
// Missed entry.
return 0;
}

// Check if the cgroup should be monitored.
is_monitored = bpf_map_lookup_elem(&monitored_cgroups, &cgroup);
if (is_monitored == NULL) {
// Missed entry.
return 0;
}
Expand All @@ -70,7 +86,7 @@ static int exit_open(int ret) {
data.pid = valp->pid;
data.flags = valp->flags;
data.ret = ret;
data.cgroup = bpf_get_current_cgroup_id();
data.cgroup = cgroup;

if (bpf_ringbuf_output(&open_events, &data, sizeof(data), 0) != 0)
INCR_COUNTER(lost);
Expand Down
18 changes: 14 additions & 4 deletions lib/bpf/bpf.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ func (s *Service) OpenSession(ctx *SessionContext) (uint64, error) {
return 0, trace.Wrap(err)
}

// Register cgroup in the BPF module.
if err := s.open.startSession(cgroupID); err != nil {
return 0, trace.Wrap(err)
}

// Start watching for any events that come from this cgroup.
s.watch.Add(cgroupID, ctx)

Expand All @@ -255,14 +260,19 @@ func (s *Service) CloseSession(ctx *SessionContext) error {
// Stop watching for events from this PID.
s.watch.Remove(cgroupID)

var errs []error
// Move all PIDs to the root cgroup and remove the cgroup created for this
// session.
err = s.cgroup.Remove(ctx.SessionID)
if err != nil {
return trace.Wrap(err)
if err := s.cgroup.Remove(ctx.SessionID); err != nil {
errs = append(errs, trace.Wrap(err))
}

return nil
// Remove the cgroup from BPF module.
if err := s.open.endSession(cgroupID); err != nil {
errs = append(errs, trace.Wrap(err))
}

return trace.NewAggregate(errs...)
}

// processAccessEvents pulls events off the perf ring buffer, parses them, and emits them to
Expand Down
152 changes: 136 additions & 16 deletions lib/bpf/bpf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ import (

"github.com/aquasecurity/libbpfgo"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/constants"
apidefaults "github.com/gravitational/teleport/api/defaults"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/cgroup"
"github.com/gravitational/teleport/lib/events/eventstest"
)

Expand All @@ -59,16 +61,18 @@ func TestRootWatch(t *testing.T) {
}

// Create temporary directory where cgroup2 hierarchy will be mounted.
dir, err := os.MkdirTemp("", "cgroup-test")
require.NoError(t, err)
defer os.RemoveAll(dir)
cgroupPath := t.TempDir()

// Create BPF service.
service, err := New(&Config{
Enabled: true,
CgroupPath: dir,
CgroupPath: cgroupPath,
}, &RestrictedSessionConfig{})
defer service.Close()
require.NoError(t, err)

t.Cleanup(func() {
require.NoError(t, service.Close())
})

// Create a fake audit log that can be used to capture the events emitted.
emitter := &eventstest.MockEmitter{}
Expand Down Expand Up @@ -97,7 +101,7 @@ func TestRootWatch(t *testing.T) {
},
})
require.NoError(t, err)
require.Greater(t, cgroupID, 0)
require.Greater(t, cgroupID, uint64(0))

// Find "ls" binary.
lsPath, err := osexec.LookPath("ls")
Expand Down Expand Up @@ -304,6 +308,7 @@ func TestRootPrograms(t *testing.T) {
inCommandArgs []string
inEventCh <-chan []byte
inHTTP bool
verifyFn func(event []byte) bool
}{
// Run execsnoop with "ls".
{
Expand All @@ -312,6 +317,11 @@ func TestRootPrograms(t *testing.T) {
inCommandArgs: []string{},
inEventCh: execsnoop.events(),
inHTTP: false,
verifyFn: func(event []byte) bool {
var e rawExecEvent
err := unmarshalEvent(event, &e)
return err == nil && ConvertString(unsafe.Pointer(&e.Command)) == "ls"
},
},
// Run opensnoop with "ls". This is fine because "ls" will open some
// shared library.
Expand All @@ -321,12 +331,22 @@ func TestRootPrograms(t *testing.T) {
inCommandArgs: []string{},
inEventCh: opensnoop.events(),
inHTTP: false,
verifyFn: func(event []byte) bool {
var e rawOpenEvent
err := unmarshalEvent(event, &e)
return err == nil
},
},
// Run tcpconnect with netcat.
{
inName: "tcpconnect",
inEventCh: tcpconnect.v4Events(),
inHTTP: true,
verifyFn: func(event []byte) bool {
var e rawConn4Event
err := unmarshalEvent(event, &e)
return err == nil
},
},
}
for _, tt := range tests {
Expand All @@ -337,11 +357,11 @@ func TestRootPrograms(t *testing.T) {
// arrive, and once it has, signal over the context that it's complete. The
// second will continue to execute or an HTTP GET in a processAccessEvents attempting to
// trigger an event.
go waitForEvent(doneContext, doneFunc, tt.inEventCh)
go waitForEvent(doneContext, doneFunc, tt.inEventCh, tt.verifyFn)
if tt.inHTTP {
go executeHTTP(t, doneContext, ts.URL)
} else {
go executeCommand(t, doneContext, tt.inCommand)
go executeCommand(t, doneContext, tt.inCommand, opensnoop)
}

// Wait for an event to arrive from execsnoop. If an event does not arrive
Expand Down Expand Up @@ -419,19 +439,78 @@ func TestRootBPFCounter(t *testing.T) {

// waitForEvent will wait for an event to arrive over the perf buffer and
// signal when it has.
func waitForEvent(ctx context.Context, cancel context.CancelFunc, eventCh <-chan []byte) {
func waitForEvent(ctx context.Context, cancel context.CancelFunc, eventCh <-chan []byte, verifyFn func(event []byte) bool) {
for {
select {
case <-eventCh:
cancel()
case e := <-eventCh:
if verifyFn(e) {
cancel()
}
case <-ctx.Done():
return
}
}
}

// Moves the passed pid into a new cgroup.
func moveIntoCgroup(t *testing.T, pid int) (uint64, error) {
t.Helper()

cgroupPath := t.TempDir()

cgroupSrv, err := cgroup.New(&cgroup.Config{
MountPath: cgroupPath,
})
if err != nil {
return 0, trace.Wrap(err)
}
t.Cleanup(func() {
require.NoError(t, cgroupSrv.Close())
})

sessionID := uuid.New().String()
// Put the cmd in a new cgroup.
cgroupID, err := createCgroup(t, cgroupSrv, sessionID)
if err != nil {
return 0, trace.Wrap(err)
}

// Place requested PID into cgroup.
err = cgroupSrv.Place(sessionID, pid)
if err != nil {
return 0, trace.Wrap(err)
}

t.Cleanup(func() {
err := cgroupSrv.Remove(sessionID)
require.NoError(t, err)
})

return cgroupID, nil
}

// createCgroup is a helper function to create Cgroup.
func createCgroup(t *testing.T, cgroup *cgroup.Service, sessionID string,
) (uint64, error) {
t.Helper()

err := cgroup.Create(sessionID)
if err != nil {
return 0, trace.Wrap(err)
}

cgroupID, err := cgroup.ID(sessionID)
if err != nil {
return 0, trace.Wrap(err)
}

return cgroupID, nil
}

// executeCommand will execute some command in a loop.
func executeCommand(t *testing.T, doneContext context.Context, file string) {
func executeCommand(t *testing.T, doneContext context.Context, file string,
traceCgroup cgroupRegister,
) {
t.Helper()

ticker := time.NewTicker(250 * time.Millisecond)
Expand All @@ -445,17 +524,58 @@ func executeCommand(t *testing.T, doneContext context.Context, file string) {
if err != nil {
t.Logf("Failed to find executable %q: %v.", file, err)
}
err = osexec.Command(path).Run()
if err != nil {
t.Logf("Failed to run command %q: %v.", file, err)
}

runCmd(t, path, traceCgroup)
case <-doneContext.Done():
return
}
}
}

func runCmd(t *testing.T, cmdName string, traceCgroup cgroupRegister) {
t.Helper()

// Create a pipe to communicate with the child process after re-exec.
readP, writeP, err := os.Pipe()
require.NoError(t, err)

t.Cleanup(func() {
readP.Close()
writeP.Close()
})

path, err := osexec.LookPath(cmdName)
require.NoError(t, err)

// Re-exec the test binary. We can then move the binary to a new cgroup.
cmd := osexec.Command(os.Args[0], reexecInCGroupCmd, path)

cmd.ExtraFiles = append(cmd.ExtraFiles, readP)

// Start the re-exec
err = cmd.Start()
require.NoError(t, err)

cgroupID, err := moveIntoCgroup(t, cmd.Process.Pid)
require.NoError(t, err)

// Register the process in the BPF module
err = traceCgroup.startSession(cgroupID)
require.NoError(t, err)

// Send one byte to continue the subprocess execution.
_, err = writeP.Write([]byte{1})
require.NoError(t, err)

// Wait for the command to exit. Otherwise, we cannot clean up the cgroup.
require.NoError(t, cmd.Wait())

// Remove the registered cgroup from the BPF module. Do not call it after
// BPF module is deregistered.
err = traceCgroup.endSession(cgroupID)
require.NoError(t, err)
}

// executeHTTP will perform a HTTP GET to some endpoint in a loop.
func executeHTTP(t *testing.T, doneContext context.Context, endpoint string) {
t.Helper()
Expand Down
34 changes: 34 additions & 0 deletions lib/bpf/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ limitations under the License.
package bpf

import (
"io"
"os"
osexec "os/exec"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -26,8 +28,26 @@ import (
"github.com/gravitational/teleport/lib/utils"
)

// reexecInCGroupCmd is a cmd argument used to re-exec the test binary.
const reexecInCGroupCmd = "reexecCgroup"

func TestMain(m *testing.M) {
utils.InitLoggerForTests()

// Check if the re-exec was requested.
if len(os.Args) >= 3 && os.Args[1] == reexecInCGroupCmd {
// Get the command to run passed as the 3rd argument.
cmd := os.Args[2]

if err := waitAndRun(cmd); err != nil {
// Something went wrong, exit with error.
os.Exit(1)
}

// The rexec was handled and nothing bad happened.
os.Exit(0)
}

os.Exit(m.Run())
}

Expand Down Expand Up @@ -79,3 +99,17 @@ func TestCheckAndSetDefaults(t *testing.T) {
require.Equal(t, *tt.outConfig.NetworkBufferSize, *tt.inConfig.NetworkBufferSize)
}
}

// waitAndRun opens FD 3 and waits for at least one byte. After it runs the
// passed command and waits until returns.
func waitAndRun(cmd string) error {
waitFD := os.NewFile(3, "/proc/self/fd/3")
defer waitFD.Close()

buff := make([]byte, 1)
if _, err := waitFD.Read(buff); err != nil && err != io.EOF {
return err
}

return osexec.Command(cmd).Run()
}
Loading

0 comments on commit 7fab8fa

Please sign in to comment.