Skip to content

Commit 10508ca

Browse files
authored
fix: introduce fallback stub for ToolMetricsClient in dev and prod (#580)
- Refactor `ShouldFlag` signatures to accept an organisation ID - Introduce a fallback stub for `ToolMetricsClient` in dev and prod.
1 parent 1d3d316 commit 10508ca

File tree

9 files changed

+64
-20
lines changed

9 files changed

+64
-20
lines changed

server/cmd/gram/deps.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,22 +97,23 @@ func newToolMetricsClient(ctx context.Context, logger *slog.Logger, c *cli.Conte
9797
},
9898
})
9999
if err != nil {
100-
return nil, nilFunc, fmt.Errorf("connect to clickhouse: %w", err)
100+
logger.WarnContext(ctx, "error connecting to clickhouse; falling back to stub tool call metrics client")
101+
return &tm.StubToolMetricsClient{}, func(context.Context) error { return nil }, nil
101102
}
102103

103104
if err = conn.Ping(ctx); err != nil {
104-
return nil, nilFunc, fmt.Errorf("ping clickhouse: %w", err)
105+
logger.WarnContext(ctx, "failed to ping clickhouse; falling back to stub tool call metrics client", attr.SlogError(err))
106+
return &tm.StubToolMetricsClient{}, func(context.Context) error { return nil }, nil
105107
}
106108

