Skip to content

Commit 3d15ef5

Browse files
committed
Support Bearer token
Signed-off-by: kpango <kpango@vdaas.org>
1 parent 37b7829 commit 3d15ef5

File tree

4 files changed

+137
-6
lines changed

4 files changed

+137
-6
lines changed

handler/grpc.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ import (
4141
)
4242

4343
const (
44-
gRPC = "grpc"
44+
gRPC = "grpc"
45+
defaultAuthHeader = "Authorization"
4546
)
4647

4748
type GRPCHandler struct {
@@ -63,6 +64,14 @@ func NewGRPC(opts ...GRPCOption) (grpc.StreamHandler, io.Closer) {
6364
return nil, nil
6465
}
6566

67+
if gh.roleCfg.Enable && gh.roleCfg.RoleAuthHeader == "" {
68+
gh.roleCfg.RoleAuthHeader = defaultAuthHeader
69+
}
70+
71+
if gh.atCfg.Enable && gh.atCfg.AccessTokenAuthHeader == "" {
72+
gh.atCfg.AccessTokenAuthHeader = defaultAuthHeader
73+
}
74+
6675
dialOpts := []grpc.DialOption{
6776
grpc.WithTransportCredentials(insecure.NewCredentials()),
6877
}
@@ -73,7 +82,7 @@ func NewGRPC(opts ...GRPCOption) (grpc.StreamHandler, io.Closer) {
7382
for _, pattern := range gh.proxyCfg.OriginHealthCheckPaths {
7483
if pattern != "" &&
7584
(fullMethodName == pattern || wildcardMatch(pattern, fullMethodName)) {
76-
glg.Infof("Authorization checking skipped on: %s,\tby %s pattern", fullMethodName, pattern)
85+
glg.Infof("Authorization checking skipped on: %s by pattern %s", fullMethodName, pattern)
7786
conn, err = gh.dialContext(ctx, target, dialOpts...)
7887
return ctx, conn, err
7988
}
@@ -124,7 +133,7 @@ func (gh *GRPCHandler) authorizeRoleToken(ctx context.Context, fullMethodName st
124133
if len(rts) == 0 {
125134
return nil, status.Error(codes.Unauthenticated, ErrRoleTokenNotFound)
126135
}
127-
p, err = gh.authorizationd.AuthorizeRoleToken(ctx, rts[0], gRPC, fullMethodName)
136+
p, err = gh.authorizationd.AuthorizeRoleToken(ctx, trimBearer(rts[0]), gRPC, fullMethodName)
128137
if err != nil {
129138
return nil, status.Error(codes.Unauthenticated, err.Error())
130139
}
@@ -139,11 +148,12 @@ func (gh *GRPCHandler) authorizeAccessToken(ctx context.Context, fullMethodName
139148
if len(ats) == 0 {
140149
return nil, status.Error(codes.Unauthenticated, ErrAccessTokenNotFound)
141150
}
151+
tok := trimBearer(ats[0])
142152
cs, ok := clientCertFromContext(ctx)
143153
if ok && cs != nil && cs[0] != nil {
144-
p, err = gh.authorizationd.AuthorizeAccessToken(ctx, ats[0], gRPC, fullMethodName, cs[0])
154+
p, err = gh.authorizationd.AuthorizeAccessToken(ctx, tok, gRPC, fullMethodName, cs[0])
145155
} else {
146-
p, err = gh.authorizationd.AuthorizeAccessToken(ctx, ats[0], gRPC, fullMethodName, nil)
156+
p, err = gh.authorizationd.AuthorizeAccessToken(ctx, tok, gRPC, fullMethodName, nil)
147157
}
148158
if err != nil {
149159
return nil, status.Error(codes.Unauthenticated, err.Error())

handler/handler.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,23 @@ func wildcardMatch(pattern, target string) bool {
229229
}
230230
return true
231231
}
232+
233+
// trimBearer removes a leading "Bearer " or "bearer " (7 bytes)
234+
// with zero allocations. ASCII-only, single fast path, minimal branching.
235+
func trimBearer(tok string) string {
236+
tl := len(tok) // Require at least 7 bytes: "Bearer "
237+
if tl == 0 || tl < 7 {
238+
return tok
239+
}
240+
// Normalize only the first byte case via ASCII bit trick.
241+
// 'B' (0x42) and 'b' (0x62) share lowercasing by (b|0x20) == 'b' (0x62).
242+
if (tok[0] | 0x20) != 'b' {
243+
return tok
244+
}
245+
// Compare the remaining 6 bytes via string slice == const.
246+
// This becomes a single runtime.memequal call (no allocation).
247+
if tok[1:7] == "earer " {
248+
return tok[7:]
249+
}
250+
return tok
251+
}

handler/handler_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,3 +871,104 @@ func Test_handleError(t *testing.T) {
871871
})
872872
}
873873
}
874+
875+
func Test_wildcardMatch(t *testing.T) {
876+
tests := []struct {
877+
name string
878+
pattern string
879+
target string
880+
want bool
881+
}{
882+
// basic cases
883+
{"empty pattern and empty target", "", "", true},
884+
{"empty pattern and non-empty target", "", "test", false},
885+
{"non-empty pattern and empty target", "test", "", false},
886+
{"exact match", "test", "test", true},
887+
{"no match", "test", "different", false},
888+
889+
// wildcard cases
890+
{"single wildcard", "*", "anything", true},
891+
{"single wildcard empty", "*", "", true},
892+
{"prefix wildcard", "test*", "testing", true},
893+
{"prefix wildcard no match", "test*", "other", false},
894+
{"suffix wildcard", "*test", "mytest", true},
895+
{"suffix wildcard match", "*test", "testing", true},
896+
{"middle wildcard", "te*st", "test", true},
897+
{"middle wildcard complex", "te*st", "teXXXst", true},
898+
{"multiple wildcards", "*test*", "mytestcase", true},
899+
900+
// edge cases
901+
{"pattern with consecutive wildcards", "test**case", "testcase", true},
902+
{"pattern starting with wildcard", "*test", "test", true},
903+
{"pattern ending with wildcard", "test*", "test", true},
904+
{"complex pattern", "a*b*c", "aXbYc", true},
905+
{"complex pattern no match", "a*b*c", "aXcYb", false},
906+
907+
// security-related path examples
908+
{"API path wildcard", "/api/*/users", "/api/v1/users", true},
909+
{"health check path", "/health*", "/healthz", true},
910+
{"domain wildcard", "*.example.com", "api.example.com", true},
911+
}
912+
913+
for _, tt := range tests {
914+
t.Run(tt.name, func(t *testing.T) {
915+
if got := wildcardMatch(tt.pattern, tt.target); got != tt.want {
916+
t.Errorf("wildcardMatch(%q, %q) = %v, want %v",
917+
tt.pattern, tt.target, got, tt.want)
918+
}
919+
})
920+
}
921+
}
922+
923+
func Test_trimBearer(t *testing.T) {
924+
tests := []struct {
925+
name string
926+
in string
927+
want string
928+
}{
929+
{
930+
name: "Uppercase Bearer",
931+
in: "Bearer abc123",
932+
want: "abc123",
933+
},
934+
{
935+
name: "Lowercase bearer",
936+
in: "bearer xyz456",
937+
want: "xyz456",
938+
},
939+
{
940+
name: "Not a bearer prefix",
941+
in: "ABC abc123",
942+
want: "ABC abc123",
943+
},
944+
{
945+
name: "Too short string",
946+
in: "Beare ",
947+
want: "Beare ",
948+
},
949+
{
950+
name: "Bearer but no space after",
951+
in: "Bearerabc123",
952+
want: "Bearerabc123",
953+
},
954+
{
955+
name: "bearer but no token",
956+
in: "bearer ",
957+
want: "",
958+
},
959+
{
960+
name: "Bearer but no token",
961+
in: "Bearer ",
962+
want: "",
963+
},
964+
}
965+
966+
for _, tt := range tests {
967+
t.Run(tt.name, func(t *testing.T) {
968+
got := trimBearer(tt.in)
969+
if got != tt.want {
970+
t.Errorf("trimBearer(%q) = %q; want %q", tt.in, got, tt.want)
971+
}
972+
})
973+
}
974+
}

handler/transport.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func (t *transport) RoundTrip(r *http.Request) (*http.Response, error) {
6060
if len(r.URL.Path) != 0 { // prevent bypassing empty path on default config
6161
for _, urlPath := range t.cfg.OriginHealthCheckPaths {
6262
if urlPath == r.URL.Path || wildcardMatch(urlPath, r.URL.Path) {
63-
glg.Infof("Authorization checking skipped on: %s,\tby %s pattern", r.URL.Path, urlPath)
63+
glg.Infof("Authorization checking skipped on: %s by pattern %s", r.URL.Path, urlPath)
6464
r.TLS = nil
6565
startTime = time.Now()
6666
return t.RoundTripper.RoundTrip(r)

0 commit comments

Comments
 (0)