Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go.mod

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

46 changes: 35 additions & 11 deletions go.sum

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions tavern/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"log"
"log/slog"
"net/http"
"net/http/pprof"
"os"
Expand All @@ -30,10 +31,15 @@ import (
"realm.pub/tavern/internal/graphql"
tavernhttp "realm.pub/tavern/internal/http"
"realm.pub/tavern/internal/http/stream"
"realm.pub/tavern/internal/namegen"
"realm.pub/tavern/internal/www"
"realm.pub/tavern/tomes"
)

func init() {
configureLogging()
}

func newApp(ctx context.Context, options ...func(*Config)) (app *cli.App) {
app = cli.NewApp()
app.Name = "tavern"
Expand Down Expand Up @@ -350,3 +356,30 @@ func registerProfiler(router tavernhttp.RouteMap) {
router.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate"))
router.Handle("/debug/pprof/block", pprof.Handler("block"))
}

func configureLogging() {
// Generate new instance ID as prefix (helps in deployments with multiple tavern instances)
var (
instanceID = namegen.NewSimple()
logger *slog.Logger
)

// Setup Default Logger
if EnvDebugLogging.String() == "" {
// Production Logging
logger = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
})).
With("tavern_id", instanceID)
} else {
// Debug Logging
logger = slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelDebug,
AddSource: true,
})).
With("tavern_id", instanceID)
}