107-
cc := tm.New(logger, tracerProvider, conn, func(ctx context.Context, log tm.ToolHTTPRequest) (bool, error) {
109+
cc := tm.New(logger, tracerProvider, conn, func(ctx context.Context, orgId string) (bool, error) {
108110
f := conv.Default[feature.Provider](features, &feature.InMemory{})
109-
isEnabled, err := f.IsFlagEnabled(ctx, feature.FlagClickhouseToolMetrics, log.OrganizationID)
111+
isEnabled, err := f.IsFlagEnabled(ctx, feature.FlagClickhouseToolMetrics, orgId)
110112
if err != nil {
111113
logger.ErrorContext(
112114
ctx, "error checking clickhouse feature flag",
113115
attr.SlogError(err),
114-
attr.SlogOrganizationSlug(log.OrganizationID),
115-
attr.SlogProjectID(log.ProjectID),
116+
attr.SlogOrganizationSlug(orgId),
116117
)
117118
return false, fmt.Errorf("check clickhouse feature flag: %w", err)
118119
}

server/internal/gateway/proxy.go

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -625,13 +625,31 @@ func reverseProxyRequest(
625625
MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1,
626626
}
627627

628-
// Wrap with HTTP logging round tripper
629-
loggingTransport := tm.NewHTTPLoggingRoundTripper(transport, tcm, logger, tracer)
630-
631-
otelTransport := otelhttp.NewTransport(
632-
loggingTransport,
633-
otelhttp.WithPropagators(propagation.TraceContext{}),
634-
)
628+
isAllowed, err := tcm.ShouldLog(ctx, tool.OrganizationID)
629+
if err != nil {
630+
// If we can't determine if the tool is allowed to log, we won't log the request.
631+
isAllowed = false
632+
logger.ErrorContext(ctx,
633+
"failed to determine if tool is allowed to log",
634+
attr.SlogOrganizationID(tool.OrganizationID),
635+
attr.SlogToolName(tool.Name),
636+
attr.SlogError(err))
637+
}
638+
639+
var otelTransport *otelhttp.Transport
640+
if isAllowed {
641+
// Wrap with HTTP logging round tripper
642+
loggingTransport := tm.NewHTTPLoggingRoundTripper(transport, tcm, logger, tracer)
643+
otelTransport = otelhttp.NewTransport(
644+
loggingTransport,
645+
otelhttp.WithPropagators(propagation.TraceContext{}),
646+
)
647+
} else {
648+
otelTransport = otelhttp.NewTransport(
649+
transport,
650+
otelhttp.WithPropagators(propagation.TraceContext{}),
651+
)
652+
}
635653

636654
client := &http.Client{
637655
Timeout: 60 * time.Second,

server/internal/gateway/proxy_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func newClickhouseClient(t *testing.T, orgId string) *toolmetrics.Queries {
7070

7171
tracerProvider := testenv.NewTracerProvider(t)
7272

73-
ch := toolmetrics.New(testenv.NewLogger(t), tracerProvider, chConn, func(ctx context.Context, log toolmetrics.ToolHTTPRequest) (bool, error) {
73+
ch := toolmetrics.New(testenv.NewLogger(t), tracerProvider, chConn, func(context.Context, string) (bool, error) {
7474
return true, nil
7575
})
7676

server/internal/logs/setup_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func newTestLogsService(t *testing.T) (context.Context, *testInstance) {
7474

7575
tracerProvider := testenv.NewTracerProvider(t)
7676

77-
chClient := toolmetrics.New(logger, tracerProvider, chConn, func(ctx context.Context, log toolmetrics.ToolHTTPRequest) (bool, error) {
77+
chClient := toolmetrics.New(logger, tracerProvider, chConn, func(context.Context, string) (bool, error) {
7878
return true, nil
7979
})
8080

server/internal/mcp/setup_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ func newTestMCPService(t *testing.T) (context.Context, *testInstance) {
9494
chConn, err := infra.NewClickhouseClient(t)
9595
require.NoError(t, err)
9696

97-
toolMetrics := toolmetrics.New(logger, tracerProvider, chConn, func(ctx context.Context, log toolmetrics.ToolHTTPRequest) (bool, error) {
97+
toolMetrics := toolmetrics.New(logger, tracerProvider, chConn, func(context.Context, string) (bool, error) {
9898
return true, nil
9999
})
100100

server/internal/thirdparty/toolmetrics/db.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ type Queries struct {
2020
conn CHTX
2121
logger *slog.Logger
2222
tracer trace.Tracer
23-
ShouldFlag func(ctx context.Context, log ToolHTTPRequest) (bool, error)
23+
ShouldFlag func(ctx context.Context, orgId string) (bool, error)
2424
}
2525

2626
// WithConn returns a new Queries instance using the provided connection.
@@ -34,9 +34,9 @@ func (q *Queries) WithConn(conn CHTX) *Queries {
3434
}
3535

3636
// New creates a new Queries instance with logger and tracer.
37-
func New(logger *slog.Logger, traceProvider trace.TracerProvider, conn CHTX, shouldFlag func(ctx context.Context, log ToolHTTPRequest) (bool, error)) *Queries {
37+
func New(logger *slog.Logger, traceProvider trace.TracerProvider, conn CHTX, shouldFlag func(ctx context.Context, orgId string) (bool, error)) *Queries {
3838
if shouldFlag == nil {
39-
shouldFlag = func(ctx context.Context, log ToolHTTPRequest) (bool, error) {
39+
shouldFlag = func(ctx context.Context, orgId string) (bool, error) {
4040
return true, nil
4141
}
4242
}

server/internal/thirdparty/toolmetrics/models.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,6 @@ type ToolMetricsProvider interface {
9696
List(ctx context.Context, opts ListToolLogsOptions) (*ListResult, error)
9797
// Log tool call request/response
9898
Log(context.Context, ToolHTTPRequest) error
99+
// ShouldLog returns true if the tool call should be logged
100+
ShouldLog(context.Context, string) (bool, error)
99101
}

server/internal/thirdparty/toolmetrics/queries.sql.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ order by ts
3737
limit $5
3838
`
3939

40+
func (q *Queries) ShouldLog(ctx context.Context, orgId string) (bool, error) {
41+
return q.ShouldFlag(ctx, orgId)
42+
}
43+
4044
// List retrieves tool logs based on the provided options.
4145
func (q *Queries) List(ctx context.Context, opts ListToolLogsOptions) (res *ListResult, err error) {
4246
projectID := opts.ProjectID
@@ -133,7 +137,7 @@ func (q *Queries) List(ctx context.Context, opts ListToolLogsOptions) (res *List
133137

134138
// Log inserts a tool HTTP request log entry.
135139
func (q *Queries) Log(ctx context.Context, log ToolHTTPRequest) (err error) {
136-
allow, err := q.ShouldFlag(ctx, log)
140+
allow, err := q.ShouldFlag(ctx, log.OrganizationID)
137141
if err != nil {
138142
q.logger.ErrorContext(ctx, "failed to fetch feature flag", attr.SlogError(err))
139143
return nil
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package toolmetrics
2+
3+
import (
4+
"context"
5+
)
6+
7+
type StubToolMetricsClient struct{}
8+
9+
func (n *StubToolMetricsClient) List(_ context.Context, _ ListToolLogsOptions) (*ListResult, error) {
10+
return nil, nil
11+
}
12+
13+
func (n *StubToolMetricsClient) Log(_ context.Context, _ ToolHTTPRequest) error {
14+
return nil
15+
}
16+
17+
func (n *StubToolMetricsClient) ShouldLog(_ context.Context, _ string) (bool, error) {
18+
return true, nil
19+
}

0 commit comments

Comments
 (0)