Skip to content
This repository has been archived by the owner on Sep 21, 2022. It is now read-only.

Commit

Permalink
Merge pull request vitessio#4657 from systay/fix4652
Browse files Browse the repository at this point in the history
Fixes vitessio#4652 - Chaining grpc interceptors
  • Loading branch information
sougou authored Feb 26, 2019
2 parents 834cafa + 9cbaed3 commit 667eb19
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 13 deletions.
59 changes: 46 additions & 13 deletions go/vt/servenv/grpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ package servenv
import (
"flag"
"fmt"
"net"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials"

"math"
"net"
"time"

"github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/go-grpc-prometheus"

"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"

"golang.org/x/net/context"
"vitess.io/vitess/go/vt/grpccommon"
"vitess.io/vitess/go/vt/log"
"vitess.io/vitess/go/vt/vttls"
Expand Down Expand Up @@ -162,6 +162,15 @@ func createGRPCServer() {
opts = append(opts, grpc.KeepaliveParams(ka))
}

opts = append(opts, interceptors()...)

GRPCServer = grpc.NewServer(opts...)
}

// We can only set a ServerInterceptor once, so we chain multiple interceptors into one
func interceptors() []grpc.ServerOption {
interceptors := &InterceptorBuilder{}

if *GRPCAuth != "" {
log.Infof("enabling auth plugin %v", *GRPCAuth)
pluginInitializer := GetAuthenticator(*GRPCAuth)
Expand All @@ -170,16 +179,20 @@ func createGRPCServer() {
log.Fatalf("Failed to load auth plugin: %v", err)
}
authPlugin = authPluginImpl
opts = append(opts, grpc.StreamInterceptor(streamInterceptor))
opts = append(opts, grpc.UnaryInterceptor(unaryInterceptor))
interceptors.Add(authenticatingStreamInterceptor, authenticatingUnaryInterceptor)
}

if *grpccommon.EnableGRPCPrometheus {
opts = append(opts, grpc.StreamInterceptor(grpc_prometheus.StreamServerInterceptor))
opts = append(opts, grpc.UnaryInterceptor(grpc_prometheus.UnaryServerInterceptor))
interceptors.Add(grpc_prometheus.StreamServerInterceptor, grpc_prometheus.UnaryServerInterceptor)
}

GRPCServer = grpc.NewServer(opts...)
if interceptors.NonEmpty() {
return []grpc.ServerOption{
grpc.StreamInterceptor(interceptors.StreamServerInterceptor),
grpc.UnaryInterceptor(interceptors.UnaryStreamInterceptor)}
} else {
return []grpc.ServerOption{}
}
}

