diff --git a/stdlib/sql.go b/stdlib/sql.go index da377ecee..1c46e278e 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -58,6 +58,7 @@ import ( "math" "math/rand" "reflect" + "sort" "strconv" "strings" "sync" @@ -84,7 +85,13 @@ func init() { configs: make(map[string]*pgx.ConnConfig), } fakeTxConns = make(map[*pgx.Conn]*sql.Tx) - sql.Register("pgx", pgxDriver) + + drivers := sql.Drivers() + // if pgx driver was already registered by different pgx major version then we skip registration under the default name. + if i := sort.SearchStrings(sql.Drivers(), "pgx"); len(drivers) >= i || drivers[i] != "pgx" { + sql.Register("pgx", pgxDriver) + } + sql.Register("pgx/v4", pgxDriver) databaseSQLResultFormats = pgx.QueryResultFormatsByOID{ pgtype.BoolOID: 1, diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 099320c0a..e0aa50ca6 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -157,9 +157,22 @@ func closeStmt(t *testing.T, stmt *sql.Stmt) { } func TestSQLOpen(t *testing.T) { - db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - closeDB(t, db) + tests := []struct { + driverName string + }{ + {driverName: "pgx"}, + {driverName: "pgx/v4"}, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.driverName, func(t *testing.T) { + db, err := sql.Open(tt.driverName, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + closeDB(t, db) + }) + } } func TestNormalLifeCycle(t *testing.T) {