Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions docs/_docs/dev-guide/tavern.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,48 @@ apt install -y graphviz

Tavern provides an HTTP(s) [gRPC API](https://grpc.io/) that agents may use directly to claim tasks and submit execution results. This is the standard request flow, and is supported as a core function of realm. For more information, please consult our [API Specification](https://github.com/spellshift/realm/blob/main/tavern/internal/c2/proto/c2.proto).

### Reverse Shell Architecture

The reverse shell in Tavern is a powerful feature that allows for interactive sessions with agents. It's a complex system with several moving parts, so this section will provide a detailed overview of its architecture.

#### Overview

The reverse shell system is designed to be highly scalable and resilient. It uses a combination of gRPC, WebSockets, and a pub/sub messaging system to create a bidirectional communication channel between the agent and the user. The system is designed to work in a distributed fashion, where the server that hosts the WebSocket connection need not be the same server that hosts the gRPC stream.

#### Components

The reverse shell system is composed of the following components:

* **gRPC Server**: The gRPC server is the entry point for the agent. It exposes the `ReverseShell` service, which is a bidirectional gRPC stream. The agent connects to this service to initiate a reverse shell session.
* **WebSocket Server**: The WebSocket server is the entry point for the user. It exposes a WebSocket endpoint that the user can connect to to interact with the reverse shell.
* **Pub/Sub Messaging System**: The pub/sub messaging system is the backbone of the reverse shell. It's used to decouple the gRPC server and the WebSocket server, and to provide a reliable and scalable way to transport messages between them. The system uses two topics:
* **Input Topic**: The input topic is used to send messages from the user (via the WebSocket) to the agent (via the gRPC stream).
* **Output Topic**: The output topic is used to send messages from the agent (via the gRPC stream) to the user (via the WebSocket).
* **Mux**: The `Mux` is a multiplexer that sits between the pub/sub system and the gRPC/WebSocket servers. It's responsible for routing messages between the two. There are two `Mux` instances:
* **wsMux**: The `wsMux` is used by the WebSocket server. It subscribes to the output topic and publishes to the input topic.
* **grpcMux**: The `grpcMux` is used by the gRPC server. It subscribes to the input topic and publishes to the output topic.
* **Stream**: A `Stream` represents a single reverse shell session. It's responsible for managing the connection between the `Mux` and the gRPC/WebSocket client.
* **sessionBuffer**: The `sessionBuffer` is used to order messages within a `Stream`. This is important because multiple users can be connected to the same shell session, and their messages need to be delivered in the correct order.

#### Communication Flow

1. The agent connects to the `ReverseShell` gRPC service.
2. The gRPC server creates a new `Shell` entity, a new `Stream`, and registers the `Stream` with the `grpcMux`.
3. The user connects to the WebSocket endpoint.
4. The WebSocket server creates a new `Stream` and registers it with the `wsMux`.
5. When the user sends a message, it's sent to the WebSocket server, which publishes it to the input topic via the `wsMux`.
6. The `grpcMux` receives the message from the input topic and sends it to the agent via the gRPC stream.
7. When the agent sends a message, it's sent to the gRPC server, which publishes it to the output topic via the `grpcMux`.
8. The `wsMux` receives the message from the output topic and sends it to the user via the WebSocket.

#### Distributed Architecture

The reverse shell system is designed to be distributed. The gRPC server and the WebSocket server can be running on different machines. This is made possible by the pub/sub messaging system, which decouples the two servers. This allows the system to be scaled horizontally by adding more gRPC or WebSocket servers as needed.

#### Message Ordering and Multi-User Sessions

The reverse shell supports multi-user sessions, where multiple users can be connected to the same shell session. To ensure that messages are delivered in the correct order, the `sessionBuffer` is used. The `sessionBuffer` assigns a unique order key to each user's session, and then it orders the messages based on that key. This ensures that messages from different users are not interleaved, and that they are delivered in the order that they were sent.

If you wish to develop an agent using a different transport method (e.g. DNS), your development will need to include a C2. The role of the C2 is to handle agent communication, and translate the chosen transport method into HTTP(s) requests to Tavern's gRPC API. We recommend reusing the existing protobuf definitions for simplicity and forward compatability. This enables developers to use any transport mechanism with Tavern. If you plan to build a C2 for a common protocol for use with Tavern, consider [submitting a PR](https://github.com/spellshift/realm/pulls).

### Agent Loop Lifecycle
Expand Down
2 changes: 1 addition & 1 deletion tavern/internal/c2/api_reverse_shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (srv *Server) ReverseShell(gstream c2pb.C2_ReverseShellServer) error {
if err != nil {
slog.ErrorContext(ctx, "reverse shell closed and failed to load ent for updates",
"error", err,
"shell_id", shell.ID,
"shell_id", shellID,
)
return
}
Expand Down
148 changes: 148 additions & 0 deletions tavern/internal/c2/reverse_shell_e2e_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package c2_test

import (
"context"
"net"
"net/http/httptest"
"strconv"
"strings"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gocloud.dev/pubsub"
_ "gocloud.dev/pubsub/mempubsub"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"
"realm.pub/tavern/internal/c2"
"realm.pub/tavern/internal/c2/c2pb"
"realm.pub/tavern/internal/ent/enttest"
"realm.pub/tavern/internal/http/stream"

_ "github.com/mattn/go-sqlite3"
)

func TestReverseShell_E2E(t *testing.T) {
// Setup Ent Client
graph := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer graph.Close()

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

// C2 Server Setup
lis := bufconn.Listen(1024 * 1024)
s := grpc.NewServer()

// Pub/Sub Topics
// The wsMux will be used by websockets to subscribe to shell output and publish new input.
// The grpcMux will be used by gRPC to subscribe to shell input and publish new output.

pubInput, err := pubsub.OpenTopic(ctx, "mem://e2e-input")
require.NoError(t, err)
defer pubInput.Shutdown(ctx)

subInput, err := pubsub.OpenSubscription(ctx, "mem://e2e-input")
require.NoError(t, err)
defer subInput.Shutdown(ctx)

pubOutput, err := pubsub.OpenTopic(ctx, "mem://e2e-output")
require.NoError(t, err)
defer pubOutput.Shutdown(ctx)

subOutput, err := pubsub.OpenSubscription(ctx, "mem://e2e-output")
require.NoError(t, err)
defer subOutput.Shutdown(ctx)

wsMux := stream.NewMux(pubInput, subOutput)
grpcMux := stream.NewMux(pubOutput, subInput)

go wsMux.Start(ctx)
go grpcMux.Start(ctx)

c2pb.RegisterC2Server(s, c2.New(graph, grpcMux))
go func() {
if err := s.Serve(lis); err != nil {
t.Logf("Server exited with error: %v", err)
}
}()

// gRPC Client Setup
conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
return lis.Dial()
}), grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()

c2Client := c2pb.NewC2Client(conn)

// Create test entities
user, err := graph.User.Create().SetName("test-user").SetOauthID("test-oauth-id").SetPhotoURL("http://example.com/photo.jpg").Save(ctx)
require.NoError(t, err)
host, err := graph.Host.Create().SetIdentifier("test-host").SetPlatform(c2pb.Host_PLATFORM_LINUX).Save(ctx)
require.NoError(t, err)
beacon, err := graph.Beacon.Create().SetHost(host).Save(ctx)
require.NoError(t, err)
tome, err := graph.Tome.Create().SetName("test-tome").SetDescription("test-desc").SetAuthor("test-author").SetEldritch("test-eldritch").SetUploader(user).Save(ctx)
require.NoError(t, err)
quest, err := graph.Quest.Create().SetName("test-quest").SetTome(tome).SetCreator(user).Save(ctx)
require.NoError(t, err)
task, err := graph.Task.Create().SetQuest(quest).SetBeacon(beacon).Save(ctx)
require.NoError(t, err)

// WebSocket Server Setup
handler := stream.NewShellHandler(graph, wsMux)
httpServer := httptest.NewServer(handler)
defer httpServer.Close()

// gRPC Stream
gRPCStream, err := c2Client.ReverseShell(ctx)
require.NoError(t, err)

// Register gRPC stream with task ID
err = gRPCStream.Send(&c2pb.ReverseShellRequest{
TaskId: int64(task.ID),
})
require.NoError(t, err)

// Find the shell created by the gRPC service
var shellID int
require.Eventually(t, func() bool {
shells, err := task.QueryShells().All(ctx)
if err != nil || len(shells) == 0 {
return false
}
shellID = shells[0].ID
return true
}, 5*time.Second, 100*time.Millisecond, "shell was not created in time")

// WebSocket Client Setup
wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http") + "?shell_id=" + strconv.Itoa(shellID)
ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer ws.Close()

// Test message from gRPC to WebSocket
grpcMsg := []byte("hello from grpc")
err = gRPCStream.Send(&c2pb.ReverseShellRequest{
Kind: c2pb.ReverseShellMessageKind_REVERSE_SHELL_MESSAGE_KIND_DATA,
Data: grpcMsg,
})
require.NoError(t, err)

_, wsMsg, err := ws.ReadMessage()
assert.NoError(t, err)
assert.Equal(t, grpcMsg, wsMsg)

// Test message from WebSocket to gRPC
wsMsgToSend := []byte("hello from websocket")
err = ws.WriteMessage(websocket.BinaryMessage, wsMsgToSend)
require.NoError(t, err)

grpcResp, err := gRPCStream.Recv()
require.NoError(t, err)
assert.Equal(t, wsMsgToSend, grpcResp.Data)
}
90 changes: 90 additions & 0 deletions tavern/internal/http/stream/gcp_coldstart_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package stream_test

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"gocloud.dev/pubsub"
_ "gocloud.dev/pubsub/mempubsub"
"realm.pub/tavern/internal/http/stream"
)