func serveGRPC() {
Expand Down Expand Up @@ -227,7 +240,7 @@ func GRPCCheckServiceMap(name string) bool {
return CheckServiceMap("grpc", name)
}

func streamInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
func authenticatingStreamInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
newCtx, err := authPlugin.Authenticate(stream.Context(), info.FullMethod)

if err != nil {
Expand All @@ -239,7 +252,7 @@ func streamInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.Str
return handler(srv, wrapped)
}

func unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
func authenticatingUnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
newCtx, err := authPlugin.Authenticate(ctx, info.FullMethod)
if err != nil {
return nil, err
Expand All @@ -266,3 +279,23 @@ func WrapServerStream(stream grpc.ServerStream) *WrappedServerStream {
}
return &WrappedServerStream{ServerStream: stream, WrappedContext: stream.Context()}
}

// InterceptorBuilder chains together multiple ServerInterceptors
type InterceptorBuilder struct {
StreamServerInterceptor grpc.StreamServerInterceptor
UnaryStreamInterceptor grpc.UnaryServerInterceptor
}

func (collector *InterceptorBuilder) Add(s grpc.StreamServerInterceptor, u grpc.UnaryServerInterceptor) {
if collector.StreamServerInterceptor == nil {
collector.StreamServerInterceptor = s
collector.UnaryStreamInterceptor = u
} else {
collector.StreamServerInterceptor = grpc_middleware.ChainStreamServer(collector.StreamServerInterceptor, s)
collector.UnaryStreamInterceptor = grpc_middleware.ChainUnaryServer(collector.UnaryStreamInterceptor, u)
}
}

func (collector *InterceptorBuilder) NonEmpty() bool {
return collector.StreamServerInterceptor != nil
}
100 changes: 100 additions & 0 deletions go/vt/servenv/grpc_server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
Copyright 2017 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package servenv

import (
"testing"

"golang.org/x/net/context"

"google.golang.org/grpc"
)

func TestEmpty(t *testing.T) {
interceptors := &InterceptorBuilder{}
if interceptors.NonEmpty() {
t.Fatalf("expected empty builder to report as empty")
}
}

func TestSingleInterceptor(t *testing.T) {
interceptors := &InterceptorBuilder{}
fake := &FakeInterceptor{}

interceptors.Add(fake.StreamServerInterceptor, fake.UnaryServerInterceptor)

if !interceptors.NonEmpty() {
t.Fatalf("non-empty collector claims to have stuff")
}

_ = interceptors.StreamServerInterceptor(42, nil, nil, nullSHandler)
_, _ = interceptors.UnaryStreamInterceptor(context.Background(), 666, nil, nullUHandler)

assertEquals(t, fake.streamSeen, 42)
assertEquals(t, fake.unarySeen, 666)
}

func TestDoubleInterceptor(t *testing.T) {
interceptors := &InterceptorBuilder{}
fake1 := &FakeInterceptor{name: "ettan"}
fake2 := &FakeInterceptor{name: "tvaon"}

interceptors.Add(fake1.StreamServerInterceptor, fake1.UnaryServerInterceptor)
interceptors.Add(fake2.StreamServerInterceptor, fake2.UnaryServerInterceptor)

if !interceptors.NonEmpty() {
t.Fatalf("non-empty collector claims to have stuff")
}

_ = interceptors.StreamServerInterceptor(42, nil, nil, nullSHandler)
_, _ = interceptors.UnaryStreamInterceptor(context.Background(), 666, nil, nullUHandler)

assertEquals(t, fake1.streamSeen, 42)
assertEquals(t, fake1.unarySeen, 666)
assertEquals(t, fake2.streamSeen, 42)
assertEquals(t, fake2.unarySeen, 666)
}

func nullSHandler(_ interface{}, _ grpc.ServerStream) error {
return nil
}

func nullUHandler(_ context.Context, req interface{}) (interface{}, error) {
return req, nil
}

func assertEquals(t *testing.T, a, b interface{}) {
if a != b {
t.Errorf("expected %v but got %v", a, b)
}
}

type FakeInterceptor struct {
name string
streamSeen interface{}
unarySeen interface{}
}

func (fake *FakeInterceptor) StreamServerInterceptor(value interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
fake.streamSeen = value
return handler(value, stream)
}

func (fake *FakeInterceptor) UnaryServerInterceptor(ctx context.Context, value interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
fake.unarySeen = value
return handler(ctx, value)
}
6 changes: 6 additions & 0 deletions vendor/vendor.json
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,12 @@
"revision": "2d1e4548da234d9cb742cc3628556fef86aafbac",
"revisionTime": "2016-09-12T15:30:41Z"
},
{
"checksumSHA1": "Naa1qU7ykpIyDUZktjbqAU3V6bY=",
"path": "github.com/grpc-ecosystem/go-grpc-middleware",
"revision": "f849b5445de4819127e123ca96ba0eeb62b5e479",
"revisionTime": "2019-01-18T09:38:23Z"
},
{
"checksumSHA1": "9dP53doJ/haDqTJyD0iuv8g0XFs=",
"path": "github.com/grpc-ecosystem/go-grpc-prometheus",
Expand Down

0 comments on commit 667eb19

Please sign in to comment.