Skip to content

Commit

Permalink
add simple auth
Browse files Browse the repository at this point in the history
  • Loading branch information
maddalax committed Dec 11, 2024
1 parent 036af51 commit 9e91a3d
Show file tree
Hide file tree
Showing 9 changed files with 489 additions and 8 deletions.
262 changes: 262 additions & 0 deletions app/auth_user.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
package app

import (
"dockman/app/util/json2"
"errors"
"github.com/google/uuid"
"github.com/maddalax/htmgo/framework/h"
"github.com/maddalax/htmgo/framework/service"
"github.com/nats-io/nats.go"
"golang.org/x/crypto/bcrypt"
"net/http"
"strings"
"time"
)

type User struct {
Id string `json:"id"`
Email string `json:"email"`
Password string `json:"password"`
}

type Session struct {
Token string `json:"token"`
UserId string `json:"user_id"`
ExpiresAt time.Time `json:"expires_at"`
}

func (session *Session) Write(ctx *h.RequestContext) {
cookie := http.Cookie{
Name: "session_id",
Value: session.Token,
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Expires: session.ExpiresAt,
Path: "/",
}
ctx.SetCookie(&cookie)
}

func UserIsInitialSetup(locator *service.Locator) bool {
anyUsers, err := UserGetByPredicate(locator, func(u *User) bool {
return true
})

if err != nil {
return true
}

if anyUsers != nil {
return true
}

return false
}

func UserCreate(locator *service.Locator, user *User) (*User, error) {

if UserIsInitialSetup(locator) {
return nil, errors.New("registration is disabled")
}

anyUsers, err := UserGetByPredicate(locator, func(u *User) bool {
return true
})

if err != nil {
return nil, err
}

if anyUsers != nil {
return nil, errors.New("registration is disabled")
}

client := service.Get[KvClient](locator)

user.Email = strings.TrimSpace(user.Email)
user.Email = strings.ToLower(user.Email)

u, err := UserGetByPredicate(locator, func(u *User) bool {
return strings.ToLower(u.Email) == strings.ToLower(user.Email)
})

if u != nil {
return nil, errors.New("user already exists by that email")
}

hashedPass, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)

if err != nil {
return nil, err
}

user.Password = string(hashedPass)
user.Id = uuid.NewString()

bucket, err := client.GetOrCreateBucket(&nats.KeyValueConfig{
Bucket: "users",
})

if err != nil {
return nil, err
}

serialized, err := json2.Serialize(user)

if err != nil {
return nil, err
}

_, err = bucket.Create(user.Id, serialized)

if err != nil {
return nil, err
}

return user, nil
}

func UserGetByPredicate(locator *service.Locator, predicate func(user *User) bool) (*User, error) {
client := service.Get[KvClient](locator)
users, err := client.GetOrCreateBucket(&nats.KeyValueConfig{
Bucket: "users",
})
if err != nil {
return nil, err
}

keys, err := users.ListKeys()

if err != nil {
return nil, err
}

for key := range keys.Keys() {
raw, err := users.Get(key)
if err != nil {
continue
}
user, err := json2.Deserialize[User](raw.Value())
if err != nil {
continue
}
if predicate(user) {
return user, nil
}
}

return nil, nil
}

const sessionDuration = 24 * time.Hour

// UserLogin verifies user credentials and generates a session token
func UserLogin(locator *service.Locator, email, password string) (*Session, error) {
client := service.Get[KvClient](locator)

// Retrieve the user by email
user, err := UserGetByPredicate(locator, func(u *User) bool {
return u.Email == strings.ToLower(strings.TrimSpace(email))
})
if err != nil {
return nil, err
}
if user == nil {
return nil, errors.New("invalid email or password")
}

// Verify the password
err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
if err != nil {
return nil, errors.New("invalid email or password")
}

// Create a session token
token := uuid.NewString()
session := &Session{
Token: token,
UserId: user.Id,
ExpiresAt: time.Now().Add(sessionDuration),
}

// Store the session in the "sessions" bucket
bucket, err := client.GetOrCreateBucket(&nats.KeyValueConfig{
Bucket: "sessions",
})
if err != nil {
return nil, err
}

serialized, err := json2.Serialize(session)
if err != nil {
return nil, err
}

_, err = bucket.Create(token, serialized)
if err != nil {
return nil, err
}

return session, nil
}

