Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ func RegisterWithSource(driverName string, source string, options ...DriverOptio
return "", errors.New("unable to register driver, all slots have been taken")
}

// Wrap takes a SQL driver and wraps it with OpenTelemetry instrumentation.
// Wrap takes an SQL driver and wraps it with OpenTelemetry instrumentation.
// It panics if there is an error when creating instruments.
func Wrap(d driver.Driver, opts ...DriverOption) driver.Driver {
o := driverOptions{
meterProvider: otel.GetMeterProvider(),
Expand All @@ -102,14 +103,18 @@ func Wrap(d driver.Driver, opts ...DriverOption) driver.Driver {
option.applyDriverOptions(&o)
}

return wrapDriver(d, o)
cc, err := newConnConfig(o)
if err != nil {
panic(err)
}

return wrapDriver(d, cc)
}

func wrapDriver(d driver.Driver, o driverOptions) driver.Driver {
func wrapDriver(d driver.Driver, cc connConfig) driver.Driver {
drv := otDriver{
parent: d,
connConfig: newConnConfig(o),
close: func() error { return nil },
connConfig: cc,
}

if _, ok := d.(driver.DriverContext); ok {
Expand All @@ -122,7 +127,7 @@ func wrapDriver(d driver.Driver, o driverOptions) driver.Driver {
return struct{ driver.Driver }{drv}
}

func newConnConfig(opts driverOptions) connConfig {
func newConnConfig(opts driverOptions) (connConfig, error) {
meter := opts.meterProvider.Meter(instrumentationName)
tracer := newMethodTracer(
opts.tracerProvider.Tracer(instrumentationName,
Expand All @@ -139,13 +144,17 @@ func newConnConfig(opts driverOptions) connConfig {
metric.WithUnit(unitMilliseconds),
metric.WithDescription(`The distribution of latencies of various calls in milliseconds`),
)
mustNoError(err)
if err != nil {
return connConfig{}, err
}

callsCounter, err := meter.Int64Counter(dbSQLClientCalls,
metric.WithUnit(unitDimensionless),
metric.WithDescription(`The number of various calls of methods`),
)
mustNoError(err)
if err != nil {
return connConfig{}, err
}

latencyRecorder := newMethodRecorder(latencyMsHistogram.Record, callsCounter.Add, opts.defaultAttributes...)

Expand All @@ -161,7 +170,7 @@ func newConnConfig(opts driverOptions) connConfig {
queryFuncMiddlewares: makeQueryerContextMiddlewares(latencyRecorder, tracerOrNil(tracer, opts.trace.AllowRoot), newQueryConfig(opts, metricMethodStmtQuery, traceMethodStmtQuery)),
queryContextFuncMiddlewares: makeQueryerContextMiddlewares(latencyRecorder, tracer, newQueryConfig(opts, metricMethodStmtQuery, traceMethodStmtQuery)),
}),
}
}, nil
}

var _ driver.Driver = (*otDriver)(nil)
Expand All @@ -184,7 +193,11 @@ func (d otDriver) Open(name string) (driver.Conn, error) {
}

func (d otDriver) Close() error {
return d.close()
if d.close != nil {
return d.close()
}

return nil
}

func (d otDriver) OpenConnector(name string) (driver.Connector, error) {
Expand Down
16 changes: 16 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,22 @@ func TestWrap_DriverContext_CloseError(t *testing.T) {
assert.Equal(t, expectedError, err)
}

func TestWrap_Panic(t *testing.T) {
t.Parallel()

parent := driverOpenFunc(func(string) (driver.Conn, error) {
return nil, errors.New("open error")
})

meterProviderOption := otelsql.WithMeterProvider(
oteltest.NewMeterProviderWithError(assert.AnError),
)

assert.PanicsWithValue(t, assert.AnError, func() {
_ = otelsql.Wrap(parent, meterProviderOption)
})
}

func Test_Open_Error(t *testing.T) {
t.Parallel()

Expand Down
15 changes: 0 additions & 15 deletions errors.go

This file was deleted.

28 changes: 0 additions & 28 deletions errors_internal_test.go

This file was deleted.

92 changes: 92 additions & 0 deletions internal/test/oteltest/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package oteltest

import (
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/embedded"
)

type errorMeterProvider struct {
embedded.MeterProvider

Error error
}

// NewMeterProviderWithError returns a new [metric.MeterProvider] that always
// returns the given error.
func NewMeterProviderWithError(e error) metric.MeterProvider {
return &errorMeterProvider{
Error: e,
}
}

func (e *errorMeterProvider) Meter(string, ...metric.MeterOption) metric.Meter {
return &errorMeter{
Error: e.Error,
}
}

type errorMeter struct {
embedded.Meter

Error error
}

func (e *errorMeter) Int64Counter(string, ...metric.Int64CounterOption) (metric.Int64Counter, error) {
return nil, e.Error
}

func (e *errorMeter) Int64UpDownCounter(string, ...metric.Int64UpDownCounterOption) (metric.Int64UpDownCounter, error) {
return nil, e.Error
}

func (e *errorMeter) Int64Histogram(string, ...metric.Int64HistogramOption) (metric.Int64Histogram, error) {
return nil, e.Error
}

func (e *errorMeter) Int64Gauge(string, ...metric.Int64GaugeOption) (metric.Int64Gauge, error) {
return nil, e.Error
}

func (e *errorMeter) Int64ObservableCounter(string, ...metric.Int64ObservableCounterOption) (metric.Int64ObservableCounter, error) {
return nil, e.Error
}

func (e *errorMeter) Int64ObservableUpDownCounter(string, ...metric.Int64ObservableUpDownCounterOption) (metric.Int64ObservableUpDownCounter, error) {
return nil, e.Error
}

func (e *errorMeter) Int64ObservableGauge(string, ...metric.Int64ObservableGaugeOption) (metric.Int64ObservableGauge, error) {
return nil, e.Error
}

func (e *errorMeter) Float64Counter(string, ...metric.Float64CounterOption) (metric.Float64Counter, error) {
return nil, e.Error
}

func (e *errorMeter) Float64UpDownCounter(string, ...metric.Float64UpDownCounterOption) (metric.Float64UpDownCounter, error) {
return nil, e.Error
}

func (e *errorMeter) Float64Histogram(string, ...metric.Float64HistogramOption) (metric.Float64Histogram, error) {
return nil, e.Error
}

func (e *errorMeter) Float64Gauge(string, ...metric.Float64GaugeOption) (metric.Float64Gauge, error) {
return nil, e.Error
}

func (e *errorMeter) Float64ObservableCounter(string, ...metric.Float64ObservableCounterOption) (metric.Float64ObservableCounter, error) {
return nil, e.Error
}

func (e *errorMeter) Float64ObservableUpDownCounter(string, ...metric.Float64ObservableUpDownCounterOption) (metric.Float64ObservableUpDownCounter, error) {
return nil, e.Error
}

func (e *errorMeter) Float64ObservableGauge(string, ...metric.Float64ObservableGaugeOption) (metric.Float64ObservableGauge, error) {
return nil, e.Error
}

func (e *errorMeter) RegisterCallback(metric.Callback, ...metric.Observable) (metric.Registration, error) {
return nil, e.Error
}
32 changes: 24 additions & 8 deletions stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,56 +75,72 @@ func recordStats(
metric.WithUnit(unitDimensionless),
metric.WithDescription("Count of open connections in the pool"),
)
handleErr(err)
if err != nil {
return err
}

idleConnections, err = meter.Int64ObservableGauge(
dbSQLConnectionsIdle,
metric.WithUnit(unitDimensionless),
metric.WithDescription("Count of idle connections in the pool"),
)
handleErr(err)
if err != nil {
return err
}

activeConnections, err = meter.Int64ObservableGauge(
dbSQLConnectionsActive,
metric.WithUnit(unitDimensionless),
metric.WithDescription("Count of active connections in the pool"),
)
handleErr(err)
if err != nil {
return err
}

waitCount, err = meter.Int64ObservableCounter(
dbSQLConnectionsWaitCount,
metric.WithUnit(unitDimensionless),
metric.WithDescription("The total number of connections waited for"),
)
handleErr(err)
if err != nil {
return err
}

waitDuration, err = meter.Float64ObservableCounter(
dbSQLConnectionsWaitDuration,
metric.WithUnit(unitMilliseconds),
metric.WithDescription("The total time blocked waiting for a new connection"),
)
handleErr(err)
if err != nil {
return err
}

idleClosed, err = meter.Int64ObservableCounter(
dbSQLConnectionsIdleClosed,
metric.WithUnit(unitDimensionless),
metric.WithDescription("The total number of connections closed due to SetMaxIdleConns"),
)
handleErr(err)
if err != nil {
return err
}

idleTimeClosed, err = meter.Int64ObservableCounter(
dbSQLConnectionsIdleTimeClosed,
metric.WithUnit(unitDimensionless),
metric.WithDescription("The total number of connections closed due to SetConnMaxIdleTime"),
)
handleErr(err)
if err != nil {
return err
}

lifetimeClosed, err = meter.Int64ObservableCounter(
dbSQLConnectionsLifetimeClosed,
metric.WithUnit(unitDimensionless),
metric.WithDescription("The total number of connections closed due to SetConnMaxLifetime"),
)
handleErr(err)
if err != nil {
return err
}

_, err = meter.RegisterCallback(func(_ context.Context, obs metric.Observer) error {
lock.Lock()
Expand Down
15 changes: 15 additions & 0 deletions stats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
semconv "go.opentelemetry.io/otel/semconv/v1.20.0"

Expand All @@ -12,6 +13,20 @@ import (
"go.nhat.io/otelsql/internal/test/sqlmock"
)

func TestRecordStatsError(t *testing.T) {
t.Parallel()

oteltest.New().Run(t, func(sc oteltest.SuiteContext) {
db, err := newDB(sc.DatabaseDSN())
require.NoError(t, err)

err = otelsql.RecordStats(db, otelsql.WithMeterProvider(
oteltest.NewMeterProviderWithError(assert.AnError),
))
require.ErrorIs(t, err, assert.AnError)
})
}

func TestRecordStats(t *testing.T) {
t.Parallel()

Expand Down