Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

124 changes: 124 additions & 0 deletions tavern/internal/c2/api_claim_tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@ import (
"time"

"github.com/prometheus/client_golang/prometheus"
"github.com/robfig/cron/v3"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"realm.pub/tavern/internal/c2/c2pb"
"realm.pub/tavern/internal/c2/epb"
"realm.pub/tavern/internal/ent"
"realm.pub/tavern/internal/ent/beacon"
"realm.pub/tavern/internal/ent/host"
"realm.pub/tavern/internal/ent/tag"
"realm.pub/tavern/internal/ent/task"
"realm.pub/tavern/internal/ent/tome"
"realm.pub/tavern/internal/namegen"
)

Expand All @@ -28,10 +31,117 @@ var (
},
[]string{"host_identifier", "host_groups", "host_services"},
)
metricTomeAutomationErrors = prometheus.NewCounter(
prometheus.CounterOpts{
Name: "tavern_tome_automation_errors_total",
Help: "The total number of errors encountered during tome automation",
},
)
)

func init() {
prometheus.MustRegister(metricHostCallbacksTotal)
prometheus.MustRegister(metricTomeAutomationErrors)
}

func (srv *Server) handleTomeAutomation(ctx context.Context, beaconID int, hostID int, isNewBeacon bool, isNewHost bool, now time.Time) {
// Tome Automation Logic
candidateTomes, err := srv.graph.Tome.Query().
Where(tome.Or(
tome.RunOnNewBeaconCallback(true),
tome.RunOnFirstHostCallback(true),
tome.RunOnScheduleNEQ(""),
)).
All(ctx)

if err != nil {
slog.ErrorContext(ctx, "failed to query candidate tomes for automation", "err", err)
metricTomeAutomationErrors.Inc()
return
}

selectedTomes := make(map[int]*ent.Tome)
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
currentMinute := now.Truncate(time.Minute)

for _, t := range candidateTomes {
shouldRun := false

// Check RunOnNewBeaconCallback
if isNewBeacon && t.RunOnNewBeaconCallback {
shouldRun = true
}

// Check RunOnFirstHostCallback
if !shouldRun && isNewHost && t.RunOnFirstHostCallback {
shouldRun = true
}

// Check RunOnSchedule
if !shouldRun && t.RunOnSchedule != "" {
sched, err := parser.Parse(t.RunOnSchedule)
if err == nil {
// Check if schedule matches current time
// Next(now-1sec) == now?
next := sched.Next(currentMinute.Add(-1 * time.Second))
if next.Equal(currentMinute) {
// Check scheduled_hosts constraint
hostCount, err := t.QueryScheduledHosts().Count(ctx)
if err != nil {
slog.ErrorContext(ctx, "failed to count scheduled hosts for automation", "err", err, "tome_id", t.ID)
metricTomeAutomationErrors.Inc()
continue
}
if hostCount == 0 {
shouldRun = true
} else {
hostExists, err := t.QueryScheduledHosts().
Where(host.ID(hostID)).
Exist(ctx)
if err != nil {
slog.ErrorContext(ctx, "failed to check host existence for automation", "err", err, "tome_id", t.ID)
metricTomeAutomationErrors.Inc()
continue
}
if hostExists {
shouldRun = true
}
}
}
} else {
// Don't log cron parse errors for now, as it might be spammy if stored in DB
// metricTomeAutomationErrors.Inc()
}
}

if shouldRun {
selectedTomes[t.ID] = t
}
}

// Create Quest and Task for each selected Tome
for _, t := range selectedTomes {
q, err := srv.graph.Quest.Create().
SetName(fmt.Sprintf("Automated: %s", t.Name)).
SetTome(t).
SetParamDefsAtCreation(t.ParamDefs).
SetEldritchAtCreation(t.Eldritch).
Save(ctx)
if err != nil {
slog.ErrorContext(ctx, "failed to create automated quest", "err", err, "tome_id", t.ID)
metricTomeAutomationErrors.Inc()
continue
}

_, err = srv.graph.Task.Create().
SetQuest(q).
SetBeaconID(beaconID).
Save(ctx)
if err != nil {
slog.ErrorContext(ctx, "failed to create automated task", "err", err, "quest_id", q.ID)
metricTomeAutomationErrors.Inc()
}
}
}

func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest) (*c2pb.ClaimTasksResponse, error) {
Expand Down Expand Up @@ -61,6 +171,15 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest)
return nil, status.Errorf(codes.InvalidArgument, "must provide agent identifier")
}

// Check if host is new (before upsert)
hostExists, err := srv.graph.Host.Query().
Where(host.IdentifierEQ(req.Beacon.Host.Identifier)).
Exist(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to query host existence: %v", err)
}
isNewHost := !hostExists

// Upsert the host
hostID, err := srv.graph.Host.Create().
SetIdentifier(req.Beacon.Host.Identifier).
Expand Down Expand Up @@ -118,6 +237,8 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to query beacon entity: %v", err)
}
isNewBeacon := !beaconExists

var beaconNameAddr *string = nil
if !beaconExists {
candidateNames := []string{
Expand Down Expand Up @@ -172,6 +293,9 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest)
return nil, status.Errorf(codes.Internal, "failed to upsert beacon entity: %v", err)
}

// Run Tome Automation (non-blocking, best effort)
srv.handleTomeAutomation(ctx, beaconID, hostID, isNewBeacon, isNewHost, now)

