Skip to content

Commit

Permalink
Notify currently connected clients on new hunts (Velocidex#1386)
Browse files Browse the repository at this point in the history
When a hunt is started we used to notify all the clients, but this
creates unnecessary churn on the system. In this PR we specifically
forward all connected clients to the hunt manager instead of notifying
all clients. This keeps clients connected and avoids a thundering herd
condition.

The PR also adds completion function to QueueMessageForClient() so it
can be notified after the message is queued.
  • Loading branch information
scudette authored Nov 24, 2021
1 parent d2fc975 commit 2f8bfe1
Show file tree
Hide file tree
Showing 39 changed files with 582 additions and 162 deletions.
13 changes: 7 additions & 6 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,18 +218,19 @@ func (self *ApiServer) CollectArtifact(
}

flow_id, err := launcher.ScheduleArtifactCollection(
ctx, self.config, acl_manager, repository, in)
ctx, self.config, acl_manager, repository, in,
func() {
notifier := services.GetNotifier()
if notifier != nil {
notifier.NotifyListener(self.config, in.ClientId)
}
})
if err != nil {
return nil, err
}

result.FlowId = flow_id

err = services.GetNotifier().NotifyListener(self.config, in.ClientId)
if err != nil {
return nil, err
}

// Log this event as an Audit event.
logging.GetLogger(self.config, &logging.Audit).
WithFields(logrus.Fields{
Expand Down
8 changes: 5 additions & 3 deletions clients/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ func currentTaskId() uint64 {
func QueueMessageForClient(
config_obj *config_proto.Config,
client_id string,
req *crypto_proto.VeloMessage) error {
req *crypto_proto.VeloMessage,
completion func()) error {

// Task ID is related to time.
req.TaskId = currentTaskId()
Expand All @@ -90,6 +91,7 @@ func QueueMessageForClient(
}

client_path_manager := paths.NewClientPathManager(client_id)
return db.SetSubject(config_obj,
client_path_manager.Task(req.TaskId), req)
return db.SetSubjectWithCompletion(config_obj,
client_path_manager.Task(req.TaskId),
req, completion)
}
5 changes: 3 additions & 2 deletions clients/tasks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (self *ClientTasksTestSuite) TestQueueMessages() {
client_id := "C.1236"

message1 := &crypto_proto.VeloMessage{Source: "Server", SessionId: "1"}
err := clients.QueueMessageForClient(self.ConfigObj, client_id, message1)
err := clients.QueueMessageForClient(self.ConfigObj, client_id, message1, nil)
assert.NoError(self.T(), err)

// Now retrieve all messages.
Expand Down Expand Up @@ -54,7 +54,8 @@ func (self *ClientTasksTestSuite) TestFastQueueMessages() {

for i := 0; i < 10; i++ {
message := &crypto_proto.VeloMessage{Source: "Server", SessionId: fmt.Sprintf("%d", i)}
err := clients.QueueMessageForClient(self.ConfigObj, client_id, message)
err := clients.QueueMessageForClient(
self.ConfigObj, client_id, message, nil)
assert.NoError(self.T(), err)

written = append(written, message)
Expand Down
4 changes: 2 additions & 2 deletions crypto/server/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ func (self *serverPublicKeyResolver) SetPublicKey(
Pem: crypto_utils.PublicKeyToPem(key),
EnrollTime: uint64(time.Now().Unix()),
}
return db.SetSubject(self.config_obj,
client_path_manager.Key(), pem)
return db.SetSubjectWithCompletion(self.config_obj,
client_path_manager.Key(), pem, nil)
}

func (self *serverPublicKeyResolver) Clear() {}
Expand Down
8 changes: 4 additions & 4 deletions datastore/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,17 @@ type DataStore interface {
urn api.DSPathSpec,
message proto.Message) error

// SetSubject writes the data to the datastore. The data is
// written asynchronously and may not be immediately visible by
// other nodes.
// SetSubject writes the data to the datastore synchronously. The
// data is written synchronously and when complete will be visible
// to other nodes as long as the data is not in their caches.
SetSubject(
config_obj *config_proto.Config,
urn api.DSPathSpec,
message proto.Message) error

// Writes the data asynchronously and fires the completion
// callback when the data hits the disk and will become visibile
// to other nodes.
// to other nodes this may be a long time in the future.
SetSubjectWithCompletion(
config_obj *config_proto.Config,
urn api.DSPathSpec,
Expand Down
48 changes: 40 additions & 8 deletions datastore/memcache_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,6 @@ func (self *MemcacheFileDataStore) GetSubject(
return err
}

func (self *MemcacheFileDataStore) SetSubject(
config_obj *config_proto.Config,
urn api.DSPathSpec,
message proto.Message) error {

return self.SetSubjectWithCompletion(config_obj, urn, message, nil)
}

func (self *MemcacheFileDataStore) SetSubjectWithCompletion(
config_obj *config_proto.Config,
urn api.DSPathSpec,
Expand Down Expand Up @@ -254,6 +246,46 @@ func (self *MemcacheFileDataStore) SetSubjectWithCompletion(
return err
}

func (self *MemcacheFileDataStore) SetSubject(
config_obj *config_proto.Config,
urn api.DSPathSpec,
message proto.Message) error {

defer Instrument("write", "MemcacheFileDataStore", urn)()

// Encode as JSON
var serialized_content []byte
var err error

if urn.Type() == api.PATH_TYPE_DATASTORE_JSON {
serialized_content, err = protojson.Marshal(message)
if err != nil {
return err
}

} else {
serialized_content, err = proto.Marshal(message)
if err != nil {
return err
}
}

// Add the data to the cache immediately.
err = self.cache.SetSubject(config_obj, urn, message)

if err != nil {
return err
}

err = writeContentToFile(config_obj, urn, serialized_content)
if err != nil {
return err
}

self.invalidateDirCache(config_obj, urn)
return err
}

func (self *MemcacheFileDataStore) DeleteSubject(
config_obj *config_proto.Config,
urn api.DSPathSpec) error {
Expand Down
8 changes: 7 additions & 1 deletion datastore/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package datastore

import (
"context"
"sync"
"time"

"google.golang.org/protobuf/encoding/protojson"
Expand Down Expand Up @@ -62,12 +63,17 @@ func (self *RemoteDataStore) GetSubject(
return err
}

// Write the data synchronously.
func (self *RemoteDataStore) SetSubject(
config_obj *config_proto.Config,
urn api.DSPathSpec,
message proto.Message) error {

return self.SetSubjectWithCompletion(config_obj, urn, message, nil)
wg := sync.WaitGroup{}
defer wg.Wait()

wg.Add(1)
return self.SetSubjectWithCompletion(config_obj, urn, message, wg.Done)
}

func (self *RemoteDataStore) SetSubjectWithCompletion(
Expand Down
10 changes: 5 additions & 5 deletions flows/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,16 @@ func CancelFlow(
&crypto_proto.VeloMessage{
Cancel: &crypto_proto.Cancel{},
SessionId: flow_id,
}, func() {
notifier := services.GetNotifier()
if notifier != nil {
notifier.NotifyListener(config_obj, client_id)
}
})
if err != nil {
return nil, err
}

err = services.GetNotifier().NotifyListener(config_obj, client_id)
if err != nil {
return nil, err
}

return &api_proto.StartFlowResponse{
FlowId: flow_id,
}, nil
Expand Down
9 changes: 7 additions & 2 deletions flows/artifacts.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ func NewCollectionContext(config_obj *config_proto.Config) *CollectionContext {
if !self.send_update {
return
}
// Do not send it again.
self.send_update = false

// If this is the final response (i.e. the flow is not running)
// and we have not yet sent an update, then we will notify a flow
Expand Down Expand Up @@ -153,6 +155,10 @@ func closeContext(
config_obj *config_proto.Config,
collection_context *CollectionContext) error {

// Ensure the completion is not fired until we are done here
// completely.
defer collection_context.completer.GetCompletionFunc()()

// Context is not dirty - nothing to do.
if !collection_context.Dirty || collection_context.ClientId == "" {
return nil
Expand Down Expand Up @@ -774,12 +780,11 @@ func (self *FlowRunner) ProcessSingleMessage(
&crypto_proto.VeloMessage{
Cancel: &crypto_proto.Cancel{},
SessionId: job.SessionId,
})
}, nil)
if err != nil {
logger.Error("Queueing for client %v: %v",
job.Source, err)
}

return
}
self.context_map[job.SessionId] = collection_context
Expand Down
8 changes: 4 additions & 4 deletions flows/artifacts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,15 @@ func (self *TestSuite) TestGetFlow() {
flow_id, err := launcher.ScheduleArtifactCollection(
ctx, self.config_obj,
vql_subsystem.NullACLManager{},
repository, request1)
repository, request1, nil)
assert.NoError(self.T(), err)

flow_ids = append(flow_ids, flow_id)

flow_id, err = launcher.ScheduleArtifactCollection(
ctx, self.config_obj,
vql_subsystem.NullACLManager{},
repository, request2)
repository, request2, nil)
assert.NoError(self.T(), err)

flow_ids = append(flow_ids, flow_id)
Expand Down Expand Up @@ -179,7 +179,7 @@ func (self *TestSuite) TestRetransmission() {
flow_id, err := launcher.ScheduleArtifactCollection(
ctx, self.config_obj,
vql_subsystem.NullACLManager{},
repository, request)
repository, request, nil)
assert.NoError(self.T(), err)

// Send one row.
Expand Down Expand Up @@ -242,7 +242,7 @@ func (self *TestSuite) TestResourceLimits() {
ctx,
self.config_obj,
vql_subsystem.NullACLManager{},
repository, request)
repository, request, nil)
assert.NoError(self.T(), err)

// Drain messages to the client.
Expand Down
4 changes: 2 additions & 2 deletions flows/foreman.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func ForemanProcessMessage(
err := QueueMessageForClient(
config_obj, client_id,
client_event_manager.GetClientUpdateEventTableMessage(
config_obj, client_id))
config_obj, client_id), nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -170,5 +170,5 @@ func ForemanProcessMessage(
UpdateForeman: &actions_proto.ForemanCheckin{
LastHuntTimestamp: latest_timestamp,
},
})
}, nil)
}
6 changes: 4 additions & 2 deletions flows/hunts_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package flows
package flows_test

import (
"context"
Expand All @@ -13,6 +13,7 @@ import (
config_proto "www.velocidex.com/golang/velociraptor/config/proto"
"www.velocidex.com/golang/velociraptor/datastore"
"www.velocidex.com/golang/velociraptor/file_store/test_utils"
"www.velocidex.com/golang/velociraptor/flows"
flows_proto "www.velocidex.com/golang/velociraptor/flows/proto"
"www.velocidex.com/golang/velociraptor/paths"
"www.velocidex.com/golang/velociraptor/services"
Expand Down Expand Up @@ -97,7 +98,8 @@ sources:
}

acl_manager := vql_subsystem.NullACLManager{}
hunt_id, err := CreateHunt(self.ctx, self.config_obj, acl_manager, request)
hunt_id, err := flows.CreateHunt(
self.ctx, self.config_obj, acl_manager, request)

assert.NoError(self.T(), err)

Expand Down
2 changes: 1 addition & 1 deletion flows/limits.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func cancelCollection(config_obj *config_proto.Config, client_id, flow_id string
&crypto_proto.VeloMessage{
Cancel: &crypto_proto.Cancel{},
SessionId: flow_id,
})
}, nil)
if err != nil {
return err
}
Expand Down
5 changes: 3 additions & 2 deletions flows/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ func ProduceBackwardCompatibleVeloMessage(req *crypto_proto.VeloMessage) *crypto
func QueueMessageForClient(
config_obj *config_proto.Config,
client_id string,
req *crypto_proto.VeloMessage) error {
req *crypto_proto.VeloMessage,
completion func()) error {

req = ProduceBackwardCompatibleVeloMessage(req)
return clients.QueueMessageForClient(config_obj, client_id, req)
return clients.QueueMessageForClient(config_obj, client_id, req, completion)
}
11 changes: 11 additions & 0 deletions notifications/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ func NewNotificationPool() *NotificationPool {
}
}

func (self *NotificationPool) ListClients() []string {
self.mu.Lock()
defer self.mu.Unlock()

result := make([]string, 0, len(self.clients))
for k := range self.clients {
result = append(result, k)
}
return result
}

func (self *NotificationPool) IsClientConnected(client_id string) bool {
self.mu.Lock()
_, pres := self.clients[client_id]
Expand Down
2 changes: 1 addition & 1 deletion search/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func SetIndex(
}

path := path_manager.IndexTerm(term, client_id)
return db.SetSubject(config_obj, path, record)
return db.SetSubjectWithCompletion(config_obj, path, record, nil)
}

func UnsetIndex(
Expand Down
4 changes: 2 additions & 2 deletions search/mru.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ func UpdateMRU(
FirstSeenAt: uint64(time.Now().Unix()),
}

return db.SetSubject(
config_obj, path_manager.MRUClient(client_id), item)
return db.SetSubjectWithCompletion(
config_obj, path_manager.MRUClient(client_id), item, nil)
}
3 changes: 2 additions & 1 deletion search/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ func SetSimpleIndex(
// they are user provided.
keyword = strings.ToLower(keyword)
subject := index_urn.AddChild(keyword, entity)
err := db.SetSubject(config_obj, subject, &empty.Empty{})
err := db.SetSubjectWithCompletion(
config_obj, subject, &empty.Empty{}, nil)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 2f8bfe1

Please sign in to comment.