// ValidateSession checks the validity of a session token
func ValidateSession(ctx *h.RequestContext) (*User, error) {
sessionTokenCookie, err := ctx.Request.Cookie("session_id")

if err != nil {
return nil, errors.New("authorization token not provided")
}

sessionToken := sessionTokenCookie.Value

if sessionToken == "" {
return nil, errors.New("authorization token not provided")
}

client := service.Get[KvClient](ctx.ServiceLocator())

// Retrieve the session from the "sessions" bucket
bucket, err := client.GetOrCreateBucket(&nats.KeyValueConfig{
Bucket: "sessions",
})

if err != nil {
return nil, err
}

raw, err := bucket.Get(sessionToken)

if errors.Is(err, nats.ErrKeyNotFound) {
return nil, errors.New("invalid or expired session token")
}

if err != nil {
return nil, err
}

session, err := json2.Deserialize[Session](raw.Value())

if err != nil {
return nil, err
}

// Check if the session is expired
if time.Now().After(session.ExpiresAt) {
_ = bucket.Delete(sessionToken) // Clean up expired session
return nil, errors.New("session token expired")
}

// Retrieve the user associated with the session
user, err := UserGetByPredicate(ctx.ServiceLocator(), func(u *User) bool {
return u.Id == session.UserId
})
if err != nil {
return nil, err
}
if user == nil {
return nil, errors.New("user not found for session")
}

return user, nil
}
2 changes: 1 addition & 1 deletion app/resource_monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (monitor *ResourceMonitor) Start() {
runner.Add(source, "ResourceServerCleanup", "Detaches servers that no longer exist from resources", time.Minute, monitor.ResourceServerCleanup)
runner.Add(source, "ServerConnectionMonitor", "Monitors if connected servers are still connected by checking for a heartbeat", time.Second*5, monitor.ServerConnectionMonitor)
runner.Add(source, "ResourceCheckForNewCommits", "Checks if a resource has a new commit and starts a new deployment if enabled", time.Second*30, monitor.ResourceCheckForNewCommits)
runner.Add(source, "ServerDuplicateCleanup", "Checks if there are any servers with the same remote ip and dedupes them", time.Second*30, monitor.CleanupDuplicateServers)
runner.Add(source, "ServerDuplicateCleanup", "Checks if there are any servers with the same remote ip and deduplicates them", time.Second*30, monitor.CleanupDuplicateServers)

}

Expand Down
13 changes: 13 additions & 0 deletions app/ui/alert.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,19 @@ func SuccessAlert(title *h.Element, message *h.Element) *h.Element {
}

func ErrorAlert(title *h.Element, message *h.Element) *h.Element {

if message == nil {
return h.Div(
h.Id("ui-alert"),
h.Role("alert"),
h.Class("rounded border-s-4 border-red-500 bg-red-50 p-4 w-full"),
h.P(
h.Class("text-sm text-red-700"),
title,
),
)
}

return h.Div(
h.Id("ui-alert"),
h.Role("alert"),
Expand Down
19 changes: 19 additions & 0 deletions app/ui/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package ui

import "github.com/maddalax/htmgo/framework/h"

func FormError(error string) *h.Element {
return h.Div(
h.Id("form-error"),
h.If(
error != "",
ErrorAlert(h.Pf(error), nil),
),
)
}

func SwapFormError(ctx *h.RequestContext, error string) *h.Partial {
return h.SwapPartial(ctx,
FormError(error),
)
}
14 changes: 8 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"dockman/__htmgo"
"dockman/app"
"dockman/middleware"
"fmt"
"github.com/maddalax/htmgo/extensions/websocket"
ws2 "github.com/maddalax/htmgo/extensions/websocket/opts"
Expand Down Expand Up @@ -42,13 +43,14 @@ func main() {
h.Start(h.AppOpts{
ServiceLocator: locator,
LiveReload: true,
Register: func(app *h.App) {

app.Use(func(ctx *h.RequestContext) {
Register: func(a *h.App) {
a.Use(func(ctx *h.RequestContext) {
session.CreateSession(ctx)
})

websocket.EnableExtension(app, ws2.ExtensionOpts{
middleware.UseLoginRequiredMiddleware(a.Router)

websocket.EnableExtension(a, ws2.ExtensionOpts{
WsPath: "/ws",
RoomName: func(ctx *h.RequestContext) string {
return "all"
Expand All @@ -68,10 +70,10 @@ func main() {

cfg := config.Get()
// change this in htmgo.yml (public_asset_path)
app.Router.Handle(fmt.Sprintf("%s/*", cfg.PublicAssetPath),
a.Router.Handle(fmt.Sprintf("%s/*", cfg.PublicAssetPath),
http.StripPrefix(cfg.PublicAssetPath, http.FileServerFS(sub)))

__htmgo.Register(app.Router)
__htmgo.Register(a.Router)
},
})
}
37 changes: 37 additions & 0 deletions middleware/auth_middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package middleware

import (
"dockman/app"
"dockman/pages"
"github.com/go-chi/chi/v5"
"github.com/maddalax/htmgo/framework/h"
"net/http"
)

func UseLoginRequiredMiddleware(router *chi.Mux) {
router.Use(func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
allowedPaths := []string{
"/login",
"/logout",
h.GetPartialPath(pages.RegisterUser),
h.GetPartialPath(pages.LoginUser)}

for _, path := range allowedPaths {
if r.URL.Path == path {
handler.ServeHTTP(w, r)
return
}
}

ctx := h.GetRequestContext(r)
user, err := app.ValidateSession(ctx)
if err != nil {
http.Redirect(w, r, "/login", http.StatusFound)
return
}
ctx.Set("user", user)
handler.ServeHTTP(w, r)
})
})
}
Loading

0 comments on commit 9e91a3d

Please sign in to comment.