Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: get gRPC descriptor via server reflection #210

Merged
merged 4 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: get gRPC descriptor via server reflection
Signed-off-by: Ink33 <Ink33@smlk.org>
  • Loading branch information
Ink-33 committed Sep 6, 2023
commit a1b139139fd103c93baee95a1f1535ffc4eef2e0
145 changes: 128 additions & 17 deletions pkg/runner/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"context"
"fmt"
"io"
"regexp"
"strings"
"time"

Expand All @@ -33,17 +34,24 @@ import (
"github.com/tidwall/gjson"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/reflection/grpc_reflection_v1"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/dynamicpb"
)

type gRPCTestCaseRunner struct {
UnimplementedRunner
host string
proto testing.GRPCDesc
// fdCache sync.Map
}

var regexFullQualifiedName = regexp.MustCompile(`^([\w\.:]+)\/([\w\.]+)\/(\w+)$`)

func NewGRPCTestCaseRunner(host string, proto testing.GRPCDesc) TestCaseRunner {
runner := &gRPCTestCaseRunner{
UnimplementedRunner: NewDefaultUnimplementedRunner(),
Expand Down Expand Up @@ -75,22 +83,22 @@ func (r *gRPCTestCaseRunner) RunTestCase(testcase *testing.TestCase, dataContext
return
}

md, err := getMethodDescriptor(ctx, r, testcase)
if err != nil {
return nil, err
}

if err = runJob(testcase.Before); err != nil {
return
}

r.log.Info("start to send request to %s\n", testcase.Request.API)
conn, err := grpc.Dial(r.host, grpc.WithTransportCredentials(insecure.NewCredentials()))
conn, err := grpc.Dial(getHost(testcase.Request.API, r.host), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
defer conn.Close()

md, err := getMethodDescriptor(ctx, r, testcase, conn)
if err != nil {
return nil, err
}

payload := testcase.Request.Body
respsStr, err := invokeRequest(ctx, md, payload, conn)
if err != nil {
Expand Down Expand Up @@ -168,7 +176,42 @@ func buildResponses(resps []*dynamicpb.Message) ([]string, error) {
return respsStr, nil
}

func getMethodDescriptor(ctx context.Context, r *gRPCTestCaseRunner, testcase *testing.TestCase) (protoreflect.MethodDescriptor, error) {
func getMethodDescriptor(ctx context.Context, r *gRPCTestCaseRunner, testcase *testing.TestCase, conn *grpc.ClientConn) (protoreflect.MethodDescriptor, error) {
fullname, err := splitFullQualifiedName(testcase.Request.API)
if err != nil {
return nil, err
}

var dp protoreflect.Descriptor
// if fd, ok := r.fdCache.Load(fullname.Parent()); ok {
// fmt.Println("hit cache",fullname)
// return getMdFromFd(fd.(protoreflect.FileDescriptor), fullname)
// }

if r.proto.ServerReflection {
dp, err = getByReflect(ctx, r, fullname, conn)
} else {
if r.proto.ProtoFile == "" {
return nil, fmt.Errorf("missing descriptor source")
}
dp, err = getByProto(ctx, r, fullname)
}

if err != nil {
return nil, err
}

if dp.IsPlaceholder() {
return nil, fmt.Errorf("%s is not found", fullname)
Ink-33 marked this conversation as resolved.
Show resolved Hide resolved
Ink-33 marked this conversation as resolved.
Show resolved Hide resolved
}

if md, ok := dp.(protoreflect.MethodDescriptor); ok {
return md, nil
}
return nil, fmt.Errorf("%s is not a gRPC method", fullname)
}

func getByProto(ctx context.Context, r *gRPCTestCaseRunner, fullName protoreflect.FullName) (protoreflect.Descriptor, error) {
compiler := protocompile.Compiler{
Resolver: protocompile.WithStandardImports(
&protocompile.SourceResolver{
Expand All @@ -182,30 +225,98 @@ func getMethodDescriptor(ctx context.Context, r *gRPCTestCaseRunner, testcase *t
return nil, err
}

fd, err := linker.AsResolver().FindFileByPath(r.proto.ProtoFile)
dp, err := linker.AsResolver().FindDescriptorByName(fullName)
if err != nil {
return nil, err
}

api := splitServiceAndMethod(testcase.Request.API)
if len(api) != 2 {
return nil, fmt.Errorf("%s is not a valid gRPC api name", testcase.Request.API)
// r.fdCache.Store(fullName.Parent(), dp.ParentFile())
return dp, nil
}

func getByReflect(ctx context.Context, r *gRPCTestCaseRunner, fullName protoreflect.FullName, conn *grpc.ClientConn) (md protoreflect.Descriptor, err error) {
reflectconn := grpc_reflection_v1.NewServerReflectionClient(conn)
cli, err := reflectconn.ServerReflectionInfo(ctx)
if err != nil {
return nil, err
}

sd := fd.Services().ByName(protoreflect.Name(api[0]))
req := &grpc_reflection_v1.ServerReflectionRequest{
Host: "",
MessageRequest: &grpc_reflection_v1.ServerReflectionRequest_FileContainingSymbol{
FileContainingSymbol: string(fullName),
},
}

err = cli.Send(req)
if err != nil {
return nil, err
}

resp, err := cli.Recv()
if err != nil {
return nil, err
}
_ = cli.CloseSend()

if resp := resp.GetErrorResponse(); resp != nil {
return nil, fmt.Errorf(resp.GetErrorMessage())
}

fdresp := resp.GetFileDescriptorResponse()

for _, fdb := range fdresp.FileDescriptorProto {
fdp := &descriptorpb.FileDescriptorProto{}
if err := proto.Unmarshal(fdb, fdp); err != nil {
return nil, err
}

fd, err := protodesc.NewFile(fdp, nil)
if err != nil {
return nil, err
}

md, err = getMdFromFd(fd, fullName)
if err == nil {
// r.fdCache.Store(fullName.Parent(), fd)
return md, nil
}
}

return nil, fmt.Errorf("%s is not found", fullName)
}

func getMdFromFd(fd protoreflect.FileDescriptor, fullname protoreflect.FullName) (md protoreflect.MethodDescriptor, err error) {
sd := fd.Services().ByName(fullname.Parent().Name())
if sd == nil {
return nil, fmt.Errorf("grpc service %s is not found in proto %s", api[0], fd.Name())
return nil, fmt.Errorf("grpc service %s is not found in proto %s", fullname.Parent().Name(), fd.Name())
}

md := sd.Methods().ByName(protoreflect.Name(api[1]))
md = sd.Methods().ByName(fullname.Name())
if md == nil {
return nil, fmt.Errorf("method %s is not found in service %s", api[1], api[0])
return nil, fmt.Errorf("method %s is not found in service %s", fullname.Name(), fullname.Parent().Name())
}
return md, nil
}

func splitServiceAndMethod(api string) []string {
return strings.Split(api, ".")
func splitFullQualifiedName(api string) (protoreflect.FullName, error) {
qn := regexFullQualifiedName.FindStringSubmatch(api)
if len(qn) == 0 {
return "", fmt.Errorf("%s is not a valid gRPC api name", api)
}
fn := protoreflect.FullName(strings.Join(qn[2:], "."))
if !fn.IsValid() {
return "", fmt.Errorf("%s is not a valid gRPC api name", api)
}
return fn, nil
}

func getHost(api, fallback string) (host string) {
qn := regexFullQualifiedName.FindStringSubmatch(api)
if len(qn) == 0 {
return fallback
}
return qn[1]
}

func getMethodName(md protoreflect.MethodDescriptor) string {
Expand Down
37 changes: 36 additions & 1 deletion pkg/runner/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,39 @@ SOFTWARE.

package runner

// TODO
import (
"testing"

"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/reflect/protoreflect"
)

func TestAPINameMatch(t *testing.T) {
qn, err := splitFullQualifiedName("127.0.0.1:7070/server.Runner/GetVersion")
assert.NoError(t, err)
assert.Equal(t,
protoreflect.FullName("server.Runner.GetVersion"),
qn,
"match full qualified name",
)

qn, err = splitFullQualifiedName("127.0.0.1:7070/server.v1.service/method")
assert.NoError(t, err)
assert.Equal(t,
protoreflect.FullName("server.v1.service.method"),
qn,
"match full qualified name long",
)

qn, err = splitFullQualifiedName("127.0.0.1:7070//server.Runner/GetVersion")
assert.NotNil(t,
err,
"unexpect leading character",
)

qn, err = splitFullQualifiedName("127.0.0.1:7070/server.Runner/GetVersion/")
assert.NotNil(t,
err,
"unexpect trailing character",
)
}
8 changes: 4 additions & 4 deletions sample/grpc-sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ spec:
import:
- ./pkg/server
protofile: server.proto
serverReflection: false
serverReflection: true
Ink-33 marked this conversation as resolved.
Show resolved Hide resolved
items:
- name: GetVersion
request:
api: Runner.GetVersion
api: /server.Runner/GetVersion
- name: FunctionsQuery
request:
api: Runner.FunctionsQuery
api: /server.Runner/FunctionsQuery
body: |
{
"name": "hello"
Expand All @@ -32,7 +32,7 @@ items:
}
- name: FunctionsQueryStream
request:
api: Runner.FunctionsQueryStream
api: /server.Runner/FunctionsQueryStream
body: |
[
{
Expand Down