Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context and lock functionality to client interface #108

Merged
merged 17 commits into from
Jun 27, 2023
191 changes: 161 additions & 30 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,76 @@ import (

"github.com/osquery/osquery-go/gen/osquery"
"github.com/osquery/osquery-go/transport"
"github.com/pkg/errors"

"github.com/apache/thrift/lib/go/thrift"
"github.com/pkg/errors"
)

const (
defaultWaitTime = 200 * time.Millisecond
defaulMaxWaitTime = 1 * time.Minute
directionless marked this conversation as resolved.
Show resolved Hide resolved
)

// ExtensionManagerClient is a wrapper for the osquery Thrift extensions API.
type ExtensionManagerClient struct {
Client osquery.ExtensionManager
client osquery.ExtensionManager
transport thrift.TTransport

waitTime time.Duration
maxWaitTime time.Duration
lock *locker
}

type ClientOption func(*ExtensionManagerClient)

// WaitTime sets the default amount of wait time for the osquery socket to free up. You can override this on a per
// call basis by setting a context deadline
func DefaultWaitTime(d time.Duration) ClientOption {
return func(c *ExtensionManagerClient) {
c.waitTime = d
}
}

// MaxWaitTime is the maximum amount of time something is allowed to wait for the osquery socket. This takes precedence
// over the context deadline.
func MaxWaitTime(d time.Duration) ClientOption {
return func(c *ExtensionManagerClient) {
c.maxWaitTime = d
}
}

