Skip to content
15 changes: 13 additions & 2 deletions tavern/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type Server struct {

// Close should always be called to clean up a Tavern server.
func (srv *Server) Close() error {
srv.HTTP.Shutdown(context.Background())
return srv.client.Close()
}

Expand Down Expand Up @@ -177,10 +178,20 @@ func NewServer(ctx context.Context, options ...func(*Config)) (*Server, error) {
return nil, fmt.Errorf("failed to configure http/2: %w", err)
}

return &Server{
// Initialize Server
tSrv := &Server{
HTTP: cfg.srv,
client: client,
}, nil
}

// Shutdown for Test Run & Exit
if cfg.IsTestRunAndExitEnabled() {
go func() {
tSrv.Close()
}()
}

return tSrv, nil
}

func newGraphQLHandler(client *ent.Client) http.Handler {
Expand Down
16 changes: 13 additions & 3 deletions tavern/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ import (

var (
// EnvEnableTestData if set will populate the database with test data.
EnvEnableTestData = EnvString{"ENABLE_TEST_DATA", ""}
// EnvEnableTestRunAndExit will start the application, but exit immediately after.
EnvEnableTestData = EnvString{"ENABLE_TEST_DATA", ""}
EnvEnableTestRunAndExit = EnvString{"ENABLE_TEST_RUN_AND_EXIT", ""}

// EnvHTTPListenAddr sets the address (ip:port) for tavern's HTTP server to bind to.
EnvHTTPListenAddr = EnvString{"HTTP_LISTEN_ADDR", "0.0.0.0:80"}

// EnvOAuthClientID set to configure OAuth Client ID.
// EnvOAuthClientSecret set to configure OAuth Client Secret.
Expand Down Expand Up @@ -112,11 +117,16 @@ func (cfg *Config) IsTestDataEnabled() bool {
return EnvEnableTestData.String() != ""
}

// IsTestRunAndExitEnabled returns true if a value for the "ENABLE_TEST_RUN_AND_EXIT" environment variable is set.
func (cfg *Config) IsTestRunAndExitEnabled() bool {
return EnvEnableTestRunAndExit.String() != ""
}

// ConfigureHTTPServer enables the configuration of the Tavern HTTP server. The endpoint field will be
// overwritten with Tavern's HTTP handler when Tavern is run.
func ConfigureHTTPServer(address string, options ...func(*http.Server)) func(*Config) {
func ConfigureHTTPServerFromEnv(options ...func(*http.Server)) func(*Config) {
srv := &http.Server{
Addr: address,
Addr: EnvHTTPListenAddr.String(),
}
for _, opt := range options {
opt(srv)
Expand Down
92 changes: 92 additions & 0 deletions tavern/env_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package main

import (
"os"
"testing"

"github.com/stretchr/testify/assert"
)

func TestEnvString(t *testing.T) {
// Test Cases
tests := []struct {
name string

env EnvString
osValue string
wantValue string
}{
{
name: "Set",
env: EnvString{"TEST_ENV_STRING", ""},
osValue: "VALUE_SET",
wantValue: "VALUE_SET",
},
{
name: "Unset",
env: EnvString{"TEST_ENV_STRING", ""},
osValue: "",
wantValue: "",
},
{
name: "Default",
env: EnvString{"TEST_ENV_STRING", "BLAH_BLAH"},
osValue: "",
wantValue: "BLAH_BLAH",
},
}

// Run Tests
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if tc.osValue != "" {
os.Setenv(tc.env.Key, tc.osValue)
defer os.Unsetenv(tc.env.Key)
}

assert.Equal(t, tc.wantValue, tc.env.String())
})
}
}

func TestEnvInteger(t *testing.T) {
// Test Cases
tests := []struct {
name string

env EnvInteger
osValue string
wantValue int
}{
{
name: "Set",
env: EnvInteger{"TEST_ENV_INT", 0},
osValue: "123",
wantValue: 123,
},
{
name: "Unset",
env: EnvInteger{"TEST_ENV_INT", 0},
osValue: "",
wantValue: 0,
},
{
name: "Default",
env: EnvInteger{"TEST_ENV_INT", 456},
osValue: "",
wantValue: 456,
},
}

// Run Tests
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if tc.osValue != "" {
os.Setenv(tc.env.Key, tc.osValue)
defer os.Unsetenv(tc.env.Key)
}

assert.Equal(t, tc.wantValue, tc.env.Int())
})
}
}
11 changes: 11 additions & 0 deletions tavern/internal/c2/api_report_process_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ func TestReportProcessList(t *testing.T) {
wantResp: nil,
wantCode: codes.InvalidArgument,
},
{
name: "Not_Found",
req: &c2pb.ReportProcessListRequest{
TaskId: 99888777776666,
List: []*c2pb.Process{
{Pid: 1, Name: "systemd", Principal: "root"},
},
},
wantResp: nil,
wantCode: codes.NotFound,
},
}

// Run Tests
Expand Down
59 changes: 59 additions & 0 deletions tavern/internal/c2/api_report_task_output.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package c2

import (
"context"
"fmt"
"time"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"realm.pub/tavern/internal/c2/c2pb"
"realm.pub/tavern/internal/ent"
)

