Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions implants/lib/pb/src/generated/portal.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions tavern/internal/c2/api_create_portal.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ func (srv *Server) CreatePortal(gstream c2pb.C2_CreatePortalServer) error {
recv, cleanup := srv.portalMux.Subscribe(portalInTopic)
defer cleanup()

// Send CLOSE
defer sendPortalClose(ctx, srv.portalMux, portalID)

// Start goroutine to subscribe to portal input and send to gRPC stream
ctx, cancel := context.WithCancel(ctx)
var wg sync.WaitGroup
Expand Down Expand Up @@ -149,3 +152,20 @@ func sendPortalInput(ctx context.Context, portalID int, gstream c2pb.C2_CreatePo
}
}
}

func sendPortalClose(ctx context.Context, mux *mux.Mux, portalID int) {
portalOutTopic := mux.TopicOut(portalID)
if err := mux.Publish(ctx, portalOutTopic, &portalpb.Mote{
Payload: &portalpb.Mote_Bytes{
Bytes: &portalpb.BytesPayload{
Data: []byte("portal closed"),
Kind: portalpb.BytesPayloadKind_BYTES_PAYLOAD_KIND_CLOSE,
},
},
}); err != nil {
slog.ErrorContext(ctx, "failed to notify subscribers that portal closed",
"portal_id", portalID,
"error", err,
)
}
}
153 changes: 43 additions & 110 deletions tavern/internal/portals/api_open_portal.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package portals

