@@ -10,7 +10,7 @@ import (
10
10
"time"
11
11
)
12
12
13
- type ICacheService interface {
13
+ type CachePort interface {
14
14
Get (ctx context.Context , key string ) (string , error )
15
15
Remove (ctx context.Context , key string ) (bool , error )
16
16
Expire (ctx context.Context , key string , timeToLive time.Duration ) (bool , error )
@@ -28,14 +28,14 @@ type SessionAuthorizer struct {
28
28
DecodeSessionID func (value string ) (string , error )
29
29
EncodeSessionID func (sid string ) string
30
30
VerifyToken func (tokenString string , secret string ) (map [string ]interface {}, int64 , int64 , error )
31
- Cache ICacheService
31
+ Cache CachePort
32
32
sessionExpiredTime time.Duration
33
33
LogError func (ctx context.Context , msg string , opts ... map [string ]interface {})
34
34
}
35
35
36
36
func NewSessionAuthorizer (secretKey string , verifyToken func (tokenString string , secret string ) (map [string ]interface {}, int64 , int64 , error ),
37
37
refreshExpire func (w http.ResponseWriter , sessionId string ) error ,
38
- cache ICacheService , sessionExpiredTime time.Duration , logError func (ctx context.Context , msg string , opts ... map [string ]interface {}), singleSession bool ,
38
+ cache CachePort , sessionExpiredTime time.Duration , logError func (ctx context.Context , msg string , opts ... map [string ]interface {}), singleSession bool ,
39
39
encodeSessionID func (sid string ) string ,
40
40
decodeSessionID func (value string ) (string , error ),
41
41
opts ... string ) * SessionAuthorizer {
@@ -139,6 +139,9 @@ func (h *SessionAuthorizer) Authorize(next http.Handler, skipRefreshTTL bool) ht
139
139
return
140
140
}
141
141
ip := getForwardedRemoteIp (r )
142
+ if len (ip ) == 0 {
143
+ ip = getRemoteIp (r )
144
+ }
142
145
sid , ok := uData [h .SId ]
143
146
if ! ok || sid != sessionId ||
144
147
getValue (uData , "userAgent" ) != r .UserAgent () ||
@@ -180,6 +183,9 @@ func (h *SessionAuthorizer) Verify(next http.Handler, skipRefreshTTL bool, sessi
180
183
return
181
184
}
182
185
ip := getForwardedRemoteIp (r )
186
+ if len (ip ) == 0 {
187
+ ip = getRemoteIp (r )
188
+ }
183
189
ctx = context .WithValue (ctx , "ip" , ip )
184
190
for k , e := range payload {
185
191
if len (k ) > 0 {
@@ -232,7 +238,13 @@ func getForwardedRemoteIp(r *http.Request) string {
232
238
}
233
239
return ""
234
240
}
235
-
241
+ func getRemoteIp (r * http.Request ) string {
242
+ remoteIP , _ , err := net .SplitHostPort (r .RemoteAddr )
243
+ if err != nil {
244
+ remoteIP = r .RemoteAddr
245
+ }
246
+ return remoteIP
247
+ }
236
248
func getValue (data map [string ]interface {}, key string ) string {
237
249
if value , ok := data [key ]; ok {
238
250
return value .(string )
0 commit comments