// Load Tasks
tasks, err := srv.graph.Task.Query().
Where(task.And(
Expand Down
193 changes: 193 additions & 0 deletions tavern/internal/c2/tome_automation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
package c2

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"realm.pub/tavern/internal/c2/c2pb"
"realm.pub/tavern/internal/ent"
"realm.pub/tavern/internal/ent/enttest"
)

func TestHandleTomeAutomation(t *testing.T) {
ctx := context.Background()
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()

srv := &Server{graph: client}
now := time.Date(2023, 10, 27, 10, 0, 0, 0, time.UTC)

// Create a dummy host and beacon for testing
h := client.Host.Create().
SetIdentifier("test-host").
SetName("Test Host").
SetPlatform(c2pb.Host_PLATFORM_LINUX).
SaveX(ctx)
b := client.Beacon.Create().
SetIdentifier("test-beacon").
SetHost(h).
SetTransport(c2pb.ActiveTransport_TRANSPORT_HTTP1).
SaveX(ctx)

// 1. Setup Tomes
// T1: New Beacon Only
client.Tome.Create().
SetName("Tome New Beacon").
SetDescription("Test").
SetAuthor("Test Author").
SetEldritch("print('new beacon')").
SetRunOnNewBeaconCallback(true).
SaveX(ctx)

// T2: New Host Only
client.Tome.Create().
SetName("Tome New Host").
SetDescription("Test").
SetAuthor("Test Author").
SetEldritch("print('new host')").
SetRunOnFirstHostCallback(true).
SaveX(ctx)

// T3: Schedule Matching (Every minute)
client.Tome.Create().
SetName("Tome Schedule Match").
SetDescription("Test").
SetAuthor("Test Author").
SetEldritch("print('schedule')").
SetRunOnSchedule("* * * * *").
SaveX(ctx)

// T4: Schedule Matching with Host Restriction (Allowed)
client.Tome.Create().
SetName("Tome Schedule Restricted Allowed").
SetDescription("Test").
SetAuthor("Test Author").
SetEldritch("print('schedule restricted')").
SetRunOnSchedule("* * * * *").
AddScheduledHosts(h).
SaveX(ctx)

// T5: Schedule Matching with Host Restriction (Denied - different host)
otherHost := client.Host.Create().
SetIdentifier("other").
SetPlatform(c2pb.Host_PLATFORM_LINUX).
SaveX(ctx)

client.Tome.Create().
SetName("Tome Schedule Restricted Denied").
SetDescription("Test").
SetAuthor("Test Author").
SetEldritch("print('schedule denied')").
SetRunOnSchedule("* * * * *").
AddScheduledHosts(otherHost).
SaveX(ctx)

tests := []struct {
name string
isNewBeacon bool
isNewHost bool
expectedTomes []string
}{
{
name: "New Beacon Only",
isNewBeacon: true,
isNewHost: false,
expectedTomes: []string{
"Tome New Beacon",
"Tome Schedule Match",
"Tome Schedule Restricted Allowed",
},
},
{
name: "New Host Only",
isNewBeacon: false,
isNewHost: true,
expectedTomes: []string{
"Tome New Host",
"Tome Schedule Match",
"Tome Schedule Restricted Allowed",
},
},
{
name: "Both New",
isNewBeacon: true,
isNewHost: true,
expectedTomes: []string{
"Tome New Beacon",
"Tome New Host",
"Tome Schedule Match",
"Tome Schedule Restricted Allowed",
},
},
{
name: "Neither New",
isNewBeacon: false,
isNewHost: false,
expectedTomes: []string{
"Tome Schedule Match",
"Tome Schedule Restricted Allowed",
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Clear existing quests/tasks to ensure clean slate
client.Task.Delete().ExecX(ctx)
client.Quest.Delete().ExecX(ctx)

srv.handleTomeAutomation(ctx, b.ID, h.ID, tt.isNewBeacon, tt.isNewHost, now)

// Verify Tasks
tasks := client.Task.Query().WithQuest(func(q *ent.QuestQuery) {
q.WithTome()
}).AllX(ctx)

var createdTomes []string
for _, t := range tasks {
createdTomes = append(createdTomes, t.Edges.Quest.Edges.Tome.Name)
}

assert.ElementsMatch(t, tt.expectedTomes, createdTomes)
})
}
}

func TestHandleTomeAutomation_Deduplication(t *testing.T) {
ctx := context.Background()
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()

srv := &Server{graph: client}
now := time.Now()

h := client.Host.Create().
SetIdentifier("test").
SetPlatform(c2pb.Host_PLATFORM_LINUX).
SaveX(ctx)
b := client.Beacon.Create().
SetIdentifier("test").
SetHost(h).
SetTransport(c2pb.ActiveTransport_TRANSPORT_HTTP1).
SaveX(ctx)

// Tome with ALL triggers enabled
client.Tome.Create().
SetName("Super Tome").
SetDescription("Test").
SetAuthor("Test Author").
SetEldritch("print('super')").
SetRunOnNewBeaconCallback(true).
SetRunOnFirstHostCallback(true).
SetRunOnSchedule("* * * * *").
SaveX(ctx)

// Trigger all conditions
srv.handleTomeAutomation(ctx, b.ID, h.ID, true, true, now)

// Should only have 1 task
count := client.Task.Query().CountX(ctx)
assert.Equal(t, 1, count, "Should only create one task despite multiple triggers matching")
}