// NewClient creates a new client communicating to osquery over the socket at
// the provided path. If resolving the address or connecting to the socket
// fails, this function will error.
func NewClient(path string, timeout time.Duration) (*ExtensionManagerClient, error) {
trans, err := transport.Open(path, timeout)
if err != nil {
return nil, err
func NewClient(path string, socketOpenTimeout time.Duration, opts ...ClientOption) (*ExtensionManagerClient, error) {
c := &ExtensionManagerClient{
waitTime: defaultWaitTime,
maxWaitTime: defaulMaxWaitTime,
}

client := osquery.NewExtensionManagerClientFactory(
trans,
thrift.NewTBinaryProtocolFactoryDefault(),
)
for _, opt := range opts {
opt(c)
}

return &ExtensionManagerClient{client, trans}, nil
if c.waitTime > c.maxWaitTime {
return nil, errors.New("default wait time larger than max wait time")
}

c.lock = NewLocker(c.waitTime, c.maxWaitTime)

if c.client == nil {
trans, err := transport.Open(path, socketOpenTimeout)
if err != nil {
return nil, err
}

c.client = osquery.NewExtensionManagerClientFactory(
trans,
thrift.NewTBinaryProtocolFactoryDefault(),
)
}

return c, nil
}

// Close should be called to close the transport when use of the client is
Expand All @@ -42,48 +86,120 @@ func (c *ExtensionManagerClient) Close() {
}
}

// Ping requests metadata from the extension manager.
// Ping requests metadata from the extension manager, using a new background context
func (c *ExtensionManagerClient) Ping() (*osquery.ExtensionStatus, error) {
return c.Client.Ping(context.Background())
return c.PingContext(context.Background())
}

// Call requests a call to an extension (or core) registry plugin.
// PingContext requests metadata from the extension manager.
func (c *ExtensionManagerClient) PingContext(ctx context.Context) (*osquery.ExtensionStatus, error) {
if err := c.lock.Lock(ctx); err != nil {
return nil, err
}
defer c.lock.Unlock()
return c.client.Ping(ctx)
}

// Call requests a call to an extension (or core) registry plugin, using a new background context
func (c *ExtensionManagerClient) Call(registry, item string, request osquery.ExtensionPluginRequest) (*osquery.ExtensionResponse, error) {
return c.Client.Call(context.Background(), registry, item, request)
return c.CallContext(context.Background(), registry, item, request)
}

// Extensions requests the list of active registered extensions.
// CallContext requests a call to an extension (or core) registry plugin.
func (c *ExtensionManagerClient) CallContext(ctx context.Context, registry, item string, request osquery.ExtensionPluginRequest) (*osquery.ExtensionResponse, error) {
if err := c.lock.Lock(ctx); err != nil {
return nil, err
}
defer c.lock.Unlock()
return c.client.Call(ctx, registry, item, request)
}

// Extensions requests the list of active registered extensions, using a new background context
func (c *ExtensionManagerClient) Extensions() (osquery.InternalExtensionList, error) {
return c.Client.Extensions(context.Background())
return c.ExtensionsContext(context.Background())
}

// ExtensionsContext requests the list of active registered extensions.
func (c *ExtensionManagerClient) ExtensionsContext(ctx context.Context) (osquery.InternalExtensionList, error) {
if err := c.lock.Lock(ctx); err != nil {
return nil, err
}
defer c.lock.Unlock()
return c.client.Extensions(ctx)
}

// RegisterExtension registers the extension plugins with the osquery process.
// RegisterExtension registers the extension plugins with the osquery process, using a new background context
func (c *ExtensionManagerClient) RegisterExtension(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) {
return c.Client.RegisterExtension(context.Background(), info, registry)
return c.RegisterExtensionContext(context.Background(), info, registry)
}

// RegisterExtensionContext registers the extension plugins with the osquery process.
func (c *ExtensionManagerClient) RegisterExtensionContext(ctx context.Context, info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) {
if err := c.lock.Lock(ctx); err != nil {
return nil, err
}
defer c.lock.Unlock()
return c.client.RegisterExtension(ctx, info, registry)
}

// DeregisterExtension de-registers the extension plugins with the osquery process.
// DeregisterExtension de-registers the extension plugins with the osquery process, using a new background context
func (c *ExtensionManagerClient) DeregisterExtension(uuid osquery.ExtensionRouteUUID) (*osquery.ExtensionStatus, error) {
return c.Client.DeregisterExtension(context.Background(), uuid)
return c.DeregisterExtensionContext(context.Background(), uuid)
}

// DeregisterExtensionContext de-registers the extension plugins with the osquery process.
func (c *ExtensionManagerClient) DeregisterExtensionContext(ctx context.Context, uuid osquery.ExtensionRouteUUID) (*osquery.ExtensionStatus, error) {
if err := c.lock.Lock(ctx); err != nil {
return nil, err
}
defer c.lock.Unlock()
return c.client.DeregisterExtension(ctx, uuid)
}

// Options requests the list of bootstrap or configuration options.
// Options requests the list of bootstrap or configuration options, using a new background context.
func (c *ExtensionManagerClient) Options() (osquery.InternalOptionList, error) {
return c.Client.Options(context.Background())
return c.OptionsContext(context.Background())
}

// OptionsContext requests the list of bootstrap or configuration options.
func (c *ExtensionManagerClient) OptionsContext(ctx context.Context) (osquery.InternalOptionList, error) {
if err := c.lock.Lock(ctx); err != nil {
return nil, err
}
defer c.lock.Unlock()
return c.client.Options(ctx)
}

// Query requests a query to be run and returns the extension response.
// Query requests a query to be run and returns the extension
// response, using a new background context. Consider using the
// QueryRow or QueryRows helpers for a more friendly interface.
func (c *ExtensionManagerClient) Query(sql string) (*osquery.ExtensionResponse, error) {
return c.QueryContext(context.Background(), sql)
}

// QueryContext requests a query to be run and returns the extension response.
// Consider using the QueryRow or QueryRows helpers for a more friendly
// interface.
func (c *ExtensionManagerClient) Query(sql string) (*osquery.ExtensionResponse, error) {
return c.Client.Query(context.Background(), sql)
func (c *ExtensionManagerClient) QueryContext(ctx context.Context, sql string) (*osquery.ExtensionResponse, error) {
if err := c.lock.Lock(ctx); err != nil {
return nil, err
}
defer c.lock.Unlock()
return c.client.Query(ctx, sql)
}

// QueryRows is a helper that executes the requested query and returns the
// results. It handles checking both the transport level errors and the osquery
// internal errors by returning a normal Go error type.
func (c *ExtensionManagerClient) QueryRows(sql string) ([]map[string]string, error) {
res, err := c.Query(sql)
return c.QueryRowsContext(context.Background(), sql)
}

// QueryRowsContext is a helper that executes the requested query and returns the
// results. It handles checking both the transport level errors and the osquery
// internal errors by returning a normal Go error type.
func (c *ExtensionManagerClient) QueryRowsContext(ctx context.Context, sql string) ([]map[string]string, error) {
res, err := c.QueryContext(ctx, sql)
if err != nil {
return nil, errors.Wrap(err, "transport error in query")
}
Expand All @@ -100,7 +216,13 @@ func (c *ExtensionManagerClient) QueryRows(sql string) ([]map[string]string, err
// QueryRow behaves similarly to QueryRows, but it returns an error if the
// query does not return exactly one row.
func (c *ExtensionManagerClient) QueryRow(sql string) (map[string]string, error) {
res, err := c.QueryRows(sql)
return c.QueryRowContext(context.Background(), sql)
}

// QueryRowContext behaves similarly to QueryRows, but it returns an error if the
// query does not return exactly one row.
func (c *ExtensionManagerClient) QueryRowContext(ctx context.Context, sql string) (map[string]string, error) {
res, err := c.QueryRowsContext(ctx, sql)
if err != nil {
return nil, err
}
Expand All @@ -110,7 +232,16 @@ func (c *ExtensionManagerClient) QueryRow(sql string) (map[string]string, error)
return res[0], nil
}

// GetQueryColumns requests the columns returned by the parsed query.
// GetQueryColumns requests the columns returned by the parsed query, using a new background context.
func (c *ExtensionManagerClient) GetQueryColumns(sql string) (*osquery.ExtensionResponse, error) {
return c.Client.GetQueryColumns(context.Background(), sql)
return c.GetQueryColumnsContext(context.Background(), sql)
}

// GetQueryColumnsContext requests the columns returned by the parsed query.
func (c *ExtensionManagerClient) GetQueryColumnsContext(ctx context.Context, sql string) (*osquery.ExtensionResponse, error) {
if err := c.lock.Lock(ctx); err != nil {
return nil, err
}
defer c.lock.Unlock()
return c.client.GetQueryColumns(ctx, sql)
}
95 changes: 94 additions & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,23 @@ package osquery
import (
"context"
"errors"
"fmt"
"os"
"sync"
"testing"
"time"

"github.com/osquery/osquery-go/gen/osquery"
"github.com/osquery/osquery-go/mock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestQueryRows(t *testing.T) {
t.Parallel()
mock := &mock.ExtensionManager{}
client := &ExtensionManagerClient{Client: mock}
client, err := NewClient("", 5*time.Second, WithOsqueryThriftClient(mock))
require.NoError(t, err)

// Transport related error
mock.QueryFunc = func(ctx context.Context, sql string) (*osquery.ExtensionResponse, error) {
Expand Down Expand Up @@ -77,3 +84,89 @@ func TestQueryRows(t *testing.T) {
row, err = client.QueryRow("select 1 union select 2")
assert.NotNil(t, err)
}

// TestLocking tests the the client correctly locks access to the osquery socket. Thrift only support a single
directionless marked this conversation as resolved.
Show resolved Hide resolved
// actor on the socket at a time, this means that in parallel go code, it's very easy to have messages get
// crossed and generate errors. This tests to ensure the locking works
func TestLocking(t *testing.T) {
t.Parallel()

sock := os.Getenv("OSQ_SOCKET")
if sock == "" {
t.Skip("no osquery socket specified")
}

osq, err := NewClient(sock, 5*time.Second)
require.NoError(t, err)

// The issue we're testing is about multithreaded access. Let's hammer on it!
wait := sync.WaitGroup{}
for i := 0; i < 100; i++ {
wait.Add(1)
go func() {
defer wait.Done()

status, err := osq.Ping()
require.NoError(t, err, "call to Ping()")
if err != nil {
require.Equal(t, 0, status.Code, fmt.Errorf("ping returned %d: %s", status.Code, status.Message))
}
}()
}

wait.Wait()
}

func TestLockTimeouts(t *testing.T) {
t.Parallel()
mock := &mock.ExtensionManager{}
client, err := NewClient("", 5*time.Second, WithOsqueryThriftClient(mock), DefaultWaitTime(100*time.Millisecond), DefaultWaitTime(5*time.Second))
require.NoError(t, err)

wait := sync.WaitGroup{}

errChan := make(chan error, 10)
for i := 0; i < 3; i++ {
wait.Add(1)
go func() {
defer wait.Done()

ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
defer cancel()

errChan <- client.SlowLocker(ctx, 75*time.Millisecond)
}()
}

wait.Wait()
close(errChan)

var successCount, errCount int
for err := range errChan {
if err == nil {
successCount += 1
} else {
errCount += 1
}
}

assert.Equal(t, 2, successCount, "expected success count")
assert.Equal(t, 1, errCount, "expected error count")
}

// WithOsqueryThriftClient sets the underlying thrift client. This can be used to set a mock
func WithOsqueryThriftClient(client osquery.ExtensionManager) ClientOption {
return func(c *ExtensionManagerClient) {
c.client = client
}
}

// SlowLocker attempts to emulate a slow sql routine, so we can test how lock timeouts work.
func (c *ExtensionManagerClient) SlowLocker(ctx context.Context, d time.Duration) error {
if err := c.lock.Lock(ctx); err != nil {
return err
}
defer c.lock.Unlock()
time.Sleep(d)
return nil
}
10 changes: 7 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ module github.com/osquery/osquery-go
require (
github.com/Microsoft/go-winio v0.4.9
github.com/apache/thrift v0.16.0
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pkg/errors v0.8.0
github.com/stretchr/testify v1.8.3
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/testify v1.2.2
golang.org/x/sys v0.1.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

go 1.16
go 1.19
Loading