slog.SetDefault(logger)
slog.Debug("Debug logging enabled 🕵️ ")
}
2 changes: 2 additions & 0 deletions tavern/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ var (
// EnvEnableTestData if set will populate the database with test data.
// EnvEnableTestRunAndExit will start the application, but exit immediately after.
// EnvDisableDefaultTomes will prevent the default tomes from being imported on startup.
// EnvDebugLogging will emit verbose debug logs to help troubleshoot issues.
EnvEnableTestData = EnvString{"ENABLE_TEST_DATA", ""}
EnvEnableTestRunAndExit = EnvString{"ENABLE_TEST_RUN_AND_EXIT", ""}
EnvDisableDefaultTomes = EnvString{"DISABLE_DEFAULT_TOMES", ""}
EnvDebugLogging = EnvString{"ENABLE_DEBUG_LOGGING", ""}

// EnvHTTPListenAddr sets the address (ip:port) for tavern's HTTP server to bind to.
// EnvHTTPMetricsAddr sets the address (ip:port) for the HTTP metrics server to bind to.
Expand Down
73 changes: 60 additions & 13 deletions tavern/internal/c2/api_reverse_shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"context"
"fmt"
"io"
"log"
"log/slog"
"sync"
"time"

Expand Down Expand Up @@ -36,20 +36,25 @@ func (srv *Server) ReverseShell(gstream c2pb.C2_ReverseShellServer) error {
task, err := srv.graph.Task.Get(ctx, int(registerMsg.TaskId))
if err != nil {
if ent.IsNotFound(err) {
slog.ErrorContext(ctx, "reverse shell failed: associated task does not exist", "task_id", registerMsg.TaskId, "error", err)
return status.Errorf(codes.NotFound, "task does not exist (task_id=%d)", registerMsg.TaskId)
}
slog.ErrorContext(ctx, "reverse shell failed: could not load associated task", "task_id", registerMsg.TaskId, "error", err)
return status.Errorf(codes.Internal, "failed to load task ent (task_id=%d): %v", registerMsg.TaskId, err)
}
beacon, err := task.Beacon(ctx)
if err != nil {
slog.ErrorContext(ctx, "reverse shell failed: could not load associated beacon", "task_id", registerMsg.TaskId, "error", err)
return status.Errorf(codes.Internal, "failed to load beacon ent (task_id=%d): %v", registerMsg.TaskId, err)
}
quest, err := task.Quest(ctx)
if err != nil {
slog.ErrorContext(ctx, "reverse shell failed: could not load associated quest", "task_id", registerMsg.TaskId, "error", err)
return status.Errorf(codes.Internal, "failed to load quest ent (task_id=%d): %v", registerMsg.TaskId, err)
}
creator, err := quest.Creator(ctx)
if err != nil {
slog.ErrorContext(ctx, "reverse shell failed: could not load associated quest creator", "task_id", registerMsg.TaskId, "error", err)
return status.Errorf(codes.Internal, "failed to load quest creator (task_id=%d): %v", registerMsg.TaskId, err)
}

Expand All @@ -61,13 +66,27 @@ func (srv *Server) ReverseShell(gstream c2pb.C2_ReverseShellServer) error {
SetData([]byte{}).
Save(ctx)
if err != nil {
slog.ErrorContext(ctx, "reverse shell failed: could not create shell entity", "task_id", registerMsg.TaskId, "error", err)
return status.Errorf(codes.Internal, "failed to create shell: %v", err)
}
shellID := shell.ID

// Log Shell Session
log.Printf("[gRPC] Reverse Shell Started (shell_id=%d)", shellID)
defer log.Printf("[gRPC] Reverse Shell Closed (shell_id=%d)", shellID)
slog.InfoContext(ctx, "started gRPC reverse shell",
"shell_id", shellID,
"task_id", registerMsg.TaskId,
"creator_id", creator.ID,
)
defer func(start time.Time) {
slog.InfoContext(ctx, "closed gRPC reverse shell",
"started_at", start.String(),
"ended_at", time.Now().String(),
"duration", time.Since(start).String(),
"shell_id", shellID,
"task_id", registerMsg.TaskId,
"creator_id", creator.ID,
)
}(time.Now())

// Create new Stream
pubsubStream := stream.New(fmt.Sprintf("%d", shellID))
Expand All @@ -81,25 +100,34 @@ func (srv *Server) ReverseShell(gstream c2pb.C2_ReverseShellServer) error {
defer cancel()

// Notify Subscribers that the stream is closed
log.Printf("[gRPC][ReverseShell] Sending stream close message")
slog.DebugContext(ctx, "reverse shell closed, sending stream close message", "shell_id", shell.ID)
if err := pubsubStream.SendMessage(ctx, &pubsub.Message{
Metadata: map[string]string{
stream.MetadataStreamClose: fmt.Sprintf("%d", shellID),
},
}, srv.mux); err != nil {
log.Printf("[gRPC][ReverseShell][ERROR] Failed to notify subscribers that shell was closed: %v", err)
slog.ErrorContext(ctx, "reverse shell closed and failed to notify subscribers",
"shell_id", shell.ID,
"error", err,
)
}

// Update Ent
shell, err := srv.graph.Shell.Get(ctx, shellID)
if err != nil {
log.Printf("[gRPC][ReverseShell][ERROR] Failed to retrieve shell ent to update it as closed: %v", err)
slog.ErrorContext(ctx, "reverse shell closed and failed to load ent for updates",
"error", err,
"shell_id", shell.ID,
)
return
}
if _, err := shell.Update().
SetClosedAt(closedAt).
Save(ctx); err != nil {
log.Printf("[gRPC][ReverseShell][ERROR] Failed to update shell ent as closed: %v", err)
slog.ErrorContext(ctx, "reverse shell closed and failed to update ent",
"error", err,
"shell_id", shell.ID,
)
}
}()

Expand All @@ -121,35 +149,44 @@ func (srv *Server) ReverseShell(gstream c2pb.C2_ReverseShellServer) error {
wg.Add(1)
go func() {
defer wg.Done()
sendShellInput(ctx, gstream, pubsubStream)
sendShellInput(ctx, shellID, gstream, pubsubStream)
}()

// Send Output (to pubsub)
err = sendShellOutput(ctx, gstream, pubsubStream, srv.mux)
err = sendShellOutput(ctx, shellID, gstream, pubsubStream, srv.mux)

wg.Wait()

return err
}

func sendShellInput(ctx context.Context, gstream c2pb.C2_ReverseShellServer, pubsubStream *stream.Stream) {
func sendShellInput(ctx context.Context, shellID int, gstream c2pb.C2_ReverseShellServer, pubsubStream *stream.Stream) {
for {
select {
case <-ctx.Done():
return
case msg := <-pubsubStream.Messages():
msgLen := len(msg.Body)
if err := gstream.Send(&c2pb.ReverseShellResponse{
Kind: c2pb.ReverseShellMessageKind_REVERSE_SHELL_MESSAGE_KIND_DATA,
Data: msg.Body,
}); err != nil {
log.Printf("[ERROR] Failed to send gRPC input: %v", err)
slog.ErrorContext(ctx, "failed to send shell input to reverse shell",
"shell_id", shellID,
"msg_len", msgLen,
"error", err,
)
return
}
slog.DebugContext(ctx, "reverse shell sent input to agent via gRPC",
"shell_id", shellID,
"msg_len", msgLen,
)
}
}
}

func sendShellOutput(ctx context.Context, gstream c2pb.C2_ReverseShellServer, pubsubStream *stream.Stream, mux *stream.Mux) error {
func sendShellOutput(ctx context.Context, shellID int, gstream c2pb.C2_ReverseShellServer, pubsubStream *stream.Stream, mux *stream.Mux) error {
for {
req, err := gstream.Recv()
if err == io.EOF {
Expand All @@ -165,11 +202,21 @@ func sendShellOutput(ctx context.Context, gstream c2pb.C2_ReverseShellServer, pu
}

// Send Pubsub Message
msgLen := len(req.Data)
if err := pubsubStream.SendMessage(ctx, &pubsub.Message{
Body: req.Data,
}, mux); err != nil {
slog.ErrorContext(ctx, "reverse shell failed to publish shell output",
"shell_id", shellID,
"msg_len", msgLen,
"error", err,
)
return status.Errorf(codes.Internal, "failed to publish message: %v", err)
}
slog.DebugContext(ctx, "reverse shell published shell output",
"shell_id", shellID,
"msg_len", msgLen,
)
}
}

Expand All @@ -185,7 +232,7 @@ func sendKeepAlives(ctx context.Context, gstream c2pb.C2_ReverseShellServer) {
if err := gstream.Send(&c2pb.ReverseShellResponse{
Kind: c2pb.ReverseShellMessageKind_REVERSE_SHELL_MESSAGE_KIND_PING,
}); err != nil {
log.Printf("[ERROR] Failed to send gRPC ping: %v", err)
slog.ErrorContext(ctx, "reverse shell failed to send gRPC keep alive ping", "error", err)
}
}
}
Expand Down
Loading
Loading