Skip to content

Commit

Permalink
switch db driver from lib/pq to pgx (techschool#93)
Browse files Browse the repository at this point in the history
Co-authored-by: phamlequang <phamlequang@gmail.com>
  • Loading branch information
techschool and phamlequang authored Jun 13, 2023
1 parent 9974870 commit 137b9a0
Show file tree
Hide file tree
Showing 36 changed files with 224 additions and 197 deletions.
14 changes: 5 additions & 9 deletions api/account.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package api

import (
"database/sql"
"errors"
"net/http"

"github.com/gin-gonic/gin"
"github.com/lib/pq"
db "github.com/techschool/simplebank/db/sqlc"
"github.com/techschool/simplebank/token"
)
Expand All @@ -31,12 +29,10 @@ func (server *Server) createAccount(ctx *gin.Context) {

account, err := server.store.CreateAccount(ctx, arg)
if err != nil {
if pqErr, ok := err.(*pq.Error); ok {
switch pqErr.Code.Name() {
case "foreign_key_violation", "unique_violation":
ctx.JSON(http.StatusForbidden, errorResponse(err))
return
}
errCode := db.ErrorCode(err)
if errCode == db.ForeignKeyViolation || errCode == db.UniqueViolation {
ctx.JSON(http.StatusForbidden, errorResponse(err))
return
}
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
Expand All @@ -58,7 +54,7 @@ func (server *Server) getAccount(ctx *gin.Context) {

account, err := server.store.GetAccount(ctx, req.ID)
if err != nil {
if err == sql.ErrNoRows {
if errors.Is(err, db.ErrRecordNotFound) {
ctx.JSON(http.StatusNotFound, errorResponse(err))
return
}
Expand Down
8 changes: 4 additions & 4 deletions api/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"database/sql"
"encoding/json"
"fmt"
"io/ioutil"
"io"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -89,7 +89,7 @@ func TestGetAccountAPI(t *testing.T) {
store.EXPECT().
GetAccount(gomock.Any(), gomock.Eq(account.ID)).
Times(1).
Return(db.Account{}, sql.ErrNoRows)
Return(db.Account{}, db.ErrRecordNotFound)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusNotFound, recorder.Code)
Expand Down Expand Up @@ -430,7 +430,7 @@ func randomAccount(owner string) db.Account {
}

func requireBodyMatchAccount(t *testing.T, body *bytes.Buffer, account db.Account) {
data, err := ioutil.ReadAll(body)
data, err := io.ReadAll(body)
require.NoError(t, err)

var gotAccount db.Account
Expand All @@ -440,7 +440,7 @@ func requireBodyMatchAccount(t *testing.T, body *bytes.Buffer, account db.Accoun
}

func requireBodyMatchAccounts(t *testing.T, body *bytes.Buffer, accounts []db.Account) {
data, err := ioutil.ReadAll(body)
data, err := io.ReadAll(body)
require.NoError(t, err)

var gotAccounts []db.Account
Expand Down
5 changes: 3 additions & 2 deletions api/token.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package api

import (
"database/sql"
"errors"
"fmt"
"net/http"
"time"

"github.com/gin-gonic/gin"
db "github.com/techschool/simplebank/db/sqlc"
)

type renewAccessTokenRequest struct {
Expand All @@ -33,7 +34,7 @@ func (server *Server) renewAccessToken(ctx *gin.Context) {

session, err := server.store.GetSession(ctx, refreshPayload.ID)
if err != nil {
if err == sql.ErrNoRows {
if errors.Is(err, db.ErrRecordNotFound) {
ctx.JSON(http.StatusNotFound, errorResponse(err))
return
}
Expand Down
3 changes: 1 addition & 2 deletions api/transfer.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"database/sql"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -60,7 +59,7 @@ func (server *Server) createTransfer(ctx *gin.Context) {
func (server *Server) validAccount(ctx *gin.Context, accountID int64, currency string) (db.Account, bool) {
account, err := server.store.GetAccount(ctx, accountID)
if err != nil {
if err == sql.ErrNoRows {
if errors.Is(err, db.ErrRecordNotFound) {
ctx.JSON(http.StatusNotFound, errorResponse(err))
return account, false
}
Expand Down
4 changes: 2 additions & 2 deletions api/transfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func TestTransferAPI(t *testing.T) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(db.Account{}, sql.ErrNoRows)
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(db.Account{}, db.ErrRecordNotFound)
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account2.ID)).Times(0)
store.EXPECT().TransferTx(gomock.Any(), gomock.Any()).Times(0)
},
Expand All @@ -137,7 +137,7 @@ func TestTransferAPI(t *testing.T) {
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account1.ID)).Times(1).Return(account1, nil)
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account2.ID)).Times(1).Return(db.Account{}, sql.ErrNoRows)
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account2.ID)).Times(1).Return(db.Account{}, db.ErrRecordNotFound)
store.EXPECT().TransferTx(gomock.Any(), gomock.Any()).Times(0)
},
checkResponse: func(recorder *httptest.ResponseRecorder) {
Expand Down
14 changes: 5 additions & 9 deletions api/user.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package api

import (
"database/sql"
"errors"
"net/http"
"time"

"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/lib/pq"
db "github.com/techschool/simplebank/db/sqlc"
"github.com/techschool/simplebank/util"
)
Expand Down Expand Up @@ -59,12 +58,9 @@ func (server *Server) createUser(ctx *gin.Context) {

user, err := server.store.CreateUser(ctx, arg)
if err != nil {
if pqErr, ok := err.(*pq.Error); ok {
switch pqErr.Code.Name() {
case "unique_violation":
ctx.JSON(http.StatusForbidden, errorResponse(err))
return
}
if db.ErrorCode(err) == db.UniqueViolation {
ctx.JSON(http.StatusForbidden, errorResponse(err))
return
}
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
Expand Down Expand Up @@ -97,7 +93,7 @@ func (server *Server) loginUser(ctx *gin.Context) {

user, err := server.store.GetUser(ctx, req.Username)
if err != nil {
if err == sql.ErrNoRows {
if errors.Is(err, db.ErrRecordNotFound) {
ctx.JSON(http.StatusNotFound, errorResponse(err))
return
}
Expand Down
13 changes: 6 additions & 7 deletions api/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@ import (
"database/sql"
"encoding/json"
"fmt"
"io/ioutil"
"io"
"net/http"
"net/http/httptest"
"reflect"
"testing"

"github.com/gin-gonic/gin"
"github.com/golang/mock/gomock"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
mockdb "github.com/techschool/simplebank/db/mock"
db "github.com/techschool/simplebank/db/sqlc"
Expand Down Expand Up @@ -111,7 +110,7 @@ func TestCreateUserAPI(t *testing.T) {
store.EXPECT().
CreateUser(gomock.Any(), gomock.Any()).
Times(1).
Return(db.User{}, &pq.Error{Code: "23505"})
Return(db.User{}, db.ErrUniqueViolation)
},
checkResponse: func(recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusForbidden, recorder.Code)
Expand Down Expand Up @@ -235,7 +234,7 @@ func TestLoginUserAPI(t *testing.T) {
store.EXPECT().
GetUser(gomock.Any(), gomock.Any()).
Times(1).
Return(db.User{}, sql.ErrNoRows)
Return(db.User{}, db.ErrRecordNotFound)
},
checkResponse: func(recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusNotFound, recorder.Code)
Expand Down Expand Up @@ -276,8 +275,8 @@ func TestLoginUserAPI(t *testing.T) {
{
name: "InvalidUsername",
body: gin.H{
"username": "invalid-user#1",
"password": password,
"username": "invalid-user#1",
"password": password,
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
Expand Down Expand Up @@ -332,7 +331,7 @@ func randomUser(t *testing.T) (user db.User, password string) {
}

func requireBodyMatchUser(t *testing.T, body *bytes.Buffer, user db.User) {
data, err := ioutil.ReadAll(body)
data, err := io.ReadAll(body)
require.NoError(t, err)

var gotUser db.User
Expand Down
1 change: 0 additions & 1 deletion app.env
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
ENVIRONMENT=development
DB_DRIVER=postgres
DB_SOURCE=postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable
MIGRATION_URL=file://db/migration
HTTP_SERVER_ADDRESS=0.0.0.0:8080
Expand Down
17 changes: 7 additions & 10 deletions db/sqlc/account.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 7 additions & 8 deletions db/sqlc/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package db

import (
"context"
"database/sql"
"testing"
"time"

Expand All @@ -19,7 +18,7 @@ func createRandomAccount(t *testing.T) Account {
Currency: util.RandomCurrency(),
}

account, err := testQueries.CreateAccount(context.Background(), arg)
account, err := testStore.CreateAccount(context.Background(), arg)
require.NoError(t, err)
require.NotEmpty(t, account)

Expand All @@ -39,7 +38,7 @@ func TestCreateAccount(t *testing.T) {

func TestGetAccount(t *testing.T) {
account1 := createRandomAccount(t)
account2, err := testQueries.GetAccount(context.Background(), account1.ID)
account2, err := testStore.GetAccount(context.Background(), account1.ID)
require.NoError(t, err)
require.NotEmpty(t, account2)

Expand All @@ -58,7 +57,7 @@ func TestUpdateAccount(t *testing.T) {
Balance: util.RandomMoney(),
}

account2, err := testQueries.UpdateAccount(context.Background(), arg)
account2, err := testStore.UpdateAccount(context.Background(), arg)
require.NoError(t, err)
require.NotEmpty(t, account2)

Expand All @@ -71,12 +70,12 @@ func TestUpdateAccount(t *testing.T) {

func TestDeleteAccount(t *testing.T) {
account1 := createRandomAccount(t)
err := testQueries.DeleteAccount(context.Background(), account1.ID)
err := testStore.DeleteAccount(context.Background(), account1.ID)
require.NoError(t, err)

account2, err := testQueries.GetAccount(context.Background(), account1.ID)
account2, err := testStore.GetAccount(context.Background(), account1.ID)
require.Error(t, err)
require.EqualError(t, err, sql.ErrNoRows.Error())
require.EqualError(t, err, ErrRecordNotFound.Error())
require.Empty(t, account2)
}

Expand All @@ -92,7 +91,7 @@ func TestListAccounts(t *testing.T) {
Offset: 0,
}

accounts, err := testQueries.ListAccounts(context.Background(), arg)
accounts, err := testStore.ListAccounts(context.Background(), arg)
require.NoError(t, err)
require.NotEmpty(t, accounts)

Expand Down
13 changes: 7 additions & 6 deletions db/sqlc/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 137b9a0

Please sign in to comment.