Skip to content

Commit

Permalink
Performance optimizations (Velocidex#286)
Browse files Browse the repository at this point in the history
* Disable compression of messages from server to client - these are typically small and compressing them does not buy much but increases server CPU load
* Implement GRPC client pooling to prevent reconnections - this reduces a lot of TLS type RSA operations.
* Precompile the artifact into the hunt. This improves performance (Since the same pre-compiled VQL can be sent for each client) and also results in a more correct hunt: If the artifact definition is modified after the hunt is launched it won't affect the hunt anymore.
  • Loading branch information
scudette authored Mar 24, 2020
1 parent 618719d commit 9fd0011
Show file tree
Hide file tree
Showing 32 changed files with 635 additions and 471 deletions.
5 changes: 3 additions & 2 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ func (self *ApiServer) CancelFlow(
}

result, err := flows.CancelFlow(
ctx,
self.config, in.ClientId, in.FlowId, user_name,
self.api_client_factory)
if err != nil {
Expand Down Expand Up @@ -170,7 +171,7 @@ func (self *ApiServer) CollectArtifact(
result.FlowId = flow_id

// Notify the client if it is listenning.
client, cancel := self.api_client_factory.GetAPIClient(self.config)
client, cancel := self.api_client_factory.GetAPIClient(ctx, self.config)
defer cancel()

_, err = client.NotifyClients(ctx, &api_proto.NotificationRequest{
Expand Down Expand Up @@ -244,7 +245,7 @@ func (self *ApiServer) ModifyHunt(
"details": fmt.Sprintf("%v", in),
}).Info("ModifyHunt")

err = flows.ModifyHunt(self.config, in, in.Creator)
err = flows.ModifyHunt(ctx, self.config, in, in.Creator)
if err != nil {
return nil, err
}
Expand Down
282 changes: 136 additions & 146 deletions api/proto/hunts.pb.go

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions api/proto/hunts.proto
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,6 @@ message Hunt {
State state = 8 [(sem_type) = {
description: "This is state of the hunt. This field is manupulated by the GUI."
}];

VQLCollectorArgs compiled_collector_args = 20;
}

message ListHuntsRequest {
Expand Down
6 changes: 3 additions & 3 deletions artifacts/assets/ab0x.go

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions artifacts/definitions/Windows/System/TaskScheduler.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ description: |
parameters:
- name: TasksPath
default: c:/Windows/System32/Tasks/**
- name: alsoUpload
- name: AlsoUpload
type: bool

sources:
- name: Analysis
queries:
- LET Uploads = SELECT Name, FullPath, if(
condition=alsoUpload,
condition=AlsoUpload='Y',
then=upload(file=FullPath)) as Upload
FROM glob(globs=TasksPath)
WHERE NOT IsDir
Expand Down
55 changes: 27 additions & 28 deletions bin/fuse.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ type VFSFs struct {
logger *logging.LogContext
}

func (self *VFSFs) fetchDir(vfs_name string) ([]*api.FileInfoRow, error) {
func (self *VFSFs) fetchDir(
ctx context.Context,
vfs_name string) ([]*api.FileInfoRow, error) {
self.logger.Info(fmt.Sprintf("Fetching dir %v from %v", vfs_name, self.client_id))
channel := grpc_client.GetChannel(self.config_obj)
defer channel.Close()
client, closer := grpc_client.Factory.GetAPIClient(ctx, self.config_obj)
defer closer()

client := api_proto.NewAPIClient(channel)
response, err := client.VFSRefreshDirectory(context.Background(),
response, err := client.VFSRefreshDirectory(ctx,
&api_proto.VFSRefreshDirectoryRequest{
ClientId: self.client_id,
VfsPath: vfs_name,
Expand Down Expand Up @@ -83,19 +84,18 @@ func (self *VFSFs) fetchDir(vfs_name string) ([]*api.FileInfoRow, error) {
time.Sleep(200 * time.Millisecond)
}

return self.getDir(vfs_name)
return self.getDir(ctx, vfs_name)
}

func (self *VFSFs) fetchFile(vfs_name string) error {
func (self *VFSFs) fetchFile(
ctx context.Context,
vfs_name string) error {
self.logger.Info("Fetching file %v", vfs_name)

channel := grpc_client.GetChannel(self.config_obj)
defer channel.Close()
client, closer := grpc_client.Factory.GetAPIClient(ctx, self.config_obj)
defer closer()

client_path, accessor := api.GetClientPath(vfs_name)

client := api_proto.NewAPIClient(channel)

request := api.MakeCollectorRequest(
self.client_id, "System.VFS.DownloadFile",
"Path", client_path, "Key", accessor)
Expand Down Expand Up @@ -143,9 +143,9 @@ func (self *VFSFs) GetAttr(name string, fcontext *fuse.Context) (*fuse.Attr, fus
vfs_name := fsPathToVFS(name)

dirname, basename := path.Split(vfs_name)
rows, err := self.getDir(dirname)
rows, err := self.getDir(fcontext, dirname)
if err != nil {
rows, err = self.fetchDir(dirname)
rows, err = self.fetchDir(fcontext, dirname)
if err != nil {
self.logger.Error(
fmt.Sprintf("Failed to fetch %s: %v", dirname, err))
Expand Down Expand Up @@ -173,21 +173,21 @@ func (self *VFSFs) GetAttr(name string, fcontext *fuse.Context) (*fuse.Attr, fus
return nil, fuse.ENOENT
}

func (self *VFSFs) getDir(vfs_name string) ([]*api.FileInfoRow, error) {
func (self *VFSFs) getDir(
ctx context.Context,
vfs_name string) ([]*api.FileInfoRow, error) {
rows, pres := self.cache[vfs_name]
if pres {
return rows, nil
}

channel := grpc_client.GetChannel(self.config_obj)
defer channel.Close()
client, closer := grpc_client.Factory.GetAPIClient(ctx, self.config_obj)
defer closer()

request := &flows_proto.VFSListRequest{
ClientId: self.client_id,
VfsPath: vfs_name,
}

client := api_proto.NewAPIClient(channel)
response, err := client.VFSListDirectory(context.Background(), request)
if err != nil {
return nil, err
Expand All @@ -206,10 +206,10 @@ func (self *VFSFs) getDir(vfs_name string) ([]*api.FileInfoRow, error) {
func (self *VFSFs) OpenDir(fs_name string, fcontext *fuse.Context) (
[]fuse.DirEntry, fuse.Status) {
vfs_name := fsPathToVFS(fs_name)
rows, err := self.getDir(vfs_name)
rows, err := self.getDir(fcontext, vfs_name)
if err != nil {
self.logger.Warn(fmt.Sprintf("Fetching directory %s", vfs_name))
rows, err = self.fetchDir(vfs_name)
rows, err = self.fetchDir(fcontext, vfs_name)
if err != nil {
return nil, fuse.ENOENT
}
Expand Down Expand Up @@ -237,17 +237,16 @@ func (self *VFSFs) Open(fs_name string, flags uint32, fcontext *fuse.Context) (

vfs_name := fsPathToVFS(fs_name)

channel := grpc_client.GetChannel(self.config_obj)
defer channel.Close()
client, closer := grpc_client.Factory.GetAPIClient(fcontext, self.config_obj)
defer closer()

client := api_proto.NewAPIClient(channel)
_, err := client.VFSGetBuffer(context.Background(),
&api_proto.VFSFileBuffer{
ClientId: self.client_id,
VfsPath: vfs_name,
})
if err != nil {
err := self.fetchFile(vfs_name)
err := self.fetchFile(fcontext, vfs_name)
if err != nil {
_, ok := errors.Cause(err).(*os.PathError)
if ok {
Expand Down Expand Up @@ -294,10 +293,10 @@ func (self *VFSFileReader) GetAttr(out *fuse.Attr) fuse.Status {
func (self *VFSFileReader) Read(dest []byte, off int64) (
fuse.ReadResult, fuse.Status) {

channel := grpc_client.GetChannel(self.config_obj)
defer channel.Close()
client, closer := grpc_client.Factory.GetAPIClient(
context.Background(), self.config_obj)
defer closer()

client := api_proto.NewAPIClient(channel)
response, err := client.VFSGetBuffer(context.Background(),
&api_proto.VFSFileBuffer{
ClientId: self.client_id,
Expand Down
39 changes: 20 additions & 19 deletions bin/fuse_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,12 @@ func (self *VFSFs) Read(file_path string, buff []byte, off int64, fd uint64) (n
}

// We need to fetch the file from the client.
ctx := context.Background()
self.logger.Info("Fetching file %v", vfs_name)

channel := grpc_client.GetChannel(self.config_obj)
defer channel.Close()
client, closer := grpc_client.Factory.GetAPIClient(ctx, self.config_obj)
defer closer()

client_path, accessor := api.GetClientPath(vfs_name)

client := api_proto.NewAPIClient(channel)
response, err := client.CollectArtifact(context.Background(),
api.MakeCollectorRequest(self.client_id, "System.VFS.DownloadFile",
"Path", client_path, "Accessor", accessor))
Expand Down Expand Up @@ -174,10 +172,10 @@ func (self *VFSFs) Read(file_path string, buff []byte, off int64, fd uint64) (n
}

func (self *VFSFs) read_buffer(vfs_name string, buff []byte, off int64, fh uint64) (n int) {
channel := grpc_client.GetChannel(self.config_obj)
defer channel.Close()
ctx := context.Background()
client, closer := grpc_client.Factory.GetAPIClient(ctx, self.config_obj)
defer closer()

client := api_proto.NewAPIClient(channel)
response, err := client.VFSGetBuffer(context.Background(),
&api_proto.VFSFileBuffer{
ClientId: self.client_id,
Expand Down Expand Up @@ -274,10 +272,10 @@ func (self *VFSFs) GetDir(vfs_name string) ([]*api.FileInfoRow, int) {
}

// Not there - initiate a new client flow.
channel := grpc_client.GetChannel(self.config_obj)
defer channel.Close()
ctx := context.Background()
client, closer := grpc_client.Factory.GetAPIClient(ctx, self.config_obj)
defer closer()

client := api_proto.NewAPIClient(channel)
response, err := client.VFSRefreshDirectory(context.Background(),
&api_proto.VFSRefreshDirectoryRequest{
ClientId: self.client_id,
Expand Down Expand Up @@ -317,14 +315,14 @@ func (self *VFSFs) GetDir(vfs_name string) ([]*api.FileInfoRow, int) {

func (self *VFSFs) isFlowComplete(flow_id, vfs_name string) (bool, int) {
// Check if the flow is completed yet.
channel := grpc_client.GetChannel(self.config_obj)
defer channel.Close()
ctx := context.Background()
client, closer := grpc_client.Factory.GetAPIClient(ctx, self.config_obj)
defer closer()

req := &api_proto.ApiFlowRequest{
ClientId: self.client_id,
FlowId: flow_id,
}
client := api_proto.NewAPIClient(channel)
response, err := client.GetFlowDetails(context.Background(), req)
if err != nil {
self.logger.Warn("GetFlowDetails %s: %v", vfs_name, err)
Expand All @@ -339,15 +337,14 @@ func (self *VFSFs) isFlowComplete(flow_id, vfs_name string) (bool, int) {
}

func (self *VFSFs) getDir(vfs_name string) ([]*api.FileInfoRow, error) {
channel := grpc_client.GetChannel(self.config_obj)
defer channel.Close()
ctx := context.Background()
client, closer := grpc_client.Factory.GetAPIClient(ctx, self.config_obj)
defer closer()

request := &flows_proto.VFSListRequest{
ClientId: self.client_id,
VfsPath: vfs_name,
}

client := api_proto.NewAPIClient(channel)
response, err := client.VFSListDirectory(context.Background(), request)
if err != nil {
self.logger.Warn("VFSListDirectory error %s (%v)", vfs_name, err)
Expand Down Expand Up @@ -401,7 +398,11 @@ func NewVFSFs(config_obj *config_proto.Config, client_id string) *VFSFs {

func doFuse() {
config_obj := get_config_or_default()
grpc_client.GetChannel(config_obj)

// Connect one time to make sure we can.
ctx := context.Background()
_, closer := grpc_client.Factory.GetAPIClient(ctx, config_obj)
closer()

args := []string{*fuse_command_mnt_point,
// Winfsp uses very few threads (2) which may cause a
Expand Down
11 changes: 5 additions & 6 deletions bin/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ func shell_executor(config_obj *config_proto.Config,
kingpin.FatalIfError(err, "Sending client message ")

api_client_factory := grpc_client.GRPCAPIClient{}
client, cancel := api_client_factory.GetAPIClient(config_obj)
client, cancel := api_client_factory.GetAPIClient(ctx, config_obj)
defer cancel()

_, err = client.NotifyClients(context.Background(),
_, err = client.NotifyClients(ctx,
&api_proto.NotificationRequest{ClientId: *shell_client})
kingpin.FatalIfError(err, "Sending client message ")
}()
Expand Down Expand Up @@ -213,11 +213,10 @@ func completer(t prompt.Document) []prompt.Suggest {
}

func getClientInfo(config_obj *config_proto.Config, ctx context.Context) (*api_proto.ApiClient, error) {
channel := grpc_client.GetChannel(config_obj)
defer channel.Close()
client, closer := grpc_client.Factory.GetAPIClient(ctx, config_obj)
defer closer()

client := api_proto.NewAPIClient(channel)
return client.GetClient(context.Background(), &api_proto.GetClientRequest{
return client.GetClient(ctx, &api_proto.GetClientRequest{
ClientId: *shell_client,
})
}
Expand Down
13 changes: 9 additions & 4 deletions crypto/crypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ func (self *TestSuite) TestEncDecServerToClient() {
assert.NoError(t, err)

cipher_text, err := self.server_manager.Encrypt(
[][]byte{serialized}, self.client_id)
[][]byte{serialized},
crypto_proto.PackedMessageList_ZCOMPRESSION,
self.client_id)
assert.NoError(t, err)

initial_c := testutil.ToFloat64(rsaDecryptCounter)
Expand All @@ -99,7 +101,7 @@ func (self *TestSuite) TestEncDecServerToClient() {
t.Fatal(err)
}
message_info.IterateJobs(context.Background(),
func(item *crypto_proto.GrrMessage) {
func(ctx context.Context, item *crypto_proto.GrrMessage) {
assert.Equal(t, item.Name, "OMG it's a string")
assert.Equal(t, item.AuthState, crypto_proto.GrrMessage_AUTHENTICATED)
})
Expand All @@ -122,7 +124,9 @@ func (self *TestSuite) TestEncDecClientToServer() {

config_obj := config.GetDefaultConfig()
cipher_text, err := self.client_manager.EncryptMessageList(
message_list, config_obj.Client.PinnedServerName)
message_list,
crypto_proto.PackedMessageList_ZCOMPRESSION,
config_obj.Client.PinnedServerName)
assert.NoError(t, err)

initial_c := testutil.ToFloat64(rsaDecryptCounter)
Expand All @@ -135,7 +139,7 @@ func (self *TestSuite) TestEncDecClientToServer() {
}

message_info.IterateJobs(context.Background(),
func(item *crypto_proto.GrrMessage) {
func(ctx context.Context, item *crypto_proto.GrrMessage) {
assert.Equal(t, item.Name, "OMG it's a string")
assert.Equal(
t, item.AuthState, crypto_proto.GrrMessage_AUTHENTICATED)
Expand All @@ -157,6 +161,7 @@ func (self *TestSuite) TestEncryption() {
for i := 0; i < 100; i++ {
cipher_text, err := self.client_manager.Encrypt(
[][]byte{Compress(plain_text)},
crypto_proto.PackedMessageList_ZCOMPRESSION,
config_obj.Client.PinnedServerName)
assert.NoError(t, err)

Expand Down
8 changes: 6 additions & 2 deletions crypto/testing_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ func (self *NullCryptoManager) EncryptMessageList(
}

cipher_text, err := self.Encrypt(
[][]byte{Compress(plain_text)}, destination)
[][]byte{Compress(plain_text)},
crypto_proto.PackedMessageList_ZCOMPRESSION,
destination)
return cipher_text, err
}

func (self *NullCryptoManager) Encrypt(
compressed_message_lists [][]byte, destination string) (
compressed_message_lists [][]byte,
compression crypto_proto.PackedMessageList_CompressionType,
destination string) (
[]byte, error) {
packed_message_list := &crypto_proto.PackedMessageList{
MessageList: compressed_message_lists,
Expand Down
Loading

0 comments on commit 9fd0011

Please sign in to comment.