func TestPreventPubSubColdStarts_ValidInterval(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

// Create a mock topic and subscription.
topic, err := pubsub.OpenTopic(ctx, "mem://valid")
if err != nil {
t.Fatalf("Failed to open topic: %v", err)
}
defer topic.Shutdown(ctx)
sub, err := pubsub.OpenSubscription(ctx, "mem://valid")
if err != nil {
t.Fatalf("Failed to open subscription: %v", err)
}
defer sub.Shutdown(ctx)

go stream.PreventPubSubColdStarts(ctx, 50*time.Millisecond, "mem://valid", "mem://valid")

// Expect to receive a message
msg, err := sub.Receive(ctx)
assert.NoError(t, err)
assert.NotNil(t, msg)
if msg != nil {
assert.Equal(t, "noop", msg.Metadata["id"])
msg.Ack()
}
}

func TestPreventPubSubColdStarts_ZeroInterval(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

topic, err := pubsub.OpenTopic(ctx, "mem://zero")
if err != nil {
t.Fatalf("Failed to open topic: %v", err)
}
defer topic.Shutdown(ctx)
sub, err := pubsub.OpenSubscription(ctx, "mem://zero")
if err != nil {
t.Fatalf("Failed to open subscription: %v", err)
}
defer sub.Shutdown(ctx)

go stream.PreventPubSubColdStarts(ctx, 0, "mem://zero", "mem://zero")

// Expect to not receive a message and for the context to timeout
_, err = sub.Receive(ctx)
assert.Error(t, err)
assert.Equal(t, context.DeadlineExceeded, err)
}

