diff --git a/pkg/runner/trpc.go b/pkg/runner/trpc.go index b93b09ee..b1cb214f 100644 --- a/pkg/runner/trpc.go +++ b/pkg/runner/trpc.go @@ -25,6 +25,7 @@ import ( "encoding/json" "fmt" "os" + "strings" "time" "github.com/linuxsuren/api-testing/pkg/testing" @@ -83,8 +84,8 @@ func (r *tRPCTestCaseRunner) RunTestCase(testcase *testing.TestCase, dataContext r.log.Info("start to send request to %s\n", testcase.Request.API) - var fd *descriptor.FileDescriptor - fd, md, err := getTRPCMethodDescriptor(r.proto, testcase) + var service string + service, md, err := getTRPCMethodDescriptor(r.proto, testcase) if err != nil { if err == protoregistry.NotFound { return nil, fmt.Errorf("API %q is not found", testcase.Request.API) @@ -96,7 +97,7 @@ func (r *tRPCTestCaseRunner) RunTestCase(testcase *testing.TestCase, dataContext } payload := testcase.Request.Body - resp, err := invokeTRPCRequest(ctx, r.cc, fd, md, payload, r.host) + resp, err := invokeTRPCRequest(ctx, r.cc, service, md, payload, r.host) if err != nil { return nil, err } @@ -120,7 +121,7 @@ func (r *tRPCTestCaseRunner) GetResponseRecord() SimpleResponse { return r.response } -func getTRPCMethodDescriptor(proto testing.RPCDesc, testcase *testing.TestCase) (fd *descriptor.FileDescriptor, d *descriptor.RPCDescriptor, err error) { +func getTRPCMethodDescriptor(proto testing.RPCDesc, testcase *testing.TestCase) (service string, d *descriptor.RPCDescriptor, err error) { opts := []parser.Option{ parser.WithAliasOn(false), parser.WithAPPName(""), @@ -146,22 +147,38 @@ func getTRPCMethodDescriptor(proto testing.RPCDesc, testcase *testing.TestCase) proto.ProtoFile = tempF.Name() } + var fd *descriptor.FileDescriptor + var method string + service, method = splitServiceAndMethod(testcase.Request.API) if fd, err = parser.Parse( proto.ProtoFile, []string{}, 0, opts..., ); err == nil { - d = fd.Services[0].MethodRPC[testcase.Request.API] + for _, svc := range fd.Services { + if svc.Name == service { + d = svc.MethodRPC[method] + break + } + } + } + return +} + +func splitServiceAndMethod(api string) (service, method string) { + parts := strings.Split(api, "/") + if len(parts) >= 2 { + service = parts[len(parts)-2] + method = parts[len(parts)-1] } return } -func invokeTRPCRequest(ctx context.Context, cc client.Client, fd *descriptor.FileDescriptor, md *descriptor.RPCDescriptor, payload string, host string) ( +func invokeTRPCRequest(ctx context.Context, cc client.Client, serviceName string, md *descriptor.RPCDescriptor, payload string, host string) ( resp map[string]string, err error) { ctx, msg := codec.WithCloneMessage(ctx) defer codec.PutBackMessage(msg) - serviceName := fd.Services[0].Name msg.WithClientRPCName(fmt.Sprintf("/%s.%s/%s", md.RequestTypePkgDirective, serviceName, md.Name)) msg.WithCalleeServiceName(md.RequestTypePkgDirective + "." + serviceName) diff --git a/pkg/runner/trpc_test.go b/pkg/runner/trpc_test.go index f05d751f..22b196e8 100644 --- a/pkg/runner/trpc_test.go +++ b/pkg/runner/trpc_test.go @@ -41,7 +41,7 @@ func TestTRPC(t *testing.T) { testcase := &atest.TestCase{ Name: "Unary", Request: atest.Request{ - API: "Unary", + API: "/Main/Unary", Body: "{}", }, } @@ -54,7 +54,7 @@ func TestTRPC(t *testing.T) { _, err := tRPCRunner.RunTestCase(&atest.TestCase{ Name: "Fake", Request: atest.Request{ - API: "Fake", + API: "/Main/Fake", Body: "{}", }, }, nil, context.Background())