Skip to content

Commit

Permalink
validate grpc request parameters (techschool#35)
Browse files Browse the repository at this point in the history
Co-authored-by: phamlequang <phamlequang@gmail.com>
  • Loading branch information
techschool and phamlequang authored Apr 10, 2022
1 parent 9212f74 commit 5ead2d2
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 0 deletions.
26 changes: 26 additions & 0 deletions gapi/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package gapi

import (
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

func fieldViolation(field string, err error) *errdetails.BadRequest_FieldViolation {
return &errdetails.BadRequest_FieldViolation{
Field: field,
Description: err.Error(),
}
}

func invalidArgumentError(violations []*errdetails.BadRequest_FieldViolation) error {
badRequest := &errdetails.BadRequest{FieldViolations: violations}
statusInvalid := status.New(codes.InvalidArgument, "invalid parameters")

statusDetails, err := statusInvalid.WithDetails(badRequest)
if err != nil {
return statusInvalid.Err()
}

return statusDetails.Err()
}
27 changes: 27 additions & 0 deletions gapi/rpc_create_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@ import (
db "github.com/techschool/simplebank/db/sqlc"
"github.com/techschool/simplebank/pb"
"github.com/techschool/simplebank/util"
"github.com/techschool/simplebank/val"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

func (server *Server) CreateUser(ctx context.Context, req *pb.CreateUserRequest) (*pb.CreateUserResponse, error) {
violations := validateCreateUserRequest(req)
if violations != nil {
return nil, invalidArgumentError(violations)
}

hashedPassword, err := util.HashPassword(req.GetPassword())
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to hash password: %s", err)
Expand Down Expand Up @@ -40,3 +47,23 @@ func (server *Server) CreateUser(ctx context.Context, req *pb.CreateUserRequest)
}
return rsp, nil
}

func validateCreateUserRequest(req *pb.CreateUserRequest) (violations []*errdetails.BadRequest_FieldViolation) {
if err := val.ValidateUsername(req.GetUsername()); err != nil {
violations = append(violations, fieldViolation("username", err))
}

if err := val.ValidatePassword(req.GetPassword()); err != nil {
violations = append(violations, fieldViolation("password", err))
}

if err := val.ValidateFullName(req.GetFullName()); err != nil {
violations = append(violations, fieldViolation("full_name", err))
}

if err := val.ValidateEmail(req.GetEmail()); err != nil {
violations = append(violations, fieldViolation("email", err))
}

return violations
}
19 changes: 19 additions & 0 deletions gapi/rpc_login_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@ import (
db "github.com/techschool/simplebank/db/sqlc"
"github.com/techschool/simplebank/pb"
"github.com/techschool/simplebank/util"
"github.com/techschool/simplebank/val"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
)

func (server *Server) LoginUser(ctx context.Context, req *pb.LoginUserRequest) (*pb.LoginUserResponse, error) {
violations := validateLoginUserRequest(req)
if violations != nil {
return nil, invalidArgumentError(violations)
}

user, err := server.store.GetUser(ctx, req.GetUsername())
if err != nil {
if err == sql.ErrNoRows {
Expand Down Expand Up @@ -66,3 +73,15 @@ func (server *Server) LoginUser(ctx context.Context, req *pb.LoginUserRequest) (
}
return rsp, nil
}

func validateLoginUserRequest(req *pb.LoginUserRequest) (violations []*errdetails.BadRequest_FieldViolation) {
if err := val.ValidateUsername(req.GetUsername()); err != nil {
violations = append(violations, fieldViolation("username", err))
}

if err := val.ValidatePassword(req.GetPassword()); err != nil {
violations = append(violations, fieldViolation("password", err))
}

return violations
}
54 changes: 54 additions & 0 deletions val/validator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package val

import (
"fmt"
"net/mail"
"regexp"
)

var (
isValidUsername = regexp.MustCompile(`^[a-z0-9_]+$`).MatchString
isValidFullName = regexp.MustCompile(`^[a-zA-Z\\s]+$`).MatchString
)

func ValidateString(value string, minLength int, maxLength int) error {
n := len(value)
if n < minLength || n > maxLength {
return fmt.Errorf("must contain from %d-%d characters", minLength, maxLength)
}
return nil
}

func ValidateUsername(value string) error {
if err := ValidateString(value, 3, 100); err != nil {
return err
}
if !isValidUsername(value) {
return fmt.Errorf("must contain only lowercase letters, digits, or underscore")
}
return nil
}

func ValidateFullName(value string) error {
if err := ValidateString(value, 3, 100); err != nil {
return err
}
if !isValidFullName(value) {
return fmt.Errorf("must contain only letters or spaces")
}
return nil
}

func ValidatePassword(value string) error {
return ValidateString(value, 6, 100)
}

func ValidateEmail(value string) error {
if err := ValidateString(value, 3, 200); err != nil {
return err
}
if _, err := mail.ParseAddress(value); err != nil {
return fmt.Errorf("is not a valid email address")
}
return nil
}

0 comments on commit 5ead2d2

Please sign in to comment.