Skip to content

Commit

Permalink
add session and refresh_token to the api (techschool#25)
Browse files Browse the repository at this point in the history
Co-authored-by: phamlequang <phamlequang@gmail.com>
  • Loading branch information
techschool and phamlequang authored Feb 5, 2022
1 parent 0a2b7b7 commit 5663a67
Show file tree
Hide file tree
Showing 21 changed files with 915 additions and 143 deletions.
3 changes: 2 additions & 1 deletion api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ func addAuthorization(
username string,
duration time.Duration,
) {
token, err := tokenMaker.CreateToken(username, duration)
token, payload, err := tokenMaker.CreateToken(username, duration)
require.NoError(t, err)
require.NotEmpty(t, payload)

authorizationHeader := fmt.Sprintf("%s %s", authorizationType, token)
request.Header.Set(authorizationHeaderKey, authorizationHeader)
Expand Down
1 change: 1 addition & 0 deletions api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func (server *Server) setupRouter() {

router.POST("/users", server.createUser)
router.POST("/users/login", server.loginUser)
router.POST("/tokens/renew_access", server.renewAccessToken)

authRoutes := router.Group("/").Use(authMiddleware(server.tokenMaker))
authRoutes.POST("/accounts", server.createAccount)
Expand Down
82 changes: 82 additions & 0 deletions api/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package api

import (
"database/sql"
"fmt"
"net/http"
"time"

"github.com/gin-gonic/gin"
)

type renewAccessTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
}

type renewAccessTokenResponse struct {
AccessToken string `json:"access_token"`
AccessTokenExpiresAt time.Time `json:"access_token_expires_at"`
}

func (server *Server) renewAccessToken(ctx *gin.Context) {
var req renewAccessTokenRequest
if err := ctx.ShouldBindJSON(&req); err != nil {
ctx.JSON(http.StatusBadRequest, errorResponse(err))
return
}

refreshPayload, err := server.tokenMaker.VerifyToken(req.RefreshToken)
if err != nil {
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return
}

session, err := server.store.GetSession(ctx, refreshPayload.ID)
if err != nil {
if err == sql.ErrNoRows {
ctx.JSON(http.StatusNotFound, errorResponse(err))
return
}
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
}

if session.IsBlocked {
err := fmt.Errorf("blocked session")
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return
}

if session.Username != refreshPayload.Username {
err := fmt.Errorf("incorrect session user")
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return
}

if session.RefreshToken != req.RefreshToken {
err := fmt.Errorf("mismatched session token")
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return
}

if time.Now().After(session.ExpiresAt) {
err := fmt.Errorf("expired session")
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return
}

accessToken, accessPayload, err := server.tokenMaker.CreateToken(
refreshPayload.Username,
server.config.AccessTokenDuration,
)
if err != nil {
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
}

rsp := renewAccessTokenResponse{
AccessToken: accessToken,
AccessTokenExpiresAt: accessPayload.ExpiredAt,
}
ctx.JSON(http.StatusOK, rsp)
}
42 changes: 37 additions & 5 deletions api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/lib/pq"
db "github.com/techschool/simplebank/db/sqlc"
"github.com/techschool/simplebank/util"
Expand Down Expand Up @@ -79,8 +80,12 @@ type loginUserRequest struct {
}

type loginUserResponse struct {
AccessToken string `json:"access_token"`
User userResponse `json:"user"`
SessionID uuid.UUID `json:"session_id"`
AccessToken string `json:"access_token"`
AccessTokenExpiresAt time.Time `json:"access_token_expires_at"`
RefreshToken string `json:"refresh_token"`
RefreshTokenExpiresAt time.Time `json:"refresh_token_expires_at"`
User userResponse `json:"user"`
}

func (server *Server) loginUser(ctx *gin.Context) {
Expand All @@ -106,7 +111,7 @@ func (server *Server) loginUser(ctx *gin.Context) {
return
}

accessToken, err := server.tokenMaker.CreateToken(
accessToken, accessPayload, err := server.tokenMaker.CreateToken(
user.Username,
server.config.AccessTokenDuration,
)
Expand All @@ -115,9 +120,36 @@ func (server *Server) loginUser(ctx *gin.Context) {
return
}

refreshToken, refreshPayload, err := server.tokenMaker.CreateToken(
user.Username,
server.config.RefreshTokenDuration,
)
if err != nil {
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
}

session, err := server.store.CreateSession(ctx, db.CreateSessionParams{
ID: refreshPayload.ID,
Username: user.Username,
RefreshToken: refreshToken,
UserAgent: ctx.Request.UserAgent(),
ClientIp: ctx.ClientIP(),
IsBlocked: false,
ExpiresAt: refreshPayload.ExpiredAt,
})
if err != nil {
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
}

rsp := loginUserResponse{
AccessToken: accessToken,
User: newUserResponse(user),
SessionID: session.ID,
AccessToken: accessToken,
AccessTokenExpiresAt: accessPayload.ExpiredAt,
RefreshToken: refreshToken,
RefreshTokenExpiresAt: refreshPayload.ExpiredAt,
User: newUserResponse(user),
}
ctx.JSON(http.StatusOK, rsp)
}
3 changes: 3 additions & 0 deletions api/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ func TestLoginUserAPI(t *testing.T) {
GetUser(gomock.Any(), gomock.Eq(user.Username)).
Times(1).
Return(user, nil)
store.EXPECT().
CreateSession(gomock.Any(), gomock.Any()).
Times(1)
},
checkResponse: func(recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusOK, recorder.Code)
Expand Down
1 change: 1 addition & 0 deletions app.env
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ DB_SOURCE=postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable
SERVER_ADDRESS=0.0.0.0:8080
TOKEN_SYMMETRIC_KEY=12345678901234567890123456789012
ACCESS_TOKEN_DURATION=15m
REFRESH_TOKEN_DURATION=24h
1 change: 1 addition & 0 deletions db/migration/000003_add_sessions.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP TABLE IF EXISTS "sessions";
12 changes: 12 additions & 0 deletions db/migration/000003_add_sessions.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
CREATE TABLE "sessions" (
"id" uuid PRIMARY KEY,
"username" varchar NOT NULL,
"refresh_token" varchar NOT NULL,
"user_agent" varchar NOT NULL,
"client_ip" varchar NOT NULL,
"is_blocked" boolean NOT NULL DEFAULT false,
"expires_at" timestamptz NOT NULL,
"created_at" timestamptz NOT NULL DEFAULT (now())
);

ALTER TABLE "sessions" ADD FOREIGN KEY ("username") REFERENCES "users" ("username");
31 changes: 31 additions & 0 deletions db/mock/store.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions db/query/session.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
-- name: CreateSession :one
INSERT INTO sessions (
id,
username,
refresh_token,
user_agent,
client_ip,
is_blocked,
expires_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7
) RETURNING *;

-- name: GetSession :one
SELECT * FROM sessions
WHERE id = $1 LIMIT 1;
13 changes: 13 additions & 0 deletions db/sqlc/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions db/sqlc/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 5663a67

Please sign in to comment.