import (
"context"
"fmt"
"log/slog"
"sync"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -43,27 +43,47 @@ func (srv *Server) OpenPortal(gstream portalpb.Portal_OpenPortalServer) error {
defer cleanup()

portalOutTopic := srv.mux.TopicOut(portalID)
// portalOutSub := srv.mux.SubName(portalOutTopic)
recv, cleanup := srv.mux.Subscribe(portalOutTopic)
defer cleanup()

done := make(chan struct{}, 2)

// Start goroutine to subscribe to portal output and send to gRPC stream
ctx, cancel := context.WithCancel(ctx)
var wg sync.WaitGroup
wg.Add(1)
go func(ctx context.Context) {
defer wg.Done()
go func() {
sendPortalOutput(ctx, portalID, gstream, recv)
}(ctx)
done <- struct{}{}
}()

// Send portal input from gRPC stream to portal input topic
go func() {
sendPortalInput(ctx, portalID, gstream, srv.mux)
done <- struct{}{}
}()

select {
case <-ctx.Done():
return nil
case <-done:
return fmt.Errorf("portal closed")
}

// ctx, cancel := context.WithCancel(ctx)
// var wg sync.WaitGroup
// wg.Add(1)
// go func(ctx context.Context) {
// defer wg.Done()
// defer cancel()
// sendPortalOutput(ctx, portalID, gstream, recv)
// }(ctx)

// Send portal input from gRPC stream to portal input topic
sendPortalInput(ctx, portalID, gstream, srv.mux)
// sendPortalInput(ctx, portalID, gstream, srv.mux)

// Cleanup
cancel()
wg.Wait()
// // Cleanup
// cancel()
// wg.Wait()

return nil
// return nil
}

func sendPortalInput(ctx context.Context, portalID int, gstream portalpb.Portal_OpenPortalServer, mux *mux.Mux) {
Expand Down Expand Up @@ -118,7 +138,11 @@ func sendPortalOutput(ctx context.Context, portalID int, gstream portalpb.Portal
select {
case <-ctx.Done():
return
case mote := <-recv:
case mote, ok := <-recv:
if !ok {
return
}

// TRACE: Server User Sub
if err := AddTraceEvent(mote, tracepb.TraceEventKind_TRACE_EVENT_KIND_SERVER_USER_SUB); err != nil {
slog.ErrorContext(ctx, "failed to add trace event (Server User Sub)", "error", err)
Expand All @@ -137,102 +161,11 @@ func sendPortalOutput(ctx context.Context, portalID int, gstream portalpb.Portal
"error", err,
)
}

if payload := mote.GetBytes(); payload != nil && payload.Kind == portalpb.BytesPayloadKind_BYTES_PAYLOAD_KIND_CLOSE {
slog.InfoContext(ctx, "received portal close, disconnecting client", "portal_id", portalID, "reason", string(payload.Data))
return
}
}
}
}

// func sendPortalInput(ctx context.Context, portalID int, gstream portalpb.Portal_InvokePortalServer, pubsubStream *stream.Stream) {
// for {
// select {
// case <-ctx.Done():
// return
// case msg := <-pubsubStream.Messages():
// payload := &portalpb.Payload{}
// if err := proto.Unmarshal(msg.Body, payload); err != nil {
// slog.ErrorContext(ctx, "failed to unmarshal portal input message",
// "portal_id", portalID,
// "error", err,
// )
// continue
// }
// msgLen := len(msg.Body)
// if err := gstream.Send(&portalpb.InvokePortalResponse{
// Payload: payload,
// }); err != nil {
// slog.ErrorContext(ctx, "failed to send input through portal",
// "portal_id", portalID,
// "msg_len", msgLen,
// "error", err,
// )
// return
// }
// slog.DebugContext(ctx, "input sent through portal via gRPC",
// "portal_id", portalID,
// "msg_len", msgLen,
// )
// }
// }
// }

// func sendPortalOutput(ctx context.Context, portalID int, gstream portalpb.Portal_InvokePortalServer, pubsubStream *stream.Stream, mux *stream.Mux) error {
// for {
// req, err := gstream.Recv()
// if err == io.EOF {
// return nil
// }
// if err != nil {
// return status.Errorf(codes.Internal, "failed to receive portal request: %v", err)
// }
// if req.Payload == nil {
// continue
// }

// // Marshal Payload
// data, err := proto.Marshal(req.Payload)
// if err != nil {
// return status.Errorf(codes.Internal, "failed to marshal portal payload: %v", err)
// }

// // Send Pubsub Message
// msgLen := len(data)
// if err := pubsubStream.SendMessage(ctx, &pubsub.Message{
// Body: data,
// }, mux); err != nil {
// slog.ErrorContext(ctx, "portal failed to publish output",
// "portal_id", portalID,
// "msg_len", msgLen,
// "error", err,
// )
// return status.Errorf(codes.Internal, "failed to publish message: %v", err)
// }
// slog.DebugContext(ctx, "portal published output",
// "portal_id", portalID,
// "msg_len", msgLen,
// )
// }
// }

// func sendPortalKeepAlives(ctx context.Context, gstream portalpb.Portal_InvokePortalServer) {
// ticker := time.NewTicker(30 * time.Second) // TODO: #m elmos_magic_numbers define this elsewhere
// defer ticker.Stop()

// for {
// select {
// case <-ctx.Done():
// return
// case <-ticker.C:
// if err := gstream.Send(&portalpb.InvokePortalResponse{
// Payload: &portalpb.Payload{
// Payload: &portalpb.Payload_Bytes{
// Bytes: &portalpb.BytesMessage{
// Kind: portalpb.BytesMessageKind_BYTES_MESSAGE_KIND_PING,
// Data: []byte(time.Now().UTC().Format(time.RFC3339Nano)),
// },
// },
// },
// }); err != nil {
// slog.ErrorContext(ctx, "portal failed to send gRPC keep alive ping", "error", err)
// }
// }
// }
// }
102 changes: 102 additions & 0 deletions tavern/internal/portals/portal_close_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package portals_test

import (
"context"
"io"
"testing"
"time"

"github.com/stretchr/testify/require"
"realm.pub/tavern/internal/c2/c2pb"
"realm.pub/tavern/internal/ent"
"realm.pub/tavern/portals/portalpb"
)

func TestPortalClose(t *testing.T) {
env := SetupTestEnv(t)
defer env.Close()

ctx := context.Background()

// 1. Setup Data (Task)
taskID, err := CreateTask(ctx, env.EntClient)
require.NoError(t, err)

// 2. Start C2.CreatePortal (Agent Side)
c2Stream, err := env.C2Client.CreatePortal(ctx)
require.NoError(t, err)

// Send initial registration message
err = c2Stream.Send(&c2pb.CreatePortalRequest{
TaskId: int64(taskID),
})
require.NoError(t, err)

// Wait for portal creation with retry backoff
var portalsAll []*ent.Portal
for i := 0; i < 10; i++ {
portalsAll, err = env.EntClient.Portal.Query().All(ctx)
require.NoError(t, err)
if len(portalsAll) > 0 {
break
}
time.Sleep(time.Duration(100*(i+1)) * time.Millisecond)
}
require.Len(t, portalsAll, 1)
portalID := portalsAll[0].ID

// 3. Start Portal.OpenPortal (User Side)
portalStream, err := env.PortalClient.OpenPortal(ctx)
require.NoError(t, err)

// Send initial registration message
err = portalStream.Send(&portalpb.OpenPortalRequest{
PortalId: int64(portalID),
})
require.NoError(t, err)

// 4. Verify connection by sending a ping from User to Agent
pingData := []byte("ping")
err = portalStream.Send(&portalpb.OpenPortalRequest{
Mote: &portalpb.Mote{
Payload: &portalpb.Mote_Bytes{
Bytes: &portalpb.BytesPayload{
Data: pingData,
Kind: portalpb.BytesPayloadKind_BYTES_PAYLOAD_KIND_DATA,
},
},
},
})
require.NoError(t, err)

// Receive ping on Agent side
resp, err := c2Stream.Recv()
require.NoError(t, err)
require.Equal(t, pingData, resp.Mote.GetBytes().Data)

// 5. Close Agent Stream (Simulate Agent Disconnect/End of Session)
err = c2Stream.CloseSend()
require.NoError(t, err)

// 6. Verify User Stream Closes
// The OpenPortal handler will receive a CLOSE mote via the pubsub topic
// and should forward it to the user client before closing the stream.

// Read from portalStream - expect CLOSE mote
msg, err := portalStream.Recv()

if err == io.EOF {
// If we get immediate EOF, it means we missed the CLOSE mote or it wasn't sent.
// But based on code reading, it should be sent.
t.Fatal("Expected CLOSE mote, got EOF immediately")
}
require.NoError(t, err)
require.NotNil(t, msg.Mote)
require.NotNil(t, msg.Mote.GetBytes())
require.Equal(t, portalpb.BytesPayloadKind_BYTES_PAYLOAD_KIND_CLOSE, msg.Mote.GetBytes().Kind)

// Attempt to receive again - expect error (portal closed) or EOF
_, err = portalStream.Recv()
require.Error(t, err)
require.Contains(t, err.Error(), "portal closed")
}
Loading
Loading