func (srv *Server) ReportTaskOutput(ctx context.Context, req *c2pb.ReportTaskOutputRequest) (*c2pb.ReportTaskOutputResponse, error) {
// Validate Input
if req.Output == nil || req.Output.Id == 0 {
return nil, status.Errorf(codes.InvalidArgument, "must provide task id")
}

// Parse Input
var (
execStartedAt *time.Time
execFinishedAt *time.Time
taskErr *string
)
if req.Output.ExecStartedAt != nil {
timestamp := req.Output.ExecStartedAt.AsTime()
execStartedAt = &timestamp
}
if req.Output.ExecFinishedAt != nil {
timestamp := req.Output.ExecFinishedAt.AsTime()
execFinishedAt = &timestamp
}
if req.Output.Error != nil {
taskErr = &req.Output.Error.Msg
}

// Load Task
t, err := srv.graph.Task.Get(ctx, int(req.Output.Id))
if ent.IsNotFound(err) {
return nil, status.Errorf(codes.NotFound, "no task found (id=%d): %v", req.Output.Id, err)
}
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to submit task result (id=%d): %v", req.Output.Id, err)
}

// Update Task
_, err = t.Update().
SetNillableExecStartedAt(execStartedAt).
SetOutput(fmt.Sprintf("%s%s", t.Output, req.Output.Output)).
SetNillableExecFinishedAt(execFinishedAt).
SetNillableError(taskErr).
Save(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to save submitted task result (id=%d): %v", t.ID, err)
}

return &c2pb.ReportTaskOutputResponse{}, nil
}
134 changes: 134 additions & 0 deletions tavern/internal/c2/api_report_task_output_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package c2_test

import (
"context"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/timestamppb"
"realm.pub/tavern/internal/c2/c2pb"
"realm.pub/tavern/internal/c2/c2test"
"realm.pub/tavern/internal/ent"
)

func TestReportTaskOutput(t *testing.T) {
// Setup Dependencies
ctx := context.Background()
client, graph, close := c2test.New(t)
defer close()

// Test Data
now := timestamppb.Now()
finishedAt := timestamppb.New(time.Now().UTC().Add(10 * time.Minute))
existingBeacon := c2test.NewRandomBeacon(ctx, graph)
existingTasks := []*ent.Task{
c2test.NewRandomAssignedTask(ctx, graph, existingBeacon.Identifier),
c2test.NewRandomAssignedTask(ctx, graph, existingBeacon.Identifier),
}

// Test Cases
tests := []struct {
name string
req *c2pb.ReportTaskOutputRequest
wantResp *c2pb.ReportTaskOutputResponse
wantCode codes.Code
wantOutput string
wantExecStartedAt *timestamppb.Timestamp
wantExecFinishedAt *timestamppb.Timestamp
}{
{
name: "First_Output",
req: &c2pb.ReportTaskOutputRequest{
Output: &c2pb.TaskOutput{
Id: int64(existingTasks[0].ID),
Output: "TestOutput",
ExecStartedAt: now,
},
},
wantResp: &c2pb.ReportTaskOutputResponse{},
wantCode: codes.OK,
wantOutput: "TestOutput",
wantExecStartedAt: now,
},
{
name: "Append_Output",
req: &c2pb.ReportTaskOutputRequest{
Output: &c2pb.TaskOutput{
Id: int64(existingTasks[0].ID),
Output: "_AppendedOutput",
},
},
wantResp: &c2pb.ReportTaskOutputResponse{},
wantCode: codes.OK,
wantOutput: "TestOutput_AppendedOutput",
wantExecStartedAt: now,
},
{
name: "Exec_Finished",
req: &c2pb.ReportTaskOutputRequest{
Output: &c2pb.TaskOutput{
Id: int64(existingTasks[0].ID),
ExecFinishedAt: finishedAt,
},
},
wantResp: &c2pb.ReportTaskOutputResponse{},
wantCode: codes.OK,
wantOutput: "TestOutput_AppendedOutput",
wantExecStartedAt: now,
wantExecFinishedAt: finishedAt,
},
{
name: "Not_Found",
req: &c2pb.ReportTaskOutputRequest{
Output: &c2pb.TaskOutput{
Id: 999888777666,
},
},
wantResp: nil,
wantCode: codes.NotFound,
},
{
name: "Invalid_Argument",
req: &c2pb.ReportTaskOutputRequest{
Output: &c2pb.TaskOutput{},
},
wantResp: nil,
wantCode: codes.InvalidArgument,
},
}

// Run Tests
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Callback
resp, err := client.ReportTaskOutput(ctx, tc.req)

// Assert Response Code
require.Equal(t, tc.wantCode.String(), status.Code(err).String(), err)
if status.Code(err) != codes.OK {
// Do not continue if we expected error code
return
}

// Assert Response
if diff := cmp.Diff(tc.wantResp, resp, protocmp.Transform()); diff != "" {
t.Errorf("invalid response (-want +got): %v", diff)
}

// Load Task

testTask, err := graph.Task.Get(ctx, int(tc.req.Output.Id))
require.NoError(t, err)

// Task Assertions
assert.Equal(t, tc.wantOutput, testTask.Output)
})
}

}
Loading