-
Notifications
You must be signed in to change notification settings - Fork 54
Add tests and docs for reverse shell #998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| package c2_test | ||
|
|
||
| import ( | ||
| "context" | ||
| "net" | ||
| "net/http/httptest" | ||
| "strconv" | ||
| "strings" | ||
| "testing" | ||
| "time" | ||
|
|
||
| "github.com/gorilla/websocket" | ||
| "github.com/stretchr/testify/assert" | ||
| "github.com/stretchr/testify/require" | ||
| "gocloud.dev/pubsub" | ||
| _ "gocloud.dev/pubsub/mempubsub" | ||
| "google.golang.org/grpc" | ||
| "google.golang.org/grpc/credentials/insecure" | ||
| "google.golang.org/grpc/test/bufconn" | ||
| "realm.pub/tavern/internal/c2" | ||
| "realm.pub/tavern/internal/c2/c2pb" | ||
| "realm.pub/tavern/internal/ent/enttest" | ||
| "realm.pub/tavern/internal/http/stream" | ||
|
|
||
| _ "github.com/mattn/go-sqlite3" | ||
| ) | ||
|
|
||
| func TestReverseShell_E2E(t *testing.T) { | ||
| // Setup Ent Client | ||
| graph := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") | ||
| defer graph.Close() | ||
|
|
||
| ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) | ||
| defer cancel() | ||
|
|
||
| // C2 Server Setup | ||
| lis := bufconn.Listen(1024 * 1024) | ||
| s := grpc.NewServer() | ||
|
|
||
| // Pub/Sub Topics | ||
| // The wsMux will be used by websockets to subscribe to shell output and publish new input. | ||
| // The grpcMux will be used by gRPC to subscribe to shell input and publish new output. | ||
|
|
||
| pubInput, err := pubsub.OpenTopic(ctx, "mem://e2e-input") | ||
| require.NoError(t, err) | ||
| defer pubInput.Shutdown(ctx) | ||
|
|
||
| subInput, err := pubsub.OpenSubscription(ctx, "mem://e2e-input") | ||
| require.NoError(t, err) | ||
| defer subInput.Shutdown(ctx) | ||
|
|
||
| pubOutput, err := pubsub.OpenTopic(ctx, "mem://e2e-output") | ||
| require.NoError(t, err) | ||
| defer pubOutput.Shutdown(ctx) | ||
|
|
||
| subOutput, err := pubsub.OpenSubscription(ctx, "mem://e2e-output") | ||
| require.NoError(t, err) | ||
| defer subOutput.Shutdown(ctx) | ||
|
|
||
| wsMux := stream.NewMux(pubInput, subOutput) | ||
| grpcMux := stream.NewMux(pubOutput, subInput) | ||
|
|
||
| go wsMux.Start(ctx) | ||
| go grpcMux.Start(ctx) | ||
|
|
||
| c2pb.RegisterC2Server(s, c2.New(graph, grpcMux)) | ||
| go func() { | ||
| if err := s.Serve(lis); err != nil { | ||
| t.Logf("Server exited with error: %v", err) | ||
| } | ||
| }() | ||
|
|
||
| // gRPC Client Setup | ||
| conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { | ||
| return lis.Dial() | ||
| }), grpc.WithTransportCredentials(insecure.NewCredentials())) | ||
| require.NoError(t, err) | ||
| defer conn.Close() | ||
|
|
||
| c2Client := c2pb.NewC2Client(conn) | ||
|
|
||
| // Create test entities | ||
| user, err := graph.User.Create().SetName("test-user").SetOauthID("test-oauth-id").SetPhotoURL("http://example.com/photo.jpg").Save(ctx) | ||
| require.NoError(t, err) | ||
| host, err := graph.Host.Create().SetIdentifier("test-host").SetPlatform(c2pb.Host_PLATFORM_LINUX).Save(ctx) | ||
| require.NoError(t, err) | ||
| beacon, err := graph.Beacon.Create().SetHost(host).Save(ctx) | ||
| require.NoError(t, err) | ||
| tome, err := graph.Tome.Create().SetName("test-tome").SetDescription("test-desc").SetAuthor("test-author").SetEldritch("test-eldritch").SetUploader(user).Save(ctx) | ||
| require.NoError(t, err) | ||
| quest, err := graph.Quest.Create().SetName("test-quest").SetTome(tome).SetCreator(user).Save(ctx) | ||
| require.NoError(t, err) | ||
| task, err := graph.Task.Create().SetQuest(quest).SetBeacon(beacon).Save(ctx) | ||
| require.NoError(t, err) | ||
|
|
||
| // WebSocket Server Setup | ||
| handler := stream.NewShellHandler(graph, wsMux) | ||
| httpServer := httptest.NewServer(handler) | ||
| defer httpServer.Close() | ||
|
|
||
| // gRPC Stream | ||
| gRPCStream, err := c2Client.ReverseShell(ctx) | ||
| require.NoError(t, err) | ||
|
|
||
| // Register gRPC stream with task ID | ||
| err = gRPCStream.Send(&c2pb.ReverseShellRequest{ | ||
| TaskId: int64(task.ID), | ||
| }) | ||
| require.NoError(t, err) | ||
|
|
||
| // Find the shell created by the gRPC service | ||
| var shellID int | ||
| require.Eventually(t, func() bool { | ||
| shells, err := task.QueryShells().All(ctx) | ||
| if err != nil || len(shells) == 0 { | ||
| return false | ||
| } | ||
| shellID = shells[0].ID | ||
| return true | ||
| }, 5*time.Second, 100*time.Millisecond, "shell was not created in time") | ||
|
|
||
| // WebSocket Client Setup | ||
| wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http") + "?shell_id=" + strconv.Itoa(shellID) | ||
| ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil) | ||
| require.NoError(t, err) | ||
| defer ws.Close() | ||
|
|
||
| // Test message from gRPC to WebSocket | ||
| grpcMsg := []byte("hello from grpc") | ||
| err = gRPCStream.Send(&c2pb.ReverseShellRequest{ | ||
| Kind: c2pb.ReverseShellMessageKind_REVERSE_SHELL_MESSAGE_KIND_DATA, | ||
| Data: grpcMsg, | ||
| }) | ||
| require.NoError(t, err) | ||
|
|
||
| _, wsMsg, err := ws.ReadMessage() | ||
| assert.NoError(t, err) | ||
| assert.Equal(t, grpcMsg, wsMsg) | ||
|
|
||
| // Test message from WebSocket to gRPC | ||
| wsMsgToSend := []byte("hello from websocket") | ||
| err = ws.WriteMessage(websocket.BinaryMessage, wsMsgToSend) | ||
| require.NoError(t, err) | ||
|
|
||
| grpcResp, err := gRPCStream.Recv() | ||
| require.NoError(t, err) | ||
| assert.Equal(t, wsMsgToSend, grpcResp.Data) | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| package stream_test | ||
|
|
||
| import ( | ||
| "context" | ||
| "testing" | ||
| "time" | ||
|
|
||
| "github.com/stretchr/testify/assert" | ||
| "gocloud.dev/pubsub" | ||
| _ "gocloud.dev/pubsub/mempubsub" | ||
| "realm.pub/tavern/internal/http/stream" | ||
| ) | ||
|
|
||
| func TestPreventPubSubColdStarts_ValidInterval(t *testing.T) { | ||
| ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) | ||
| defer cancel() | ||
|
|
||
| // Create a mock topic and subscription. | ||
| topic, err := pubsub.OpenTopic(ctx, "mem://valid") | ||
| if err != nil { | ||
| t.Fatalf("Failed to open topic: %v", err) | ||
| } | ||
| defer topic.Shutdown(ctx) | ||
| sub, err := pubsub.OpenSubscription(ctx, "mem://valid") | ||
| if err != nil { | ||
| t.Fatalf("Failed to open subscription: %v", err) | ||
| } | ||
| defer sub.Shutdown(ctx) | ||
|
|
||
| go stream.PreventPubSubColdStarts(ctx, 50*time.Millisecond, "mem://valid", "mem://valid") | ||
|
|
||
| // Expect to receive a message | ||
| msg, err := sub.Receive(ctx) | ||
| assert.NoError(t, err) | ||
| assert.NotNil(t, msg) | ||
| if msg != nil { | ||
| assert.Equal(t, "noop", msg.Metadata["id"]) | ||
| msg.Ack() | ||
| } | ||
| } | ||
|
|
||
| func TestPreventPubSubColdStarts_ZeroInterval(t *testing.T) { | ||
| ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) | ||
| defer cancel() | ||
|
|
||
| topic, err := pubsub.OpenTopic(ctx, "mem://zero") | ||
| if err != nil { | ||
| t.Fatalf("Failed to open topic: %v", err) | ||
| } | ||
| defer topic.Shutdown(ctx) | ||
| sub, err := pubsub.OpenSubscription(ctx, "mem://zero") | ||
| if err != nil { | ||
| t.Fatalf("Failed to open subscription: %v", err) | ||
| } | ||
| defer sub.Shutdown(ctx) | ||
|
|
||
| go stream.PreventPubSubColdStarts(ctx, 0, "mem://zero", "mem://zero") | ||
|
|
||
| // Expect to not receive a message and for the context to timeout | ||
| _, err = sub.Receive(ctx) | ||
| assert.Error(t, err) | ||
| assert.Equal(t, context.DeadlineExceeded, err) | ||
| } | ||
|
|
||
| func TestPreventPubSubColdStarts_SubMillisecondInterval(t *testing.T) { | ||
| ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) | ||
| defer cancel() | ||
|
|
||
| topic, err := pubsub.OpenTopic(ctx, "mem://sub") | ||
| if err != nil { | ||
| t.Fatalf("Failed to open topic: %v", err) | ||
| } | ||
| defer topic.Shutdown(ctx) | ||
| sub, err := pubsub.OpenSubscription(ctx, "mem://sub") | ||
| if err != nil { | ||
| t.Fatalf("Failed to open subscription: %v", err) | ||
| } | ||
| defer sub.Shutdown(ctx) | ||
|
|
||
| go stream.PreventPubSubColdStarts(ctx, 1*time.Microsecond, "mem://sub", "mem://sub") | ||
|
|
||
| // Expect to receive a message | ||
| msg, err := sub.Receive(ctx) | ||
| assert.NoError(t, err) | ||
| assert.NotNil(t, msg) | ||
| if msg != nil { | ||
| assert.Equal(t, "noop", msg.Metadata["id"]) | ||
| msg.Ack() | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| package stream_test | ||
|
|
||
| import ( | ||
| "context" | ||
| "testing" | ||
| "time" | ||
|
|
||
| "github.com/stretchr/testify/assert" | ||
| "github.com/stretchr/testify/require" | ||
| "gocloud.dev/pubsub" | ||
| _ "gocloud.dev/pubsub/mempubsub" | ||
| "realm.pub/tavern/internal/http/stream" | ||
| ) | ||
|
|
||
| func TestMux(t *testing.T) { | ||
| ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||
| defer cancel() | ||
|
|
||
| // Setup Topic and Subscription | ||
| topic, err := pubsub.OpenTopic(ctx, "mem://mux-test") | ||
| require.NoError(t, err) | ||
| defer topic.Shutdown(ctx) | ||
| sub, err := pubsub.OpenSubscription(ctx, "mem://mux-test") | ||
| require.NoError(t, err) | ||
| defer sub.Shutdown(ctx) | ||
|
|
||
| // Create Mux | ||
| mux := stream.NewMux(topic, sub) | ||
| go mux.Start(ctx) | ||
|
|
||
| // Create and Register Streams | ||
| stream1 := stream.New("stream1") | ||
| stream2 := stream.New("stream2") | ||
|
|
||
| mux.Register(stream1) | ||
| mux.Register(stream2) | ||
|
|
||
| // Give the mux a moment to register the streams | ||
| time.Sleep(10 * time.Millisecond) | ||
|
|
||
| // Send a message for stream1 | ||
| err = topic.Send(ctx, &pubsub.Message{ | ||
| Body: []byte("hello stream 1"), | ||
| Metadata: map[string]string{"id": "stream1"}, | ||
| }) | ||
| require.NoError(t, err) | ||
|
|
||
| // Send a message for stream2 | ||
| err = topic.Send(ctx, &pubsub.Message{ | ||
| Body: []byte("hello stream 2"), | ||
| Metadata: map[string]string{"id": "stream2"}, | ||
| }) | ||
| require.NoError(t, err) | ||
|
|
||
| // Send a message with no id | ||
| err = topic.Send(ctx, &pubsub.Message{ | ||
| Body: []byte("no id"), | ||
| }) | ||
| require.NoError(t, err) | ||
|
|
||
| // Assert messages are received by the correct stream | ||
| select { | ||
| case msg1 := <-stream1.Messages(): | ||
| assert.Equal(t, "hello stream 1", string(msg1.Body)) | ||
| case <-time.After(1 * time.Second): | ||
| t.Fatal("stream1 did not receive message in time") | ||
| } | ||
|
|
||
| select { | ||
| case msg2 := <-stream2.Messages(): | ||
| assert.Equal(t, "hello stream 2", string(msg2.Body)) | ||
| case <-time.After(1 * time.Second): | ||
| t.Fatal("stream2 did not receive message in time") | ||
| } | ||
|
|
||
| // Unregister stream1 | ||
| mux.Unregister(stream1) | ||
|
|
||
| // Give the mux a moment to unregister the stream | ||
| time.Sleep(10 * time.Millisecond) | ||
|
|
||
| // Send another message for stream1 | ||
| err = topic.Send(ctx, &pubsub.Message{ | ||
| Body: []byte("goodbye stream 1"), | ||
| Metadata: map[string]string{"id": "stream1"}, | ||
| }) | ||
| require.NoError(t, err) | ||
|
|
||
| // Assert stream1 does not receive the message | ||
| select { | ||
| case <-stream1.Messages(): | ||
| t.Fatal("stream1 received message after being unregistered") | ||
| case <-time.After(100 * time.Millisecond): | ||
| // This is expected | ||
| } | ||
| } | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After calling
mux.Unregister(stream1)the mux immediately closes the stream’sMessages()channel viaStream.Close(). The finalselectthen reads from that closed channel and immediately enters the first branch, triggeringt.Fataleven though no message was delivered. This makes the test fail deterministically once unregistration succeeds. The assertion should check the second return value from the receive or avoid reading a closed channel rather than assuming any receive implies a published message.Useful? React with 👍 / 👎.