|
| 1 | +// Copyright 2020 The Gitea Authors. All rights reserved. |
| 2 | +// Use of this source code is governed by a MIT-style |
| 3 | +// license that can be found in the LICENSE file. |
| 4 | + |
| 5 | +package session |
| 6 | + |
| 7 | +import ( |
| 8 | + "log" |
| 9 | + "sync" |
| 10 | + |
| 11 | + "code.gitea.io/gitea/models" |
| 12 | + "code.gitea.io/gitea/modules/timeutil" |
| 13 | + |
| 14 | + "gitea.com/go-chi/session" |
| 15 | +) |
| 16 | + |
| 17 | +// DBStore represents a session store implementation based on the DB. |
| 18 | +type DBStore struct { |
| 19 | + sid string |
| 20 | + lock sync.RWMutex |
| 21 | + data map[interface{}]interface{} |
| 22 | +} |
| 23 | + |
| 24 | +// NewDBStore creates and returns a DB session store. |
| 25 | +func NewDBStore(sid string, kv map[interface{}]interface{}) *DBStore { |
| 26 | + return &DBStore{ |
| 27 | + sid: sid, |
| 28 | + data: kv, |
| 29 | + } |
| 30 | +} |
| 31 | + |
| 32 | +// Set sets value to given key in session. |
| 33 | +func (s *DBStore) Set(key, val interface{}) error { |
| 34 | + s.lock.Lock() |
| 35 | + defer s.lock.Unlock() |
| 36 | + |
| 37 | + s.data[key] = val |
| 38 | + return nil |
| 39 | +} |
| 40 | + |
| 41 | +// Get gets value by given key in session. |
| 42 | +func (s *DBStore) Get(key interface{}) interface{} { |
| 43 | + s.lock.RLock() |
| 44 | + defer s.lock.RUnlock() |
| 45 | + |
| 46 | + return s.data[key] |
| 47 | +} |
| 48 | + |
| 49 | +// Delete delete a key from session. |
| 50 | +func (s *DBStore) Delete(key interface{}) error { |
| 51 | + s.lock.Lock() |
| 52 | + defer s.lock.Unlock() |
| 53 | + |
| 54 | + delete(s.data, key) |
| 55 | + return nil |
| 56 | +} |
| 57 | + |
| 58 | +// ID returns current session ID. |
| 59 | +func (s *DBStore) ID() string { |
| 60 | + return s.sid |
| 61 | +} |
| 62 | + |
| 63 | +// Release releases resource and save data to provider. |
| 64 | +func (s *DBStore) Release() error { |
| 65 | + // Skip encoding if the data is empty |
| 66 | + if len(s.data) == 0 { |
| 67 | + return nil |
| 68 | + } |
| 69 | + |
| 70 | + data, err := session.EncodeGob(s.data) |
| 71 | + if err != nil { |
| 72 | + return err |
| 73 | + } |
| 74 | + |
| 75 | + return models.UpdateSession(s.sid, data) |
| 76 | +} |
| 77 | + |
| 78 | +// Flush deletes all session data. |
| 79 | +func (s *DBStore) Flush() error { |
| 80 | + s.lock.Lock() |
| 81 | + defer s.lock.Unlock() |
| 82 | + |
| 83 | + s.data = make(map[interface{}]interface{}) |
| 84 | + return nil |
| 85 | +} |
| 86 | + |
| 87 | +// DBProvider represents a DB session provider implementation. |
| 88 | +type DBProvider struct { |
| 89 | + maxLifetime int64 |
| 90 | +} |
| 91 | + |
| 92 | +// Init initializes DB session provider. |
| 93 | +// connStr: username:password@protocol(address)/dbname?param=value |
| 94 | +func (p *DBProvider) Init(maxLifetime int64, connStr string) error { |
| 95 | + p.maxLifetime = maxLifetime |
| 96 | + return nil |
| 97 | +} |
| 98 | + |
| 99 | +// Read returns raw session store by session ID. |
| 100 | +func (p *DBProvider) Read(sid string) (session.RawStore, error) { |
| 101 | + s, err := models.ReadSession(sid) |
| 102 | + if err != nil { |
| 103 | + return nil, err |
| 104 | + } |
| 105 | + |
| 106 | + var kv map[interface{}]interface{} |
| 107 | + if len(s.Data) == 0 || s.Expiry.Add(p.maxLifetime) <= timeutil.TimeStampNow() { |
| 108 | + kv = make(map[interface{}]interface{}) |
| 109 | + } else { |
| 110 | + kv, err = session.DecodeGob(s.Data) |
| 111 | + if err != nil { |
| 112 | + return nil, err |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + return NewDBStore(sid, kv), nil |
| 117 | +} |
| 118 | + |
| 119 | +// Exist returns true if session with given ID exists. |
| 120 | +func (p *DBProvider) Exist(sid string) bool { |
| 121 | + has, err := models.ExistSession(sid) |
| 122 | + if err != nil { |
| 123 | + panic("session/DB: error checking existence: " + err.Error()) |
| 124 | + } |
| 125 | + return has |
| 126 | +} |
| 127 | + |
| 128 | +// Destroy deletes a session by session ID. |
| 129 | +func (p *DBProvider) Destroy(sid string) error { |
| 130 | + return models.DestroySession(sid) |
| 131 | +} |
| 132 | + |
| 133 | +// Regenerate regenerates a session store from old session ID to new one. |
| 134 | +func (p *DBProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) { |
| 135 | + s, err := models.RegenerateSession(oldsid, sid) |
| 136 | + if err != nil { |
| 137 | + return nil, err |
| 138 | + |
| 139 | + } |
| 140 | + |
| 141 | + var kv map[interface{}]interface{} |
| 142 | + if len(s.Data) == 0 || s.Expiry.Add(p.maxLifetime) <= timeutil.TimeStampNow() { |
| 143 | + kv = make(map[interface{}]interface{}) |
| 144 | + } else { |
| 145 | + kv, err = session.DecodeGob(s.Data) |
| 146 | + if err != nil { |
| 147 | + return nil, err |
| 148 | + } |
| 149 | + } |
| 150 | + |
| 151 | + return NewDBStore(sid, kv), nil |
| 152 | +} |
| 153 | + |
| 154 | +// Count counts and returns number of sessions. |
| 155 | +func (p *DBProvider) Count() int { |
| 156 | + total, err := models.CountSessions() |
| 157 | + if err != nil { |
| 158 | + panic("session/DB: error counting records: " + err.Error()) |
| 159 | + } |
| 160 | + return int(total) |
| 161 | +} |
| 162 | + |
| 163 | +// GC calls GC to clean expired sessions. |
| 164 | +func (p *DBProvider) GC() { |
| 165 | + if err := models.CleanupSessions(p.maxLifetime); err != nil { |
| 166 | + log.Printf("session/DB: error garbage collecting: %v", err) |
| 167 | + } |
| 168 | +} |
| 169 | + |
| 170 | +func init() { |
| 171 | + session.Register("db", &DBProvider{}) |
| 172 | +} |
0 commit comments