func TestPreventPubSubColdStarts_SubMillisecondInterval(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

topic, err := pubsub.OpenTopic(ctx, "mem://sub")
if err != nil {
t.Fatalf("Failed to open topic: %v", err)
}
defer topic.Shutdown(ctx)
sub, err := pubsub.OpenSubscription(ctx, "mem://sub")
if err != nil {
t.Fatalf("Failed to open subscription: %v", err)
}
defer sub.Shutdown(ctx)

go stream.PreventPubSubColdStarts(ctx, 1*time.Microsecond, "mem://sub", "mem://sub")

// Expect to receive a message
msg, err := sub.Receive(ctx)
assert.NoError(t, err)
assert.NotNil(t, msg)
if msg != nil {
assert.Equal(t, "noop", msg.Metadata["id"])
msg.Ack()
}
}
96 changes: 96 additions & 0 deletions tavern/internal/http/stream/mux_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package stream_test

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gocloud.dev/pubsub"
_ "gocloud.dev/pubsub/mempubsub"
"realm.pub/tavern/internal/http/stream"
)

func TestMux(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// Setup Topic and Subscription
topic, err := pubsub.OpenTopic(ctx, "mem://mux-test")
require.NoError(t, err)
defer topic.Shutdown(ctx)
sub, err := pubsub.OpenSubscription(ctx, "mem://mux-test")
require.NoError(t, err)
defer sub.Shutdown(ctx)

// Create Mux
mux := stream.NewMux(topic, sub)
go mux.Start(ctx)

// Create and Register Streams
stream1 := stream.New("stream1")
stream2 := stream.New("stream2")

mux.Register(stream1)
mux.Register(stream2)

// Give the mux a moment to register the streams
time.Sleep(10 * time.Millisecond)

// Send a message for stream1
err = topic.Send(ctx, &pubsub.Message{
Body: []byte("hello stream 1"),
Metadata: map[string]string{"id": "stream1"},
})
require.NoError(t, err)

// Send a message for stream2
err = topic.Send(ctx, &pubsub.Message{
Body: []byte("hello stream 2"),
Metadata: map[string]string{"id": "stream2"},
})
require.NoError(t, err)

// Send a message with no id
err = topic.Send(ctx, &pubsub.Message{
Body: []byte("no id"),
})
require.NoError(t, err)

// Assert messages are received by the correct stream
select {
case msg1 := <-stream1.Messages():
assert.Equal(t, "hello stream 1", string(msg1.Body))
case <-time.After(1 * time.Second):
t.Fatal("stream1 did not receive message in time")
}

select {
case msg2 := <-stream2.Messages():
assert.Equal(t, "hello stream 2", string(msg2.Body))
case <-time.After(1 * time.Second):
t.Fatal("stream2 did not receive message in time")
}

// Unregister stream1
mux.Unregister(stream1)

// Give the mux a moment to unregister the stream
time.Sleep(10 * time.Millisecond)

// Send another message for stream1
err = topic.Send(ctx, &pubsub.Message{
Body: []byte("goodbye stream 1"),
Metadata: map[string]string{"id": "stream1"},
})
require.NoError(t, err)

// Assert stream1 does not receive the message
select {
case <-stream1.Messages():
t.Fatal("stream1 received message after being unregistered")
Comment on lines +76 to +92

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid treating closed stream channel as received message

After calling mux.Unregister(stream1) the mux immediately closes the stream’s Messages() channel via Stream.Close(). The final select then reads from that closed channel and immediately enters the first branch, triggering t.Fatal even though no message was delivered. This makes the test fail deterministically once unregistration succeeds. The assertion should check the second return value from the receive or avoid reading a closed channel rather than assuming any receive implies a published message.

Useful? React with 👍 / 👎.

case <-time.After(100 * time.Millisecond):
// This is expected
}
}
Loading
Loading